feat: Human Input Node (#32060)

The frontend and backend implementation for the human input node.

Co-authored-by: twwu <twwu@dify.ai>
Co-authored-by: JzoNg <jzongcode@gmail.com>
Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com>
Co-authored-by: zhsama <torvalds@linux.do>
This commit is contained in:
QuantumGhost
2026-02-09 14:57:23 +08:00
committed by GitHub
parent 56e3a55023
commit a1fc280102
474 changed files with 32667 additions and 2050 deletions

View File

@ -164,6 +164,62 @@ def test_db_extras_options_merging(monkeypatch: pytest.MonkeyPatch):
assert "timezone=UTC" in options
def test_pubsub_redis_url_default(monkeypatch: pytest.MonkeyPatch):
os.environ.clear()
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
monkeypatch.setenv("DB_USERNAME", "postgres")
monkeypatch.setenv("DB_PASSWORD", "postgres")
monkeypatch.setenv("DB_HOST", "localhost")
monkeypatch.setenv("DB_PORT", "5432")
monkeypatch.setenv("DB_DATABASE", "dify")
monkeypatch.setenv("REDIS_HOST", "redis.example.com")
monkeypatch.setenv("REDIS_PORT", "6380")
monkeypatch.setenv("REDIS_USERNAME", "user")
monkeypatch.setenv("REDIS_PASSWORD", "pass@word")
monkeypatch.setenv("REDIS_DB", "2")
monkeypatch.setenv("REDIS_USE_SSL", "true")
config = DifyConfig()
assert config.normalized_pubsub_redis_url == "rediss://user:pass%40word@redis.example.com:6380/2"
assert config.PUBSUB_REDIS_CHANNEL_TYPE == "pubsub"
def test_pubsub_redis_url_override(monkeypatch: pytest.MonkeyPatch):
os.environ.clear()
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
monkeypatch.setenv("DB_USERNAME", "postgres")
monkeypatch.setenv("DB_PASSWORD", "postgres")
monkeypatch.setenv("DB_HOST", "localhost")
monkeypatch.setenv("DB_PORT", "5432")
monkeypatch.setenv("DB_DATABASE", "dify")
monkeypatch.setenv("PUBSUB_REDIS_URL", "redis://pubsub-host:6381/5")
config = DifyConfig()
assert config.normalized_pubsub_redis_url == "redis://pubsub-host:6381/5"
def test_pubsub_redis_url_required_when_default_unavailable(monkeypatch: pytest.MonkeyPatch):
os.environ.clear()
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
monkeypatch.setenv("DB_USERNAME", "postgres")
monkeypatch.setenv("DB_PASSWORD", "postgres")
monkeypatch.setenv("DB_HOST", "localhost")
monkeypatch.setenv("DB_PORT", "5432")
monkeypatch.setenv("DB_DATABASE", "dify")
monkeypatch.setenv("REDIS_HOST", "")
with pytest.raises(ValueError, match="PUBSUB_REDIS_URL must be set"):
_ = DifyConfig().normalized_pubsub_redis_url
@pytest.mark.parametrize(
("broker_url", "expected_host", "expected_port", "expected_username", "expected_password", "expected_db"),
[

View File

@ -51,6 +51,8 @@ def _patch_redis_clients_on_loaded_modules():
continue
if hasattr(module, "redis_client"):
module.redis_client = redis_mock
if hasattr(module, "pubsub_redis_client"):
module.pubsub_redis_client = redis_mock
@pytest.fixture
@ -68,7 +70,10 @@ def _provide_app_context(app: Flask):
def _patch_redis_clients():
"""Patch redis_client to MagicMock only for unit test executions."""
with patch.object(ext_redis, "redis_client", redis_mock):
with (
patch.object(ext_redis, "redis_client", redis_mock),
patch.object(ext_redis, "pubsub_redis_client", redis_mock),
):
_patch_redis_clients_on_loaded_modules()
yield

View File

@ -16,11 +16,9 @@ if not hasattr(builtins, "MethodView"):
builtins.MethodView = MethodView # type: ignore[attr-defined]
def _load_app_module():
@pytest.fixture(scope="module")
def app_module():
module_name = "controllers.console.app.app"
if module_name in sys.modules:
return sys.modules[module_name]
root = Path(__file__).resolve().parents[5]
module_path = root / "controllers" / "console" / "app" / "app.py"
@ -59,8 +57,12 @@ def _load_app_module():
stub_namespace = _StubNamespace()
original_console = sys.modules.get("controllers.console")
original_app_pkg = sys.modules.get("controllers.console.app")
original_modules: dict[str, ModuleType | None] = {
"controllers.console": sys.modules.get("controllers.console"),
"controllers.console.app": sys.modules.get("controllers.console.app"),
"controllers.common.schema": sys.modules.get("controllers.common.schema"),
module_name: sys.modules.get(module_name),
}
stubbed_modules: list[tuple[str, ModuleType | None]] = []
console_module = ModuleType("controllers.console")
@ -105,35 +107,35 @@ def _load_app_module():
module = util.module_from_spec(spec)
sys.modules[module_name] = module
assert spec.loader is not None
spec.loader.exec_module(module)
try:
assert spec.loader is not None
spec.loader.exec_module(module)
yield module
finally:
for name, original in reversed(stubbed_modules):
if original is not None:
sys.modules[name] = original
else:
sys.modules.pop(name, None)
if original_console is not None:
sys.modules["controllers.console"] = original_console
else:
sys.modules.pop("controllers.console", None)
if original_app_pkg is not None:
sys.modules["controllers.console.app"] = original_app_pkg
else:
sys.modules.pop("controllers.console.app", None)
return module
for name, original in original_modules.items():
if original is not None:
sys.modules[name] = original
else:
sys.modules.pop(name, None)
_app_module = _load_app_module()
AppDetailWithSite = _app_module.AppDetailWithSite
AppPagination = _app_module.AppPagination
AppPartial = _app_module.AppPartial
@pytest.fixture(scope="module")
def app_models(app_module):
return SimpleNamespace(
AppDetailWithSite=app_module.AppDetailWithSite,
AppPagination=app_module.AppPagination,
AppPartial=app_module.AppPartial,
)
@pytest.fixture(autouse=True)
def patch_signed_url(monkeypatch):
def patch_signed_url(monkeypatch, app_module):
"""Ensure icon URL generation uses a deterministic helper for tests."""
def _fake_signed_url(key: str | None) -> str | None:
@ -141,7 +143,7 @@ def patch_signed_url(monkeypatch):
return None
return f"signed:{key}"
monkeypatch.setattr(_app_module.file_helpers, "get_signed_file_url", _fake_signed_url)
monkeypatch.setattr(app_module.file_helpers, "get_signed_file_url", _fake_signed_url)
def _ts(hour: int = 12) -> datetime:
@ -169,7 +171,8 @@ def _dummy_workflow():
)
def test_app_partial_serialization_uses_aliases():
def test_app_partial_serialization_uses_aliases(app_models):
AppPartial = app_models.AppPartial
created_at = _ts()
app_obj = SimpleNamespace(
id="app-1",
@ -204,7 +207,8 @@ def test_app_partial_serialization_uses_aliases():
assert serialized["tags"][0]["name"] == "Utilities"
def test_app_detail_with_site_includes_nested_serialization():
def test_app_detail_with_site_includes_nested_serialization(app_models):
AppDetailWithSite = app_models.AppDetailWithSite
timestamp = _ts(14)
site = SimpleNamespace(
code="site-code",
@ -253,7 +257,8 @@ def test_app_detail_with_site_includes_nested_serialization():
assert serialized["site"]["created_at"] == int(timestamp.timestamp())
def test_app_pagination_aliases_per_page_and_has_next():
def test_app_pagination_aliases_per_page_and_has_next(app_models):
AppPagination = app_models.AppPagination
item_one = SimpleNamespace(
id="app-10",
name="Paginated One",

View File

@ -0,0 +1,229 @@
from __future__ import annotations
from dataclasses import dataclass
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from flask import Flask
from pydantic import ValidationError
from controllers.console import wraps as console_wraps
from controllers.console.app import workflow as workflow_module
from controllers.console.app import wraps as app_wraps
from libs import login as login_lib
from models.account import Account, AccountStatus, TenantAccountRole
from models.model import AppMode
def _make_account() -> Account:
account = Account(name="tester", email="tester@example.com")
account.status = AccountStatus.ACTIVE
account.role = TenantAccountRole.OWNER
account.id = "account-123" # type: ignore[assignment]
account._current_tenant = SimpleNamespace(id="tenant-123") # type: ignore[attr-defined]
account._get_current_object = lambda: account # type: ignore[attr-defined]
return account
def _make_app(mode: AppMode) -> SimpleNamespace:
return SimpleNamespace(id="app-123", tenant_id="tenant-123", mode=mode.value)
def _patch_console_guards(monkeypatch: pytest.MonkeyPatch, account: Account, app_model: SimpleNamespace) -> None:
# Skip setup and auth guardrails
monkeypatch.setattr("configs.dify_config.EDITION", "CLOUD")
monkeypatch.setattr(login_lib.dify_config, "LOGIN_DISABLED", True)
monkeypatch.setattr(login_lib, "current_user", account)
monkeypatch.setattr(login_lib, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
monkeypatch.setattr(login_lib, "check_csrf_token", lambda *_, **__: None)
monkeypatch.setattr(console_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
monkeypatch.setattr(app_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
monkeypatch.setattr(console_wraps.dify_config, "EDITION", "CLOUD")
monkeypatch.delenv("INIT_PASSWORD", raising=False)
# Avoid hitting the database when resolving the app model
monkeypatch.setattr(app_wraps, "_load_app_model", lambda _app_id: app_model)
@dataclass
class PreviewCase:
resource_cls: type
path: str
mode: AppMode
@pytest.mark.parametrize(
"case",
[
PreviewCase(
resource_cls=workflow_module.AdvancedChatDraftHumanInputFormPreviewApi,
path="/console/api/apps/app-123/advanced-chat/workflows/draft/human-input/nodes/node-42/form/preview",
mode=AppMode.ADVANCED_CHAT,
),
PreviewCase(
resource_cls=workflow_module.WorkflowDraftHumanInputFormPreviewApi,
path="/console/api/apps/app-123/workflows/draft/human-input/nodes/node-42/form/preview",
mode=AppMode.WORKFLOW,
),
],
)
def test_human_input_preview_delegates_to_service(
app: Flask, monkeypatch: pytest.MonkeyPatch, case: PreviewCase
) -> None:
account = _make_account()
app_model = _make_app(case.mode)
_patch_console_guards(monkeypatch, account, app_model)
preview_payload = {
"form_id": "node-42",
"form_content": "<div>example</div>",
"inputs": [{"name": "topic"}],
"actions": [{"id": "continue"}],
}
service_instance = MagicMock()
service_instance.get_human_input_form_preview.return_value = preview_payload
monkeypatch.setattr(workflow_module, "WorkflowService", MagicMock(return_value=service_instance))
with app.test_request_context(case.path, method="POST", json={"inputs": {"topic": "tech"}}):
response = case.resource_cls().post(app_id=app_model.id, node_id="node-42")
assert response == preview_payload
service_instance.get_human_input_form_preview.assert_called_once_with(
app_model=app_model,
account=account,
node_id="node-42",
inputs={"topic": "tech"},
)
@dataclass
class SubmitCase:
resource_cls: type
path: str
mode: AppMode
@pytest.mark.parametrize(
"case",
[
SubmitCase(
resource_cls=workflow_module.AdvancedChatDraftHumanInputFormRunApi,
path="/console/api/apps/app-123/advanced-chat/workflows/draft/human-input/nodes/node-99/form/run",
mode=AppMode.ADVANCED_CHAT,
),
SubmitCase(
resource_cls=workflow_module.WorkflowDraftHumanInputFormRunApi,
path="/console/api/apps/app-123/workflows/draft/human-input/nodes/node-99/form/run",
mode=AppMode.WORKFLOW,
),
],
)
def test_human_input_submit_forwards_payload(app: Flask, monkeypatch: pytest.MonkeyPatch, case: SubmitCase) -> None:
account = _make_account()
app_model = _make_app(case.mode)
_patch_console_guards(monkeypatch, account, app_model)
result_payload = {"node_id": "node-99", "outputs": {"__rendered_content": "<p>done</p>"}, "action": "approve"}
service_instance = MagicMock()
service_instance.submit_human_input_form_preview.return_value = result_payload
monkeypatch.setattr(workflow_module, "WorkflowService", MagicMock(return_value=service_instance))
with app.test_request_context(
case.path,
method="POST",
json={"form_inputs": {"answer": "42"}, "inputs": {"#node-1.result#": "LLM output"}, "action": "approve"},
):
response = case.resource_cls().post(app_id=app_model.id, node_id="node-99")
assert response == result_payload
service_instance.submit_human_input_form_preview.assert_called_once_with(
app_model=app_model,
account=account,
node_id="node-99",
form_inputs={"answer": "42"},
inputs={"#node-1.result#": "LLM output"},
action="approve",
)
@dataclass
class DeliveryTestCase:
resource_cls: type
path: str
mode: AppMode
@pytest.mark.parametrize(
"case",
[
DeliveryTestCase(
resource_cls=workflow_module.WorkflowDraftHumanInputDeliveryTestApi,
path="/console/api/apps/app-123/workflows/draft/human-input/nodes/node-7/delivery-test",
mode=AppMode.ADVANCED_CHAT,
),
DeliveryTestCase(
resource_cls=workflow_module.WorkflowDraftHumanInputDeliveryTestApi,
path="/console/api/apps/app-123/workflows/draft/human-input/nodes/node-7/delivery-test",
mode=AppMode.WORKFLOW,
),
],
)
def test_human_input_delivery_test_calls_service(
app: Flask, monkeypatch: pytest.MonkeyPatch, case: DeliveryTestCase
) -> None:
account = _make_account()
app_model = _make_app(case.mode)
_patch_console_guards(monkeypatch, account, app_model)
service_instance = MagicMock()
monkeypatch.setattr(workflow_module, "WorkflowService", MagicMock(return_value=service_instance))
with app.test_request_context(
case.path,
method="POST",
json={"delivery_method_id": "delivery-123"},
):
response = case.resource_cls().post(app_id=app_model.id, node_id="node-7")
assert response == {}
service_instance.test_human_input_delivery.assert_called_once_with(
app_model=app_model,
account=account,
node_id="node-7",
delivery_method_id="delivery-123",
inputs={},
)
def test_human_input_delivery_test_maps_validation_error(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
account = _make_account()
app_model = _make_app(AppMode.ADVANCED_CHAT)
_patch_console_guards(monkeypatch, account, app_model)
service_instance = MagicMock()
service_instance.test_human_input_delivery.side_effect = ValueError("bad delivery method")
monkeypatch.setattr(workflow_module, "WorkflowService", MagicMock(return_value=service_instance))
with app.test_request_context(
"/console/api/apps/app-123/workflows/draft/human-input/nodes/node-1/delivery-test",
method="POST",
json={"delivery_method_id": "bad"},
):
with pytest.raises(ValueError):
workflow_module.WorkflowDraftHumanInputDeliveryTestApi().post(app_id=app_model.id, node_id="node-1")
def test_human_input_preview_rejects_non_mapping(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
account = _make_account()
app_model = _make_app(AppMode.ADVANCED_CHAT)
_patch_console_guards(monkeypatch, account, app_model)
with app.test_request_context(
"/console/api/apps/app-123/advanced-chat/workflows/draft/human-input/nodes/node-1/form/preview",
method="POST",
json={"inputs": ["not-a-dict"]},
):
with pytest.raises(ValidationError):
workflow_module.AdvancedChatDraftHumanInputFormPreviewApi().post(app_id=app_model.id, node_id="node-1")

View File

@ -0,0 +1,110 @@
from __future__ import annotations
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import Mock
import pytest
from flask import Flask
from controllers.console import wraps as console_wraps
from controllers.console.app import workflow_run as workflow_run_module
from controllers.web.error import NotFoundError
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.nodes.human_input.entities import FormInput, UserAction
from core.workflow.nodes.human_input.enums import FormInputType
from libs import login as login_lib
from models.account import Account, AccountStatus, TenantAccountRole
from models.workflow import WorkflowRun
def _make_account() -> Account:
account = Account(name="tester", email="tester@example.com")
account.status = AccountStatus.ACTIVE
account.role = TenantAccountRole.OWNER
account.id = "account-123" # type: ignore[assignment]
account._current_tenant = SimpleNamespace(id="tenant-123") # type: ignore[attr-defined]
account._get_current_object = lambda: account # type: ignore[attr-defined]
return account
def _patch_console_guards(monkeypatch: pytest.MonkeyPatch, account: Account) -> None:
monkeypatch.setattr(login_lib.dify_config, "LOGIN_DISABLED", True)
monkeypatch.setattr(login_lib, "current_user", account)
monkeypatch.setattr(login_lib, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
monkeypatch.setattr(login_lib, "check_csrf_token", lambda *_, **__: None)
monkeypatch.setattr(console_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
monkeypatch.setattr(workflow_run_module, "current_user", account)
monkeypatch.setattr(console_wraps.dify_config, "EDITION", "CLOUD")
class _PauseEntity:
def __init__(self, paused_at: datetime, reasons: list[HumanInputRequired]):
self.paused_at = paused_at
self._reasons = reasons
def get_pause_reasons(self):
return self._reasons
def test_pause_details_returns_backstage_input_url(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
account = _make_account()
_patch_console_guards(monkeypatch, account)
monkeypatch.setattr(workflow_run_module.dify_config, "APP_WEB_URL", "https://web.example.com")
workflow_run = Mock(spec=WorkflowRun)
workflow_run.tenant_id = "tenant-123"
workflow_run.status = WorkflowExecutionStatus.PAUSED
workflow_run.created_at = datetime(2024, 1, 1, 12, 0, 0)
fake_db = SimpleNamespace(engine=Mock(), session=SimpleNamespace(get=lambda *_: workflow_run))
monkeypatch.setattr(workflow_run_module, "db", fake_db)
reason = HumanInputRequired(
form_id="form-1",
form_content="content",
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")],
actions=[UserAction(id="approve", title="Approve")],
node_id="node-1",
node_title="Ask Name",
form_token="backstage-token",
)
pause_entity = _PauseEntity(paused_at=datetime(2024, 1, 1, 12, 0, 0), reasons=[reason])
repo = Mock()
repo.get_workflow_pause.return_value = pause_entity
monkeypatch.setattr(
workflow_run_module.DifyAPIRepositoryFactory,
"create_api_workflow_run_repository",
lambda *_, **__: repo,
)
with app.test_request_context("/console/api/workflow/run-1/pause-details", method="GET"):
response, status = workflow_run_module.ConsoleWorkflowPauseDetailsApi().get(workflow_run_id="run-1")
assert status == 200
assert response["paused_at"] == "2024-01-01T12:00:00Z"
assert response["paused_nodes"][0]["node_id"] == "node-1"
assert response["paused_nodes"][0]["pause_type"]["type"] == "human_input"
assert (
response["paused_nodes"][0]["pause_type"]["backstage_input_url"]
== "https://web.example.com/form/backstage-token"
)
assert "pending_human_inputs" not in response
def test_pause_details_tenant_isolation(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
account = _make_account()
_patch_console_guards(monkeypatch, account)
monkeypatch.setattr(workflow_run_module.dify_config, "APP_WEB_URL", "https://web.example.com")
workflow_run = Mock(spec=WorkflowRun)
workflow_run.tenant_id = "tenant-456"
workflow_run.status = WorkflowExecutionStatus.PAUSED
workflow_run.created_at = datetime(2024, 1, 1, 12, 0, 0)
fake_db = SimpleNamespace(engine=Mock(), session=SimpleNamespace(get=lambda *_: workflow_run))
monkeypatch.setattr(workflow_run_module, "db", fake_db)
with pytest.raises(NotFoundError):
with app.test_request_context("/console/api/workflow/run-1/pause-details", method="GET"):
response, status = workflow_run_module.ConsoleWorkflowPauseDetailsApi().get(workflow_run_id="run-1")

View File

@ -0,0 +1,25 @@
from types import SimpleNamespace
from controllers.service_api.app.workflow import WorkflowRunOutputsField, WorkflowRunStatusField
from core.workflow.enums import WorkflowExecutionStatus
def test_workflow_run_status_field_with_enum() -> None:
field = WorkflowRunStatusField()
obj = SimpleNamespace(status=WorkflowExecutionStatus.PAUSED)
assert field.output("status", obj) == "paused"
def test_workflow_run_outputs_field_paused_returns_empty() -> None:
field = WorkflowRunOutputsField()
obj = SimpleNamespace(status=WorkflowExecutionStatus.PAUSED, outputs_dict={"foo": "bar"})
assert field.output("outputs", obj) == {}
def test_workflow_run_outputs_field_running_returns_outputs() -> None:
field = WorkflowRunOutputsField()
obj = SimpleNamespace(status=WorkflowExecutionStatus.RUNNING, outputs_dict={"foo": "bar"})
assert field.output("outputs", obj) == {"foo": "bar"}

View File

@ -0,0 +1,456 @@
"""Unit tests for controllers.web.human_input_form endpoints."""
from __future__ import annotations
import json
from datetime import UTC, datetime
from types import SimpleNamespace
from typing import Any
from unittest.mock import MagicMock
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden
import controllers.web.human_input_form as human_input_module
import controllers.web.site as site_module
from controllers.web.error import WebFormRateLimitExceededError
from models.human_input import RecipientType
from services.human_input_service import FormExpiredError
HumanInputFormApi = human_input_module.HumanInputFormApi
TenantStatus = human_input_module.TenantStatus
@pytest.fixture
def app() -> Flask:
"""Configure a minimal Flask app for request contexts."""
app = Flask(__name__)
app.config["TESTING"] = True
return app
class _FakeSession:
"""Simple stand-in for db.session that returns pre-seeded objects."""
def __init__(self, mapping: dict[str, Any]):
self._mapping = mapping
self._model_name: str | None = None
def query(self, model):
self._model_name = model.__name__
return self
def where(self, *args, **kwargs):
return self
def first(self):
assert self._model_name is not None
return self._mapping.get(self._model_name)
class _FakeDB:
"""Minimal db stub exposing engine and session."""
def __init__(self, session: _FakeSession):
self.session = session
self.engine = object()
def test_get_form_includes_site(monkeypatch: pytest.MonkeyPatch, app: Flask):
"""GET returns form definition merged with site payload."""
expiration_time = datetime(2099, 1, 1, tzinfo=UTC)
class _FakeDefinition:
def model_dump(self):
return {
"form_content": "Raw content",
"rendered_content": "Rendered {{#$output.name#}}",
"inputs": [{"type": "text", "output_variable_name": "name", "default": None}],
"default_values": {"name": "Alice", "age": 30, "meta": {"k": "v"}},
"user_actions": [{"id": "approve", "title": "Approve", "button_style": "default"}],
}
class _FakeForm:
def __init__(self, expiration: datetime):
self.workflow_run_id = "workflow-1"
self.app_id = "app-1"
self.tenant_id = "tenant-1"
self.expiration_time = expiration
self.recipient_type = RecipientType.BACKSTAGE
def get_definition(self):
return _FakeDefinition()
form = _FakeForm(expiration_time)
limiter_mock = MagicMock()
limiter_mock.is_rate_limited.return_value = False
monkeypatch.setattr(human_input_module, "_FORM_ACCESS_RATE_LIMITER", limiter_mock)
monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10")
tenant = SimpleNamespace(
id="tenant-1",
status=TenantStatus.NORMAL,
plan="basic",
custom_config_dict={"remove_webapp_brand": True, "replace_webapp_logo": False},
)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True)
workflow_run = SimpleNamespace(app_id="app-1")
site_model = SimpleNamespace(
title="My Site",
icon_type="emoji",
icon="robot",
icon_background="#fff",
description="desc",
default_language="en",
chat_color_theme="light",
chat_color_theme_inverted=False,
copyright=None,
privacy_policy=None,
custom_disclaimer=None,
prompt_public=False,
show_workflow_steps=True,
use_icon_as_answer_icon=False,
)
# Patch service to return fake form.
service_mock = MagicMock()
service_mock.get_form_by_token.return_value = form
monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock)
# Patch db session.
db_stub = _FakeDB(_FakeSession({"WorkflowRun": workflow_run, "App": app_model, "Site": site_model}))
monkeypatch.setattr(human_input_module, "db", db_stub)
monkeypatch.setattr(
site_module.FeatureService,
"get_features",
lambda tenant_id: SimpleNamespace(can_replace_logo=True),
)
with app.test_request_context("/api/form/human_input/token-1", method="GET"):
response = HumanInputFormApi().get("token-1")
body = json.loads(response.get_data(as_text=True))
assert set(body.keys()) == {
"site",
"form_content",
"inputs",
"resolved_default_values",
"user_actions",
"expiration_time",
}
assert body["form_content"] == "Rendered {{#$output.name#}}"
assert body["inputs"] == [{"type": "text", "output_variable_name": "name", "default": None}]
assert body["resolved_default_values"] == {"name": "Alice", "age": "30", "meta": '{"k": "v"}'}
assert body["user_actions"] == [{"id": "approve", "title": "Approve", "button_style": "default"}]
assert body["expiration_time"] == int(expiration_time.timestamp())
assert body["site"] == {
"app_id": "app-1",
"end_user_id": None,
"enable_site": True,
"site": {
"title": "My Site",
"chat_color_theme": "light",
"chat_color_theme_inverted": False,
"icon_type": "emoji",
"icon": "robot",
"icon_background": "#fff",
"icon_url": None,
"description": "desc",
"copyright": None,
"privacy_policy": None,
"custom_disclaimer": None,
"default_language": "en",
"prompt_public": False,
"show_workflow_steps": True,
"use_icon_as_answer_icon": False,
},
"model_config": None,
"plan": "basic",
"can_replace_logo": True,
"custom_config": {
"remove_webapp_brand": True,
"replace_webapp_logo": None,
},
}
service_mock.get_form_by_token.assert_called_once_with("token-1")
limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10")
limiter_mock.increment_rate_limit.assert_called_once_with("203.0.113.10")
def test_get_form_allows_backstage_token(monkeypatch: pytest.MonkeyPatch, app: Flask):
"""GET returns form payload for backstage token."""
expiration_time = datetime(2099, 1, 2, tzinfo=UTC)
class _FakeDefinition:
def model_dump(self):
return {
"form_content": "Raw content",
"rendered_content": "Rendered",
"inputs": [],
"default_values": {},
"user_actions": [],
}
class _FakeForm:
def __init__(self, expiration: datetime):
self.workflow_run_id = "workflow-1"
self.app_id = "app-1"
self.tenant_id = "tenant-1"
self.expiration_time = expiration
def get_definition(self):
return _FakeDefinition()
form = _FakeForm(expiration_time)
limiter_mock = MagicMock()
limiter_mock.is_rate_limited.return_value = False
monkeypatch.setattr(human_input_module, "_FORM_ACCESS_RATE_LIMITER", limiter_mock)
monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10")
tenant = SimpleNamespace(
id="tenant-1",
status=TenantStatus.NORMAL,
plan="basic",
custom_config_dict={"remove_webapp_brand": True, "replace_webapp_logo": False},
)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True)
workflow_run = SimpleNamespace(app_id="app-1")
site_model = SimpleNamespace(
title="My Site",
icon_type="emoji",
icon="robot",
icon_background="#fff",
description="desc",
default_language="en",
chat_color_theme="light",
chat_color_theme_inverted=False,
copyright=None,
privacy_policy=None,
custom_disclaimer=None,
prompt_public=False,
show_workflow_steps=True,
use_icon_as_answer_icon=False,
)
service_mock = MagicMock()
service_mock.get_form_by_token.return_value = form
monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock)
db_stub = _FakeDB(_FakeSession({"WorkflowRun": workflow_run, "App": app_model, "Site": site_model}))
monkeypatch.setattr(human_input_module, "db", db_stub)
monkeypatch.setattr(
site_module.FeatureService,
"get_features",
lambda tenant_id: SimpleNamespace(can_replace_logo=True),
)
with app.test_request_context("/api/form/human_input/token-1", method="GET"):
response = HumanInputFormApi().get("token-1")
body = json.loads(response.get_data(as_text=True))
assert set(body.keys()) == {
"site",
"form_content",
"inputs",
"resolved_default_values",
"user_actions",
"expiration_time",
}
assert body["form_content"] == "Rendered"
assert body["inputs"] == []
assert body["resolved_default_values"] == {}
assert body["user_actions"] == []
assert body["expiration_time"] == int(expiration_time.timestamp())
assert body["site"] == {
"app_id": "app-1",
"end_user_id": None,
"enable_site": True,
"site": {
"title": "My Site",
"chat_color_theme": "light",
"chat_color_theme_inverted": False,
"icon_type": "emoji",
"icon": "robot",
"icon_background": "#fff",
"icon_url": None,
"description": "desc",
"copyright": None,
"privacy_policy": None,
"custom_disclaimer": None,
"default_language": "en",
"prompt_public": False,
"show_workflow_steps": True,
"use_icon_as_answer_icon": False,
},
"model_config": None,
"plan": "basic",
"can_replace_logo": True,
"custom_config": {
"remove_webapp_brand": True,
"replace_webapp_logo": None,
},
}
service_mock.get_form_by_token.assert_called_once_with("token-1")
limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10")
limiter_mock.increment_rate_limit.assert_called_once_with("203.0.113.10")
def test_get_form_raises_forbidden_when_site_missing(monkeypatch: pytest.MonkeyPatch, app: Flask):
"""GET raises Forbidden if site cannot be resolved."""
expiration_time = datetime(2099, 1, 3, tzinfo=UTC)
class _FakeDefinition:
def model_dump(self):
return {
"form_content": "Raw content",
"rendered_content": "Rendered",
"inputs": [],
"default_values": {},
"user_actions": [],
}
class _FakeForm:
def __init__(self, expiration: datetime):
self.workflow_run_id = "workflow-1"
self.app_id = "app-1"
self.tenant_id = "tenant-1"
self.expiration_time = expiration
def get_definition(self):
return _FakeDefinition()
form = _FakeForm(expiration_time)
limiter_mock = MagicMock()
limiter_mock.is_rate_limited.return_value = False
monkeypatch.setattr(human_input_module, "_FORM_ACCESS_RATE_LIMITER", limiter_mock)
monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10")
tenant = SimpleNamespace(status=TenantStatus.NORMAL)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant)
workflow_run = SimpleNamespace(app_id="app-1")
service_mock = MagicMock()
service_mock.get_form_by_token.return_value = form
monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock)
db_stub = _FakeDB(_FakeSession({"WorkflowRun": workflow_run, "App": app_model, "Site": None}))
monkeypatch.setattr(human_input_module, "db", db_stub)
with app.test_request_context("/api/form/human_input/token-1", method="GET"):
with pytest.raises(Forbidden):
HumanInputFormApi().get("token-1")
limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10")
limiter_mock.increment_rate_limit.assert_called_once_with("203.0.113.10")
def test_submit_form_accepts_backstage_token(monkeypatch: pytest.MonkeyPatch, app: Flask):
"""POST forwards backstage submissions to the service."""
class _FakeForm:
recipient_type = RecipientType.BACKSTAGE
form = _FakeForm()
limiter_mock = MagicMock()
limiter_mock.is_rate_limited.return_value = False
monkeypatch.setattr(human_input_module, "_FORM_SUBMIT_RATE_LIMITER", limiter_mock)
monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10")
service_mock = MagicMock()
service_mock.get_form_by_token.return_value = form
monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock)
monkeypatch.setattr(human_input_module, "db", _FakeDB(_FakeSession({})))
with app.test_request_context(
"/api/form/human_input/token-1",
method="POST",
json={"inputs": {"content": "ok"}, "action": "approve"},
):
response, status = HumanInputFormApi().post("token-1")
assert status == 200
assert response == {}
service_mock.submit_form_by_token.assert_called_once_with(
recipient_type=RecipientType.BACKSTAGE,
form_token="token-1",
selected_action_id="approve",
form_data={"content": "ok"},
submission_end_user_id=None,
)
limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10")
limiter_mock.increment_rate_limit.assert_called_once_with("203.0.113.10")
def test_submit_form_rate_limited(monkeypatch: pytest.MonkeyPatch, app: Flask):
"""POST rejects submissions when rate limit is exceeded."""
limiter_mock = MagicMock()
limiter_mock.is_rate_limited.return_value = True
monkeypatch.setattr(human_input_module, "_FORM_SUBMIT_RATE_LIMITER", limiter_mock)
monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10")
service_mock = MagicMock()
service_mock.get_form_by_token.return_value = None
monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock)
monkeypatch.setattr(human_input_module, "db", _FakeDB(_FakeSession({})))
with app.test_request_context(
"/api/form/human_input/token-1",
method="POST",
json={"inputs": {"content": "ok"}, "action": "approve"},
):
with pytest.raises(WebFormRateLimitExceededError):
HumanInputFormApi().post("token-1")
limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10")
limiter_mock.increment_rate_limit.assert_not_called()
service_mock.get_form_by_token.assert_not_called()
def test_get_form_rate_limited(monkeypatch: pytest.MonkeyPatch, app: Flask):
"""GET rejects requests when rate limit is exceeded."""
limiter_mock = MagicMock()
limiter_mock.is_rate_limited.return_value = True
monkeypatch.setattr(human_input_module, "_FORM_ACCESS_RATE_LIMITER", limiter_mock)
monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10")
service_mock = MagicMock()
service_mock.get_form_by_token.return_value = None
monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock)
monkeypatch.setattr(human_input_module, "db", _FakeDB(_FakeSession({})))
with app.test_request_context("/api/form/human_input/token-1", method="GET"):
with pytest.raises(WebFormRateLimitExceededError):
HumanInputFormApi().get("token-1")
limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10")
limiter_mock.increment_rate_limit.assert_not_called()
service_mock.get_form_by_token.assert_not_called()
def test_get_form_raises_expired(monkeypatch: pytest.MonkeyPatch, app: Flask):
class _FakeForm:
pass
form = _FakeForm()
limiter_mock = MagicMock()
limiter_mock.is_rate_limited.return_value = False
monkeypatch.setattr(human_input_module, "_FORM_ACCESS_RATE_LIMITER", limiter_mock)
monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10")
service_mock = MagicMock()
service_mock.get_form_by_token.return_value = form
service_mock.ensure_form_active.side_effect = FormExpiredError("form-id")
monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock)
monkeypatch.setattr(human_input_module, "db", _FakeDB(_FakeSession({})))
with app.test_request_context("/api/form/human_input/token-1", method="GET"):
with pytest.raises(FormExpiredError):
HumanInputFormApi().get("token-1")
service_mock.ensure_form_active.assert_called_once_with(form)
limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10")
limiter_mock.increment_rate_limit.assert_called_once_with("203.0.113.10")

View File

@ -3,6 +3,7 @@
from __future__ import annotations
import builtins
import uuid
from datetime import datetime
from types import ModuleType, SimpleNamespace
from unittest.mock import patch
@ -12,6 +13,8 @@ import pytest
from flask import Flask
from flask.views import MethodView
from core.entities.execution_extra_content import HumanInputContent
# Ensure flask_restx.api finds MethodView during import.
if not hasattr(builtins, "MethodView"):
builtins.MethodView = MethodView # type: ignore[attr-defined]
@ -137,6 +140,12 @@ def test_message_list_mapping(app: Flask) -> None:
status="success",
error=None,
message_metadata_dict={"meta": "value"},
extra_contents=[
HumanInputContent(
workflow_run_id=str(uuid.uuid4()),
submitted=True,
)
],
)
pagination = SimpleNamespace(limit=20, has_more=False, data=[message])
@ -169,6 +178,8 @@ def test_message_list_mapping(app: Flask) -> None:
assert item["agent_thoughts"][0]["chain_id"] == "chain-1"
assert item["agent_thoughts"][0]["created_at"] == int(thought_created_at.timestamp())
assert item["extra_contents"][0]["workflow_run_id"] == message.extra_contents[0].workflow_run_id
assert item["extra_contents"][0]["submitted"] == message.extra_contents[0].submitted
assert item["message_files"][0]["id"] == "file-dict"
assert item["message_files"][1]["id"] == "file-obj"

View File

@ -0,0 +1,187 @@
from __future__ import annotations
from contextlib import contextmanager
from datetime import datetime
from types import SimpleNamespace
from unittest import mock
import pytest
from core.app.apps.advanced_chat import generate_task_pipeline as pipeline_module
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueTextChunkEvent, QueueWorkflowPausedEvent
from core.workflow.entities.pause_reason import HumanInputRequired
from models.enums import MessageStatus
from models.execution_extra_content import HumanInputContent
from models.model import EndUser
def _build_pipeline() -> pipeline_module.AdvancedChatAppGenerateTaskPipeline:
pipeline = pipeline_module.AdvancedChatAppGenerateTaskPipeline.__new__(
pipeline_module.AdvancedChatAppGenerateTaskPipeline
)
pipeline._workflow_run_id = "run-1"
pipeline._message_id = "message-1"
pipeline._workflow_tenant_id = "tenant-1"
return pipeline
def test_persist_human_input_extra_content_adds_record(monkeypatch: pytest.MonkeyPatch) -> None:
pipeline = _build_pipeline()
monkeypatch.setattr(pipeline, "_load_human_input_form_id", lambda **kwargs: "form-1")
captured_session: dict[str, mock.Mock] = {}
@contextmanager
def fake_session():
session = mock.Mock()
session.scalar.return_value = None
captured_session["session"] = session
yield session
pipeline._database_session = fake_session # type: ignore[method-assign]
pipeline._persist_human_input_extra_content(node_id="node-1")
session = captured_session["session"]
session.add.assert_called_once()
content = session.add.call_args.args[0]
assert isinstance(content, HumanInputContent)
assert content.workflow_run_id == "run-1"
assert content.message_id == "message-1"
assert content.form_id == "form-1"
def test_persist_human_input_extra_content_skips_when_form_missing(monkeypatch: pytest.MonkeyPatch) -> None:
pipeline = _build_pipeline()
monkeypatch.setattr(pipeline, "_load_human_input_form_id", lambda **kwargs: None)
called = {"value": False}
@contextmanager
def fake_session():
called["value"] = True
session = mock.Mock()
yield session
pipeline._database_session = fake_session # type: ignore[method-assign]
pipeline._persist_human_input_extra_content(node_id="node-1")
assert called["value"] is False
def test_persist_human_input_extra_content_skips_when_existing(monkeypatch: pytest.MonkeyPatch) -> None:
pipeline = _build_pipeline()
monkeypatch.setattr(pipeline, "_load_human_input_form_id", lambda **kwargs: "form-1")
captured_session: dict[str, mock.Mock] = {}
@contextmanager
def fake_session():
session = mock.Mock()
session.scalar.return_value = HumanInputContent(
workflow_run_id="run-1",
message_id="message-1",
form_id="form-1",
)
captured_session["session"] = session
yield session
pipeline._database_session = fake_session # type: ignore[method-assign]
pipeline._persist_human_input_extra_content(node_id="node-1")
session = captured_session["session"]
session.add.assert_not_called()
def test_handle_workflow_paused_event_persists_human_input_extra_content() -> None:
pipeline = _build_pipeline()
pipeline._application_generate_entity = SimpleNamespace(task_id="task-1")
pipeline._workflow_response_converter = mock.Mock()
pipeline._workflow_response_converter.workflow_pause_to_stream_response.return_value = []
pipeline._ensure_graph_runtime_initialized = mock.Mock(
return_value=SimpleNamespace(
total_tokens=0,
node_run_steps=0,
),
)
pipeline._save_message = mock.Mock()
message = SimpleNamespace(status=MessageStatus.NORMAL)
pipeline._get_message = mock.Mock(return_value=message)
pipeline._persist_human_input_extra_content = mock.Mock()
pipeline._base_task_pipeline = mock.Mock()
pipeline._base_task_pipeline.queue_manager = mock.Mock()
pipeline._message_saved_on_pause = False
@contextmanager
def fake_session():
session = mock.Mock()
yield session
pipeline._database_session = fake_session # type: ignore[method-assign]
reason = HumanInputRequired(
form_id="form-1",
form_content="content",
inputs=[],
actions=[],
node_id="node-1",
node_title="Approval",
form_token="token-1",
resolved_default_values={},
)
event = QueueWorkflowPausedEvent(reasons=[reason], outputs={}, paused_nodes=["node-1"])
list(pipeline._handle_workflow_paused_event(event))
pipeline._persist_human_input_extra_content.assert_called_once_with(form_id="form-1", node_id="node-1")
assert message.status == MessageStatus.PAUSED
def test_resume_appends_chunks_to_paused_answer() -> None:
app_config = SimpleNamespace(app_id="app-1", tenant_id="tenant-1", sensitive_word_avoidance=None)
application_generate_entity = SimpleNamespace(
app_config=app_config,
files=[],
workflow_run_id="run-1",
query="hello",
invoke_from=InvokeFrom.WEB_APP,
inputs={},
task_id="task-1",
)
queue_manager = SimpleNamespace(graph_runtime_state=None)
conversation = SimpleNamespace(id="conversation-1", mode="advanced-chat")
message = SimpleNamespace(
id="message-1",
created_at=datetime(2024, 1, 1),
query="hello",
answer="before",
status=MessageStatus.PAUSED,
)
user = EndUser()
user.id = "user-1"
user.session_id = "session-1"
workflow = SimpleNamespace(id="workflow-1", tenant_id="tenant-1", features_dict={})
pipeline = pipeline_module.AdvancedChatAppGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
workflow=workflow,
queue_manager=queue_manager,
conversation=conversation,
message=message,
user=user,
stream=True,
dialogue_count=1,
draft_var_saver_factory=SimpleNamespace(),
)
pipeline._get_message = mock.Mock(return_value=message)
pipeline._recorded_files = []
list(pipeline._handle_text_chunk_event(QueueTextChunkEvent(text="after")))
pipeline._save_message(session=mock.Mock())
assert message.answer == "beforeafter"
assert message.status == MessageStatus.NORMAL

View File

@ -0,0 +1,87 @@
from datetime import UTC, datetime
from types import SimpleNamespace
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueHumanInputFormFilledEvent, QueueHumanInputFormTimeoutEvent
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
def _build_converter():
system_variables = SystemVariable(
files=[],
user_id="user-1",
app_id="app-1",
workflow_id="wf-1",
workflow_execution_id="run-1",
)
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
app_entity = SimpleNamespace(
task_id="task-1",
app_config=SimpleNamespace(app_id="app-1", tenant_id="tenant-1"),
invoke_from=InvokeFrom.EXPLORE,
files=[],
inputs={},
workflow_execution_id="run-1",
call_depth=0,
)
account = SimpleNamespace(id="acc-1", name="tester", email="tester@example.com")
return WorkflowResponseConverter(
application_generate_entity=app_entity,
user=account,
system_variables=system_variables,
)
def test_human_input_form_filled_stream_response_contains_rendered_content():
converter = _build_converter()
converter.workflow_start_to_stream_response(
task_id="task-1",
workflow_run_id="run-1",
workflow_id="wf-1",
reason=WorkflowStartReason.INITIAL,
)
queue_event = QueueHumanInputFormFilledEvent(
node_execution_id="exec-1",
node_id="node-1",
node_type="human-input",
node_title="Human Input",
rendered_content="# Title\nvalue",
action_id="Approve",
action_text="Approve",
)
resp = converter.human_input_form_filled_to_stream_response(event=queue_event, task_id="task-1")
assert resp.workflow_run_id == "run-1"
assert resp.data.node_id == "node-1"
assert resp.data.node_title == "Human Input"
assert resp.data.rendered_content.startswith("# Title")
assert resp.data.action_id == "Approve"
def test_human_input_form_timeout_stream_response_contains_timeout_metadata():
converter = _build_converter()
converter.workflow_start_to_stream_response(
task_id="task-1",
workflow_run_id="run-1",
workflow_id="wf-1",
reason=WorkflowStartReason.INITIAL,
)
queue_event = QueueHumanInputFormTimeoutEvent(
node_id="node-1",
node_type="human-input",
node_title="Human Input",
expiration_time=datetime(2025, 1, 1, tzinfo=UTC),
)
resp = converter.human_input_form_timeout_to_stream_response(event=queue_event, task_id="task-1")
assert resp.workflow_run_id == "run-1"
assert resp.data.node_id == "node-1"
assert resp.data.node_title == "Human Input"
assert resp.data.expiration_time == 1735689600

View File

@ -0,0 +1,56 @@
from types import SimpleNamespace
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
def _build_converter() -> WorkflowResponseConverter:
"""Construct a minimal WorkflowResponseConverter for testing."""
system_variables = SystemVariable(
files=[],
user_id="user-1",
app_id="app-1",
workflow_id="wf-1",
workflow_execution_id="run-1",
)
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
app_entity = SimpleNamespace(
task_id="task-1",
app_config=SimpleNamespace(app_id="app-1", tenant_id="tenant-1"),
invoke_from=InvokeFrom.EXPLORE,
files=[],
inputs={},
workflow_execution_id="run-1",
call_depth=0,
)
account = SimpleNamespace(id="acc-1", name="tester", email="tester@example.com")
return WorkflowResponseConverter(
application_generate_entity=app_entity,
user=account,
system_variables=system_variables,
)
def test_workflow_start_stream_response_carries_resumption_reason():
converter = _build_converter()
resp = converter.workflow_start_to_stream_response(
task_id="task-1",
workflow_run_id="run-1",
workflow_id="wf-1",
reason=WorkflowStartReason.RESUMPTION,
)
assert resp.data.reason is WorkflowStartReason.RESUMPTION
def test_workflow_start_stream_response_carries_initial_reason():
converter = _build_converter()
resp = converter.workflow_start_to_stream_response(
task_id="task-1",
workflow_run_id="run-1",
workflow_id="wf-1",
reason=WorkflowStartReason.INITIAL,
)
assert resp.data.reason is WorkflowStartReason.INITIAL

View File

@ -23,6 +23,7 @@ from core.app.entities.queue_entities import (
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
)
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.enums import NodeType
from core.workflow.system_variable import SystemVariable
from libs.datetime_utils import naive_utc_now
@ -124,7 +125,12 @@ class TestWorkflowResponseConverter:
original_data = {"large_field": "x" * 10000, "metadata": "info"}
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
converter.workflow_start_to_stream_response(
task_id="bootstrap",
workflow_run_id="run-id",
workflow_id="wf-id",
reason=WorkflowStartReason.INITIAL,
)
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
@ -160,7 +166,12 @@ class TestWorkflowResponseConverter:
original_data = {"small": "data"}
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
converter.workflow_start_to_stream_response(
task_id="bootstrap",
workflow_run_id="run-id",
workflow_id="wf-id",
reason=WorkflowStartReason.INITIAL,
)
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
@ -191,7 +202,12 @@ class TestWorkflowResponseConverter:
"""Test node finish response when process_data is None."""
converter = self.create_workflow_response_converter()
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
converter.workflow_start_to_stream_response(
task_id="bootstrap",
workflow_run_id="run-id",
workflow_id="wf-id",
reason=WorkflowStartReason.INITIAL,
)
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
@ -225,7 +241,12 @@ class TestWorkflowResponseConverter:
original_data = {"large_field": "x" * 10000, "metadata": "info"}
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
converter.workflow_start_to_stream_response(
task_id="bootstrap",
workflow_run_id="run-id",
workflow_id="wf-id",
reason=WorkflowStartReason.INITIAL,
)
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
@ -261,7 +282,12 @@ class TestWorkflowResponseConverter:
original_data = {"small": "data"}
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
converter.workflow_start_to_stream_response(
task_id="bootstrap",
workflow_run_id="run-id",
workflow_id="wf-id",
reason=WorkflowStartReason.INITIAL,
)
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
@ -400,6 +426,7 @@ class TestWorkflowResponseConverterServiceApiTruncation:
task_id="test-task-id",
workflow_run_id="test-workflow-run-id",
workflow_id="test-workflow-id",
reason=WorkflowStartReason.INITIAL,
)
return converter

View File

@ -0,0 +1,139 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig
from core.app.apps import message_based_app_generator
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.task_pipeline import message_cycle_manager
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from models.model import AppMode, Conversation, Message
def _make_app_config() -> WorkflowUIBasedAppConfig:
return WorkflowUIBasedAppConfig(
tenant_id="tenant-id",
app_id="app-id",
app_mode=AppMode.ADVANCED_CHAT,
workflow_id="workflow-id",
additional_features=AppAdditionalFeatures(),
variables=[],
)
def _make_generate_entity(app_config: WorkflowUIBasedAppConfig) -> AdvancedChatAppGenerateEntity:
return AdvancedChatAppGenerateEntity(
task_id="task-id",
app_config=app_config,
file_upload_config=None,
conversation_id=None,
inputs={},
query="hello",
files=[],
parent_message_id=None,
user_id="user-id",
stream=True,
invoke_from=InvokeFrom.WEB_APP,
extras={},
workflow_run_id="workflow-run-id",
)
@pytest.fixture(autouse=True)
def _mock_db_session(monkeypatch):
session = MagicMock()
def refresh_side_effect(obj):
if isinstance(obj, Conversation) and obj.id is None:
obj.id = "generated-conversation-id"
if isinstance(obj, Message) and obj.id is None:
obj.id = "generated-message-id"
session.refresh.side_effect = refresh_side_effect
session.add.return_value = None
session.commit.return_value = None
monkeypatch.setattr(message_based_app_generator, "db", SimpleNamespace(session=session))
return session
def test_init_generate_records_sets_conversation_metadata():
app_config = _make_app_config()
entity = _make_generate_entity(app_config)
generator = AdvancedChatAppGenerator()
conversation, _ = generator._init_generate_records(entity, conversation=None)
assert entity.conversation_id == "generated-conversation-id"
assert conversation.id == "generated-conversation-id"
assert entity.is_new_conversation is True
def test_init_generate_records_marks_existing_conversation():
app_config = _make_app_config()
entity = _make_generate_entity(app_config)
existing_conversation = Conversation(
app_id=app_config.app_id,
app_model_config_id=None,
model_provider=None,
override_model_configs=None,
model_id=None,
mode=app_config.app_mode.value,
name="existing",
inputs={},
introduction="",
system_instruction="",
system_instruction_tokens=0,
status="normal",
invoke_from=InvokeFrom.WEB_APP.value,
from_source="api",
from_end_user_id="user-id",
from_account_id=None,
)
existing_conversation.id = "existing-conversation-id"
generator = AdvancedChatAppGenerator()
conversation, _ = generator._init_generate_records(entity, conversation=existing_conversation)
assert entity.conversation_id == "existing-conversation-id"
assert conversation is existing_conversation
assert entity.is_new_conversation is False
def test_message_cycle_manager_uses_new_conversation_flag(monkeypatch):
app_config = _make_app_config()
entity = _make_generate_entity(app_config)
entity.conversation_id = "existing-conversation-id"
entity.is_new_conversation = True
entity.extras = {"auto_generate_conversation_name": True}
captured = {}
class DummyThread:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.started = False
def start(self):
self.started = True
def fake_thread(**kwargs):
thread = DummyThread(**kwargs)
captured["thread"] = thread
return thread
monkeypatch.setattr(message_cycle_manager, "Thread", fake_thread)
manager = MessageCycleManager(application_generate_entity=entity, task_state=MagicMock())
thread = manager.generate_conversation_name(conversation_id="existing-conversation-id", query="hello")
assert thread is captured["thread"]
assert thread.started is True
assert entity.is_new_conversation is False

View File

@ -0,0 +1,127 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from core.app.app_config.entities import (
AppAdditionalFeatures,
EasyUIBasedAppConfig,
EasyUIBasedAppModelConfigFrom,
ModelConfigEntity,
PromptTemplateEntity,
)
from core.app.apps import message_based_app_generator
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
from models.model import AppMode, Conversation, Message
class DummyModelConf:
def __init__(self, provider: str = "mock-provider", model: str = "mock-model") -> None:
self.provider = provider
self.model = model
class DummyCompletionGenerateEntity:
__slots__ = ("app_config", "invoke_from", "user_id", "query", "inputs", "files", "model_conf")
app_config: EasyUIBasedAppConfig
invoke_from: InvokeFrom
user_id: str
query: str
inputs: dict
files: list
model_conf: DummyModelConf
def __init__(self, app_config: EasyUIBasedAppConfig) -> None:
self.app_config = app_config
self.invoke_from = InvokeFrom.WEB_APP
self.user_id = "user-id"
self.query = "hello"
self.inputs = {}
self.files = []
self.model_conf = DummyModelConf()
def _make_app_config(app_mode: AppMode) -> EasyUIBasedAppConfig:
return EasyUIBasedAppConfig(
tenant_id="tenant-id",
app_id="app-id",
app_mode=app_mode,
app_model_config_from=EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG,
app_model_config_id="model-config-id",
app_model_config_dict={},
model=ModelConfigEntity(provider="mock-provider", model="mock-model", mode="chat"),
prompt_template=PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
simple_prompt_template="Hello",
),
additional_features=AppAdditionalFeatures(),
variables=[],
)
def _make_chat_generate_entity(app_config: EasyUIBasedAppConfig) -> ChatAppGenerateEntity:
return ChatAppGenerateEntity.model_construct(
task_id="task-id",
app_config=app_config,
model_conf=DummyModelConf(),
file_upload_config=None,
conversation_id=None,
inputs={},
query="hello",
files=[],
parent_message_id=None,
user_id="user-id",
stream=False,
invoke_from=InvokeFrom.WEB_APP,
extras={},
call_depth=0,
trace_manager=None,
)
@pytest.fixture(autouse=True)
def _mock_db_session(monkeypatch):
session = MagicMock()
def refresh_side_effect(obj):
if isinstance(obj, Conversation) and obj.id is None:
obj.id = "generated-conversation-id"
if isinstance(obj, Message) and obj.id is None:
obj.id = "generated-message-id"
session.refresh.side_effect = refresh_side_effect
session.add.return_value = None
session.commit.return_value = None
monkeypatch.setattr(message_based_app_generator, "db", SimpleNamespace(session=session))
return session
def test_init_generate_records_skips_conversation_fields_for_non_conversation_entity():
app_config = _make_app_config(AppMode.COMPLETION)
entity = DummyCompletionGenerateEntity(app_config=app_config)
generator = MessageBasedAppGenerator()
conversation, message = generator._init_generate_records(entity, conversation=None)
assert conversation.id == "generated-conversation-id"
assert message.id == "generated-message-id"
assert hasattr(entity, "conversation_id") is False
assert hasattr(entity, "is_new_conversation") is False
def test_init_generate_records_sets_conversation_fields_for_chat_entity():
app_config = _make_app_config(AppMode.CHAT)
entity = _make_chat_generate_entity(app_config)
generator = MessageBasedAppGenerator()
conversation, _ = generator._init_generate_records(entity, conversation=None)
assert entity.conversation_id == "generated-conversation-id"
assert entity.is_new_conversation is True
assert conversation.id == "generated-conversation-id"

