Merge branch 'main' into feat/mcp-06-18

This commit is contained in:
Novice
2025-10-20 10:29:09 +08:00
542 changed files with 11548 additions and 7438 deletions

View File

@ -447,6 +447,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
"message_id": message.id,
"context": context,
"variable_loader": variable_loader,
"workflow_execution_repository": workflow_execution_repository,
"workflow_node_execution_repository": workflow_node_execution_repository,
},
)
@ -466,8 +468,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation,
message=message,
user=user,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
stream=stream,
draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from, account=user),
)
@ -483,6 +483,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
message_id: str,
context: contextvars.Context,
variable_loader: VariableLoader,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
):
"""
Generate worker in a new thread.
@ -538,6 +540,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow=workflow,
system_user_id=system_user_id,
app=app,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
)
try:
@ -570,8 +574,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
@ -584,7 +586,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
:param message: message
:param user: account or end user
:param stream: is stream
:param workflow_node_execution_repository: optional repository for workflow node execution
:return:
"""
# init generate task pipeline
@ -596,8 +597,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
message=message,
user=user,
dialogue_count=self._dialogue_count,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
stream=stream,
draft_var_saver_factory=draft_var_saver_factory,
)

View File

@ -23,8 +23,12 @@ from core.app.features.annotation_reply.annotation_reply import AnnotationReplyF
from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
from core.variables.variables import VariableUnion
from core.workflow.entities import GraphRuntimeState, VariablePool
from core.workflow.enums import WorkflowType
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
@ -55,6 +59,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
workflow: Workflow,
system_user_id: str,
app: App,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
):
super().__init__(
queue_manager=queue_manager,
@ -68,11 +74,24 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
self._workflow = workflow
self.system_user_id = system_user_id
self._app = app
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
def run(self):
app_config = self.application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config)
system_inputs = SystemVariable(
query=self.application_generate_entity.query,
files=self.application_generate_entity.files,
conversation_id=self.conversation.id,
user_id=self.system_user_id,
dialogue_count=self._dialogue_count,
app_id=app_config.app_id,
workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_run_id,
)
with Session(db.engine, expire_on_commit=False) as session:
app_record = session.scalar(select(App).where(App.id == app_config.app_id))
@ -89,7 +108,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
else:
inputs = self.application_generate_entity.inputs
query = self.application_generate_entity.query
files = self.application_generate_entity.files
# moderation
if self.handle_input_moderation(
@ -114,17 +132,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
conversation_variables = self._initialize_conversation_variables()
# Create a variable pool.
system_inputs = SystemVariable(
query=query,
files=files,
conversation_id=self.conversation.id,
user_id=self.system_user_id,
dialogue_count=self._dialogue_count,
app_id=app_config.app_id,
workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_run_id,
)
# init variable pool
variable_pool = VariablePool(
system_variables=system_inputs,
@ -172,6 +179,23 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
command_channel=command_channel,
)
self._queue_manager.graph_runtime_state = graph_runtime_state
persistence_layer = WorkflowPersistenceLayer(
application_generate_entity=self.application_generate_entity,
workflow_info=PersistenceWorkflowInfo(
workflow_id=self._workflow.id,
workflow_type=WorkflowType(self._workflow.type),
version=self._workflow.version,
graph_data=self._workflow.graph_dict,
),
workflow_execution_repository=self._workflow_execution_repository,
workflow_node_execution_repository=self._workflow_node_execution_repository,
trace_manager=self.application_generate_entity.trace_manager,
)
workflow_entry.graph_engine.layer(persistence_layer)
generator = workflow_entry.run()
for event in generator:

View File

@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
@ -60,25 +61,21 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_runtime.entities.llm_entities import LLMUsage
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities import GraphRuntimeState
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.nodes import NodeType
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import Conversation, EndUser, Message, MessageFile
from models.account import Account
from models import Account, Conversation, EndUser, Message, MessageFile
from models.enums import CreatorUserRole
from models.workflow import Workflow
logger = logging.getLogger(__name__)
class AdvancedChatAppGenerateTaskPipeline:
class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
"""
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
@ -93,8 +90,6 @@ class AdvancedChatAppGenerateTaskPipeline:
user: Union[Account, EndUser],
stream: bool,
dialogue_count: int,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
draft_var_saver_factory: DraftVariableSaverFactory,
):
self._base_task_pipeline = BasedGenerateTaskPipeline(
@ -114,31 +109,20 @@ class AdvancedChatAppGenerateTaskPipeline:
else:
raise NotImplementedError(f"User type not supported: {type(user)}")
self._workflow_cycle_manager = WorkflowCycleManager(
application_generate_entity=application_generate_entity,
workflow_system_variables=SystemVariable(
query=message.query,
files=application_generate_entity.files,
conversation_id=conversation.id,
user_id=user_session_id,
dialogue_count=dialogue_count,
app_id=application_generate_entity.app_config.app_id,
workflow_id=workflow.id,
workflow_execution_id=application_generate_entity.workflow_run_id,
),
workflow_info=CycleManagerWorkflowInfo(
workflow_id=workflow.id,
workflow_type=WorkflowType(workflow.type),
version=workflow.version,
graph_data=workflow.graph_dict,
),
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
self._workflow_system_variables = SystemVariable(
query=message.query,
files=application_generate_entity.files,
conversation_id=conversation.id,
user_id=user_session_id,
dialogue_count=dialogue_count,
app_id=application_generate_entity.app_config.app_id,
workflow_id=workflow.id,
workflow_execution_id=application_generate_entity.workflow_run_id,
)
self._workflow_response_converter = WorkflowResponseConverter(
application_generate_entity=application_generate_entity,
user=user,
system_variables=self._workflow_system_variables,
)
self._task_state = WorkflowTaskState()
@ -157,6 +141,8 @@ class AdvancedChatAppGenerateTaskPipeline:
self._recorded_files: list[Mapping[str, Any]] = []
self._workflow_run_id: str = ""
self._draft_var_saver_factory = draft_var_saver_factory
self._graph_runtime_state: GraphRuntimeState | None = None
self._seed_graph_runtime_state_from_queue_manager()
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
"""
@ -289,12 +275,6 @@ class AdvancedChatAppGenerateTaskPipeline:
if not self._workflow_run_id:
raise ValueError("workflow run not initialized.")
def _ensure_graph_runtime_initialized(self, graph_runtime_state: GraphRuntimeState | None) -> GraphRuntimeState:
"""Fluent validation for graph runtime state."""
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
return graph_runtime_state
def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]:
"""Handle ping events."""
yield self._base_task_pipeline.ping_stream_response()
@ -305,21 +285,28 @@ class AdvancedChatAppGenerateTaskPipeline:
err = self._base_task_pipeline.handle_error(event=event, session=session, message_id=self._message_id)
yield self._base_task_pipeline.error_to_stream_response(err)
def _handle_workflow_started_event(self, *args, **kwargs) -> Generator[StreamResponse, None, None]:
def _handle_workflow_started_event(
self,
event: QueueWorkflowStartedEvent,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow started events."""
with self._database_session() as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
self._workflow_run_id = workflow_execution.id_
runtime_state = self._resolve_graph_runtime_state()
run_id = self._extract_workflow_run_id(runtime_state)
self._workflow_run_id = run_id
with self._database_session() as session:
message = self._get_message(session=session)
if not message:
raise ValueError(f"Message not found: {self._message_id}")
message.workflow_run_id = workflow_execution.id_
workflow_start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
message.workflow_run_id = run_id
workflow_start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run_id=run_id,
workflow_id=self._workflow_id,
)
yield workflow_start_resp
@ -327,13 +314,9 @@ class AdvancedChatAppGenerateTaskPipeline:
"""Handle node retry events."""
self._ensure_workflow_initialized()
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
workflow_execution_id=self._workflow_run_id, event=event
)
node_retry_resp = self._workflow_response_converter.workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if node_retry_resp:
@ -345,14 +328,9 @@ class AdvancedChatAppGenerateTaskPipeline:
"""Handle node started events."""
self._ensure_workflow_initialized()
workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start(
workflow_execution_id=self._workflow_run_id, event=event
)
node_start_resp = self._workflow_response_converter.workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if node_start_resp:
@ -368,14 +346,12 @@ class AdvancedChatAppGenerateTaskPipeline:
self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {})
)
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event)
node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
self._save_output_for_event(event, workflow_node_execution.id)
self._save_output_for_event(event, event.node_execution_id)
if node_finish_resp:
yield node_finish_resp
@ -386,16 +362,13 @@ class AdvancedChatAppGenerateTaskPipeline:
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle various node failure events."""
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed(event=event)
node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if isinstance(event, QueueNodeExceptionEvent):
self._save_output_for_event(event, workflow_node_execution.id)
self._save_output_for_event(event, event.node_execution_id)
if node_finish_resp:
yield node_finish_resp
@ -505,29 +478,19 @@ class AdvancedChatAppGenerateTaskPipeline:
self,
event: QueueWorkflowSucceededEvent,
*,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow succeeded events."""
_ = trace_manager
self._ensure_workflow_initialized()
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
with self._database_session() as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success(
workflow_run_id=self._workflow_run_id,
total_tokens=validated_state.total_tokens,
total_steps=validated_state.node_run_steps,
outputs=event.outputs,
conversation_id=self._conversation_id,
trace_manager=trace_manager,
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
)
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
validated_state = self._ensure_graph_runtime_initialized()
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_id=self._workflow_id,
status=WorkflowExecutionStatus.SUCCEEDED,
graph_runtime_state=validated_state,
)
yield workflow_finish_resp
self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
@ -536,30 +499,20 @@ class AdvancedChatAppGenerateTaskPipeline:
self,
event: QueueWorkflowPartialSuccessEvent,
*,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow partial success events."""
_ = trace_manager
self._ensure_workflow_initialized()
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
with self._database_session() as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success(
workflow_run_id=self._workflow_run_id,
total_tokens=validated_state.total_tokens,
total_steps=validated_state.node_run_steps,
outputs=event.outputs,
exceptions_count=event.exceptions_count,
conversation_id=self._conversation_id,
trace_manager=trace_manager,
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
)
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
validated_state = self._ensure_graph_runtime_initialized()
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_id=self._workflow_id,
status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
graph_runtime_state=validated_state,
exceptions_count=event.exceptions_count,
)
yield workflow_finish_resp
self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
@ -568,32 +521,25 @@ class AdvancedChatAppGenerateTaskPipeline:
self,
event: QueueWorkflowFailedEvent,
*,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow failed events."""
_ = trace_manager
self._ensure_workflow_initialized()
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
validated_state = self._ensure_graph_runtime_initialized()
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_id=self._workflow_id,
status=WorkflowExecutionStatus.FAILED,
graph_runtime_state=validated_state,
error=event.error,
exceptions_count=event.exceptions_count,
)
with self._database_session() as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
workflow_run_id=self._workflow_run_id,
total_tokens=validated_state.total_tokens,
total_steps=validated_state.node_run_steps,
status=WorkflowExecutionStatus.FAILED,
error_message=event.error,
conversation_id=self._conversation_id,
trace_manager=trace_manager,
exceptions_count=event.exceptions_count,
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
)
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_execution.error_message}"))
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {event.error}"))
err = self._base_task_pipeline.handle_error(event=err_event, session=session, message_id=self._message_id)
yield workflow_finish_resp
@ -608,25 +554,23 @@ class AdvancedChatAppGenerateTaskPipeline:
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle stop events."""
if self._workflow_run_id and graph_runtime_state:
_ = trace_manager
resolved_state = None
if self._workflow_run_id:
resolved_state = self._resolve_graph_runtime_state(graph_runtime_state)
if self._workflow_run_id and resolved_state:
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_id=self._workflow_id,
status=WorkflowExecutionStatus.STOPPED,
graph_runtime_state=resolved_state,
error=event.get_stop_reason(),
)
with self._database_session() as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
workflow_run_id=self._workflow_run_id,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowExecutionStatus.STOPPED,
error_message=event.get_stop_reason(),
conversation_id=self._conversation_id,
trace_manager=trace_manager,
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
)
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
# Save message
self._save_message(session=session, graph_runtime_state=graph_runtime_state)
self._save_message(session=session, graph_runtime_state=resolved_state)
yield workflow_finish_resp
elif event.stopped_by in (
@ -648,7 +592,7 @@ class AdvancedChatAppGenerateTaskPipeline:
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle advanced chat message end events."""
self._ensure_graph_runtime_initialized(graph_runtime_state)
resolved_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
output_moderation_answer = self._base_task_pipeline.handle_output_moderation_when_task_finished(
self._task_state.answer
@ -662,7 +606,7 @@ class AdvancedChatAppGenerateTaskPipeline:
# Save message
with self._database_session() as session:
self._save_message(session=session, graph_runtime_state=graph_runtime_state)
self._save_message(session=session, graph_runtime_state=resolved_state)
yield self._message_end_to_stream_response()
@ -671,10 +615,6 @@ class AdvancedChatAppGenerateTaskPipeline:
) -> Generator[StreamResponse, None, None]:
"""Handle retriever resources events."""
self._message_cycle_manager.handle_retriever_resources(event)
with self._database_session() as session:
message = self._get_message(session=session)
message.message_metadata = self._task_state.metadata.model_dump_json()
return
yield # Make this a generator
@ -683,10 +623,6 @@ class AdvancedChatAppGenerateTaskPipeline:
) -> Generator[StreamResponse, None, None]:
"""Handle annotation reply events."""
self._message_cycle_manager.handle_annotation_reply(event)
with self._database_session() as session:
message = self._get_message(session=session)
message.message_metadata = self._task_state.metadata.model_dump_json()
return
yield # Make this a generator
@ -740,7 +676,6 @@ class AdvancedChatAppGenerateTaskPipeline:
self,
event: Any,
*,
graph_runtime_state: GraphRuntimeState | None = None,
tts_publisher: AppGeneratorTTSPublisher | None = None,
trace_manager: TraceQueueManager | None = None,
queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None,
@ -753,7 +688,6 @@ class AdvancedChatAppGenerateTaskPipeline:
if handler := handlers.get(event_type):
yield from handler(
event,
graph_runtime_state=graph_runtime_state,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
@ -770,7 +704,6 @@ class AdvancedChatAppGenerateTaskPipeline:
):
yield from self._handle_node_failed_events(
event,
graph_runtime_state=graph_runtime_state,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
@ -789,15 +722,12 @@ class AdvancedChatAppGenerateTaskPipeline:
Process stream response using elegant Fluent Python patterns.
Maintains exact same functionality as original 57-if-statement version.
"""
# Initialize graph runtime state
graph_runtime_state: GraphRuntimeState | None = None
for queue_message in self._base_task_pipeline.queue_manager.listen():
event = queue_message.event
match event:
case QueueWorkflowStartedEvent():
graph_runtime_state = event.graph_runtime_state
self._resolve_graph_runtime_state()
yield from self._handle_workflow_started_event(event)
case QueueErrorEvent():
@ -805,15 +735,11 @@ class AdvancedChatAppGenerateTaskPipeline:
break
case QueueWorkflowFailedEvent():
yield from self._handle_workflow_failed_event(
event, graph_runtime_state=graph_runtime_state, trace_manager=trace_manager
)
yield from self._handle_workflow_failed_event(event, trace_manager=trace_manager)
break
case QueueStopEvent():
yield from self._handle_stop_event(
event, graph_runtime_state=graph_runtime_state, trace_manager=trace_manager
)
yield from self._handle_stop_event(event, graph_runtime_state=None, trace_manager=trace_manager)
break
# Handle all other events through elegant dispatch
@ -821,7 +747,6 @@ class AdvancedChatAppGenerateTaskPipeline:
if responses := list(
self._dispatch_event(
event,
graph_runtime_state=graph_runtime_state,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
@ -879,6 +804,12 @@ class AdvancedChatAppGenerateTaskPipeline:
else:
self._task_state.metadata.usage = LLMUsage.empty_usage()
def _seed_graph_runtime_state_from_queue_manager(self) -> None:
"""Bootstrap the cached runtime state from the queue manager when present."""
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state
if candidate is not None:
self._graph_runtime_state = candidate
def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
"""
Message end to stream response.

View File

@ -20,6 +20,7 @@ from core.app.entities.queue_entities import (
QueueStopEvent,
WorkflowQueueMessage,
)
from core.workflow.runtime import GraphRuntimeState
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
@ -47,6 +48,7 @@ class AppQueueManager:
q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
self._q = q
self._graph_runtime_state: GraphRuntimeState | None = None
self._stopped_cache: TTLCache[tuple, bool] = TTLCache(maxsize=1, ttl=1)
self._cache_lock = threading.Lock()
@ -109,6 +111,16 @@ class AppQueueManager:
"""
self.publish(QueueErrorEvent(error=e), pub_from)
@property
def graph_runtime_state(self) -> GraphRuntimeState | None:
"""Retrieve the attached graph runtime state, if available."""
return self._graph_runtime_state
@graph_runtime_state.setter
def graph_runtime_state(self, graph_runtime_state: GraphRuntimeState | None) -> None:
"""Attach the live graph runtime state reference for downstream consumers."""
self._graph_runtime_state = graph_runtime_state
def publish(self, event: AppQueueEvent, pub_from: PublishFrom):
"""
Publish event to queue

View File

@ -23,7 +23,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from factories import file_factory
from models.account import Account
from models import Account
from models.model import App, EndUser
from services.conversation_service import ConversationService

View File

@ -0,0 +1,55 @@
"""Shared helpers for managing GraphRuntimeState across task pipelines."""
from __future__ import annotations
from typing import TYPE_CHECKING
from core.workflow.runtime import GraphRuntimeState
if TYPE_CHECKING:
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
class GraphRuntimeStateSupport:
"""
Mixin that centralises common GraphRuntimeState access patterns used by task pipelines.
Subclasses are expected to provide:
* `_base_task_pipeline` exposing the queue manager with an optional cached runtime state.
* `_graph_runtime_state` attribute used as the local cache for the runtime state.
"""
_base_task_pipeline: BasedGenerateTaskPipeline
_graph_runtime_state: GraphRuntimeState | None = None
def _ensure_graph_runtime_initialized(
self,
graph_runtime_state: GraphRuntimeState | None = None,
) -> GraphRuntimeState:
"""Validate and return the active graph runtime state."""
return self._resolve_graph_runtime_state(graph_runtime_state)
def _extract_workflow_run_id(self, graph_runtime_state: GraphRuntimeState) -> str:
system_variables = graph_runtime_state.variable_pool.system_variables
if not system_variables or not system_variables.workflow_execution_id:
raise ValueError("workflow_execution_id missing from runtime state")
return str(system_variables.workflow_execution_id)
def _resolve_graph_runtime_state(
self,
graph_runtime_state: GraphRuntimeState | None = None,
) -> GraphRuntimeState:
"""Return the cached runtime state or bootstrap it from the queue manager."""
if graph_runtime_state is not None:
self._graph_runtime_state = graph_runtime_state
return graph_runtime_state
if self._graph_runtime_state is None:
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state
if candidate is not None:
self._graph_runtime_state = candidate
if self._graph_runtime_state is None:
raise ValueError("graph runtime state not initialized.")
return self._graph_runtime_state

View File

@ -1,9 +1,8 @@
import time
from collections.abc import Mapping, Sequence
from datetime import UTC, datetime
from typing import Any, Union
from sqlalchemy.orm import Session
from dataclasses import dataclass
from datetime import datetime
from typing import Any, NewType, Union
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
@ -39,16 +38,36 @@ from core.plugin.impl.datasource import PluginDatasourceManager
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.tool_manager import ToolManager
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.enums import (
NodeType,
SystemVariableKey,
WorkflowExecutionStatus,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.runtime import GraphRuntimeState
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_entry import WorkflowEntry
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from libs.datetime_utils import naive_utc_now
from models import (
Account,
EndUser,
)
from models import Account, EndUser
from services.variable_truncator import VariableTruncator
NodeExecutionId = NewType("NodeExecutionId", str)
@dataclass(slots=True)
class _NodeSnapshot:
"""In-memory cache for node metadata between start and completion events."""
title: str
index: int
start_at: datetime
iteration_id: str = ""
"""Empty string means the node is not executing inside an iteration."""
loop_id: str = ""
"""Empty string means the node is not executing inside a loop."""
class WorkflowResponseConverter:
def __init__(
@ -56,37 +75,151 @@ class WorkflowResponseConverter:
*,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
user: Union[Account, EndUser],
system_variables: SystemVariable,
):
self._application_generate_entity = application_generate_entity
self._user = user
self._system_variables = system_variables
self._workflow_inputs = self._prepare_workflow_inputs()
self._truncator = VariableTruncator.default()
self._node_snapshots: dict[NodeExecutionId, _NodeSnapshot] = {}
self._workflow_execution_id: str | None = None
self._workflow_started_at: datetime | None = None
# ------------------------------------------------------------------
# Workflow lifecycle helpers
# ------------------------------------------------------------------
def _prepare_workflow_inputs(self) -> Mapping[str, Any]:
inputs = dict(self._application_generate_entity.inputs)
for field_name, value in self._system_variables.to_dict().items():
# TODO(@future-refactor): store system variables separately from user inputs so we don't
# need to flatten `sys.*` entries into the input payload just for rerun/export tooling.
if field_name == SystemVariableKey.CONVERSATION_ID:
# Conversation IDs are session-scoped; omitting them keeps workflow inputs
# reusable without pinning new runs to a prior conversation.
continue
inputs[f"sys.{field_name}"] = value
handled = WorkflowEntry.handle_special_values(inputs)
return dict(handled or {})
def _ensure_workflow_run_id(self, workflow_run_id: str | None = None) -> str:
"""Return the memoized workflow run id, optionally seeding it during start events."""
if workflow_run_id is not None:
self._workflow_execution_id = workflow_run_id
if not self._workflow_execution_id:
raise ValueError("workflow_run_id missing before streaming workflow events")
return self._workflow_execution_id
# ------------------------------------------------------------------
# Node snapshot helpers
# ------------------------------------------------------------------
def _store_snapshot(self, event: QueueNodeStartedEvent) -> _NodeSnapshot:
snapshot = _NodeSnapshot(
title=event.node_title,
index=event.node_run_index,
start_at=event.start_at,
iteration_id=event.in_iteration_id or "",
loop_id=event.in_loop_id or "",
)
node_execution_id = NodeExecutionId(event.node_execution_id)
self._node_snapshots[node_execution_id] = snapshot
return snapshot
def _get_snapshot(self, node_execution_id: str) -> _NodeSnapshot | None:
return self._node_snapshots.get(NodeExecutionId(node_execution_id))
def _pop_snapshot(self, node_execution_id: str) -> _NodeSnapshot | None:
return self._node_snapshots.pop(NodeExecutionId(node_execution_id), None)
@staticmethod
def _merge_metadata(
base_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None,
snapshot: _NodeSnapshot | None,
) -> Mapping[WorkflowNodeExecutionMetadataKey, Any] | None:
if not base_metadata and not snapshot:
return base_metadata
merged: dict[WorkflowNodeExecutionMetadataKey, Any] = {}
if base_metadata:
merged.update(base_metadata)
if snapshot:
if snapshot.iteration_id:
merged[WorkflowNodeExecutionMetadataKey.ITERATION_ID] = snapshot.iteration_id
if snapshot.loop_id:
merged[WorkflowNodeExecutionMetadataKey.LOOP_ID] = snapshot.loop_id
return merged or None
def _truncate_mapping(
self,
mapping: Mapping[str, Any] | None,
) -> tuple[Mapping[str, Any] | None, bool]:
if mapping is None:
return None, False
if not mapping:
return {}, False
normalized = WorkflowEntry.handle_special_values(dict(mapping))
if normalized is None:
return None, False
truncated, is_truncated = self._truncator.truncate_variable_mapping(dict(normalized))
return truncated, is_truncated
@staticmethod
def _encode_outputs(outputs: Mapping[str, Any] | None) -> Mapping[str, Any] | None:
if outputs is None:
return None
converter = WorkflowRuntimeTypeConverter()
return converter.to_json_encodable(outputs)
def workflow_start_to_stream_response(
self,
*,
task_id: str,
workflow_execution: WorkflowExecution,
workflow_run_id: str,
workflow_id: str,
) -> WorkflowStartStreamResponse:
run_id = self._ensure_workflow_run_id(workflow_run_id)
started_at = naive_utc_now()
self._workflow_started_at = started_at
return WorkflowStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution.id_,
workflow_run_id=run_id,
data=WorkflowStartStreamResponse.Data(
id=workflow_execution.id_,
workflow_id=workflow_execution.workflow_id,
inputs=workflow_execution.inputs,
created_at=int(workflow_execution.started_at.timestamp()),
id=run_id,
workflow_id=workflow_id,
inputs=self._workflow_inputs,
created_at=int(started_at.timestamp()),
),
)
def workflow_finish_to_stream_response(
self,
*,
session: Session,
task_id: str,
workflow_execution: WorkflowExecution,
workflow_id: str,
status: WorkflowExecutionStatus,
graph_runtime_state: GraphRuntimeState,
error: str | None = None,
exceptions_count: int = 0,
) -> WorkflowFinishStreamResponse:
created_by = None
run_id = self._ensure_workflow_run_id()
started_at = self._workflow_started_at
if started_at is None:
raise ValueError(
"workflow_finish_to_stream_response called before workflow_start_to_stream_response",
)
finished_at = naive_utc_now()
elapsed_time = (finished_at - started_at).total_seconds()
outputs_mapping = graph_runtime_state.outputs or {}
encoded_outputs = WorkflowRuntimeTypeConverter().to_json_encodable(outputs_mapping)
created_by: Mapping[str, object] | None
user = self._user
if isinstance(user, Account):
created_by = {
@ -94,38 +227,29 @@ class WorkflowResponseConverter:
"name": user.name,
"email": user.email,
}
elif isinstance(user, EndUser):
else:
created_by = {
"id": user.id,
"user": user.session_id,
}
else:
raise NotImplementedError(f"User type not supported: {type(user)}")
# Handle the case where finished_at is None by using current time as default
finished_at_timestamp = (
int(workflow_execution.finished_at.timestamp())
if workflow_execution.finished_at
else int(datetime.now(UTC).timestamp())
)
return WorkflowFinishStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution.id_,
workflow_run_id=run_id,
data=WorkflowFinishStreamResponse.Data(
id=workflow_execution.id_,
workflow_id=workflow_execution.workflow_id,
status=workflow_execution.status,
outputs=WorkflowRuntimeTypeConverter().to_json_encodable(workflow_execution.outputs),
error=workflow_execution.error_message,
elapsed_time=workflow_execution.elapsed_time,
total_tokens=workflow_execution.total_tokens,
total_steps=workflow_execution.total_steps,
id=run_id,
workflow_id=workflow_id,
status=status.value,
outputs=encoded_outputs,
error=error,
elapsed_time=elapsed_time,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
created_by=created_by,
created_at=int(workflow_execution.started_at.timestamp()),
finished_at=finished_at_timestamp,
files=self.fetch_files_from_node_outputs(workflow_execution.outputs),
exceptions_count=workflow_execution.exceptions_count,
created_at=int(started_at.timestamp()),
finished_at=int(finished_at.timestamp()),
files=self.fetch_files_from_node_outputs(outputs_mapping),
exceptions_count=exceptions_count,
),
)
@ -134,38 +258,28 @@ class WorkflowResponseConverter:
*,
event: QueueNodeStartedEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> NodeStartStreamResponse | None:
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
if not workflow_node_execution.workflow_execution_id:
if event.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
run_id = self._ensure_workflow_run_id()
snapshot = self._store_snapshot(event)
response = NodeStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_execution_id,
workflow_run_id=run_id,
data=NodeStartStreamResponse.Data(
id=workflow_node_execution.id,
node_id=workflow_node_execution.node_id,
node_type=workflow_node_execution.node_type,
title=workflow_node_execution.title,
index=workflow_node_execution.index,
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.get_response_inputs(),
inputs_truncated=workflow_node_execution.inputs_truncated,
created_at=int(workflow_node_execution.created_at.timestamp()),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
id=event.node_execution_id,
node_id=event.node_id,
node_type=event.node_type,
title=snapshot.title,
index=snapshot.index,
created_at=int(snapshot.start_at.timestamp()),
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
parallel_run_id=event.parallel_mode_run_id,
agent_strategy=event.agent_strategy,
),
)
# extras logic
if event.node_type == NodeType.TOOL:
response.data.extras["icon"] = ToolManager.get_tool_icon(
tenant_id=self._application_generate_entity.app_config.tenant_id,
@ -189,41 +303,54 @@ class WorkflowResponseConverter:
*,
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeExceptionEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> NodeFinishStreamResponse | None:
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
if not workflow_node_execution.workflow_execution_id:
return None
if not workflow_node_execution.finished_at:
if event.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
run_id = self._ensure_workflow_run_id()
snapshot = self._pop_snapshot(event.node_execution_id)
json_converter = WorkflowRuntimeTypeConverter()
start_at = snapshot.start_at if snapshot else event.start_at
finished_at = naive_utc_now()
elapsed_time = (finished_at - start_at).total_seconds()
inputs, inputs_truncated = self._truncate_mapping(event.inputs)
process_data, process_data_truncated = self._truncate_mapping(event.process_data)
encoded_outputs = self._encode_outputs(event.outputs)
outputs, outputs_truncated = self._truncate_mapping(encoded_outputs)
metadata = self._merge_metadata(event.execution_metadata, snapshot)
if isinstance(event, QueueNodeSucceededEvent):
status = WorkflowNodeExecutionStatus.SUCCEEDED.value
error_message = event.error
elif isinstance(event, QueueNodeFailedEvent):
status = WorkflowNodeExecutionStatus.FAILED.value
error_message = event.error
else:
status = WorkflowNodeExecutionStatus.EXCEPTION.value
error_message = event.error
return NodeFinishStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_execution_id,
workflow_run_id=run_id,
data=NodeFinishStreamResponse.Data(
id=workflow_node_execution.id,
node_id=workflow_node_execution.node_id,
node_type=workflow_node_execution.node_type,
index=workflow_node_execution.index,
title=workflow_node_execution.title,
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.get_response_inputs(),
inputs_truncated=workflow_node_execution.inputs_truncated,
process_data=workflow_node_execution.get_response_process_data(),
process_data_truncated=workflow_node_execution.process_data_truncated,
outputs=json_converter.to_json_encodable(workflow_node_execution.get_response_outputs()),
outputs_truncated=workflow_node_execution.outputs_truncated,
status=workflow_node_execution.status,
error=workflow_node_execution.error,
elapsed_time=workflow_node_execution.elapsed_time,
execution_metadata=workflow_node_execution.metadata,
created_at=int(workflow_node_execution.created_at.timestamp()),
finished_at=int(workflow_node_execution.finished_at.timestamp()),
files=self.fetch_files_from_node_outputs(workflow_node_execution.outputs or {}),
parallel_id=event.parallel_id,
id=event.node_execution_id,
node_id=event.node_id,
node_type=event.node_type,
index=snapshot.index if snapshot else 0,
title=snapshot.title if snapshot else "",
inputs=inputs,
inputs_truncated=inputs_truncated,
process_data=process_data,
process_data_truncated=process_data_truncated,
outputs=outputs,
outputs_truncated=outputs_truncated,
status=status,
error=error_message,
elapsed_time=elapsed_time,
execution_metadata=metadata,
created_at=int(start_at.timestamp()),
finished_at=int(finished_at.timestamp()),
files=self.fetch_files_from_node_outputs(event.outputs or {}),
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
),
@ -234,44 +361,45 @@ class WorkflowResponseConverter:
*,
event: QueueNodeRetryEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Union[NodeRetryStreamResponse, NodeFinishStreamResponse] | None:
if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
if not workflow_node_execution.workflow_execution_id:
return None
if not workflow_node_execution.finished_at:
) -> NodeRetryStreamResponse | None:
if event.node_type in {NodeType.ITERATION, NodeType.LOOP}:
return None
run_id = self._ensure_workflow_run_id()
json_converter = WorkflowRuntimeTypeConverter()
snapshot = self._get_snapshot(event.node_execution_id)
if snapshot is None:
raise AssertionError("node retry event arrived without a stored snapshot")
finished_at = naive_utc_now()
elapsed_time = (finished_at - event.start_at).total_seconds()
inputs, inputs_truncated = self._truncate_mapping(event.inputs)
process_data, process_data_truncated = self._truncate_mapping(event.process_data)
encoded_outputs = self._encode_outputs(event.outputs)
outputs, outputs_truncated = self._truncate_mapping(encoded_outputs)
metadata = self._merge_metadata(event.execution_metadata, snapshot)
return NodeRetryStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_execution_id,
workflow_run_id=run_id,
data=NodeRetryStreamResponse.Data(
id=workflow_node_execution.id,
node_id=workflow_node_execution.node_id,
node_type=workflow_node_execution.node_type,
index=workflow_node_execution.index,
title=workflow_node_execution.title,
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.get_response_inputs(),
inputs_truncated=workflow_node_execution.inputs_truncated,
process_data=workflow_node_execution.get_response_process_data(),
process_data_truncated=workflow_node_execution.process_data_truncated,
outputs=json_converter.to_json_encodable(workflow_node_execution.get_response_outputs()),
outputs_truncated=workflow_node_execution.outputs_truncated,
status=workflow_node_execution.status,
error=workflow_node_execution.error,
elapsed_time=workflow_node_execution.elapsed_time,
execution_metadata=workflow_node_execution.metadata,
created_at=int(workflow_node_execution.created_at.timestamp()),
finished_at=int(workflow_node_execution.finished_at.timestamp()),
files=self.fetch_files_from_node_outputs(workflow_node_execution.outputs or {}),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
id=event.node_execution_id,
node_id=event.node_id,
node_type=event.node_type,
index=snapshot.index,
title=snapshot.title,
inputs=inputs,
inputs_truncated=inputs_truncated,
process_data=process_data,
process_data_truncated=process_data_truncated,
outputs=outputs,
outputs_truncated=outputs_truncated,
status=WorkflowNodeExecutionStatus.RETRY.value,
error=event.error,
elapsed_time=elapsed_time,
execution_metadata=metadata,
created_at=int(snapshot.start_at.timestamp()),
finished_at=int(finished_at.timestamp()),
files=self.fetch_files_from_node_outputs(event.outputs or {}),
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
retry_index=event.retry_index,
@ -379,8 +507,6 @@ class WorkflowResponseConverter:
inputs=new_inputs,
inputs_truncated=truncated,
metadata=event.metadata or {},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
),
)
@ -405,9 +531,6 @@ class WorkflowResponseConverter:
pre_loop_output={},
created_at=int(time.time()),
extras={},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parallel_mode_run_id=event.parallel_mode_run_id,
),
)
@ -446,8 +569,6 @@ class WorkflowResponseConverter:
execution_metadata=event.metadata,
finished_at=int(time.time()),
steps=event.steps,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
),
)

View File

@ -112,7 +112,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
metadata = {}
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
if isinstance(sub_stream_response, ErrorStreamResponse):
elif isinstance(sub_stream_response, ErrorStreamResponse):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:

View File

@ -207,6 +207,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
from_source=from_source,
from_end_user_id=end_user_id,
from_account_id=account_id,
app_mode=app_config.app_mode,
)
db.session.add(message)

View File

@ -352,6 +352,8 @@ class PipelineGenerator(BaseAppGenerator):
"application_generate_entity": application_generate_entity,
"workflow_thread_pool_id": workflow_thread_pool_id,
"variable_loader": variable_loader,
"workflow_execution_repository": workflow_execution_repository,
"workflow_node_execution_repository": workflow_node_execution_repository,
},
)
@ -367,8 +369,6 @@ class PipelineGenerator(BaseAppGenerator):
workflow=workflow,
queue_manager=queue_manager,
user=user,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
stream=streaming,
draft_var_saver_factory=draft_var_saver_factory,
)
@ -573,6 +573,8 @@ class PipelineGenerator(BaseAppGenerator):
queue_manager: AppQueueManager,
context: contextvars.Context,
variable_loader: VariableLoader,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
workflow_thread_pool_id: str | None = None,
) -> None:
"""
@ -620,6 +622,8 @@ class PipelineGenerator(BaseAppGenerator):
variable_loader=variable_loader,
workflow=workflow,
system_user_id=system_user_id,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
)
runner.run()
@ -648,8 +652,6 @@ class PipelineGenerator(BaseAppGenerator):
workflow: Workflow,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
@ -660,7 +662,6 @@ class PipelineGenerator(BaseAppGenerator):
:param queue_manager: queue manager
:param user: account or end user
:param stream: is stream
:param workflow_node_execution_repository: optional repository for workflow node execution
:return:
"""
# init generate task pipeline
@ -670,8 +671,6 @@ class PipelineGenerator(BaseAppGenerator):
queue_manager=queue_manager,
user=user,
stream=stream,
workflow_node_execution_repository=workflow_node_execution_repository,
workflow_execution_repository=workflow_execution_repository,
draft_var_saver_factory=draft_var_saver_factory,
)

View File

@ -11,11 +11,14 @@ from core.app.entities.app_invoke_entities import (
)
from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import WorkflowType
from core.workflow.graph import Graph
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.workflow.graph_events import GraphEngineEvent, GraphRunFailedEvent
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
@ -40,6 +43,8 @@ class PipelineRunner(WorkflowBasedAppRunner):
variable_loader: VariableLoader,
workflow: Workflow,
system_user_id: str,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
workflow_thread_pool_id: str | None = None,
) -> None:
"""
@ -56,6 +61,8 @@ class PipelineRunner(WorkflowBasedAppRunner):
self.workflow_thread_pool_id = workflow_thread_pool_id
self._workflow = workflow
self._sys_user_id = system_user_id
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
def _get_app_id(self) -> str:
return self.application_generate_entity.app_config.app_id
@ -163,6 +170,23 @@ class PipelineRunner(WorkflowBasedAppRunner):
variable_pool=variable_pool,
)
self._queue_manager.graph_runtime_state = graph_runtime_state
persistence_layer = WorkflowPersistenceLayer(
application_generate_entity=self.application_generate_entity,
workflow_info=PersistenceWorkflowInfo(
workflow_id=workflow.id,
workflow_type=WorkflowType(workflow.type),
version=workflow.version,
graph_data=workflow.graph_dict,
),
workflow_execution_repository=self._workflow_execution_repository,
workflow_node_execution_repository=self._workflow_node_execution_repository,
trace_manager=self.application_generate_entity.trace_manager,
)
workflow_entry.graph_engine.layer(persistence_layer)
generator = workflow_entry.run()
for event in generator:

View File

@ -231,6 +231,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
"queue_manager": queue_manager,
"context": context,
"variable_loader": variable_loader,
"workflow_execution_repository": workflow_execution_repository,
"workflow_node_execution_repository": workflow_node_execution_repository,
},
)
@ -244,8 +246,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow=workflow,
queue_manager=queue_manager,
user=user,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
draft_var_saver_factory=draft_var_saver_factory,
stream=streaming,
)
@ -424,6 +424,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
queue_manager: AppQueueManager,
context: contextvars.Context,
variable_loader: VariableLoader,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
) -> None:
"""
Generate worker in a new thread.
@ -465,6 +467,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
variable_loader=variable_loader,
workflow=workflow,
system_user_id=system_user_id,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
)
try:
@ -493,8 +497,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow: Workflow,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
@ -514,8 +516,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow=workflow,
queue_manager=queue_manager,
user=user,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
draft_var_saver_factory=draft_var_saver_factory,
stream=stream,
)

View File

@ -5,12 +5,13 @@ from typing import cast
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.entities.app_invoke_entities import (
InvokeFrom,
WorkflowAppGenerateEntity,
)
from core.workflow.entities import GraphRuntimeState, VariablePool
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.workflow.enums import WorkflowType
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
@ -34,6 +35,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
variable_loader: VariableLoader,
workflow: Workflow,
system_user_id: str,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
):
super().__init__(
queue_manager=queue_manager,
@ -43,6 +46,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
self.application_generate_entity = application_generate_entity
self._workflow = workflow
self._sys_user_id = system_user_id
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
def run(self):
"""
@ -51,6 +56,14 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
app_config = self.application_generate_entity.app_config
app_config = cast(WorkflowAppConfig, app_config)
system_inputs = SystemVariable(
files=self.application_generate_entity.files,
user_id=self._sys_user_id,
app_id=app_config.app_id,
workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
)
# if only single iteration or single loop run is requested
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
@ -60,18 +73,9 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
)
else:
inputs = self.application_generate_entity.inputs
files = self.application_generate_entity.files
# Create a variable pool.
system_inputs = SystemVariable(
files=files,
user_id=self._sys_user_id,
app_id=app_config.app_id,
workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
)
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
@ -96,6 +100,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
channel_key = f"workflow:{task_id}:commands"
command_channel = RedisChannel(redis_client, channel_key)
self._queue_manager.graph_runtime_state = graph_runtime_state
workflow_entry = WorkflowEntry(
tenant_id=self._workflow.tenant_id,
app_id=self._workflow.app_id,
@ -115,6 +121,21 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
command_channel=command_channel,
)
persistence_layer = WorkflowPersistenceLayer(
application_generate_entity=self.application_generate_entity,
workflow_info=PersistenceWorkflowInfo(
workflow_id=self._workflow.id,
workflow_type=WorkflowType(self._workflow.type),
version=self._workflow.version,
graph_data=self._workflow.graph_dict,
),
workflow_execution_repository=self._workflow_execution_repository,
workflow_node_execution_repository=self._workflow_node_execution_repository,
trace_manager=self.application_generate_entity.trace_manager,
)
workflow_entry.graph_engine.layer(persistence_layer)
generator = workflow_entry.run()
for event in generator:

