feat(api): restore node state snapshots via /events api

This commit is contained in:
QuantumGhost
2026-01-19 09:49:58 +08:00
parent 6bf6bf6a2a
commit b085df9425
6 changed files with 849 additions and 5 deletions

View File

@ -6,7 +6,7 @@ import json
import logging
from collections.abc import Generator
from flask import Response, jsonify
from flask import Response, jsonify, request
from flask_restx import Resource, reqparse
from pydantic import BaseModel
from sqlalchemy import select
@ -28,6 +28,7 @@ from models.model import AppMode
from models.workflow import WorkflowRun
from repositories.factory import DifyAPIRepositoryFactory
from services.human_input_service import Form, HumanInputService
from services.workflow_event_snapshot_service import build_workflow_event_stream
logger = logging.getLogger(__name__)
@ -168,7 +169,19 @@ class ConsoleWorkflowEventsApi(Resource):
else:
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
def _generate_stream_events():
if include_state_snapshot:
return generator.convert_to_event_stream(
build_workflow_event_stream(
app_mode=AppMode(app.mode),
workflow_run=workflow_run,
tenant_id=workflow_run.tenant_id,
app_id=workflow_run.app_id,
session_maker=session_maker,
)
)
return generator.convert_to_event_stream(
msg_generator.retrieve_events(AppMode(app.mode), workflow_run.id),
)

View File

@ -5,7 +5,7 @@ Web App Workflow Resume APIs.
import json
from collections.abc import Generator
from flask import Response
from flask import Response, request
from sqlalchemy.orm import sessionmaker
from controllers.web import api
@ -19,6 +19,7 @@ from extensions.ext_database import db
from models.enums import CreatorUserRole
from models.model import App, AppMode, EndUser
from repositories.factory import DifyAPIRepositoryFactory
from services.workflow_event_snapshot_service import build_workflow_event_stream
class WorkflowEventsApi(WebApiResource):
@ -76,7 +77,19 @@ class WorkflowEventsApi(WebApiResource):
else:
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
def _generate_stream_events():
if include_state_snapshot:
return generator.convert_to_event_stream(
build_workflow_event_stream(
app_mode=app_mode,
workflow_run=workflow_run,
tenant_id=app_model.tenant_id,
app_id=app_model.id,
session_maker=session_maker,
)
)
return generator.convert_to_event_stream(
msg_generator.retrieve_events(app_mode, workflow_run.id),
)

View File