View File

@ -0,0 +1,287 @@
import sys
import time
from pathlib import Path
from types import ModuleType, SimpleNamespace
from typing import Any
API_DIR = str(Path(__file__).resolve().parents[5])
if API_DIR not in sys.path:
sys.path.insert(0, API_DIR)
import core.workflow.nodes.human_input.entities # noqa: F401
from core.app.apps.advanced_chat import app_generator as adv_app_gen_module
from core.app.apps.workflow import app_generator as wf_app_gen_module
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.entities import GraphInitParams
from core.workflow.entities.pause_reason import SchedulingPause
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel
from core.workflow.graph_events import (
GraphEngineEvent,
GraphRunPausedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunSucceededEvent,
)
from core.workflow.node_events import NodeRunResult, PauseRequestedEvent
from core.workflow.nodes.base.entities import BaseNodeData, OutputVariableEntity, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
if "core.ops.ops_trace_manager" not in sys.modules:
ops_stub = ModuleType("core.ops.ops_trace_manager")
class _StubTraceQueueManager:
def __init__(self, *_, **__):
pass
ops_stub.TraceQueueManager = _StubTraceQueueManager
sys.modules["core.ops.ops_trace_manager"] = ops_stub
class _StubToolNodeData(BaseNodeData):
pause_on: bool = False
class _StubToolNode(Node[_StubToolNodeData]):
node_type = NodeType.TOOL
@classmethod
def version(cls) -> str:
return "1"
def init_node_data(self, data):
self._node_data = _StubToolNodeData.model_validate(data)
def _get_error_strategy(self):
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self):
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
def _run(self):
if self.node_data.pause_on:
yield PauseRequestedEvent(reason=SchedulingPause(message="test pause"))
return
result = NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"value": f"{self.id}-done"},
)
yield self._convert_node_run_result_to_graph_node_event(result)
def _patch_tool_node(mocker):
original_create_node = DifyNodeFactory.create_node
def _patched_create_node(self, node_config: dict[str, object]) -> Node:
node_data = node_config.get("data", {})
if isinstance(node_data, dict) and node_data.get("type") == NodeType.TOOL.value:
return _StubToolNode(
id=str(node_config["id"]),
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
)
return original_create_node(self, node_config)
mocker.patch.object(DifyNodeFactory, "create_node", _patched_create_node)
def _node_data(node_type: NodeType, data: BaseNodeData) -> dict[str, object]:
node_data = data.model_dump()
node_data["type"] = node_type.value
return node_data
def _build_graph_config(*, pause_on: str | None) -> dict[str, object]:
start_data = StartNodeData(title="start", variables=[])
tool_data_a = _StubToolNodeData(title="tool", pause_on=pause_on == "tool_a")
tool_data_b = _StubToolNodeData(title="tool", pause_on=pause_on == "tool_b")
tool_data_c = _StubToolNodeData(title="tool", pause_on=pause_on == "tool_c")
end_data = EndNodeData(
title="end",
outputs=[OutputVariableEntity(variable="result", value_selector=["tool_c", "value"])],
desc=None,
)
nodes = [
{"id": "start", "data": _node_data(NodeType.START, start_data)},
{"id": "tool_a", "data": _node_data(NodeType.TOOL, tool_data_a)},
{"id": "tool_b", "data": _node_data(NodeType.TOOL, tool_data_b)},
{"id": "tool_c", "data": _node_data(NodeType.TOOL, tool_data_c)},
{"id": "end", "data": _node_data(NodeType.END, end_data)},
]
edges = [
{"source": "start", "target": "tool_a"},
{"source": "tool_a", "target": "tool_b"},
{"source": "tool_b", "target": "tool_c"},
{"source": "tool_c", "target": "end"},
]
return {"nodes": nodes, "edges": edges}
def _build_graph(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> Graph:
graph_config = _build_graph_config(pause_on=pause_on)
params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config=graph_config,
user_id="user",
user_from="account",
invoke_from="service-api",
call_depth=0,
)
node_factory = DifyNodeFactory(
graph_init_params=params,
graph_runtime_state=runtime_state,
)
return Graph.init(graph_config=graph_config, node_factory=node_factory)
def _build_runtime_state(run_id: str) -> GraphRuntimeState:
variable_pool = VariablePool(
system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"),
user_inputs={},
conversation_variables=[],
)
variable_pool.system_variables.workflow_execution_id = run_id
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
def _run_with_optional_pause(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> list[GraphEngineEvent]:
command_channel = InMemoryChannel()
graph = _build_graph(runtime_state, pause_on=pause_on)
engine = GraphEngine(
workflow_id="workflow",
graph=graph,
graph_runtime_state=runtime_state,
command_channel=command_channel,
)
events: list[GraphEngineEvent] = []
for event in engine.run():
events.append(event)
return events
def _node_successes(events: list[GraphEngineEvent]) -> list[str]:
return [evt.node_id for evt in events if isinstance(evt, NodeRunSucceededEvent)]
def test_workflow_app_pause_resume_matches_baseline(mocker):
_patch_tool_node(mocker)
baseline_state = _build_runtime_state("baseline")
baseline_events = _run_with_optional_pause(baseline_state, pause_on=None)
assert isinstance(baseline_events[-1], GraphRunSucceededEvent)
baseline_nodes = _node_successes(baseline_events)
baseline_outputs = baseline_state.outputs
paused_state = _build_runtime_state("paused-run")
paused_events = _run_with_optional_pause(paused_state, pause_on="tool_a")
assert isinstance(paused_events[-1], GraphRunPausedEvent)
paused_nodes = _node_successes(paused_events)
snapshot = paused_state.dumps()
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
generator = wf_app_gen_module.WorkflowAppGenerator()
def _fake_generate(**kwargs):
state: GraphRuntimeState = kwargs["graph_runtime_state"]
events = _run_with_optional_pause(state, pause_on=None)
return _node_successes(events)
mocker.patch.object(generator, "_generate", side_effect=_fake_generate)
resumed_nodes = generator.resume(
app_model=SimpleNamespace(mode="workflow"),
workflow=SimpleNamespace(),
user=SimpleNamespace(),
application_generate_entity=SimpleNamespace(stream=False, invoke_from=InvokeFrom.SERVICE_API),
graph_runtime_state=resumed_state,
workflow_execution_repository=SimpleNamespace(),
workflow_node_execution_repository=SimpleNamespace(),
)
assert paused_nodes + resumed_nodes == baseline_nodes
assert resumed_state.outputs == baseline_outputs
def test_advanced_chat_pause_resume_matches_baseline(mocker):
_patch_tool_node(mocker)
baseline_state = _build_runtime_state("adv-baseline")
baseline_events = _run_with_optional_pause(baseline_state, pause_on=None)
assert isinstance(baseline_events[-1], GraphRunSucceededEvent)
baseline_nodes = _node_successes(baseline_events)
baseline_outputs = baseline_state.outputs
paused_state = _build_runtime_state("adv-paused")
paused_events = _run_with_optional_pause(paused_state, pause_on="tool_a")
assert isinstance(paused_events[-1], GraphRunPausedEvent)
paused_nodes = _node_successes(paused_events)
snapshot = paused_state.dumps()
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
generator = adv_app_gen_module.AdvancedChatAppGenerator()
def _fake_generate(**kwargs):
state: GraphRuntimeState = kwargs["graph_runtime_state"]
events = _run_with_optional_pause(state, pause_on=None)
return _node_successes(events)
mocker.patch.object(generator, "_generate", side_effect=_fake_generate)
resumed_nodes = generator.resume(
app_model=SimpleNamespace(mode="workflow"),
workflow=SimpleNamespace(),
user=SimpleNamespace(),
conversation=SimpleNamespace(id="conv"),
message=SimpleNamespace(id="msg"),
application_generate_entity=SimpleNamespace(stream=False, invoke_from=InvokeFrom.SERVICE_API),
workflow_execution_repository=SimpleNamespace(),
workflow_node_execution_repository=SimpleNamespace(),
graph_runtime_state=resumed_state,
)
assert paused_nodes + resumed_nodes == baseline_nodes
assert resumed_state.outputs == baseline_outputs
def test_resume_emits_resumption_start_reason(mocker) -> None:
_patch_tool_node(mocker)
paused_state = _build_runtime_state("resume-reason")
paused_events = _run_with_optional_pause(paused_state, pause_on="tool_a")
initial_start = next(event for event in paused_events if isinstance(event, GraphRunStartedEvent))
assert initial_start.reason == WorkflowStartReason.INITIAL
resumed_state = GraphRuntimeState.from_snapshot(paused_state.dumps())
resumed_events = _run_with_optional_pause(resumed_state, pause_on=None)
resume_start = next(event for event in resumed_events if isinstance(event, GraphRunStartedEvent))
assert resume_start.reason == WorkflowStartReason.RESUMPTION

View File

@ -0,0 +1,80 @@
from __future__ import annotations
import json
import queue
import pytest
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.entities.task_entities import StreamEvent
from models.model import AppMode
class FakeSubscription:
def __init__(self, message_queue: queue.Queue[bytes], state: dict[str, bool]) -> None:
self._queue = message_queue
self._state = state
self._closed = False
def __enter__(self):
self._state["subscribed"] = True
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
def close(self) -> None:
self._closed = True
def receive(self, timeout: float | None = 0.1) -> bytes | None:
if self._closed:
return None
try:
if timeout is None:
return self._queue.get()
return self._queue.get(timeout=timeout)
except queue.Empty:
return None
class FakeTopic:
def __init__(self) -> None:
self._queue: queue.Queue[bytes] = queue.Queue()
self._state = {"subscribed": False}
def subscribe(self) -> FakeSubscription:
return FakeSubscription(self._queue, self._state)
def publish(self, payload: bytes) -> None:
self._queue.put(payload)
@property
def subscribed(self) -> bool:
return self._state["subscribed"]
def test_retrieve_events_calls_on_subscribe_after_subscription(monkeypatch):
topic = FakeTopic()
def fake_get_response_topic(cls, app_mode, workflow_run_id):
return topic
monkeypatch.setattr(MessageBasedAppGenerator, "get_response_topic", classmethod(fake_get_response_topic))
def on_subscribe() -> None:
assert topic.subscribed is True
event = {"event": StreamEvent.WORKFLOW_FINISHED.value}
topic.publish(json.dumps(event).encode())
generator = MessageBasedAppGenerator.retrieve_events(
AppMode.WORKFLOW,
"workflow-run-id",
idle_timeout=0.5,
on_subscribe=on_subscribe,
)
assert next(generator) == StreamEvent.PING.value
event = next(generator)
assert event["event"] == StreamEvent.WORKFLOW_FINISHED.value
with pytest.raises(StopIteration):
next(generator)

View File

@ -1,3 +1,6 @@
from types import SimpleNamespace
from unittest.mock import MagicMock
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator
@ -17,3 +20,193 @@ def test_should_prepare_user_inputs_keeps_validation_when_flag_false():
args = {"inputs": {}, SKIP_PREPARE_USER_INPUTS_KEY: False}
assert WorkflowAppGenerator()._should_prepare_user_inputs(args)
def test_resume_delegates_to_generate(mocker):
generator = WorkflowAppGenerator()
mock_generate = mocker.patch.object(generator, "_generate", return_value="ok")
application_generate_entity = SimpleNamespace(stream=False, invoke_from="debugger")
runtime_state = MagicMock(name="runtime-state")
pause_config = MagicMock(name="pause-config")
result = generator.resume(
app_model=MagicMock(),
workflow=MagicMock(),
user=MagicMock(),
application_generate_entity=application_generate_entity,
graph_runtime_state=runtime_state,
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
graph_engine_layers=("layer",),
pause_state_config=pause_config,
variable_loader=MagicMock(),
)
assert result == "ok"
mock_generate.assert_called_once()
kwargs = mock_generate.call_args.kwargs
assert kwargs["graph_runtime_state"] is runtime_state
assert kwargs["pause_state_config"] is pause_config
assert kwargs["streaming"] is False
assert kwargs["invoke_from"] == "debugger"
def test_generate_appends_pause_layer_and_forwards_state(mocker):
generator = WorkflowAppGenerator()
mock_queue_manager = MagicMock()
mocker.patch("core.app.apps.workflow.app_generator.WorkflowAppQueueManager", return_value=mock_queue_manager)
fake_current_app = MagicMock()
fake_current_app._get_current_object.return_value = MagicMock()
mocker.patch("core.app.apps.workflow.app_generator.current_app", fake_current_app)
mocker.patch(
"core.app.apps.workflow.app_generator.WorkflowAppGenerateResponseConverter.convert",
return_value="converted",
)
mocker.patch.object(WorkflowAppGenerator, "_handle_response", return_value="response")
mocker.patch.object(WorkflowAppGenerator, "_get_draft_var_saver_factory", return_value=MagicMock())
pause_layer = MagicMock(name="pause-layer")
mocker.patch(
"core.app.apps.workflow.app_generator.PauseStatePersistenceLayer",
return_value=pause_layer,
)
dummy_session = MagicMock()
dummy_session.close = MagicMock()
mocker.patch("core.app.apps.workflow.app_generator.db.session", dummy_session)
worker_kwargs: dict[str, object] = {}
class DummyThread:
def __init__(self, target, kwargs):
worker_kwargs["target"] = target
worker_kwargs["kwargs"] = kwargs
def start(self):
return None
mocker.patch("core.app.apps.workflow.app_generator.threading.Thread", DummyThread)
app_model = SimpleNamespace(mode="workflow")
app_config = SimpleNamespace(app_id="app", tenant_id="tenant", workflow_id="wf")
application_generate_entity = SimpleNamespace(
task_id="task",
user_id="user",
invoke_from="service-api",
app_config=app_config,
files=[],
stream=True,
workflow_execution_id="run",
)
graph_runtime_state = MagicMock()
result = generator._generate(
app_model=app_model,
workflow=MagicMock(),
user=MagicMock(),
application_generate_entity=application_generate_entity,
invoke_from="service-api",
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
streaming=True,
graph_engine_layers=("base-layer",),
graph_runtime_state=graph_runtime_state,
pause_state_config=SimpleNamespace(session_factory=MagicMock(), state_owner_user_id="owner"),
)
assert result == "converted"
assert worker_kwargs["kwargs"]["graph_engine_layers"] == ("base-layer", pause_layer)
assert worker_kwargs["kwargs"]["graph_runtime_state"] is graph_runtime_state
def test_resume_path_runs_worker_with_runtime_state(mocker):
generator = WorkflowAppGenerator()
runtime_state = MagicMock(name="runtime-state")
pause_layer = MagicMock(name="pause-layer")
mocker.patch("core.app.apps.workflow.app_generator.PauseStatePersistenceLayer", return_value=pause_layer)
queue_manager = MagicMock()
mocker.patch("core.app.apps.workflow.app_generator.WorkflowAppQueueManager", return_value=queue_manager)
mocker.patch.object(generator, "_handle_response", return_value="raw-response")
mocker.patch(
"core.app.apps.workflow.app_generator.WorkflowAppGenerateResponseConverter.convert",
side_effect=lambda response, invoke_from: response,
)
fake_db = SimpleNamespace(session=MagicMock(), engine=MagicMock())
mocker.patch("core.app.apps.workflow.app_generator.db", fake_db)
workflow = SimpleNamespace(
id="workflow", tenant_id="tenant", app_id="app", graph_dict={}, type="workflow", version="1"
)
end_user = SimpleNamespace(session_id="end-user-session")
app_record = SimpleNamespace(id="app")
session = MagicMock()
session.__enter__.return_value = session
session.__exit__.return_value = False
session.scalar.side_effect = [workflow, end_user, app_record]
mocker.patch("core.app.apps.workflow.app_generator.session_factory", return_value=session)
runner_instance = MagicMock()
def runner_ctor(**kwargs):
assert kwargs["graph_runtime_state"] is runtime_state
return runner_instance
mocker.patch("core.app.apps.workflow.app_generator.WorkflowAppRunner", side_effect=runner_ctor)
class ImmediateThread:
def __init__(self, target, kwargs):
target(**kwargs)
def start(self):
return None
mocker.patch("core.app.apps.workflow.app_generator.threading.Thread", ImmediateThread)
mocker.patch(
"core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository",
return_value=MagicMock(),
)
mocker.patch(
"core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository",
return_value=MagicMock(),
)
pause_config = SimpleNamespace(session_factory=MagicMock(), state_owner_user_id="owner")
app_model = SimpleNamespace(mode="workflow")
app_config = SimpleNamespace(app_id="app", tenant_id="tenant", workflow_id="workflow")
application_generate_entity = SimpleNamespace(
task_id="task",
user_id="user",
invoke_from="service-api",
app_config=app_config,
files=[],
stream=True,
workflow_execution_id="run",
trace_manager=MagicMock(),
)
result = generator.resume(
app_model=app_model,
workflow=workflow,
user=MagicMock(),
application_generate_entity=application_generate_entity,
graph_runtime_state=runtime_state,
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
pause_state_config=pause_config,
)
assert result == "raw-response"
runner_instance.run.assert_called_once()
queue_manager.graph_runtime_state = runtime_state

View File

@ -0,0 +1,59 @@
from unittest.mock import MagicMock
import pytest
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.entities.queue_entities import QueueWorkflowPausedEvent
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.graph_events.graph import GraphRunPausedEvent
class _DummyQueueManager:
def __init__(self):
self.published = []
def publish(self, event, _from):
self.published.append(event)
class _DummyRuntimeState:
def get_paused_nodes(self):
return ["node-1"]
class _DummyGraphEngine:
def __init__(self):
self.graph_runtime_state = _DummyRuntimeState()
class _DummyWorkflowEntry:
def __init__(self):
self.graph_engine = _DummyGraphEngine()
def test_handle_pause_event_enqueues_email_task(monkeypatch: pytest.MonkeyPatch):
queue_manager = _DummyQueueManager()
runner = WorkflowBasedAppRunner(queue_manager=queue_manager, app_id="app-id")
workflow_entry = _DummyWorkflowEntry()
reason = HumanInputRequired(
form_id="form-123",
form_content="content",
inputs=[],
actions=[],
node_id="node-1",
node_title="Review",
)
event = GraphRunPausedEvent(reasons=[reason], outputs={})
email_task = MagicMock()
monkeypatch.setattr("core.app.apps.workflow_app_runner.dispatch_human_input_email_task", email_task)
runner._handle_event(workflow_entry, event)
email_task.apply_async.assert_called_once()
kwargs = email_task.apply_async.call_args.kwargs["kwargs"]
assert kwargs["form_id"] == "form-123"
assert kwargs["node_title"] == "Review"
assert any(isinstance(evt, QueueWorkflowPausedEvent) for evt in queue_manager.published)

View File

@ -0,0 +1,183 @@
from datetime import UTC, datetime
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from core.app.apps.common import workflow_response_converter
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.apps.workflow.app_runner import WorkflowAppRunner
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueWorkflowPausedEvent
from core.app.entities.task_entities import HumanInputRequiredResponse, WorkflowPauseStreamResponse
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.graph_events.graph import GraphRunPausedEvent
from core.workflow.nodes.human_input.entities import FormInput, UserAction
from core.workflow.nodes.human_input.enums import FormInputType
from core.workflow.system_variable import SystemVariable
from models.account import Account
class _RecordingWorkflowAppRunner(WorkflowAppRunner):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.published_events = []
def _publish_event(self, event):
self.published_events.append(event)
class _FakeRuntimeState:
def get_paused_nodes(self):
return ["node-pause-1"]
def _build_runner():
app_entity = SimpleNamespace(
app_config=SimpleNamespace(app_id="app-id"),
inputs={},
files=[],
invoke_from=InvokeFrom.SERVICE_API,
single_iteration_run=None,
single_loop_run=None,
workflow_execution_id="run-id",
user_id="user-id",
)
workflow = SimpleNamespace(
graph_dict={},
tenant_id="tenant-id",
environment_variables={},
id="workflow-id",
)
queue_manager = SimpleNamespace(publish=lambda event, pub_from: None)
return _RecordingWorkflowAppRunner(
application_generate_entity=app_entity,
queue_manager=queue_manager,
variable_loader=MagicMock(),
workflow=workflow,
system_user_id="sys-user",
root_node_id=None,
workflow_execution_repository=MagicMock(),
workflow_node_execution_repository=MagicMock(),
graph_engine_layers=(),
graph_runtime_state=None,
)
def test_graph_run_paused_event_emits_queue_pause_event():
runner = _build_runner()
reason = HumanInputRequired(
form_id="form-1",
form_content="content",
inputs=[],
actions=[],
node_id="node-human",
node_title="Human Step",
form_token="tok",
)
event = GraphRunPausedEvent(reasons=[reason], outputs={"foo": "bar"})
workflow_entry = SimpleNamespace(
graph_engine=SimpleNamespace(graph_runtime_state=_FakeRuntimeState()),
)
runner._handle_event(workflow_entry, event)
assert len(runner.published_events) == 1
queue_event = runner.published_events[0]
assert isinstance(queue_event, QueueWorkflowPausedEvent)
assert queue_event.reasons == [reason]
assert queue_event.outputs == {"foo": "bar"}
assert queue_event.paused_nodes == ["node-pause-1"]
def _build_converter():
application_generate_entity = SimpleNamespace(
inputs={},
files=[],
invoke_from=InvokeFrom.SERVICE_API,
app_config=SimpleNamespace(app_id="app-id", tenant_id="tenant-id"),
)
system_variables = SystemVariable(
user_id="user",
app_id="app-id",
workflow_id="workflow-id",
workflow_execution_id="run-id",
)
user = MagicMock(spec=Account)
user.id = "account-id"
user.name = "Tester"
user.email = "tester@example.com"
return WorkflowResponseConverter(
application_generate_entity=application_generate_entity,
user=user,
system_variables=system_variables,
)
def test_queue_workflow_paused_event_to_stream_responses(monkeypatch: pytest.MonkeyPatch):
converter = _build_converter()
converter.workflow_start_to_stream_response(
task_id="task",
workflow_run_id="run-id",
workflow_id="workflow-id",
reason=WorkflowStartReason.INITIAL,
)
expiration_time = datetime(2024, 1, 1, tzinfo=UTC)
class _FakeSession:
def execute(self, _stmt):
return [("form-1", expiration_time)]
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
monkeypatch.setattr(workflow_response_converter, "Session", lambda **_: _FakeSession())
monkeypatch.setattr(workflow_response_converter, "db", SimpleNamespace(engine=object()))
reason = HumanInputRequired(
form_id="form-1",
form_content="Rendered",
inputs=[
FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="field", default=None),
],
actions=[UserAction(id="approve", title="Approve")],
display_in_ui=True,
node_id="node-id",
node_title="Human Step",
form_token="token",
)
queue_event = QueueWorkflowPausedEvent(
reasons=[reason],
outputs={"answer": "value"},
paused_nodes=["node-id"],
)
runtime_state = SimpleNamespace(total_tokens=0, node_run_steps=0)
responses = converter.workflow_pause_to_stream_response(
event=queue_event,
task_id="task",
graph_runtime_state=runtime_state,
)
assert isinstance(responses[-1], WorkflowPauseStreamResponse)
pause_resp = responses[-1]
assert pause_resp.workflow_run_id == "run-id"
assert pause_resp.data.paused_nodes == ["node-id"]
assert pause_resp.data.outputs == {}
assert pause_resp.data.reasons[0]["form_id"] == "form-1"
assert pause_resp.data.reasons[0]["display_in_ui"] is True
assert isinstance(responses[0], HumanInputRequiredResponse)
hi_resp = responses[0]
assert hi_resp.data.form_id == "form-1"
assert hi_resp.data.node_id == "node-id"
assert hi_resp.data.node_title == "Human Step"
assert hi_resp.data.inputs[0].output_variable_name == "field"
assert hi_resp.data.actions[0].id == "approve"
assert hi_resp.data.display_in_ui is True
assert hi_resp.data.expiration_time == int(expiration_time.timestamp())

