feat: Human Input Node (#32060)

The frontend and backend implementation for the human input node.

Co-authored-by: twwu <twwu@dify.ai>
Co-authored-by: JzoNg <jzongcode@gmail.com>
Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com>
Co-authored-by: zhsama <torvalds@linux.do>
This commit is contained in:
QuantumGhost
2026-02-09 14:57:23 +08:00
committed by GitHub
parent 56e3a55023
commit a1fc280102
474 changed files with 32667 additions and 2050 deletions

View File

@ -0,0 +1,210 @@
from __future__ import annotations
from datetime import datetime, timedelta
from types import SimpleNamespace
from typing import Any
import pytest
from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
from tasks import human_input_timeout_tasks as task_module
class _FakeScalarResult:
def __init__(self, items: list[Any]):
self._items = items
def all(self) -> list[Any]:
return self._items
class _FakeSession:
def __init__(self, items: list[Any], capture: dict[str, Any]):
self._items = items
self._capture = capture
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def scalars(self, stmt):
self._capture["stmt"] = stmt
return _FakeScalarResult(self._items)
class _FakeSessionFactory:
def __init__(self, items: list[Any], capture: dict[str, Any]):
self._items = items
self._capture = capture
self._capture["session_factory"] = self
def __call__(self):
session = _FakeSession(self._items, self._capture)
self._capture["session"] = session
return session
class _FakeFormRepo:
def __init__(self, _session_factory, form_map: dict[str, Any] | None = None):
self.calls: list[dict[str, Any]] = []
self._form_map = form_map or {}
def mark_timeout(self, *, form_id: str, timeout_status: HumanInputFormStatus, reason: str | None = None):
self.calls.append(
{
"form_id": form_id,
"timeout_status": timeout_status,
"reason": reason,
}
)
form = self._form_map.get(form_id)
return SimpleNamespace(
form_id=form_id,
workflow_run_id=getattr(form, "workflow_run_id", None),
node_id=getattr(form, "node_id", None),
)
class _FakeService:
def __init__(self, _session_factory, form_repository=None):
self.enqueued: list[str] = []
def enqueue_resume(self, workflow_run_id: str | None) -> None:
if workflow_run_id is not None:
self.enqueued.append(workflow_run_id)
def _build_form(
*,
form_id: str,
form_kind: HumanInputFormKind,
created_at: datetime,
expiration_time: datetime,
workflow_run_id: str | None,
node_id: str,
) -> SimpleNamespace:
return SimpleNamespace(
id=form_id,
form_kind=form_kind,
created_at=created_at,
expiration_time=expiration_time,
workflow_run_id=workflow_run_id,
node_id=node_id,
status=HumanInputFormStatus.WAITING,
)
def test_is_global_timeout_uses_created_at():
now = datetime(2025, 1, 1, 12, 0, 0)
form = SimpleNamespace(created_at=now - timedelta(seconds=61), workflow_run_id="run-1")
assert task_module._is_global_timeout(form, 60, now=now) is True
form.workflow_run_id = None
assert task_module._is_global_timeout(form, 60, now=now) is False
form.workflow_run_id = "run-1"
form.created_at = now - timedelta(seconds=59)
assert task_module._is_global_timeout(form, 60, now=now) is False
assert task_module._is_global_timeout(form, 0, now=now) is False
def test_check_and_handle_human_input_timeouts_marks_and_routes(monkeypatch: pytest.MonkeyPatch):
now = datetime(2025, 1, 1, 12, 0, 0)
monkeypatch.setattr(task_module, "naive_utc_now", lambda: now)
monkeypatch.setattr(task_module.dify_config, "HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS", 3600)
monkeypatch.setattr(task_module, "db", SimpleNamespace(engine=object()))
forms = [
_build_form(
form_id="form-global",
form_kind=HumanInputFormKind.RUNTIME,
created_at=now - timedelta(hours=2),
expiration_time=now + timedelta(hours=1),
workflow_run_id="run-global",
node_id="node-global",
),
_build_form(
form_id="form-node",
form_kind=HumanInputFormKind.RUNTIME,
created_at=now - timedelta(minutes=5),
expiration_time=now - timedelta(seconds=1),
workflow_run_id="run-node",
node_id="node-node",
),
_build_form(
form_id="form-delivery",
form_kind=HumanInputFormKind.DELIVERY_TEST,
created_at=now - timedelta(minutes=1),
expiration_time=now - timedelta(seconds=1),
workflow_run_id=None,
node_id="node-delivery",
),
]
capture: dict[str, Any] = {}
monkeypatch.setattr(task_module, "sessionmaker", lambda *args, **kwargs: _FakeSessionFactory(forms, capture))
form_map = {form.id: form for form in forms}
repo = _FakeFormRepo(None, form_map=form_map)
def _repo_factory(_session_factory):
return repo
service = _FakeService(None)
def _service_factory(_session_factory, form_repository=None):
return service
global_calls: list[dict[str, Any]] = []
monkeypatch.setattr(task_module, "HumanInputFormSubmissionRepository", _repo_factory)
monkeypatch.setattr(task_module, "HumanInputService", _service_factory)
monkeypatch.setattr(task_module, "_handle_global_timeout", lambda **kwargs: global_calls.append(kwargs))
task_module.check_and_handle_human_input_timeouts(limit=100)
assert {(call["form_id"], call["timeout_status"], call["reason"]) for call in repo.calls} == {
("form-global", HumanInputFormStatus.EXPIRED, "global_timeout"),
("form-node", HumanInputFormStatus.TIMEOUT, "node_timeout"),
("form-delivery", HumanInputFormStatus.TIMEOUT, "delivery_test_timeout"),
}
assert service.enqueued == ["run-node"]
assert global_calls == [
{
"form_id": "form-global",
"workflow_run_id": "run-global",
"node_id": "node-global",
"session_factory": capture.get("session_factory"),
}
]
stmt = capture.get("stmt")
assert stmt is not None
stmt_text = str(stmt)
assert "created_at <=" in stmt_text
assert "expiration_time <=" in stmt_text
assert "ORDER BY human_input_forms.id" in stmt_text
def test_check_and_handle_human_input_timeouts_omits_global_filter_when_disabled(monkeypatch: pytest.MonkeyPatch):
now = datetime(2025, 1, 1, 12, 0, 0)
monkeypatch.setattr(task_module, "naive_utc_now", lambda: now)
monkeypatch.setattr(task_module.dify_config, "HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS", 0)
monkeypatch.setattr(task_module, "db", SimpleNamespace(engine=object()))
capture: dict[str, Any] = {}
monkeypatch.setattr(task_module, "sessionmaker", lambda *args, **kwargs: _FakeSessionFactory([], capture))
monkeypatch.setattr(task_module, "HumanInputFormSubmissionRepository", _FakeFormRepo)
monkeypatch.setattr(task_module, "HumanInputService", _FakeService)
monkeypatch.setattr(task_module, "_handle_global_timeout", lambda **_kwargs: None)
task_module.check_and_handle_human_input_timeouts(limit=1)
stmt = capture.get("stmt")
assert stmt is not None
stmt_text = str(stmt)
assert "created_at <=" not in stmt_text

View File

@ -0,0 +1,123 @@
from collections.abc import Sequence
from types import SimpleNamespace
import pytest
from tasks import mail_human_input_delivery_task as task_module
class _DummyMail:
def __init__(self):
self.sent: list[dict[str, str]] = []
self._inited = True
def is_inited(self) -> bool:
return self._inited
def send(self, *, to: str, subject: str, html: str):
self.sent.append({"to": to, "subject": subject, "html": html})
class _DummySession:
def __init__(self, form):
self._form = form
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
return False
def get(self, _model, _form_id):
return self._form
def _build_job(recipient_count: int = 1) -> task_module._EmailDeliveryJob:
recipients: list[task_module._EmailRecipient] = []
for idx in range(recipient_count):
recipients.append(task_module._EmailRecipient(email=f"user{idx}@example.com", token=f"token-{idx}"))
return task_module._EmailDeliveryJob(
form_id="form-1",
subject="Subject",
body="Body for {{#url}}",
form_content="content",
recipients=recipients,
)
def test_dispatch_human_input_email_task_sends_to_each_recipient(monkeypatch: pytest.MonkeyPatch):
mail = _DummyMail()
form = SimpleNamespace(id="form-1", tenant_id="tenant-1", workflow_run_id=None)
monkeypatch.setattr(task_module, "mail", mail)
monkeypatch.setattr(
task_module.FeatureService,
"get_features",
lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=True),
)
jobs: Sequence[task_module._EmailDeliveryJob] = [_build_job(recipient_count=2)]
monkeypatch.setattr(task_module, "_load_email_jobs", lambda _session, _form: jobs)
task_module.dispatch_human_input_email_task(
form_id="form-1",
node_title="Approve",
session_factory=lambda: _DummySession(form),
)
assert len(mail.sent) == 2
assert all(payload["subject"] == "Subject" for payload in mail.sent)
assert all("Body for" in payload["html"] for payload in mail.sent)
def test_dispatch_human_input_email_task_skips_when_feature_disabled(monkeypatch: pytest.MonkeyPatch):
mail = _DummyMail()
form = SimpleNamespace(id="form-1", tenant_id="tenant-1", workflow_run_id=None)
monkeypatch.setattr(task_module, "mail", mail)
monkeypatch.setattr(
task_module.FeatureService,
"get_features",
lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=False),
)
monkeypatch.setattr(task_module, "_load_email_jobs", lambda _session, _form: [])
task_module.dispatch_human_input_email_task(
form_id="form-1",
node_title="Approve",
session_factory=lambda: _DummySession(form),
)
assert mail.sent == []
def test_dispatch_human_input_email_task_replaces_body_variables(monkeypatch: pytest.MonkeyPatch):
mail = _DummyMail()
form = SimpleNamespace(id="form-1", tenant_id="tenant-1", workflow_run_id="run-1")
job = task_module._EmailDeliveryJob(
form_id="form-1",
subject="Subject",
body="Body {{#node1.value#}}",
form_content="content",
recipients=[task_module._EmailRecipient(email="user@example.com", token="token-1")],
)
variable_pool = task_module.VariablePool()
variable_pool.add(["node1", "value"], "OK")
monkeypatch.setattr(task_module, "mail", mail)
monkeypatch.setattr(
task_module.FeatureService,
"get_features",
lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=True),
)
monkeypatch.setattr(task_module, "_load_email_jobs", lambda _session, _form: [job])
monkeypatch.setattr(task_module, "_load_variable_pool", lambda _workflow_run_id: variable_pool)
task_module.dispatch_human_input_email_task(
form_id="form-1",
node_title="Approve",
session_factory=lambda: _DummySession(form),
)
assert mail.sent[0]["html"] == "Body OK"

