feat(api): Implement HITL for Workflow, add is_resumption for start event

This commit is contained in:
QuantumGhost
2025-12-30 16:40:08 +08:00
parent 01325c543f
commit 37dd61558c
27 changed files with 762 additions and 344 deletions

View File

@ -0,0 +1,160 @@
from __future__ import annotations
from dataclasses import dataclass
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from flask import Flask
from controllers.console import wraps as console_wraps
from controllers.console.app import workflow as workflow_module
from controllers.console.app import wraps as app_wraps
from libs import login as login_lib
from models.account import Account, AccountStatus, TenantAccountRole
from models.model import AppMode
def _make_account() -> Account:
account = Account(name="tester", email="tester@example.com")
account.status = AccountStatus.ACTIVE
account.role = TenantAccountRole.OWNER
account.id = "account-123" # type: ignore[assignment]
account._current_tenant = SimpleNamespace(id="tenant-123") # type: ignore[attr-defined]
account._get_current_object = lambda: account # type: ignore[attr-defined]
return account
def _make_app(mode: AppMode) -> SimpleNamespace:
return SimpleNamespace(id="app-123", tenant_id="tenant-123", mode=mode.value)
def _patch_console_guards(monkeypatch: pytest.MonkeyPatch, account: Account, app_model: SimpleNamespace) -> None:
# Skip setup and auth guardrails
monkeypatch.setattr("configs.dify_config.EDITION", "CLOUD")
monkeypatch.setattr(login_lib.dify_config, "LOGIN_DISABLED", True)
monkeypatch.setattr(login_lib, "current_user", account)
monkeypatch.setattr(login_lib, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
monkeypatch.setattr(login_lib, "check_csrf_token", lambda *_, **__: None)
monkeypatch.setattr(console_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
monkeypatch.setattr(app_wraps, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (account, account.current_tenant_id))
monkeypatch.setattr(console_wraps.dify_config, "EDITION", "CLOUD")
monkeypatch.delenv("INIT_PASSWORD", raising=False)
# Avoid hitting the database when resolving the app model
monkeypatch.setattr(app_wraps, "_load_app_model", lambda _app_id: app_model)
@dataclass
class PreviewCase:
resource_cls: type
path: str
mode: AppMode
@pytest.mark.parametrize(
"case",
[
PreviewCase(
resource_cls=workflow_module.AdvancedChatDraftHumanInputFormApi,
path="/console/api/apps/app-123/advanced-chat/workflows/draft/human-input/nodes/node-42/form",
mode=AppMode.ADVANCED_CHAT,
),
PreviewCase(
resource_cls=workflow_module.WorkflowDraftHumanInputFormApi,
path="/console/api/apps/app-123/workflows/draft/human-input/nodes/node-42/form",
mode=AppMode.WORKFLOW,
),
],
)
def test_human_input_preview_delegates_to_service(
app: Flask, monkeypatch: pytest.MonkeyPatch, case: PreviewCase
) -> None:
account = _make_account()
app_model = _make_app(case.mode)
_patch_console_guards(monkeypatch, account, app_model)
preview_payload = {
"form_id": "node-42",
"form_content": "<div>example</div>",
"inputs": [{"name": "topic"}],
"actions": [{"id": "continue"}],
}
service_instance = MagicMock()
service_instance.get_human_input_form_preview.return_value = preview_payload
monkeypatch.setattr(workflow_module, "WorkflowService", MagicMock(return_value=service_instance))
with app.test_request_context(case.path, method="GET", json={"inputs": {"topic": "tech"}}):
response = case.resource_cls().get(app_id=app_model.id, node_id="node-42")
assert response == preview_payload
service_instance.get_human_input_form_preview.assert_called_once_with(
app_model=app_model,
account=account,
node_id="node-42",
manual_inputs={"topic": "tech"},
)
@dataclass
class SubmitCase:
resource_cls: type
path: str
mode: AppMode
@pytest.mark.parametrize(
"case",
[
SubmitCase(
resource_cls=workflow_module.AdvancedChatDraftHumanInputFormApi,
path="/console/api/apps/app-123/advanced-chat/workflows/draft/human-input/nodes/node-99/form",
mode=AppMode.ADVANCED_CHAT,
),
SubmitCase(
resource_cls=workflow_module.WorkflowDraftHumanInputFormApi,
path="/console/api/apps/app-123/workflows/draft/human-input/nodes/node-99/form",
mode=AppMode.WORKFLOW,
),
],
)
def test_human_input_submit_forwards_payload(app: Flask, monkeypatch: pytest.MonkeyPatch, case: SubmitCase) -> None:
account = _make_account()
app_model = _make_app(case.mode)
_patch_console_guards(monkeypatch, account, app_model)
result_payload = {"node_id": "node-99", "outputs": {"__rendered_content": "<p>done</p>"}, "action": "approve"}
service_instance = MagicMock()
service_instance.submit_human_input_form_preview.return_value = result_payload
monkeypatch.setattr(workflow_module, "WorkflowService", MagicMock(return_value=service_instance))
with app.test_request_context(
case.path,
method="POST",
json={"inputs": {"answer": "42"}, "action": "approve"},
):
response = case.resource_cls().post(app_id=app_model.id, node_id="node-99")
assert response == result_payload
service_instance.submit_human_input_form_preview.assert_called_once_with(
app_model=app_model,
account=account,
node_id="node-99",
form_inputs={"answer": "42"},
action="approve",
)
def test_human_input_preview_rejects_non_mapping(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
account = _make_account()
app_model = _make_app(AppMode.ADVANCED_CHAT)
_patch_console_guards(monkeypatch, account, app_model)
with app.test_request_context(
"/console/api/apps/app-123/advanced-chat/workflows/draft/human-input/nodes/node-1/form",
method="GET",
json={"inputs": ["not-a-dict"]},
):
with pytest.raises(ValueError):
workflow_module.AdvancedChatDraftHumanInputFormApi().get(app_id=app_model.id, node_id="node-1")

View File

@ -124,7 +124,12 @@ class TestWorkflowResponseConverter:
original_data = {"large_field": "x" * 10000, "metadata": "info"}
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
converter.workflow_start_to_stream_response(
task_id="bootstrap",
workflow_run_id="run-id",
workflow_id="wf-id",
is_resumption=False,
)
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
@ -160,7 +165,12 @@ class TestWorkflowResponseConverter:
original_data = {"small": "data"}
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
converter.workflow_start_to_stream_response(
task_id="bootstrap",
workflow_run_id="run-id",
workflow_id="wf-id",
is_resumption=False,
)
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
@ -191,7 +201,12 @@ class TestWorkflowResponseConverter:
"""Test node finish response when process_data is None."""
converter = self.create_workflow_response_converter()
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
converter.workflow_start_to_stream_response(
task_id="bootstrap",
workflow_run_id="run-id",
workflow_id="wf-id",
is_resumption=False,
)
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
@ -225,7 +240,12 @@ class TestWorkflowResponseConverter:
original_data = {"large_field": "x" * 10000, "metadata": "info"}
truncated_data = {"large_field": "[TRUNCATED]", "metadata": "info"}
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
converter.workflow_start_to_stream_response(
task_id="bootstrap",
workflow_run_id="run-id",
workflow_id="wf-id",
is_resumption=False,
)
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
@ -261,7 +281,12 @@ class TestWorkflowResponseConverter:
original_data = {"small": "data"}
converter.workflow_start_to_stream_response(task_id="bootstrap", workflow_run_id="run-id", workflow_id="wf-id")
converter.workflow_start_to_stream_response(
task_id="bootstrap",
workflow_run_id="run-id",
workflow_id="wf-id",
is_resumption=False,
)
start_event = self.create_node_started_event()
converter.workflow_node_start_to_stream_response(
event=start_event,
@ -400,6 +425,7 @@ class TestWorkflowResponseConverterServiceApiTruncation:
task_id="test-task-id",
workflow_run_id="test-workflow-run-id",
workflow_id="test-workflow-id",
is_resumption=False,
)
return converter

View File

@ -112,7 +112,12 @@ def _build_converter():
def test_queue_workflow_paused_event_to_stream_responses():
converter = _build_converter()
converter.workflow_start_to_stream_response(task_id="task", workflow_run_id="run-id", workflow_id="workflow-id")
converter.workflow_start_to_stream_response(
task_id="task",
workflow_run_id="run-id",
workflow_id="workflow-id",
is_resumption=False,
)
reason = HumanInputRequired(
form_id="form-1",

View File

@ -508,9 +508,12 @@ class TestConversationServiceMessageCreation:
within conversations.
"""
@patch("services.message_service._create_execution_extra_content_repository")
@patch("services.message_service.db.session")
@patch("services.message_service.ConversationService.get_conversation")
def test_pagination_by_first_id_without_first_id(self, mock_get_conversation, mock_db_session):
def test_pagination_by_first_id_without_first_id(
self, mock_get_conversation, mock_db_session, mock_create_extra_repo
):
"""
Test message pagination without specifying first_id.
@ -540,6 +543,9 @@ class TestConversationServiceMessageCreation:
mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining
mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
mock_query.all.return_value = messages # Final .all() returns the messages
mock_repository = MagicMock()
mock_repository.get_by_message_ids.return_value = [[] for _ in messages]
mock_create_extra_repo.return_value = mock_repository
# Act - Call the pagination method without first_id
result = MessageService.pagination_by_first_id(
@ -556,9 +562,10 @@ class TestConversationServiceMessageCreation:
# Verify conversation was looked up with correct parameters
mock_get_conversation.assert_called_once_with(app_model=app_model, user=user, conversation_id=conversation.id)
@patch("services.message_service._create_execution_extra_content_repository")
@patch("services.message_service.db.session")
@patch("services.message_service.ConversationService.get_conversation")
def test_pagination_by_first_id_with_first_id(self, mock_get_conversation, mock_db_session):
def test_pagination_by_first_id_with_first_id(self, mock_get_conversation, mock_db_session, mock_create_extra_repo):
"""
Test message pagination with first_id specified.
@ -590,6 +597,9 @@ class TestConversationServiceMessageCreation:
mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
mock_query.first.return_value = first_message # First message returned
mock_query.all.return_value = messages # Remaining messages returned
mock_repository = MagicMock()
mock_repository.get_by_message_ids.return_value = [[] for _ in messages]
mock_create_extra_repo.return_value = mock_repository
# Act - Call the pagination method with first_id
result = MessageService.pagination_by_first_id(
@ -684,9 +694,10 @@ class TestConversationServiceMessageCreation:
assert result.data == []
assert result.has_more is False
@patch("services.message_service._create_execution_extra_content_repository")
@patch("services.message_service.db.session")
@patch("services.message_service.ConversationService.get_conversation")
def test_pagination_with_has_more_flag(self, mock_get_conversation, mock_db_session):
def test_pagination_with_has_more_flag(self, mock_get_conversation, mock_db_session, mock_create_extra_repo):
"""
Test that has_more flag is correctly set when there are more messages.
@ -716,6 +727,9 @@ class TestConversationServiceMessageCreation:
mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining
mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
mock_query.all.return_value = messages # Final .all() returns the messages
mock_repository = MagicMock()
mock_repository.get_by_message_ids.return_value = [[] for _ in messages]
mock_create_extra_repo.return_value = mock_repository
# Act
result = MessageService.pagination_by_first_id(
@ -730,9 +744,10 @@ class TestConversationServiceMessageCreation:
assert len(result.data) == limit # Extra message should be removed
assert result.has_more is True # Flag should be set
@patch("services.message_service._create_execution_extra_content_repository")
@patch("services.message_service.db.session")
@patch("services.message_service.ConversationService.get_conversation")
def test_pagination_with_ascending_order(self, mock_get_conversation, mock_db_session):
def test_pagination_with_ascending_order(self, mock_get_conversation, mock_db_session, mock_create_extra_repo):
"""
Test message pagination with ascending order.
@ -761,6 +776,9 @@ class TestConversationServiceMessageCreation:
mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining
mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
mock_query.all.return_value = messages # Final .all() returns the messages
mock_repository = MagicMock()
mock_repository.get_by_message_ids.return_value = [[] for _ in messages]
mock_create_extra_repo.return_value = mock_repository
# Act
result = MessageService.pagination_by_first_id(

View File

@ -65,72 +65,25 @@ def sample_form_record():
)
def test_enqueue_resume_dispatches_task(mocker, mock_session_factory):
def test_enqueue_resume_dispatches_task_for_workflow(mocker, mock_session_factory):
session_factory, session = mock_session_factory
service = HumanInputService(session_factory)
trigger_log = MagicMock()
trigger_log.id = "trigger-log-id"
trigger_log.queue_name = "workflow_queue"
repo_cls = mocker.patch(
"services.human_input_service.SQLAlchemyWorkflowTriggerLogRepository",
autospec=True,
)
repo = repo_cls.return_value
repo.get_by_workflow_run_id.return_value = trigger_log
resume_task = mocker.patch("services.human_input_service.resume_workflow_execution")
service._enqueue_resume("workflow-run-id")
repo_cls.assert_called_once_with(session)
resume_task.apply_async.assert_called_once()
call_kwargs = resume_task.apply_async.call_args.kwargs
assert call_kwargs["queue"] == "workflow_queue"
payload = call_kwargs["kwargs"]["task_data_dict"]
assert payload["workflow_trigger_log_id"] == "trigger-log-id"
assert payload["workflow_run_id"] == "workflow-run-id"
def test_enqueue_resume_no_trigger_log(mocker, mock_session_factory):
session_factory, session = mock_session_factory
service = HumanInputService(session_factory)
repo_cls = mocker.patch(
"services.human_input_service.SQLAlchemyWorkflowTriggerLogRepository",
autospec=True,
)
repo = repo_cls.return_value
repo.get_by_workflow_run_id.return_value = None
resume_task = mocker.patch("services.human_input_service.resume_workflow_execution")
service._enqueue_resume("workflow-run-id")
repo_cls.assert_called_once_with(session)
resume_task.apply_async.assert_not_called()
def test_enqueue_resume_chatflow_fallback(mocker, mock_session_factory):
session_factory, session = mock_session_factory
service = HumanInputService(session_factory)
repo_cls = mocker.patch(
"services.human_input_service.SQLAlchemyWorkflowTriggerLogRepository",
autospec=True,
)
repo = repo_cls.return_value
repo.get_by_workflow_run_id.return_value = None
workflow_run = MagicMock()
workflow_run.app_id = "app-id"
workflow_run_repo = MagicMock()
workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run
mocker.patch(
"services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository",
return_value=workflow_run_repo,
)
app = MagicMock()
app.mode = "advanced-chat"
app.mode = "workflow"
session.execute.return_value.scalar_one_or_none.return_value = app
session.get.side_effect = [workflow_run, app]
resume_task = mocker.patch("services.human_input_service.resume_chatflow_execution")
resume_task = mocker.patch("services.human_input_service.resume_app_execution")
service._enqueue_resume("workflow-run-id")
@ -140,6 +93,59 @@ def test_enqueue_resume_chatflow_fallback(mocker, mock_session_factory):
assert call_kwargs["kwargs"]["payload"]["workflow_run_id"] == "workflow-run-id"
def test_enqueue_resume_dispatches_task_for_advanced_chat(mocker, mock_session_factory):
session_factory, session = mock_session_factory
service = HumanInputService(session_factory)
workflow_run = MagicMock()
workflow_run.app_id = "app-id"
workflow_run_repo = MagicMock()
workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run
mocker.patch(
"services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository",
return_value=workflow_run_repo,
)
app = MagicMock()
app.mode = "advanced-chat"
session.execute.return_value.scalar_one_or_none.return_value = app
resume_task = mocker.patch("services.human_input_service.resume_app_execution")
service._enqueue_resume("workflow-run-id")
resume_task.apply_async.assert_called_once()
call_kwargs = resume_task.apply_async.call_args.kwargs
assert call_kwargs["queue"] == "chatflow_execute"
assert call_kwargs["kwargs"]["payload"]["workflow_run_id"] == "workflow-run-id"
def test_enqueue_resume_skips_unsupported_app_mode(mocker, mock_session_factory):
session_factory, session = mock_session_factory
service = HumanInputService(session_factory)
workflow_run = MagicMock()
workflow_run.app_id = "app-id"
workflow_run_repo = MagicMock()
workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run
mocker.patch(
"services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository",
return_value=workflow_run_repo,
)
app = MagicMock()
app.mode = "completion"
session.execute.return_value.scalar_one_or_none.return_value = app
resume_task = mocker.patch("services.human_input_service.resume_app_execution")
service._enqueue_resume("workflow-run-id")
resume_task.apply_async.assert_not_called()
def test_get_form_definition_by_id_uses_repository(sample_form_record, mock_session_factory):
session_factory, _ = mock_session_factory
repo = MagicMock(spec=HumanInputFormSubmissionRepository)