mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 17:08:03 +08:00
feat(api): restore node state snapshots via /events api
This commit is contained in:
@ -6,7 +6,7 @@ import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
|
||||
from flask import Response, jsonify
|
||||
from flask import Response, jsonify, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
@ -28,6 +28,7 @@ from models.model import AppMode
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.human_input_service import Form, HumanInputService
|
||||
from services.workflow_event_snapshot_service import build_workflow_event_stream
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -168,7 +169,19 @@ class ConsoleWorkflowEventsApi(Resource):
|
||||
else:
|
||||
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
|
||||
|
||||
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
|
||||
|
||||
def _generate_stream_events():
|
||||
if include_state_snapshot:
|
||||
return generator.convert_to_event_stream(
|
||||
build_workflow_event_stream(
|
||||
app_mode=AppMode(app.mode),
|
||||
workflow_run=workflow_run,
|
||||
tenant_id=workflow_run.tenant_id,
|
||||
app_id=workflow_run.app_id,
|
||||
session_maker=session_maker,
|
||||
)
|
||||
)
|
||||
return generator.convert_to_event_stream(
|
||||
msg_generator.retrieve_events(AppMode(app.mode), workflow_run.id),
|
||||
)
|
||||
|
||||
@ -5,7 +5,7 @@ Web App Workflow Resume APIs.
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
from flask import Response
|
||||
from flask import Response, request
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.web import api
|
||||
@ -19,6 +19,7 @@ from extensions.ext_database import db
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, AppMode, EndUser
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.workflow_event_snapshot_service import build_workflow_event_stream
|
||||
|
||||
|
||||
class WorkflowEventsApi(WebApiResource):
|
||||
@ -76,7 +77,19 @@ class WorkflowEventsApi(WebApiResource):
|
||||
else:
|
||||
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
|
||||
|
||||
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
|
||||
|
||||
def _generate_stream_events():
|
||||
if include_state_snapshot:
|
||||
return generator.convert_to_event_stream(
|
||||
build_workflow_event_stream(
|
||||
app_mode=app_mode,
|
||||
workflow_run=workflow_run,
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
session_maker=session_maker,
|
||||
)
|
||||
)
|
||||
return generator.convert_to_event_stream(
|
||||
msg_generator.retrieve_events(app_mode, workflow_run.id),
|
||||
)
|
||||
|
||||
@ -10,6 +10,7 @@ tenant_id, app_id, triggered_from, etc., which are not part of the core domain m
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Protocol
|
||||
|
||||
@ -17,6 +18,27 @@ from core.workflow.repositories.workflow_node_execution_repository import Workfl
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WorkflowNodeExecutionSnapshot:
|
||||
"""
|
||||
Minimal snapshot of workflow node execution for stream recovery.
|
||||
|
||||
Only includes fields required by snapshot events.
|
||||
"""
|
||||
|
||||
execution_id: str # Unique execution identifier (node_execution_id or row id).
|
||||
node_id: str # Workflow graph node id.
|
||||
node_type: str # Workflow graph node type (e.g. "human-input").
|
||||
title: str # Human-friendly node title.
|
||||
index: int # Execution order index within the workflow run.
|
||||
status: str # Execution status (running/succeeded/failed/paused).
|
||||
elapsed_time: float # Execution elapsed time in seconds.
|
||||
created_at: datetime # Execution created timestamp.
|
||||
finished_at: datetime | None # Execution finished timestamp.
|
||||
iteration_id: str | None = None # Iteration id from execution metadata, if any.
|
||||
loop_id: str | None = None # Loop id from execution metadata, if any.
|
||||
|
||||
|
||||
class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Protocol):
|
||||
"""
|
||||
Protocol for service-layer operations on WorkflowNodeExecutionModel.
|
||||
@ -77,6 +99,8 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
app_id: The application identifier
|
||||
workflow_id: The workflow identifier
|
||||
triggered_from: The workflow trigger source
|
||||
workflow_run_id: The workflow run identifier
|
||||
|
||||
Returns:
|
||||
@ -84,6 +108,27 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr
|
||||
"""
|
||||
...
|
||||
|
||||
def get_execution_snapshots_by_workflow_run(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_id: str,
|
||||
triggered_from: str,
|
||||
workflow_run_id: str,
|
||||
) -> Sequence[WorkflowNodeExecutionSnapshot]:
|
||||
"""
|
||||
Get minimal snapshots for node executions in a workflow run.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
app_id: The application identifier
|
||||
workflow_run_id: The workflow run identifier
|
||||
|
||||
Returns:
|
||||
A sequence of WorkflowNodeExecutionSnapshot ordered by creation time
|
||||
"""
|
||||
...
|
||||
|
||||
def get_execution_by_id(
|
||||
self,
|
||||
execution_id: str,
|
||||
|
||||
@ -5,17 +5,20 @@ This module provides a concrete implementation of the service repository protoco
|
||||
using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations.
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import asc, delete, desc, select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
|
||||
from repositories.api_workflow_node_execution_repository import (
|
||||
DifyAPIWorkflowNodeExecutionRepository,
|
||||
WorkflowNodeExecutionSnapshot,
|
||||
)
|
||||
|
||||
|
||||
class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository):
|
||||
@ -117,6 +120,80 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
||||
with self._session_maker() as session:
|
||||
return session.execute(stmt).scalars().all()
|
||||
|
||||
def get_execution_snapshots_by_workflow_run(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_id: str,
|
||||
triggered_from: str,
|
||||
workflow_run_id: str,
|
||||
) -> Sequence[WorkflowNodeExecutionSnapshot]:
|
||||
stmt = (
|
||||
select(
|
||||
WorkflowNodeExecutionModel.id,
|
||||
WorkflowNodeExecutionModel.node_execution_id,
|
||||
WorkflowNodeExecutionModel.node_id,
|
||||
WorkflowNodeExecutionModel.node_type,
|
||||
WorkflowNodeExecutionModel.title,
|
||||
WorkflowNodeExecutionModel.index,
|
||||
WorkflowNodeExecutionModel.status,
|
||||
WorkflowNodeExecutionModel.elapsed_time,
|
||||
WorkflowNodeExecutionModel.created_at,
|
||||
WorkflowNodeExecutionModel.finished_at,
|
||||
WorkflowNodeExecutionModel.execution_metadata,
|
||||
)
|
||||
.where(
|
||||
WorkflowNodeExecutionModel.tenant_id == tenant_id,
|
||||
WorkflowNodeExecutionModel.app_id == app_id,
|
||||
WorkflowNodeExecutionModel.workflow_id == workflow_id,
|
||||
WorkflowNodeExecutionModel.triggered_from == triggered_from,
|
||||
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
|
||||
)
|
||||
.order_by(
|
||||
asc(WorkflowNodeExecutionModel.created_at),
|
||||
asc(WorkflowNodeExecutionModel.index),
|
||||
)
|
||||
)
|
||||
|
||||
with self._session_maker() as session:
|
||||
rows = session.execute(stmt).all()
|
||||
|
||||
return [self._row_to_snapshot(row) for row in rows]
|
||||
|
||||
@staticmethod
|
||||
def _row_to_snapshot(row: object) -> WorkflowNodeExecutionSnapshot:
|
||||
metadata: dict[str, object] = {}
|
||||
execution_metadata = getattr(row, "execution_metadata", None)
|
||||
if execution_metadata:
|
||||
try:
|
||||
metadata = json.loads(execution_metadata)
|
||||
except json.JSONDecodeError:
|
||||
metadata = {}
|
||||
iteration_id = metadata.get(WorkflowNodeExecutionMetadataKey.ITERATION_ID.value)
|
||||
loop_id = metadata.get(WorkflowNodeExecutionMetadataKey.LOOP_ID.value)
|
||||
execution_id = getattr(row, "node_execution_id", None) or row.id
|
||||
elapsed_time = getattr(row, "elapsed_time", None)
|
||||
created_at = row.created_at
|
||||
finished_at = getattr(row, "finished_at", None)
|
||||
if elapsed_time is None:
|
||||
if finished_at is not None and created_at is not None:
|
||||
elapsed_time = (finished_at - created_at).total_seconds()
|
||||
else:
|
||||
elapsed_time = 0.0
|
||||
return WorkflowNodeExecutionSnapshot(
|
||||
execution_id=str(execution_id),
|
||||
node_id=row.node_id,
|
||||
node_type=row.node_type,
|
||||
title=row.title,
|
||||
index=row.index,
|
||||
status=row.status,
|
||||
elapsed_time=float(elapsed_time),
|
||||
created_at=created_at,
|
||||
finished_at=finished_at,
|
||||
iteration_id=str(iteration_id) if iteration_id else None,
|
||||
loop_id=str(loop_id) if loop_id else None,
|
||||
)
|
||||
|
||||
def get_execution_by_id(
|
||||
self,
|
||||
execution_id: str,
|
||||
|
||||
459
api/services/workflow_event_snapshot_service.py
Normal file
459
api/services/workflow_event_snapshot_service.py
Normal file
@ -0,0 +1,459 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Generator, Iterable, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import desc, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.app.apps.message_generator import MessageGenerator
|
||||
from core.app.entities.task_entities import (
|
||||
NodeFinishStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
StreamEvent,
|
||||
WorkflowPauseStreamResponse,
|
||||
WorkflowStartStreamResponse,
|
||||
)
|
||||
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext
|
||||
from core.workflow.entities import WorkflowStartReason
|
||||
from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from models.model import AppMode, Message
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MessageContext:
|
||||
conversation_id: str
|
||||
message_id: str
|
||||
created_at: int
|
||||
|
||||
|
||||
def build_workflow_event_stream(
|
||||
*,
|
||||
app_mode: AppMode,
|
||||
workflow_run: WorkflowRun,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
session_maker: sessionmaker[Session],
|
||||
idle_timeout: float = 300,
|
||||
ping_interval: float = 10.0,
|
||||
) -> Generator[Mapping[str, Any] | str, None, None]:
|
||||
topic = MessageGenerator.get_response_topic(app_mode, workflow_run.id)
|
||||
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker)
|
||||
message_context = (
|
||||
_get_message_context(session_maker, workflow_run.id) if app_mode == AppMode.ADVANCED_CHAT else None
|
||||
)
|
||||
|
||||
pause_entity: WorkflowPauseEntity | None = None
|
||||
if workflow_run.status == WorkflowExecutionStatus.PAUSED:
|
||||
try:
|
||||
pause_entity = workflow_run_repo.get_workflow_pause(workflow_run.id)
|
||||
except Exception:
|
||||
logger.exception("Failed to load workflow pause for run %s", workflow_run.id)
|
||||
pause_entity = None
|
||||
|
||||
resumption_context = _load_resumption_context(pause_entity)
|
||||
node_snapshots = node_execution_repo.get_execution_snapshots_by_workflow_run(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
triggered_from=workflow_run.triggered_from,
|
||||
workflow_run_id=workflow_run.id,
|
||||
)
|
||||
|
||||
def _generate() -> Generator[Mapping[str, Any] | str, None, None]:
|
||||
last_msg_time = time.time()
|
||||
last_ping_time = last_msg_time
|
||||
|
||||
with topic.subscribe() as sub:
|
||||
buffer_queue, stop_event, buffer_done = _start_buffering(sub)
|
||||
try:
|
||||
buffered_events = _drain_buffer(buffer_queue)
|
||||
task_id = _resolve_task_id(resumption_context, buffered_events, workflow_run.id)
|
||||
|
||||
snapshot_events = _build_snapshot_events(
|
||||
workflow_run=workflow_run,
|
||||
node_snapshots=node_snapshots,
|
||||
task_id=task_id,
|
||||
message_context=message_context,
|
||||
pause_entity=pause_entity,
|
||||
resumption_context=resumption_context,
|
||||
)
|
||||
buffered_events.extend(_drain_buffer(buffer_queue))
|
||||
snapshot_keys = _collect_snapshot_keys(snapshot_events)
|
||||
|
||||
for event in snapshot_events:
|
||||
last_msg_time = time.time()
|
||||
last_ping_time = last_msg_time
|
||||
yield event
|
||||
if _is_terminal_event(event):
|
||||
return
|
||||
|
||||
for event in _filter_buffered_events(buffered_events, snapshot_keys):
|
||||
last_msg_time = time.time()
|
||||
last_ping_time = last_msg_time
|
||||
yield event
|
||||
if _is_terminal_event(event):
|
||||
return
|
||||
|
||||
while True:
|
||||
if buffer_done.is_set() and buffer_queue.empty():
|
||||
return
|
||||
|
||||
try:
|
||||
event = buffer_queue.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
current_time = time.time()
|
||||
if current_time - last_msg_time > idle_timeout:
|
||||
return
|
||||
if current_time - last_ping_time >= ping_interval:
|
||||
yield StreamEvent.PING.value
|
||||
last_ping_time = current_time
|
||||
continue
|
||||
|
||||
if _is_duplicate_event(event, snapshot_keys):
|
||||
continue
|
||||
last_msg_time = time.time()
|
||||
last_ping_time = last_msg_time
|
||||
yield event
|
||||
if _is_terminal_event(event):
|
||||
return
|
||||
finally:
|
||||
stop_event.set()
|
||||
|
||||
return _generate()
|
||||
|
||||
|
||||
def _get_message_context(session_maker: sessionmaker[Session], workflow_run_id: str) -> MessageContext | None:
|
||||
with session_maker() as session:
|
||||
stmt = select(Message).where(Message.workflow_run_id == workflow_run_id).order_by(desc(Message.created_at))
|
||||
message = session.scalar(stmt)
|
||||
if message is None:
|
||||
return None
|
||||
created_at = int(message.created_at.timestamp()) if message.created_at else 0
|
||||
return MessageContext(
|
||||
conversation_id=message.conversation_id,
|
||||
message_id=message.id,
|
||||
created_at=created_at,
|
||||
)
|
||||
|
||||
|
||||
def _load_resumption_context(pause_entity: WorkflowPauseEntity | None) -> WorkflowResumptionContext | None:
|
||||
if pause_entity is None:
|
||||
return None
|
||||
try:
|
||||
raw_state = pause_entity.get_state().decode()
|
||||
return WorkflowResumptionContext.loads(raw_state)
|
||||
except Exception:
|
||||
logger.exception("Failed to load resumption context")
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_task_id(
|
||||
resumption_context: WorkflowResumptionContext | None,
|
||||
buffered_events: Sequence[Mapping[str, Any]],
|
||||
workflow_run_id: str,
|
||||
) -> str:
|
||||
if resumption_context is not None:
|
||||
generate_entity = resumption_context.get_generate_entity()
|
||||
if generate_entity.task_id:
|
||||
return generate_entity.task_id
|
||||
for event in buffered_events:
|
||||
task_id = event.get("task_id")
|
||||
if task_id:
|
||||
return str(task_id)
|
||||
return workflow_run_id
|
||||
|
||||
|
||||
def _build_snapshot_events(
|
||||
*,
|
||||
workflow_run: WorkflowRun,
|
||||
node_snapshots: Sequence[WorkflowNodeExecutionSnapshot],
|
||||
task_id: str,
|
||||
message_context: MessageContext | None,
|
||||
pause_entity: WorkflowPauseEntity | None,
|
||||
resumption_context: WorkflowResumptionContext | None,
|
||||
) -> list[Mapping[str, Any]]:
|
||||
events: list[Mapping[str, Any]] = []
|
||||
|
||||
workflow_started = _build_workflow_started_event(
|
||||
workflow_run=workflow_run,
|
||||
task_id=task_id,
|
||||
)
|
||||
_apply_message_context(workflow_started, message_context)
|
||||
events.append(workflow_started)
|
||||
|
||||
for snapshot in node_snapshots:
|
||||
node_started = _build_node_started_event(
|
||||
workflow_run_id=workflow_run.id,
|
||||
task_id=task_id,
|
||||
snapshot=snapshot,
|
||||
)
|
||||
_apply_message_context(node_started, message_context)
|
||||
events.append(node_started)
|
||||
|
||||
if snapshot.status != WorkflowNodeExecutionStatus.RUNNING.value:
|
||||
node_finished = _build_node_finished_event(
|
||||
workflow_run_id=workflow_run.id,
|
||||
task_id=task_id,
|
||||
snapshot=snapshot,
|
||||
)
|
||||
_apply_message_context(node_finished, message_context)
|
||||
events.append(node_finished)
|
||||
|
||||
if workflow_run.status == WorkflowExecutionStatus.PAUSED and pause_entity is not None:
|
||||
pause_event = _build_pause_event(
|
||||
workflow_run_id=workflow_run.id,
|
||||
task_id=task_id,
|
||||
pause_entity=pause_entity,
|
||||
resumption_context=resumption_context,
|
||||
)
|
||||
if pause_event is not None:
|
||||
_apply_message_context(pause_event, message_context)
|
||||
events.append(pause_event)
|
||||
|
||||
return events
|
||||
|
||||
|
||||
def _build_workflow_started_event(
|
||||
*,
|
||||
workflow_run: WorkflowRun,
|
||||
task_id: str,
|
||||
) -> dict[str, Any]:
|
||||
response = WorkflowStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=WorkflowStartStreamResponse.Data(
|
||||
id=workflow_run.id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
inputs=workflow_run.inputs_dict or {},
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
reason=WorkflowStartReason.INITIAL,
|
||||
),
|
||||
)
|
||||
payload = response.model_dump(mode="json")
|
||||
payload["event"] = response.event.value
|
||||
return payload
|
||||
|
||||
|
||||
def _build_node_started_event(
|
||||
*,
|
||||
workflow_run_id: str,
|
||||
task_id: str,
|
||||
snapshot: WorkflowNodeExecutionSnapshot,
|
||||
) -> Mapping[str, Any]:
|
||||
created_at = int(snapshot.created_at.timestamp()) if snapshot.created_at else 0
|
||||
response = NodeStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
data=NodeStartStreamResponse.Data(
|
||||
id=snapshot.execution_id,
|
||||
node_id=snapshot.node_id,
|
||||
node_type=snapshot.node_type,
|
||||
title=snapshot.title,
|
||||
index=snapshot.index,
|
||||
predecessor_node_id=None,
|
||||
inputs=None,
|
||||
created_at=created_at,
|
||||
extras={},
|
||||
iteration_id=snapshot.iteration_id,
|
||||
loop_id=snapshot.loop_id,
|
||||
),
|
||||
)
|
||||
return response.to_ignore_detail_dict()
|
||||
|
||||
|
||||
def _build_node_finished_event(
|
||||
*,
|
||||
workflow_run_id: str,
|
||||
task_id: str,
|
||||
snapshot: WorkflowNodeExecutionSnapshot,
|
||||
) -> Mapping[str, Any]:
|
||||
created_at = int(snapshot.created_at.timestamp()) if snapshot.created_at else 0
|
||||
finished_at = int(snapshot.finished_at.timestamp()) if snapshot.finished_at else created_at
|
||||
response = NodeFinishStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
data=NodeFinishStreamResponse.Data(
|
||||
id=snapshot.execution_id,
|
||||
node_id=snapshot.node_id,
|
||||
node_type=snapshot.node_type,
|
||||
title=snapshot.title,
|
||||
index=snapshot.index,
|
||||
predecessor_node_id=None,
|
||||
inputs=None,
|
||||
process_data=None,
|
||||
outputs=None,
|
||||
status=snapshot.status,
|
||||
error=None,
|
||||
elapsed_time=snapshot.elapsed_time,
|
||||
execution_metadata=None,
|
||||
created_at=created_at,
|
||||
finished_at=finished_at,
|
||||
files=[],
|
||||
iteration_id=snapshot.iteration_id,
|
||||
loop_id=snapshot.loop_id,
|
||||
),
|
||||
)
|
||||
return response.to_ignore_detail_dict()
|
||||
|
||||
|
||||
def _build_pause_event(
|
||||
*,
|
||||
workflow_run_id: str,
|
||||
task_id: str,
|
||||
pause_entity: WorkflowPauseEntity,
|
||||
resumption_context: WorkflowResumptionContext | None,
|
||||
) -> Mapping[str, Any] | None:
|
||||
paused_nodes: list[str] = []
|
||||
outputs: dict[str, Any] = {}
|
||||
if resumption_context is not None:
|
||||
state = GraphRuntimeState.from_snapshot(resumption_context.serialized_graph_runtime_state)
|
||||
paused_nodes = state.get_paused_nodes()
|
||||
outputs = WorkflowRuntimeTypeConverter().to_json_encodable(state.outputs or {})
|
||||
|
||||
reasons = [reason.model_dump(mode="json") for reason in pause_entity.get_pause_reasons()]
|
||||
response = WorkflowPauseStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
data=WorkflowPauseStreamResponse.Data(
|
||||
workflow_run_id=workflow_run_id,
|
||||
paused_nodes=paused_nodes,
|
||||
outputs=outputs,
|
||||
reasons=reasons,
|
||||
),
|
||||
)
|
||||
payload = response.model_dump(mode="json")
|
||||
payload["event"] = response.event.value
|
||||
return payload
|
||||
|
||||
|
||||
def _apply_message_context(payload: dict[str, Any], message_context: MessageContext | None) -> None:
|
||||
if message_context is None:
|
||||
return
|
||||
payload["conversation_id"] = message_context.conversation_id
|
||||
payload["message_id"] = message_context.message_id
|
||||
payload["created_at"] = message_context.created_at
|
||||
|
||||
|
||||
def _start_buffering(subscription) -> tuple[queue.Queue[Mapping[str, Any]], threading.Event, threading.Event]:
|
||||
buffer_queue: queue.Queue[Mapping[str, Any]] = queue.Queue(maxsize=2048)
|
||||
stop_event = threading.Event()
|
||||
done_event = threading.Event()
|
||||
|
||||
def _worker() -> None:
|
||||
dropped_count = 0
|
||||
try:
|
||||
while not stop_event.is_set():
|
||||
msg = subscription.receive(timeout=0.1)
|
||||
if msg is None:
|
||||
continue
|
||||
event = _parse_event_message(msg)
|
||||
if event is None:
|
||||
continue
|
||||
try:
|
||||
buffer_queue.put_nowait(event)
|
||||
except queue.Full:
|
||||
dropped_count += 1
|
||||
try:
|
||||
buffer_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
pass
|
||||
try:
|
||||
buffer_queue.put_nowait(event)
|
||||
except queue.Full:
|
||||
continue
|
||||
logger.warning("Dropped buffered workflow event, total_dropped=%s", dropped_count)
|
||||
except Exception:
|
||||
logger.exception("Failed while buffering workflow events")
|
||||
finally:
|
||||
done_event.set()
|
||||
|
||||
thread = threading.Thread(target=_worker, name=f"workflow-event-buffer-{id(subscription)}", daemon=True)
|
||||
thread.start()
|
||||
return buffer_queue, stop_event, done_event
|
||||
|
||||
|
||||
def _drain_buffer(
|
||||
buffer_queue: queue.Queue[Mapping[str, Any]],
|
||||
) -> list[Mapping[str, Any]]:
|
||||
events: list[Mapping[str, Any]] = []
|
||||
while True:
|
||||
try:
|
||||
event = buffer_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
break
|
||||
events.append(event)
|
||||
return events
|
||||
|
||||
|
||||
def _parse_event_message(message: bytes) -> Mapping[str, Any] | None:
|
||||
try:
|
||||
event = json.loads(message)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Failed to decode workflow event payload")
|
||||
return None
|
||||
if not isinstance(event, dict):
|
||||
return None
|
||||
return event
|
||||
|
||||
|
||||
def _is_terminal_event(event: Mapping[str, Any] | str) -> bool:
|
||||
if not isinstance(event, Mapping):
|
||||
return False
|
||||
event_type = event.get("event")
|
||||
return event_type in (StreamEvent.WORKFLOW_FINISHED.value, StreamEvent.WORKFLOW_PAUSED.value)
|
||||
|
||||
|
||||
def _collect_snapshot_keys(events: Iterable[Mapping[str, Any]]) -> set[tuple[str, str]]:
|
||||
keys: set[tuple[str, str]] = set()
|
||||
for event in events:
|
||||
key = _event_snapshot_key(event)
|
||||
if key is not None:
|
||||
keys.add(key)
|
||||
return keys
|
||||
|
||||
|
||||
def _filter_buffered_events(
|
||||
events: Sequence[Mapping[str, Any]],
|
||||
snapshot_keys: set[tuple[str, str]],
|
||||
) -> Iterable[Mapping[str, Any]]:
|
||||
for event in events:
|
||||
if _is_duplicate_event(event, snapshot_keys):
|
||||
continue
|
||||
yield event
|
||||
|
||||
|
||||
def _is_duplicate_event(event: Mapping[str, Any], snapshot_keys: set[tuple[str, str]]) -> bool:
|
||||
key = _event_snapshot_key(event)
|
||||
if key is None:
|
||||
return False
|
||||
return key in snapshot_keys
|
||||
|
||||
|
||||
def _event_snapshot_key(event: Mapping[str, Any]) -> tuple[str, str] | None:
|
||||
event_type = event.get("event")
|
||||
if not event_type:
|
||||
return None
|
||||
if event_type == StreamEvent.WORKFLOW_STARTED.value:
|
||||
return (event_type, event.get("workflow_run_id") or "")
|
||||
if event_type in {StreamEvent.NODE_STARTED.value, StreamEvent.NODE_FINISHED.value}:
|
||||
data = event.get("data") or {}
|
||||
return (event_type, str(data.get("id") or ""))
|
||||
if event_type == StreamEvent.WORKFLOW_PAUSED.value:
|
||||
return (event_type, event.get("workflow_run_id") or "")
|
||||
return None
|
||||
@ -0,0 +1,237 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import WorkflowUIBasedAppConfig
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
from services.workflow_event_snapshot_service import (
|
||||
MessageContext,
|
||||
_build_snapshot_events,
|
||||
_collect_snapshot_keys,
|
||||
_filter_buffered_events,
|
||||
_resolve_task_id,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _FakePauseEntity(WorkflowPauseEntity):
|
||||
pause_id: str
|
||||
workflow_run_id: str
|
||||
paused_at_value: datetime
|
||||
pause_reasons: Sequence[HumanInputRequired]
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self.pause_id
|
||||
|
||||
@property
|
||||
def workflow_execution_id(self) -> str:
|
||||
return self.workflow_run_id
|
||||
|
||||
def get_state(self) -> bytes:
|
||||
raise AssertionError("state is not required for snapshot tests")
|
||||
|
||||
@property
|
||||
def resumed_at(self) -> datetime | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def paused_at(self) -> datetime:
|
||||
return self.paused_at_value
|
||||
|
||||
def get_pause_reasons(self) -> Sequence[HumanInputRequired]:
|
||||
return self.pause_reasons
|
||||
|
||||
|
||||
def _build_workflow_run(status: WorkflowExecutionStatus) -> WorkflowRun:
|
||||
return WorkflowRun(
|
||||
id="run-1",
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
workflow_id="workflow-1",
|
||||
type="workflow",
|
||||
triggered_from="app-run",
|
||||
version="v1",
|
||||
graph=None,
|
||||
inputs=json.dumps({"input": "value"}),
|
||||
status=status,
|
||||
outputs=json.dumps({}),
|
||||
error=None,
|
||||
elapsed_time=0.0,
|
||||
total_tokens=0,
|
||||
total_steps=0,
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="user-1",
|
||||
created_at=datetime(2024, 1, 1, tzinfo=UTC),
|
||||
)
|
||||
|
||||
|
||||
def _build_snapshot(status: WorkflowNodeExecutionStatus) -> WorkflowNodeExecutionSnapshot:
|
||||
created_at = datetime(2024, 1, 1, tzinfo=UTC)
|
||||
finished_at = datetime(2024, 1, 1, 0, 0, 5, tzinfo=UTC)
|
||||
return WorkflowNodeExecutionSnapshot(
|
||||
execution_id="exec-1",
|
||||
node_id="node-1",
|
||||
node_type="human-input",
|
||||
title="Human Input",
|
||||
index=1,
|
||||
status=status.value,
|
||||
elapsed_time=0.5,
|
||||
created_at=created_at,
|
||||
finished_at=finished_at,
|
||||
iteration_id=None,
|
||||
loop_id=None,
|
||||
)
|
||||
|
||||
|
||||
def _build_resumption_context(task_id: str) -> WorkflowResumptionContext:
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_id="workflow-1",
|
||||
)
|
||||
generate_entity = WorkflowAppGenerateEntity(
|
||||
task_id=task_id,
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id="user-1",
|
||||
stream=True,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
call_depth=0,
|
||||
workflow_execution_id="run-1",
|
||||
)
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
|
||||
runtime_state.register_paused_node("node-1")
|
||||
runtime_state.outputs = {"result": "value"}
|
||||
wrapper = _WorkflowGenerateEntityWrapper(entity=generate_entity)
|
||||
return WorkflowResumptionContext(
|
||||
generate_entity=wrapper,
|
||||
serialized_graph_runtime_state=runtime_state.dumps(),
|
||||
)
|
||||
|
||||
|
||||
def test_build_snapshot_events_includes_pause_event() -> None:
|
||||
workflow_run = _build_workflow_run(WorkflowExecutionStatus.PAUSED)
|
||||
snapshot = _build_snapshot(WorkflowNodeExecutionStatus.PAUSED)
|
||||
resumption_context = _build_resumption_context("task-ctx")
|
||||
pause_entity = _FakePauseEntity(
|
||||
pause_id="pause-1",
|
||||
workflow_run_id="run-1",
|
||||
paused_at_value=datetime(2024, 1, 1, tzinfo=UTC),
|
||||
pause_reasons=[
|
||||
HumanInputRequired(
|
||||
form_id="form-1",
|
||||
form_content="content",
|
||||
node_id="node-1",
|
||||
node_title="Human Input",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
events = _build_snapshot_events(
|
||||
workflow_run=workflow_run,
|
||||
node_snapshots=[snapshot],
|
||||
task_id="task-ctx",
|
||||
message_context=None,
|
||||
pause_entity=pause_entity,
|
||||
resumption_context=resumption_context,
|
||||
)
|
||||
|
||||
assert [event["event"] for event in events] == [
|
||||
"workflow_started",
|
||||
"node_started",
|
||||
"node_finished",
|
||||
"workflow_paused",
|
||||
]
|
||||
assert events[2]["data"]["status"] == WorkflowNodeExecutionStatus.PAUSED.value
|
||||
pause_data = events[-1]["data"]
|
||||
assert pause_data["paused_nodes"] == ["node-1"]
|
||||
assert pause_data["outputs"] == {"result": "value"}
|
||||
|
||||
|
||||
def test_build_snapshot_events_applies_message_context() -> None:
|
||||
workflow_run = _build_workflow_run(WorkflowExecutionStatus.RUNNING)
|
||||
snapshot = _build_snapshot(WorkflowNodeExecutionStatus.SUCCEEDED)
|
||||
message_context = MessageContext(
|
||||
conversation_id="conv-1",
|
||||
message_id="msg-1",
|
||||
created_at=1700000000,
|
||||
)
|
||||
|
||||
events = _build_snapshot_events(
|
||||
workflow_run=workflow_run,
|
||||
node_snapshots=[snapshot],
|
||||
task_id="task-1",
|
||||
message_context=message_context,
|
||||
pause_entity=None,
|
||||
resumption_context=None,
|
||||
)
|
||||
|
||||
for event in events:
|
||||
assert event["conversation_id"] == "conv-1"
|
||||
assert event["message_id"] == "msg-1"
|
||||
assert event["created_at"] == 1700000000
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("context_task_id", "buffered_task_id", "expected"),
|
||||
[
|
||||
("task-ctx", "task-buffer", "task-ctx"),
|
||||
(None, "task-buffer", "task-buffer"),
|
||||
(None, None, "run-1"),
|
||||
],
|
||||
)
|
||||
def test_resolve_task_id_priority(context_task_id, buffered_task_id, expected) -> None:
|
||||
resumption_context = _build_resumption_context(context_task_id) if context_task_id else None
|
||||
buffered_events = [{"task_id": buffered_task_id}] if buffered_task_id else []
|
||||
task_id = _resolve_task_id(resumption_context, buffered_events, "run-1")
|
||||
assert task_id == expected
|
||||
|
||||
|
||||
def test_filter_buffered_events_deduplicates_snapshot_nodes() -> None:
|
||||
workflow_run = _build_workflow_run(WorkflowExecutionStatus.RUNNING)
|
||||
snapshot = _build_snapshot(WorkflowNodeExecutionStatus.SUCCEEDED)
|
||||
events = _build_snapshot_events(
|
||||
workflow_run=workflow_run,
|
||||
node_snapshots=[snapshot],
|
||||
task_id="task-1",
|
||||
message_context=None,
|
||||
pause_entity=None,
|
||||
resumption_context=None,
|
||||
)
|
||||
snapshot_keys = _collect_snapshot_keys(events)
|
||||
|
||||
buffered_events = [
|
||||
{
|
||||
"event": "node_started",
|
||||
"data": {"id": "exec-1"},
|
||||
},
|
||||
{
|
||||
"event": "node_finished",
|
||||
"data": {"id": "exec-2"},
|
||||
},
|
||||
]
|
||||
|
||||
filtered = list(_filter_buffered_events(buffered_events, snapshot_keys))
|
||||
assert filtered == [
|
||||
{
|
||||
"event": "node_finished",
|
||||
"data": {"id": "exec-2"},
|
||||
}
|
||||
]
|
||||
Reference in New Issue
Block a user