mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 17:38:04 +08:00
feat(api): Implement HITL for Workflow, add is_resumption for start event
This commit is contained in:
@ -508,9 +508,12 @@ class TestConversationServiceMessageCreation:
|
||||
within conversations.
|
||||
"""
|
||||
|
||||
@patch("services.message_service._create_execution_extra_content_repository")
|
||||
@patch("services.message_service.db.session")
|
||||
@patch("services.message_service.ConversationService.get_conversation")
|
||||
def test_pagination_by_first_id_without_first_id(self, mock_get_conversation, mock_db_session):
|
||||
def test_pagination_by_first_id_without_first_id(
|
||||
self, mock_get_conversation, mock_db_session, mock_create_extra_repo
|
||||
):
|
||||
"""
|
||||
Test message pagination without specifying first_id.
|
||||
|
||||
@ -540,6 +543,9 @@ class TestConversationServiceMessageCreation:
|
||||
mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining
|
||||
mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
|
||||
mock_query.all.return_value = messages # Final .all() returns the messages
|
||||
mock_repository = MagicMock()
|
||||
mock_repository.get_by_message_ids.return_value = [[] for _ in messages]
|
||||
mock_create_extra_repo.return_value = mock_repository
|
||||
|
||||
# Act - Call the pagination method without first_id
|
||||
result = MessageService.pagination_by_first_id(
|
||||
@ -556,9 +562,10 @@ class TestConversationServiceMessageCreation:
|
||||
# Verify conversation was looked up with correct parameters
|
||||
mock_get_conversation.assert_called_once_with(app_model=app_model, user=user, conversation_id=conversation.id)
|
||||
|
||||
@patch("services.message_service._create_execution_extra_content_repository")
|
||||
@patch("services.message_service.db.session")
|
||||
@patch("services.message_service.ConversationService.get_conversation")
|
||||
def test_pagination_by_first_id_with_first_id(self, mock_get_conversation, mock_db_session):
|
||||
def test_pagination_by_first_id_with_first_id(self, mock_get_conversation, mock_db_session, mock_create_extra_repo):
|
||||
"""
|
||||
Test message pagination with first_id specified.
|
||||
|
||||
@ -590,6 +597,9 @@ class TestConversationServiceMessageCreation:
|
||||
mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
|
||||
mock_query.first.return_value = first_message # First message returned
|
||||
mock_query.all.return_value = messages # Remaining messages returned
|
||||
mock_repository = MagicMock()
|
||||
mock_repository.get_by_message_ids.return_value = [[] for _ in messages]
|
||||
mock_create_extra_repo.return_value = mock_repository
|
||||
|
||||
# Act - Call the pagination method with first_id
|
||||
result = MessageService.pagination_by_first_id(
|
||||
@ -684,9 +694,10 @@ class TestConversationServiceMessageCreation:
|
||||
assert result.data == []
|
||||
assert result.has_more is False
|
||||
|
||||
@patch("services.message_service._create_execution_extra_content_repository")
|
||||
@patch("services.message_service.db.session")
|
||||
@patch("services.message_service.ConversationService.get_conversation")
|
||||
def test_pagination_with_has_more_flag(self, mock_get_conversation, mock_db_session):
|
||||
def test_pagination_with_has_more_flag(self, mock_get_conversation, mock_db_session, mock_create_extra_repo):
|
||||
"""
|
||||
Test that has_more flag is correctly set when there are more messages.
|
||||
|
||||
@ -716,6 +727,9 @@ class TestConversationServiceMessageCreation:
|
||||
mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining
|
||||
mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
|
||||
mock_query.all.return_value = messages # Final .all() returns the messages
|
||||
mock_repository = MagicMock()
|
||||
mock_repository.get_by_message_ids.return_value = [[] for _ in messages]
|
||||
mock_create_extra_repo.return_value = mock_repository
|
||||
|
||||
# Act
|
||||
result = MessageService.pagination_by_first_id(
|
||||
@ -730,9 +744,10 @@ class TestConversationServiceMessageCreation:
|
||||
assert len(result.data) == limit # Extra message should be removed
|
||||
assert result.has_more is True # Flag should be set
|
||||
|
||||
@patch("services.message_service._create_execution_extra_content_repository")
|
||||
@patch("services.message_service.db.session")
|
||||
@patch("services.message_service.ConversationService.get_conversation")
|
||||
def test_pagination_with_ascending_order(self, mock_get_conversation, mock_db_session):
|
||||
def test_pagination_with_ascending_order(self, mock_get_conversation, mock_db_session, mock_create_extra_repo):
|
||||
"""
|
||||
Test message pagination with ascending order.
|
||||
|
||||
@ -761,6 +776,9 @@ class TestConversationServiceMessageCreation:
|
||||
mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining
|
||||
mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
|
||||
mock_query.all.return_value = messages # Final .all() returns the messages
|
||||
mock_repository = MagicMock()
|
||||
mock_repository.get_by_message_ids.return_value = [[] for _ in messages]
|
||||
mock_create_extra_repo.return_value = mock_repository
|
||||
|
||||
# Act
|
||||
result = MessageService.pagination_by_first_id(
|
||||
|
||||
@ -65,72 +65,25 @@ def sample_form_record():
|
||||
)
|
||||
|
||||
|
||||
def test_enqueue_resume_dispatches_task(mocker, mock_session_factory):
|
||||
def test_enqueue_resume_dispatches_task_for_workflow(mocker, mock_session_factory):
|
||||
session_factory, session = mock_session_factory
|
||||
service = HumanInputService(session_factory)
|
||||
|
||||
trigger_log = MagicMock()
|
||||
trigger_log.id = "trigger-log-id"
|
||||
trigger_log.queue_name = "workflow_queue"
|
||||
|
||||
repo_cls = mocker.patch(
|
||||
"services.human_input_service.SQLAlchemyWorkflowTriggerLogRepository",
|
||||
autospec=True,
|
||||
)
|
||||
repo = repo_cls.return_value
|
||||
repo.get_by_workflow_run_id.return_value = trigger_log
|
||||
|
||||
resume_task = mocker.patch("services.human_input_service.resume_workflow_execution")
|
||||
|
||||
service._enqueue_resume("workflow-run-id")
|
||||
|
||||
repo_cls.assert_called_once_with(session)
|
||||
resume_task.apply_async.assert_called_once()
|
||||
call_kwargs = resume_task.apply_async.call_args.kwargs
|
||||
assert call_kwargs["queue"] == "workflow_queue"
|
||||
payload = call_kwargs["kwargs"]["task_data_dict"]
|
||||
assert payload["workflow_trigger_log_id"] == "trigger-log-id"
|
||||
assert payload["workflow_run_id"] == "workflow-run-id"
|
||||
|
||||
|
||||
def test_enqueue_resume_no_trigger_log(mocker, mock_session_factory):
|
||||
session_factory, session = mock_session_factory
|
||||
service = HumanInputService(session_factory)
|
||||
|
||||
repo_cls = mocker.patch(
|
||||
"services.human_input_service.SQLAlchemyWorkflowTriggerLogRepository",
|
||||
autospec=True,
|
||||
)
|
||||
repo = repo_cls.return_value
|
||||
repo.get_by_workflow_run_id.return_value = None
|
||||
|
||||
resume_task = mocker.patch("services.human_input_service.resume_workflow_execution")
|
||||
|
||||
service._enqueue_resume("workflow-run-id")
|
||||
|
||||
repo_cls.assert_called_once_with(session)
|
||||
resume_task.apply_async.assert_not_called()
|
||||
|
||||
|
||||
def test_enqueue_resume_chatflow_fallback(mocker, mock_session_factory):
|
||||
session_factory, session = mock_session_factory
|
||||
service = HumanInputService(session_factory)
|
||||
|
||||
repo_cls = mocker.patch(
|
||||
"services.human_input_service.SQLAlchemyWorkflowTriggerLogRepository",
|
||||
autospec=True,
|
||||
)
|
||||
repo = repo_cls.return_value
|
||||
repo.get_by_workflow_run_id.return_value = None
|
||||
|
||||
workflow_run = MagicMock()
|
||||
workflow_run.app_id = "app-id"
|
||||
|
||||
workflow_run_repo = MagicMock()
|
||||
workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run
|
||||
mocker.patch(
|
||||
"services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository",
|
||||
return_value=workflow_run_repo,
|
||||
)
|
||||
|
||||
app = MagicMock()
|
||||
app.mode = "advanced-chat"
|
||||
app.mode = "workflow"
|
||||
session.execute.return_value.scalar_one_or_none.return_value = app
|
||||
|
||||
session.get.side_effect = [workflow_run, app]
|
||||
|
||||
resume_task = mocker.patch("services.human_input_service.resume_chatflow_execution")
|
||||
resume_task = mocker.patch("services.human_input_service.resume_app_execution")
|
||||
|
||||
service._enqueue_resume("workflow-run-id")
|
||||
|
||||
@ -140,6 +93,59 @@ def test_enqueue_resume_chatflow_fallback(mocker, mock_session_factory):
|
||||
assert call_kwargs["kwargs"]["payload"]["workflow_run_id"] == "workflow-run-id"
|
||||
|
||||
|
||||
def test_enqueue_resume_dispatches_task_for_advanced_chat(mocker, mock_session_factory):
|
||||
session_factory, session = mock_session_factory
|
||||
service = HumanInputService(session_factory)
|
||||
|
||||
workflow_run = MagicMock()
|
||||
workflow_run.app_id = "app-id"
|
||||
|
||||
workflow_run_repo = MagicMock()
|
||||
workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run
|
||||
mocker.patch(
|
||||
"services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository",
|
||||
return_value=workflow_run_repo,
|
||||
)
|
||||
|
||||
app = MagicMock()
|
||||
app.mode = "advanced-chat"
|
||||
session.execute.return_value.scalar_one_or_none.return_value = app
|
||||
|
||||
resume_task = mocker.patch("services.human_input_service.resume_app_execution")
|
||||
|
||||
service._enqueue_resume("workflow-run-id")
|
||||
|
||||
resume_task.apply_async.assert_called_once()
|
||||
call_kwargs = resume_task.apply_async.call_args.kwargs
|
||||
assert call_kwargs["queue"] == "chatflow_execute"
|
||||
assert call_kwargs["kwargs"]["payload"]["workflow_run_id"] == "workflow-run-id"
|
||||
|
||||
|
||||
def test_enqueue_resume_skips_unsupported_app_mode(mocker, mock_session_factory):
|
||||
session_factory, session = mock_session_factory
|
||||
service = HumanInputService(session_factory)
|
||||
|
||||
workflow_run = MagicMock()
|
||||
workflow_run.app_id = "app-id"
|
||||
|
||||
workflow_run_repo = MagicMock()
|
||||
workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run
|
||||
mocker.patch(
|
||||
"services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository",
|
||||
return_value=workflow_run_repo,
|
||||
)
|
||||
|
||||
app = MagicMock()
|
||||
app.mode = "completion"
|
||||
session.execute.return_value.scalar_one_or_none.return_value = app
|
||||
|
||||
resume_task = mocker.patch("services.human_input_service.resume_app_execution")
|
||||
|
||||
service._enqueue_resume("workflow-run-id")
|
||||
|
||||
resume_task.apply_async.assert_not_called()
|
||||
|
||||
|
||||
def test_get_form_definition_by_id_uses_repository(sample_form_record, mock_session_factory):
|
||||
session_factory, _ = mock_session_factory
|
||||
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
|
||||
|
||||
Reference in New Issue
Block a user