@ -10,6 +10,7 @@ tenant_id, app_id, triggered_from, etc., which are not part of the core domain m
"""
from collections.abc import Sequence
from dataclasses import dataclass
from datetime import datetime
from typing import Protocol
@ -17,6 +18,27 @@ from core.workflow.repositories.workflow_node_execution_repository import Workfl
from models.workflow import WorkflowNodeExecutionModel
@dataclass(frozen=True)
class WorkflowNodeExecutionSnapshot:
"""
Minimal snapshot of workflow node execution for stream recovery.
Only includes fields required by snapshot events.
"""
execution_id: str # Unique execution identifier (node_execution_id or row id).
node_id: str # Workflow graph node id.
node_type: str # Workflow graph node type (e.g. "human-input").
title: str # Human-friendly node title.
index: int # Execution order index within the workflow run.
status: str # Execution status (running/succeeded/failed/paused).
elapsed_time: float # Execution elapsed time in seconds.
created_at: datetime # Execution created timestamp.
finished_at: datetime | None # Execution finished timestamp.
iteration_id: str | None = None # Iteration id from execution metadata, if any.
loop_id: str | None = None # Loop id from execution metadata, if any.
class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Protocol):
"""
Protocol for service-layer operations on WorkflowNodeExecutionModel.
@ -77,6 +99,8 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr
Args:
tenant_id: The tenant identifier
app_id: The application identifier
workflow_id: The workflow identifier
triggered_from: The workflow trigger source
workflow_run_id: The workflow run identifier
Returns:
@ -84,6 +108,27 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr
"""
...
def get_execution_snapshots_by_workflow_run(
self,
tenant_id: str,
app_id: str,
workflow_id: str,
triggered_from: str,
workflow_run_id: str,
) -> Sequence[WorkflowNodeExecutionSnapshot]:
"""
Get minimal snapshots for node executions in a workflow run.
Args:
tenant_id: The tenant identifier
app_id: The application identifier
workflow_run_id: The workflow run identifier
Returns:
A sequence of WorkflowNodeExecutionSnapshot ordered by creation time
"""
...
def get_execution_by_id(
self,
execution_id: str,

View File

@ -5,17 +5,20 @@ This module provides a concrete implementation of the service repository protoco
using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations.
"""
import json
from collections.abc import Sequence
from datetime import datetime
from typing import cast
from sqlalchemy import asc, delete, desc, select
from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import Session, sessionmaker
from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from models.workflow import WorkflowNodeExecutionModel
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
from repositories.api_workflow_node_execution_repository import (
DifyAPIWorkflowNodeExecutionRepository,
WorkflowNodeExecutionSnapshot,
)
class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository):
@ -117,6 +120,80 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
with self._session_maker() as session:
return session.execute(stmt).scalars().all()
def get_execution_snapshots_by_workflow_run(
self,
tenant_id: str,
app_id: str,
workflow_id: str,
triggered_from: str,
workflow_run_id: str,
) -> Sequence[WorkflowNodeExecutionSnapshot]:
stmt = (
select(
WorkflowNodeExecutionModel.id,
WorkflowNodeExecutionModel.node_execution_id,
WorkflowNodeExecutionModel.node_id,
WorkflowNodeExecutionModel.node_type,
WorkflowNodeExecutionModel.title,
WorkflowNodeExecutionModel.index,
WorkflowNodeExecutionModel.status,
WorkflowNodeExecutionModel.elapsed_time,
WorkflowNodeExecutionModel.created_at,
WorkflowNodeExecutionModel.finished_at,
WorkflowNodeExecutionModel.execution_metadata,
)
.where(
WorkflowNodeExecutionModel.tenant_id == tenant_id,
WorkflowNodeExecutionModel.app_id == app_id,
WorkflowNodeExecutionModel.workflow_id == workflow_id,
WorkflowNodeExecutionModel.triggered_from == triggered_from,
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
)
.order_by(
asc(WorkflowNodeExecutionModel.created_at),
asc(WorkflowNodeExecutionModel.index),
)
)
with self._session_maker() as session:
rows = session.execute(stmt).all()
return [self._row_to_snapshot(row) for row in rows]
@staticmethod
def _row_to_snapshot(row: object) -> WorkflowNodeExecutionSnapshot:
metadata: dict[str, object] = {}
execution_metadata = getattr(row, "execution_metadata", None)
if execution_metadata:
try:
metadata = json.loads(execution_metadata)
except json.JSONDecodeError:
metadata = {}
iteration_id = metadata.get(WorkflowNodeExecutionMetadataKey.ITERATION_ID.value)
loop_id = metadata.get(WorkflowNodeExecutionMetadataKey.LOOP_ID.value)
execution_id = getattr(row, "node_execution_id", None) or row.id
elapsed_time = getattr(row, "elapsed_time", None)
created_at = row.created_at
finished_at = getattr(row, "finished_at", None)
if elapsed_time is None:
if finished_at is not None and created_at is not None:
elapsed_time = (finished_at - created_at).total_seconds()
else:
elapsed_time = 0.0
return WorkflowNodeExecutionSnapshot(
execution_id=str(execution_id),
node_id=row.node_id,
node_type=row.node_type,
title=row.title,
index=row.index,
status=row.status,
elapsed_time=float(elapsed_time),
created_at=created_at,
finished_at=finished_at,
iteration_id=str(iteration_id) if iteration_id else None,
loop_id=str(loop_id) if loop_id else None,
)
def get_execution_by_id(
self,
execution_id: str,

View File

@ -0,0 +1,459 @@
from __future__ import annotations
import json
import logging
import queue
import threading
import time
from collections.abc import Generator, Iterable, Mapping, Sequence
from dataclasses import dataclass
from typing import Any
from sqlalchemy import desc, select
from sqlalchemy.orm import Session, sessionmaker
from core.app.apps.message_generator import MessageGenerator
from core.app.entities.task_entities import (
NodeFinishStreamResponse,
NodeStartStreamResponse,
StreamEvent,
WorkflowPauseStreamResponse,
WorkflowStartStreamResponse,
)
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext
from core.workflow.entities import WorkflowStartReason
from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus
from core.workflow.runtime import GraphRuntimeState
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from models.model import AppMode, Message
from models.workflow import WorkflowRun
from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot
from repositories.entities.workflow_pause import WorkflowPauseEntity
from repositories.factory import DifyAPIRepositoryFactory
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class MessageContext:
conversation_id: str
message_id: str
created_at: int
def build_workflow_event_stream(
*,
app_mode: AppMode,
workflow_run: WorkflowRun,
tenant_id: str,
app_id: str,
session_maker: sessionmaker[Session],
idle_timeout: float = 300,
ping_interval: float = 10.0,
) -> Generator[Mapping[str, Any] | str, None, None]:
topic = MessageGenerator.get_response_topic(app_mode, workflow_run.id)
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker)
message_context = (
_get_message_context(session_maker, workflow_run.id) if app_mode == AppMode.ADVANCED_CHAT else None
)
pause_entity: WorkflowPauseEntity | None = None
if workflow_run.status == WorkflowExecutionStatus.PAUSED:
try:
pause_entity = workflow_run_repo.get_workflow_pause(workflow_run.id)
except Exception:
logger.exception("Failed to load workflow pause for run %s", workflow_run.id)
pause_entity = None
resumption_context = _load_resumption_context(pause_entity)
node_snapshots = node_execution_repo.get_execution_snapshots_by_workflow_run(
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_run.workflow_id,
triggered_from=workflow_run.triggered_from,
workflow_run_id=workflow_run.id,
)
def _generate() -> Generator[Mapping[str, Any] | str, None, None]:
last_msg_time = time.time()
last_ping_time = last_msg_time
with topic.subscribe() as sub:
buffer_queue, stop_event, buffer_done = _start_buffering(sub)
try:
buffered_events = _drain_buffer(buffer_queue)
task_id = _resolve_task_id(resumption_context, buffered_events, workflow_run.id)
snapshot_events = _build_snapshot_events(
workflow_run=workflow_run,
node_snapshots=node_snapshots,
task_id=task_id,
message_context=message_context,
pause_entity=pause_entity,
resumption_context=resumption_context,
)
buffered_events.extend(_drain_buffer(buffer_queue))
snapshot_keys = _collect_snapshot_keys(snapshot_events)
for event in snapshot_events:
last_msg_time = time.time()
last_ping_time = last_msg_time
yield event
if _is_terminal_event(event):
return
for event in _filter_buffered_events(buffered_events, snapshot_keys):
last_msg_time = time.time()
last_ping_time = last_msg_time
yield event
if _is_terminal_event(event):
return
while True:
if buffer_done.is_set() and buffer_queue.empty():
return
try:
event = buffer_queue.get(timeout=0.1)
except queue.Empty:
current_time = time.time()
if current_time - last_msg_time > idle_timeout:
return
if current_time - last_ping_time >= ping_interval:
yield StreamEvent.PING.value
last_ping_time = current_time
continue
if _is_duplicate_event(event, snapshot_keys):
continue
last_msg_time = time.time()
last_ping_time = last_msg_time
yield event
if _is_terminal_event(event):
return
finally:
stop_event.set()
return _generate()
def _get_message_context(session_maker: sessionmaker[Session], workflow_run_id: str) -> MessageContext | None:
with session_maker() as session:
stmt = select(Message).where(Message.workflow_run_id == workflow_run_id).order_by(desc(Message.created_at))
message = session.scalar(stmt)
if message is None:
return None
created_at = int(message.created_at.timestamp()) if message.created_at else 0
return MessageContext(
conversation_id=message.conversation_id,
message_id=message.id,
created_at=created_at,
)
def _load_resumption_context(pause_entity: WorkflowPauseEntity | None) -> WorkflowResumptionContext | None:
if pause_entity is None:
return None
try:
raw_state = pause_entity.get_state().decode()
return WorkflowResumptionContext.loads(raw_state)
except Exception:
logger.exception("Failed to load resumption context")
return None
def _resolve_task_id(
resumption_context: WorkflowResumptionContext | None,
buffered_events: Sequence[Mapping[str, Any]],
workflow_run_id: str,
) -> str:
if resumption_context is not None:
generate_entity = resumption_context.get_generate_entity()
if generate_entity.task_id:
return generate_entity.task_id
for event in buffered_events:
task_id = event.get("task_id")
if task_id:
return str(task_id)
return workflow_run_id
def _build_snapshot_events(
*,
workflow_run: WorkflowRun,
node_snapshots: Sequence[WorkflowNodeExecutionSnapshot],
task_id: str,
message_context: MessageContext | None,
pause_entity: WorkflowPauseEntity | None,
resumption_context: WorkflowResumptionContext | None,
) -> list[Mapping[str, Any]]:
events: list[Mapping[str, Any]] = []
workflow_started = _build_workflow_started_event(
workflow_run=workflow_run,
task_id=task_id,
)
_apply_message_context(workflow_started, message_context)
events.append(workflow_started)
for snapshot in node_snapshots:
node_started = _build_node_started_event(
workflow_run_id=workflow_run.id,
task_id=task_id,
snapshot=snapshot,
)
_apply_message_context(node_started, message_context)
events.append(node_started)
if snapshot.status != WorkflowNodeExecutionStatus.RUNNING.value:
node_finished = _build_node_finished_event(
workflow_run_id=workflow_run.id,
task_id=task_id,
snapshot=snapshot,
)
_apply_message_context(node_finished, message_context)
events.append(node_finished)
if workflow_run.status == WorkflowExecutionStatus.PAUSED and pause_entity is not None:
pause_event = _build_pause_event(
workflow_run_id=workflow_run.id,
task_id=task_id,
pause_entity=pause_entity,
resumption_context=resumption_context,
)
if pause_event is not None:
_apply_message_context(pause_event, message_context)
events.append(pause_event)
return events
def _build_workflow_started_event(
*,
workflow_run: WorkflowRun,
task_id: str,
) -> dict[str, Any]:
response = WorkflowStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=WorkflowStartStreamResponse.Data(
id=workflow_run.id,
workflow_id=workflow_run.workflow_id,
inputs=workflow_run.inputs_dict or {},
created_at=int(workflow_run.created_at.timestamp()),
reason=WorkflowStartReason.INITIAL,
),
)
payload = response.model_dump(mode="json")
payload["event"] = response.event.value
return payload
def _build_node_started_event(
*,
workflow_run_id: str,
task_id: str,
snapshot: WorkflowNodeExecutionSnapshot,
) -> Mapping[str, Any]:
created_at = int(snapshot.created_at.timestamp()) if snapshot.created_at else 0
response = NodeStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run_id,
data=NodeStartStreamResponse.Data(
id=snapshot.execution_id,
node_id=snapshot.node_id,
node_type=snapshot.node_type,
title=snapshot.title,
index=snapshot.index,
predecessor_node_id=None,
inputs=None,
created_at=created_at,
extras={},
iteration_id=snapshot.iteration_id,
loop_id=snapshot.loop_id,
),
)
return response.to_ignore_detail_dict()
def _build_node_finished_event(
*,
workflow_run_id: str,
task_id: str,
snapshot: WorkflowNodeExecutionSnapshot,
) -> Mapping[str, Any]:
created_at = int(snapshot.created_at.timestamp()) if snapshot.created_at else 0
finished_at = int(snapshot.finished_at.timestamp()) if snapshot.finished_at else created_at
response = NodeFinishStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run_id,
data=NodeFinishStreamResponse.Data(
id=snapshot.execution_id,
node_id=snapshot.node_id,
node_type=snapshot.node_type,
title=snapshot.title,
index=snapshot.index,
predecessor_node_id=None,
inputs=None,
process_data=None,
outputs=None,
status=snapshot.status,
error=None,
elapsed_time=snapshot.elapsed_time,
execution_metadata=None,
created_at=created_at,
finished_at=finished_at,
files=[],
iteration_id=snapshot.iteration_id,
loop_id=snapshot.loop_id,
),
)
return response.to_ignore_detail_dict()
def _build_pause_event(
*,
workflow_run_id: str,
task_id: str,
pause_entity: WorkflowPauseEntity,
resumption_context: WorkflowResumptionContext | None,
) -> Mapping[str, Any] | None:
paused_nodes: list[str] = []
outputs: dict[str, Any] = {}
if resumption_context is not None:
state = GraphRuntimeState.from_snapshot(resumption_context.serialized_graph_runtime_state)
paused_nodes = state.get_paused_nodes()
outputs = WorkflowRuntimeTypeConverter().to_json_encodable(state.outputs or {})
reasons = [reason.model_dump(mode="json") for reason in pause_entity.get_pause_reasons()]
response = WorkflowPauseStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run_id,
data=WorkflowPauseStreamResponse.Data(
workflow_run_id=workflow_run_id,
paused_nodes=paused_nodes,
outputs=outputs,
reasons=reasons,
),
)
payload = response.model_dump(mode="json")
payload["event"] = response.event.value
return payload
def _apply_message_context(payload: dict[str, Any], message_context: MessageContext | None) -> None:
if message_context is None:
return
payload["conversation_id"] = message_context.conversation_id
payload["message_id"] = message_context.message_id
payload["created_at"] = message_context.created_at
def _start_buffering(subscription) -> tuple[queue.Queue[Mapping[str, Any]], threading.Event, threading.Event]:
buffer_queue: queue.Queue[Mapping[str, Any]] = queue.Queue(maxsize=2048)
stop_event = threading.Event()
done_event = threading.Event()
def _worker() -> None:
dropped_count = 0
try:
while not stop_event.is_set():
msg = subscription.receive(timeout=0.1)
if msg is None:
continue
event = _parse_event_message(msg)
if event is None:
continue
try:
buffer_queue.put_nowait(event)
except queue.Full:
dropped_count += 1
try:
buffer_queue.get_nowait()
except queue.Empty:
pass
try:
buffer_queue.put_nowait(event)
except queue.Full:
continue
logger.warning("Dropped buffered workflow event, total_dropped=%s", dropped_count)
except Exception:
logger.exception("Failed while buffering workflow events")
finally:
done_event.set()
thread = threading.Thread(target=_worker, name=f"workflow-event-buffer-{id(subscription)}", daemon=True)
thread.start()
return buffer_queue, stop_event, done_event
def _drain_buffer(
buffer_queue: queue.Queue[Mapping[str, Any]],
) -> list[Mapping[str, Any]]:
events: list[Mapping[str, Any]] = []
while True:
try:
event = buffer_queue.get_nowait()
except queue.Empty:
break
events.append(event)
return events
def _parse_event_message(message: bytes) -> Mapping[str, Any] | None:
try:
event = json.loads(message)
except json.JSONDecodeError:
logger.warning("Failed to decode workflow event payload")
return None
if not isinstance(event, dict):
return None
return event
def _is_terminal_event(event: Mapping[str, Any] | str) -> bool:
if not isinstance(event, Mapping):
return False
event_type = event.get("event")
return event_type in (StreamEvent.WORKFLOW_FINISHED.value, StreamEvent.WORKFLOW_PAUSED.value)
def _collect_snapshot_keys(events: Iterable[Mapping[str, Any]]) -> set[tuple[str, str]]:
keys: set[tuple[str, str]] = set()
for event in events:
key = _event_snapshot_key(event)
if key is not None:
keys.add(key)
return keys
def _filter_buffered_events(
events: Sequence[Mapping[str, Any]],
snapshot_keys: set[tuple[str, str]],
) -> Iterable[Mapping[str, Any]]:
for event in events:
if _is_duplicate_event(event, snapshot_keys):
continue
yield event
def _is_duplicate_event(event: Mapping[str, Any], snapshot_keys: set[tuple[str, str]]) -> bool:
key = _event_snapshot_key(event)
if key is None:
return False
return key in snapshot_keys
def _event_snapshot_key(event: Mapping[str, Any]) -> tuple[str, str] | None:
event_type = event.get("event")
if not event_type:
return None
if event_type == StreamEvent.WORKFLOW_STARTED.value:
return (event_type, event.get("workflow_run_id") or "")
if event_type in {StreamEvent.NODE_STARTED.value, StreamEvent.NODE_FINISHED.value}:
data = event.get("data") or {}
return (event_type, str(data.get("id") or ""))
if event_type == StreamEvent.WORKFLOW_PAUSED.value:
return (event_type, event.get("workflow_run_id") or "")
return None