View File

@ -0,0 +1,96 @@
import time
from contextlib import contextmanager
from unittest.mock import MagicMock
from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import QueueWorkflowStartedEvent
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from models.account import Account
from models.model import AppMode
def _build_workflow_app_config() -> WorkflowUIBasedAppConfig:
return WorkflowUIBasedAppConfig(
tenant_id="tenant-id",
app_id="app-id",
app_mode=AppMode.WORKFLOW,
workflow_id="workflow-id",
)
def _build_generate_entity(run_id: str) -> WorkflowAppGenerateEntity:
return WorkflowAppGenerateEntity(
task_id="task-id",
app_config=_build_workflow_app_config(),
inputs={},
files=[],
user_id="user-id",
stream=False,
invoke_from=InvokeFrom.SERVICE_API,
workflow_execution_id=run_id,
)
def _build_runtime_state(run_id: str) -> GraphRuntimeState:
variable_pool = VariablePool(
system_variables=SystemVariable(workflow_execution_id=run_id),
user_inputs={},
conversation_variables=[],
)
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
@contextmanager
def _noop_session():
yield MagicMock()
def _build_pipeline(run_id: str) -> WorkflowAppGenerateTaskPipeline:
queue_manager = MagicMock(spec=AppQueueManager)
queue_manager.invoke_from = InvokeFrom.SERVICE_API
queue_manager.graph_runtime_state = _build_runtime_state(run_id)
workflow = MagicMock()
workflow.id = "workflow-id"
workflow.features_dict = {}
user = Account(name="user", email="user@example.com")
pipeline = WorkflowAppGenerateTaskPipeline(
application_generate_entity=_build_generate_entity(run_id),
workflow=workflow,
queue_manager=queue_manager,
user=user,
stream=False,
draft_var_saver_factory=MagicMock(),
)
pipeline._database_session = _noop_session
return pipeline
def test_workflow_app_log_saved_only_on_initial_start() -> None:
run_id = "run-initial"
pipeline = _build_pipeline(run_id)
pipeline._save_workflow_app_log = MagicMock()
event = QueueWorkflowStartedEvent(reason=WorkflowStartReason.INITIAL)
list(pipeline._handle_workflow_started_event(event))
pipeline._save_workflow_app_log.assert_called_once()
_, kwargs = pipeline._save_workflow_app_log.call_args
assert kwargs["workflow_run_id"] == run_id
assert pipeline._workflow_execution_id == run_id
def test_workflow_app_log_skipped_on_resumption_start() -> None:
run_id = "run-resume"
pipeline = _build_pipeline(run_id)
pipeline._save_workflow_app_log = MagicMock()
event = QueueWorkflowStartedEvent(reason=WorkflowStartReason.RESUMPTION)
list(pipeline._handle_workflow_started_event(event))
pipeline._save_workflow_app_log.assert_not_called()
assert pipeline._workflow_execution_id == run_id

View File

@ -0,0 +1,143 @@
import json
from collections.abc import Callable
from dataclasses import dataclass
import pytest
from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
InvokeFrom,
WorkflowAppGenerateEntity,
)
from core.app.layers.pause_state_persist_layer import (
WorkflowResumptionContext,
_AdvancedChatAppGenerateEntityWrapper,
_WorkflowGenerateEntityWrapper,
)
from core.ops.ops_trace_manager import TraceQueueManager
from models.model import AppMode
class TraceQueueManagerStub(TraceQueueManager):
"""Minimal TraceQueueManager stub that avoids Flask dependencies."""
def __init__(self):
# Skip parent initialization to avoid starting timers or accessing Flask globals.
pass
def _build_workflow_app_config(app_mode: AppMode) -> WorkflowUIBasedAppConfig:
return WorkflowUIBasedAppConfig(
tenant_id="tenant-id",
app_id="app-id",
app_mode=app_mode,
workflow_id=f"{app_mode.value}-workflow-id",
)
def _create_workflow_generate_entity(trace_manager: TraceQueueManager | None = None) -> WorkflowAppGenerateEntity:
return WorkflowAppGenerateEntity(
task_id="workflow-task",
app_config=_build_workflow_app_config(AppMode.WORKFLOW),
inputs={"topic": "serialization"},
files=[],
user_id="user-workflow",
stream=True,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=1,
trace_manager=trace_manager,
workflow_execution_id="workflow-exec-id",
extras={"external_trace_id": "trace-id"},
)
def _create_advanced_chat_generate_entity(
trace_manager: TraceQueueManager | None = None,
) -> AdvancedChatAppGenerateEntity:
return AdvancedChatAppGenerateEntity(
task_id="advanced-task",
app_config=_build_workflow_app_config(AppMode.ADVANCED_CHAT),
conversation_id="conversation-id",
inputs={"topic": "roundtrip"},
files=[],
user_id="user-advanced",
stream=False,
invoke_from=InvokeFrom.DEBUGGER,
query="Explain serialization",
extras={"auto_generate_conversation_name": True},
trace_manager=trace_manager,
workflow_run_id="workflow-run-id",
)
def test_workflow_app_generate_entity_roundtrip_excludes_trace_manager():
entity = _create_workflow_generate_entity(trace_manager=TraceQueueManagerStub())
serialized = entity.model_dump_json()
payload = json.loads(serialized)
assert "trace_manager" not in payload
restored = WorkflowAppGenerateEntity.model_validate_json(serialized)
assert restored.model_dump() == entity.model_dump()
assert restored.trace_manager is None
def test_advanced_chat_generate_entity_roundtrip_excludes_trace_manager():
entity = _create_advanced_chat_generate_entity(trace_manager=TraceQueueManagerStub())
serialized = entity.model_dump_json()
payload = json.loads(serialized)
assert "trace_manager" not in payload
restored = AdvancedChatAppGenerateEntity.model_validate_json(serialized)
assert restored.model_dump() == entity.model_dump()
assert restored.trace_manager is None
@dataclass(frozen=True)
class ResumptionContextCase:
name: str
context_factory: Callable[[], tuple[WorkflowResumptionContext, type]]
def _workflow_resumption_case() -> tuple[WorkflowResumptionContext, type]:
entity = _create_workflow_generate_entity(trace_manager=TraceQueueManagerStub())
context = WorkflowResumptionContext(
serialized_graph_runtime_state=json.dumps({"state": "workflow"}),
generate_entity=_WorkflowGenerateEntityWrapper(entity=entity),
)
return context, WorkflowAppGenerateEntity
def _advanced_chat_resumption_case() -> tuple[WorkflowResumptionContext, type]:
entity = _create_advanced_chat_generate_entity(trace_manager=TraceQueueManagerStub())
context = WorkflowResumptionContext(
serialized_graph_runtime_state=json.dumps({"state": "advanced"}),
generate_entity=_AdvancedChatAppGenerateEntityWrapper(entity=entity),
)
return context, AdvancedChatAppGenerateEntity
@pytest.mark.parametrize(
"case",
[
pytest.param(ResumptionContextCase("workflow", _workflow_resumption_case), id="workflow"),
pytest.param(ResumptionContextCase("advanced_chat", _advanced_chat_resumption_case), id="advanced_chat"),
],
)
def test_workflow_resumption_context_roundtrip(case: ResumptionContextCase):
context, expected_type = case.context_factory()
serialized = context.dumps()
restored = WorkflowResumptionContext.loads(serialized)
assert restored.serialized_graph_runtime_state == context.serialized_graph_runtime_state
entity = restored.get_generate_entity()
assert isinstance(entity, expected_type)
assert entity.model_dump() == context.get_generate_entity().model_dump()
assert entity.trace_manager is None

View File

@ -0,0 +1,72 @@
from types import SimpleNamespace
from unittest.mock import MagicMock
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig
from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation
from models.model import AppMode
def test_invoke_chat_app_advanced_chat_injects_pause_state_config(mocker):
workflow = MagicMock()
workflow.created_by = "owner-id"
app = MagicMock()
app.mode = AppMode.ADVANCED_CHAT
app.workflow = workflow
mocker.patch(
"core.plugin.backwards_invocation.app.db",
SimpleNamespace(engine=MagicMock()),
)
generator_spy = mocker.patch(
"core.plugin.backwards_invocation.app.AdvancedChatAppGenerator.generate",
return_value={"result": "ok"},
)
result = PluginAppBackwardsInvocation.invoke_chat_app(
app=app,
user=MagicMock(),
conversation_id="conv-1",
query="hello",
stream=False,
inputs={"k": "v"},
files=[],
)
assert result == {"result": "ok"}
call_kwargs = generator_spy.call_args.kwargs
pause_state_config = call_kwargs.get("pause_state_config")
assert isinstance(pause_state_config, PauseStateLayerConfig)
assert pause_state_config.state_owner_user_id == "owner-id"
def test_invoke_workflow_app_injects_pause_state_config(mocker):
workflow = MagicMock()
workflow.created_by = "owner-id"
app = MagicMock()
app.mode = AppMode.WORKFLOW
app.workflow = workflow
mocker.patch(
"core.plugin.backwards_invocation.app.db",
SimpleNamespace(engine=MagicMock()),
)
generator_spy = mocker.patch(
"core.plugin.backwards_invocation.app.WorkflowAppGenerator.generate",
return_value={"result": "ok"},
)
result = PluginAppBackwardsInvocation.invoke_workflow_app(
app=app,
user=MagicMock(),
stream=False,
inputs={"k": "v"},
files=[],
)
assert result == {"result": "ok"}
call_kwargs = generator_spy.call_args.kwargs
pause_state_config = call_kwargs.get("pause_state_config")
assert isinstance(pause_state_config, PauseStateLayerConfig)
assert pause_state_config.state_owner_user_id == "owner-id"

View File

@ -0,0 +1,574 @@
"""Unit tests for HumanInputFormRepositoryImpl private helpers."""
from __future__ import annotations
import dataclasses
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from core.repositories.human_input_repository import (
HumanInputFormRecord,
HumanInputFormRepositoryImpl,
HumanInputFormSubmissionRepository,
_WorkspaceMemberInfo,
)
from core.workflow.nodes.human_input.entities import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,
ExternalRecipient,
FormDefinition,
MemberRecipient,
UserAction,
)
from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
from libs.datetime_utils import naive_utc_now
from models.human_input import (
EmailExternalRecipientPayload,
EmailMemberRecipientPayload,
HumanInputFormRecipient,
RecipientType,
)
def _build_repository() -> HumanInputFormRepositoryImpl:
return HumanInputFormRepositoryImpl(session_factory=MagicMock(), tenant_id="tenant-id")
def _patch_recipient_factory(monkeypatch: pytest.MonkeyPatch) -> list[SimpleNamespace]:
created: list[SimpleNamespace] = []
def fake_new(cls, form_id: str, delivery_id: str, payload): # type: ignore[no-untyped-def]
recipient = SimpleNamespace(
form_id=form_id,
delivery_id=delivery_id,
recipient_type=payload.TYPE,
recipient_payload=payload.model_dump_json(),
)
created.append(recipient)
return recipient
monkeypatch.setattr(HumanInputFormRecipient, "new", classmethod(fake_new))
return created
@pytest.fixture(autouse=True)
def _stub_selectinload(monkeypatch: pytest.MonkeyPatch) -> None:
"""Avoid SQLAlchemy mapper configuration in tests using fake sessions."""
class _FakeSelect:
def options(self, *_args, **_kwargs): # type: ignore[no-untyped-def]
return self
def where(self, *_args, **_kwargs): # type: ignore[no-untyped-def]
return self
monkeypatch.setattr(
"core.repositories.human_input_repository.selectinload", lambda *args, **kwargs: "_loader_option"
)
monkeypatch.setattr("core.repositories.human_input_repository.select", lambda *args, **kwargs: _FakeSelect())
class TestHumanInputFormRepositoryImplHelpers:
def test_build_email_recipients_with_member_and_external(self, monkeypatch: pytest.MonkeyPatch) -> None:
repo = _build_repository()
session_stub = object()
_patch_recipient_factory(monkeypatch)
def fake_query(self, session, restrict_to_user_ids): # type: ignore[no-untyped-def]
assert session is session_stub
assert restrict_to_user_ids == ["member-1"]
return [_WorkspaceMemberInfo(user_id="member-1", email="member@example.com")]
monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members_by_ids", fake_query)
recipients = repo._build_email_recipients(
session=session_stub,
form_id="form-id",
delivery_id="delivery-id",
recipients_config=EmailRecipients(
whole_workspace=False,
items=[
MemberRecipient(user_id="member-1"),
ExternalRecipient(email="external@example.com"),
],
),
)
assert len(recipients) == 2
member_recipient = next(r for r in recipients if r.recipient_type == RecipientType.EMAIL_MEMBER)
external_recipient = next(r for r in recipients if r.recipient_type == RecipientType.EMAIL_EXTERNAL)
member_payload = EmailMemberRecipientPayload.model_validate_json(member_recipient.recipient_payload)
assert member_payload.user_id == "member-1"
assert member_payload.email == "member@example.com"
external_payload = EmailExternalRecipientPayload.model_validate_json(external_recipient.recipient_payload)
assert external_payload.email == "external@example.com"
def test_build_email_recipients_skips_unknown_members(self, monkeypatch: pytest.MonkeyPatch) -> None:
repo = _build_repository()
session_stub = object()
created = _patch_recipient_factory(monkeypatch)
def fake_query(self, session, restrict_to_user_ids): # type: ignore[no-untyped-def]
assert session is session_stub
assert restrict_to_user_ids == ["missing-member"]
return []
monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members_by_ids", fake_query)
recipients = repo._build_email_recipients(
session=session_stub,
form_id="form-id",
delivery_id="delivery-id",
recipients_config=EmailRecipients(
whole_workspace=False,
items=[
MemberRecipient(user_id="missing-member"),
ExternalRecipient(email="external@example.com"),
],
),
)
assert len(recipients) == 1
assert recipients[0].recipient_type == RecipientType.EMAIL_EXTERNAL
assert len(created) == 1 # only external recipient created via factory
def test_build_email_recipients_whole_workspace_uses_all_members(self, monkeypatch: pytest.MonkeyPatch) -> None:
repo = _build_repository()
session_stub = object()
_patch_recipient_factory(monkeypatch)
def fake_query(self, session): # type: ignore[no-untyped-def]
assert session is session_stub
return [
_WorkspaceMemberInfo(user_id="member-1", email="member1@example.com"),
_WorkspaceMemberInfo(user_id="member-2", email="member2@example.com"),
]
monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_all_workspace_members", fake_query)
recipients = repo._build_email_recipients(
session=session_stub,
form_id="form-id",
delivery_id="delivery-id",
recipients_config=EmailRecipients(
whole_workspace=True,
items=[],
),
)
assert len(recipients) == 2
emails = {EmailMemberRecipientPayload.model_validate_json(r.recipient_payload).email for r in recipients}
assert emails == {"member1@example.com", "member2@example.com"}
def test_build_email_recipients_dedupes_external_by_email(self, monkeypatch: pytest.MonkeyPatch) -> None:
repo = _build_repository()
session_stub = object()
created = _patch_recipient_factory(monkeypatch)
def fake_query(self, session, restrict_to_user_ids): # type: ignore[no-untyped-def]
assert session is session_stub
assert restrict_to_user_ids == []
return []
monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members_by_ids", fake_query)
recipients = repo._build_email_recipients(
session=session_stub,
form_id="form-id",
delivery_id="delivery-id",
recipients_config=EmailRecipients(
whole_workspace=False,
items=[
ExternalRecipient(email="external@example.com"),
ExternalRecipient(email="external@example.com"),
],
),
)
assert len(recipients) == 1
assert len(created) == 1
def test_build_email_recipients_prefers_member_over_external_by_email(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
repo = _build_repository()
session_stub = object()
_patch_recipient_factory(monkeypatch)
def fake_query(self, session, restrict_to_user_ids): # type: ignore[no-untyped-def]
assert session is session_stub
assert restrict_to_user_ids == ["member-1"]
return [_WorkspaceMemberInfo(user_id="member-1", email="shared@example.com")]
monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members_by_ids", fake_query)
recipients = repo._build_email_recipients(
session=session_stub,
form_id="form-id",
delivery_id="delivery-id",
recipients_config=EmailRecipients(
whole_workspace=False,
items=[
MemberRecipient(user_id="member-1"),
ExternalRecipient(email="shared@example.com"),
],
),
)
assert len(recipients) == 1
assert recipients[0].recipient_type == RecipientType.EMAIL_MEMBER
def test_delivery_method_to_model_includes_external_recipients_with_whole_workspace(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
repo = _build_repository()
session_stub = object()
_patch_recipient_factory(monkeypatch)
def fake_query(self, session): # type: ignore[no-untyped-def]
assert session is session_stub
return [
_WorkspaceMemberInfo(user_id="member-1", email="member1@example.com"),
_WorkspaceMemberInfo(user_id="member-2", email="member2@example.com"),
]
monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_all_workspace_members", fake_query)
method = EmailDeliveryMethod(
config=EmailDeliveryConfig(
recipients=EmailRecipients(
whole_workspace=True,
items=[ExternalRecipient(email="external@example.com")],
),
subject="subject",
body="body",
)
)
result = repo._delivery_method_to_model(session=session_stub, form_id="form-id", delivery_method=method)
assert len(result.recipients) == 3
member_emails = {
EmailMemberRecipientPayload.model_validate_json(r.recipient_payload).email
for r in result.recipients
if r.recipient_type == RecipientType.EMAIL_MEMBER
}
assert member_emails == {"member1@example.com", "member2@example.com"}
external_payload = EmailExternalRecipientPayload.model_validate_json(
next(r for r in result.recipients if r.recipient_type == RecipientType.EMAIL_EXTERNAL).recipient_payload
)
assert external_payload.email == "external@example.com"
def _make_form_definition() -> str:
return FormDefinition(
form_content="hello",
inputs=[],
user_actions=[UserAction(id="submit", title="Submit")],
rendered_content="<p>hello</p>",
expiration_time=datetime.utcnow(),
).model_dump_json()
@dataclasses.dataclass
class _DummyForm:
id: str
workflow_run_id: str
node_id: str
tenant_id: str
app_id: str
form_definition: str
rendered_content: str
expiration_time: datetime
form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME
created_at: datetime = dataclasses.field(default_factory=naive_utc_now)
selected_action_id: str | None = None
submitted_data: str | None = None
submitted_at: datetime | None = None
submission_user_id: str | None = None
submission_end_user_id: str | None = None
completed_by_recipient_id: str | None = None
status: HumanInputFormStatus = HumanInputFormStatus.WAITING
@dataclasses.dataclass
class _DummyRecipient:
id: str
form_id: str
recipient_type: RecipientType
access_token: str
form: _DummyForm | None = None
class _FakeScalarResult:
def __init__(self, obj):
self._obj = obj
def first(self):
if isinstance(self._obj, list):
return self._obj[0] if self._obj else None
return self._obj
def all(self):
if isinstance(self._obj, list):
return list(self._obj)
if self._obj is None:
return []
return [self._obj]
class _FakeSession:
def __init__(
self,
*,
scalars_result=None,
scalars_results: list[object] | None = None,
forms: dict[str, _DummyForm] | None = None,
recipients: dict[str, _DummyRecipient] | None = None,
):
if scalars_results is not None:
self._scalars_queue = list(scalars_results)
elif scalars_result is not None:
self._scalars_queue = [scalars_result]
else:
self._scalars_queue = []
self.forms = forms or {}
self.recipients = recipients or {}
def scalars(self, _query):
if self._scalars_queue:
result = self._scalars_queue.pop(0)
else:
result = None
return _FakeScalarResult(result)
def get(self, model_cls, obj_id): # type: ignore[no-untyped-def]
if getattr(model_cls, "__name__", None) == "HumanInputForm":
return self.forms.get(obj_id)
if getattr(model_cls, "__name__", None) == "HumanInputFormRecipient":
return self.recipients.get(obj_id)
return None
def add(self, _obj):
return None
def flush(self):
return None
def refresh(self, _obj):
return None
def begin(self):
return self
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return None
def _session_factory(session: _FakeSession):
class _SessionContext:
def __enter__(self):
return session
def __exit__(self, exc_type, exc, tb):
return None
def _factory(*_args, **_kwargs):
return _SessionContext()
return _factory
class TestHumanInputFormRepositoryImplPublicMethods:
def test_get_form_returns_entity_and_recipients(self):
form = _DummyForm(
id="form-1",
workflow_run_id="run-1",
node_id="node-1",
tenant_id="tenant-id",
app_id="app-id",
form_definition=_make_form_definition(),
rendered_content="<p>hello</p>",
expiration_time=naive_utc_now(),
)
recipient = _DummyRecipient(
id="recipient-1",
form_id=form.id,
recipient_type=RecipientType.STANDALONE_WEB_APP,
access_token="token-123",
)
session = _FakeSession(scalars_results=[form, [recipient]])
repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id")
entity = repo.get_form(form.workflow_run_id, form.node_id)
assert entity is not None
assert entity.id == form.id
assert entity.web_app_token == "token-123"
assert len(entity.recipients) == 1
assert entity.recipients[0].token == "token-123"
def test_get_form_returns_none_when_missing(self):
session = _FakeSession(scalars_results=[None])
repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id")
assert repo.get_form("run-1", "node-1") is None
def test_get_form_returns_unsubmitted_state(self):
form = _DummyForm(
id="form-1",
workflow_run_id="run-1",
node_id="node-1",
tenant_id="tenant-id",
app_id="app-id",
form_definition=_make_form_definition(),
rendered_content="<p>hello</p>",
expiration_time=naive_utc_now(),
)
session = _FakeSession(scalars_results=[form, []])
repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id")
entity = repo.get_form(form.workflow_run_id, form.node_id)
assert entity is not None
assert entity.submitted is False
assert entity.selected_action_id is None
assert entity.submitted_data is None
def test_get_form_returns_submission_when_completed(self):
form = _DummyForm(
id="form-1",
workflow_run_id="run-1",
node_id="node-1",
tenant_id="tenant-id",
app_id="app-id",
form_definition=_make_form_definition(),
rendered_content="<p>hello</p>",
expiration_time=naive_utc_now(),
selected_action_id="approve",
submitted_data='{"field": "value"}',
submitted_at=naive_utc_now(),
)
session = _FakeSession(scalars_results=[form, []])
repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id")
entity = repo.get_form(form.workflow_run_id, form.node_id)
assert entity is not None
assert entity.submitted is True
assert entity.selected_action_id == "approve"
assert entity.submitted_data == {"field": "value"}
class TestHumanInputFormSubmissionRepository:
def test_get_by_token_returns_record(self):
form = _DummyForm(
id="form-1",
workflow_run_id="run-1",
node_id="node-1",
tenant_id="tenant-1",
app_id="app-1",
form_definition=_make_form_definition(),
rendered_content="<p>hello</p>",
expiration_time=naive_utc_now(),
)
recipient = _DummyRecipient(
id="recipient-1",
form_id=form.id,
recipient_type=RecipientType.STANDALONE_WEB_APP,
access_token="token-123",
form=form,
)
session = _FakeSession(scalars_result=recipient)
repo = HumanInputFormSubmissionRepository(_session_factory(session))
record = repo.get_by_token("token-123")
assert record is not None
assert record.form_id == form.id
assert record.recipient_type == RecipientType.STANDALONE_WEB_APP
assert record.submitted is False
def test_get_by_form_id_and_recipient_type_uses_recipient(self):
form = _DummyForm(
id="form-1",
workflow_run_id="run-1",
node_id="node-1",
tenant_id="tenant-1",
app_id="app-1",
form_definition=_make_form_definition(),
rendered_content="<p>hello</p>",
expiration_time=naive_utc_now(),
)
recipient = _DummyRecipient(
id="recipient-1",
form_id=form.id,
recipient_type=RecipientType.STANDALONE_WEB_APP,
access_token="token-123",
form=form,
)
session = _FakeSession(scalars_result=recipient)
repo = HumanInputFormSubmissionRepository(_session_factory(session))
record = repo.get_by_form_id_and_recipient_type(
form_id=form.id,
recipient_type=RecipientType.STANDALONE_WEB_APP,
)
assert record is not None
assert record.recipient_id == recipient.id
assert record.access_token == recipient.access_token
def test_mark_submitted_updates_fields(self, monkeypatch: pytest.MonkeyPatch):
fixed_now = datetime(2024, 1, 1, 0, 0, 0)
monkeypatch.setattr("core.repositories.human_input_repository.naive_utc_now", lambda: fixed_now)
form = _DummyForm(
id="form-1",
workflow_run_id="run-1",
node_id="node-1",
tenant_id="tenant-1",
app_id="app-1",
form_definition=_make_form_definition(),
rendered_content="<p>hello</p>",
expiration_time=fixed_now,
)
recipient = _DummyRecipient(
id="recipient-1",
form_id="form-1",
recipient_type=RecipientType.STANDALONE_WEB_APP,
access_token="token-123",
)
session = _FakeSession(
forms={form.id: form},
recipients={recipient.id: recipient},
)
repo = HumanInputFormSubmissionRepository(_session_factory(session))
record: HumanInputFormRecord = repo.mark_submitted(
form_id=form.id,
recipient_id=recipient.id,
selected_action_id="approve",
form_data={"field": "value"},
submission_user_id="user-1",
submission_end_user_id="end-user-1",
)
assert form.selected_action_id == "approve"
assert form.completed_by_recipient_id == recipient.id
assert form.submission_user_id == "user-1"
assert form.submission_end_user_id == "end-user-1"
assert form.submitted_at == fixed_now
assert record.submitted is True
assert record.selected_action_id == "approve"
assert record.submitted_data == {"field": "value"}

View File

@ -0,0 +1,33 @@
import pytest
from core.tools.errors import WorkflowToolHumanInputNotSupportedError
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
def test_ensure_no_human_input_nodes_passes_for_non_human_input():
graph = {
"nodes": [
{
"id": "start_node",
"data": {"type": "start"},
}
]
}
WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(graph)
def test_ensure_no_human_input_nodes_raises_for_human_input():
graph = {
"nodes": [
{
"id": "human_input_node",
"data": {"type": "human-input"},
}
]
}
with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info:
WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(graph)
assert exc_info.value.error_code == "workflow_tool_human_input_not_supported"

View File

@ -55,6 +55,43 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel
assert exc_info.value.args == ("oops",)
def test_workflow_tool_does_not_use_pause_state_config(monkeypatch: pytest.MonkeyPatch):
entity = ToolEntity(
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
parameters=[],
description=None,
has_runtime_parameters=False,
)
runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE)
tool = WorkflowTool(
workflow_app_id="",
workflow_as_tool_id="",
version="1",
workflow_entities={},
workflow_call_depth=1,
entity=entity,
runtime=runtime,
)
monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None)
monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
from unittest.mock import MagicMock, Mock
mock_user = Mock()
monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user)
generate_mock = MagicMock(return_value={"data": {}})
monkeypatch.setattr("core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", generate_mock)
monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None)
list(tool.invoke("test_user", {}))
call_kwargs = generate_mock.call_args.kwargs
assert "pause_state_config" in call_kwargs
assert call_kwargs["pause_state_config"] is None
def test_workflow_tool_should_generate_variable_messages_for_outputs(monkeypatch: pytest.MonkeyPatch):
"""Test that WorkflowTool should generate variable messages when there are outputs"""
entity = ToolEntity(

View File

@ -118,7 +118,6 @@ class TestGraphRuntimeState:
from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue
assert isinstance(queue, InMemoryReadyQueue)
assert state.ready_queue is queue
def test_graph_execution_lazy_instantiation(self):
state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time())

View File

@ -0,0 +1,88 @@
"""
Tests for PauseReason discriminated union serialization/deserialization.
"""
import pytest
from pydantic import BaseModel, ValidationError
from core.workflow.entities.pause_reason import (
HumanInputRequired,
PauseReason,
SchedulingPause,
)
class _Holder(BaseModel):
"""Helper model that embeds PauseReason for union tests."""
reason: PauseReason
class TestPauseReasonDiscriminator:
"""Test suite for PauseReason union discriminator."""
@pytest.mark.parametrize(
("dict_value", "expected"),
[
pytest.param(
{
"reason": {
"TYPE": "human_input_required",
"form_id": "form_id",
"form_content": "form_content",
"node_id": "node_id",
"node_title": "node_title",
},
},
HumanInputRequired(
form_id="form_id",
form_content="form_content",
node_id="node_id",
node_title="node_title",
),
id="HumanInputRequired",
),
pytest.param(
{
"reason": {
"TYPE": "scheduled_pause",
"message": "Hold on",
}
},
SchedulingPause(message="Hold on"),
id="SchedulingPause",
),
],
)
def test_model_validate(self, dict_value, expected):
"""Ensure scheduled pause payloads with lowercase TYPE deserialize."""
holder = _Holder.model_validate(dict_value)
assert type(holder.reason) == type(expected)
assert holder.reason == expected
@pytest.mark.parametrize(
"reason",
[
HumanInputRequired(
form_id="form_id",
form_content="form_content",
node_id="node_id",
node_title="node_title",
),
SchedulingPause(message="Hold on"),
],
ids=lambda x: type(x).__name__,
)
def test_model_construct(self, reason):
holder = _Holder(reason=reason)
assert holder.reason == reason
def test_model_construct_with_invalid_type(self):
with pytest.raises(ValidationError):
holder = _Holder(reason=object()) # type: ignore
def test_unknown_type_fails_validation(self):
"""Unknown TYPE values should raise a validation error."""
with pytest.raises(ValidationError):
_Holder.model_validate({"reason": {"TYPE": "UNKNOWN"}})

View File

@ -0,0 +1,131 @@
"""Utilities for testing HumanInputNode without database dependencies."""
from __future__ import annotations
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Any
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
from core.workflow.repositories.human_input_form_repository import (
FormCreateParams,
HumanInputFormEntity,
HumanInputFormRecipientEntity,
HumanInputFormRepository,
)
from libs.datetime_utils import naive_utc_now
class _InMemoryFormRecipient(HumanInputFormRecipientEntity):
"""Minimal recipient entity required by the repository interface."""
def __init__(self, recipient_id: str, token: str) -> None:
self._id = recipient_id
self._token = token
@property
def id(self) -> str:
return self._id
@property
def token(self) -> str:
return self._token
@dataclass
class _InMemoryFormEntity(HumanInputFormEntity):
form_id: str
rendered: str
token: str | None = None
action_id: str | None = None
data: Mapping[str, Any] | None = None
is_submitted: bool = False
status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING
expiration: datetime = naive_utc_now()
@property
def id(self) -> str:
return self.form_id
@property
def web_app_token(self) -> str | None:
return self.token
@property
def recipients(self) -> list[HumanInputFormRecipientEntity]:
return []
@property
def rendered_content(self) -> str:
return self.rendered
@property
def selected_action_id(self) -> str | None:
return self.action_id
@property
def submitted_data(self) -> Mapping[str, Any] | None:
return self.data
@property
def submitted(self) -> bool:
return self.is_submitted
@property
def status(self) -> HumanInputFormStatus:
return self.status_value
@property
def expiration_time(self) -> datetime:
return self.expiration
class InMemoryHumanInputFormRepository(HumanInputFormRepository):
"""Pure in-memory repository used by workflow graph engine tests."""
def __init__(self) -> None:
self._form_counter = 0
self.created_params: list[FormCreateParams] = []
self.created_forms: list[_InMemoryFormEntity] = []
self._forms_by_key: dict[tuple[str, str], _InMemoryFormEntity] = {}
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
self.created_params.append(params)
self._form_counter += 1
form_id = f"form-{self._form_counter}"
token = f"console-{form_id}" if params.console_recipient_required else f"token-{form_id}"
entity = _InMemoryFormEntity(
form_id=form_id,
rendered=params.rendered_content,
token=token,
)
self.created_forms.append(entity)
self._forms_by_key[(params.workflow_execution_id, params.node_id)] = entity
return entity
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
return self._forms_by_key.get((workflow_execution_id, node_id))
# Convenience helpers for tests -------------------------------------
def set_submission(self, *, action_id: str, form_data: Mapping[str, Any] | None = None) -> None:
"""Simulate a human submission for the next repository lookup."""
if not self.created_forms:
raise AssertionError("no form has been created to attach submission data")
entity = self.created_forms[-1]
entity.action_id = action_id
entity.data = form_data or {}
entity.is_submitted = True
entity.status_value = HumanInputFormStatus.SUBMITTED
entity.expiration = naive_utc_now() + timedelta(days=1)
def clear_submission(self) -> None:
if not self.created_forms:
return
for form in self.created_forms:
form.action_id = None
form.data = None
form.is_submitted = False
form.status_value = HumanInputFormStatus.WAITING