View File

@ -0,0 +1,39 @@
from __future__ import annotations
import json
import uuid
from unittest.mock import MagicMock
import pytest
from models.model import AppMode
from tasks.app_generate.workflow_execute_task import _publish_streaming_response
@pytest.fixture
def mock_topic(mocker) -> MagicMock:
topic = MagicMock()
mocker.patch(
"tasks.app_generate.workflow_execute_task.MessageBasedAppGenerator.get_response_topic",
return_value=topic,
)
return topic
def test_publish_streaming_response_with_uuid(mock_topic: MagicMock):
workflow_run_id = uuid.uuid4()
response_stream = iter([{"event": "foo"}, "ping"])
_publish_streaming_response(response_stream, workflow_run_id, app_mode=AppMode.ADVANCED_CHAT)
payloads = [call.args[0] for call in mock_topic.publish.call_args_list]
assert payloads == [json.dumps({"event": "foo"}).encode(), json.dumps("ping").encode()]
def test_publish_streaming_response_coerces_string_uuid(mock_topic: MagicMock):
workflow_run_id = uuid.uuid4()
response_stream = iter([{"event": "bar"}])
_publish_streaming_response(response_stream, str(workflow_run_id), app_mode=AppMode.ADVANCED_CHAT)
mock_topic.publish.assert_called_once_with(json.dumps({"event": "bar"}).encode())

