from unittest.mock import MagicMock, patch import pytest from sqlalchemy.orm import Session from core.plugin.entities.plugin_daemon import CredentialType from dify_graph.model_runtime.entities.provider_entities import FormType from models.account import Account from models.model import EndUser from models.oauth import DatasourceProvider from models.provider_ids import DatasourceProviderID from services.datasource_provider_service import DatasourceProviderService, get_current_user # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def make_id(s: str = "org/plugin/provider") -> DatasourceProviderID: return DatasourceProviderID(s) # --------------------------------------------------------------------------- # Test class # --------------------------------------------------------------------------- class TestDatasourceProviderService: """Comprehensive tests for DatasourceProviderService targeting >95% coverage.""" @pytest.fixture def service(self): return DatasourceProviderService() @pytest.fixture def mock_db_session(self): """ Robust, chainable query mock. q returns itself for .filter_by(), .order_by(), .where() so any SQLAlchemy chaining pattern works without multiple brittle sub-mocks. """ with patch("services.datasource_provider_service.Session") as mock_cls: sess = MagicMock(spec=Session) q = MagicMock() sess.query.return_value = q # Self-returning chain — any method called on q returns q q.filter_by.return_value = q q.order_by.return_value = q q.where.return_value = q # Default terminal values (tests override per-case) q.first.return_value = None q.all.return_value = [] q.count.return_value = 0 q.delete.return_value = 1 mock_cls.return_value.__enter__.return_value = sess mock_cls.return_value.no_autoflush.__enter__.return_value = sess yield sess @pytest.fixture(autouse=True) def patch_db(self, mock_db_session): with patch("services.datasource_provider_service.db") as mock_db: mock_db.session = mock_db_session mock_db.engine = MagicMock() yield mock_db @pytest.fixture(autouse=True) def patch_externals(self): with ( patch("httpx.request") as mock_httpx, patch("services.datasource_provider_service.dify_config") as mock_cfg, patch("services.datasource_provider_service.encrypter") as mock_enc, patch("services.datasource_provider_service.redis_client") as mock_redis, patch("services.datasource_provider_service.generate_incremental_name") as mock_genname, patch("services.datasource_provider_service.OAuthHandler") as mock_oauth, ): mock_cfg.CONSOLE_API_URL = "http://localhost" mock_enc.encrypt_token.return_value = "enc_tok" mock_enc.decrypt_token.return_value = "dec_tok" mock_enc.decrypt.return_value = {"k": "dec"} mock_enc.encrypt.return_value = {"k": "enc"} mock_enc.obfuscated_token.return_value = "obf" mock_enc.mask_plugin_credentials.return_value = {"k": "mask"} mock_redis.lock.return_value.__enter__.return_value = MagicMock() mock_genname.return_value = "gen_name" mock_oauth.return_value.refresh_credentials.return_value = MagicMock( credentials={"k": "v"}, expires_at=9999 ) resp = MagicMock() resp.status_code = 200 resp.json.return_value = { "code": 0, "message": "ok", "data": { "provider": "prov", "plugin_unique_identifier": "pui", "plugin_id": "org/plug", "is_authorized": False, "declaration": { "identity": { "author": "a", "name": "n", "description": {"en_US": "d"}, "icon": "i", "label": {"en_US": "l"}, }, "credentials_schema": [], "oauth_schema": {"credentials_schema": [], "client_schema": []}, "provider_type": "local_file", "datasources": [], }, }, } mock_httpx.return_value = resp # Store handles for assertions self._enc = mock_enc self._redis = mock_redis yield @pytest.fixture def mock_user(self): u = MagicMock() u.id = "uid-1" return u # ----------------------------------------------------------------------- # get_current_user (lines 27-40) # ----------------------------------------------------------------------- def test_should_return_proxy_when_current_object_is_account(self): with patch("libs.login.current_user", new_callable=MagicMock) as proxy: user_obj = MagicMock() user_obj.__class__ = Account proxy._get_current_object.return_value = user_obj assert get_current_user() is proxy def test_should_return_proxy_when_current_object_is_enduser(self): with patch("libs.login.current_user", new_callable=MagicMock) as proxy: user_obj = MagicMock() user_obj.__class__ = EndUser proxy._get_current_object.return_value = user_obj assert get_current_user() is proxy def test_should_return_proxy_when_get_current_object_raises_attribute_error(self): """AttributeError from LocalProxy falls back to the proxy itself.""" with patch("libs.login.current_user", new_callable=MagicMock) as proxy: proxy._get_current_object.side_effect = AttributeError("no attr") proxy.__class__ = Account # make the proxy itself satisfy isinstance assert get_current_user() is proxy def test_should_raise_type_error_when_user_is_not_account_or_enduser(self): with patch("libs.login.current_user", new_callable=MagicMock) as proxy: proxy._get_current_object.return_value = "plain_string" with pytest.raises(TypeError, match="current_user must be Account or EndUser"): get_current_user() # ----------------------------------------------------------------------- # is_system_oauth_params_exist (line 357-363) # ----------------------------------------------------------------------- def test_should_return_true_when_system_oauth_params_exist(self, service, mock_db_session): mock_db_session.query().first.return_value = MagicMock() assert service.is_system_oauth_params_exist(make_id()) is True def test_should_return_false_when_system_oauth_params_missing(self, service, mock_db_session): mock_db_session.query().first.return_value = None assert service.is_system_oauth_params_exist(make_id()) is False # ----------------------------------------------------------------------- # is_tenant_oauth_params_enabled (lines 365-379) # NOTE: uses .count() not .first() # ----------------------------------------------------------------------- def test_should_return_true_when_tenant_oauth_params_enabled(self, service, mock_db_session): mock_db_session.query().count.return_value = 1 assert service.is_tenant_oauth_params_enabled("t1", make_id()) is True def test_should_return_false_when_tenant_oauth_params_disabled(self, service, mock_db_session): mock_db_session.query().count.return_value = 0 assert service.is_tenant_oauth_params_enabled("t1", make_id()) is False # ----------------------------------------------------------------------- # remove_oauth_custom_client_params (lines 55-61) # ----------------------------------------------------------------------- def test_should_delete_tenant_config_when_removing_oauth_params(self, service, mock_db_session): service.remove_oauth_custom_client_params("t1", make_id()) mock_db_session.query().delete.assert_called_once() # ----------------------------------------------------------------------- # setup_oauth_custom_client_params (315-351) # ----------------------------------------------------------------------- def test_should_skip_db_write_when_credentials_are_none(self, service, mock_db_session): """When credentials=None, should return immediately without any DB write.""" service.setup_oauth_custom_client_params("t1", make_id(), None, None) mock_db_session.add.assert_not_called() def test_should_create_new_config_when_none_exists(self, service, mock_db_session): mock_db_session.query().first.return_value = None with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)): service.setup_oauth_custom_client_params("t1", make_id(), {"k": "v"}, True) mock_db_session.add.assert_called_once() def test_should_update_existing_config_when_record_found(self, service, mock_db_session): existing = MagicMock() mock_db_session.query().first.return_value = existing with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)): service.setup_oauth_custom_client_params("t1", make_id(), {"k": "v"}, False) mock_db_session.add.assert_not_called() # update in place, no add # ----------------------------------------------------------------------- # decrypt / encrypt credentials (lines 70-98) # ----------------------------------------------------------------------- def test_should_decrypt_secret_fields_when_decrypting_api_key_credentials(self, service, mock_db_session): p = MagicMock(spec=DatasourceProvider) p.auth_type = "api_key" p.encrypted_credentials = {"sk": "enc_val"} with patch.object(service, "extract_secret_variables", return_value=["sk"]): result = service.decrypt_datasource_provider_credentials("t1", p, "org/plug", "prov") assert result["sk"] == "dec_tok" def test_should_encrypt_secret_fields_when_encrypting_api_key_credentials(self, service, mock_db_session): p = MagicMock(spec=DatasourceProvider) p.auth_type = "api_key" with patch.object(service, "extract_secret_variables", return_value=["sk"]): result = service.encrypt_datasource_provider_credentials("t1", "prov", "org/plug", {"sk": "plain"}, p) assert result["sk"] == "enc_tok" self._enc.encrypt_token.assert_called() # ----------------------------------------------------------------------- # get_datasource_credentials (lines 113-165) # ----------------------------------------------------------------------- def test_should_return_empty_dict_when_credential_not_found(self, service, mock_db_session, mock_user): with patch("services.datasource_provider_service.get_current_user", return_value=mock_user): mock_db_session.query().first.return_value = None assert service.get_datasource_credentials("t1", "prov", "org/plug") == {} def test_should_refresh_oauth_tokens_when_expired(self, service, mock_db_session, mock_user): """Expired OAuth credential (expires_at near zero) triggers a silent refresh.""" p = MagicMock(spec=DatasourceProvider) p.auth_type = "oauth2" p.expires_at = 0 # expired p.encrypted_credentials = {"tok": "x"} mock_db_session.query().first.return_value = p with ( patch("services.datasource_provider_service.get_current_user", return_value=mock_user), patch.object(service, "get_oauth_client", return_value={"oc": "v"}), patch.object(service, "decrypt_datasource_provider_credentials", return_value={"tok": "plain"}), ): service.get_datasource_credentials("t1", "prov", "org/plug") mock_db_session.commit.assert_called_once() def test_should_return_decrypted_credentials_when_api_key_not_expired(self, service, mock_db_session, mock_user): """API key credentials with expires_at=-1 skip refresh and return directly.""" p = MagicMock(spec=DatasourceProvider) p.auth_type = "api_key" p.expires_at = -1 # sentinel: never expires p.encrypted_credentials = {"k": "v"} mock_db_session.query().first.return_value = p with ( patch("services.datasource_provider_service.get_current_user", return_value=mock_user), patch.object(service, "decrypt_datasource_provider_credentials", return_value={"k": "plain"}), ): result = service.get_datasource_credentials("t1", "prov", "org/plug") assert result == {"k": "plain"} def test_should_fetch_by_credential_id_when_provided(self, service, mock_db_session, mock_user): """When credential_id is passed, the credential_id filter path (line 113) is taken.""" p = MagicMock(spec=DatasourceProvider) p.auth_type = "api_key" p.expires_at = -1 p.encrypted_credentials = {} mock_db_session.query().first.return_value = p with ( patch("services.datasource_provider_service.get_current_user", return_value=mock_user), patch.object(service, "decrypt_datasource_provider_credentials", return_value={"k": "v"}), ): result = service.get_datasource_credentials("t1", "prov", "org/plug", credential_id="cred-id") assert result == {"k": "v"} # ----------------------------------------------------------------------- # get_all_datasource_credentials_by_provider (lines 176-228) # ----------------------------------------------------------------------- def test_should_return_empty_list_when_no_provider_credentials_exist(self, service, mock_db_session, mock_user): with patch("services.datasource_provider_service.get_current_user", return_value=mock_user): mock_db_session.query().all.return_value = [] assert service.get_all_datasource_credentials_by_provider("t1", "prov", "org/plug") == [] def test_should_refresh_and_return_credentials_when_oauth_expired(self, service, mock_db_session, mock_user): p = MagicMock(spec=DatasourceProvider) p.auth_type = "oauth2" p.expires_at = 0 p.encrypted_credentials = {"t": "x"} mock_db_session.query().all.return_value = [p] with ( patch("services.datasource_provider_service.get_current_user", return_value=mock_user), patch.object(service, "get_oauth_client", return_value={"oc": "v"}), patch.object(service, "decrypt_datasource_provider_credentials", return_value={"t": "plain"}), ): result = service.get_all_datasource_credentials_by_provider("t1", "prov", "org/plug") assert len(result) == 1 # ----------------------------------------------------------------------- # update_datasource_provider_name (lines 236-303) # ----------------------------------------------------------------------- def test_should_raise_value_error_when_provider_not_found_on_name_update(self, service, mock_db_session): mock_db_session.query().first.return_value = None with pytest.raises(ValueError, match="not found"): service.update_datasource_provider_name("t1", make_id(), "new", "cred-id") def test_should_return_early_when_new_name_matches_current(self, service, mock_db_session): p = MagicMock(spec=DatasourceProvider) p.name = "same" mock_db_session.query().first.return_value = p service.update_datasource_provider_name("t1", make_id(), "same", "cred-id") mock_db_session.commit.assert_not_called() def test_should_raise_value_error_when_name_already_exists(self, service, mock_db_session): p = MagicMock(spec=DatasourceProvider) p.name = "old_name" p.is_default = False mock_db_session.query().first.return_value = p mock_db_session.query().count.return_value = 1 # conflict with pytest.raises(ValueError, match="already exists"): service.update_datasource_provider_name("t1", make_id(), "new_name", "some-id") def test_should_update_name_and_commit_when_no_conflict(self, service, mock_db_session): p = MagicMock(spec=DatasourceProvider) p.name = "old_name" p.is_default = False mock_db_session.query().first.return_value = p mock_db_session.query().count.return_value = 0 service.update_datasource_provider_name("t1", make_id(), "new_name", "some-id") assert p.name == "new_name" mock_db_session.commit.assert_called_once() # ----------------------------------------------------------------------- # set_default_datasource_provider (lines 277-303) # ----------------------------------------------------------------------- def test_should_raise_value_error_when_target_provider_not_found(self, service, mock_db_session): mock_db_session.query().first.return_value = None with pytest.raises(ValueError, match="not found"): service.set_default_datasource_provider("t1", make_id(), "bad-id") def test_should_mark_target_as_default_and_commit(self, service, mock_db_session): target = MagicMock(spec=DatasourceProvider) target.provider = "provider" target.plugin_id = "org/plug" mock_db_session.query().first.return_value = target service.set_default_datasource_provider("t1", make_id(), "new-id") assert target.is_default is True mock_db_session.commit.assert_called_once() # ----------------------------------------------------------------------- # get_oauth_encrypter (lines 404-420) # ----------------------------------------------------------------------- def test_should_raise_value_error_when_oauth_schema_missing(self, service): pm = MagicMock() pm.declaration.oauth_schema = None with patch.object(service.provider_manager, "fetch_datasource_provider", return_value=pm): with pytest.raises(ValueError, match="oauth schema not found"): service.get_oauth_encrypter("t1", make_id()) def test_should_return_encrypter_when_oauth_schema_exists(self, service): schema_item = MagicMock() schema_item.to_basic_provider_config.return_value = MagicMock() pm = MagicMock() pm.declaration.oauth_schema.client_schema = [schema_item] with ( patch.object(service.provider_manager, "fetch_datasource_provider", return_value=pm), patch( "services.datasource_provider_service.create_provider_encrypter", return_value=(MagicMock(), MagicMock()), ), ): result = service.get_oauth_encrypter("t1", make_id()) assert result is not None # ----------------------------------------------------------------------- # get_tenant_oauth_client (lines 381-402) # ----------------------------------------------------------------------- def test_should_return_masked_credentials_when_mask_is_true(self, service, mock_db_session): tenant_params = MagicMock() tenant_params.client_params = {"k": "v"} mock_db_session.query().first.return_value = tenant_params with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)): result = service.get_tenant_oauth_client("t1", make_id(), mask=True) assert result == {"k": "mask"} def test_should_return_decrypted_credentials_when_mask_is_false(self, service, mock_db_session): tenant_params = MagicMock() tenant_params.client_params = {"k": "v"} mock_db_session.query().first.return_value = tenant_params with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)): result = service.get_tenant_oauth_client("t1", make_id(), mask=False) assert result == {"k": "dec"} def test_should_return_none_when_no_tenant_oauth_config_exists(self, service, mock_db_session): mock_db_session.query().first.return_value = None assert service.get_tenant_oauth_client("t1", make_id()) is None # ----------------------------------------------------------------------- # get_oauth_client (lines 423-457) # ----------------------------------------------------------------------- def test_should_use_tenant_config_when_available(self, service, mock_db_session): mock_db_session.query().first.return_value = MagicMock(client_params={"k": "v"}) with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)): result = service.get_oauth_client("t1", make_id()) assert result == {"k": "dec"} def test_should_fallback_to_system_credentials_when_tenant_config_missing(self, service, mock_db_session): mock_db_session.query().first.side_effect = [None, MagicMock(system_credentials={"k": "sys"})] with ( patch.object(service.provider_manager, "fetch_datasource_provider"), patch("services.datasource_provider_service.PluginService.is_plugin_verified", return_value=True), ): result = service.get_oauth_client("t1", make_id()) assert result == {"k": "sys"} def test_should_raise_value_error_when_no_oauth_config_available(self, service, mock_db_session): """Neither tenant nor system credentials → raises ValueError.""" mock_db_session.query().first.side_effect = [None, None] with ( patch.object(service.provider_manager, "fetch_datasource_provider"), patch("services.datasource_provider_service.PluginService.is_plugin_verified", return_value=False), ): with pytest.raises(ValueError, match="Please configure oauth client params"): service.get_oauth_client("t1", make_id()) # ----------------------------------------------------------------------- # add_datasource_oauth_provider (lines 539-607) # ----------------------------------------------------------------------- def test_should_add_oauth_provider_successfully_when_name_is_unique(self, service, mock_db_session): mock_db_session.query().count.return_value = 0 with patch.object(service, "extract_secret_variables", return_value=[]): service.add_datasource_oauth_provider("new", "t1", make_id(), "http://cb", 9999, {}) mock_db_session.add.assert_called_once() mock_db_session.commit.assert_called_once() def test_should_auto_rename_when_oauth_provider_name_conflicts(self, service, mock_db_session): """Conflict on name results in auto-incremented name, not an error.""" mock_db_session.query().count.return_value = 1 # conflict first, then auto-named mock_db_session.query().all.return_value = [] with ( patch.object(service, "extract_secret_variables", return_value=[]), patch.object(service, "generate_next_datasource_provider_name", return_value="new_gen"), ): service.add_datasource_oauth_provider("conflict", "t1", make_id(), "http://cb", 9999, {}) mock_db_session.add.assert_called_once() def test_should_auto_generate_name_when_none_provided_for_oauth(self, service, mock_db_session): """name=None causes auto-generation via generate_next_datasource_provider_name.""" mock_db_session.query().count.return_value = 0 mock_db_session.query().all.return_value = [] with ( patch.object(service, "extract_secret_variables", return_value=[]), patch.object(service, "generate_next_datasource_provider_name", return_value="auto"), ): service.add_datasource_oauth_provider(None, "t1", make_id(), "http://cb", 9999, {}) mock_db_session.add.assert_called_once() def test_should_encrypt_secret_fields_when_adding_oauth_provider(self, service, mock_db_session): mock_db_session.query().count.return_value = 0 with patch.object(service, "extract_secret_variables", return_value=["secret_key"]): service.add_datasource_oauth_provider("nm", "t1", make_id(), "http://cb", 9999, {"secret_key": "value"}) self._enc.encrypt_token.assert_called() def test_should_acquire_redis_lock_when_adding_oauth_provider(self, service, mock_db_session): mock_db_session.query().count.return_value = 0 with patch.object(service, "extract_secret_variables", return_value=[]): service.add_datasource_oauth_provider("nm", "t1", make_id(), "http://cb", 9999, {}) self._redis.lock.assert_called() # ----------------------------------------------------------------------- # reauthorize_datasource_oauth_provider (lines 477-537) # ----------------------------------------------------------------------- def test_should_raise_value_error_when_credential_id_not_found_on_reauth(self, service, mock_db_session): mock_db_session.query().first.return_value = None with patch.object(service, "extract_secret_variables", return_value=[]): with pytest.raises(ValueError, match="not found"): service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "bad-id") def test_should_reauthorize_and_commit_when_credential_found(self, service, mock_db_session): p = MagicMock(spec=DatasourceProvider) mock_db_session.query().first.return_value = p mock_db_session.query().count.return_value = 0 with patch.object(service, "extract_secret_variables", return_value=[]): service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "oid") mock_db_session.commit.assert_called_once() def test_should_auto_rename_when_reauth_name_conflicts(self, service, mock_db_session): p = MagicMock(spec=DatasourceProvider) mock_db_session.query().first.return_value = p mock_db_session.query().count.return_value = 1 # conflict mock_db_session.query().all.return_value = [] with patch.object(service, "extract_secret_variables", return_value=["tok"]): service.reauthorize_datasource_oauth_provider( "conflict_name", "t1", make_id(), "u", 9999, {"tok": "v"}, "cred-id" ) mock_db_session.commit.assert_called_once() def test_should_encrypt_secret_fields_when_reauthorizing(self, service, mock_db_session): p = MagicMock(spec=DatasourceProvider) mock_db_session.query().first.return_value = p mock_db_session.query().count.return_value = 0 with patch.object(service, "extract_secret_variables", return_value=["tok"]): service.reauthorize_datasource_oauth_provider(None, "t1", make_id(), "u", 9999, {"tok": "val"}, "cred-id") self._enc.encrypt_token.assert_called() def test_should_acquire_redis_lock_when_reauthorizing(self, service, mock_db_session): p = MagicMock(spec=DatasourceProvider) mock_db_session.query().first.return_value = p mock_db_session.query().count.return_value = 0 with patch.object(service, "extract_secret_variables", return_value=[]): service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "oid") self._redis.lock.assert_called() # ----------------------------------------------------------------------- # add_datasource_api_key_provider (lines 608-675) # ----------------------------------------------------------------------- def test_should_raise_value_error_when_api_key_name_already_exists(self, service, mock_db_session, mock_user): """explicit name supplied + conflict → raises ValueError immediately.""" mock_db_session.query().count.return_value = 1 with patch("services.datasource_provider_service.get_current_user", return_value=mock_user): with pytest.raises(ValueError, match="already exists"): service.add_datasource_api_key_provider("clash", "t1", make_id(), {"sk": "v"}) def test_should_raise_value_error_when_credentials_validation_fails(self, service, mock_db_session, mock_user): mock_db_session.query().count.return_value = 0 with ( patch("services.datasource_provider_service.get_current_user", return_value=mock_user), patch.object(service.provider_manager, "validate_provider_credentials", side_effect=Exception("bad cred")), patch.object(service, "extract_secret_variables", return_value=[]), ): with pytest.raises(ValueError, match="Failed to validate"): service.add_datasource_api_key_provider("nm", "t1", make_id(), {"k": "v"}) def test_should_add_api_key_provider_and_commit_when_valid(self, service, mock_db_session, mock_user): mock_db_session.query().count.return_value = 0 with ( patch("services.datasource_provider_service.get_current_user", return_value=mock_user), patch.object(service.provider_manager, "validate_provider_credentials"), patch.object(service, "extract_secret_variables", return_value=["sk"]), ): service.add_datasource_api_key_provider(None, "t1", make_id(), {"sk": "v"}) mock_db_session.add.assert_called_once() mock_db_session.commit.assert_called_once() def test_should_acquire_redis_lock_when_adding_api_key_provider(self, service, mock_db_session, mock_user): mock_db_session.query().count.return_value = 0 with ( patch("services.datasource_provider_service.get_current_user", return_value=mock_user), patch.object(service.provider_manager, "validate_provider_credentials"), patch.object(service, "extract_secret_variables", return_value=[]), ): service.add_datasource_api_key_provider(None, "t1", make_id(), {}) self._redis.lock.assert_called() # ----------------------------------------------------------------------- # extract_secret_variables (lines 666-699) # ----------------------------------------------------------------------- def test_should_extract_secret_variable_names_for_api_key_schema(self, service): schema = MagicMock() schema.name = "my_secret" schema.type = MagicMock() schema.type.value = FormType.SECRET_INPUT # "secret-input" pm = MagicMock() pm.declaration.credentials_schema = [schema] with patch.object(service.provider_manager, "fetch_datasource_provider", return_value=pm): result = service.extract_secret_variables("t1", "org/plug/prov", CredentialType.API_KEY) assert "my_secret" in result def test_should_extract_secret_variable_names_for_oauth2_schema(self, service): schema = MagicMock() schema.name = "oauth_secret" schema.type = MagicMock() schema.type.value = FormType.SECRET_INPUT pm = MagicMock() pm.declaration.oauth_schema.credentials_schema = [schema] with patch.object(service.provider_manager, "fetch_datasource_provider", return_value=pm): result = service.extract_secret_variables("t1", "org/plug/prov", CredentialType.OAUTH2) assert "oauth_secret" in result def test_should_raise_value_error_when_credential_type_is_invalid(self, service): pm = MagicMock() with patch.object(service.provider_manager, "fetch_datasource_provider", return_value=pm): with pytest.raises(ValueError, match="Invalid credential type"): service.extract_secret_variables("t1", "org/plug/prov", CredentialType.UNAUTHORIZED) # ----------------------------------------------------------------------- # list_datasource_credentials (lines 721-754) # ----------------------------------------------------------------------- def test_should_return_empty_list_when_no_credentials_stored(self, service, mock_db_session): mock_db_session.query().all.return_value = [] assert service.list_datasource_credentials("t1", "prov", "org/plug") == [] def test_should_return_masked_credentials_list_when_credentials_exist(self, service, mock_db_session): p = MagicMock(spec=DatasourceProvider) p.auth_type = "api_key" p.encrypted_credentials = {"sk": "v"} p.is_default = False mock_db_session.query().all.return_value = [p] with patch.object(service, "extract_secret_variables", return_value=["sk"]): result = service.list_datasource_credentials("t1", "prov", "org/plug") assert len(result) == 1 # ----------------------------------------------------------------------- # get_all_datasource_credentials (lines 808-871) # ----------------------------------------------------------------------- def test_should_aggregate_credentials_for_non_hardcoded_plugin(self, service): with patch("services.datasource_provider_service.PluginDatasourceManager") as mock_mgr: ds = MagicMock() ds.provider = "prov" ds.plugin_id = "org/plug" ds.declaration.identity.label.model_dump.return_value = {"en_US": "Label"} mock_mgr.return_value.fetch_installed_datasource_providers.return_value = [ds] cred = {"credential": {"k": "v"}, "is_default": True} with patch.object(service, "list_datasource_credentials", return_value=[cred]): results = service.get_all_datasource_credentials("t1") assert len(results) == 1 def test_should_include_oauth_schema_for_hardcoded_plugin_ids(self, service, mock_db_session): """Lines 819-871: get_all_datasource_credentials covers hardcoded langgenius plugin IDs.""" with patch("services.datasource_provider_service.PluginDatasourceManager") as mock_mgr: ds = MagicMock() ds.plugin_id = "langgenius/firecrawl_datasource" ds.provider = "firecrawl" ds.plugin_unique_identifier = "pui" ds.declaration.identity.icon = "icon" ds.declaration.identity.name = "langgenius/firecrawl_datasource" ds.declaration.identity.label.model_dump.return_value = {"en_US": "Firecrawl"} ds.declaration.identity.description.model_dump.return_value = {"en_US": "desc"} ds.declaration.identity.author = "langgenius" ds.declaration.credentials_schema = [] ds.declaration.oauth_schema.client_schema = [] ds.declaration.oauth_schema.credentials_schema = [] mock_mgr.return_value.fetch_installed_datasource_providers.return_value = [ds] with ( patch.object(service, "list_datasource_credentials", return_value=[]), patch.object(service, "get_tenant_oauth_client", return_value=None), patch.object(service, "is_tenant_oauth_params_enabled", return_value=False), patch.object(service, "is_system_oauth_params_exist", return_value=False), ): results = service.get_all_datasource_credentials("t1") assert len(results) == 1 assert results[0]["oauth_schema"] is not None # ----------------------------------------------------------------------- # get_real_datasource_credentials (lines 873-915) # ----------------------------------------------------------------------- def test_should_return_empty_list_when_no_real_credentials_exist(self, service, mock_db_session): mock_db_session.query().all.return_value = [] assert service.get_real_datasource_credentials("t1", "prov", "org/plug") == [] def test_should_return_decrypted_credential_list_when_credentials_exist(self, service, mock_db_session): p = MagicMock(spec=DatasourceProvider) p.auth_type = "api_key" p.encrypted_credentials = {"sk": "v"} mock_db_session.query().all.return_value = [p] with patch.object(service, "extract_secret_variables", return_value=["sk"]): result = service.get_real_datasource_credentials("t1", "prov", "org/plug") assert len(result) == 1 # ----------------------------------------------------------------------- # update_datasource_credentials (lines 917-978) # ----------------------------------------------------------------------- def test_should_raise_value_error_when_credential_not_found_on_update(self, service, mock_db_session, mock_user): mock_db_session.query().first.return_value = None with patch("services.datasource_provider_service.get_current_user", return_value=mock_user): with pytest.raises(ValueError, match="not found"): service.update_datasource_credentials("t1", "id", "prov", "org/plug", {}, "name") def test_should_raise_value_error_when_new_name_already_used_on_update(self, service, mock_db_session, mock_user): p = MagicMock(spec=DatasourceProvider) p.name = "old_name" p.auth_type = "api_key" p.encrypted_credentials = {"sk": "e"} mock_db_session.query().first.return_value = p mock_db_session.query().count.return_value = 1 with patch("services.datasource_provider_service.get_current_user", return_value=mock_user): with pytest.raises(ValueError, match="already exists"): service.update_datasource_credentials("t1", "id", "prov", "org/plug", {}, "new_name") def test_should_raise_value_error_when_credential_validation_fails_on_update( self, service, mock_db_session, mock_user ): p = MagicMock(spec=DatasourceProvider) p.name = "old_name" p.auth_type = "api_key" p.encrypted_credentials = {"sk": "e"} mock_db_session.query().first.return_value = p mock_db_session.query().count.return_value = 0 with ( patch("services.datasource_provider_service.get_current_user", return_value=mock_user), patch.object(service, "extract_secret_variables", return_value=["sk"]), patch.object(service.provider_manager, "validate_provider_credentials", side_effect=Exception("bad")), ): with pytest.raises(ValueError, match="Failed to validate"): service.update_datasource_credentials("t1", "id", "prov", "org/plug", {"sk": "v"}, "name") def test_should_encrypt_credentials_and_commit_when_update_succeeds(self, service, mock_db_session, mock_user): """Verifies that encrypted_credentials is reassigned with encrypted value and commit is called.""" p = MagicMock(spec=DatasourceProvider) p.name = "old_name" p.auth_type = "api_key" p.encrypted_credentials = {"sk": "old_enc"} mock_db_session.query().first.return_value = p mock_db_session.query().count.return_value = 0 with ( patch("services.datasource_provider_service.get_current_user", return_value=mock_user), patch.object(service, "extract_secret_variables", return_value=["sk"]), patch.object(service.provider_manager, "validate_provider_credentials"), ): service.update_datasource_credentials("t1", "id", "prov", "org/plug", {"sk": "new_val"}, "name") # encrypter must have been called with the new secret value self._enc.encrypt_token.assert_called() # commit must be called exactly once mock_db_session.commit.assert_called_once() # ----------------------------------------------------------------------- # remove_datasource_credentials (lines 980-997) # ----------------------------------------------------------------------- def test_should_delete_provider_and_commit_when_found(self, service, mock_db_session): p = MagicMock(spec=DatasourceProvider) mock_db_session.query().first.return_value = p service.remove_datasource_credentials("t1", "id", "prov", "org/plug") mock_db_session.delete.assert_called_once_with(p) mock_db_session.commit.assert_called_once() def test_should_do_nothing_when_credential_not_found_on_remove(self, service, mock_db_session): """No error raised; no delete called when record doesn't exist (lines 994 branch).""" mock_db_session.query().first.return_value = None service.remove_datasource_credentials("t1", "id", "prov", "org/plug") mock_db_session.delete.assert_not_called()