View File

@ -0,0 +1,74 @@
import queue
import threading
from datetime import datetime
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph_engine.orchestration.dispatcher import Dispatcher
from core.workflow.graph_events import NodeRunSucceededEvent
from core.workflow.node_events import NodeRunResult
class StubExecutionCoordinator:
def __init__(self, paused: bool) -> None:
self._paused = paused
self.mark_complete_called = False
self.failed_error: Exception | None = None
@property
def aborted(self) -> bool:
return False
@property
def paused(self) -> bool:
return self._paused
@property
def execution_complete(self) -> bool:
return False
def check_scaling(self) -> None:
return None
def process_commands(self) -> None:
return None
def mark_complete(self) -> None:
self.mark_complete_called = True
def mark_failed(self, error: Exception) -> None:
self.failed_error = error
class StubEventHandler:
def __init__(self) -> None:
self.events: list[object] = []
def dispatch(self, event: object) -> None:
self.events.append(event)
def test_dispatcher_drains_events_when_paused() -> None:
event_queue: queue.Queue = queue.Queue()
event = NodeRunSucceededEvent(
id="exec-1",
node_id="node-1",
node_type=NodeType.START,
start_at=datetime.utcnow(),
node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED),
)
event_queue.put(event)
handler = StubEventHandler()
coordinator = StubExecutionCoordinator(paused=True)
dispatcher = Dispatcher(
event_queue=event_queue,
event_handler=handler,
execution_coordinator=coordinator,
event_emitter=None,
stop_event=threading.Event(),
)
dispatcher._dispatcher_loop()
assert handler.events == [event]
assert coordinator.mark_complete_called is True

View File

@ -2,6 +2,8 @@
from unittest.mock import MagicMock
import pytest
from core.workflow.graph_engine.command_processing.command_processor import CommandProcessor
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
from core.workflow.graph_engine.graph_state_manager import GraphStateManager
@ -48,3 +50,13 @@ def test_handle_pause_noop_when_execution_running() -> None:
worker_pool.stop.assert_not_called()
state_manager.clear_executing.assert_not_called()
def test_has_executing_nodes_requires_pause() -> None:
graph_execution = GraphExecution(workflow_id="workflow")
graph_execution.start()
coordinator, _, _ = _build_coordinator(graph_execution)
with pytest.raises(AssertionError):
coordinator.has_executing_nodes()

View File

@ -0,0 +1,189 @@
import time
from collections.abc import Mapping
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.workflow.entities import GraphInitParams
from core.workflow.enums import NodeState
from core.workflow.graph import Graph
from core.workflow.graph_engine.graph_state_manager import GraphStateManager
from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.llm.entities import (
ContextConfig,
LLMNodeChatModelMessage,
LLMNodeData,
ModelConfig,
VisionConfig,
)
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from .test_mock_config import MockConfig
from .test_mock_nodes import MockLLMNode
def _build_runtime_state() -> GraphRuntimeState:
variable_pool = VariablePool(
system_variables=SystemVariable(
user_id="user",
app_id="app",
workflow_id="workflow",
workflow_execution_id="exec-1",
),
user_inputs={},
conversation_variables=[],
)
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
def _build_llm_node(
*,
node_id: str,
runtime_state: GraphRuntimeState,
graph_init_params: GraphInitParams,
mock_config: MockConfig,
) -> MockLLMNode:
llm_data = LLMNodeData(
title=f"LLM {node_id}",
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
prompt_template=[
LLMNodeChatModelMessage(
text=f"Prompt {node_id}",
role=PromptMessageRole.USER,
edition_type="basic",
)
],
context=ContextConfig(enabled=False, variable_selector=None),
vision=VisionConfig(enabled=False),
reasoning_format="tagged",
)
llm_config = {"id": node_id, "data": llm_data.model_dump()}
return MockLLMNode(
id=llm_config["id"],
config=llm_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
mock_config=mock_config,
)
def _build_graph(runtime_state: GraphRuntimeState) -> Graph:
graph_config: dict[str, object] = {"nodes": [], "edges": []}
graph_init_params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config=graph_config,
user_id="user",
user_from="account",
invoke_from="debugger",
call_depth=0,
)
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
start_node = StartNode(
id=start_config["id"],
config=start_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
mock_config = MockConfig()
llm_a = _build_llm_node(
node_id="llm_a",
runtime_state=runtime_state,
graph_init_params=graph_init_params,
mock_config=mock_config,
)
llm_b = _build_llm_node(
node_id="llm_b",
runtime_state=runtime_state,
graph_init_params=graph_init_params,
mock_config=mock_config,
)
end_data = EndNodeData(title="End", outputs=[], desc=None)
end_config = {"id": "end", "data": end_data.model_dump()}
end_node = EndNode(
id=end_config["id"],
config=end_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
builder = (
Graph.new()
.add_root(start_node)
.add_node(llm_a, from_node_id="start")
.add_node(llm_b, from_node_id="start")
.add_node(end_node, from_node_id="llm_a")
)
return builder.connect(tail="llm_b", head="end").build()
def _edge_state_map(graph: Graph) -> Mapping[tuple[str, str, str], NodeState]:
return {(edge.tail, edge.head, edge.source_handle): edge.state for edge in graph.edges.values()}
def test_runtime_state_snapshot_restores_graph_states() -> None:
runtime_state = _build_runtime_state()
graph = _build_graph(runtime_state)
runtime_state.attach_graph(graph)
graph.nodes["llm_a"].state = NodeState.TAKEN
graph.nodes["llm_b"].state = NodeState.SKIPPED
for edge in graph.edges.values():
if edge.tail == "start" and edge.head == "llm_a":
edge.state = NodeState.TAKEN
elif edge.tail == "start" and edge.head == "llm_b":
edge.state = NodeState.SKIPPED
elif edge.head == "end" and edge.tail == "llm_a":
edge.state = NodeState.TAKEN
elif edge.head == "end" and edge.tail == "llm_b":
edge.state = NodeState.SKIPPED
snapshot = runtime_state.dumps()
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
resumed_graph = _build_graph(resumed_state)
resumed_state.attach_graph(resumed_graph)
assert resumed_graph.nodes["llm_a"].state == NodeState.TAKEN
assert resumed_graph.nodes["llm_b"].state == NodeState.SKIPPED
assert _edge_state_map(resumed_graph) == _edge_state_map(graph)
def test_join_readiness_uses_restored_edge_states() -> None:
runtime_state = _build_runtime_state()
graph = _build_graph(runtime_state)
runtime_state.attach_graph(graph)
ready_queue = InMemoryReadyQueue()
state_manager = GraphStateManager(graph, ready_queue)
for edge in graph.get_incoming_edges("end"):
if edge.tail == "llm_a":
edge.state = NodeState.TAKEN
if edge.tail == "llm_b":
edge.state = NodeState.UNKNOWN
assert state_manager.is_node_ready("end") is False
for edge in graph.get_incoming_edges("end"):
if edge.tail == "llm_b":
edge.state = NodeState.TAKEN
assert state_manager.is_node_ready("end") is True
snapshot = runtime_state.dumps()
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
resumed_graph = _build_graph(resumed_state)
resumed_state.attach_graph(resumed_graph)
resumed_state_manager = GraphStateManager(resumed_graph, InMemoryReadyQueue())
assert resumed_state_manager.is_node_ready("end") is True

View File

@ -1,5 +1,7 @@
import datetime
import time
from collections.abc import Iterable
from unittest.mock import MagicMock
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.message_entities import PromptMessageRole
@ -14,11 +16,12 @@ from core.workflow.graph_events import (
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.graph_events.node import NodeRunHumanInputFormFilledEvent
from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.human_input import HumanInputNode
from core.workflow.nodes.human_input.entities import HumanInputNodeData
from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
from core.workflow.nodes.llm.entities import (
ContextConfig,
LLMNodeChatModelMessage,
@ -28,15 +31,21 @@ from core.workflow.nodes.llm.entities import (
)
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from libs.datetime_utils import naive_utc_now
from .test_mock_config import MockConfig
from .test_mock_nodes import MockLLMNode
from .test_table_runner import TableTestRunner, WorkflowTestCase
def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]:
def _build_branching_graph(
mock_config: MockConfig,
form_repository: HumanInputFormRepository,
graph_runtime_state: GraphRuntimeState | None = None,
) -> tuple[Graph, GraphRuntimeState]:
graph_config: dict[str, object] = {"nodes": [], "edges": []}
graph_init_params = GraphInitParams(
tenant_id="tenant",
@ -49,12 +58,18 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime
call_depth=0,
)
variable_pool = VariablePool(
system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"),
user_inputs={},
conversation_variables=[],
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
if graph_runtime_state is None:
variable_pool = VariablePool(
system_variables=SystemVariable(
user_id="user",
app_id="app",
workflow_id="workflow",
workflow_execution_id="test-execution-id",
),
user_inputs={},
conversation_variables=[],
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
start_node = StartNode(
@ -93,15 +108,21 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime
human_data = HumanInputNodeData(
title="Human Input",
required_variables=["human.input_ready"],
pause_reason="Awaiting human input",
form_content="Human input required",
inputs=[],
user_actions=[
UserAction(id="primary", title="Primary"),
UserAction(id="secondary", title="Secondary"),
],
)
human_config = {"id": "human", "data": human_data.model_dump()}
human_node = HumanInputNode(
id=human_config["id"],
config=human_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
form_repository=form_repository,
)
llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output")
@ -219,8 +240,18 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None:
for scenario in branch_scenarios:
runner = TableTestRunner()
def initial_graph_factory() -> tuple[Graph, GraphRuntimeState]:
return _build_branching_graph(mock_config)
mock_create_repo = MagicMock(spec=HumanInputFormRepository)
mock_create_repo.get_form.return_value = None
mock_form_entity = MagicMock(spec=HumanInputFormEntity)
mock_form_entity.id = "test_form_id"
mock_form_entity.web_app_token = "test_web_app_token"
mock_form_entity.recipients = []
mock_form_entity.rendered_content = "rendered"
mock_form_entity.submitted = False
mock_create_repo.create_form.return_value = mock_form_entity
def initial_graph_factory(mock_create_repo=mock_create_repo) -> tuple[Graph, GraphRuntimeState]:
return _build_branching_graph(mock_config, mock_create_repo)
initial_case = WorkflowTestCase(
description="HumanInput pause before branching decision",
@ -242,23 +273,16 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None:
assert initial_result.success, initial_result.event_mismatch_details
assert not any(isinstance(event, NodeRunStreamChunkEvent) for event in initial_result.events)
graph_runtime_state = initial_result.graph_runtime_state
graph = initial_result.graph
assert graph_runtime_state is not None
assert graph is not None
graph_runtime_state.variable_pool.add(("human", "input_ready"), True)
graph_runtime_state.variable_pool.add(("human", "edge_source_handle"), scenario["handle"])
graph_runtime_state.graph_execution.pause_reason = None
pre_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_pre_chunks"])
post_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_post_chunks"])
expected_pre_chunk_events_in_resumption = [
GraphRunStartedEvent,
NodeRunStartedEvent,
NodeRunHumanInputFormFilledEvent,
]
expected_resume_sequence: list[type] = (
[
GraphRunStartedEvent,
NodeRunStartedEvent,
]
expected_pre_chunk_events_in_resumption
+ [NodeRunStreamChunkEvent] * pre_chunk_count
+ [
NodeRunSucceededEvent,
@ -273,11 +297,25 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None:
]
)
mock_get_repo = MagicMock(spec=HumanInputFormRepository)
submitted_form = MagicMock(spec=HumanInputFormEntity)
submitted_form.id = mock_form_entity.id
submitted_form.web_app_token = mock_form_entity.web_app_token
submitted_form.recipients = []
submitted_form.rendered_content = mock_form_entity.rendered_content
submitted_form.submitted = True
submitted_form.selected_action_id = scenario["handle"]
submitted_form.submitted_data = {}
submitted_form.expiration_time = naive_utc_now() + datetime.timedelta(days=1)
mock_get_repo.get_form.return_value = submitted_form
def resume_graph_factory(
graph_snapshot: Graph = graph,
state_snapshot: GraphRuntimeState = graph_runtime_state,
initial_result=initial_result, mock_get_repo=mock_get_repo
) -> tuple[Graph, GraphRuntimeState]:
return graph_snapshot, state_snapshot
assert initial_result.graph_runtime_state is not None
serialized_runtime_state = initial_result.graph_runtime_state.dumps()
resume_runtime_state = GraphRuntimeState.from_snapshot(serialized_runtime_state)
return _build_branching_graph(mock_config, mock_get_repo, resume_runtime_state)
resume_case = WorkflowTestCase(
description=f"HumanInput resumes via {scenario['handle']} branch",
@ -321,7 +359,8 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None:
for index, event in enumerate(resume_events)
if isinstance(event, NodeRunStreamChunkEvent) and index < human_success_index
]
assert pre_indices == list(range(2, 2 + pre_chunk_count))
expected_pre_chunk_events_count_in_resumption = len(expected_pre_chunk_events_in_resumption)
assert pre_indices == list(range(expected_pre_chunk_events_count_in_resumption, human_success_index))
resume_chunk_indices = [
index

View File

@ -1,4 +1,6 @@
import datetime
import time
from unittest.mock import MagicMock
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.message_entities import PromptMessageRole
@ -13,11 +15,12 @@ from core.workflow.graph_events import (
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.graph_events.node import NodeRunHumanInputFormFilledEvent
from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.human_input import HumanInputNode
from core.workflow.nodes.human_input.entities import HumanInputNodeData
from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
from core.workflow.nodes.llm.entities import (
ContextConfig,
LLMNodeChatModelMessage,
@ -27,15 +30,21 @@ from core.workflow.nodes.llm.entities import (
)
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from libs.datetime_utils import naive_utc_now
from .test_mock_config import MockConfig
from .test_mock_nodes import MockLLMNode
from .test_table_runner import TableTestRunner, WorkflowTestCase
def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]:
def _build_llm_human_llm_graph(
mock_config: MockConfig,
form_repository: HumanInputFormRepository,
graph_runtime_state: GraphRuntimeState | None = None,
) -> tuple[Graph, GraphRuntimeState]:
graph_config: dict[str, object] = {"nodes": [], "edges": []}
graph_init_params = GraphInitParams(
tenant_id="tenant",
@ -48,12 +57,15 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun
call_depth=0,
)
variable_pool = VariablePool(
system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"),
user_inputs={},
conversation_variables=[],
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
if graph_runtime_state is None:
variable_pool = VariablePool(
system_variables=SystemVariable(
user_id="user", app_id="app", workflow_id="workflow", workflow_execution_id="test-execution-id,"
),
user_inputs={},
conversation_variables=[],
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
start_node = StartNode(
@ -92,15 +104,21 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun
human_data = HumanInputNodeData(
title="Human Input",
required_variables=["human.input_ready"],
pause_reason="Awaiting human input",
form_content="Human input required",
inputs=[],
user_actions=[
UserAction(id="accept", title="Accept"),
UserAction(id="reject", title="Reject"),
],
)
human_config = {"id": "human", "data": human_data.model_dump()}
human_node = HumanInputNode(
id=human_config["id"],
config=human_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
form_repository=form_repository,
)
llm_second = _create_llm_node("llm_resume", "Follow-up LLM", "Follow-up prompt")
@ -130,7 +148,7 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun
.add_root(start_node)
.add_node(llm_first)
.add_node(human_node)
.add_node(llm_second)
.add_node(llm_second, source_handle="accept")
.add_node(end_node)
.build()
)
@ -167,8 +185,18 @@ def test_human_input_llm_streaming_order_across_pause() -> None:
GraphRunPausedEvent, # graph run pauses awaiting resume
]
mock_create_repo = MagicMock(spec=HumanInputFormRepository)
mock_create_repo.get_form.return_value = None
mock_form_entity = MagicMock(spec=HumanInputFormEntity)
mock_form_entity.id = "test_form_id"
mock_form_entity.web_app_token = "test_web_app_token"
mock_form_entity.recipients = []
mock_form_entity.rendered_content = "rendered"
mock_form_entity.submitted = False
mock_create_repo.create_form.return_value = mock_form_entity
def graph_factory() -> tuple[Graph, GraphRuntimeState]:
return _build_llm_human_llm_graph(mock_config)
return _build_llm_human_llm_graph(mock_config, mock_create_repo)
initial_case = WorkflowTestCase(
description="HumanInput pause preserves LLM streaming order",
@ -210,6 +238,8 @@ def test_human_input_llm_streaming_order_across_pause() -> None:
expected_resume_sequence: list[type] = [
GraphRunStartedEvent, # resumed graph run begins
NodeRunStartedEvent, # human node restarts
# Form Filled should be generated first, then the node execution ends and stream chunk is generated.
NodeRunHumanInputFormFilledEvent,
NodeRunStreamChunkEvent, # cached llm_initial chunk 1
NodeRunStreamChunkEvent, # cached llm_initial chunk 2
NodeRunStreamChunkEvent, # cached llm_initial final chunk
@ -225,12 +255,27 @@ def test_human_input_llm_streaming_order_across_pause() -> None:
GraphRunSucceededEvent, # graph run succeeds after resume
]
mock_get_repo = MagicMock(spec=HumanInputFormRepository)
submitted_form = MagicMock(spec=HumanInputFormEntity)
submitted_form.id = mock_form_entity.id
submitted_form.web_app_token = mock_form_entity.web_app_token
submitted_form.recipients = []
submitted_form.rendered_content = mock_form_entity.rendered_content
submitted_form.submitted = True
submitted_form.selected_action_id = "accept"
submitted_form.submitted_data = {}
submitted_form.expiration_time = naive_utc_now() + datetime.timedelta(days=1)
mock_get_repo.get_form.return_value = submitted_form
def resume_graph_factory() -> tuple[Graph, GraphRuntimeState]:
assert graph_runtime_state is not None
assert graph is not None
graph_runtime_state.variable_pool.add(("human", "input_ready"), True)
graph_runtime_state.graph_execution.pause_reason = None
return graph, graph_runtime_state
# restruct the graph runtime state
serialized_runtime_state = initial_result.graph_runtime_state.dumps()
resume_runtime_state = GraphRuntimeState.from_snapshot(serialized_runtime_state)
return _build_llm_human_llm_graph(
mock_config,
mock_get_repo,
resume_runtime_state,
)
resume_case = WorkflowTestCase(
description="HumanInput resume continues LLM streaming order",

View File

@ -0,0 +1,270 @@
import time
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Any, Protocol
from core.workflow.entities import GraphInitParams
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.graph import Graph
from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel
from core.workflow.graph_engine.config import GraphEngineConfig
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.graph_events import (
GraphRunPausedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunSucceededEvent,
)
from core.workflow.nodes.base.entities import OutputVariableEntity
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.repositories.human_input_form_repository import (
FormCreateParams,
HumanInputFormEntity,
HumanInputFormRepository,
)
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from libs.datetime_utils import naive_utc_now
class PauseStateStore(Protocol):
def save(self, runtime_state: GraphRuntimeState) -> None: ...
def load(self) -> GraphRuntimeState: ...
class InMemoryPauseStore:
def __init__(self) -> None:
self._snapshot: str | None = None
def save(self, runtime_state: GraphRuntimeState) -> None:
self._snapshot = runtime_state.dumps()
def load(self) -> GraphRuntimeState:
assert self._snapshot is not None
return GraphRuntimeState.from_snapshot(self._snapshot)
@dataclass
class StaticForm(HumanInputFormEntity):
form_id: str
rendered: str
is_submitted: bool
action_id: str | None = None
data: Mapping[str, Any] | None = None
status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING
expiration: datetime = naive_utc_now() + timedelta(days=1)
@property
def id(self) -> str:
return self.form_id
@property
def web_app_token(self) -> str | None:
return "token"
@property
def recipients(self) -> list:
return []
@property
def rendered_content(self) -> str:
return self.rendered
@property
def selected_action_id(self) -> str | None:
return self.action_id
@property
def submitted_data(self) -> Mapping[str, Any] | None:
return self.data
@property
def submitted(self) -> bool:
return self.is_submitted
@property
def status(self) -> HumanInputFormStatus:
return self.status_value
@property
def expiration_time(self) -> datetime:
return self.expiration
class StaticRepo(HumanInputFormRepository):
def __init__(self, forms_by_node_id: Mapping[str, HumanInputFormEntity]) -> None:
self._forms_by_node_id = dict(forms_by_node_id)
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
return self._forms_by_node_id.get(node_id)
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
raise AssertionError("create_form should not be called in resume scenario")
def _build_runtime_state() -> GraphRuntimeState:
variable_pool = VariablePool(
system_variables=SystemVariable(
user_id="user",
app_id="app",
workflow_id="workflow",
workflow_execution_id="exec-1",
),
user_inputs={},
conversation_variables=[],
)
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository) -> Graph:
graph_config: dict[str, object] = {"nodes": [], "edges": []}
graph_init_params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config=graph_config,
user_id="user",
user_from="account",
invoke_from="debugger",
call_depth=0,
)
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
start_node = StartNode(
id=start_config["id"],
config=start_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
human_data = HumanInputNodeData(
title="Human Input",
form_content="Human input required",
inputs=[],
user_actions=[UserAction(id="approve", title="Approve")],
)
human_a_config = {"id": "human_a", "data": human_data.model_dump()}
human_a = HumanInputNode(
id=human_a_config["id"],
config=human_a_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
form_repository=repo,
)
human_b_config = {"id": "human_b", "data": human_data.model_dump()}
human_b = HumanInputNode(
id=human_b_config["id"],
config=human_b_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
form_repository=repo,
)
end_data = EndNodeData(
title="End",
outputs=[
OutputVariableEntity(variable="res_a", value_selector=["human_a", "__action_id"]),
OutputVariableEntity(variable="res_b", value_selector=["human_b", "__action_id"]),
],
desc=None,
)
end_config = {"id": "end", "data": end_data.model_dump()}
end_node = EndNode(
id=end_config["id"],
config=end_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
builder = (
Graph.new()
.add_root(start_node)
.add_node(human_a, from_node_id="start")
.add_node(human_b, from_node_id="start")
.add_node(end_node, from_node_id="human_a", source_handle="approve")
)
return builder.connect(tail="human_b", head="end", source_handle="approve").build()
def _run_graph(graph: Graph, runtime_state: GraphRuntimeState) -> list[object]:
engine = GraphEngine(
workflow_id="workflow",
graph=graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(
min_workers=2,
max_workers=2,
scale_up_threshold=1,
scale_down_idle_time=30.0,
),
)
return list(engine.run())
def _form(submitted: bool, action_id: str | None) -> StaticForm:
return StaticForm(
form_id="form",
rendered="rendered",
is_submitted=submitted,
action_id=action_id,
data={},
status_value=HumanInputFormStatus.SUBMITTED if submitted else HumanInputFormStatus.WAITING,
)
def test_parallel_human_input_join_completes_after_second_resume() -> None:
pause_store: PauseStateStore = InMemoryPauseStore()
initial_state = _build_runtime_state()
initial_repo = StaticRepo(
{
"human_a": _form(submitted=False, action_id=None),
"human_b": _form(submitted=False, action_id=None),
}
)
initial_graph = _build_graph(initial_state, initial_repo)
initial_events = _run_graph(initial_graph, initial_state)
assert isinstance(initial_events[-1], GraphRunPausedEvent)
pause_store.save(initial_state)
first_resume_state = pause_store.load()
first_resume_repo = StaticRepo(
{
"human_a": _form(submitted=True, action_id="approve"),
"human_b": _form(submitted=False, action_id=None),
}
)
first_resume_graph = _build_graph(first_resume_state, first_resume_repo)
first_resume_events = _run_graph(first_resume_graph, first_resume_state)
assert isinstance(first_resume_events[0], GraphRunStartedEvent)
assert first_resume_events[0].reason is WorkflowStartReason.RESUMPTION
assert isinstance(first_resume_events[-1], GraphRunPausedEvent)
pause_store.save(first_resume_state)
second_resume_state = pause_store.load()
second_resume_repo = StaticRepo(
{
"human_a": _form(submitted=True, action_id="approve"),
"human_b": _form(submitted=True, action_id="approve"),
}
)
second_resume_graph = _build_graph(second_resume_state, second_resume_repo)
second_resume_events = _run_graph(second_resume_graph, second_resume_state)
assert isinstance(second_resume_events[0], GraphRunStartedEvent)
assert second_resume_events[0].reason is WorkflowStartReason.RESUMPTION
assert isinstance(second_resume_events[-1], GraphRunSucceededEvent)
assert any(isinstance(event, NodeRunSucceededEvent) and event.node_id == "end" for event in second_resume_events)

View File

@ -0,0 +1,333 @@
import time
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Any
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.workflow.entities import GraphInitParams
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.graph import Graph
from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel
from core.workflow.graph_engine.config import GraphEngineConfig
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.graph_events import (
GraphRunPausedEvent,
GraphRunStartedEvent,
NodeRunPauseRequestedEvent,
NodeRunStartedEvent,
NodeRunSucceededEvent,
)
from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
from core.workflow.nodes.llm.entities import (
ContextConfig,
LLMNodeChatModelMessage,
LLMNodeData,
ModelConfig,
VisionConfig,
)
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.repositories.human_input_form_repository import (
FormCreateParams,
HumanInputFormEntity,
HumanInputFormRepository,
)
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from libs.datetime_utils import naive_utc_now
from .test_mock_config import MockConfig, NodeMockConfig
from .test_mock_nodes import MockLLMNode
@dataclass
class StaticForm(HumanInputFormEntity):
form_id: str
rendered: str
is_submitted: bool
action_id: str | None = None
data: Mapping[str, Any] | None = None
status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING
expiration: datetime = naive_utc_now() + timedelta(days=1)
@property
def id(self) -> str:
return self.form_id
@property
def web_app_token(self) -> str | None:
return "token"
@property
def recipients(self) -> list:
return []
@property
def rendered_content(self) -> str:
return self.rendered
@property
def selected_action_id(self) -> str | None:
return self.action_id
@property
def submitted_data(self) -> Mapping[str, Any] | None:
return self.data
@property
def submitted(self) -> bool:
return self.is_submitted
@property
def status(self) -> HumanInputFormStatus:
return self.status_value
@property
def expiration_time(self) -> datetime:
return self.expiration
class StaticRepo(HumanInputFormRepository):
def __init__(self, forms_by_node_id: Mapping[str, HumanInputFormEntity]) -> None:
self._forms_by_node_id = dict(forms_by_node_id)
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
return self._forms_by_node_id.get(node_id)
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
raise AssertionError("create_form should not be called in resume scenario")
class DelayedHumanInputNode(HumanInputNode):
def __init__(self, delay_seconds: float, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._delay_seconds = delay_seconds
def _run(self):
if self._delay_seconds > 0:
time.sleep(self._delay_seconds)
yield from super()._run()
def _build_runtime_state() -> GraphRuntimeState:
variable_pool = VariablePool(
system_variables=SystemVariable(
user_id="user",
app_id="app",
workflow_id="workflow",
workflow_execution_id="exec-1",
),
user_inputs={},
conversation_variables=[],
)
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository, mock_config: MockConfig) -> Graph:
graph_config: dict[str, object] = {"nodes": [], "edges": []}
graph_init_params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config=graph_config,
user_id="user",
user_from="account",
invoke_from="debugger",
call_depth=0,
)
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
start_node = StartNode(
id=start_config["id"],
config=start_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
human_data = HumanInputNodeData(
title="Human Input",
form_content="Human input required",
inputs=[],
user_actions=[UserAction(id="approve", title="Approve")],
)
human_a_config = {"id": "human_a", "data": human_data.model_dump()}
human_a = HumanInputNode(
id=human_a_config["id"],
config=human_a_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
form_repository=repo,
)
human_b_config = {"id": "human_b", "data": human_data.model_dump()}
human_b = DelayedHumanInputNode(
id=human_b_config["id"],
config=human_b_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
form_repository=repo,
delay_seconds=0.2,
)
llm_data = LLMNodeData(
title="LLM A",
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
prompt_template=[
LLMNodeChatModelMessage(
text="Prompt A",
role=PromptMessageRole.USER,
edition_type="basic",
)
],
context=ContextConfig(enabled=False, variable_selector=None),
vision=VisionConfig(enabled=False),
reasoning_format="tagged",
structured_output_enabled=False,
)
llm_config = {"id": "llm_a", "data": llm_data.model_dump()}
llm_a = MockLLMNode(
id=llm_config["id"],
config=llm_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
mock_config=mock_config,
)
return (
Graph.new()
.add_root(start_node)
.add_node(human_a, from_node_id="start")
.add_node(human_b, from_node_id="start")
.add_node(llm_a, from_node_id="human_a", source_handle="approve")
.build()
)
def test_parallel_human_input_pause_preserves_node_finished() -> None:
runtime_state = _build_runtime_state()
runtime_state.graph_execution.start()
runtime_state.register_paused_node("human_a")
runtime_state.register_paused_node("human_b")
submitted = StaticForm(
form_id="form-a",
rendered="rendered",
is_submitted=True,
action_id="approve",
data={},
status_value=HumanInputFormStatus.SUBMITTED,
)
pending = StaticForm(
form_id="form-b",
rendered="rendered",
is_submitted=False,
action_id=None,
data=None,
status_value=HumanInputFormStatus.WAITING,
)
repo = StaticRepo({"human_a": submitted, "human_b": pending})
mock_config = MockConfig()
mock_config.simulate_delays = True
mock_config.set_node_config(
"llm_a",
NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5),
)
graph = _build_graph(runtime_state, repo, mock_config)
engine = GraphEngine(
workflow_id="workflow",
graph=graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(
min_workers=2,
max_workers=2,
scale_up_threshold=1,
scale_down_idle_time=30.0,
),
)
events = list(engine.run())
llm_started = any(isinstance(e, NodeRunStartedEvent) and e.node_id == "llm_a" for e in events)
llm_succeeded = any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in events)
human_b_pause = any(isinstance(e, NodeRunPauseRequestedEvent) and e.node_id == "human_b" for e in events)
graph_paused = any(isinstance(e, GraphRunPausedEvent) for e in events)
graph_started = any(isinstance(e, GraphRunStartedEvent) for e in events)
assert graph_started
assert graph_paused
assert human_b_pause
assert llm_started
assert llm_succeeded
def test_parallel_human_input_pause_preserves_node_finished_after_snapshot_resume() -> None:
base_state = _build_runtime_state()
base_state.graph_execution.start()
base_state.register_paused_node("human_a")
base_state.register_paused_node("human_b")
snapshot = base_state.dumps()
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
submitted = StaticForm(
form_id="form-a",
rendered="rendered",
is_submitted=True,
action_id="approve",
data={},
status_value=HumanInputFormStatus.SUBMITTED,
)
pending = StaticForm(
form_id="form-b",
rendered="rendered",
is_submitted=False,
action_id=None,
data=None,
status_value=HumanInputFormStatus.WAITING,
)
repo = StaticRepo({"human_a": submitted, "human_b": pending})
mock_config = MockConfig()
mock_config.simulate_delays = True
mock_config.set_node_config(
"llm_a",
NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5),
)
graph = _build_graph(resumed_state, repo, mock_config)
engine = GraphEngine(
workflow_id="workflow",
graph=graph,
graph_runtime_state=resumed_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(
min_workers=2,
max_workers=2,
scale_up_threshold=1,
scale_down_idle_time=30.0,
),
)
events = list(engine.run())
start_event = next(e for e in events if isinstance(e, GraphRunStartedEvent))
assert start_event.reason is WorkflowStartReason.RESUMPTION
llm_started = any(isinstance(e, NodeRunStartedEvent) and e.node_id == "llm_a" for e in events)
llm_succeeded = any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in events)
human_b_pause = any(isinstance(e, NodeRunPauseRequestedEvent) and e.node_id == "human_b" for e in events)
graph_paused = any(isinstance(e, GraphRunPausedEvent) for e in events)
assert graph_paused
assert human_b_pause
assert llm_started
assert llm_succeeded

View File

@ -0,0 +1,309 @@
import time
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Any
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.workflow.entities import GraphInitParams
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.graph import Graph
from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel
from core.workflow.graph_engine.config import GraphEngineConfig
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.graph_events import (
GraphRunPausedEvent,
GraphRunStartedEvent,
NodeRunStartedEvent,
NodeRunSucceededEvent,
)
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
from core.workflow.nodes.llm.entities import (
ContextConfig,
LLMNodeChatModelMessage,
LLMNodeData,
ModelConfig,
VisionConfig,
)
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.repositories.human_input_form_repository import (
FormCreateParams,
HumanInputFormEntity,
HumanInputFormRepository,
)
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from libs.datetime_utils import naive_utc_now
from .test_mock_config import MockConfig, NodeMockConfig
from .test_mock_nodes import MockLLMNode
@dataclass
class StaticForm(HumanInputFormEntity):
form_id: str
rendered: str
is_submitted: bool
action_id: str | None = None
data: Mapping[str, Any] | None = None
status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING
expiration: datetime = naive_utc_now() + timedelta(days=1)
@property
def id(self) -> str:
return self.form_id
@property
def web_app_token(self) -> str | None:
return "token"
@property
def recipients(self) -> list:
return []
@property
def rendered_content(self) -> str:
return self.rendered
@property
def selected_action_id(self) -> str | None:
return self.action_id
@property
def submitted_data(self) -> Mapping[str, Any] | None:
return self.data
@property
def submitted(self) -> bool:
return self.is_submitted
@property
def status(self) -> HumanInputFormStatus:
return self.status_value
@property
def expiration_time(self) -> datetime:
return self.expiration
class StaticRepo(HumanInputFormRepository):
def __init__(self, form: HumanInputFormEntity) -> None:
self._form = form
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
if node_id != "human_pause":
return None
return self._form
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
raise AssertionError("create_form should not be called in this test")
def _build_runtime_state() -> GraphRuntimeState:
variable_pool = VariablePool(
system_variables=SystemVariable(
user_id="user",
app_id="app",
workflow_id="workflow",
workflow_execution_id="exec-1",
),
user_inputs={},
conversation_variables=[],
)
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository, mock_config: MockConfig) -> Graph:
graph_config: dict[str, object] = {"nodes": [], "edges": []}
graph_init_params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config=graph_config,
user_id="user",
user_from="account",
invoke_from="debugger",
call_depth=0,
)
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
start_node = StartNode(
id=start_config["id"],
config=start_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
llm_a_data = LLMNodeData(
title="LLM A",
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
prompt_template=[
LLMNodeChatModelMessage(
text="Prompt A",
role=PromptMessageRole.USER,
edition_type="basic",
)
],
context=ContextConfig(enabled=False, variable_selector=None),
vision=VisionConfig(enabled=False),
reasoning_format="tagged",
structured_output_enabled=False,
)
llm_a_config = {"id": "llm_a", "data": llm_a_data.model_dump()}
llm_a = MockLLMNode(
id=llm_a_config["id"],
config=llm_a_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
mock_config=mock_config,
)
llm_b_data = LLMNodeData(
title="LLM B",
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
prompt_template=[
LLMNodeChatModelMessage(
text="Prompt B",
role=PromptMessageRole.USER,
edition_type="basic",
)
],
context=ContextConfig(enabled=False, variable_selector=None),
vision=VisionConfig(enabled=False),
reasoning_format="tagged",
structured_output_enabled=False,
)
llm_b_config = {"id": "llm_b", "data": llm_b_data.model_dump()}
llm_b = MockLLMNode(
id=llm_b_config["id"],
config=llm_b_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
mock_config=mock_config,
)
human_data = HumanInputNodeData(
title="Human Input",
form_content="Pause here",
inputs=[],
user_actions=[UserAction(id="approve", title="Approve")],
)
human_config = {"id": "human_pause", "data": human_data.model_dump()}
human_node = HumanInputNode(
id=human_config["id"],
config=human_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
form_repository=repo,
)
end_human_data = EndNodeData(title="End Human", outputs=[], desc=None)
end_human_config = {"id": "end_human", "data": end_human_data.model_dump()}
end_human = EndNode(
id=end_human_config["id"],
config=end_human_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
return (
Graph.new()
.add_root(start_node)
.add_node(llm_a, from_node_id="start")
.add_node(human_node, from_node_id="start")
.add_node(llm_b, from_node_id="llm_a")
.add_node(end_human, from_node_id="human_pause", source_handle="approve")
.build()
)
def _get_node_started_event(events: list[object], node_id: str) -> NodeRunStartedEvent | None:
for event in events:
if isinstance(event, NodeRunStartedEvent) and event.node_id == node_id:
return event
return None
def test_pause_defers_ready_nodes_until_resume() -> None:
runtime_state = _build_runtime_state()
paused_form = StaticForm(
form_id="form-pause",
rendered="rendered",
is_submitted=False,
status_value=HumanInputFormStatus.WAITING,
)
pause_repo = StaticRepo(paused_form)
mock_config = MockConfig()
mock_config.simulate_delays = True
mock_config.set_node_config(
"llm_a",
NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5),
)
mock_config.set_node_config(
"llm_b",
NodeMockConfig(node_id="llm_b", outputs={"text": "LLM B output"}, delay=0.0),
)
graph = _build_graph(runtime_state, pause_repo, mock_config)
engine = GraphEngine(
workflow_id="workflow",
graph=graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(
min_workers=2,
max_workers=2,
scale_up_threshold=1,
scale_down_idle_time=30.0,
),
)
paused_events = list(engine.run())
assert any(isinstance(e, GraphRunPausedEvent) for e in paused_events)
assert any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in paused_events)
assert _get_node_started_event(paused_events, "llm_b") is None
snapshot = runtime_state.dumps()
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
submitted_form = StaticForm(
form_id="form-pause",
rendered="rendered",
is_submitted=True,
action_id="approve",
data={},
status_value=HumanInputFormStatus.SUBMITTED,
)
resume_repo = StaticRepo(submitted_form)
resumed_graph = _build_graph(resumed_state, resume_repo, mock_config)
resumed_engine = GraphEngine(
workflow_id="workflow",
graph=resumed_graph,
graph_runtime_state=resumed_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(
min_workers=2,
max_workers=2,
scale_up_threshold=1,
scale_down_idle_time=30.0,
),
)
resumed_events = list(resumed_engine.run())
start_event = next(e for e in resumed_events if isinstance(e, GraphRunStartedEvent))
assert start_event.reason is WorkflowStartReason.RESUMPTION
llm_b_started = _get_node_started_event(resumed_events, "llm_b")
assert llm_b_started is not None
assert any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_b" for e in resumed_events)

View File

@ -0,0 +1,217 @@
import datetime
import time
from typing import Any
from unittest.mock import MagicMock
from core.workflow.entities import GraphInitParams
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.graph import Graph
from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.graph_events import (
GraphEngineEvent,
GraphRunPausedEvent,
GraphRunSucceededEvent,
NodeRunStartedEvent,
NodeRunSucceededEvent,
)
from core.workflow.graph_events.graph import GraphRunStartedEvent
from core.workflow.nodes.base.entities import OutputVariableEntity
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.repositories.human_input_form_repository import (
HumanInputFormEntity,
HumanInputFormRepository,
)
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from libs.datetime_utils import naive_utc_now
def _build_runtime_state() -> GraphRuntimeState:
variable_pool = VariablePool(
system_variables=SystemVariable(
user_id="user",
app_id="app",
workflow_id="workflow",
workflow_execution_id="test-execution-id",
),
user_inputs={},
conversation_variables=[],
)
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepository:
repo = MagicMock(spec=HumanInputFormRepository)
form_entity = MagicMock(spec=HumanInputFormEntity)
form_entity.id = "test-form-id"
form_entity.web_app_token = "test-form-token"
form_entity.recipients = []
form_entity.rendered_content = "rendered"
form_entity.submitted = True
form_entity.selected_action_id = action_id
form_entity.submitted_data = {}
form_entity.expiration_time = naive_utc_now() + datetime.timedelta(days=1)
repo.get_form.return_value = form_entity
return repo
def _mock_form_repository_without_submission() -> HumanInputFormRepository:
repo = MagicMock(spec=HumanInputFormRepository)
form_entity = MagicMock(spec=HumanInputFormEntity)
form_entity.id = "test-form-id"
form_entity.web_app_token = "test-form-token"
form_entity.recipients = []
form_entity.rendered_content = "rendered"
form_entity.submitted = False
repo.create_form.return_value = form_entity
repo.get_form.return_value = None
return repo
def _build_human_input_graph(
runtime_state: GraphRuntimeState,
form_repository: HumanInputFormRepository,
) -> Graph:
graph_config: dict[str, object] = {"nodes": [], "edges": []}
params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config=graph_config,
user_id="user",
user_from="account",
invoke_from="service-api",
call_depth=0,
)
start_data = StartNodeData(title="start", variables=[])
start_node = StartNode(
id="start",
config={"id": "start", "data": start_data.model_dump()},
graph_init_params=params,
graph_runtime_state=runtime_state,
)
human_data = HumanInputNodeData(
title="human",
form_content="Awaiting human input",
inputs=[],
user_actions=[
UserAction(id="continue", title="Continue"),
],
)
human_node = HumanInputNode(
id="human",
config={"id": "human", "data": human_data.model_dump()},
graph_init_params=params,
graph_runtime_state=runtime_state,
form_repository=form_repository,
)
end_data = EndNodeData(
title="end",
outputs=[
OutputVariableEntity(variable="result", value_selector=["human", "action_id"]),
],
desc=None,
)
end_node = EndNode(
id="end",
config={"id": "end", "data": end_data.model_dump()},
graph_init_params=params,
graph_runtime_state=runtime_state,
)
return (
Graph.new()
.add_root(start_node)
.add_node(human_node)
.add_node(end_node, from_node_id="human", source_handle="continue")
.build()
)
def _run_graph(graph: Graph, runtime_state: GraphRuntimeState) -> list[GraphEngineEvent]:
engine = GraphEngine(
workflow_id="workflow",
graph=graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
)
return list(engine.run())
def _node_successes(events: list[GraphEngineEvent]) -> list[str]:
return [event.node_id for event in events if isinstance(event, NodeRunSucceededEvent)]
def _node_start_event(events: list[GraphEngineEvent], node_id: str) -> NodeRunStartedEvent | None:
for event in events:
if isinstance(event, NodeRunStartedEvent) and event.node_id == node_id:
return event
return None
def _segment_value(variable_pool: VariablePool, selector: tuple[str, str]) -> Any:
segment = variable_pool.get(selector)
assert segment is not None
return getattr(segment, "value", segment)
def test_engine_resume_restores_state_and_completion():
# Baseline run without pausing
baseline_state = _build_runtime_state()
baseline_repo = _mock_form_repository_with_submission(action_id="continue")
baseline_graph = _build_human_input_graph(baseline_state, baseline_repo)
baseline_events = _run_graph(baseline_graph, baseline_state)
assert baseline_events
first_paused_event = baseline_events[0]
assert isinstance(first_paused_event, GraphRunStartedEvent)
assert first_paused_event.reason is WorkflowStartReason.INITIAL
assert isinstance(baseline_events[-1], GraphRunSucceededEvent)
baseline_success_nodes = _node_successes(baseline_events)
# Run with pause
paused_state = _build_runtime_state()
pause_repo = _mock_form_repository_without_submission()
paused_graph = _build_human_input_graph(paused_state, pause_repo)
paused_events = _run_graph(paused_graph, paused_state)
assert paused_events
first_paused_event = paused_events[0]
assert isinstance(first_paused_event, GraphRunStartedEvent)
assert first_paused_event.reason is WorkflowStartReason.INITIAL
assert isinstance(paused_events[-1], GraphRunPausedEvent)
snapshot = paused_state.dumps()
# Resume from snapshot
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
resume_repo = _mock_form_repository_with_submission(action_id="continue")
resumed_graph = _build_human_input_graph(resumed_state, resume_repo)
resumed_events = _run_graph(resumed_graph, resumed_state)
assert resumed_events
first_resumed_event = resumed_events[0]
assert isinstance(first_resumed_event, GraphRunStartedEvent)
assert first_resumed_event.reason is WorkflowStartReason.RESUMPTION
assert isinstance(resumed_events[-1], GraphRunSucceededEvent)
combined_success_nodes = _node_successes(paused_events) + _node_successes(resumed_events)
assert combined_success_nodes == baseline_success_nodes
paused_human_started = _node_start_event(paused_events, "human")
resumed_human_started = _node_start_event(resumed_events, "human")
assert paused_human_started is not None
assert resumed_human_started is not None
assert paused_human_started.id == resumed_human_started.id
assert baseline_state.outputs == resumed_state.outputs
assert _segment_value(baseline_state.variable_pool, ("human", "__action_id")) == _segment_value(
resumed_state.variable_pool, ("human", "__action_id")
)
assert baseline_state.graph_execution.completed
assert resumed_state.graph_execution.completed

View File

@ -7,6 +7,7 @@ from core.workflow.nodes.base.node import Node
# Ensures that all node classes are imported.
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
# Ensure `NODE_TYPE_CLASSES_MAPPING` is used and not automatically removed.
_ = NODE_TYPE_CLASSES_MAPPING
@ -45,7 +46,9 @@ def test_ensure_subclasses_of_base_node_has_node_type_and_version_method_defined
assert isinstance(cls.node_type, NodeType)
assert isinstance(node_version, str)
node_type_and_version = (node_type, node_version)
assert node_type_and_version not in type_version_set
assert node_type_and_version not in type_version_set, (
f"Duplicate node type and version for class: {cls=} {node_type_and_version=}"
)
type_version_set.add(node_type_and_version)

View File

@ -0,0 +1 @@
# Unit tests for human input node

View File

@ -0,0 +1,16 @@
from core.workflow.nodes.human_input.entities import EmailDeliveryConfig, EmailRecipients
from core.workflow.runtime import VariablePool
def test_render_body_template_replaces_variable_values():
config = EmailDeliveryConfig(
recipients=EmailRecipients(),
subject="Subject",
body="Hello {{#node1.value#}} {{#url#}}",
)
variable_pool = VariablePool()
variable_pool.add(["node1", "value"], "World")
result = config.render_body_template(body=config.body, url="https://example.com", variable_pool=variable_pool)
assert result == "Hello World https://example.com"

View File

@ -0,0 +1,597 @@
"""
Unit tests for human input node entities.
"""
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from pydantic import ValidationError
from core.workflow.entities import GraphInitParams
from core.workflow.node_events import PauseRequestedEvent
from core.workflow.node_events.node import StreamCompletedEvent
from core.workflow.nodes.human_input.entities import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,
ExternalRecipient,
FormInput,
FormInputDefault,
HumanInputNodeData,
MemberRecipient,
UserAction,
WebAppDeliveryMethod,
_WebAppDeliveryConfig,
)
from core.workflow.nodes.human_input.enums import (
ButtonStyle,
DeliveryMethodType,
EmailRecipientType,
FormInputType,
PlaceholderType,
TimeoutUnit,
)
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
from core.workflow.repositories.human_input_form_repository import HumanInputFormRepository
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from tests.unit_tests.core.workflow.graph_engine.human_input_test_utils import InMemoryHumanInputFormRepository
class TestDeliveryMethod:
"""Test DeliveryMethod entity."""
def test_webapp_delivery_method(self):
"""Test webapp delivery method creation."""
delivery_method = WebAppDeliveryMethod(enabled=True, config=_WebAppDeliveryConfig())
assert delivery_method.type == DeliveryMethodType.WEBAPP
assert delivery_method.enabled is True
assert isinstance(delivery_method.config, _WebAppDeliveryConfig)
def test_email_delivery_method(self):
"""Test email delivery method creation."""
recipients = EmailRecipients(
whole_workspace=False,
items=[
MemberRecipient(type=EmailRecipientType.MEMBER, user_id="test-user-123"),
ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="test@example.com"),
],
)
config = EmailDeliveryConfig(
recipients=recipients, subject="Test Subject", body="Test body with {{#url#}} placeholder"
)
delivery_method = EmailDeliveryMethod(enabled=True, config=config)
assert delivery_method.type == DeliveryMethodType.EMAIL
assert delivery_method.enabled is True
assert isinstance(delivery_method.config, EmailDeliveryConfig)
assert delivery_method.config.subject == "Test Subject"
assert len(delivery_method.config.recipients.items) == 2
class TestFormInput:
"""Test FormInput entity."""
def test_text_input_with_constant_default(self):
"""Test text input with constant default value."""
default = FormInputDefault(type=PlaceholderType.CONSTANT, value="Enter your response here...")
form_input = FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="user_input", default=default)
assert form_input.type == FormInputType.TEXT_INPUT
assert form_input.output_variable_name == "user_input"
assert form_input.default.type == PlaceholderType.CONSTANT
assert form_input.default.value == "Enter your response here..."
def test_text_input_with_variable_default(self):
"""Test text input with variable default value."""
default = FormInputDefault(type=PlaceholderType.VARIABLE, selector=["node_123", "output_var"])
form_input = FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="user_input", default=default)
assert form_input.default.type == PlaceholderType.VARIABLE
assert form_input.default.selector == ["node_123", "output_var"]
def test_form_input_without_default(self):
"""Test form input without default value."""
form_input = FormInput(type=FormInputType.PARAGRAPH, output_variable_name="description")
assert form_input.type == FormInputType.PARAGRAPH
assert form_input.output_variable_name == "description"
assert form_input.default is None
class TestUserAction:
"""Test UserAction entity."""
def test_user_action_creation(self):
"""Test user action creation."""
action = UserAction(id="approve", title="Approve", button_style=ButtonStyle.PRIMARY)
assert action.id == "approve"
assert action.title == "Approve"
assert action.button_style == ButtonStyle.PRIMARY
def test_user_action_default_button_style(self):
"""Test user action with default button style."""
action = UserAction(id="cancel", title="Cancel")
assert action.button_style == ButtonStyle.DEFAULT
def test_user_action_length_boundaries(self):
"""Test user action id and title length boundaries."""
action = UserAction(id="a" * 20, title="b" * 20)
assert action.id == "a" * 20
assert action.title == "b" * 20
@pytest.mark.parametrize(
("field_name", "value"),
[
("id", "a" * 21),
("title", "b" * 21),
],
)
def test_user_action_length_limits(self, field_name: str, value: str):
"""User action fields should enforce max length."""
data = {"id": "approve", "title": "Approve"}
data[field_name] = value
with pytest.raises(ValidationError) as exc_info:
UserAction(**data)
errors = exc_info.value.errors()
assert any(error["loc"] == (field_name,) and error["type"] == "string_too_long" for error in errors)
class TestHumanInputNodeData:
"""Test HumanInputNodeData entity."""
def test_valid_node_data_creation(self):
"""Test creating valid human input node data."""
delivery_methods = [WebAppDeliveryMethod(enabled=True, config=_WebAppDeliveryConfig())]
inputs = [
FormInput(
type=FormInputType.TEXT_INPUT,
output_variable_name="content",
default=FormInputDefault(type=PlaceholderType.CONSTANT, value="Enter content..."),
)
]
user_actions = [UserAction(id="submit", title="Submit", button_style=ButtonStyle.PRIMARY)]
node_data = HumanInputNodeData(
title="Human Input Test",
desc="Test node description",
delivery_methods=delivery_methods,
form_content="# Test Form\n\nPlease provide input:\n\n{{#$output.content#}}",
inputs=inputs,
user_actions=user_actions,
timeout=24,
timeout_unit=TimeoutUnit.HOUR,
)
assert node_data.title == "Human Input Test"
assert node_data.desc == "Test node description"
assert len(node_data.delivery_methods) == 1
assert node_data.form_content.startswith("# Test Form")
assert len(node_data.inputs) == 1
assert len(node_data.user_actions) == 1
assert node_data.timeout == 24
assert node_data.timeout_unit == TimeoutUnit.HOUR
def test_node_data_with_multiple_delivery_methods(self):
"""Test node data with multiple delivery methods."""
delivery_methods = [
WebAppDeliveryMethod(enabled=True, config=_WebAppDeliveryConfig()),
EmailDeliveryMethod(
enabled=False, # Disabled method should be fine
config=EmailDeliveryConfig(
subject="Hi there", body="", recipients=EmailRecipients(whole_workspace=True)
),
),
]
node_data = HumanInputNodeData(
title="Test Node", delivery_methods=delivery_methods, timeout=1, timeout_unit=TimeoutUnit.DAY
)
assert len(node_data.delivery_methods) == 2
assert node_data.timeout == 1
assert node_data.timeout_unit == TimeoutUnit.DAY
def test_node_data_defaults(self):
"""Test node data with default values."""
node_data = HumanInputNodeData(title="Test Node")
assert node_data.title == "Test Node"
assert node_data.desc is None
assert node_data.delivery_methods == []
assert node_data.form_content == ""
assert node_data.inputs == []
assert node_data.user_actions == []
assert node_data.timeout == 36
assert node_data.timeout_unit == TimeoutUnit.HOUR
def test_duplicate_input_output_variable_name_raises_validation_error(self):
"""Duplicate form input output_variable_name should raise validation error."""
duplicate_inputs = [
FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="content"),
FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="content"),
]
with pytest.raises(ValidationError, match="duplicated output_variable_name 'content'"):
HumanInputNodeData(title="Test Node", inputs=duplicate_inputs)
def test_duplicate_user_action_ids_raise_validation_error(self):
"""Duplicate user action ids should raise validation error."""
duplicate_actions = [
UserAction(id="submit", title="Submit"),
UserAction(id="submit", title="Submit Again"),
]
with pytest.raises(ValidationError, match="duplicated user action id 'submit'"):
HumanInputNodeData(title="Test Node", user_actions=duplicate_actions)
def test_extract_outputs_field_names(self):
content = r"""This is titile {{#start.title#}}
A content is required:
{{#$output.content#}}
A ending is required:
{{#$output.ending#}}
"""
node_data = HumanInputNodeData(title="Human Input", form_content=content)
field_names = node_data.outputs_field_names()
assert field_names == ["content", "ending"]
class TestRecipients:
"""Test email recipient entities."""
def test_member_recipient(self):
"""Test member recipient creation."""
recipient = MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123")
assert recipient.type == EmailRecipientType.MEMBER
assert recipient.user_id == "user-123"
def test_external_recipient(self):
"""Test external recipient creation."""
recipient = ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="test@example.com")
assert recipient.type == EmailRecipientType.EXTERNAL
assert recipient.email == "test@example.com"
def test_email_recipients_whole_workspace(self):
"""Test email recipients with whole workspace enabled."""
recipients = EmailRecipients(
whole_workspace=True, items=[MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123")]
)
assert recipients.whole_workspace is True
assert len(recipients.items) == 1 # Items are preserved even when whole_workspace is True
def test_email_recipients_specific_users(self):
"""Test email recipients with specific users."""
recipients = EmailRecipients(
whole_workspace=False,
items=[
MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123"),
ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="external@example.com"),
],
)
assert recipients.whole_workspace is False
assert len(recipients.items) == 2
assert recipients.items[0].user_id == "user-123"
assert recipients.items[1].email == "external@example.com"
class TestHumanInputNodeVariableResolution:
"""Tests for resolving variable-based defaults in HumanInputNode."""
def test_resolves_variable_defaults(self):
variable_pool = VariablePool(
system_variables=SystemVariable(
user_id="user",
app_id="app",
workflow_id="workflow",
workflow_execution_id="exec-1",
),
user_inputs={},
conversation_variables=[],
)
variable_pool.add(("start", "name"), "Jane Doe")
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
graph_init_params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config={"nodes": [], "edges": []},
user_id="user",
user_from="account",
invoke_from="debugger",
call_depth=0,
)
node_data = HumanInputNodeData(
title="Human Input",
form_content="Provide your name",
inputs=[
FormInput(
type=FormInputType.TEXT_INPUT,
output_variable_name="user_name",
default=FormInputDefault(type=PlaceholderType.VARIABLE, selector=["start", "name"]),
),
FormInput(
type=FormInputType.TEXT_INPUT,
output_variable_name="user_email",
default=FormInputDefault(type=PlaceholderType.CONSTANT, value="foo@example.com"),
),
],
user_actions=[UserAction(id="submit", title="Submit")],
)
config = {"id": "human", "data": node_data.model_dump()}
mock_repo = MagicMock(spec=HumanInputFormRepository)
mock_repo.get_form.return_value = None
mock_repo.create_form.return_value = SimpleNamespace(
id="form-1",
rendered_content="Provide your name",
web_app_token="token",
recipients=[],
submitted=False,
)
node = HumanInputNode(
id=config["id"],
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
form_repository=mock_repo,
)
run_result = node._run()
pause_event = next(run_result)
assert isinstance(pause_event, PauseRequestedEvent)
expected_values = {"user_name": "Jane Doe"}
assert pause_event.reason.resolved_default_values == expected_values
params = mock_repo.create_form.call_args.args[0]
assert params.resolved_default_values == expected_values
def test_debugger_falls_back_to_recipient_token_when_webapp_disabled(self):
variable_pool = VariablePool(
system_variables=SystemVariable(
user_id="user",
app_id="app",
workflow_id="workflow",
workflow_execution_id="exec-2",
),
user_inputs={},
conversation_variables=[],
)
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
graph_init_params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config={"nodes": [], "edges": []},
user_id="user",
user_from="account",
invoke_from="debugger",
call_depth=0,
)
node_data = HumanInputNodeData(
title="Human Input",
form_content="Provide your name",
inputs=[],
user_actions=[UserAction(id="submit", title="Submit")],
)
config = {"id": "human", "data": node_data.model_dump()}
mock_repo = MagicMock(spec=HumanInputFormRepository)
mock_repo.get_form.return_value = None
mock_repo.create_form.return_value = SimpleNamespace(
id="form-2",
rendered_content="Provide your name",
web_app_token="console-token",
recipients=[SimpleNamespace(token="recipient-token")],
submitted=False,
)
node = HumanInputNode(
id=config["id"],
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
form_repository=mock_repo,
)
run_result = node._run()
pause_event = next(run_result)
assert isinstance(pause_event, PauseRequestedEvent)
assert pause_event.reason.form_token == "console-token"
def test_debugger_debug_mode_overrides_email_recipients(self):
variable_pool = VariablePool(
system_variables=SystemVariable(
user_id="user-123",
app_id="app",
workflow_id="workflow",
workflow_execution_id="exec-3",
),
user_inputs={},
conversation_variables=[],
)
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
graph_init_params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config={"nodes": [], "edges": []},
user_id="user-123",
user_from="account",
invoke_from="debugger",
call_depth=0,
)
node_data = HumanInputNodeData(
title="Human Input",
form_content="Provide your name",
inputs=[],
user_actions=[UserAction(id="submit", title="Submit")],
delivery_methods=[
EmailDeliveryMethod(
enabled=True,
config=EmailDeliveryConfig(
recipients=EmailRecipients(
whole_workspace=False,
items=[ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="target@example.com")],
),
subject="Subject",
body="Body",
debug_mode=True,
),
)
],
)
config = {"id": "human", "data": node_data.model_dump()}
mock_repo = MagicMock(spec=HumanInputFormRepository)
mock_repo.get_form.return_value = None
mock_repo.create_form.return_value = SimpleNamespace(
id="form-3",
rendered_content="Provide your name",
web_app_token="token",
recipients=[],
submitted=False,
)
node = HumanInputNode(
id=config["id"],
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
form_repository=mock_repo,
)
run_result = node._run()
pause_event = next(run_result)
assert isinstance(pause_event, PauseRequestedEvent)
params = mock_repo.create_form.call_args.args[0]
assert len(params.delivery_methods) == 1
method = params.delivery_methods[0]
assert isinstance(method, EmailDeliveryMethod)
assert method.config.debug_mode is True
assert method.config.recipients.whole_workspace is False
assert len(method.config.recipients.items) == 1
recipient = method.config.recipients.items[0]
assert isinstance(recipient, MemberRecipient)
assert recipient.user_id == "user-123"
class TestValidation:
"""Test validation scenarios."""
def test_invalid_form_input_type(self):
"""Test validation with invalid form input type."""
with pytest.raises(ValidationError):
FormInput(
type="invalid-type", # Invalid type
output_variable_name="test",
)
def test_invalid_button_style(self):
"""Test validation with invalid button style."""
with pytest.raises(ValidationError):
UserAction(
id="test",
title="Test",
button_style="invalid-style", # Invalid style
)
def test_invalid_timeout_unit(self):
"""Test validation with invalid timeout unit."""
with pytest.raises(ValidationError):
HumanInputNodeData(
title="Test",
timeout_unit="invalid-unit", # Invalid unit
)
class TestHumanInputNodeRenderedContent:
"""Tests for rendering submitted content."""
def test_replaces_outputs_placeholders_after_submission(self):
variable_pool = VariablePool(
system_variables=SystemVariable(
user_id="user",
app_id="app",
workflow_id="workflow",
workflow_execution_id="exec-1",
),
user_inputs={},
conversation_variables=[],
)
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
graph_init_params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config={"nodes": [], "edges": []},
user_id="user",
user_from="account",
invoke_from="debugger",
call_depth=0,
)
node_data = HumanInputNodeData(
title="Human Input",
form_content="Name: {{#$output.name#}}",
inputs=[
FormInput(
type=FormInputType.TEXT_INPUT,
output_variable_name="name",
)
],
user_actions=[UserAction(id="approve", title="Approve")],
)
config = {"id": "human", "data": node_data.model_dump()}
form_repository = InMemoryHumanInputFormRepository()
node = HumanInputNode(
id=config["id"],
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
form_repository=form_repository,
)
pause_gen = node._run()
pause_event = next(pause_gen)
assert isinstance(pause_event, PauseRequestedEvent)
with pytest.raises(StopIteration):
next(pause_gen)
form_repository.set_submission(action_id="approve", form_data={"name": "Alice"})
events = list(node._run())
last_event = events[-1]
assert isinstance(last_event, StreamCompletedEvent)
node_run_result = last_event.node_run_result
assert node_run_result.outputs["__rendered_content"] == "Name: Alice"

