mirror of
https://github.com/langgenius/dify.git
synced 2026-02-22 19:15:47 +08:00
feat: Human Input Node (#32060)
The frontend and backend implementation for the human input node. Co-authored-by: twwu <twwu@dify.ai> Co-authored-by: JzoNg <jzongcode@gmail.com> Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com> Co-authored-by: zhsama <torvalds@linux.do>
This commit is contained in:
@ -164,6 +164,62 @@ def test_db_extras_options_merging(monkeypatch: pytest.MonkeyPatch):
|
||||
assert "timezone=UTC" in options
|
||||
|
||||
|
||||
def test_pubsub_redis_url_default(monkeypatch: pytest.MonkeyPatch):
|
||||
os.environ.clear()
|
||||
|
||||
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
|
||||
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
|
||||
monkeypatch.setenv("DB_USERNAME", "postgres")
|
||||
monkeypatch.setenv("DB_PASSWORD", "postgres")
|
||||
monkeypatch.setenv("DB_HOST", "localhost")
|
||||
monkeypatch.setenv("DB_PORT", "5432")
|
||||
monkeypatch.setenv("DB_DATABASE", "dify")
|
||||
monkeypatch.setenv("REDIS_HOST", "redis.example.com")
|
||||
monkeypatch.setenv("REDIS_PORT", "6380")
|
||||
monkeypatch.setenv("REDIS_USERNAME", "user")
|
||||
monkeypatch.setenv("REDIS_PASSWORD", "pass@word")
|
||||
monkeypatch.setenv("REDIS_DB", "2")
|
||||
monkeypatch.setenv("REDIS_USE_SSL", "true")
|
||||
|
||||
config = DifyConfig()
|
||||
|
||||
assert config.normalized_pubsub_redis_url == "rediss://user:pass%40word@redis.example.com:6380/2"
|
||||
assert config.PUBSUB_REDIS_CHANNEL_TYPE == "pubsub"
|
||||
|
||||
|
||||
def test_pubsub_redis_url_override(monkeypatch: pytest.MonkeyPatch):
|
||||
os.environ.clear()
|
||||
|
||||
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
|
||||
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
|
||||
monkeypatch.setenv("DB_USERNAME", "postgres")
|
||||
monkeypatch.setenv("DB_PASSWORD", "postgres")
|
||||
monkeypatch.setenv("DB_HOST", "localhost")
|
||||
monkeypatch.setenv("DB_PORT", "5432")
|
||||
monkeypatch.setenv("DB_DATABASE", "dify")
|
||||
monkeypatch.setenv("PUBSUB_REDIS_URL", "redis://pubsub-host:6381/5")
|
||||
|
||||
config = DifyConfig()
|
||||
|
||||
assert config.normalized_pubsub_redis_url == "redis://pubsub-host:6381/5"
|
||||
|
||||
|
||||
def test_pubsub_redis_url_required_when_default_unavailable(monkeypatch: pytest.MonkeyPatch):
|
||||
os.environ.clear()
|
||||
|
||||
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
|
||||
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
|
||||
monkeypatch.setenv("DB_USERNAME", "postgres")
|
||||
monkeypatch.setenv("DB_PASSWORD", "postgres")
|
||||
monkeypatch.setenv("DB_HOST", "localhost")
|
||||
monkeypatch.setenv("DB_PORT", "5432")
|
||||
monkeypatch.setenv("DB_DATABASE", "dify")
|
||||
monkeypatch.setenv("REDIS_HOST", "")
|
||||
|
||||
with pytest.raises(ValueError, match="PUBSUB_REDIS_URL must be set"):
|
||||
_ = DifyConfig().normalized_pubsub_redis_url
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("broker_url", "expected_host", "expected_port", "expected_username", "expected_password", "expected_db"),
|
||||
[
|
||||
|
||||
@ -51,6 +51,8 @@ def _patch_redis_clients_on_loaded_modules():
|
||||
continue
|
||||
if hasattr(module, "redis_client"):
|
||||
module.redis_client = redis_mock
|
||||
if hasattr(module, "pubsub_redis_client"):
|
||||
module.pubsub_redis_client = redis_mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -68,7 +70,10 @@ def _provide_app_context(app: Flask):
|
||||
def _patch_redis_clients():
|
||||
"""Patch redis_client to MagicMock only for unit test executions."""
|
||||
|
||||
with patch.object(ext_redis, "redis_client", redis_mock):
|
||||
with (
|
||||
patch.object(ext_redis, "redis_client", redis_mock),
|
||||
patch.object(ext_redis, "pubsub_redis_client", redis_mock),
|
||||
):
|
||||
_patch_redis_clients_on_loaded_modules()
|
||||
yield
|
||||
|
||||
|
||||
@ -16,11 +16,9 @@ if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _load_app_module():
|
||||
@pytest.fixture(scope="module")
|
||||
def app_module():
|
||||
module_name = "controllers.console.app.app"
|
||||
if module_name in sys.modules:
|
||||
return sys.modules[module_name]
|
||||
|
||||
root = Path(__file__).resolve().parents[5]
|
||||
module_path = root / "controllers" / "console" / "app" / "app.py"
|
||||
|
||||
@ -59,8 +57,12 @@ def _load_app_module():
|
||||
|
||||
stub_namespace = _StubNamespace()
|
||||
|
||||
original_console = sys.modules.get("controllers.console")
|
||||
original_app_pkg = sys.modules.get("controllers.console.app")
|
||||
original_modules: dict[str, ModuleType | None] = {
|
||||
"controllers.console": sys.modules.get("controllers.console"),
|
||||
"controllers.console.app": sys.modules.get("controllers.console.app"),
|
||||
"controllers.common.schema": sys.modules.get("controllers.common.schema"),
|
||||
module_name: sys.modules.get(module_name),
|
||||
}
|
||||
stubbed_modules: list[tuple[str, ModuleType | None]] = []
|
||||
|
||||
console_module = ModuleType("controllers.console")
|
||||
@ -105,35 +107,35 @@ def _load_app_module():
|
||||
module = util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
try:
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(module)
|
||||
yield module
|
||||
finally:
|
||||
for name, original in reversed(stubbed_modules):
|
||||
if original is not None:
|
||||
sys.modules[name] = original
|
||||
else:
|
||||
sys.modules.pop(name, None)
|
||||
if original_console is not None:
|
||||
sys.modules["controllers.console"] = original_console
|
||||
else:
|
||||
sys.modules.pop("controllers.console", None)
|
||||
if original_app_pkg is not None:
|
||||
sys.modules["controllers.console.app"] = original_app_pkg
|
||||
else:
|
||||
sys.modules.pop("controllers.console.app", None)
|
||||
|
||||
return module
|
||||
for name, original in original_modules.items():
|
||||
if original is not None:
|
||||
sys.modules[name] = original
|
||||
else:
|
||||
sys.modules.pop(name, None)
|
||||
|
||||
|
||||
_app_module = _load_app_module()
|
||||
AppDetailWithSite = _app_module.AppDetailWithSite
|
||||
AppPagination = _app_module.AppPagination
|
||||
AppPartial = _app_module.AppPartial
|
||||
@pytest.fixture(scope="module")
|
||||
def app_models(app_module):
|
||||
return SimpleNamespace(
|
||||
AppDetailWithSite=app_module.AppDetailWithSite,
|
||||
AppPagination=app_module.AppPagination,
|
||||
AppPartial=app_module.AppPartial,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_signed_url(monkeypatch):
|
||||
def patch_signed_url(monkeypatch, app_module):
|
||||
"""Ensure icon URL generation uses a deterministic helper for tests."""
|
||||
|
||||
def _fake_signed_url(key: str | None) -> str | None:
|
||||
@ -141,7 +143,7 @@ def patch_signed_url(monkeypatch):
|
||||
return None
|
||||
return f"signed:{key}"
|
||||
|
||||
monkeypatch.setattr(_app_module.file_helpers, "get_signed_file_url", _fake_signed_url)
|
||||
monkeypatch.setattr(app_module.file_helpers, "get_signed_file_url", _fake_signed_url)
|
||||
|
||||
|
||||
def _ts(hour: int = 12) -> datetime:
|
||||
@ -169,7 +171,8 @@ def _dummy_workflow():
|
||||
)
|
||||
|
||||
|
||||
def test_app_partial_serialization_uses_aliases():
|
||||
def test_app_partial_serialization_uses_aliases(app_models):
|
||||
AppPartial = app_models.AppPartial
|
||||
created_at = _ts()
|
||||
app_obj = SimpleNamespace(
|
||||
id="app-1",
|
||||
@ -204,7 +207,8 @@ def test_app_partial_serialization_uses_aliases():
|
||||
assert serialized["tags"][0]["name"] == "Utilities"
|
||||
|
||||
|
||||
def test_app_detail_with_site_includes_nested_serialization():
|
||||
def test_app_detail_with_site_includes_nested_serialization(app_models):
|
||||
AppDetailWithSite = app_models.AppDetailWithSite
|
||||
timestamp = _ts(14)
|
||||
site = SimpleNamespace(
|
||||
code="site-code",
|
||||
@ -253,7 +257,8 @@ def test_app_detail_with_site_includes_nested_serialization():
|
||||
assert serialized["site"]["created_at"] == int(timestamp.timestamp())
|
||||
|
||||
|
||||
def test_app_pagination_aliases_per_page_and_has_next():
|
||||
def test_app_pagination_aliases_per_page_and_has_next(app_models):
|
||||
AppPagination = app_models.AppPagination
|
||||
item_one = SimpleNamespace(
|
||||
id="app-10",
|
||||
name="Paginated One",
|
||||
|
||||
@ -0,0 +1,229 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.console import wraps as console_wraps
|
||||
from controllers.console.app import workflow as workflow_module
|
||||
from controllers.console.app import wraps as app_wraps
|
||||
from libs import login as login_lib
|
||||
from models.account import Account, AccountStatus, TenantAccountRole
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
def _make_account() -> Account:
|
||||
account = Account(name="tester", email="tester@example.com")
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.role = TenantAccountRole.OWNER
|
||||
account.id = "account-123" # type: ignore[assignment]
|
||||
account._current_tenant = SimpleNamespace(id="tenant-123") # type: ignore[attr-defined]
|
||||
account._get_current_object = lambda: account # type: ignore[attr-defined]
|
||||
return account
|
||||
|
||||
|
||||
def _make_app(mode: AppMode) -> SimpleNamespace:
|
||||
return SimpleNamespace(id="app-123", tenant_id="tenant-123", mode=mode.value)
|
||||
|
||||
|
||||
def _patch_console_guards(monkeypatch: pytest.MonkeyPatch, account: Account, app_model: SimpleNamespace) -> None:
|
||||
# Skip setup and auth guardrails
|
||||
monkeypatch.setattr("configs.dify_config.EDITION", "CLOUD")
|
||||
monkeypatch.setattr(login_lib.dify_config, "LOGIN_DISABLED", True)
|
||||
monkeypatch.setattr(login_lib, "current_user", account)
|
||||
monkeypatch.setattr(login_lib, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
|
||||
monkeypatch.setattr(login_lib, "check_csrf_token", lambda *_, **__: None)
|
||||
monkeypatch.setattr(console_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
|
||||
monkeypatch.setattr(app_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
|
||||
monkeypatch.setattr(console_wraps.dify_config, "EDITION", "CLOUD")
|
||||
monkeypatch.delenv("INIT_PASSWORD", raising=False)
|
||||
|
||||
# Avoid hitting the database when resolving the app model
|
||||
monkeypatch.setattr(app_wraps, "_load_app_model", lambda _app_id: app_model)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreviewCase:
|
||||
resource_cls: type
|
||||
path: str
|
||||
mode: AppMode
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
[
|
||||
PreviewCase(
|
||||
resource_cls=workflow_module.AdvancedChatDraftHumanInputFormPreviewApi,
|
||||
path="/console/api/apps/app-123/advanced-chat/workflows/draft/human-input/nodes/node-42/form/preview",
|
||||
mode=AppMode.ADVANCED_CHAT,
|
||||
),
|
||||
PreviewCase(
|
||||
resource_cls=workflow_module.WorkflowDraftHumanInputFormPreviewApi,
|
||||
path="/console/api/apps/app-123/workflows/draft/human-input/nodes/node-42/form/preview",
|
||||
mode=AppMode.WORKFLOW,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_human_input_preview_delegates_to_service(
|
||||
app: Flask, monkeypatch: pytest.MonkeyPatch, case: PreviewCase
|
||||
) -> None:
|
||||
account = _make_account()
|
||||
app_model = _make_app(case.mode)
|
||||
_patch_console_guards(monkeypatch, account, app_model)
|
||||
|
||||
preview_payload = {
|
||||
"form_id": "node-42",
|
||||
"form_content": "<div>example</div>",
|
||||
"inputs": [{"name": "topic"}],
|
||||
"actions": [{"id": "continue"}],
|
||||
}
|
||||
service_instance = MagicMock()
|
||||
service_instance.get_human_input_form_preview.return_value = preview_payload
|
||||
monkeypatch.setattr(workflow_module, "WorkflowService", MagicMock(return_value=service_instance))
|
||||
|
||||
with app.test_request_context(case.path, method="POST", json={"inputs": {"topic": "tech"}}):
|
||||
response = case.resource_cls().post(app_id=app_model.id, node_id="node-42")
|
||||
|
||||
assert response == preview_payload
|
||||
service_instance.get_human_input_form_preview.assert_called_once_with(
|
||||
app_model=app_model,
|
||||
account=account,
|
||||
node_id="node-42",
|
||||
inputs={"topic": "tech"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubmitCase:
|
||||
resource_cls: type
|
||||
path: str
|
||||
mode: AppMode
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
[
|
||||
SubmitCase(
|
||||
resource_cls=workflow_module.AdvancedChatDraftHumanInputFormRunApi,
|
||||
path="/console/api/apps/app-123/advanced-chat/workflows/draft/human-input/nodes/node-99/form/run",
|
||||
mode=AppMode.ADVANCED_CHAT,
|
||||
),
|
||||
SubmitCase(
|
||||
resource_cls=workflow_module.WorkflowDraftHumanInputFormRunApi,
|
||||
path="/console/api/apps/app-123/workflows/draft/human-input/nodes/node-99/form/run",
|
||||
mode=AppMode.WORKFLOW,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_human_input_submit_forwards_payload(app: Flask, monkeypatch: pytest.MonkeyPatch, case: SubmitCase) -> None:
|
||||
account = _make_account()
|
||||
app_model = _make_app(case.mode)
|
||||
_patch_console_guards(monkeypatch, account, app_model)
|
||||
|
||||
result_payload = {"node_id": "node-99", "outputs": {"__rendered_content": "<p>done</p>"}, "action": "approve"}
|
||||
service_instance = MagicMock()
|
||||
service_instance.submit_human_input_form_preview.return_value = result_payload
|
||||
monkeypatch.setattr(workflow_module, "WorkflowService", MagicMock(return_value=service_instance))
|
||||
|
||||
with app.test_request_context(
|
||||
case.path,
|
||||
method="POST",
|
||||
json={"form_inputs": {"answer": "42"}, "inputs": {"#node-1.result#": "LLM output"}, "action": "approve"},
|
||||
):
|
||||
response = case.resource_cls().post(app_id=app_model.id, node_id="node-99")
|
||||
|
||||
assert response == result_payload
|
||||
service_instance.submit_human_input_form_preview.assert_called_once_with(
|
||||
app_model=app_model,
|
||||
account=account,
|
||||
node_id="node-99",
|
||||
form_inputs={"answer": "42"},
|
||||
inputs={"#node-1.result#": "LLM output"},
|
||||
action="approve",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeliveryTestCase:
|
||||
resource_cls: type
|
||||
path: str
|
||||
mode: AppMode
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
[
|
||||
DeliveryTestCase(
|
||||
resource_cls=workflow_module.WorkflowDraftHumanInputDeliveryTestApi,
|
||||
path="/console/api/apps/app-123/workflows/draft/human-input/nodes/node-7/delivery-test",
|
||||
mode=AppMode.ADVANCED_CHAT,
|
||||
),
|
||||
DeliveryTestCase(
|
||||
resource_cls=workflow_module.WorkflowDraftHumanInputDeliveryTestApi,
|
||||
path="/console/api/apps/app-123/workflows/draft/human-input/nodes/node-7/delivery-test",
|
||||
mode=AppMode.WORKFLOW,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_human_input_delivery_test_calls_service(
|
||||
app: Flask, monkeypatch: pytest.MonkeyPatch, case: DeliveryTestCase
|
||||
) -> None:
|
||||
account = _make_account()
|
||||
app_model = _make_app(case.mode)
|
||||
_patch_console_guards(monkeypatch, account, app_model)
|
||||
|
||||
service_instance = MagicMock()
|
||||
monkeypatch.setattr(workflow_module, "WorkflowService", MagicMock(return_value=service_instance))
|
||||
|
||||
with app.test_request_context(
|
||||
case.path,
|
||||
method="POST",
|
||||
json={"delivery_method_id": "delivery-123"},
|
||||
):
|
||||
response = case.resource_cls().post(app_id=app_model.id, node_id="node-7")
|
||||
|
||||
assert response == {}
|
||||
service_instance.test_human_input_delivery.assert_called_once_with(
|
||||
app_model=app_model,
|
||||
account=account,
|
||||
node_id="node-7",
|
||||
delivery_method_id="delivery-123",
|
||||
inputs={},
|
||||
)
|
||||
|
||||
|
||||
def test_human_input_delivery_test_maps_validation_error(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
account = _make_account()
|
||||
app_model = _make_app(AppMode.ADVANCED_CHAT)
|
||||
_patch_console_guards(monkeypatch, account, app_model)
|
||||
|
||||
service_instance = MagicMock()
|
||||
service_instance.test_human_input_delivery.side_effect = ValueError("bad delivery method")
|
||||
monkeypatch.setattr(workflow_module, "WorkflowService", MagicMock(return_value=service_instance))
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-123/workflows/draft/human-input/nodes/node-1/delivery-test",
|
||||
method="POST",
|
||||
json={"delivery_method_id": "bad"},
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
workflow_module.WorkflowDraftHumanInputDeliveryTestApi().post(app_id=app_model.id, node_id="node-1")
|
||||
|
||||
|
||||
def test_human_input_preview_rejects_non_mapping(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
account = _make_account()
|
||||
app_model = _make_app(AppMode.ADVANCED_CHAT)
|
||||
_patch_console_guards(monkeypatch, account, app_model)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/apps/app-123/advanced-chat/workflows/draft/human-input/nodes/node-1/form/preview",
|
||||
method="POST",
|
||||
json={"inputs": ["not-a-dict"]},
|
||||
):
|
||||
with pytest.raises(ValidationError):
|
||||
workflow_module.AdvancedChatDraftHumanInputFormPreviewApi().post(app_id=app_model.id, node_id="node-1")
|
||||
@ -0,0 +1,110 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console import wraps as console_wraps
|
||||
from controllers.console.app import workflow_run as workflow_run_module
|
||||
from controllers.web.error import NotFoundError
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.nodes.human_input.entities import FormInput, UserAction
|
||||
from core.workflow.nodes.human_input.enums import FormInputType
|
||||
from libs import login as login_lib
|
||||
from models.account import Account, AccountStatus, TenantAccountRole
|
||||
from models.workflow import WorkflowRun
|
||||
|
||||
|
||||
def _make_account() -> Account:
|
||||
account = Account(name="tester", email="tester@example.com")
|
||||
account.status = AccountStatus.ACTIVE
|
||||
account.role = TenantAccountRole.OWNER
|
||||
account.id = "account-123" # type: ignore[assignment]
|
||||
account._current_tenant = SimpleNamespace(id="tenant-123") # type: ignore[attr-defined]
|
||||
account._get_current_object = lambda: account # type: ignore[attr-defined]
|
||||
return account
|
||||
|
||||
|
||||
def _patch_console_guards(monkeypatch: pytest.MonkeyPatch, account: Account) -> None:
|
||||
monkeypatch.setattr(login_lib.dify_config, "LOGIN_DISABLED", True)
|
||||
monkeypatch.setattr(login_lib, "current_user", account)
|
||||
monkeypatch.setattr(login_lib, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
|
||||
monkeypatch.setattr(login_lib, "check_csrf_token", lambda *_, **__: None)
|
||||
monkeypatch.setattr(console_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
|
||||
monkeypatch.setattr(workflow_run_module, "current_user", account)
|
||||
monkeypatch.setattr(console_wraps.dify_config, "EDITION", "CLOUD")
|
||||
|
||||
|
||||
class _PauseEntity:
|
||||
def __init__(self, paused_at: datetime, reasons: list[HumanInputRequired]):
|
||||
self.paused_at = paused_at
|
||||
self._reasons = reasons
|
||||
|
||||
def get_pause_reasons(self):
|
||||
return self._reasons
|
||||
|
||||
|
||||
def test_pause_details_returns_backstage_input_url(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
account = _make_account()
|
||||
_patch_console_guards(monkeypatch, account)
|
||||
monkeypatch.setattr(workflow_run_module.dify_config, "APP_WEB_URL", "https://web.example.com")
|
||||
|
||||
workflow_run = Mock(spec=WorkflowRun)
|
||||
workflow_run.tenant_id = "tenant-123"
|
||||
workflow_run.status = WorkflowExecutionStatus.PAUSED
|
||||
workflow_run.created_at = datetime(2024, 1, 1, 12, 0, 0)
|
||||
fake_db = SimpleNamespace(engine=Mock(), session=SimpleNamespace(get=lambda *_: workflow_run))
|
||||
monkeypatch.setattr(workflow_run_module, "db", fake_db)
|
||||
|
||||
reason = HumanInputRequired(
|
||||
form_id="form-1",
|
||||
form_content="content",
|
||||
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")],
|
||||
actions=[UserAction(id="approve", title="Approve")],
|
||||
node_id="node-1",
|
||||
node_title="Ask Name",
|
||||
form_token="backstage-token",
|
||||
)
|
||||
pause_entity = _PauseEntity(paused_at=datetime(2024, 1, 1, 12, 0, 0), reasons=[reason])
|
||||
|
||||
repo = Mock()
|
||||
repo.get_workflow_pause.return_value = pause_entity
|
||||
monkeypatch.setattr(
|
||||
workflow_run_module.DifyAPIRepositoryFactory,
|
||||
"create_api_workflow_run_repository",
|
||||
lambda *_, **__: repo,
|
||||
)
|
||||
|
||||
with app.test_request_context("/console/api/workflow/run-1/pause-details", method="GET"):
|
||||
response, status = workflow_run_module.ConsoleWorkflowPauseDetailsApi().get(workflow_run_id="run-1")
|
||||
|
||||
assert status == 200
|
||||
assert response["paused_at"] == "2024-01-01T12:00:00Z"
|
||||
assert response["paused_nodes"][0]["node_id"] == "node-1"
|
||||
assert response["paused_nodes"][0]["pause_type"]["type"] == "human_input"
|
||||
assert (
|
||||
response["paused_nodes"][0]["pause_type"]["backstage_input_url"]
|
||||
== "https://web.example.com/form/backstage-token"
|
||||
)
|
||||
assert "pending_human_inputs" not in response
|
||||
|
||||
|
||||
def test_pause_details_tenant_isolation(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
account = _make_account()
|
||||
_patch_console_guards(monkeypatch, account)
|
||||
monkeypatch.setattr(workflow_run_module.dify_config, "APP_WEB_URL", "https://web.example.com")
|
||||
|
||||
workflow_run = Mock(spec=WorkflowRun)
|
||||
workflow_run.tenant_id = "tenant-456"
|
||||
workflow_run.status = WorkflowExecutionStatus.PAUSED
|
||||
workflow_run.created_at = datetime(2024, 1, 1, 12, 0, 0)
|
||||
fake_db = SimpleNamespace(engine=Mock(), session=SimpleNamespace(get=lambda *_: workflow_run))
|
||||
monkeypatch.setattr(workflow_run_module, "db", fake_db)
|
||||
|
||||
with pytest.raises(NotFoundError):
|
||||
with app.test_request_context("/console/api/workflow/run-1/pause-details", method="GET"):
|
||||
response, status = workflow_run_module.ConsoleWorkflowPauseDetailsApi().get(workflow_run_id="run-1")
|
||||
@ -0,0 +1,25 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from controllers.service_api.app.workflow import WorkflowRunOutputsField, WorkflowRunStatusField
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
|
||||
|
||||
def test_workflow_run_status_field_with_enum() -> None:
|
||||
field = WorkflowRunStatusField()
|
||||
obj = SimpleNamespace(status=WorkflowExecutionStatus.PAUSED)
|
||||
|
||||
assert field.output("status", obj) == "paused"
|
||||
|
||||
|
||||
def test_workflow_run_outputs_field_paused_returns_empty() -> None:
|
||||
field = WorkflowRunOutputsField()
|
||||
obj = SimpleNamespace(status=WorkflowExecutionStatus.PAUSED, outputs_dict={"foo": "bar"})
|
||||
|
||||
assert field.output("outputs", obj) == {}
|
||||
|
||||
|
||||
def test_workflow_run_outputs_field_running_returns_outputs() -> None:
|
||||
field = WorkflowRunOutputsField()
|
||||
obj = SimpleNamespace(status=WorkflowExecutionStatus.RUNNING, outputs_dict={"foo": "bar"})
|
||||
|
||||
assert field.output("outputs", obj) == {"foo": "bar"}
|
||||
456
api/tests/unit_tests/controllers/web/test_human_input_form.py
Normal file
456
api/tests/unit_tests/controllers/web/test_human_input_form.py
Normal file
@ -0,0 +1,456 @@
|
||||
"""Unit tests for controllers.web.human_input_form endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import controllers.web.human_input_form as human_input_module
|
||||
import controllers.web.site as site_module
|
||||
from controllers.web.error import WebFormRateLimitExceededError
|
||||
from models.human_input import RecipientType
|
||||
from services.human_input_service import FormExpiredError
|
||||
|
||||
HumanInputFormApi = human_input_module.HumanInputFormApi
|
||||
TenantStatus = human_input_module.TenantStatus
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
"""Configure a minimal Flask app for request contexts."""
|
||||
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
"""Simple stand-in for db.session that returns pre-seeded objects."""
|
||||
|
||||
def __init__(self, mapping: dict[str, Any]):
|
||||
self._mapping = mapping
|
||||
self._model_name: str | None = None
|
||||
|
||||
def query(self, model):
|
||||
self._model_name = model.__name__
|
||||
return self
|
||||
|
||||
def where(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
assert self._model_name is not None
|
||||
return self._mapping.get(self._model_name)
|
||||
|
||||
|
||||
class _FakeDB:
|
||||
"""Minimal db stub exposing engine and session."""
|
||||
|
||||
def __init__(self, session: _FakeSession):
|
||||
self.session = session
|
||||
self.engine = object()
|
||||
|
||||
|
||||
def test_get_form_includes_site(monkeypatch: pytest.MonkeyPatch, app: Flask):
|
||||
"""GET returns form definition merged with site payload."""
|
||||
|
||||
expiration_time = datetime(2099, 1, 1, tzinfo=UTC)
|
||||
|
||||
class _FakeDefinition:
|
||||
def model_dump(self):
|
||||
return {
|
||||
"form_content": "Raw content",
|
||||
"rendered_content": "Rendered {{#$output.name#}}",
|
||||
"inputs": [{"type": "text", "output_variable_name": "name", "default": None}],
|
||||
"default_values": {"name": "Alice", "age": 30, "meta": {"k": "v"}},
|
||||
"user_actions": [{"id": "approve", "title": "Approve", "button_style": "default"}],
|
||||
}
|
||||
|
||||
class _FakeForm:
|
||||
def __init__(self, expiration: datetime):
|
||||
self.workflow_run_id = "workflow-1"
|
||||
self.app_id = "app-1"
|
||||
self.tenant_id = "tenant-1"
|
||||
self.expiration_time = expiration
|
||||
self.recipient_type = RecipientType.BACKSTAGE
|
||||
|
||||
def get_definition(self):
|
||||
return _FakeDefinition()
|
||||
|
||||
form = _FakeForm(expiration_time)
|
||||
limiter_mock = MagicMock()
|
||||
limiter_mock.is_rate_limited.return_value = False
|
||||
monkeypatch.setattr(human_input_module, "_FORM_ACCESS_RATE_LIMITER", limiter_mock)
|
||||
monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10")
|
||||
|
||||
tenant = SimpleNamespace(
|
||||
id="tenant-1",
|
||||
status=TenantStatus.NORMAL,
|
||||
plan="basic",
|
||||
custom_config_dict={"remove_webapp_brand": True, "replace_webapp_logo": False},
|
||||
)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True)
|
||||
workflow_run = SimpleNamespace(app_id="app-1")
|
||||
site_model = SimpleNamespace(
|
||||
title="My Site",
|
||||
icon_type="emoji",
|
||||
icon="robot",
|
||||
icon_background="#fff",
|
||||
description="desc",
|
||||
default_language="en",
|
||||
chat_color_theme="light",
|
||||
chat_color_theme_inverted=False,
|
||||
copyright=None,
|
||||
privacy_policy=None,
|
||||
custom_disclaimer=None,
|
||||
prompt_public=False,
|
||||
show_workflow_steps=True,
|
||||
use_icon_as_answer_icon=False,
|
||||
)
|
||||
|
||||
# Patch service to return fake form.
|
||||
service_mock = MagicMock()
|
||||
service_mock.get_form_by_token.return_value = form
|
||||
monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock)
|
||||
|
||||
# Patch db session.
|
||||
db_stub = _FakeDB(_FakeSession({"WorkflowRun": workflow_run, "App": app_model, "Site": site_model}))
|
||||
monkeypatch.setattr(human_input_module, "db", db_stub)
|
||||
|
||||
monkeypatch.setattr(
|
||||
site_module.FeatureService,
|
||||
"get_features",
|
||||
lambda tenant_id: SimpleNamespace(can_replace_logo=True),
|
||||
)
|
||||
|
||||
with app.test_request_context("/api/form/human_input/token-1", method="GET"):
|
||||
response = HumanInputFormApi().get("token-1")
|
||||
|
||||
body = json.loads(response.get_data(as_text=True))
|
||||
assert set(body.keys()) == {
|
||||
"site",
|
||||
"form_content",
|
||||
"inputs",
|
||||
"resolved_default_values",
|
||||
"user_actions",
|
||||
"expiration_time",
|
||||
}
|
||||
assert body["form_content"] == "Rendered {{#$output.name#}}"
|
||||
assert body["inputs"] == [{"type": "text", "output_variable_name": "name", "default": None}]
|
||||
assert body["resolved_default_values"] == {"name": "Alice", "age": "30", "meta": '{"k": "v"}'}
|
||||
assert body["user_actions"] == [{"id": "approve", "title": "Approve", "button_style": "default"}]
|
||||
assert body["expiration_time"] == int(expiration_time.timestamp())
|
||||
assert body["site"] == {
|
||||
"app_id": "app-1",
|
||||
"end_user_id": None,
|
||||
"enable_site": True,
|
||||
"site": {
|
||||
"title": "My Site",
|
||||
"chat_color_theme": "light",
|
||||
"chat_color_theme_inverted": False,
|
||||
"icon_type": "emoji",
|
||||
"icon": "robot",
|
||||
"icon_background": "#fff",
|
||||
"icon_url": None,
|
||||
"description": "desc",
|
||||
"copyright": None,
|
||||
"privacy_policy": None,
|
||||
"custom_disclaimer": None,
|
||||
"default_language": "en",
|
||||
"prompt_public": False,
|
||||
"show_workflow_steps": True,
|
||||
"use_icon_as_answer_icon": False,
|
||||
},
|
||||
"model_config": None,
|
||||
"plan": "basic",
|
||||
"can_replace_logo": True,
|
||||
"custom_config": {
|
||||
"remove_webapp_brand": True,
|
||||
"replace_webapp_logo": None,
|
||||
},
|
||||
}
|
||||
service_mock.get_form_by_token.assert_called_once_with("token-1")
|
||||
limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10")
|
||||
limiter_mock.increment_rate_limit.assert_called_once_with("203.0.113.10")
|
||||
|
||||
|
||||
def test_get_form_allows_backstage_token(monkeypatch: pytest.MonkeyPatch, app: Flask):
|
||||
"""GET returns form payload for backstage token."""
|
||||
|
||||
expiration_time = datetime(2099, 1, 2, tzinfo=UTC)
|
||||
|
||||
class _FakeDefinition:
|
||||
def model_dump(self):
|
||||
return {
|
||||
"form_content": "Raw content",
|
||||
"rendered_content": "Rendered",
|
||||
"inputs": [],
|
||||
"default_values": {},
|
||||
"user_actions": [],
|
||||
}
|
||||
|
||||
class _FakeForm:
|
||||
def __init__(self, expiration: datetime):
|
||||
self.workflow_run_id = "workflow-1"
|
||||
self.app_id = "app-1"
|
||||
self.tenant_id = "tenant-1"
|
||||
self.expiration_time = expiration
|
||||
|
||||
def get_definition(self):
|
||||
return _FakeDefinition()
|
||||
|
||||
form = _FakeForm(expiration_time)
|
||||
limiter_mock = MagicMock()
|
||||
limiter_mock.is_rate_limited.return_value = False
|
||||
monkeypatch.setattr(human_input_module, "_FORM_ACCESS_RATE_LIMITER", limiter_mock)
|
||||
monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10")
|
||||
tenant = SimpleNamespace(
|
||||
id="tenant-1",
|
||||
status=TenantStatus.NORMAL,
|
||||
plan="basic",
|
||||
custom_config_dict={"remove_webapp_brand": True, "replace_webapp_logo": False},
|
||||
)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True)
|
||||
workflow_run = SimpleNamespace(app_id="app-1")
|
||||
site_model = SimpleNamespace(
|
||||
title="My Site",
|
||||
icon_type="emoji",
|
||||
icon="robot",
|
||||
icon_background="#fff",
|
||||
description="desc",
|
||||
default_language="en",
|
||||
chat_color_theme="light",
|
||||
chat_color_theme_inverted=False,
|
||||
copyright=None,
|
||||
privacy_policy=None,
|
||||
custom_disclaimer=None,
|
||||
prompt_public=False,
|
||||
show_workflow_steps=True,
|
||||
use_icon_as_answer_icon=False,
|
||||
)
|
||||
|
||||
service_mock = MagicMock()
|
||||
service_mock.get_form_by_token.return_value = form
|
||||
monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock)
|
||||
|
||||
db_stub = _FakeDB(_FakeSession({"WorkflowRun": workflow_run, "App": app_model, "Site": site_model}))
|
||||
monkeypatch.setattr(human_input_module, "db", db_stub)
|
||||
|
||||
monkeypatch.setattr(
|
||||
site_module.FeatureService,
|
||||
"get_features",
|
||||
lambda tenant_id: SimpleNamespace(can_replace_logo=True),
|
||||
)
|
||||
|
||||
with app.test_request_context("/api/form/human_input/token-1", method="GET"):
|
||||
response = HumanInputFormApi().get("token-1")
|
||||
|
||||
body = json.loads(response.get_data(as_text=True))
|
||||
assert set(body.keys()) == {
|
||||
"site",
|
||||
"form_content",
|
||||
"inputs",
|
||||
"resolved_default_values",
|
||||
"user_actions",
|
||||
"expiration_time",
|
||||
}
|
||||
assert body["form_content"] == "Rendered"
|
||||
assert body["inputs"] == []
|
||||
assert body["resolved_default_values"] == {}
|
||||
assert body["user_actions"] == []
|
||||
assert body["expiration_time"] == int(expiration_time.timestamp())
|
||||
assert body["site"] == {
|
||||
"app_id": "app-1",
|
||||
"end_user_id": None,
|
||||
"enable_site": True,
|
||||
"site": {
|
||||
"title": "My Site",
|
||||
"chat_color_theme": "light",
|
||||
"chat_color_theme_inverted": False,
|
||||
"icon_type": "emoji",
|
||||
"icon": "robot",
|
||||
"icon_background": "#fff",
|
||||
"icon_url": None,
|
||||
"description": "desc",
|
||||
"copyright": None,
|
||||
"privacy_policy": None,
|
||||
"custom_disclaimer": None,
|
||||
"default_language": "en",
|
||||
"prompt_public": False,
|
||||
"show_workflow_steps": True,
|
||||
"use_icon_as_answer_icon": False,
|
||||
},
|
||||
"model_config": None,
|
||||
"plan": "basic",
|
||||
"can_replace_logo": True,
|
||||
"custom_config": {
|
||||
"remove_webapp_brand": True,
|
||||
"replace_webapp_logo": None,
|
||||
},
|
||||
}
|
||||
service_mock.get_form_by_token.assert_called_once_with("token-1")
|
||||
limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10")
|
||||
limiter_mock.increment_rate_limit.assert_called_once_with("203.0.113.10")
|
||||
|
||||
|
||||
def test_get_form_raises_forbidden_when_site_missing(monkeypatch: pytest.MonkeyPatch, app: Flask):
|
||||
"""GET raises Forbidden if site cannot be resolved."""
|
||||
|
||||
expiration_time = datetime(2099, 1, 3, tzinfo=UTC)
|
||||
|
||||
class _FakeDefinition:
|
||||
def model_dump(self):
|
||||
return {
|
||||
"form_content": "Raw content",
|
||||
"rendered_content": "Rendered",
|
||||
"inputs": [],
|
||||
"default_values": {},
|
||||
"user_actions": [],
|
||||
}
|
||||
|
||||
class _FakeForm:
|
||||
def __init__(self, expiration: datetime):
|
||||
self.workflow_run_id = "workflow-1"
|
||||
self.app_id = "app-1"
|
||||
self.tenant_id = "tenant-1"
|
||||
self.expiration_time = expiration
|
||||
|
||||
def get_definition(self):
|
||||
return _FakeDefinition()
|
||||
|
||||
form = _FakeForm(expiration_time)
|
||||
limiter_mock = MagicMock()
|
||||
limiter_mock.is_rate_limited.return_value = False
|
||||
monkeypatch.setattr(human_input_module, "_FORM_ACCESS_RATE_LIMITER", limiter_mock)
|
||||
monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10")
|
||||
tenant = SimpleNamespace(status=TenantStatus.NORMAL)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant)
|
||||
workflow_run = SimpleNamespace(app_id="app-1")
|
||||
|
||||
service_mock = MagicMock()
|
||||
service_mock.get_form_by_token.return_value = form
|
||||
monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock)
|
||||
|
||||
db_stub = _FakeDB(_FakeSession({"WorkflowRun": workflow_run, "App": app_model, "Site": None}))
|
||||
monkeypatch.setattr(human_input_module, "db", db_stub)
|
||||
|
||||
with app.test_request_context("/api/form/human_input/token-1", method="GET"):
|
||||
with pytest.raises(Forbidden):
|
||||
HumanInputFormApi().get("token-1")
|
||||
limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10")
|
||||
limiter_mock.increment_rate_limit.assert_called_once_with("203.0.113.10")
|
||||
|
||||
|
||||
def test_submit_form_accepts_backstage_token(monkeypatch: pytest.MonkeyPatch, app: Flask):
|
||||
"""POST forwards backstage submissions to the service."""
|
||||
|
||||
class _FakeForm:
|
||||
recipient_type = RecipientType.BACKSTAGE
|
||||
|
||||
form = _FakeForm()
|
||||
limiter_mock = MagicMock()
|
||||
limiter_mock.is_rate_limited.return_value = False
|
||||
monkeypatch.setattr(human_input_module, "_FORM_SUBMIT_RATE_LIMITER", limiter_mock)
|
||||
monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10")
|
||||
service_mock = MagicMock()
|
||||
service_mock.get_form_by_token.return_value = form
|
||||
monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock)
|
||||
monkeypatch.setattr(human_input_module, "db", _FakeDB(_FakeSession({})))
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/form/human_input/token-1",
|
||||
method="POST",
|
||||
json={"inputs": {"content": "ok"}, "action": "approve"},
|
||||
):
|
||||
response, status = HumanInputFormApi().post("token-1")
|
||||
|
||||
assert status == 200
|
||||
assert response == {}
|
||||
service_mock.submit_form_by_token.assert_called_once_with(
|
||||
recipient_type=RecipientType.BACKSTAGE,
|
||||
form_token="token-1",
|
||||
selected_action_id="approve",
|
||||
form_data={"content": "ok"},
|
||||
submission_end_user_id=None,
|
||||
)
|
||||
limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10")
|
||||
limiter_mock.increment_rate_limit.assert_called_once_with("203.0.113.10")
|
||||
|
||||
|
||||
def test_submit_form_rate_limited(monkeypatch: pytest.MonkeyPatch, app: Flask):
|
||||
"""POST rejects submissions when rate limit is exceeded."""
|
||||
|
||||
limiter_mock = MagicMock()
|
||||
limiter_mock.is_rate_limited.return_value = True
|
||||
monkeypatch.setattr(human_input_module, "_FORM_SUBMIT_RATE_LIMITER", limiter_mock)
|
||||
monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10")
|
||||
|
||||
service_mock = MagicMock()
|
||||
service_mock.get_form_by_token.return_value = None
|
||||
monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock)
|
||||
monkeypatch.setattr(human_input_module, "db", _FakeDB(_FakeSession({})))
|
||||
|
||||
with app.test_request_context(
|
||||
"/api/form/human_input/token-1",
|
||||
method="POST",
|
||||
json={"inputs": {"content": "ok"}, "action": "approve"},
|
||||
):
|
||||
with pytest.raises(WebFormRateLimitExceededError):
|
||||
HumanInputFormApi().post("token-1")
|
||||
|
||||
limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10")
|
||||
limiter_mock.increment_rate_limit.assert_not_called()
|
||||
service_mock.get_form_by_token.assert_not_called()
|
||||
|
||||
|
||||
def test_get_form_rate_limited(monkeypatch: pytest.MonkeyPatch, app: Flask):
|
||||
"""GET rejects requests when rate limit is exceeded."""
|
||||
|
||||
limiter_mock = MagicMock()
|
||||
limiter_mock.is_rate_limited.return_value = True
|
||||
monkeypatch.setattr(human_input_module, "_FORM_ACCESS_RATE_LIMITER", limiter_mock)
|
||||
monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10")
|
||||
|
||||
service_mock = MagicMock()
|
||||
service_mock.get_form_by_token.return_value = None
|
||||
monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock)
|
||||
monkeypatch.setattr(human_input_module, "db", _FakeDB(_FakeSession({})))
|
||||
|
||||
with app.test_request_context("/api/form/human_input/token-1", method="GET"):
|
||||
with pytest.raises(WebFormRateLimitExceededError):
|
||||
HumanInputFormApi().get("token-1")
|
||||
|
||||
limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10")
|
||||
limiter_mock.increment_rate_limit.assert_not_called()
|
||||
service_mock.get_form_by_token.assert_not_called()
|
||||
|
||||
|
||||
def test_get_form_raises_expired(monkeypatch: pytest.MonkeyPatch, app: Flask):
|
||||
class _FakeForm:
|
||||
pass
|
||||
|
||||
form = _FakeForm()
|
||||
limiter_mock = MagicMock()
|
||||
limiter_mock.is_rate_limited.return_value = False
|
||||
monkeypatch.setattr(human_input_module, "_FORM_ACCESS_RATE_LIMITER", limiter_mock)
|
||||
monkeypatch.setattr(human_input_module, "extract_remote_ip", lambda req: "203.0.113.10")
|
||||
service_mock = MagicMock()
|
||||
service_mock.get_form_by_token.return_value = form
|
||||
service_mock.ensure_form_active.side_effect = FormExpiredError("form-id")
|
||||
monkeypatch.setattr(human_input_module, "HumanInputService", lambda engine: service_mock)
|
||||
monkeypatch.setattr(human_input_module, "db", _FakeDB(_FakeSession({})))
|
||||
|
||||
with app.test_request_context("/api/form/human_input/token-1", method="GET"):
|
||||
with pytest.raises(FormExpiredError):
|
||||
HumanInputFormApi().get("token-1")
|
||||
|
||||
service_mock.ensure_form_active.assert_called_once_with(form)
|
||||
limiter_mock.is_rate_limited.assert_called_once_with("203.0.113.10")
|
||||
limiter_mock.increment_rate_limit.assert_called_once_with("203.0.113.10")
|
||||
@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
@ -12,6 +13,8 @@ import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
from core.entities.execution_extra_content import HumanInputContent
|
||||
|
||||
# Ensure flask_restx.api finds MethodView during import.
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
@ -137,6 +140,12 @@ def test_message_list_mapping(app: Flask) -> None:
|
||||
status="success",
|
||||
error=None,
|
||||
message_metadata_dict={"meta": "value"},
|
||||
extra_contents=[
|
||||
HumanInputContent(
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
submitted=True,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
pagination = SimpleNamespace(limit=20, has_more=False, data=[message])
|
||||
@ -169,6 +178,8 @@ def test_message_list_mapping(app: Flask) -> None:
|
||||
|
||||
assert item["agent_thoughts"][0]["chain_id"] == "chain-1"
|
||||
assert item["agent_thoughts"][0]["created_at"] == int(thought_created_at.timestamp())
|
||||
assert item["extra_contents"][0]["workflow_run_id"] == message.extra_contents[0].workflow_run_id
|
||||
assert item["extra_contents"][0]["submitted"] == message.extra_contents[0].submitted
|
||||
|
||||
assert item["message_files"][0]["id"] == "file-dict"
|
||||
assert item["message_files"][1]["id"] == "file-obj"
|
||||
|
||||
@ -0,0 +1,187 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.advanced_chat import generate_task_pipeline as pipeline_module
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueTextChunkEvent, QueueWorkflowPausedEvent
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from models.enums import MessageStatus
|
||||
from models.execution_extra_content import HumanInputContent
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
def _build_pipeline() -> pipeline_module.AdvancedChatAppGenerateTaskPipeline:
|
||||
pipeline = pipeline_module.AdvancedChatAppGenerateTaskPipeline.__new__(
|
||||
pipeline_module.AdvancedChatAppGenerateTaskPipeline
|
||||
)
|
||||
pipeline._workflow_run_id = "run-1"
|
||||
pipeline._message_id = "message-1"
|
||||
pipeline._workflow_tenant_id = "tenant-1"
|
||||
return pipeline
|
||||
|
||||
|
||||
def test_persist_human_input_extra_content_adds_record(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
pipeline = _build_pipeline()
|
||||
monkeypatch.setattr(pipeline, "_load_human_input_form_id", lambda **kwargs: "form-1")
|
||||
|
||||
captured_session: dict[str, mock.Mock] = {}
|
||||
|
||||
@contextmanager
|
||||
def fake_session():
|
||||
session = mock.Mock()
|
||||
session.scalar.return_value = None
|
||||
captured_session["session"] = session
|
||||
yield session
|
||||
|
||||
pipeline._database_session = fake_session # type: ignore[method-assign]
|
||||
|
||||
pipeline._persist_human_input_extra_content(node_id="node-1")
|
||||
|
||||
session = captured_session["session"]
|
||||
session.add.assert_called_once()
|
||||
content = session.add.call_args.args[0]
|
||||
assert isinstance(content, HumanInputContent)
|
||||
assert content.workflow_run_id == "run-1"
|
||||
assert content.message_id == "message-1"
|
||||
assert content.form_id == "form-1"
|
||||
|
||||
|
||||
def test_persist_human_input_extra_content_skips_when_form_missing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
pipeline = _build_pipeline()
|
||||
monkeypatch.setattr(pipeline, "_load_human_input_form_id", lambda **kwargs: None)
|
||||
|
||||
called = {"value": False}
|
||||
|
||||
@contextmanager
|
||||
def fake_session():
|
||||
called["value"] = True
|
||||
session = mock.Mock()
|
||||
yield session
|
||||
|
||||
pipeline._database_session = fake_session # type: ignore[method-assign]
|
||||
|
||||
pipeline._persist_human_input_extra_content(node_id="node-1")
|
||||
|
||||
assert called["value"] is False
|
||||
|
||||
|
||||
def test_persist_human_input_extra_content_skips_when_existing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
pipeline = _build_pipeline()
|
||||
monkeypatch.setattr(pipeline, "_load_human_input_form_id", lambda **kwargs: "form-1")
|
||||
|
||||
captured_session: dict[str, mock.Mock] = {}
|
||||
|
||||
@contextmanager
|
||||
def fake_session():
|
||||
session = mock.Mock()
|
||||
session.scalar.return_value = HumanInputContent(
|
||||
workflow_run_id="run-1",
|
||||
message_id="message-1",
|
||||
form_id="form-1",
|
||||
)
|
||||
captured_session["session"] = session
|
||||
yield session
|
||||
|
||||
pipeline._database_session = fake_session # type: ignore[method-assign]
|
||||
|
||||
pipeline._persist_human_input_extra_content(node_id="node-1")
|
||||
|
||||
session = captured_session["session"]
|
||||
session.add.assert_not_called()
|
||||
|
||||
|
||||
def test_handle_workflow_paused_event_persists_human_input_extra_content() -> None:
|
||||
pipeline = _build_pipeline()
|
||||
pipeline._application_generate_entity = SimpleNamespace(task_id="task-1")
|
||||
pipeline._workflow_response_converter = mock.Mock()
|
||||
pipeline._workflow_response_converter.workflow_pause_to_stream_response.return_value = []
|
||||
pipeline._ensure_graph_runtime_initialized = mock.Mock(
|
||||
return_value=SimpleNamespace(
|
||||
total_tokens=0,
|
||||
node_run_steps=0,
|
||||
),
|
||||
)
|
||||
pipeline._save_message = mock.Mock()
|
||||
message = SimpleNamespace(status=MessageStatus.NORMAL)
|
||||
pipeline._get_message = mock.Mock(return_value=message)
|
||||
pipeline._persist_human_input_extra_content = mock.Mock()
|
||||
pipeline._base_task_pipeline = mock.Mock()
|
||||
pipeline._base_task_pipeline.queue_manager = mock.Mock()
|
||||
pipeline._message_saved_on_pause = False
|
||||
|
||||
@contextmanager
|
||||
def fake_session():
|
||||
session = mock.Mock()
|
||||
yield session
|
||||
|
||||
pipeline._database_session = fake_session # type: ignore[method-assign]
|
||||
|
||||
reason = HumanInputRequired(
|
||||
form_id="form-1",
|
||||
form_content="content",
|
||||
inputs=[],
|
||||
actions=[],
|
||||
node_id="node-1",
|
||||
node_title="Approval",
|
||||
form_token="token-1",
|
||||
resolved_default_values={},
|
||||
)
|
||||
event = QueueWorkflowPausedEvent(reasons=[reason], outputs={}, paused_nodes=["node-1"])
|
||||
|
||||
list(pipeline._handle_workflow_paused_event(event))
|
||||
|
||||
pipeline._persist_human_input_extra_content.assert_called_once_with(form_id="form-1", node_id="node-1")
|
||||
assert message.status == MessageStatus.PAUSED
|
||||
|
||||
|
||||
def test_resume_appends_chunks_to_paused_answer() -> None:
|
||||
app_config = SimpleNamespace(app_id="app-1", tenant_id="tenant-1", sensitive_word_avoidance=None)
|
||||
application_generate_entity = SimpleNamespace(
|
||||
app_config=app_config,
|
||||
files=[],
|
||||
workflow_run_id="run-1",
|
||||
query="hello",
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
inputs={},
|
||||
task_id="task-1",
|
||||
)
|
||||
queue_manager = SimpleNamespace(graph_runtime_state=None)
|
||||
conversation = SimpleNamespace(id="conversation-1", mode="advanced-chat")
|
||||
message = SimpleNamespace(
|
||||
id="message-1",
|
||||
created_at=datetime(2024, 1, 1),
|
||||
query="hello",
|
||||
answer="before",
|
||||
status=MessageStatus.PAUSED,
|
||||
)
|
||||
user = EndUser()
|
||||
user.id = "user-1"
|
||||
user.session_id = "session-1"
|
||||
workflow = SimpleNamespace(id="workflow-1", tenant_id="tenant-1", features_dict={})
|
||||
|
||||
pipeline = pipeline_module.AdvancedChatAppGenerateTaskPipeline(
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
user=user,
|
||||
stream=True,
|
||||
dialogue_count=1,
|
||||
draft_var_saver_factory=SimpleNamespace(),
|
||||
)
|
||||
|
||||
pipeline._get_message = mock.Mock(return_value=message)
|
||||
pipeline._recorded_files = []
|
||||
|
||||
list(pipeline._handle_text_chunk_event(QueueTextChunkEvent(text="after")))
|
||||
pipeline._save_message(session=mock.Mock())
|
||||
|
||||
assert message.answer == "beforeafter"
|
||||
assert message.status == MessageStatus.NORMAL
|
||||
@ -0,0 +1,87 @@
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueHumanInputFormFilledEvent, QueueHumanInputFormTimeoutEvent
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
|
||||
def _build_converter():
|
||||
system_variables = SystemVariable(
|
||||
files=[],
|
||||
user_id="user-1",
|
||||
app_id="app-1",
|
||||
workflow_id="wf-1",
|
||||
workflow_execution_id="run-1",
|
||||
)
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
|
||||
app_entity = SimpleNamespace(
|
||||
task_id="task-1",
|
||||
app_config=SimpleNamespace(app_id="app-1", tenant_id="tenant-1"),
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
files=[],
|
||||
inputs={},
|
||||
workflow_execution_id="run-1",
|
||||
call_depth=0,
|
||||
)
|
||||
account = SimpleNamespace(id="acc-1", name="tester", email="tester@example.com")
|
||||
return WorkflowResponseConverter(
|
||||
application_generate_entity=app_entity,
|
||||
user=account,
|
||||
system_variables=system_variables,
|
||||
)
|
||||
|
||||
|
||||
def test_human_input_form_filled_stream_response_contains_rendered_content():
|
||||
converter = _build_converter()
|
||||
converter.workflow_start_to_stream_response(
|
||||
task_id="task-1",
|
||||
workflow_run_id="run-1",
|
||||
workflow_id="wf-1",
|
||||
reason=WorkflowStartReason.INITIAL,
|
||||
)
|
||||
|
||||
queue_event = QueueHumanInputFormFilledEvent(
|
||||
node_execution_id="exec-1",
|
||||
node_id="node-1",
|
||||
node_type="human-input",
|
||||
node_title="Human Input",
|
||||
rendered_content="# Title\nvalue",
|
||||
action_id="Approve",
|
||||
action_text="Approve",
|
||||
)
|
||||
|
||||
resp = converter.human_input_form_filled_to_stream_response(event=queue_event, task_id="task-1")
|
||||
|
||||
assert resp.workflow_run_id == "run-1"
|
||||
assert resp.data.node_id == "node-1"
|
||||
assert resp.data.node_title == "Human Input"
|
||||
assert resp.data.rendered_content.startswith("# Title")
|
||||
assert resp.data.action_id == "Approve"
|
||||
|
||||
|
||||
def test_human_input_form_timeout_stream_response_contains_timeout_metadata():
|
||||
converter = _build_converter()
|
||||
converter.workflow_start_to_stream_response(
|
||||
task_id="task-1",
|
||||
workflow_run_id="run-1",
|
||||
workflow_id="wf-1",
|
||||
reason=WorkflowStartReason.INITIAL,
|
||||
)
|
||||
|
||||
queue_event = QueueHumanInputFormTimeoutEvent(
|
||||
node_id="node-1",
|
||||
node_type="human-input",
|
||||
node_title="Human Input",
|
||||
expiration_time=datetime(2025, 1, 1, tzinfo=UTC),
|
||||
)
|
||||
|
||||
resp = converter.human_input_form_timeout_to_stream_response(event=queue_event, task_id="task-1")
|
||||
|
||||
assert resp.workflow_run_id == "run-1"
|
||||
assert resp.data.node_id == "node-1"
|
||||
assert resp.data.node_title == "Human Input"
|
||||
assert resp.data.expiration_time == 1735689600
|
||||
@ -0,0 +1,56 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
|
||||
def _build_converter() -> WorkflowResponseConverter:
|
||||
"""Construct a minimal WorkflowResponseConverter for testing."""
|
||||
system_variables = SystemVariable(
|
||||
files=[],
|
||||
user_id="user-1",
|
||||
app_id="app-1",
|
||||
workflow_id="wf-1",
|
||||
workflow_execution_id="run-1",
|
||||
)
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
|
||||
app_entity = SimpleNamespace(
|
||||
task_id="task-1",
|
||||
app_config=SimpleNamespace(app_id="app-1", tenant_id="tenant-1"),
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
files=[],
|
||||
inputs={},
|
||||
workflow_execution_id="run-1",
|
||||
call_depth=0,
|
||||
)
|
||||
account = SimpleNamespace(id="acc-1", name="tester", email="tester@example.com")
|
||||
return WorkflowResponseConverter(
|
||||
application_generate_entity=app_entity,
|
||||
user=account,
|
||||
system_variables=system_variables,
|
||||
)
|
||||
|
||||
|
||||
def test_workflow_start_stream_response_carries_resumption_reason():
|
||||
converter = _build_converter()
|
||||
resp = converter.workflow_start_to_stream_response(
|
||||
task_id="task-1",
|
||||
workflow_run_id="run-1",
|
||||
workflow_id="wf-1",
|
||||
reason=WorkflowStartReason.RESUMPTION,
|
||||
)
|
||||
assert resp.data.reason is WorkflowStartReason.RESUMPTION
|
||||
|
||||
|
||||
def test_workflow_start_stream_response_carries_initial_reason():
|
||||
converter = _build_converter()
|
||||
resp = converter.workflow_start_to_stream_response(
|
||||
task_id="task-1",
|
||||
workflow_run_id="run-1",
|
||||
workflow_id="wf-1",
|
||||
reason=WorkflowStartReason.INITIAL,
|
||||
)
|
||||
assert resp.data.reason is WorkflowStartReason.INITIAL
|
||||
@ -23,6 +23,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
)
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@ -124,7 +125,12 @@ class TestWorkflowResponseConverter:
|
||||
original_data = {"large_field": "x" * 10000, "metadata": "info"}
|
||||
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
|
||||
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
converter.workflow_start_to_stream_response(
|
||||
task_id="bootstrap",
|
||||
workflow_run_id="run-id",
|
||||
workflow_id="wf-id",
|
||||
reason=WorkflowStartReason.INITIAL,
|
||||
)
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
@ -160,7 +166,12 @@ class TestWorkflowResponseConverter:
|
||||
|
||||
original_data = {"small": "data"}
|
||||
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
converter.workflow_start_to_stream_response(
|
||||
task_id="bootstrap",
|
||||
workflow_run_id="run-id",
|
||||
workflow_id="wf-id",
|
||||
reason=WorkflowStartReason.INITIAL,
|
||||
)
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
@ -191,7 +202,12 @@ class TestWorkflowResponseConverter:
|
||||
"""Test node finish response when process_data is None."""
|
||||
converter = self.create_workflow_response_converter()
|
||||
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
converter.workflow_start_to_stream_response(
|
||||
task_id="bootstrap",
|
||||
workflow_run_id="run-id",
|
||||
workflow_id="wf-id",
|
||||
reason=WorkflowStartReason.INITIAL,
|
||||
)
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
@ -225,7 +241,12 @@ class TestWorkflowResponseConverter:
|
||||
original_data = {"large_field": "x" * 10000, "metadata": "info"}
|
||||
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
|
||||
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
converter.workflow_start_to_stream_response(
|
||||
task_id="bootstrap",
|
||||
workflow_run_id="run-id",
|
||||
workflow_id="wf-id",
|
||||
reason=WorkflowStartReason.INITIAL,
|
||||
)
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
@ -261,7 +282,12 @@ class TestWorkflowResponseConverter:
|
||||
|
||||
original_data = {"small": "data"}
|
||||
|
||||
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
|
||||
converter.workflow_start_to_stream_response(
|
||||
task_id="bootstrap",
|
||||
workflow_run_id="run-id",
|
||||
workflow_id="wf-id",
|
||||
reason=WorkflowStartReason.INITIAL,
|
||||
)
|
||||
start_event = self.create_node_started_event()
|
||||
converter.workflow_node_start_to_stream_response(
|
||||
event=start_event,
|
||||
@ -400,6 +426,7 @@ class TestWorkflowResponseConverterServiceApiTruncation:
|
||||
task_id="test-task-id",
|
||||
workflow_run_id="test-workflow-run-id",
|
||||
workflow_id="test-workflow-id",
|
||||
reason=WorkflowStartReason.INITIAL,
|
||||
)
|
||||
return converter
|
||||
|
||||
|
||||
@ -0,0 +1,139 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig
|
||||
from core.app.apps import message_based_app_generator
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.task_pipeline import message_cycle_manager
|
||||
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
from models.model import AppMode, Conversation, Message
|
||||
|
||||
|
||||
def _make_app_config() -> WorkflowUIBasedAppConfig:
|
||||
return WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
app_mode=AppMode.ADVANCED_CHAT,
|
||||
workflow_id="workflow-id",
|
||||
additional_features=AppAdditionalFeatures(),
|
||||
variables=[],
|
||||
)
|
||||
|
||||
|
||||
def _make_generate_entity(app_config: WorkflowUIBasedAppConfig) -> AdvancedChatAppGenerateEntity:
|
||||
return AdvancedChatAppGenerateEntity(
|
||||
task_id="task-id",
|
||||
app_config=app_config,
|
||||
file_upload_config=None,
|
||||
conversation_id=None,
|
||||
inputs={},
|
||||
query="hello",
|
||||
files=[],
|
||||
parent_message_id=None,
|
||||
user_id="user-id",
|
||||
stream=True,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
extras={},
|
||||
workflow_run_id="workflow-run-id",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_db_session(monkeypatch):
|
||||
session = MagicMock()
|
||||
|
||||
def refresh_side_effect(obj):
|
||||
if isinstance(obj, Conversation) and obj.id is None:
|
||||
obj.id = "generated-conversation-id"
|
||||
if isinstance(obj, Message) and obj.id is None:
|
||||
obj.id = "generated-message-id"
|
||||
|
||||
session.refresh.side_effect = refresh_side_effect
|
||||
session.add.return_value = None
|
||||
session.commit.return_value = None
|
||||
|
||||
monkeypatch.setattr(message_based_app_generator, "db", SimpleNamespace(session=session))
|
||||
return session
|
||||
|
||||
|
||||
def test_init_generate_records_sets_conversation_metadata():
|
||||
app_config = _make_app_config()
|
||||
entity = _make_generate_entity(app_config)
|
||||
|
||||
generator = AdvancedChatAppGenerator()
|
||||
|
||||
conversation, _ = generator._init_generate_records(entity, conversation=None)
|
||||
|
||||
assert entity.conversation_id == "generated-conversation-id"
|
||||
assert conversation.id == "generated-conversation-id"
|
||||
assert entity.is_new_conversation is True
|
||||
|
||||
|
||||
def test_init_generate_records_marks_existing_conversation():
|
||||
app_config = _make_app_config()
|
||||
entity = _make_generate_entity(app_config)
|
||||
|
||||
existing_conversation = Conversation(
|
||||
app_id=app_config.app_id,
|
||||
app_model_config_id=None,
|
||||
model_provider=None,
|
||||
override_model_configs=None,
|
||||
model_id=None,
|
||||
mode=app_config.app_mode.value,
|
||||
name="existing",
|
||||
inputs={},
|
||||
introduction="",
|
||||
system_instruction="",
|
||||
system_instruction_tokens=0,
|
||||
status="normal",
|
||||
invoke_from=InvokeFrom.WEB_APP.value,
|
||||
from_source="api",
|
||||
from_end_user_id="user-id",
|
||||
from_account_id=None,
|
||||
)
|
||||
existing_conversation.id = "existing-conversation-id"
|
||||
|
||||
generator = AdvancedChatAppGenerator()
|
||||
|
||||
conversation, _ = generator._init_generate_records(entity, conversation=existing_conversation)
|
||||
|
||||
assert entity.conversation_id == "existing-conversation-id"
|
||||
assert conversation is existing_conversation
|
||||
assert entity.is_new_conversation is False
|
||||
|
||||
|
||||
def test_message_cycle_manager_uses_new_conversation_flag(monkeypatch):
|
||||
app_config = _make_app_config()
|
||||
entity = _make_generate_entity(app_config)
|
||||
entity.conversation_id = "existing-conversation-id"
|
||||
entity.is_new_conversation = True
|
||||
entity.extras = {"auto_generate_conversation_name": True}
|
||||
|
||||
captured = {}
|
||||
|
||||
class DummyThread:
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
self.started = False
|
||||
|
||||
def start(self):
|
||||
self.started = True
|
||||
|
||||
def fake_thread(**kwargs):
|
||||
thread = DummyThread(**kwargs)
|
||||
captured["thread"] = thread
|
||||
return thread
|
||||
|
||||
monkeypatch.setattr(message_cycle_manager, "Thread", fake_thread)
|
||||
|
||||
manager = MessageCycleManager(application_generate_entity=entity, task_state=MagicMock())
|
||||
thread = manager.generate_conversation_name(conversation_id="existing-conversation-id", query="hello")
|
||||
|
||||
assert thread is captured["thread"]
|
||||
assert thread.started is True
|
||||
assert entity.is_new_conversation is False
|
||||
@ -0,0 +1,127 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
AppAdditionalFeatures,
|
||||
EasyUIBasedAppConfig,
|
||||
EasyUIBasedAppModelConfigFrom,
|
||||
ModelConfigEntity,
|
||||
PromptTemplateEntity,
|
||||
)
|
||||
from core.app.apps import message_based_app_generator
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
|
||||
from models.model import AppMode, Conversation, Message
|
||||
|
||||
|
||||
class DummyModelConf:
|
||||
def __init__(self, provider: str = "mock-provider", model: str = "mock-model") -> None:
|
||||
self.provider = provider
|
||||
self.model = model
|
||||
|
||||
|
||||
class DummyCompletionGenerateEntity:
|
||||
__slots__ = ("app_config", "invoke_from", "user_id", "query", "inputs", "files", "model_conf")
|
||||
app_config: EasyUIBasedAppConfig
|
||||
invoke_from: InvokeFrom
|
||||
user_id: str
|
||||
query: str
|
||||
inputs: dict
|
||||
files: list
|
||||
model_conf: DummyModelConf
|
||||
|
||||
def __init__(self, app_config: EasyUIBasedAppConfig) -> None:
|
||||
self.app_config = app_config
|
||||
self.invoke_from = InvokeFrom.WEB_APP
|
||||
self.user_id = "user-id"
|
||||
self.query = "hello"
|
||||
self.inputs = {}
|
||||
self.files = []
|
||||
self.model_conf = DummyModelConf()
|
||||
|
||||
|
||||
def _make_app_config(app_mode: AppMode) -> EasyUIBasedAppConfig:
|
||||
return EasyUIBasedAppConfig(
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
app_mode=app_mode,
|
||||
app_model_config_from=EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG,
|
||||
app_model_config_id="model-config-id",
|
||||
app_model_config_dict={},
|
||||
model=ModelConfigEntity(provider="mock-provider", model="mock-model", mode="chat"),
|
||||
prompt_template=PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
|
||||
simple_prompt_template="Hello",
|
||||
),
|
||||
additional_features=AppAdditionalFeatures(),
|
||||
variables=[],
|
||||
)
|
||||
|
||||
|
||||
def _make_chat_generate_entity(app_config: EasyUIBasedAppConfig) -> ChatAppGenerateEntity:
|
||||
return ChatAppGenerateEntity.model_construct(
|
||||
task_id="task-id",
|
||||
app_config=app_config,
|
||||
model_conf=DummyModelConf(),
|
||||
file_upload_config=None,
|
||||
conversation_id=None,
|
||||
inputs={},
|
||||
query="hello",
|
||||
files=[],
|
||||
parent_message_id=None,
|
||||
user_id="user-id",
|
||||
stream=False,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
extras={},
|
||||
call_depth=0,
|
||||
trace_manager=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_db_session(monkeypatch):
|
||||
session = MagicMock()
|
||||
|
||||
def refresh_side_effect(obj):
|
||||
if isinstance(obj, Conversation) and obj.id is None:
|
||||
obj.id = "generated-conversation-id"
|
||||
if isinstance(obj, Message) and obj.id is None:
|
||||
obj.id = "generated-message-id"
|
||||
|
||||
session.refresh.side_effect = refresh_side_effect
|
||||
session.add.return_value = None
|
||||
session.commit.return_value = None
|
||||
|
||||
monkeypatch.setattr(message_based_app_generator, "db", SimpleNamespace(session=session))
|
||||
return session
|
||||
|
||||
|
||||
def test_init_generate_records_skips_conversation_fields_for_non_conversation_entity():
|
||||
app_config = _make_app_config(AppMode.COMPLETION)
|
||||
entity = DummyCompletionGenerateEntity(app_config=app_config)
|
||||
|
||||
generator = MessageBasedAppGenerator()
|
||||
|
||||
conversation, message = generator._init_generate_records(entity, conversation=None)
|
||||
|
||||
assert conversation.id == "generated-conversation-id"
|
||||
assert message.id == "generated-message-id"
|
||||
assert hasattr(entity, "conversation_id") is False
|
||||
assert hasattr(entity, "is_new_conversation") is False
|
||||
|
||||
|
||||
def test_init_generate_records_sets_conversation_fields_for_chat_entity():
|
||||
app_config = _make_app_config(AppMode.CHAT)
|
||||
entity = _make_chat_generate_entity(app_config)
|
||||
|
||||
generator = MessageBasedAppGenerator()
|
||||
|
||||
conversation, _ = generator._init_generate_records(entity, conversation=None)
|
||||
|
||||
assert entity.conversation_id == "generated-conversation-id"
|
||||
assert entity.is_new_conversation is True
|
||||
assert conversation.id == "generated-conversation-id"
|
||||
287
api/tests/unit_tests/core/app/apps/test_pause_resume.py
Normal file
287
api/tests/unit_tests/core/app/apps/test_pause_resume.py
Normal file
@ -0,0 +1,287 @@
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
API_DIR = str(Path(__file__).resolve().parents[5])
|
||||
if API_DIR not in sys.path:
|
||||
sys.path.insert(0, API_DIR)
|
||||
|
||||
import core.workflow.nodes.human_input.entities # noqa: F401
|
||||
from core.app.apps.advanced_chat import app_generator as adv_app_gen_module
|
||||
from core.app.apps.workflow import app_generator as wf_app_gen_module
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.entities.pause_reason import SchedulingPause
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.node_events import NodeRunResult, PauseRequestedEvent
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, OutputVariableEntity, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
if "core.ops.ops_trace_manager" not in sys.modules:
|
||||
ops_stub = ModuleType("core.ops.ops_trace_manager")
|
||||
|
||||
class _StubTraceQueueManager:
|
||||
def __init__(self, *_, **__):
|
||||
pass
|
||||
|
||||
ops_stub.TraceQueueManager = _StubTraceQueueManager
|
||||
sys.modules["core.ops.ops_trace_manager"] = ops_stub
|
||||
|
||||
|
||||
class _StubToolNodeData(BaseNodeData):
|
||||
pause_on: bool = False
|
||||
|
||||
|
||||
class _StubToolNode(Node[_StubToolNodeData]):
|
||||
node_type = NodeType.TOOL
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def init_node_data(self, data):
|
||||
self._node_data = _StubToolNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self):
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self):
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
def _run(self):
|
||||
if self.node_data.pause_on:
|
||||
yield PauseRequestedEvent(reason=SchedulingPause(message="test pause"))
|
||||
return
|
||||
|
||||
result = NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"value": f"{self.id}-done"},
|
||||
)
|
||||
yield self._convert_node_run_result_to_graph_node_event(result)
|
||||
|
||||
|
||||
def _patch_tool_node(mocker):
|
||||
original_create_node = DifyNodeFactory.create_node
|
||||
|
||||
def _patched_create_node(self, node_config: dict[str, object]) -> Node:
|
||||
node_data = node_config.get("data", {})
|
||||
if isinstance(node_data, dict) and node_data.get("type") == NodeType.TOOL.value:
|
||||
return _StubToolNode(
|
||||
id=str(node_config["id"]),
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
)
|
||||
return original_create_node(self, node_config)
|
||||
|
||||
mocker.patch.object(DifyNodeFactory, "create_node", _patched_create_node)
|
||||
|
||||
|
||||
def _node_data(node_type: NodeType, data: BaseNodeData) -> dict[str, object]:
|
||||
node_data = data.model_dump()
|
||||
node_data["type"] = node_type.value
|
||||
return node_data
|
||||
|
||||
|
||||
def _build_graph_config(*, pause_on: str | None) -> dict[str, object]:
|
||||
start_data = StartNodeData(title="start", variables=[])
|
||||
tool_data_a = _StubToolNodeData(title="tool", pause_on=pause_on == "tool_a")
|
||||
tool_data_b = _StubToolNodeData(title="tool", pause_on=pause_on == "tool_b")
|
||||
tool_data_c = _StubToolNodeData(title="tool", pause_on=pause_on == "tool_c")
|
||||
end_data = EndNodeData(
|
||||
title="end",
|
||||
outputs=[OutputVariableEntity(variable="result", value_selector=["tool_c", "value"])],
|
||||
desc=None,
|
||||
)
|
||||
|
||||
nodes = [
|
||||
{"id": "start", "data": _node_data(NodeType.START, start_data)},
|
||||
{"id": "tool_a", "data": _node_data(NodeType.TOOL, tool_data_a)},
|
||||
{"id": "tool_b", "data": _node_data(NodeType.TOOL, tool_data_b)},
|
||||
{"id": "tool_c", "data": _node_data(NodeType.TOOL, tool_data_c)},
|
||||
{"id": "end", "data": _node_data(NodeType.END, end_data)},
|
||||
]
|
||||
edges = [
|
||||
{"source": "start", "target": "tool_a"},
|
||||
{"source": "tool_a", "target": "tool_b"},
|
||||
{"source": "tool_b", "target": "tool_c"},
|
||||
{"source": "tool_c", "target": "end"},
|
||||
]
|
||||
return {"nodes": nodes, "edges": edges}
|
||||
|
||||
|
||||
def _build_graph(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> Graph:
|
||||
graph_config = _build_graph_config(pause_on=pause_on)
|
||||
params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config=graph_config,
|
||||
user_id="user",
|
||||
user_from="account",
|
||||
invoke_from="service-api",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
return Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
|
||||
def _build_runtime_state(run_id: str) -> GraphRuntimeState:
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
variable_pool.system_variables.workflow_execution_id = run_id
|
||||
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
|
||||
def _run_with_optional_pause(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> list[GraphEngineEvent]:
|
||||
command_channel = InMemoryChannel()
|
||||
graph = _build_graph(runtime_state, pause_on=pause_on)
|
||||
engine = GraphEngine(
|
||||
workflow_id="workflow",
|
||||
graph=graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=command_channel,
|
||||
)
|
||||
|
||||
events: list[GraphEngineEvent] = []
|
||||
for event in engine.run():
|
||||
events.append(event)
|
||||
return events
|
||||
|
||||
|
||||
def _node_successes(events: list[GraphEngineEvent]) -> list[str]:
|
||||
return [evt.node_id for evt in events if isinstance(evt, NodeRunSucceededEvent)]
|
||||
|
||||
|
||||
def test_workflow_app_pause_resume_matches_baseline(mocker):
|
||||
_patch_tool_node(mocker)
|
||||
|
||||
baseline_state = _build_runtime_state("baseline")
|
||||
baseline_events = _run_with_optional_pause(baseline_state, pause_on=None)
|
||||
assert isinstance(baseline_events[-1], GraphRunSucceededEvent)
|
||||
baseline_nodes = _node_successes(baseline_events)
|
||||
baseline_outputs = baseline_state.outputs
|
||||
|
||||
paused_state = _build_runtime_state("paused-run")
|
||||
paused_events = _run_with_optional_pause(paused_state, pause_on="tool_a")
|
||||
assert isinstance(paused_events[-1], GraphRunPausedEvent)
|
||||
paused_nodes = _node_successes(paused_events)
|
||||
snapshot = paused_state.dumps()
|
||||
|
||||
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
|
||||
|
||||
generator = wf_app_gen_module.WorkflowAppGenerator()
|
||||
|
||||
def _fake_generate(**kwargs):
|
||||
state: GraphRuntimeState = kwargs["graph_runtime_state"]
|
||||
events = _run_with_optional_pause(state, pause_on=None)
|
||||
return _node_successes(events)
|
||||
|
||||
mocker.patch.object(generator, "_generate", side_effect=_fake_generate)
|
||||
|
||||
resumed_nodes = generator.resume(
|
||||
app_model=SimpleNamespace(mode="workflow"),
|
||||
workflow=SimpleNamespace(),
|
||||
user=SimpleNamespace(),
|
||||
application_generate_entity=SimpleNamespace(stream=False, invoke_from=InvokeFrom.SERVICE_API),
|
||||
graph_runtime_state=resumed_state,
|
||||
workflow_execution_repository=SimpleNamespace(),
|
||||
workflow_node_execution_repository=SimpleNamespace(),
|
||||
)
|
||||
|
||||
assert paused_nodes + resumed_nodes == baseline_nodes
|
||||
assert resumed_state.outputs == baseline_outputs
|
||||
|
||||
|
||||
def test_advanced_chat_pause_resume_matches_baseline(mocker):
|
||||
_patch_tool_node(mocker)
|
||||
|
||||
baseline_state = _build_runtime_state("adv-baseline")
|
||||
baseline_events = _run_with_optional_pause(baseline_state, pause_on=None)
|
||||
assert isinstance(baseline_events[-1], GraphRunSucceededEvent)
|
||||
baseline_nodes = _node_successes(baseline_events)
|
||||
baseline_outputs = baseline_state.outputs
|
||||
|
||||
paused_state = _build_runtime_state("adv-paused")
|
||||
paused_events = _run_with_optional_pause(paused_state, pause_on="tool_a")
|
||||
assert isinstance(paused_events[-1], GraphRunPausedEvent)
|
||||
paused_nodes = _node_successes(paused_events)
|
||||
snapshot = paused_state.dumps()
|
||||
|
||||
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
|
||||
|
||||
generator = adv_app_gen_module.AdvancedChatAppGenerator()
|
||||
|
||||
def _fake_generate(**kwargs):
|
||||
state: GraphRuntimeState = kwargs["graph_runtime_state"]
|
||||
events = _run_with_optional_pause(state, pause_on=None)
|
||||
return _node_successes(events)
|
||||
|
||||
mocker.patch.object(generator, "_generate", side_effect=_fake_generate)
|
||||
|
||||
resumed_nodes = generator.resume(
|
||||
app_model=SimpleNamespace(mode="workflow"),
|
||||
workflow=SimpleNamespace(),
|
||||
user=SimpleNamespace(),
|
||||
conversation=SimpleNamespace(id="conv"),
|
||||
message=SimpleNamespace(id="msg"),
|
||||
application_generate_entity=SimpleNamespace(stream=False, invoke_from=InvokeFrom.SERVICE_API),
|
||||
workflow_execution_repository=SimpleNamespace(),
|
||||
workflow_node_execution_repository=SimpleNamespace(),
|
||||
graph_runtime_state=resumed_state,
|
||||
)
|
||||
|
||||
assert paused_nodes + resumed_nodes == baseline_nodes
|
||||
assert resumed_state.outputs == baseline_outputs
|
||||
|
||||
|
||||
def test_resume_emits_resumption_start_reason(mocker) -> None:
|
||||
_patch_tool_node(mocker)
|
||||
|
||||
paused_state = _build_runtime_state("resume-reason")
|
||||
paused_events = _run_with_optional_pause(paused_state, pause_on="tool_a")
|
||||
initial_start = next(event for event in paused_events if isinstance(event, GraphRunStartedEvent))
|
||||
assert initial_start.reason == WorkflowStartReason.INITIAL
|
||||
|
||||
resumed_state = GraphRuntimeState.from_snapshot(paused_state.dumps())
|
||||
resumed_events = _run_with_optional_pause(resumed_state, pause_on=None)
|
||||
resume_start = next(event for event in resumed_events if isinstance(event, GraphRunStartedEvent))
|
||||
assert resume_start.reason == WorkflowStartReason.RESUMPTION
|
||||
80
api/tests/unit_tests/core/app/apps/test_streaming_utils.py
Normal file
80
api/tests/unit_tests/core/app/apps/test_streaming_utils.py
Normal file
@ -0,0 +1,80 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import queue
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.entities.task_entities import StreamEvent
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class FakeSubscription:
|
||||
def __init__(self, message_queue: queue.Queue[bytes], state: dict[str, bool]) -> None:
|
||||
self._queue = message_queue
|
||||
self._state = state
|
||||
self._closed = False
|
||||
|
||||
def __enter__(self):
|
||||
self._state["subscribed"] = True
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.close()
|
||||
|
||||
def close(self) -> None:
|
||||
self._closed = True
|
||||
|
||||
def receive(self, timeout: float | None = 0.1) -> bytes | None:
|
||||
if self._closed:
|
||||
return None
|
||||
try:
|
||||
if timeout is None:
|
||||
return self._queue.get()
|
||||
return self._queue.get(timeout=timeout)
|
||||
except queue.Empty:
|
||||
return None
|
||||
|
||||
|
||||
class FakeTopic:
|
||||
def __init__(self) -> None:
|
||||
self._queue: queue.Queue[bytes] = queue.Queue()
|
||||
self._state = {"subscribed": False}
|
||||
|
||||
def subscribe(self) -> FakeSubscription:
|
||||
return FakeSubscription(self._queue, self._state)
|
||||
|
||||
def publish(self, payload: bytes) -> None:
|
||||
self._queue.put(payload)
|
||||
|
||||
@property
|
||||
def subscribed(self) -> bool:
|
||||
return self._state["subscribed"]
|
||||
|
||||
|
||||
def test_retrieve_events_calls_on_subscribe_after_subscription(monkeypatch):
|
||||
topic = FakeTopic()
|
||||
|
||||
def fake_get_response_topic(cls, app_mode, workflow_run_id):
|
||||
return topic
|
||||
|
||||
monkeypatch.setattr(MessageBasedAppGenerator, "get_response_topic", classmethod(fake_get_response_topic))
|
||||
|
||||
def on_subscribe() -> None:
|
||||
assert topic.subscribed is True
|
||||
event = {"event": StreamEvent.WORKFLOW_FINISHED.value}
|
||||
topic.publish(json.dumps(event).encode())
|
||||
|
||||
generator = MessageBasedAppGenerator.retrieve_events(
|
||||
AppMode.WORKFLOW,
|
||||
"workflow-run-id",
|
||||
idle_timeout=0.5,
|
||||
on_subscribe=on_subscribe,
|
||||
)
|
||||
|
||||
assert next(generator) == StreamEvent.PING.value
|
||||
event = next(generator)
|
||||
assert event["event"] == StreamEvent.WORKFLOW_FINISHED.value
|
||||
with pytest.raises(StopIteration):
|
||||
next(generator)
|
||||
@ -1,3 +1,6 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator
|
||||
|
||||
|
||||
@ -17,3 +20,193 @@ def test_should_prepare_user_inputs_keeps_validation_when_flag_false():
|
||||
args = {"inputs": {}, SKIP_PREPARE_USER_INPUTS_KEY: False}
|
||||
|
||||
assert WorkflowAppGenerator()._should_prepare_user_inputs(args)
|
||||
|
||||
|
||||
def test_resume_delegates_to_generate(mocker):
|
||||
generator = WorkflowAppGenerator()
|
||||
mock_generate = mocker.patch.object(generator, "_generate", return_value="ok")
|
||||
|
||||
application_generate_entity = SimpleNamespace(stream=False, invoke_from="debugger")
|
||||
runtime_state = MagicMock(name="runtime-state")
|
||||
pause_config = MagicMock(name="pause-config")
|
||||
|
||||
result = generator.resume(
|
||||
app_model=MagicMock(),
|
||||
workflow=MagicMock(),
|
||||
user=MagicMock(),
|
||||
application_generate_entity=application_generate_entity,
|
||||
graph_runtime_state=runtime_state,
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
graph_engine_layers=("layer",),
|
||||
pause_state_config=pause_config,
|
||||
variable_loader=MagicMock(),
|
||||
)
|
||||
|
||||
assert result == "ok"
|
||||
mock_generate.assert_called_once()
|
||||
kwargs = mock_generate.call_args.kwargs
|
||||
assert kwargs["graph_runtime_state"] is runtime_state
|
||||
assert kwargs["pause_state_config"] is pause_config
|
||||
assert kwargs["streaming"] is False
|
||||
assert kwargs["invoke_from"] == "debugger"
|
||||
|
||||
|
||||
def test_generate_appends_pause_layer_and_forwards_state(mocker):
|
||||
generator = WorkflowAppGenerator()
|
||||
|
||||
mock_queue_manager = MagicMock()
|
||||
mocker.patch("core.app.apps.workflow.app_generator.WorkflowAppQueueManager", return_value=mock_queue_manager)
|
||||
|
||||
fake_current_app = MagicMock()
|
||||
fake_current_app._get_current_object.return_value = MagicMock()
|
||||
mocker.patch("core.app.apps.workflow.app_generator.current_app", fake_current_app)
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.workflow.app_generator.WorkflowAppGenerateResponseConverter.convert",
|
||||
return_value="converted",
|
||||
)
|
||||
mocker.patch.object(WorkflowAppGenerator, "_handle_response", return_value="response")
|
||||
mocker.patch.object(WorkflowAppGenerator, "_get_draft_var_saver_factory", return_value=MagicMock())
|
||||
|
||||
pause_layer = MagicMock(name="pause-layer")
|
||||
mocker.patch(
|
||||
"core.app.apps.workflow.app_generator.PauseStatePersistenceLayer",
|
||||
return_value=pause_layer,
|
||||
)
|
||||
|
||||
dummy_session = MagicMock()
|
||||
dummy_session.close = MagicMock()
|
||||
mocker.patch("core.app.apps.workflow.app_generator.db.session", dummy_session)
|
||||
|
||||
worker_kwargs: dict[str, object] = {}
|
||||
|
||||
class DummyThread:
|
||||
def __init__(self, target, kwargs):
|
||||
worker_kwargs["target"] = target
|
||||
worker_kwargs["kwargs"] = kwargs
|
||||
|
||||
def start(self):
|
||||
return None
|
||||
|
||||
mocker.patch("core.app.apps.workflow.app_generator.threading.Thread", DummyThread)
|
||||
|
||||
app_model = SimpleNamespace(mode="workflow")
|
||||
app_config = SimpleNamespace(app_id="app", tenant_id="tenant", workflow_id="wf")
|
||||
application_generate_entity = SimpleNamespace(
|
||||
task_id="task",
|
||||
user_id="user",
|
||||
invoke_from="service-api",
|
||||
app_config=app_config,
|
||||
files=[],
|
||||
stream=True,
|
||||
workflow_execution_id="run",
|
||||
)
|
||||
|
||||
graph_runtime_state = MagicMock()
|
||||
|
||||
result = generator._generate(
|
||||
app_model=app_model,
|
||||
workflow=MagicMock(),
|
||||
user=MagicMock(),
|
||||
application_generate_entity=application_generate_entity,
|
||||
invoke_from="service-api",
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
streaming=True,
|
||||
graph_engine_layers=("base-layer",),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
pause_state_config=SimpleNamespace(session_factory=MagicMock(), state_owner_user_id="owner"),
|
||||
)
|
||||
|
||||
assert result == "converted"
|
||||
assert worker_kwargs["kwargs"]["graph_engine_layers"] == ("base-layer", pause_layer)
|
||||
assert worker_kwargs["kwargs"]["graph_runtime_state"] is graph_runtime_state
|
||||
|
||||
|
||||
def test_resume_path_runs_worker_with_runtime_state(mocker):
|
||||
generator = WorkflowAppGenerator()
|
||||
runtime_state = MagicMock(name="runtime-state")
|
||||
|
||||
pause_layer = MagicMock(name="pause-layer")
|
||||
mocker.patch("core.app.apps.workflow.app_generator.PauseStatePersistenceLayer", return_value=pause_layer)
|
||||
|
||||
queue_manager = MagicMock()
|
||||
mocker.patch("core.app.apps.workflow.app_generator.WorkflowAppQueueManager", return_value=queue_manager)
|
||||
|
||||
mocker.patch.object(generator, "_handle_response", return_value="raw-response")
|
||||
mocker.patch(
|
||||
"core.app.apps.workflow.app_generator.WorkflowAppGenerateResponseConverter.convert",
|
||||
side_effect=lambda response, invoke_from: response,
|
||||
)
|
||||
|
||||
fake_db = SimpleNamespace(session=MagicMock(), engine=MagicMock())
|
||||
mocker.patch("core.app.apps.workflow.app_generator.db", fake_db)
|
||||
|
||||
workflow = SimpleNamespace(
|
||||
id="workflow", tenant_id="tenant", app_id="app", graph_dict={}, type="workflow", version="1"
|
||||
)
|
||||
end_user = SimpleNamespace(session_id="end-user-session")
|
||||
app_record = SimpleNamespace(id="app")
|
||||
|
||||
session = MagicMock()
|
||||
session.__enter__.return_value = session
|
||||
session.__exit__.return_value = False
|
||||
session.scalar.side_effect = [workflow, end_user, app_record]
|
||||
mocker.patch("core.app.apps.workflow.app_generator.session_factory", return_value=session)
|
||||
|
||||
runner_instance = MagicMock()
|
||||
|
||||
def runner_ctor(**kwargs):
|
||||
assert kwargs["graph_runtime_state"] is runtime_state
|
||||
return runner_instance
|
||||
|
||||
mocker.patch("core.app.apps.workflow.app_generator.WorkflowAppRunner", side_effect=runner_ctor)
|
||||
|
||||
class ImmediateThread:
|
||||
def __init__(self, target, kwargs):
|
||||
target(**kwargs)
|
||||
|
||||
def start(self):
|
||||
return None
|
||||
|
||||
mocker.patch("core.app.apps.workflow.app_generator.threading.Thread", ImmediateThread)
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
mocker.patch(
|
||||
"core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository",
|
||||
return_value=MagicMock(),
|
||||
)
|
||||
|
||||
pause_config = SimpleNamespace(session_factory=MagicMock(), state_owner_user_id="owner")
|
||||
|
||||
app_model = SimpleNamespace(mode="workflow")
|
||||
app_config = SimpleNamespace(app_id="app", tenant_id="tenant", workflow_id="workflow")
|
||||
application_generate_entity = SimpleNamespace(
|
||||
task_id="task",
|
||||
user_id="user",
|
||||
invoke_from="service-api",
|
||||
app_config=app_config,
|
||||
files=[],
|
||||
stream=True,
|
||||
workflow_execution_id="run",
|
||||
trace_manager=MagicMock(),
|
||||
)
|
||||
|
||||
result = generator.resume(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=MagicMock(),
|
||||
application_generate_entity=application_generate_entity,
|
||||
graph_runtime_state=runtime_state,
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
pause_state_config=pause_config,
|
||||
)
|
||||
|
||||
assert result == "raw-response"
|
||||
runner_instance.run.assert_called_once()
|
||||
queue_manager.graph_runtime_state = runtime_state
|
||||
|
||||
@ -0,0 +1,59 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from core.app.entities.queue_entities import QueueWorkflowPausedEvent
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.graph_events.graph import GraphRunPausedEvent
|
||||
|
||||
|
||||
class _DummyQueueManager:
|
||||
def __init__(self):
|
||||
self.published = []
|
||||
|
||||
def publish(self, event, _from):
|
||||
self.published.append(event)
|
||||
|
||||
|
||||
class _DummyRuntimeState:
|
||||
def get_paused_nodes(self):
|
||||
return ["node-1"]
|
||||
|
||||
|
||||
class _DummyGraphEngine:
|
||||
def __init__(self):
|
||||
self.graph_runtime_state = _DummyRuntimeState()
|
||||
|
||||
|
||||
class _DummyWorkflowEntry:
|
||||
def __init__(self):
|
||||
self.graph_engine = _DummyGraphEngine()
|
||||
|
||||
|
||||
def test_handle_pause_event_enqueues_email_task(monkeypatch: pytest.MonkeyPatch):
|
||||
queue_manager = _DummyQueueManager()
|
||||
runner = WorkflowBasedAppRunner(queue_manager=queue_manager, app_id="app-id")
|
||||
workflow_entry = _DummyWorkflowEntry()
|
||||
|
||||
reason = HumanInputRequired(
|
||||
form_id="form-123",
|
||||
form_content="content",
|
||||
inputs=[],
|
||||
actions=[],
|
||||
node_id="node-1",
|
||||
node_title="Review",
|
||||
)
|
||||
event = GraphRunPausedEvent(reasons=[reason], outputs={})
|
||||
|
||||
email_task = MagicMock()
|
||||
monkeypatch.setattr("core.app.apps.workflow_app_runner.dispatch_human_input_email_task", email_task)
|
||||
|
||||
runner._handle_event(workflow_entry, event)
|
||||
|
||||
email_task.apply_async.assert_called_once()
|
||||
kwargs = email_task.apply_async.call_args.kwargs["kwargs"]
|
||||
assert kwargs["form_id"] == "form-123"
|
||||
assert kwargs["node_title"] == "Review"
|
||||
|
||||
assert any(isinstance(evt, QueueWorkflowPausedEvent) for evt in queue_manager.published)
|
||||
183
api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py
Normal file
183
api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py
Normal file
@ -0,0 +1,183 @@
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.common import workflow_response_converter
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.apps.workflow.app_runner import WorkflowAppRunner
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueWorkflowPausedEvent
|
||||
from core.app.entities.task_entities import HumanInputRequiredResponse, WorkflowPauseStreamResponse
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.graph_events.graph import GraphRunPausedEvent
|
||||
from core.workflow.nodes.human_input.entities import FormInput, UserAction
|
||||
from core.workflow.nodes.human_input.enums import FormInputType
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.account import Account
|
||||
|
||||
|
||||
class _RecordingWorkflowAppRunner(WorkflowAppRunner):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.published_events = []
|
||||
|
||||
def _publish_event(self, event):
|
||||
self.published_events.append(event)
|
||||
|
||||
|
||||
class _FakeRuntimeState:
|
||||
def get_paused_nodes(self):
|
||||
return ["node-pause-1"]
|
||||
|
||||
|
||||
def _build_runner():
|
||||
app_entity = SimpleNamespace(
|
||||
app_config=SimpleNamespace(app_id="app-id"),
|
||||
inputs={},
|
||||
files=[],
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
single_iteration_run=None,
|
||||
single_loop_run=None,
|
||||
workflow_execution_id="run-id",
|
||||
user_id="user-id",
|
||||
)
|
||||
workflow = SimpleNamespace(
|
||||
graph_dict={},
|
||||
tenant_id="tenant-id",
|
||||
environment_variables={},
|
||||
id="workflow-id",
|
||||
)
|
||||
queue_manager = SimpleNamespace(publish=lambda event, pub_from: None)
|
||||
return _RecordingWorkflowAppRunner(
|
||||
application_generate_entity=app_entity,
|
||||
queue_manager=queue_manager,
|
||||
variable_loader=MagicMock(),
|
||||
workflow=workflow,
|
||||
system_user_id="sys-user",
|
||||
root_node_id=None,
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
graph_engine_layers=(),
|
||||
graph_runtime_state=None,
|
||||
)
|
||||
|
||||
|
||||
def test_graph_run_paused_event_emits_queue_pause_event():
|
||||
runner = _build_runner()
|
||||
reason = HumanInputRequired(
|
||||
form_id="form-1",
|
||||
form_content="content",
|
||||
inputs=[],
|
||||
actions=[],
|
||||
node_id="node-human",
|
||||
node_title="Human Step",
|
||||
form_token="tok",
|
||||
)
|
||||
event = GraphRunPausedEvent(reasons=[reason], outputs={"foo": "bar"})
|
||||
workflow_entry = SimpleNamespace(
|
||||
graph_engine=SimpleNamespace(graph_runtime_state=_FakeRuntimeState()),
|
||||
)
|
||||
|
||||
runner._handle_event(workflow_entry, event)
|
||||
|
||||
assert len(runner.published_events) == 1
|
||||
queue_event = runner.published_events[0]
|
||||
assert isinstance(queue_event, QueueWorkflowPausedEvent)
|
||||
assert queue_event.reasons == [reason]
|
||||
assert queue_event.outputs == {"foo": "bar"}
|
||||
assert queue_event.paused_nodes == ["node-pause-1"]
|
||||
|
||||
|
||||
def _build_converter():
|
||||
application_generate_entity = SimpleNamespace(
|
||||
inputs={},
|
||||
files=[],
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
app_config=SimpleNamespace(app_id="app-id", tenant_id="tenant-id"),
|
||||
)
|
||||
system_variables = SystemVariable(
|
||||
user_id="user",
|
||||
app_id="app-id",
|
||||
workflow_id="workflow-id",
|
||||
workflow_execution_id="run-id",
|
||||
)
|
||||
user = MagicMock(spec=Account)
|
||||
user.id = "account-id"
|
||||
user.name = "Tester"
|
||||
user.email = "tester@example.com"
|
||||
return WorkflowResponseConverter(
|
||||
application_generate_entity=application_generate_entity,
|
||||
user=user,
|
||||
system_variables=system_variables,
|
||||
)
|
||||
|
||||
|
||||
def test_queue_workflow_paused_event_to_stream_responses(monkeypatch: pytest.MonkeyPatch):
|
||||
converter = _build_converter()
|
||||
converter.workflow_start_to_stream_response(
|
||||
task_id="task",
|
||||
workflow_run_id="run-id",
|
||||
workflow_id="workflow-id",
|
||||
reason=WorkflowStartReason.INITIAL,
|
||||
)
|
||||
|
||||
expiration_time = datetime(2024, 1, 1, tzinfo=UTC)
|
||||
|
||||
class _FakeSession:
|
||||
def execute(self, _stmt):
|
||||
return [("form-1", expiration_time)]
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(workflow_response_converter, "Session", lambda **_: _FakeSession())
|
||||
monkeypatch.setattr(workflow_response_converter, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
reason = HumanInputRequired(
|
||||
form_id="form-1",
|
||||
form_content="Rendered",
|
||||
inputs=[
|
||||
FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="field", default=None),
|
||||
],
|
||||
actions=[UserAction(id="approve", title="Approve")],
|
||||
display_in_ui=True,
|
||||
node_id="node-id",
|
||||
node_title="Human Step",
|
||||
form_token="token",
|
||||
)
|
||||
queue_event = QueueWorkflowPausedEvent(
|
||||
reasons=[reason],
|
||||
outputs={"answer": "value"},
|
||||
paused_nodes=["node-id"],
|
||||
)
|
||||
|
||||
runtime_state = SimpleNamespace(total_tokens=0, node_run_steps=0)
|
||||
responses = converter.workflow_pause_to_stream_response(
|
||||
event=queue_event,
|
||||
task_id="task",
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
assert isinstance(responses[-1], WorkflowPauseStreamResponse)
|
||||
pause_resp = responses[-1]
|
||||
assert pause_resp.workflow_run_id == "run-id"
|
||||
assert pause_resp.data.paused_nodes == ["node-id"]
|
||||
assert pause_resp.data.outputs == {}
|
||||
assert pause_resp.data.reasons[0]["form_id"] == "form-1"
|
||||
assert pause_resp.data.reasons[0]["display_in_ui"] is True
|
||||
|
||||
assert isinstance(responses[0], HumanInputRequiredResponse)
|
||||
hi_resp = responses[0]
|
||||
assert hi_resp.data.form_id == "form-1"
|
||||
assert hi_resp.data.node_id == "node-id"
|
||||
assert hi_resp.data.node_title == "Human Step"
|
||||
assert hi_resp.data.inputs[0].output_variable_name == "field"
|
||||
assert hi_resp.data.actions[0].id == "approve"
|
||||
assert hi_resp.data.display_in_ui is True
|
||||
assert hi_resp.data.expiration_time == int(expiration_time.timestamp())
|
||||
@ -0,0 +1,96 @@
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.app_config.entities import WorkflowUIBasedAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import QueueWorkflowStartedEvent
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.account import Account
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
def _build_workflow_app_config() -> WorkflowUIBasedAppConfig:
|
||||
return WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_id="workflow-id",
|
||||
)
|
||||
|
||||
|
||||
def _build_generate_entity(run_id: str) -> WorkflowAppGenerateEntity:
|
||||
return WorkflowAppGenerateEntity(
|
||||
task_id="task-id",
|
||||
app_config=_build_workflow_app_config(),
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id="user-id",
|
||||
stream=False,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
workflow_execution_id=run_id,
|
||||
)
|
||||
|
||||
|
||||
def _build_runtime_state(run_id: str) -> GraphRuntimeState:
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(workflow_execution_id=run_id),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _noop_session():
|
||||
yield MagicMock()
|
||||
|
||||
|
||||
def _build_pipeline(run_id: str) -> WorkflowAppGenerateTaskPipeline:
|
||||
queue_manager = MagicMock(spec=AppQueueManager)
|
||||
queue_manager.invoke_from = InvokeFrom.SERVICE_API
|
||||
queue_manager.graph_runtime_state = _build_runtime_state(run_id)
|
||||
workflow = MagicMock()
|
||||
workflow.id = "workflow-id"
|
||||
workflow.features_dict = {}
|
||||
user = Account(name="user", email="user@example.com")
|
||||
pipeline = WorkflowAppGenerateTaskPipeline(
|
||||
application_generate_entity=_build_generate_entity(run_id),
|
||||
workflow=workflow,
|
||||
queue_manager=queue_manager,
|
||||
user=user,
|
||||
stream=False,
|
||||
draft_var_saver_factory=MagicMock(),
|
||||
)
|
||||
pipeline._database_session = _noop_session
|
||||
return pipeline
|
||||
|
||||
|
||||
def test_workflow_app_log_saved_only_on_initial_start() -> None:
|
||||
run_id = "run-initial"
|
||||
pipeline = _build_pipeline(run_id)
|
||||
pipeline._save_workflow_app_log = MagicMock()
|
||||
|
||||
event = QueueWorkflowStartedEvent(reason=WorkflowStartReason.INITIAL)
|
||||
list(pipeline._handle_workflow_started_event(event))
|
||||
|
||||
pipeline._save_workflow_app_log.assert_called_once()
|
||||
_, kwargs = pipeline._save_workflow_app_log.call_args
|
||||
assert kwargs["workflow_run_id"] == run_id
|
||||
assert pipeline._workflow_execution_id == run_id
|
||||
|
||||
|
||||
def test_workflow_app_log_skipped_on_resumption_start() -> None:
|
||||
run_id = "run-resume"
|
||||
pipeline = _build_pipeline(run_id)
|
||||
pipeline._save_workflow_app_log = MagicMock()
|
||||
|
||||
event = QueueWorkflowStartedEvent(reason=WorkflowStartReason.RESUMPTION)
|
||||
list(pipeline._handle_workflow_started_event(event))
|
||||
|
||||
pipeline._save_workflow_app_log.assert_not_called()
|
||||
assert pipeline._workflow_execution_id == run_id
|
||||
@ -0,0 +1,143 @@
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import WorkflowUIBasedAppConfig
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
InvokeFrom,
|
||||
WorkflowAppGenerateEntity,
|
||||
)
|
||||
from core.app.layers.pause_state_persist_layer import (
|
||||
WorkflowResumptionContext,
|
||||
_AdvancedChatAppGenerateEntityWrapper,
|
||||
_WorkflowGenerateEntityWrapper,
|
||||
)
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class TraceQueueManagerStub(TraceQueueManager):
|
||||
"""Minimal TraceQueueManager stub that avoids Flask dependencies."""
|
||||
|
||||
def __init__(self):
|
||||
# Skip parent initialization to avoid starting timers or accessing Flask globals.
|
||||
pass
|
||||
|
||||
|
||||
def _build_workflow_app_config(app_mode: AppMode) -> WorkflowUIBasedAppConfig:
|
||||
return WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
app_mode=app_mode,
|
||||
workflow_id=f"{app_mode.value}-workflow-id",
|
||||
)
|
||||
|
||||
|
||||
def _create_workflow_generate_entity(trace_manager: TraceQueueManager | None = None) -> WorkflowAppGenerateEntity:
|
||||
return WorkflowAppGenerateEntity(
|
||||
task_id="workflow-task",
|
||||
app_config=_build_workflow_app_config(AppMode.WORKFLOW),
|
||||
inputs={"topic": "serialization"},
|
||||
files=[],
|
||||
user_id="user-workflow",
|
||||
stream=True,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=1,
|
||||
trace_manager=trace_manager,
|
||||
workflow_execution_id="workflow-exec-id",
|
||||
extras={"external_trace_id": "trace-id"},
|
||||
)
|
||||
|
||||
|
||||
def _create_advanced_chat_generate_entity(
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
) -> AdvancedChatAppGenerateEntity:
|
||||
return AdvancedChatAppGenerateEntity(
|
||||
task_id="advanced-task",
|
||||
app_config=_build_workflow_app_config(AppMode.ADVANCED_CHAT),
|
||||
conversation_id="conversation-id",
|
||||
inputs={"topic": "roundtrip"},
|
||||
files=[],
|
||||
user_id="user-advanced",
|
||||
stream=False,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
query="Explain serialization",
|
||||
extras={"auto_generate_conversation_name": True},
|
||||
trace_manager=trace_manager,
|
||||
workflow_run_id="workflow-run-id",
|
||||
)
|
||||
|
||||
|
||||
def test_workflow_app_generate_entity_roundtrip_excludes_trace_manager():
|
||||
entity = _create_workflow_generate_entity(trace_manager=TraceQueueManagerStub())
|
||||
|
||||
serialized = entity.model_dump_json()
|
||||
payload = json.loads(serialized)
|
||||
|
||||
assert "trace_manager" not in payload
|
||||
|
||||
restored = WorkflowAppGenerateEntity.model_validate_json(serialized)
|
||||
|
||||
assert restored.model_dump() == entity.model_dump()
|
||||
assert restored.trace_manager is None
|
||||
|
||||
|
||||
def test_advanced_chat_generate_entity_roundtrip_excludes_trace_manager():
|
||||
entity = _create_advanced_chat_generate_entity(trace_manager=TraceQueueManagerStub())
|
||||
|
||||
serialized = entity.model_dump_json()
|
||||
payload = json.loads(serialized)
|
||||
|
||||
assert "trace_manager" not in payload
|
||||
|
||||
restored = AdvancedChatAppGenerateEntity.model_validate_json(serialized)
|
||||
|
||||
assert restored.model_dump() == entity.model_dump()
|
||||
assert restored.trace_manager is None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ResumptionContextCase:
|
||||
name: str
|
||||
context_factory: Callable[[], tuple[WorkflowResumptionContext, type]]
|
||||
|
||||
|
||||
def _workflow_resumption_case() -> tuple[WorkflowResumptionContext, type]:
|
||||
entity = _create_workflow_generate_entity(trace_manager=TraceQueueManagerStub())
|
||||
context = WorkflowResumptionContext(
|
||||
serialized_graph_runtime_state=json.dumps({"state": "workflow"}),
|
||||
generate_entity=_WorkflowGenerateEntityWrapper(entity=entity),
|
||||
)
|
||||
return context, WorkflowAppGenerateEntity
|
||||
|
||||
|
||||
def _advanced_chat_resumption_case() -> tuple[WorkflowResumptionContext, type]:
|
||||
entity = _create_advanced_chat_generate_entity(trace_manager=TraceQueueManagerStub())
|
||||
context = WorkflowResumptionContext(
|
||||
serialized_graph_runtime_state=json.dumps({"state": "advanced"}),
|
||||
generate_entity=_AdvancedChatAppGenerateEntityWrapper(entity=entity),
|
||||
)
|
||||
return context, AdvancedChatAppGenerateEntity
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
[
|
||||
pytest.param(ResumptionContextCase("workflow", _workflow_resumption_case), id="workflow"),
|
||||
pytest.param(ResumptionContextCase("advanced_chat", _advanced_chat_resumption_case), id="advanced_chat"),
|
||||
],
|
||||
)
|
||||
def test_workflow_resumption_context_roundtrip(case: ResumptionContextCase):
|
||||
context, expected_type = case.context_factory()
|
||||
|
||||
serialized = context.dumps()
|
||||
restored = WorkflowResumptionContext.loads(serialized)
|
||||
|
||||
assert restored.serialized_graph_runtime_state == context.serialized_graph_runtime_state
|
||||
entity = restored.get_generate_entity()
|
||||
assert isinstance(entity, expected_type)
|
||||
assert entity.model_dump() == context.get_generate_entity().model_dump()
|
||||
assert entity.trace_manager is None
|
||||
@ -0,0 +1,72 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig
|
||||
from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
def test_invoke_chat_app_advanced_chat_injects_pause_state_config(mocker):
|
||||
workflow = MagicMock()
|
||||
workflow.created_by = "owner-id"
|
||||
|
||||
app = MagicMock()
|
||||
app.mode = AppMode.ADVANCED_CHAT
|
||||
app.workflow = workflow
|
||||
|
||||
mocker.patch(
|
||||
"core.plugin.backwards_invocation.app.db",
|
||||
SimpleNamespace(engine=MagicMock()),
|
||||
)
|
||||
generator_spy = mocker.patch(
|
||||
"core.plugin.backwards_invocation.app.AdvancedChatAppGenerator.generate",
|
||||
return_value={"result": "ok"},
|
||||
)
|
||||
|
||||
result = PluginAppBackwardsInvocation.invoke_chat_app(
|
||||
app=app,
|
||||
user=MagicMock(),
|
||||
conversation_id="conv-1",
|
||||
query="hello",
|
||||
stream=False,
|
||||
inputs={"k": "v"},
|
||||
files=[],
|
||||
)
|
||||
|
||||
assert result == {"result": "ok"}
|
||||
call_kwargs = generator_spy.call_args.kwargs
|
||||
pause_state_config = call_kwargs.get("pause_state_config")
|
||||
assert isinstance(pause_state_config, PauseStateLayerConfig)
|
||||
assert pause_state_config.state_owner_user_id == "owner-id"
|
||||
|
||||
|
||||
def test_invoke_workflow_app_injects_pause_state_config(mocker):
|
||||
workflow = MagicMock()
|
||||
workflow.created_by = "owner-id"
|
||||
|
||||
app = MagicMock()
|
||||
app.mode = AppMode.WORKFLOW
|
||||
app.workflow = workflow
|
||||
|
||||
mocker.patch(
|
||||
"core.plugin.backwards_invocation.app.db",
|
||||
SimpleNamespace(engine=MagicMock()),
|
||||
)
|
||||
generator_spy = mocker.patch(
|
||||
"core.plugin.backwards_invocation.app.WorkflowAppGenerator.generate",
|
||||
return_value={"result": "ok"},
|
||||
)
|
||||
|
||||
result = PluginAppBackwardsInvocation.invoke_workflow_app(
|
||||
app=app,
|
||||
user=MagicMock(),
|
||||
stream=False,
|
||||
inputs={"k": "v"},
|
||||
files=[],
|
||||
)
|
||||
|
||||
assert result == {"result": "ok"}
|
||||
call_kwargs = generator_spy.call_args.kwargs
|
||||
pause_state_config = call_kwargs.get("pause_state_config")
|
||||
assert isinstance(pause_state_config, PauseStateLayerConfig)
|
||||
assert pause_state_config.state_owner_user_id == "owner-id"
|
||||
@ -0,0 +1,574 @@
|
||||
"""Unit tests for HumanInputFormRepositoryImpl private helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.repositories.human_input_repository import (
|
||||
HumanInputFormRecord,
|
||||
HumanInputFormRepositoryImpl,
|
||||
HumanInputFormSubmissionRepository,
|
||||
_WorkspaceMemberInfo,
|
||||
)
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
EmailDeliveryConfig,
|
||||
EmailDeliveryMethod,
|
||||
EmailRecipients,
|
||||
ExternalRecipient,
|
||||
FormDefinition,
|
||||
MemberRecipient,
|
||||
UserAction,
|
||||
)
|
||||
from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.human_input import (
|
||||
EmailExternalRecipientPayload,
|
||||
EmailMemberRecipientPayload,
|
||||
HumanInputFormRecipient,
|
||||
RecipientType,
|
||||
)
|
||||
|
||||
|
||||
def _build_repository() -> HumanInputFormRepositoryImpl:
|
||||
return HumanInputFormRepositoryImpl(session_factory=MagicMock(), tenant_id="tenant-id")
|
||||
|
||||
|
||||
def _patch_recipient_factory(monkeypatch: pytest.MonkeyPatch) -> list[SimpleNamespace]:
|
||||
created: list[SimpleNamespace] = []
|
||||
|
||||
def fake_new(cls, form_id: str, delivery_id: str, payload): # type: ignore[no-untyped-def]
|
||||
recipient = SimpleNamespace(
|
||||
form_id=form_id,
|
||||
delivery_id=delivery_id,
|
||||
recipient_type=payload.TYPE,
|
||||
recipient_payload=payload.model_dump_json(),
|
||||
)
|
||||
created.append(recipient)
|
||||
return recipient
|
||||
|
||||
monkeypatch.setattr(HumanInputFormRecipient, "new", classmethod(fake_new))
|
||||
return created
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _stub_selectinload(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Avoid SQLAlchemy mapper configuration in tests using fake sessions."""
|
||||
|
||||
class _FakeSelect:
|
||||
def options(self, *_args, **_kwargs): # type: ignore[no-untyped-def]
|
||||
return self
|
||||
|
||||
def where(self, *_args, **_kwargs): # type: ignore[no-untyped-def]
|
||||
return self
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.repositories.human_input_repository.selectinload", lambda *args, **kwargs: "_loader_option"
|
||||
)
|
||||
monkeypatch.setattr("core.repositories.human_input_repository.select", lambda *args, **kwargs: _FakeSelect())
|
||||
|
||||
|
||||
class TestHumanInputFormRepositoryImplHelpers:
|
||||
def test_build_email_recipients_with_member_and_external(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
repo = _build_repository()
|
||||
session_stub = object()
|
||||
_patch_recipient_factory(monkeypatch)
|
||||
|
||||
def fake_query(self, session, restrict_to_user_ids): # type: ignore[no-untyped-def]
|
||||
assert session is session_stub
|
||||
assert restrict_to_user_ids == ["member-1"]
|
||||
return [_WorkspaceMemberInfo(user_id="member-1", email="member@example.com")]
|
||||
|
||||
monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members_by_ids", fake_query)
|
||||
|
||||
recipients = repo._build_email_recipients(
|
||||
session=session_stub,
|
||||
form_id="form-id",
|
||||
delivery_id="delivery-id",
|
||||
recipients_config=EmailRecipients(
|
||||
whole_workspace=False,
|
||||
items=[
|
||||
MemberRecipient(user_id="member-1"),
|
||||
ExternalRecipient(email="external@example.com"),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
assert len(recipients) == 2
|
||||
member_recipient = next(r for r in recipients if r.recipient_type == RecipientType.EMAIL_MEMBER)
|
||||
external_recipient = next(r for r in recipients if r.recipient_type == RecipientType.EMAIL_EXTERNAL)
|
||||
|
||||
member_payload = EmailMemberRecipientPayload.model_validate_json(member_recipient.recipient_payload)
|
||||
assert member_payload.user_id == "member-1"
|
||||
assert member_payload.email == "member@example.com"
|
||||
|
||||
external_payload = EmailExternalRecipientPayload.model_validate_json(external_recipient.recipient_payload)
|
||||
assert external_payload.email == "external@example.com"
|
||||
|
||||
def test_build_email_recipients_skips_unknown_members(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
repo = _build_repository()
|
||||
session_stub = object()
|
||||
created = _patch_recipient_factory(monkeypatch)
|
||||
|
||||
def fake_query(self, session, restrict_to_user_ids): # type: ignore[no-untyped-def]
|
||||
assert session is session_stub
|
||||
assert restrict_to_user_ids == ["missing-member"]
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members_by_ids", fake_query)
|
||||
|
||||
recipients = repo._build_email_recipients(
|
||||
session=session_stub,
|
||||
form_id="form-id",
|
||||
delivery_id="delivery-id",
|
||||
recipients_config=EmailRecipients(
|
||||
whole_workspace=False,
|
||||
items=[
|
||||
MemberRecipient(user_id="missing-member"),
|
||||
ExternalRecipient(email="external@example.com"),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
assert len(recipients) == 1
|
||||
assert recipients[0].recipient_type == RecipientType.EMAIL_EXTERNAL
|
||||
assert len(created) == 1 # only external recipient created via factory
|
||||
|
||||
def test_build_email_recipients_whole_workspace_uses_all_members(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
repo = _build_repository()
|
||||
session_stub = object()
|
||||
_patch_recipient_factory(monkeypatch)
|
||||
|
||||
def fake_query(self, session): # type: ignore[no-untyped-def]
|
||||
assert session is session_stub
|
||||
return [
|
||||
_WorkspaceMemberInfo(user_id="member-1", email="member1@example.com"),
|
||||
_WorkspaceMemberInfo(user_id="member-2", email="member2@example.com"),
|
||||
]
|
||||
|
||||
monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_all_workspace_members", fake_query)
|
||||
|
||||
recipients = repo._build_email_recipients(
|
||||
session=session_stub,
|
||||
form_id="form-id",
|
||||
delivery_id="delivery-id",
|
||||
recipients_config=EmailRecipients(
|
||||
whole_workspace=True,
|
||||
items=[],
|
||||
),
|
||||
)
|
||||
|
||||
assert len(recipients) == 2
|
||||
emails = {EmailMemberRecipientPayload.model_validate_json(r.recipient_payload).email for r in recipients}
|
||||
assert emails == {"member1@example.com", "member2@example.com"}
|
||||
|
||||
def test_build_email_recipients_dedupes_external_by_email(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
repo = _build_repository()
|
||||
session_stub = object()
|
||||
created = _patch_recipient_factory(monkeypatch)
|
||||
|
||||
def fake_query(self, session, restrict_to_user_ids): # type: ignore[no-untyped-def]
|
||||
assert session is session_stub
|
||||
assert restrict_to_user_ids == []
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members_by_ids", fake_query)
|
||||
|
||||
recipients = repo._build_email_recipients(
|
||||
session=session_stub,
|
||||
form_id="form-id",
|
||||
delivery_id="delivery-id",
|
||||
recipients_config=EmailRecipients(
|
||||
whole_workspace=False,
|
||||
items=[
|
||||
ExternalRecipient(email="external@example.com"),
|
||||
ExternalRecipient(email="external@example.com"),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
assert len(recipients) == 1
|
||||
assert len(created) == 1
|
||||
|
||||
def test_build_email_recipients_prefers_member_over_external_by_email(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
repo = _build_repository()
|
||||
session_stub = object()
|
||||
_patch_recipient_factory(monkeypatch)
|
||||
|
||||
def fake_query(self, session, restrict_to_user_ids): # type: ignore[no-untyped-def]
|
||||
assert session is session_stub
|
||||
assert restrict_to_user_ids == ["member-1"]
|
||||
return [_WorkspaceMemberInfo(user_id="member-1", email="shared@example.com")]
|
||||
|
||||
monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_workspace_members_by_ids", fake_query)
|
||||
|
||||
recipients = repo._build_email_recipients(
|
||||
session=session_stub,
|
||||
form_id="form-id",
|
||||
delivery_id="delivery-id",
|
||||
recipients_config=EmailRecipients(
|
||||
whole_workspace=False,
|
||||
items=[
|
||||
MemberRecipient(user_id="member-1"),
|
||||
ExternalRecipient(email="shared@example.com"),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
assert len(recipients) == 1
|
||||
assert recipients[0].recipient_type == RecipientType.EMAIL_MEMBER
|
||||
|
||||
def test_delivery_method_to_model_includes_external_recipients_with_whole_workspace(
|
||||
self,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
repo = _build_repository()
|
||||
session_stub = object()
|
||||
_patch_recipient_factory(monkeypatch)
|
||||
|
||||
def fake_query(self, session): # type: ignore[no-untyped-def]
|
||||
assert session is session_stub
|
||||
return [
|
||||
_WorkspaceMemberInfo(user_id="member-1", email="member1@example.com"),
|
||||
_WorkspaceMemberInfo(user_id="member-2", email="member2@example.com"),
|
||||
]
|
||||
|
||||
monkeypatch.setattr(HumanInputFormRepositoryImpl, "_query_all_workspace_members", fake_query)
|
||||
|
||||
method = EmailDeliveryMethod(
|
||||
config=EmailDeliveryConfig(
|
||||
recipients=EmailRecipients(
|
||||
whole_workspace=True,
|
||||
items=[ExternalRecipient(email="external@example.com")],
|
||||
),
|
||||
subject="subject",
|
||||
body="body",
|
||||
)
|
||||
)
|
||||
|
||||
result = repo._delivery_method_to_model(session=session_stub, form_id="form-id", delivery_method=method)
|
||||
|
||||
assert len(result.recipients) == 3
|
||||
member_emails = {
|
||||
EmailMemberRecipientPayload.model_validate_json(r.recipient_payload).email
|
||||
for r in result.recipients
|
||||
if r.recipient_type == RecipientType.EMAIL_MEMBER
|
||||
}
|
||||
assert member_emails == {"member1@example.com", "member2@example.com"}
|
||||
external_payload = EmailExternalRecipientPayload.model_validate_json(
|
||||
next(r for r in result.recipients if r.recipient_type == RecipientType.EMAIL_EXTERNAL).recipient_payload
|
||||
)
|
||||
assert external_payload.email == "external@example.com"
|
||||
|
||||
|
||||
def _make_form_definition() -> str:
|
||||
return FormDefinition(
|
||||
form_content="hello",
|
||||
inputs=[],
|
||||
user_actions=[UserAction(id="submit", title="Submit")],
|
||||
rendered_content="<p>hello</p>",
|
||||
expiration_time=datetime.utcnow(),
|
||||
).model_dump_json()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _DummyForm:
|
||||
id: str
|
||||
workflow_run_id: str
|
||||
node_id: str
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
form_definition: str
|
||||
rendered_content: str
|
||||
expiration_time: datetime
|
||||
form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME
|
||||
created_at: datetime = dataclasses.field(default_factory=naive_utc_now)
|
||||
selected_action_id: str | None = None
|
||||
submitted_data: str | None = None
|
||||
submitted_at: datetime | None = None
|
||||
submission_user_id: str | None = None
|
||||
submission_end_user_id: str | None = None
|
||||
completed_by_recipient_id: str | None = None
|
||||
status: HumanInputFormStatus = HumanInputFormStatus.WAITING
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _DummyRecipient:
|
||||
id: str
|
||||
form_id: str
|
||||
recipient_type: RecipientType
|
||||
access_token: str
|
||||
form: _DummyForm | None = None
|
||||
|
||||
|
||||
class _FakeScalarResult:
|
||||
def __init__(self, obj):
|
||||
self._obj = obj
|
||||
|
||||
def first(self):
|
||||
if isinstance(self._obj, list):
|
||||
return self._obj[0] if self._obj else None
|
||||
return self._obj
|
||||
|
||||
def all(self):
|
||||
if isinstance(self._obj, list):
|
||||
return list(self._obj)
|
||||
if self._obj is None:
|
||||
return []
|
||||
return [self._obj]
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
scalars_result=None,
|
||||
scalars_results: list[object] | None = None,
|
||||
forms: dict[str, _DummyForm] | None = None,
|
||||
recipients: dict[str, _DummyRecipient] | None = None,
|
||||
):
|
||||
if scalars_results is not None:
|
||||
self._scalars_queue = list(scalars_results)
|
||||
elif scalars_result is not None:
|
||||
self._scalars_queue = [scalars_result]
|
||||
else:
|
||||
self._scalars_queue = []
|
||||
self.forms = forms or {}
|
||||
self.recipients = recipients or {}
|
||||
|
||||
def scalars(self, _query):
|
||||
if self._scalars_queue:
|
||||
result = self._scalars_queue.pop(0)
|
||||
else:
|
||||
result = None
|
||||
return _FakeScalarResult(result)
|
||||
|
||||
def get(self, model_cls, obj_id): # type: ignore[no-untyped-def]
|
||||
if getattr(model_cls, "__name__", None) == "HumanInputForm":
|
||||
return self.forms.get(obj_id)
|
||||
if getattr(model_cls, "__name__", None) == "HumanInputFormRecipient":
|
||||
return self.recipients.get(obj_id)
|
||||
return None
|
||||
|
||||
def add(self, _obj):
|
||||
return None
|
||||
|
||||
def flush(self):
|
||||
return None
|
||||
|
||||
def refresh(self, _obj):
|
||||
return None
|
||||
|
||||
def begin(self):
|
||||
return self
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
|
||||
def _session_factory(session: _FakeSession):
|
||||
class _SessionContext:
|
||||
def __enter__(self):
|
||||
return session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
def _factory(*_args, **_kwargs):
|
||||
return _SessionContext()
|
||||
|
||||
return _factory
|
||||
|
||||
|
||||
class TestHumanInputFormRepositoryImplPublicMethods:
|
||||
def test_get_form_returns_entity_and_recipients(self):
|
||||
form = _DummyForm(
|
||||
id="form-1",
|
||||
workflow_run_id="run-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
form_definition=_make_form_definition(),
|
||||
rendered_content="<p>hello</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
)
|
||||
recipient = _DummyRecipient(
|
||||
id="recipient-1",
|
||||
form_id=form.id,
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
access_token="token-123",
|
||||
)
|
||||
session = _FakeSession(scalars_results=[form, [recipient]])
|
||||
repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id")
|
||||
|
||||
entity = repo.get_form(form.workflow_run_id, form.node_id)
|
||||
|
||||
assert entity is not None
|
||||
assert entity.id == form.id
|
||||
assert entity.web_app_token == "token-123"
|
||||
assert len(entity.recipients) == 1
|
||||
assert entity.recipients[0].token == "token-123"
|
||||
|
||||
def test_get_form_returns_none_when_missing(self):
|
||||
session = _FakeSession(scalars_results=[None])
|
||||
repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id")
|
||||
|
||||
assert repo.get_form("run-1", "node-1") is None
|
||||
|
||||
def test_get_form_returns_unsubmitted_state(self):
|
||||
form = _DummyForm(
|
||||
id="form-1",
|
||||
workflow_run_id="run-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
form_definition=_make_form_definition(),
|
||||
rendered_content="<p>hello</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
)
|
||||
session = _FakeSession(scalars_results=[form, []])
|
||||
repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id")
|
||||
|
||||
entity = repo.get_form(form.workflow_run_id, form.node_id)
|
||||
|
||||
assert entity is not None
|
||||
assert entity.submitted is False
|
||||
assert entity.selected_action_id is None
|
||||
assert entity.submitted_data is None
|
||||
|
||||
def test_get_form_returns_submission_when_completed(self):
|
||||
form = _DummyForm(
|
||||
id="form-1",
|
||||
workflow_run_id="run-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
form_definition=_make_form_definition(),
|
||||
rendered_content="<p>hello</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
selected_action_id="approve",
|
||||
submitted_data='{"field": "value"}',
|
||||
submitted_at=naive_utc_now(),
|
||||
)
|
||||
session = _FakeSession(scalars_results=[form, []])
|
||||
repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id")
|
||||
|
||||
entity = repo.get_form(form.workflow_run_id, form.node_id)
|
||||
|
||||
assert entity is not None
|
||||
assert entity.submitted is True
|
||||
assert entity.selected_action_id == "approve"
|
||||
assert entity.submitted_data == {"field": "value"}
|
||||
|
||||
|
||||
class TestHumanInputFormSubmissionRepository:
|
||||
def test_get_by_token_returns_record(self):
|
||||
form = _DummyForm(
|
||||
id="form-1",
|
||||
workflow_run_id="run-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
form_definition=_make_form_definition(),
|
||||
rendered_content="<p>hello</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
)
|
||||
recipient = _DummyRecipient(
|
||||
id="recipient-1",
|
||||
form_id=form.id,
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
access_token="token-123",
|
||||
form=form,
|
||||
)
|
||||
session = _FakeSession(scalars_result=recipient)
|
||||
repo = HumanInputFormSubmissionRepository(_session_factory(session))
|
||||
|
||||
record = repo.get_by_token("token-123")
|
||||
|
||||
assert record is not None
|
||||
assert record.form_id == form.id
|
||||
assert record.recipient_type == RecipientType.STANDALONE_WEB_APP
|
||||
assert record.submitted is False
|
||||
|
||||
def test_get_by_form_id_and_recipient_type_uses_recipient(self):
|
||||
form = _DummyForm(
|
||||
id="form-1",
|
||||
workflow_run_id="run-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
form_definition=_make_form_definition(),
|
||||
rendered_content="<p>hello</p>",
|
||||
expiration_time=naive_utc_now(),
|
||||
)
|
||||
recipient = _DummyRecipient(
|
||||
id="recipient-1",
|
||||
form_id=form.id,
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
access_token="token-123",
|
||||
form=form,
|
||||
)
|
||||
session = _FakeSession(scalars_result=recipient)
|
||||
repo = HumanInputFormSubmissionRepository(_session_factory(session))
|
||||
|
||||
record = repo.get_by_form_id_and_recipient_type(
|
||||
form_id=form.id,
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
)
|
||||
|
||||
assert record is not None
|
||||
assert record.recipient_id == recipient.id
|
||||
assert record.access_token == recipient.access_token
|
||||
|
||||
def test_mark_submitted_updates_fields(self, monkeypatch: pytest.MonkeyPatch):
|
||||
fixed_now = datetime(2024, 1, 1, 0, 0, 0)
|
||||
monkeypatch.setattr("core.repositories.human_input_repository.naive_utc_now", lambda: fixed_now)
|
||||
|
||||
form = _DummyForm(
|
||||
id="form-1",
|
||||
workflow_run_id="run-1",
|
||||
node_id="node-1",
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
form_definition=_make_form_definition(),
|
||||
rendered_content="<p>hello</p>",
|
||||
expiration_time=fixed_now,
|
||||
)
|
||||
recipient = _DummyRecipient(
|
||||
id="recipient-1",
|
||||
form_id="form-1",
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
access_token="token-123",
|
||||
)
|
||||
session = _FakeSession(
|
||||
forms={form.id: form},
|
||||
recipients={recipient.id: recipient},
|
||||
)
|
||||
repo = HumanInputFormSubmissionRepository(_session_factory(session))
|
||||
|
||||
record: HumanInputFormRecord = repo.mark_submitted(
|
||||
form_id=form.id,
|
||||
recipient_id=recipient.id,
|
||||
selected_action_id="approve",
|
||||
form_data={"field": "value"},
|
||||
submission_user_id="user-1",
|
||||
submission_end_user_id="end-user-1",
|
||||
)
|
||||
|
||||
assert form.selected_action_id == "approve"
|
||||
assert form.completed_by_recipient_id == recipient.id
|
||||
assert form.submission_user_id == "user-1"
|
||||
assert form.submission_end_user_id == "end-user-1"
|
||||
assert form.submitted_at == fixed_now
|
||||
assert record.submitted is True
|
||||
assert record.selected_action_id == "approve"
|
||||
assert record.submitted_data == {"field": "value"}
|
||||
@ -0,0 +1,33 @@
|
||||
import pytest
|
||||
|
||||
from core.tools.errors import WorkflowToolHumanInputNotSupportedError
|
||||
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
|
||||
|
||||
|
||||
def test_ensure_no_human_input_nodes_passes_for_non_human_input():
|
||||
graph = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "start_node",
|
||||
"data": {"type": "start"},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(graph)
|
||||
|
||||
|
||||
def test_ensure_no_human_input_nodes_raises_for_human_input():
|
||||
graph = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "human_input_node",
|
||||
"data": {"type": "human-input"},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info:
|
||||
WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(graph)
|
||||
|
||||
assert exc_info.value.error_code == "workflow_tool_human_input_not_supported"
|
||||
@ -55,6 +55,43 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel
|
||||
assert exc_info.value.args == ("oops",)
|
||||
|
||||
|
||||
def test_workflow_tool_does_not_use_pause_state_config(monkeypatch: pytest.MonkeyPatch):
|
||||
entity = ToolEntity(
|
||||
identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"),
|
||||
parameters=[],
|
||||
description=None,
|
||||
has_runtime_parameters=False,
|
||||
)
|
||||
runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE)
|
||||
tool = WorkflowTool(
|
||||
workflow_app_id="",
|
||||
workflow_as_tool_id="",
|
||||
version="1",
|
||||
workflow_entities={},
|
||||
workflow_call_depth=1,
|
||||
entity=entity,
|
||||
runtime=runtime,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
|
||||
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
mock_user = Mock()
|
||||
monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user)
|
||||
|
||||
generate_mock = MagicMock(return_value={"data": {}})
|
||||
monkeypatch.setattr("core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", generate_mock)
|
||||
monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None)
|
||||
|
||||
list(tool.invoke("test_user", {}))
|
||||
|
||||
call_kwargs = generate_mock.call_args.kwargs
|
||||
assert "pause_state_config" in call_kwargs
|
||||
assert call_kwargs["pause_state_config"] is None
|
||||
|
||||
|
||||
def test_workflow_tool_should_generate_variable_messages_for_outputs(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test that WorkflowTool should generate variable messages when there are outputs"""
|
||||
entity = ToolEntity(
|
||||
|
||||
@ -118,7 +118,6 @@ class TestGraphRuntimeState:
|
||||
from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue
|
||||
|
||||
assert isinstance(queue, InMemoryReadyQueue)
|
||||
assert state.ready_queue is queue
|
||||
|
||||
def test_graph_execution_lazy_instantiation(self):
|
||||
state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time())
|
||||
|
||||
@ -0,0 +1,88 @@
|
||||
"""
|
||||
Tests for PauseReason discriminated union serialization/deserialization.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from core.workflow.entities.pause_reason import (
|
||||
HumanInputRequired,
|
||||
PauseReason,
|
||||
SchedulingPause,
|
||||
)
|
||||
|
||||
|
||||
class _Holder(BaseModel):
|
||||
"""Helper model that embeds PauseReason for union tests."""
|
||||
|
||||
reason: PauseReason
|
||||
|
||||
|
||||
class TestPauseReasonDiscriminator:
|
||||
"""Test suite for PauseReason union discriminator."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("dict_value", "expected"),
|
||||
[
|
||||
pytest.param(
|
||||
{
|
||||
"reason": {
|
||||
"TYPE": "human_input_required",
|
||||
"form_id": "form_id",
|
||||
"form_content": "form_content",
|
||||
"node_id": "node_id",
|
||||
"node_title": "node_title",
|
||||
},
|
||||
},
|
||||
HumanInputRequired(
|
||||
form_id="form_id",
|
||||
form_content="form_content",
|
||||
node_id="node_id",
|
||||
node_title="node_title",
|
||||
),
|
||||
id="HumanInputRequired",
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"reason": {
|
||||
"TYPE": "scheduled_pause",
|
||||
"message": "Hold on",
|
||||
}
|
||||
},
|
||||
SchedulingPause(message="Hold on"),
|
||||
id="SchedulingPause",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_model_validate(self, dict_value, expected):
|
||||
"""Ensure scheduled pause payloads with lowercase TYPE deserialize."""
|
||||
holder = _Holder.model_validate(dict_value)
|
||||
|
||||
assert type(holder.reason) == type(expected)
|
||||
assert holder.reason == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"reason",
|
||||
[
|
||||
HumanInputRequired(
|
||||
form_id="form_id",
|
||||
form_content="form_content",
|
||||
node_id="node_id",
|
||||
node_title="node_title",
|
||||
),
|
||||
SchedulingPause(message="Hold on"),
|
||||
],
|
||||
ids=lambda x: type(x).__name__,
|
||||
)
|
||||
def test_model_construct(self, reason):
|
||||
holder = _Holder(reason=reason)
|
||||
assert holder.reason == reason
|
||||
|
||||
def test_model_construct_with_invalid_type(self):
|
||||
with pytest.raises(ValidationError):
|
||||
holder = _Holder(reason=object()) # type: ignore
|
||||
|
||||
def test_unknown_type_fails_validation(self):
|
||||
"""Unknown TYPE values should raise a validation error."""
|
||||
with pytest.raises(ValidationError):
|
||||
_Holder.model_validate({"reason": {"TYPE": "UNKNOWN"}})
|
||||
@ -0,0 +1,131 @@
|
||||
"""Utilities for testing HumanInputNode without database dependencies."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
|
||||
from core.workflow.repositories.human_input_form_repository import (
|
||||
FormCreateParams,
|
||||
HumanInputFormEntity,
|
||||
HumanInputFormRecipientEntity,
|
||||
HumanInputFormRepository,
|
||||
)
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
|
||||
class _InMemoryFormRecipient(HumanInputFormRecipientEntity):
|
||||
"""Minimal recipient entity required by the repository interface."""
|
||||
|
||||
def __init__(self, recipient_id: str, token: str) -> None:
|
||||
self._id = recipient_id
|
||||
self._token = token
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def token(self) -> str:
|
||||
return self._token
|
||||
|
||||
|
||||
@dataclass
|
||||
class _InMemoryFormEntity(HumanInputFormEntity):
|
||||
form_id: str
|
||||
rendered: str
|
||||
token: str | None = None
|
||||
action_id: str | None = None
|
||||
data: Mapping[str, Any] | None = None
|
||||
is_submitted: bool = False
|
||||
status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING
|
||||
expiration: datetime = naive_utc_now()
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self.form_id
|
||||
|
||||
@property
|
||||
def web_app_token(self) -> str | None:
|
||||
return self.token
|
||||
|
||||
@property
|
||||
def recipients(self) -> list[HumanInputFormRecipientEntity]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def rendered_content(self) -> str:
|
||||
return self.rendered
|
||||
|
||||
@property
|
||||
def selected_action_id(self) -> str | None:
|
||||
return self.action_id
|
||||
|
||||
@property
|
||||
def submitted_data(self) -> Mapping[str, Any] | None:
|
||||
return self.data
|
||||
|
||||
@property
|
||||
def submitted(self) -> bool:
|
||||
return self.is_submitted
|
||||
|
||||
@property
|
||||
def status(self) -> HumanInputFormStatus:
|
||||
return self.status_value
|
||||
|
||||
@property
|
||||
def expiration_time(self) -> datetime:
|
||||
return self.expiration
|
||||
|
||||
|
||||
class InMemoryHumanInputFormRepository(HumanInputFormRepository):
|
||||
"""Pure in-memory repository used by workflow graph engine tests."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._form_counter = 0
|
||||
self.created_params: list[FormCreateParams] = []
|
||||
self.created_forms: list[_InMemoryFormEntity] = []
|
||||
self._forms_by_key: dict[tuple[str, str], _InMemoryFormEntity] = {}
|
||||
|
||||
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
|
||||
self.created_params.append(params)
|
||||
self._form_counter += 1
|
||||
form_id = f"form-{self._form_counter}"
|
||||
token = f"console-{form_id}" if params.console_recipient_required else f"token-{form_id}"
|
||||
entity = _InMemoryFormEntity(
|
||||
form_id=form_id,
|
||||
rendered=params.rendered_content,
|
||||
token=token,
|
||||
)
|
||||
self.created_forms.append(entity)
|
||||
self._forms_by_key[(params.workflow_execution_id, params.node_id)] = entity
|
||||
return entity
|
||||
|
||||
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
|
||||
return self._forms_by_key.get((workflow_execution_id, node_id))
|
||||
|
||||
# Convenience helpers for tests -------------------------------------
|
||||
|
||||
def set_submission(self, *, action_id: str, form_data: Mapping[str, Any] | None = None) -> None:
|
||||
"""Simulate a human submission for the next repository lookup."""
|
||||
|
||||
if not self.created_forms:
|
||||
raise AssertionError("no form has been created to attach submission data")
|
||||
entity = self.created_forms[-1]
|
||||
entity.action_id = action_id
|
||||
entity.data = form_data or {}
|
||||
entity.is_submitted = True
|
||||
entity.status_value = HumanInputFormStatus.SUBMITTED
|
||||
entity.expiration = naive_utc_now() + timedelta(days=1)
|
||||
|
||||
def clear_submission(self) -> None:
|
||||
if not self.created_forms:
|
||||
return
|
||||
for form in self.created_forms:
|
||||
form.action_id = None
|
||||
form.data = None
|
||||
form.is_submitted = False
|
||||
form.status_value = HumanInputFormStatus.WAITING
|
||||
@ -0,0 +1,74 @@
|
||||
import queue
|
||||
import threading
|
||||
from datetime import datetime
|
||||
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.orchestration.dispatcher import Dispatcher
|
||||
from core.workflow.graph_events import NodeRunSucceededEvent
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
|
||||
|
||||
class StubExecutionCoordinator:
|
||||
def __init__(self, paused: bool) -> None:
|
||||
self._paused = paused
|
||||
self.mark_complete_called = False
|
||||
self.failed_error: Exception | None = None
|
||||
|
||||
@property
|
||||
def aborted(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def paused(self) -> bool:
|
||||
return self._paused
|
||||
|
||||
@property
|
||||
def execution_complete(self) -> bool:
|
||||
return False
|
||||
|
||||
def check_scaling(self) -> None:
|
||||
return None
|
||||
|
||||
def process_commands(self) -> None:
|
||||
return None
|
||||
|
||||
def mark_complete(self) -> None:
|
||||
self.mark_complete_called = True
|
||||
|
||||
def mark_failed(self, error: Exception) -> None:
|
||||
self.failed_error = error
|
||||
|
||||
|
||||
class StubEventHandler:
|
||||
def __init__(self) -> None:
|
||||
self.events: list[object] = []
|
||||
|
||||
def dispatch(self, event: object) -> None:
|
||||
self.events.append(event)
|
||||
|
||||
|
||||
def test_dispatcher_drains_events_when_paused() -> None:
|
||||
event_queue: queue.Queue = queue.Queue()
|
||||
event = NodeRunSucceededEvent(
|
||||
id="exec-1",
|
||||
node_id="node-1",
|
||||
node_type=NodeType.START,
|
||||
start_at=datetime.utcnow(),
|
||||
node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED),
|
||||
)
|
||||
event_queue.put(event)
|
||||
|
||||
handler = StubEventHandler()
|
||||
coordinator = StubExecutionCoordinator(paused=True)
|
||||
dispatcher = Dispatcher(
|
||||
event_queue=event_queue,
|
||||
event_handler=handler,
|
||||
execution_coordinator=coordinator,
|
||||
event_emitter=None,
|
||||
stop_event=threading.Event(),
|
||||
)
|
||||
|
||||
dispatcher._dispatcher_loop()
|
||||
|
||||
assert handler.events == [event]
|
||||
assert coordinator.mark_complete_called is True
|
||||
@ -2,6 +2,8 @@
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.graph_engine.command_processing.command_processor import CommandProcessor
|
||||
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
|
||||
from core.workflow.graph_engine.graph_state_manager import GraphStateManager
|
||||
@ -48,3 +50,13 @@ def test_handle_pause_noop_when_execution_running() -> None:
|
||||
|
||||
worker_pool.stop.assert_not_called()
|
||||
state_manager.clear_executing.assert_not_called()
|
||||
|
||||
|
||||
def test_has_executing_nodes_requires_pause() -> None:
|
||||
graph_execution = GraphExecution(workflow_id="workflow")
|
||||
graph_execution.start()
|
||||
|
||||
coordinator, _, _ = _build_coordinator(graph_execution)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
coordinator.has_executing_nodes()
|
||||
|
||||
@ -0,0 +1,189 @@
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import NodeState
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.graph_state_manager import GraphStateManager
|
||||
from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.llm.entities import (
|
||||
ContextConfig,
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeData,
|
||||
ModelConfig,
|
||||
VisionConfig,
|
||||
)
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
from .test_mock_config import MockConfig
|
||||
from .test_mock_nodes import MockLLMNode
|
||||
|
||||
|
||||
def _build_runtime_state() -> GraphRuntimeState:
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="user",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
workflow_execution_id="exec-1",
|
||||
),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
|
||||
def _build_llm_node(
|
||||
*,
|
||||
node_id: str,
|
||||
runtime_state: GraphRuntimeState,
|
||||
graph_init_params: GraphInitParams,
|
||||
mock_config: MockConfig,
|
||||
) -> MockLLMNode:
|
||||
llm_data = LLMNodeData(
|
||||
title=f"LLM {node_id}",
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
|
||||
prompt_template=[
|
||||
LLMNodeChatModelMessage(
|
||||
text=f"Prompt {node_id}",
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
)
|
||||
],
|
||||
context=ContextConfig(enabled=False, variable_selector=None),
|
||||
vision=VisionConfig(enabled=False),
|
||||
reasoning_format="tagged",
|
||||
)
|
||||
llm_config = {"id": node_id, "data": llm_data.model_dump()}
|
||||
return MockLLMNode(
|
||||
id=llm_config["id"],
|
||||
config=llm_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
|
||||
|
||||
def _build_graph(runtime_state: GraphRuntimeState) -> Graph:
|
||||
graph_config: dict[str, object] = {"nodes": [], "edges": []}
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config=graph_config,
|
||||
user_id="user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
|
||||
start_node = StartNode(
|
||||
id=start_config["id"],
|
||||
config=start_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
mock_config = MockConfig()
|
||||
llm_a = _build_llm_node(
|
||||
node_id="llm_a",
|
||||
runtime_state=runtime_state,
|
||||
graph_init_params=graph_init_params,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
llm_b = _build_llm_node(
|
||||
node_id="llm_b",
|
||||
runtime_state=runtime_state,
|
||||
graph_init_params=graph_init_params,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
|
||||
end_data = EndNodeData(title="End", outputs=[], desc=None)
|
||||
end_config = {"id": "end", "data": end_data.model_dump()}
|
||||
end_node = EndNode(
|
||||
id=end_config["id"],
|
||||
config=end_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
builder = (
|
||||
Graph.new()
|
||||
.add_root(start_node)
|
||||
.add_node(llm_a, from_node_id="start")
|
||||
.add_node(llm_b, from_node_id="start")
|
||||
.add_node(end_node, from_node_id="llm_a")
|
||||
)
|
||||
return builder.connect(tail="llm_b", head="end").build()
|
||||
|
||||
|
||||
def _edge_state_map(graph: Graph) -> Mapping[tuple[str, str, str], NodeState]:
|
||||
return {(edge.tail, edge.head, edge.source_handle): edge.state for edge in graph.edges.values()}
|
||||
|
||||
|
||||
def test_runtime_state_snapshot_restores_graph_states() -> None:
|
||||
runtime_state = _build_runtime_state()
|
||||
graph = _build_graph(runtime_state)
|
||||
runtime_state.attach_graph(graph)
|
||||
|
||||
graph.nodes["llm_a"].state = NodeState.TAKEN
|
||||
graph.nodes["llm_b"].state = NodeState.SKIPPED
|
||||
|
||||
for edge in graph.edges.values():
|
||||
if edge.tail == "start" and edge.head == "llm_a":
|
||||
edge.state = NodeState.TAKEN
|
||||
elif edge.tail == "start" and edge.head == "llm_b":
|
||||
edge.state = NodeState.SKIPPED
|
||||
elif edge.head == "end" and edge.tail == "llm_a":
|
||||
edge.state = NodeState.TAKEN
|
||||
elif edge.head == "end" and edge.tail == "llm_b":
|
||||
edge.state = NodeState.SKIPPED
|
||||
|
||||
snapshot = runtime_state.dumps()
|
||||
|
||||
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
|
||||
resumed_graph = _build_graph(resumed_state)
|
||||
resumed_state.attach_graph(resumed_graph)
|
||||
|
||||
assert resumed_graph.nodes["llm_a"].state == NodeState.TAKEN
|
||||
assert resumed_graph.nodes["llm_b"].state == NodeState.SKIPPED
|
||||
assert _edge_state_map(resumed_graph) == _edge_state_map(graph)
|
||||
|
||||
|
||||
def test_join_readiness_uses_restored_edge_states() -> None:
|
||||
runtime_state = _build_runtime_state()
|
||||
graph = _build_graph(runtime_state)
|
||||
runtime_state.attach_graph(graph)
|
||||
|
||||
ready_queue = InMemoryReadyQueue()
|
||||
state_manager = GraphStateManager(graph, ready_queue)
|
||||
|
||||
for edge in graph.get_incoming_edges("end"):
|
||||
if edge.tail == "llm_a":
|
||||
edge.state = NodeState.TAKEN
|
||||
if edge.tail == "llm_b":
|
||||
edge.state = NodeState.UNKNOWN
|
||||
|
||||
assert state_manager.is_node_ready("end") is False
|
||||
|
||||
for edge in graph.get_incoming_edges("end"):
|
||||
if edge.tail == "llm_b":
|
||||
edge.state = NodeState.TAKEN
|
||||
|
||||
assert state_manager.is_node_ready("end") is True
|
||||
|
||||
snapshot = runtime_state.dumps()
|
||||
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
|
||||
resumed_graph = _build_graph(resumed_state)
|
||||
resumed_state.attach_graph(resumed_graph)
|
||||
|
||||
resumed_state_manager = GraphStateManager(resumed_graph, InMemoryReadyQueue())
|
||||
assert resumed_state_manager.is_node_ready("end") is True
|
||||
@ -1,5 +1,7 @@
|
||||
import datetime
|
||||
import time
|
||||
from collections.abc import Iterable
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
@ -14,11 +16,12 @@ from core.workflow.graph_events import (
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_events.node import NodeRunHumanInputFormFilledEvent
|
||||
from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.human_input import HumanInputNode
|
||||
from core.workflow.nodes.human_input.entities import HumanInputNodeData
|
||||
from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction
|
||||
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
|
||||
from core.workflow.nodes.llm.entities import (
|
||||
ContextConfig,
|
||||
LLMNodeChatModelMessage,
|
||||
@ -28,15 +31,21 @@ from core.workflow.nodes.llm.entities import (
|
||||
)
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
from .test_mock_config import MockConfig
|
||||
from .test_mock_nodes import MockLLMNode
|
||||
from .test_table_runner import TableTestRunner, WorkflowTestCase
|
||||
|
||||
|
||||
def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]:
|
||||
def _build_branching_graph(
|
||||
mock_config: MockConfig,
|
||||
form_repository: HumanInputFormRepository,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
) -> tuple[Graph, GraphRuntimeState]:
|
||||
graph_config: dict[str, object] = {"nodes": [], "edges": []}
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
@ -49,12 +58,18 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
if graph_runtime_state is None:
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="user",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
workflow_execution_id="test-execution-id",
|
||||
),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
|
||||
start_node = StartNode(
|
||||
@ -93,15 +108,21 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime
|
||||
|
||||
human_data = HumanInputNodeData(
|
||||
title="Human Input",
|
||||
required_variables=["human.input_ready"],
|
||||
pause_reason="Awaiting human input",
|
||||
form_content="Human input required",
|
||||
inputs=[],
|
||||
user_actions=[
|
||||
UserAction(id="primary", title="Primary"),
|
||||
UserAction(id="secondary", title="Secondary"),
|
||||
],
|
||||
)
|
||||
|
||||
human_config = {"id": "human", "data": human_data.model_dump()}
|
||||
human_node = HumanInputNode(
|
||||
id=human_config["id"],
|
||||
config=human_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
form_repository=form_repository,
|
||||
)
|
||||
|
||||
llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output")
|
||||
@ -219,8 +240,18 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None:
|
||||
for scenario in branch_scenarios:
|
||||
runner = TableTestRunner()
|
||||
|
||||
def initial_graph_factory() -> tuple[Graph, GraphRuntimeState]:
|
||||
return _build_branching_graph(mock_config)
|
||||
mock_create_repo = MagicMock(spec=HumanInputFormRepository)
|
||||
mock_create_repo.get_form.return_value = None
|
||||
mock_form_entity = MagicMock(spec=HumanInputFormEntity)
|
||||
mock_form_entity.id = "test_form_id"
|
||||
mock_form_entity.web_app_token = "test_web_app_token"
|
||||
mock_form_entity.recipients = []
|
||||
mock_form_entity.rendered_content = "rendered"
|
||||
mock_form_entity.submitted = False
|
||||
mock_create_repo.create_form.return_value = mock_form_entity
|
||||
|
||||
def initial_graph_factory(mock_create_repo=mock_create_repo) -> tuple[Graph, GraphRuntimeState]:
|
||||
return _build_branching_graph(mock_config, mock_create_repo)
|
||||
|
||||
initial_case = WorkflowTestCase(
|
||||
description="HumanInput pause before branching decision",
|
||||
@ -242,23 +273,16 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None:
|
||||
assert initial_result.success, initial_result.event_mismatch_details
|
||||
assert not any(isinstance(event, NodeRunStreamChunkEvent) for event in initial_result.events)
|
||||
|
||||
graph_runtime_state = initial_result.graph_runtime_state
|
||||
graph = initial_result.graph
|
||||
assert graph_runtime_state is not None
|
||||
assert graph is not None
|
||||
|
||||
graph_runtime_state.variable_pool.add(("human", "input_ready"), True)
|
||||
graph_runtime_state.variable_pool.add(("human", "edge_source_handle"), scenario["handle"])
|
||||
graph_runtime_state.graph_execution.pause_reason = None
|
||||
|
||||
pre_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_pre_chunks"])
|
||||
post_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_post_chunks"])
|
||||
expected_pre_chunk_events_in_resumption = [
|
||||
GraphRunStartedEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunHumanInputFormFilledEvent,
|
||||
]
|
||||
|
||||
expected_resume_sequence: list[type] = (
|
||||
[
|
||||
GraphRunStartedEvent,
|
||||
NodeRunStartedEvent,
|
||||
]
|
||||
expected_pre_chunk_events_in_resumption
|
||||
+ [NodeRunStreamChunkEvent] * pre_chunk_count
|
||||
+ [
|
||||
NodeRunSucceededEvent,
|
||||
@ -273,11 +297,25 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None:
|
||||
]
|
||||
)
|
||||
|
||||
mock_get_repo = MagicMock(spec=HumanInputFormRepository)
|
||||
submitted_form = MagicMock(spec=HumanInputFormEntity)
|
||||
submitted_form.id = mock_form_entity.id
|
||||
submitted_form.web_app_token = mock_form_entity.web_app_token
|
||||
submitted_form.recipients = []
|
||||
submitted_form.rendered_content = mock_form_entity.rendered_content
|
||||
submitted_form.submitted = True
|
||||
submitted_form.selected_action_id = scenario["handle"]
|
||||
submitted_form.submitted_data = {}
|
||||
submitted_form.expiration_time = naive_utc_now() + datetime.timedelta(days=1)
|
||||
mock_get_repo.get_form.return_value = submitted_form
|
||||
|
||||
def resume_graph_factory(
|
||||
graph_snapshot: Graph = graph,
|
||||
state_snapshot: GraphRuntimeState = graph_runtime_state,
|
||||
initial_result=initial_result, mock_get_repo=mock_get_repo
|
||||
) -> tuple[Graph, GraphRuntimeState]:
|
||||
return graph_snapshot, state_snapshot
|
||||
assert initial_result.graph_runtime_state is not None
|
||||
serialized_runtime_state = initial_result.graph_runtime_state.dumps()
|
||||
resume_runtime_state = GraphRuntimeState.from_snapshot(serialized_runtime_state)
|
||||
return _build_branching_graph(mock_config, mock_get_repo, resume_runtime_state)
|
||||
|
||||
resume_case = WorkflowTestCase(
|
||||
description=f"HumanInput resumes via {scenario['handle']} branch",
|
||||
@ -321,7 +359,8 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None:
|
||||
for index, event in enumerate(resume_events)
|
||||
if isinstance(event, NodeRunStreamChunkEvent) and index < human_success_index
|
||||
]
|
||||
assert pre_indices == list(range(2, 2 + pre_chunk_count))
|
||||
expected_pre_chunk_events_count_in_resumption = len(expected_pre_chunk_events_in_resumption)
|
||||
assert pre_indices == list(range(expected_pre_chunk_events_count_in_resumption, human_success_index))
|
||||
|
||||
resume_chunk_indices = [
|
||||
index
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
import datetime
|
||||
import time
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
@ -13,11 +15,12 @@ from core.workflow.graph_events import (
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_events.node import NodeRunHumanInputFormFilledEvent
|
||||
from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.human_input import HumanInputNode
|
||||
from core.workflow.nodes.human_input.entities import HumanInputNodeData
|
||||
from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction
|
||||
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
|
||||
from core.workflow.nodes.llm.entities import (
|
||||
ContextConfig,
|
||||
LLMNodeChatModelMessage,
|
||||
@ -27,15 +30,21 @@ from core.workflow.nodes.llm.entities import (
|
||||
)
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
from .test_mock_config import MockConfig
|
||||
from .test_mock_nodes import MockLLMNode
|
||||
from .test_table_runner import TableTestRunner, WorkflowTestCase
|
||||
|
||||
|
||||
def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]:
|
||||
def _build_llm_human_llm_graph(
|
||||
mock_config: MockConfig,
|
||||
form_repository: HumanInputFormRepository,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
) -> tuple[Graph, GraphRuntimeState]:
|
||||
graph_config: dict[str, object] = {"nodes": [], "edges": []}
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
@ -48,12 +57,15 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
if graph_runtime_state is None:
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="user", app_id="app", workflow_id="workflow", workflow_execution_id="test-execution-id,"
|
||||
),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
|
||||
start_node = StartNode(
|
||||
@ -92,15 +104,21 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun
|
||||
|
||||
human_data = HumanInputNodeData(
|
||||
title="Human Input",
|
||||
required_variables=["human.input_ready"],
|
||||
pause_reason="Awaiting human input",
|
||||
form_content="Human input required",
|
||||
inputs=[],
|
||||
user_actions=[
|
||||
UserAction(id="accept", title="Accept"),
|
||||
UserAction(id="reject", title="Reject"),
|
||||
],
|
||||
)
|
||||
|
||||
human_config = {"id": "human", "data": human_data.model_dump()}
|
||||
human_node = HumanInputNode(
|
||||
id=human_config["id"],
|
||||
config=human_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
form_repository=form_repository,
|
||||
)
|
||||
|
||||
llm_second = _create_llm_node("llm_resume", "Follow-up LLM", "Follow-up prompt")
|
||||
@ -130,7 +148,7 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun
|
||||
.add_root(start_node)
|
||||
.add_node(llm_first)
|
||||
.add_node(human_node)
|
||||
.add_node(llm_second)
|
||||
.add_node(llm_second, source_handle="accept")
|
||||
.add_node(end_node)
|
||||
.build()
|
||||
)
|
||||
@ -167,8 +185,18 @@ def test_human_input_llm_streaming_order_across_pause() -> None:
|
||||
GraphRunPausedEvent, # graph run pauses awaiting resume
|
||||
]
|
||||
|
||||
mock_create_repo = MagicMock(spec=HumanInputFormRepository)
|
||||
mock_create_repo.get_form.return_value = None
|
||||
mock_form_entity = MagicMock(spec=HumanInputFormEntity)
|
||||
mock_form_entity.id = "test_form_id"
|
||||
mock_form_entity.web_app_token = "test_web_app_token"
|
||||
mock_form_entity.recipients = []
|
||||
mock_form_entity.rendered_content = "rendered"
|
||||
mock_form_entity.submitted = False
|
||||
mock_create_repo.create_form.return_value = mock_form_entity
|
||||
|
||||
def graph_factory() -> tuple[Graph, GraphRuntimeState]:
|
||||
return _build_llm_human_llm_graph(mock_config)
|
||||
return _build_llm_human_llm_graph(mock_config, mock_create_repo)
|
||||
|
||||
initial_case = WorkflowTestCase(
|
||||
description="HumanInput pause preserves LLM streaming order",
|
||||
@ -210,6 +238,8 @@ def test_human_input_llm_streaming_order_across_pause() -> None:
|
||||
expected_resume_sequence: list[type] = [
|
||||
GraphRunStartedEvent, # resumed graph run begins
|
||||
NodeRunStartedEvent, # human node restarts
|
||||
# Form Filled should be generated first, then the node execution ends and stream chunk is generated.
|
||||
NodeRunHumanInputFormFilledEvent,
|
||||
NodeRunStreamChunkEvent, # cached llm_initial chunk 1
|
||||
NodeRunStreamChunkEvent, # cached llm_initial chunk 2
|
||||
NodeRunStreamChunkEvent, # cached llm_initial final chunk
|
||||
@ -225,12 +255,27 @@ def test_human_input_llm_streaming_order_across_pause() -> None:
|
||||
GraphRunSucceededEvent, # graph run succeeds after resume
|
||||
]
|
||||
|
||||
mock_get_repo = MagicMock(spec=HumanInputFormRepository)
|
||||
submitted_form = MagicMock(spec=HumanInputFormEntity)
|
||||
submitted_form.id = mock_form_entity.id
|
||||
submitted_form.web_app_token = mock_form_entity.web_app_token
|
||||
submitted_form.recipients = []
|
||||
submitted_form.rendered_content = mock_form_entity.rendered_content
|
||||
submitted_form.submitted = True
|
||||
submitted_form.selected_action_id = "accept"
|
||||
submitted_form.submitted_data = {}
|
||||
submitted_form.expiration_time = naive_utc_now() + datetime.timedelta(days=1)
|
||||
mock_get_repo.get_form.return_value = submitted_form
|
||||
|
||||
def resume_graph_factory() -> tuple[Graph, GraphRuntimeState]:
|
||||
assert graph_runtime_state is not None
|
||||
assert graph is not None
|
||||
graph_runtime_state.variable_pool.add(("human", "input_ready"), True)
|
||||
graph_runtime_state.graph_execution.pause_reason = None
|
||||
return graph, graph_runtime_state
|
||||
# restruct the graph runtime state
|
||||
serialized_runtime_state = initial_result.graph_runtime_state.dumps()
|
||||
resume_runtime_state = GraphRuntimeState.from_snapshot(serialized_runtime_state)
|
||||
return _build_llm_human_llm_graph(
|
||||
mock_config,
|
||||
mock_get_repo,
|
||||
resume_runtime_state,
|
||||
)
|
||||
|
||||
resume_case = WorkflowTestCase(
|
||||
description="HumanInput resume continues LLM streaming order",
|
||||
|
||||
@ -0,0 +1,270 @@
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel
|
||||
from core.workflow.graph_engine.config import GraphEngineConfig
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.graph_events import (
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.nodes.base.entities import OutputVariableEntity
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction
|
||||
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
|
||||
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.repositories.human_input_form_repository import (
|
||||
FormCreateParams,
|
||||
HumanInputFormEntity,
|
||||
HumanInputFormRepository,
|
||||
)
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
|
||||
class PauseStateStore(Protocol):
|
||||
def save(self, runtime_state: GraphRuntimeState) -> None: ...
|
||||
|
||||
def load(self) -> GraphRuntimeState: ...
|
||||
|
||||
|
||||
class InMemoryPauseStore:
|
||||
def __init__(self) -> None:
|
||||
self._snapshot: str | None = None
|
||||
|
||||
def save(self, runtime_state: GraphRuntimeState) -> None:
|
||||
self._snapshot = runtime_state.dumps()
|
||||
|
||||
def load(self) -> GraphRuntimeState:
|
||||
assert self._snapshot is not None
|
||||
return GraphRuntimeState.from_snapshot(self._snapshot)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StaticForm(HumanInputFormEntity):
|
||||
form_id: str
|
||||
rendered: str
|
||||
is_submitted: bool
|
||||
action_id: str | None = None
|
||||
data: Mapping[str, Any] | None = None
|
||||
status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING
|
||||
expiration: datetime = naive_utc_now() + timedelta(days=1)
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self.form_id
|
||||
|
||||
@property
|
||||
def web_app_token(self) -> str | None:
|
||||
return "token"
|
||||
|
||||
@property
|
||||
def recipients(self) -> list:
|
||||
return []
|
||||
|
||||
@property
|
||||
def rendered_content(self) -> str:
|
||||
return self.rendered
|
||||
|
||||
@property
|
||||
def selected_action_id(self) -> str | None:
|
||||
return self.action_id
|
||||
|
||||
@property
|
||||
def submitted_data(self) -> Mapping[str, Any] | None:
|
||||
return self.data
|
||||
|
||||
@property
|
||||
def submitted(self) -> bool:
|
||||
return self.is_submitted
|
||||
|
||||
@property
|
||||
def status(self) -> HumanInputFormStatus:
|
||||
return self.status_value
|
||||
|
||||
@property
|
||||
def expiration_time(self) -> datetime:
|
||||
return self.expiration
|
||||
|
||||
|
||||
class StaticRepo(HumanInputFormRepository):
|
||||
def __init__(self, forms_by_node_id: Mapping[str, HumanInputFormEntity]) -> None:
|
||||
self._forms_by_node_id = dict(forms_by_node_id)
|
||||
|
||||
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
|
||||
return self._forms_by_node_id.get(node_id)
|
||||
|
||||
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
|
||||
raise AssertionError("create_form should not be called in resume scenario")
|
||||
|
||||
|
||||
def _build_runtime_state() -> GraphRuntimeState:
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="user",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
workflow_execution_id="exec-1",
|
||||
),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
|
||||
def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository) -> Graph:
|
||||
graph_config: dict[str, object] = {"nodes": [], "edges": []}
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config=graph_config,
|
||||
user_id="user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
|
||||
start_node = StartNode(
|
||||
id=start_config["id"],
|
||||
config=start_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
human_data = HumanInputNodeData(
|
||||
title="Human Input",
|
||||
form_content="Human input required",
|
||||
inputs=[],
|
||||
user_actions=[UserAction(id="approve", title="Approve")],
|
||||
)
|
||||
|
||||
human_a_config = {"id": "human_a", "data": human_data.model_dump()}
|
||||
human_a = HumanInputNode(
|
||||
id=human_a_config["id"],
|
||||
config=human_a_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
form_repository=repo,
|
||||
)
|
||||
|
||||
human_b_config = {"id": "human_b", "data": human_data.model_dump()}
|
||||
human_b = HumanInputNode(
|
||||
id=human_b_config["id"],
|
||||
config=human_b_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
form_repository=repo,
|
||||
)
|
||||
|
||||
end_data = EndNodeData(
|
||||
title="End",
|
||||
outputs=[
|
||||
OutputVariableEntity(variable="res_a", value_selector=["human_a", "__action_id"]),
|
||||
OutputVariableEntity(variable="res_b", value_selector=["human_b", "__action_id"]),
|
||||
],
|
||||
desc=None,
|
||||
)
|
||||
end_config = {"id": "end", "data": end_data.model_dump()}
|
||||
end_node = EndNode(
|
||||
id=end_config["id"],
|
||||
config=end_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
builder = (
|
||||
Graph.new()
|
||||
.add_root(start_node)
|
||||
.add_node(human_a, from_node_id="start")
|
||||
.add_node(human_b, from_node_id="start")
|
||||
.add_node(end_node, from_node_id="human_a", source_handle="approve")
|
||||
)
|
||||
return builder.connect(tail="human_b", head="end", source_handle="approve").build()
|
||||
|
||||
|
||||
def _run_graph(graph: Graph, runtime_state: GraphRuntimeState) -> list[object]:
|
||||
engine = GraphEngine(
|
||||
workflow_id="workflow",
|
||||
graph=graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
config=GraphEngineConfig(
|
||||
min_workers=2,
|
||||
max_workers=2,
|
||||
scale_up_threshold=1,
|
||||
scale_down_idle_time=30.0,
|
||||
),
|
||||
)
|
||||
return list(engine.run())
|
||||
|
||||
|
||||
def _form(submitted: bool, action_id: str | None) -> StaticForm:
|
||||
return StaticForm(
|
||||
form_id="form",
|
||||
rendered="rendered",
|
||||
is_submitted=submitted,
|
||||
action_id=action_id,
|
||||
data={},
|
||||
status_value=HumanInputFormStatus.SUBMITTED if submitted else HumanInputFormStatus.WAITING,
|
||||
)
|
||||
|
||||
|
||||
def test_parallel_human_input_join_completes_after_second_resume() -> None:
|
||||
pause_store: PauseStateStore = InMemoryPauseStore()
|
||||
|
||||
initial_state = _build_runtime_state()
|
||||
initial_repo = StaticRepo(
|
||||
{
|
||||
"human_a": _form(submitted=False, action_id=None),
|
||||
"human_b": _form(submitted=False, action_id=None),
|
||||
}
|
||||
)
|
||||
initial_graph = _build_graph(initial_state, initial_repo)
|
||||
initial_events = _run_graph(initial_graph, initial_state)
|
||||
|
||||
assert isinstance(initial_events[-1], GraphRunPausedEvent)
|
||||
pause_store.save(initial_state)
|
||||
|
||||
first_resume_state = pause_store.load()
|
||||
first_resume_repo = StaticRepo(
|
||||
{
|
||||
"human_a": _form(submitted=True, action_id="approve"),
|
||||
"human_b": _form(submitted=False, action_id=None),
|
||||
}
|
||||
)
|
||||
first_resume_graph = _build_graph(first_resume_state, first_resume_repo)
|
||||
first_resume_events = _run_graph(first_resume_graph, first_resume_state)
|
||||
|
||||
assert isinstance(first_resume_events[0], GraphRunStartedEvent)
|
||||
assert first_resume_events[0].reason is WorkflowStartReason.RESUMPTION
|
||||
assert isinstance(first_resume_events[-1], GraphRunPausedEvent)
|
||||
pause_store.save(first_resume_state)
|
||||
|
||||
second_resume_state = pause_store.load()
|
||||
second_resume_repo = StaticRepo(
|
||||
{
|
||||
"human_a": _form(submitted=True, action_id="approve"),
|
||||
"human_b": _form(submitted=True, action_id="approve"),
|
||||
}
|
||||
)
|
||||
second_resume_graph = _build_graph(second_resume_state, second_resume_repo)
|
||||
second_resume_events = _run_graph(second_resume_graph, second_resume_state)
|
||||
|
||||
assert isinstance(second_resume_events[0], GraphRunStartedEvent)
|
||||
assert second_resume_events[0].reason is WorkflowStartReason.RESUMPTION
|
||||
assert isinstance(second_resume_events[-1], GraphRunSucceededEvent)
|
||||
assert any(isinstance(event, NodeRunSucceededEvent) and event.node_id == "end" for event in second_resume_events)
|
||||
@ -0,0 +1,333 @@
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel
|
||||
from core.workflow.graph_engine.config import GraphEngineConfig
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.graph_events import (
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
NodeRunPauseRequestedEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction
|
||||
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
|
||||
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
|
||||
from core.workflow.nodes.llm.entities import (
|
||||
ContextConfig,
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeData,
|
||||
ModelConfig,
|
||||
VisionConfig,
|
||||
)
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.repositories.human_input_form_repository import (
|
||||
FormCreateParams,
|
||||
HumanInputFormEntity,
|
||||
HumanInputFormRepository,
|
||||
)
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
from .test_mock_config import MockConfig, NodeMockConfig
|
||||
from .test_mock_nodes import MockLLMNode
|
||||
|
||||
|
||||
@dataclass
|
||||
class StaticForm(HumanInputFormEntity):
|
||||
form_id: str
|
||||
rendered: str
|
||||
is_submitted: bool
|
||||
action_id: str | None = None
|
||||
data: Mapping[str, Any] | None = None
|
||||
status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING
|
||||
expiration: datetime = naive_utc_now() + timedelta(days=1)
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self.form_id
|
||||
|
||||
@property
|
||||
def web_app_token(self) -> str | None:
|
||||
return "token"
|
||||
|
||||
@property
|
||||
def recipients(self) -> list:
|
||||
return []
|
||||
|
||||
@property
|
||||
def rendered_content(self) -> str:
|
||||
return self.rendered
|
||||
|
||||
@property
|
||||
def selected_action_id(self) -> str | None:
|
||||
return self.action_id
|
||||
|
||||
@property
|
||||
def submitted_data(self) -> Mapping[str, Any] | None:
|
||||
return self.data
|
||||
|
||||
@property
|
||||
def submitted(self) -> bool:
|
||||
return self.is_submitted
|
||||
|
||||
@property
|
||||
def status(self) -> HumanInputFormStatus:
|
||||
return self.status_value
|
||||
|
||||
@property
|
||||
def expiration_time(self) -> datetime:
|
||||
return self.expiration
|
||||
|
||||
|
||||
class StaticRepo(HumanInputFormRepository):
|
||||
def __init__(self, forms_by_node_id: Mapping[str, HumanInputFormEntity]) -> None:
|
||||
self._forms_by_node_id = dict(forms_by_node_id)
|
||||
|
||||
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
|
||||
return self._forms_by_node_id.get(node_id)
|
||||
|
||||
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
|
||||
raise AssertionError("create_form should not be called in resume scenario")
|
||||
|
||||
|
||||
class DelayedHumanInputNode(HumanInputNode):
|
||||
def __init__(self, delay_seconds: float, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self._delay_seconds = delay_seconds
|
||||
|
||||
def _run(self):
|
||||
if self._delay_seconds > 0:
|
||||
time.sleep(self._delay_seconds)
|
||||
yield from super()._run()
|
||||
|
||||
|
||||
def _build_runtime_state() -> GraphRuntimeState:
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="user",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
workflow_execution_id="exec-1",
|
||||
),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
|
||||
def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository, mock_config: MockConfig) -> Graph:
|
||||
graph_config: dict[str, object] = {"nodes": [], "edges": []}
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config=graph_config,
|
||||
user_id="user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
|
||||
start_node = StartNode(
|
||||
id=start_config["id"],
|
||||
config=start_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
human_data = HumanInputNodeData(
|
||||
title="Human Input",
|
||||
form_content="Human input required",
|
||||
inputs=[],
|
||||
user_actions=[UserAction(id="approve", title="Approve")],
|
||||
)
|
||||
|
||||
human_a_config = {"id": "human_a", "data": human_data.model_dump()}
|
||||
human_a = HumanInputNode(
|
||||
id=human_a_config["id"],
|
||||
config=human_a_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
form_repository=repo,
|
||||
)
|
||||
|
||||
human_b_config = {"id": "human_b", "data": human_data.model_dump()}
|
||||
human_b = DelayedHumanInputNode(
|
||||
id=human_b_config["id"],
|
||||
config=human_b_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
form_repository=repo,
|
||||
delay_seconds=0.2,
|
||||
)
|
||||
|
||||
llm_data = LLMNodeData(
|
||||
title="LLM A",
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
|
||||
prompt_template=[
|
||||
LLMNodeChatModelMessage(
|
||||
text="Prompt A",
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
)
|
||||
],
|
||||
context=ContextConfig(enabled=False, variable_selector=None),
|
||||
vision=VisionConfig(enabled=False),
|
||||
reasoning_format="tagged",
|
||||
structured_output_enabled=False,
|
||||
)
|
||||
llm_config = {"id": "llm_a", "data": llm_data.model_dump()}
|
||||
llm_a = MockLLMNode(
|
||||
id=llm_config["id"],
|
||||
config=llm_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
|
||||
return (
|
||||
Graph.new()
|
||||
.add_root(start_node)
|
||||
.add_node(human_a, from_node_id="start")
|
||||
.add_node(human_b, from_node_id="start")
|
||||
.add_node(llm_a, from_node_id="human_a", source_handle="approve")
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
def test_parallel_human_input_pause_preserves_node_finished() -> None:
|
||||
runtime_state = _build_runtime_state()
|
||||
|
||||
runtime_state.graph_execution.start()
|
||||
runtime_state.register_paused_node("human_a")
|
||||
runtime_state.register_paused_node("human_b")
|
||||
|
||||
submitted = StaticForm(
|
||||
form_id="form-a",
|
||||
rendered="rendered",
|
||||
is_submitted=True,
|
||||
action_id="approve",
|
||||
data={},
|
||||
status_value=HumanInputFormStatus.SUBMITTED,
|
||||
)
|
||||
pending = StaticForm(
|
||||
form_id="form-b",
|
||||
rendered="rendered",
|
||||
is_submitted=False,
|
||||
action_id=None,
|
||||
data=None,
|
||||
status_value=HumanInputFormStatus.WAITING,
|
||||
)
|
||||
repo = StaticRepo({"human_a": submitted, "human_b": pending})
|
||||
|
||||
mock_config = MockConfig()
|
||||
mock_config.simulate_delays = True
|
||||
mock_config.set_node_config(
|
||||
"llm_a",
|
||||
NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5),
|
||||
)
|
||||
|
||||
graph = _build_graph(runtime_state, repo, mock_config)
|
||||
engine = GraphEngine(
|
||||
workflow_id="workflow",
|
||||
graph=graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
config=GraphEngineConfig(
|
||||
min_workers=2,
|
||||
max_workers=2,
|
||||
scale_up_threshold=1,
|
||||
scale_down_idle_time=30.0,
|
||||
),
|
||||
)
|
||||
|
||||
events = list(engine.run())
|
||||
|
||||
llm_started = any(isinstance(e, NodeRunStartedEvent) and e.node_id == "llm_a" for e in events)
|
||||
llm_succeeded = any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in events)
|
||||
human_b_pause = any(isinstance(e, NodeRunPauseRequestedEvent) and e.node_id == "human_b" for e in events)
|
||||
graph_paused = any(isinstance(e, GraphRunPausedEvent) for e in events)
|
||||
graph_started = any(isinstance(e, GraphRunStartedEvent) for e in events)
|
||||
|
||||
assert graph_started
|
||||
assert graph_paused
|
||||
assert human_b_pause
|
||||
assert llm_started
|
||||
assert llm_succeeded
|
||||
|
||||
|
||||
def test_parallel_human_input_pause_preserves_node_finished_after_snapshot_resume() -> None:
|
||||
base_state = _build_runtime_state()
|
||||
base_state.graph_execution.start()
|
||||
base_state.register_paused_node("human_a")
|
||||
base_state.register_paused_node("human_b")
|
||||
snapshot = base_state.dumps()
|
||||
|
||||
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
|
||||
|
||||
submitted = StaticForm(
|
||||
form_id="form-a",
|
||||
rendered="rendered",
|
||||
is_submitted=True,
|
||||
action_id="approve",
|
||||
data={},
|
||||
status_value=HumanInputFormStatus.SUBMITTED,
|
||||
)
|
||||
pending = StaticForm(
|
||||
form_id="form-b",
|
||||
rendered="rendered",
|
||||
is_submitted=False,
|
||||
action_id=None,
|
||||
data=None,
|
||||
status_value=HumanInputFormStatus.WAITING,
|
||||
)
|
||||
repo = StaticRepo({"human_a": submitted, "human_b": pending})
|
||||
|
||||
mock_config = MockConfig()
|
||||
mock_config.simulate_delays = True
|
||||
mock_config.set_node_config(
|
||||
"llm_a",
|
||||
NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5),
|
||||
)
|
||||
|
||||
graph = _build_graph(resumed_state, repo, mock_config)
|
||||
engine = GraphEngine(
|
||||
workflow_id="workflow",
|
||||
graph=graph,
|
||||
graph_runtime_state=resumed_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
config=GraphEngineConfig(
|
||||
min_workers=2,
|
||||
max_workers=2,
|
||||
scale_up_threshold=1,
|
||||
scale_down_idle_time=30.0,
|
||||
),
|
||||
)
|
||||
|
||||
events = list(engine.run())
|
||||
|
||||
start_event = next(e for e in events if isinstance(e, GraphRunStartedEvent))
|
||||
assert start_event.reason is WorkflowStartReason.RESUMPTION
|
||||
|
||||
llm_started = any(isinstance(e, NodeRunStartedEvent) and e.node_id == "llm_a" for e in events)
|
||||
llm_succeeded = any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in events)
|
||||
human_b_pause = any(isinstance(e, NodeRunPauseRequestedEvent) and e.node_id == "human_b" for e in events)
|
||||
graph_paused = any(isinstance(e, GraphRunPausedEvent) for e in events)
|
||||
|
||||
assert graph_paused
|
||||
assert human_b_pause
|
||||
assert llm_started
|
||||
assert llm_succeeded
|
||||
@ -0,0 +1,309 @@
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel
|
||||
from core.workflow.graph_engine.config import GraphEngineConfig
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.graph_events import (
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction
|
||||
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
|
||||
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
|
||||
from core.workflow.nodes.llm.entities import (
|
||||
ContextConfig,
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeData,
|
||||
ModelConfig,
|
||||
VisionConfig,
|
||||
)
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.repositories.human_input_form_repository import (
|
||||
FormCreateParams,
|
||||
HumanInputFormEntity,
|
||||
HumanInputFormRepository,
|
||||
)
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
from .test_mock_config import MockConfig, NodeMockConfig
|
||||
from .test_mock_nodes import MockLLMNode
|
||||
|
||||
|
||||
@dataclass
|
||||
class StaticForm(HumanInputFormEntity):
|
||||
form_id: str
|
||||
rendered: str
|
||||
is_submitted: bool
|
||||
action_id: str | None = None
|
||||
data: Mapping[str, Any] | None = None
|
||||
status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING
|
||||
expiration: datetime = naive_utc_now() + timedelta(days=1)
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self.form_id
|
||||
|
||||
@property
|
||||
def web_app_token(self) -> str | None:
|
||||
return "token"
|
||||
|
||||
@property
|
||||
def recipients(self) -> list:
|
||||
return []
|
||||
|
||||
@property
|
||||
def rendered_content(self) -> str:
|
||||
return self.rendered
|
||||
|
||||
@property
|
||||
def selected_action_id(self) -> str | None:
|
||||
return self.action_id
|
||||
|
||||
@property
|
||||
def submitted_data(self) -> Mapping[str, Any] | None:
|
||||
return self.data
|
||||
|
||||
@property
|
||||
def submitted(self) -> bool:
|
||||
return self.is_submitted
|
||||
|
||||
@property
|
||||
def status(self) -> HumanInputFormStatus:
|
||||
return self.status_value
|
||||
|
||||
@property
|
||||
def expiration_time(self) -> datetime:
|
||||
return self.expiration
|
||||
|
||||
|
||||
class StaticRepo(HumanInputFormRepository):
|
||||
def __init__(self, form: HumanInputFormEntity) -> None:
|
||||
self._form = form
|
||||
|
||||
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
|
||||
if node_id != "human_pause":
|
||||
return None
|
||||
return self._form
|
||||
|
||||
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
|
||||
raise AssertionError("create_form should not be called in this test")
|
||||
|
||||
|
||||
def _build_runtime_state() -> GraphRuntimeState:
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="user",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
workflow_execution_id="exec-1",
|
||||
),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
|
||||
def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository, mock_config: MockConfig) -> Graph:
|
||||
graph_config: dict[str, object] = {"nodes": [], "edges": []}
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config=graph_config,
|
||||
user_id="user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
|
||||
start_node = StartNode(
|
||||
id=start_config["id"],
|
||||
config=start_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
llm_a_data = LLMNodeData(
|
||||
title="LLM A",
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
|
||||
prompt_template=[
|
||||
LLMNodeChatModelMessage(
|
||||
text="Prompt A",
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
)
|
||||
],
|
||||
context=ContextConfig(enabled=False, variable_selector=None),
|
||||
vision=VisionConfig(enabled=False),
|
||||
reasoning_format="tagged",
|
||||
structured_output_enabled=False,
|
||||
)
|
||||
llm_a_config = {"id": "llm_a", "data": llm_a_data.model_dump()}
|
||||
llm_a = MockLLMNode(
|
||||
id=llm_a_config["id"],
|
||||
config=llm_a_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
|
||||
llm_b_data = LLMNodeData(
|
||||
title="LLM B",
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
|
||||
prompt_template=[
|
||||
LLMNodeChatModelMessage(
|
||||
text="Prompt B",
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
)
|
||||
],
|
||||
context=ContextConfig(enabled=False, variable_selector=None),
|
||||
vision=VisionConfig(enabled=False),
|
||||
reasoning_format="tagged",
|
||||
structured_output_enabled=False,
|
||||
)
|
||||
llm_b_config = {"id": "llm_b", "data": llm_b_data.model_dump()}
|
||||
llm_b = MockLLMNode(
|
||||
id=llm_b_config["id"],
|
||||
config=llm_b_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
|
||||
human_data = HumanInputNodeData(
|
||||
title="Human Input",
|
||||
form_content="Pause here",
|
||||
inputs=[],
|
||||
user_actions=[UserAction(id="approve", title="Approve")],
|
||||
)
|
||||
human_config = {"id": "human_pause", "data": human_data.model_dump()}
|
||||
human_node = HumanInputNode(
|
||||
id=human_config["id"],
|
||||
config=human_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
form_repository=repo,
|
||||
)
|
||||
|
||||
end_human_data = EndNodeData(title="End Human", outputs=[], desc=None)
|
||||
end_human_config = {"id": "end_human", "data": end_human_data.model_dump()}
|
||||
end_human = EndNode(
|
||||
id=end_human_config["id"],
|
||||
config=end_human_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
return (
|
||||
Graph.new()
|
||||
.add_root(start_node)
|
||||
.add_node(llm_a, from_node_id="start")
|
||||
.add_node(human_node, from_node_id="start")
|
||||
.add_node(llm_b, from_node_id="llm_a")
|
||||
.add_node(end_human, from_node_id="human_pause", source_handle="approve")
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
def _get_node_started_event(events: list[object], node_id: str) -> NodeRunStartedEvent | None:
|
||||
for event in events:
|
||||
if isinstance(event, NodeRunStartedEvent) and event.node_id == node_id:
|
||||
return event
|
||||
return None
|
||||
|
||||
|
||||
def test_pause_defers_ready_nodes_until_resume() -> None:
|
||||
runtime_state = _build_runtime_state()
|
||||
|
||||
paused_form = StaticForm(
|
||||
form_id="form-pause",
|
||||
rendered="rendered",
|
||||
is_submitted=False,
|
||||
status_value=HumanInputFormStatus.WAITING,
|
||||
)
|
||||
pause_repo = StaticRepo(paused_form)
|
||||
|
||||
mock_config = MockConfig()
|
||||
mock_config.simulate_delays = True
|
||||
mock_config.set_node_config(
|
||||
"llm_a",
|
||||
NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5),
|
||||
)
|
||||
mock_config.set_node_config(
|
||||
"llm_b",
|
||||
NodeMockConfig(node_id="llm_b", outputs={"text": "LLM B output"}, delay=0.0),
|
||||
)
|
||||
|
||||
graph = _build_graph(runtime_state, pause_repo, mock_config)
|
||||
engine = GraphEngine(
|
||||
workflow_id="workflow",
|
||||
graph=graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
config=GraphEngineConfig(
|
||||
min_workers=2,
|
||||
max_workers=2,
|
||||
scale_up_threshold=1,
|
||||
scale_down_idle_time=30.0,
|
||||
),
|
||||
)
|
||||
|
||||
paused_events = list(engine.run())
|
||||
|
||||
assert any(isinstance(e, GraphRunPausedEvent) for e in paused_events)
|
||||
assert any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in paused_events)
|
||||
assert _get_node_started_event(paused_events, "llm_b") is None
|
||||
|
||||
snapshot = runtime_state.dumps()
|
||||
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
|
||||
|
||||
submitted_form = StaticForm(
|
||||
form_id="form-pause",
|
||||
rendered="rendered",
|
||||
is_submitted=True,
|
||||
action_id="approve",
|
||||
data={},
|
||||
status_value=HumanInputFormStatus.SUBMITTED,
|
||||
)
|
||||
resume_repo = StaticRepo(submitted_form)
|
||||
|
||||
resumed_graph = _build_graph(resumed_state, resume_repo, mock_config)
|
||||
resumed_engine = GraphEngine(
|
||||
workflow_id="workflow",
|
||||
graph=resumed_graph,
|
||||
graph_runtime_state=resumed_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
config=GraphEngineConfig(
|
||||
min_workers=2,
|
||||
max_workers=2,
|
||||
scale_up_threshold=1,
|
||||
scale_down_idle_time=30.0,
|
||||
),
|
||||
)
|
||||
|
||||
resumed_events = list(resumed_engine.run())
|
||||
|
||||
start_event = next(e for e in resumed_events if isinstance(e, GraphRunStartedEvent))
|
||||
assert start_event.reason is WorkflowStartReason.RESUMPTION
|
||||
|
||||
llm_b_started = _get_node_started_event(resumed_events, "llm_b")
|
||||
assert llm_b_started is not None
|
||||
assert any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_b" for e in resumed_events)
|
||||
@ -0,0 +1,217 @@
|
||||
import datetime
|
||||
import time
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunPausedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_events.graph import GraphRunStartedEvent
|
||||
from core.workflow.nodes.base.entities import OutputVariableEntity
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction
|
||||
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.repositories.human_input_form_repository import (
|
||||
HumanInputFormEntity,
|
||||
HumanInputFormRepository,
|
||||
)
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
|
||||
def _build_runtime_state() -> GraphRuntimeState:
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="user",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
workflow_execution_id="test-execution-id",
|
||||
),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
|
||||
def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepository:
|
||||
repo = MagicMock(spec=HumanInputFormRepository)
|
||||
form_entity = MagicMock(spec=HumanInputFormEntity)
|
||||
form_entity.id = "test-form-id"
|
||||
form_entity.web_app_token = "test-form-token"
|
||||
form_entity.recipients = []
|
||||
form_entity.rendered_content = "rendered"
|
||||
form_entity.submitted = True
|
||||
form_entity.selected_action_id = action_id
|
||||
form_entity.submitted_data = {}
|
||||
form_entity.expiration_time = naive_utc_now() + datetime.timedelta(days=1)
|
||||
repo.get_form.return_value = form_entity
|
||||
return repo
|
||||
|
||||
|
||||
def _mock_form_repository_without_submission() -> HumanInputFormRepository:
|
||||
repo = MagicMock(spec=HumanInputFormRepository)
|
||||
form_entity = MagicMock(spec=HumanInputFormEntity)
|
||||
form_entity.id = "test-form-id"
|
||||
form_entity.web_app_token = "test-form-token"
|
||||
form_entity.recipients = []
|
||||
form_entity.rendered_content = "rendered"
|
||||
form_entity.submitted = False
|
||||
repo.create_form.return_value = form_entity
|
||||
repo.get_form.return_value = None
|
||||
return repo
|
||||
|
||||
|
||||
def _build_human_input_graph(
|
||||
runtime_state: GraphRuntimeState,
|
||||
form_repository: HumanInputFormRepository,
|
||||
) -> Graph:
|
||||
graph_config: dict[str, object] = {"nodes": [], "edges": []}
|
||||
params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config=graph_config,
|
||||
user_id="user",
|
||||
user_from="account",
|
||||
invoke_from="service-api",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
start_data = StartNodeData(title="start", variables=[])
|
||||
start_node = StartNode(
|
||||
id="start",
|
||||
config={"id": "start", "data": start_data.model_dump()},
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
human_data = HumanInputNodeData(
|
||||
title="human",
|
||||
form_content="Awaiting human input",
|
||||
inputs=[],
|
||||
user_actions=[
|
||||
UserAction(id="continue", title="Continue"),
|
||||
],
|
||||
)
|
||||
human_node = HumanInputNode(
|
||||
id="human",
|
||||
config={"id": "human", "data": human_data.model_dump()},
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
form_repository=form_repository,
|
||||
)
|
||||
|
||||
end_data = EndNodeData(
|
||||
title="end",
|
||||
outputs=[
|
||||
OutputVariableEntity(variable="result", value_selector=["human", "action_id"]),
|
||||
],
|
||||
desc=None,
|
||||
)
|
||||
end_node = EndNode(
|
||||
id="end",
|
||||
config={"id": "end", "data": end_data.model_dump()},
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
return (
|
||||
Graph.new()
|
||||
.add_root(start_node)
|
||||
.add_node(human_node)
|
||||
.add_node(end_node, from_node_id="human", source_handle="continue")
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
def _run_graph(graph: Graph, runtime_state: GraphRuntimeState) -> list[GraphEngineEvent]:
|
||||
engine = GraphEngine(
|
||||
workflow_id="workflow",
|
||||
graph=graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
)
|
||||
return list(engine.run())
|
||||
|
||||
|
||||
def _node_successes(events: list[GraphEngineEvent]) -> list[str]:
|
||||
return [event.node_id for event in events if isinstance(event, NodeRunSucceededEvent)]
|
||||
|
||||
|
||||
def _node_start_event(events: list[GraphEngineEvent], node_id: str) -> NodeRunStartedEvent | None:
|
||||
for event in events:
|
||||
if isinstance(event, NodeRunStartedEvent) and event.node_id == node_id:
|
||||
return event
|
||||
return None
|
||||
|
||||
|
||||
def _segment_value(variable_pool: VariablePool, selector: tuple[str, str]) -> Any:
|
||||
segment = variable_pool.get(selector)
|
||||
assert segment is not None
|
||||
return getattr(segment, "value", segment)
|
||||
|
||||
|
||||
def test_engine_resume_restores_state_and_completion():
|
||||
# Baseline run without pausing
|
||||
baseline_state = _build_runtime_state()
|
||||
baseline_repo = _mock_form_repository_with_submission(action_id="continue")
|
||||
baseline_graph = _build_human_input_graph(baseline_state, baseline_repo)
|
||||
baseline_events = _run_graph(baseline_graph, baseline_state)
|
||||
assert baseline_events
|
||||
first_paused_event = baseline_events[0]
|
||||
assert isinstance(first_paused_event, GraphRunStartedEvent)
|
||||
assert first_paused_event.reason is WorkflowStartReason.INITIAL
|
||||
assert isinstance(baseline_events[-1], GraphRunSucceededEvent)
|
||||
baseline_success_nodes = _node_successes(baseline_events)
|
||||
|
||||
# Run with pause
|
||||
paused_state = _build_runtime_state()
|
||||
pause_repo = _mock_form_repository_without_submission()
|
||||
paused_graph = _build_human_input_graph(paused_state, pause_repo)
|
||||
paused_events = _run_graph(paused_graph, paused_state)
|
||||
assert paused_events
|
||||
first_paused_event = paused_events[0]
|
||||
assert isinstance(first_paused_event, GraphRunStartedEvent)
|
||||
assert first_paused_event.reason is WorkflowStartReason.INITIAL
|
||||
assert isinstance(paused_events[-1], GraphRunPausedEvent)
|
||||
snapshot = paused_state.dumps()
|
||||
|
||||
# Resume from snapshot
|
||||
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
|
||||
resume_repo = _mock_form_repository_with_submission(action_id="continue")
|
||||
resumed_graph = _build_human_input_graph(resumed_state, resume_repo)
|
||||
resumed_events = _run_graph(resumed_graph, resumed_state)
|
||||
assert resumed_events
|
||||
first_resumed_event = resumed_events[0]
|
||||
assert isinstance(first_resumed_event, GraphRunStartedEvent)
|
||||
assert first_resumed_event.reason is WorkflowStartReason.RESUMPTION
|
||||
assert isinstance(resumed_events[-1], GraphRunSucceededEvent)
|
||||
|
||||
combined_success_nodes = _node_successes(paused_events) + _node_successes(resumed_events)
|
||||
assert combined_success_nodes == baseline_success_nodes
|
||||
|
||||
paused_human_started = _node_start_event(paused_events, "human")
|
||||
resumed_human_started = _node_start_event(resumed_events, "human")
|
||||
assert paused_human_started is not None
|
||||
assert resumed_human_started is not None
|
||||
assert paused_human_started.id == resumed_human_started.id
|
||||
|
||||
assert baseline_state.outputs == resumed_state.outputs
|
||||
assert _segment_value(baseline_state.variable_pool, ("human", "__action_id")) == _segment_value(
|
||||
resumed_state.variable_pool, ("human", "__action_id")
|
||||
)
|
||||
assert baseline_state.graph_execution.completed
|
||||
assert resumed_state.graph_execution.completed
|
||||
@ -7,6 +7,7 @@ from core.workflow.nodes.base.node import Node
|
||||
# Ensures that all node classes are imported.
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
# Ensure `NODE_TYPE_CLASSES_MAPPING` is used and not automatically removed.
|
||||
_ = NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
|
||||
@ -45,7 +46,9 @@ def test_ensure_subclasses_of_base_node_has_node_type_and_version_method_defined
|
||||
assert isinstance(cls.node_type, NodeType)
|
||||
assert isinstance(node_version, str)
|
||||
node_type_and_version = (node_type, node_version)
|
||||
assert node_type_and_version not in type_version_set
|
||||
assert node_type_and_version not in type_version_set, (
|
||||
f"Duplicate node type and version for class: {cls=} {node_type_and_version=}"
|
||||
)
|
||||
type_version_set.add(node_type_and_version)
|
||||
|
||||
|
||||
|
||||
@ -0,0 +1 @@
|
||||
# Unit tests for human input node
|
||||
@ -0,0 +1,16 @@
|
||||
from core.workflow.nodes.human_input.entities import EmailDeliveryConfig, EmailRecipients
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
|
||||
def test_render_body_template_replaces_variable_values():
|
||||
config = EmailDeliveryConfig(
|
||||
recipients=EmailRecipients(),
|
||||
subject="Subject",
|
||||
body="Hello {{#node1.value#}} {{#url#}}",
|
||||
)
|
||||
variable_pool = VariablePool()
|
||||
variable_pool.add(["node1", "value"], "World")
|
||||
|
||||
result = config.render_body_template(body=config.body, url="https://example.com", variable_pool=variable_pool)
|
||||
|
||||
assert result == "Hello World https://example.com"
|
||||
@ -0,0 +1,597 @@
|
||||
"""
|
||||
Unit tests for human input node entities.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.node_events import PauseRequestedEvent
|
||||
from core.workflow.node_events.node import StreamCompletedEvent
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
EmailDeliveryConfig,
|
||||
EmailDeliveryMethod,
|
||||
EmailRecipients,
|
||||
ExternalRecipient,
|
||||
FormInput,
|
||||
FormInputDefault,
|
||||
HumanInputNodeData,
|
||||
MemberRecipient,
|
||||
UserAction,
|
||||
WebAppDeliveryMethod,
|
||||
_WebAppDeliveryConfig,
|
||||
)
|
||||
from core.workflow.nodes.human_input.enums import (
|
||||
ButtonStyle,
|
||||
DeliveryMethodType,
|
||||
EmailRecipientType,
|
||||
FormInputType,
|
||||
PlaceholderType,
|
||||
TimeoutUnit,
|
||||
)
|
||||
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
|
||||
from core.workflow.repositories.human_input_form_repository import HumanInputFormRepository
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from tests.unit_tests.core.workflow.graph_engine.human_input_test_utils import InMemoryHumanInputFormRepository
|
||||
|
||||
|
||||
class TestDeliveryMethod:
|
||||
"""Test DeliveryMethod entity."""
|
||||
|
||||
def test_webapp_delivery_method(self):
|
||||
"""Test webapp delivery method creation."""
|
||||
delivery_method = WebAppDeliveryMethod(enabled=True, config=_WebAppDeliveryConfig())
|
||||
|
||||
assert delivery_method.type == DeliveryMethodType.WEBAPP
|
||||
assert delivery_method.enabled is True
|
||||
assert isinstance(delivery_method.config, _WebAppDeliveryConfig)
|
||||
|
||||
def test_email_delivery_method(self):
|
||||
"""Test email delivery method creation."""
|
||||
recipients = EmailRecipients(
|
||||
whole_workspace=False,
|
||||
items=[
|
||||
MemberRecipient(type=EmailRecipientType.MEMBER, user_id="test-user-123"),
|
||||
ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="test@example.com"),
|
||||
],
|
||||
)
|
||||
|
||||
config = EmailDeliveryConfig(
|
||||
recipients=recipients, subject="Test Subject", body="Test body with {{#url#}} placeholder"
|
||||
)
|
||||
|
||||
delivery_method = EmailDeliveryMethod(enabled=True, config=config)
|
||||
|
||||
assert delivery_method.type == DeliveryMethodType.EMAIL
|
||||
assert delivery_method.enabled is True
|
||||
assert isinstance(delivery_method.config, EmailDeliveryConfig)
|
||||
assert delivery_method.config.subject == "Test Subject"
|
||||
assert len(delivery_method.config.recipients.items) == 2
|
||||
|
||||
|
||||
class TestFormInput:
|
||||
"""Test FormInput entity."""
|
||||
|
||||
def test_text_input_with_constant_default(self):
|
||||
"""Test text input with constant default value."""
|
||||
default = FormInputDefault(type=PlaceholderType.CONSTANT, value="Enter your response here...")
|
||||
|
||||
form_input = FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="user_input", default=default)
|
||||
|
||||
assert form_input.type == FormInputType.TEXT_INPUT
|
||||
assert form_input.output_variable_name == "user_input"
|
||||
assert form_input.default.type == PlaceholderType.CONSTANT
|
||||
assert form_input.default.value == "Enter your response here..."
|
||||
|
||||
def test_text_input_with_variable_default(self):
|
||||
"""Test text input with variable default value."""
|
||||
default = FormInputDefault(type=PlaceholderType.VARIABLE, selector=["node_123", "output_var"])
|
||||
|
||||
form_input = FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="user_input", default=default)
|
||||
|
||||
assert form_input.default.type == PlaceholderType.VARIABLE
|
||||
assert form_input.default.selector == ["node_123", "output_var"]
|
||||
|
||||
def test_form_input_without_default(self):
|
||||
"""Test form input without default value."""
|
||||
form_input = FormInput(type=FormInputType.PARAGRAPH, output_variable_name="description")
|
||||
|
||||
assert form_input.type == FormInputType.PARAGRAPH
|
||||
assert form_input.output_variable_name == "description"
|
||||
assert form_input.default is None
|
||||
|
||||
|
||||
class TestUserAction:
|
||||
"""Test UserAction entity."""
|
||||
|
||||
def test_user_action_creation(self):
|
||||
"""Test user action creation."""
|
||||
action = UserAction(id="approve", title="Approve", button_style=ButtonStyle.PRIMARY)
|
||||
|
||||
assert action.id == "approve"
|
||||
assert action.title == "Approve"
|
||||
assert action.button_style == ButtonStyle.PRIMARY
|
||||
|
||||
def test_user_action_default_button_style(self):
|
||||
"""Test user action with default button style."""
|
||||
action = UserAction(id="cancel", title="Cancel")
|
||||
|
||||
assert action.button_style == ButtonStyle.DEFAULT
|
||||
|
||||
def test_user_action_length_boundaries(self):
|
||||
"""Test user action id and title length boundaries."""
|
||||
action = UserAction(id="a" * 20, title="b" * 20)
|
||||
|
||||
assert action.id == "a" * 20
|
||||
assert action.title == "b" * 20
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("field_name", "value"),
|
||||
[
|
||||
("id", "a" * 21),
|
||||
("title", "b" * 21),
|
||||
],
|
||||
)
|
||||
def test_user_action_length_limits(self, field_name: str, value: str):
|
||||
"""User action fields should enforce max length."""
|
||||
data = {"id": "approve", "title": "Approve"}
|
||||
data[field_name] = value
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
UserAction(**data)
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert any(error["loc"] == (field_name,) and error["type"] == "string_too_long" for error in errors)
|
||||
|
||||
|
||||
class TestHumanInputNodeData:
|
||||
"""Test HumanInputNodeData entity."""
|
||||
|
||||
def test_valid_node_data_creation(self):
|
||||
"""Test creating valid human input node data."""
|
||||
delivery_methods = [WebAppDeliveryMethod(enabled=True, config=_WebAppDeliveryConfig())]
|
||||
|
||||
inputs = [
|
||||
FormInput(
|
||||
type=FormInputType.TEXT_INPUT,
|
||||
output_variable_name="content",
|
||||
default=FormInputDefault(type=PlaceholderType.CONSTANT, value="Enter content..."),
|
||||
)
|
||||
]
|
||||
|
||||
user_actions = [UserAction(id="submit", title="Submit", button_style=ButtonStyle.PRIMARY)]
|
||||
|
||||
node_data = HumanInputNodeData(
|
||||
title="Human Input Test",
|
||||
desc="Test node description",
|
||||
delivery_methods=delivery_methods,
|
||||
form_content="# Test Form\n\nPlease provide input:\n\n{{#$output.content#}}",
|
||||
inputs=inputs,
|
||||
user_actions=user_actions,
|
||||
timeout=24,
|
||||
timeout_unit=TimeoutUnit.HOUR,
|
||||
)
|
||||
|
||||
assert node_data.title == "Human Input Test"
|
||||
assert node_data.desc == "Test node description"
|
||||
assert len(node_data.delivery_methods) == 1
|
||||
assert node_data.form_content.startswith("# Test Form")
|
||||
assert len(node_data.inputs) == 1
|
||||
assert len(node_data.user_actions) == 1
|
||||
assert node_data.timeout == 24
|
||||
assert node_data.timeout_unit == TimeoutUnit.HOUR
|
||||
|
||||
def test_node_data_with_multiple_delivery_methods(self):
|
||||
"""Test node data with multiple delivery methods."""
|
||||
delivery_methods = [
|
||||
WebAppDeliveryMethod(enabled=True, config=_WebAppDeliveryConfig()),
|
||||
EmailDeliveryMethod(
|
||||
enabled=False, # Disabled method should be fine
|
||||
config=EmailDeliveryConfig(
|
||||
subject="Hi there", body="", recipients=EmailRecipients(whole_workspace=True)
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
node_data = HumanInputNodeData(
|
||||
title="Test Node", delivery_methods=delivery_methods, timeout=1, timeout_unit=TimeoutUnit.DAY
|
||||
)
|
||||
|
||||
assert len(node_data.delivery_methods) == 2
|
||||
assert node_data.timeout == 1
|
||||
assert node_data.timeout_unit == TimeoutUnit.DAY
|
||||
|
||||
def test_node_data_defaults(self):
|
||||
"""Test node data with default values."""
|
||||
node_data = HumanInputNodeData(title="Test Node")
|
||||
|
||||
assert node_data.title == "Test Node"
|
||||
assert node_data.desc is None
|
||||
assert node_data.delivery_methods == []
|
||||
assert node_data.form_content == ""
|
||||
assert node_data.inputs == []
|
||||
assert node_data.user_actions == []
|
||||
assert node_data.timeout == 36
|
||||
assert node_data.timeout_unit == TimeoutUnit.HOUR
|
||||
|
||||
def test_duplicate_input_output_variable_name_raises_validation_error(self):
|
||||
"""Duplicate form input output_variable_name should raise validation error."""
|
||||
duplicate_inputs = [
|
||||
FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="content"),
|
||||
FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="content"),
|
||||
]
|
||||
|
||||
with pytest.raises(ValidationError, match="duplicated output_variable_name 'content'"):
|
||||
HumanInputNodeData(title="Test Node", inputs=duplicate_inputs)
|
||||
|
||||
def test_duplicate_user_action_ids_raise_validation_error(self):
|
||||
"""Duplicate user action ids should raise validation error."""
|
||||
duplicate_actions = [
|
||||
UserAction(id="submit", title="Submit"),
|
||||
UserAction(id="submit", title="Submit Again"),
|
||||
]
|
||||
|
||||
with pytest.raises(ValidationError, match="duplicated user action id 'submit'"):
|
||||
HumanInputNodeData(title="Test Node", user_actions=duplicate_actions)
|
||||
|
||||
def test_extract_outputs_field_names(self):
|
||||
content = r"""This is titile {{#start.title#}}
|
||||
|
||||
A content is required:
|
||||
|
||||
{{#$output.content#}}
|
||||
|
||||
A ending is required:
|
||||
|
||||
{{#$output.ending#}}
|
||||
"""
|
||||
|
||||
node_data = HumanInputNodeData(title="Human Input", form_content=content)
|
||||
field_names = node_data.outputs_field_names()
|
||||
assert field_names == ["content", "ending"]
|
||||
|
||||
|
||||
class TestRecipients:
|
||||
"""Test email recipient entities."""
|
||||
|
||||
def test_member_recipient(self):
|
||||
"""Test member recipient creation."""
|
||||
recipient = MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123")
|
||||
|
||||
assert recipient.type == EmailRecipientType.MEMBER
|
||||
assert recipient.user_id == "user-123"
|
||||
|
||||
def test_external_recipient(self):
|
||||
"""Test external recipient creation."""
|
||||
recipient = ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="test@example.com")
|
||||
|
||||
assert recipient.type == EmailRecipientType.EXTERNAL
|
||||
assert recipient.email == "test@example.com"
|
||||
|
||||
def test_email_recipients_whole_workspace(self):
|
||||
"""Test email recipients with whole workspace enabled."""
|
||||
recipients = EmailRecipients(
|
||||
whole_workspace=True, items=[MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123")]
|
||||
)
|
||||
|
||||
assert recipients.whole_workspace is True
|
||||
assert len(recipients.items) == 1 # Items are preserved even when whole_workspace is True
|
||||
|
||||
def test_email_recipients_specific_users(self):
|
||||
"""Test email recipients with specific users."""
|
||||
recipients = EmailRecipients(
|
||||
whole_workspace=False,
|
||||
items=[
|
||||
MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123"),
|
||||
ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="external@example.com"),
|
||||
],
|
||||
)
|
||||
|
||||
assert recipients.whole_workspace is False
|
||||
assert len(recipients.items) == 2
|
||||
assert recipients.items[0].user_id == "user-123"
|
||||
assert recipients.items[1].email == "external@example.com"
|
||||
|
||||
|
||||
class TestHumanInputNodeVariableResolution:
|
||||
"""Tests for resolving variable-based defaults in HumanInputNode."""
|
||||
|
||||
def test_resolves_variable_defaults(self):
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="user",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
workflow_execution_id="exec-1",
|
||||
),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
variable_pool.add(("start", "name"), "Jane Doe")
|
||||
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config={"nodes": [], "edges": []},
|
||||
user_id="user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
node_data = HumanInputNodeData(
|
||||
title="Human Input",
|
||||
form_content="Provide your name",
|
||||
inputs=[
|
||||
FormInput(
|
||||
type=FormInputType.TEXT_INPUT,
|
||||
output_variable_name="user_name",
|
||||
default=FormInputDefault(type=PlaceholderType.VARIABLE, selector=["start", "name"]),
|
||||
),
|
||||
FormInput(
|
||||
type=FormInputType.TEXT_INPUT,
|
||||
output_variable_name="user_email",
|
||||
default=FormInputDefault(type=PlaceholderType.CONSTANT, value="foo@example.com"),
|
||||
),
|
||||
],
|
||||
user_actions=[UserAction(id="submit", title="Submit")],
|
||||
)
|
||||
config = {"id": "human", "data": node_data.model_dump()}
|
||||
|
||||
mock_repo = MagicMock(spec=HumanInputFormRepository)
|
||||
mock_repo.get_form.return_value = None
|
||||
mock_repo.create_form.return_value = SimpleNamespace(
|
||||
id="form-1",
|
||||
rendered_content="Provide your name",
|
||||
web_app_token="token",
|
||||
recipients=[],
|
||||
submitted=False,
|
||||
)
|
||||
|
||||
node = HumanInputNode(
|
||||
id=config["id"],
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
form_repository=mock_repo,
|
||||
)
|
||||
|
||||
run_result = node._run()
|
||||
pause_event = next(run_result)
|
||||
|
||||
assert isinstance(pause_event, PauseRequestedEvent)
|
||||
expected_values = {"user_name": "Jane Doe"}
|
||||
assert pause_event.reason.resolved_default_values == expected_values
|
||||
|
||||
params = mock_repo.create_form.call_args.args[0]
|
||||
assert params.resolved_default_values == expected_values
|
||||
|
||||
def test_debugger_falls_back_to_recipient_token_when_webapp_disabled(self):
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="user",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
workflow_execution_id="exec-2",
|
||||
),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config={"nodes": [], "edges": []},
|
||||
user_id="user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
node_data = HumanInputNodeData(
|
||||
title="Human Input",
|
||||
form_content="Provide your name",
|
||||
inputs=[],
|
||||
user_actions=[UserAction(id="submit", title="Submit")],
|
||||
)
|
||||
config = {"id": "human", "data": node_data.model_dump()}
|
||||
|
||||
mock_repo = MagicMock(spec=HumanInputFormRepository)
|
||||
mock_repo.get_form.return_value = None
|
||||
mock_repo.create_form.return_value = SimpleNamespace(
|
||||
id="form-2",
|
||||
rendered_content="Provide your name",
|
||||
web_app_token="console-token",
|
||||
recipients=[SimpleNamespace(token="recipient-token")],
|
||||
submitted=False,
|
||||
)
|
||||
|
||||
node = HumanInputNode(
|
||||
id=config["id"],
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
form_repository=mock_repo,
|
||||
)
|
||||
|
||||
run_result = node._run()
|
||||
pause_event = next(run_result)
|
||||
|
||||
assert isinstance(pause_event, PauseRequestedEvent)
|
||||
assert pause_event.reason.form_token == "console-token"
|
||||
|
||||
def test_debugger_debug_mode_overrides_email_recipients(self):
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="user-123",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
workflow_execution_id="exec-3",
|
||||
),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config={"nodes": [], "edges": []},
|
||||
user_id="user-123",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
node_data = HumanInputNodeData(
|
||||
title="Human Input",
|
||||
form_content="Provide your name",
|
||||
inputs=[],
|
||||
user_actions=[UserAction(id="submit", title="Submit")],
|
||||
delivery_methods=[
|
||||
EmailDeliveryMethod(
|
||||
enabled=True,
|
||||
config=EmailDeliveryConfig(
|
||||
recipients=EmailRecipients(
|
||||
whole_workspace=False,
|
||||
items=[ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="target@example.com")],
|
||||
),
|
||||
subject="Subject",
|
||||
body="Body",
|
||||
debug_mode=True,
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
config = {"id": "human", "data": node_data.model_dump()}
|
||||
|
||||
mock_repo = MagicMock(spec=HumanInputFormRepository)
|
||||
mock_repo.get_form.return_value = None
|
||||
mock_repo.create_form.return_value = SimpleNamespace(
|
||||
id="form-3",
|
||||
rendered_content="Provide your name",
|
||||
web_app_token="token",
|
||||
recipients=[],
|
||||
submitted=False,
|
||||
)
|
||||
|
||||
node = HumanInputNode(
|
||||
id=config["id"],
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
form_repository=mock_repo,
|
||||
)
|
||||
|
||||
run_result = node._run()
|
||||
pause_event = next(run_result)
|
||||
assert isinstance(pause_event, PauseRequestedEvent)
|
||||
|
||||
params = mock_repo.create_form.call_args.args[0]
|
||||
assert len(params.delivery_methods) == 1
|
||||
method = params.delivery_methods[0]
|
||||
assert isinstance(method, EmailDeliveryMethod)
|
||||
assert method.config.debug_mode is True
|
||||
assert method.config.recipients.whole_workspace is False
|
||||
assert len(method.config.recipients.items) == 1
|
||||
recipient = method.config.recipients.items[0]
|
||||
assert isinstance(recipient, MemberRecipient)
|
||||
assert recipient.user_id == "user-123"
|
||||
|
||||
|
||||
class TestValidation:
|
||||
"""Test validation scenarios."""
|
||||
|
||||
def test_invalid_form_input_type(self):
|
||||
"""Test validation with invalid form input type."""
|
||||
with pytest.raises(ValidationError):
|
||||
FormInput(
|
||||
type="invalid-type", # Invalid type
|
||||
output_variable_name="test",
|
||||
)
|
||||
|
||||
def test_invalid_button_style(self):
|
||||
"""Test validation with invalid button style."""
|
||||
with pytest.raises(ValidationError):
|
||||
UserAction(
|
||||
id="test",
|
||||
title="Test",
|
||||
button_style="invalid-style", # Invalid style
|
||||
)
|
||||
|
||||
def test_invalid_timeout_unit(self):
|
||||
"""Test validation with invalid timeout unit."""
|
||||
with pytest.raises(ValidationError):
|
||||
HumanInputNodeData(
|
||||
title="Test",
|
||||
timeout_unit="invalid-unit", # Invalid unit
|
||||
)
|
||||
|
||||
|
||||
class TestHumanInputNodeRenderedContent:
|
||||
"""Tests for rendering submitted content."""
|
||||
|
||||
def test_replaces_outputs_placeholders_after_submission(self):
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(
|
||||
user_id="user",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
workflow_execution_id="exec-1",
|
||||
),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0)
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config={"nodes": [], "edges": []},
|
||||
user_id="user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
node_data = HumanInputNodeData(
|
||||
title="Human Input",
|
||||
form_content="Name: {{#$output.name#}}",
|
||||
inputs=[
|
||||
FormInput(
|
||||
type=FormInputType.TEXT_INPUT,
|
||||
output_variable_name="name",
|
||||
)
|
||||
],
|
||||
user_actions=[UserAction(id="approve", title="Approve")],
|
||||
)
|
||||
config = {"id": "human", "data": node_data.model_dump()}
|
||||
|
||||
form_repository = InMemoryHumanInputFormRepository()
|
||||
node = HumanInputNode(
|
||||
id=config["id"],
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
form_repository=form_repository,
|
||||
)
|
||||
|
||||
pause_gen = node._run()
|
||||
pause_event = next(pause_gen)
|
||||
assert isinstance(pause_event, PauseRequestedEvent)
|
||||
with pytest.raises(StopIteration):
|
||||
next(pause_gen)
|
||||
|
||||
form_repository.set_submission(action_id="approve", form_data={"name": "Alice"})
|
||||
|
||||
events = list(node._run())
|
||||
last_event = events[-1]
|
||||
assert isinstance(last_event, StreamCompletedEvent)
|
||||
node_run_result = last_event.node_run_result
|
||||
assert node_run_result.outputs["__rendered_content"] == "Name: Alice"
|
||||
@ -0,0 +1,172 @@
|
||||
import datetime
|
||||
from types import SimpleNamespace
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph_events import (
|
||||
NodeRunHumanInputFormFilledEvent,
|
||||
NodeRunHumanInputFormTimeoutEvent,
|
||||
NodeRunStartedEvent,
|
||||
)
|
||||
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
|
||||
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.enums import UserFrom
|
||||
|
||||
|
||||
class _FakeFormRepository:
|
||||
def __init__(self, form):
|
||||
self._form = form
|
||||
|
||||
def get_form(self, *_args, **_kwargs):
|
||||
return self._form
|
||||
|
||||
|
||||
def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#}}") -> HumanInputNode:
|
||||
system_variables = SystemVariable.default()
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]),
|
||||
start_at=0.0,
|
||||
)
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config={"nodes": [], "edges": []},
|
||||
user_id="user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
config = {
|
||||
"id": "node-1",
|
||||
"type": NodeType.HUMAN_INPUT.value,
|
||||
"data": {
|
||||
"title": "Human Input",
|
||||
"form_content": form_content,
|
||||
"inputs": [
|
||||
{
|
||||
"type": "text_input",
|
||||
"output_variable_name": "name",
|
||||
"default": {"type": "constant", "value": ""},
|
||||
}
|
||||
],
|
||||
"user_actions": [
|
||||
{
|
||||
"id": "Accept",
|
||||
"title": "Approve",
|
||||
"button_style": "default",
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
fake_form = SimpleNamespace(
|
||||
id="form-1",
|
||||
rendered_content=form_content,
|
||||
submitted=True,
|
||||
selected_action_id="Accept",
|
||||
submitted_data={"name": "Alice"},
|
||||
status=HumanInputFormStatus.SUBMITTED,
|
||||
expiration_time=naive_utc_now() + datetime.timedelta(days=1),
|
||||
)
|
||||
|
||||
repo = _FakeFormRepository(fake_form)
|
||||
return HumanInputNode(
|
||||
id="node-1",
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
form_repository=repo,
|
||||
)
|
||||
|
||||
|
||||
def _build_timeout_node() -> HumanInputNode:
|
||||
system_variables = SystemVariable.default()
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]),
|
||||
start_at=0.0,
|
||||
)
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config={"nodes": [], "edges": []},
|
||||
user_id="user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
config = {
|
||||
"id": "node-1",
|
||||
"type": NodeType.HUMAN_INPUT.value,
|
||||
"data": {
|
||||
"title": "Human Input",
|
||||
"form_content": "Please enter your name:\n\n{{#$output.name#}}",
|
||||
"inputs": [
|
||||
{
|
||||
"type": "text_input",
|
||||
"output_variable_name": "name",
|
||||
"default": {"type": "constant", "value": ""},
|
||||
}
|
||||
],
|
||||
"user_actions": [
|
||||
{
|
||||
"id": "Accept",
|
||||
"title": "Approve",
|
||||
"button_style": "default",
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
fake_form = SimpleNamespace(
|
||||
id="form-1",
|
||||
rendered_content="content",
|
||||
submitted=False,
|
||||
selected_action_id=None,
|
||||
submitted_data=None,
|
||||
status=HumanInputFormStatus.TIMEOUT,
|
||||
expiration_time=naive_utc_now() - datetime.timedelta(minutes=1),
|
||||
)
|
||||
|
||||
repo = _FakeFormRepository(fake_form)
|
||||
return HumanInputNode(
|
||||
id="node-1",
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
form_repository=repo,
|
||||
)
|
||||
|
||||
|
||||
def test_human_input_node_emits_form_filled_event_before_succeeded():
|
||||
node = _build_node()
|
||||
|
||||
events = list(node.run())
|
||||
|
||||
assert isinstance(events[0], NodeRunStartedEvent)
|
||||
assert isinstance(events[1], NodeRunHumanInputFormFilledEvent)
|
||||
|
||||
filled_event = events[1]
|
||||
assert filled_event.node_title == "Human Input"
|
||||
assert filled_event.rendered_content.endswith("Alice")
|
||||
assert filled_event.action_id == "Accept"
|
||||
assert filled_event.action_text == "Approve"
|
||||
|
||||
|
||||
def test_human_input_node_emits_timeout_event_before_succeeded():
|
||||
node = _build_timeout_node()
|
||||
|
||||
events = list(node.run())
|
||||
|
||||
assert isinstance(events[0], NodeRunStartedEvent)
|
||||
assert isinstance(events[1], NodeRunHumanInputFormTimeoutEvent)
|
||||
|
||||
timeout_event = events[1]
|
||||
assert timeout_event.node_title == "Human Input"
|
||||
@ -104,6 +104,7 @@ class TestCelerySSLConfiguration:
|
||||
def test_celery_init_applies_ssl_to_broker_and_backend(self):
|
||||
"""Test that SSL options are applied to both broker and backend when using Redis."""
|
||||
mock_config = MagicMock()
|
||||
mock_config.HUMAN_INPUT_TIMEOUT_TASK_INTERVAL = 1
|
||||
mock_config.CELERY_BROKER_URL = "redis://localhost:6379/0"
|
||||
mock_config.CELERY_BACKEND = "redis"
|
||||
mock_config.CELERY_RESULT_BACKEND = "redis://localhost:6379/0"
|
||||
|
||||
20
api/tests/unit_tests/extensions/test_pubsub_channel.py
Normal file
20
api/tests/unit_tests/extensions/test_pubsub_channel.py
Normal file
@ -0,0 +1,20 @@
|
||||
from configs import dify_config
|
||||
from extensions import ext_redis
|
||||
from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel
|
||||
from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel
|
||||
|
||||
|
||||
def test_get_pubsub_broadcast_channel_defaults_to_pubsub(monkeypatch):
|
||||
monkeypatch.setattr(dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "pubsub")
|
||||
|
||||
channel = ext_redis.get_pubsub_broadcast_channel()
|
||||
|
||||
assert isinstance(channel, RedisBroadcastChannel)
|
||||
|
||||
|
||||
def test_get_pubsub_broadcast_channel_sharded(monkeypatch):
|
||||
monkeypatch.setattr(dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "sharded")
|
||||
|
||||
channel = ext_redis.get_pubsub_broadcast_channel()
|
||||
|
||||
assert isinstance(channel, ShardedRedisBroadcastChannel)
|
||||
1
api/tests/unit_tests/libs/_human_input/__init__.py
Normal file
1
api/tests/unit_tests/libs/_human_input/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# Treat this directory as a package so support modules can be imported relatively.
|
||||
249
api/tests/unit_tests/libs/_human_input/support.py
Normal file
249
api/tests/unit_tests/libs/_human_input/support.py
Normal file
@ -0,0 +1,249 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.nodes.human_input.entities import FormInput
|
||||
from core.workflow.nodes.human_input.enums import TimeoutUnit
|
||||
|
||||
|
||||
# Exceptions
|
||||
class HumanInputError(Exception):
|
||||
error_code: str = "unknown"
|
||||
|
||||
def __init__(self, message: str = "", error_code: str | None = None):
|
||||
super().__init__(message)
|
||||
self.message = message or self.__class__.__name__
|
||||
if error_code:
|
||||
self.error_code = error_code
|
||||
|
||||
|
||||
class FormNotFoundError(HumanInputError):
|
||||
error_code = "form_not_found"
|
||||
|
||||
|
||||
class FormExpiredError(HumanInputError):
|
||||
error_code = "human_input_form_expired"
|
||||
|
||||
|
||||
class FormAlreadySubmittedError(HumanInputError):
|
||||
error_code = "human_input_form_submitted"
|
||||
|
||||
|
||||
class InvalidFormDataError(HumanInputError):
|
||||
error_code = "invalid_form_data"
|
||||
|
||||
|
||||
# Models
|
||||
@dataclass
|
||||
class HumanInputForm:
|
||||
form_id: str
|
||||
workflow_run_id: str
|
||||
node_id: str
|
||||
tenant_id: str
|
||||
app_id: str | None
|
||||
form_content: str
|
||||
inputs: list[FormInput]
|
||||
user_actions: list[dict[str, Any]]
|
||||
timeout: int
|
||||
timeout_unit: TimeoutUnit
|
||||
form_token: str | None = None
|
||||
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||
expires_at: datetime | None = None
|
||||
submitted_at: datetime | None = None
|
||||
submitted_data: dict[str, Any] | None = None
|
||||
submitted_action: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.expires_at is None:
|
||||
self.calculate_expiration()
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
return self.expires_at is not None and datetime.utcnow() > self.expires_at
|
||||
|
||||
@property
|
||||
def is_submitted(self) -> bool:
|
||||
return self.submitted_at is not None
|
||||
|
||||
def mark_submitted(self, inputs: dict[str, Any], action: str) -> None:
|
||||
self.submitted_data = inputs
|
||||
self.submitted_action = action
|
||||
self.submitted_at = datetime.utcnow()
|
||||
|
||||
def submit(self, inputs: dict[str, Any], action: str) -> None:
|
||||
self.mark_submitted(inputs, action)
|
||||
|
||||
def calculate_expiration(self) -> None:
|
||||
start = self.created_at
|
||||
if self.timeout_unit == TimeoutUnit.HOUR:
|
||||
self.expires_at = start + timedelta(hours=self.timeout)
|
||||
elif self.timeout_unit == TimeoutUnit.DAY:
|
||||
self.expires_at = start + timedelta(days=self.timeout)
|
||||
else:
|
||||
raise ValueError(f"Unsupported timeout unit {self.timeout_unit}")
|
||||
|
||||
def to_response_dict(self, *, include_site_info: bool) -> dict[str, Any]:
|
||||
inputs_response = [
|
||||
{
|
||||
"type": form_input.type.name.lower().replace("_", "-"),
|
||||
"output_variable_name": form_input.output_variable_name,
|
||||
}
|
||||
for form_input in self.inputs
|
||||
]
|
||||
response = {
|
||||
"form_content": self.form_content,
|
||||
"inputs": inputs_response,
|
||||
"user_actions": self.user_actions,
|
||||
}
|
||||
if include_site_info:
|
||||
response["site"] = {"app_id": self.app_id, "title": "Workflow Form"}
|
||||
return response
|
||||
|
||||
|
||||
@dataclass
|
||||
class FormSubmissionData:
|
||||
form_id: str
|
||||
inputs: dict[str, Any]
|
||||
action: str
|
||||
submitted_at: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
@classmethod
|
||||
def from_request(cls, form_id: str, request: FormSubmissionRequest) -> FormSubmissionData: # type: ignore
|
||||
return cls(form_id=form_id, inputs=request.inputs, action=request.action)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FormSubmissionRequest:
|
||||
inputs: dict[str, Any]
|
||||
action: str
|
||||
|
||||
|
||||
# Repository
|
||||
class InMemoryFormRepository:
|
||||
"""
|
||||
Simple in-memory repository used by unit tests.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._forms: dict[str, HumanInputForm] = {}
|
||||
|
||||
@property
|
||||
def forms(self) -> dict[str, HumanInputForm]:
|
||||
return self._forms
|
||||
|
||||
def save(self, form: HumanInputForm) -> None:
|
||||
self._forms[form.form_id] = form
|
||||
|
||||
def get_by_id(self, form_id: str) -> HumanInputForm | None:
|
||||
return self._forms.get(form_id)
|
||||
|
||||
def get_by_token(self, token: str) -> HumanInputForm | None:
|
||||
for form in self._forms.values():
|
||||
if form.form_token == token:
|
||||
return form
|
||||
return None
|
||||
|
||||
def delete(self, form_id: str) -> None:
|
||||
self._forms.pop(form_id, None)
|
||||
|
||||
|
||||
# Service
|
||||
class FormService:
|
||||
"""Service layer for managing human input forms in tests."""
|
||||
|
||||
def __init__(self, repository: InMemoryFormRepository):
|
||||
self.repository = repository
|
||||
|
||||
def create_form(
|
||||
self,
|
||||
*,
|
||||
form_id: str,
|
||||
workflow_run_id: str,
|
||||
node_id: str,
|
||||
tenant_id: str,
|
||||
app_id: str | None,
|
||||
form_content: str,
|
||||
inputs,
|
||||
user_actions,
|
||||
timeout: int,
|
||||
timeout_unit: TimeoutUnit,
|
||||
form_token: str | None = None,
|
||||
) -> HumanInputForm:
|
||||
form = HumanInputForm(
|
||||
form_id=form_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
node_id=node_id,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
form_content=form_content,
|
||||
inputs=list(inputs),
|
||||
user_actions=[{"id": action.id, "title": action.title} for action in user_actions],
|
||||
timeout=timeout,
|
||||
timeout_unit=timeout_unit,
|
||||
form_token=form_token,
|
||||
)
|
||||
form.calculate_expiration()
|
||||
self.repository.save(form)
|
||||
return form
|
||||
|
||||
def get_form_by_id(self, form_id: str) -> HumanInputForm:
|
||||
form = self.repository.get_by_id(form_id)
|
||||
if form is None:
|
||||
raise FormNotFoundError()
|
||||
return form
|
||||
|
||||
def get_form_by_token(self, token: str) -> HumanInputForm:
|
||||
form = self.repository.get_by_token(token)
|
||||
if form is None:
|
||||
raise FormNotFoundError()
|
||||
return form
|
||||
|
||||
def get_form_definition(self, form_id: str, *, is_token: bool) -> dict:
|
||||
form = self.get_form_by_token(form_id) if is_token else self.get_form_by_id(form_id)
|
||||
if form.is_expired:
|
||||
raise FormExpiredError()
|
||||
if form.is_submitted:
|
||||
raise FormAlreadySubmittedError()
|
||||
|
||||
definition = {
|
||||
"form_content": form.form_content,
|
||||
"inputs": form.inputs,
|
||||
"user_actions": form.user_actions,
|
||||
}
|
||||
if is_token:
|
||||
definition["site"] = {"title": "Workflow Form"}
|
||||
return definition
|
||||
|
||||
def submit_form(self, form_id: str, submission_data: FormSubmissionData, *, is_token: bool) -> None:
|
||||
form = self.get_form_by_token(form_id) if is_token else self.get_form_by_id(form_id)
|
||||
if form.is_expired:
|
||||
raise FormExpiredError()
|
||||
if form.is_submitted:
|
||||
raise FormAlreadySubmittedError()
|
||||
|
||||
self._validate_submission(form=form, submission_data=submission_data)
|
||||
form.mark_submitted(inputs=submission_data.inputs, action=submission_data.action)
|
||||
self.repository.save(form)
|
||||
|
||||
def cleanup_expired_forms(self) -> int:
|
||||
expired_ids = [form_id for form_id, form in list(self.repository.forms.items()) if form.is_expired]
|
||||
for form_id in expired_ids:
|
||||
self.repository.delete(form_id)
|
||||
return len(expired_ids)
|
||||
|
||||
def _validate_submission(self, form: HumanInputForm, submission_data: FormSubmissionData) -> None:
|
||||
defined_actions = {action["id"] for action in form.user_actions}
|
||||
if submission_data.action not in defined_actions:
|
||||
raise InvalidFormDataError(f"Invalid action: {submission_data.action}")
|
||||
|
||||
missing_inputs = []
|
||||
for form_input in form.inputs:
|
||||
if form_input.output_variable_name not in submission_data.inputs:
|
||||
missing_inputs.append(form_input.output_variable_name)
|
||||
|
||||
if missing_inputs:
|
||||
raise InvalidFormDataError(f"Missing required inputs: {', '.join(missing_inputs)}")
|
||||
|
||||
# Extra inputs are allowed; no further validation required.
|
||||
326
api/tests/unit_tests/libs/_human_input/test_form_service.py
Normal file
326
api/tests/unit_tests/libs/_human_input/test_form_service.py
Normal file
@ -0,0 +1,326 @@
|
||||
"""
|
||||
Unit tests for FormService.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
FormInput,
|
||||
UserAction,
|
||||
)
|
||||
from core.workflow.nodes.human_input.enums import (
|
||||
FormInputType,
|
||||
TimeoutUnit,
|
||||
)
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
from .support import (
|
||||
FormAlreadySubmittedError,
|
||||
FormExpiredError,
|
||||
FormNotFoundError,
|
||||
FormService,
|
||||
FormSubmissionData,
|
||||
InMemoryFormRepository,
|
||||
InvalidFormDataError,
|
||||
)
|
||||
|
||||
|
||||
class TestFormService:
|
||||
"""Test FormService functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def repository(self):
|
||||
"""Create in-memory repository for testing."""
|
||||
return InMemoryFormRepository()
|
||||
|
||||
@pytest.fixture
|
||||
def form_service(self, repository):
|
||||
"""Create FormService with in-memory repository."""
|
||||
return FormService(repository)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_form_data(self):
|
||||
"""Create sample form data."""
|
||||
return {
|
||||
"form_id": "form-123",
|
||||
"workflow_run_id": "run-456",
|
||||
"node_id": "node-789",
|
||||
"tenant_id": "tenant-abc",
|
||||
"app_id": "app-def",
|
||||
"form_content": "# Test Form\n\nInput: {{#$output.input#}}",
|
||||
"inputs": [FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="input", default=None)],
|
||||
"user_actions": [UserAction(id="submit", title="Submit")],
|
||||
"timeout": 1,
|
||||
"timeout_unit": TimeoutUnit.HOUR,
|
||||
"form_token": "token-xyz",
|
||||
}
|
||||
|
||||
def test_create_form(self, form_service, sample_form_data):
|
||||
"""Test form creation."""
|
||||
form = form_service.create_form(**sample_form_data)
|
||||
|
||||
assert form.form_id == "form-123"
|
||||
assert form.workflow_run_id == "run-456"
|
||||
assert form.node_id == "node-789"
|
||||
assert form.tenant_id == "tenant-abc"
|
||||
assert form.app_id == "app-def"
|
||||
assert form.form_token == "token-xyz"
|
||||
assert form.timeout == 1
|
||||
assert form.timeout_unit == TimeoutUnit.HOUR
|
||||
assert form.expires_at is not None
|
||||
assert not form.is_expired
|
||||
assert not form.is_submitted
|
||||
|
||||
def test_get_form_by_id(self, form_service, sample_form_data):
|
||||
"""Test getting form by ID."""
|
||||
# Create form first
|
||||
created_form = form_service.create_form(**sample_form_data)
|
||||
|
||||
# Retrieve form
|
||||
retrieved_form = form_service.get_form_by_id("form-123")
|
||||
|
||||
assert retrieved_form.form_id == created_form.form_id
|
||||
assert retrieved_form.workflow_run_id == created_form.workflow_run_id
|
||||
|
||||
def test_get_form_by_id_not_found(self, form_service):
|
||||
"""Test getting non-existent form by ID."""
|
||||
with pytest.raises(FormNotFoundError) as exc_info:
|
||||
form_service.get_form_by_id("non-existent-form")
|
||||
|
||||
assert exc_info.value.error_code == "form_not_found"
|
||||
|
||||
def test_get_form_by_token(self, form_service, sample_form_data):
|
||||
"""Test getting form by token."""
|
||||
# Create form first
|
||||
created_form = form_service.create_form(**sample_form_data)
|
||||
|
||||
# Retrieve form by token
|
||||
retrieved_form = form_service.get_form_by_token("token-xyz")
|
||||
|
||||
assert retrieved_form.form_id == created_form.form_id
|
||||
assert retrieved_form.form_token == "token-xyz"
|
||||
|
||||
def test_get_form_by_token_not_found(self, form_service):
|
||||
"""Test getting non-existent form by token."""
|
||||
with pytest.raises(FormNotFoundError) as exc_info:
|
||||
form_service.get_form_by_token("non-existent-token")
|
||||
|
||||
assert exc_info.value.error_code == "form_not_found"
|
||||
|
||||
def test_get_form_definition_by_id(self, form_service, sample_form_data):
|
||||
"""Test getting form definition by ID."""
|
||||
# Create form first
|
||||
form_service.create_form(**sample_form_data)
|
||||
|
||||
# Get form definition
|
||||
definition = form_service.get_form_definition("form-123", is_token=False)
|
||||
|
||||
assert "form_content" in definition
|
||||
assert "inputs" in definition
|
||||
assert definition["form_content"] == "# Test Form\n\nInput: {{#$output.input#}}"
|
||||
assert len(definition["inputs"]) == 1
|
||||
assert "site" not in definition # Should not include site info for ID-based access
|
||||
|
||||
def test_get_form_definition_by_token(self, form_service, sample_form_data):
|
||||
"""Test getting form definition by token."""
|
||||
# Create form first
|
||||
form_service.create_form(**sample_form_data)
|
||||
|
||||
# Get form definition
|
||||
definition = form_service.get_form_definition("token-xyz", is_token=True)
|
||||
|
||||
assert "form_content" in definition
|
||||
assert "inputs" in definition
|
||||
assert "site" in definition # Should include site info for token-based access
|
||||
|
||||
def test_get_form_definition_expired_form(self, form_service, sample_form_data):
|
||||
"""Test getting definition for expired form."""
|
||||
# Create form with past expiry
|
||||
form_service.create_form(**sample_form_data)
|
||||
|
||||
# Manually expire the form by modifying expiry time
|
||||
form = form_service.get_form_by_id("form-123")
|
||||
form.expires_at = datetime.utcnow() - timedelta(hours=1)
|
||||
form_service.repository.save(form)
|
||||
|
||||
# Should raise FormExpiredError
|
||||
with pytest.raises(FormExpiredError) as exc_info:
|
||||
form_service.get_form_definition("form-123", is_token=False)
|
||||
|
||||
assert exc_info.value.error_code == "human_input_form_expired"
|
||||
|
||||
def test_get_form_definition_submitted_form(self, form_service, sample_form_data):
|
||||
"""Test getting definition for already submitted form."""
|
||||
# Create form first
|
||||
form_service.create_form(**sample_form_data)
|
||||
|
||||
# Submit the form
|
||||
submission_data = FormSubmissionData(form_id="form-123", inputs={"input": "test value"}, action="submit")
|
||||
form_service.submit_form("form-123", submission_data, is_token=False)
|
||||
|
||||
# Should raise FormAlreadySubmittedError
|
||||
with pytest.raises(FormAlreadySubmittedError) as exc_info:
|
||||
form_service.get_form_definition("form-123", is_token=False)
|
||||
|
||||
assert exc_info.value.error_code == "human_input_form_submitted"
|
||||
|
||||
def test_submit_form_success(self, form_service, sample_form_data):
|
||||
"""Test successful form submission."""
|
||||
# Create form first
|
||||
form_service.create_form(**sample_form_data)
|
||||
|
||||
# Submit form
|
||||
submission_data = FormSubmissionData(form_id="form-123", inputs={"input": "test value"}, action="submit")
|
||||
|
||||
# Should not raise any exception
|
||||
form_service.submit_form("form-123", submission_data, is_token=False)
|
||||
|
||||
# Verify form is marked as submitted
|
||||
form = form_service.get_form_by_id("form-123")
|
||||
assert form.is_submitted
|
||||
assert form.submitted_data == {"input": "test value"}
|
||||
assert form.submitted_action == "submit"
|
||||
assert form.submitted_at is not None
|
||||
|
||||
def test_submit_form_missing_inputs(self, form_service, sample_form_data):
|
||||
"""Test form submission with missing inputs."""
|
||||
# Create form first
|
||||
form_service.create_form(**sample_form_data)
|
||||
|
||||
# Submit form with missing required input
|
||||
submission_data = FormSubmissionData(
|
||||
form_id="form-123",
|
||||
inputs={}, # Missing required "input" field
|
||||
action="submit",
|
||||
)
|
||||
|
||||
with pytest.raises(InvalidFormDataError) as exc_info:
|
||||
form_service.submit_form("form-123", submission_data, is_token=False)
|
||||
|
||||
assert "Missing required inputs" in exc_info.value.message
|
||||
assert "input" in exc_info.value.message
|
||||
|
||||
def test_submit_form_invalid_action(self, form_service, sample_form_data):
|
||||
"""Test form submission with invalid action."""
|
||||
# Create form first
|
||||
form_service.create_form(**sample_form_data)
|
||||
|
||||
# Submit form with invalid action
|
||||
submission_data = FormSubmissionData(
|
||||
form_id="form-123",
|
||||
inputs={"input": "test value"},
|
||||
action="invalid_action", # Not in the allowed actions
|
||||
)
|
||||
|
||||
with pytest.raises(InvalidFormDataError) as exc_info:
|
||||
form_service.submit_form("form-123", submission_data, is_token=False)
|
||||
|
||||
assert "Invalid action" in exc_info.value.message
|
||||
assert "invalid_action" in exc_info.value.message
|
||||
|
||||
def test_submit_form_expired(self, form_service, sample_form_data):
|
||||
"""Test submitting expired form."""
|
||||
# Create form first
|
||||
form_service.create_form(**sample_form_data)
|
||||
|
||||
# Manually expire the form
|
||||
form = form_service.get_form_by_id("form-123")
|
||||
form.expires_at = datetime.utcnow() - timedelta(hours=1)
|
||||
form_service.repository.save(form)
|
||||
|
||||
# Try to submit expired form
|
||||
submission_data = FormSubmissionData(form_id="form-123", inputs={"input": "test value"}, action="submit")
|
||||
|
||||
with pytest.raises(FormExpiredError) as exc_info:
|
||||
form_service.submit_form("form-123", submission_data, is_token=False)
|
||||
|
||||
assert exc_info.value.error_code == "human_input_form_expired"
|
||||
|
||||
def test_submit_form_already_submitted(self, form_service, sample_form_data):
|
||||
"""Test submitting form that's already submitted."""
|
||||
# Create and submit form first
|
||||
form_service.create_form(**sample_form_data)
|
||||
|
||||
submission_data = FormSubmissionData(form_id="form-123", inputs={"input": "first submission"}, action="submit")
|
||||
form_service.submit_form("form-123", submission_data, is_token=False)
|
||||
|
||||
# Try to submit again
|
||||
second_submission = FormSubmissionData(
|
||||
form_id="form-123", inputs={"input": "second submission"}, action="submit"
|
||||
)
|
||||
|
||||
with pytest.raises(FormAlreadySubmittedError) as exc_info:
|
||||
form_service.submit_form("form-123", second_submission, is_token=False)
|
||||
|
||||
assert exc_info.value.error_code == "human_input_form_submitted"
|
||||
|
||||
def test_cleanup_expired_forms(self, form_service, sample_form_data):
|
||||
"""Test cleanup of expired forms."""
|
||||
# Create multiple forms
|
||||
for i in range(3):
|
||||
data = sample_form_data.copy()
|
||||
data["form_id"] = f"form-{i}"
|
||||
data["form_token"] = f"token-{i}"
|
||||
form_service.create_form(**data)
|
||||
|
||||
# Manually expire some forms
|
||||
for i in range(2): # Expire first 2 forms
|
||||
form = form_service.get_form_by_id(f"form-{i}")
|
||||
form.expires_at = naive_utc_now() - timedelta(hours=1)
|
||||
form_service.repository.save(form)
|
||||
|
||||
# Clean up expired forms
|
||||
cleaned_count = form_service.cleanup_expired_forms()
|
||||
|
||||
assert cleaned_count == 2
|
||||
|
||||
# Verify expired forms are gone
|
||||
with pytest.raises(FormNotFoundError):
|
||||
form_service.get_form_by_id("form-0")
|
||||
|
||||
with pytest.raises(FormNotFoundError):
|
||||
form_service.get_form_by_id("form-1")
|
||||
|
||||
# Verify non-expired form still exists
|
||||
form = form_service.get_form_by_id("form-2")
|
||||
assert form.form_id == "form-2"
|
||||
|
||||
|
||||
class TestFormValidation:
|
||||
"""Test form validation logic."""
|
||||
|
||||
def test_validate_submission_with_extra_inputs(self):
|
||||
"""Test validation allows extra inputs that aren't defined in form."""
|
||||
repository = InMemoryFormRepository()
|
||||
form_service = FormService(repository)
|
||||
|
||||
# Create form with one input
|
||||
form_data = {
|
||||
"form_id": "form-123",
|
||||
"workflow_run_id": "run-456",
|
||||
"node_id": "node-789",
|
||||
"tenant_id": "tenant-abc",
|
||||
"app_id": "app-def",
|
||||
"form_content": "Test form",
|
||||
"inputs": [FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="required_input", default=None)],
|
||||
"user_actions": [UserAction(id="submit", title="Submit")],
|
||||
"timeout": 1,
|
||||
"timeout_unit": TimeoutUnit.HOUR,
|
||||
}
|
||||
|
||||
form_service.create_form(**form_data)
|
||||
|
||||
# Submit with extra input (should be allowed)
|
||||
submission_data = FormSubmissionData(
|
||||
form_id="form-123",
|
||||
inputs={
|
||||
"required_input": "value1",
|
||||
"extra_input": "value2", # Extra input not defined in form
|
||||
},
|
||||
action="submit",
|
||||
)
|
||||
|
||||
# Should not raise any exception
|
||||
form_service.submit_form("form-123", submission_data, is_token=False)
|
||||
232
api/tests/unit_tests/libs/_human_input/test_models.py
Normal file
232
api/tests/unit_tests/libs/_human_input/test_models.py
Normal file
@ -0,0 +1,232 @@
|
||||
"""
|
||||
Unit tests for human input form models.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
FormInput,
|
||||
UserAction,
|
||||
)
|
||||
from core.workflow.nodes.human_input.enums import (
|
||||
FormInputType,
|
||||
TimeoutUnit,
|
||||
)
|
||||
|
||||
from .support import FormSubmissionData, FormSubmissionRequest, HumanInputForm
|
||||
|
||||
|
||||
class TestHumanInputForm:
|
||||
"""Test HumanInputForm model."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_form_data(self):
|
||||
"""Create sample form data."""
|
||||
return {
|
||||
"form_id": "form-123",
|
||||
"workflow_run_id": "run-456",
|
||||
"node_id": "node-789",
|
||||
"tenant_id": "tenant-abc",
|
||||
"app_id": "app-def",
|
||||
"form_content": "# Test Form\n\nInput: {{#$output.input#}}",
|
||||
"inputs": [FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="input", default=None)],
|
||||
"user_actions": [UserAction(id="submit", title="Submit")],
|
||||
"timeout": 2,
|
||||
"timeout_unit": TimeoutUnit.HOUR,
|
||||
"form_token": "token-xyz",
|
||||
}
|
||||
|
||||
def test_form_creation(self, sample_form_data):
|
||||
"""Test form creation."""
|
||||
form = HumanInputForm(**sample_form_data)
|
||||
|
||||
assert form.form_id == "form-123"
|
||||
assert form.workflow_run_id == "run-456"
|
||||
assert form.node_id == "node-789"
|
||||
assert form.tenant_id == "tenant-abc"
|
||||
assert form.app_id == "app-def"
|
||||
assert form.form_token == "token-xyz"
|
||||
assert form.timeout == 2
|
||||
assert form.timeout_unit == TimeoutUnit.HOUR
|
||||
assert form.created_at is not None
|
||||
assert form.expires_at is not None
|
||||
assert form.submitted_at is None
|
||||
assert form.submitted_data is None
|
||||
assert form.submitted_action is None
|
||||
|
||||
def test_form_expiry_calculation_hours(self, sample_form_data):
|
||||
"""Test form expiry calculation for hours."""
|
||||
form = HumanInputForm(**sample_form_data)
|
||||
|
||||
# Should expire 2 hours after creation
|
||||
expected_expiry = form.created_at + timedelta(hours=2)
|
||||
assert abs((form.expires_at - expected_expiry).total_seconds()) < 1 # Within 1 second
|
||||
|
||||
def test_form_expiry_calculation_days(self, sample_form_data):
|
||||
"""Test form expiry calculation for days."""
|
||||
sample_form_data["timeout"] = 3
|
||||
sample_form_data["timeout_unit"] = TimeoutUnit.DAY
|
||||
|
||||
form = HumanInputForm(**sample_form_data)
|
||||
|
||||
# Should expire 3 days after creation
|
||||
expected_expiry = form.created_at + timedelta(days=3)
|
||||
assert abs((form.expires_at - expected_expiry).total_seconds()) < 1 # Within 1 second
|
||||
|
||||
def test_form_expiry_property_not_expired(self, sample_form_data):
|
||||
"""Test is_expired property for non-expired form."""
|
||||
form = HumanInputForm(**sample_form_data)
|
||||
assert not form.is_expired
|
||||
|
||||
def test_form_expiry_property_expired(self, sample_form_data):
|
||||
"""Test is_expired property for expired form."""
|
||||
# Create form with past expiry
|
||||
past_time = datetime.utcnow() - timedelta(hours=1)
|
||||
sample_form_data["created_at"] = past_time
|
||||
|
||||
form = HumanInputForm(**sample_form_data)
|
||||
# Manually set expiry to past time
|
||||
form.expires_at = past_time
|
||||
|
||||
assert form.is_expired
|
||||
|
||||
def test_form_submission_property_not_submitted(self, sample_form_data):
|
||||
"""Test is_submitted property for non-submitted form."""
|
||||
form = HumanInputForm(**sample_form_data)
|
||||
assert not form.is_submitted
|
||||
|
||||
def test_form_submission_property_submitted(self, sample_form_data):
|
||||
"""Test is_submitted property for submitted form."""
|
||||
form = HumanInputForm(**sample_form_data)
|
||||
form.submit({"input": "test value"}, "submit")
|
||||
|
||||
assert form.is_submitted
|
||||
assert form.submitted_at is not None
|
||||
assert form.submitted_data == {"input": "test value"}
|
||||
assert form.submitted_action == "submit"
|
||||
|
||||
def test_form_submit_method(self, sample_form_data):
|
||||
"""Test form submit method."""
|
||||
form = HumanInputForm(**sample_form_data)
|
||||
|
||||
submission_time_before = datetime.utcnow()
|
||||
form.submit({"input": "test value"}, "submit")
|
||||
submission_time_after = datetime.utcnow()
|
||||
|
||||
assert form.is_submitted
|
||||
assert form.submitted_data == {"input": "test value"}
|
||||
assert form.submitted_action == "submit"
|
||||
assert submission_time_before <= form.submitted_at <= submission_time_after
|
||||
|
||||
def test_form_to_response_dict_without_site_info(self, sample_form_data):
|
||||
"""Test converting form to response dict without site info."""
|
||||
form = HumanInputForm(**sample_form_data)
|
||||
|
||||
response = form.to_response_dict(include_site_info=False)
|
||||
|
||||
assert "form_content" in response
|
||||
assert "inputs" in response
|
||||
assert "site" not in response
|
||||
assert response["form_content"] == "# Test Form\n\nInput: {{#$output.input#}}"
|
||||
assert len(response["inputs"]) == 1
|
||||
assert response["inputs"][0]["type"] == "text-input"
|
||||
assert response["inputs"][0]["output_variable_name"] == "input"
|
||||
|
||||
def test_form_to_response_dict_with_site_info(self, sample_form_data):
|
||||
"""Test converting form to response dict with site info."""
|
||||
form = HumanInputForm(**sample_form_data)
|
||||
|
||||
response = form.to_response_dict(include_site_info=True)
|
||||
|
||||
assert "form_content" in response
|
||||
assert "inputs" in response
|
||||
assert "site" in response
|
||||
assert response["site"]["app_id"] == "app-def"
|
||||
assert response["site"]["title"] == "Workflow Form"
|
||||
|
||||
def test_form_without_web_app_token(self, sample_form_data):
|
||||
"""Test form creation without web app token."""
|
||||
sample_form_data["form_token"] = None
|
||||
|
||||
form = HumanInputForm(**sample_form_data)
|
||||
|
||||
assert form.form_token is None
|
||||
assert form.form_id == "form-123" # Other fields should still work
|
||||
|
||||
def test_form_with_explicit_timestamps(self):
|
||||
"""Test form creation with explicit timestamps."""
|
||||
created_time = datetime(2024, 1, 15, 10, 30, 0)
|
||||
expires_time = datetime(2024, 1, 15, 12, 30, 0)
|
||||
|
||||
form = HumanInputForm(
|
||||
form_id="form-123",
|
||||
workflow_run_id="run-456",
|
||||
node_id="node-789",
|
||||
tenant_id="tenant-abc",
|
||||
app_id="app-def",
|
||||
form_content="Test content",
|
||||
inputs=[],
|
||||
user_actions=[],
|
||||
timeout=2,
|
||||
timeout_unit=TimeoutUnit.HOUR,
|
||||
created_at=created_time,
|
||||
expires_at=expires_time,
|
||||
)
|
||||
|
||||
assert form.created_at == created_time
|
||||
assert form.expires_at == expires_time
|
||||
|
||||
|
||||
class TestFormSubmissionData:
|
||||
"""Test FormSubmissionData model."""
|
||||
|
||||
def test_submission_data_creation(self):
|
||||
"""Test submission data creation."""
|
||||
submission_data = FormSubmissionData(
|
||||
form_id="form-123", inputs={"field1": "value1", "field2": "value2"}, action="submit"
|
||||
)
|
||||
|
||||
assert submission_data.form_id == "form-123"
|
||||
assert submission_data.inputs == {"field1": "value1", "field2": "value2"}
|
||||
assert submission_data.action == "submit"
|
||||
assert submission_data.submitted_at is not None
|
||||
|
||||
def test_submission_data_from_request(self):
|
||||
"""Test creating submission data from API request."""
|
||||
request = FormSubmissionRequest(inputs={"input": "test value"}, action="confirm")
|
||||
|
||||
submission_data = FormSubmissionData.from_request("form-456", request)
|
||||
|
||||
assert submission_data.form_id == "form-456"
|
||||
assert submission_data.inputs == {"input": "test value"}
|
||||
assert submission_data.action == "confirm"
|
||||
assert submission_data.submitted_at is not None
|
||||
|
||||
def test_submission_data_with_empty_inputs(self):
|
||||
"""Test submission data with empty inputs."""
|
||||
submission_data = FormSubmissionData(form_id="form-123", inputs={}, action="cancel")
|
||||
|
||||
assert submission_data.inputs == {}
|
||||
assert submission_data.action == "cancel"
|
||||
|
||||
def test_submission_data_timestamps(self):
|
||||
"""Test submission data timestamp handling."""
|
||||
before_time = datetime.utcnow()
|
||||
|
||||
submission_data = FormSubmissionData(form_id="form-123", inputs={"test": "value"}, action="submit")
|
||||
|
||||
after_time = datetime.utcnow()
|
||||
|
||||
assert before_time <= submission_data.submitted_at <= after_time
|
||||
|
||||
def test_submission_data_with_explicit_timestamp(self):
|
||||
"""Test submission data with explicit timestamp."""
|
||||
specific_time = datetime(2024, 1, 15, 14, 30, 0)
|
||||
|
||||
submission_data = FormSubmissionData(
|
||||
form_id="form-123", inputs={"test": "value"}, action="submit", submitted_at=specific_time
|
||||
)
|
||||
|
||||
assert submission_data.submitted_at == specific_time
|
||||
@ -181,6 +181,7 @@ class TestShardedTopic:
|
||||
subscription = sharded_topic.subscribe()
|
||||
|
||||
assert isinstance(subscription, _RedisShardedSubscription)
|
||||
assert subscription._client is mock_redis_client
|
||||
assert subscription._pubsub is mock_redis_client.pubsub.return_value
|
||||
assert subscription._topic == "test-sharded-topic"
|
||||
|
||||
@ -200,6 +201,11 @@ class SubscriptionTestCase:
|
||||
class TestRedisSubscription:
|
||||
"""Test cases for the _RedisSubscription class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis_client(self) -> MagicMock:
|
||||
client = MagicMock()
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pubsub(self) -> MagicMock:
|
||||
"""Create a mock PubSub instance for testing."""
|
||||
@ -211,9 +217,12 @@ class TestRedisSubscription:
|
||||
return pubsub
|
||||
|
||||
@pytest.fixture
|
||||
def subscription(self, mock_pubsub: MagicMock) -> Generator[_RedisSubscription, None, None]:
|
||||
def subscription(
|
||||
self, mock_pubsub: MagicMock, mock_redis_client: MagicMock
|
||||
) -> Generator[_RedisSubscription, None, None]:
|
||||
"""Create a _RedisSubscription instance for testing."""
|
||||
subscription = _RedisSubscription(
|
||||
client=mock_redis_client,
|
||||
pubsub=mock_pubsub,
|
||||
topic="test-topic",
|
||||
)
|
||||
@ -228,13 +237,15 @@ class TestRedisSubscription:
|
||||
|
||||
# ==================== Lifecycle Tests ====================
|
||||
|
||||
def test_subscription_initialization(self, mock_pubsub: MagicMock):
|
||||
def test_subscription_initialization(self, mock_pubsub: MagicMock, mock_redis_client: MagicMock):
|
||||
"""Test that subscription is properly initialized."""
|
||||
subscription = _RedisSubscription(
|
||||
client=mock_redis_client,
|
||||
pubsub=mock_pubsub,
|
||||
topic="test-topic",
|
||||
)
|
||||
|
||||
assert subscription._client is mock_redis_client
|
||||
assert subscription._pubsub is mock_pubsub
|
||||
assert subscription._topic == "test-topic"
|
||||
assert not subscription._closed.is_set()
|
||||
@ -486,9 +497,12 @@ class TestRedisSubscription:
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_subscription_scenarios(self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock):
|
||||
def test_subscription_scenarios(
|
||||
self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock, mock_redis_client: MagicMock
|
||||
):
|
||||
"""Test various subscription scenarios using table-driven approach."""
|
||||
subscription = _RedisSubscription(
|
||||
client=mock_redis_client,
|
||||
pubsub=mock_pubsub,
|
||||
topic="test-topic",
|
||||
)
|
||||
@ -572,7 +586,7 @@ class TestRedisSubscription:
|
||||
# Close should still work
|
||||
subscription.close() # Should not raise
|
||||
|
||||
def test_channel_name_variations(self, mock_pubsub: MagicMock):
|
||||
def test_channel_name_variations(self, mock_pubsub: MagicMock, mock_redis_client: MagicMock):
|
||||
"""Test various channel name formats."""
|
||||
channel_names = [
|
||||
"simple",
|
||||
@ -586,6 +600,7 @@ class TestRedisSubscription:
|
||||
|
||||
for channel_name in channel_names:
|
||||
subscription = _RedisSubscription(
|
||||
client=mock_redis_client,
|
||||
pubsub=mock_pubsub,
|
||||
topic=channel_name,
|
||||
)
|
||||
@ -604,6 +619,11 @@ class TestRedisSubscription:
|
||||
class TestRedisShardedSubscription:
|
||||
"""Test cases for the _RedisShardedSubscription class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis_client(self) -> MagicMock:
|
||||
client = MagicMock()
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pubsub(self) -> MagicMock:
|
||||
"""Create a mock PubSub instance for testing."""
|
||||
@ -615,9 +635,12 @@ class TestRedisShardedSubscription:
|
||||
return pubsub
|
||||
|
||||
@pytest.fixture
|
||||
def sharded_subscription(self, mock_pubsub: MagicMock) -> Generator[_RedisShardedSubscription, None, None]:
|
||||
def sharded_subscription(
|
||||
self, mock_pubsub: MagicMock, mock_redis_client: MagicMock
|
||||
) -> Generator[_RedisShardedSubscription, None, None]:
|
||||
"""Create a _RedisShardedSubscription instance for testing."""
|
||||
subscription = _RedisShardedSubscription(
|
||||
client=mock_redis_client,
|
||||
pubsub=mock_pubsub,
|
||||
topic="test-sharded-topic",
|
||||
)
|
||||
@ -634,13 +657,15 @@ class TestRedisShardedSubscription:
|
||||
|
||||
# ==================== Lifecycle Tests ====================
|
||||
|
||||
def test_sharded_subscription_initialization(self, mock_pubsub: MagicMock):
|
||||
def test_sharded_subscription_initialization(self, mock_pubsub: MagicMock, mock_redis_client: MagicMock):
|
||||
"""Test that sharded subscription is properly initialized."""
|
||||
subscription = _RedisShardedSubscription(
|
||||
client=mock_redis_client,
|
||||
pubsub=mock_pubsub,
|
||||
topic="test-sharded-topic",
|
||||
)
|
||||
|
||||
assert subscription._client is mock_redis_client
|
||||
assert subscription._pubsub is mock_pubsub
|
||||
assert subscription._topic == "test-sharded-topic"
|
||||
assert not subscription._closed.is_set()
|
||||
@ -808,6 +833,37 @@ class TestRedisShardedSubscription:
|
||||
assert not sharded_subscription._queue.empty()
|
||||
assert sharded_subscription._queue.get_nowait() == b"test sharded payload"
|
||||
|
||||
def test_get_message_uses_target_node_for_cluster_client(self, mock_pubsub: MagicMock, monkeypatch):
|
||||
"""Test that cluster clients use target_node for sharded messages."""
|
||||
|
||||
class DummyRedisCluster:
|
||||
def __init__(self):
|
||||
self.get_node_from_key = MagicMock(return_value="node-1")
|
||||
|
||||
monkeypatch.setattr("libs.broadcast_channel.redis.sharded_channel.RedisCluster", DummyRedisCluster)
|
||||
|
||||
client = DummyRedisCluster()
|
||||
subscription = _RedisShardedSubscription(
|
||||
client=client,
|
||||
pubsub=mock_pubsub,
|
||||
topic="test-sharded-topic",
|
||||
)
|
||||
mock_pubsub.get_sharded_message.return_value = {
|
||||
"type": "smessage",
|
||||
"channel": "test-sharded-topic",
|
||||
"data": b"payload",
|
||||
}
|
||||
|
||||
result = subscription._get_message()
|
||||
|
||||
client.get_node_from_key.assert_called_once_with("test-sharded-topic")
|
||||
mock_pubsub.get_sharded_message.assert_called_once_with(
|
||||
ignore_subscribe_messages=False,
|
||||
timeout=1,
|
||||
target_node="node-1",
|
||||
)
|
||||
assert result == mock_pubsub.get_sharded_message.return_value
|
||||
|
||||
def test_listener_thread_ignores_subscribe_messages(
|
||||
self, sharded_subscription: _RedisShardedSubscription, mock_pubsub: MagicMock
|
||||
):
|
||||
@ -913,9 +969,12 @@ class TestRedisShardedSubscription:
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_sharded_subscription_scenarios(self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock):
|
||||
def test_sharded_subscription_scenarios(
|
||||
self, test_case: SubscriptionTestCase, mock_pubsub: MagicMock, mock_redis_client: MagicMock
|
||||
):
|
||||
"""Test various sharded subscription scenarios using table-driven approach."""
|
||||
subscription = _RedisShardedSubscription(
|
||||
client=mock_redis_client,
|
||||
pubsub=mock_pubsub,
|
||||
topic="test-sharded-topic",
|
||||
)
|
||||
@ -999,7 +1058,7 @@ class TestRedisShardedSubscription:
|
||||
# Close should still work
|
||||
sharded_subscription.close() # Should not raise
|
||||
|
||||
def test_channel_name_variations(self, mock_pubsub: MagicMock):
|
||||
def test_channel_name_variations(self, mock_pubsub: MagicMock, mock_redis_client: MagicMock):
|
||||
"""Test various sharded channel name formats."""
|
||||
channel_names = [
|
||||
"simple",
|
||||
@ -1013,6 +1072,7 @@ class TestRedisShardedSubscription:
|
||||
|
||||
for channel_name in channel_names:
|
||||
subscription = _RedisShardedSubscription(
|
||||
client=mock_redis_client,
|
||||
pubsub=mock_pubsub,
|
||||
topic=channel_name,
|
||||
)
|
||||
@ -1060,6 +1120,11 @@ class TestRedisSubscriptionCommon:
|
||||
"""Parameterized fixture providing subscription type and class."""
|
||||
return request.param
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis_client(self) -> MagicMock:
|
||||
client = MagicMock()
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pubsub(self) -> MagicMock:
|
||||
"""Create a mock PubSub instance for testing."""
|
||||
@ -1075,11 +1140,12 @@ class TestRedisSubscriptionCommon:
|
||||
return pubsub
|
||||
|
||||
@pytest.fixture
|
||||
def subscription(self, subscription_params, mock_pubsub: MagicMock):
|
||||
def subscription(self, subscription_params, mock_pubsub: MagicMock, mock_redis_client: MagicMock):
|
||||
"""Create a subscription instance based on parameterized type."""
|
||||
subscription_type, subscription_class = subscription_params
|
||||
topic_name = f"test-{subscription_type}-topic"
|
||||
subscription = subscription_class(
|
||||
client=mock_redis_client,
|
||||
pubsub=mock_pubsub,
|
||||
topic=topic_name,
|
||||
)
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from libs.helper import escape_like_pattern, extract_tenant_id
|
||||
from libs.helper import OptionalTimestampField, escape_like_pattern, extract_tenant_id
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
|
||||
@ -65,6 +67,19 @@ class TestExtractTenantId:
|
||||
extract_tenant_id(dict_user)
|
||||
|
||||
|
||||
class TestOptionalTimestampField:
|
||||
def test_format_returns_none_for_none(self):
|
||||
field = OptionalTimestampField()
|
||||
|
||||
assert field.format(None) is None
|
||||
|
||||
def test_format_returns_unix_timestamp_for_datetime(self):
|
||||
field = OptionalTimestampField()
|
||||
value = datetime(2024, 1, 2, 3, 4, 5)
|
||||
|
||||
assert field.format(value) == int(value.timestamp())
|
||||
|
||||
|
||||
class TestEscapeLikePattern:
|
||||
"""Test cases for the escape_like_pattern utility function."""
|
||||
|
||||
|
||||
68
api/tests/unit_tests/libs/test_rate_limiter.py
Normal file
68
api/tests/unit_tests/libs/test_rate_limiter.py
Normal file
@ -0,0 +1,68 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from libs import helper as helper_module
|
||||
|
||||
|
||||
class _FakeRedis:
|
||||
def __init__(self) -> None:
|
||||
self._zsets: dict[str, dict[str, float]] = {}
|
||||
self._expiry: dict[str, int] = {}
|
||||
|
||||
def zadd(self, key: str, mapping: dict[str, float]) -> int:
|
||||
zset = self._zsets.setdefault(key, {})
|
||||
for member, score in mapping.items():
|
||||
zset[str(member)] = float(score)
|
||||
return len(mapping)
|
||||
|
||||
def zremrangebyscore(self, key: str, min_score: str | float, max_score: str | float) -> int:
|
||||
zset = self._zsets.get(key, {})
|
||||
min_value = float("-inf") if min_score == "-inf" else float(min_score)
|
||||
max_value = float("inf") if max_score == "+inf" else float(max_score)
|
||||
to_delete = [member for member, score in zset.items() if min_value <= score <= max_value]
|
||||
for member in to_delete:
|
||||
del zset[member]
|
||||
return len(to_delete)
|
||||
|
||||
def zcard(self, key: str) -> int:
|
||||
return len(self._zsets.get(key, {}))
|
||||
|
||||
def expire(self, key: str, ttl: int) -> bool:
|
||||
self._expiry[key] = ttl
|
||||
return True
|
||||
|
||||
|
||||
def test_rate_limiter_counts_attempts_within_same_second(monkeypatch):
|
||||
fake_redis = _FakeRedis()
|
||||
monkeypatch.setattr(helper_module.time, "time", lambda: 1000)
|
||||
|
||||
limiter = helper_module.RateLimiter(
|
||||
prefix="test_rate_limit",
|
||||
max_attempts=2,
|
||||
time_window=60,
|
||||
redis_client=fake_redis,
|
||||
)
|
||||
|
||||
limiter.increment_rate_limit("203.0.113.10")
|
||||
limiter.increment_rate_limit("203.0.113.10")
|
||||
|
||||
assert limiter.is_rate_limited("203.0.113.10") is True
|
||||
|
||||
|
||||
def test_rate_limiter_uses_injected_redis(monkeypatch):
|
||||
redis_client = MagicMock()
|
||||
redis_client.zcard.return_value = 1
|
||||
monkeypatch.setattr(helper_module.time, "time", lambda: 1000)
|
||||
|
||||
limiter = helper_module.RateLimiter(
|
||||
prefix="test_rate_limit",
|
||||
max_attempts=1,
|
||||
time_window=60,
|
||||
redis_client=redis_client,
|
||||
)
|
||||
|
||||
limiter.increment_rate_limit("203.0.113.10")
|
||||
limiter.is_rate_limited("203.0.113.10")
|
||||
|
||||
assert redis_client.zadd.called is True
|
||||
assert redis_client.zremrangebyscore.called is True
|
||||
assert redis_client.zcard.called is True
|
||||
@ -1296,6 +1296,7 @@ class TestConversationStatusCount:
|
||||
assert result["success"] == 1 # One SUCCEEDED
|
||||
assert result["failed"] == 1 # One FAILED
|
||||
assert result["partial_success"] == 1 # One PARTIAL_SUCCEEDED
|
||||
assert result["paused"] == 0
|
||||
|
||||
def test_status_count_app_id_filtering(self):
|
||||
"""Test that status_count filters workflow runs by app_id for security."""
|
||||
@ -1350,6 +1351,7 @@ class TestConversationStatusCount:
|
||||
assert result["success"] == 0
|
||||
assert result["failed"] == 0
|
||||
assert result["partial_success"] == 0
|
||||
assert result["paused"] == 0
|
||||
|
||||
def test_status_count_handles_invalid_workflow_status(self):
|
||||
"""Test that status_count gracefully handles invalid workflow status values."""
|
||||
@ -1404,3 +1406,57 @@ class TestConversationStatusCount:
|
||||
assert result["success"] == 0
|
||||
assert result["failed"] == 0
|
||||
assert result["partial_success"] == 0
|
||||
assert result["paused"] == 0
|
||||
|
||||
def test_status_count_paused(self):
|
||||
"""Test status_count includes paused workflow runs."""
|
||||
# Arrange
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
|
||||
app_id = str(uuid4())
|
||||
conversation_id = str(uuid4())
|
||||
workflow_run_id = str(uuid4())
|
||||
|
||||
conversation = Conversation(
|
||||
app_id=app_id,
|
||||
mode=AppMode.CHAT,
|
||||
name="Test Conversation",
|
||||
status="normal",
|
||||
from_source="api",
|
||||
)
|
||||
conversation.id = conversation_id
|
||||
|
||||
mock_messages = [
|
||||
MagicMock(
|
||||
conversation_id=conversation_id,
|
||||
workflow_run_id=workflow_run_id,
|
||||
),
|
||||
]
|
||||
|
||||
mock_workflow_runs = [
|
||||
MagicMock(
|
||||
id=workflow_run_id,
|
||||
status=WorkflowExecutionStatus.PAUSED.value,
|
||||
app_id=app_id,
|
||||
),
|
||||
]
|
||||
|
||||
with patch("models.model.db.session.scalars") as mock_scalars:
|
||||
|
||||
def mock_scalars_side_effect(query):
|
||||
mock_result = MagicMock()
|
||||
if "messages" in str(query):
|
||||
mock_result.all.return_value = mock_messages
|
||||
elif "workflow_runs" in str(query):
|
||||
mock_result.all.return_value = mock_workflow_runs
|
||||
else:
|
||||
mock_result.all.return_value = []
|
||||
return mock_result
|
||||
|
||||
mock_scalars.side_effect = mock_scalars_side_effect
|
||||
|
||||
# Act
|
||||
result = conversation.status_count
|
||||
|
||||
# Assert
|
||||
assert result["paused"] == 1
|
||||
|
||||
@ -0,0 +1,40 @@
|
||||
"""Unit tests for DifyAPISQLAlchemyWorkflowNodeExecutionRepository implementation."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from repositories.sqlalchemy_api_workflow_node_execution_repository import (
|
||||
DifyAPISQLAlchemyWorkflowNodeExecutionRepository,
|
||||
)
|
||||
|
||||
|
||||
class TestDifyAPISQLAlchemyWorkflowNodeExecutionRepository:
|
||||
def test_get_executions_by_workflow_run_keeps_paused_records(self):
|
||||
mock_session = Mock(spec=Session)
|
||||
execute_result = Mock()
|
||||
execute_result.scalars.return_value.all.return_value = []
|
||||
mock_session.execute.return_value = execute_result
|
||||
|
||||
session_maker = Mock(spec=sessionmaker)
|
||||
context_manager = Mock()
|
||||
context_manager.__enter__ = Mock(return_value=mock_session)
|
||||
context_manager.__exit__ = Mock(return_value=None)
|
||||
session_maker.return_value = context_manager
|
||||
|
||||
repository = DifyAPISQLAlchemyWorkflowNodeExecutionRepository(session_maker)
|
||||
|
||||
repository.get_executions_by_workflow_run(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-123",
|
||||
workflow_run_id="workflow-run-123",
|
||||
)
|
||||
|
||||
stmt = mock_session.execute.call_args[0][0]
|
||||
where_clauses = list(getattr(stmt, "_where_criteria", []) or [])
|
||||
where_strs = [str(clause).lower() for clause in where_clauses]
|
||||
|
||||
assert any("tenant_id" in clause for clause in where_strs)
|
||||
assert any("app_id" in clause for clause in where_strs)
|
||||
assert any("workflow_run_id" in clause for clause in where_strs)
|
||||
assert not any("paused" in clause for clause in where_strs)
|
||||
@ -1,5 +1,6 @@
|
||||
"""Unit tests for DifyAPISQLAlchemyWorkflowRunRepository implementation."""
|
||||
|
||||
import secrets
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
@ -7,12 +8,17 @@ import pytest
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired, PauseReasonType
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.nodes.human_input.entities import FormDefinition, FormInput, UserAction
|
||||
from core.workflow.nodes.human_input.enums import FormInputType, HumanInputFormStatus
|
||||
from models.human_input import BackstageRecipientPayload, HumanInputForm, HumanInputFormRecipient, RecipientType
|
||||
from models.workflow import WorkflowPause as WorkflowPauseModel
|
||||
from models.workflow import WorkflowRun
|
||||
from models.workflow import WorkflowPauseReason, WorkflowRun
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
from repositories.sqlalchemy_api_workflow_run_repository import (
|
||||
DifyAPISQLAlchemyWorkflowRunRepository,
|
||||
_build_human_input_required_reason,
|
||||
_PrivateWorkflowPauseEntity,
|
||||
_WorkflowRunError,
|
||||
)
|
||||
@ -205,11 +211,11 @@ class TestCreateWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
):
|
||||
"""Test workflow pause creation when workflow not in RUNNING status."""
|
||||
# Arrange
|
||||
sample_workflow_run.status = WorkflowExecutionStatus.PAUSED
|
||||
sample_workflow_run.status = WorkflowExecutionStatus.SUCCEEDED
|
||||
mock_session.get.return_value = sample_workflow_run
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(_WorkflowRunError, match="Only WorkflowRun with RUNNING status can be paused"):
|
||||
with pytest.raises(_WorkflowRunError, match="Only WorkflowRun with RUNNING or PAUSED status can be paused"):
|
||||
repository.create_workflow_pause(
|
||||
workflow_run_id="workflow-run-123",
|
||||
state_owner_user_id="user-123",
|
||||
@ -295,6 +301,7 @@ class TestResumeWorkflowPause(TestDifyAPISQLAlchemyWorkflowRunRepository):
|
||||
sample_workflow_pause.resumed_at = None
|
||||
|
||||
mock_session.scalar.return_value = sample_workflow_run
|
||||
mock_session.scalars.return_value.all.return_value = []
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.naive_utc_now") as mock_now:
|
||||
mock_now.return_value = datetime.now(UTC)
|
||||
@ -455,3 +462,53 @@ class TestPrivateWorkflowPauseEntity(TestDifyAPISQLAlchemyWorkflowRunRepository)
|
||||
assert result1 == expected_state
|
||||
assert result2 == expected_state
|
||||
mock_storage.load.assert_called_once() # Only called once due to caching
|
||||
|
||||
|
||||
class TestBuildHumanInputRequiredReason:
|
||||
def test_prefers_backstage_token_when_available(self):
|
||||
expiration_time = datetime.now(UTC)
|
||||
form_definition = FormDefinition(
|
||||
form_content="content",
|
||||
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")],
|
||||
user_actions=[UserAction(id="approve", title="Approve")],
|
||||
rendered_content="rendered",
|
||||
expiration_time=expiration_time,
|
||||
default_values={"name": "Alice"},
|
||||
node_title="Ask Name",
|
||||
display_in_ui=True,
|
||||
)
|
||||
form_model = HumanInputForm(
|
||||
id="form-1",
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
workflow_run_id="run-1",
|
||||
node_id="node-1",
|
||||
form_definition=form_definition.model_dump_json(),
|
||||
rendered_content="rendered",
|
||||
status=HumanInputFormStatus.WAITING,
|
||||
expiration_time=expiration_time,
|
||||
)
|
||||
reason_model = WorkflowPauseReason(
|
||||
pause_id="pause-1",
|
||||
type_=PauseReasonType.HUMAN_INPUT_REQUIRED,
|
||||
form_id="form-1",
|
||||
node_id="node-1",
|
||||
message="",
|
||||
)
|
||||
access_token = secrets.token_urlsafe(8)
|
||||
backstage_recipient = HumanInputFormRecipient(
|
||||
form_id="form-1",
|
||||
delivery_id="delivery-1",
|
||||
recipient_type=RecipientType.BACKSTAGE,
|
||||
recipient_payload=BackstageRecipientPayload().model_dump_json(),
|
||||
access_token=access_token,
|
||||
)
|
||||
|
||||
reason = _build_human_input_required_reason(reason_model, form_model, [backstage_recipient])
|
||||
|
||||
assert isinstance(reason, HumanInputRequired)
|
||||
assert reason.form_token == access_token
|
||||
assert reason.node_title == "Ask Name"
|
||||
assert reason.form_content == "content"
|
||||
assert reason.inputs[0].output_variable_name == "name"
|
||||
assert reason.actions[0].id == "approve"
|
||||
|
||||
@ -0,0 +1,180 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from core.entities.execution_extra_content import HumanInputContent as HumanInputContentDomain
|
||||
from core.entities.execution_extra_content import HumanInputFormSubmissionData
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
FormDefinition,
|
||||
UserAction,
|
||||
)
|
||||
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
|
||||
from models.execution_extra_content import HumanInputContent as HumanInputContentModel
|
||||
from models.human_input import ConsoleRecipientPayload, HumanInputForm, HumanInputFormRecipient, RecipientType
|
||||
from repositories.sqlalchemy_execution_extra_content_repository import SQLAlchemyExecutionExtraContentRepository
|
||||
|
||||
|
||||
class _FakeScalarResult:
|
||||
def __init__(self, values: Sequence[HumanInputContentModel]):
|
||||
self._values = list(values)
|
||||
|
||||
def all(self) -> list[HumanInputContentModel]:
|
||||
return list(self._values)
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, values: Sequence[Sequence[object]]):
|
||||
self._values = list(values)
|
||||
|
||||
def scalars(self, _stmt):
|
||||
if not self._values:
|
||||
return _FakeScalarResult([])
|
||||
return _FakeScalarResult(self._values.pop(0))
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FakeSessionMaker:
|
||||
session: _FakeSession
|
||||
|
||||
def __call__(self) -> _FakeSession:
|
||||
return self.session
|
||||
|
||||
|
||||
def _build_form(action_id: str, action_title: str, rendered_content: str) -> HumanInputForm:
|
||||
expiration_time = datetime.now(UTC) + timedelta(days=1)
|
||||
definition = FormDefinition(
|
||||
form_content="content",
|
||||
inputs=[],
|
||||
user_actions=[UserAction(id=action_id, title=action_title)],
|
||||
rendered_content="rendered",
|
||||
expiration_time=expiration_time,
|
||||
node_title="Approval",
|
||||
display_in_ui=True,
|
||||
)
|
||||
form = HumanInputForm(
|
||||
id=f"form-{action_id}",
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
workflow_run_id="workflow-run",
|
||||
node_id="node-id",
|
||||
form_definition=definition.model_dump_json(),
|
||||
rendered_content=rendered_content,
|
||||
status=HumanInputFormStatus.SUBMITTED,
|
||||
expiration_time=expiration_time,
|
||||
)
|
||||
form.selected_action_id = action_id
|
||||
return form
|
||||
|
||||
|
||||
def _build_content(message_id: str, action_id: str, action_title: str) -> HumanInputContentModel:
|
||||
form = _build_form(
|
||||
action_id=action_id,
|
||||
action_title=action_title,
|
||||
rendered_content=f"Rendered {action_title}",
|
||||
)
|
||||
content = HumanInputContentModel(
|
||||
id=f"content-{message_id}",
|
||||
form_id=form.id,
|
||||
message_id=message_id,
|
||||
workflow_run_id=form.workflow_run_id,
|
||||
)
|
||||
content.form = form
|
||||
return content
|
||||
|
||||
|
||||
def test_get_by_message_ids_groups_contents_by_message() -> None:
|
||||
message_ids = ["msg-1", "msg-2"]
|
||||
contents = [_build_content("msg-1", "approve", "Approve")]
|
||||
repository = SQLAlchemyExecutionExtraContentRepository(
|
||||
session_maker=_FakeSessionMaker(session=_FakeSession(values=[contents, []]))
|
||||
)
|
||||
|
||||
result = repository.get_by_message_ids(message_ids)
|
||||
|
||||
assert len(result) == 2
|
||||
assert [content.model_dump(mode="json", exclude_none=True) for content in result[0]] == [
|
||||
HumanInputContentDomain(
|
||||
workflow_run_id="workflow-run",
|
||||
submitted=True,
|
||||
form_submission_data=HumanInputFormSubmissionData(
|
||||
node_id="node-id",
|
||||
node_title="Approval",
|
||||
rendered_content="Rendered Approve",
|
||||
action_id="approve",
|
||||
action_text="Approve",
|
||||
),
|
||||
).model_dump(mode="json", exclude_none=True)
|
||||
]
|
||||
assert result[1] == []
|
||||
|
||||
|
||||
def test_get_by_message_ids_returns_unsubmitted_form_definition() -> None:
|
||||
expiration_time = datetime.now(UTC) + timedelta(days=1)
|
||||
definition = FormDefinition(
|
||||
form_content="content",
|
||||
inputs=[],
|
||||
user_actions=[UserAction(id="approve", title="Approve")],
|
||||
rendered_content="rendered",
|
||||
expiration_time=expiration_time,
|
||||
default_values={"name": "John"},
|
||||
node_title="Approval",
|
||||
display_in_ui=True,
|
||||
)
|
||||
form = HumanInputForm(
|
||||
id="form-1",
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
workflow_run_id="workflow-run",
|
||||
node_id="node-id",
|
||||
form_definition=definition.model_dump_json(),
|
||||
rendered_content="Rendered block",
|
||||
status=HumanInputFormStatus.WAITING,
|
||||
expiration_time=expiration_time,
|
||||
)
|
||||
content = HumanInputContentModel(
|
||||
id="content-msg-1",
|
||||
form_id=form.id,
|
||||
message_id="msg-1",
|
||||
workflow_run_id=form.workflow_run_id,
|
||||
)
|
||||
content.form = form
|
||||
|
||||
recipient = HumanInputFormRecipient(
|
||||
form_id=form.id,
|
||||
delivery_id="delivery-1",
|
||||
recipient_type=RecipientType.CONSOLE,
|
||||
recipient_payload=ConsoleRecipientPayload(account_id=None).model_dump_json(),
|
||||
access_token="token-1",
|
||||
)
|
||||
|
||||
repository = SQLAlchemyExecutionExtraContentRepository(
|
||||
session_maker=_FakeSessionMaker(session=_FakeSession(values=[[content], [recipient]]))
|
||||
)
|
||||
|
||||
result = repository.get_by_message_ids(["msg-1"])
|
||||
|
||||
assert len(result) == 1
|
||||
assert len(result[0]) == 1
|
||||
domain_content = result[0][0]
|
||||
assert domain_content.submitted is False
|
||||
assert domain_content.workflow_run_id == "workflow-run"
|
||||
assert domain_content.form_definition is not None
|
||||
assert domain_content.form_definition.expiration_time == int(form.expiration_time.timestamp())
|
||||
assert domain_content.form_definition is not None
|
||||
form_definition = domain_content.form_definition
|
||||
assert form_definition.form_id == "form-1"
|
||||
assert form_definition.node_id == "node-id"
|
||||
assert form_definition.node_title == "Approval"
|
||||
assert form_definition.form_content == "Rendered block"
|
||||
assert form_definition.display_in_ui is True
|
||||
assert form_definition.form_token == "token-1"
|
||||
assert form_definition.resolved_default_values == {"name": "John"}
|
||||
assert form_definition.expiration_time == int(form.expiration_time.timestamp())
|
||||
65
api/tests/unit_tests/services/test_app_generate_service.py
Normal file
65
api/tests/unit_tests/services/test_app_generate_service.py
Normal file
@ -0,0 +1,65 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import services.app_generate_service as app_generate_service_module
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
||||
|
||||
class _DummyRateLimit:
|
||||
def __init__(self, client_id: str, max_active_requests: int) -> None:
|
||||
self.client_id = client_id
|
||||
self.max_active_requests = max_active_requests
|
||||
|
||||
@staticmethod
|
||||
def gen_request_key() -> str:
|
||||
return "dummy-request-id"
|
||||
|
||||
def enter(self, request_id: str | None = None) -> str:
|
||||
return request_id or "dummy-request-id"
|
||||
|
||||
def exit(self, request_id: str) -> None:
|
||||
return None
|
||||
|
||||
def generate(self, generator, request_id: str):
|
||||
return generator
|
||||
|
||||
|
||||
def test_workflow_blocking_injects_pause_state_config(mocker, monkeypatch):
|
||||
monkeypatch.setattr(app_generate_service_module.dify_config, "BILLING_ENABLED", False)
|
||||
mocker.patch("services.app_generate_service.RateLimit", _DummyRateLimit)
|
||||
|
||||
workflow = MagicMock()
|
||||
workflow.id = "workflow-id"
|
||||
workflow.created_by = "owner-id"
|
||||
|
||||
mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow)
|
||||
|
||||
generator_spy = mocker.patch(
|
||||
"services.app_generate_service.WorkflowAppGenerator.generate",
|
||||
return_value={"result": "ok"},
|
||||
)
|
||||
|
||||
app_model = MagicMock()
|
||||
app_model.mode = AppMode.WORKFLOW
|
||||
app_model.id = "app-id"
|
||||
app_model.tenant_id = "tenant-id"
|
||||
app_model.max_active_requests = 0
|
||||
app_model.is_agent = False
|
||||
|
||||
user = MagicMock()
|
||||
user.id = "user-id"
|
||||
|
||||
result = AppGenerateService.generate(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
args={"inputs": {"k": "v"}},
|
||||
invoke_from=MagicMock(),
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
assert result == {"result": "ok"}
|
||||
|
||||
call_kwargs = generator_spy.call_args.kwargs
|
||||
pause_state_config = call_kwargs.get("pause_state_config")
|
||||
assert pause_state_config is not None
|
||||
assert pause_state_config.state_owner_user_id == "owner-id"
|
||||
@ -508,9 +508,12 @@ class TestConversationServiceMessageCreation:
|
||||
within conversations.
|
||||
"""
|
||||
|
||||
@patch("services.message_service._create_execution_extra_content_repository")
|
||||
@patch("services.message_service.db.session")
|
||||
@patch("services.message_service.ConversationService.get_conversation")
|
||||
def test_pagination_by_first_id_without_first_id(self, mock_get_conversation, mock_db_session):
|
||||
def test_pagination_by_first_id_without_first_id(
|
||||
self, mock_get_conversation, mock_db_session, mock_create_extra_repo
|
||||
):
|
||||
"""
|
||||
Test message pagination without specifying first_id.
|
||||
|
||||
@ -540,6 +543,9 @@ class TestConversationServiceMessageCreation:
|
||||
mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining
|
||||
mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
|
||||
mock_query.all.return_value = messages # Final .all() returns the messages
|
||||
mock_repository = MagicMock()
|
||||
mock_repository.get_by_message_ids.return_value = [[] for _ in messages]
|
||||
mock_create_extra_repo.return_value = mock_repository
|
||||
|
||||
# Act - Call the pagination method without first_id
|
||||
result = MessageService.pagination_by_first_id(
|
||||
@ -556,9 +562,10 @@ class TestConversationServiceMessageCreation:
|
||||
# Verify conversation was looked up with correct parameters
|
||||
mock_get_conversation.assert_called_once_with(app_model=app_model, user=user, conversation_id=conversation.id)
|
||||
|
||||
@patch("services.message_service._create_execution_extra_content_repository")
|
||||
@patch("services.message_service.db.session")
|
||||
@patch("services.message_service.ConversationService.get_conversation")
|
||||
def test_pagination_by_first_id_with_first_id(self, mock_get_conversation, mock_db_session):
|
||||
def test_pagination_by_first_id_with_first_id(self, mock_get_conversation, mock_db_session, mock_create_extra_repo):
|
||||
"""
|
||||
Test message pagination with first_id specified.
|
||||
|
||||
@ -590,6 +597,9 @@ class TestConversationServiceMessageCreation:
|
||||
mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
|
||||
mock_query.first.return_value = first_message # First message returned
|
||||
mock_query.all.return_value = messages # Remaining messages returned
|
||||
mock_repository = MagicMock()
|
||||
mock_repository.get_by_message_ids.return_value = [[] for _ in messages]
|
||||
mock_create_extra_repo.return_value = mock_repository
|
||||
|
||||
# Act - Call the pagination method with first_id
|
||||
result = MessageService.pagination_by_first_id(
|
||||
@ -684,9 +694,10 @@ class TestConversationServiceMessageCreation:
|
||||
assert result.data == []
|
||||
assert result.has_more is False
|
||||
|
||||
@patch("services.message_service._create_execution_extra_content_repository")
|
||||
@patch("services.message_service.db.session")
|
||||
@patch("services.message_service.ConversationService.get_conversation")
|
||||
def test_pagination_with_has_more_flag(self, mock_get_conversation, mock_db_session):
|
||||
def test_pagination_with_has_more_flag(self, mock_get_conversation, mock_db_session, mock_create_extra_repo):
|
||||
"""
|
||||
Test that has_more flag is correctly set when there are more messages.
|
||||
|
||||
@ -716,6 +727,9 @@ class TestConversationServiceMessageCreation:
|
||||
mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining
|
||||
mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
|
||||
mock_query.all.return_value = messages # Final .all() returns the messages
|
||||
mock_repository = MagicMock()
|
||||
mock_repository.get_by_message_ids.return_value = [[] for _ in messages]
|
||||
mock_create_extra_repo.return_value = mock_repository
|
||||
|
||||
# Act
|
||||
result = MessageService.pagination_by_first_id(
|
||||
@ -730,9 +744,10 @@ class TestConversationServiceMessageCreation:
|
||||
assert len(result.data) == limit # Extra message should be removed
|
||||
assert result.has_more is True # Flag should be set
|
||||
|
||||
@patch("services.message_service._create_execution_extra_content_repository")
|
||||
@patch("services.message_service.db.session")
|
||||
@patch("services.message_service.ConversationService.get_conversation")
|
||||
def test_pagination_with_ascending_order(self, mock_get_conversation, mock_db_session):
|
||||
def test_pagination_with_ascending_order(self, mock_get_conversation, mock_db_session, mock_create_extra_repo):
|
||||
"""
|
||||
Test message pagination with ascending order.
|
||||
|
||||
@ -761,6 +776,9 @@ class TestConversationServiceMessageCreation:
|
||||
mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining
|
||||
mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
|
||||
mock_query.all.return_value = messages # Final .all() returns the messages
|
||||
mock_repository = MagicMock()
|
||||
mock_repository.get_by_message_ids.return_value = [[] for _ in messages]
|
||||
mock_create_extra_repo.return_value = mock_repository
|
||||
|
||||
# Act
|
||||
result = MessageService.pagination_by_first_id(
|
||||
|
||||
@ -0,0 +1,104 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from services import feature_service as feature_service_module
|
||||
from services.feature_service import FeatureModel, FeatureService
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class HumanInputEmailDeliveryCase:
|
||||
name: str
|
||||
enterprise_enabled: bool
|
||||
billing_enabled: bool
|
||||
tenant_id: str | None
|
||||
billing_feature_enabled: bool
|
||||
plan: str
|
||||
expected: bool
|
||||
|
||||
|
||||
CASES = [
|
||||
HumanInputEmailDeliveryCase(
|
||||
name="enterprise_enabled",
|
||||
enterprise_enabled=True,
|
||||
billing_enabled=True,
|
||||
tenant_id=None,
|
||||
billing_feature_enabled=False,
|
||||
plan=CloudPlan.SANDBOX,
|
||||
expected=True,
|
||||
),
|
||||
HumanInputEmailDeliveryCase(
|
||||
name="billing_disabled",
|
||||
enterprise_enabled=False,
|
||||
billing_enabled=False,
|
||||
tenant_id=None,
|
||||
billing_feature_enabled=False,
|
||||
plan=CloudPlan.SANDBOX,
|
||||
expected=True,
|
||||
),
|
||||
HumanInputEmailDeliveryCase(
|
||||
name="billing_enabled_requires_tenant",
|
||||
enterprise_enabled=False,
|
||||
billing_enabled=True,
|
||||
tenant_id=None,
|
||||
billing_feature_enabled=True,
|
||||
plan=CloudPlan.PROFESSIONAL,
|
||||
expected=False,
|
||||
),
|
||||
HumanInputEmailDeliveryCase(
|
||||
name="billing_feature_off",
|
||||
enterprise_enabled=False,
|
||||
billing_enabled=True,
|
||||
tenant_id="tenant-1",
|
||||
billing_feature_enabled=False,
|
||||
plan=CloudPlan.PROFESSIONAL,
|
||||
expected=False,
|
||||
),
|
||||
HumanInputEmailDeliveryCase(
|
||||
name="professional_plan",
|
||||
enterprise_enabled=False,
|
||||
billing_enabled=True,
|
||||
tenant_id="tenant-1",
|
||||
billing_feature_enabled=True,
|
||||
plan=CloudPlan.PROFESSIONAL,
|
||||
expected=True,
|
||||
),
|
||||
HumanInputEmailDeliveryCase(
|
||||
name="team_plan",
|
||||
enterprise_enabled=False,
|
||||
billing_enabled=True,
|
||||
tenant_id="tenant-1",
|
||||
billing_feature_enabled=True,
|
||||
plan=CloudPlan.TEAM,
|
||||
expected=True,
|
||||
),
|
||||
HumanInputEmailDeliveryCase(
|
||||
name="sandbox_plan",
|
||||
enterprise_enabled=False,
|
||||
billing_enabled=True,
|
||||
tenant_id="tenant-1",
|
||||
billing_feature_enabled=True,
|
||||
plan=CloudPlan.SANDBOX,
|
||||
expected=False,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("case", CASES, ids=lambda case: case.name)
|
||||
def test_resolve_human_input_email_delivery_enabled_matrix(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
case: HumanInputEmailDeliveryCase,
|
||||
):
|
||||
monkeypatch.setattr(feature_service_module.dify_config, "ENTERPRISE_ENABLED", case.enterprise_enabled)
|
||||
monkeypatch.setattr(feature_service_module.dify_config, "BILLING_ENABLED", case.billing_enabled)
|
||||
features = FeatureModel()
|
||||
features.billing.enabled = case.billing_feature_enabled
|
||||
features.billing.subscription.plan = case.plan
|
||||
|
||||
result = FeatureService._resolve_human_input_email_delivery_enabled(
|
||||
features=features,
|
||||
tenant_id=case.tenant_id,
|
||||
)
|
||||
|
||||
assert result is case.expected
|
||||
@ -0,0 +1,97 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
EmailDeliveryConfig,
|
||||
EmailDeliveryMethod,
|
||||
EmailRecipients,
|
||||
ExternalRecipient,
|
||||
)
|
||||
from core.workflow.runtime import VariablePool
|
||||
from services import human_input_delivery_test_service as service_module
|
||||
from services.human_input_delivery_test_service import (
|
||||
DeliveryTestContext,
|
||||
DeliveryTestError,
|
||||
EmailDeliveryTestHandler,
|
||||
)
|
||||
|
||||
|
||||
def _make_email_method() -> EmailDeliveryMethod:
|
||||
return EmailDeliveryMethod(
|
||||
config=EmailDeliveryConfig(
|
||||
recipients=EmailRecipients(
|
||||
whole_workspace=False,
|
||||
items=[ExternalRecipient(email="tester@example.com")],
|
||||
),
|
||||
subject="Test subject",
|
||||
body="Test body",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_email_delivery_test_handler_rejects_when_feature_disabled(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(
|
||||
service_module.FeatureService,
|
||||
"get_features",
|
||||
lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=False),
|
||||
)
|
||||
|
||||
handler = EmailDeliveryTestHandler(session_factory=object())
|
||||
context = DeliveryTestContext(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
node_id="node-1",
|
||||
node_title="Human Input",
|
||||
rendered_content="content",
|
||||
)
|
||||
method = _make_email_method()
|
||||
|
||||
with pytest.raises(DeliveryTestError, match="Email delivery is not available"):
|
||||
handler.send_test(context=context, method=method)
|
||||
|
||||
|
||||
def test_email_delivery_test_handler_replaces_body_variables(monkeypatch: pytest.MonkeyPatch):
|
||||
class DummyMail:
|
||||
def __init__(self):
|
||||
self.sent: list[dict[str, str]] = []
|
||||
|
||||
def is_inited(self) -> bool:
|
||||
return True
|
||||
|
||||
def send(self, *, to: str, subject: str, html: str):
|
||||
self.sent.append({"to": to, "subject": subject, "html": html})
|
||||
|
||||
mail = DummyMail()
|
||||
monkeypatch.setattr(service_module, "mail", mail)
|
||||
monkeypatch.setattr(service_module, "render_email_template", lambda template, _substitutions: template)
|
||||
monkeypatch.setattr(
|
||||
service_module.FeatureService,
|
||||
"get_features",
|
||||
lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=True),
|
||||
)
|
||||
|
||||
handler = EmailDeliveryTestHandler(session_factory=object())
|
||||
handler._resolve_recipients = lambda **_kwargs: ["tester@example.com"] # type: ignore[assignment]
|
||||
|
||||
method = EmailDeliveryMethod(
|
||||
config=EmailDeliveryConfig(
|
||||
recipients=EmailRecipients(whole_workspace=False, items=[ExternalRecipient(email="tester@example.com")]),
|
||||
subject="Subject",
|
||||
body="Value {{#node1.value#}}",
|
||||
)
|
||||
)
|
||||
variable_pool = VariablePool()
|
||||
variable_pool.add(["node1", "value"], "OK")
|
||||
context = DeliveryTestContext(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
node_id="node-1",
|
||||
node_title="Human Input",
|
||||
rendered_content="content",
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
handler.send_test(context=context, method=method)
|
||||
|
||||
assert mail.sent[0]["html"] == "Value OK"
|
||||
290
api/tests/unit_tests/services/test_human_input_service.py
Normal file
290
api/tests/unit_tests/services/test_human_input_service.py
Normal file
@ -0,0 +1,290 @@
|
||||
import dataclasses
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import services.human_input_service as human_input_service_module
|
||||
from core.repositories.human_input_repository import (
|
||||
HumanInputFormRecord,
|
||||
HumanInputFormSubmissionRepository,
|
||||
)
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
FormDefinition,
|
||||
FormInput,
|
||||
UserAction,
|
||||
)
|
||||
from core.workflow.nodes.human_input.enums import FormInputType, HumanInputFormKind, HumanInputFormStatus
|
||||
from models.human_input import RecipientType
|
||||
from services.human_input_service import Form, FormExpiredError, HumanInputService, InvalidFormDataError
|
||||
from tasks.app_generate.workflow_execute_task import WORKFLOW_BASED_APP_EXECUTION_QUEUE
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_factory():
|
||||
session = MagicMock()
|
||||
session_cm = MagicMock()
|
||||
session_cm.__enter__.return_value = session
|
||||
session_cm.__exit__.return_value = None
|
||||
|
||||
factory = MagicMock()
|
||||
factory.return_value = session_cm
|
||||
return factory, session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_form_record():
|
||||
return HumanInputFormRecord(
|
||||
form_id="form-id",
|
||||
workflow_run_id="workflow-run-id",
|
||||
node_id="node-id",
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
form_kind=HumanInputFormKind.RUNTIME,
|
||||
definition=FormDefinition(
|
||||
form_content="hello",
|
||||
inputs=[],
|
||||
user_actions=[UserAction(id="submit", title="Submit")],
|
||||
rendered_content="<p>hello</p>",
|
||||
expiration_time=datetime.utcnow() + timedelta(hours=1),
|
||||
),
|
||||
rendered_content="<p>hello</p>",
|
||||
created_at=datetime.utcnow(),
|
||||
expiration_time=datetime.utcnow() + timedelta(hours=1),
|
||||
status=HumanInputFormStatus.WAITING,
|
||||
selected_action_id=None,
|
||||
submitted_data=None,
|
||||
submitted_at=None,
|
||||
submission_user_id=None,
|
||||
submission_end_user_id=None,
|
||||
completed_by_recipient_id=None,
|
||||
recipient_id="recipient-id",
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
access_token="token",
|
||||
)
|
||||
|
||||
|
||||
def test_enqueue_resume_dispatches_task_for_workflow(mocker, mock_session_factory):
|
||||
session_factory, session = mock_session_factory
|
||||
service = HumanInputService(session_factory)
|
||||
|
||||
workflow_run = MagicMock()
|
||||
workflow_run.app_id = "app-id"
|
||||
|
||||
workflow_run_repo = MagicMock()
|
||||
workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run
|
||||
mocker.patch(
|
||||
"services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository",
|
||||
return_value=workflow_run_repo,
|
||||
)
|
||||
|
||||
app = MagicMock()
|
||||
app.mode = "workflow"
|
||||
session.execute.return_value.scalar_one_or_none.return_value = app
|
||||
|
||||
resume_task = mocker.patch("services.human_input_service.resume_app_execution")
|
||||
|
||||
service.enqueue_resume("workflow-run-id")
|
||||
|
||||
resume_task.apply_async.assert_called_once()
|
||||
call_kwargs = resume_task.apply_async.call_args.kwargs
|
||||
assert call_kwargs["queue"] == WORKFLOW_BASED_APP_EXECUTION_QUEUE
|
||||
assert call_kwargs["kwargs"]["payload"]["workflow_run_id"] == "workflow-run-id"
|
||||
|
||||
|
||||
def test_ensure_form_active_respects_global_timeout(monkeypatch, sample_form_record, mock_session_factory):
|
||||
session_factory, _ = mock_session_factory
|
||||
service = HumanInputService(session_factory)
|
||||
expired_record = dataclasses.replace(
|
||||
sample_form_record,
|
||||
created_at=datetime.utcnow() - timedelta(hours=2),
|
||||
expiration_time=datetime.utcnow() + timedelta(hours=2),
|
||||
)
|
||||
monkeypatch.setattr(human_input_service_module.dify_config, "HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS", 3600)
|
||||
|
||||
with pytest.raises(FormExpiredError):
|
||||
service.ensure_form_active(Form(expired_record))
|
||||
|
||||
|
||||
def test_enqueue_resume_dispatches_task_for_advanced_chat(mocker, mock_session_factory):
|
||||
session_factory, session = mock_session_factory
|
||||
service = HumanInputService(session_factory)
|
||||
|
||||
workflow_run = MagicMock()
|
||||
workflow_run.app_id = "app-id"
|
||||
|
||||
workflow_run_repo = MagicMock()
|
||||
workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run
|
||||
mocker.patch(
|
||||
"services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository",
|
||||
return_value=workflow_run_repo,
|
||||
)
|
||||
|
||||
app = MagicMock()
|
||||
app.mode = "advanced-chat"
|
||||
session.execute.return_value.scalar_one_or_none.return_value = app
|
||||
|
||||
resume_task = mocker.patch("services.human_input_service.resume_app_execution")
|
||||
|
||||
service.enqueue_resume("workflow-run-id")
|
||||
|
||||
resume_task.apply_async.assert_called_once()
|
||||
call_kwargs = resume_task.apply_async.call_args.kwargs
|
||||
assert call_kwargs["queue"] == WORKFLOW_BASED_APP_EXECUTION_QUEUE
|
||||
assert call_kwargs["kwargs"]["payload"]["workflow_run_id"] == "workflow-run-id"
|
||||
|
||||
|
||||
def test_enqueue_resume_skips_unsupported_app_mode(mocker, mock_session_factory):
|
||||
session_factory, session = mock_session_factory
|
||||
service = HumanInputService(session_factory)
|
||||
|
||||
workflow_run = MagicMock()
|
||||
workflow_run.app_id = "app-id"
|
||||
|
||||
workflow_run_repo = MagicMock()
|
||||
workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run
|
||||
mocker.patch(
|
||||
"services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository",
|
||||
return_value=workflow_run_repo,
|
||||
)
|
||||
|
||||
app = MagicMock()
|
||||
app.mode = "completion"
|
||||
session.execute.return_value.scalar_one_or_none.return_value = app
|
||||
|
||||
resume_task = mocker.patch("services.human_input_service.resume_app_execution")
|
||||
|
||||
service.enqueue_resume("workflow-run-id")
|
||||
|
||||
resume_task.apply_async.assert_not_called()
|
||||
|
||||
|
||||
def test_get_form_definition_by_token_for_console_uses_repository(sample_form_record, mock_session_factory):
|
||||
session_factory, _ = mock_session_factory
|
||||
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
|
||||
console_record = dataclasses.replace(sample_form_record, recipient_type=RecipientType.CONSOLE)
|
||||
repo.get_by_token.return_value = console_record
|
||||
|
||||
service = HumanInputService(session_factory, form_repository=repo)
|
||||
form = service.get_form_definition_by_token_for_console("token")
|
||||
|
||||
repo.get_by_token.assert_called_once_with("token")
|
||||
assert form is not None
|
||||
assert form.get_definition() == console_record.definition
|
||||
|
||||
|
||||
def test_submit_form_by_token_calls_repository_and_enqueue(sample_form_record, mock_session_factory, mocker):
|
||||
session_factory, _ = mock_session_factory
|
||||
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
|
||||
repo.get_by_token.return_value = sample_form_record
|
||||
repo.mark_submitted.return_value = sample_form_record
|
||||
service = HumanInputService(session_factory, form_repository=repo)
|
||||
enqueue_spy = mocker.patch.object(service, "enqueue_resume")
|
||||
|
||||
service.submit_form_by_token(
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
form_token="token",
|
||||
selected_action_id="submit",
|
||||
form_data={"field": "value"},
|
||||
submission_end_user_id="end-user-id",
|
||||
)
|
||||
|
||||
repo.get_by_token.assert_called_once_with("token")
|
||||
repo.mark_submitted.assert_called_once()
|
||||
call_kwargs = repo.mark_submitted.call_args.kwargs
|
||||
assert call_kwargs["form_id"] == sample_form_record.form_id
|
||||
assert call_kwargs["recipient_id"] == sample_form_record.recipient_id
|
||||
assert call_kwargs["selected_action_id"] == "submit"
|
||||
assert call_kwargs["form_data"] == {"field": "value"}
|
||||
assert call_kwargs["submission_end_user_id"] == "end-user-id"
|
||||
enqueue_spy.assert_called_once_with(sample_form_record.workflow_run_id)
|
||||
|
||||
|
||||
def test_submit_form_by_token_skips_enqueue_for_delivery_test(sample_form_record, mock_session_factory, mocker):
|
||||
session_factory, _ = mock_session_factory
|
||||
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
|
||||
test_record = dataclasses.replace(
|
||||
sample_form_record,
|
||||
form_kind=HumanInputFormKind.DELIVERY_TEST,
|
||||
workflow_run_id=None,
|
||||
)
|
||||
repo.get_by_token.return_value = test_record
|
||||
repo.mark_submitted.return_value = test_record
|
||||
service = HumanInputService(session_factory, form_repository=repo)
|
||||
enqueue_spy = mocker.patch.object(service, "enqueue_resume")
|
||||
|
||||
service.submit_form_by_token(
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
form_token="token",
|
||||
selected_action_id="submit",
|
||||
form_data={"field": "value"},
|
||||
)
|
||||
|
||||
enqueue_spy.assert_not_called()
|
||||
|
||||
|
||||
def test_submit_form_by_token_passes_submission_user_id(sample_form_record, mock_session_factory, mocker):
|
||||
session_factory, _ = mock_session_factory
|
||||
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
|
||||
repo.get_by_token.return_value = sample_form_record
|
||||
repo.mark_submitted.return_value = sample_form_record
|
||||
service = HumanInputService(session_factory, form_repository=repo)
|
||||
enqueue_spy = mocker.patch.object(service, "enqueue_resume")
|
||||
|
||||
service.submit_form_by_token(
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
form_token="token",
|
||||
selected_action_id="submit",
|
||||
form_data={"field": "value"},
|
||||
submission_user_id="account-id",
|
||||
)
|
||||
|
||||
call_kwargs = repo.mark_submitted.call_args.kwargs
|
||||
assert call_kwargs["submission_user_id"] == "account-id"
|
||||
assert call_kwargs["submission_end_user_id"] is None
|
||||
enqueue_spy.assert_called_once_with(sample_form_record.workflow_run_id)
|
||||
|
||||
|
||||
def test_submit_form_by_token_invalid_action(sample_form_record, mock_session_factory):
|
||||
session_factory, _ = mock_session_factory
|
||||
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
|
||||
repo.get_by_token.return_value = dataclasses.replace(sample_form_record)
|
||||
service = HumanInputService(session_factory, form_repository=repo)
|
||||
|
||||
with pytest.raises(InvalidFormDataError) as exc_info:
|
||||
service.submit_form_by_token(
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
form_token="token",
|
||||
selected_action_id="invalid",
|
||||
form_data={},
|
||||
)
|
||||
|
||||
assert "Invalid action" in str(exc_info.value)
|
||||
repo.mark_submitted.assert_not_called()
|
||||
|
||||
|
||||
def test_submit_form_by_token_missing_inputs(sample_form_record, mock_session_factory):
|
||||
session_factory, _ = mock_session_factory
|
||||
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
|
||||
|
||||
definition_with_input = FormDefinition(
|
||||
form_content="hello",
|
||||
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="content")],
|
||||
user_actions=sample_form_record.definition.user_actions,
|
||||
rendered_content="<p>hello</p>",
|
||||
expiration_time=sample_form_record.expiration_time,
|
||||
)
|
||||
form_with_input = dataclasses.replace(sample_form_record, definition=definition_with_input)
|
||||
repo.get_by_token.return_value = form_with_input
|
||||
service = HumanInputService(session_factory, form_repository=repo)
|
||||
|
||||
with pytest.raises(InvalidFormDataError) as exc_info:
|
||||
service.submit_form_by_token(
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
form_token="token",
|
||||
selected_action_id="submit",
|
||||
form_data={},
|
||||
)
|
||||
|
||||
assert "Missing required inputs" in str(exc_info.value)
|
||||
repo.mark_submitted.assert_not_called()
|
||||
@ -0,0 +1,61 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.entities.execution_extra_content import HumanInputContent, HumanInputFormSubmissionData
|
||||
from services import message_service
|
||||
|
||||
|
||||
class _FakeMessage:
|
||||
def __init__(self, message_id: str):
|
||||
self.id = message_id
|
||||
self.extra_contents = None
|
||||
|
||||
def set_extra_contents(self, contents):
|
||||
self.extra_contents = contents
|
||||
|
||||
|
||||
def test_attach_message_extra_contents_assigns_serialized_payload(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
messages = [_FakeMessage("msg-1"), _FakeMessage("msg-2")]
|
||||
repo = type(
|
||||
"Repo",
|
||||
(),
|
||||
{
|
||||
"get_by_message_ids": lambda _self, message_ids: [
|
||||
[
|
||||
HumanInputContent(
|
||||
workflow_run_id="workflow-run-1",
|
||||
submitted=True,
|
||||
form_submission_data=HumanInputFormSubmissionData(
|
||||
node_id="node-1",
|
||||
node_title="Approval",
|
||||
rendered_content="Rendered",
|
||||
action_id="approve",
|
||||
action_text="Approve",
|
||||
),
|
||||
)
|
||||
],
|
||||
[],
|
||||
]
|
||||
},
|
||||
)()
|
||||
|
||||
monkeypatch.setattr(message_service, "_create_execution_extra_content_repository", lambda: repo)
|
||||
|
||||
message_service.attach_message_extra_contents(messages)
|
||||
|
||||
assert messages[0].extra_contents == [
|
||||
{
|
||||
"type": "human_input",
|
||||
"workflow_run_id": "workflow-run-1",
|
||||
"submitted": True,
|
||||
"form_submission_data": {
|
||||
"node_id": "node-1",
|
||||
"node_title": "Approval",
|
||||
"rendered_content": "Rendered",
|
||||
"action_id": "approve",
|
||||
"action_text": "Approve",
|
||||
},
|
||||
}
|
||||
]
|
||||
assert messages[1].extra_contents == []
|
||||
@ -35,7 +35,6 @@ class TestDataFactory:
|
||||
app_id: str = "app-789",
|
||||
workflow_id: str = "workflow-101",
|
||||
status: str | WorkflowExecutionStatus = "paused",
|
||||
pause_id: str | None = None,
|
||||
**kwargs,
|
||||
) -> MagicMock:
|
||||
"""Create a mock WorkflowRun object."""
|
||||
@ -45,7 +44,6 @@ class TestDataFactory:
|
||||
mock_run.app_id = app_id
|
||||
mock_run.workflow_id = workflow_id
|
||||
mock_run.status = status
|
||||
mock_run.pause_id = pause_id
|
||||
|
||||
for key, value in kwargs.items():
|
||||
setattr(mock_run, key, value)
|
||||
|
||||
@ -0,0 +1,162 @@
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.tools.entities.tool_entities import ToolParameter, WorkflowToolParameterConfiguration
|
||||
from core.tools.errors import WorkflowToolHumanInputNotSupportedError
|
||||
from models.model import App
|
||||
from models.tools import WorkflowToolProvider
|
||||
from services.tools import workflow_tools_manage_service
|
||||
|
||||
|
||||
class DummyWorkflow:
|
||||
def __init__(self, graph_dict: dict, version: str = "1.0.0") -> None:
|
||||
self._graph_dict = graph_dict
|
||||
self.version = version
|
||||
|
||||
@property
|
||||
def graph_dict(self) -> dict:
|
||||
return self._graph_dict
|
||||
|
||||
|
||||
class FakeQuery:
|
||||
def __init__(self, result):
|
||||
self._result = result
|
||||
|
||||
def where(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return self._result
|
||||
|
||||
|
||||
class DummySession:
|
||||
def __init__(self) -> None:
|
||||
self.added: list[object] = []
|
||||
|
||||
def __enter__(self) -> "DummySession":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb) -> bool:
|
||||
return False
|
||||
|
||||
def add(self, obj) -> None:
|
||||
self.added.append(obj)
|
||||
|
||||
def begin(self):
|
||||
return DummyBegin(self)
|
||||
|
||||
|
||||
class DummyBegin:
|
||||
def __init__(self, session: DummySession) -> None:
|
||||
self._session = session
|
||||
|
||||
def __enter__(self) -> DummySession:
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class DummySessionContext:
|
||||
def __init__(self, session: DummySession) -> None:
|
||||
self._session = session
|
||||
|
||||
def __enter__(self) -> DummySession:
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class DummySessionFactory:
|
||||
def __init__(self, session: DummySession) -> None:
|
||||
self._session = session
|
||||
|
||||
def create_session(self) -> DummySessionContext:
|
||||
return DummySessionContext(self._session)
|
||||
|
||||
|
||||
def _build_fake_session(app) -> SimpleNamespace:
|
||||
def query(model):
|
||||
if model is WorkflowToolProvider:
|
||||
return FakeQuery(None)
|
||||
if model is App:
|
||||
return FakeQuery(app)
|
||||
return FakeQuery(None)
|
||||
|
||||
return SimpleNamespace(query=query)
|
||||
|
||||
|
||||
def _build_parameters() -> list[WorkflowToolParameterConfiguration]:
|
||||
return [
|
||||
WorkflowToolParameterConfiguration(name="input", description="input", form=ToolParameter.ToolParameterForm.LLM),
|
||||
]
|
||||
|
||||
|
||||
def test_create_workflow_tool_rejects_human_input_nodes(monkeypatch):
|
||||
workflow = DummyWorkflow(graph_dict={"nodes": [{"id": "node_1", "data": {"type": "human-input"}}]})
|
||||
app = SimpleNamespace(workflow=workflow)
|
||||
|
||||
fake_session = _build_fake_session(app)
|
||||
monkeypatch.setattr(workflow_tools_manage_service.db, "session", fake_session)
|
||||
|
||||
mock_from_db = MagicMock()
|
||||
monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", mock_from_db)
|
||||
mock_invalidate = MagicMock()
|
||||
|
||||
with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info:
|
||||
workflow_tools_manage_service.WorkflowToolManageService.create_workflow_tool(
|
||||
user_id="user-id",
|
||||
tenant_id="tenant-id",
|
||||
workflow_app_id="app-id",
|
||||
name="tool_name",
|
||||
label="Tool",
|
||||
icon={"type": "emoji", "emoji": "tool"},
|
||||
description="desc",
|
||||
parameters=_build_parameters(),
|
||||
)
|
||||
|
||||
assert exc_info.value.error_code == "workflow_tool_human_input_not_supported"
|
||||
mock_from_db.assert_not_called()
|
||||
mock_invalidate.assert_not_called()
|
||||
|
||||
|
||||
def test_create_workflow_tool_success(monkeypatch):
|
||||
workflow = DummyWorkflow(graph_dict={"nodes": [{"id": "node_1", "data": {"type": "start"}}]})
|
||||
app = SimpleNamespace(workflow=workflow)
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_session = _build_fake_session(app)
|
||||
fake_db.session = fake_session
|
||||
monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db)
|
||||
|
||||
dummy_session = DummySession()
|
||||
monkeypatch.setattr(workflow_tools_manage_service, "Session", lambda *_, **__: dummy_session)
|
||||
|
||||
mock_from_db = MagicMock()
|
||||
monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", mock_from_db)
|
||||
|
||||
icon = {"type": "emoji", "emoji": "tool"}
|
||||
|
||||
result = workflow_tools_manage_service.WorkflowToolManageService.create_workflow_tool(
|
||||
user_id="user-id",
|
||||
tenant_id="tenant-id",
|
||||
workflow_app_id="app-id",
|
||||
name="tool_name",
|
||||
label="Tool",
|
||||
icon=icon,
|
||||
description="desc",
|
||||
parameters=_build_parameters(),
|
||||
)
|
||||
|
||||
assert result == {"result": "success"}
|
||||
assert len(dummy_session.added) == 1
|
||||
created_provider = dummy_session.added[0]
|
||||
assert created_provider.name == "tool_name"
|
||||
assert created_provider.label == "Tool"
|
||||
assert created_provider.icon == json.dumps(icon)
|
||||
assert created_provider.version == workflow.version
|
||||
mock_from_db.assert_called_once()
|
||||
@ -0,0 +1,226 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import queue
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from threading import Event
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import WorkflowUIBasedAppConfig
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
from services.workflow_event_snapshot_service import (
|
||||
BufferState,
|
||||
MessageContext,
|
||||
_build_snapshot_events,
|
||||
_resolve_task_id,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _FakePauseEntity(WorkflowPauseEntity):
|
||||
pause_id: str
|
||||
workflow_run_id: str
|
||||
paused_at_value: datetime
|
||||
pause_reasons: Sequence[HumanInputRequired]
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self.pause_id
|
||||
|
||||
@property
|
||||
def workflow_execution_id(self) -> str:
|
||||
return self.workflow_run_id
|
||||
|
||||
def get_state(self) -> bytes:
|
||||
raise AssertionError("state is not required for snapshot tests")
|
||||
|
||||
@property
|
||||
def resumed_at(self) -> datetime | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def paused_at(self) -> datetime:
|
||||
return self.paused_at_value
|
||||
|
||||
def get_pause_reasons(self) -> Sequence[HumanInputRequired]:
|
||||
return self.pause_reasons
|
||||
|
||||
|
||||
def _build_workflow_run(status: WorkflowExecutionStatus) -> WorkflowRun:
|
||||
return WorkflowRun(
|
||||
id="run-1",
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
workflow_id="workflow-1",
|
||||
type="workflow",
|
||||
triggered_from="app-run",
|
||||
version="v1",
|
||||
graph=None,
|
||||
inputs=json.dumps({"input": "value"}),
|
||||
status=status,
|
||||
outputs=json.dumps({}),
|
||||
error=None,
|
||||
elapsed_time=0.0,
|
||||
total_tokens=0,
|
||||
total_steps=0,
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="user-1",
|
||||
created_at=datetime(2024, 1, 1, tzinfo=UTC),
|
||||
)
|
||||
|
||||
|
||||
def _build_snapshot(status: WorkflowNodeExecutionStatus) -> WorkflowNodeExecutionSnapshot:
|
||||
created_at = datetime(2024, 1, 1, tzinfo=UTC)
|
||||
finished_at = datetime(2024, 1, 1, 0, 0, 5, tzinfo=UTC)
|
||||
return WorkflowNodeExecutionSnapshot(
|
||||
execution_id="exec-1",
|
||||
node_id="node-1",
|
||||
node_type="human-input",
|
||||
title="Human Input",
|
||||
index=1,
|
||||
status=status.value,
|
||||
elapsed_time=0.5,
|
||||
created_at=created_at,
|
||||
finished_at=finished_at,
|
||||
iteration_id=None,
|
||||
loop_id=None,
|
||||
)
|
||||
|
||||
|
||||
def _build_resumption_context(task_id: str) -> WorkflowResumptionContext:
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_id="workflow-1",
|
||||
)
|
||||
generate_entity = WorkflowAppGenerateEntity(
|
||||
task_id=task_id,
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id="user-1",
|
||||
stream=True,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
call_depth=0,
|
||||
workflow_execution_id="run-1",
|
||||
)
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
|
||||
runtime_state.register_paused_node("node-1")
|
||||
runtime_state.outputs = {"result": "value"}
|
||||
wrapper = _WorkflowGenerateEntityWrapper(entity=generate_entity)
|
||||
return WorkflowResumptionContext(
|
||||
generate_entity=wrapper,
|
||||
serialized_graph_runtime_state=runtime_state.dumps(),
|
||||
)
|
||||
|
||||
|
||||
def test_build_snapshot_events_includes_pause_event() -> None:
|
||||
workflow_run = _build_workflow_run(WorkflowExecutionStatus.PAUSED)
|
||||
snapshot = _build_snapshot(WorkflowNodeExecutionStatus.PAUSED)
|
||||
resumption_context = _build_resumption_context("task-ctx")
|
||||
pause_entity = _FakePauseEntity(
|
||||
pause_id="pause-1",
|
||||
workflow_run_id="run-1",
|
||||
paused_at_value=datetime(2024, 1, 1, tzinfo=UTC),
|
||||
pause_reasons=[
|
||||
HumanInputRequired(
|
||||
form_id="form-1",
|
||||
form_content="content",
|
||||
node_id="node-1",
|
||||
node_title="Human Input",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
events = _build_snapshot_events(
|
||||
workflow_run=workflow_run,
|
||||
node_snapshots=[snapshot],
|
||||
task_id="task-ctx",
|
||||
message_context=None,
|
||||
pause_entity=pause_entity,
|
||||
resumption_context=resumption_context,
|
||||
)
|
||||
|
||||
assert [event["event"] for event in events] == [
|
||||
"workflow_started",
|
||||
"node_started",
|
||||
"node_finished",
|
||||
"workflow_paused",
|
||||
]
|
||||
assert events[2]["data"]["status"] == WorkflowNodeExecutionStatus.PAUSED.value
|
||||
pause_data = events[-1]["data"]
|
||||
assert pause_data["paused_nodes"] == ["node-1"]
|
||||
assert pause_data["outputs"] == {"result": "value"}
|
||||
assert pause_data["status"] == WorkflowExecutionStatus.PAUSED.value
|
||||
assert pause_data["created_at"] == int(workflow_run.created_at.timestamp())
|
||||
assert pause_data["elapsed_time"] == workflow_run.elapsed_time
|
||||
assert pause_data["total_tokens"] == workflow_run.total_tokens
|
||||
assert pause_data["total_steps"] == workflow_run.total_steps
|
||||
|
||||
|
||||
def test_build_snapshot_events_applies_message_context() -> None:
|
||||
workflow_run = _build_workflow_run(WorkflowExecutionStatus.RUNNING)
|
||||
snapshot = _build_snapshot(WorkflowNodeExecutionStatus.SUCCEEDED)
|
||||
message_context = MessageContext(
|
||||
conversation_id="conv-1",
|
||||
message_id="msg-1",
|
||||
created_at=1700000000,
|
||||
answer="snapshot message",
|
||||
)
|
||||
|
||||
events = _build_snapshot_events(
|
||||
workflow_run=workflow_run,
|
||||
node_snapshots=[snapshot],
|
||||
task_id="task-1",
|
||||
message_context=message_context,
|
||||
pause_entity=None,
|
||||
resumption_context=None,
|
||||
)
|
||||
|
||||
assert [event["event"] for event in events] == [
|
||||
"workflow_started",
|
||||
"message_replace",
|
||||
"node_started",
|
||||
"node_finished",
|
||||
]
|
||||
assert events[1]["answer"] == "snapshot message"
|
||||
for event in events:
|
||||
assert event["conversation_id"] == "conv-1"
|
||||
assert event["message_id"] == "msg-1"
|
||||
assert event["created_at"] == 1700000000
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("context_task_id", "buffered_task_id", "expected"),
|
||||
[
|
||||
("task-ctx", "task-buffer", "task-ctx"),
|
||||
(None, "task-buffer", "task-buffer"),
|
||||
(None, None, "run-1"),
|
||||
],
|
||||
)
|
||||
def test_resolve_task_id_priority(context_task_id, buffered_task_id, expected) -> None:
|
||||
resumption_context = _build_resumption_context(context_task_id) if context_task_id else None
|
||||
buffer_state = BufferState(
|
||||
queue=queue.Queue(),
|
||||
stop_event=Event(),
|
||||
done_event=Event(),
|
||||
task_id_ready=Event(),
|
||||
task_id_hint=buffered_task_id,
|
||||
)
|
||||
if buffered_task_id:
|
||||
buffer_state.task_id_ready.set()
|
||||
task_id = _resolve_task_id(resumption_context, buffer_state, "run-1", wait_timeout=0.0)
|
||||
assert task_id == expected
|
||||
@ -0,0 +1,184 @@
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
EmailDeliveryConfig,
|
||||
EmailDeliveryMethod,
|
||||
EmailRecipients,
|
||||
ExternalRecipient,
|
||||
HumanInputNodeData,
|
||||
MemberRecipient,
|
||||
)
|
||||
from services import workflow_service as workflow_service_module
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
def _make_service() -> WorkflowService:
|
||||
return WorkflowService(session_maker=sessionmaker())
|
||||
|
||||
|
||||
def _build_node_config(delivery_methods):
|
||||
node_data = HumanInputNodeData(
|
||||
title="Human Input",
|
||||
delivery_methods=delivery_methods,
|
||||
form_content="Test content",
|
||||
inputs=[],
|
||||
user_actions=[],
|
||||
).model_dump(mode="json")
|
||||
node_data["type"] = NodeType.HUMAN_INPUT.value
|
||||
return {"id": "node-1", "data": node_data}
|
||||
|
||||
|
||||
def _make_email_method(enabled: bool = True, debug_mode: bool = False) -> EmailDeliveryMethod:
|
||||
return EmailDeliveryMethod(
|
||||
id=uuid.uuid4(),
|
||||
enabled=enabled,
|
||||
config=EmailDeliveryConfig(
|
||||
recipients=EmailRecipients(
|
||||
whole_workspace=False,
|
||||
items=[ExternalRecipient(email="tester@example.com")],
|
||||
),
|
||||
subject="Test subject",
|
||||
body="Test body",
|
||||
debug_mode=debug_mode,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_human_input_delivery_requires_draft_workflow():
|
||||
service = _make_service()
|
||||
service.get_draft_workflow = MagicMock(return_value=None) # type: ignore[method-assign]
|
||||
app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1")
|
||||
account = SimpleNamespace(id="account-1")
|
||||
|
||||
with pytest.raises(ValueError, match="Workflow not initialized"):
|
||||
service.test_human_input_delivery(
|
||||
app_model=app_model,
|
||||
account=account,
|
||||
node_id="node-1",
|
||||
delivery_method_id="delivery-1",
|
||||
)
|
||||
|
||||
|
||||
def test_human_input_delivery_allows_disabled_method(monkeypatch: pytest.MonkeyPatch):
|
||||
service = _make_service()
|
||||
delivery_method = _make_email_method(enabled=False)
|
||||
node_config = _build_node_config([delivery_method])
|
||||
workflow = MagicMock()
|
||||
workflow.get_node_config_by_id.return_value = node_config
|
||||
service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign]
|
||||
service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[attr-defined]
|
||||
node_stub = MagicMock()
|
||||
node_stub._render_form_content_before_submission.return_value = "rendered"
|
||||
node_stub._resolve_default_values.return_value = {}
|
||||
service._build_human_input_node = MagicMock(return_value=node_stub) # type: ignore[attr-defined]
|
||||
service._create_human_input_delivery_test_form = MagicMock( # type: ignore[attr-defined]
|
||||
return_value=("form-1", {})
|
||||
)
|
||||
|
||||
test_service_instance = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
workflow_service_module,
|
||||
"HumanInputDeliveryTestService",
|
||||
MagicMock(return_value=test_service_instance),
|
||||
)
|
||||
|
||||
app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1")
|
||||
account = SimpleNamespace(id="account-1")
|
||||
|
||||
service.test_human_input_delivery(
|
||||
app_model=app_model,
|
||||
account=account,
|
||||
node_id="node-1",
|
||||
delivery_method_id=str(delivery_method.id),
|
||||
)
|
||||
|
||||
test_service_instance.send_test.assert_called_once()
|
||||
|
||||
|
||||
def test_human_input_delivery_dispatches_to_test_service(monkeypatch: pytest.MonkeyPatch):
|
||||
service = _make_service()
|
||||
delivery_method = _make_email_method(enabled=True)
|
||||
node_config = _build_node_config([delivery_method])
|
||||
workflow = MagicMock()
|
||||
workflow.get_node_config_by_id.return_value = node_config
|
||||
service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign]
|
||||
service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[attr-defined]
|
||||
node_stub = MagicMock()
|
||||
node_stub._render_form_content_before_submission.return_value = "rendered"
|
||||
node_stub._resolve_default_values.return_value = {}
|
||||
service._build_human_input_node = MagicMock(return_value=node_stub) # type: ignore[attr-defined]
|
||||
service._create_human_input_delivery_test_form = MagicMock( # type: ignore[attr-defined]
|
||||
return_value=("form-1", {})
|
||||
)
|
||||
|
||||
test_service_instance = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
workflow_service_module,
|
||||
"HumanInputDeliveryTestService",
|
||||
MagicMock(return_value=test_service_instance),
|
||||
)
|
||||
|
||||
app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1")
|
||||
account = SimpleNamespace(id="account-1")
|
||||
|
||||
service.test_human_input_delivery(
|
||||
app_model=app_model,
|
||||
account=account,
|
||||
node_id="node-1",
|
||||
delivery_method_id=str(delivery_method.id),
|
||||
inputs={"#node-1.output#": "value"},
|
||||
)
|
||||
|
||||
pool_args = service._build_human_input_variable_pool.call_args.kwargs
|
||||
assert pool_args["manual_inputs"] == {"#node-1.output#": "value"}
|
||||
test_service_instance.send_test.assert_called_once()
|
||||
|
||||
|
||||
def test_human_input_delivery_debug_mode_overrides_recipients(monkeypatch: pytest.MonkeyPatch):
|
||||
service = _make_service()
|
||||
delivery_method = _make_email_method(enabled=True, debug_mode=True)
|
||||
node_config = _build_node_config([delivery_method])
|
||||
workflow = MagicMock()
|
||||
workflow.get_node_config_by_id.return_value = node_config
|
||||
service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign]
|
||||
service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[attr-defined]
|
||||
node_stub = MagicMock()
|
||||
node_stub._render_form_content_before_submission.return_value = "rendered"
|
||||
node_stub._resolve_default_values.return_value = {}
|
||||
service._build_human_input_node = MagicMock(return_value=node_stub) # type: ignore[attr-defined]
|
||||
service._create_human_input_delivery_test_form = MagicMock( # type: ignore[attr-defined]
|
||||
return_value=("form-1", {})
|
||||
)
|
||||
|
||||
test_service_instance = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
workflow_service_module,
|
||||
"HumanInputDeliveryTestService",
|
||||
MagicMock(return_value=test_service_instance),
|
||||
)
|
||||
|
||||
app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1")
|
||||
account = SimpleNamespace(id="account-1")
|
||||
|
||||
service.test_human_input_delivery(
|
||||
app_model=app_model,
|
||||
account=account,
|
||||
node_id="node-1",
|
||||
delivery_method_id=str(delivery_method.id),
|
||||
)
|
||||
|
||||
test_service_instance.send_test.assert_called_once()
|
||||
sent_method = test_service_instance.send_test.call_args.kwargs["method"]
|
||||
assert isinstance(sent_method, EmailDeliveryMethod)
|
||||
assert sent_method.config.debug_mode is True
|
||||
assert sent_method.config.recipients.whole_workspace is False
|
||||
assert len(sent_method.config.recipients.items) == 1
|
||||
recipient = sent_method.config.recipients.items[0]
|
||||
assert isinstance(recipient, MemberRecipient)
|
||||
assert recipient.user_id == account.id
|
||||
@ -5,6 +5,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
from repositories.sqlalchemy_api_workflow_node_execution_repository import (
|
||||
DifyAPISQLAlchemyWorkflowNodeExecutionRepository,
|
||||
@ -52,6 +53,9 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
|
||||
call_args = mock_session.scalar.call_args[0][0]
|
||||
assert hasattr(call_args, "compile") # It's a SQLAlchemy statement
|
||||
|
||||
compiled = call_args.compile()
|
||||
assert WorkflowNodeExecutionStatus.PAUSED in compiled.params.values()
|
||||
|
||||
def test_get_node_last_execution_not_found(self, repository):
|
||||
"""Test getting the last execution for a node when it doesn't exist."""
|
||||
# Arrange
|
||||
@ -71,28 +75,6 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
|
||||
assert result is None
|
||||
mock_session.scalar.assert_called_once()
|
||||
|
||||
def test_get_executions_by_workflow_run(self, repository, mock_execution):
|
||||
"""Test getting all executions for a workflow run."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
executions = [mock_execution]
|
||||
mock_session.execute.return_value.scalars.return_value.all.return_value = executions
|
||||
|
||||
# Act
|
||||
result = repository.get_executions_by_workflow_run(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-456",
|
||||
workflow_run_id="run-101",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == executions
|
||||
mock_session.execute.assert_called_once()
|
||||
# Verify the query was constructed correctly
|
||||
call_args = mock_session.execute.call_args[0][0]
|
||||
assert hasattr(call_args, "compile") # It's a SQLAlchemy statement
|
||||
|
||||
def test_get_executions_by_workflow_run_empty(self, repository):
|
||||
"""Test getting executions for a workflow run when none exist."""
|
||||
# Arrange
|
||||
|
||||
@ -1,9 +1,15 @@
|
||||
from contextlib import nullcontext
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.human_input.entities import FormInput, HumanInputNodeData, UserAction
|
||||
from core.workflow.nodes.human_input.enums import FormInputType
|
||||
from models.model import App
|
||||
from models.workflow import Workflow
|
||||
from services import workflow_service as workflow_service_module
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
@ -161,3 +167,120 @@ class TestWorkflowService:
|
||||
assert workflows == []
|
||||
assert has_more is False
|
||||
mock_session.scalars.assert_called_once()
|
||||
|
||||
def test_submit_human_input_form_preview_uses_rendered_content(
|
||||
self, workflow_service: WorkflowService, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
service = workflow_service
|
||||
node_data = HumanInputNodeData(
|
||||
title="Human Input",
|
||||
form_content="<p>{{#$output.name#}}</p>",
|
||||
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")],
|
||||
user_actions=[UserAction(id="approve", title="Approve")],
|
||||
)
|
||||
node = MagicMock()
|
||||
node.node_data = node_data
|
||||
node.render_form_content_before_submission.return_value = "<p>preview</p>"
|
||||
node.render_form_content_with_outputs.return_value = "<p>rendered</p>"
|
||||
|
||||
service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[method-assign]
|
||||
service._build_human_input_node = MagicMock(return_value=node) # type: ignore[method-assign]
|
||||
|
||||
workflow = MagicMock()
|
||||
workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}}
|
||||
workflow.get_enclosing_node_type_and_id.return_value = None
|
||||
service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign]
|
||||
|
||||
saved_outputs: dict[str, object] = {}
|
||||
|
||||
class DummySession:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.commit = MagicMock()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def begin(self):
|
||||
return nullcontext()
|
||||
|
||||
class DummySaver:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def save(self, outputs, process_data):
|
||||
saved_outputs.update(outputs)
|
||||
|
||||
monkeypatch.setattr(workflow_service_module, "Session", DummySession)
|
||||
monkeypatch.setattr(workflow_service_module, "DraftVariableSaver", DummySaver)
|
||||
monkeypatch.setattr(workflow_service_module, "db", SimpleNamespace(engine=MagicMock()))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
account = SimpleNamespace(id="account-1")
|
||||
|
||||
result = service.submit_human_input_form_preview(
|
||||
app_model=app_model,
|
||||
account=account,
|
||||
node_id="node-1",
|
||||
form_inputs={"name": "Ada", "extra": "ignored"},
|
||||
inputs={"#node-0.result#": "LLM output"},
|
||||
action="approve",
|
||||
)
|
||||
|
||||
service._build_human_input_variable_pool.assert_called_once_with(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
node_config={"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}},
|
||||
manual_inputs={"#node-0.result#": "LLM output"},
|
||||
)
|
||||
|
||||
node.render_form_content_with_outputs.assert_called_once()
|
||||
called_args = node.render_form_content_with_outputs.call_args.args
|
||||
assert called_args[0] == "<p>preview</p>"
|
||||
assert called_args[2] == node_data.outputs_field_names()
|
||||
rendered_outputs = called_args[1]
|
||||
assert rendered_outputs["name"] == "Ada"
|
||||
assert rendered_outputs["extra"] == "ignored"
|
||||
assert "extra" in saved_outputs
|
||||
assert "extra" in result
|
||||
assert saved_outputs["name"] == "Ada"
|
||||
assert result["name"] == "Ada"
|
||||
assert result["__action_id"] == "approve"
|
||||
assert "__rendered_content" in result
|
||||
|
||||
def test_submit_human_input_form_preview_missing_inputs_message(self, workflow_service: WorkflowService) -> None:
|
||||
service = workflow_service
|
||||
node_data = HumanInputNodeData(
|
||||
title="Human Input",
|
||||
form_content="<p>{{#$output.name#}}</p>",
|
||||
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")],
|
||||
user_actions=[UserAction(id="approve", title="Approve")],
|
||||
)
|
||||
node = MagicMock()
|
||||
node.node_data = node_data
|
||||
node._render_form_content_before_submission.return_value = "<p>preview</p>"
|
||||
node._render_form_content_with_outputs.return_value = "<p>rendered</p>"
|
||||
|
||||
service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[method-assign]
|
||||
service._build_human_input_node = MagicMock(return_value=node) # type: ignore[method-assign]
|
||||
|
||||
workflow = MagicMock()
|
||||
workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}}
|
||||
service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign]
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
account = SimpleNamespace(id="account-1")
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
service.submit_human_input_form_preview(
|
||||
app_model=app_model,
|
||||
account=account,
|
||||
node_id="node-1",
|
||||
form_inputs={},
|
||||
inputs={},
|
||||
action="approve",
|
||||
)
|
||||
|
||||
assert "Missing required inputs" in str(exc_info.value)
|
||||
|
||||
210
api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py
Normal file
210
api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py
Normal file
@ -0,0 +1,210 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
|
||||
from tasks import human_input_timeout_tasks as task_module
|
||||
|
||||
|
||||
class _FakeScalarResult:
|
||||
def __init__(self, items: list[Any]):
|
||||
self._items = items
|
||||
|
||||
def all(self) -> list[Any]:
|
||||
return self._items
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, items: list[Any], capture: dict[str, Any]):
|
||||
self._items = items
|
||||
self._capture = capture
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def scalars(self, stmt):
|
||||
self._capture["stmt"] = stmt
|
||||
return _FakeScalarResult(self._items)
|
||||
|
||||
|
||||
class _FakeSessionFactory:
|
||||
def __init__(self, items: list[Any], capture: dict[str, Any]):
|
||||
self._items = items
|
||||
self._capture = capture
|
||||
self._capture["session_factory"] = self
|
||||
|
||||
def __call__(self):
|
||||
session = _FakeSession(self._items, self._capture)
|
||||
self._capture["session"] = session
|
||||
return session
|
||||
|
||||
|
||||
class _FakeFormRepo:
|
||||
def __init__(self, _session_factory, form_map: dict[str, Any] | None = None):
|
||||
self.calls: list[dict[str, Any]] = []
|
||||
self._form_map = form_map or {}
|
||||
|
||||
def mark_timeout(self, *, form_id: str, timeout_status: HumanInputFormStatus, reason: str | None = None):
|
||||
self.calls.append(
|
||||
{
|
||||
"form_id": form_id,
|
||||
"timeout_status": timeout_status,
|
||||
"reason": reason,
|
||||
}
|
||||
)
|
||||
form = self._form_map.get(form_id)
|
||||
return SimpleNamespace(
|
||||
form_id=form_id,
|
||||
workflow_run_id=getattr(form, "workflow_run_id", None),
|
||||
node_id=getattr(form, "node_id", None),
|
||||
)
|
||||
|
||||
|
||||
class _FakeService:
|
||||
def __init__(self, _session_factory, form_repository=None):
|
||||
self.enqueued: list[str] = []
|
||||
|
||||
def enqueue_resume(self, workflow_run_id: str | None) -> None:
|
||||
if workflow_run_id is not None:
|
||||
self.enqueued.append(workflow_run_id)
|
||||
|
||||
|
||||
def _build_form(
|
||||
*,
|
||||
form_id: str,
|
||||
form_kind: HumanInputFormKind,
|
||||
created_at: datetime,
|
||||
expiration_time: datetime,
|
||||
workflow_run_id: str | None,
|
||||
node_id: str,
|
||||
) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
id=form_id,
|
||||
form_kind=form_kind,
|
||||
created_at=created_at,
|
||||
expiration_time=expiration_time,
|
||||
workflow_run_id=workflow_run_id,
|
||||
node_id=node_id,
|
||||
status=HumanInputFormStatus.WAITING,
|
||||
)
|
||||
|
||||
|
||||
def test_is_global_timeout_uses_created_at():
|
||||
now = datetime(2025, 1, 1, 12, 0, 0)
|
||||
form = SimpleNamespace(created_at=now - timedelta(seconds=61), workflow_run_id="run-1")
|
||||
|
||||
assert task_module._is_global_timeout(form, 60, now=now) is True
|
||||
|
||||
form.workflow_run_id = None
|
||||
assert task_module._is_global_timeout(form, 60, now=now) is False
|
||||
|
||||
form.workflow_run_id = "run-1"
|
||||
form.created_at = now - timedelta(seconds=59)
|
||||
assert task_module._is_global_timeout(form, 60, now=now) is False
|
||||
|
||||
assert task_module._is_global_timeout(form, 0, now=now) is False
|
||||
|
||||
|
||||
def test_check_and_handle_human_input_timeouts_marks_and_routes(monkeypatch: pytest.MonkeyPatch):
|
||||
now = datetime(2025, 1, 1, 12, 0, 0)
|
||||
monkeypatch.setattr(task_module, "naive_utc_now", lambda: now)
|
||||
monkeypatch.setattr(task_module.dify_config, "HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS", 3600)
|
||||
monkeypatch.setattr(task_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
forms = [
|
||||
_build_form(
|
||||
form_id="form-global",
|
||||
form_kind=HumanInputFormKind.RUNTIME,
|
||||
created_at=now - timedelta(hours=2),
|
||||
expiration_time=now + timedelta(hours=1),
|
||||
workflow_run_id="run-global",
|
||||
node_id="node-global",
|
||||
),
|
||||
_build_form(
|
||||
form_id="form-node",
|
||||
form_kind=HumanInputFormKind.RUNTIME,
|
||||
created_at=now - timedelta(minutes=5),
|
||||
expiration_time=now - timedelta(seconds=1),
|
||||
workflow_run_id="run-node",
|
||||
node_id="node-node",
|
||||
),
|
||||
_build_form(
|
||||
form_id="form-delivery",
|
||||
form_kind=HumanInputFormKind.DELIVERY_TEST,
|
||||
created_at=now - timedelta(minutes=1),
|
||||
expiration_time=now - timedelta(seconds=1),
|
||||
workflow_run_id=None,
|
||||
node_id="node-delivery",
|
||||
),
|
||||
]
|
||||
|
||||
capture: dict[str, Any] = {}
|
||||
monkeypatch.setattr(task_module, "sessionmaker", lambda *args, **kwargs: _FakeSessionFactory(forms, capture))
|
||||
|
||||
form_map = {form.id: form for form in forms}
|
||||
repo = _FakeFormRepo(None, form_map=form_map)
|
||||
|
||||
def _repo_factory(_session_factory):
|
||||
return repo
|
||||
|
||||
service = _FakeService(None)
|
||||
|
||||
def _service_factory(_session_factory, form_repository=None):
|
||||
return service
|
||||
|
||||
global_calls: list[dict[str, Any]] = []
|
||||
|
||||
monkeypatch.setattr(task_module, "HumanInputFormSubmissionRepository", _repo_factory)
|
||||
monkeypatch.setattr(task_module, "HumanInputService", _service_factory)
|
||||
monkeypatch.setattr(task_module, "_handle_global_timeout", lambda **kwargs: global_calls.append(kwargs))
|
||||
|
||||
task_module.check_and_handle_human_input_timeouts(limit=100)
|
||||
|
||||
assert {(call["form_id"], call["timeout_status"], call["reason"]) for call in repo.calls} == {
|
||||
("form-global", HumanInputFormStatus.EXPIRED, "global_timeout"),
|
||||
("form-node", HumanInputFormStatus.TIMEOUT, "node_timeout"),
|
||||
("form-delivery", HumanInputFormStatus.TIMEOUT, "delivery_test_timeout"),
|
||||
}
|
||||
assert service.enqueued == ["run-node"]
|
||||
assert global_calls == [
|
||||
{
|
||||
"form_id": "form-global",
|
||||
"workflow_run_id": "run-global",
|
||||
"node_id": "node-global",
|
||||
"session_factory": capture.get("session_factory"),
|
||||
}
|
||||
]
|
||||
|
||||
stmt = capture.get("stmt")
|
||||
assert stmt is not None
|
||||
stmt_text = str(stmt)
|
||||
assert "created_at <=" in stmt_text
|
||||
assert "expiration_time <=" in stmt_text
|
||||
assert "ORDER BY human_input_forms.id" in stmt_text
|
||||
|
||||
|
||||
def test_check_and_handle_human_input_timeouts_omits_global_filter_when_disabled(monkeypatch: pytest.MonkeyPatch):
|
||||
now = datetime(2025, 1, 1, 12, 0, 0)
|
||||
monkeypatch.setattr(task_module, "naive_utc_now", lambda: now)
|
||||
monkeypatch.setattr(task_module.dify_config, "HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS", 0)
|
||||
monkeypatch.setattr(task_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
capture: dict[str, Any] = {}
|
||||
monkeypatch.setattr(task_module, "sessionmaker", lambda *args, **kwargs: _FakeSessionFactory([], capture))
|
||||
monkeypatch.setattr(task_module, "HumanInputFormSubmissionRepository", _FakeFormRepo)
|
||||
monkeypatch.setattr(task_module, "HumanInputService", _FakeService)
|
||||
monkeypatch.setattr(task_module, "_handle_global_timeout", lambda **_kwargs: None)
|
||||
|
||||
task_module.check_and_handle_human_input_timeouts(limit=1)
|
||||
|
||||
stmt = capture.get("stmt")
|
||||
assert stmt is not None
|
||||
stmt_text = str(stmt)
|
||||
assert "created_at <=" not in stmt_text
|
||||
@ -0,0 +1,123 @@
|
||||
from collections.abc import Sequence
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from tasks import mail_human_input_delivery_task as task_module
|
||||
|
||||
|
||||
class _DummyMail:
|
||||
def __init__(self):
|
||||
self.sent: list[dict[str, str]] = []
|
||||
self._inited = True
|
||||
|
||||
def is_inited(self) -> bool:
|
||||
return self._inited
|
||||
|
||||
def send(self, *, to: str, subject: str, html: str):
|
||||
self.sent.append({"to": to, "subject": subject, "html": html})
|
||||
|
||||
|
||||
class _DummySession:
|
||||
def __init__(self, form):
|
||||
self._form = form
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
return False
|
||||
|
||||
def get(self, _model, _form_id):
|
||||
return self._form
|
||||
|
||||
|
||||
def _build_job(recipient_count: int = 1) -> task_module._EmailDeliveryJob:
|
||||
recipients: list[task_module._EmailRecipient] = []
|
||||
for idx in range(recipient_count):
|
||||
recipients.append(task_module._EmailRecipient(email=f"user{idx}@example.com", token=f"token-{idx}"))
|
||||
|
||||
return task_module._EmailDeliveryJob(
|
||||
form_id="form-1",
|
||||
subject="Subject",
|
||||
body="Body for {{#url}}",
|
||||
form_content="content",
|
||||
recipients=recipients,
|
||||
)
|
||||
|
||||
|
||||
def test_dispatch_human_input_email_task_sends_to_each_recipient(monkeypatch: pytest.MonkeyPatch):
|
||||
mail = _DummyMail()
|
||||
form = SimpleNamespace(id="form-1", tenant_id="tenant-1", workflow_run_id=None)
|
||||
|
||||
monkeypatch.setattr(task_module, "mail", mail)
|
||||
monkeypatch.setattr(
|
||||
task_module.FeatureService,
|
||||
"get_features",
|
||||
lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=True),
|
||||
)
|
||||
jobs: Sequence[task_module._EmailDeliveryJob] = [_build_job(recipient_count=2)]
|
||||
monkeypatch.setattr(task_module, "_load_email_jobs", lambda _session, _form: jobs)
|
||||
|
||||
task_module.dispatch_human_input_email_task(
|
||||
form_id="form-1",
|
||||
node_title="Approve",
|
||||
session_factory=lambda: _DummySession(form),
|
||||
)
|
||||
|
||||
assert len(mail.sent) == 2
|
||||
assert all(payload["subject"] == "Subject" for payload in mail.sent)
|
||||
assert all("Body for" in payload["html"] for payload in mail.sent)
|
||||
|
||||
|
||||
def test_dispatch_human_input_email_task_skips_when_feature_disabled(monkeypatch: pytest.MonkeyPatch):
|
||||
mail = _DummyMail()
|
||||
form = SimpleNamespace(id="form-1", tenant_id="tenant-1", workflow_run_id=None)
|
||||
|
||||
monkeypatch.setattr(task_module, "mail", mail)
|
||||
monkeypatch.setattr(
|
||||
task_module.FeatureService,
|
||||
"get_features",
|
||||
lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=False),
|
||||
)
|
||||
monkeypatch.setattr(task_module, "_load_email_jobs", lambda _session, _form: [])
|
||||
|
||||
task_module.dispatch_human_input_email_task(
|
||||
form_id="form-1",
|
||||
node_title="Approve",
|
||||
session_factory=lambda: _DummySession(form),
|
||||
)
|
||||
|
||||
assert mail.sent == []
|
||||
|
||||
|
||||
def test_dispatch_human_input_email_task_replaces_body_variables(monkeypatch: pytest.MonkeyPatch):
|
||||
mail = _DummyMail()
|
||||
form = SimpleNamespace(id="form-1", tenant_id="tenant-1", workflow_run_id="run-1")
|
||||
job = task_module._EmailDeliveryJob(
|
||||
form_id="form-1",
|
||||
subject="Subject",
|
||||
body="Body {{#node1.value#}}",
|
||||
form_content="content",
|
||||
recipients=[task_module._EmailRecipient(email="user@example.com", token="token-1")],
|
||||
)
|
||||
|
||||
variable_pool = task_module.VariablePool()
|
||||
variable_pool.add(["node1", "value"], "OK")
|
||||
|
||||
monkeypatch.setattr(task_module, "mail", mail)
|
||||
monkeypatch.setattr(
|
||||
task_module.FeatureService,
|
||||
"get_features",
|
||||
lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=True),
|
||||
)
|
||||
monkeypatch.setattr(task_module, "_load_email_jobs", lambda _session, _form: [job])
|
||||
monkeypatch.setattr(task_module, "_load_variable_pool", lambda _workflow_run_id: variable_pool)
|
||||
|
||||
task_module.dispatch_human_input_email_task(
|
||||
form_id="form-1",
|
||||
node_title="Approve",
|
||||
session_factory=lambda: _DummySession(form),
|
||||
)
|
||||
|
||||
assert mail.sent[0]["html"] == "Body OK"
|
||||
39
api/tests/unit_tests/tasks/test_workflow_execute_task.py
Normal file
39
api/tests/unit_tests/tasks/test_workflow_execute_task.py
Normal file
@ -0,0 +1,39 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from models.model import AppMode
|
||||
from tasks.app_generate.workflow_execute_task import _publish_streaming_response
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_topic(mocker) -> MagicMock:
|
||||
topic = MagicMock()
|
||||
mocker.patch(
|
||||
"tasks.app_generate.workflow_execute_task.MessageBasedAppGenerator.get_response_topic",
|
||||
return_value=topic,
|
||||
)
|
||||
return topic
|
||||
|
||||
|
||||
def test_publish_streaming_response_with_uuid(mock_topic: MagicMock):
|
||||
workflow_run_id = uuid.uuid4()
|
||||
response_stream = iter([{"event": "foo"}, "ping"])
|
||||
|
||||
_publish_streaming_response(response_stream, workflow_run_id, app_mode=AppMode.ADVANCED_CHAT)
|
||||
|
||||
payloads = [call.args[0] for call in mock_topic.publish.call_args_list]
|
||||
assert payloads == [json.dumps({"event": "foo"}).encode(), json.dumps("ping").encode()]
|
||||
|
||||
|
||||
def test_publish_streaming_response_coerces_string_uuid(mock_topic: MagicMock):
|
||||
workflow_run_id = uuid.uuid4()
|
||||
response_stream = iter([{"event": "bar"}])
|
||||
|
||||
_publish_streaming_response(response_stream, str(workflow_run_id), app_mode=AppMode.ADVANCED_CHAT)
|
||||
|
||||
mock_topic.publish.assert_called_once_with(json.dumps({"event": "bar"}).encode())
|
||||
488
api/tests/unit_tests/tasks/test_workflow_node_execution_tasks.py
Normal file
488
api/tests/unit_tests/tasks/test_workflow_node_execution_tasks.py
Normal file
@ -0,0 +1,488 @@
|
||||
# """
|
||||
# Unit tests for workflow node execution Celery tasks.
|
||||
|
||||
# These tests verify the asynchronous storage functionality for workflow node execution data,
|
||||
# including truncation and offloading logic.
|
||||
# """
|
||||
|
||||
# import json
|
||||
# from unittest.mock import MagicMock, Mock, patch
|
||||
# from uuid import uuid4
|
||||
|
||||
# import pytest
|
||||
|
||||
# from core.workflow.entities.workflow_node_execution import (
|
||||
# WorkflowNodeExecution,
|
||||
# WorkflowNodeExecutionStatus,
|
||||
# )
|
||||
# from core.workflow.enums import NodeType
|
||||
# from libs.datetime_utils import naive_utc_now
|
||||
# from models import WorkflowNodeExecutionModel
|
||||
# from models.enums import ExecutionOffLoadType
|
||||
# from models.model import UploadFile
|
||||
# from models.workflow import WorkflowNodeExecutionOffload, WorkflowNodeExecutionTriggeredFrom
|
||||
# from tasks.workflow_node_execution_tasks import (
|
||||
# _create_truncator,
|
||||
# _json_encode,
|
||||
# _replace_or_append_offload,
|
||||
# _truncate_and_upload_async,
|
||||
# save_workflow_node_execution_data_task,
|
||||
# save_workflow_node_execution_task,
|
||||
# )
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def sample_execution_data():
|
||||
# """Sample execution data for testing."""
|
||||
# execution = WorkflowNodeExecution(
|
||||
# id=str(uuid4()),
|
||||
# node_execution_id=str(uuid4()),
|
||||
# workflow_id=str(uuid4()),
|
||||
# workflow_execution_id=str(uuid4()),
|
||||
# index=1,
|
||||
# node_id="test_node",
|
||||
# node_type=NodeType.LLM,
|
||||
# title="Test Node",
|
||||
# inputs={"input_key": "input_value"},
|
||||
# outputs={"output_key": "output_value"},
|
||||
# process_data={"process_key": "process_value"},
|
||||
# status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
# created_at=naive_utc_now(),
|
||||
# )
|
||||
# return execution.model_dump()
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def mock_db_model():
|
||||
# """Mock database model for testing."""
|
||||
# db_model = Mock(spec=WorkflowNodeExecutionModel)
|
||||
# db_model.id = "test-execution-id"
|
||||
# db_model.offload_data = []
|
||||
# return db_model
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def mock_file_service():
|
||||
# """Mock file service for testing."""
|
||||
# file_service = Mock()
|
||||
# mock_upload_file = Mock(spec=UploadFile)
|
||||
# mock_upload_file.id = "mock-file-id"
|
||||
# file_service.upload_file.return_value = mock_upload_file
|
||||
# return file_service
|
||||
|
||||
|
||||
# class TestSaveWorkflowNodeExecutionDataTask:
|
||||
# """Test cases for save_workflow_node_execution_data_task."""
|
||||
|
||||
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
|
||||
# @patch("tasks.workflow_node_execution_tasks.select")
|
||||
# def test_save_execution_data_task_success(
|
||||
# self, mock_select, mock_sessionmaker, sample_execution_data, mock_db_model
|
||||
# ):
|
||||
# """Test successful execution of save_workflow_node_execution_data_task."""
|
||||
# # Setup mocks
|
||||
# mock_session = MagicMock()
|
||||
# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session
|
||||
# mock_session.execute.return_value.scalars.return_value.first.return_value = mock_db_model
|
||||
|
||||
# # Execute task
|
||||
# result = save_workflow_node_execution_data_task(
|
||||
# execution_data=sample_execution_data,
|
||||
# tenant_id="test-tenant-id",
|
||||
# app_id="test-app-id",
|
||||
# user_data={"user_id": "test-user-id", "user_type": "account"},
|
||||
# )
|
||||
|
||||
# # Verify success
|
||||
# assert result is True
|
||||
# mock_session.merge.assert_called_once_with(mock_db_model)
|
||||
# mock_session.commit.assert_called_once()
|
||||
|
||||
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
|
||||
# @patch("tasks.workflow_node_execution_tasks.select")
|
||||
# def test_save_execution_data_task_execution_not_found(self, mock_select, mock_sessionmaker,
|
||||
# sample_execution_data):
|
||||
# """Test task when execution is not found in database."""
|
||||
# # Setup mocks
|
||||
# mock_session = MagicMock()
|
||||
# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session
|
||||
# mock_session.execute.return_value.scalars.return_value.first.return_value = None
|
||||
|
||||
# # Execute task
|
||||
# result = save_workflow_node_execution_data_task(
|
||||
# execution_data=sample_execution_data,
|
||||
# tenant_id="test-tenant-id",
|
||||
# app_id="test-app-id",
|
||||
# user_data={"user_id": "test-user-id", "user_type": "account"},
|
||||
# )
|
||||
|
||||
# # Verify failure
|
||||
# assert result is False
|
||||
# mock_session.merge.assert_not_called()
|
||||
# mock_session.commit.assert_not_called()
|
||||
|
||||
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
|
||||
# @patch("tasks.workflow_node_execution_tasks.select")
|
||||
# def test_save_execution_data_task_with_truncation(self, mock_select, mock_sessionmaker, mock_db_model):
|
||||
# """Test task with data that requires truncation."""
|
||||
# # Create execution with large data
|
||||
# large_data = {"large_field": "x" * 10000}
|
||||
# execution = WorkflowNodeExecution(
|
||||
# id=str(uuid4()),
|
||||
# node_execution_id=str(uuid4()),
|
||||
# workflow_id=str(uuid4()),
|
||||
# workflow_execution_id=str(uuid4()),
|
||||
# index=1,
|
||||
# node_id="test_node",
|
||||
# node_type=NodeType.LLM,
|
||||
# title="Test Node",
|
||||
# inputs=large_data,
|
||||
# outputs=large_data,
|
||||
# process_data=large_data,
|
||||
# status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
# created_at=naive_utc_now(),
|
||||
# )
|
||||
# execution_data = execution.model_dump()
|
||||
|
||||
# # Setup mocks
|
||||
# mock_session = MagicMock()
|
||||
# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session
|
||||
# mock_session.execute.return_value.scalars.return_value.first.return_value = mock_db_model
|
||||
|
||||
# # Create mock upload file
|
||||
# mock_upload_file = Mock(spec=UploadFile)
|
||||
# mock_upload_file.id = "mock-file-id"
|
||||
|
||||
# # Execute task
|
||||
# with patch("tasks.workflow_node_execution_tasks._truncate_and_upload_async") as mock_truncate:
|
||||
# # Mock truncation results
|
||||
# mock_truncate.return_value = {
|
||||
# "truncated_value": {"large_field": "[TRUNCATED]"},
|
||||
# "file": mock_upload_file,
|
||||
# "offload": WorkflowNodeExecutionOffload(
|
||||
# id=str(uuid4()),
|
||||
# tenant_id="test-tenant-id",
|
||||
# app_id="test-app-id",
|
||||
# node_execution_id=execution.id,
|
||||
# type_=ExecutionOffLoadType.INPUTS,
|
||||
# file_id=mock_upload_file.id,
|
||||
# ),
|
||||
# }
|
||||
|
||||
# result = save_workflow_node_execution_data_task(
|
||||
# execution_data=execution_data,
|
||||
# tenant_id="test-tenant-id",
|
||||
# app_id="test-app-id",
|
||||
# user_data={"user_id": "test-user-id", "user_type": "account"},
|
||||
# )
|
||||
|
||||
# # Verify success and truncation was called
|
||||
# assert result is True
|
||||
# assert mock_truncate.call_count == 3 # inputs, outputs, process_data
|
||||
# mock_session.merge.assert_called_once_with(mock_db_model)
|
||||
# mock_session.commit.assert_called_once()
|
||||
|
||||
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
|
||||
# def test_save_execution_data_task_retry_on_exception(self, mock_sessionmaker, sample_execution_data):
|
||||
# """Test task retry mechanism on exception."""
|
||||
# # Setup mock to raise exception
|
||||
# mock_sessionmaker.side_effect = Exception("Database error")
|
||||
|
||||
# # Create a mock task instance with proper retry behavior
|
||||
# with patch.object(save_workflow_node_execution_data_task, "retry") as mock_retry:
|
||||
# mock_retry.side_effect = Exception("Retry called")
|
||||
|
||||
# # Execute task and expect retry
|
||||
# with pytest.raises(Exception, match="Retry called"):
|
||||
# save_workflow_node_execution_data_task(
|
||||
# execution_data=sample_execution_data,
|
||||
# tenant_id="test-tenant-id",
|
||||
# app_id="test-app-id",
|
||||
# user_data={"user_id": "test-user-id", "user_type": "account"},
|
||||
# )
|
||||
|
||||
# # Verify retry was called
|
||||
# mock_retry.assert_called_once()
|
||||
|
||||
|
||||
# class TestTruncateAndUploadAsync:
|
||||
# """Test cases for _truncate_and_upload_async function."""
|
||||
|
||||
# def test_truncate_and_upload_with_none_values(self, mock_file_service):
|
||||
# """Test _truncate_and_upload_async with None values."""
|
||||
# # The function handles None values internally, so we test with empty dict instead
|
||||
# result = _truncate_and_upload_async(
|
||||
# values={},
|
||||
# execution_id="test-id",
|
||||
# type_=ExecutionOffLoadType.INPUTS,
|
||||
# tenant_id="test-tenant",
|
||||
# app_id="test-app",
|
||||
# user_data={"user_id": "test-user", "user_type": "account"},
|
||||
# file_service=mock_file_service,
|
||||
# )
|
||||
|
||||
# # Empty dict should not require truncation
|
||||
# assert result is None
|
||||
# mock_file_service.upload_file.assert_not_called()
|
||||
|
||||
# @patch("tasks.workflow_node_execution_tasks._create_truncator")
|
||||
# def test_truncate_and_upload_no_truncation_needed(self, mock_create_truncator, mock_file_service):
|
||||
# """Test _truncate_and_upload_async when no truncation is needed."""
|
||||
# # Mock truncator to return no truncation
|
||||
# mock_truncator = Mock()
|
||||
# mock_truncator.truncate_variable_mapping.return_value = ({"small": "data"}, False)
|
||||
# mock_create_truncator.return_value = mock_truncator
|
||||
|
||||
# small_values = {"small": "data"}
|
||||
# result = _truncate_and_upload_async(
|
||||
# values=small_values,
|
||||
# execution_id="test-id",
|
||||
# type_=ExecutionOffLoadType.INPUTS,
|
||||
# tenant_id="test-tenant",
|
||||
# app_id="test-app",
|
||||
# user_data={"user_id": "test-user", "user_type": "account"},
|
||||
# file_service=mock_file_service,
|
||||
# )
|
||||
|
||||
# assert result is None
|
||||
# mock_file_service.upload_file.assert_not_called()
|
||||
|
||||
# @patch("tasks.workflow_node_execution_tasks._create_truncator")
|
||||
# @patch("models.Account")
|
||||
# @patch("models.Tenant")
|
||||
# def test_truncate_and_upload_with_account_user(
|
||||
# self, mock_tenant_class, mock_account_class, mock_create_truncator, mock_file_service
|
||||
# ):
|
||||
# """Test _truncate_and_upload_async with account user."""
|
||||
# # Mock truncator to return truncation needed
|
||||
# mock_truncator = Mock()
|
||||
# mock_truncator.truncate_variable_mapping.return_value = ({"truncated": "data"}, True)
|
||||
# mock_create_truncator.return_value = mock_truncator
|
||||
|
||||
# # Mock user and tenant creation
|
||||
# mock_account = Mock()
|
||||
# mock_account.id = "test-user"
|
||||
# mock_account_class.return_value = mock_account
|
||||
|
||||
# mock_tenant = Mock()
|
||||
# mock_tenant.id = "test-tenant"
|
||||
# mock_tenant_class.return_value = mock_tenant
|
||||
|
||||
# large_values = {"large": "x" * 10000}
|
||||
# result = _truncate_and_upload_async(
|
||||
# values=large_values,
|
||||
# execution_id="test-id",
|
||||
# type_=ExecutionOffLoadType.INPUTS,
|
||||
# tenant_id="test-tenant",
|
||||
# app_id="test-app",
|
||||
# user_data={"user_id": "test-user", "user_type": "account"},
|
||||
# file_service=mock_file_service,
|
||||
# )
|
||||
|
||||
# # Verify result structure
|
||||
# assert result is not None
|
||||
# assert "truncated_value" in result
|
||||
# assert "file" in result
|
||||
# assert "offload" in result
|
||||
# assert result["truncated_value"] == {"truncated": "data"}
|
||||
|
||||
# # Verify file upload was called
|
||||
# mock_file_service.upload_file.assert_called_once()
|
||||
# upload_call = mock_file_service.upload_file.call_args
|
||||
# assert upload_call[1]["filename"] == "node_execution_test-id_inputs.json"
|
||||
# assert upload_call[1]["mimetype"] == "application/json"
|
||||
# assert upload_call[1]["user"] == mock_account
|
||||
|
||||
# @patch("tasks.workflow_node_execution_tasks._create_truncator")
|
||||
# @patch("models.EndUser")
|
||||
# def test_truncate_and_upload_with_end_user(self, mock_end_user_class, mock_create_truncator, mock_file_service):
|
||||
# """Test _truncate_and_upload_async with end user."""
|
||||
# # Mock truncator to return truncation needed
|
||||
# mock_truncator = Mock()
|
||||
# mock_truncator.truncate_variable_mapping.return_value = ({"truncated": "data"}, True)
|
||||
# mock_create_truncator.return_value = mock_truncator
|
||||
|
||||
# # Mock end user creation
|
||||
# mock_end_user = Mock()
|
||||
# mock_end_user.id = "test-user"
|
||||
# mock_end_user.tenant_id = "test-tenant"
|
||||
# mock_end_user_class.return_value = mock_end_user
|
||||
|
||||
# large_values = {"large": "x" * 10000}
|
||||
# result = _truncate_and_upload_async(
|
||||
# values=large_values,
|
||||
# execution_id="test-id",
|
||||
# type_=ExecutionOffLoadType.OUTPUTS,
|
||||
# tenant_id="test-tenant",
|
||||
# app_id="test-app",
|
||||
# user_data={"user_id": "test-user", "user_type": "end_user"},
|
||||
# file_service=mock_file_service,
|
||||
# )
|
||||
|
||||
# # Verify result structure
|
||||
# assert result is not None
|
||||
# assert result["truncated_value"] == {"truncated": "data"}
|
||||
|
||||
# # Verify file upload was called with end user
|
||||
# mock_file_service.upload_file.assert_called_once()
|
||||
# upload_call = mock_file_service.upload_file.call_args
|
||||
# assert upload_call[1]["filename"] == "node_execution_test-id_outputs.json"
|
||||
# assert upload_call[1]["user"] == mock_end_user
|
||||
|
||||
|
||||
# class TestHelperFunctions:
|
||||
# """Test cases for helper functions."""
|
||||
|
||||
# @patch("tasks.workflow_node_execution_tasks.dify_config")
|
||||
# def test_create_truncator(self, mock_config):
|
||||
# """Test _create_truncator function."""
|
||||
# mock_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE = 1000
|
||||
# mock_config.WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH = 100
|
||||
# mock_config.WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH = 500
|
||||
|
||||
# truncator = _create_truncator()
|
||||
|
||||
# # Verify truncator was created with correct config
|
||||
# assert truncator is not None
|
||||
|
||||
# def test_json_encode(self):
|
||||
# """Test _json_encode function."""
|
||||
# test_data = {"key": "value", "number": 42}
|
||||
# result = _json_encode(test_data)
|
||||
|
||||
# assert isinstance(result, str)
|
||||
# decoded = json.loads(result)
|
||||
# assert decoded == test_data
|
||||
|
||||
# def test_replace_or_append_offload_replace_existing(self):
|
||||
# """Test _replace_or_append_offload replaces existing offload of same type."""
|
||||
# existing_offload = WorkflowNodeExecutionOffload(
|
||||
# id=str(uuid4()),
|
||||
# tenant_id="test-tenant",
|
||||
# app_id="test-app",
|
||||
# node_execution_id="test-execution",
|
||||
# type_=ExecutionOffLoadType.INPUTS,
|
||||
# file_id="old-file-id",
|
||||
# )
|
||||
|
||||
# new_offload = WorkflowNodeExecutionOffload(
|
||||
# id=str(uuid4()),
|
||||
# tenant_id="test-tenant",
|
||||
# app_id="test-app",
|
||||
# node_execution_id="test-execution",
|
||||
# type_=ExecutionOffLoadType.INPUTS,
|
||||
# file_id="new-file-id",
|
||||
# )
|
||||
|
||||
# result = _replace_or_append_offload([existing_offload], new_offload)
|
||||
|
||||
# assert len(result) == 1
|
||||
# assert result[0].file_id == "new-file-id"
|
||||
|
||||
# def test_replace_or_append_offload_append_new_type(self):
|
||||
# """Test _replace_or_append_offload appends new offload of different type."""
|
||||
# existing_offload = WorkflowNodeExecutionOffload(
|
||||
# id=str(uuid4()),
|
||||
# tenant_id="test-tenant",
|
||||
# app_id="test-app",
|
||||
# node_execution_id="test-execution",
|
||||
# type_=ExecutionOffLoadType.INPUTS,
|
||||
# file_id="inputs-file-id",
|
||||
# )
|
||||
|
||||
# new_offload = WorkflowNodeExecutionOffload(
|
||||
# id=str(uuid4()),
|
||||
# tenant_id="test-tenant",
|
||||
# app_id="test-app",
|
||||
# node_execution_id="test-execution",
|
||||
# type_=ExecutionOffLoadType.OUTPUTS,
|
||||
# file_id="outputs-file-id",
|
||||
# )
|
||||
|
||||
# result = _replace_or_append_offload([existing_offload], new_offload)
|
||||
|
||||
# assert len(result) == 2
|
||||
# file_ids = [offload.file_id for offload in result]
|
||||
# assert "inputs-file-id" in file_ids
|
||||
# assert "outputs-file-id" in file_ids
|
||||
|
||||
|
||||
# class TestSaveWorkflowNodeExecutionTask:
|
||||
# """Test cases for save_workflow_node_execution_task."""
|
||||
|
||||
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
|
||||
# @patch("tasks.workflow_node_execution_tasks.select")
|
||||
# def test_save_workflow_node_execution_task_create_new(self, mock_select, mock_sessionmaker,
|
||||
# sample_execution_data):
|
||||
# """Test creating a new workflow node execution."""
|
||||
# # Setup mocks
|
||||
# mock_session = MagicMock()
|
||||
# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session
|
||||
# mock_session.scalar.return_value = None # No existing execution
|
||||
|
||||
# # Execute task
|
||||
# result = save_workflow_node_execution_task(
|
||||
# execution_data=sample_execution_data,
|
||||
# tenant_id="test-tenant-id",
|
||||
# app_id="test-app-id",
|
||||
# triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
# creator_user_id="test-user-id",
|
||||
# creator_user_role="account",
|
||||
# )
|
||||
|
||||
# # Verify success
|
||||
# assert result is True
|
||||
# mock_session.add.assert_called_once()
|
||||
# mock_session.commit.assert_called_once()
|
||||
|
||||
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
|
||||
# @patch("tasks.workflow_node_execution_tasks.select")
|
||||
# def test_save_workflow_node_execution_task_update_existing(
|
||||
# self, mock_select, mock_sessionmaker, sample_execution_data
|
||||
# ):
|
||||
# """Test updating an existing workflow node execution."""
|
||||
# # Setup mocks
|
||||
# mock_session = MagicMock()
|
||||
# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# existing_execution = Mock(spec=WorkflowNodeExecutionModel)
|
||||
# mock_session.scalar.return_value = existing_execution
|
||||
|
||||
# # Execute task
|
||||
# result = save_workflow_node_execution_task(
|
||||
# execution_data=sample_execution_data,
|
||||
# tenant_id="test-tenant-id",
|
||||
# app_id="test-app-id",
|
||||
# triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
# creator_user_id="test-user-id",
|
||||
# creator_user_role="account",
|
||||
# )
|
||||
|
||||
# # Verify success
|
||||
# assert result is True
|
||||
# mock_session.add.assert_not_called() # Should not add new, just update existing
|
||||
# mock_session.commit.assert_called_once()
|
||||
|
||||
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
|
||||
# def test_save_workflow_node_execution_task_retry_on_exception(self, mock_sessionmaker, sample_execution_data):
|
||||
# """Test task retry mechanism on exception."""
|
||||
# # Setup mock to raise exception
|
||||
# mock_sessionmaker.side_effect = Exception("Database error")
|
||||
|
||||
# # Create a mock task instance with proper retry behavior
|
||||
# with patch.object(save_workflow_node_execution_task, "retry") as mock_retry:
|
||||
# mock_retry.side_effect = Exception("Retry called")
|
||||
|
||||
# # Execute task and expect retry
|
||||
# with pytest.raises(Exception, match="Retry called"):
|
||||
# save_workflow_node_execution_task(
|
||||
# execution_data=sample_execution_data,
|
||||
# tenant_id="test-tenant-id",
|
||||
# app_id="test-app-id",
|
||||
# triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
# creator_user_id="test-user-id",
|
||||
# creator_user_role="account",
|
||||
# )
|
||||
|
||||
# # Verify retry was called
|
||||
# mock_retry.assert_called_once()
|
||||
Reference in New Issue
Block a user