fix the issue in mark_timeout (vibe-kanban db2a9506)

distinguish between the global timeout and node timeout.

For node-level timeout, the status should be updated to timeout.

for global timeout, the status should be updated to expired.

For status in (timeout, expired, SUBMITTED), the form should not be processed by the `check_and_handle_human_input_timeouts` logic.

only node-level timeout should resume the execution of workflow, global timeout should mark execution as STOPPED.

Update the documentation of HumanInputFormStatus to reflect the facts above.
This commit is contained in:
QuantumGhost
2026-01-27 13:53:27 +08:00
parent b467edb524
commit 4bfd33c6f2
2 changed files with 224 additions and 21 deletions

View File

@ -2,14 +2,13 @@ import logging
from datetime import timedelta
from celery import shared_task
from sqlalchemy import select
from sqlalchemy import or_, select
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from core.repositories.human_input_reposotiry import HumanInputFormSubmissionRepository
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.nodes.human_input.entities import FormDefinition
from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus, TimeoutUnit
from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
from extensions.ext_database import db
from extensions.ext_storage import storage
from libs.datetime_utils import ensure_naive_utc, naive_utc_now
@ -20,26 +19,14 @@ from services.human_input_service import HumanInputService
logger = logging.getLogger(__name__)
def _calculate_node_deadline(definition: FormDefinition, created_at, *, start_time=None):
start = start_time or created_at
if definition.timeout_unit == TimeoutUnit.HOUR:
return start + timedelta(hours=definition.timeout)
if definition.timeout_unit == TimeoutUnit.DAY:
return start + timedelta(days=definition.timeout)
raise AssertionError("unknown timeout unit.")
def _is_global_timeout(form_model: HumanInputForm, global_timeout_seconds: int) -> bool:
def _is_global_timeout(form_model: HumanInputForm, global_timeout_seconds: int, *, now) -> bool:
if global_timeout_seconds <= 0:
return False
form_definition = FormDefinition.model_validate_json(form_model.form_definition)
if form_model.workflow_run_id is None:
return False
created_at = ensure_naive_utc(form_model.created_at)
expiration_time = ensure_naive_utc(form_model.expiration_time)
node_deadline = _calculate_node_deadline(form_definition, created_at)
global_deadline = created_at + timedelta(seconds=global_timeout_seconds)
return global_deadline <= node_deadline and expiration_time <= global_deadline
return global_deadline <= now
def _handle_global_timeout(*, form_id: str, workflow_run_id: str, node_id: str, session_factory: sessionmaker) -> None:
@ -77,11 +64,15 @@ def check_and_handle_human_input_timeouts(limit: int = 100) -> None:
global_timeout_seconds = dify_config.HITL_GLOBAL_TIMEOUT_SECONDS
with session_factory() as session:
global_deadline = now - timedelta(seconds=global_timeout_seconds) if global_timeout_seconds > 0 else None
timeout_filter = HumanInputForm.expiration_time <= now
if global_deadline is not None:
timeout_filter = or_(timeout_filter, HumanInputForm.created_at <= global_deadline)
stmt = (
select(HumanInputForm)
.where(
HumanInputForm.status == HumanInputFormStatus.WAITING,
HumanInputForm.expiration_time <= now,
timeout_filter,
)
.order_by(HumanInputForm.id.asc())
.limit(limit)
@ -97,7 +88,7 @@ def check_and_handle_human_input_timeouts(limit: int = 100) -> None:
reason="delivery_test_timeout",
)
continue
is_global = _is_global_timeout(form_model, global_timeout_seconds)
is_global = _is_global_timeout(form_model, global_timeout_seconds, now=now)
record = form_repo.mark_timeout(
form_id=form_model.id,
timeout_status=HumanInputFormStatus.EXPIRED if is_global else HumanInputFormStatus.TIMEOUT,

View File

@ -0,0 +1,212 @@
from __future__ import annotations
from datetime import datetime, timedelta
from types import SimpleNamespace
from typing import Any
import pytest
from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
from tasks import human_input_timeout_tasks as task_module
class _FakeScalarResult:
def __init__(self, items: list[Any]):
self._items = items
def all(self) -> list[Any]:
return self._items
class _FakeSession:
def __init__(self, items: list[Any], capture: dict[str, Any]):
self._items = items
self._capture = capture
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def scalars(self, stmt):
self._capture["stmt"] = stmt
return _FakeScalarResult(self._items)
class _FakeSessionFactory:
def __init__(self, items: list[Any], capture: dict[str, Any]):
self._items = items
self._capture = capture
self._capture["session_factory"] = self
def __call__(self):
session = _FakeSession(self._items, self._capture)
self._capture["session"] = session
return session
class _FakeFormRepo:
def __init__(self, _session_factory, form_map: dict[str, Any] | None = None):
self.calls: list[dict[str, Any]] = []
self._form_map = form_map or {}
def mark_timeout(self, *, form_id: str, timeout_status: HumanInputFormStatus, reason: str | None = None):
self.calls.append(
{
"form_id": form_id,
"timeout_status": timeout_status,
"reason": reason,
}
)
form = self._form_map.get(form_id)
return SimpleNamespace(
form_id=form_id,
workflow_run_id=getattr(form, "workflow_run_id", None),
node_id=getattr(form, "node_id", None),
)
class _FakeService:
def __init__(self, _session_factory, form_repository=None):
self.enqueued: list[str] = []
def _enqueue_resume(self, workflow_run_id: str | None) -> None:
if workflow_run_id is not None:
self.enqueued.append(workflow_run_id)
def _build_form(
*,
form_id: str,
form_kind: HumanInputFormKind,
created_at: datetime,
expiration_time: datetime,
workflow_run_id: str | None,
node_id: str,
) -> SimpleNamespace:
return SimpleNamespace(
id=form_id,
form_kind=form_kind,
created_at=created_at,
expiration_time=expiration_time,
workflow_run_id=workflow_run_id,
node_id=node_id,
status=HumanInputFormStatus.WAITING,
)
def test_is_global_timeout_uses_created_at():
now = datetime(2025, 1, 1, 12, 0, 0)
form = SimpleNamespace(created_at=now - timedelta(seconds=61), workflow_run_id="run-1")
assert task_module._is_global_timeout(form, 60, now=now) is True
form.workflow_run_id = None
assert task_module._is_global_timeout(form, 60, now=now) is False
form.workflow_run_id = "run-1"
form.created_at = now - timedelta(seconds=59)
assert task_module._is_global_timeout(form, 60, now=now) is False
assert task_module._is_global_timeout(form, 0, now=now) is False
def test_check_and_handle_human_input_timeouts_marks_and_routes(monkeypatch: pytest.MonkeyPatch):
now = datetime(2025, 1, 1, 12, 0, 0)
monkeypatch.setattr(task_module, "naive_utc_now", lambda: now)
monkeypatch.setattr(task_module.dify_config, "HITL_GLOBAL_TIMEOUT_SECONDS", 3600)
monkeypatch.setattr(task_module, "db", SimpleNamespace(engine=object()))
forms = [
_build_form(
form_id="form-global",
form_kind=HumanInputFormKind.RUNTIME,
created_at=now - timedelta(hours=2),
expiration_time=now + timedelta(hours=1),
workflow_run_id="run-global",
node_id="node-global",
),
_build_form(
form_id="form-node",
form_kind=HumanInputFormKind.RUNTIME,
created_at=now - timedelta(minutes=5),
expiration_time=now - timedelta(seconds=1),
workflow_run_id="run-node",
node_id="node-node",
),
_build_form(
form_id="form-delivery",
form_kind=HumanInputFormKind.DELIVERY_TEST,
created_at=now - timedelta(minutes=1),
expiration_time=now - timedelta(seconds=1),
workflow_run_id=None,
node_id="node-delivery",
),
]
capture: dict[str, Any] = {}
monkeypatch.setattr(task_module, "sessionmaker", lambda *args, **kwargs: _FakeSessionFactory(forms, capture))
form_map = {form.id: form for form in forms}
repo = _FakeFormRepo(None, form_map=form_map)
def _repo_factory(_session_factory):
return repo
service = _FakeService(None)
def _service_factory(_session_factory, form_repository=None):
return service
global_calls: list[dict[str, Any]] = []
monkeypatch.setattr(task_module, "HumanInputFormSubmissionRepository", _repo_factory)
monkeypatch.setattr(task_module, "HumanInputService", _service_factory)
monkeypatch.setattr(task_module, "_handle_global_timeout", lambda **kwargs: global_calls.append(kwargs))
task_module.check_and_handle_human_input_timeouts(limit=100)
assert {
(call["form_id"], call["timeout_status"], call["reason"]) for call in repo.calls
} == {
("form-global", HumanInputFormStatus.EXPIRED, "global_timeout"),
("form-node", HumanInputFormStatus.TIMEOUT, "node_timeout"),
("form-delivery", HumanInputFormStatus.TIMEOUT, "delivery_test_timeout"),
}
assert service.enqueued == ["run-node"]
assert global_calls == [
{
"form_id": "form-global",
"workflow_run_id": "run-global",
"node_id": "node-global",
"session_factory": capture.get("session_factory"),
}
]
stmt = capture.get("stmt")
assert stmt is not None
stmt_text = str(stmt)
assert "created_at <=" in stmt_text
assert "expiration_time <=" in stmt_text
assert "ORDER BY human_input_forms.id" in stmt_text
def test_check_and_handle_human_input_timeouts_omits_global_filter_when_disabled(monkeypatch: pytest.MonkeyPatch):
now = datetime(2025, 1, 1, 12, 0, 0)
monkeypatch.setattr(task_module, "naive_utc_now", lambda: now)
monkeypatch.setattr(task_module.dify_config, "HITL_GLOBAL_TIMEOUT_SECONDS", 0)
monkeypatch.setattr(task_module, "db", SimpleNamespace(engine=object()))
capture: dict[str, Any] = {}
monkeypatch.setattr(task_module, "sessionmaker", lambda *args, **kwargs: _FakeSessionFactory([], capture))
monkeypatch.setattr(task_module, "HumanInputFormSubmissionRepository", _FakeFormRepo)
monkeypatch.setattr(task_module, "HumanInputService", _FakeService)
monkeypatch.setattr(task_module, "_handle_global_timeout", lambda **_kwargs: None)
task_module.check_and_handle_human_input_timeouts(limit=1)
stmt = capture.get("stmt")
assert stmt is not None
stmt_text = str(stmt)
assert "created_at <=" not in stmt_text