View File

@ -0,0 +1,237 @@
from __future__ import annotations
import json
from collections.abc import Sequence
from dataclasses import dataclass
from datetime import UTC, datetime
import pytest
from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus
from core.workflow.runtime import GraphRuntimeState, VariablePool
from models.enums import CreatorUserRole
from models.model import AppMode
from models.workflow import WorkflowRun
from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot
from repositories.entities.workflow_pause import WorkflowPauseEntity
from services.workflow_event_snapshot_service import (
MessageContext,
_build_snapshot_events,
_collect_snapshot_keys,
_filter_buffered_events,
_resolve_task_id,
)
@dataclass(frozen=True)
class _FakePauseEntity(WorkflowPauseEntity):
pause_id: str
workflow_run_id: str
paused_at_value: datetime
pause_reasons: Sequence[HumanInputRequired]
@property
def id(self) -> str:
return self.pause_id
@property
def workflow_execution_id(self) -> str:
return self.workflow_run_id
def get_state(self) -> bytes:
raise AssertionError("state is not required for snapshot tests")
@property
def resumed_at(self) -> datetime | None:
return None
@property
def paused_at(self) -> datetime:
return self.paused_at_value
def get_pause_reasons(self) -> Sequence[HumanInputRequired]:
return self.pause_reasons
def _build_workflow_run(status: WorkflowExecutionStatus) -> WorkflowRun:
return WorkflowRun(
id="run-1",
tenant_id="tenant-1",
app_id="app-1",
workflow_id="workflow-1",
type="workflow",
triggered_from="app-run",
version="v1",
graph=None,
inputs=json.dumps({"input": "value"}),
status=status,
outputs=json.dumps({}),
error=None,
elapsed_time=0.0,
total_tokens=0,
total_steps=0,
created_by_role=CreatorUserRole.END_USER,
created_by="user-1",
created_at=datetime(2024, 1, 1, tzinfo=UTC),
)
def _build_snapshot(status: WorkflowNodeExecutionStatus) -> WorkflowNodeExecutionSnapshot:
created_at = datetime(2024, 1, 1, tzinfo=UTC)
finished_at = datetime(2024, 1, 1, 0, 0, 5, tzinfo=UTC)
return WorkflowNodeExecutionSnapshot(
execution_id="exec-1",
node_id="node-1",
node_type="human-input",
title="Human Input",
index=1,
status=status.value,
elapsed_time=0.5,
created_at=created_at,
finished_at=finished_at,
iteration_id=None,
loop_id=None,
)
def _build_resumption_context(task_id: str) -> WorkflowResumptionContext:
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant-1",
app_id="app-1",
app_mode=AppMode.WORKFLOW,
workflow_id="workflow-1",
)
generate_entity = WorkflowAppGenerateEntity(
task_id=task_id,
app_config=app_config,
inputs={},
files=[],
user_id="user-1",
stream=True,
invoke_from=InvokeFrom.EXPLORE,
call_depth=0,
workflow_execution_id="run-1",
)
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
runtime_state.register_paused_node("node-1")
runtime_state.outputs = {"result": "value"}
wrapper = _WorkflowGenerateEntityWrapper(entity=generate_entity)
return WorkflowResumptionContext(
generate_entity=wrapper,
serialized_graph_runtime_state=runtime_state.dumps(),
)
def test_build_snapshot_events_includes_pause_event() -> None:
workflow_run = _build_workflow_run(WorkflowExecutionStatus.PAUSED)
snapshot = _build_snapshot(WorkflowNodeExecutionStatus.PAUSED)
resumption_context = _build_resumption_context("task-ctx")
pause_entity = _FakePauseEntity(
pause_id="pause-1",
workflow_run_id="run-1",
paused_at_value=datetime(2024, 1, 1, tzinfo=UTC),
pause_reasons=[
HumanInputRequired(
form_id="form-1",
form_content="content",
node_id="node-1",
node_title="Human Input",
)
],
)
events = _build_snapshot_events(
workflow_run=workflow_run,
node_snapshots=[snapshot],
task_id="task-ctx",
message_context=None,
pause_entity=pause_entity,
resumption_context=resumption_context,
)
assert [event["event"] for event in events] == [
"workflow_started",
"node_started",
"node_finished",
"workflow_paused",
]
assert events[2]["data"]["status"] == WorkflowNodeExecutionStatus.PAUSED.value
pause_data = events[-1]["data"]
assert pause_data["paused_nodes"] == ["node-1"]
assert pause_data["outputs"] == {"result": "value"}
def test_build_snapshot_events_applies_message_context() -> None:
workflow_run = _build_workflow_run(WorkflowExecutionStatus.RUNNING)
snapshot = _build_snapshot(WorkflowNodeExecutionStatus.SUCCEEDED)
message_context = MessageContext(
conversation_id="conv-1",
message_id="msg-1",
created_at=1700000000,
)
events = _build_snapshot_events(
workflow_run=workflow_run,
node_snapshots=[snapshot],
task_id="task-1",
message_context=message_context,
pause_entity=None,
resumption_context=None,
)
for event in events:
assert event["conversation_id"] == "conv-1"
assert event["message_id"] == "msg-1"
assert event["created_at"] == 1700000000
@pytest.mark.parametrize(
("context_task_id", "buffered_task_id", "expected"),
[
("task-ctx", "task-buffer", "task-ctx"),
(None, "task-buffer", "task-buffer"),
(None, None, "run-1"),
],
)
def test_resolve_task_id_priority(context_task_id, buffered_task_id, expected) -> None:
resumption_context = _build_resumption_context(context_task_id) if context_task_id else None
buffered_events = [{"task_id": buffered_task_id}] if buffered_task_id else []
task_id = _resolve_task_id(resumption_context, buffered_events, "run-1")
assert task_id == expected
def test_filter_buffered_events_deduplicates_snapshot_nodes() -> None:
workflow_run = _build_workflow_run(WorkflowExecutionStatus.RUNNING)
snapshot = _build_snapshot(WorkflowNodeExecutionStatus.SUCCEEDED)
events = _build_snapshot_events(
workflow_run=workflow_run,
node_snapshots=[snapshot],
task_id="task-1",
message_context=None,
pause_entity=None,
resumption_context=None,
)
snapshot_keys = _collect_snapshot_keys(events)
buffered_events = [
{
"event": "node_started",
"data": {"id": "exec-1"},
},
{
"event": "node_finished",
"data": {"id": "exec-2"},
},
]
filtered = list(_filter_buffered_events(buffered_events, snapshot_keys))
assert filtered == [
{
"event": "node_finished",
"data": {"id": "exec-2"},
}
]