mirror of
https://github.com/langgenius/dify.git
synced 2026-04-30 07:28:05 +08:00
fix the issue in mark_timeout (vibe-kanban db2a9506)
distinguish between the global timeout and node timeout. For node-level timeout, the status should be updated to timeout. for global timeout, the status should be updated to expired. For status in (timeout, expired, SUBMITTED), the form should not be processed by the `check_and_handle_human_input_timeouts` logic. only node-level timeout should resume the execution of workflow, global timeout should mark execution as STOPPED. Update the documentation of HumanInputFormStatus to reflect the facts above.
This commit is contained in:
@ -2,14 +2,13 @@ import logging
|
||||
from datetime import timedelta
|
||||
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import or_, select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.repositories.human_input_reposotiry import HumanInputFormSubmissionRepository
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.nodes.human_input.entities import FormDefinition
|
||||
from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus, TimeoutUnit
|
||||
from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from libs.datetime_utils import ensure_naive_utc, naive_utc_now
|
||||
@ -20,26 +19,14 @@ from services.human_input_service import HumanInputService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _calculate_node_deadline(definition: FormDefinition, created_at, *, start_time=None):
|
||||
start = start_time or created_at
|
||||
if definition.timeout_unit == TimeoutUnit.HOUR:
|
||||
return start + timedelta(hours=definition.timeout)
|
||||
if definition.timeout_unit == TimeoutUnit.DAY:
|
||||
return start + timedelta(days=definition.timeout)
|
||||
raise AssertionError("unknown timeout unit.")
|
||||
|
||||
|
||||
def _is_global_timeout(form_model: HumanInputForm, global_timeout_seconds: int) -> bool:
|
||||
def _is_global_timeout(form_model: HumanInputForm, global_timeout_seconds: int, *, now) -> bool:
|
||||
if global_timeout_seconds <= 0:
|
||||
return False
|
||||
|
||||
form_definition = FormDefinition.model_validate_json(form_model.form_definition)
|
||||
|
||||
if form_model.workflow_run_id is None:
|
||||
return False
|
||||
created_at = ensure_naive_utc(form_model.created_at)
|
||||
expiration_time = ensure_naive_utc(form_model.expiration_time)
|
||||
node_deadline = _calculate_node_deadline(form_definition, created_at)
|
||||
global_deadline = created_at + timedelta(seconds=global_timeout_seconds)
|
||||
return global_deadline <= node_deadline and expiration_time <= global_deadline
|
||||
return global_deadline <= now
|
||||
|
||||
|
||||
def _handle_global_timeout(*, form_id: str, workflow_run_id: str, node_id: str, session_factory: sessionmaker) -> None:
|
||||
@ -77,11 +64,15 @@ def check_and_handle_human_input_timeouts(limit: int = 100) -> None:
|
||||
global_timeout_seconds = dify_config.HITL_GLOBAL_TIMEOUT_SECONDS
|
||||
|
||||
with session_factory() as session:
|
||||
global_deadline = now - timedelta(seconds=global_timeout_seconds) if global_timeout_seconds > 0 else None
|
||||
timeout_filter = HumanInputForm.expiration_time <= now
|
||||
if global_deadline is not None:
|
||||
timeout_filter = or_(timeout_filter, HumanInputForm.created_at <= global_deadline)
|
||||
stmt = (
|
||||
select(HumanInputForm)
|
||||
.where(
|
||||
HumanInputForm.status == HumanInputFormStatus.WAITING,
|
||||
HumanInputForm.expiration_time <= now,
|
||||
timeout_filter,
|
||||
)
|
||||
.order_by(HumanInputForm.id.asc())
|
||||
.limit(limit)
|
||||
@ -97,7 +88,7 @@ def check_and_handle_human_input_timeouts(limit: int = 100) -> None:
|
||||
reason="delivery_test_timeout",
|
||||
)
|
||||
continue
|
||||
is_global = _is_global_timeout(form_model, global_timeout_seconds)
|
||||
is_global = _is_global_timeout(form_model, global_timeout_seconds, now=now)
|
||||
record = form_repo.mark_timeout(
|
||||
form_id=form_model.id,
|
||||
timeout_status=HumanInputFormStatus.EXPIRED if is_global else HumanInputFormStatus.TIMEOUT,
|
||||
|
||||
212
api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py
Normal file
212
api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py
Normal file
@ -0,0 +1,212 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
|
||||
from tasks import human_input_timeout_tasks as task_module
|
||||
|
||||
|
||||
class _FakeScalarResult:
|
||||
def __init__(self, items: list[Any]):
|
||||
self._items = items
|
||||
|
||||
def all(self) -> list[Any]:
|
||||
return self._items
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, items: list[Any], capture: dict[str, Any]):
|
||||
self._items = items
|
||||
self._capture = capture
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def scalars(self, stmt):
|
||||
self._capture["stmt"] = stmt
|
||||
return _FakeScalarResult(self._items)
|
||||
|
||||
|
||||
class _FakeSessionFactory:
|
||||
def __init__(self, items: list[Any], capture: dict[str, Any]):
|
||||
self._items = items
|
||||
self._capture = capture
|
||||
self._capture["session_factory"] = self
|
||||
|
||||
def __call__(self):
|
||||
session = _FakeSession(self._items, self._capture)
|
||||
self._capture["session"] = session
|
||||
return session
|
||||
|
||||
|
||||
class _FakeFormRepo:
|
||||
def __init__(self, _session_factory, form_map: dict[str, Any] | None = None):
|
||||
self.calls: list[dict[str, Any]] = []
|
||||
self._form_map = form_map or {}
|
||||
|
||||
def mark_timeout(self, *, form_id: str, timeout_status: HumanInputFormStatus, reason: str | None = None):
|
||||
self.calls.append(
|
||||
{
|
||||
"form_id": form_id,
|
||||
"timeout_status": timeout_status,
|
||||
"reason": reason,
|
||||
}
|
||||
)
|
||||
form = self._form_map.get(form_id)
|
||||
return SimpleNamespace(
|
||||
form_id=form_id,
|
||||
workflow_run_id=getattr(form, "workflow_run_id", None),
|
||||
node_id=getattr(form, "node_id", None),
|
||||
)
|
||||
|
||||
|
||||
class _FakeService:
|
||||
def __init__(self, _session_factory, form_repository=None):
|
||||
self.enqueued: list[str] = []
|
||||
|
||||
def _enqueue_resume(self, workflow_run_id: str | None) -> None:
|
||||
if workflow_run_id is not None:
|
||||
self.enqueued.append(workflow_run_id)
|
||||
|
||||
|
||||
def _build_form(
|
||||
*,
|
||||
form_id: str,
|
||||
form_kind: HumanInputFormKind,
|
||||
created_at: datetime,
|
||||
expiration_time: datetime,
|
||||
workflow_run_id: str | None,
|
||||
node_id: str,
|
||||
) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
id=form_id,
|
||||
form_kind=form_kind,
|
||||
created_at=created_at,
|
||||
expiration_time=expiration_time,
|
||||
workflow_run_id=workflow_run_id,
|
||||
node_id=node_id,
|
||||
status=HumanInputFormStatus.WAITING,
|
||||
)
|
||||
|
||||
|
||||
def test_is_global_timeout_uses_created_at():
|
||||
now = datetime(2025, 1, 1, 12, 0, 0)
|
||||
form = SimpleNamespace(created_at=now - timedelta(seconds=61), workflow_run_id="run-1")
|
||||
|
||||
assert task_module._is_global_timeout(form, 60, now=now) is True
|
||||
|
||||
form.workflow_run_id = None
|
||||
assert task_module._is_global_timeout(form, 60, now=now) is False
|
||||
|
||||
form.workflow_run_id = "run-1"
|
||||
form.created_at = now - timedelta(seconds=59)
|
||||
assert task_module._is_global_timeout(form, 60, now=now) is False
|
||||
|
||||
assert task_module._is_global_timeout(form, 0, now=now) is False
|
||||
|
||||
|
||||
def test_check_and_handle_human_input_timeouts_marks_and_routes(monkeypatch: pytest.MonkeyPatch):
|
||||
now = datetime(2025, 1, 1, 12, 0, 0)
|
||||
monkeypatch.setattr(task_module, "naive_utc_now", lambda: now)
|
||||
monkeypatch.setattr(task_module.dify_config, "HITL_GLOBAL_TIMEOUT_SECONDS", 3600)
|
||||
monkeypatch.setattr(task_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
forms = [
|
||||
_build_form(
|
||||
form_id="form-global",
|
||||
form_kind=HumanInputFormKind.RUNTIME,
|
||||
created_at=now - timedelta(hours=2),
|
||||
expiration_time=now + timedelta(hours=1),
|
||||
workflow_run_id="run-global",
|
||||
node_id="node-global",
|
||||
),
|
||||
_build_form(
|
||||
form_id="form-node",
|
||||
form_kind=HumanInputFormKind.RUNTIME,
|
||||
created_at=now - timedelta(minutes=5),
|
||||
expiration_time=now - timedelta(seconds=1),
|
||||
workflow_run_id="run-node",
|
||||
node_id="node-node",
|
||||
),
|
||||
_build_form(
|
||||
form_id="form-delivery",
|
||||
form_kind=HumanInputFormKind.DELIVERY_TEST,
|
||||
created_at=now - timedelta(minutes=1),
|
||||
expiration_time=now - timedelta(seconds=1),
|
||||
workflow_run_id=None,
|
||||
node_id="node-delivery",
|
||||
),
|
||||
]
|
||||
|
||||
capture: dict[str, Any] = {}
|
||||
monkeypatch.setattr(task_module, "sessionmaker", lambda *args, **kwargs: _FakeSessionFactory(forms, capture))
|
||||
|
||||
form_map = {form.id: form for form in forms}
|
||||
repo = _FakeFormRepo(None, form_map=form_map)
|
||||
|
||||
def _repo_factory(_session_factory):
|
||||
return repo
|
||||
|
||||
service = _FakeService(None)
|
||||
|
||||
def _service_factory(_session_factory, form_repository=None):
|
||||
return service
|
||||
|
||||
global_calls: list[dict[str, Any]] = []
|
||||
|
||||
monkeypatch.setattr(task_module, "HumanInputFormSubmissionRepository", _repo_factory)
|
||||
monkeypatch.setattr(task_module, "HumanInputService", _service_factory)
|
||||
monkeypatch.setattr(task_module, "_handle_global_timeout", lambda **kwargs: global_calls.append(kwargs))
|
||||
|
||||
task_module.check_and_handle_human_input_timeouts(limit=100)
|
||||
|
||||
assert {
|
||||
(call["form_id"], call["timeout_status"], call["reason"]) for call in repo.calls
|
||||
} == {
|
||||
("form-global", HumanInputFormStatus.EXPIRED, "global_timeout"),
|
||||
("form-node", HumanInputFormStatus.TIMEOUT, "node_timeout"),
|
||||
("form-delivery", HumanInputFormStatus.TIMEOUT, "delivery_test_timeout"),
|
||||
}
|
||||
assert service.enqueued == ["run-node"]
|
||||
assert global_calls == [
|
||||
{
|
||||
"form_id": "form-global",
|
||||
"workflow_run_id": "run-global",
|
||||
"node_id": "node-global",
|
||||
"session_factory": capture.get("session_factory"),
|
||||
}
|
||||
]
|
||||
|
||||
stmt = capture.get("stmt")
|
||||
assert stmt is not None
|
||||
stmt_text = str(stmt)
|
||||
assert "created_at <=" in stmt_text
|
||||
assert "expiration_time <=" in stmt_text
|
||||
assert "ORDER BY human_input_forms.id" in stmt_text
|
||||
|
||||
|
||||
def test_check_and_handle_human_input_timeouts_omits_global_filter_when_disabled(monkeypatch: pytest.MonkeyPatch):
|
||||
now = datetime(2025, 1, 1, 12, 0, 0)
|
||||
monkeypatch.setattr(task_module, "naive_utc_now", lambda: now)
|
||||
monkeypatch.setattr(task_module.dify_config, "HITL_GLOBAL_TIMEOUT_SECONDS", 0)
|
||||
monkeypatch.setattr(task_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
capture: dict[str, Any] = {}
|
||||
monkeypatch.setattr(task_module, "sessionmaker", lambda *args, **kwargs: _FakeSessionFactory([], capture))
|
||||
monkeypatch.setattr(task_module, "HumanInputFormSubmissionRepository", _FakeFormRepo)
|
||||
monkeypatch.setattr(task_module, "HumanInputService", _FakeService)
|
||||
monkeypatch.setattr(task_module, "_handle_global_timeout", lambda **_kwargs: None)
|
||||
|
||||
task_module.check_and_handle_human_input_timeouts(limit=1)
|
||||
|
||||
stmt = capture.get("stmt")
|
||||
assert stmt is not None
|
||||
stmt_text = str(stmt)
|
||||
assert "created_at <=" not in stmt_text
|
||||
Reference in New Issue
Block a user