View File

@ -0,0 +1,488 @@
# """
# Unit tests for workflow node execution Celery tasks.
# These tests verify the asynchronous storage functionality for workflow node execution data,
# including truncation and offloading logic.
# """
# import json
# from unittest.mock import MagicMock, Mock, patch
# from uuid import uuid4
# import pytest
# from core.workflow.entities.workflow_node_execution import (
# WorkflowNodeExecution,
# WorkflowNodeExecutionStatus,
# )
# from core.workflow.enums import NodeType
# from libs.datetime_utils import naive_utc_now
# from models import WorkflowNodeExecutionModel
# from models.enums import ExecutionOffLoadType
# from models.model import UploadFile
# from models.workflow import WorkflowNodeExecutionOffload, WorkflowNodeExecutionTriggeredFrom
# from tasks.workflow_node_execution_tasks import (
# _create_truncator,
# _json_encode,
# _replace_or_append_offload,
# _truncate_and_upload_async,
# save_workflow_node_execution_data_task,
# save_workflow_node_execution_task,
# )
# @pytest.fixture
# def sample_execution_data():
# """Sample execution data for testing."""
# execution = WorkflowNodeExecution(
# id=str(uuid4()),
# node_execution_id=str(uuid4()),
# workflow_id=str(uuid4()),
# workflow_execution_id=str(uuid4()),
# index=1,
# node_id="test_node",
# node_type=NodeType.LLM,
# title="Test Node",
# inputs={"input_key": "input_value"},
# outputs={"output_key": "output_value"},
# process_data={"process_key": "process_value"},
# status=WorkflowNodeExecutionStatus.RUNNING,
# created_at=naive_utc_now(),
# )
# return execution.model_dump()
# @pytest.fixture
# def mock_db_model():
# """Mock database model for testing."""
# db_model = Mock(spec=WorkflowNodeExecutionModel)
# db_model.id = "test-execution-id"
# db_model.offload_data = []
# return db_model
# @pytest.fixture
# def mock_file_service():
# """Mock file service for testing."""
# file_service = Mock()
# mock_upload_file = Mock(spec=UploadFile)
# mock_upload_file.id = "mock-file-id"
# file_service.upload_file.return_value = mock_upload_file
# return file_service
# class TestSaveWorkflowNodeExecutionDataTask:
# """Test cases for save_workflow_node_execution_data_task."""
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
# @patch("tasks.workflow_node_execution_tasks.select")
# def test_save_execution_data_task_success(
# self, mock_select, mock_sessionmaker, sample_execution_data, mock_db_model
# ):
# """Test successful execution of save_workflow_node_execution_data_task."""
# # Setup mocks
# mock_session = MagicMock()
# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session
# mock_session.execute.return_value.scalars.return_value.first.return_value = mock_db_model
# # Execute task
# result = save_workflow_node_execution_data_task(
# execution_data=sample_execution_data,
# tenant_id="test-tenant-id",
# app_id="test-app-id",
# user_data={"user_id": "test-user-id", "user_type": "account"},
# )
# # Verify success
# assert result is True
# mock_session.merge.assert_called_once_with(mock_db_model)
# mock_session.commit.assert_called_once()
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
# @patch("tasks.workflow_node_execution_tasks.select")
# def test_save_execution_data_task_execution_not_found(self, mock_select, mock_sessionmaker,
# sample_execution_data):
# """Test task when execution is not found in database."""
# # Setup mocks
# mock_session = MagicMock()
# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session
# mock_session.execute.return_value.scalars.return_value.first.return_value = None
# # Execute task
# result = save_workflow_node_execution_data_task(
# execution_data=sample_execution_data,
# tenant_id="test-tenant-id",
# app_id="test-app-id",
# user_data={"user_id": "test-user-id", "user_type": "account"},
# )
# # Verify failure
# assert result is False
# mock_session.merge.assert_not_called()
# mock_session.commit.assert_not_called()
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
# @patch("tasks.workflow_node_execution_tasks.select")
# def test_save_execution_data_task_with_truncation(self, mock_select, mock_sessionmaker, mock_db_model):
# """Test task with data that requires truncation."""
# # Create execution with large data
# large_data = {"large_field": "x" * 10000}
# execution = WorkflowNodeExecution(
# id=str(uuid4()),
# node_execution_id=str(uuid4()),
# workflow_id=str(uuid4()),
# workflow_execution_id=str(uuid4()),
# index=1,
# node_id="test_node",
# node_type=NodeType.LLM,
# title="Test Node",
# inputs=large_data,
# outputs=large_data,
# process_data=large_data,
# status=WorkflowNodeExecutionStatus.RUNNING,
# created_at=naive_utc_now(),
# )
# execution_data = execution.model_dump()
# # Setup mocks
# mock_session = MagicMock()
# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session
# mock_session.execute.return_value.scalars.return_value.first.return_value = mock_db_model
# # Create mock upload file
# mock_upload_file = Mock(spec=UploadFile)
# mock_upload_file.id = "mock-file-id"
# # Execute task
# with patch("tasks.workflow_node_execution_tasks._truncate_and_upload_async") as mock_truncate:
# # Mock truncation results
# mock_truncate.return_value = {
# "truncated_value": {"large_field": "[TRUNCATED]"},
# "file": mock_upload_file,
# "offload": WorkflowNodeExecutionOffload(
# id=str(uuid4()),
# tenant_id="test-tenant-id",
# app_id="test-app-id",
# node_execution_id=execution.id,
# type_=ExecutionOffLoadType.INPUTS,
# file_id=mock_upload_file.id,
# ),
# }
# result = save_workflow_node_execution_data_task(
# execution_data=execution_data,
# tenant_id="test-tenant-id",
# app_id="test-app-id",
# user_data={"user_id": "test-user-id", "user_type": "account"},
# )
# # Verify success and truncation was called
# assert result is True
# assert mock_truncate.call_count == 3 # inputs, outputs, process_data
# mock_session.merge.assert_called_once_with(mock_db_model)
# mock_session.commit.assert_called_once()
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
# def test_save_execution_data_task_retry_on_exception(self, mock_sessionmaker, sample_execution_data):
# """Test task retry mechanism on exception."""
# # Setup mock to raise exception
# mock_sessionmaker.side_effect = Exception("Database error")
# # Create a mock task instance with proper retry behavior
# with patch.object(save_workflow_node_execution_data_task, "retry") as mock_retry:
# mock_retry.side_effect = Exception("Retry called")
# # Execute task and expect retry
# with pytest.raises(Exception, match="Retry called"):
# save_workflow_node_execution_data_task(
# execution_data=sample_execution_data,
# tenant_id="test-tenant-id",
# app_id="test-app-id",
# user_data={"user_id": "test-user-id", "user_type": "account"},
# )
# # Verify retry was called
# mock_retry.assert_called_once()
# class TestTruncateAndUploadAsync:
# """Test cases for _truncate_and_upload_async function."""
# def test_truncate_and_upload_with_none_values(self, mock_file_service):
# """Test _truncate_and_upload_async with None values."""
# # The function handles None values internally, so we test with empty dict instead
# result = _truncate_and_upload_async(
# values={},
# execution_id="test-id",
# type_=ExecutionOffLoadType.INPUTS,
# tenant_id="test-tenant",
# app_id="test-app",
# user_data={"user_id": "test-user", "user_type": "account"},
# file_service=mock_file_service,
# )
# # Empty dict should not require truncation
# assert result is None
# mock_file_service.upload_file.assert_not_called()
# @patch("tasks.workflow_node_execution_tasks._create_truncator")
# def test_truncate_and_upload_no_truncation_needed(self, mock_create_truncator, mock_file_service):
# """Test _truncate_and_upload_async when no truncation is needed."""
# # Mock truncator to return no truncation
# mock_truncator = Mock()
# mock_truncator.truncate_variable_mapping.return_value = ({"small": "data"}, False)
# mock_create_truncator.return_value = mock_truncator
# small_values = {"small": "data"}
# result = _truncate_and_upload_async(
# values=small_values,
# execution_id="test-id",
# type_=ExecutionOffLoadType.INPUTS,
# tenant_id="test-tenant",
# app_id="test-app",
# user_data={"user_id": "test-user", "user_type": "account"},
# file_service=mock_file_service,
# )
# assert result is None
# mock_file_service.upload_file.assert_not_called()
# @patch("tasks.workflow_node_execution_tasks._create_truncator")
# @patch("models.Account")
# @patch("models.Tenant")
# def test_truncate_and_upload_with_account_user(
# self, mock_tenant_class, mock_account_class, mock_create_truncator, mock_file_service
# ):
# """Test _truncate_and_upload_async with account user."""
# # Mock truncator to return truncation needed
# mock_truncator = Mock()
# mock_truncator.truncate_variable_mapping.return_value = ({"truncated": "data"}, True)
# mock_create_truncator.return_value = mock_truncator
# # Mock user and tenant creation
# mock_account = Mock()
# mock_account.id = "test-user"
# mock_account_class.return_value = mock_account
# mock_tenant = Mock()
# mock_tenant.id = "test-tenant"
# mock_tenant_class.return_value = mock_tenant
# large_values = {"large": "x" * 10000}
# result = _truncate_and_upload_async(
# values=large_values,
# execution_id="test-id",
# type_=ExecutionOffLoadType.INPUTS,
# tenant_id="test-tenant",
# app_id="test-app",
# user_data={"user_id": "test-user", "user_type": "account"},
# file_service=mock_file_service,
# )
# # Verify result structure
# assert result is not None
# assert "truncated_value" in result
# assert "file" in result
# assert "offload" in result
# assert result["truncated_value"] == {"truncated": "data"}
# # Verify file upload was called
# mock_file_service.upload_file.assert_called_once()
# upload_call = mock_file_service.upload_file.call_args
# assert upload_call[1]["filename"] == "node_execution_test-id_inputs.json"
# assert upload_call[1]["mimetype"] == "application/json"
# assert upload_call[1]["user"] == mock_account
# @patch("tasks.workflow_node_execution_tasks._create_truncator")
# @patch("models.EndUser")
# def test_truncate_and_upload_with_end_user(self, mock_end_user_class, mock_create_truncator, mock_file_service):
# """Test _truncate_and_upload_async with end user."""
# # Mock truncator to return truncation needed
# mock_truncator = Mock()
# mock_truncator.truncate_variable_mapping.return_value = ({"truncated": "data"}, True)
# mock_create_truncator.return_value = mock_truncator
# # Mock end user creation
# mock_end_user = Mock()
# mock_end_user.id = "test-user"
# mock_end_user.tenant_id = "test-tenant"
# mock_end_user_class.return_value = mock_end_user
# large_values = {"large": "x" * 10000}
# result = _truncate_and_upload_async(
# values=large_values,
# execution_id="test-id",
# type_=ExecutionOffLoadType.OUTPUTS,
# tenant_id="test-tenant",
# app_id="test-app",
# user_data={"user_id": "test-user", "user_type": "end_user"},
# file_service=mock_file_service,
# )
# # Verify result structure
# assert result is not None
# assert result["truncated_value"] == {"truncated": "data"}
# # Verify file upload was called with end user
# mock_file_service.upload_file.assert_called_once()
# upload_call = mock_file_service.upload_file.call_args
# assert upload_call[1]["filename"] == "node_execution_test-id_outputs.json"
# assert upload_call[1]["user"] == mock_end_user
# class TestHelperFunctions:
# """Test cases for helper functions."""
# @patch("tasks.workflow_node_execution_tasks.dify_config")
# def test_create_truncator(self, mock_config):
# """Test _create_truncator function."""
# mock_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE = 1000
# mock_config.WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH = 100
# mock_config.WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH = 500
# truncator = _create_truncator()
# # Verify truncator was created with correct config
# assert truncator is not None
# def test_json_encode(self):
# """Test _json_encode function."""
# test_data = {"key": "value", "number": 42}
# result = _json_encode(test_data)
# assert isinstance(result, str)
# decoded = json.loads(result)
# assert decoded == test_data
# def test_replace_or_append_offload_replace_existing(self):
# """Test _replace_or_append_offload replaces existing offload of same type."""
# existing_offload = WorkflowNodeExecutionOffload(
# id=str(uuid4()),
# tenant_id="test-tenant",
# app_id="test-app",
# node_execution_id="test-execution",
# type_=ExecutionOffLoadType.INPUTS,
# file_id="old-file-id",
# )
# new_offload = WorkflowNodeExecutionOffload(
# id=str(uuid4()),
# tenant_id="test-tenant",
# app_id="test-app",
# node_execution_id="test-execution",
# type_=ExecutionOffLoadType.INPUTS,
# file_id="new-file-id",
# )
# result = _replace_or_append_offload([existing_offload], new_offload)
# assert len(result) == 1
# assert result[0].file_id == "new-file-id"
# def test_replace_or_append_offload_append_new_type(self):
# """Test _replace_or_append_offload appends new offload of different type."""
# existing_offload = WorkflowNodeExecutionOffload(
# id=str(uuid4()),
# tenant_id="test-tenant",
# app_id="test-app",
# node_execution_id="test-execution",
# type_=ExecutionOffLoadType.INPUTS,
# file_id="inputs-file-id",
# )
# new_offload = WorkflowNodeExecutionOffload(
# id=str(uuid4()),
# tenant_id="test-tenant",
# app_id="test-app",
# node_execution_id="test-execution",
# type_=ExecutionOffLoadType.OUTPUTS,
# file_id="outputs-file-id",
# )
# result = _replace_or_append_offload([existing_offload], new_offload)
# assert len(result) == 2
# file_ids = [offload.file_id for offload in result]
# assert "inputs-file-id" in file_ids
# assert "outputs-file-id" in file_ids
# class TestSaveWorkflowNodeExecutionTask:
# """Test cases for save_workflow_node_execution_task."""
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
# @patch("tasks.workflow_node_execution_tasks.select")
# def test_save_workflow_node_execution_task_create_new(self, mock_select, mock_sessionmaker,
# sample_execution_data):
# """Test creating a new workflow node execution."""
# # Setup mocks
# mock_session = MagicMock()
# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session
# mock_session.scalar.return_value = None # No existing execution
# # Execute task
# result = save_workflow_node_execution_task(
# execution_data=sample_execution_data,
# tenant_id="test-tenant-id",
# app_id="test-app-id",
# triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
# creator_user_id="test-user-id",
# creator_user_role="account",
# )
# # Verify success
# assert result is True
# mock_session.add.assert_called_once()
# mock_session.commit.assert_called_once()
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
# @patch("tasks.workflow_node_execution_tasks.select")
# def test_save_workflow_node_execution_task_update_existing(
# self, mock_select, mock_sessionmaker, sample_execution_data
# ):
# """Test updating an existing workflow node execution."""
# # Setup mocks
# mock_session = MagicMock()
# mock_sessionmaker.return_value.return_value.__enter__.return_value = mock_session
# existing_execution = Mock(spec=WorkflowNodeExecutionModel)
# mock_session.scalar.return_value = existing_execution
# # Execute task
# result = save_workflow_node_execution_task(
# execution_data=sample_execution_data,
# tenant_id="test-tenant-id",
# app_id="test-app-id",
# triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
# creator_user_id="test-user-id",
# creator_user_role="account",
# )
# # Verify success
# assert result is True
# mock_session.add.assert_not_called() # Should not add new, just update existing
# mock_session.commit.assert_called_once()
# @patch("tasks.workflow_node_execution_tasks.sessionmaker")
# def test_save_workflow_node_execution_task_retry_on_exception(self, mock_sessionmaker, sample_execution_data):
# """Test task retry mechanism on exception."""
# # Setup mock to raise exception
# mock_sessionmaker.side_effect = Exception("Database error")
# # Create a mock task instance with proper retry behavior
# with patch.object(save_workflow_node_execution_task, "retry") as mock_retry:
# mock_retry.side_effect = Exception("Retry called")
# # Execute task and expect retry
# with pytest.raises(Exception, match="Retry called"):
# save_workflow_node_execution_task(
# execution_data=sample_execution_data,
# tenant_id="test-tenant-id",
# app_id="test-app-id",
# triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
# creator_user_id="test-user-id",
# creator_user_role="account",
# )
# # Verify retry was called
# mock_retry.assert_called_once()