View File

@ -8,11 +8,9 @@ from sqlalchemy.orm import Session
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.entities.app_invoke_entities import (
InvokeFrom,
WorkflowAppGenerateEntity,
)
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
AppQueueEvent,
MessageQueueMessage,
@ -53,27 +51,20 @@ from core.app.entities.task_entities import (
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities import GraphRuntimeState, WorkflowExecution
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
from extensions.ext_database import db
from models.account import Account
from models import Account
from models.enums import CreatorUserRole
from models.model import EndUser
from models.workflow import (
Workflow,
WorkflowAppLog,
WorkflowAppLogCreatedFrom,
)
from models.workflow import Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom
logger = logging.getLogger(__name__)
class WorkflowAppGenerateTaskPipeline:
class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
"""
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
@ -85,8 +76,6 @@ class WorkflowAppGenerateTaskPipeline:
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
stream: bool,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
draft_var_saver_factory: DraftVariableSaverFactory,
):
self._base_task_pipeline = BasedGenerateTaskPipeline(
@ -99,42 +88,30 @@ class WorkflowAppGenerateTaskPipeline:
self._user_id = user.id
user_session_id = user.session_id
self._created_by_role = CreatorUserRole.END_USER
elif isinstance(user, Account):
else:
self._user_id = user.id
user_session_id = user.id
self._created_by_role = CreatorUserRole.ACCOUNT
else:
raise ValueError(f"Invalid user type: {type(user)}")
self._workflow_cycle_manager = WorkflowCycleManager(
application_generate_entity=application_generate_entity,
workflow_system_variables=SystemVariable(
files=application_generate_entity.files,
user_id=user_session_id,
app_id=application_generate_entity.app_config.app_id,
workflow_id=workflow.id,
workflow_execution_id=application_generate_entity.workflow_execution_id,
),
workflow_info=CycleManagerWorkflowInfo(
workflow_id=workflow.id,
workflow_type=WorkflowType(workflow.type),
version=workflow.version,
graph_data=workflow.graph_dict,
),
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
)
self._workflow_response_converter = WorkflowResponseConverter(
application_generate_entity=application_generate_entity,
user=user,
)
self._application_generate_entity = application_generate_entity
self._workflow_features_dict = workflow.features_dict
self._workflow_run_id = ""
self._workflow_execution_id = ""
self._invoke_from = queue_manager.invoke_from
self._draft_var_saver_factory = draft_var_saver_factory
self._workflow = workflow
self._workflow_system_variables = SystemVariable(
files=application_generate_entity.files,
user_id=user_session_id,
app_id=application_generate_entity.app_config.app_id,
workflow_id=workflow.id,
workflow_execution_id=application_generate_entity.workflow_execution_id,
)
self._workflow_response_converter = WorkflowResponseConverter(
application_generate_entity=application_generate_entity,
user=user,
system_variables=self._workflow_system_variables,
)
self._graph_runtime_state: GraphRuntimeState | None = self._base_task_pipeline.queue_manager.graph_runtime_state
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
@ -261,15 +238,9 @@ class WorkflowAppGenerateTaskPipeline:
def _ensure_workflow_initialized(self):
"""Fluent validation for workflow state."""
if not self._workflow_run_id:
if not self._workflow_execution_id:
raise ValueError("workflow run not initialized.")
def _ensure_graph_runtime_initialized(self, graph_runtime_state: GraphRuntimeState | None) -> GraphRuntimeState:
"""Fluent validation for graph runtime state."""
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
return graph_runtime_state
def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]:
"""Handle ping events."""
yield self._base_task_pipeline.ping_stream_response()
@ -283,12 +254,14 @@ class WorkflowAppGenerateTaskPipeline:
self, event: QueueWorkflowStartedEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle workflow started events."""
# init workflow run
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
self._workflow_run_id = workflow_execution.id_
runtime_state = self._resolve_graph_runtime_state()
run_id = self._extract_workflow_run_id(runtime_state)
self._workflow_execution_id = run_id
start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
workflow_run_id=run_id,
workflow_id=self._workflow.id,
)
yield start_resp
@ -296,14 +269,9 @@ class WorkflowAppGenerateTaskPipeline:
"""Handle node retry events."""
self._ensure_workflow_initialized()
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
workflow_execution_id=self._workflow_run_id,
event=event,
)
response = self._workflow_response_converter.workflow_node_retry_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response:
@ -315,13 +283,9 @@ class WorkflowAppGenerateTaskPipeline:
"""Handle node started events."""
self._ensure_workflow_initialized()
workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start(
workflow_execution_id=self._workflow_run_id, event=event
)
node_start_response = self._workflow_response_converter.workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if node_start_response:
@ -331,14 +295,12 @@ class WorkflowAppGenerateTaskPipeline:
self, event: QueueNodeSucceededEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle node succeeded events."""
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event)
node_success_response = self._workflow_response_converter.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
self._save_output_for_event(event, workflow_node_execution.id)
self._save_output_for_event(event, event.node_execution_id)
if node_success_response:
yield node_success_response
@ -349,17 +311,13 @@ class WorkflowAppGenerateTaskPipeline:
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle various node failure events."""
workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed(
event=event,
)
node_failed_response = self._workflow_response_converter.workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if isinstance(event, QueueNodeExceptionEvent):
self._save_output_for_event(event, workflow_node_execution.id)
self._save_output_for_event(event, event.node_execution_id)
if node_failed_response:
yield node_failed_response
@ -372,7 +330,7 @@ class WorkflowAppGenerateTaskPipeline:
iter_start_resp = self._workflow_response_converter.workflow_iteration_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
workflow_execution_id=self._workflow_execution_id,
event=event,
)
yield iter_start_resp
@ -385,7 +343,7 @@ class WorkflowAppGenerateTaskPipeline:
iter_next_resp = self._workflow_response_converter.workflow_iteration_next_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
workflow_execution_id=self._workflow_execution_id,
event=event,
)
yield iter_next_resp
@ -398,7 +356,7 @@ class WorkflowAppGenerateTaskPipeline:
iter_finish_resp = self._workflow_response_converter.workflow_iteration_completed_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
workflow_execution_id=self._workflow_execution_id,
event=event,
)
yield iter_finish_resp
@ -409,7 +367,7 @@ class WorkflowAppGenerateTaskPipeline:
loop_start_resp = self._workflow_response_converter.workflow_loop_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
workflow_execution_id=self._workflow_execution_id,
event=event,
)
yield loop_start_resp
@ -420,7 +378,7 @@ class WorkflowAppGenerateTaskPipeline:
loop_next_resp = self._workflow_response_converter.workflow_loop_next_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
workflow_execution_id=self._workflow_execution_id,
event=event,
)
yield loop_next_resp
@ -433,7 +391,7 @@ class WorkflowAppGenerateTaskPipeline:
loop_finish_resp = self._workflow_response_converter.workflow_loop_completed_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_execution_id=self._workflow_run_id,
workflow_execution_id=self._workflow_execution_id,
event=event,
)
yield loop_finish_resp
@ -442,33 +400,22 @@ class WorkflowAppGenerateTaskPipeline:
self,
event: QueueWorkflowSucceededEvent,
*,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow succeeded events."""
_ = trace_manager
self._ensure_workflow_initialized()
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
validated_state = self._ensure_graph_runtime_initialized()
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_id=self._workflow.id,
status=WorkflowExecutionStatus.SUCCEEDED,
graph_runtime_state=validated_state,
)
with self._database_session() as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success(
workflow_run_id=self._workflow_run_id,
total_tokens=validated_state.total_tokens,
total_steps=validated_state.node_run_steps,
outputs=event.outputs,
conversation_id=None,
trace_manager=trace_manager,
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
)
# save workflow app log
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
yield workflow_finish_resp
@ -476,34 +423,23 @@ class WorkflowAppGenerateTaskPipeline:
self,
event: QueueWorkflowPartialSuccessEvent,
*,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow partial success events."""
_ = trace_manager
self._ensure_workflow_initialized()
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
validated_state = self._ensure_graph_runtime_initialized()
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_id=self._workflow.id,
status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
graph_runtime_state=validated_state,
exceptions_count=event.exceptions_count,
)
with self._database_session() as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success(
workflow_run_id=self._workflow_run_id,
total_tokens=validated_state.total_tokens,
total_steps=validated_state.node_run_steps,
outputs=event.outputs,
exceptions_count=event.exceptions_count,
conversation_id=None,
trace_manager=trace_manager,
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
)
# save workflow app log
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
yield workflow_finish_resp
@ -511,37 +447,33 @@ class WorkflowAppGenerateTaskPipeline:
self,
event: Union[QueueWorkflowFailedEvent, QueueStopEvent],
*,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow failed and stop events."""
_ = trace_manager
self._ensure_workflow_initialized()
validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
validated_state = self._ensure_graph_runtime_initialized()
if isinstance(event, QueueWorkflowFailedEvent):
status = WorkflowExecutionStatus.FAILED
error = event.error
exceptions_count = event.exceptions_count
else:
status = WorkflowExecutionStatus.STOPPED
error = event.get_stop_reason()
exceptions_count = 0
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_id=self._workflow.id,
status=status,
graph_runtime_state=validated_state,
error=error,
exceptions_count=exceptions_count,
)
with self._database_session() as session:
workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
workflow_run_id=self._workflow_run_id,
total_tokens=validated_state.total_tokens,
total_steps=validated_state.node_run_steps,
status=WorkflowExecutionStatus.FAILED
if isinstance(event, QueueWorkflowFailedEvent)
else WorkflowExecutionStatus.STOPPED,
error_message=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
conversation_id=None,
trace_manager=trace_manager,
exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
)
# save workflow app log
self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_execution=workflow_execution,
)
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
yield workflow_finish_resp
@ -601,7 +533,6 @@ class WorkflowAppGenerateTaskPipeline:
self,
event: AppQueueEvent,
*,
graph_runtime_state: GraphRuntimeState | None = None,
tts_publisher: AppGeneratorTTSPublisher | None = None,
trace_manager: TraceQueueManager | None = None,
queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None,
@ -614,7 +545,6 @@ class WorkflowAppGenerateTaskPipeline:
if handler := handlers.get(event_type):
yield from handler(
event,
graph_runtime_state=graph_runtime_state,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
@ -631,7 +561,6 @@ class WorkflowAppGenerateTaskPipeline:
):
yield from self._handle_node_failed_events(
event,
graph_runtime_state=graph_runtime_state,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
@ -642,7 +571,6 @@ class WorkflowAppGenerateTaskPipeline:
if isinstance(event, (QueueWorkflowFailedEvent, QueueStopEvent)):
yield from self._handle_workflow_failed_and_stop_events(
event,
graph_runtime_state=graph_runtime_state,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
@ -661,15 +589,12 @@ class WorkflowAppGenerateTaskPipeline:
Process stream response using elegant Fluent Python patterns.
Maintains exact same functionality as original 44-if-statement version.
"""
# Initialize graph runtime state
graph_runtime_state = None
for queue_message in self._base_task_pipeline.queue_manager.listen():
event = queue_message.event
match event:
case QueueWorkflowStartedEvent():
graph_runtime_state = event.graph_runtime_state
self._resolve_graph_runtime_state()
yield from self._handle_workflow_started_event(event)
case QueueTextChunkEvent():
@ -681,12 +606,19 @@ class WorkflowAppGenerateTaskPipeline:
yield from self._handle_error_event(event)
break
case QueueWorkflowFailedEvent():
yield from self._handle_workflow_failed_and_stop_events(event)
break
case QueueStopEvent():
yield from self._handle_workflow_failed_and_stop_events(event)
break
# Handle all other events through elegant dispatch
case _:
if responses := list(
self._dispatch_event(
event,
graph_runtime_state=graph_runtime_state,
tts_publisher=tts_publisher,
trace_manager=trace_manager,
queue_message=queue_message,
@ -697,7 +629,7 @@ class WorkflowAppGenerateTaskPipeline:
if tts_publisher:
tts_publisher.publish(None)
def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution):
def _save_workflow_app_log(self, *, session: Session, workflow_run_id: str | None):
invoke_from = self._application_generate_entity.invoke_from
if invoke_from == InvokeFrom.SERVICE_API:
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
@ -709,11 +641,14 @@ class WorkflowAppGenerateTaskPipeline:
# not save log for debugging
return
if not workflow_run_id:
return
workflow_app_log = WorkflowAppLog()
workflow_app_log.tenant_id = self._application_generate_entity.app_config.tenant_id
workflow_app_log.app_id = self._application_generate_entity.app_config.app_id
workflow_app_log.workflow_id = workflow_execution.workflow_id
workflow_app_log.workflow_run_id = workflow_execution.id_
workflow_app_log.workflow_id = self._workflow.id
workflow_app_log.workflow_run_id = workflow_run_id
workflow_app_log.created_from = created_from.value
workflow_app_log.created_by_role = self._created_by_role
workflow_app_log.created_by = self._user_id

View File

@ -25,7 +25,7 @@ from core.app.entities.queue_entities import (
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_events import (
GraphEngineEvent,
@ -54,6 +54,7 @@ from core.workflow.graph_events.graph import GraphRunAbortedEvent
from core.workflow.nodes import NodeType
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
from core.workflow.workflow_entry import WorkflowEntry
@ -346,9 +347,7 @@ class WorkflowBasedAppRunner:
:param event: event
"""
if isinstance(event, GraphRunStartedEvent):
self._publish_event(
QueueWorkflowStartedEvent(graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state)
)
self._publish_event(QueueWorkflowStartedEvent())
elif isinstance(event, GraphRunSucceededEvent):
self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs))
elif isinstance(event, GraphRunPartialSucceededEvent):
@ -372,7 +371,6 @@ class WorkflowBasedAppRunner:
node_title=event.node_title,
node_type=event.node_type,
start_at=event.start_at,
predecessor_node_id=event.predecessor_node_id,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
inputs=inputs,
@ -393,7 +391,6 @@ class WorkflowBasedAppRunner:
node_title=event.node_title,
node_type=event.node_type,
start_at=event.start_at,
predecessor_node_id=event.predecessor_node_id,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
agent_strategy=event.agent_strategy,
@ -494,7 +491,6 @@ class WorkflowBasedAppRunner:
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
predecessor_node_id=event.predecessor_node_id,
metadata=event.metadata,
)
)
@ -536,7 +532,6 @@ class WorkflowBasedAppRunner:
start_at=event.start_at,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
inputs=event.inputs,
predecessor_node_id=event.predecessor_node_id,
metadata=event.metadata,
)
)

View File

@ -3,11 +3,11 @@ from datetime import datetime
from enum import StrEnum, auto
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities import AgentNodeStrategyInit, GraphRuntimeState
from core.workflow.entities import AgentNodeStrategyInit
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
from core.workflow.nodes import NodeType
@ -54,6 +54,7 @@ class AppQueueEvent(BaseModel):
"""
event: QueueEvent
model_config = ConfigDict(arbitrary_types_allowed=True)
class QueueLLMChunkEvent(AppQueueEvent):
@ -80,7 +81,6 @@ class QueueIterationStartEvent(AppQueueEvent):
node_run_index: int
inputs: Mapping[str, object] = Field(default_factory=dict)
predecessor_node_id: str | None = None
metadata: Mapping[str, object] = Field(default_factory=dict)
@ -132,19 +132,10 @@ class QueueLoopStartEvent(AppQueueEvent):
node_id: str
node_type: NodeType
node_title: str
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
start_at: datetime
node_run_index: int
inputs: Mapping[str, object] = Field(default_factory=dict)
predecessor_node_id: str | None = None
metadata: Mapping[str, object] = Field(default_factory=dict)
@ -160,16 +151,6 @@ class QueueLoopNextEvent(AppQueueEvent):
node_id: str
node_type: NodeType
node_title: str
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
parallel_mode_run_id: str | None = None
"""iteration run in parallel mode run id"""
node_run_index: int
output: Any = None # output for the current loop
@ -185,14 +166,6 @@ class QueueLoopCompletedEvent(AppQueueEvent):
node_id: str
node_type: NodeType
node_title: str
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
start_at: datetime
node_run_index: int
@ -285,12 +258,9 @@ class QueueAdvancedChatMessageEndEvent(AppQueueEvent):
class QueueWorkflowStartedEvent(AppQueueEvent):
"""
QueueWorkflowStartedEvent entity
"""
"""QueueWorkflowStartedEvent entity."""
event: QueueEvent = QueueEvent.WORKFLOW_STARTED
graph_runtime_state: GraphRuntimeState
class QueueWorkflowSucceededEvent(AppQueueEvent):
@ -334,15 +304,9 @@ class QueueNodeStartedEvent(AppQueueEvent):
node_title: str
node_type: NodeType
node_run_index: int = 1 # FIXME(-LAN-): may not used
predecessor_node_id: str | None = None
parallel_id: str | None = None
parallel_start_node_id: str | None = None
parent_parallel_id: str | None = None
parent_parallel_start_node_id: str | None = None
in_iteration_id: str | None = None
in_loop_id: str | None = None
start_at: datetime
parallel_mode_run_id: str | None = None
agent_strategy: AgentNodeStrategyInit | None = None
# FIXME(-LAN-): only for ToolNode, need to refactor
@ -360,14 +324,6 @@ class QueueNodeSucceededEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: str | None = None
"""iteration id if node is in iteration"""
in_loop_id: str | None = None
@ -423,14 +379,6 @@ class QueueNodeExceptionEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
parallel_id: str | None = None
"""parallel id if node is in parallel"""
parallel_start_node_id: str | None = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: str | None = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: str | None = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: str | None = None
"""iteration id if node is in iteration"""
in_loop_id: str | None = None
@ -455,7 +403,6 @@ class QueueNodeFailedEvent(AppQueueEvent):
node_execution_id: str
node_id: str
node_type: NodeType
parallel_id: str | None = None
in_iteration_id: str | None = None
"""iteration id if node is in iteration"""
in_loop_id: str | None = None

View File

@ -257,13 +257,8 @@ class NodeStartStreamResponse(StreamResponse):
inputs_truncated: bool = False
created_at: int
extras: dict[str, object] = Field(default_factory=dict)
parallel_id: str | None = None
parallel_start_node_id: str | None = None
parent_parallel_id: str | None = None
parent_parallel_start_node_id: str | None = None
iteration_id: str | None = None
loop_id: str | None = None
parallel_run_id: str | None = None
agent_strategy: AgentNodeStrategyInit | None = None
event: StreamEvent = StreamEvent.NODE_STARTED
@ -285,10 +280,6 @@ class NodeStartStreamResponse(StreamResponse):
"inputs": None,
"created_at": self.data.created_at,
"extras": {},
"parallel_id": self.data.parallel_id,
"parallel_start_node_id": self.data.parallel_start_node_id,
"parent_parallel_id": self.data.parent_parallel_id,
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
"iteration_id": self.data.iteration_id,
"loop_id": self.data.loop_id,
},
@ -324,10 +315,6 @@ class NodeFinishStreamResponse(StreamResponse):
created_at: int
finished_at: int
files: Sequence[Mapping[str, Any]] | None = []
parallel_id: str | None = None
parallel_start_node_id: str | None = None
parent_parallel_id: str | None = None
parent_parallel_start_node_id: str | None = None
iteration_id: str | None = None
loop_id: str | None = None
@ -357,10 +344,6 @@ class NodeFinishStreamResponse(StreamResponse):
"created_at": self.data.created_at,
"finished_at": self.data.finished_at,
"files": [],
"parallel_id": self.data.parallel_id,
"parallel_start_node_id": self.data.parallel_start_node_id,
"parent_parallel_id": self.data.parent_parallel_id,
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
"iteration_id": self.data.iteration_id,
"loop_id": self.data.loop_id,
},
@ -396,10 +379,6 @@ class NodeRetryStreamResponse(StreamResponse):
created_at: int
finished_at: int
files: Sequence[Mapping[str, Any]] | None = []
parallel_id: str | None = None
parallel_start_node_id: str | None = None
parent_parallel_id: str | None = None
parent_parallel_start_node_id: str | None = None
iteration_id: str | None = None
loop_id: str | None = None
retry_index: int = 0
@ -430,10 +409,6 @@ class NodeRetryStreamResponse(StreamResponse):
"created_at": self.data.created_at,
"finished_at": self.data.finished_at,
"files": [],
"parallel_id": self.data.parallel_id,
"parallel_start_node_id": self.data.parallel_start_node_id,
"parent_parallel_id": self.data.parent_parallel_id,
"parent_parallel_start_node_id": self.data.parent_parallel_start_node_id,
"iteration_id": self.data.iteration_id,
"loop_id": self.data.loop_id,
"retry_index": self.data.retry_index,
@ -541,8 +516,6 @@ class LoopNodeStartStreamResponse(StreamResponse):
metadata: Mapping = {}
inputs: Mapping = {}
inputs_truncated: bool = False
parallel_id: str | None = None
parallel_start_node_id: str | None = None
event: StreamEvent = StreamEvent.LOOP_STARTED
workflow_run_id: str
@ -567,9 +540,6 @@ class LoopNodeNextStreamResponse(StreamResponse):
created_at: int
pre_loop_output: Any = None
extras: Mapping[str, object] = Field(default_factory=dict)
parallel_id: str | None = None
parallel_start_node_id: str | None = None
parallel_mode_run_id: str | None = None
event: StreamEvent = StreamEvent.LOOP_NEXT
workflow_run_id: str
@ -603,8 +573,6 @@ class LoopNodeCompletedStreamResponse(StreamResponse):
execution_metadata: Mapping[str, object] = Field(default_factory=dict)
finished_at: int
steps: int
parallel_id: str | None = None
parallel_start_node_id: str | None = None
event: StreamEvent = StreamEvent.LOOP_COMPLETED
workflow_run_id: str

View File

@ -1,11 +1,9 @@
import logging
from threading import Lock
from typing import Union
import contexts
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.entities.common_entities import I18nObject
from core.datasource.entities.datasource_entities import DatasourceProviderType
from core.datasource.errors import DatasourceProviderNotFoundError
from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController
@ -18,11 +16,6 @@ logger = logging.getLogger(__name__)
class DatasourceManager:
_builtin_provider_lock = Lock()
_hardcoded_providers: dict[str, DatasourcePluginProviderController] = {}
_builtin_providers_loaded = False
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
@classmethod
def get_datasource_plugin_provider(
cls, provider_id: str, tenant_id: str, datasource_type: DatasourceProviderType

View File

@ -1148,6 +1148,15 @@ class ProviderConfiguration(BaseModel):
raise ValueError("Can't add same credential")
provider_model_record.credential_id = credential_record.id
provider_model_record.updated_at = naive_utc_now()
# clear cache
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_model_record.id,
cache_type=ProviderCredentialsCacheType.MODEL,
)
provider_model_credentials_cache.delete()
session.add(provider_model_record)
session.commit()
@ -1181,6 +1190,14 @@ class ProviderConfiguration(BaseModel):
session.add(provider_model_record)
session.commit()
# clear cache
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_model_record.id,
cache_type=ProviderCredentialsCacheType.MODEL,
)
provider_model_credentials_cache.delete()
def delete_custom_model(self, model_type: ModelType, model: str):
"""
Delete custom model.

View File

@ -913,4 +913,4 @@ class TraceQueueManager:
"file_id": file_id,
"app_id": task.app_id,
}
process_trace_tasks.delay(file_info)
process_trace_tasks.delay(file_info) # type: ignore

View File

@ -14,7 +14,7 @@ from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
from extensions.ext_database import db
from models.account import Account
from models import Account
from models.model import App, AppMode, EndUser

View File

@ -18,7 +18,7 @@ from core.model_runtime.entities.message_entities import ImagePromptMessageConte
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.prompt.prompt_transform import PromptTransform
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.runtime import VariablePool
class AdvancedPromptTransform(PromptTransform):

View File

@ -250,7 +250,6 @@ class WeaviateVector(BaseVector):
)
)
batch_size = max(1, int(dify_config.WEAVIATE_BATCH_SIZE or 100))
with col.batch.dynamic() as batch:
for obj in objs:
batch.add_object(properties=obj.properties, uuid=obj.uuid, vector=obj.vector)
@ -348,7 +347,10 @@ class WeaviateVector(BaseVector):
for obj in res.objects:
properties = dict(obj.properties or {})
text = properties.pop(Field.TEXT_KEY.value, "")
distance = (obj.metadata.distance if obj.metadata else None) or 1.0
if obj.metadata and obj.metadata.distance is not None:
distance = obj.metadata.distance
else:
distance = 1.0
score = 1.0 - distance
if score > score_threshold:

View File

@ -25,7 +25,7 @@ class FirecrawlApp:
}
if params:
json_data.update(params)
response = self._post_request(f"{self.base_url}/v1/scrape", json_data, headers)
response = self._post_request(f"{self.base_url}/v2/scrape", json_data, headers)
if response.status_code == 200:
response_data = response.json()
data = response_data["data"]
@ -42,7 +42,7 @@ class FirecrawlApp:
json_data = {"url": url}
if params:
json_data.update(params)
response = self._post_request(f"{self.base_url}/v1/crawl", json_data, headers)
response = self._post_request(f"{self.base_url}/v2/crawl", json_data, headers)
if response.status_code == 200:
# There's also another two fields in the response: "success" (bool) and "url" (str)
job_id = response.json().get("id")
@ -51,9 +51,25 @@ class FirecrawlApp:
self._handle_error(response, "start crawl job")
return "" # unreachable
def map(self, url: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
# Documentation: https://docs.firecrawl.dev/api-reference/endpoint/map
headers = self._prepare_headers()
json_data: dict[str, Any] = {"url": url, "integration": "dify"}
if params:
# Pass through provided params, including optional "sitemap": "only" | "include" | "skip"
json_data.update(params)
response = self._post_request(f"{self.base_url}/v2/map", json_data, headers)
if response.status_code == 200:
return cast(dict[str, Any], response.json())
elif response.status_code in {402, 409, 500, 429, 408}:
self._handle_error(response, "start map job")
return {}
else:
raise Exception(f"Failed to start map job. Status code: {response.status_code}")
def check_crawl_status(self, job_id) -> dict[str, Any]:
headers = self._prepare_headers()
response = self._get_request(f"{self.base_url}/v1/crawl/{job_id}", headers)
response = self._get_request(f"{self.base_url}/v2/crawl/{job_id}", headers)
if response.status_code == 200:
crawl_status_response = response.json()
if crawl_status_response.get("status") == "completed":
@ -135,12 +151,16 @@ class FirecrawlApp:
"lang": "en",
"country": "us",
"timeout": 60000,
"ignoreInvalidURLs": False,
"ignoreInvalidURLs": True,
"scrapeOptions": {},
"sources": [
{"type": "web"},
],
"integration": "dify",
}
if params:
json_data.update(params)
response = self._post_request(f"{self.base_url}/v1/search", json_data, headers)
response = self._post_request(f"{self.base_url}/v2/search", json_data, headers)
if response.status_code == 200:
response_data = response.json()
if not response_data.get("success"):

View File

@ -108,7 +108,7 @@ class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository):
execution_data = execution.model_dump()
# Queue the save operation as a Celery task (fire and forget)
save_workflow_execution_task.delay(
save_workflow_execution_task.delay( # type: ignore
execution_data=execution_data,
tenant_id=self._tenant_id,
app_id=self._app_id or "",

View File

@ -104,7 +104,6 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER
# Initialize in-memory cache for node executions
# Key: node_execution_id, Value: WorkflowNodeExecution (DB model)
self._node_execution_cache: dict[str, WorkflowNodeExecutionModel] = {}
# Initialize FileService for handling offloaded data
@ -332,17 +331,10 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
Args:
execution: The NodeExecution domain entity to persist
"""
# NOTE: As per the implementation of `WorkflowCycleManager`,
# the `save` method is invoked multiple times during the node's execution lifecycle, including:
#
# - When the node starts execution
# - When the node retries execution
# - When the node completes execution (either successfully or with failure)
#
# Only the final invocation will have `inputs` and `outputs` populated.
#
# This simplifies the logic for saving offloaded variables but introduces a tight coupling
# between this module and `WorkflowCycleManager`.
# NOTE: The workflow engine triggers `save` multiple times for a single node execution:
# when the node starts, any time it retries, and once more when it reaches a terminal state.
# Only the final call contains the complete inputs and outputs payloads, so earlier invocations
# must tolerate missing data without attempting to offload variables.
# Convert domain model to database model using tenant context and other attributes
db_model = self._to_db_model(execution)