View File

@ -0,0 +1,172 @@
import datetime
from types import SimpleNamespace
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.enums import NodeType
from core.workflow.graph_events import (
NodeRunHumanInputFormFilledEvent,
NodeRunHumanInputFormTimeoutEvent,
NodeRunStartedEvent,
)
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from libs.datetime_utils import naive_utc_now
from models.enums import UserFrom
class _FakeFormRepository:
def __init__(self, form):
self._form = form
def get_form(self, *_args, **_kwargs):
return self._form
def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#}}") -> HumanInputNode:
system_variables = SystemVariable.default()
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]),
start_at=0.0,
)
graph_init_params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config={"nodes": [], "edges": []},
user_id="user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)
config = {
"id": "node-1",
"type": NodeType.HUMAN_INPUT.value,
"data": {
"title": "Human Input",
"form_content": form_content,
"inputs": [
{
"type": "text_input",
"output_variable_name": "name",
"default": {"type": "constant", "value": ""},
}
],
"user_actions": [
{
"id": "Accept",
"title": "Approve",
"button_style": "default",
}
],
},
}
fake_form = SimpleNamespace(
id="form-1",
rendered_content=form_content,
submitted=True,
selected_action_id="Accept",
submitted_data={"name": "Alice"},
status=HumanInputFormStatus.SUBMITTED,
expiration_time=naive_utc_now() + datetime.timedelta(days=1),
)
repo = _FakeFormRepository(fake_form)
return HumanInputNode(
id="node-1",
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
form_repository=repo,
)
def _build_timeout_node() -> HumanInputNode:
system_variables = SystemVariable.default()
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]),
start_at=0.0,
)
graph_init_params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config={"nodes": [], "edges": []},
user_id="user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)
config = {
"id": "node-1",
"type": NodeType.HUMAN_INPUT.value,
"data": {
"title": "Human Input",
"form_content": "Please enter your name:\n\n{{#$output.name#}}",
"inputs": [
{
"type": "text_input",
"output_variable_name": "name",
"default": {"type": "constant", "value": ""},
}
],
"user_actions": [
{
"id": "Accept",
"title": "Approve",
"button_style": "default",
}
],
},
}
fake_form = SimpleNamespace(
id="form-1",
rendered_content="content",
submitted=False,
selected_action_id=None,
submitted_data=None,
status=HumanInputFormStatus.TIMEOUT,
expiration_time=naive_utc_now() - datetime.timedelta(minutes=1),
)
repo = _FakeFormRepository(fake_form)
return HumanInputNode(
id="node-1",
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
form_repository=repo,
)
def test_human_input_node_emits_form_filled_event_before_succeeded():
node = _build_node()
events = list(node.run())
assert isinstance(events[0], NodeRunStartedEvent)
assert isinstance(events[1], NodeRunHumanInputFormFilledEvent)
filled_event = events[1]
assert filled_event.node_title == "Human Input"
assert filled_event.rendered_content.endswith("Alice")
assert filled_event.action_id == "Accept"
assert filled_event.action_text == "Approve"
def test_human_input_node_emits_timeout_event_before_succeeded():
node = _build_timeout_node()
events = list(node.run())
assert isinstance(events[0], NodeRunStartedEvent)
assert isinstance(events[1], NodeRunHumanInputFormTimeoutEvent)
timeout_event = events[1]
assert timeout_event.node_title == "Human Input"

View File

@ -104,6 +104,7 @@ class TestCelerySSLConfiguration:
def test_celery_init_applies_ssl_to_broker_and_backend(self):
"""Test that SSL options are applied to both broker and backend when using Redis."""
mock_config = MagicMock()
mock_config.HUMAN_INPUT_TIMEOUT_TASK_INTERVAL = 1
mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0"
mock_config.CELERY_BACKEND = "redis"
mock_config.CELERY_RESULT_BACKEND = "redis://localhost:6379/0"

View File

@ -0,0 +1,20 @@
from configs import dify_config
from extensions import ext_redis
from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel
from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel
def test_get_pubsub_broadcast_channel_defaults_to_pubsub(monkeypatch):
monkeypatch.setattr(dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "pubsub")
channel = ext_redis.get_pubsub_broadcast_channel()
assert isinstance(channel, RedisBroadcastChannel)
def test_get_pubsub_broadcast_channel_sharded(monkeypatch):
monkeypatch.setattr(dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "sharded")
channel = ext_redis.get_pubsub_broadcast_channel()
assert isinstance(channel, ShardedRedisBroadcastChannel)

View File

@ -0,0 +1 @@
# Treat this directory as a package so support modules can be imported relatively.

View File

@ -0,0 +1,249 @@
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Any
from core.workflow.nodes.human_input.entities import FormInput
from core.workflow.nodes.human_input.enums import TimeoutUnit
# Exceptions
class HumanInputError(Exception):
error_code: str = "unknown"
def __init__(self, message: str = "", error_code: str | None = None):
super().__init__(message)
self.message = message or self.__class__.__name__
if error_code:
self.error_code = error_code
class FormNotFoundError(HumanInputError):
error_code = "form_not_found"
class FormExpiredError(HumanInputError):
error_code = "human_input_form_expired"
class FormAlreadySubmittedError(HumanInputError):
error_code = "human_input_form_submitted"
class InvalidFormDataError(HumanInputError):
error_code = "invalid_form_data"
# Models
@dataclass
class HumanInputForm:
form_id: str
workflow_run_id: str
node_id: str
tenant_id: str
app_id: str | None
form_content: str
inputs: list[FormInput]
user_actions: list[dict[str, Any]]
timeout: int
timeout_unit: TimeoutUnit
form_token: str | None = None
created_at: datetime = field(default_factory=datetime.utcnow)
expires_at: datetime | None = None
submitted_at: datetime | None = None
submitted_data: dict[str, Any] | None = None
submitted_action: str | None = None
def __post_init__(self) -> None:
if self.expires_at is None:
self.calculate_expiration()
@property
def is_expired(self) -> bool:
return self.expires_at is not None and datetime.utcnow() > self.expires_at
@property
def is_submitted(self) -> bool:
return self.submitted_at is not None
def mark_submitted(self, inputs: dict[str, Any], action: str) -> None:
self.submitted_data = inputs
self.submitted_action = action
self.submitted_at = datetime.utcnow()
def submit(self, inputs: dict[str, Any], action: str) -> None:
self.mark_submitted(inputs, action)
def calculate_expiration(self) -> None:
start = self.created_at
if self.timeout_unit == TimeoutUnit.HOUR:
self.expires_at = start + timedelta(hours=self.timeout)
elif self.timeout_unit == TimeoutUnit.DAY:
self.expires_at = start + timedelta(days=self.timeout)
else:
raise ValueError(f"Unsupported timeout unit {self.timeout_unit}")
def to_response_dict(self, *, include_site_info: bool) -> dict[str, Any]:
inputs_response = [
{
"type": form_input.type.name.lower().replace("_", "-"),
"output_variable_name": form_input.output_variable_name,
}
for form_input in self.inputs
]
response = {
"form_content": self.form_content,
"inputs": inputs_response,
"user_actions": self.user_actions,
}
if include_site_info:
response["site"] = {"app_id": self.app_id, "title": "Workflow Form"}
return response
@dataclass
class FormSubmissionData:
form_id: str
inputs: dict[str, Any]
action: str
submitted_at: datetime = field(default_factory=datetime.utcnow)
@classmethod
def from_request(cls, form_id: str, request: FormSubmissionRequest) -> FormSubmissionData: # type: ignore
return cls(form_id=form_id, inputs=request.inputs, action=request.action)
@dataclass
class FormSubmissionRequest:
inputs: dict[str, Any]
action: str
# Repository
class InMemoryFormRepository:
"""
Simple in-memory repository used by unit tests.
"""
def __init__(self):
self._forms: dict[str, HumanInputForm] = {}
@property
def forms(self) -> dict[str, HumanInputForm]:
return self._forms
def save(self, form: HumanInputForm) -> None:
self._forms[form.form_id] = form
def get_by_id(self, form_id: str) -> HumanInputForm | None:
return self._forms.get(form_id)
def get_by_token(self, token: str) -> HumanInputForm | None:
for form in self._forms.values():
if form.form_token == token:
return form
return None
def delete(self, form_id: str) -> None:
self._forms.pop(form_id, None)
# Service
class FormService:
"""Service layer for managing human input forms in tests."""
def __init__(self, repository: InMemoryFormRepository):
self.repository = repository
def create_form(
self,
*,
form_id: str,
workflow_run_id: str,
node_id: str,
tenant_id: str,
app_id: str | None,
form_content: str,
inputs,
user_actions,
timeout: int,
timeout_unit: TimeoutUnit,
form_token: str | None = None,
) -> HumanInputForm:
form = HumanInputForm(
form_id=form_id,
workflow_run_id=workflow_run_id,
node_id=node_id,
tenant_id=tenant_id,
app_id=app_id,
form_content=form_content,
inputs=list(inputs),
user_actions=[{"id": action.id, "title": action.title} for action in user_actions],
timeout=timeout,
timeout_unit=timeout_unit,
form_token=form_token,
)
form.calculate_expiration()
self.repository.save(form)
return form
def get_form_by_id(self, form_id: str) -> HumanInputForm:
form = self.repository.get_by_id(form_id)
if form is None:
raise FormNotFoundError()
return form
def get_form_by_token(self, token: str) -> HumanInputForm:
form = self.repository.get_by_token(token)
if form is None:
raise FormNotFoundError()
return form
def get_form_definition(self, form_id: str, *, is_token: bool) -> dict:
form = self.get_form_by_token(form_id) if is_token else self.get_form_by_id(form_id)
if form.is_expired:
raise FormExpiredError()
if form.is_submitted:
raise FormAlreadySubmittedError()
definition = {
"form_content": form.form_content,
"inputs": form.inputs,
"user_actions": form.user_actions,
}
if is_token:
definition["site"] = {"title": "Workflow Form"}
return definition
def submit_form(self, form_id: str, submission_data: FormSubmissionData, *, is_token: bool) -> None:
form = self.get_form_by_token(form_id) if is_token else self.get_form_by_id(form_id)
if form.is_expired:
raise FormExpiredError()
if form.is_submitted:
raise FormAlreadySubmittedError()
self._validate_submission(form=form, submission_data=submission_data)
form.mark_submitted(inputs=submission_data.inputs, action=submission_data.action)
self.repository.save(form)
def cleanup_expired_forms(self) -> int:
expired_ids = [form_id for form_id, form in list(self.repository.forms.items()) if form.is_expired]
for form_id in expired_ids:
self.repository.delete(form_id)
return len(expired_ids)
def _validate_submission(self, form: HumanInputForm, submission_data: FormSubmissionData) -> None:
defined_actions = {action["id"] for action in form.user_actions}
if submission_data.action not in defined_actions:
raise InvalidFormDataError(f"Invalid action: {submission_data.action}")
missing_inputs = []
for form_input in form.inputs:
if form_input.output_variable_name not in submission_data.inputs:
missing_inputs.append(form_input.output_variable_name)
if missing_inputs:
raise InvalidFormDataError(f"Missing required inputs: {', '.join(missing_inputs)}")
# Extra inputs are allowed; no further validation required.

View File

