mirror of
https://github.com/langgenius/dify.git
synced 2026-03-19 13:47:37 +08:00
feat: Human Input Node (#32060)
The frontend and backend implementation for the human input node. Co-authored-by: twwu <twwu@dify.ai> Co-authored-by: JzoNg <jzongcode@gmail.com> Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com> Co-authored-by: zhsama <torvalds@linux.do>
This commit is contained in:
210
api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py
Normal file
210
api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py
Normal file
@ -0,0 +1,210 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
|
||||
from tasks import human_input_timeout_tasks as task_module
|
||||
|
||||
|
||||
class _FakeScalarResult:
|
||||
def __init__(self, items: list[Any]):
|
||||
self._items = items
|
||||
|
||||
def all(self) -> list[Any]:
|
||||
return self._items
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, items: list[Any], capture: dict[str, Any]):
|
||||
self._items = items
|
||||
self._capture = capture
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def scalars(self, stmt):
|
||||
self._capture["stmt"] = stmt
|
||||
return _FakeScalarResult(self._items)
|
||||
|
||||
|
||||
class _FakeSessionFactory:
|
||||
def __init__(self, items: list[Any], capture: dict[str, Any]):
|
||||
self._items = items
|
||||
self._capture = capture
|
||||
self._capture["session_factory"] = self
|
||||
|
||||
def __call__(self):
|
||||
session = _FakeSession(self._items, self._capture)
|
||||
self._capture["session"] = session
|
||||
return session
|
||||
|
||||
|
||||
class _FakeFormRepo:
|
||||
def __init__(self, _session_factory, form_map: dict[str, Any] | None = None):
|
||||
self.calls: list[dict[str, Any]] = []
|
||||
self._form_map = form_map or {}
|
||||
|
||||
def mark_timeout(self, *, form_id: str, timeout_status: HumanInputFormStatus, reason: str | None = None):
|
||||
self.calls.append(
|
||||
{
|
||||
"form_id": form_id,
|
||||
"timeout_status": timeout_status,
|
||||
"reason": reason,
|
||||
}
|
||||
)
|
||||
form = self._form_map.get(form_id)
|
||||
return SimpleNamespace(
|
||||
form_id=form_id,
|
||||
workflow_run_id=getattr(form, "workflow_run_id", None),
|
||||
node_id=getattr(form, "node_id", None),
|
||||
)
|
||||
|
||||
|
||||
class _FakeService:
|
||||
def __init__(self, _session_factory, form_repository=None):
|
||||
self.enqueued: list[str] = []
|
||||
|
||||
def enqueue_resume(self, workflow_run_id: str | None) -> None:
|
||||
if workflow_run_id is not None:
|
||||
self.enqueued.append(workflow_run_id)
|
||||
|
||||
|
||||
def _build_form(
|
||||
*,
|
||||
form_id: str,
|
||||
form_kind: HumanInputFormKind,
|
||||
created_at: datetime,
|
||||
expiration_time: datetime,
|
||||
workflow_run_id: str | None,
|
||||
node_id: str,
|
||||
) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
id=form_id,
|
||||
form_kind=form_kind,
|
||||
created_at=created_at,
|
||||
expiration_time=expiration_time,
|
||||
workflow_run_id=workflow_run_id,
|
||||
node_id=node_id,
|
||||
status=HumanInputFormStatus.WAITING,
|
||||
)
|
||||
|
||||
|
||||
def test_is_global_timeout_uses_created_at():
|
||||
now = datetime(2025, 1, 1, 12, 0, 0)
|
||||
form = SimpleNamespace(created_at=now - timedelta(seconds=61), workflow_run_id="run-1")
|
||||
|
||||
assert task_module._is_global_timeout(form, 60, now=now) is True
|
||||
|
||||
form.workflow_run_id = None
|
||||
assert task_module._is_global_timeout(form, 60, now=now) is False
|
||||
|
||||
form.workflow_run_id = "run-1"
|
||||
form.created_at = now - timedelta(seconds=59)
|
||||
assert task_module._is_global_timeout(form, 60, now=now) is False
|
||||
|
||||
assert task_module._is_global_timeout(form, 0, now=now) is False
|
||||
|
||||
|
||||
def test_check_and_handle_human_input_timeouts_marks_and_routes(monkeypatch: pytest.MonkeyPatch):
|
||||
now = datetime(2025, 1, 1, 12, 0, 0)
|
||||
monkeypatch.setattr(task_module, "naive_utc_now", lambda: now)
|
||||
monkeypatch.setattr(task_module.dify_config, "HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS", 3600)
|
||||
monkeypatch.setattr(task_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
forms = [
|
||||
_build_form(
|
||||
form_id="form-global",
|
||||
form_kind=HumanInputFormKind.RUNTIME,
|
||||
created_at=now - timedelta(hours=2),
|
||||
expiration_time=now + timedelta(hours=1),
|
||||
workflow_run_id="run-global",
|
||||
node_id="node-global",
|
||||
),
|
||||
_build_form(
|
||||
form_id="form-node",
|
||||
form_kind=HumanInputFormKind.RUNTIME,
|
||||
created_at=now - timedelta(minutes=5),
|
||||
expiration_time=now - timedelta(seconds=1),
|
||||
workflow_run_id="run-node",
|
||||
node_id="node-node",
|
||||
),
|
||||
_build_form(
|
||||
form_id="form-delivery",
|
||||
form_kind=HumanInputFormKind.DELIVERY_TEST,
|
||||
created_at=now - timedelta(minutes=1),
|
||||
expiration_time=now - timedelta(seconds=1),
|
||||
workflow_run_id=None,
|
||||
node_id="node-delivery",
|
||||
),
|
||||
]
|
||||
|
||||
capture: dict[str, Any] = {}
|
||||
monkeypatch.setattr(task_module, "sessionmaker", lambda *args, **kwargs: _FakeSessionFactory(forms, capture))
|
||||
|
||||
form_map = {form.id: form for form in forms}
|
||||
repo = _FakeFormRepo(None, form_map=form_map)
|
||||
|
||||
def _repo_factory(_session_factory):
|
||||
return repo
|
||||
|
||||
service = _FakeService(None)
|
||||
|
||||
def _service_factory(_session_factory, form_repository=None):
|
||||
return service
|
||||
|
||||
global_calls: list[dict[str, Any]] = []
|
||||
|
||||
monkeypatch.setattr(task_module, "HumanInputFormSubmissionRepository", _repo_factory)
|
||||
monkeypatch.setattr(task_module, "HumanInputService", _service_factory)
|
||||
monkeypatch.setattr(task_module, "_handle_global_timeout", lambda **kwargs: global_calls.append(kwargs))
|
||||
|
||||
task_module.check_and_handle_human_input_timeouts(limit=100)
|
||||
|
||||
assert {(call["form_id"], call["timeout_status"], call["reason"]) for call in repo.calls} == {
|
||||
("form-global", HumanInputFormStatus.EXPIRED, "global_timeout"),
|
||||
("form-node", HumanInputFormStatus.TIMEOUT, "node_timeout"),
|
||||
("form-delivery", HumanInputFormStatus.TIMEOUT, "delivery_test_timeout"),
|
||||
}
|
||||
assert service.enqueued == ["run-node"]
|
||||
assert global_calls == [
|
||||
{
|
||||
"form_id": "form-global",
|
||||
"workflow_run_id": "run-global",
|
||||
"node_id": "node-global",
|
||||
"session_factory": capture.get("session_factory"),
|
||||
}
|
||||
]
|
||||
|
||||
stmt = capture.get("stmt")
|
||||
assert stmt is not None
|
||||
stmt_text = str(stmt)
|
||||
assert "created_at <=" in stmt_text
|
||||
assert "expiration_time <=" in stmt_text
|
||||
assert "ORDER BY human_input_forms.id" in stmt_text
|
||||
|
||||
|
||||
def test_check_and_handle_human_input_timeouts_omits_global_filter_when_disabled(monkeypatch: pytest.MonkeyPatch):
|
||||
now = datetime(2025, 1, 1, 12, 0, 0)
|
||||
monkeypatch.setattr(task_module, "naive_utc_now", lambda: now)
|
||||
monkeypatch.setattr(task_module.dify_config, "HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS", 0)
|
||||
monkeypatch.setattr(task_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
capture: dict[str, Any] = {}
|
||||
monkeypatch.setattr(task_module, "sessionmaker", lambda *args, **kwargs: _FakeSessionFactory([], capture))
|
||||
monkeypatch.setattr(task_module, "HumanInputFormSubmissionRepository", _FakeFormRepo)
|
||||
monkeypatch.setattr(task_module, "HumanInputService", _FakeService)
|
||||
monkeypatch.setattr(task_module, "_handle_global_timeout", lambda **_kwargs: None)
|
||||
|
||||
task_module.check_and_handle_human_input_timeouts(limit=1)
|
||||
|
||||
stmt = capture.get("stmt")
|
||||
assert stmt is not None
|
||||
stmt_text = str(stmt)
|
||||
assert "created_at <=" not in stmt_text
|
||||
@ -0,0 +1,123 @@
|
||||
from collections.abc import Sequence
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from tasks import mail_human_input_delivery_task as task_module
|
||||
|
||||
|
||||
class _DummyMail:
|
||||
def __init__(self):
|
||||
self.sent: list[dict[str, str]] = []
|
||||
self._inited = True
|
||||
|
||||
def is_inited(self) -> bool:
|
||||
return self._inited
|
||||
|
||||
def send(self, *, to: str, subject: str, html: str):
|
||||
self.sent.append({"to": to, "subject": subject, "html": html})
|
||||
|
||||
|
||||
class _DummySession:
|
||||
def __init__(self, form):
|
||||
self._form = form
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
return False
|
||||
|
||||
def get(self, _model, _form_id):
|
||||
return self._form
|
||||
|
||||
|
||||
def _build_job(recipient_count: int = 1) -> task_module._EmailDeliveryJob:
|
||||
recipients: list[task_module._EmailRecipient] = []
|
||||
for idx in range(recipient_count):
|
||||
recipients.append(task_module._EmailRecipient(email=f"user{idx}@example.com", token=f"token-{idx}"))
|
||||
|
||||
return task_module._EmailDeliveryJob(
|
||||
form_id="form-1",
|
||||
subject="Subject",
|
||||
body="Body for {{#url}}",
|
||||
form_content="content",
|
||||
recipients=recipients,
|
||||
)
|
||||
|
||||
|
||||
def test_dispatch_human_input_email_task_sends_to_each_recipient(monkeypatch: pytest.MonkeyPatch):
|
||||
mail = _DummyMail()
|
||||
form = SimpleNamespace(id="form-1", tenant_id="tenant-1", workflow_run_id=None)
|
||||
|
||||
monkeypatch.setattr(task_module, "mail", mail)
|
||||
monkeypatch.setattr(
|
||||
task_module.FeatureService,
|
||||
"get_features",
|
||||
lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=True),
|
||||
)
|
||||
jobs: Sequence[task_module._EmailDeliveryJob] = [_build_job(recipient_count=2)]
|
||||
monkeypatch.setattr(task_module, "_load_email_jobs", lambda _session, _form: jobs)
|
||||
|
||||
task_module.dispatch_human_input_email_task(
|
||||
form_id="form-1",
|
||||
node_title="Approve",
|
||||
session_factory=lambda: _DummySession(form),
|
||||
)
|
||||
|
||||
assert len(mail.sent) == 2
|
||||
assert all(payload["subject"] == "Subject" for payload in mail.sent)
|
||||
assert all("Body for" in payload["html"] for payload in mail.sent)
|
||||
|
||||
|
||||
def test_dispatch_human_input_email_task_skips_when_feature_disabled(monkeypatch: pytest.MonkeyPatch):
|
||||
mail = _DummyMail()
|
||||
form = SimpleNamespace(id="form-1", tenant_id="tenant-1", workflow_run_id=None)
|
||||
|
||||
monkeypatch.setattr(task_module, "mail", mail)
|
||||
monkeypatch.setattr(
|
||||
task_module.FeatureService,
|
||||
"get_features",
|
||||
lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=False),
|
||||
)
|
||||
monkeypatch.setattr(task_module, "_load_email_jobs", lambda _session, _form: [])
|
||||
|
||||
task_module.dispatch_human_input_email_task(
|
||||
form_id="form-1",
|
||||
node_title="Approve",
|
||||
session_factory=lambda: _DummySession(form),
|
||||
)
|
||||
|
||||
assert mail.sent == []
|
||||
|
||||
|
||||
def test_dispatch_human_input_email_task_replaces_body_variables(monkeypatch: pytest.MonkeyPatch):
|
||||
mail = _DummyMail()
|
||||
form = SimpleNamespace(id="form-1", tenant_id="tenant-1", workflow_run_id="run-1")
|
||||
job = task_module._EmailDeliveryJob(
|
||||
form_id="form-1",
|
||||
subject="Subject",
|
||||
body="Body {{#node1.value#}}",
|
||||
form_content="content",
|
||||
recipients=[task_module._EmailRecipient(email="user@example.com", token="token-1")],
|
||||
)
|
||||
|
||||
variable_pool = task_module.VariablePool()
|
||||
variable_pool.add(["node1", "value"], "OK")
|
||||
|
||||
monkeypatch.setattr(task_module, "mail", mail)
|
||||
monkeypatch.setattr(
|
||||
task_module.FeatureService,
|
||||
"get_features",
|
||||
lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=True),
|
||||
)
|
||||
monkeypatch.setattr(task_module, "_load_email_jobs", lambda _session, _form: [job])
|
||||
monkeypatch.setattr(task_module, "_load_variable_pool", lambda _workflow_run_id: variable_pool)
|
||||
|
||||
task_module.dispatch_human_input_email_task(
|
||||
form_id="form-1",
|
||||
node_title="Approve",
|
||||
session_factory=lambda: _DummySession(form),
|
||||
)
|
||||
|
||||
assert mail.sent[0]["html"] == "Body OK"
|
||||
39
api/tests/unit_tests/tasks/test_workflow_execute_task.py
Normal file
39
api/tests/unit_tests/tasks/test_workflow_execute_task.py
Normal file
@ -0,0 +1,39 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from models.model import AppMode
|
||||
from tasks.app_generate.workflow_execute_task import _publish_streaming_response
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_topic(mocker) -> MagicMock:
|
||||
topic = MagicMock()
|
||||
mocker.patch(
|
||||
"tasks.app_generate.workflow_execute_task.MessageBasedAppGenerator.get_response_topic",
|
||||
return_value=topic,
|
||||
)
|
||||
return topic
|
||||
|
||||
|
||||
def test_publish_streaming_response_with_uuid(mock_topic: MagicMock):
|
||||
workflow_run_id = uuid.uuid4()
|
||||
response_stream = iter([{"event": "foo"}, "ping"])
|
||||
|
||||
_publish_streaming_response(response_stream, workflow_run_id, app_mode=AppMode.ADVANCED_CHAT)
|
||||
|
||||
payloads = [call.args[0] for call in mock_topic.publish.call_args_list]
|
||||
assert payloads == [json.dumps({"event": "foo"}).encode(), json.dumps("ping").encode()]
|
||||
|
||||
|
||||
def test_publish_streaming_response_coerces_string_uuid(mock_topic: MagicMock):
|
||||
workflow_run_id = uuid.uuid4()
|
||||
response_stream = iter([{"event": "bar"}])
|
||||
|
||||
_publish_streaming_response(response_stream, str(workflow_run_id), app_mode=AppMode.ADVANCED_CHAT)
|
||||
|
||||
mock_topic.publish.assert_called_once_with(json.dumps({"event": "bar"}).encode())
|
||||
488
api/tests/unit_tests/tasks/test_workflow_node_execution_tasks.py
Normal file
488
api/tests/unit_tests/tasks/test_workflow_node_execution_tasks.py
Normal file
@ -0,0 +1,488 @@
|
||||
# """
|
||||
# Unit tests for workflow node execution Celery tasks.
|
||||
|
||||
# These tests verify the asynchronous storage functionality for workflow node execution data,
|
||||
# including truncation and offloading logic.
|
||||
# """
|
||||
|
||||
# import json
|
||||
# from unittest.mock import MagicMock, Mock, patch
|
||||
# from uuid import uuid4
|
||||
|
||||
# import pytest
|
||||
|
||||
# from core.workflow.entities.workflow_node_execution import (
|
||||
# WorkflowNodeExecution,
|
||||
# WorkflowNodeExecutionStatus,
|
||||
# )
|
||||
# from core.workflow.enums import NodeType
|
||||
# from libs.datetime_utils import naive_utc_now
|
||||
# from models import WorkflowNodeExecutionModel
|
||||
# from models.enums import ExecutionOffLoadType
|
||||
# from models.model import UploadFile
|
||||
# from models.workflow import WorkflowNodeExecutionOffload, WorkflowNodeExecutionTriggeredFrom
|
||||
# from tasks.workflow_node_execution_tasks import (
|
||||
# _create_truncator,
|
||||
# _json_encode,
|
||||
# _replace_or_append_offload,
|
||||
# _truncate_and_upload_async,
|
||||
# save_workflow_node_execution_data_task,
|
||||
# save_workflow_node_execution_task,
|
||||
# )
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def sample_execution_data():
|
||||
# """Sample execution data for testing."""
|
||||
# execution = WorkflowNodeExecution(
|
||||
# id=str(uuid4()),
|
||||
# node_execution_id=str(uuid4()),
|
||||
# workflow_id=str(uuid4()),
|
||||
# workflow_execution_id=str(uuid4()),
|
||||
# index=1,
|
||||
# node_id="test_node",
|
||||
# node_type=NodeType.LLM,
|
||||
# title="Test Node",
|
||||
# inputs={"input_key": "input_value"},
|
||||
# outputs={"output_key": "output_value"},
|
||||
# process_data={"process_key": "process_value"},
|
||||
# status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
# created_at=naive_utc_now(),
|
||||
# )
|
||||
# return execution.model_dump()
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def mock_db_model():
|
||||
# """Mock database model for testing."""
|
||||
# db_model = Mock(spec=WorkflowNodeExecutionModel)
|
||||
# db_model.id = "test-execution-id"
|
||||
# db_model.offload_data = []
|
||||
# return db_model
|
||||
|
||||
|
||||
# @pytest.fixture
|
||||
# def mock_file_service():
|
||||
# """Mock file service for testing."""
|
||||
# file_service = Mock()
|
||||
# mock_upload_file = Mock(spec=UploadFile)
|
||||
# mock_upload_file.id = "mock-file-id"
|
||||
# file_service.upload_file.return_value = mock_upload_file
|
||||
# return file_service
|
||||
|
||||
|
||||
# class TestSaveWorkflowNodeExecutionDataTask:
|
||||
# """Test cases for save_workflow_node_execution_data_task."""
|
||||
|
||||
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
|
||||
# @patch("tasks.workflow_node_execution_tasks.select")
|
||||
# def test_save_execution_data_task_success(
|
||||
# self, mock_select, mock_sessionmaker, sample_execution_data, mock_db_model
|
||||
# ):
|
||||
# """Test successful execution of save_workflow_node_execution_data_task."""
|
||||
# # Setup mocks
|
||||
# mock_session = MagicMock()
|
||||
# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session
|
||||
# mock_session.execute.return_value.scalars.return_value.first.return_value = mock_db_model
|
||||
|
||||
# # Execute task
|
||||
# result = save_workflow_node_execution_data_task(
|
||||
# execution_data=sample_execution_data,
|
||||
# tenant_id="test-tenant-id",
|
||||
# app_id="test-app-id",
|
||||
# user_data={"user_id": "test-user-id", "user_type": "account"},
|
||||
# )
|
||||
|
||||
# # Verify success
|
||||
# assert result is True
|
||||
# mock_session.merge.assert_called_once_with(mock_db_model)
|
||||
# mock_session.commit.assert_called_once()
|
||||
|
||||
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
|
||||
# @patch("tasks.workflow_node_execution_tasks.select")
|
||||
# def test_save_execution_data_task_execution_not_found(self, mock_select, mock_sessionmaker,
|
||||
# sample_execution_data):
|
||||
# """Test task when execution is not found in database."""
|
||||
# # Setup mocks
|
||||
# mock_session = MagicMock()
|
||||
# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session
|
||||
# mock_session.execute.return_value.scalars.return_value.first.return_value = None
|
||||
|
||||
# # Execute task
|
||||
# result = save_workflow_node_execution_data_task(
|
||||
# execution_data=sample_execution_data,
|
||||
# tenant_id="test-tenant-id",
|
||||
# app_id="test-app-id",
|
||||
# user_data={"user_id": "test-user-id", "user_type": "account"},
|
||||
# )
|
||||
|
||||
# # Verify failure
|
||||
# assert result is False
|
||||
# mock_session.merge.assert_not_called()
|
||||
# mock_session.commit.assert_not_called()
|
||||
|
||||
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
|
||||
# @patch("tasks.workflow_node_execution_tasks.select")
|
||||
# def test_save_execution_data_task_with_truncation(self, mock_select, mock_sessionmaker, mock_db_model):
|
||||
# """Test task with data that requires truncation."""
|
||||
# # Create execution with large data
|
||||
# large_data = {"large_field": "x" * 10000}
|
||||
# execution = WorkflowNodeExecution(
|
||||
# id=str(uuid4()),
|
||||
# node_execution_id=str(uuid4()),
|
||||
# workflow_id=str(uuid4()),
|
||||
# workflow_execution_id=str(uuid4()),
|
||||
# index=1,
|
||||
# node_id="test_node",
|
||||
# node_type=NodeType.LLM,
|
||||
# title="Test Node",
|
||||
# inputs=large_data,
|
||||
# outputs=large_data,
|
||||
# process_data=large_data,
|
||||
# status=WorkflowNodeExecutionStatus.RUNNING,
|
||||
# created_at=naive_utc_now(),
|
||||
# )
|
||||
# execution_data = execution.model_dump()
|
||||
|
||||
# # Setup mocks
|
||||
# mock_session = MagicMock()
|
||||
# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session
|
||||
# mock_session.execute.return_value.scalars.return_value.first.return_value = mock_db_model
|
||||
|
||||
# # Create mock upload file
|
||||
# mock_upload_file = Mock(spec=UploadFile)
|
||||
# mock_upload_file.id = "mock-file-id"
|
||||
|
||||
# # Execute task
|
||||
# with patch("tasks.workflow_node_execution_tasks._truncate_and_upload_async") as mock_truncate:
|
||||
# # Mock truncation results
|
||||
# mock_truncate.return_value = {
|
||||
# "truncated_value": {"large_field": "[TRUNCATED]"},
|
||||
# "file": mock_upload_file,
|
||||
# "offload": WorkflowNodeExecutionOffload(
|
||||
# id=str(uuid4()),
|
||||
# tenant_id="test-tenant-id",
|
||||
# app_id="test-app-id",
|
||||
# node_execution_id=execution.id,
|
||||
# type_=ExecutionOffLoadType.INPUTS,
|
||||
# file_id=mock_upload_file.id,
|
||||
# ),
|
||||
# }
|
||||
|
||||
# result = save_workflow_node_execution_data_task(
|
||||
# execution_data=execution_data,
|
||||
# tenant_id="test-tenant-id",
|
||||
# app_id="test-app-id",
|
||||
# user_data={"user_id": "test-user-id", "user_type": "account"},
|
||||
# )
|
||||
|
||||
# # Verify success and truncation was called
|
||||
# assert result is True
|
||||
# assert mock_truncate.call_count == 3 # inputs, outputs, process_data
|
||||
# mock_session.merge.assert_called_once_with(mock_db_model)
|
||||
# mock_session.commit.assert_called_once()
|
||||
|
||||
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
|
||||
# def test_save_execution_data_task_retry_on_exception(self, mock_sessionmaker, sample_execution_data):
|
||||
# """Test task retry mechanism on exception."""
|
||||
# # Setup mock to raise exception
|
||||
# mock_sessionmaker.side_effect = Exception("Database error")
|
||||
|
||||
# # Create a mock task instance with proper retry behavior
|
||||
# with patch.object(save_workflow_node_execution_data_task, "retry") as mock_retry:
|
||||
# mock_retry.side_effect = Exception("Retry called")
|
||||
|
||||
# # Execute task and expect retry
|
||||
# with pytest.raises(Exception, match="Retry called"):
|
||||
# save_workflow_node_execution_data_task(
|
||||
# execution_data=sample_execution_data,
|
||||
# tenant_id="test-tenant-id",
|
||||
# app_id="test-app-id",
|
||||
# user_data={"user_id": "test-user-id", "user_type": "account"},
|
||||
# )
|
||||
|
||||
# # Verify retry was called
|
||||
# mock_retry.assert_called_once()
|
||||
|
||||
|
||||
# class TestTruncateAndUploadAsync:
|
||||
# """Test cases for _truncate_and_upload_async function."""
|
||||
|
||||
# def test_truncate_and_upload_with_none_values(self, mock_file_service):
|
||||
# """Test _truncate_and_upload_async with None values."""
|
||||
# # The function handles None values internally, so we test with empty dict instead
|
||||
# result = _truncate_and_upload_async(
|
||||
# values={},
|
||||
# execution_id="test-id",
|
||||
# type_=ExecutionOffLoadType.INPUTS,
|
||||
# tenant_id="test-tenant",
|
||||
# app_id="test-app",
|
||||
# user_data={"user_id": "test-user", "user_type": "account"},
|
||||
# file_service=mock_file_service,
|
||||
# )
|
||||
|
||||
# # Empty dict should not require truncation
|
||||
# assert result is None
|
||||
# mock_file_service.upload_file.assert_not_called()
|
||||
|
||||
# @patch("tasks.workflow_node_execution_tasks._create_truncator")
|
||||
# def test_truncate_and_upload_no_truncation_needed(self, mock_create_truncator, mock_file_service):
|
||||
# """Test _truncate_and_upload_async when no truncation is needed."""
|
||||
# # Mock truncator to return no truncation
|
||||
# mock_truncator = Mock()
|
||||
# mock_truncator.truncate_variable_mapping.return_value = ({"small": "data"}, False)
|
||||
# mock_create_truncator.return_value = mock_truncator
|
||||
|
||||
# small_values = {"small": "data"}
|
||||
# result = _truncate_and_upload_async(
|
||||
# values=small_values,
|
||||
# execution_id="test-id",
|
||||
# type_=ExecutionOffLoadType.INPUTS,
|
||||
# tenant_id="test-tenant",
|
||||
# app_id="test-app",
|
||||
# user_data={"user_id": "test-user", "user_type": "account"},
|
||||
# file_service=mock_file_service,
|
||||
# )
|
||||
|
||||
# assert result is None
|
||||
# mock_file_service.upload_file.assert_not_called()
|
||||
|
||||
# @patch("tasks.workflow_node_execution_tasks._create_truncator")
|
||||
# @patch("models.Account")
|
||||
# @patch("models.Tenant")
|
||||
# def test_truncate_and_upload_with_account_user(
|
||||
# self, mock_tenant_class, mock_account_class, mock_create_truncator, mock_file_service
|
||||
# ):
|
||||
# """Test _truncate_and_upload_async with account user."""
|
||||
# # Mock truncator to return truncation needed
|
||||
# mock_truncator = Mock()
|
||||
# mock_truncator.truncate_variable_mapping.return_value = ({"truncated": "data"}, True)
|
||||
# mock_create_truncator.return_value = mock_truncator
|
||||
|
||||
# # Mock user and tenant creation
|
||||
# mock_account = Mock()
|
||||
# mock_account.id = "test-user"
|
||||
# mock_account_class.return_value = mock_account
|
||||
|
||||
# mock_tenant = Mock()
|
||||
# mock_tenant.id = "test-tenant"
|
||||
# mock_tenant_class.return_value = mock_tenant
|
||||
|
||||
# large_values = {"large": "x" * 10000}
|
||||
# result = _truncate_and_upload_async(
|
||||
# values=large_values,
|
||||
# execution_id="test-id",
|
||||
# type_=ExecutionOffLoadType.INPUTS,
|
||||
# tenant_id="test-tenant",
|
||||
# app_id="test-app",
|
||||
# user_data={"user_id": "test-user", "user_type": "account"},
|
||||
# file_service=mock_file_service,
|
||||
# )
|
||||
|
||||
# # Verify result structure
|
||||
# assert result is not None
|
||||
# assert "truncated_value" in result
|
||||
# assert "file" in result
|
||||
# assert "offload" in result
|
||||
# assert result["truncated_value"] == {"truncated": "data"}
|
||||
|
||||
# # Verify file upload was called
|
||||
# mock_file_service.upload_file.assert_called_once()
|
||||
# upload_call = mock_file_service.upload_file.call_args
|
||||
# assert upload_call[1]["filename"] == "node_execution_test-id_inputs.json"
|
||||
# assert upload_call[1]["mimetype"] == "application/json"
|
||||
# assert upload_call[1]["user"] == mock_account
|
||||
|
||||
# @patch("tasks.workflow_node_execution_tasks._create_truncator")
|
||||
# @patch("models.EndUser")
|
||||
# def test_truncate_and_upload_with_end_user(self, mock_end_user_class, mock_create_truncator, mock_file_service):
|
||||
# """Test _truncate_and_upload_async with end user."""
|
||||
# # Mock truncator to return truncation needed
|
||||
# mock_truncator = Mock()
|
||||
# mock_truncator.truncate_variable_mapping.return_value = ({"truncated": "data"}, True)
|
||||
# mock_create_truncator.return_value = mock_truncator
|
||||
|
||||
# # Mock end user creation
|
||||
# mock_end_user = Mock()
|
||||
# mock_end_user.id = "test-user"
|
||||
# mock_end_user.tenant_id = "test-tenant"
|
||||
# mock_end_user_class.return_value = mock_end_user
|
||||
|
||||
# large_values = {"large": "x" * 10000}
|
||||
# result = _truncate_and_upload_async(
|
||||
# values=large_values,
|
||||
# execution_id="test-id",
|
||||
# type_=ExecutionOffLoadType.OUTPUTS,
|
||||
# tenant_id="test-tenant",
|
||||
# app_id="test-app",
|
||||
# user_data={"user_id": "test-user", "user_type": "end_user"},
|
||||
# file_service=mock_file_service,
|
||||
# )
|
||||
|
||||
# # Verify result structure
|
||||
# assert result is not None
|
||||
# assert result["truncated_value"] == {"truncated": "data"}
|
||||
|
||||
# # Verify file upload was called with end user
|
||||
# mock_file_service.upload_file.assert_called_once()
|
||||
# upload_call = mock_file_service.upload_file.call_args
|
||||
# assert upload_call[1]["filename"] == "node_execution_test-id_outputs.json"
|
||||
# assert upload_call[1]["user"] == mock_end_user
|
||||
|
||||
|
||||
# class TestHelperFunctions:
|
||||
# """Test cases for helper functions."""
|
||||
|
||||
# @patch("tasks.workflow_node_execution_tasks.dify_config")
|
||||
# def test_create_truncator(self, mock_config):
|
||||
# """Test _create_truncator function."""
|
||||
# mock_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE = 1000
|
||||
# mock_config.WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH = 100
|
||||
# mock_config.WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH = 500
|
||||
|
||||
# truncator = _create_truncator()
|
||||
|
||||
# # Verify truncator was created with correct config
|
||||
# assert truncator is not None
|
||||
|
||||
# def test_json_encode(self):
|
||||
# """Test _json_encode function."""
|
||||
# test_data = {"key": "value", "number": 42}
|
||||
# result = _json_encode(test_data)
|
||||
|
||||
# assert isinstance(result, str)
|
||||
# decoded = json.loads(result)
|
||||
# assert decoded == test_data
|
||||
|
||||
# def test_replace_or_append_offload_replace_existing(self):
|
||||
# """Test _replace_or_append_offload replaces existing offload of same type."""
|
||||
# existing_offload = WorkflowNodeExecutionOffload(
|
||||
# id=str(uuid4()),
|
||||
# tenant_id="test-tenant",
|
||||
# app_id="test-app",
|
||||
# node_execution_id="test-execution",
|
||||
# type_=ExecutionOffLoadType.INPUTS,
|
||||
# file_id="old-file-id",
|
||||
# )
|
||||
|
||||
# new_offload = WorkflowNodeExecutionOffload(
|
||||
# id=str(uuid4()),
|
||||
# tenant_id="test-tenant",
|
||||
# app_id="test-app",
|
||||
# node_execution_id="test-execution",
|
||||
# type_=ExecutionOffLoadType.INPUTS,
|
||||
# file_id="new-file-id",
|
||||
# )
|
||||
|
||||
# result = _replace_or_append_offload([existing_offload], new_offload)
|
||||
|
||||
# assert len(result) == 1
|
||||
# assert result[0].file_id == "new-file-id"
|
||||
|
||||
# def test_replace_or_append_offload_append_new_type(self):
|
||||
# """Test _replace_or_append_offload appends new offload of different type."""
|
||||
# existing_offload = WorkflowNodeExecutionOffload(
|
||||
# id=str(uuid4()),
|
||||
# tenant_id="test-tenant",
|
||||
# app_id="test-app",
|
||||
# node_execution_id="test-execution",
|
||||
# type_=ExecutionOffLoadType.INPUTS,
|
||||
# file_id="inputs-file-id",
|
||||
# )
|
||||
|
||||
# new_offload = WorkflowNodeExecutionOffload(
|
||||
# id=str(uuid4()),
|
||||
# tenant_id="test-tenant",
|
||||
# app_id="test-app",
|
||||
# node_execution_id="test-execution",
|
||||
# type_=ExecutionOffLoadType.OUTPUTS,
|
||||
# file_id="outputs-file-id",
|
||||
# )
|
||||
|
||||
# result = _replace_or_append_offload([existing_offload], new_offload)
|
||||
|
||||
# assert len(result) == 2
|
||||
# file_ids = [offload.file_id for offload in result]
|
||||
# assert "inputs-file-id" in file_ids
|
||||
# assert "outputs-file-id" in file_ids
|
||||
|
||||
|
||||
# class TestSaveWorkflowNodeExecutionTask:
|
||||
# """Test cases for save_workflow_node_execution_task."""
|
||||
|
||||
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
|
||||
# @patch("tasks.workflow_node_execution_tasks.select")
|
||||
# def test_save_workflow_node_execution_task_create_new(self, mock_select, mock_sessionmaker,
|
||||
# sample_execution_data):
|
||||
# """Test creating a new workflow node execution."""
|
||||
# # Setup mocks
|
||||
# mock_session = MagicMock()
|
||||
# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session
|
||||
# mock_session.scalar.return_value = None # No existing execution
|
||||
|
||||
# # Execute task
|
||||
# result = save_workflow_node_execution_task(
|
||||
# execution_data=sample_execution_data,
|
||||
# tenant_id="test-tenant-id",
|
||||
# app_id="test-app-id",
|
||||
# triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
# creator_user_id="test-user-id",
|
||||
# creator_user_role="account",
|
||||
# )
|
||||
|
||||
# # Verify success
|
||||
# assert result is True
|
||||
# mock_session.add.assert_called_once()
|
||||
# mock_session.commit.assert_called_once()
|
||||
|
||||
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
|
||||
# @patch("tasks.workflow_node_execution_tasks.select")
|
||||
# def test_save_workflow_node_execution_task_update_existing(
|
||||
# self, mock_select, mock_sessionmaker, sample_execution_data
|
||||
# ):
|
||||
# """Test updating an existing workflow node execution."""
|
||||
# # Setup mocks
|
||||
# mock_session = MagicMock()
|
||||
# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# existing_execution = Mock(spec=WorkflowNodeExecutionModel)
|
||||
# mock_session.scalar.return_value = existing_execution
|
||||
|
||||
# # Execute task
|
||||
# result = save_workflow_node_execution_task(
|
||||
# execution_data=sample_execution_data,
|
||||
# tenant_id="test-tenant-id",
|
||||
# app_id="test-app-id",
|
||||
# triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
# creator_user_id="test-user-id",
|
||||
# creator_user_role="account",
|
||||
# )
|
||||
|
||||
# # Verify success
|
||||
# assert result is True
|
||||
# mock_session.add.assert_not_called() # Should not add new, just update existing
|
||||
# mock_session.commit.assert_called_once()
|
||||
|
||||
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
|
||||
# def test_save_workflow_node_execution_task_retry_on_exception(self, mock_sessionmaker, sample_execution_data):
|
||||
# """Test task retry mechanism on exception."""
|
||||
# # Setup mock to raise exception
|
||||
# mock_sessionmaker.side_effect = Exception("Database error")
|
||||
|
||||
# # Create a mock task instance with proper retry behavior
|
||||
# with patch.object(save_workflow_node_execution_task, "retry") as mock_retry:
|
||||
# mock_retry.side_effect = Exception("Retry called")
|
||||
|
||||
# # Execute task and expect retry
|
||||
# with pytest.raises(Exception, match="Retry called"):
|
||||
# save_workflow_node_execution_task(
|
||||
# execution_data=sample_execution_data,
|
||||
# tenant_id="test-tenant-id",
|
||||
# app_id="test-app-id",
|
||||
# triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
# creator_user_id="test-user-id",
|
||||
# creator_user_role="account",
|
||||
# )
|
||||
|
||||
# # Verify retry was called
|
||||
# mock_retry.assert_called_once()
|
||||
Reference in New Issue
Block a user