View File

@ -395,11 +395,13 @@ class ApiTool(Tool):
parsed_response = self.validate_and_parse_response(response)
# assemble invoke message based on response type
if parsed_response.is_json and isinstance(parsed_response.content, dict):
yield self.create_json_message(parsed_response.content)
if parsed_response.is_json:
if isinstance(parsed_response.content, dict):
yield self.create_json_message(parsed_response.content)
# FIXES: https://github.com/langgenius/dify/pull/23456#issuecomment-3182413088
# We need never break the original flows
# The yield below must be preserved to keep backward compatibility.
#
# ref: https://github.com/langgenius/dify/pull/23456#issuecomment-3182413088
yield self.create_text_message(response.text)
else:
# Convert to string if needed and create text message

View File

@ -189,6 +189,11 @@ class ToolInvokeMessage(BaseModel):
data: Mapping[str, Any] = Field(..., description="Detailed log data")
metadata: Mapping[str, Any] = Field(default_factory=dict, description="The metadata of the log")
@field_validator("metadata", mode="before")
@classmethod
def _normalize_metadata(cls, value: Mapping[str, Any] | None) -> Mapping[str, Any]:
return value or {}
class RetrieverResourceMessage(BaseModel):
retriever_resources: list[RetrievalSourceMetadata] = Field(..., description="retriever resources")
context: str = Field(..., description="context")
@ -376,6 +381,11 @@ class ToolEntity(BaseModel):
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[ToolParameter]:
return v or []
@field_validator("output_schema", mode="before")
@classmethod
def _normalize_output_schema(cls, value: Mapping[str, object] | None) -> Mapping[str, object]:
return value or {}
class OAuthSchema(BaseModel):
client_schema: list[ProviderConfig] = Field(

View File

@ -63,8 +63,8 @@ from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvi
from services.tools.tools_transform_service import ToolTransformService
if TYPE_CHECKING:
from core.workflow.entities import VariablePool
from core.workflow.nodes.tool.entities import ToolEntity
from core.workflow.runtime import VariablePool
logger = logging.getLogger(__name__)

View File

@ -12,7 +12,7 @@ from core.file import File, FileTransferMethod, FileType
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool_file_manager import ToolFileManager
from libs.login import current_user
from models.account import Account
from models import Account
logger = logging.getLogger(__name__)

View File

@ -3,6 +3,7 @@ import logging
from collections.abc import Generator
from typing import Any
from flask import has_request_context
from sqlalchemy import select
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
@ -18,7 +19,8 @@ from core.tools.errors import ToolInvokeError
from extensions.ext_database import db
from factories.file_factory import build_from_mapping
from libs.login import current_user
from models.model import App
from models import Account, Tenant
from models.model import App, EndUser
from models.workflow import Workflow
logger = logging.getLogger(__name__)
@ -79,11 +81,16 @@ class WorkflowTool(Tool):
generator = WorkflowAppGenerator()
assert self.runtime is not None
assert self.runtime.invoke_from is not None
assert current_user is not None
user = self._resolve_user(user_id=user_id)
if user is None:
raise ToolInvokeError("User not found")
result = generator.generate(
app_model=app,
workflow=workflow,
user=current_user,
user=user,
args={"inputs": tool_parameters, "files": files},
invoke_from=self.runtime.invoke_from,
streaming=False,
@ -123,6 +130,51 @@ class WorkflowTool(Tool):
label=self.label,
)
def _resolve_user(self, user_id: str) -> Account | EndUser | None:
"""
Resolve user object in both HTTP and worker contexts.
In HTTP context: dereference the current_user LocalProxy (can return Account or EndUser).
In worker context: load Account from database by user_id (only returns Account, never EndUser).
Returns:
Account | EndUser | None: The resolved user object, or None if resolution fails.
"""
if has_request_context():
return self._resolve_user_from_request()
else:
return self._resolve_user_from_database(user_id=user_id)
def _resolve_user_from_request(self) -> Account | EndUser | None:
"""
Resolve user from Flask request context.
"""
try:
# Note: `current_user` is a LocalProxy. Never compare it with None directly.
return getattr(current_user, "_get_current_object", lambda: current_user)()
except Exception as e:
logger.warning("Failed to resolve user from request context: %s", e)
return None
def _resolve_user_from_database(self, user_id: str) -> Account | None:
"""
Resolve user from database (worker/Celery context).
"""
user_stmt = select(Account).where(Account.id == user_id)
user = db.session.scalar(user_stmt)
if not user:
return None
tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id)
tenant = db.session.scalar(tenant_stmt)
if not tenant:
return None
user.current_tenant = tenant
return user
def _get_workflow(self, app_id: str, version: str) -> Workflow:
"""
get the workflow by app id and version

View File

@ -1,18 +1,11 @@
from .agent import AgentNodeStrategyInit
from .graph_init_params import GraphInitParams
from .graph_runtime_state import GraphRuntimeState
from .run_condition import RunCondition
from .variable_pool import VariablePool, VariableValue
from .workflow_execution import WorkflowExecution
from .workflow_node_execution import WorkflowNodeExecution
__all__ = [
"AgentNodeStrategyInit",
"GraphInitParams",
"GraphRuntimeState",
"RunCondition",
"VariablePool",
"VariableValue",
"WorkflowExecution",
"WorkflowNodeExecution",
]

View File

@ -1,160 +0,0 @@
from copy import deepcopy
from pydantic import BaseModel, PrivateAttr
from core.model_runtime.entities.llm_entities import LLMUsage
from .variable_pool import VariablePool
class GraphRuntimeState(BaseModel):
# Private attributes to prevent direct modification
_variable_pool: VariablePool = PrivateAttr()
_start_at: float = PrivateAttr()
_total_tokens: int = PrivateAttr(default=0)
_llm_usage: LLMUsage = PrivateAttr(default_factory=LLMUsage.empty_usage)
_outputs: dict[str, object] = PrivateAttr(default_factory=dict[str, object])
_node_run_steps: int = PrivateAttr(default=0)
_ready_queue_json: str = PrivateAttr()
_graph_execution_json: str = PrivateAttr()
_response_coordinator_json: str = PrivateAttr()
def __init__(
self,
*,
variable_pool: VariablePool,
start_at: float,
total_tokens: int = 0,
llm_usage: LLMUsage | None = None,
outputs: dict[str, object] | None = None,
node_run_steps: int = 0,
ready_queue_json: str = "",
graph_execution_json: str = "",
response_coordinator_json: str = "",
**kwargs: object,
):
"""Initialize the GraphRuntimeState with validation."""
super().__init__(**kwargs)
# Initialize private attributes with validation
self._variable_pool = variable_pool
self._start_at = start_at
if total_tokens < 0:
raise ValueError("total_tokens must be non-negative")
self._total_tokens = total_tokens
if llm_usage is None:
llm_usage = LLMUsage.empty_usage()
self._llm_usage = llm_usage
if outputs is None:
outputs = {}
self._outputs = deepcopy(outputs)
if node_run_steps < 0:
raise ValueError("node_run_steps must be non-negative")
self._node_run_steps = node_run_steps
self._ready_queue_json = ready_queue_json
self._graph_execution_json = graph_execution_json
self._response_coordinator_json = response_coordinator_json
@property
def variable_pool(self) -> VariablePool:
"""Get the variable pool."""
return self._variable_pool
@property
def start_at(self) -> float:
"""Get the start time."""
return self._start_at
@start_at.setter
def start_at(self, value: float) -> None:
"""Set the start time."""
self._start_at = value
@property
def total_tokens(self) -> int:
"""Get the total tokens count."""
return self._total_tokens
@total_tokens.setter
def total_tokens(self, value: int):
"""Set the total tokens count."""
if value < 0:
raise ValueError("total_tokens must be non-negative")
self._total_tokens = value
@property
def llm_usage(self) -> LLMUsage:
"""Get the LLM usage info."""
# Return a copy to prevent external modification
return self._llm_usage.model_copy()
@llm_usage.setter
def llm_usage(self, value: LLMUsage):
"""Set the LLM usage info."""
self._llm_usage = value.model_copy()
@property
def outputs(self) -> dict[str, object]:
"""Get a copy of the outputs dictionary."""
return deepcopy(self._outputs)
@outputs.setter
def outputs(self, value: dict[str, object]) -> None:
"""Set the outputs dictionary."""
self._outputs = deepcopy(value)
def set_output(self, key: str, value: object) -> None:
"""Set a single output value."""
self._outputs[key] = deepcopy(value)
def get_output(self, key: str, default: object = None) -> object:
"""Get a single output value."""
return deepcopy(self._outputs.get(key, default))
def update_outputs(self, updates: dict[str, object]) -> None:
"""Update multiple output values."""
for key, value in updates.items():
self._outputs[key] = deepcopy(value)
@property
def node_run_steps(self) -> int:
"""Get the node run steps count."""
return self._node_run_steps
@node_run_steps.setter
def node_run_steps(self, value: int) -> None:
"""Set the node run steps count."""
if value < 0:
raise ValueError("node_run_steps must be non-negative")
self._node_run_steps = value
def increment_node_run_steps(self) -> None:
"""Increment the node run steps by 1."""
self._node_run_steps += 1
def add_tokens(self, tokens: int) -> None:
"""Add tokens to the total count."""
if tokens < 0:
raise ValueError("tokens must be non-negative")
self._total_tokens += tokens
@property
def ready_queue_json(self) -> str:
"""Get a copy of the ready queue state."""
return self._ready_queue_json
@property
def graph_execution_json(self) -> str:
"""Get a copy of the serialized graph execution state."""
return self._graph_execution_json
@property
def response_coordinator_json(self) -> str:
"""Get a copy of the serialized response coordinator state."""
return self._response_coordinator_json

View File

@ -1,21 +0,0 @@
import hashlib
from typing import Literal
from pydantic import BaseModel
from core.workflow.utils.condition.entities import Condition
class RunCondition(BaseModel):
type: Literal["branch_identify", "condition"]
"""condition type"""
branch_identify: str | None = None
"""branch identify like: sourceHandle, required when type is branch_identify"""
conditions: list[Condition] | None = None
"""conditions to run the node, required when type is condition"""
@property
def hash(self) -> str:
return hashlib.sha256(self.model_dump_json().encode()).hexdigest()

View File

@ -58,6 +58,7 @@ class NodeType(StrEnum):
DOCUMENT_EXTRACTOR = "document-extractor"
LIST_OPERATOR = "list-operator"
AGENT = "agent"
HUMAN_INPUT = "human-input"
class NodeExecutionType(StrEnum):
@ -96,6 +97,7 @@ class WorkflowExecutionStatus(StrEnum):
FAILED = "failed"
STOPPED = "stopped"
PARTIAL_SUCCEEDED = "partial-succeeded"
PAUSED = "paused"
class WorkflowNodeExecutionMetadataKey(StrEnum):

View File

@ -1,16 +1,11 @@
from .edge import Edge
from .graph import Graph, NodeFactory
from .graph_runtime_state_protocol import ReadOnlyGraphRuntimeState, ReadOnlyVariablePool
from .graph import Graph, GraphBuilder, NodeFactory
from .graph_template import GraphTemplate
from .read_only_state_wrapper import ReadOnlyGraphRuntimeStateWrapper, ReadOnlyVariablePoolWrapper
__all__ = [
"Edge",
"Graph",
"GraphBuilder",
"GraphTemplate",
"NodeFactory",
"ReadOnlyGraphRuntimeState",
"ReadOnlyGraphRuntimeStateWrapper",
"ReadOnlyVariablePool",
"ReadOnlyVariablePoolWrapper",
]

View File

@ -195,6 +195,12 @@ class Graph:
return nodes
@classmethod
def new(cls) -> "GraphBuilder":
"""Create a fluent builder for assembling a graph programmatically."""
return GraphBuilder(graph_cls=cls)
@classmethod
def _mark_inactive_root_branches(
cls,
@ -344,3 +350,96 @@ class Graph:
"""
edge_ids = self.in_edges.get(node_id, [])
return [self.edges[eid] for eid in edge_ids if eid in self.edges]
@final
class GraphBuilder:
"""Fluent helper for constructing simple graphs, primarily for tests."""
def __init__(self, *, graph_cls: type[Graph]):
self._graph_cls = graph_cls
self._nodes: list[Node] = []
self._nodes_by_id: dict[str, Node] = {}
self._edges: list[Edge] = []
self._edge_counter = 0
def add_root(self, node: Node) -> "GraphBuilder":
"""Register the root node. Must be called exactly once."""
if self._nodes:
raise ValueError("Root node has already been added")
self._register_node(node)
self._nodes.append(node)
return self
def add_node(
self,
node: Node,
*,
from_node_id: str | None = None,
source_handle: str = "source",
) -> "GraphBuilder":
"""Append a node and connect it from the specified predecessor."""
if not self._nodes:
raise ValueError("Root node must be added before adding other nodes")
predecessor_id = from_node_id or self._nodes[-1].id
if predecessor_id not in self._nodes_by_id:
raise ValueError(f"Predecessor node '{predecessor_id}' not found")
predecessor = self._nodes_by_id[predecessor_id]
self._register_node(node)
self._nodes.append(node)
edge_id = f"edge_{self._edge_counter}"
self._edge_counter += 1
edge = Edge(id=edge_id, tail=predecessor.id, head=node.id, source_handle=source_handle)
self._edges.append(edge)
return self
def connect(self, *, tail: str, head: str, source_handle: str = "source") -> "GraphBuilder":
"""Connect two existing nodes without adding a new node."""
if tail not in self._nodes_by_id:
raise ValueError(f"Tail node '{tail}' not found")
if head not in self._nodes_by_id:
raise ValueError(f"Head node '{head}' not found")
edge_id = f"edge_{self._edge_counter}"
self._edge_counter += 1
edge = Edge(id=edge_id, tail=tail, head=head, source_handle=source_handle)
self._edges.append(edge)
return self
def build(self) -> Graph:
"""Materialize the graph instance from the accumulated nodes and edges."""
if not self._nodes:
raise ValueError("Cannot build an empty graph")
nodes = {node.id: node for node in self._nodes}
edges = {edge.id: edge for edge in self._edges}
in_edges: dict[str, list[str]] = defaultdict(list)
out_edges: dict[str, list[str]] = defaultdict(list)
for edge in self._edges:
out_edges[edge.tail].append(edge.id)
in_edges[edge.head].append(edge.id)
return self._graph_cls(
nodes=nodes,
edges=edges,
in_edges=dict(in_edges),
out_edges=dict(out_edges),
root_node=self._nodes[0],
)
def _register_node(self, node: Node) -> None:
if not node.id:
raise ValueError("Node must have a non-empty id")
if node.id in self._nodes_by_id:
raise ValueError(f"Duplicate node id detected: {node.id}")
self._nodes_by_id[node.id] = node

