feat(api): Implement HITL for Workflow, add is_resumption for start event

This commit is contained in:
QuantumGhost
2025-12-30 16:40:08 +08:00
parent 01325c543f
commit 37dd61558c
27 changed files with 762 additions and 344 deletions

View File

@ -508,9 +508,12 @@ class TestConversationServiceMessageCreation:
within conversations.
"""
@patch("services.message_service._create_execution_extra_content_repository")
@patch("services.message_service.db.session")
@patch("services.message_service.ConversationService.get_conversation")
def test_pagination_by_first_id_without_first_id(self, mock_get_conversation, mock_db_session):
def test_pagination_by_first_id_without_first_id(
self, mock_get_conversation, mock_db_session, mock_create_extra_repo
):
"""
Test message pagination without specifying first_id.
@ -540,6 +543,9 @@ class TestConversationServiceMessageCreation:
mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining
mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
mock_query.all.return_value = messages # Final .all() returns the messages
mock_repository = MagicMock()
mock_repository.get_by_message_ids.return_value = [[] for _ in messages]
mock_create_extra_repo.return_value = mock_repository
# Act - Call the pagination method without first_id
result = MessageService.pagination_by_first_id(
@ -556,9 +562,10 @@ class TestConversationServiceMessageCreation:
# Verify conversation was looked up with correct parameters
mock_get_conversation.assert_called_once_with(app_model=app_model, user=user, conversation_id=conversation.id)
@patch("services.message_service._create_execution_extra_content_repository")
@patch("services.message_service.db.session")
@patch("services.message_service.ConversationService.get_conversation")
def test_pagination_by_first_id_with_first_id(self, mock_get_conversation, mock_db_session):
def test_pagination_by_first_id_with_first_id(self, mock_get_conversation, mock_db_session, mock_create_extra_repo):
"""
Test message pagination with first_id specified.
@ -590,6 +597,9 @@ class TestConversationServiceMessageCreation:
mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
mock_query.first.return_value = first_message # First message returned
mock_query.all.return_value = messages # Remaining messages returned
mock_repository = MagicMock()
mock_repository.get_by_message_ids.return_value = [[] for _ in messages]
mock_create_extra_repo.return_value = mock_repository
# Act - Call the pagination method with first_id
result = MessageService.pagination_by_first_id(
@ -684,9 +694,10 @@ class TestConversationServiceMessageCreation:
assert result.data == []
assert result.has_more is False
@patch("services.message_service._create_execution_extra_content_repository")
@patch("services.message_service.db.session")
@patch("services.message_service.ConversationService.get_conversation")
def test_pagination_with_has_more_flag(self, mock_get_conversation, mock_db_session):
def test_pagination_with_has_more_flag(self, mock_get_conversation, mock_db_session, mock_create_extra_repo):
"""
Test that has_more flag is correctly set when there are more messages.
@ -716,6 +727,9 @@ class TestConversationServiceMessageCreation:
mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining
mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
mock_query.all.return_value = messages # Final .all() returns the messages
mock_repository = MagicMock()
mock_repository.get_by_message_ids.return_value = [[] for _ in messages]
mock_create_extra_repo.return_value = mock_repository
# Act
result = MessageService.pagination_by_first_id(
@ -730,9 +744,10 @@ class TestConversationServiceMessageCreation:
assert len(result.data) == limit # Extra message should be removed
assert result.has_more is True # Flag should be set
@patch("services.message_service._create_execution_extra_content_repository")
@patch("services.message_service.db.session")
@patch("services.message_service.ConversationService.get_conversation")
def test_pagination_with_ascending_order(self, mock_get_conversation, mock_db_session):
def test_pagination_with_ascending_order(self, mock_get_conversation, mock_db_session, mock_create_extra_repo):
"""
Test message pagination with ascending order.
@ -761,6 +776,9 @@ class TestConversationServiceMessageCreation:
mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining
mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
mock_query.all.return_value = messages # Final .all() returns the messages
mock_repository = MagicMock()
mock_repository.get_by_message_ids.return_value = [[] for _ in messages]
mock_create_extra_repo.return_value = mock_repository
# Act
result = MessageService.pagination_by_first_id(

View File

@ -65,72 +65,25 @@ def sample_form_record():
)
def test_enqueue_resume_dispatches_task(mocker, mock_session_factory):
def test_enqueue_resume_dispatches_task_for_workflow(mocker, mock_session_factory):
session_factory, session = mock_session_factory
service = HumanInputService(session_factory)
trigger_log = MagicMock()
trigger_log.id = "trigger-log-id"
trigger_log.queue_name = "workflow_queue"
repo_cls = mocker.patch(
"services.human_input_service.SQLAlchemyWorkflowTriggerLogRepository",
autospec=True,
)
repo = repo_cls.return_value
repo.get_by_workflow_run_id.return_value = trigger_log
resume_task = mocker.patch("services.human_input_service.resume_workflow_execution")
service._enqueue_resume("workflow-run-id")
repo_cls.assert_called_once_with(session)
resume_task.apply_async.assert_called_once()
call_kwargs = resume_task.apply_async.call_args.kwargs
assert call_kwargs["queue"] == "workflow_queue"
payload = call_kwargs["kwargs"]["task_data_dict"]
assert payload["workflow_trigger_log_id"] == "trigger-log-id"
assert payload["workflow_run_id"] == "workflow-run-id"
def test_enqueue_resume_no_trigger_log(mocker, mock_session_factory):
session_factory, session = mock_session_factory
service = HumanInputService(session_factory)
repo_cls = mocker.patch(
"services.human_input_service.SQLAlchemyWorkflowTriggerLogRepository",
autospec=True,
)
repo = repo_cls.return_value
repo.get_by_workflow_run_id.return_value = None
resume_task = mocker.patch("services.human_input_service.resume_workflow_execution")
service._enqueue_resume("workflow-run-id")
repo_cls.assert_called_once_with(session)
resume_task.apply_async.assert_not_called()
def test_enqueue_resume_chatflow_fallback(mocker, mock_session_factory):
session_factory, session = mock_session_factory
service = HumanInputService(session_factory)
repo_cls = mocker.patch(
"services.human_input_service.SQLAlchemyWorkflowTriggerLogRepository",
autospec=True,
)
repo = repo_cls.return_value
repo.get_by_workflow_run_id.return_value = None
workflow_run = MagicMock()
workflow_run.app_id = "app-id"
workflow_run_repo = MagicMock()
workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run
mocker.patch(
"services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository",
return_value=workflow_run_repo,
)
app = MagicMock()
app.mode = "advanced-chat"
app.mode = "workflow"
session.execute.return_value.scalar_one_or_none.return_value = app
session.get.side_effect = [workflow_run, app]
resume_task = mocker.patch("services.human_input_service.resume_chatflow_execution")
resume_task = mocker.patch("services.human_input_service.resume_app_execution")
service._enqueue_resume("workflow-run-id")
@ -140,6 +93,59 @@ def test_enqueue_resume_chatflow_fallback(mocker, mock_session_factory):
assert call_kwargs["kwargs"]["payload"]["workflow_run_id"] == "workflow-run-id"
def test_enqueue_resume_dispatches_task_for_advanced_chat(mocker, mock_session_factory):
session_factory, session = mock_session_factory
service = HumanInputService(session_factory)
workflow_run = MagicMock()
workflow_run.app_id = "app-id"
workflow_run_repo = MagicMock()
workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run
mocker.patch(
"services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository",
return_value=workflow_run_repo,
)
app = MagicMock()
app.mode = "advanced-chat"
session.execute.return_value.scalar_one_or_none.return_value = app
resume_task = mocker.patch("services.human_input_service.resume_app_execution")
service._enqueue_resume("workflow-run-id")
resume_task.apply_async.assert_called_once()
call_kwargs = resume_task.apply_async.call_args.kwargs
assert call_kwargs["queue"] == "chatflow_execute"
assert call_kwargs["kwargs"]["payload"]["workflow_run_id"] == "workflow-run-id"
def test_enqueue_resume_skips_unsupported_app_mode(mocker, mock_session_factory):
session_factory, session = mock_session_factory
service = HumanInputService(session_factory)
workflow_run = MagicMock()
workflow_run.app_id = "app-id"
workflow_run_repo = MagicMock()
workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run
mocker.patch(
"services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository",
return_value=workflow_run_repo,
)
app = MagicMock()
app.mode = "completion"
session.execute.return_value.scalar_one_or_none.return_value = app
resume_task = mocker.patch("services.human_input_service.resume_app_execution")
service._enqueue_resume("workflow-run-id")
resume_task.apply_async.assert_not_called()
def test_get_form_definition_by_id_uses_repository(sample_form_record, mock_session_factory):
session_factory, _ = mock_session_factory
repo = MagicMock(spec=HumanInputFormSubmissionRepository)