mirror of
https://github.com/langgenius/dify.git
synced 2026-05-02 16:38:04 +08:00
feat(api): adjust /events resumption mechanism
Avoid drain_queue and race condition caused by drain queue. The current approach starts a background thread and buffer in-fly events to an intermediate queue.Queue. The queue is bound and will drop events once it's full.
This commit is contained in:
@ -25,6 +25,7 @@ from core.workflow.entities import WorkflowStartReason
|
||||
from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import AppMode, Message
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom, WorkflowRun
|
||||
from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot
|
||||
@ -41,6 +42,15 @@ class MessageContext:
|
||||
created_at: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class BufferState:
|
||||
queue: queue.Queue[Mapping[str, Any]]
|
||||
stop_event: threading.Event
|
||||
done_event: threading.Event
|
||||
task_id_ready: threading.Event
|
||||
task_id_hint: str | None = None
|
||||
|
||||
|
||||
def build_workflow_event_stream(
|
||||
*,
|
||||
app_mode: AppMode,
|
||||
@ -66,7 +76,7 @@ def build_workflow_event_stream(
|
||||
logger.exception("Failed to load workflow pause for run %s", workflow_run.id)
|
||||
pause_entity = None
|
||||
|
||||
resumption_context = _load_resumption_context(pause_entity=pause_entity)
|
||||
resumption_context = _load_resumption_context(pause_entity)
|
||||
node_snapshots = node_execution_repo.get_execution_snapshots_by_workflow_run(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
@ -87,10 +97,9 @@ def build_workflow_event_stream(
|
||||
last_ping_time = last_msg_time
|
||||
|
||||
with topic.subscribe() as sub:
|
||||
buffer_queue, stop_event, buffer_done = _start_buffering(sub)
|
||||
buffer_state = _start_buffering(sub)
|
||||
try:
|
||||
buffered_events = _drain_buffer(buffer_queue)
|
||||
task_id = _resolve_task_id(resumption_context, buffered_events, workflow_run.id)
|
||||
task_id = _resolve_task_id(resumption_context, buffer_state, workflow_run.id)
|
||||
|
||||
snapshot_events = _build_snapshot_events(
|
||||
workflow_run=workflow_run,
|
||||
@ -100,7 +109,6 @@ def build_workflow_event_stream(
|
||||
pause_entity=pause_entity,
|
||||
resumption_context=resumption_context,
|
||||
)
|
||||
buffered_events.extend(_drain_buffer(buffer_queue))
|
||||
snapshot_keys = _collect_snapshot_keys(snapshot_events)
|
||||
|
||||
for event in snapshot_events:
|
||||
@ -110,19 +118,12 @@ def build_workflow_event_stream(
|
||||
if _is_terminal_event(event):
|
||||
return
|
||||
|
||||
for event in _filter_buffered_events(buffered_events, snapshot_keys):
|
||||
last_msg_time = time.time()
|
||||
last_ping_time = last_msg_time
|
||||
yield event
|
||||
if _is_terminal_event(event):
|
||||
return
|
||||
|
||||
while True:
|
||||
if buffer_done.is_set() and buffer_queue.empty():
|
||||
if buffer_state.done_event.is_set() and buffer_state.queue.empty():
|
||||
return
|
||||
|
||||
try:
|
||||
event = buffer_queue.get(timeout=0.1)
|
||||
event = buffer_state.queue.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
current_time = time.time()
|
||||
if current_time - last_msg_time > idle_timeout:
|
||||
@ -144,7 +145,7 @@ def build_workflow_event_stream(
|
||||
if _is_terminal_event(event):
|
||||
return
|
||||
finally:
|
||||
stop_event.set()
|
||||
buffer_state.stop_event.set()
|
||||
|
||||
return _generate()
|
||||
|
||||
@ -176,17 +177,20 @@ def _load_resumption_context(pause_entity: WorkflowPauseEntity | None) -> Workfl
|
||||
|
||||
def _resolve_task_id(
|
||||
resumption_context: WorkflowResumptionContext | None,
|
||||
buffered_events: Sequence[Mapping[str, Any]],
|
||||
buffer_state: BufferState | None,
|
||||
workflow_run_id: str,
|
||||
wait_timeout: float = 0.2,
|
||||
) -> str:
|
||||
if resumption_context is not None:
|
||||
generate_entity = resumption_context.get_generate_entity()
|
||||
if generate_entity.task_id:
|
||||
return generate_entity.task_id
|
||||
for event in buffered_events:
|
||||
task_id = event.get("task_id")
|
||||
if task_id:
|
||||
return str(task_id)
|
||||
if buffer_state is None:
|
||||
return workflow_run_id
|
||||
if buffer_state.task_id_hint is None:
|
||||
buffer_state.task_id_ready.wait(timeout=wait_timeout)
|
||||
if buffer_state.task_id_hint:
|
||||
return buffer_state.task_id_hint
|
||||
return workflow_run_id
|
||||
|
||||
|
||||
@ -361,55 +365,49 @@ def _apply_message_context(payload: dict[str, Any], message_context: MessageCont
|
||||
payload["created_at"] = message_context.created_at
|
||||
|
||||
|
||||
def _start_buffering(subscription) -> tuple[queue.Queue[Mapping[str, Any]], threading.Event, threading.Event]:
|
||||
buffer_queue: queue.Queue[Mapping[str, Any]] = queue.Queue(maxsize=2048)
|
||||
stop_event = threading.Event()
|
||||
done_event = threading.Event()
|
||||
def _start_buffering(subscription) -> BufferState:
|
||||
buffer_state = BufferState(
|
||||
queue=queue.Queue(maxsize=2048),
|
||||
stop_event=threading.Event(),
|
||||
done_event=threading.Event(),
|
||||
task_id_ready=threading.Event(),
|
||||
)
|
||||
|
||||
def _worker() -> None:
|
||||
dropped_count = 0
|
||||
try:
|
||||
while not stop_event.is_set():
|
||||
while not buffer_state.stop_event.is_set():
|
||||
msg = subscription.receive(timeout=0.1)
|
||||
if msg is None:
|
||||
continue
|
||||
event = _parse_event_message(msg)
|
||||
if event is None:
|
||||
continue
|
||||
task_id = event.get("task_id")
|
||||
if task_id and buffer_state.task_id_hint is None:
|
||||
buffer_state.task_id_hint = str(task_id)
|
||||
buffer_state.task_id_ready.set()
|
||||
try:
|
||||
buffer_queue.put_nowait(event)
|
||||
buffer_state.queue.put_nowait(event)
|
||||
except queue.Full:
|
||||
dropped_count += 1
|
||||
try:
|
||||
buffer_queue.get_nowait()
|
||||
buffer_state.queue.get_nowait()
|
||||
except queue.Empty:
|
||||
pass
|
||||
try:
|
||||
buffer_queue.put_nowait(event)
|
||||
buffer_state.queue.put_nowait(event)
|
||||
except queue.Full:
|
||||
continue
|
||||
logger.warning("Dropped buffered workflow event, total_dropped=%s", dropped_count)
|
||||
except Exception:
|
||||
logger.exception("Failed while buffering workflow events")
|
||||
finally:
|
||||
done_event.set()
|
||||
buffer_state.done_event.set()
|
||||
|
||||
thread = threading.Thread(target=_worker, name=f"workflow-event-buffer-{id(subscription)}", daemon=True)
|
||||
thread.start()
|
||||
return buffer_queue, stop_event, done_event
|
||||
|
||||
|
||||
def _drain_buffer(
|
||||
buffer_queue: queue.Queue[Mapping[str, Any]],
|
||||
) -> list[Mapping[str, Any]]:
|
||||
events: list[Mapping[str, Any]] = []
|
||||
while True:
|
||||
try:
|
||||
event = buffer_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
break
|
||||
events.append(event)
|
||||
return events
|
||||
return buffer_state
|
||||
|
||||
|
||||
def _parse_event_message(message: bytes) -> Mapping[str, Any] | None:
|
||||
@ -468,3 +466,26 @@ def _event_snapshot_key(event: Mapping[str, Any]) -> tuple[str, str] | None:
|
||||
if event_type == StreamEvent.WORKFLOW_PAUSED.value:
|
||||
return (event_type, event.get("workflow_run_id") or "")
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_node_triggered_from(workflow_run_triggered_from: str | None) -> str:
|
||||
if not workflow_run_triggered_from:
|
||||
return WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
|
||||
|
||||
mapping = {
|
||||
WorkflowRunTriggeredFrom.DEBUGGING.value: WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
WorkflowRunTriggeredFrom.APP_RUN.value: WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
WorkflowRunTriggeredFrom.WEBHOOK.value: WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
WorkflowRunTriggeredFrom.SCHEDULE.value: WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
WorkflowRunTriggeredFrom.PLUGIN.value: WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN.value: WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN.value,
|
||||
WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value: WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value,
|
||||
}
|
||||
if workflow_run_triggered_from in mapping:
|
||||
return mapping[workflow_run_triggered_from]
|
||||
|
||||
logger.warning(
|
||||
"Unknown workflow run triggered_from %s, defaulting node executions to workflow-run",
|
||||
workflow_run_triggered_from,
|
||||
)
|
||||
return WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import queue
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from threading import Event
|
||||
|
||||
import pytest
|
||||
|
||||
@ -19,6 +21,7 @@ from models.workflow import WorkflowRun
|
||||
from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
from services.workflow_event_snapshot_service import (
|
||||
BufferState,
|
||||
MessageContext,
|
||||
_build_snapshot_events,
|
||||
_collect_snapshot_keys,
|
||||
@ -199,8 +202,16 @@ def test_build_snapshot_events_applies_message_context() -> None:
|
||||
)
|
||||
def test_resolve_task_id_priority(context_task_id, buffered_task_id, expected) -> None:
|
||||
resumption_context = _build_resumption_context(context_task_id) if context_task_id else None
|
||||
buffered_events = [{"task_id": buffered_task_id}] if buffered_task_id else []
|
||||
task_id = _resolve_task_id(resumption_context, buffered_events, "run-1")
|
||||
buffer_state = BufferState(
|
||||
queue=queue.Queue(),
|
||||
stop_event=Event(),
|
||||
done_event=Event(),
|
||||
task_id_ready=Event(),
|
||||
task_id_hint=buffered_task_id,
|
||||
)
|
||||
if buffered_task_id:
|
||||
buffer_state.task_id_ready.set()
|
||||
task_id = _resolve_task_id(resumption_context, buffer_state, "run-1", wait_timeout=0.0)
|
||||
assert task_id == expected
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user