View File

@ -9,7 +9,7 @@ Each instance uses a unique key for its command queue.
import json
from typing import TYPE_CHECKING, Any, final
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand
if TYPE_CHECKING:
from extensions.ext_redis import RedisClientWrapper
@ -41,6 +41,7 @@ class RedisChannel:
self._redis = redis_client
self._key = channel_key
self._command_ttl = command_ttl
self._pending_key = f"{channel_key}:pending"
def fetch_commands(self) -> list[GraphEngineCommand]:
"""
@ -49,6 +50,9 @@ class RedisChannel:
Returns:
List of pending commands (drains the Redis list)
"""
if not self._has_pending_commands():
return []
commands: list[GraphEngineCommand] = []
# Use pipeline for atomic operations
@ -85,6 +89,7 @@ class RedisChannel:
with self._redis.pipeline() as pipe:
pipe.rpush(self._key, command_json)
pipe.expire(self._key, self._command_ttl)
pipe.set(self._pending_key, "1", ex=self._command_ttl)
pipe.execute()
def _deserialize_command(self, data: dict[str, Any]) -> GraphEngineCommand | None:
@ -106,9 +111,25 @@ class RedisChannel:
if command_type == CommandType.ABORT:
return AbortCommand.model_validate(data)
else:
# For other command types, use base class
return GraphEngineCommand.model_validate(data)
if command_type == CommandType.PAUSE:
return PauseCommand.model_validate(data)
# For other command types, use base class
return GraphEngineCommand.model_validate(data)
except (ValueError, TypeError):
return None
def _has_pending_commands(self) -> bool:
"""
Check and consume the pending marker to avoid unnecessary list reads.
Returns:
True if commands should be fetched from Redis.
"""
with self._redis.pipeline() as pipe:
pipe.get(self._pending_key)
pipe.delete(self._pending_key)
pending_value, _ = pipe.execute()
return pending_value is not None