@ -0,0 +1,326 @@
"""
Unit tests for FormService.
"""
from datetime import datetime, timedelta
import pytest
from core.workflow.nodes.human_input.entities import (
FormInput,
UserAction,
)
from core.workflow.nodes.human_input.enums import (
FormInputType,
TimeoutUnit,
)
from libs.datetime_utils import naive_utc_now
from .support import (
FormAlreadySubmittedError,
FormExpiredError,
FormNotFoundError,
FormService,
FormSubmissionData,
InMemoryFormRepository,
InvalidFormDataError,
)
class TestFormService:
"""Test FormService functionality."""
@pytest.fixture
def repository(self):
"""Create in-memory repository for testing."""
return InMemoryFormRepository()
@pytest.fixture
def form_service(self, repository):
"""Create FormService with in-memory repository."""
return FormService(repository)
@pytest.fixture
def sample_form_data(self):
"""Create sample form data."""
return {
"form_id": "form-123",
"workflow_run_id": "run-456",
"node_id": "node-789",
"tenant_id": "tenant-abc",
"app_id": "app-def",
"form_content": "# Test Form\n\nInput: {{#$output.input#}}",
"inputs": [FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="input", default=None)],
"user_actions": [UserAction(id="submit", title="Submit")],
"timeout": 1,
"timeout_unit": TimeoutUnit.HOUR,
"form_token": "token-xyz",
}
def test_create_form(self, form_service, sample_form_data):
"""Test form creation."""
form = form_service.create_form(**sample_form_data)
assert form.form_id == "form-123"
assert form.workflow_run_id == "run-456"
assert form.node_id == "node-789"
assert form.tenant_id == "tenant-abc"
assert form.app_id == "app-def"
assert form.form_token == "token-xyz"
assert form.timeout == 1
assert form.timeout_unit == TimeoutUnit.HOUR
assert form.expires_at is not None
assert not form.is_expired
assert not form.is_submitted
def test_get_form_by_id(self, form_service, sample_form_data):
"""Test getting form by ID."""
# Create form first
created_form = form_service.create_form(**sample_form_data)
# Retrieve form
retrieved_form = form_service.get_form_by_id("form-123")
assert retrieved_form.form_id == created_form.form_id
assert retrieved_form.workflow_run_id == created_form.workflow_run_id
def test_get_form_by_id_not_found(self, form_service):
"""Test getting non-existent form by ID."""
with pytest.raises(FormNotFoundError) as exc_info:
form_service.get_form_by_id("non-existent-form")
assert exc_info.value.error_code == "form_not_found"
def test_get_form_by_token(self, form_service, sample_form_data):
"""Test getting form by token."""
# Create form first
created_form = form_service.create_form(**sample_form_data)
# Retrieve form by token
retrieved_form = form_service.get_form_by_token("token-xyz")
assert retrieved_form.form_id == created_form.form_id
assert retrieved_form.form_token == "token-xyz"
def test_get_form_by_token_not_found(self, form_service):
"""Test getting non-existent form by token."""
with pytest.raises(FormNotFoundError) as exc_info:
form_service.get_form_by_token("non-existent-token")
assert exc_info.value.error_code == "form_not_found"
def test_get_form_definition_by_id(self, form_service, sample_form_data):
"""Test getting form definition by ID."""
# Create form first
form_service.create_form(**sample_form_data)
# Get form definition
definition = form_service.get_form_definition("form-123", is_token=False)
assert "form_content" in definition
assert "inputs" in definition
assert definition["form_content"] == "# Test Form\n\nInput: {{#$output.input#}}"
assert len(definition["inputs"]) == 1
assert "site" not in definition # Should not include site info for ID-based access
def test_get_form_definition_by_token(self, form_service, sample_form_data):
"""Test getting form definition by token."""
# Create form first
form_service.create_form(**sample_form_data)
# Get form definition
definition = form_service.get_form_definition("token-xyz", is_token=True)
assert "form_content" in definition
assert "inputs" in definition
assert "site" in definition # Should include site info for token-based access
def test_get_form_definition_expired_form(self, form_service, sample_form_data):
"""Test getting definition for expired form."""
# Create form with past expiry
form_service.create_form(**sample_form_data)
# Manually expire the form by modifying expiry time
form = form_service.get_form_by_id("form-123")
form.expires_at = datetime.utcnow() - timedelta(hours=1)
form_service.repository.save(form)
# Should raise FormExpiredError
with pytest.raises(FormExpiredError) as exc_info:
form_service.get_form_definition("form-123", is_token=False)
assert exc_info.value.error_code == "human_input_form_expired"
def test_get_form_definition_submitted_form(self, form_service, sample_form_data):
"""Test getting definition for already submitted form."""
# Create form first
form_service.create_form(**sample_form_data)
# Submit the form
submission_data = FormSubmissionData(form_id="form-123", inputs={"input": "test value"}, action="submit")
form_service.submit_form("form-123", submission_data, is_token=False)
# Should raise FormAlreadySubmittedError
with pytest.raises(FormAlreadySubmittedError) as exc_info:
form_service.get_form_definition("form-123", is_token=False)
assert exc_info.value.error_code == "human_input_form_submitted"
def test_submit_form_success(self, form_service, sample_form_data):
"""Test successful form submission."""
# Create form first
form_service.create_form(**sample_form_data)
# Submit form
submission_data = FormSubmissionData(form_id="form-123", inputs={"input": "test value"}, action="submit")
# Should not raise any exception
form_service.submit_form("form-123", submission_data, is_token=False)
# Verify form is marked as submitted
form = form_service.get_form_by_id("form-123")
assert form.is_submitted
assert form.submitted_data == {"input": "test value"}
assert form.submitted_action == "submit"
assert form.submitted_at is not None
def test_submit_form_missing_inputs(self, form_service, sample_form_data):
"""Test form submission with missing inputs."""
# Create form first
form_service.create_form(**sample_form_data)
# Submit form with missing required input
submission_data = FormSubmissionData(
form_id="form-123",
inputs={}, # Missing required "input" field
action="submit",
)
with pytest.raises(InvalidFormDataError) as exc_info:
form_service.submit_form("form-123", submission_data, is_token=False)
assert "Missing required inputs" in exc_info.value.message
assert "input" in exc_info.value.message
def test_submit_form_invalid_action(self, form_service, sample_form_data):
"""Test form submission with invalid action."""
# Create form first
form_service.create_form(**sample_form_data)
# Submit form with invalid action
submission_data = FormSubmissionData(
form_id="form-123",
inputs={"input": "test value"},
action="invalid_action", # Not in the allowed actions
)
with pytest.raises(InvalidFormDataError) as exc_info:
form_service.submit_form("form-123", submission_data, is_token=False)
assert "Invalid action" in exc_info.value.message
assert "invalid_action" in exc_info.value.message
def test_submit_form_expired(self, form_service, sample_form_data):
"""Test submitting expired form."""
# Create form first
form_service.create_form(**sample_form_data)
# Manually expire the form
form = form_service.get_form_by_id("form-123")
form.expires_at = datetime.utcnow() - timedelta(hours=1)
form_service.repository.save(form)
# Try to submit expired form
submission_data = FormSubmissionData(form_id="form-123", inputs={"input": "test value"}, action="submit")
with pytest.raises(FormExpiredError) as exc_info:
form_service.submit_form("form-123", submission_data, is_token=False)
assert exc_info.value.error_code == "human_input_form_expired"
def test_submit_form_already_submitted(self, form_service, sample_form_data):
"""Test submitting form that's already submitted."""
# Create and submit form first
form_service.create_form(**sample_form_data)
submission_data = FormSubmissionData(form_id="form-123", inputs={"input": "first submission"}, action="submit")
form_service.submit_form("form-123", submission_data, is_token=False)
# Try to submit again
second_submission = FormSubmissionData(
form_id="form-123", inputs={"input": "second submission"}, action="submit"
)
with pytest.raises(FormAlreadySubmittedError) as exc_info:
form_service.submit_form("form-123", second_submission, is_token=False)
assert exc_info.value.error_code == "human_input_form_submitted"
def test_cleanup_expired_forms(self, form_service, sample_form_data):
"""Test cleanup of expired forms."""
# Create multiple forms
for i in range(3):
data = sample_form_data.copy()
data["form_id"] = f"form-{i}"
data["form_token"] = f"token-{i}"
form_service.create_form(**data)
# Manually expire some forms
for i in range(2): # Expire first 2 forms
form = form_service.get_form_by_id(f"form-{i}")
form.expires_at = naive_utc_now() - timedelta(hours=1)
form_service.repository.save(form)
# Clean up expired forms
cleaned_count = form_service.cleanup_expired_forms()
assert cleaned_count == 2
# Verify expired forms are gone
with pytest.raises(FormNotFoundError):
form_service.get_form_by_id("form-0")
with pytest.raises(FormNotFoundError):
form_service.get_form_by_id("form-1")
# Verify non-expired form still exists
form = form_service.get_form_by_id("form-2")
assert form.form_id == "form-2"
class TestFormValidation:
"""Test form validation logic."""
def test_validate_submission_with_extra_inputs(self):
"""Test validation allows extra inputs that aren't defined in form."""
repository = InMemoryFormRepository()
form_service = FormService(repository)
# Create form with one input
form_data = {
"form_id": "form-123",
"workflow_run_id": "run-456",
"node_id": "node-789",
"tenant_id": "tenant-abc",
"app_id": "app-def",
"form_content": "Test form",
"inputs": [FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="required_input", default=None)],
"user_actions": [UserAction(id="submit", title="Submit")],
"timeout": 1,
"timeout_unit": TimeoutUnit.HOUR,
}
form_service.create_form(**form_data)
# Submit with extra input (should be allowed)
submission_data = FormSubmissionData(
form_id="form-123",
inputs={
"required_input": "value1",
"extra_input": "value2", # Extra input not defined in form
},
action="submit",
)
# Should not raise any exception
form_service.submit_form("form-123", submission_data, is_token=False)

View File

@ -0,0 +1,232 @@
"""
Unit tests for human input form models.
"""
from datetime import datetime, timedelta
import pytest
from core.workflow.nodes.human_input.entities import (
FormInput,
UserAction,
)
from core.workflow.nodes.human_input.enums import (
FormInputType,
TimeoutUnit,
)
from .support import FormSubmissionData, FormSubmissionRequest, HumanInputForm
class TestHumanInputForm:
"""Test HumanInputForm model."""
@pytest.fixture
def sample_form_data(self):
"""Create sample form data."""
return {
"form_id": "form-123",
"workflow_run_id": "run-456",
"node_id": "node-789",
"tenant_id": "tenant-abc",
"app_id": "app-def",
"form_content": "# Test Form\n\nInput: {{#$output.input#}}",
"inputs": [FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="input", default=None)],
"user_actions": [UserAction(id="submit", title="Submit")],
"timeout": 2,
"timeout_unit": TimeoutUnit.HOUR,
"form_token": "token-xyz",
}
def test_form_creation(self, sample_form_data):
"""Test form creation."""
form = HumanInputForm(**sample_form_data)
assert form.form_id == "form-123"
assert form.workflow_run_id == "run-456"
assert form.node_id == "node-789"
assert form.tenant_id == "tenant-abc"
assert form.app_id == "app-def"
assert form.form_token == "token-xyz"
assert form.timeout == 2
assert form.timeout_unit == TimeoutUnit.HOUR
assert form.created_at is not None
assert form.expires_at is not None
assert form.submitted_at is None
assert form.submitted_data is None
assert form.submitted_action is None
def test_form_expiry_calculation_hours(self, sample_form_data):
"""Test form expiry calculation for hours."""
form = HumanInputForm(**sample_form_data)
# Should expire 2 hours after creation
expected_expiry = form.created_at + timedelta(hours=2)
assert abs((form.expires_at - expected_expiry).total_seconds()) < 1 # Within 1 second
def test_form_expiry_calculation_days(self, sample_form_data):
"""Test form expiry calculation for days."""
sample_form_data["timeout"] = 3
sample_form_data["timeout_unit"] = TimeoutUnit.DAY
form = HumanInputForm(**sample_form_data)
# Should expire 3 days after creation
expected_expiry = form.created_at + timedelta(days=3)
assert abs((form.expires_at - expected_expiry).total_seconds()) < 1 # Within 1 second
def test_form_expiry_property_not_expired(self, sample_form_data):
"""Test is_expired property for non-expired form."""
form = HumanInputForm(**sample_form_data)
assert not form.is_expired
def test_form_expiry_property_expired(self, sample_form_data):
"""Test is_expired property for expired form."""
# Create form with past expiry
past_time = datetime.utcnow() - timedelta(hours=1)
sample_form_data["created_at"] = past_time
form = HumanInputForm(**sample_form_data)
# Manually set expiry to past time
form.expires_at = past_time
assert form.is_expired
def test_form_submission_property_not_submitted(self, sample_form_data):
"""Test is_submitted property for non-submitted form."""
form = HumanInputForm(**sample_form_data)
assert not form.is_submitted
def test_form_submission_property_submitted(self, sample_form_data):
"""Test is_submitted property for submitted form."""
form = HumanInputForm(**sample_form_data)
form.submit({"input": "test value"}, "submit")
assert form.is_submitted
assert form.submitted_at is not None
assert form.submitted_data == {"input": "test value"}
assert form.submitted_action == "submit"
def test_form_submit_method(self, sample_form_data):
"""Test form submit method."""
form = HumanInputForm(**sample_form_data)
submission_time_before = datetime.utcnow()
form.submit({"input": "test value"}, "submit")
submission_time_after = datetime.utcnow()
assert form.is_submitted
assert form.submitted_data == {"input": "test value"}
assert form.submitted_action == "submit"
assert submission_time_before <= form.submitted_at <= submission_time_after
def test_form_to_response_dict_without_site_info(self, sample_form_data):
"""Test converting form to response dict without site info."""
form = HumanInputForm(**sample_form_data)
response = form.to_response_dict(include_site_info=False)
assert "form_content" in response
assert "inputs" in response
assert "site" not in response
assert response["form_content"] == "# Test Form\n\nInput: {{#$output.input#}}"
assert len(response["inputs"]) == 1
assert response["inputs"][0]["type"] == "text-input"
assert response["inputs"][0]["output_variable_name"] == "input"
def test_form_to_response_dict_with_site_info(self, sample_form_data):
"""Test converting form to response dict with site info."""
form = HumanInputForm(**sample_form_data)
response = form.to_response_dict(include_site_info=True)
assert "form_content" in response
assert "inputs" in response
assert "site" in response
assert response["site"]["app_id"] == "app-def"
assert response["site"]["title"] == "Workflow Form"
def test_form_without_web_app_token(self, sample_form_data):
"""Test form creation without web app token."""
sample_form_data["form_token"] = None
form = HumanInputForm(**sample_form_data)
assert form.form_token is None
assert form.form_id == "form-123" # Other fields should still work
def test_form_with_explicit_timestamps(self):
"""Test form creation with explicit timestamps."""
created_time = datetime(2024, 1, 15, 10, 30, 0)
expires_time = datetime(2024, 1, 15, 12, 30, 0)
form = HumanInputForm(
form_id="form-123",
workflow_run_id="run-456",
node_id="node-789",
tenant_id="tenant-abc",
app_id="app-def",
form_content="Test content",
inputs=[],
user_actions=[],
timeout=2,
timeout_unit=TimeoutUnit.HOUR,
created_at=created_time,
expires_at=expires_time,
)
assert form.created_at == created_time
assert form.expires_at == expires_time
class TestFormSubmissionData:
"""Test FormSubmissionData model."""
def test_submission_data_creation(self):
"""Test submission data creation."""
submission_data = FormSubmissionData(
form_id="form-123", inputs={"field1": "value1", "field2": "value2"}, action="submit"
)
assert submission_data.form_id == "form-123"
assert submission_data.inputs == {"field1": "value1", "field2": "value2"}
assert submission_data.action == "submit"
assert submission_data.submitted_at is not None
def test_submission_data_from_request(self):
"""Test creating submission data from API request."""
request = FormSubmissionRequest(inputs={"input": "test value"}, action="confirm")
submission_data = FormSubmissionData.from_request("form-456", request)
assert submission_data.form_id == "form-456"
assert submission_data.inputs == {"input": "test value"}
assert submission_data.action == "confirm"
assert submission_data.submitted_at is not None
def test_submission_data_with_empty_inputs(self):
"""Test submission data with empty inputs."""
submission_data = FormSubmissionData(form_id="form-123", inputs={}, action="cancel")
assert submission_data.inputs == {}
assert submission_data.action == "cancel"
def test_submission_data_timestamps(self):
"""Test submission data timestamp handling."""
before_time = datetime.utcnow()
submission_data = FormSubmissionData(form_id="form-123", inputs={"test": "value"}, action="submit")
after_time = datetime.utcnow()
assert before_time <= submission_data.submitted_at <= after_time
def test_submission_data_with_explicit_timestamp(self):
"""Test submission data with explicit timestamp."""
specific_time = datetime(2024, 1, 15, 14, 30, 0)
submission_data = FormSubmissionData(
form_id="form-123", inputs={"test": "value"}, action="submit", submitted_at=specific_time
)
assert submission_data.submitted_at == specific_time

View File

@ -181,6 +181,7 @@ class TestShardedTopic:
subscription = sharded_topic.subscribe()
assert isinstance(subscription, _RedisShardedSubscription)
assert subscription._client is mock_redis_client
assert subscription._pubsub is mock_redis_client.pubsub.return_value
assert subscription._topic == "test-sharded-topic"
@ -200,6 +201,11 @@ class SubscriptionTestCase:
class TestRedisSubscription:
"""Test cases for the _RedisSubscription class."""
@pytest.fixture
def mock_redis_client(self) -> MagicMock:
client = MagicMock()
return client
@pytest.fixture
def mock_pubsub(self) -> MagicMock:
"""Create a mock PubSub instance for testing."""
@ -211,9 +217,12 @@ class TestRedisSubscription:
return pubsub
@pytest.fixture
def subscription(self, mock_pubsub: MagicMock) -> Generator[_RedisSubscription, None, None]:
def subscription(
self, mock_pubsub: MagicMock, mock_redis_client: MagicMock
) -> Generator[_RedisSubscription, None, None]:
"""Create a _RedisSubscription instance for testing."""
subscription = _RedisSubscription(
client=mock_redis_client,
pubsub=mock_pubsub,
topic="test-topic",
)
@ -228,13 +237,15 @@ class TestRedisSubscription:
# ==================== Lifecycle Tests ====================
def test_subscription_initialization(self, mock_pubsub: MagicMock):
def test_subscription_initialization(self, mock_pubsub: MagicMock, mock_redis_client: MagicMock):
"""Test that subscription is properly initialized."""
subscription = _RedisSubscription(
client=mock_redis_client,
pubsub=mock_pubsub,
topic="test-topic",
)
assert subscription._client is mock_redis_client
assert subscription._pubsub is mock_pubsub
assert subscription._topic == "test-topic"
assert not subscription._closed.is_set()
@ -486,9 +497,12 @@ class TestRedisSubscription:
),
],
)
def test_subscription_scenarios(self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock):
def test_subscription_scenarios(
self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock, mock_redis_client: MagicMock
):
"""Test various subscription scenarios using table-driven approach."""
subscription = _RedisSubscription(
client=mock_redis_client,
pubsub=mock_pubsub,
topic="test-topic",
)
@ -572,7 +586,7 @@ class TestRedisSubscription:
# Close should still work
subscription.close() # Should not raise
def test_channel_name_variations(self, mock_pubsub: MagicMock):
def test_channel_name_variations(self, mock_pubsub: MagicMock, mock_redis_client: MagicMock):
"""Test various channel name formats."""
channel_names = [
"simple",
@ -586,6 +600,7 @@ class TestRedisSubscription:
for channel_name in channel_names:
subscription = _RedisSubscription(
client=mock_redis_client,
pubsub=mock_pubsub,
topic=channel_name,
)
@ -604,6 +619,11 @@ class TestRedisSubscription:
class TestRedisShardedSubscription:
"""Test cases for the _RedisShardedSubscription class."""
@pytest.fixture
def mock_redis_client(self) -> MagicMock:
client = MagicMock()
return client
@pytest.fixture
def mock_pubsub(self) -> MagicMock:
"""Create a mock PubSub instance for testing."""
@ -615,9 +635,12 @@ class TestRedisShardedSubscription:
return pubsub
@pytest.fixture
def sharded_subscription(self, mock_pubsub: MagicMock) -> Generator[_RedisShardedSubscription, None, None]:
def sharded_subscription(
self, mock_pubsub: MagicMock, mock_redis_client: MagicMock
) -> Generator[_RedisShardedSubscription, None, None]:
"""Create a _RedisShardedSubscription instance for testing."""
subscription = _RedisShardedSubscription(
client=mock_redis_client,
pubsub=mock_pubsub,
topic="test-sharded-topic",
)
@ -634,13 +657,15 @@ class TestRedisShardedSubscription:
# ==================== Lifecycle Tests ====================
def test_sharded_subscription_initialization(self, mock_pubsub: MagicMock):
def test_sharded_subscription_initialization(self, mock_pubsub: MagicMock, mock_redis_client: MagicMock):
"""Test that sharded subscription is properly initialized."""
subscription = _RedisShardedSubscription(
client=mock_redis_client,
pubsub=mock_pubsub,
topic="test-sharded-topic",
)
assert subscription._client is mock_redis_client
assert subscription._pubsub is mock_pubsub
assert subscription._topic == "test-sharded-topic"
assert not subscription._closed.is_set()
@ -808,6 +833,37 @@ class TestRedisShardedSubscription:
assert not sharded_subscription._queue.empty()
assert sharded_subscription._queue.get_nowait() == b"test sharded payload"
def test_get_message_uses_target_node_for_cluster_client(self, mock_pubsub: MagicMock, monkeypatch):
"""Test that cluster clients use target_node for sharded messages."""
class DummyRedisCluster:
def __init__(self):
self.get_node_from_key = MagicMock(return_value="node-1")
monkeypatch.setattr("libs.broadcast_channel.redis.sharded_channel.RedisCluster", DummyRedisCluster)
client = DummyRedisCluster()
subscription = _RedisShardedSubscription(
client=client,
pubsub=mock_pubsub,
topic="test-sharded-topic",
)
mock_pubsub.get_sharded_message.return_value = {
"type": "smessage",
"channel": "test-sharded-topic",
"data": b"payload",
}
result = subscription._get_message()
client.get_node_from_key.assert_called_once_with("test-sharded-topic")
mock_pubsub.get_sharded_message.assert_called_once_with(
ignore_subscribe_messages=False,
timeout=1,
target_node="node-1",
)
assert result == mock_pubsub.get_sharded_message.return_value
def test_listener_thread_ignores_subscribe_messages(
self, sharded_subscription: _RedisShardedSubscription, mock_pubsub: MagicMock
):
@ -913,9 +969,12 @@ class TestRedisShardedSubscription:
),
],
)
def test_sharded_subscription_scenarios(self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock):
def test_sharded_subscription_scenarios(
self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock, mock_redis_client: MagicMock
):
"""Test various sharded subscription scenarios using table-driven approach."""
subscription = _RedisShardedSubscription(
client=mock_redis_client,
pubsub=mock_pubsub,
topic="test-sharded-topic",
)
@ -999,7 +1058,7 @@ class TestRedisShardedSubscription:
# Close should still work
sharded_subscription.close() # Should not raise
def test_channel_name_variations(self, mock_pubsub: MagicMock):
def test_channel_name_variations(self, mock_pubsub: MagicMock, mock_redis_client: MagicMock):
"""Test various sharded channel name formats."""
channel_names = [
"simple",
@ -1013,6 +1072,7 @@ class TestRedisShardedSubscription:
for channel_name in channel_names:
subscription = _RedisShardedSubscription(
client=mock_redis_client,
pubsub=mock_pubsub,
topic=channel_name,
)
@ -1060,6 +1120,11 @@ class TestRedisSubscriptionCommon:
"""Parameterized fixture providing subscription type and class."""
return request.param
@pytest.fixture
def mock_redis_client(self) -> MagicMock:
client = MagicMock()
return client
@pytest.fixture
def mock_pubsub(self) -> MagicMock:
"""Create a mock PubSub instance for testing."""
@ -1075,11 +1140,12 @@ class TestRedisSubscriptionCommon:
return pubsub
@pytest.fixture
def subscription(self, subscription_params, mock_pubsub: MagicMock):
def subscription(self, subscription_params, mock_pubsub: MagicMock, mock_redis_client: MagicMock):
"""Create a subscription instance based on parameterized type."""
subscription_type, subscription_class = subscription_params
topic_name = f"test-{subscription_type}-topic"
subscription = subscription_class(
client=mock_redis_client,
pubsub=mock_pubsub,
topic=topic_name,
)

View File

@ -1,6 +1,8 @@
from datetime import datetime
import pytest
from libs.helper import escape_like_pattern, extract_tenant_id
from libs.helper import OptionalTimestampField, escape_like_pattern, extract_tenant_id
from models.account import Account
from models.model import EndUser
@ -65,6 +67,19 @@ class TestExtractTenantId:
extract_tenant_id(dict_user)
class TestOptionalTimestampField:
def test_format_returns_none_for_none(self):
field = OptionalTimestampField()
assert field.format(None) is None
def test_format_returns_unix_timestamp_for_datetime(self):
field = OptionalTimestampField()
value = datetime(2024, 1, 2, 3, 4, 5)
assert field.format(value) == int(value.timestamp())
class TestEscapeLikePattern:
"""Test cases for the escape_like_pattern utility function."""

View File

@ -0,0 +1,68 @@
from unittest.mock import MagicMock
from libs import helper as helper_module
class _FakeRedis:
def __init__(self) -> None:
self._zsets: dict[str, dict[str, float]] = {}
self._expiry: dict[str, int] = {}
def zadd(self, key: str, mapping: dict[str, float]) -> int:
zset = self._zsets.setdefault(key, {})
for member, score in mapping.items():
zset[str(member)] = float(score)
return len(mapping)
def zremrangebyscore(self, key: str, min_score: str | float, max_score: str | float) -> int:
zset = self._zsets.get(key, {})
min_value = float("-inf") if min_score == "-inf" else float(min_score)
max_value = float("inf") if max_score == "+inf" else float(max_score)
to_delete = [member for member, score in zset.items() if min_value <= score <= max_value]
for member in to_delete:
del zset[member]
return len(to_delete)
def zcard(self, key: str) -> int:
return len(self._zsets.get(key, {}))
def expire(self, key: str, ttl: int) -> bool:
self._expiry[key] = ttl
return True
def test_rate_limiter_counts_attempts_within_same_second(monkeypatch):
fake_redis = _FakeRedis()
monkeypatch.setattr(helper_module.time, "time", lambda: 1000)
limiter = helper_module.RateLimiter(
prefix="test_rate_limit",
max_attempts=2,
time_window=60,
redis_client=fake_redis,
)
limiter.increment_rate_limit("203.0.113.10")
limiter.increment_rate_limit("203.0.113.10")
assert limiter.is_rate_limited("203.0.113.10") is True
def test_rate_limiter_uses_injected_redis(monkeypatch):
redis_client = MagicMock()
redis_client.zcard.return_value = 1
monkeypatch.setattr(helper_module.time, "time", lambda: 1000)
limiter = helper_module.RateLimiter(
prefix="test_rate_limit",
max_attempts=1,
time_window=60,
redis_client=redis_client,
)
limiter.increment_rate_limit("203.0.113.10")
limiter.is_rate_limited("203.0.113.10")
assert redis_client.zadd.called is True
assert redis_client.zremrangebyscore.called is True
assert redis_client.zcard.called is True

View File

@ -1296,6 +1296,7 @@ class TestConversationStatusCount:
assert result["success"] == 1 # One SUCCEEDED
assert result["failed"] == 1 # One FAILED
assert result["partial_success"] == 1 # One PARTIAL_SUCCEEDED
assert result["paused"] == 0
def test_status_count_app_id_filtering(self):
"""Test that status_count filters workflow runs by app_id for security."""
@ -1350,6 +1351,7 @@ class TestConversationStatusCount:
assert result["success"] == 0
assert result["failed"] == 0
assert result["partial_success"] == 0
assert result["paused"] == 0
def test_status_count_handles_invalid_workflow_status(self):
"""Test that status_count gracefully handles invalid workflow status values."""
@ -1404,3 +1406,57 @@ class TestConversationStatusCount:
assert result["success"] == 0
assert result["failed"] == 0
assert result["partial_success"] == 0
assert result["paused"] == 0
def test_status_count_paused(self):
"""Test status_count includes paused workflow runs."""
# Arrange
from core.workflow.enums import WorkflowExecutionStatus
app_id = str(uuid4())
conversation_id = str(uuid4())
workflow_run_id = str(uuid4())
conversation = Conversation(
app_id=app_id,
mode=AppMode.CHAT,
name="Test Conversation",
status="normal",
from_source="api",
)
conversation.id = conversation_id
mock_messages = [
MagicMock(
conversation_id=conversation_id,
workflow_run_id=workflow_run_id,
),
]
mock_workflow_runs = [
MagicMock(
id=workflow_run_id,
status=WorkflowExecutionStatus.PAUSED.value,
app_id=app_id,
),
]
with patch("models.model.db.session.scalars") as mock_scalars:
def mock_scalars_side_effect(query):
mock_result = MagicMock()
if "messages" in str(query):
mock_result.all.return_value = mock_messages
elif "workflow_runs" in str(query):
mock_result.all.return_value = mock_workflow_runs
else:
mock_result.all.return_value = []
return mock_result
mock_scalars.side_effect = mock_scalars_side_effect
# Act
result = conversation.status_count
# Assert
assert result["paused"] == 1

View File

@ -0,0 +1,40 @@
"""Unit tests for DifyAPISQLAlchemyWorkflowNodeExecutionRepository implementation."""
from unittest.mock import Mock
from sqlalchemy.orm import Session, sessionmaker
from repositories.sqlalchemy_api_workflow_node_execution_repository import (
DifyAPISQLAlchemyWorkflowNodeExecutionRepository,
)
class TestDifyAPISQLAlchemyWorkflowNodeExecutionRepository:
def test_get_executions_by_workflow_run_keeps_paused_records(self):
mock_session = Mock(spec=Session)
execute_result = Mock()
execute_result.scalars.return_value.all.return_value = []
mock_session.execute.return_value = execute_result
session_maker = Mock(spec=sessionmaker)
context_manager = Mock()
context_manager.__enter__ = Mock(return_value=mock_session)
context_manager.__exit__ = Mock(return_value=None)
session_maker.return_value = context_manager
repository = DifyAPISQLAlchemyWorkflowNodeExecutionRepository(session_maker)
repository.get_executions_by_workflow_run(
tenant_id="tenant-123",
app_id="app-123",
workflow_run_id="workflow-run-123",
)
stmt = mock_session.execute.call_args[0][0]
where_clauses = list(getattr(stmt, "_where_criteria", []) or [])
where_strs = [str(clause).lower() for clause in where_clauses]
assert any("tenant_id" in clause for clause in where_strs)
assert any("app_id" in clause for clause in where_strs)
assert any("workflow_run_id" in clause for clause in where_strs)
assert not any("paused" in clause for clause in where_strs)

View File

@ -1,5 +1,6 @@
"""Unit tests for DifyAPISQLAlchemyWorkflowRunRepository implementation."""
import secrets
from datetime import UTC, datetime
from unittest.mock import Mock, patch
@ -7,12 +8,17 @@ import pytest
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import Session, sessionmaker
from core.workflow.entities.pause_reason import HumanInputRequired, PauseReasonType
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.nodes.human_input.entities import FormDefinition, FormInput, UserAction
from core.workflow.nodes.human_input.enums import FormInputType, HumanInputFormStatus
from models.human_input import BackstageRecipientPayload, HumanInputForm, HumanInputFormRecipient, RecipientType
from models.workflow import WorkflowPause as WorkflowPauseModel
from models.workflow import WorkflowRun
from models.workflow import WorkflowPauseReason, WorkflowRun
from repositories.entities.workflow_pause import WorkflowPauseEntity
from repositories.sqlalchemy_api_workflow_run_repository import (
DifyAPISQLAlchemyWorkflowRunRepository,
_build_human_input_required_reason,
_PrivateWorkflowPauseEntity,
_WorkflowRunError,
)
@ -205,11 +211,11 @@ class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
):
"""Test workflow pause creation when workflow not in RUNNING status."""
# Arrange
sample_workflow_run.status = WorkflowExecutionStatus.PAUSED
sample_workflow_run.status = WorkflowExecutionStatus.SUCCEEDED
mock_session.get.return_value = sample_workflow_run
# Act & Assert
with pytest.raises(_WorkflowRunError, match="Only WorkflowRun with RUNNING status can be paused"):
with pytest.raises(_WorkflowRunError, match="Only WorkflowRun with RUNNING or PAUSED status can be paused"):
repository.create_workflow_pause(
workflow_run_id="workflow-run-123",
state_owner_user_id="user-123",
@ -295,6 +301,7 @@ class TestResumeWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
sample_workflow_pause.resumed_at = None
mock_session.scalar.return_value = sample_workflow_run
mock_session.scalars.return_value.all.return_value = []
with patch("repositories.sqlalchemy_api_workflow_run_repository.naive_utc_now") as mock_now:
mock_now.return_value = datetime.now(UTC)
@ -455,3 +462,53 @@ class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository)
assert result1 == expected_state
assert result2 == expected_state
mock_storage.load.assert_called_once() # Only called once due to caching
class TestBuildHumanInputRequiredReason:
def test_prefers_backstage_token_when_available(self):
expiration_time = datetime.now(UTC)
form_definition = FormDefinition(
form_content="content",
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")],
user_actions=[UserAction(id="approve", title="Approve")],
rendered_content="rendered",
expiration_time=expiration_time,
default_values={"name": "Alice"},
node_title="Ask Name",
display_in_ui=True,
)
form_model = HumanInputForm(
id="form-1",
tenant_id="tenant-1",
app_id="app-1",
workflow_run_id="run-1",
node_id="node-1",
form_definition=form_definition.model_dump_json(),
rendered_content="rendered",
status=HumanInputFormStatus.WAITING,
expiration_time=expiration_time,
)
reason_model = WorkflowPauseReason(
pause_id="pause-1",
type_=PauseReasonType.HUMAN_INPUT_REQUIRED,
form_id="form-1",
node_id="node-1",
message="",
)
access_token = secrets.token_urlsafe(8)
backstage_recipient = HumanInputFormRecipient(
form_id="form-1",
delivery_id="delivery-1",
recipient_type=RecipientType.BACKSTAGE,
recipient_payload=BackstageRecipientPayload().model_dump_json(),
access_token=access_token,
)
reason = _build_human_input_required_reason(reason_model, form_model, [backstage_recipient])
assert isinstance(reason, HumanInputRequired)
assert reason.form_token == access_token
assert reason.node_title == "Ask Name"
assert reason.form_content == "content"
assert reason.inputs[0].output_variable_name == "name"
assert reason.actions[0].id == "approve"

View File

