diff --git a/api/controllers/console/auth/email_register.py b/api/controllers/console/auth/email_register.py index c2a95ddad2..9e7faa09c5 100644 --- a/api/controllers/console/auth/email_register.py +++ b/api/controllers/console/auth/email_register.py @@ -1,7 +1,7 @@ from flask import request from flask_restx import Resource from pydantic import BaseModel, Field, field_validator -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from configs import dify_config from constants.languages import languages @@ -73,7 +73,7 @@ class EmailRegisterSendEmailApi(Resource): if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email): raise AccountInFreezeError() - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language) return {"result": "success", "data": token} @@ -145,7 +145,7 @@ class EmailRegisterResetApi(Resource): email = register_data.get("email", "") normalized_email = email.lower() - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: account = AccountService.get_account_by_email_with_case_fallback(email, session=session) if account: diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index 1ed931b0d7..844f3c91ff 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -4,7 +4,7 @@ import secrets from flask import request from flask_restx import Resource from pydantic import BaseModel, Field, field_validator -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from controllers.common.schema import register_schema_models from controllers.console import console_ns @@ -102,7 +102,7 @@ class ForgotPasswordSendEmailApi(Resource): else: language = "en-US" - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session) token = AccountService.send_reset_password_email( @@ -201,7 +201,7 @@ class ForgotPasswordResetApi(Resource): password_hashed = hash_password(args.new_password, salt) email = reset_data.get("email", "") - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: account = AccountService.get_account_by_email_with_case_fallback(email, session=session) if account: @@ -215,7 +215,6 @@ class ForgotPasswordResetApi(Resource): # Update existing account credentials account.password = base64.b64encode(password_hashed).decode() account.password_salt = base64.b64encode(salt).decode() - session.commit() # Create workspace if needed if ( diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 5c9023f27b..5c7011fd22 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -4,7 +4,7 @@ import urllib.parse import httpx from flask import current_app, redirect, request from flask_restx import Resource -from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import Unauthorized from configs import dify_config @@ -180,7 +180,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> account: Account | None = Account.get_by_openid(provider, user_info.id) if not account: - with Session(db.engine) as session: + with sessionmaker(db.engine).begin() as session: account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session) return account diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index bc90c4ffbd..074694e7ea 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -10,7 +10,7 @@ import sqlalchemy as sa from flask import request, send_file from flask_restx import Resource, fields, marshal, marshal_with from pydantic import BaseModel, Field -from sqlalchemy import asc, desc, select +from sqlalchemy import asc, desc, func, select from werkzeug.exceptions import Forbidden, NotFound import services @@ -211,12 +211,11 @@ class GetProcessRuleApi(Resource): raise Forbidden(str(e)) # get the latest process rule - dataset_process_rule = ( - db.session.query(DatasetProcessRule) + dataset_process_rule = db.session.scalar( + select(DatasetProcessRule) .where(DatasetProcessRule.dataset_id == document.dataset_id) .order_by(DatasetProcessRule.created_at.desc()) .limit(1) - .one_or_none() ) if dataset_process_rule: mode = dataset_process_rule.mode @@ -330,21 +329,23 @@ class DatasetDocumentListApi(Resource): if fetch: for document in documents: completed_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != SegmentStatus.RE_SEGMENT, + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) total_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != SegmentStatus.RE_SEGMENT, + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) document.completed_segments = completed_segments document.total_segments = total_segments @@ -521,10 +522,10 @@ class DocumentIndexingEstimateApi(DocumentResource): if data_source_info and "upload_file_id" in data_source_info: file_id = data_source_info["upload_file_id"] - file = ( - db.session.query(UploadFile) + file = db.session.scalar( + select(UploadFile) .where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id) - .first() + .limit(1) ) # raise error if file not found @@ -586,10 +587,10 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): if not data_source_info: continue file_id = data_source_info["upload_file_id"] - file_detail = ( - db.session.query(UploadFile) + file_detail = db.session.scalar( + select(UploadFile) .where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id) - .first() + .limit(1) ) if file_detail is None: @@ -672,20 +673,23 @@ class DocumentBatchIndexingStatusApi(DocumentResource): documents_status = [] for document in documents: completed_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != SegmentStatus.RE_SEGMENT, + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) total_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) # Create a dictionary with document attributes and additional fields document_dict = { @@ -723,18 +727,23 @@ class DocumentIndexingStatusApi(DocumentResource): document = self.get_document(dataset_id, document_id) completed_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document_id), - DocumentSegment.status != SegmentStatus.RE_SEGMENT, + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document_id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) total_segments = ( - db.session.query(DocumentSegment) - .where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != SegmentStatus.RE_SEGMENT) - .count() + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.document_id == str(document_id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) + ) + or 0 ) # Create a dictionary with document attributes and additional fields @@ -1258,11 +1267,11 @@ class DocumentPipelineExecutionLogApi(DocumentResource): document = DocumentService.get_document(dataset.id, document_id) if not document: raise NotFound("Document not found.") - log = ( - db.session.query(DocumentPipelineExecutionLog) - .filter_by(document_id=document_id) + log = db.session.scalar( + select(DocumentPipelineExecutionLog) + .where(DocumentPipelineExecutionLog.document_id == document_id) .order_by(DocumentPipelineExecutionLog.created_at.desc()) - .first() + .limit(1) ) if not log: return { diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 3fd0f3b712..fa9bc7f159 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -45,7 +45,7 @@ def _get_segment_with_summary(segment, dataset_id): """Helper function to marshal segment and add summary information.""" from services.summary_index_service import SummaryIndexService - segment_dict = dict(marshal(segment, segment_fields)) + segment_dict = dict(marshal(segment, segment_fields)) # type: ignore # Query summary for this segment (only enabled summaries) summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id) segment_dict["summary"] = summary.summary_content if summary else None @@ -206,7 +206,7 @@ class DatasetDocumentSegmentListApi(Resource): # Add summary to each segment segments_with_summary = [] for segment in segments.items: - segment_dict = dict(marshal(segment, segment_fields)) + segment_dict = dict(marshal(segment, segment_fields)) # type: ignore segment_dict["summary"] = summaries.get(segment.id) segments_with_summary.append(segment_dict) diff --git a/api/controllers/console/datasets/wraps.py b/api/controllers/console/datasets/wraps.py index 3ef1341abc..d533e6c5b1 100644 --- a/api/controllers/console/datasets/wraps.py +++ b/api/controllers/console/datasets/wraps.py @@ -2,6 +2,8 @@ from collections.abc import Callable from functools import wraps from typing import ParamSpec, TypeVar +from sqlalchemy import select + from controllers.console.datasets.error import PipelineNotFoundError from extensions.ext_database import db from libs.login import current_account_with_tenant @@ -24,10 +26,8 @@ def get_rag_pipeline(view_func: Callable[P, R]): del kwargs["pipeline_id"] - pipeline = ( - db.session.query(Pipeline) - .where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id) - .first() + pipeline = db.session.scalar( + select(Pipeline).where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id).limit(1) ) if not pipeline: diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 83d07087ab..89be847cd3 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -153,15 +153,15 @@ class DatasetListApi(DatasetApiResource): data = marshal(datasets, dataset_detail_fields) for item in data: - if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]: - item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"])) - item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" + if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]: # type: ignore + item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"])) # type: ignore + item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" # type: ignore if item_model in model_names: - item["embedding_available"] = True + item["embedding_available"] = True # type: ignore else: - item["embedding_available"] = False + item["embedding_available"] = False # type: ignore else: - item["embedding_available"] = True + item["embedding_available"] = True # type: ignore response = { "data": data, "has_more": len(datasets) == query.limit, diff --git a/api/core/ops/weave_trace/entities/weave_trace_entity.py b/api/core/ops/weave_trace/entities/weave_trace_entity.py index ef1a3be45b..ed6a7dabbb 100644 --- a/api/core/ops/weave_trace/entities/weave_trace_entity.py +++ b/api/core/ops/weave_trace/entities/weave_trace_entity.py @@ -67,7 +67,8 @@ class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel): if field_name == "inputs": data = { "messages": [ - dict(msg, **{"usage_metadata": usage_metadata, "file_list": file_list}) for msg in v + dict(msg, **{"usage_metadata": usage_metadata, "file_list": file_list}) # type: ignore + for msg in v ] if isinstance(v, list) else v, diff --git a/api/models/model.py b/api/models/model.py index 20daa010d8..bea17246fb 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -591,7 +591,9 @@ class AppModelConfig(TypeBase): __tablename__ = "app_model_configs" __table_args__ = (sa.PrimaryKeyConstraint("id", name="app_model_config_pkey"), sa.Index("app_app_id_idx", "app_id")) - id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False) + id: Mapped[str] = mapped_column( + StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False + ) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) provider: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) model_id: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) diff --git a/api/pyproject.toml b/api/pyproject.toml index 1fb0d97dc7..8a196f4485 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -182,7 +182,7 @@ dev = [ "sseclient-py>=1.8.0", "pytest-timeout>=2.4.0", "pytest-xdist>=3.8.0", - "pyrefly>=0.55.0", + "pyrefly>=0.57.1", ] ############################################################ diff --git a/api/tests/unit_tests/controllers/console/auth/test_email_register.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py similarity index 82% rename from api/tests/unit_tests/controllers/console/auth/test_email_register.py rename to api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py index 724c80f18c..879c337319 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_email_register.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py @@ -1,8 +1,11 @@ +"""Testcontainers integration tests for email register controller endpoints.""" + +from __future__ import annotations + from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest -from flask import Flask from controllers.console.auth.email_register import ( EmailRegisterCheckApi, @@ -13,14 +16,11 @@ from services.account_service import AccountService @pytest.fixture -def app(): - flask_app = Flask(__name__) - flask_app.config["TESTING"] = True - return flask_app +def app(flask_app_with_containers): + return flask_app_with_containers class TestEmailRegisterSendEmailApi: - @patch("controllers.console.auth.email_register.Session") @patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.email_register.AccountService.send_email_register_email") @patch("controllers.console.auth.email_register.BillingService.is_email_in_freeze") @@ -33,20 +33,15 @@ class TestEmailRegisterSendEmailApi: mock_is_freeze, mock_send_mail, mock_get_account, - mock_session_cls, app, ): mock_send_mail.return_value = "token-123" mock_is_freeze.return_value = False mock_account = MagicMock() - - mock_session = MagicMock() - mock_session_cls.return_value.__enter__.return_value = mock_session mock_get_account.return_value = mock_account feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) with ( - patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")), patch("controllers.console.auth.email_register.dify_config", SimpleNamespace(BILLING_ENABLED=True)), patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags), @@ -61,7 +56,6 @@ class TestEmailRegisterSendEmailApi: assert response == {"result": "success", "data": "token-123"} mock_is_freeze.assert_called_once_with("invitee@example.com") mock_send_mail.assert_called_once_with(email="invitee@example.com", account=mock_account, language="en-US") - mock_get_account.assert_called_once_with("Invitee@Example.com", session=mock_session) mock_extract_ip.assert_called_once() mock_is_email_send_ip_limit.assert_called_once_with("127.0.0.1") @@ -89,7 +83,6 @@ class TestEmailRegisterCheckApi: feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) with ( - patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")), patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags), ): @@ -114,7 +107,6 @@ class TestEmailRegisterResetApi: @patch("controllers.console.auth.email_register.AccountService.reset_login_error_rate_limit") @patch("controllers.console.auth.email_register.AccountService.login") @patch("controllers.console.auth.email_register.EmailRegisterResetApi._create_new_account") - @patch("controllers.console.auth.email_register.Session") @patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token") @patch("controllers.console.auth.email_register.AccountService.get_email_register_data") @@ -125,7 +117,6 @@ class TestEmailRegisterResetApi: mock_get_data, mock_revoke_token, mock_get_account, - mock_session_cls, mock_create_account, mock_login, mock_reset_login_rate, @@ -136,14 +127,10 @@ class TestEmailRegisterResetApi: token_pair = MagicMock() token_pair.model_dump.return_value = {"access_token": "a", "refresh_token": "r"} mock_login.return_value = token_pair - - mock_session = MagicMock() - mock_session_cls.return_value.__enter__.return_value = mock_session mock_get_account.return_value = None feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) with ( - patch("controllers.console.auth.email_register.db", SimpleNamespace(engine="engine")), patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags), ): @@ -159,19 +146,19 @@ class TestEmailRegisterResetApi: mock_reset_login_rate.assert_called_once_with("invitee@example.com") mock_revoke_token.assert_called_once_with("token-123") mock_extract_ip.assert_called_once() - mock_get_account.assert_called_once_with("Invitee@Example.com", session=mock_session) -def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(): +def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(): + """Test that case fallback tries lowercase when exact match fails.""" mock_session = MagicMock() - first_query = MagicMock() - first_query.scalar_one_or_none.return_value = None + first_result = MagicMock() + first_result.scalar_one_or_none.return_value = None expected_account = MagicMock() - second_query = MagicMock() - second_query.scalar_one_or_none.return_value = expected_account - mock_session.execute.side_effect = [first_query, second_query] + second_result = MagicMock() + second_result.scalar_one_or_none.return_value = expected_account + mock_session.execute.side_effect = [first_result, second_result] - account = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session) + result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session) - assert account is expected_account + assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/unit_tests/controllers/console/auth/test_forgot_password.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py similarity index 82% rename from api/tests/unit_tests/controllers/console/auth/test_forgot_password.py rename to api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py index 8403777dc9..7b7393dade 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_forgot_password.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py @@ -1,8 +1,11 @@ +"""Testcontainers integration tests for forgot password controller endpoints.""" + +from __future__ import annotations + from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest -from flask import Flask from controllers.console.auth.forgot_password import ( ForgotPasswordCheckApi, @@ -13,14 +16,11 @@ from services.account_service import AccountService @pytest.fixture -def app(): - flask_app = Flask(__name__) - flask_app.config["TESTING"] = True - return flask_app +def app(flask_app_with_containers): + return flask_app_with_containers class TestForgotPasswordSendEmailApi: - @patch("controllers.console.auth.forgot_password.Session") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.forgot_password.AccountService.send_reset_password_email") @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit", return_value=False) @@ -31,19 +31,15 @@ class TestForgotPasswordSendEmailApi: mock_is_ip_limit, mock_send_email, mock_get_account, - mock_session_cls, app, ): mock_account = MagicMock() mock_get_account.return_value = mock_account mock_send_email.return_value = "token-123" - mock_session = MagicMock() - mock_session_cls.return_value.__enter__.return_value = mock_session wraps_features = SimpleNamespace(enable_email_password_login=True, is_allow_register=True) controller_features = SimpleNamespace(is_allow_register=True) with ( - patch("controllers.console.auth.forgot_password.db", SimpleNamespace(engine="engine")), patch( "controllers.console.auth.forgot_password.FeatureService.get_system_features", return_value=controller_features, @@ -59,7 +55,6 @@ class TestForgotPasswordSendEmailApi: response = ForgotPasswordSendEmailApi().post() assert response == {"result": "success", "data": "token-123"} - mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) mock_send_email.assert_called_once_with( account=mock_account, email="user@example.com", @@ -117,7 +112,6 @@ class TestForgotPasswordCheckApi: class TestForgotPasswordResetApi: @patch("controllers.console.auth.forgot_password.ForgotPasswordResetApi._update_existing_account") - @patch("controllers.console.auth.forgot_password.Session") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @@ -126,7 +120,6 @@ class TestForgotPasswordResetApi: mock_get_reset_data, mock_revoke_token, mock_get_account, - mock_session_cls, mock_update_account, app, ): @@ -134,12 +127,8 @@ class TestForgotPasswordResetApi: mock_account = MagicMock() mock_get_account.return_value = mock_account - mock_session = MagicMock() - mock_session_cls.return_value.__enter__.return_value = mock_session - wraps_features = SimpleNamespace(enable_email_password_login=True) with ( - patch("controllers.console.auth.forgot_password.db", SimpleNamespace(engine="engine")), patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")), patch("controllers.console.wraps.FeatureService.get_system_features", return_value=wraps_features), ): @@ -157,20 +146,22 @@ class TestForgotPasswordResetApi: assert response == {"result": "success"} mock_get_reset_data.assert_called_once_with("token-123") mock_revoke_token.assert_called_once_with("token-123") - mock_get_account.assert_called_once_with("User@Example.com", session=mock_session) mock_update_account.assert_called_once() -def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(): +def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(): + """Test that case fallback tries lowercase when exact match fails.""" + from unittest.mock import MagicMock + mock_session = MagicMock() - first_query = MagicMock() - first_query.scalar_one_or_none.return_value = None + first_result = MagicMock() + first_result.scalar_one_or_none.return_value = None expected_account = MagicMock() - second_query = MagicMock() - second_query.scalar_one_or_none.return_value = expected_account - mock_session.execute.side_effect = [first_query, second_query] + second_result = MagicMock() + second_result.scalar_one_or_none.return_value = expected_account + mock_session.execute.side_effect = [first_result, second_result] - account = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session) + result = AccountService.get_account_by_email_with_case_fallback("Mixed@Test.com", session=mock_session) - assert account is expected_account + assert result is expected_account assert mock_session.execute.call_count == 2 diff --git a/api/tests/unit_tests/controllers/console/auth/test_oauth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py similarity index 92% rename from api/tests/unit_tests/controllers/console/auth/test_oauth.py rename to api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py index 6345c2ab23..a2f1328579 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py @@ -1,7 +1,10 @@ +"""Testcontainers integration tests for OAuth controller endpoints.""" + +from __future__ import annotations + from unittest.mock import MagicMock, patch import pytest -from flask import Flask from controllers.console.auth.oauth import ( OAuthCallback, @@ -18,10 +21,8 @@ from services.errors.account import AccountRegisterError class TestGetOAuthProviders: @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app + def app(self, flask_app_with_containers): + return flask_app_with_containers @pytest.mark.parametrize( ("github_config", "google_config", "expected_github", "expected_google"), @@ -64,10 +65,8 @@ class TestOAuthLogin: return OAuthLogin() @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app + def app(self, flask_app_with_containers): + return flask_app_with_containers @pytest.fixture def mock_oauth_provider(self): @@ -131,10 +130,8 @@ class TestOAuthCallback: return OAuthCallback() @pytest.fixture - def app(self): - app = Flask(__name__) - app.config["TESTING"] = True - return app + def app(self, flask_app_with_containers): + return flask_app_with_containers @pytest.fixture def oauth_setup(self): @@ -190,15 +187,8 @@ class TestOAuthCallback: (KeyError("Missing key"), "OAuth process failed"), ], ) - @patch("controllers.console.auth.oauth.db") @patch("controllers.console.auth.oauth.get_oauth_providers") - def test_should_handle_oauth_exceptions( - self, mock_get_providers, mock_db, resource, app, exception, expected_error - ): - # Mock database session - mock_db.session = MagicMock() - mock_db.session.rollback = MagicMock() - + def test_should_handle_oauth_exceptions(self, mock_get_providers, resource, app, exception, expected_error): # Import the real requests module to create a proper exception import httpx @@ -258,7 +248,6 @@ class TestOAuthCallback: ) @patch("controllers.console.auth.oauth.AccountService") @patch("controllers.console.auth.oauth.TenantService") - @patch("controllers.console.auth.oauth.db") @patch("controllers.console.auth.oauth.dify_config") @patch("controllers.console.auth.oauth.get_oauth_providers") @patch("controllers.console.auth.oauth._generate_account") @@ -269,7 +258,6 @@ class TestOAuthCallback: mock_generate_account, mock_get_providers, mock_config, - mock_db, mock_tenant_service, mock_account_service, resource, @@ -278,10 +266,6 @@ class TestOAuthCallback: account_status, expected_redirect, ): - # Mock database session - mock_db.session = MagicMock() - mock_db.session.rollback = MagicMock() - mock_db.session.commit = MagicMock() mock_config.CONSOLE_WEB_URL = "http://localhost:3000" mock_get_providers.return_value = {"github": oauth_setup["provider"]} @@ -306,14 +290,12 @@ class TestOAuthCallback: @patch("controllers.console.auth.oauth.dify_config") @patch("controllers.console.auth.oauth.get_oauth_providers") @patch("controllers.console.auth.oauth._generate_account") - @patch("controllers.console.auth.oauth.db") @patch("controllers.console.auth.oauth.TenantService") @patch("controllers.console.auth.oauth.AccountService") def test_should_activate_pending_account( self, mock_account_service, mock_tenant_service, - mock_db, mock_generate_account, mock_get_providers, mock_config, @@ -338,12 +320,10 @@ class TestOAuthCallback: assert mock_account.status == AccountStatus.ACTIVE assert mock_account.initialized_at is not None - mock_db.session.commit.assert_called_once() @patch("controllers.console.auth.oauth.dify_config") @patch("controllers.console.auth.oauth.get_oauth_providers") @patch("controllers.console.auth.oauth._generate_account") - @patch("controllers.console.auth.oauth.db") @patch("controllers.console.auth.oauth.TenantService") @patch("controllers.console.auth.oauth.AccountService") @patch("controllers.console.auth.oauth.redirect") @@ -352,7 +332,6 @@ class TestOAuthCallback: mock_redirect, mock_account_service, mock_tenant_service, - mock_db, mock_generate_account, mock_get_providers, mock_config, @@ -414,6 +393,10 @@ class TestOAuthCallback: class TestAccountGeneration: + @pytest.fixture + def app(self, flask_app_with_containers): + return flask_app_with_containers + @pytest.fixture def user_info(self): return OAuthUserInfo(id="123", name="Test User", email="test@example.com") @@ -425,15 +408,10 @@ class TestAccountGeneration: return account @patch("controllers.console.auth.oauth.AccountService.get_account_by_email_with_case_fallback") - @patch("controllers.console.auth.oauth.Session") @patch("controllers.console.auth.oauth.Account") - @patch("controllers.console.auth.oauth.db") def test_should_get_account_by_openid_or_email( - self, mock_db, mock_account_model, mock_session, mock_get_account, user_info, mock_account + self, mock_account_model, mock_get_account, flask_req_ctx_with_containers, user_info, mock_account ): - # Mock db.engine for Session creation - mock_db.engine = MagicMock() - # Test OpenID found mock_account_model.get_by_openid.return_value = mock_account result = _get_account_by_openid_or_email("github", user_info) @@ -443,15 +421,14 @@ class TestAccountGeneration: # Test fallback to email lookup mock_account_model.get_by_openid.return_value = None - mock_session_instance = MagicMock() - mock_session.return_value.__enter__.return_value = mock_session_instance mock_get_account.return_value = mock_account result = _get_account_by_openid_or_email("github", user_info) assert result == mock_account - mock_get_account.assert_called_once_with(user_info.email, session=mock_session_instance) + mock_get_account.assert_called_once() - def test_get_account_by_email_with_case_fallback_uses_lowercase_lookup(self): + def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase(self): + """Test that case fallback tries lowercase when exact match fails.""" mock_session = MagicMock() first_result = MagicMock() first_result.scalar_one_or_none.return_value = None @@ -462,7 +439,7 @@ class TestAccountGeneration: result = AccountService.get_account_by_email_with_case_fallback("Case@Test.com", session=mock_session) - assert result == expected_account + assert result is expected_account assert mock_session.execute.call_count == 2 @pytest.mark.parametrize( @@ -478,10 +455,8 @@ class TestAccountGeneration: @patch("controllers.console.auth.oauth.RegisterService") @patch("controllers.console.auth.oauth.AccountService") @patch("controllers.console.auth.oauth.TenantService") - @patch("controllers.console.auth.oauth.db") def test_should_handle_account_generation_scenarios( self, - mock_db, mock_tenant_service, mock_account_service, mock_register_service, @@ -519,10 +494,8 @@ class TestAccountGeneration: @patch("controllers.console.auth.oauth.RegisterService") @patch("controllers.console.auth.oauth.AccountService") @patch("controllers.console.auth.oauth.TenantService") - @patch("controllers.console.auth.oauth.db") def test_should_register_with_lowercase_email( self, - mock_db, mock_tenant_service, mock_account_service, mock_register_service, diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_service.py index a484c7be87..c4d20bc02c 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service.py @@ -707,3 +707,104 @@ class TestDatasetServiceRetrievalConfiguration: db_session_with_containers.refresh(dataset) assert result.id == dataset.id assert dataset.retrieval_model == update_data["retrieval_model"] + + +class TestDocumentServicePauseRecoverRetry: + """Tests for pause/recover/retry orchestration using real DB and Redis.""" + + def _create_indexing_document(self, db_session_with_containers, indexing_status="indexing"): + factory = DatasetServiceIntegrationDataFactory + account, tenant = factory.create_account_with_tenant(db_session_with_containers) + dataset = factory.create_dataset(db_session_with_containers, tenant.id, account.id) + doc = factory.create_document(db_session_with_containers, dataset, account.id) + doc.indexing_status = indexing_status + db_session_with_containers.commit() + return doc, account + + def test_pause_document_success(self, db_session_with_containers): + from extensions.ext_redis import redis_client + from services.dataset_service import DocumentService + + doc, account = self._create_indexing_document(db_session_with_containers, indexing_status="indexing") + + with patch("services.dataset_service.current_user") as mock_user: + mock_user.id = account.id + DocumentService.pause_document(doc) + + db_session_with_containers.refresh(doc) + assert doc.is_paused is True + assert doc.paused_by == account.id + assert doc.paused_at is not None + + cache_key = f"document_{doc.id}_is_paused" + assert redis_client.get(cache_key) is not None + redis_client.delete(cache_key) + + def test_pause_document_invalid_status_error(self, db_session_with_containers): + from services.dataset_service import DocumentService + from services.errors.document import DocumentIndexingError + + doc, account = self._create_indexing_document(db_session_with_containers, indexing_status="completed") + + with patch("services.dataset_service.current_user") as mock_user: + mock_user.id = account.id + with pytest.raises(DocumentIndexingError): + DocumentService.pause_document(doc) + + def test_recover_document_success(self, db_session_with_containers): + from extensions.ext_redis import redis_client + from services.dataset_service import DocumentService + + doc, account = self._create_indexing_document(db_session_with_containers, indexing_status="indexing") + + # Pause first + with patch("services.dataset_service.current_user") as mock_user: + mock_user.id = account.id + DocumentService.pause_document(doc) + + # Recover + with patch("services.dataset_service.recover_document_indexing_task") as recover_task: + DocumentService.recover_document(doc) + + db_session_with_containers.refresh(doc) + assert doc.is_paused is False + assert doc.paused_by is None + assert doc.paused_at is None + + cache_key = f"document_{doc.id}_is_paused" + assert redis_client.get(cache_key) is None + recover_task.delay.assert_called_once_with(doc.dataset_id, doc.id) + + def test_retry_document_indexing_success(self, db_session_with_containers): + from extensions.ext_redis import redis_client + from services.dataset_service import DocumentService + + factory = DatasetServiceIntegrationDataFactory + account, tenant = factory.create_account_with_tenant(db_session_with_containers) + dataset = factory.create_dataset(db_session_with_containers, tenant.id, account.id) + doc1 = factory.create_document(db_session_with_containers, dataset, account.id, name="doc1.txt") + doc2 = factory.create_document(db_session_with_containers, dataset, account.id, name="doc2.txt") + doc2.position = 2 + doc1.indexing_status = "error" + doc2.indexing_status = "error" + db_session_with_containers.commit() + + with ( + patch("services.dataset_service.current_user") as mock_user, + patch("services.dataset_service.retry_document_indexing_task") as retry_task, + ): + mock_user.id = account.id + DocumentService.retry_document(dataset.id, [doc1, doc2]) + + db_session_with_containers.refresh(doc1) + db_session_with_containers.refresh(doc2) + assert doc1.indexing_status == "waiting" + assert doc2.indexing_status == "waiting" + + # Verify redis keys were set + assert redis_client.get(f"document_{doc1.id}_is_retried") is not None + assert redis_client.get(f"document_{doc2.id}_is_retried") is not None + retry_task.delay.assert_called_once_with(dataset.id, [doc1.id, doc2.id], account.id) + + # Cleanup + redis_client.delete(f"document_{doc1.id}_is_retried", f"document_{doc2.id}_is_retried") diff --git a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py index 5f86cb2ae9..376a89d1ce 100644 --- a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py +++ b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py @@ -141,3 +141,73 @@ class TestArchivedWorkflowRunDeletion: db_session_with_containers.expunge_all() deleted_run = db_session_with_containers.get(WorkflowRun, run_id) assert deleted_run is None + + def test_delete_run_dry_run(self, db_session_with_containers): + """Dry run should return success without actually deleting.""" + tenant_id = str(uuid4()) + run = self._create_workflow_run( + db_session_with_containers, + tenant_id=tenant_id, + created_at=datetime.now(UTC), + ) + run_id = run.id + deleter = ArchivedWorkflowRunDeletion(dry_run=True) + + result = deleter._delete_run(run) + + assert result.success is True + assert result.run_id == run_id + # Run should still exist because it's a dry run + db_session_with_containers.expire_all() + assert db_session_with_containers.get(WorkflowRun, run_id) is not None + + def test_delete_run_exception_returns_error(self, db_session_with_containers): + """Exception during deletion should return failure result.""" + from unittest.mock import MagicMock, patch + + tenant_id = str(uuid4()) + run = self._create_workflow_run( + db_session_with_containers, + tenant_id=tenant_id, + created_at=datetime.now(UTC), + ) + deleter = ArchivedWorkflowRunDeletion(dry_run=False) + + with patch.object(deleter, "_get_workflow_run_repo") as mock_get_repo: + mock_repo = MagicMock() + mock_get_repo.return_value = mock_repo + mock_repo.delete_runs_with_related.side_effect = Exception("Database error") + + result = deleter._delete_run(run) + + assert result.success is False + assert result.error == "Database error" + + def test_delete_by_run_id_success(self, db_session_with_containers): + """Successfully delete an archived workflow run by ID.""" + tenant_id = str(uuid4()) + base_time = datetime.now(UTC) + run = self._create_workflow_run( + db_session_with_containers, + tenant_id=tenant_id, + created_at=base_time, + ) + self._create_archive_log(db_session_with_containers, run=run) + run_id = run.id + + deleter = ArchivedWorkflowRunDeletion() + result = deleter.delete_by_run_id(run_id) + + assert result.success is True + db_session_with_containers.expunge_all() + assert db_session_with_containers.get(WorkflowRun, run_id) is None + + def test_get_workflow_run_repo_caches_instance(self, db_session_with_containers): + """_get_workflow_run_repo should return a cached repo on subsequent calls.""" + deleter = ArchivedWorkflowRunDeletion() + + repo1 = deleter._get_workflow_run_repo() + repo2 = deleter._get_workflow_run_repo() + + assert repo1 is repo2 + assert deleter.workflow_run_repo is repo1 diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py index f08f21ee14..ce2278de4f 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py @@ -140,8 +140,8 @@ class TestDatasetDocumentListApi: return_value=pagination, ), patch( - "controllers.console.datasets.datasets_document.db.session.query", - return_value=MagicMock(where=lambda *a, **k: MagicMock(count=count_mock)), + "controllers.console.datasets.datasets_document.db.session.scalar", + return_value=2, ), patch( "controllers.console.datasets.datasets_document.DocumentService.enrich_documents_with_summary_index_status", @@ -700,10 +700,8 @@ class TestDocumentPipelineExecutionLogApi: return_value=MagicMock(), ), patch( - "controllers.console.datasets.datasets_document.db.session.query", - return_value=MagicMock( - filter_by=lambda **k: MagicMock(order_by=lambda *a: MagicMock(first=lambda: log)) - ), + "controllers.console.datasets.datasets_document.db.session.scalar", + return_value=log, ), ): response, status = method(api, "ds-1", "doc-1") @@ -827,15 +825,12 @@ class TestDocumentIndexingEstimateApi: dataset_process_rule=None, ) - query_mock = MagicMock() - query_mock.where.return_value.first.return_value = None - with ( app.test_request_context("/"), patch.object(api, "get_document", return_value=document), patch( - "controllers.console.datasets.datasets_document.db.session.query", - return_value=query_mock, + "controllers.console.datasets.datasets_document.db.session.scalar", + return_value=None, ), ): with pytest.raises(NotFound): @@ -863,10 +858,8 @@ class TestDocumentIndexingEstimateApi: app.test_request_context("/"), patch.object(api, "get_document", return_value=document), patch( - "controllers.console.datasets.datasets_document.db.session.query", - return_value=MagicMock( - where=MagicMock(return_value=MagicMock(first=MagicMock(return_value=upload_file))) - ), + "controllers.console.datasets.datasets_document.db.session.scalar", + return_value=upload_file, ), patch( "controllers.console.datasets.datasets_document.ExtractSetting", @@ -1239,12 +1232,8 @@ class TestDocumentPermissionCases: return_value=None, ), patch( - "controllers.console.datasets.datasets_document.db.session.query", - return_value=MagicMock( - where=lambda *a: MagicMock( - order_by=lambda *b: MagicMock(limit=lambda n: MagicMock(one_or_none=lambda: process_rule)) - ) - ), + "controllers.console.datasets.datasets_document.db.session.scalar", + return_value=process_rule, ), ): result = method(api) @@ -1364,8 +1353,8 @@ class TestDocumentIndexingEdgeCases: app.test_request_context("/"), patch.object(api, "get_document", return_value=document), patch( - "controllers.console.datasets.datasets_document.db.session.query", - return_value=MagicMock(where=lambda *a: MagicMock(first=lambda: upload_file)), + "controllers.console.datasets.datasets_document.db.session.scalar", + return_value=upload_file, ), patch( "controllers.console.datasets.datasets_document.ExtractSetting", diff --git a/api/tests/unit_tests/controllers/console/datasets/test_wraps.py b/api/tests/unit_tests/controllers/console/datasets/test_wraps.py index 90f00711c1..e358435de4 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_wraps.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_wraps.py @@ -26,12 +26,9 @@ class TestGetRagPipeline: return_value=(Mock(), "tenant-1"), ) - mock_query = Mock() - mock_query.where.return_value.first.return_value = None - mocker.patch( - "controllers.console.datasets.wraps.db.session.query", - return_value=mock_query, + "controllers.console.datasets.wraps.db.session.scalar", + return_value=None, ) with pytest.raises(PipelineNotFoundError): @@ -51,12 +48,9 @@ class TestGetRagPipeline: return_value=(Mock(), "tenant-1"), ) - mock_query = Mock() - mock_query.where.return_value.first.return_value = pipeline - mocker.patch( - "controllers.console.datasets.wraps.db.session.query", - return_value=mock_query, + "controllers.console.datasets.wraps.db.session.scalar", + return_value=pipeline, ) result = dummy_view(pipeline_id="pipeline-1") @@ -76,12 +70,9 @@ class TestGetRagPipeline: return_value=(Mock(), "tenant-1"), ) - mock_query = Mock() - mock_query.where.return_value.first.return_value = pipeline - mocker.patch( - "controllers.console.datasets.wraps.db.session.query", - return_value=mock_query, + "controllers.console.datasets.wraps.db.session.scalar", + return_value=pipeline, ) result = dummy_view(pipeline_id="pipeline-1") @@ -100,18 +91,15 @@ class TestGetRagPipeline: return_value=(Mock(), "tenant-1"), ) - def where_side_effect(*args, **kwargs): - assert args[0].right.value == "123" - return Mock(first=lambda: pipeline) - - mock_query = Mock() - mock_query.where.side_effect = where_side_effect - - mocker.patch( - "controllers.console.datasets.wraps.db.session.query", - return_value=mock_query, + mock_scalar = mocker.patch( + "controllers.console.datasets.wraps.db.session.scalar", + return_value=pipeline, ) result = dummy_view(pipeline_id=123) assert result is pipeline + # Verify the pipeline_id was cast to string in the where clause + stmt = mock_scalar.call_args[0][0] + where_clauses = stmt.whereclause.clauses + assert where_clauses[0].right.value == "123" diff --git a/api/tests/unit_tests/services/retention/workflow_run/test_delete_archived_workflow_run.py b/api/tests/unit_tests/services/retention/workflow_run/test_delete_archived_workflow_run.py deleted file mode 100644 index 9fe153c153..0000000000 --- a/api/tests/unit_tests/services/retention/workflow_run/test_delete_archived_workflow_run.py +++ /dev/null @@ -1,216 +0,0 @@ -from datetime import datetime -from unittest.mock import MagicMock, patch - -import pytest -from sqlalchemy.orm import Session - -from models.workflow import WorkflowRun -from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion, DeleteResult - - -class TestArchivedWorkflowRunDeletion: - @pytest.fixture - def mock_db(self): - with patch("services.retention.workflow_run.delete_archived_workflow_run.db") as mock_db: - mock_db.engine = MagicMock() - yield mock_db - - @pytest.fixture - def mock_sessionmaker(self): - with patch("services.retention.workflow_run.delete_archived_workflow_run.sessionmaker") as mock_sm: - mock_session = MagicMock(spec=Session) - mock_sm.return_value.return_value.__enter__.return_value = mock_session - yield mock_sm, mock_session - - @pytest.fixture - def mock_workflow_run_repo(self): - with patch( - "services.retention.workflow_run.delete_archived_workflow_run.APIWorkflowRunRepository" - ) as mock_repo_cls: - mock_repo = MagicMock() - yield mock_repo - - def test_delete_by_run_id_success(self, mock_db, mock_sessionmaker): - mock_sm, mock_session = mock_sessionmaker - run_id = "run-123" - tenant_id = "tenant-456" - - mock_run = MagicMock(spec=WorkflowRun) - mock_run.id = run_id - mock_run.tenant_id = tenant_id - mock_session.get.return_value = mock_run - - deletion = ArchivedWorkflowRunDeletion() - - with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo: - mock_repo = MagicMock() - mock_get_repo.return_value = mock_repo - mock_repo.get_archived_run_ids.return_value = [run_id] - - with patch.object(deletion, "_delete_run") as mock_delete_run: - expected_result = DeleteResult(run_id=run_id, tenant_id=tenant_id, success=True) - mock_delete_run.return_value = expected_result - - result = deletion.delete_by_run_id(run_id) - - assert result == expected_result - mock_session.get.assert_called_once_with(WorkflowRun, run_id) - mock_repo.get_archived_run_ids.assert_called_once() - mock_delete_run.assert_called_once_with(mock_run) - - def test_delete_by_run_id_not_found(self, mock_db, mock_sessionmaker): - mock_sm, mock_session = mock_sessionmaker - run_id = "run-123" - mock_session.get.return_value = None - - deletion = ArchivedWorkflowRunDeletion() - with patch.object(deletion, "_get_workflow_run_repo"): - result = deletion.delete_by_run_id(run_id) - - assert result.success is False - assert "not found" in result.error - assert result.run_id == run_id - - def test_delete_by_run_id_not_archived(self, mock_db, mock_sessionmaker): - mock_sm, mock_session = mock_sessionmaker - run_id = "run-123" - - mock_run = MagicMock(spec=WorkflowRun) - mock_run.id = run_id - mock_session.get.return_value = mock_run - - deletion = ArchivedWorkflowRunDeletion() - with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo: - mock_repo = MagicMock() - mock_get_repo.return_value = mock_repo - mock_repo.get_archived_run_ids.return_value = [] - - result = deletion.delete_by_run_id(run_id) - - assert result.success is False - assert "is not archived" in result.error - - def test_delete_batch(self, mock_db, mock_sessionmaker): - mock_sm, mock_session = mock_sessionmaker - deletion = ArchivedWorkflowRunDeletion() - - mock_run1 = MagicMock(spec=WorkflowRun) - mock_run1.id = "run-1" - mock_run2 = MagicMock(spec=WorkflowRun) - mock_run2.id = "run-2" - - with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo: - mock_repo = MagicMock() - mock_get_repo.return_value = mock_repo - mock_repo.get_archived_runs_by_time_range.return_value = [mock_run1, mock_run2] - - with patch.object(deletion, "_delete_run") as mock_delete_run: - mock_delete_run.side_effect = [ - DeleteResult(run_id="run-1", tenant_id="t1", success=True), - DeleteResult(run_id="run-2", tenant_id="t1", success=True), - ] - - results = deletion.delete_batch(tenant_ids=["t1"], start_date=datetime.now(), end_date=datetime.now()) - - assert len(results) == 2 - assert results[0].run_id == "run-1" - assert results[1].run_id == "run-2" - assert mock_delete_run.call_count == 2 - - def test_delete_run_dry_run(self): - deletion = ArchivedWorkflowRunDeletion(dry_run=True) - mock_run = MagicMock(spec=WorkflowRun) - mock_run.id = "run-123" - mock_run.tenant_id = "tenant-456" - - result = deletion._delete_run(mock_run) - - assert result.success is True - assert result.run_id == "run-123" - - def test_delete_run_success(self): - deletion = ArchivedWorkflowRunDeletion(dry_run=False) - mock_run = MagicMock(spec=WorkflowRun) - mock_run.id = "run-123" - mock_run.tenant_id = "tenant-456" - - with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo: - mock_repo = MagicMock() - mock_get_repo.return_value = mock_repo - mock_repo.delete_runs_with_related.return_value = {"workflow_runs": 1} - - result = deletion._delete_run(mock_run) - - assert result.success is True - assert result.deleted_counts == {"workflow_runs": 1} - - def test_delete_run_exception(self): - deletion = ArchivedWorkflowRunDeletion(dry_run=False) - mock_run = MagicMock(spec=WorkflowRun) - mock_run.id = "run-123" - - with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo: - mock_repo = MagicMock() - mock_get_repo.return_value = mock_repo - mock_repo.delete_runs_with_related.side_effect = Exception("Database error") - - result = deletion._delete_run(mock_run) - - assert result.success is False - assert result.error == "Database error" - - def test_delete_trigger_logs(self): - mock_session = MagicMock(spec=Session) - run_ids = ["run-1", "run-2"] - - with patch( - "services.retention.workflow_run.delete_archived_workflow_run.SQLAlchemyWorkflowTriggerLogRepository" - ) as mock_repo_cls: - mock_repo = MagicMock() - mock_repo_cls.return_value = mock_repo - mock_repo.delete_by_run_ids.return_value = 5 - - count = ArchivedWorkflowRunDeletion._delete_trigger_logs(mock_session, run_ids) - - assert count == 5 - mock_repo_cls.assert_called_once_with(mock_session) - mock_repo.delete_by_run_ids.assert_called_once_with(run_ids) - - def test_delete_node_executions(self): - mock_session = MagicMock(spec=Session) - mock_run = MagicMock(spec=WorkflowRun) - mock_run.id = "run-1" - runs = [mock_run] - - with patch( - "repositories.factory.DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository" - ) as mock_create_repo: - mock_repo = MagicMock() - mock_create_repo.return_value = mock_repo - mock_repo.delete_by_runs.return_value = (1, 2) - - with patch("services.retention.workflow_run.delete_archived_workflow_run.sessionmaker") as mock_sm: - result = ArchivedWorkflowRunDeletion._delete_node_executions(mock_session, runs) - - assert result == (1, 2) - mock_create_repo.assert_called_once() - mock_repo.delete_by_runs.assert_called_once_with(mock_session, ["run-1"]) - - def test_get_workflow_run_repo(self, mock_db): - deletion = ArchivedWorkflowRunDeletion() - - with patch( - "repositories.factory.DifyAPIRepositoryFactory.create_api_workflow_run_repository" - ) as mock_create_repo: - mock_repo = MagicMock() - mock_create_repo.return_value = mock_repo - - # First call - repo1 = deletion._get_workflow_run_repo() - assert repo1 == mock_repo - assert deletion.workflow_run_repo == mock_repo - - # Second call (should return cached) - repo2 = deletion._get_workflow_run_repo() - assert repo2 == mock_repo - mock_create_repo.assert_called_once() diff --git a/api/tests/unit_tests/services/test_agent_service.py b/api/tests/unit_tests/services/test_agent_service.py deleted file mode 100644 index 7ce3d7ef7b..0000000000 --- a/api/tests/unit_tests/services/test_agent_service.py +++ /dev/null @@ -1,346 +0,0 @@ -""" -Unit tests for services.agent_service -""" - -from collections.abc import Callable -from datetime import datetime -from unittest.mock import MagicMock, patch - -import pytest -import pytz - -from core.plugin.impl.exc import PluginDaemonClientSideError -from models import Account -from models.model import App, Conversation, EndUser, Message, MessageAgentThought -from services.agent_service import AgentService - - -def _make_current_user_account(timezone: str = "UTC") -> Account: - account = Account(name="Test User", email="test@example.com") - account.timezone = timezone - return account - - -def _make_app_model(app_model_config: MagicMock | None) -> MagicMock: - app_model = MagicMock(spec=App) - app_model.id = "app-123" - app_model.tenant_id = "tenant-123" - app_model.app_model_config = app_model_config - return app_model - - -def _make_conversation(from_end_user_id: str | None, from_account_id: str | None) -> MagicMock: - conversation = MagicMock(spec=Conversation) - conversation.id = "conv-123" - conversation.app_id = "app-123" - conversation.from_end_user_id = from_end_user_id - conversation.from_account_id = from_account_id - return conversation - - -def _make_message(agent_thoughts: list[MessageAgentThought]) -> MagicMock: - message = MagicMock(spec=Message) - message.id = "msg-123" - message.conversation_id = "conv-123" - message.created_at = datetime(2024, 1, 1, tzinfo=pytz.UTC) - message.provider_response_latency = 1.23 - message.answer_tokens = 4 - message.message_tokens = 6 - message.agent_thoughts = agent_thoughts - message.message_files = ["file-a.txt"] - return message - - -def _make_agent_thought() -> MagicMock: - agent_thought = MagicMock(spec=MessageAgentThought) - agent_thought.tokens = 3 - agent_thought.tool_input = "raw-input" - agent_thought.observation = "raw-output" - agent_thought.thought = "thinking" - agent_thought.created_at = datetime(2024, 1, 1, tzinfo=pytz.UTC) - agent_thought.files = [] - agent_thought.tools = ["tool_a", "dataset_tool"] - agent_thought.tool_labels = {"tool_a": "Tool A"} - agent_thought.tool_meta = { - "tool_a": { - "tool_config": { - "tool_provider_type": "custom", - "tool_provider": "provider-1", - }, - "tool_parameters": {"param": "value"}, - "time_cost": 2.5, - }, - "dataset_tool": { - "tool_config": { - "tool_provider_type": "dataset-retrieval", - "tool_provider": "dataset-provider", - } - }, - } - agent_thought.tool_inputs_dict = {"tool_a": {"q": "hello"}, "dataset_tool": {"k": "v"}} - agent_thought.tool_outputs_dict = {"tool_a": {"result": "ok"}} - return agent_thought - - -def _build_query_side_effect( - conversation: Conversation | None, - message: Message | None, - executor: EndUser | Account | None, -) -> Callable[..., MagicMock]: - def _query_side_effect(*args: object, **kwargs: object) -> MagicMock: - query = MagicMock() - query.where.return_value = query - if any(arg is Conversation for arg in args): - query.first.return_value = conversation - elif any(arg is Message for arg in args): - query.first.return_value = message - elif any(arg is EndUser for arg in args) or any(arg is Account for arg in args): - query.first.return_value = executor - return query - - return _query_side_effect - - -class TestAgentServiceGetAgentLogs: - """Test suite for AgentService.get_agent_logs.""" - - def test_get_agent_logs_should_raise_when_conversation_missing(self) -> None: - """Test missing conversation raises ValueError.""" - # Arrange - app_model = _make_app_model(MagicMock()) - with patch("services.agent_service.db") as mock_db: - query = MagicMock() - query.where.return_value = query - query.first.return_value = None - mock_db.session.query.return_value = query - - # Act & Assert - with pytest.raises(ValueError): - AgentService.get_agent_logs(app_model, "missing-conv", "msg-1") - - def test_get_agent_logs_should_raise_when_message_missing(self) -> None: - """Test missing message raises ValueError.""" - # Arrange - app_model = _make_app_model(MagicMock()) - conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None) - with patch("services.agent_service.db") as mock_db: - conversation_query = MagicMock() - conversation_query.where.return_value = conversation_query - conversation_query.first.return_value = conversation - - message_query = MagicMock() - message_query.where.return_value = message_query - message_query.first.return_value = None - - mock_db.session.query.side_effect = [conversation_query, message_query] - - # Act & Assert - with pytest.raises(ValueError): - AgentService.get_agent_logs(app_model, conversation.id, "missing-msg") - - def test_get_agent_logs_should_raise_when_app_model_config_missing(self) -> None: - """Test missing app model config raises ValueError.""" - # Arrange - app_model = _make_app_model(None) - conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None) - message = _make_message([]) - current_user = _make_current_user_account() - - with patch("services.agent_service.db") as mock_db, patch("services.agent_service.current_user", current_user): - mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, MagicMock()) - - # Act & Assert - with pytest.raises(ValueError): - AgentService.get_agent_logs(app_model, conversation.id, message.id) - - def test_get_agent_logs_should_raise_when_agent_config_missing(self) -> None: - """Test missing agent config raises ValueError.""" - # Arrange - app_model_config = MagicMock() - app_model_config.agent_mode_dict = {"strategy": "react"} - app_model_config.to_dict.return_value = {"tools": []} - app_model = _make_app_model(app_model_config) - conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None) - message = _make_message([]) - current_user = _make_current_user_account() - - with ( - patch("services.agent_service.db") as mock_db, - patch("services.agent_service.AgentConfigManager.convert", return_value=None), - patch("services.agent_service.current_user", current_user), - ): - mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, MagicMock()) - - # Act & Assert - with pytest.raises(ValueError): - AgentService.get_agent_logs(app_model, conversation.id, message.id) - - def test_get_agent_logs_should_return_logs_for_end_user_executor(self) -> None: - """Test agent logs returned for end-user executor with tool icons.""" - # Arrange - agent_thought = _make_agent_thought() - message = _make_message([agent_thought]) - conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None) - executor = MagicMock(spec=EndUser) - executor.name = "End User" - app_model_config = MagicMock() - app_model_config.agent_mode_dict = {"strategy": "react"} - app_model_config.to_dict.return_value = {"tools": []} - app_model = _make_app_model(app_model_config) - current_user = _make_current_user_account() - agent_tool = MagicMock() - agent_tool.tool_name = "tool_a" - agent_tool.provider_type = "custom" - agent_tool.provider_id = "provider-2" - agent_config = MagicMock() - agent_config.tools = [agent_tool] - - with ( - patch("services.agent_service.db") as mock_db, - patch("services.agent_service.AgentConfigManager.convert", return_value=agent_config) as mock_convert, - patch("services.agent_service.ToolManager.get_tool_icon") as mock_get_icon, - patch("services.agent_service.current_user", current_user), - ): - mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, executor) - mock_get_icon.side_effect = [None, "icon-a"] - - # Act - result = AgentService.get_agent_logs(app_model, conversation.id, message.id) - - # Assert - assert result["meta"]["status"] == "success" - assert result["meta"]["executor"] == "End User" - assert result["meta"]["total_tokens"] == 10 - assert result["meta"]["agent_mode"] == "react" - assert result["meta"]["iterations"] == 1 - assert result["files"] == ["file-a.txt"] - assert len(result["iterations"]) == 1 - tool_calls = result["iterations"][0]["tool_calls"] - assert tool_calls[0]["tool_name"] == "tool_a" - assert tool_calls[0]["tool_icon"] == "icon-a" - assert tool_calls[1]["tool_name"] == "dataset_tool" - assert tool_calls[1]["tool_icon"] == "" - mock_convert.assert_called_once() - - def test_get_agent_logs_should_return_account_executor_when_no_end_user(self) -> None: - """Test agent logs fall back to account executor when end user is missing.""" - # Arrange - agent_thought = _make_agent_thought() - message = _make_message([agent_thought]) - conversation = _make_conversation(from_end_user_id=None, from_account_id="account-1") - executor = MagicMock(spec=Account) - executor.name = "Account User" - app_model_config = MagicMock() - app_model_config.agent_mode_dict = {"strategy": "react"} - app_model_config.to_dict.return_value = {"tools": []} - app_model = _make_app_model(app_model_config) - current_user = _make_current_user_account() - agent_config = MagicMock() - agent_config.tools = [] - - with ( - patch("services.agent_service.db") as mock_db, - patch("services.agent_service.AgentConfigManager.convert", return_value=agent_config), - patch("services.agent_service.ToolManager.get_tool_icon", return_value=""), - patch("services.agent_service.current_user", current_user), - ): - mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, executor) - - # Act - result = AgentService.get_agent_logs(app_model, conversation.id, message.id) - - # Assert - assert result["meta"]["executor"] == "Account User" - - def test_get_agent_logs_should_use_defaults_when_executor_and_tool_data_missing(self) -> None: - """Test unknown executor and missing tool details fall back to defaults.""" - # Arrange - agent_thought = _make_agent_thought() - agent_thought.tool_labels = {} - agent_thought.tool_inputs_dict = {} - agent_thought.tool_outputs_dict = None - agent_thought.tool_meta = {"tool_a": {"error": "failed"}} - agent_thought.tools = ["tool_a"] - - message = _make_message([agent_thought]) - conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None) - app_model_config = MagicMock() - app_model_config.agent_mode_dict = {} - app_model_config.to_dict.return_value = {"tools": []} - app_model = _make_app_model(app_model_config) - current_user = _make_current_user_account() - agent_config = MagicMock() - agent_config.tools = [] - - with ( - patch("services.agent_service.db") as mock_db, - patch("services.agent_service.AgentConfigManager.convert", return_value=agent_config), - patch("services.agent_service.ToolManager.get_tool_icon", return_value=None), - patch("services.agent_service.current_user", current_user), - ): - mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, None) - - # Act - result = AgentService.get_agent_logs(app_model, conversation.id, message.id) - - # Assert - assert result["meta"]["executor"] == "Unknown" - assert result["meta"]["agent_mode"] == "react" - tool_call = result["iterations"][0]["tool_calls"][0] - assert tool_call["status"] == "error" - assert tool_call["error"] == "failed" - assert tool_call["tool_label"] == "tool_a" - assert tool_call["tool_input"] == {} - assert tool_call["tool_output"] == {} - assert tool_call["time_cost"] == 0 - assert tool_call["tool_parameters"] == {} - assert tool_call["tool_icon"] is None - - -class TestAgentServiceProviders: - """Test suite for AgentService provider methods.""" - - def test_list_agent_providers_should_delegate_to_plugin_client(self) -> None: - """Test list_agent_providers delegates to PluginAgentClient.""" - # Arrange - tenant_id = "tenant-1" - expected = [{"name": "provider"}] - with patch("services.agent_service.PluginAgentClient") as mock_client: - mock_client.return_value.fetch_agent_strategy_providers.return_value = expected - - # Act - result = AgentService.list_agent_providers("user-1", tenant_id) - - # Assert - assert result == expected - mock_client.return_value.fetch_agent_strategy_providers.assert_called_once_with(tenant_id) - - def test_get_agent_provider_should_return_provider_when_successful(self) -> None: - """Test get_agent_provider returns provider when successful.""" - # Arrange - tenant_id = "tenant-1" - provider_name = "provider-a" - expected = {"name": provider_name} - with patch("services.agent_service.PluginAgentClient") as mock_client: - mock_client.return_value.fetch_agent_strategy_provider.return_value = expected - - # Act - result = AgentService.get_agent_provider("user-1", tenant_id, provider_name) - - # Assert - assert result == expected - mock_client.return_value.fetch_agent_strategy_provider.assert_called_once_with(tenant_id, provider_name) - - def test_get_agent_provider_should_raise_value_error_on_plugin_error(self) -> None: - """Test get_agent_provider wraps PluginDaemonClientSideError into ValueError.""" - # Arrange - tenant_id = "tenant-1" - provider_name = "provider-a" - with patch("services.agent_service.PluginAgentClient") as mock_client: - mock_client.return_value.fetch_agent_strategy_provider.side_effect = PluginDaemonClientSideError( - "plugin error" - ) - - # Act & Assert - with pytest.raises(ValueError): - AgentService.get_agent_provider("user-1", tenant_id, provider_name) diff --git a/api/tests/unit_tests/services/test_dataset_service.py b/api/tests/unit_tests/services/test_dataset_service.py deleted file mode 100644 index a1d2f6410c..0000000000 --- a/api/tests/unit_tests/services/test_dataset_service.py +++ /dev/null @@ -1,129 +0,0 @@ -"""Unit tests for non-SQL DocumentService orchestration behaviors. - -This file intentionally keeps only collaborator-oriented document indexing -orchestration tests. SQL-backed dataset lifecycle cases are covered by -integration tests under testcontainers. -""" - -from unittest.mock import Mock, patch - -import pytest - -from models.dataset import Document -from services.errors.document import DocumentIndexingError - - -class DatasetServiceUnitDataFactory: - """Factory for creating lightweight document doubles used in unit tests.""" - - @staticmethod - def create_document_mock( - document_id: str = "doc-123", - dataset_id: str = "dataset-123", - indexing_status: str = "completed", - is_paused: bool = False, - ) -> Mock: - """Create a document-shaped mock for DocumentService orchestration tests.""" - document = Mock(spec=Document) - document.id = document_id - document.dataset_id = dataset_id - document.indexing_status = indexing_status - document.is_paused = is_paused - document.paused_by = None - document.paused_at = None - return document - - -class TestDatasetServiceDocumentIndexing: - """Unit tests for pause/recover/retry orchestration without SQL assertions.""" - - @pytest.fixture - def mock_document_service_dependencies(self): - """Patch non-SQL collaborators used by DocumentService methods.""" - with ( - patch("services.dataset_service.redis_client") as mock_redis, - patch("services.dataset_service.db.session") as mock_db, - patch("services.dataset_service.current_user") as mock_current_user, - ): - mock_current_user.id = "user-123" - yield { - "redis_client": mock_redis, - "db_session": mock_db, - "current_user": mock_current_user, - } - - def test_pause_document_success(self, mock_document_service_dependencies): - """Pause a document that is currently in an indexable status.""" - # Arrange - document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="indexing") - - # Act - from services.dataset_service import DocumentService - - DocumentService.pause_document(document) - - # Assert - assert document.is_paused is True - assert document.paused_by == "user-123" - mock_document_service_dependencies["db_session"].add.assert_called_once_with(document) - mock_document_service_dependencies["db_session"].commit.assert_called_once() - mock_document_service_dependencies["redis_client"].setnx.assert_called_once_with( - f"document_{document.id}_is_paused", - "True", - ) - - def test_pause_document_invalid_status_error(self, mock_document_service_dependencies): - """Raise DocumentIndexingError when pausing a completed document.""" - # Arrange - document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="completed") - - # Act / Assert - from services.dataset_service import DocumentService - - with pytest.raises(DocumentIndexingError): - DocumentService.pause_document(document) - - def test_recover_document_success(self, mock_document_service_dependencies): - """Recover a paused document and dispatch the recover indexing task.""" - # Arrange - document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="indexing", is_paused=True) - - # Act - with patch("services.dataset_service.recover_document_indexing_task") as recover_task: - from services.dataset_service import DocumentService - - DocumentService.recover_document(document) - - # Assert - assert document.is_paused is False - assert document.paused_by is None - assert document.paused_at is None - mock_document_service_dependencies["db_session"].add.assert_called_once_with(document) - mock_document_service_dependencies["db_session"].commit.assert_called_once() - mock_document_service_dependencies["redis_client"].delete.assert_called_once_with( - f"document_{document.id}_is_paused" - ) - recover_task.delay.assert_called_once_with(document.dataset_id, document.id) - - def test_retry_document_indexing_success(self, mock_document_service_dependencies): - """Reset documents to waiting state and dispatch retry indexing task.""" - # Arrange - dataset_id = "dataset-123" - documents = [ - DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-1", indexing_status="error"), - DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-2", indexing_status="error"), - ] - mock_document_service_dependencies["redis_client"].get.return_value = None - - # Act - with patch("services.dataset_service.retry_document_indexing_task") as retry_task: - from services.dataset_service import DocumentService - - DocumentService.retry_document(dataset_id, documents) - - # Assert - assert all(document.indexing_status == "waiting" for document in documents) - assert mock_document_service_dependencies["db_session"].add.call_count == 2 - assert mock_document_service_dependencies["db_session"].commit.call_count == 2 - assert mock_document_service_dependencies["redis_client"].setex.call_count == 2 - retry_task.delay.assert_called_once_with(dataset_id, ["doc-1", "doc-2"], "user-123") diff --git a/api/tests/unit_tests/services/test_feedback_service.py b/api/tests/unit_tests/services/test_feedback_service.py deleted file mode 100644 index 1f70839ee2..0000000000 --- a/api/tests/unit_tests/services/test_feedback_service.py +++ /dev/null @@ -1,626 +0,0 @@ -import csv -import io -import json -from datetime import datetime -from unittest.mock import MagicMock, patch - -import pytest - -from services.feedback_service import FeedbackService - - -class TestFeedbackServiceFactory: - """Factory class for creating test data and mock objects for feedback service tests.""" - - @staticmethod - def create_feedback_mock( - feedback_id: str = "feedback-123", - app_id: str = "app-456", - conversation_id: str = "conv-789", - message_id: str = "msg-001", - rating: str = "like", - content: str | None = "Great response!", - from_source: str = "user", - from_account_id: str | None = None, - from_end_user_id: str | None = "end-user-001", - created_at: datetime | None = None, - ) -> MagicMock: - """Create a mock MessageFeedback object.""" - feedback = MagicMock() - feedback.id = feedback_id - feedback.app_id = app_id - feedback.conversation_id = conversation_id - feedback.message_id = message_id - feedback.rating = rating - feedback.content = content - feedback.from_source = from_source - feedback.from_account_id = from_account_id - feedback.from_end_user_id = from_end_user_id - feedback.created_at = created_at or datetime.now() - return feedback - - @staticmethod - def create_message_mock( - message_id: str = "msg-001", - query: str = "What is AI?", - answer: str = "AI stands for Artificial Intelligence.", - inputs: dict | None = None, - created_at: datetime | None = None, - ): - """Create a mock Message object.""" - - # Create a simple object with instance attributes - # Using a class with __init__ ensures attributes are instance attributes - class Message: - def __init__(self): - self.id = message_id - self.query = query - self.answer = answer - self.inputs = inputs - self.created_at = created_at or datetime.now() - - return Message() - - @staticmethod - def create_conversation_mock( - conversation_id: str = "conv-789", - name: str | None = "Test Conversation", - ) -> MagicMock: - """Create a mock Conversation object.""" - conversation = MagicMock() - conversation.id = conversation_id - conversation.name = name - return conversation - - @staticmethod - def create_app_mock( - app_id: str = "app-456", - name: str = "Test App", - ) -> MagicMock: - """Create a mock App object.""" - app = MagicMock() - app.id = app_id - app.name = name - return app - - @staticmethod - def create_account_mock( - account_id: str = "account-123", - name: str = "Test Admin", - ) -> MagicMock: - """Create a mock Account object.""" - account = MagicMock() - account.id = account_id - account.name = name - return account - - -class TestFeedbackService: - """ - Comprehensive unit tests for FeedbackService. - - This test suite covers: - - CSV and JSON export formats - - All filter combinations - - Edge cases and error handling - - Response validation - """ - - @pytest.fixture - def factory(self): - """Provide test data factory.""" - return TestFeedbackServiceFactory() - - @pytest.fixture - def sample_feedback_data(self, factory): - """Create sample feedback data for testing.""" - feedback = factory.create_feedback_mock( - rating="like", - content="Excellent answer!", - from_source="user", - ) - message = factory.create_message_mock( - query="What is Python?", - answer="Python is a programming language.", - ) - conversation = factory.create_conversation_mock(name="Python Discussion") - app = factory.create_app_mock(name="AI Assistant") - account = factory.create_account_mock(name="Admin User") - - return [(feedback, message, conversation, app, account)] - - # Test 01: CSV Export - Basic Functionality - @patch("services.feedback_service.db") - def test_export_feedbacks_csv_basic(self, mock_db, factory, sample_feedback_data): - """Test basic CSV export with single feedback record.""" - # Arrange - mock_query = MagicMock() - # Configure the mock to return itself for all chaining methods - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = sample_feedback_data - - # Set up the session.query to return our mock - mock_db.session.query.return_value = mock_query - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="csv") - - # Assert - assert response.mimetype == "text/csv" - assert "charset=utf-8-sig" in response.content_type - assert "attachment" in response.headers["Content-Disposition"] - assert "dify_feedback_export_app-456" in response.headers["Content-Disposition"] - - # Verify CSV content - csv_content = response.get_data(as_text=True) - reader = csv.DictReader(io.StringIO(csv_content)) - rows = list(reader) - - assert len(rows) == 1 - assert rows[0]["feedback_rating"] == "👍" - assert rows[0]["feedback_rating_raw"] == "like" - assert rows[0]["feedback_comment"] == "Excellent answer!" - assert rows[0]["user_query"] == "What is Python?" - assert rows[0]["ai_response"] == "Python is a programming language." - - # Test 02: JSON Export - Basic Functionality - @patch("services.feedback_service.db") - def test_export_feedbacks_json_basic(self, mock_db, factory, sample_feedback_data): - """Test basic JSON export with metadata structure.""" - # Arrange - mock_query = MagicMock() - # Configure the mock to return itself for all chaining methods - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = sample_feedback_data - - # Set up the session.query to return our mock - mock_db.session.query.return_value = mock_query - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json") - - # Assert - assert response.mimetype == "application/json" - assert "charset=utf-8" in response.content_type - assert "attachment" in response.headers["Content-Disposition"] - - # Verify JSON structure - json_content = json.loads(response.get_data(as_text=True)) - assert "export_info" in json_content - assert "feedback_data" in json_content - assert json_content["export_info"]["app_id"] == "app-456" - assert json_content["export_info"]["total_records"] == 1 - assert len(json_content["feedback_data"]) == 1 - - # Test 03: Filter by from_source - @patch("services.feedback_service.db") - def test_export_feedbacks_filter_from_source(self, mock_db, factory): - """Test filtering by feedback source (user/admin).""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [] - - # Act - FeedbackService.export_feedbacks(app_id="app-456", from_source="admin") - - # Assert - mock_query.filter.assert_called() - - # Test 04: Filter by rating - @patch("services.feedback_service.db") - def test_export_feedbacks_filter_rating(self, mock_db, factory): - """Test filtering by rating (like/dislike).""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [] - - # Act - FeedbackService.export_feedbacks(app_id="app-456", rating="dislike") - - # Assert - mock_query.filter.assert_called() - - # Test 05: Filter by has_comment (True) - @patch("services.feedback_service.db") - def test_export_feedbacks_filter_has_comment_true(self, mock_db, factory): - """Test filtering for feedback with comments.""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [] - - # Act - FeedbackService.export_feedbacks(app_id="app-456", has_comment=True) - - # Assert - mock_query.filter.assert_called() - - # Test 06: Filter by has_comment (False) - @patch("services.feedback_service.db") - def test_export_feedbacks_filter_has_comment_false(self, mock_db, factory): - """Test filtering for feedback without comments.""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [] - - # Act - FeedbackService.export_feedbacks(app_id="app-456", has_comment=False) - - # Assert - mock_query.filter.assert_called() - - # Test 07: Filter by date range - @patch("services.feedback_service.db") - def test_export_feedbacks_filter_date_range(self, mock_db, factory): - """Test filtering by start and end dates.""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [] - - # Act - FeedbackService.export_feedbacks( - app_id="app-456", - start_date="2024-01-01", - end_date="2024-12-31", - ) - - # Assert - assert mock_query.filter.call_count >= 2 # Called for both start and end dates - - # Test 08: Invalid date format - start_date - @patch("services.feedback_service.db") - def test_export_feedbacks_invalid_start_date(self, mock_db): - """Test error handling for invalid start_date format.""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - - # Act & Assert - with pytest.raises(ValueError, match="Invalid start_date format"): - FeedbackService.export_feedbacks(app_id="app-456", start_date="invalid-date") - - # Test 09: Invalid date format - end_date - @patch("services.feedback_service.db") - def test_export_feedbacks_invalid_end_date(self, mock_db): - """Test error handling for invalid end_date format.""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - - # Act & Assert - with pytest.raises(ValueError, match="Invalid end_date format"): - FeedbackService.export_feedbacks(app_id="app-456", end_date="2024-13-45") - - # Test 10: Unsupported format - def test_export_feedbacks_unsupported_format(self): - """Test error handling for unsupported export format.""" - # Act & Assert - with pytest.raises(ValueError, match="Unsupported format"): - FeedbackService.export_feedbacks(app_id="app-456", format_type="xml") - - # Test 11: Empty result set - CSV - @patch("services.feedback_service.db") - def test_export_feedbacks_empty_results_csv(self, mock_db): - """Test CSV export with no feedback records.""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [] - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="csv") - - # Assert - csv_content = response.get_data(as_text=True) - reader = csv.DictReader(io.StringIO(csv_content)) - rows = list(reader) - assert len(rows) == 0 - # But headers should still be present - assert reader.fieldnames is not None - - # Test 12: Empty result set - JSON - @patch("services.feedback_service.db") - def test_export_feedbacks_empty_results_json(self, mock_db): - """Test JSON export with no feedback records.""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [] - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json") - - # Assert - json_content = json.loads(response.get_data(as_text=True)) - assert json_content["export_info"]["total_records"] == 0 - assert len(json_content["feedback_data"]) == 0 - - # Test 13: Long response truncation - @patch("services.feedback_service.db") - def test_export_feedbacks_long_response_truncation(self, mock_db, factory): - """Test that long AI responses are truncated to 500 characters.""" - # Arrange - long_answer = "A" * 600 # 600 characters - feedback = factory.create_feedback_mock() - message = factory.create_message_mock(answer=long_answer) - conversation = factory.create_conversation_mock() - app = factory.create_app_mock() - account = factory.create_account_mock() - - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [(feedback, message, conversation, app, account)] - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json") - - # Assert - json_content = json.loads(response.get_data(as_text=True)) - ai_response = json_content["feedback_data"][0]["ai_response"] - assert len(ai_response) == 503 # 500 + "..." - assert ai_response.endswith("...") - - # Test 14: Null account (end user feedback) - @patch("services.feedback_service.db") - def test_export_feedbacks_null_account(self, mock_db, factory): - """Test handling of feedback from end users (no account).""" - # Arrange - feedback = factory.create_feedback_mock(from_account_id=None) - message = factory.create_message_mock() - conversation = factory.create_conversation_mock() - app = factory.create_app_mock() - account = None # No account for end user - - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [(feedback, message, conversation, app, account)] - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json") - - # Assert - json_content = json.loads(response.get_data(as_text=True)) - assert json_content["feedback_data"][0]["from_account_name"] == "" - - # Test 15: Null conversation name - @patch("services.feedback_service.db") - def test_export_feedbacks_null_conversation_name(self, mock_db, factory): - """Test handling of conversations without names.""" - # Arrange - feedback = factory.create_feedback_mock() - message = factory.create_message_mock() - conversation = factory.create_conversation_mock(name=None) - app = factory.create_app_mock() - account = factory.create_account_mock() - - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [(feedback, message, conversation, app, account)] - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json") - - # Assert - json_content = json.loads(response.get_data(as_text=True)) - assert json_content["feedback_data"][0]["conversation_name"] == "" - - # Test 16: Dislike rating emoji - @patch("services.feedback_service.db") - def test_export_feedbacks_dislike_rating(self, mock_db, factory): - """Test that dislike rating shows thumbs down emoji.""" - # Arrange - feedback = factory.create_feedback_mock(rating="dislike") - message = factory.create_message_mock() - conversation = factory.create_conversation_mock() - app = factory.create_app_mock() - account = factory.create_account_mock() - - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [(feedback, message, conversation, app, account)] - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json") - - # Assert - json_content = json.loads(response.get_data(as_text=True)) - assert json_content["feedback_data"][0]["feedback_rating"] == "👎" - assert json_content["feedback_data"][0]["feedback_rating_raw"] == "dislike" - - # Test 17: Combined filters - @patch("services.feedback_service.db") - def test_export_feedbacks_combined_filters(self, mock_db, factory): - """Test applying multiple filters simultaneously.""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [] - - # Act - FeedbackService.export_feedbacks( - app_id="app-456", - from_source="admin", - rating="like", - has_comment=True, - start_date="2024-01-01", - end_date="2024-12-31", - ) - - # Assert - # Should have called filter multiple times for each condition - assert mock_query.filter.call_count >= 4 - - # Test 18: Message query fallback to inputs - @patch("services.feedback_service.db") - def test_export_feedbacks_message_query_from_inputs(self, mock_db, factory): - """Test fallback to inputs.query when message.query is None.""" - # Arrange - feedback = factory.create_feedback_mock() - message = factory.create_message_mock(query=None, inputs={"query": "Query from inputs"}) - conversation = factory.create_conversation_mock() - app = factory.create_app_mock() - account = factory.create_account_mock() - - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [(feedback, message, conversation, app, account)] - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json") - - # Assert - json_content = json.loads(response.get_data(as_text=True)) - assert json_content["feedback_data"][0]["user_query"] == "Query from inputs" - - # Test 19: Empty feedback content - @patch("services.feedback_service.db") - def test_export_feedbacks_empty_feedback_content(self, mock_db, factory): - """Test handling of feedback with empty/null content.""" - # Arrange - feedback = factory.create_feedback_mock(content=None) - message = factory.create_message_mock() - conversation = factory.create_conversation_mock() - app = factory.create_app_mock() - account = factory.create_account_mock() - - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = [(feedback, message, conversation, app, account)] - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="json") - - # Assert - json_content = json.loads(response.get_data(as_text=True)) - assert json_content["feedback_data"][0]["feedback_comment"] == "" - assert json_content["feedback_data"][0]["has_comment"] == "No" - - # Test 20: CSV headers validation - @patch("services.feedback_service.db") - def test_export_feedbacks_csv_headers(self, mock_db, factory, sample_feedback_data): - """Test that CSV contains all expected headers.""" - # Arrange - mock_query = MagicMock() - mock_db.session.query.return_value = mock_query - mock_query.join.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = sample_feedback_data - - expected_headers = [ - "feedback_id", - "app_name", - "app_id", - "conversation_id", - "conversation_name", - "message_id", - "user_query", - "ai_response", - "feedback_rating", - "feedback_rating_raw", - "feedback_comment", - "feedback_source", - "feedback_date", - "message_date", - "from_account_name", - "from_end_user_id", - "has_comment", - ] - - # Act - response = FeedbackService.export_feedbacks(app_id="app-456", format_type="csv") - - # Assert - csv_content = response.get_data(as_text=True) - reader = csv.DictReader(io.StringIO(csv_content)) - assert list(reader.fieldnames) == expected_headers diff --git a/api/tests/unit_tests/services/tools/test_mcp_tools_manage_service.py b/api/tests/unit_tests/services/tools/test_mcp_tools_manage_service.py deleted file mode 100644 index d35e014fab..0000000000 --- a/api/tests/unit_tests/services/tools/test_mcp_tools_manage_service.py +++ /dev/null @@ -1,1045 +0,0 @@ -from __future__ import annotations - -import hashlib -import json -from datetime import datetime -from types import SimpleNamespace -from typing import cast -from unittest.mock import MagicMock - -import pytest -from pytest_mock import MockerFixture -from sqlalchemy.exc import IntegrityError - -from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity -from core.mcp.entities import AuthActionType -from core.mcp.error import MCPAuthError, MCPError -from models.tools import MCPToolProvider -from services.tools.mcp_tools_manage_service import ( - EMPTY_CREDENTIALS_JSON, - EMPTY_TOOLS_JSON, - UNCHANGED_SERVER_URL_PLACEHOLDER, - MCPToolManageService, - OAuthDataType, - ProviderUrlValidationData, - ReconnectResult, - ServerUrlValidationResult, -) - - -class _ToolStub: - def __init__(self, name: str, description: str | None) -> None: - self._name = name - self._description = description - - def model_dump(self) -> dict[str, str | None]: - return {"name": self._name, "description": self._description} - - -@pytest.fixture -def mock_session() -> MagicMock: - # Arrange - return MagicMock() - - -@pytest.fixture -def service(mock_session: MagicMock) -> MCPToolManageService: - # Arrange - return MCPToolManageService(session=mock_session) - - -def _provider_entity_stub(*, authed: bool = True) -> MCPProviderEntity: - return cast( - MCPProviderEntity, - SimpleNamespace( - authed=authed, - timeout=30.0, - sse_read_timeout=300.0, - provider_id="server-1", - headers={"x-api-key": "enc"}, - decrypt_headers=lambda: {"x-api-key": "key"}, - retrieve_tokens=lambda: SimpleNamespace(token_type="bearer", access_token="token-1"), - decrypt_server_url=lambda: "https://mcp.example.com/sse", - to_api_response=lambda user_name=None: { - "id": "provider-1", - "author": user_name or "Anonymous", - "name": "MCP Tool", - "description": {"en_US": "", "zh_Hans": ""}, - "icon": "icon", - "label": {"en_US": "MCP Tool", "zh_Hans": "MCP Tool"}, - "type": "mcp", - "is_team_authorization": True, - "server_url": "https://mcp.example.com/******", - "updated_at": 1, - "server_identifier": "server-1", - "configuration": {"timeout": "30", "sse_read_timeout": "300"}, - "masked_headers": {}, - "is_dynamic_registration": True, - }, - decrypt_credentials=lambda: {"client_id": "plain-id", "client_secret": "plain-secret"}, - masked_credentials=lambda: {"client_id": "pl***id", "client_secret": "pl***et"}, - masked_headers=lambda: {"x-api-key": "ke***ey"}, - ), - ) - - -def _provider_stub(*, authed: bool = True) -> MCPToolProvider: - entity = _provider_entity_stub(authed=authed) - return cast( - MCPToolProvider, - SimpleNamespace( - id="provider-1", - tenant_id="tenant-1", - user_id="user-1", - name="Provider A", - server_identifier="server-1", - server_url="encrypted-url", - server_url_hash="old-hash", - authed=authed, - tools=EMPTY_TOOLS_JSON, - encrypted_credentials=json.dumps({"existing": "credential"}), - encrypted_headers=json.dumps({"x-api-key": "enc"}), - credentials={"existing": "credential"}, - timeout=30.0, - sse_read_timeout=300.0, - updated_at=datetime.now(), - icon="icon", - to_entity=lambda: entity, - load_user=lambda: SimpleNamespace(name="Tester"), - ), - ) - - -def test_server_url_validation_result_should_update_server_url_when_all_conditions_match() -> None: - # Arrange - result = ServerUrlValidationResult( - needs_validation=True, - validation_passed=True, - reconnect_result=ReconnectResult(authed=True, tools="[]", encrypted_credentials="{}"), - ) - - # Act - should_update = result.should_update_server_url - - # Assert - assert should_update is True - - -def test_get_provider_should_return_provider_when_exists( - service: MCPToolManageService, - mock_session: MagicMock, -) -> None: - # Arrange - provider = _provider_stub() - mock_session.scalar.return_value = provider - - # Act - result = service.get_provider(provider_id="provider-1", tenant_id="tenant-1") - - # Assert - assert result is provider - - -def test_get_provider_should_raise_error_when_provider_not_found( - service: MCPToolManageService, mock_session: MagicMock -) -> None: - # Arrange - mock_session.scalar.return_value = None - - # Act + Assert - with pytest.raises(ValueError, match="MCP tool not found"): - service.get_provider(provider_id="provider-404", tenant_id="tenant-1") - - -def test_get_provider_entity_should_get_entity_by_provider_id_when_by_server_id_is_false( - service: MCPToolManageService, - mocker: MockerFixture, -) -> None: - # Arrange - provider = _provider_stub() - mock_get_provider = mocker.patch.object(service, "get_provider", return_value=provider) - - # Act - result = service.get_provider_entity("provider-1", "tenant-1", by_server_id=False) - - # Assert - assert result is provider.to_entity() - mock_get_provider.assert_called_once_with(provider_id="provider-1", tenant_id="tenant-1") - - -def test_get_provider_entity_should_get_entity_by_server_identifier_when_by_server_id_is_true( - service: MCPToolManageService, - mocker: MockerFixture, -) -> None: - # Arrange - provider = _provider_stub() - mock_get_provider = mocker.patch.object(service, "get_provider", return_value=provider) - - # Act - result = service.get_provider_entity("server-1", "tenant-1", by_server_id=True) - - # Assert - assert result is provider.to_entity() - mock_get_provider.assert_called_once_with(server_identifier="server-1", tenant_id="tenant-1") - - -def test_create_provider_should_raise_error_when_server_url_is_invalid(service: MCPToolManageService) -> None: - # Arrange - config = MCPConfiguration(timeout=30, sse_read_timeout=300) - - # Act + Assert - with pytest.raises(ValueError, match="Server URL is not valid"): - service.create_provider( - tenant_id="tenant-1", - name="Provider A", - server_url="invalid-url", - user_id="user-1", - icon="icon", - icon_type="emoji", - icon_background="#fff", - server_identifier="server-1", - configuration=config, - ) - - -def test_create_provider_should_create_and_return_user_provider_when_input_is_valid( - service: MCPToolManageService, - mock_session: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - config = MCPConfiguration(timeout=42, sse_read_timeout=123) - auth_data = MCPAuthentication(client_id="client-id", client_secret="secret") - mocker.patch.object(service, "_check_provider_exists") - mocker.patch("services.tools.mcp_tools_manage_service.encrypter.encrypt_token", return_value="encrypted-url") - mocker.patch.object(service, "_prepare_encrypted_dict", return_value='{"x":"enc"}') - mocker.patch.object(service, "_build_and_encrypt_credentials", return_value='{"client_information":{}}') - mocker.patch.object(service, "_prepare_icon", return_value='{"content":"😀"}') - expected_user_provider = {"id": "provider-1"} - mock_convert = mocker.patch( - "services.tools.mcp_tools_manage_service.ToolTransformService.mcp_provider_to_user_provider", - return_value=expected_user_provider, - ) - - # Act - result = service.create_provider( - tenant_id="tenant-1", - name="Provider A", - server_url="https://mcp.example.com", - user_id="user-1", - icon="😀", - icon_type="emoji", - icon_background="#fff", - server_identifier="server-1", - configuration=config, - authentication=auth_data, - headers={"x-api-key": "v1"}, - ) - - # Assert - assert result == expected_user_provider - mock_session.add.assert_called_once() - mock_session.flush.assert_called_once() - mock_convert.assert_called_once() - - -def test_update_provider_should_raise_error_when_new_name_conflicts( - service: MCPToolManageService, - mock_session: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - provider = _provider_stub() - mocker.patch.object(service, "get_provider", return_value=provider) - mock_session.scalar.return_value = object() - - # Act + Assert - with pytest.raises(ValueError, match="already exists"): - service.update_provider( - tenant_id="tenant-1", - provider_id="provider-1", - name="New Name", - server_url="https://mcp.example.com", - icon="😀", - icon_type="emoji", - icon_background="#fff", - server_identifier="server-1", - configuration=MCPConfiguration(), - ) - - -def test_update_provider_should_update_fields_when_input_is_valid( - service: MCPToolManageService, - mock_session: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - provider = _provider_stub() - validation = ServerUrlValidationResult( - needs_validation=True, - validation_passed=True, - reconnect_result=ReconnectResult(authed=True, tools='[{"name":"t"}]', encrypted_credentials='{"x":"y"}'), - encrypted_server_url="new-encrypted-url", - server_url_hash="new-hash", - ) - mocker.patch.object(service, "get_provider", return_value=provider) - mock_session.scalar.return_value = None - mocker.patch.object(service, "_prepare_icon", return_value="new-icon") - mocker.patch.object(service, "_process_headers", return_value='{"x":"enc"}') - mocker.patch.object(service, "_process_credentials", return_value='{"client":"enc"}') - - # Act - service.update_provider( - tenant_id="tenant-1", - provider_id="provider-1", - name="Provider B", - server_url="https://mcp.example.com/new", - icon="😎", - icon_type="emoji", - icon_background="#000", - server_identifier="server-2", - headers={"x-api-key": "v2"}, - configuration=MCPConfiguration(timeout=50, sse_read_timeout=120), - authentication=MCPAuthentication(client_id="new-id", client_secret="new-secret"), - validation_result=validation, - ) - - # Assert - assert provider.name == "Provider B" - assert provider.server_identifier == "server-2" - assert provider.server_url == "new-encrypted-url" - assert provider.server_url_hash == "new-hash" - assert provider.authed is True - assert provider.tools == '[{"name":"t"}]' - assert provider.encrypted_credentials == '{"client":"enc"}' - assert provider.encrypted_headers == '{"x":"enc"}' - assert provider.timeout == 50 - assert provider.sse_read_timeout == 120 - mock_session.flush.assert_called_once() - - -def test_update_provider_should_handle_integrity_error_with_readable_message( - service: MCPToolManageService, - mock_session: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - provider = _provider_stub() - mocker.patch.object(service, "get_provider", return_value=provider) - mock_session.scalar.return_value = None - mocker.patch.object(service, "_prepare_icon", return_value="icon") - mock_session.flush.side_effect = IntegrityError("stmt", {}, Exception("unique_mcp_provider_name")) - - # Act + Assert - with pytest.raises(ValueError, match="MCP tool Provider A already exists"): - service.update_provider( - tenant_id="tenant-1", - provider_id="provider-1", - name="Provider A", - server_url="https://mcp.example.com", - icon="😀", - icon_type="emoji", - icon_background="#fff", - server_identifier="server-1", - configuration=MCPConfiguration(), - ) - - -def test_delete_provider_should_delete_existing_provider( - service: MCPToolManageService, - mock_session: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - provider = _provider_stub() - mocker.patch.object(service, "get_provider", return_value=provider) - - # Act - service.delete_provider(tenant_id="tenant-1", provider_id="provider-1") - - # Assert - mock_session.delete.assert_called_once_with(provider) - - -def test_list_providers_should_return_empty_list_when_no_provider_exists( - service: MCPToolManageService, - mock_session: MagicMock, -) -> None: - # Arrange - mock_session.scalars.return_value.all.return_value = [] - - # Act - result = service.list_providers(tenant_id="tenant-1") - - # Assert - assert result == [] - - -def test_list_providers_should_convert_all_providers_and_attach_user_names( - service: MCPToolManageService, - mock_session: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - provider_1 = _provider_stub() - provider_2 = _provider_stub() - provider_2.user_id = "user-2" - mock_session.scalars.return_value.all.return_value = [provider_1, provider_2] - mock_session.query.return_value.where.return_value.all.return_value = [ - SimpleNamespace(id="user-1", name="Alice"), - SimpleNamespace(id="user-2", name="Bob"), - ] - mock_convert = mocker.patch( - "services.tools.mcp_tools_manage_service.ToolTransformService.mcp_provider_to_user_provider", - side_effect=[{"id": "1"}, {"id": "2"}], - ) - - # Act - result = service.list_providers(tenant_id="tenant-1", for_list=True, include_sensitive=False) - - # Assert - assert result == [{"id": "1"}, {"id": "2"}] - assert mock_convert.call_count == 2 - - -def test_list_provider_tools_should_raise_error_when_provider_is_not_authenticated( - service: MCPToolManageService, - mocker: MockerFixture, -) -> None: - # Arrange - provider = _provider_stub(authed=False) - mocker.patch.object(service, "get_provider", return_value=provider) - - # Act + Assert - with pytest.raises(ValueError, match="Please auth the tool first"): - service.list_provider_tools(tenant_id="tenant-1", provider_id="provider-1") - - -def test_list_provider_tools_should_raise_error_when_remote_client_fails( - service: MCPToolManageService, - mocker: MockerFixture, -) -> None: - # Arrange - provider = _provider_stub(authed=True) - mocker.patch.object(service, "get_provider", return_value=provider) - mcp_client_instance = MagicMock() - mcp_client_instance.list_tools.side_effect = MCPError("connection failed") - mock_client_cls = mocker.patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") - mock_client_cls.return_value.__enter__.return_value = mcp_client_instance - - # Act + Assert - with pytest.raises(ValueError, match="Failed to connect to MCP server"): - service.list_provider_tools(tenant_id="tenant-1", provider_id="provider-1") - - -def test_list_provider_tools_should_update_db_and_return_response_on_success( - service: MCPToolManageService, - mock_session: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - provider = _provider_stub(authed=True) - mocker.patch.object(service, "get_provider", return_value=provider) - mcp_client_instance = MagicMock() - mcp_client_instance.list_tools.return_value = [ - _ToolStub("tool-a", None), - _ToolStub("tool-b", "desc"), - ] - mock_client_cls = mocker.patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") - mock_client_cls.return_value.__enter__.return_value = mcp_client_instance - mocker.patch("services.tools.mcp_tools_manage_service.ToolTransformService.mcp_tool_to_user_tool", return_value=[]) - - # Act - result = service.list_provider_tools(tenant_id="tenant-1", provider_id="provider-1") - - # Assert - assert result.plugin_unique_identifier == "server-1" - assert provider.authed is True - payload = json.loads(provider.tools) - assert payload[0]["description"] == "" - assert payload[1]["description"] == "desc" - mock_session.flush.assert_called_once() - - -def test_update_provider_credentials_should_update_encrypted_credentials_and_auth_state( - service: MCPToolManageService, - mock_session: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - provider = _provider_stub(authed=True) - provider.encrypted_credentials = json.dumps({"existing": "value"}) - mocker.patch.object(service, "get_provider", return_value=provider) - mock_controller = MagicMock() - mocker.patch("core.tools.mcp_tool.provider.MCPToolProviderController.from_db", return_value=mock_controller) - mock_encryptor = MagicMock() - mock_encryptor.encrypt.return_value = {"access_token": "encrypted-token"} - mocker.patch("services.tools.mcp_tools_manage_service.ProviderConfigEncrypter", return_value=mock_encryptor) - - # Act - service.update_provider_credentials( - provider_id="provider-1", - tenant_id="tenant-1", - credentials={"access_token": "plain-token"}, - authed=False, - ) - - # Assert - assert provider.authed is False - assert provider.tools == EMPTY_TOOLS_JSON - assert json.loads(cast(str, provider.encrypted_credentials))["access_token"] == "encrypted-token" - mock_session.flush.assert_called_once() - - -@pytest.mark.parametrize( - ("data_type", "data", "expected_authed"), - [ - (OAuthDataType.TOKENS, {"access_token": "token"}, True), - (OAuthDataType.MIXED, {"access_token": "token"}, True), - (OAuthDataType.MIXED, {"client_id": "id"}, None), - (OAuthDataType.CLIENT_INFO, {"client_id": "id"}, None), - ], -) -def test_save_oauth_data_should_delegate_with_expected_authed_value( - data_type: OAuthDataType, - data: dict[str, str], - expected_authed: bool | None, - service: MCPToolManageService, - mocker: MockerFixture, -) -> None: - # Arrange - mock_update = mocker.patch.object(service, "update_provider_credentials") - - # Act - service.save_oauth_data("provider-1", "tenant-1", data, data_type) - - # Assert - assert mock_update.call_args.kwargs["authed"] == expected_authed - - -def test_clear_provider_credentials_should_reset_provider_state( - service: MCPToolManageService, - mocker: MockerFixture, -) -> None: - # Arrange - provider = _provider_stub(authed=True) - mocker.patch.object(service, "get_provider", return_value=provider) - - # Act - service.clear_provider_credentials(provider_id="provider-1", tenant_id="tenant-1") - - # Assert - assert provider.tools == EMPTY_TOOLS_JSON - assert provider.encrypted_credentials == EMPTY_CREDENTIALS_JSON - assert provider.authed is False - - -def test_check_provider_exists_should_raise_different_errors_for_conflicts( - service: MCPToolManageService, - mock_session: MagicMock, -) -> None: - # Arrange - mock_session.scalar.return_value = SimpleNamespace( - name="name-a", - server_url_hash="hash-a", - server_identifier="server-a", - ) - - # Act + Assert - with pytest.raises(ValueError, match="MCP tool name-a already exists"): - service._check_provider_exists("tenant-1", "name-a", "hash-b", "server-b") - with pytest.raises(ValueError, match="MCP tool with this server URL already exists"): - service._check_provider_exists("tenant-1", "name-b", "hash-a", "server-b") - with pytest.raises(ValueError, match="MCP tool server-a already exists"): - service._check_provider_exists("tenant-1", "name-b", "hash-b", "server-a") - - -def test_prepare_icon_should_return_json_for_emoji_and_raw_value_for_non_emoji(service: MCPToolManageService) -> None: - # Arrange - # Act - emoji_icon = service._prepare_icon("😀", "emoji", "#fff") - raw_icon = service._prepare_icon("https://icon.png", "file", "#000") - - # Assert - assert json.loads(emoji_icon)["content"] == "😀" - assert raw_icon == "https://icon.png" - - -def test_encrypt_dict_fields_should_encrypt_secret_fields(service: MCPToolManageService, mocker: MockerFixture) -> None: - # Arrange - mock_encryptor = MagicMock() - mock_encryptor.encrypt.return_value = {"Authorization": "enc-token"} - mocker.patch("core.tools.utils.encryption.create_provider_encrypter", return_value=(mock_encryptor, MagicMock())) - - # Act - result = service._encrypt_dict_fields({"Authorization": "token"}, ["Authorization"], "tenant-1") - - # Assert - assert result == {"Authorization": "enc-token"} - - -def test_prepare_encrypted_dict_should_return_json_string(service: MCPToolManageService, mocker: MockerFixture) -> None: - # Arrange - mocker.patch.object(service, "_encrypt_dict_fields", return_value={"x": "enc"}) - - # Act - result = service._prepare_encrypted_dict({"x": "v"}, "tenant-1") - - # Assert - assert result == '{"x": "enc"}' - - -def test_prepare_auth_headers_should_append_authorization_when_tokens_exist(service: MCPToolManageService) -> None: - # Arrange - provider_entity = _provider_entity_stub() - - # Act - headers = service._prepare_auth_headers(provider_entity) - - # Assert - assert headers["Authorization"] == "Bearer token-1" - - -def test_retrieve_remote_mcp_tools_should_return_tools_from_client( - service: MCPToolManageService, - mocker: MockerFixture, -) -> None: - # Arrange - mcp_client_instance = MagicMock() - mcp_client_instance.list_tools.return_value = [_ToolStub("tool-a", "desc")] - mock_client_cls = mocker.patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") - mock_client_cls.return_value.__enter__.return_value = mcp_client_instance - - # Act - tools = service._retrieve_remote_mcp_tools("https://mcp.example.com", {}, _provider_entity_stub()) - - # Assert - assert len(tools) == 1 - assert tools[0].model_dump()["name"] == "tool-a" - - -def test_execute_auth_actions_should_dispatch_supported_actions( - service: MCPToolManageService, - mocker: MockerFixture, -) -> None: - # Arrange - mock_save = mocker.patch.object(service, "save_oauth_data") - auth_result = SimpleNamespace( - actions=[ - SimpleNamespace( - action_type=AuthActionType.SAVE_CLIENT_INFO, - data={"client_id": "c1"}, - provider_id="provider-1", - tenant_id="tenant-1", - ), - SimpleNamespace( - action_type=AuthActionType.SAVE_TOKENS, - data={"access_token": "t1"}, - provider_id="provider-1", - tenant_id="tenant-1", - ), - SimpleNamespace( - action_type=AuthActionType.SAVE_CODE_VERIFIER, - data={"code_verifier": "cv"}, - provider_id="provider-1", - tenant_id="tenant-1", - ), - SimpleNamespace( - action_type=AuthActionType.SAVE_TOKENS, - data={"access_token": "skip"}, - provider_id=None, - tenant_id="tenant-1", - ), - ], - response={"ok": "1"}, - ) - - # Act - result = service.execute_auth_actions(auth_result) - - # Assert - assert result == {"ok": "1"} - assert mock_save.call_count == 3 - - -def test_auth_with_actions_should_call_auth_and_execute_actions( - service: MCPToolManageService, - mocker: MockerFixture, -) -> None: - # Arrange - provider_entity = _provider_entity_stub() - auth_result = SimpleNamespace(actions=[], response={"status": "ok"}) - mocker.patch("services.tools.mcp_tools_manage_service.auth", return_value=auth_result) - mock_execute = mocker.patch.object(service, "execute_auth_actions", return_value={"status": "ok"}) - - # Act - result = service.auth_with_actions(provider_entity=provider_entity, authorization_code="code-1") - - # Assert - assert result == {"status": "ok"} - mock_execute.assert_called_once_with(auth_result) - - -def test_get_provider_for_url_validation_should_return_validation_data( - service: MCPToolManageService, - mocker: MockerFixture, -) -> None: - # Arrange - provider = _provider_stub() - mocker.patch.object(service, "get_provider", return_value=provider) - - # Act - result = service.get_provider_for_url_validation(tenant_id="tenant-1", provider_id="provider-1") - - # Assert - assert result.current_server_url_hash == "old-hash" - assert result.headers == {"x-api-key": "enc"} - - -def test_validate_server_url_standalone_should_skip_validation_for_unchanged_placeholder() -> None: - # Arrange - data = ProviderUrlValidationData(current_server_url_hash="hash", headers={}, timeout=30, sse_read_timeout=300) - - # Act - result = MCPToolManageService.validate_server_url_standalone( - tenant_id="tenant-1", - new_server_url=UNCHANGED_SERVER_URL_PLACEHOLDER, - validation_data=data, - ) - - # Assert - assert result.needs_validation is False - - -def test_validate_server_url_standalone_should_raise_error_for_invalid_url() -> None: - # Arrange - data = ProviderUrlValidationData(current_server_url_hash="hash", headers={}, timeout=30, sse_read_timeout=300) - - # Act + Assert - with pytest.raises(ValueError, match="Server URL is not valid"): - MCPToolManageService.validate_server_url_standalone( - tenant_id="tenant-1", - new_server_url="bad-url", - validation_data=data, - ) - - -def test_validate_server_url_standalone_should_return_no_validation_when_hash_unchanged(mocker: MockerFixture) -> None: - # Arrange - url = "https://mcp.example.com" - current_hash = hashlib.sha256(url.encode()).hexdigest() - data = ProviderUrlValidationData(current_server_url_hash=current_hash, headers={}, timeout=30, sse_read_timeout=300) - mocker.patch("services.tools.mcp_tools_manage_service.encrypter.encrypt_token", return_value="enc-url") - - # Act - result = MCPToolManageService.validate_server_url_standalone( - tenant_id="tenant-1", - new_server_url=url, - validation_data=data, - ) - - # Assert - assert result.needs_validation is False - assert result.encrypted_server_url == "enc-url" - assert result.server_url_hash == current_hash - - -def test_validate_server_url_standalone_should_reconnect_when_url_changes(mocker: MockerFixture) -> None: - # Arrange - url = "https://mcp-new.example.com" - data = ProviderUrlValidationData(current_server_url_hash="old", headers={}, timeout=30, sse_read_timeout=300) - reconnect_result = ReconnectResult(authed=True, tools='[{"name":"x"}]', encrypted_credentials="{}") - mocker.patch("services.tools.mcp_tools_manage_service.encrypter.encrypt_token", return_value="enc-new") - mock_reconnect = mocker.patch.object(MCPToolManageService, "_reconnect_with_url", return_value=reconnect_result) - - # Act - result = MCPToolManageService.validate_server_url_standalone( - tenant_id="tenant-1", - new_server_url=url, - validation_data=data, - ) - - # Assert - assert result.validation_passed is True - assert result.reconnect_result == reconnect_result - mock_reconnect.assert_called_once() - - -def test_reconnect_with_url_should_delegate_to_private_method(mocker: MockerFixture) -> None: - # Arrange - expected = ReconnectResult(authed=True, tools="[]", encrypted_credentials="{}") - mock_delegate = mocker.patch.object(MCPToolManageService, "_reconnect_with_url", return_value=expected) - - # Act - result = MCPToolManageService.reconnect_with_url( - server_url="https://mcp.example.com", - headers={}, - timeout=30, - sse_read_timeout=300, - ) - - # Assert - assert result == expected - mock_delegate.assert_called_once() - - -def test_private_reconnect_with_url_should_return_authed_true_when_connection_succeeds(mocker: MockerFixture) -> None: - # Arrange - mcp_client_instance = MagicMock() - mcp_client_instance.list_tools.return_value = [_ToolStub("tool-a", None)] - mock_client_cls = mocker.patch("core.mcp.mcp_client.MCPClient") - mock_client_cls.return_value.__enter__.return_value = mcp_client_instance - - # Act - result = MCPToolManageService._reconnect_with_url( - server_url="https://mcp.example.com", - headers={}, - timeout=30, - sse_read_timeout=300, - ) - - # Assert - assert result.authed is True - assert json.loads(result.tools)[0]["description"] == "" - - -def test_private_reconnect_with_url_should_return_authed_false_on_auth_error(mocker: MockerFixture) -> None: - # Arrange - mcp_client_instance = MagicMock() - mcp_client_instance.list_tools.side_effect = MCPAuthError("auth required") - mock_client_cls = mocker.patch("core.mcp.mcp_client.MCPClient") - mock_client_cls.return_value.__enter__.return_value = mcp_client_instance - - # Act - result = MCPToolManageService._reconnect_with_url( - server_url="https://mcp.example.com", - headers={}, - timeout=30, - sse_read_timeout=300, - ) - - # Assert - assert result.authed is False - assert result.tools == EMPTY_TOOLS_JSON - - -def test_private_reconnect_with_url_should_raise_value_error_on_mcp_error(mocker: MockerFixture) -> None: - # Arrange - mcp_client_instance = MagicMock() - mcp_client_instance.list_tools.side_effect = MCPError("network failure") - mock_client_cls = mocker.patch("core.mcp.mcp_client.MCPClient") - mock_client_cls.return_value.__enter__.return_value = mcp_client_instance - - # Act + Assert - with pytest.raises(ValueError, match="Failed to re-connect MCP server: network failure"): - MCPToolManageService._reconnect_with_url( - server_url="https://mcp.example.com", - headers={}, - timeout=30, - sse_read_timeout=300, - ) - - -def test_build_tool_provider_response_should_build_api_entity_with_tools( - service: MCPToolManageService, - mocker: MockerFixture, -) -> None: - # Arrange - db_provider = _provider_stub() - provider_entity = _provider_entity_stub() - tools = [_ToolStub("tool-a", "desc")] - mocker.patch("services.tools.mcp_tools_manage_service.ToolTransformService.mcp_tool_to_user_tool", return_value=[]) - - # Act - result = service._build_tool_provider_response(db_provider, provider_entity, tools) - - # Assert - assert result.plugin_unique_identifier == "server-1" - assert result.name == "MCP Tool" - - -@pytest.mark.parametrize( - ("orig_message", "expected_error"), - [ - ("unique_mcp_provider_name", "MCP tool name already exists"), - ("unique_mcp_provider_server_url", "MCP tool https://mcp.example.com already exists"), - ("unique_mcp_provider_server_identifier", "MCP tool server-1 already exists"), - ], -) -def test_handle_integrity_error_should_raise_readable_value_errors( - orig_message: str, - expected_error: str, - service: MCPToolManageService, -) -> None: - """Test that known integrity errors raise readable value errors.""" - # Arrange - error = IntegrityError("stmt", {}, Exception(orig_message)) - - # Act + Assert - with pytest.raises(ValueError, match=expected_error): - service._handle_integrity_error(error, "name", "https://mcp.example.com", "server-1") - - -def test_handle_integrity_error_should_reraise_unknown_error(service: MCPToolManageService) -> None: - """Test that unknown integrity errors are re-raised.""" - # Arrange - error = IntegrityError("stmt", {}, Exception("unknown-constraint")) - - # Act + Assert - with pytest.raises(IntegrityError) as exc_info: - service._handle_integrity_error(error, "name", "url", "identifier") - - assert exc_info.value is error - - -@pytest.mark.parametrize( - ("url", "expected"), - [ - ("https://mcp.example.com", True), - ("http://mcp.example.com", True), - ("", False), - ("invalid", False), - ("ftp://mcp.example.com", False), - ], -) -def test_is_valid_url_should_validate_supported_schemes( - url: str, - expected: bool, - service: MCPToolManageService, -) -> None: - # Arrange - # Act - result = service._is_valid_url(url) - - # Assert - assert result is expected - - -def test_update_optional_fields_should_update_only_non_none_values(service: MCPToolManageService) -> None: - # Arrange - provider = _provider_stub() - configuration = MCPConfiguration(timeout=99, sse_read_timeout=300) - - # Act - service._update_optional_fields(provider, configuration) - - # Assert - assert provider.timeout == 99 - assert provider.sse_read_timeout == 300 - - -def test_process_headers_should_return_none_when_empty_headers(service: MCPToolManageService) -> None: - # Arrange - provider = _provider_stub() - - # Act - result = service._process_headers({}, provider, "tenant-1") - - # Assert - assert result is None - - -def test_process_headers_should_merge_and_encrypt_headers( - service: MCPToolManageService, - mocker: MockerFixture, -) -> None: - # Arrange - provider = _provider_stub() - mocker.patch.object(service, "_merge_headers_with_masked", return_value={"x-api-key": "plain"}) - mocker.patch.object(service, "_prepare_encrypted_dict", return_value='{"x-api-key":"enc"}') - - # Act - result = service._process_headers({"x-api-key": "*****"}, provider, "tenant-1") - - # Assert - assert result == '{"x-api-key":"enc"}' - - -def test_process_credentials_should_merge_and_encrypt_credentials( - service: MCPToolManageService, - mocker: MockerFixture, -) -> None: - # Arrange - provider = _provider_stub() - authentication = MCPAuthentication(client_id="masked-id", client_secret="masked-secret") - mocker.patch.object(service, "_merge_credentials_with_masked", return_value=("plain-id", "plain-secret")) - mocker.patch.object(service, "_build_and_encrypt_credentials", return_value='{"client_information":{}}') - - # Act - result = service._process_credentials(authentication, provider, "tenant-1") - - # Assert - assert result == '{"client_information":{}}' - - -def test_merge_headers_with_masked_should_preserve_original_values_for_unchanged_masked_inputs( - service: MCPToolManageService, -) -> None: - # Arrange - provider = _provider_stub() - incoming_headers = {"x-api-key": "ke***ey", "new-header": "new-value", "dropped": "*****"} - - # Act - result = service._merge_headers_with_masked(incoming_headers, provider) - - # Assert - assert result["x-api-key"] == "key" - assert result["new-header"] == "new-value" - assert result["dropped"] == "*****" - - -def test_merge_credentials_with_masked_should_preserve_decrypted_values_when_masked_match( - service: MCPToolManageService, -) -> None: - # Arrange - provider = _provider_stub() - - # Act - client_id, client_secret = service._merge_credentials_with_masked("pl***id", "pl***et", provider) - - # Assert - assert client_id == "plain-id" - assert client_secret == "plain-secret" - - -def test_build_and_encrypt_credentials_should_encrypt_secret_when_client_secret_present( - service: MCPToolManageService, - mocker: MockerFixture, -) -> None: - # Arrange - mocker.patch.object( - service, - "_encrypt_dict_fields", - return_value={ - "client_id": "id", - "client_name": "Dify", - "is_dynamic_registration": False, - "encrypted_client_secret": "enc-secret", - }, - ) - - # Act - result = service._build_and_encrypt_credentials("id", "secret", "tenant-1") - - # Assert - payload = json.loads(result) - assert payload["client_information"]["encrypted_client_secret"] == "enc-secret" - - -def test_build_and_encrypt_credentials_should_skip_secret_field_when_client_secret_is_none( - service: MCPToolManageService, - mocker: MockerFixture, -) -> None: - # Arrange - mocker.patch.object( - service, - "_encrypt_dict_fields", - return_value={"client_id": "id", "client_name": "Dify", "is_dynamic_registration": False}, - ) - - # Act - result = service._build_and_encrypt_credentials("id", None, "tenant-1") - - # Assert - payload = json.loads(result) - assert "encrypted_client_secret" not in payload["client_information"] diff --git a/api/uv.lock b/api/uv.lock index 1b0cc495d9..f6ae4e3e90 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1901,7 +1901,7 @@ dev = [ { name = "lxml-stubs", specifier = "~=0.5.1" }, { name = "mypy", specifier = "~=1.19.1" }, { name = "pandas-stubs", specifier = "~=3.0.0" }, - { name = "pyrefly", specifier = ">=0.55.0" }, + { name = "pyrefly", specifier = ">=0.57.1" }, { name = "pytest", specifier = "~=9.0.2" }, { name = "pytest-benchmark", specifier = "~=5.2.3" }, { name = "pytest-cov", specifier = "~=7.1.0" }, @@ -5707,18 +5707,18 @@ wheels = [ [[package]] name = "pyrefly" -version = "0.55.0" +version = "0.57.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/bf/c4/76e0797215e62d007f81f86c9c4fb5d6202685a3f5e70810f3fd94294f92/pyrefly-0.55.0.tar.gz", hash = "sha256:434c3282532dd4525c4840f2040ed0eb79b0ec8224fe18d957956b15471f2441", size = 5135682, upload-time = "2026-03-03T00:46:38.122Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c9/c1/c17211e5bbd2b90a24447484713da7cc2cee4e9455e57b87016ffc69d426/pyrefly-0.57.1.tar.gz", hash = "sha256:b05f6f5ee3a6a5d502ca19d84cb9ab62d67f05083819964a48c1510f2993efc6", size = 5310800, upload-time = "2026-03-18T18:42:35.614Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/39/b0/16e50cf716784513648e23e726a24f71f9544aa4f86103032dcaa5ff71a2/pyrefly-0.55.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:49aafcefe5e2dd4256147db93e5b0ada42bff7d9a60db70e03d1f7055338eec9", size = 12210073, upload-time = "2026-03-03T00:46:15.51Z" }, - { url = "https://files.pythonhosted.org/packages/3a/ad/89500c01bac3083383011600370289fbc67700c5be46e781787392628a3a/pyrefly-0.55.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:2827426e6b28397c13badb93c0ede0fb0f48046a7a89e3d774cda04e8e2067cd", size = 11767474, upload-time = "2026-03-03T00:46:18.003Z" }, - { url = "https://files.pythonhosted.org/packages/78/68/4c66b260f817f304ead11176ff13985625f7c269e653304b4bdb546551af/pyrefly-0.55.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7346b2d64dc575bd61aa3bca854fbf8b5a19a471cbdb45e0ca1e09861b63488c", size = 33260395, upload-time = "2026-03-03T00:46:20.509Z" }, - { url = "https://files.pythonhosted.org/packages/47/09/10bd48c9f860064f29f412954126a827d60f6451512224912c265e26bbe6/pyrefly-0.55.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:233b861b4cff008b1aff62f4f941577ed752e4d0060834229eb9b6826e6973c9", size = 35848269, upload-time = "2026-03-03T00:46:23.418Z" }, - { url = "https://files.pythonhosted.org/packages/a9/39/bc65cdd5243eb2dfea25dd1321f9a5a93e8d9c3a308501c4c6c05d011585/pyrefly-0.55.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f5aa85657d76da1d25d081a49f0e33c8fc3ec91c1a0f185a8ed393a5a3d9e178", size = 38449820, upload-time = "2026-03-03T00:46:26.309Z" }, - { url = "https://files.pythonhosted.org/packages/e5/64/58b38963b011af91209e87f868cc85cfc762ec49a4568ce610c45e7a5f40/pyrefly-0.55.0-py3-none-win32.whl", hash = "sha256:23f786a78536a56fed331b245b7d10ec8945bebee7b723491c8d66fdbc155fe6", size = 11259415, upload-time = "2026-03-03T00:46:30.875Z" }, - { url = "https://files.pythonhosted.org/packages/7a/0b/a4aa519ff632a1ea69eec942566951670b870b99b5c08407e1387b85b6a4/pyrefly-0.55.0-py3-none-win_amd64.whl", hash = "sha256:d465b49e999b50eeb069ad23f0f5710651cad2576f9452a82991bef557df91ee", size = 12043581, upload-time = "2026-03-03T00:46:33.674Z" }, - { url = "https://files.pythonhosted.org/packages/f1/51/89017636fbe1ffd166ad478990c6052df615b926182fa6d3c0842b407e89/pyrefly-0.55.0-py3-none-win_arm64.whl", hash = "sha256:732ff490e0e863b296e7c0b2471e08f8ba7952f9fa6e9de09d8347fd67dde77f", size = 11548076, upload-time = "2026-03-03T00:46:36.193Z" }, + { url = "https://files.pythonhosted.org/packages/b7/58/8af37856c8d45b365ece635a6728a14b0356b08d1ff1ac601d7120def1e0/pyrefly-0.57.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:91974bfbe951eebf5a7bc959c1f3921f0371c789cad84761511d695e9ab2265f", size = 12681847, upload-time = "2026-03-18T18:42:10.963Z" }, + { url = "https://files.pythonhosted.org/packages/5f/d7/fae6dd9d0355fc5b8df7793f1423b7433ca8e10b698ea934c35f0e4e6522/pyrefly-0.57.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:808087298537c70f5e7cdccb5bbaad482e7e056e947c0adf00fb612cbace9fdc", size = 12219634, upload-time = "2026-03-18T18:42:13.469Z" }, + { url = "https://files.pythonhosted.org/packages/29/8f/9511ae460f0690e837b9ba0f7e5e192079e16ff9a9ba8a272450e81f11f8/pyrefly-0.57.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b01f454fa5539e070c0cba17ddec46b3d2107d571d519bd8eca8f3142ba02a6", size = 34947757, upload-time = "2026-03-18T18:42:17.152Z" }, + { url = "https://files.pythonhosted.org/packages/07/43/f053bf9c65218f70e6a49561e9942c7233f8c3e4da8d42e5fe2aae50b3d2/pyrefly-0.57.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:02ad59ea722191f51635f23e37574662116b82ca9d814529f7cb5528f041f381", size = 37621018, upload-time = "2026-03-18T18:42:20.79Z" }, + { url = "https://files.pythonhosted.org/packages/0e/76/9cea46de01665bbc125e4f215340c9365c8d56cda6198ff238a563ea8e75/pyrefly-0.57.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:54bc0afe56776145e37733ff763e7e9679ee8a76c467b617dc3f227d4124a9e2", size = 40203649, upload-time = "2026-03-18T18:42:24.519Z" }, + { url = "https://files.pythonhosted.org/packages/fd/8b/2fb4a96d75e2a57df698a43e2970e441ba2704e3906cdc0386a055daa05a/pyrefly-0.57.1-py3-none-win32.whl", hash = "sha256:468e5839144b25bb0dce839bfc5fd879c9f38e68ebf5de561f30bed9ae19d8ca", size = 11732953, upload-time = "2026-03-18T18:42:27.379Z" }, + { url = "https://files.pythonhosted.org/packages/13/5a/4a197910fe2e9b102b15ae5e7687c45b7b5981275a11a564b41e185dd907/pyrefly-0.57.1-py3-none-win_amd64.whl", hash = "sha256:46db9c97093673c4fb7fab96d610e74d140661d54688a92d8e75ad885a56c141", size = 12537319, upload-time = "2026-03-18T18:42:30.196Z" }, + { url = "https://files.pythonhosted.org/packages/b5/c6/bc442874be1d9b63da1f9debb4f04b7d0c590a8dc4091921f3c288207242/pyrefly-0.57.1-py3-none-win_arm64.whl", hash = "sha256:feb1bbe3b0d8d5a70121dcdf1476e6a99cc056a26a49379a156f040729244dcb", size = 12013455, upload-time = "2026-03-18T18:42:32.928Z" }, ] [[package]]