View File

@ -5,10 +5,11 @@ This package handles external commands sent to the engine
during execution.
"""
from .command_handlers import AbortCommandHandler
from .command_handlers import AbortCommandHandler, PauseCommandHandler
from .command_processor import CommandProcessor
__all__ = [
"AbortCommandHandler",
"CommandProcessor",
"PauseCommandHandler",
]

View File

@ -1,14 +1,10 @@
"""
Command handler implementations.
"""
import logging
from typing import final
from typing_extensions import override
from ..domain.graph_execution import GraphExecution
from ..entities.commands import AbortCommand, GraphEngineCommand
from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
from .command_processor import CommandHandler
logger = logging.getLogger(__name__)
@ -16,17 +12,17 @@ logger = logging.getLogger(__name__)
@final
class AbortCommandHandler(CommandHandler):
"""Handles abort commands."""
@override
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
"""
Handle an abort command.
Args:
command: The abort command
execution: Graph execution to abort
"""
assert isinstance(command, AbortCommand)
logger.debug("Aborting workflow %s: %s", execution.workflow_id, command.reason)
execution.abort(command.reason or "User requested abort")
@final
class PauseCommandHandler(CommandHandler):
@override
def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None:
assert isinstance(command, PauseCommand)
logger.debug("Pausing workflow %s: %s", execution.workflow_id, command.reason)
execution.pause(command.reason)

View File

@ -40,6 +40,8 @@ class GraphExecutionState(BaseModel):
started: bool = Field(default=False)
completed: bool = Field(default=False)
aborted: bool = Field(default=False)
paused: bool = Field(default=False)
pause_reason: str | None = Field(default=None)
error: GraphExecutionErrorState | None = Field(default=None)
exceptions_count: int = Field(default=0)
node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState])
@ -103,6 +105,8 @@ class GraphExecution:
started: bool = False
completed: bool = False
aborted: bool = False
paused: bool = False
pause_reason: str | None = None
error: Exception | None = None
node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution])
exceptions_count: int = 0
@ -126,6 +130,17 @@ class GraphExecution:
self.aborted = True
self.error = RuntimeError(f"Aborted: {reason}")
def pause(self, reason: str | None = None) -> None:
"""Pause the graph execution without marking it complete."""
if self.completed:
raise RuntimeError("Cannot pause execution that has completed")
if self.aborted:
raise RuntimeError("Cannot pause execution that has been aborted")
if self.paused:
return
self.paused = True
self.pause_reason = reason
def fail(self, error: Exception) -> None:
"""Mark the graph execution as failed."""
self.error = error
@ -140,7 +155,12 @@ class GraphExecution:
@property
def is_running(self) -> bool:
"""Check if the execution is currently running."""
return self.started and not self.completed and not self.aborted
return self.started and not self.completed and not self.aborted and not self.paused
@property
def is_paused(self) -> bool:
"""Check if the execution is currently paused."""
return self.paused
@property
def has_error(self) -> bool:
@ -173,6 +193,8 @@ class GraphExecution:
started=self.started,
completed=self.completed,
aborted=self.aborted,
paused=self.paused,
pause_reason=self.pause_reason,
error=_serialize_error(self.error),
exceptions_count=self.exceptions_count,
node_executions=node_states,
@ -197,6 +219,8 @@ class GraphExecution:
self.started = state.started
self.completed = state.completed
self.aborted = state.aborted
self.paused = state.paused
self.pause_reason = state.pause_reason
self.error = _deserialize_error(state.error)
self.exceptions_count = state.exceptions_count
self.node_executions = {

View File

@ -16,7 +16,6 @@ class CommandType(StrEnum):
ABORT = "abort"
PAUSE = "pause"
RESUME = "resume"
class GraphEngineCommand(BaseModel):
@ -31,3 +30,10 @@ class AbortCommand(GraphEngineCommand):
command_type: CommandType = Field(default=CommandType.ABORT, description="Type of command")
reason: str | None = Field(default=None, description="Optional reason for abort")
class PauseCommand(GraphEngineCommand):
"""Command to pause a running workflow execution."""
command_type: CommandType = Field(default=CommandType.PAUSE, description="Type of command")
reason: str | None = Field(default=None, description="Optional reason for pause")

View File

@ -7,8 +7,8 @@ from collections.abc import Mapping
from functools import singledispatchmethod
from typing import TYPE_CHECKING, final
from core.workflow.entities import GraphRuntimeState
from core.workflow.enums import ErrorStrategy, NodeExecutionType
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState
from core.workflow.graph import Graph
from core.workflow.graph_events import (
GraphNodeEventBase,
@ -23,11 +23,13 @@ from core.workflow.graph_events import (
NodeRunLoopNextEvent,
NodeRunLoopStartedEvent,
NodeRunLoopSucceededEvent,
NodeRunPauseRequestedEvent,
NodeRunRetryEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.runtime import GraphRuntimeState
from ..domain.graph_execution import GraphExecution
from ..response_coordinator import ResponseStreamCoordinator
@ -125,6 +127,7 @@ class EventHandler:
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
is_initial_attempt = node_execution.retry_count == 0
node_execution.mark_started(event.id)
self._graph_runtime_state.increment_node_run_steps()
# Track in response coordinator for stream ordering
self._response_coordinator.track_node_execution(event.node_id, event.id)
@ -163,6 +166,8 @@ class EventHandler:
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_taken()
self._accumulate_node_usage(event.node_run_result.llm_usage)
# Store outputs in variable pool
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
@ -199,6 +204,18 @@ class EventHandler:
# Collect the event
self._event_collector.collect(event)
@_dispatch.register
def _(self, event: NodeRunPauseRequestedEvent) -> None:
"""Handle pause requests emitted by nodes."""
pause_reason = event.reason or "Awaiting human input"
self._graph_execution.pause(pause_reason)
self._state_manager.finish_execution(event.node_id)
if event.node_id in self._graph.nodes:
self._graph.nodes[event.node_id].state = NodeState.UNKNOWN
self._graph_runtime_state.register_paused_node(event.node_id)
self._event_collector.collect(event)
@_dispatch.register
def _(self, event: NodeRunFailedEvent) -> None:
"""
@ -212,6 +229,8 @@ class EventHandler:
node_execution.mark_failed(event.error)
self._graph_execution.record_node_failure()
self._accumulate_node_usage(event.node_run_result.llm_usage)
result = self._error_handler.handle_node_failure(event)
if result:
@ -235,6 +254,8 @@ class EventHandler:
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_taken()
self._accumulate_node_usage(event.node_run_result.llm_usage)
# Persist outputs produced by the exception strategy (e.g. default values)
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
@ -286,6 +307,19 @@ class EventHandler:
self._state_manager.enqueue_node(event.node_id)
self._state_manager.start_execution(event.node_id)
def _accumulate_node_usage(self, usage: LLMUsage) -> None:
"""Accumulate token usage into the shared runtime state."""
if usage.total_tokens <= 0:
return
self._graph_runtime_state.add_tokens(usage.total_tokens)
current_usage = self._graph_runtime_state.llm_usage
if current_usage.total_tokens == 0:
self._graph_runtime_state.llm_usage = usage
else:
self._graph_runtime_state.llm_usage = current_usage.plus(usage)
def _store_node_outputs(self, node_id: str, outputs: Mapping[str, object]) -> None:
"""
Store node outputs in the variable pool.

View File

@ -97,6 +97,10 @@ class EventManager:
"""
self._layers = layers
def notify_layers(self, event: GraphEngineEvent) -> None:
"""Notify registered layers about an event without buffering it."""
self._notify_layers(event)
def collect(self, event: GraphEngineEvent) -> None:
"""
Thread-safe method to collect an event.

View File

@ -9,28 +9,29 @@ import contextvars
import logging
import queue
from collections.abc import Generator
from typing import final
from typing import TYPE_CHECKING, cast, final
from flask import Flask, current_app
from core.workflow.entities import GraphRuntimeState
from core.workflow.enums import NodeExecutionType
from core.workflow.graph import Graph
from core.workflow.graph.read_only_state_wrapper import ReadOnlyGraphRuntimeStateWrapper
from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue
from core.workflow.graph_events import (
GraphEngineEvent,
GraphNodeEventBase,
GraphRunAbortedEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunPausedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
)
from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper
from .command_processing import AbortCommandHandler, CommandProcessor
from .domain import GraphExecution
from .entities.commands import AbortCommand
if TYPE_CHECKING: # pragma: no cover - used only for static analysis
from core.workflow.runtime.graph_runtime_state import GraphProtocol
from .command_processing import AbortCommandHandler, CommandProcessor, PauseCommandHandler
from .entities.commands import AbortCommand, PauseCommand
from .error_handler import ErrorHandler
from .event_management import EventHandler, EventManager
from .graph_state_manager import GraphStateManager
@ -38,10 +39,13 @@ from .graph_traversal import EdgeProcessor, SkipPropagator
from .layers.base import GraphEngineLayer
from .orchestration import Dispatcher, ExecutionCoordinator
from .protocols.command_channel import CommandChannel
from .ready_queue import ReadyQueue, ReadyQueueState, create_ready_queue_from_state
from .response_coordinator import ResponseStreamCoordinator
from .ready_queue import ReadyQueue
from .worker_management import WorkerPool
if TYPE_CHECKING:
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
from core.workflow.graph_engine.response_coordinator import ResponseStreamCoordinator
logger = logging.getLogger(__name__)
@ -67,17 +71,16 @@ class GraphEngine:
) -> None:
"""Initialize the graph engine with all subsystems and dependencies."""
# Graph execution tracks the overall execution state
self._graph_execution = GraphExecution(workflow_id=workflow_id)
if graph_runtime_state.graph_execution_json != "":
self._graph_execution.loads(graph_runtime_state.graph_execution_json)
# === Core Dependencies ===
# Graph structure and configuration
# Bind runtime state to current workflow context
self._graph = graph
self._graph_runtime_state = graph_runtime_state
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
self._command_channel = command_channel
# Graph execution tracks the overall execution state
self._graph_execution = cast("GraphExecution", self._graph_runtime_state.graph_execution)
self._graph_execution.workflow_id = workflow_id
# === Worker Management Parameters ===
# Parameters for dynamic worker pool scaling
self._min_workers = min_workers
@ -86,13 +89,7 @@ class GraphEngine:
self._scale_down_idle_time = scale_down_idle_time
# === Execution Queues ===
# Create ready queue from saved state or initialize new one
self._ready_queue: ReadyQueue
if self._graph_runtime_state.ready_queue_json == "":
self._ready_queue = InMemoryReadyQueue()
else:
ready_queue_state = ReadyQueueState.model_validate_json(self._graph_runtime_state.ready_queue_json)
self._ready_queue = create_ready_queue_from_state(ready_queue_state)
self._ready_queue = cast(ReadyQueue, self._graph_runtime_state.ready_queue)
# Queue for events generated during execution
self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()
@ -103,11 +100,7 @@ class GraphEngine:
# === Response Coordination ===
# Coordinates response streaming from response nodes
self._response_coordinator = ResponseStreamCoordinator(
variable_pool=self._graph_runtime_state.variable_pool, graph=self._graph
)
if graph_runtime_state.response_coordinator_json != "":
self._response_coordinator.loads(graph_runtime_state.response_coordinator_json)
self._response_coordinator = cast("ResponseStreamCoordinator", self._graph_runtime_state.response_coordinator)
# === Event Management ===
# Event manager handles both collection and emission of events
@ -133,19 +126,6 @@ class GraphEngine:
skip_propagator=self._skip_propagator,
)
# === Event Handler Registry ===
# Central registry for handling all node execution events
self._event_handler_registry = EventHandler(
graph=self._graph,
graph_runtime_state=self._graph_runtime_state,
graph_execution=self._graph_execution,
response_coordinator=self._response_coordinator,
event_collector=self._event_manager,
edge_processor=self._edge_processor,
state_manager=self._state_manager,
error_handler=self._error_handler,
)
# === Command Processing ===
# Processes external commands (e.g., abort requests)
self._command_processor = CommandProcessor(
@ -153,12 +133,12 @@ class GraphEngine:
graph_execution=self._graph_execution,
)
# Register abort command handler
# Register command handlers
abort_handler = AbortCommandHandler()
self._command_processor.register_handler(
AbortCommand,
abort_handler,
)
self._command_processor.register_handler(AbortCommand, abort_handler)
pause_handler = PauseCommandHandler()
self._command_processor.register_handler(PauseCommand, pause_handler)
# === Worker Pool Setup ===
# Capture Flask app context for worker threads
@ -191,12 +171,23 @@ class GraphEngine:
self._execution_coordinator = ExecutionCoordinator(
graph_execution=self._graph_execution,
state_manager=self._state_manager,
event_handler=self._event_handler_registry,
event_collector=self._event_manager,
command_processor=self._command_processor,
worker_pool=self._worker_pool,
)
# === Event Handler Registry ===
# Central registry for handling all node execution events
self._event_handler_registry = EventHandler(
graph=self._graph,
graph_runtime_state=self._graph_runtime_state,
graph_execution=self._graph_execution,
response_coordinator=self._response_coordinator,
event_collector=self._event_manager,
edge_processor=self._edge_processor,
state_manager=self._state_manager,
error_handler=self._error_handler,
)
# Dispatches events and manages execution flow
self._dispatcher = Dispatcher(
event_queue=self._event_queue,
@ -237,26 +228,41 @@ class GraphEngine:
# Initialize layers
self._initialize_layers()
# Start execution
self._graph_execution.start()
is_resume = self._graph_execution.started
if not is_resume:
self._graph_execution.start()
else:
self._graph_execution.paused = False
self._graph_execution.pause_reason = None
start_event = GraphRunStartedEvent()
self._event_manager.notify_layers(start_event)
yield start_event
# Start subsystems
self._start_execution()
self._start_execution(resume=is_resume)
# Yield events as they occur
yield from self._event_manager.emit_events()
# Handle completion
if self._graph_execution.aborted:
if self._graph_execution.is_paused:
paused_event = GraphRunPausedEvent(
reason=self._graph_execution.pause_reason,
outputs=self._graph_runtime_state.outputs,
)
self._event_manager.notify_layers(paused_event)
yield paused_event
elif self._graph_execution.aborted:
abort_reason = "Workflow execution aborted by user command"
if self._graph_execution.error:
abort_reason = str(self._graph_execution.error)
yield GraphRunAbortedEvent(
aborted_event = GraphRunAbortedEvent(
reason=abort_reason,
outputs=self._graph_runtime_state.outputs,
)
self._event_manager.notify_layers(aborted_event)
yield aborted_event
elif self._graph_execution.has_error:
if self._graph_execution.error:
raise self._graph_execution.error
@ -264,20 +270,26 @@ class GraphEngine:
outputs = self._graph_runtime_state.outputs
exceptions_count = self._graph_execution.exceptions_count
if exceptions_count > 0:
yield GraphRunPartialSucceededEvent(
partial_event = GraphRunPartialSucceededEvent(
exceptions_count=exceptions_count,
outputs=outputs,
)
self._event_manager.notify_layers(partial_event)
yield partial_event
else:
yield GraphRunSucceededEvent(
succeeded_event = GraphRunSucceededEvent(
outputs=outputs,
)
self._event_manager.notify_layers(succeeded_event)
yield succeeded_event
except Exception as e:
yield GraphRunFailedEvent(
failed_event = GraphRunFailedEvent(
error=str(e),
exceptions_count=self._graph_execution.exceptions_count,
)
self._event_manager.notify_layers(failed_event)
yield failed_event
raise
finally:
@ -299,8 +311,12 @@ class GraphEngine:
except Exception as e:
logger.warning("Layer %s failed on_graph_start: %s", layer.__class__.__name__, e)
def _start_execution(self) -> None:
def _start_execution(self, *, resume: bool = False) -> None:
"""Start execution subsystems."""
paused_nodes: list[str] = []
if resume:
paused_nodes = self._graph_runtime_state.consume_paused_nodes()
# Start worker pool (it calculates initial workers internally)
self._worker_pool.start()
@ -309,10 +325,15 @@ class GraphEngine:
if node.execution_type == NodeExecutionType.RESPONSE:
self._response_coordinator.register(node.id)
# Enqueue root node
root_node = self._graph.root_node
self._state_manager.enqueue_node(root_node.id)
self._state_manager.start_execution(root_node.id)
if not resume:
# Enqueue root node
root_node = self._graph.root_node
self._state_manager.enqueue_node(root_node.id)
self._state_manager.start_execution(root_node.id)
else:
for node_id in paused_nodes:
self._state_manager.enqueue_node(node_id)
self._state_manager.start_execution(node_id)
# Start dispatcher
self._dispatcher.start()

View File

@ -7,9 +7,9 @@ intercept and respond to GraphEngine events.
from abc import ABC, abstractmethod
from core.workflow.graph.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
from core.workflow.graph_events import GraphEngineEvent
from core.workflow.runtime import ReadOnlyGraphRuntimeState
class GraphEngineLayer(ABC):

View File

@ -0,0 +1,410 @@
"""Workflow persistence layer for GraphEngine.
This layer mirrors the former ``WorkflowCycleManager`` responsibilities by
listening to ``GraphEngineEvent`` instances directly and persisting workflow
and node execution state via the injected repositories.
The design keeps domain persistence concerns inside the engine thread, while
allowing presentation layers to remain read-only observers of repository
state.
"""
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Union
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution
from core.workflow.enums import (
SystemVariableKey,
WorkflowExecutionStatus,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
WorkflowType,
)
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events import (
GraphEngineEvent,
GraphRunAbortedEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunPausedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunPauseRequestedEvent,
NodeRunRetryEvent,
NodeRunStartedEvent,
NodeRunSucceededEvent,
)
from core.workflow.node_events import NodeRunResult
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.workflow_entry import WorkflowEntry
from libs.datetime_utils import naive_utc_now
@dataclass(slots=True)
class PersistenceWorkflowInfo:
"""Static workflow metadata required for persistence."""
workflow_id: str
workflow_type: WorkflowType
version: str
graph_data: Mapping[str, Any]
@dataclass(slots=True)
class _NodeRuntimeSnapshot:
"""Lightweight cache to keep node metadata across event phases."""
node_id: str
title: str
predecessor_node_id: str | None
iteration_id: str | None
loop_id: str | None
created_at: datetime
class WorkflowPersistenceLayer(GraphEngineLayer):
"""GraphEngine layer that persists workflow and node execution state."""
def __init__(
self,
*,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
workflow_info: PersistenceWorkflowInfo,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
trace_manager: TraceQueueManager | None = None,
) -> None:
super().__init__()
self._application_generate_entity = application_generate_entity
self._workflow_info = workflow_info
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
self._trace_manager = trace_manager
self._workflow_execution: WorkflowExecution | None = None
self._node_execution_cache: dict[str, WorkflowNodeExecution] = {}
self._node_snapshots: dict[str, _NodeRuntimeSnapshot] = {}
self._node_sequence: int = 0
# ------------------------------------------------------------------
# GraphEngineLayer lifecycle
# ------------------------------------------------------------------
def on_graph_start(self) -> None:
self._workflow_execution = None
self._node_execution_cache.clear()
self._node_snapshots.clear()
self._node_sequence = 0
def on_event(self, event: GraphEngineEvent) -> None:
if isinstance(event, GraphRunStartedEvent):
self._handle_graph_run_started()
return
if isinstance(event, GraphRunSucceededEvent):
self._handle_graph_run_succeeded(event)
return
if isinstance(event, GraphRunPartialSucceededEvent):
self._handle_graph_run_partial_succeeded(event)
return
if isinstance(event, GraphRunFailedEvent):
self._handle_graph_run_failed(event)
return
if isinstance(event, GraphRunAbortedEvent):
self._handle_graph_run_aborted(event)
return
if isinstance(event, GraphRunPausedEvent):
self._handle_graph_run_paused(event)
return
if isinstance(event, NodeRunStartedEvent):
self._handle_node_started(event)
return
if isinstance(event, NodeRunRetryEvent):
self._handle_node_retry(event)
return
if isinstance(event, NodeRunSucceededEvent):
self._handle_node_succeeded(event)
return
if isinstance(event, NodeRunFailedEvent):
self._handle_node_failed(event)
return
if isinstance(event, NodeRunExceptionEvent):
self._handle_node_exception(event)
return
if isinstance(event, NodeRunPauseRequestedEvent):
self._handle_node_pause_requested(event)
def on_graph_end(self, error: Exception | None) -> None:
return
# ------------------------------------------------------------------
# Graph-level handlers
# ------------------------------------------------------------------
def _handle_graph_run_started(self) -> None:
execution_id = self._get_execution_id()
workflow_execution = WorkflowExecution.new(
id_=execution_id,
workflow_id=self._workflow_info.workflow_id,
workflow_type=self._workflow_info.workflow_type,
workflow_version=self._workflow_info.version,
graph=self._workflow_info.graph_data,
inputs=self._prepare_workflow_inputs(),
started_at=naive_utc_now(),
)
self._workflow_execution_repository.save(workflow_execution)
self._workflow_execution = workflow_execution
def _handle_graph_run_succeeded(self, event: GraphRunSucceededEvent) -> None:
execution = self._get_workflow_execution()
execution.outputs = event.outputs
execution.status = WorkflowExecutionStatus.SUCCEEDED
self._populate_completion_statistics(execution)
self._workflow_execution_repository.save(execution)
self._enqueue_trace_task(execution)
def _handle_graph_run_partial_succeeded(self, event: GraphRunPartialSucceededEvent) -> None:
execution = self._get_workflow_execution()
execution.outputs = event.outputs
execution.status = WorkflowExecutionStatus.PARTIAL_SUCCEEDED
execution.exceptions_count = event.exceptions_count
self._populate_completion_statistics(execution)
self._workflow_execution_repository.save(execution)
self._enqueue_trace_task(execution)
def _handle_graph_run_failed(self, event: GraphRunFailedEvent) -> None:
execution = self._get_workflow_execution()
execution.status = WorkflowExecutionStatus.FAILED
execution.error_message = event.error
execution.exceptions_count = event.exceptions_count
self._populate_completion_statistics(execution)
self._fail_running_node_executions(error_message=event.error)
self._workflow_execution_repository.save(execution)
self._enqueue_trace_task(execution)
def _handle_graph_run_aborted(self, event: GraphRunAbortedEvent) -> None:
execution = self._get_workflow_execution()
execution.status = WorkflowExecutionStatus.STOPPED
execution.error_message = event.reason or "Workflow execution aborted"
self._populate_completion_statistics(execution)
self._fail_running_node_executions(error_message=execution.error_message or "")
self._workflow_execution_repository.save(execution)
self._enqueue_trace_task(execution)
def _handle_graph_run_paused(self, event: GraphRunPausedEvent) -> None:
execution = self._get_workflow_execution()
execution.status = WorkflowExecutionStatus.PAUSED
execution.error_message = event.reason or "Workflow execution paused"
execution.outputs = event.outputs
self._populate_completion_statistics(execution, update_finished=False)
self._workflow_execution_repository.save(execution)
# ------------------------------------------------------------------
# Node-level handlers
# ------------------------------------------------------------------
def _handle_node_started(self, event: NodeRunStartedEvent) -> None:
execution = self._get_workflow_execution()
metadata = {
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
}
domain_execution = WorkflowNodeExecution(
id=event.id,
node_execution_id=event.id,
workflow_id=execution.workflow_id,
workflow_execution_id=execution.id_,
predecessor_node_id=event.predecessor_node_id,
index=self._next_node_sequence(),
node_id=event.node_id,
node_type=event.node_type,
title=event.node_title,
status=WorkflowNodeExecutionStatus.RUNNING,
metadata=metadata,
created_at=event.start_at,
)
self._node_execution_cache[event.id] = domain_execution
self._workflow_node_execution_repository.save(domain_execution)
snapshot = _NodeRuntimeSnapshot(
node_id=event.node_id,
title=event.node_title,
predecessor_node_id=event.predecessor_node_id,
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
created_at=event.start_at,
)
self._node_snapshots[event.id] = snapshot
def _handle_node_retry(self, event: NodeRunRetryEvent) -> None:
domain_execution = self._get_node_execution(event.id)
domain_execution.status = WorkflowNodeExecutionStatus.RETRY
domain_execution.error = event.error
self._workflow_node_execution_repository.save(domain_execution)
self._workflow_node_execution_repository.save_execution_data(domain_execution)
def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None:
domain_execution = self._get_node_execution(event.id)
self._update_node_execution(domain_execution, event.node_run_result, WorkflowNodeExecutionStatus.SUCCEEDED)
def _handle_node_failed(self, event: NodeRunFailedEvent) -> None:
domain_execution = self._get_node_execution(event.id)
self._update_node_execution(
domain_execution,
event.node_run_result,
WorkflowNodeExecutionStatus.FAILED,
error=event.error,
)
def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None:
domain_execution = self._get_node_execution(event.id)
self._update_node_execution(
domain_execution,
event.node_run_result,
WorkflowNodeExecutionStatus.EXCEPTION,
error=event.error,
)
def _handle_node_pause_requested(self, event: NodeRunPauseRequestedEvent) -> None:
domain_execution = self._get_node_execution(event.id)
self._update_node_execution(
domain_execution,
event.node_run_result,
WorkflowNodeExecutionStatus.PAUSED,
error=event.reason,
update_outputs=False,
)
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _get_execution_id(self) -> str:
workflow_execution_id = self._system_variables().get(SystemVariableKey.WORKFLOW_EXECUTION_ID)
if not workflow_execution_id:
raise ValueError("workflow_execution_id must be provided in system variables for pause/resume flows")
return str(workflow_execution_id)
def _prepare_workflow_inputs(self) -> Mapping[str, Any]:
inputs = {**self._application_generate_entity.inputs}
for field_name, value in self._system_variables().items():
if field_name == SystemVariableKey.CONVERSATION_ID.value:
# Conversation IDs are tied to the current session; omit them so persisted
# workflow inputs stay reusable without binding future runs to this conversation.
continue
inputs[f"sys.{field_name}"] = value
handled = WorkflowEntry.handle_special_values(inputs)
return handled or {}
def _get_workflow_execution(self) -> WorkflowExecution:
if self._workflow_execution is None:
raise ValueError("workflow execution not initialized")
return self._workflow_execution
def _get_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution:
if node_execution_id not in self._node_execution_cache:
raise ValueError(f"Node execution not found for id={node_execution_id}")
return self._node_execution_cache[node_execution_id]
def _next_node_sequence(self) -> int:
self._node_sequence += 1
return self._node_sequence
def _populate_completion_statistics(self, execution: WorkflowExecution, *, update_finished: bool = True) -> None:
if update_finished:
execution.finished_at = naive_utc_now()
runtime_state = self.graph_runtime_state
if runtime_state is None:
return
execution.total_tokens = runtime_state.total_tokens
execution.total_steps = runtime_state.node_run_steps
execution.outputs = execution.outputs or runtime_state.outputs
execution.exceptions_count = runtime_state.exceptions_count
def _update_node_execution(
self,
domain_execution: WorkflowNodeExecution,
node_result: NodeRunResult,
status: WorkflowNodeExecutionStatus,
*,
error: str | None = None,
update_outputs: bool = True,
) -> None:
finished_at = naive_utc_now()
snapshot = self._node_snapshots.get(domain_execution.id)
start_at = snapshot.created_at if snapshot else domain_execution.created_at
domain_execution.status = status
domain_execution.finished_at = finished_at
domain_execution.elapsed_time = max((finished_at - start_at).total_seconds(), 0.0)
if error:
domain_execution.error = error
if update_outputs:
domain_execution.update_from_mapping(
inputs=node_result.inputs,
process_data=node_result.process_data,
outputs=node_result.outputs,
metadata=node_result.metadata,
)
self._workflow_node_execution_repository.save(domain_execution)
self._workflow_node_execution_repository.save_execution_data(domain_execution)
def _fail_running_node_executions(self, *, error_message: str) -> None:
now = naive_utc_now()
for execution in self._node_execution_cache.values():
if execution.status == WorkflowNodeExecutionStatus.RUNNING:
execution.status = WorkflowNodeExecutionStatus.FAILED
execution.error = error_message
execution.finished_at = now
execution.elapsed_time = max((now - execution.created_at).total_seconds(), 0.0)
self._workflow_node_execution_repository.save(execution)
def _enqueue_trace_task(self, execution: WorkflowExecution) -> None:
if not self._trace_manager:
return
conversation_id = self._system_variables().get(SystemVariableKey.CONVERSATION_ID.value)
external_trace_id = None
if isinstance(self._application_generate_entity, (WorkflowAppGenerateEntity, AdvancedChatAppGenerateEntity)):
external_trace_id = self._application_generate_entity.extras.get("external_trace_id")
trace_task = TraceTask(
TraceTaskName.WORKFLOW_TRACE,
workflow_execution=execution,
conversation_id=conversation_id,
user_id=self._trace_manager.user_id,
external_trace_id=external_trace_id,
)
self._trace_manager.add_trace_task(trace_task)
def _system_variables(self) -> Mapping[str, Any]:
runtime_state = self.graph_runtime_state
if runtime_state is None:
return {}
return runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID)

View File

@ -9,7 +9,7 @@ Supports stop, pause, and resume operations.
from typing import final
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.entities.commands import AbortCommand
from core.workflow.graph_engine.entities.commands import AbortCommand, GraphEngineCommand, PauseCommand
from extensions.ext_redis import redis_client
@ -20,7 +20,7 @@ class GraphEngineManager:
This class provides a simple interface for controlling workflow executions
by sending commands through Redis channels, without user validation.
Supports stop, pause, and resume operations.
Supports stop and pause operations.
"""
@staticmethod
@ -32,19 +32,29 @@ class GraphEngineManager:
task_id: The task ID of the workflow to stop
reason: Optional reason for stopping (defaults to "User requested stop")
"""
abort_command = AbortCommand(reason=reason or "User requested stop")
GraphEngineManager._send_command(task_id, abort_command)
@staticmethod
def send_pause_command(task_id: str, reason: str | None = None) -> None:
"""Send a pause command to a running workflow."""
pause_command = PauseCommand(reason=reason or "User requested pause")
GraphEngineManager._send_command(task_id, pause_command)
@staticmethod
def _send_command(task_id: str, command: GraphEngineCommand) -> None:
"""Send a command to the workflow-specific Redis channel."""
if not task_id:
return
# Create Redis channel for this task
channel_key = f"workflow:{task_id}:commands"
channel = RedisChannel(redis_client, channel_key)
# Create and send abort command
abort_command = AbortCommand(reason=reason or "User requested stop")
try:
channel.send_command(abort_command)
channel.send_command(command)
except Exception:
# Silently fail if Redis is unavailable
# The legacy stop flag mechanism will still work
# The legacy control mechanisms will still work
pass

View File

@ -8,7 +8,12 @@ import threading
import time
from typing import TYPE_CHECKING, final
from core.workflow.graph_events.base import GraphNodeEventBase
from core.workflow.graph_events import (
GraphNodeEventBase,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunSucceededEvent,
)
from ..event_management import EventManager
from .execution_coordinator import ExecutionCoordinator
@ -28,6 +33,12 @@ class Dispatcher:
with timeout and completion detection.
"""
_COMMAND_TRIGGER_EVENTS = (
NodeRunSucceededEvent,
NodeRunFailedEvent,
NodeRunExceptionEvent,
)
def __init__(
self,
event_queue: queue.Queue[GraphNodeEventBase],
@ -76,22 +87,37 @@ class Dispatcher:
"""Main dispatcher loop."""
try:
while not self._stop_event.is_set():
# Check for commands
self._execution_coordinator.check_commands()
commands_checked = False
should_check_commands = False
should_break = False
# Check for scaling
self._execution_coordinator.check_scaling()
if self._execution_coordinator.is_execution_complete():
should_check_commands = True
should_break = True
else:
# Check for scaling
self._execution_coordinator.check_scaling()
# Process events
try:
event = self._event_queue.get(timeout=0.1)
# Route to the event handler
self._event_handler.dispatch(event)
self._event_queue.task_done()
except queue.Empty:
# Check if execution is complete
if self._execution_coordinator.is_execution_complete():
break
# Process events
try:
event = self._event_queue.get(timeout=0.1)
# Route to the event handler
self._event_handler.dispatch(event)
should_check_commands = self._should_check_commands(event)
self._event_queue.task_done()
except queue.Empty:
# Process commands even when no new events arrive so abort requests are not missed
should_check_commands = True
time.sleep(0.1)
if should_check_commands and not commands_checked:
self._execution_coordinator.check_commands()
commands_checked = True
if should_break:
if not commands_checked:
self._execution_coordinator.check_commands()
break
except Exception as e:
logger.exception("Dispatcher error")
@ -102,3 +128,7 @@ class Dispatcher:
# Signal the event emitter that execution is complete
if self._event_emitter:
self._event_emitter.mark_complete()
def _should_check_commands(self, event: GraphNodeEventBase) -> bool:
"""Return True if the event represents a node completion."""
return isinstance(event, self._COMMAND_TRIGGER_EVENTS)

View File

@ -2,17 +2,13 @@
Execution coordinator for managing overall workflow execution.
"""
from typing import TYPE_CHECKING, final
from typing import final
from ..command_processing import CommandProcessor
from ..domain import GraphExecution
from ..event_management import EventManager
from ..graph_state_manager import GraphStateManager
from ..worker_management import WorkerPool
if TYPE_CHECKING:
from ..event_management import EventHandler
@final
class ExecutionCoordinator:
@ -27,8 +23,6 @@ class ExecutionCoordinator:
self,
graph_execution: GraphExecution,
state_manager: GraphStateManager,
event_handler: "EventHandler",
event_collector: EventManager,
command_processor: CommandProcessor,
worker_pool: WorkerPool,
) -> None:
@ -38,15 +32,11 @@ class ExecutionCoordinator:
Args:
graph_execution: Graph execution aggregate
state_manager: Unified state manager
event_handler: Event handler registry for processing events
event_collector: Event manager for collecting events
command_processor: Processor for commands
worker_pool: Pool of workers
"""
self._graph_execution = graph_execution
self._state_manager = state_manager
self._event_handler = event_handler
self._event_collector = event_collector
self._command_processor = command_processor
self._worker_pool = worker_pool
@ -65,15 +55,24 @@ class ExecutionCoordinator:
Returns:
True if execution is complete
"""
# Check if aborted or failed
# Treat paused, aborted, or failed executions as terminal states
if self._graph_execution.is_paused:
return True
if self._graph_execution.aborted or self._graph_execution.has_error:
return True
# Complete if no work remains
return self._state_manager.is_execution_complete()
@property
def is_paused(self) -> bool:
"""Expose whether the underlying graph execution is paused."""
return self._graph_execution.is_paused
def mark_complete(self) -> None:
"""Mark execution as complete."""
if self._graph_execution.is_paused:
return
if not self._graph_execution.completed:
self._graph_execution.complete()
@ -85,3 +84,21 @@ class ExecutionCoordinator:
error: The error that caused failure
"""
self._graph_execution.fail(error)
def handle_pause_if_needed(self) -> None:
"""If the execution has been paused, stop workers immediately."""
if not self._graph_execution.is_paused:
return
self._worker_pool.stop()
self._state_manager.clear_executing()
def handle_abort_if_needed(self) -> None:
"""If the execution has been aborted, stop workers immediately."""
if not self._graph_execution.aborted:
return
self._worker_pool.stop()
self._state_manager.clear_executing()

View File

@ -14,11 +14,11 @@ from uuid import uuid4
from pydantic import BaseModel, Field
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import NodeExecutionType, NodeState
from core.workflow.graph import Graph
from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent
from core.workflow.nodes.base.template import TextSegment, VariableSegment
from core.workflow.runtime import VariablePool
from .path import Path
from .session import ResponseSession

View File

@ -13,6 +13,7 @@ from .graph import (
GraphRunAbortedEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunPausedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
)
@ -37,6 +38,7 @@ from .loop import (
from .node import (
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunPauseRequestedEvent,
NodeRunRetrieverResourceEvent,
NodeRunRetryEvent,
NodeRunStartedEvent,
@ -51,6 +53,7 @@ __all__ = [
"GraphRunAbortedEvent",
"GraphRunFailedEvent",
"GraphRunPartialSucceededEvent",
"GraphRunPausedEvent",
"GraphRunStartedEvent",
"GraphRunSucceededEvent",
"NodeRunAgentLogEvent",
@ -64,6 +67,7 @@ __all__ = [
"NodeRunLoopNextEvent",
"NodeRunLoopStartedEvent",
"NodeRunLoopSucceededEvent",
"NodeRunPauseRequestedEvent",
"NodeRunRetrieverResourceEvent",
"NodeRunRetryEvent",
"NodeRunStartedEvent",

View File

@ -8,7 +8,12 @@ class GraphRunStartedEvent(BaseGraphEvent):
class GraphRunSucceededEvent(BaseGraphEvent):
outputs: dict[str, object] = Field(default_factory=dict)
"""Event emitted when a run completes successfully with final outputs."""
outputs: dict[str, object] = Field(
default_factory=dict,
description="Final workflow outputs keyed by output selector.",
)
class GraphRunFailedEvent(BaseGraphEvent):
@ -17,12 +22,30 @@ class GraphRunFailedEvent(BaseGraphEvent):
class GraphRunPartialSucceededEvent(BaseGraphEvent):
"""Event emitted when a run finishes with partial success and failures."""
exceptions_count: int = Field(..., description="exception count")
outputs: dict[str, object] = Field(default_factory=dict)
outputs: dict[str, object] = Field(
default_factory=dict,
description="Outputs that were materialised before failures occurred.",
)
class GraphRunAbortedEvent(BaseGraphEvent):
"""Event emitted when a graph run is aborted by user command."""
reason: str | None = Field(default=None, description="reason for abort")
outputs: dict[str, object] = Field(default_factory=dict, description="partial outputs if any")
outputs: dict[str, object] = Field(
default_factory=dict,
description="Outputs produced before the abort was requested.",
)
class GraphRunPausedEvent(BaseGraphEvent):
"""Event emitted when a graph run is paused by user command."""
reason: str | None = Field(default=None, description="reason for pause")
outputs: dict[str, object] = Field(
default_factory=dict,
description="Outputs available to the client while the run is paused.",
)

View File

@ -51,3 +51,7 @@ class NodeRunExceptionEvent(GraphNodeEventBase):
class NodeRunRetryEvent(NodeRunStartedEvent):
error: str = Field(..., description="error")
retry_index: int = Field(..., description="which retry attempt is about to be performed")
class NodeRunPauseRequestedEvent(GraphNodeEventBase):
reason: str | None = Field(default=None, description="Optional pause reason")

View File

@ -14,6 +14,7 @@ from .loop import (
)
from .node import (
ModelInvokeCompletedEvent,
PauseRequestedEvent,
RunRetrieverResourceEvent,
RunRetryEvent,
StreamChunkEvent,
@ -33,6 +34,7 @@ __all__ = [
"ModelInvokeCompletedEvent",
"NodeEventBase",
"NodeRunResult",
"PauseRequestedEvent",
"RunRetrieverResourceEvent",
"RunRetryEvent",
"StreamChunkEvent",

View File

@ -40,3 +40,7 @@ class StreamChunkEvent(NodeEventBase):
class StreamCompletedEvent(NodeEventBase):
node_run_result: NodeRunResult = Field(..., description="run result")
class PauseRequestedEvent(NodeEventBase):
reason: str | None = Field(default=None, description="Optional pause reason")

View File

@ -25,7 +25,6 @@ from core.tools.entities.tool_entities import (
from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.variables.segments import ArrayFileSegment, StringSegment
from core.workflow.entities import VariablePool
from core.workflow.enums import (
ErrorStrategy,
NodeType,
@ -44,6 +43,7 @@ from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionMod
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
from factories import file_factory
from factories.agent_factory import get_plugin_agent_strategy

View File

@ -6,7 +6,7 @@ from typing import Any, ClassVar
from uuid import uuid4
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import AgentNodeStrategyInit, GraphInitParams, GraphRuntimeState
from core.workflow.entities import AgentNodeStrategyInit, GraphInitParams
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph_events import (
GraphNodeEventBase,
@ -20,6 +20,7 @@ from core.workflow.graph_events import (
NodeRunLoopNextEvent,
NodeRunLoopStartedEvent,
NodeRunLoopSucceededEvent,
NodeRunPauseRequestedEvent,
NodeRunRetrieverResourceEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
@ -37,10 +38,12 @@ from core.workflow.node_events import (
LoopSucceededEvent,
NodeEventBase,
NodeRunResult,
PauseRequestedEvent,
RunRetrieverResourceEvent,
StreamChunkEvent,
StreamCompletedEvent,
)
from core.workflow.runtime import GraphRuntimeState
from libs.datetime_utils import naive_utc_now
from models.enums import UserFrom
@ -385,6 +388,16 @@ class Node:
f"Node {self._node_id} does not support status {event.node_run_result.status}"
)
@_dispatch.register
def _(self, event: PauseRequestedEvent) -> NodeRunPauseRequestedEvent:
return NodeRunPauseRequestedEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.PAUSED),
reason=event.reason,
)
@_dispatch.register
def _(self, event: AgentLogEvent) -> NodeRunAgentLogEvent:
return NodeRunAgentLogEvent(

View File

@ -19,7 +19,6 @@ from core.file.enums import FileTransferMethod, FileType
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.variables.segments import ArrayAnySegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, SystemVariableKey
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
@ -27,6 +26,7 @@ from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.tool.exc import ToolFileError
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
from factories import file_factory
from models.model import UploadFile

View File

@ -15,7 +15,7 @@ from core.file import file_manager
from core.file.enums import FileTransferMethod
from core.helper import ssrf_proxy
from core.variables.segments import ArrayFileSegment, FileSegment
from core.workflow.entities import VariablePool
from core.workflow.runtime import VariablePool
from .entities import (
HttpRequestNodeAuthorization,

View File

@ -0,0 +1,3 @@
from .human_input_node import HumanInputNode
__all__ = ["HumanInputNode"]

View File

@ -0,0 +1,10 @@
from pydantic import Field
from core.workflow.nodes.base import BaseNodeData
class HumanInputNodeData(BaseNodeData):
"""Configuration schema for the HumanInput node."""
required_variables: list[str] = Field(default_factory=list)
pause_reason: str | None = Field(default=None)

View File

@ -0,0 +1,132 @@
from collections.abc import Mapping
from typing import Any
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult, PauseRequestedEvent
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from .entities import HumanInputNodeData
class HumanInputNode(Node):
node_type = NodeType.HUMAN_INPUT
execution_type = NodeExecutionType.BRANCH
_BRANCH_SELECTION_KEYS: tuple[str, ...] = (
"edge_source_handle",
"edgeSourceHandle",
"source_handle",
"selected_branch",
"selectedBranch",
"branch",
"branch_id",
"branchId",
"handle",
)
_node_data: HumanInputNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = HumanInputNodeData(**data)
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@classmethod
def version(cls) -> str:
return "1"
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def _run(self): # type: ignore[override]
if self._is_completion_ready():
branch_handle = self._resolve_branch_selection()
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={},
edge_source_handle=branch_handle or "source",
)
return self._pause_generator()
def _pause_generator(self):
yield PauseRequestedEvent(reason=self._node_data.pause_reason)
def _is_completion_ready(self) -> bool:
"""Determine whether all required inputs are satisfied."""
if not self._node_data.required_variables:
return False
variable_pool = self.graph_runtime_state.variable_pool
for selector_str in self._node_data.required_variables:
parts = selector_str.split(".")
if len(parts) != 2:
return False
segment = variable_pool.get(parts)
if segment is None:
return False
return True
def _resolve_branch_selection(self) -> str | None:
"""Determine the branch handle selected by human input if available."""
variable_pool = self.graph_runtime_state.variable_pool
for key in self._BRANCH_SELECTION_KEYS:
handle = self._extract_branch_handle(variable_pool.get((self.id, key)))
if handle:
return handle
default_values = self._node_data.default_value_dict
for key in self._BRANCH_SELECTION_KEYS:
handle = self._normalize_branch_value(default_values.get(key))
if handle:
return handle
return None
@staticmethod
def _extract_branch_handle(segment: Any) -> str | None:
if segment is None:
return None
candidate = getattr(segment, "to_object", None)
raw_value = candidate() if callable(candidate) else getattr(segment, "value", None)
if raw_value is None:
return None
return HumanInputNode._normalize_branch_value(raw_value)
@staticmethod
def _normalize_branch_value(value: Any) -> str | None:
if value is None:
return None
if isinstance(value, str):
stripped = value.strip()
return stripped or None
if isinstance(value, Mapping):
for key in ("handle", "edge_source_handle", "edgeSourceHandle", "branch", "id", "value"):
candidate = value.get(key)
if isinstance(candidate, str) and candidate:
return candidate
return None

View File

@ -3,12 +3,12 @@ from typing import Any, Literal
from typing_extensions import deprecated
from core.workflow.entities import VariablePool
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.if_else.entities import IfElseNodeData
from core.workflow.runtime import VariablePool
from core.workflow.utils.condition.entities import Condition
from core.workflow.utils.condition.processor import ConditionProcessor

View File

@ -12,7 +12,6 @@ from core.variables import IntegerVariable, NoneSegment
from core.variables.segments import ArrayAnySegment, ArraySegment
from core.variables.variables import VariableUnion
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.entities import VariablePool
from core.workflow.enums import (
ErrorStrategy,
NodeExecutionType,
@ -38,6 +37,7 @@ from core.workflow.node_events import (
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from core.workflow.runtime import VariablePool
from libs.datetime_utils import naive_utc_now
from libs.flask_utils import preserve_flask_contexts
@ -557,11 +557,12 @@ class IterationNode(Node):
def _create_graph_engine(self, index: int, item: object):
# Import dependencies
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.runtime import GraphRuntimeState
# Create GraphInitParams from node attributes
graph_init_params = GraphInitParams(

View File

@ -9,13 +9,13 @@ from sqlalchemy import func, select
from core.app.entities.app_invoke_entities import InvokeFrom
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, SystemVariableKey
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.template import Template
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment

View File

@ -67,7 +67,7 @@ from .exc import (
if TYPE_CHECKING:
from core.file.models import File
from core.workflow.entities import GraphRuntimeState
from core.workflow.runtime import GraphRuntimeState
logger = logging.getLogger(__name__)

View File

@ -15,9 +15,9 @@ from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
from core.workflow.entities import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.llm.entities import ModelConfig
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.model import Conversation

View File

@ -52,7 +52,7 @@ from core.variables import (
StringSegment,
)
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities import GraphInitParams, VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.enums import (
ErrorStrategy,
NodeType,
@ -71,6 +71,7 @@ from core.workflow.node_events import (
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.runtime import VariablePool
from . import llm_utils
from .entities import (
@ -93,7 +94,7 @@ from .file_saver import FileSaverImpl, LLMFileSaver
if TYPE_CHECKING:
from core.file.models import File
from core.workflow.entities import GraphRuntimeState
from core.workflow.runtime import GraphRuntimeState
logger = logging.getLogger(__name__)

View File

@ -406,11 +406,12 @@ class LoopNode(Node):
def _create_graph_engine(self, start_at: datetime, root_node_id: str):
# Import dependencies
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.runtime import GraphRuntimeState
# Create GraphInitParams from node attributes
graph_init_params = GraphInitParams(

View File

@ -10,7 +10,8 @@ from libs.typing import is_str, is_str_dict
from .node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams, GraphRuntimeState
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
@final

View File

@ -9,6 +9,7 @@ from core.workflow.nodes.datasource.datasource_node import DatasourceNode
from core.workflow.nodes.document_extractor import DocumentExtractorNode
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.http_request import HttpRequestNode
from core.workflow.nodes.human_input import HumanInputNode
from core.workflow.nodes.if_else import IfElseNode
from core.workflow.nodes.iteration import IterationNode, IterationStartNode
from core.workflow.nodes.knowledge_index import KnowledgeIndexNode
@ -134,6 +135,10 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = {
"2": AgentNode,
"1": AgentNode,
},
NodeType.HUMAN_INPUT: {
LATEST_VERSION: HumanInputNode,
"1": HumanInputNode,
},
NodeType.DATASOURCE: {
LATEST_VERSION: DatasourceNode,
"1": DatasourceNode,

View File

@ -27,13 +27,13 @@ from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, Comp
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.variables.types import ArrayValidation, SegmentType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base import variable_template_parser
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.llm import ModelConfig, llm_utils
from core.workflow.runtime import VariablePool
from factories.variable_factory import build_segment_with_type
from .entities import ParameterExtractorNodeData

View File

@ -1,4 +1,5 @@
import json
import re
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
@ -40,7 +41,7 @@ from .template_prompts import (
if TYPE_CHECKING:
from core.file.models import File
from core.workflow.entities import GraphRuntimeState
from core.workflow.runtime import GraphRuntimeState
class QuestionClassifierNode(Node):
@ -194,6 +195,8 @@ class QuestionClassifierNode(Node):
category_name = node_data.classes[0].name
category_id = node_data.classes[0].id
if "<think>" in result_text:
result_text = re.sub(r"<think[^>]*>[\s\S]*?</think>", "", result_text, flags=re.IGNORECASE)
result_text_json = parse_and_check_json_markdown(result_text, [])
# result_text_json = json.loads(result_text.strip('```JSON\n'))
if "category_name" in result_text_json and "category_id" in result_text_json:

View File

@ -36,7 +36,7 @@ from .exc import (
)
if TYPE_CHECKING:
from core.workflow.entities import VariablePool
from core.workflow.runtime import VariablePool
class ToolNode(Node):

View File

@ -18,7 +18,7 @@ from ..common.impl import conversation_variable_updater_factory
from .node_data import VariableAssignerData, WriteMode
if TYPE_CHECKING:
from core.workflow.entities import GraphRuntimeState
from core.workflow.runtime import GraphRuntimeState
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]

View File

@ -0,0 +1,14 @@
from .graph_runtime_state import GraphRuntimeState
from .graph_runtime_state_protocol import ReadOnlyGraphRuntimeState, ReadOnlyVariablePool
from .read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper, ReadOnlyVariablePoolWrapper
from .variable_pool import VariablePool, VariableValue
__all__ = [
"GraphRuntimeState",
"ReadOnlyGraphRuntimeState",
"ReadOnlyGraphRuntimeStateWrapper",
"ReadOnlyVariablePool",
"ReadOnlyVariablePoolWrapper",
"VariablePool",
"VariableValue",
]

View File

@ -0,0 +1,393 @@
from __future__ import annotations
import importlib
import json
from collections.abc import Mapping, Sequence
from collections.abc import Mapping as TypingMapping
from copy import deepcopy
from typing import Any, Protocol
from pydantic.json import pydantic_encoder
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.runtime.variable_pool import VariablePool
class ReadyQueueProtocol(Protocol):
"""Structural interface required from ready queue implementations."""
def put(self, item: str) -> None:
"""Enqueue the identifier of a node that is ready to run."""
...
def get(self, timeout: float | None = None) -> str:
"""Return the next node identifier, blocking until available or timeout expires."""
...
def task_done(self) -> None:
"""Signal that the most recently dequeued node has completed processing."""
...
def empty(self) -> bool:
"""Return True when the queue contains no pending nodes."""
...
def qsize(self) -> int:
"""Approximate the number of pending nodes awaiting execution."""
...
def dumps(self) -> str:
"""Serialize the queue contents for persistence."""
...
def loads(self, data: str) -> None:
"""Restore the queue contents from a serialized payload."""
...
class GraphExecutionProtocol(Protocol):
"""Structural interface for graph execution aggregate."""
workflow_id: str
started: bool
completed: bool
aborted: bool
error: Exception | None
exceptions_count: int
def start(self) -> None:
"""Transition execution into the running state."""
...
def complete(self) -> None:
"""Mark execution as successfully completed."""
...
def abort(self, reason: str) -> None:
"""Abort execution in response to an external stop request."""
...
def fail(self, error: Exception) -> None:
"""Record an unrecoverable error and end execution."""
...
def dumps(self) -> str:
"""Serialize execution state into a JSON payload."""
...
def loads(self, data: str) -> None:
"""Restore execution state from a previously serialized payload."""
...
class ResponseStreamCoordinatorProtocol(Protocol):
"""Structural interface for response stream coordinator."""
def register(self, response_node_id: str) -> None:
"""Register a response node so its outputs can be streamed."""
...
def loads(self, data: str) -> None:
"""Restore coordinator state from a serialized payload."""
...
def dumps(self) -> str:
"""Serialize coordinator state for persistence."""
...
class GraphProtocol(Protocol):
"""Structural interface required from graph instances attached to the runtime state."""
nodes: TypingMapping[str, object]
edges: TypingMapping[str, object]
root_node: object
def get_outgoing_edges(self, node_id: str) -> Sequence[object]: ...
class GraphRuntimeState:
"""Mutable runtime state shared across graph execution components."""
def __init__(
self,
*,
variable_pool: VariablePool,
start_at: float,
total_tokens: int = 0,
llm_usage: LLMUsage | None = None,
outputs: dict[str, object] | None = None,
node_run_steps: int = 0,
ready_queue: ReadyQueueProtocol | None = None,
graph_execution: GraphExecutionProtocol | None = None,
response_coordinator: ResponseStreamCoordinatorProtocol | None = None,
graph: GraphProtocol | None = None,
) -> None:
self._variable_pool = variable_pool
self._start_at = start_at
if total_tokens < 0:
raise ValueError("total_tokens must be non-negative")
self._total_tokens = total_tokens
self._llm_usage = (llm_usage or LLMUsage.empty_usage()).model_copy()
self._outputs = deepcopy(outputs) if outputs is not None else {}
if node_run_steps < 0:
raise ValueError("node_run_steps must be non-negative")
self._node_run_steps = node_run_steps
self._graph: GraphProtocol | None = None
self._ready_queue = ready_queue
self._graph_execution = graph_execution
self._response_coordinator = response_coordinator
self._pending_response_coordinator_dump: str | None = None
self._pending_graph_execution_workflow_id: str | None = None
self._paused_nodes: set[str] = set()
if graph is not None:
self.attach_graph(graph)
# ------------------------------------------------------------------
# Context binding helpers
# ------------------------------------------------------------------
def attach_graph(self, graph: GraphProtocol) -> None:
"""Attach the materialized graph to the runtime state."""
if self._graph is not None and self._graph is not graph:
raise ValueError("GraphRuntimeState already attached to a different graph instance")
self._graph = graph
if self._response_coordinator is None:
self._response_coordinator = self._build_response_coordinator(graph)
if self._pending_response_coordinator_dump is not None and self._response_coordinator is not None:
self._response_coordinator.loads(self._pending_response_coordinator_dump)
self._pending_response_coordinator_dump = None
def configure(self, *, graph: GraphProtocol | None = None) -> None:
"""Ensure core collaborators are initialized with the provided context."""
if graph is not None:
self.attach_graph(graph)
# Ensure collaborators are instantiated
_ = self.ready_queue
_ = self.graph_execution
if self._graph is not None:
_ = self.response_coordinator
# ------------------------------------------------------------------
# Primary collaborators
# ------------------------------------------------------------------
@property
def variable_pool(self) -> VariablePool:
return self._variable_pool
@property
def ready_queue(self) -> ReadyQueueProtocol:
if self._ready_queue is None:
self._ready_queue = self._build_ready_queue()
return self._ready_queue
@property
def graph_execution(self) -> GraphExecutionProtocol:
if self._graph_execution is None:
self._graph_execution = self._build_graph_execution()
return self._graph_execution
@property
def response_coordinator(self) -> ResponseStreamCoordinatorProtocol:
if self._response_coordinator is None:
if self._graph is None:
raise ValueError("Graph must be attached before accessing response coordinator")
self._response_coordinator = self._build_response_coordinator(self._graph)
return self._response_coordinator
# ------------------------------------------------------------------
# Scalar state
# ------------------------------------------------------------------
@property
def start_at(self) -> float:
return self._start_at
@start_at.setter
def start_at(self, value: float) -> None:
self._start_at = value
@property
def total_tokens(self) -> int:
return self._total_tokens
@total_tokens.setter
def total_tokens(self, value: int) -> None:
if value < 0:
raise ValueError("total_tokens must be non-negative")
self._total_tokens = value
@property
def llm_usage(self) -> LLMUsage:
return self._llm_usage.model_copy()
@llm_usage.setter
def llm_usage(self, value: LLMUsage) -> None:
self._llm_usage = value.model_copy()
@property
def outputs(self) -> dict[str, Any]:
return deepcopy(self._outputs)
@outputs.setter
def outputs(self, value: dict[str, Any]) -> None:
self._outputs = deepcopy(value)
def set_output(self, key: str, value: object) -> None:
self._outputs[key] = deepcopy(value)
def get_output(self, key: str, default: object = None) -> object:
return deepcopy(self._outputs.get(key, default))
def update_outputs(self, updates: dict[str, object]) -> None:
for key, value in updates.items():
self._outputs[key] = deepcopy(value)
@property
def node_run_steps(self) -> int:
return self._node_run_steps
@node_run_steps.setter
def node_run_steps(self, value: int) -> None:
if value < 0:
raise ValueError("node_run_steps must be non-negative")
self._node_run_steps = value
def increment_node_run_steps(self) -> None:
self._node_run_steps += 1
def add_tokens(self, tokens: int) -> None:
if tokens < 0:
raise ValueError("tokens must be non-negative")
self._total_tokens += tokens
# ------------------------------------------------------------------
# Serialization
# ------------------------------------------------------------------
def dumps(self) -> str:
"""Serialize runtime state into a JSON string."""
snapshot: dict[str, Any] = {
"version": "1.0",
"start_at": self._start_at,
"total_tokens": self._total_tokens,
"node_run_steps": self._node_run_steps,
"llm_usage": self._llm_usage.model_dump(mode="json"),
"outputs": self.outputs,
"variable_pool": self.variable_pool.model_dump(mode="json"),
"ready_queue": self.ready_queue.dumps(),
"graph_execution": self.graph_execution.dumps(),
"paused_nodes": list(self._paused_nodes),
}
if self._response_coordinator is not None and self._graph is not None:
snapshot["response_coordinator"] = self._response_coordinator.dumps()
return json.dumps(snapshot, default=pydantic_encoder)
def loads(self, data: str | Mapping[str, Any]) -> None:
"""Restore runtime state from a serialized snapshot."""
payload: dict[str, Any]
if isinstance(data, str):
payload = json.loads(data)
else:
payload = dict(data)
version = payload.get("version")
if version != "1.0":
raise ValueError(f"Unsupported GraphRuntimeState snapshot version: {version}")
self._start_at = float(payload.get("start_at", 0.0))
total_tokens = int(payload.get("total_tokens", 0))
if total_tokens < 0:
raise ValueError("total_tokens must be non-negative")
self._total_tokens = total_tokens
node_run_steps = int(payload.get("node_run_steps", 0))
if node_run_steps < 0:
raise ValueError("node_run_steps must be non-negative")
self._node_run_steps = node_run_steps
llm_usage_payload = payload.get("llm_usage", {})
self._llm_usage = LLMUsage.model_validate(llm_usage_payload)
self._outputs = deepcopy(payload.get("outputs", {}))
variable_pool_payload = payload.get("variable_pool")
if variable_pool_payload is not None:
self._variable_pool = VariablePool.model_validate(variable_pool_payload)
ready_queue_payload = payload.get("ready_queue")
if ready_queue_payload is not None:
self._ready_queue = self._build_ready_queue()
self._ready_queue.loads(ready_queue_payload)
else:
self._ready_queue = None
graph_execution_payload = payload.get("graph_execution")
self._graph_execution = None
self._pending_graph_execution_workflow_id = None
if graph_execution_payload is not None:
try:
execution_payload = json.loads(graph_execution_payload)
self._pending_graph_execution_workflow_id = execution_payload.get("workflow_id")
except (json.JSONDecodeError, TypeError, AttributeError):
self._pending_graph_execution_workflow_id = None
self.graph_execution.loads(graph_execution_payload)
response_payload = payload.get("response_coordinator")
if response_payload is not None:
if self._graph is not None:
self.response_coordinator.loads(response_payload)
else:
self._pending_response_coordinator_dump = response_payload
else:
self._pending_response_coordinator_dump = None
self._response_coordinator = None
paused_nodes_payload = payload.get("paused_nodes", [])
self._paused_nodes = set(map(str, paused_nodes_payload))
def register_paused_node(self, node_id: str) -> None:
"""Record a node that should resume when execution is continued."""
self._paused_nodes.add(node_id)
def consume_paused_nodes(self) -> list[str]:
"""Retrieve and clear the list of paused nodes awaiting resume."""
nodes = list(self._paused_nodes)
self._paused_nodes.clear()
return nodes
# ------------------------------------------------------------------
# Builders
# ------------------------------------------------------------------
def _build_ready_queue(self) -> ReadyQueueProtocol:
# Import lazily to avoid breaching architecture boundaries enforced by import-linter.
module = importlib.import_module("core.workflow.graph_engine.ready_queue")
in_memory_cls = module.InMemoryReadyQueue
return in_memory_cls()
def _build_graph_execution(self) -> GraphExecutionProtocol:
# Lazily import to keep the runtime domain decoupled from graph_engine modules.
module = importlib.import_module("core.workflow.graph_engine.domain.graph_execution")
graph_execution_cls = module.GraphExecution
workflow_id = self._pending_graph_execution_workflow_id or ""
self._pending_graph_execution_workflow_id = None
return graph_execution_cls(workflow_id=workflow_id)
def _build_response_coordinator(self, graph: GraphProtocol) -> ResponseStreamCoordinatorProtocol:
# Lazily import to keep the runtime domain decoupled from graph_engine modules.
module = importlib.import_module("core.workflow.graph_engine.response_coordinator")
coordinator_cls = module.ResponseStreamCoordinator
return coordinator_cls(variable_pool=self.variable_pool, graph=graph)

View File

@ -16,6 +16,10 @@ class ReadOnlyVariablePool(Protocol):
"""Get all variables for a node (read-only)."""
...
def get_by_prefix(self, prefix: str) -> Mapping[str, object]:
"""Get all variables stored under a given node prefix (read-only)."""
...
class ReadOnlyGraphRuntimeState(Protocol):
"""
@ -56,6 +60,20 @@ class ReadOnlyGraphRuntimeState(Protocol):
"""Get the node run steps count (read-only)."""
...
@property
def ready_queue_size(self) -> int:
"""Get the number of nodes currently in the ready queue."""
...
@property
def exceptions_count(self) -> int:
"""Get the number of node execution exceptions recorded."""
...
def get_output(self, key: str, default: Any = None) -> Any:
"""Get a single output value (returns a copy)."""
...
def dumps(self) -> str:
"""Serialize the runtime state into a JSON snapshot (read-only)."""
...

View File

@ -1,77 +1,82 @@
from __future__ import annotations
from collections.abc import Mapping
from copy import deepcopy
from typing import Any
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables.segments import Segment
from core.workflow.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.entities.variable_pool import VariablePool
from .graph_runtime_state import GraphRuntimeState
from .variable_pool import VariablePool
class ReadOnlyVariablePoolWrapper:
"""Wrapper that provides read-only access to VariablePool."""
"""Provide defensive, read-only access to ``VariablePool``."""
def __init__(self, variable_pool: VariablePool):
def __init__(self, variable_pool: VariablePool) -> None:
self._variable_pool = variable_pool
def get(self, node_id: str, variable_key: str) -> Segment | None:
"""Get a variable value (returns a defensive copy)."""
"""Return a copy of a variable value if present."""
value = self._variable_pool.get([node_id, variable_key])
return deepcopy(value) if value is not None else None
def get_all_by_node(self, node_id: str) -> Mapping[str, object]:
"""Get all variables for a node (returns defensive copies)."""
"""Return a copy of all variables for the specified node."""
variables: dict[str, object] = {}
if node_id in self._variable_pool.variable_dictionary:
for key, var in self._variable_pool.variable_dictionary[node_id].items():
# Variables have a value property that contains the actual data
variables[key] = deepcopy(var.value)
for key, variable in self._variable_pool.variable_dictionary[node_id].items():
variables[key] = deepcopy(variable.value)
return variables
def get_by_prefix(self, prefix: str) -> Mapping[str, object]:
"""Return a copy of all variables stored under the given prefix."""
return self._variable_pool.get_by_prefix(prefix)
class ReadOnlyGraphRuntimeStateWrapper:
"""
Wrapper that provides read-only access to GraphRuntimeState.
"""Expose a defensive, read-only view of ``GraphRuntimeState``."""
This wrapper ensures that layers can observe the state without
modifying it. All returned values are defensive copies.
"""
def __init__(self, state: GraphRuntimeState):
def __init__(self, state: GraphRuntimeState) -> None:
self._state = state
self._variable_pool_wrapper = ReadOnlyVariablePoolWrapper(state.variable_pool)
@property
def variable_pool(self) -> ReadOnlyVariablePoolWrapper:
"""Get read-only access to the variable pool."""
return self._variable_pool_wrapper
@property
def start_at(self) -> float:
"""Get the start time (read-only)."""
return self._state.start_at
@property
def total_tokens(self) -> int:
"""Get the total tokens count (read-only)."""
return self._state.total_tokens
@property
def llm_usage(self) -> LLMUsage:
"""Get a copy of LLM usage info (read-only)."""
# Return a copy to prevent modification
return self._state.llm_usage.model_copy()
@property
def outputs(self) -> dict[str, Any]:
"""Get a defensive copy of outputs (read-only)."""
return deepcopy(self._state.outputs)
@property
def node_run_steps(self) -> int:
"""Get the node run steps count (read-only)."""
return self._state.node_run_steps
@property
def ready_queue_size(self) -> int:
return self._state.ready_queue.qsize()
@property
def exceptions_count(self) -> int:
return self._state.graph_execution.exceptions_count
def get_output(self, key: str, default: Any = None) -> Any:
"""Get a single output value (returns a copy)."""
return self._state.get_output(key, default)
def dumps(self) -> str:
"""Serialize the underlying runtime state for external persistence."""
return self._state.dumps()

View File

@ -1,6 +1,7 @@
import re
from collections import defaultdict
from collections.abc import Mapping, Sequence
from copy import deepcopy
from typing import Annotated, Any, Union, cast
from pydantic import BaseModel, Field
@ -235,6 +236,20 @@ class VariablePool(BaseModel):
return segment
return None
def get_by_prefix(self, prefix: str, /) -> Mapping[str, object]:
"""Return a copy of all variables stored under the given node prefix."""
nodes = self.variable_dictionary.get(prefix)
if not nodes:
return {}
result: dict[str, object] = {}
for key, variable in nodes.items():
value = variable.value
result[key] = deepcopy(value)
return result
def _add_system_variables(self, system_variable: SystemVariable):
sys_var_mapping = system_variable.to_dict()
for key, value in sys_var_mapping.items():

View File

@ -5,7 +5,7 @@ from typing import Literal, NamedTuple
from core.file import FileAttribute, file_manager
from core.variables import ArrayFileSegment
from core.variables.segments import ArrayBooleanSegment, BooleanSegment
from core.workflow.entities import VariablePool
from core.workflow.runtime import VariablePool
from .entities import Condition, SubCondition, SupportedComparisonOperator

View File

@ -4,7 +4,7 @@ from typing import Any, Protocol
from core.variables import Variable
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.runtime import VariablePool
class VariableLoader(Protocol):

View File

@ -1,459 +0,0 @@
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Union
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
)
from core.app.task_pipeline.exc import WorkflowRunNotFoundError
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.workflow.entities import (
WorkflowExecution,
WorkflowNodeExecution,
)
from core.workflow.enums import (
SystemVariableKey,
WorkflowExecutionStatus,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
WorkflowType,
)
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_entry import WorkflowEntry
from libs.datetime_utils import naive_utc_now
from libs.uuid_utils import uuidv7
@dataclass
class CycleManagerWorkflowInfo:
workflow_id: str
workflow_type: WorkflowType
version: str
graph_data: Mapping[str, Any]
class WorkflowCycleManager:
def __init__(
self,
*,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
workflow_system_variables: SystemVariable,
workflow_info: CycleManagerWorkflowInfo,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
):
self._application_generate_entity = application_generate_entity
self._workflow_system_variables = workflow_system_variables
self._workflow_info = workflow_info
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
# Initialize caches for workflow execution cycle
# These caches avoid redundant repository calls during a single workflow execution
self._workflow_execution_cache: dict[str, WorkflowExecution] = {}
self._node_execution_cache: dict[str, WorkflowNodeExecution] = {}
def handle_workflow_run_start(self) -> WorkflowExecution:
inputs = self._prepare_workflow_inputs()
execution_id = self._get_or_generate_execution_id()
execution = WorkflowExecution.new(
id_=execution_id,
workflow_id=self._workflow_info.workflow_id,
workflow_type=self._workflow_info.workflow_type,
workflow_version=self._workflow_info.version,
graph=self._workflow_info.graph_data,
inputs=inputs,
started_at=naive_utc_now(),
)
return self._save_and_cache_workflow_execution(execution)
def handle_workflow_run_success(
self,
*,
workflow_run_id: str,
total_tokens: int,
total_steps: int,
outputs: Mapping[str, Any] | None = None,
conversation_id: str | None = None,
trace_manager: TraceQueueManager | None = None,
external_trace_id: str | None = None,
) -> WorkflowExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
self._update_workflow_execution_completion(
workflow_execution,
status=WorkflowExecutionStatus.SUCCEEDED,
outputs=outputs,
total_tokens=total_tokens,
total_steps=total_steps,
)
self._add_trace_task_if_needed(trace_manager, workflow_execution, conversation_id, external_trace_id)
self._workflow_execution_repository.save(workflow_execution)
return workflow_execution
def handle_workflow_run_partial_success(
self,
*,
workflow_run_id: str,
total_tokens: int,
total_steps: int,
outputs: Mapping[str, Any] | None = None,
exceptions_count: int = 0,
conversation_id: str | None = None,
trace_manager: TraceQueueManager | None = None,
external_trace_id: str | None = None,
) -> WorkflowExecution:
execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
self._update_workflow_execution_completion(
execution,
status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
outputs=outputs,
total_tokens=total_tokens,
total_steps=total_steps,
exceptions_count=exceptions_count,
)
self._add_trace_task_if_needed(trace_manager, execution, conversation_id, external_trace_id)
self._workflow_execution_repository.save(execution)
return execution
def handle_workflow_run_failed(
self,
*,
workflow_run_id: str,
total_tokens: int,
total_steps: int,
status: WorkflowExecutionStatus,
error_message: str,
conversation_id: str | None = None,
trace_manager: TraceQueueManager | None = None,
exceptions_count: int = 0,
external_trace_id: str | None = None,
) -> WorkflowExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
now = naive_utc_now()
self._update_workflow_execution_completion(
workflow_execution,
status=status,
total_tokens=total_tokens,
total_steps=total_steps,
error_message=error_message,
exceptions_count=exceptions_count,
finished_at=now,
)
self._fail_running_node_executions(workflow_execution.id_, error_message, now)
self._add_trace_task_if_needed(trace_manager, workflow_execution, conversation_id, external_trace_id)
self._workflow_execution_repository.save(workflow_execution)
return workflow_execution
def handle_node_execution_start(
self,
*,
workflow_execution_id: str,
event: QueueNodeStartedEvent,
) -> WorkflowNodeExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id)
domain_execution = self._create_node_execution_from_event(
workflow_execution=workflow_execution,
event=event,
status=WorkflowNodeExecutionStatus.RUNNING,
)
return self._save_and_cache_node_execution(domain_execution)
def handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
domain_execution = self._get_node_execution_from_cache(event.node_execution_id)
self._update_node_execution_completion(
domain_execution,
event=event,
status=WorkflowNodeExecutionStatus.SUCCEEDED,
)
self._workflow_node_execution_repository.save(domain_execution)
self._workflow_node_execution_repository.save_execution_data(domain_execution)
return domain_execution
def handle_workflow_node_execution_failed(
self,
*,
event: QueueNodeFailedEvent | QueueNodeExceptionEvent,
) -> WorkflowNodeExecution:
"""
Workflow node execution failed
:param event: queue node failed event
:return:
"""
domain_execution = self._get_node_execution_from_cache(event.node_execution_id)
status = (
WorkflowNodeExecutionStatus.EXCEPTION
if isinstance(event, QueueNodeExceptionEvent)
else WorkflowNodeExecutionStatus.FAILED
)
self._update_node_execution_completion(
domain_execution,
event=event,
status=status,
error=event.error,
handle_special_values=True,
)
self._workflow_node_execution_repository.save(domain_execution)
self._workflow_node_execution_repository.save_execution_data(domain_execution)
return domain_execution
def handle_workflow_node_execution_retried(
self, *, workflow_execution_id: str, event: QueueNodeRetryEvent
) -> WorkflowNodeExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id)
domain_execution = self._create_node_execution_from_event(
workflow_execution=workflow_execution,
event=event,
status=WorkflowNodeExecutionStatus.RETRY,
error=event.error,
created_at=event.start_at,
)
# Handle inputs and outputs
inputs = WorkflowEntry.handle_special_values(event.inputs)
outputs = event.outputs
metadata = self._merge_event_metadata(event)
domain_execution.update_from_mapping(inputs=inputs, outputs=outputs, metadata=metadata)
execution = self._save_and_cache_node_execution(domain_execution)
self._workflow_node_execution_repository.save_execution_data(execution)
return execution
def _get_workflow_execution_or_raise_error(self, id: str, /) -> WorkflowExecution:
# Check cache first
if id in self._workflow_execution_cache:
return self._workflow_execution_cache[id]
raise WorkflowRunNotFoundError(id)
def _prepare_workflow_inputs(self) -> dict[str, Any]:
"""Prepare workflow inputs by merging application inputs with system variables."""
inputs = {**self._application_generate_entity.inputs}
if self._workflow_system_variables:
for field_name, value in self._workflow_system_variables.to_dict().items():
if field_name != SystemVariableKey.CONVERSATION_ID:
inputs[f"sys.{field_name}"] = value
return dict(WorkflowEntry.handle_special_values(inputs) or {})
def _get_or_generate_execution_id(self) -> str:
"""Get execution ID from system variables or generate a new one."""
if self._workflow_system_variables and self._workflow_system_variables.workflow_execution_id:
return str(self._workflow_system_variables.workflow_execution_id)
return str(uuidv7())
def _save_and_cache_workflow_execution(self, execution: WorkflowExecution) -> WorkflowExecution:
"""Save workflow execution to repository and cache it."""
self._workflow_execution_repository.save(execution)
self._workflow_execution_cache[execution.id_] = execution
return execution
def _save_and_cache_node_execution(self, execution: WorkflowNodeExecution) -> WorkflowNodeExecution:
"""Save node execution to repository and cache it if it has an ID.
This does not persist the `inputs` / `process_data` / `outputs` fields of the execution model.
"""
self._workflow_node_execution_repository.save(execution)
if execution.node_execution_id:
self._node_execution_cache[execution.node_execution_id] = execution
return execution
def _get_node_execution_from_cache(self, node_execution_id: str) -> WorkflowNodeExecution:
"""Get node execution from cache or raise error if not found."""
domain_execution = self._node_execution_cache.get(node_execution_id)
if not domain_execution:
raise ValueError(f"Domain node execution not found: {node_execution_id}")
return domain_execution
def _update_workflow_execution_completion(
self,
execution: WorkflowExecution,
*,
status: WorkflowExecutionStatus,
total_tokens: int,
total_steps: int,
outputs: Mapping[str, Any] | None = None,
error_message: str | None = None,
exceptions_count: int = 0,
finished_at: datetime | None = None,
):
"""Update workflow execution with completion data."""
execution.status = status
execution.outputs = outputs or {}
execution.total_tokens = total_tokens
execution.total_steps = total_steps
execution.finished_at = finished_at or naive_utc_now()
execution.exceptions_count = exceptions_count
if error_message:
execution.error_message = error_message
def _add_trace_task_if_needed(
self,
trace_manager: TraceQueueManager | None,
workflow_execution: WorkflowExecution,
conversation_id: str | None,
external_trace_id: str | None,
):
"""Add trace task if trace manager is provided."""
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.WORKFLOW_TRACE,
workflow_execution=workflow_execution,
conversation_id=conversation_id,
user_id=trace_manager.user_id,
external_trace_id=external_trace_id,
)
)
def _fail_running_node_executions(
self,
workflow_execution_id: str,
error_message: str,
now: datetime,
):
"""Fail all running node executions for a workflow."""
running_node_executions = [
node_exec
for node_exec in self._node_execution_cache.values()
if node_exec.workflow_execution_id == workflow_execution_id
and node_exec.status == WorkflowNodeExecutionStatus.RUNNING
]
for node_execution in running_node_executions:
if node_execution.node_execution_id:
node_execution.status = WorkflowNodeExecutionStatus.FAILED
node_execution.error = error_message
node_execution.finished_at = now
node_execution.elapsed_time = (now - node_execution.created_at).total_seconds()
self._workflow_node_execution_repository.save(node_execution)
def _create_node_execution_from_event(
self,
*,
workflow_execution: WorkflowExecution,
event: QueueNodeStartedEvent,
status: WorkflowNodeExecutionStatus,
error: str | None = None,
created_at: datetime | None = None,
) -> WorkflowNodeExecution:
"""Create a node execution from an event."""
now = naive_utc_now()
created_at = created_at or now
metadata = {
WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
}
domain_execution = WorkflowNodeExecution(
id=event.node_execution_id,
workflow_id=workflow_execution.workflow_id,
workflow_execution_id=workflow_execution.id_,
predecessor_node_id=event.predecessor_node_id,
index=event.node_run_index,
node_execution_id=event.node_execution_id,
node_id=event.node_id,
node_type=event.node_type,
title=event.node_title,
status=status,
metadata=metadata,
created_at=created_at,
error=error,
)
if status == WorkflowNodeExecutionStatus.RETRY:
domain_execution.finished_at = now
domain_execution.elapsed_time = (now - created_at).total_seconds()
return domain_execution
def _update_node_execution_completion(
self,
domain_execution: WorkflowNodeExecution,
*,
event: Union[
QueueNodeSucceededEvent,
QueueNodeFailedEvent,
QueueNodeExceptionEvent,
],
status: WorkflowNodeExecutionStatus,
error: str | None = None,
handle_special_values: bool = False,
):
"""Update node execution with completion data."""
finished_at = naive_utc_now()
elapsed_time = (finished_at - event.start_at).total_seconds()
# Process data
if handle_special_values:
inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data)
else:
inputs = event.inputs
process_data = event.process_data
outputs = event.outputs
# Convert metadata
execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, Any] = {}
if event.execution_metadata:
execution_metadata_dict.update(event.execution_metadata)
# Update domain model
domain_execution.status = status
domain_execution.update_from_mapping(
inputs=inputs,
process_data=process_data,
outputs=outputs,
metadata=execution_metadata_dict,
)
domain_execution.finished_at = finished_at
domain_execution.elapsed_time = elapsed_time
if error:
domain_execution.error = error
def _merge_event_metadata(self, event: QueueNodeRetryEvent) -> dict[WorkflowNodeExecutionMetadataKey, str | None]:
"""Merge event metadata with origin metadata."""
origin_metadata = {
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
}
execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, str | None] = {}
if event.execution_metadata:
execution_metadata_dict.update(event.execution_metadata)
return {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata

View File

@ -9,7 +9,7 @@ from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.entities import GraphInitParams
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
@ -20,6 +20,7 @@ from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, Gra
from core.workflow.nodes import NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
from factories import file_factory