@ -0,0 +1,180 @@
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from core.entities.execution_extra_content import HumanInputContent as HumanInputContentDomain
from core.entities.execution_extra_content import HumanInputFormSubmissionData
from core.workflow.nodes.human_input.entities import (
FormDefinition,
UserAction,
)
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
from models.execution_extra_content import HumanInputContent as HumanInputContentModel
from models.human_input import ConsoleRecipientPayload, HumanInputForm, HumanInputFormRecipient, RecipientType
from repositories.sqlalchemy_execution_extra_content_repository import SQLAlchemyExecutionExtraContentRepository
class _FakeScalarResult:
def __init__(self, values: Sequence[HumanInputContentModel]):
self._values = list(values)
def all(self) -> list[HumanInputContentModel]:
return list(self._values)
class _FakeSession:
def __init__(self, values: Sequence[Sequence[object]]):
self._values = list(values)
def scalars(self, _stmt):
if not self._values:
return _FakeScalarResult([])
return _FakeScalarResult(self._values.pop(0))
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
@dataclass
class _FakeSessionMaker:
session: _FakeSession
def __call__(self) -> _FakeSession:
return self.session
def _build_form(action_id: str, action_title: str, rendered_content: str) -> HumanInputForm:
expiration_time = datetime.now(UTC) + timedelta(days=1)
definition = FormDefinition(
form_content="content",
inputs=[],
user_actions=[UserAction(id=action_id, title=action_title)],
rendered_content="rendered",
expiration_time=expiration_time,
node_title="Approval",
display_in_ui=True,
)
form = HumanInputForm(
id=f"form-{action_id}",
tenant_id="tenant-id",
app_id="app-id",
workflow_run_id="workflow-run",
node_id="node-id",
form_definition=definition.model_dump_json(),
rendered_content=rendered_content,
status=HumanInputFormStatus.SUBMITTED,
expiration_time=expiration_time,
)
form.selected_action_id = action_id
return form
def _build_content(message_id: str, action_id: str, action_title: str) -> HumanInputContentModel:
form = _build_form(
action_id=action_id,
action_title=action_title,
rendered_content=f"Rendered {action_title}",
)
content = HumanInputContentModel(
id=f"content-{message_id}",
form_id=form.id,
message_id=message_id,
workflow_run_id=form.workflow_run_id,
)
content.form = form
return content
def test_get_by_message_ids_groups_contents_by_message() -> None:
message_ids = ["msg-1", "msg-2"]
contents = [_build_content("msg-1", "approve", "Approve")]
repository = SQLAlchemyExecutionExtraContentRepository(
session_maker=_FakeSessionMaker(session=_FakeSession(values=[contents, []]))
)
result = repository.get_by_message_ids(message_ids)
assert len(result) == 2
assert [content.model_dump(mode="json", exclude_none=True) for content in result[0]] == [
HumanInputContentDomain(
workflow_run_id="workflow-run",
submitted=True,
form_submission_data=HumanInputFormSubmissionData(
node_id="node-id",
node_title="Approval",
rendered_content="Rendered Approve",
action_id="approve",
action_text="Approve",
),
).model_dump(mode="json", exclude_none=True)
]
assert result[1] == []
def test_get_by_message_ids_returns_unsubmitted_form_definition() -> None:
expiration_time = datetime.now(UTC) + timedelta(days=1)
definition = FormDefinition(
form_content="content",
inputs=[],
user_actions=[UserAction(id="approve", title="Approve")],
rendered_content="rendered",
expiration_time=expiration_time,
default_values={"name": "John"},
node_title="Approval",
display_in_ui=True,
)
form = HumanInputForm(
id="form-1",
tenant_id="tenant-id",
app_id="app-id",
workflow_run_id="workflow-run",
node_id="node-id",
form_definition=definition.model_dump_json(),
rendered_content="Rendered block",
status=HumanInputFormStatus.WAITING,
expiration_time=expiration_time,
)
content = HumanInputContentModel(
id="content-msg-1",
form_id=form.id,
message_id="msg-1",
workflow_run_id=form.workflow_run_id,
)
content.form = form
recipient = HumanInputFormRecipient(
form_id=form.id,
delivery_id="delivery-1",
recipient_type=RecipientType.CONSOLE,
recipient_payload=ConsoleRecipientPayload(account_id=None).model_dump_json(),
access_token="token-1",
)
repository = SQLAlchemyExecutionExtraContentRepository(
session_maker=_FakeSessionMaker(session=_FakeSession(values=[[content], [recipient]]))
)
result = repository.get_by_message_ids(["msg-1"])
assert len(result) == 1
assert len(result[0]) == 1
domain_content = result[0][0]
assert domain_content.submitted is False
assert domain_content.workflow_run_id == "workflow-run"
assert domain_content.form_definition is not None
assert domain_content.form_definition.expiration_time == int(form.expiration_time.timestamp())
assert domain_content.form_definition is not None
form_definition = domain_content.form_definition
assert form_definition.form_id == "form-1"
assert form_definition.node_id == "node-id"
assert form_definition.node_title == "Approval"
assert form_definition.form_content == "Rendered block"
assert form_definition.display_in_ui is True
assert form_definition.form_token == "token-1"
assert form_definition.resolved_default_values == {"name": "John"}
assert form_definition.expiration_time == int(form.expiration_time.timestamp())

View File

@ -0,0 +1,65 @@
from unittest.mock import MagicMock
import services.app_generate_service as app_generate_service_module
from models.model import AppMode
from services.app_generate_service import AppGenerateService
class _DummyRateLimit:
def __init__(self, client_id: str, max_active_requests: int) -> None:
self.client_id = client_id
self.max_active_requests = max_active_requests
@staticmethod
def gen_request_key() -> str:
return "dummy-request-id"
def enter(self, request_id: str | None = None) -> str:
return request_id or "dummy-request-id"
def exit(self, request_id: str) -> None:
return None
def generate(self, generator, request_id: str):
return generator
def test_workflow_blocking_injects_pause_state_config(mocker, monkeypatch):
monkeypatch.setattr(app_generate_service_module.dify_config, "BILLING_ENABLED", False)
mocker.patch("services.app_generate_service.RateLimit", _DummyRateLimit)
workflow = MagicMock()
workflow.id = "workflow-id"
workflow.created_by = "owner-id"
mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow)
generator_spy = mocker.patch(
"services.app_generate_service.WorkflowAppGenerator.generate",
return_value={"result": "ok"},
)
app_model = MagicMock()
app_model.mode = AppMode.WORKFLOW
app_model.id = "app-id"
app_model.tenant_id = "tenant-id"
app_model.max_active_requests = 0
app_model.is_agent = False
user = MagicMock()
user.id = "user-id"
result = AppGenerateService.generate(
app_model=app_model,
user=user,
args={"inputs": {"k": "v"}},
invoke_from=MagicMock(),
streaming=False,
)
assert result == {"result": "ok"}
call_kwargs = generator_spy.call_args.kwargs
pause_state_config = call_kwargs.get("pause_state_config")
assert pause_state_config is not None
assert pause_state_config.state_owner_user_id == "owner-id"

View File

@ -508,9 +508,12 @@ class TestConversationServiceMessageCreation:
within conversations.
"""
@patch("services.message_service._create_execution_extra_content_repository")
@patch("services.message_service.db.session")
@patch("services.message_service.ConversationService.get_conversation")
def test_pagination_by_first_id_without_first_id(self, mock_get_conversation, mock_db_session):
def test_pagination_by_first_id_without_first_id(
self, mock_get_conversation, mock_db_session, mock_create_extra_repo
):
"""
Test message pagination without specifying first_id.
@ -540,6 +543,9 @@ class TestConversationServiceMessageCreation:
mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining
mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
mock_query.all.return_value = messages # Final .all() returns the messages
mock_repository = MagicMock()
mock_repository.get_by_message_ids.return_value = [[] for _ in messages]
mock_create_extra_repo.return_value = mock_repository
# Act - Call the pagination method without first_id
result = MessageService.pagination_by_first_id(
@ -556,9 +562,10 @@ class TestConversationServiceMessageCreation:
# Verify conversation was looked up with correct parameters
mock_get_conversation.assert_called_once_with(app_model=app_model, user=user, conversation_id=conversation.id)
@patch("services.message_service._create_execution_extra_content_repository")
@patch("services.message_service.db.session")
@patch("services.message_service.ConversationService.get_conversation")
def test_pagination_by_first_id_with_first_id(self, mock_get_conversation, mock_db_session):
def test_pagination_by_first_id_with_first_id(self, mock_get_conversation, mock_db_session, mock_create_extra_repo):
"""
Test message pagination with first_id specified.
@ -590,6 +597,9 @@ class TestConversationServiceMessageCreation:
mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
mock_query.first.return_value = first_message # First message returned
mock_query.all.return_value = messages # Remaining messages returned
mock_repository = MagicMock()
mock_repository.get_by_message_ids.return_value = [[] for _ in messages]
mock_create_extra_repo.return_value = mock_repository
# Act - Call the pagination method with first_id
result = MessageService.pagination_by_first_id(
@ -684,9 +694,10 @@ class TestConversationServiceMessageCreation:
assert result.data == []
assert result.has_more is False
@patch("services.message_service._create_execution_extra_content_repository")
@patch("services.message_service.db.session")
@patch("services.message_service.ConversationService.get_conversation")
def test_pagination_with_has_more_flag(self, mock_get_conversation, mock_db_session):
def test_pagination_with_has_more_flag(self, mock_get_conversation, mock_db_session, mock_create_extra_repo):
"""
Test that has_more flag is correctly set when there are more messages.
@ -716,6 +727,9 @@ class TestConversationServiceMessageCreation:
mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining
mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
mock_query.all.return_value = messages # Final .all() returns the messages
mock_repository = MagicMock()
mock_repository.get_by_message_ids.return_value = [[] for _ in messages]
mock_create_extra_repo.return_value = mock_repository
# Act
result = MessageService.pagination_by_first_id(
@ -730,9 +744,10 @@ class TestConversationServiceMessageCreation:
assert len(result.data) == limit # Extra message should be removed
assert result.has_more is True # Flag should be set
@patch("services.message_service._create_execution_extra_content_repository")
@patch("services.message_service.db.session")
@patch("services.message_service.ConversationService.get_conversation")
def test_pagination_with_ascending_order(self, mock_get_conversation, mock_db_session):
def test_pagination_with_ascending_order(self, mock_get_conversation, mock_db_session, mock_create_extra_repo):
"""
Test message pagination with ascending order.
@ -761,6 +776,9 @@ class TestConversationServiceMessageCreation:
mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining
mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
mock_query.all.return_value = messages # Final .all() returns the messages
mock_repository = MagicMock()
mock_repository.get_by_message_ids.return_value = [[] for _ in messages]
mock_create_extra_repo.return_value = mock_repository
# Act
result = MessageService.pagination_by_first_id(

View File

@ -0,0 +1,104 @@
from dataclasses import dataclass
import pytest
from enums.cloud_plan import CloudPlan
from services import feature_service as feature_service_module
from services.feature_service import FeatureModel, FeatureService
@dataclass(frozen=True)
class HumanInputEmailDeliveryCase:
name: str
enterprise_enabled: bool
billing_enabled: bool
tenant_id: str | None
billing_feature_enabled: bool
plan: str
expected: bool
CASES = [
HumanInputEmailDeliveryCase(
name="enterprise_enabled",
enterprise_enabled=True,
billing_enabled=True,
tenant_id=None,
billing_feature_enabled=False,
plan=CloudPlan.SANDBOX,
expected=True,
),
HumanInputEmailDeliveryCase(
name="billing_disabled",
enterprise_enabled=False,
billing_enabled=False,
tenant_id=None,
billing_feature_enabled=False,
plan=CloudPlan.SANDBOX,
expected=True,
),
HumanInputEmailDeliveryCase(
name="billing_enabled_requires_tenant",
enterprise_enabled=False,
billing_enabled=True,
tenant_id=None,
billing_feature_enabled=True,
plan=CloudPlan.PROFESSIONAL,
expected=False,
),
HumanInputEmailDeliveryCase(
name="billing_feature_off",
enterprise_enabled=False,
billing_enabled=True,
tenant_id="tenant-1",
billing_feature_enabled=False,
plan=CloudPlan.PROFESSIONAL,
expected=False,
),
HumanInputEmailDeliveryCase(
name="professional_plan",
enterprise_enabled=False,
billing_enabled=True,
tenant_id="tenant-1",
billing_feature_enabled=True,
plan=CloudPlan.PROFESSIONAL,
expected=True,
),
HumanInputEmailDeliveryCase(
name="team_plan",
enterprise_enabled=False,
billing_enabled=True,
tenant_id="tenant-1",
billing_feature_enabled=True,
plan=CloudPlan.TEAM,
expected=True,
),
HumanInputEmailDeliveryCase(
name="sandbox_plan",
enterprise_enabled=False,
billing_enabled=True,
tenant_id="tenant-1",
billing_feature_enabled=True,
plan=CloudPlan.SANDBOX,
expected=False,
),
]
@pytest.mark.parametrize("case", CASES, ids=lambda case: case.name)
def test_resolve_human_input_email_delivery_enabled_matrix(
monkeypatch: pytest.MonkeyPatch,
case: HumanInputEmailDeliveryCase,
):
monkeypatch.setattr(feature_service_module.dify_config, "ENTERPRISE_ENABLED", case.enterprise_enabled)
monkeypatch.setattr(feature_service_module.dify_config, "BILLING_ENABLED", case.billing_enabled)
features = FeatureModel()
features.billing.enabled = case.billing_feature_enabled
features.billing.subscription.plan = case.plan
result = FeatureService._resolve_human_input_email_delivery_enabled(
features=features,
tenant_id=case.tenant_id,
)
assert result is case.expected

View File

@ -0,0 +1,97 @@
from types import SimpleNamespace
import pytest
from core.workflow.nodes.human_input.entities import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,
ExternalRecipient,
)
from core.workflow.runtime import VariablePool
from services import human_input_delivery_test_service as service_module
from services.human_input_delivery_test_service import (
DeliveryTestContext,
DeliveryTestError,
EmailDeliveryTestHandler,
)
def _make_email_method() -> EmailDeliveryMethod:
return EmailDeliveryMethod(
config=EmailDeliveryConfig(
recipients=EmailRecipients(
whole_workspace=False,
items=[ExternalRecipient(email="tester@example.com")],
),
subject="Test subject",
body="Test body",
)
)
def test_email_delivery_test_handler_rejects_when_feature_disabled(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(
service_module.FeatureService,
"get_features",
lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=False),
)
handler = EmailDeliveryTestHandler(session_factory=object())
context = DeliveryTestContext(
tenant_id="tenant-1",
app_id="app-1",
node_id="node-1",
node_title="Human Input",
rendered_content="content",
)
method = _make_email_method()
with pytest.raises(DeliveryTestError, match="Email delivery is not available"):
handler.send_test(context=context, method=method)
def test_email_delivery_test_handler_replaces_body_variables(monkeypatch: pytest.MonkeyPatch):
class DummyMail:
def __init__(self):
self.sent: list[dict[str, str]] = []
def is_inited(self) -> bool:
return True
def send(self, *, to: str, subject: str, html: str):
self.sent.append({"to": to, "subject": subject, "html": html})
mail = DummyMail()
monkeypatch.setattr(service_module, "mail", mail)
monkeypatch.setattr(service_module, "render_email_template", lambda template, _substitutions: template)
monkeypatch.setattr(
service_module.FeatureService,
"get_features",
lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=True),
)
handler = EmailDeliveryTestHandler(session_factory=object())
handler._resolve_recipients = lambda **_kwargs: ["tester@example.com"] # type: ignore[assignment]
method = EmailDeliveryMethod(
config=EmailDeliveryConfig(
recipients=EmailRecipients(whole_workspace=False, items=[ExternalRecipient(email="tester@example.com")]),
subject="Subject",
body="Value {{#node1.value#}}",
)
)
variable_pool = VariablePool()
variable_pool.add(["node1", "value"], "OK")
context = DeliveryTestContext(
tenant_id="tenant-1",
app_id="app-1",
node_id="node-1",
node_title="Human Input",
rendered_content="content",
variable_pool=variable_pool,
)
handler.send_test(context=context, method=method)
assert mail.sent[0]["html"] == "Value OK"

View File

@ -0,0 +1,290 @@
import dataclasses
from datetime import datetime, timedelta
from unittest.mock import MagicMock
import pytest
import services.human_input_service as human_input_service_module
from core.repositories.human_input_repository import (
HumanInputFormRecord,
HumanInputFormSubmissionRepository,
)
from core.workflow.nodes.human_input.entities import (
FormDefinition,
FormInput,
UserAction,
)
from core.workflow.nodes.human_input.enums import FormInputType, HumanInputFormKind, HumanInputFormStatus
from models.human_input import RecipientType
from services.human_input_service import Form, FormExpiredError, HumanInputService, InvalidFormDataError
from tasks.app_generate.workflow_execute_task import WORKFLOW_BASED_APP_EXECUTION_QUEUE
@pytest.fixture
def mock_session_factory():
session = MagicMock()
session_cm = MagicMock()
session_cm.__enter__.return_value = session
session_cm.__exit__.return_value = None
factory = MagicMock()
factory.return_value = session_cm
return factory, session
@pytest.fixture
def sample_form_record():
return HumanInputFormRecord(
form_id="form-id",
workflow_run_id="workflow-run-id",
node_id="node-id",
tenant_id="tenant-id",
app_id="app-id",
form_kind=HumanInputFormKind.RUNTIME,
definition=FormDefinition(
form_content="hello",
inputs=[],
user_actions=[UserAction(id="submit", title="Submit")],
rendered_content="<p>hello</p>",
expiration_time=datetime.utcnow() + timedelta(hours=1),
),
rendered_content="<p>hello</p>",
created_at=datetime.utcnow(),
expiration_time=datetime.utcnow() + timedelta(hours=1),
status=HumanInputFormStatus.WAITING,
selected_action_id=None,
submitted_data=None,
submitted_at=None,
submission_user_id=None,
submission_end_user_id=None,
completed_by_recipient_id=None,
recipient_id="recipient-id",
recipient_type=RecipientType.STANDALONE_WEB_APP,
access_token="token",
)
def test_enqueue_resume_dispatches_task_for_workflow(mocker, mock_session_factory):
session_factory, session = mock_session_factory
service = HumanInputService(session_factory)
workflow_run = MagicMock()
workflow_run.app_id = "app-id"
workflow_run_repo = MagicMock()
workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run
mocker.patch(
"services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository",
return_value=workflow_run_repo,
)
app = MagicMock()
app.mode = "workflow"
session.execute.return_value.scalar_one_or_none.return_value = app
resume_task = mocker.patch("services.human_input_service.resume_app_execution")
service.enqueue_resume("workflow-run-id")
resume_task.apply_async.assert_called_once()
call_kwargs = resume_task.apply_async.call_args.kwargs
assert call_kwargs["queue"] == WORKFLOW_BASED_APP_EXECUTION_QUEUE
assert call_kwargs["kwargs"]["payload"]["workflow_run_id"] == "workflow-run-id"
def test_ensure_form_active_respects_global_timeout(monkeypatch, sample_form_record, mock_session_factory):
session_factory, _ = mock_session_factory
service = HumanInputService(session_factory)
expired_record = dataclasses.replace(
sample_form_record,
created_at=datetime.utcnow() - timedelta(hours=2),
expiration_time=datetime.utcnow() + timedelta(hours=2),
)
monkeypatch.setattr(human_input_service_module.dify_config, "HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS", 3600)
with pytest.raises(FormExpiredError):
service.ensure_form_active(Form(expired_record))
def test_enqueue_resume_dispatches_task_for_advanced_chat(mocker, mock_session_factory):
session_factory, session = mock_session_factory
service = HumanInputService(session_factory)
workflow_run = MagicMock()
workflow_run.app_id = "app-id"
workflow_run_repo = MagicMock()
workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run
mocker.patch(
"services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository",
return_value=workflow_run_repo,
)
app = MagicMock()
app.mode = "advanced-chat"
session.execute.return_value.scalar_one_or_none.return_value = app
resume_task = mocker.patch("services.human_input_service.resume_app_execution")
service.enqueue_resume("workflow-run-id")
resume_task.apply_async.assert_called_once()
call_kwargs = resume_task.apply_async.call_args.kwargs
assert call_kwargs["queue"] == WORKFLOW_BASED_APP_EXECUTION_QUEUE
assert call_kwargs["kwargs"]["payload"]["workflow_run_id"] == "workflow-run-id"
def test_enqueue_resume_skips_unsupported_app_mode(mocker, mock_session_factory):
session_factory, session = mock_session_factory
service = HumanInputService(session_factory)
workflow_run = MagicMock()
workflow_run.app_id = "app-id"
workflow_run_repo = MagicMock()
workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run
mocker.patch(
"services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository",
return_value=workflow_run_repo,
)
app = MagicMock()
app.mode = "completion"
session.execute.return_value.scalar_one_or_none.return_value = app
resume_task = mocker.patch("services.human_input_service.resume_app_execution")
service.enqueue_resume("workflow-run-id")
resume_task.apply_async.assert_not_called()
def test_get_form_definition_by_token_for_console_uses_repository(sample_form_record, mock_session_factory):
session_factory, _ = mock_session_factory
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
console_record = dataclasses.replace(sample_form_record, recipient_type=RecipientType.CONSOLE)
repo.get_by_token.return_value = console_record
service = HumanInputService(session_factory, form_repository=repo)
form = service.get_form_definition_by_token_for_console("token")
repo.get_by_token.assert_called_once_with("token")
assert form is not None
assert form.get_definition() == console_record.definition
def test_submit_form_by_token_calls_repository_and_enqueue(sample_form_record, mock_session_factory, mocker):
session_factory, _ = mock_session_factory
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
repo.get_by_token.return_value = sample_form_record
repo.mark_submitted.return_value = sample_form_record
service = HumanInputService(session_factory, form_repository=repo)
enqueue_spy = mocker.patch.object(service, "enqueue_resume")
service.submit_form_by_token(
recipient_type=RecipientType.STANDALONE_WEB_APP,
form_token="token",
selected_action_id="submit",
form_data={"field": "value"},
submission_end_user_id="end-user-id",
)
repo.get_by_token.assert_called_once_with("token")
repo.mark_submitted.assert_called_once()
call_kwargs = repo.mark_submitted.call_args.kwargs
assert call_kwargs["form_id"] == sample_form_record.form_id
assert call_kwargs["recipient_id"] == sample_form_record.recipient_id
assert call_kwargs["selected_action_id"] == "submit"
assert call_kwargs["form_data"] == {"field": "value"}
assert call_kwargs["submission_end_user_id"] == "end-user-id"
enqueue_spy.assert_called_once_with(sample_form_record.workflow_run_id)
def test_submit_form_by_token_skips_enqueue_for_delivery_test(sample_form_record, mock_session_factory, mocker):
session_factory, _ = mock_session_factory
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
test_record = dataclasses.replace(
sample_form_record,
form_kind=HumanInputFormKind.DELIVERY_TEST,
workflow_run_id=None,
)
repo.get_by_token.return_value = test_record
repo.mark_submitted.return_value = test_record
service = HumanInputService(session_factory, form_repository=repo)
enqueue_spy = mocker.patch.object(service, "enqueue_resume")
service.submit_form_by_token(
recipient_type=RecipientType.STANDALONE_WEB_APP,
form_token="token",
selected_action_id="submit",
form_data={"field": "value"},
)
enqueue_spy.assert_not_called()
def test_submit_form_by_token_passes_submission_user_id(sample_form_record, mock_session_factory, mocker):
session_factory, _ = mock_session_factory
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
repo.get_by_token.return_value = sample_form_record
repo.mark_submitted.return_value = sample_form_record
service = HumanInputService(session_factory, form_repository=repo)
enqueue_spy = mocker.patch.object(service, "enqueue_resume")
service.submit_form_by_token(
recipient_type=RecipientType.STANDALONE_WEB_APP,
form_token="token",
selected_action_id="submit",
form_data={"field": "value"},
submission_user_id="account-id",
)
call_kwargs = repo.mark_submitted.call_args.kwargs
assert call_kwargs["submission_user_id"] == "account-id"
assert call_kwargs["submission_end_user_id"] is None
enqueue_spy.assert_called_once_with(sample_form_record.workflow_run_id)
def test_submit_form_by_token_invalid_action(sample_form_record, mock_session_factory):
session_factory, _ = mock_session_factory
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
repo.get_by_token.return_value = dataclasses.replace(sample_form_record)
service = HumanInputService(session_factory, form_repository=repo)
with pytest.raises(InvalidFormDataError) as exc_info:
service.submit_form_by_token(
recipient_type=RecipientType.STANDALONE_WEB_APP,
form_token="token",
selected_action_id="invalid",
form_data={},
)
assert "Invalid action" in str(exc_info.value)
repo.mark_submitted.assert_not_called()
def test_submit_form_by_token_missing_inputs(sample_form_record, mock_session_factory):
session_factory, _ = mock_session_factory
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
definition_with_input = FormDefinition(
form_content="hello",
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="content")],
user_actions=sample_form_record.definition.user_actions,
rendered_content="<p>hello</p>",
expiration_time=sample_form_record.expiration_time,
)
form_with_input = dataclasses.replace(sample_form_record, definition=definition_with_input)
repo.get_by_token.return_value = form_with_input
service = HumanInputService(session_factory, form_repository=repo)
with pytest.raises(InvalidFormDataError) as exc_info:
service.submit_form_by_token(
recipient_type=RecipientType.STANDALONE_WEB_APP,
form_token="token",
selected_action_id="submit",
form_data={},
)
assert "Missing required inputs" in str(exc_info.value)
repo.mark_submitted.assert_not_called()

View File

@ -0,0 +1,61 @@
from __future__ import annotations
import pytest
from core.entities.execution_extra_content import HumanInputContent, HumanInputFormSubmissionData
from services import message_service
class _FakeMessage:
def __init__(self, message_id: str):
self.id = message_id
self.extra_contents = None
def set_extra_contents(self, contents):
self.extra_contents = contents
def test_attach_message_extra_contents_assigns_serialized_payload(monkeypatch: pytest.MonkeyPatch) -> None:
messages = [_FakeMessage("msg-1"), _FakeMessage("msg-2")]
repo = type(
"Repo",
(),
{
"get_by_message_ids": lambda _self, message_ids: [
[
HumanInputContent(
workflow_run_id="workflow-run-1",
submitted=True,
form_submission_data=HumanInputFormSubmissionData(
node_id="node-1",
node_title="Approval",
rendered_content="Rendered",
action_id="approve",
action_text="Approve",
),
)
],
[],
]
},
)()
monkeypatch.setattr(message_service, "_create_execution_extra_content_repository", lambda: repo)
message_service.attach_message_extra_contents(messages)
assert messages[0].extra_contents == [
{
"type": "human_input",
"workflow_run_id": "workflow-run-1",
"submitted": True,
"form_submission_data": {
"node_id": "node-1",
"node_title": "Approval",
"rendered_content": "Rendered",
"action_id": "approve",
"action_text": "Approve",
},
}
]
assert messages[1].extra_contents == []

View File

@ -35,7 +35,6 @@ class TestDataFactory:
app_id: str = "app-789",
workflow_id: str = "workflow-101",
status: str | WorkflowExecutionStatus = "paused",
pause_id: str | None = None,
**kwargs,
) -> MagicMock:
"""Create a mock WorkflowRun object."""
@ -45,7 +44,6 @@ class TestDataFactory:
mock_run.app_id = app_id
mock_run.workflow_id = workflow_id
mock_run.status = status
mock_run.pause_id = pause_id
for key, value in kwargs.items():
setattr(mock_run, key, value)

View File

@ -0,0 +1,162 @@
import json
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from core.tools.entities.tool_entities import ToolParameter, WorkflowToolParameterConfiguration
from core.tools.errors import WorkflowToolHumanInputNotSupportedError
from models.model import App
from models.tools import WorkflowToolProvider
from services.tools import workflow_tools_manage_service
class DummyWorkflow:
def __init__(self, graph_dict: dict, version: str = "1.0.0") -> None:
self._graph_dict = graph_dict
self.version = version
@property
def graph_dict(self) -> dict:
return self._graph_dict
class FakeQuery:
def __init__(self, result):
self._result = result
def where(self, *args, **kwargs):
return self
def first(self):
return self._result
class DummySession:
def __init__(self) -> None:
self.added: list[object] = []
def __enter__(self) -> "DummySession":
return self
def __exit__(self, exc_type, exc, tb) -> bool:
return False
def add(self, obj) -> None:
self.added.append(obj)
def begin(self):
return DummyBegin(self)
class DummyBegin:
def __init__(self, session: DummySession) -> None:
self._session = session
def __enter__(self) -> DummySession:
return self._session
def __exit__(self, exc_type, exc, tb) -> bool:
return False
class DummySessionContext:
def __init__(self, session: DummySession) -> None:
self._session = session
def __enter__(self) -> DummySession:
return self._session
def __exit__(self, exc_type, exc, tb) -> bool:
return False
class DummySessionFactory:
def __init__(self, session: DummySession) -> None:
self._session = session
def create_session(self) -> DummySessionContext:
return DummySessionContext(self._session)
def _build_fake_session(app) -> SimpleNamespace:
def query(model):
if model is WorkflowToolProvider:
return FakeQuery(None)
if model is App:
return FakeQuery(app)
return FakeQuery(None)
return SimpleNamespace(query=query)
def _build_parameters() -> list[WorkflowToolParameterConfiguration]:
return [
WorkflowToolParameterConfiguration(name="input", description="input", form=ToolParameter.ToolParameterForm.LLM),
]
def test_create_workflow_tool_rejects_human_input_nodes(monkeypatch):
workflow = DummyWorkflow(graph_dict={"nodes": [{"id": "node_1", "data": {"type": "human-input"}}]})
app = SimpleNamespace(workflow=workflow)
fake_session = _build_fake_session(app)
monkeypatch.setattr(workflow_tools_manage_service.db, "session", fake_session)
mock_from_db = MagicMock()
monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", mock_from_db)
mock_invalidate = MagicMock()
with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info:
workflow_tools_manage_service.WorkflowToolManageService.create_workflow_tool(
user_id="user-id",
tenant_id="tenant-id",
workflow_app_id="app-id",
name="tool_name",
label="Tool",
icon={"type": "emoji", "emoji": "tool"},
description="desc",
parameters=_build_parameters(),
)
assert exc_info.value.error_code == "workflow_tool_human_input_not_supported"
mock_from_db.assert_not_called()
mock_invalidate.assert_not_called()
def test_create_workflow_tool_success(monkeypatch):
workflow = DummyWorkflow(graph_dict={"nodes": [{"id": "node_1", "data": {"type": "start"}}]})
app = SimpleNamespace(workflow=workflow)
fake_db = MagicMock()
fake_session = _build_fake_session(app)
fake_db.session = fake_session
monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db)
dummy_session = DummySession()
monkeypatch.setattr(workflow_tools_manage_service, "Session", lambda *_, **__: dummy_session)
mock_from_db = MagicMock()
monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", mock_from_db)
icon = {"type": "emoji", "emoji": "tool"}
result = workflow_tools_manage_service.WorkflowToolManageService.create_workflow_tool(
user_id="user-id",
tenant_id="tenant-id",
workflow_app_id="app-id",
name="tool_name",
label="Tool",
icon=icon,
description="desc",
parameters=_build_parameters(),
)
assert result == {"result": "success"}
assert len(dummy_session.added) == 1
created_provider = dummy_session.added[0]
assert created_provider.name == "tool_name"
assert created_provider.label == "Tool"
assert created_provider.icon == json.dumps(icon)
assert created_provider.version == workflow.version
mock_from_db.assert_called_once()

View File

@ -0,0 +1,226 @@
from __future__ import annotations
import json
import queue
from collections.abc import Sequence
from dataclasses import dataclass
from datetime import UTC, datetime
from threading import Event
import pytest
from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus
from core.workflow.runtime import GraphRuntimeState, VariablePool
from models.enums import CreatorUserRole
from models.model import AppMode
from models.workflow import WorkflowRun
from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot
from repositories.entities.workflow_pause import WorkflowPauseEntity
from services.workflow_event_snapshot_service import (
BufferState,
MessageContext,
_build_snapshot_events,
_resolve_task_id,
)
@dataclass(frozen=True)
class _FakePauseEntity(WorkflowPauseEntity):
pause_id: str
workflow_run_id: str
paused_at_value: datetime
pause_reasons: Sequence[HumanInputRequired]
@property
def id(self) -> str:
return self.pause_id
@property
def workflow_execution_id(self) -> str:
return self.workflow_run_id
def get_state(self) -> bytes:
raise AssertionError("state is not required for snapshot tests")
@property
def resumed_at(self) -> datetime | None:
return None
@property
def paused_at(self) -> datetime:
return self.paused_at_value
def get_pause_reasons(self) -> Sequence[HumanInputRequired]:
return self.pause_reasons
def _build_workflow_run(status: WorkflowExecutionStatus) -> WorkflowRun:
return WorkflowRun(
id="run-1",
tenant_id="tenant-1",
app_id="app-1",
workflow_id="workflow-1",
type="workflow",
triggered_from="app-run",
version="v1",
graph=None,
inputs=json.dumps({"input": "value"}),
status=status,
outputs=json.dumps({}),
error=None,
elapsed_time=0.0,
total_tokens=0,
total_steps=0,
created_by_role=CreatorUserRole.END_USER,
created_by="user-1",
created_at=datetime(2024, 1, 1, tzinfo=UTC),
)
def _build_snapshot(status: WorkflowNodeExecutionStatus) -> WorkflowNodeExecutionSnapshot:
created_at = datetime(2024, 1, 1, tzinfo=UTC)
finished_at = datetime(2024, 1, 1, 0, 0, 5, tzinfo=UTC)
return WorkflowNodeExecutionSnapshot(
execution_id="exec-1",
node_id="node-1",
node_type="human-input",
title="Human Input",
index=1,
status=status.value,
elapsed_time=0.5,
created_at=created_at,
finished_at=finished_at,
iteration_id=None,
loop_id=None,
)
def _build_resumption_context(task_id: str) -> WorkflowResumptionContext:
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant-1",
app_id="app-1",
app_mode=AppMode.WORKFLOW,
workflow_id="workflow-1",
)
generate_entity = WorkflowAppGenerateEntity(
task_id=task_id,
app_config=app_config,
inputs={},
files=[],
user_id="user-1",
stream=True,
invoke_from=InvokeFrom.EXPLORE,
call_depth=0,
workflow_execution_id="run-1",
)
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
runtime_state.register_paused_node("node-1")
runtime_state.outputs = {"result": "value"}
wrapper = _WorkflowGenerateEntityWrapper(entity=generate_entity)
return WorkflowResumptionContext(
generate_entity=wrapper,
serialized_graph_runtime_state=runtime_state.dumps(),
)
def test_build_snapshot_events_includes_pause_event() -> None:
workflow_run = _build_workflow_run(WorkflowExecutionStatus.PAUSED)
snapshot = _build_snapshot(WorkflowNodeExecutionStatus.PAUSED)
resumption_context = _build_resumption_context("task-ctx")
pause_entity = _FakePauseEntity(
pause_id="pause-1",
workflow_run_id="run-1",
paused_at_value=datetime(2024, 1, 1, tzinfo=UTC),
pause_reasons=[
HumanInputRequired(
form_id="form-1",
form_content="content",
node_id="node-1",
node_title="Human Input",
)
],
)
events = _build_snapshot_events(
workflow_run=workflow_run,
node_snapshots=[snapshot],
task_id="task-ctx",
message_context=None,
pause_entity=pause_entity,
resumption_context=resumption_context,
)
assert [event["event"] for event in events] == [
"workflow_started",
"node_started",
"node_finished",
"workflow_paused",
]
assert events[2]["data"]["status"] == WorkflowNodeExecutionStatus.PAUSED.value
pause_data = events[-1]["data"]
assert pause_data["paused_nodes"] == ["node-1"]
assert pause_data["outputs"] == {"result": "value"}
assert pause_data["status"] == WorkflowExecutionStatus.PAUSED.value
assert pause_data["created_at"] == int(workflow_run.created_at.timestamp())
assert pause_data["elapsed_time"] == workflow_run.elapsed_time
assert pause_data["total_tokens"] == workflow_run.total_tokens
assert pause_data["total_steps"] == workflow_run.total_steps
def test_build_snapshot_events_applies_message_context() -> None:
workflow_run = _build_workflow_run(WorkflowExecutionStatus.RUNNING)
snapshot = _build_snapshot(WorkflowNodeExecutionStatus.SUCCEEDED)
message_context = MessageContext(
conversation_id="conv-1",
message_id="msg-1",
created_at=1700000000,
answer="snapshot message",
)
events = _build_snapshot_events(
workflow_run=workflow_run,
node_snapshots=[snapshot],
task_id="task-1",
message_context=message_context,
pause_entity=None,
resumption_context=None,
)
assert [event["event"] for event in events] == [
"workflow_started",
"message_replace",
"node_started",
"node_finished",
]
assert events[1]["answer"] == "snapshot message"
for event in events:
assert event["conversation_id"] == "conv-1"
assert event["message_id"] == "msg-1"
assert event["created_at"] == 1700000000
@pytest.mark.parametrize(
("context_task_id", "buffered_task_id", "expected"),
[
("task-ctx", "task-buffer", "task-ctx"),
(None, "task-buffer", "task-buffer"),
(None, None, "run-1"),
],
)
def test_resolve_task_id_priority(context_task_id, buffered_task_id, expected) -> None:
resumption_context = _build_resumption_context(context_task_id) if context_task_id else None
buffer_state = BufferState(
queue=queue.Queue(),
stop_event=Event(),
done_event=Event(),
task_id_ready=Event(),
task_id_hint=buffered_task_id,
)
if buffered_task_id:
buffer_state.task_id_ready.set()
task_id = _resolve_task_id(resumption_context, buffer_state, "run-1", wait_timeout=0.0)
assert task_id == expected

View File

@ -0,0 +1,184 @@
import uuid
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from sqlalchemy.orm import sessionmaker
from core.workflow.enums import NodeType
from core.workflow.nodes.human_input.entities import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,
ExternalRecipient,
HumanInputNodeData,
MemberRecipient,
)
from services import workflow_service as workflow_service_module
from services.workflow_service import WorkflowService
def _make_service() -> WorkflowService:
return WorkflowService(session_maker=sessionmaker())
def _build_node_config(delivery_methods):
node_data = HumanInputNodeData(
title="Human Input",
delivery_methods=delivery_methods,
form_content="Test content",
inputs=[],
user_actions=[],
).model_dump(mode="json")
node_data["type"] = NodeType.HUMAN_INPUT.value
return {"id": "node-1", "data": node_data}
def _make_email_method(enabled: bool = True, debug_mode: bool = False) -> EmailDeliveryMethod:
return EmailDeliveryMethod(
id=uuid.uuid4(),
enabled=enabled,
config=EmailDeliveryConfig(
recipients=EmailRecipients(
whole_workspace=False,
items=[ExternalRecipient(email="tester@example.com")],
),
subject="Test subject",
body="Test body",
debug_mode=debug_mode,
),
)
def test_human_input_delivery_requires_draft_workflow():
service = _make_service()
service.get_draft_workflow = MagicMock(return_value=None) # type: ignore[method-assign]
app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1")
account = SimpleNamespace(id="account-1")
with pytest.raises(ValueError, match="Workflow not initialized"):
service.test_human_input_delivery(
app_model=app_model,
account=account,
node_id="node-1",
delivery_method_id="delivery-1",
)
def test_human_input_delivery_allows_disabled_method(monkeypatch: pytest.MonkeyPatch):
service = _make_service()
delivery_method = _make_email_method(enabled=False)
node_config = _build_node_config([delivery_method])
workflow = MagicMock()
workflow.get_node_config_by_id.return_value = node_config
service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign]
service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[attr-defined]
node_stub = MagicMock()
node_stub._render_form_content_before_submission.return_value = "rendered"
node_stub._resolve_default_values.return_value = {}
service._build_human_input_node = MagicMock(return_value=node_stub) # type: ignore[attr-defined]
service._create_human_input_delivery_test_form = MagicMock( # type: ignore[attr-defined]
return_value=("form-1", {})
)
test_service_instance = MagicMock()
monkeypatch.setattr(
workflow_service_module,
"HumanInputDeliveryTestService",
MagicMock(return_value=test_service_instance),
)
app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1")
account = SimpleNamespace(id="account-1")
service.test_human_input_delivery(
app_model=app_model,
account=account,
node_id="node-1",
delivery_method_id=str(delivery_method.id),
)
test_service_instance.send_test.assert_called_once()
def test_human_input_delivery_dispatches_to_test_service(monkeypatch: pytest.MonkeyPatch):
service = _make_service()
delivery_method = _make_email_method(enabled=True)
node_config = _build_node_config([delivery_method])
workflow = MagicMock()
workflow.get_node_config_by_id.return_value = node_config
service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign]
service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[attr-defined]
node_stub = MagicMock()
node_stub._render_form_content_before_submission.return_value = "rendered"
node_stub._resolve_default_values.return_value = {}
service._build_human_input_node = MagicMock(return_value=node_stub) # type: ignore[attr-defined]
service._create_human_input_delivery_test_form = MagicMock( # type: ignore[attr-defined]
return_value=("form-1", {})
)
test_service_instance = MagicMock()
monkeypatch.setattr(
workflow_service_module,
"HumanInputDeliveryTestService",
MagicMock(return_value=test_service_instance),
)
app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1")
account = SimpleNamespace(id="account-1")
service.test_human_input_delivery(
app_model=app_model,
account=account,
node_id="node-1",
delivery_method_id=str(delivery_method.id),
inputs={"#node-1.output#": "value"},
)
pool_args = service._build_human_input_variable_pool.call_args.kwargs
assert pool_args["manual_inputs"] == {"#node-1.output#": "value"}
test_service_instance.send_test.assert_called_once()
def test_human_input_delivery_debug_mode_overrides_recipients(monkeypatch: pytest.MonkeyPatch):
service = _make_service()
delivery_method = _make_email_method(enabled=True, debug_mode=True)
node_config = _build_node_config([delivery_method])
workflow = MagicMock()
workflow.get_node_config_by_id.return_value = node_config
service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign]
service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[attr-defined]
node_stub = MagicMock()
node_stub._render_form_content_before_submission.return_value = "rendered"
node_stub._resolve_default_values.return_value = {}
service._build_human_input_node = MagicMock(return_value=node_stub) # type: ignore[attr-defined]
service._create_human_input_delivery_test_form = MagicMock( # type: ignore[attr-defined]
return_value=("form-1", {})
)
test_service_instance = MagicMock()
monkeypatch.setattr(
workflow_service_module,
"HumanInputDeliveryTestService",
MagicMock(return_value=test_service_instance),
)
app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1")
account = SimpleNamespace(id="account-1")
service.test_human_input_delivery(
app_model=app_model,
account=account,
node_id="node-1",
delivery_method_id=str(delivery_method.id),
)
test_service_instance.send_test.assert_called_once()
sent_method = test_service_instance.send_test.call_args.kwargs["method"]
assert isinstance(sent_method, EmailDeliveryMethod)
assert sent_method.config.debug_mode is True
assert sent_method.config.recipients.whole_workspace is False
assert len(sent_method.config.recipients.items) == 1
recipient = sent_method.config.recipients.items[0]
assert isinstance(recipient, MemberRecipient)
assert recipient.user_id == account.id

View File

@ -5,6 +5,7 @@ from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from core.workflow.enums import WorkflowNodeExecutionStatus
from models.workflow import WorkflowNodeExecutionModel
from repositories.sqlalchemy_api_workflow_node_execution_repository import (
DifyAPISQLAlchemyWorkflowNodeExecutionRepository,
@ -52,6 +53,9 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
call_args = mock_session.scalar.call_args[0][0]
assert hasattr(call_args, "compile") # It's a SQLAlchemy statement
compiled = call_args.compile()
assert WorkflowNodeExecutionStatus.PAUSED in compiled.params.values()
def test_get_node_last_execution_not_found(self, repository):
"""Test getting the last execution for a node when it doesn't exist."""
# Arrange
@ -71,28 +75,6 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
assert result is None
mock_session.scalar.assert_called_once()
def test_get_executions_by_workflow_run(self, repository, mock_execution):
"""Test getting all executions for a workflow run."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
executions = [mock_execution]
mock_session.execute.return_value.scalars.return_value.all.return_value = executions
# Act
result = repository.get_executions_by_workflow_run(
tenant_id="tenant-123",
app_id="app-456",
workflow_run_id="run-101",
)
# Assert
assert result == executions
mock_session.execute.assert_called_once()
# Verify the query was constructed correctly
call_args = mock_session.execute.call_args[0][0]
assert hasattr(call_args, "compile") # It's a SQLAlchemy statement
def test_get_executions_by_workflow_run_empty(self, repository):
"""Test getting executions for a workflow run when none exist."""
# Arrange

View File

@ -1,9 +1,15 @@
from contextlib import nullcontext
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from core.workflow.enums import NodeType
from core.workflow.nodes.human_input.entities import FormInput, HumanInputNodeData, UserAction
from core.workflow.nodes.human_input.enums import FormInputType
from models.model import App
from models.workflow import Workflow
from services import workflow_service as workflow_service_module
from services.workflow_service import WorkflowService
@ -161,3 +167,120 @@ class TestWorkflowService:
assert workflows == []
assert has_more is False
mock_session.scalars.assert_called_once()
def test_submit_human_input_form_preview_uses_rendered_content(
self, workflow_service: WorkflowService, monkeypatch: pytest.MonkeyPatch
) -> None:
service = workflow_service
node_data = HumanInputNodeData(
title="Human Input",
form_content="<p>{{#$output.name#}}</p>",
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")],
user_actions=[UserAction(id="approve", title="Approve")],
)
node = MagicMock()
node.node_data = node_data
node.render_form_content_before_submission.return_value = "<p>preview</p>"
node.render_form_content_with_outputs.return_value = "<p>rendered</p>"
service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[method-assign]
service._build_human_input_node = MagicMock(return_value=node) # type: ignore[method-assign]
workflow = MagicMock()
workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}}
workflow.get_enclosing_node_type_and_id.return_value = None
service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign]
saved_outputs: dict[str, object] = {}
class DummySession:
def __init__(self, *args, **kwargs):
self.commit = MagicMock()
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def begin(self):
return nullcontext()
class DummySaver:
def __init__(self, *args, **kwargs):
pass
def save(self, outputs, process_data):
saved_outputs.update(outputs)
monkeypatch.setattr(workflow_service_module, "Session", DummySession)
monkeypatch.setattr(workflow_service_module, "DraftVariableSaver", DummySaver)
monkeypatch.setattr(workflow_service_module, "db", SimpleNamespace(engine=MagicMock()))
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
account = SimpleNamespace(id="account-1")
result = service.submit_human_input_form_preview(
app_model=app_model,
account=account,
node_id="node-1",
form_inputs={"name": "Ada", "extra": "ignored"},
inputs={"#node-0.result#": "LLM output"},
action="approve",
)
service._build_human_input_variable_pool.assert_called_once_with(
app_model=app_model,
workflow=workflow,
node_config={"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}},
manual_inputs={"#node-0.result#": "LLM output"},
)
node.render_form_content_with_outputs.assert_called_once()
called_args = node.render_form_content_with_outputs.call_args.args
assert called_args[0] == "<p>preview</p>"
assert called_args[2] == node_data.outputs_field_names()
rendered_outputs = called_args[1]
assert rendered_outputs["name"] == "Ada"
assert rendered_outputs["extra"] == "ignored"
assert "extra" in saved_outputs
assert "extra" in result
assert saved_outputs["name"] == "Ada"
assert result["name"] == "Ada"
assert result["__action_id"] == "approve"
assert "__rendered_content" in result
def test_submit_human_input_form_preview_missing_inputs_message(self, workflow_service: WorkflowService) -> None:
service = workflow_service
node_data = HumanInputNodeData(
title="Human Input",
form_content="<p>{{#$output.name#}}</p>",
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")],
user_actions=[UserAction(id="approve", title="Approve")],
)
node = MagicMock()
node.node_data = node_data
node._render_form_content_before_submission.return_value = "<p>preview</p>"
node._render_form_content_with_outputs.return_value = "<p>rendered</p>"
service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[method-assign]
service._build_human_input_node = MagicMock(return_value=node) # type: ignore[method-assign]
workflow = MagicMock()
workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}}
service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign]
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
account = SimpleNamespace(id="account-1")
with pytest.raises(ValueError) as exc_info:
service.submit_human_input_form_preview(
app_model=app_model,
account=account,
node_id="node-1",
form_inputs={},
inputs={},
action="approve",
)
assert "Missing required inputs" in str(exc_info.value)

View File

@ -0,0 +1,210 @@
from __future__ import annotations
from datetime import datetime, timedelta
from types import SimpleNamespace
from typing import Any
import pytest
from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
from tasks import human_input_timeout_tasks as task_module
class _FakeScalarResult:
def __init__(self, items: list[Any]):
self._items = items
def all(self) -> list[Any]:
return self._items
class _FakeSession:
def __init__(self, items: list[Any], capture: dict[str, Any]):
self._items = items
self._capture = capture
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def scalars(self, stmt):
self._capture["stmt"] = stmt
return _FakeScalarResult(self._items)
class _FakeSessionFactory:
def __init__(self, items: list[Any], capture: dict[str, Any]):
self._items = items
self._capture = capture
self._capture["session_factory"] = self
def __call__(self):
session = _FakeSession(self._items, self._capture)
self._capture["session"] = session
return session
class _FakeFormRepo:
def __init__(self, _session_factory, form_map: dict[str, Any] | None = None):
self.calls: list[dict[str, Any]] = []
self._form_map = form_map or {}
def mark_timeout(self, *, form_id: str, timeout_status: HumanInputFormStatus, reason: str | None = None):
self.calls.append(
{
"form_id": form_id,
"timeout_status": timeout_status,
"reason": reason,
}
)
form = self._form_map.get(form_id)
return SimpleNamespace(
form_id=form_id,
workflow_run_id=getattr(form, "workflow_run_id", None),
node_id=getattr(form, "node_id", None),
)
class _FakeService:
def __init__(self, _session_factory, form_repository=None):
self.enqueued: list[str] = []
def enqueue_resume(self, workflow_run_id: str | None) -> None:
if workflow_run_id is not None:
self.enqueued.append(workflow_run_id)
def _build_form(
*,
form_id: str,
form_kind: HumanInputFormKind,
created_at: datetime,
expiration_time: datetime,
workflow_run_id: str | None,
node_id: str,
) -> SimpleNamespace:
return SimpleNamespace(
id=form_id,
form_kind=form_kind,
created_at=created_at,
expiration_time=expiration_time,
workflow_run_id=workflow_run_id,
node_id=node_id,
status=HumanInputFormStatus.WAITING,
)
def test_is_global_timeout_uses_created_at():
now = datetime(2025, 1, 1, 12, 0, 0)
form = SimpleNamespace(created_at=now - timedelta(seconds=61), workflow_run_id="run-1")
assert task_module._is_global_timeout(form, 60, now=now) is True
form.workflow_run_id = None
assert task_module._is_global_timeout(form, 60, now=now) is False
form.workflow_run_id = "run-1"
form.created_at = now - timedelta(seconds=59)
assert task_module._is_global_timeout(form, 60, now=now) is False
assert task_module._is_global_timeout(form, 0, now=now) is False
def test_check_and_handle_human_input_timeouts_marks_and_routes(monkeypatch: pytest.MonkeyPatch):
now = datetime(2025, 1, 1, 12, 0, 0)
monkeypatch.setattr(task_module, "naive_utc_now", lambda: now)
monkeypatch.setattr(task_module.dify_config, "HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS", 3600)
monkeypatch.setattr(task_module, "db", SimpleNamespace(engine=object()))
forms = [
_build_form(
form_id="form-global",
form_kind=HumanInputFormKind.RUNTIME,
created_at=now - timedelta(hours=2),
expiration_time=now + timedelta(hours=1),
workflow_run_id="run-global",
node_id="node-global",
),
_build_form(
form_id="form-node",
form_kind=HumanInputFormKind.RUNTIME,
created_at=now - timedelta(minutes=5),
expiration_time=now - timedelta(seconds=1),
workflow_run_id="run-node",
node_id="node-node",
),
_build_form(
form_id="form-delivery",
form_kind=HumanInputFormKind.DELIVERY_TEST,
created_at=now - timedelta(minutes=1),
expiration_time=now - timedelta(seconds=1),
workflow_run_id=None,
node_id="node-delivery",
),
]
capture: dict[str, Any] = {}
monkeypatch.setattr(task_module, "sessionmaker", lambda *args, **kwargs: _FakeSessionFactory(forms, capture))
form_map = {form.id: form for form in forms}
repo = _FakeFormRepo(None, form_map=form_map)
def _repo_factory(_session_factory):
return repo
service = _FakeService(None)
def _service_factory(_session_factory, form_repository=None):
return service
global_calls: list[dict[str, Any]] = []
monkeypatch.setattr(task_module, "HumanInputFormSubmissionRepository", _repo_factory)
monkeypatch.setattr(task_module, "HumanInputService", _service_factory)
monkeypatch.setattr(task_module, "_handle_global_timeout", lambda **kwargs: global_calls.append(kwargs))
task_module.check_and_handle_human_input_timeouts(limit=100)
assert {(call["form_id"], call["timeout_status"], call["reason"]) for call in repo.calls} == {
("form-global", HumanInputFormStatus.EXPIRED, "global_timeout"),
("form-node", HumanInputFormStatus.TIMEOUT, "node_timeout"),
("form-delivery", HumanInputFormStatus.TIMEOUT, "delivery_test_timeout"),
}
assert service.enqueued == ["run-node"]
assert global_calls == [
{
"form_id": "form-global",
"workflow_run_id": "run-global",
"node_id": "node-global",
"session_factory": capture.get("session_factory"),
}
]
stmt = capture.get("stmt")
assert stmt is not None
stmt_text = str(stmt)
assert "created_at <=" in stmt_text
assert "expiration_time <=" in stmt_text
assert "ORDER BY human_input_forms.id" in stmt_text
def test_check_and_handle_human_input_timeouts_omits_global_filter_when_disabled(monkeypatch: pytest.MonkeyPatch):
now = datetime(2025, 1, 1, 12, 0, 0)
monkeypatch.setattr(task_module, "naive_utc_now", lambda: now)
monkeypatch.setattr(task_module.dify_config, "HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS", 0)
monkeypatch.setattr(task_module, "db", SimpleNamespace(engine=object()))
capture: dict[str, Any] = {}
monkeypatch.setattr(task_module, "sessionmaker", lambda *args, **kwargs: _FakeSessionFactory([], capture))
monkeypatch.setattr(task_module, "HumanInputFormSubmissionRepository", _FakeFormRepo)
monkeypatch.setattr(task_module, "HumanInputService", _FakeService)
monkeypatch.setattr(task_module, "_handle_global_timeout", lambda **_kwargs: None)
task_module.check_and_handle_human_input_timeouts(limit=1)
stmt = capture.get("stmt")
assert stmt is not None
stmt_text = str(stmt)
assert "created_at <=" not in stmt_text

View File

@ -0,0 +1,123 @@
from collections.abc import Sequence
from types import SimpleNamespace
import pytest
from tasks import mail_human_input_delivery_task as task_module
class _DummyMail:
def __init__(self):
self.sent: list[dict[str, str]] = []
self._inited = True
def is_inited(self) -> bool:
return self._inited
def send(self, *, to: str, subject: str, html: str):
self.sent.append({"to": to, "subject": subject, "html": html})
class _DummySession:
def __init__(self, form):
self._form = form
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
return False
def get(self, _model, _form_id):
return self._form
def _build_job(recipient_count: int = 1) -> task_module._EmailDeliveryJob:
recipients: list[task_module._EmailRecipient] = []
for idx in range(recipient_count):
recipients.append(task_module._EmailRecipient(email=f"user{idx}@example.com", token=f"token-{idx}"))
return task_module._EmailDeliveryJob(
form_id="form-1",
subject="Subject",
body="Body for {{#url}}",
form_content="content",
recipients=recipients,
)
def test_dispatch_human_input_email_task_sends_to_each_recipient(monkeypatch: pytest.MonkeyPatch):
mail = _DummyMail()
form = SimpleNamespace(id="form-1", tenant_id="tenant-1", workflow_run_id=None)
monkeypatch.setattr(task_module, "mail", mail)
monkeypatch.setattr(
task_module.FeatureService,
"get_features",
lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=True),
)
jobs: Sequence[task_module._EmailDeliveryJob] = [_build_job(recipient_count=2)]
monkeypatch.setattr(task_module, "_load_email_jobs", lambda _session, _form: jobs)
task_module.dispatch_human_input_email_task(
form_id="form-1",
node_title="Approve",
session_factory=lambda: _DummySession(form),
)
assert len(mail.sent) == 2
assert all(payload["subject"] == "Subject" for payload in mail.sent)
assert all("Body for" in payload["html"] for payload in mail.sent)
def test_dispatch_human_input_email_task_skips_when_feature_disabled(monkeypatch: pytest.MonkeyPatch):
mail = _DummyMail()
form = SimpleNamespace(id="form-1", tenant_id="tenant-1", workflow_run_id=None)
monkeypatch.setattr(task_module, "mail", mail)
monkeypatch.setattr(
task_module.FeatureService,
"get_features",
lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=False),
)
monkeypatch.setattr(task_module, "_load_email_jobs", lambda _session, _form: [])
task_module.dispatch_human_input_email_task(
form_id="form-1",
node_title="Approve",
session_factory=lambda: _DummySession(form),
)
assert mail.sent == []
def test_dispatch_human_input_email_task_replaces_body_variables(monkeypatch: pytest.MonkeyPatch):
mail = _DummyMail()
form = SimpleNamespace(id="form-1", tenant_id="tenant-1", workflow_run_id="run-1")
job = task_module._EmailDeliveryJob(
form_id="form-1",
subject="Subject",
body="Body {{#node1.value#}}",
form_content="content",
recipients=[task_module._EmailRecipient(email="user@example.com", token="token-1")],
)
variable_pool = task_module.VariablePool()
variable_pool.add(["node1", "value"], "OK")
monkeypatch.setattr(task_module, "mail", mail)
monkeypatch.setattr(
task_module.FeatureService,
"get_features",
lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=True),
)
monkeypatch.setattr(task_module, "_load_email_jobs", lambda _session, _form: [job])
monkeypatch.setattr(task_module, "_load_variable_pool", lambda _workflow_run_id: variable_pool)
task_module.dispatch_human_input_email_task(
form_id="form-1",
node_title="Approve",
session_factory=lambda: _DummySession(form),
)
assert mail.sent[0]["html"] == "Body OK"

View File

@ -0,0 +1,39 @@
from __future__ import annotations
import json
import uuid
from unittest.mock import MagicMock
import pytest
from models.model import AppMode
from tasks.app_generate.workflow_execute_task import _publish_streaming_response
@pytest.fixture
def mock_topic(mocker) -> MagicMock:
topic = MagicMock()
mocker.patch(
"tasks.app_generate.workflow_execute_task.MessageBasedAppGenerator.get_response_topic",
return_value=topic,
)
return topic
def test_publish_streaming_response_with_uuid(mock_topic: MagicMock):
workflow_run_id = uuid.uuid4()
response_stream = iter([{"event": "foo"}, "ping"])
_publish_streaming_response(response_stream, workflow_run_id, app_mode=AppMode.ADVANCED_CHAT)
payloads = [call.args[0] for call in mock_topic.publish.call_args_list]
assert payloads == [json.dumps({"event": "foo"}).encode(), json.dumps("ping").encode()]
def test_publish_streaming_response_coerces_string_uuid(mock_topic: MagicMock):
workflow_run_id = uuid.uuid4()
response_stream = iter([{"event": "bar"}])
_publish_streaming_response(response_stream, str(workflow_run_id), app_mode=AppMode.ADVANCED_CHAT)
mock_topic.publish.assert_called_once_with(json.dumps({"event": "bar"}).encode())

View File

@ -0,0 +1,488 @@
# """
# Unit tests for workflow node execution Celery tasks.
# These tests verify the asynchronous storage functionality for workflow node execution data,
# including truncation and offloading logic.
# """
# import json
# from unittest.mock import MagicMock, Mock, patch
# from uuid import uuid4
# import pytest
# from core.workflow.entities.workflow_node_execution import (
# WorkflowNodeExecution,
# WorkflowNodeExecutionStatus,
# )
# from core.workflow.enums import NodeType
# from libs.datetime_utils import naive_utc_now
# from models import WorkflowNodeExecutionModel
# from models.enums import ExecutionOffLoadType
# from models.model import UploadFile
# from models.workflow import WorkflowNodeExecutionOffload, WorkflowNodeExecutionTriggeredFrom
# from tasks.workflow_node_execution_tasks import (
# _create_truncator,
# _json_encode,
# _replace_or_append_offload,
# _truncate_and_upload_async,
# save_workflow_node_execution_data_task,
# save_workflow_node_execution_task,
# )
# @pytest.fixture
# def sample_execution_data():
# """Sample execution data for testing."""
# execution = WorkflowNodeExecution(
# id=str(uuid4()),
# node_execution_id=str(uuid4()),
# workflow_id=str(uuid4()),
# workflow_execution_id=str(uuid4()),
# index=1,
# node_id="test_node",
# node_type=NodeType.LLM,
# title="Test Node",
# inputs={"input_key": "input_value"},
# outputs={"output_key": "output_value"},
# process_data={"process_key": "process_value"},
# status=WorkflowNodeExecutionStatus.RUNNING,
# created_at=naive_utc_now(),
# )
# return execution.model_dump()
# @pytest.fixture
# def mock_db_model():
# """Mock database model for testing."""
# db_model = Mock(spec=WorkflowNodeExecutionModel)
# db_model.id = "test-execution-id"
# db_model.offload_data = []
# return db_model
# @pytest.fixture
# def mock_file_service():
# """Mock file service for testing."""
# file_service = Mock()
# mock_upload_file = Mock(spec=UploadFile)
# mock_upload_file.id = "mock-file-id"
# file_service.upload_file.return_value = mock_upload_file
# return file_service
# class TestSaveWorkflowNodeExecutionDataTask:
# """Test cases for save_workflow_node_execution_data_task."""
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
# @patch("tasks.workflow_node_execution_tasks.select")
# def test_save_execution_data_task_success(
# self, mock_select, mock_sessionmaker, sample_execution_data, mock_db_model
# ):
# """Test successful execution of save_workflow_node_execution_data_task."""
# # Setup mocks
# mock_session = MagicMock()
# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session
# mock_session.execute.return_value.scalars.return_value.first.return_value = mock_db_model
# # Execute task
# result = save_workflow_node_execution_data_task(
# execution_data=sample_execution_data,
# tenant_id="test-tenant-id",
# app_id="test-app-id",
# user_data={"user_id": "test-user-id", "user_type": "account"},
# )
# # Verify success
# assert result is True
# mock_session.merge.assert_called_once_with(mock_db_model)
# mock_session.commit.assert_called_once()
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
# @patch("tasks.workflow_node_execution_tasks.select")
# def test_save_execution_data_task_execution_not_found(self, mock_select, mock_sessionmaker,
# sample_execution_data):
# """Test task when execution is not found in database."""
# # Setup mocks
# mock_session = MagicMock()
# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session
# mock_session.execute.return_value.scalars.return_value.first.return_value = None
# # Execute task
# result = save_workflow_node_execution_data_task(
# execution_data=sample_execution_data,
# tenant_id="test-tenant-id",
# app_id="test-app-id",
# user_data={"user_id": "test-user-id", "user_type": "account"},
# )
# # Verify failure
# assert result is False
# mock_session.merge.assert_not_called()
# mock_session.commit.assert_not_called()
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
# @patch("tasks.workflow_node_execution_tasks.select")
# def test_save_execution_data_task_with_truncation(self, mock_select, mock_sessionmaker, mock_db_model):
# """Test task with data that requires truncation."""
# # Create execution with large data
# large_data = {"large_field": "x" * 10000}
# execution = WorkflowNodeExecution(
# id=str(uuid4()),
# node_execution_id=str(uuid4()),
# workflow_id=str(uuid4()),
# workflow_execution_id=str(uuid4()),
# index=1,
# node_id="test_node",
# node_type=NodeType.LLM,
# title="Test Node",
# inputs=large_data,
# outputs=large_data,
# process_data=large_data,
# status=WorkflowNodeExecutionStatus.RUNNING,
# created_at=naive_utc_now(),
# )
# execution_data = execution.model_dump()
# # Setup mocks
# mock_session = MagicMock()
# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session
# mock_session.execute.return_value.scalars.return_value.first.return_value = mock_db_model
# # Create mock upload file
# mock_upload_file = Mock(spec=UploadFile)
# mock_upload_file.id = "mock-file-id"
# # Execute task
# with patch("tasks.workflow_node_execution_tasks._truncate_and_upload_async") as mock_truncate:
# # Mock truncation results
# mock_truncate.return_value = {
# "truncated_value": {"large_field": "[TRUNCATED]"},
# "file": mock_upload_file,
# "offload": WorkflowNodeExecutionOffload(
# id=str(uuid4()),
# tenant_id="test-tenant-id",
# app_id="test-app-id",
# node_execution_id=execution.id,
# type_=ExecutionOffLoadType.INPUTS,
# file_id=mock_upload_file.id,
# ),
# }
# result = save_workflow_node_execution_data_task(
# execution_data=execution_data,
# tenant_id="test-tenant-id",
# app_id="test-app-id",
# user_data={"user_id": "test-user-id", "user_type": "account"},
# )
# # Verify success and truncation was called
# assert result is True
# assert mock_truncate.call_count == 3 # inputs, outputs, process_data
# mock_session.merge.assert_called_once_with(mock_db_model)
# mock_session.commit.assert_called_once()
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
# def test_save_execution_data_task_retry_on_exception(self, mock_sessionmaker, sample_execution_data):
# """Test task retry mechanism on exception."""
# # Setup mock to raise exception
# mock_sessionmaker.side_effect = Exception("Database error")
# # Create a mock task instance with proper retry behavior
# with patch.object(save_workflow_node_execution_data_task, "retry") as mock_retry:
# mock_retry.side_effect = Exception("Retry called")
# # Execute task and expect retry
# with pytest.raises(Exception, match="Retry called"):
# save_workflow_node_execution_data_task(
# execution_data=sample_execution_data,
# tenant_id="test-tenant-id",
# app_id="test-app-id",
# user_data={"user_id": "test-user-id", "user_type": "account"},
# )
# # Verify retry was called
# mock_retry.assert_called_once()
# class TestTruncateAndUploadAsync:
# """Test cases for _truncate_and_upload_async function."""
# def test_truncate_and_upload_with_none_values(self, mock_file_service):
# """Test _truncate_and_upload_async with None values."""
# # The function handles None values internally, so we test with empty dict instead
# result = _truncate_and_upload_async(
# values={},
# execution_id="test-id",
# type_=ExecutionOffLoadType.INPUTS,
# tenant_id="test-tenant",
# app_id="test-app",
# user_data={"user_id": "test-user", "user_type": "account"},
# file_service=mock_file_service,
# )
# # Empty dict should not require truncation
# assert result is None
# mock_file_service.upload_file.assert_not_called()
# @patch("tasks.workflow_node_execution_tasks._create_truncator")
# def test_truncate_and_upload_no_truncation_needed(self, mock_create_truncator, mock_file_service):
# """Test _truncate_and_upload_async when no truncation is needed."""
# # Mock truncator to return no truncation
# mock_truncator = Mock()
# mock_truncator.truncate_variable_mapping.return_value = ({"small": "data"}, False)
# mock_create_truncator.return_value = mock_truncator
# small_values = {"small": "data"}
# result = _truncate_and_upload_async(
# values=small_values,
# execution_id="test-id",
# type_=ExecutionOffLoadType.INPUTS,
# tenant_id="test-tenant",
# app_id="test-app",
# user_data={"user_id": "test-user", "user_type": "account"},
# file_service=mock_file_service,
# )
# assert result is None
# mock_file_service.upload_file.assert_not_called()
# @patch("tasks.workflow_node_execution_tasks._create_truncator")
# @patch("models.Account")
# @patch("models.Tenant")
# def test_truncate_and_upload_with_account_user(
# self, mock_tenant_class, mock_account_class, mock_create_truncator, mock_file_service
# ):
# """Test _truncate_and_upload_async with account user."""
# # Mock truncator to return truncation needed
# mock_truncator = Mock()
# mock_truncator.truncate_variable_mapping.return_value = ({"truncated": "data"}, True)
# mock_create_truncator.return_value = mock_truncator
# # Mock user and tenant creation
# mock_account = Mock()
# mock_account.id = "test-user"
# mock_account_class.return_value = mock_account
# mock_tenant = Mock()
# mock_tenant.id = "test-tenant"
# mock_tenant_class.return_value = mock_tenant
# large_values = {"large": "x" * 10000}
# result = _truncate_and_upload_async(
# values=large_values,
# execution_id="test-id",
# type_=ExecutionOffLoadType.INPUTS,
# tenant_id="test-tenant",
# app_id="test-app",
# user_data={"user_id": "test-user", "user_type": "account"},
# file_service=mock_file_service,
# )
# # Verify result structure
# assert result is not None
# assert "truncated_value" in result
# assert "file" in result
# assert "offload" in result
# assert result["truncated_value"] == {"truncated": "data"}
# # Verify file upload was called
# mock_file_service.upload_file.assert_called_once()
# upload_call = mock_file_service.upload_file.call_args
# assert upload_call[1]["filename"] == "node_execution_test-id_inputs.json"
# assert upload_call[1]["mimetype"] == "application/json"
# assert upload_call[1]["user"] == mock_account
# @patch("tasks.workflow_node_execution_tasks._create_truncator")
# @patch("models.EndUser")
# def test_truncate_and_upload_with_end_user(self, mock_end_user_class, mock_create_truncator, mock_file_service):
# """Test _truncate_and_upload_async with end user."""
# # Mock truncator to return truncation needed
# mock_truncator = Mock()
# mock_truncator.truncate_variable_mapping.return_value = ({"truncated": "data"}, True)
# mock_create_truncator.return_value = mock_truncator
# # Mock end user creation
# mock_end_user = Mock()
# mock_end_user.id = "test-user"
# mock_end_user.tenant_id = "test-tenant"
# mock_end_user_class.return_value = mock_end_user
# large_values = {"large": "x" * 10000}
# result = _truncate_and_upload_async(
# values=large_values,
# execution_id="test-id",
# type_=ExecutionOffLoadType.OUTPUTS,
# tenant_id="test-tenant",
# app_id="test-app",
# user_data={"user_id": "test-user", "user_type": "end_user"},
# file_service=mock_file_service,
# )
# # Verify result structure
# assert result is not None
# assert result["truncated_value"] == {"truncated": "data"}
# # Verify file upload was called with end user
# mock_file_service.upload_file.assert_called_once()
# upload_call = mock_file_service.upload_file.call_args
# assert upload_call[1]["filename"] == "node_execution_test-id_outputs.json"
# assert upload_call[1]["user"] == mock_end_user
# class TestHelperFunctions:
# """Test cases for helper functions."""
# @patch("tasks.workflow_node_execution_tasks.dify_config")
# def test_create_truncator(self, mock_config):
# """Test _create_truncator function."""
# mock_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE = 1000
# mock_config.WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH = 100
# mock_config.WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH = 500
# truncator = _create_truncator()
# # Verify truncator was created with correct config
# assert truncator is not None
# def test_json_encode(self):
# """Test _json_encode function."""
# test_data = {"key": "value", "number": 42}
# result = _json_encode(test_data)
# assert isinstance(result, str)
# decoded = json.loads(result)
# assert decoded == test_data
# def test_replace_or_append_offload_replace_existing(self):
# """Test _replace_or_append_offload replaces existing offload of same type."""
# existing_offload = WorkflowNodeExecutionOffload(
# id=str(uuid4()),
# tenant_id="test-tenant",
# app_id="test-app",
# node_execution_id="test-execution",
# type_=ExecutionOffLoadType.INPUTS,
# file_id="old-file-id",
# )
# new_offload = WorkflowNodeExecutionOffload(
# id=str(uuid4()),
# tenant_id="test-tenant",
# app_id="test-app",
# node_execution_id="test-execution",
# type_=ExecutionOffLoadType.INPUTS,
# file_id="new-file-id",
# )
# result = _replace_or_append_offload([existing_offload], new_offload)
# assert len(result) == 1
# assert result[0].file_id == "new-file-id"
# def test_replace_or_append_offload_append_new_type(self):
# """Test _replace_or_append_offload appends new offload of different type."""
# existing_offload = WorkflowNodeExecutionOffload(
# id=str(uuid4()),
# tenant_id="test-tenant",
# app_id="test-app",
# node_execution_id="test-execution",
# type_=ExecutionOffLoadType.INPUTS,
# file_id="inputs-file-id",
# )
# new_offload = WorkflowNodeExecutionOffload(
# id=str(uuid4()),
# tenant_id="test-tenant",
# app_id="test-app",
# node_execution_id="test-execution",
# type_=ExecutionOffLoadType.OUTPUTS,
# file_id="outputs-file-id",
# )
# result = _replace_or_append_offload([existing_offload], new_offload)
# assert len(result) == 2
# file_ids = [offload.file_id for offload in result]
# assert "inputs-file-id" in file_ids
# assert "outputs-file-id" in file_ids
# class TestSaveWorkflowNodeExecutionTask:
# """Test cases for save_workflow_node_execution_task."""
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
# @patch("tasks.workflow_node_execution_tasks.select")
# def test_save_workflow_node_execution_task_create_new(self, mock_select, mock_sessionmaker,
# sample_execution_data):
# """Test creating a new workflow node execution."""
# # Setup mocks
# mock_session = MagicMock()
# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session
# mock_session.scalar.return_value = None # No existing execution
# # Execute task
# result = save_workflow_node_execution_task(
# execution_data=sample_execution_data,
# tenant_id="test-tenant-id",
# app_id="test-app-id",
# triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
# creator_user_id="test-user-id",
# creator_user_role="account",
# )
# # Verify success
# assert result is True
# mock_session.add.assert_called_once()
# mock_session.commit.assert_called_once()
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
# @patch("tasks.workflow_node_execution_tasks.select")
# def test_save_workflow_node_execution_task_update_existing(
# self, mock_select, mock_sessionmaker, sample_execution_data
# ):
# """Test updating an existing workflow node execution."""
# # Setup mocks
# mock_session = MagicMock()
# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session
# existing_execution = Mock(spec=WorkflowNodeExecutionModel)
# mock_session.scalar.return_value = existing_execution
# # Execute task
# result = save_workflow_node_execution_task(
# execution_data=sample_execution_data,
# tenant_id="test-tenant-id",
# app_id="test-app-id",
# triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
# creator_user_id="test-user-id",
# creator_user_role="account",
# )
# # Verify success
# assert result is True
# mock_session.add.assert_not_called() # Should not add new, just update existing
# mock_session.commit.assert_called_once()
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
# def test_save_workflow_node_execution_task_retry_on_exception(self, mock_sessionmaker, sample_execution_data):
# """Test task retry mechanism on exception."""
# # Setup mock to raise exception
# mock_sessionmaker.side_effect = Exception("Database error")
# # Create a mock task instance with proper retry behavior
# with patch.object(save_workflow_node_execution_task, "retry") as mock_retry:
# mock_retry.side_effect = Exception("Retry called")
# # Execute task and expect retry
# with pytest.raises(Exception, match="Retry called"):
# save_workflow_node_execution_task(
# execution_data=sample_execution_data,
# tenant_id="test-tenant-id",
# app_id="test-app-id",
# triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
# creator_user_id="test-user-id",
# creator_user_role="account",
# )
# # Verify retry was called
# mock_retry.assert_called_once()