mirror of
https://github.com/langgenius/dify.git
synced 2026-03-15 20:07:23 +08:00
761 lines
40 KiB
Python
761 lines
40 KiB
Python
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
from sqlalchemy.orm import Session
|
|
|
|
from core.plugin.entities.plugin_daemon import CredentialType
|
|
from dify_graph.model_runtime.entities.provider_entities import FormType
|
|
from models.account import Account
|
|
from models.model import EndUser
|
|
from models.oauth import DatasourceProvider
|
|
from models.provider_ids import DatasourceProviderID
|
|
from services.datasource_provider_service import DatasourceProviderService, get_current_user
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def make_id(s: str = "org/plugin/provider") -> DatasourceProviderID:
|
|
return DatasourceProviderID(s)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Test class
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestDatasourceProviderService:
|
|
"""Comprehensive tests for DatasourceProviderService targeting >95% coverage."""
|
|
|
|
@pytest.fixture
|
|
def service(self):
|
|
return DatasourceProviderService()
|
|
|
|
@pytest.fixture
|
|
def mock_db_session(self):
|
|
"""
|
|
Robust, chainable query mock.
|
|
q returns itself for .filter_by(), .order_by(), .where() so any
|
|
SQLAlchemy chaining pattern works without multiple brittle sub-mocks.
|
|
"""
|
|
with patch("services.datasource_provider_service.Session") as mock_cls:
|
|
sess = MagicMock(spec=Session)
|
|
|
|
q = MagicMock()
|
|
sess.query.return_value = q
|
|
|
|
# Self-returning chain — any method called on q returns q
|
|
q.filter_by.return_value = q
|
|
q.order_by.return_value = q
|
|
q.where.return_value = q
|
|
|
|
# Default terminal values (tests override per-case)
|
|
q.first.return_value = None
|
|
q.all.return_value = []
|
|
q.count.return_value = 0
|
|
q.delete.return_value = 1
|
|
|
|
mock_cls.return_value.__enter__.return_value = sess
|
|
mock_cls.return_value.no_autoflush.__enter__.return_value = sess
|
|
|
|
yield sess
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def patch_db(self, mock_db_session):
|
|
with patch("services.datasource_provider_service.db") as mock_db:
|
|
mock_db.session = mock_db_session
|
|
mock_db.engine = MagicMock()
|
|
yield mock_db
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def patch_externals(self):
|
|
with (
|
|
patch("httpx.request") as mock_httpx,
|
|
patch("services.datasource_provider_service.dify_config") as mock_cfg,
|
|
patch("services.datasource_provider_service.encrypter") as mock_enc,
|
|
patch("services.datasource_provider_service.redis_client") as mock_redis,
|
|
patch("services.datasource_provider_service.generate_incremental_name") as mock_genname,
|
|
patch("services.datasource_provider_service.OAuthHandler") as mock_oauth,
|
|
):
|
|
mock_cfg.CONSOLE_API_URL = "http://localhost"
|
|
mock_enc.encrypt_token.return_value = "enc_tok"
|
|
mock_enc.decrypt_token.return_value = "dec_tok"
|
|
mock_enc.decrypt.return_value = {"k": "dec"}
|
|
mock_enc.encrypt.return_value = {"k": "enc"}
|
|
mock_enc.obfuscated_token.return_value = "obf"
|
|
mock_enc.mask_plugin_credentials.return_value = {"k": "mask"}
|
|
|
|
mock_redis.lock.return_value.__enter__.return_value = MagicMock()
|
|
mock_genname.return_value = "gen_name"
|
|
|
|
mock_oauth.return_value.refresh_credentials.return_value = MagicMock(
|
|
credentials={"k": "v"}, expires_at=9999
|
|
)
|
|
|
|
resp = MagicMock()
|
|
resp.status_code = 200
|
|
resp.json.return_value = {
|
|
"code": 0,
|
|
"message": "ok",
|
|
"data": {
|
|
"provider": "prov",
|
|
"plugin_unique_identifier": "pui",
|
|
"plugin_id": "org/plug",
|
|
"is_authorized": False,
|
|
"declaration": {
|
|
"identity": {
|
|
"author": "a",
|
|
"name": "n",
|
|
"description": {"en_US": "d"},
|
|
"icon": "i",
|
|
"label": {"en_US": "l"},
|
|
},
|
|
"credentials_schema": [],
|
|
"oauth_schema": {"credentials_schema": [], "client_schema": []},
|
|
"provider_type": "local_file",
|
|
"datasources": [],
|
|
},
|
|
},
|
|
}
|
|
mock_httpx.return_value = resp
|
|
|
|
# Store handles for assertions
|
|
self._enc = mock_enc
|
|
self._redis = mock_redis
|
|
yield
|
|
|
|
@pytest.fixture
|
|
def mock_user(self):
|
|
u = MagicMock()
|
|
u.id = "uid-1"
|
|
return u
|
|
|
|
# -----------------------------------------------------------------------
|
|
# get_current_user (lines 27-40)
|
|
# -----------------------------------------------------------------------
|
|
|
|
def test_should_return_proxy_when_current_object_is_account(self):
|
|
with patch("libs.login.current_user", new_callable=MagicMock) as proxy:
|
|
user_obj = MagicMock()
|
|
user_obj.__class__ = Account
|
|
proxy._get_current_object.return_value = user_obj
|
|
assert get_current_user() is proxy
|
|
|
|
def test_should_return_proxy_when_current_object_is_enduser(self):
|
|
with patch("libs.login.current_user", new_callable=MagicMock) as proxy:
|
|
user_obj = MagicMock()
|
|
user_obj.__class__ = EndUser
|
|
proxy._get_current_object.return_value = user_obj
|
|
assert get_current_user() is proxy
|
|
|
|
def test_should_return_proxy_when_get_current_object_raises_attribute_error(self):
|
|
"""AttributeError from LocalProxy falls back to the proxy itself."""
|
|
with patch("libs.login.current_user", new_callable=MagicMock) as proxy:
|
|
proxy._get_current_object.side_effect = AttributeError("no attr")
|
|
proxy.__class__ = Account # make the proxy itself satisfy isinstance
|
|
assert get_current_user() is proxy
|
|
|
|
def test_should_raise_type_error_when_user_is_not_account_or_enduser(self):
|
|
with patch("libs.login.current_user", new_callable=MagicMock) as proxy:
|
|
proxy._get_current_object.return_value = "plain_string"
|
|
with pytest.raises(TypeError, match="current_user must be Account or EndUser"):
|
|
get_current_user()
|
|
|
|
# -----------------------------------------------------------------------
|
|
# is_system_oauth_params_exist (line 357-363)
|
|
# -----------------------------------------------------------------------
|
|
|
|
def test_should_return_true_when_system_oauth_params_exist(self, service, mock_db_session):
|
|
mock_db_session.query().first.return_value = MagicMock()
|
|
assert service.is_system_oauth_params_exist(make_id()) is True
|
|
|
|
def test_should_return_false_when_system_oauth_params_missing(self, service, mock_db_session):
|
|
mock_db_session.query().first.return_value = None
|
|
assert service.is_system_oauth_params_exist(make_id()) is False
|
|
|
|
# -----------------------------------------------------------------------
|
|
# is_tenant_oauth_params_enabled (lines 365-379)
|
|
# NOTE: uses .count() not .first()
|
|
# -----------------------------------------------------------------------
|
|
|
|
def test_should_return_true_when_tenant_oauth_params_enabled(self, service, mock_db_session):
|
|
mock_db_session.query().count.return_value = 1
|
|
assert service.is_tenant_oauth_params_enabled("t1", make_id()) is True
|
|
|
|
def test_should_return_false_when_tenant_oauth_params_disabled(self, service, mock_db_session):
|
|
mock_db_session.query().count.return_value = 0
|
|
assert service.is_tenant_oauth_params_enabled("t1", make_id()) is False
|
|
|
|
# -----------------------------------------------------------------------
|
|
# remove_oauth_custom_client_params (lines 55-61)
|
|
# -----------------------------------------------------------------------
|
|
|
|
def test_should_delete_tenant_config_when_removing_oauth_params(self, service, mock_db_session):
|
|
service.remove_oauth_custom_client_params("t1", make_id())
|
|
mock_db_session.query().delete.assert_called_once()
|
|
|
|
# -----------------------------------------------------------------------
|
|
# setup_oauth_custom_client_params (315-351)
|
|
# -----------------------------------------------------------------------
|
|
|
|
def test_should_skip_db_write_when_credentials_are_none(self, service, mock_db_session):
|
|
"""When credentials=None, should return immediately without any DB write."""
|
|
service.setup_oauth_custom_client_params("t1", make_id(), None, None)
|
|
mock_db_session.add.assert_not_called()
|
|
|
|
def test_should_create_new_config_when_none_exists(self, service, mock_db_session):
|
|
mock_db_session.query().first.return_value = None
|
|
with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)):
|
|
service.setup_oauth_custom_client_params("t1", make_id(), {"k": "v"}, True)
|
|
mock_db_session.add.assert_called_once()
|
|
|
|
def test_should_update_existing_config_when_record_found(self, service, mock_db_session):
|
|
existing = MagicMock()
|
|
mock_db_session.query().first.return_value = existing
|
|
with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)):
|
|
service.setup_oauth_custom_client_params("t1", make_id(), {"k": "v"}, False)
|
|
mock_db_session.add.assert_not_called() # update in place, no add
|
|
|
|
# -----------------------------------------------------------------------
|
|
# decrypt / encrypt credentials (lines 70-98)
|
|
# -----------------------------------------------------------------------
|
|
|
|
def test_should_decrypt_secret_fields_when_decrypting_api_key_credentials(self, service, mock_db_session):
|
|
p = MagicMock(spec=DatasourceProvider)
|
|
p.auth_type = "api_key"
|
|
p.encrypted_credentials = {"sk": "enc_val"}
|
|
with patch.object(service, "extract_secret_variables", return_value=["sk"]):
|
|
result = service.decrypt_datasource_provider_credentials("t1", p, "org/plug", "prov")
|
|
assert result["sk"] == "dec_tok"
|
|
|
|
def test_should_encrypt_secret_fields_when_encrypting_api_key_credentials(self, service, mock_db_session):
|
|
p = MagicMock(spec=DatasourceProvider)
|
|
p.auth_type = "api_key"
|
|
with patch.object(service, "extract_secret_variables", return_value=["sk"]):
|
|
result = service.encrypt_datasource_provider_credentials("t1", "prov", "org/plug", {"sk": "plain"}, p)
|
|
assert result["sk"] == "enc_tok"
|
|
self._enc.encrypt_token.assert_called()
|
|
|
|
# -----------------------------------------------------------------------
|
|
# get_datasource_credentials (lines 113-165)
|
|
# -----------------------------------------------------------------------
|
|
|
|
def test_should_return_empty_dict_when_credential_not_found(self, service, mock_db_session, mock_user):
|
|
with patch("services.datasource_provider_service.get_current_user", return_value=mock_user):
|
|
mock_db_session.query().first.return_value = None
|
|
assert service.get_datasource_credentials("t1", "prov", "org/plug") == {}
|
|
|
|
def test_should_refresh_oauth_tokens_when_expired(self, service, mock_db_session, mock_user):
|
|
"""Expired OAuth credential (expires_at near zero) triggers a silent refresh."""
|
|
p = MagicMock(spec=DatasourceProvider)
|
|
p.auth_type = "oauth2"
|
|
p.expires_at = 0 # expired
|
|
p.encrypted_credentials = {"tok": "x"}
|
|
mock_db_session.query().first.return_value = p
|
|
with (
|
|
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
|
|
patch.object(service, "get_oauth_client", return_value={"oc": "v"}),
|
|
patch.object(service, "decrypt_datasource_provider_credentials", return_value={"tok": "plain"}),
|
|
):
|
|
service.get_datasource_credentials("t1", "prov", "org/plug")
|
|
mock_db_session.commit.assert_called_once()
|
|
|
|
def test_should_return_decrypted_credentials_when_api_key_not_expired(self, service, mock_db_session, mock_user):
|
|
"""API key credentials with expires_at=-1 skip refresh and return directly."""
|
|
p = MagicMock(spec=DatasourceProvider)
|
|
p.auth_type = "api_key"
|
|
p.expires_at = -1 # sentinel: never expires
|
|
p.encrypted_credentials = {"k": "v"}
|
|
mock_db_session.query().first.return_value = p
|
|
with (
|
|
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
|
|
patch.object(service, "decrypt_datasource_provider_credentials", return_value={"k": "plain"}),
|
|
):
|
|
result = service.get_datasource_credentials("t1", "prov", "org/plug")
|
|
assert result == {"k": "plain"}
|
|
|
|
def test_should_fetch_by_credential_id_when_provided(self, service, mock_db_session, mock_user):
|
|
"""When credential_id is passed, the credential_id filter path (line 113) is taken."""
|
|
p = MagicMock(spec=DatasourceProvider)
|
|
p.auth_type = "api_key"
|
|
p.expires_at = -1
|
|
p.encrypted_credentials = {}
|
|
mock_db_session.query().first.return_value = p
|
|
with (
|
|
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
|
|
patch.object(service, "decrypt_datasource_provider_credentials", return_value={"k": "v"}),
|
|
):
|
|
result = service.get_datasource_credentials("t1", "prov", "org/plug", credential_id="cred-id")
|
|
assert result == {"k": "v"}
|
|
|
|
# -----------------------------------------------------------------------
|
|
# get_all_datasource_credentials_by_provider (lines 176-228)
|
|
# -----------------------------------------------------------------------
|
|
|
|
def test_should_return_empty_list_when_no_provider_credentials_exist(self, service, mock_db_session, mock_user):
|
|
with patch("services.datasource_provider_service.get_current_user", return_value=mock_user):
|
|
mock_db_session.query().all.return_value = []
|
|
assert service.get_all_datasource_credentials_by_provider("t1", "prov", "org/plug") == []
|
|
|
|
def test_should_refresh_and_return_credentials_when_oauth_expired(self, service, mock_db_session, mock_user):
|
|
p = MagicMock(spec=DatasourceProvider)
|
|
p.auth_type = "oauth2"
|
|
p.expires_at = 0
|
|
p.encrypted_credentials = {"t": "x"}
|
|
mock_db_session.query().all.return_value = [p]
|
|
with (
|
|
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
|
|
patch.object(service, "get_oauth_client", return_value={"oc": "v"}),
|
|
patch.object(service, "decrypt_datasource_provider_credentials", return_value={"t": "plain"}),
|
|
):
|
|
result = service.get_all_datasource_credentials_by_provider("t1", "prov", "org/plug")
|
|
assert len(result) == 1
|
|
|
|
# -----------------------------------------------------------------------
|
|
# update_datasource_provider_name (lines 236-303)
|
|
# -----------------------------------------------------------------------
|
|
|
|
def test_should_raise_value_error_when_provider_not_found_on_name_update(self, service, mock_db_session):
|
|
mock_db_session.query().first.return_value = None
|
|
with pytest.raises(ValueError, match="not found"):
|
|
service.update_datasource_provider_name("t1", make_id(), "new", "cred-id")
|
|
|
|
def test_should_return_early_when_new_name_matches_current(self, service, mock_db_session):
|
|
p = MagicMock(spec=DatasourceProvider)
|
|
p.name = "same"
|
|
mock_db_session.query().first.return_value = p
|
|
service.update_datasource_provider_name("t1", make_id(), "same", "cred-id")
|
|
mock_db_session.commit.assert_not_called()
|
|
|
|
def test_should_raise_value_error_when_name_already_exists(self, service, mock_db_session):
|
|
p = MagicMock(spec=DatasourceProvider)
|
|
p.name = "old_name"
|
|
p.is_default = False
|
|
mock_db_session.query().first.return_value = p
|
|
mock_db_session.query().count.return_value = 1 # conflict
|
|
with pytest.raises(ValueError, match="already exists"):
|
|
service.update_datasource_provider_name("t1", make_id(), "new_name", "some-id")
|
|
|
|
def test_should_update_name_and_commit_when_no_conflict(self, service, mock_db_session):
|
|
p = MagicMock(spec=DatasourceProvider)
|
|
p.name = "old_name"
|
|
p.is_default = False
|
|
mock_db_session.query().first.return_value = p
|
|
mock_db_session.query().count.return_value = 0
|
|
service.update_datasource_provider_name("t1", make_id(), "new_name", "some-id")
|
|
assert p.name == "new_name"
|
|
mock_db_session.commit.assert_called_once()
|
|
|
|
# -----------------------------------------------------------------------
|
|
# set_default_datasource_provider (lines 277-303)
|
|
# -----------------------------------------------------------------------
|
|
|
|
def test_should_raise_value_error_when_target_provider_not_found(self, service, mock_db_session):
|
|
mock_db_session.query().first.return_value = None
|
|
with pytest.raises(ValueError, match="not found"):
|
|
service.set_default_datasource_provider("t1", make_id(), "bad-id")
|
|
|
|
def test_should_mark_target_as_default_and_commit(self, service, mock_db_session):
|
|
target = MagicMock(spec=DatasourceProvider)
|
|
target.provider = "provider"
|
|
target.plugin_id = "org/plug"
|
|
mock_db_session.query().first.return_value = target
|
|
service.set_default_datasource_provider("t1", make_id(), "new-id")
|
|
assert target.is_default is True
|
|
mock_db_session.commit.assert_called_once()
|
|
|
|
# -----------------------------------------------------------------------
|
|
# get_oauth_encrypter (lines 404-420)
|
|
# -----------------------------------------------------------------------
|
|
|
|
def test_should_raise_value_error_when_oauth_schema_missing(self, service):
|
|
pm = MagicMock()
|
|
pm.declaration.oauth_schema = None
|
|
with patch.object(service.provider_manager, "fetch_datasource_provider", return_value=pm):
|
|
with pytest.raises(ValueError, match="oauth schema not found"):
|
|
service.get_oauth_encrypter("t1", make_id())
|
|
|
|
def test_should_return_encrypter_when_oauth_schema_exists(self, service):
|
|
schema_item = MagicMock()
|
|
schema_item.to_basic_provider_config.return_value = MagicMock()
|
|
pm = MagicMock()
|
|
pm.declaration.oauth_schema.client_schema = [schema_item]
|
|
with (
|
|
patch.object(service.provider_manager, "fetch_datasource_provider", return_value=pm),
|
|
patch(
|
|
"services.datasource_provider_service.create_provider_encrypter",
|
|
return_value=(MagicMock(), MagicMock()),
|
|
),
|
|
):
|
|
result = service.get_oauth_encrypter("t1", make_id())
|
|
assert result is not None
|
|
|
|
# -----------------------------------------------------------------------
|
|
# get_tenant_oauth_client (lines 381-402)
|
|
# -----------------------------------------------------------------------
|
|
|
|
def test_should_return_masked_credentials_when_mask_is_true(self, service, mock_db_session):
|
|
tenant_params = MagicMock()
|
|
tenant_params.client_params = {"k": "v"}
|
|
mock_db_session.query().first.return_value = tenant_params
|
|
with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)):
|
|
result = service.get_tenant_oauth_client("t1", make_id(), mask=True)
|
|
assert result == {"k": "mask"}
|
|
|
|
def test_should_return_decrypted_credentials_when_mask_is_false(self, service, mock_db_session):
|
|
tenant_params = MagicMock()
|
|
tenant_params.client_params = {"k": "v"}
|
|
mock_db_session.query().first.return_value = tenant_params
|
|
with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)):
|
|
result = service.get_tenant_oauth_client("t1", make_id(), mask=False)
|
|
assert result == {"k": "dec"}
|
|
|
|
def test_should_return_none_when_no_tenant_oauth_config_exists(self, service, mock_db_session):
|
|
mock_db_session.query().first.return_value = None
|
|
assert service.get_tenant_oauth_client("t1", make_id()) is None
|
|
|
|
# -----------------------------------------------------------------------
|
|
# get_oauth_client (lines 423-457)
|
|
# -----------------------------------------------------------------------
|
|
|
|
def test_should_use_tenant_config_when_available(self, service, mock_db_session):
|
|
mock_db_session.query().first.return_value = MagicMock(client_params={"k": "v"})
|
|
with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)):
|
|
result = service.get_oauth_client("t1", make_id())
|
|
assert result == {"k": "dec"}
|
|
|
|
def test_should_fallback_to_system_credentials_when_tenant_config_missing(self, service, mock_db_session):
|
|
mock_db_session.query().first.side_effect = [None, MagicMock(system_credentials={"k": "sys"})]
|
|
with (
|
|
patch.object(service.provider_manager, "fetch_datasource_provider"),
|
|
patch("services.datasource_provider_service.PluginService.is_plugin_verified", return_value=True),
|
|
):
|
|
result = service.get_oauth_client("t1", make_id())
|
|
assert result == {"k": "sys"}
|
|
|
|
def test_should_raise_value_error_when_no_oauth_config_available(self, service, mock_db_session):
|
|
"""Neither tenant nor system credentials → raises ValueError."""
|
|
mock_db_session.query().first.side_effect = [None, None]
|
|
with (
|
|
patch.object(service.provider_manager, "fetch_datasource_provider"),
|
|
patch("services.datasource_provider_service.PluginService.is_plugin_verified", return_value=False),
|
|
):
|
|
with pytest.raises(ValueError, match="Please configure oauth client params"):
|
|
service.get_oauth_client("t1", make_id())
|
|
|
|
# -----------------------------------------------------------------------
|
|
# add_datasource_oauth_provider (lines 539-607)
|
|
# -----------------------------------------------------------------------
|
|
|
|
def test_should_add_oauth_provider_successfully_when_name_is_unique(self, service, mock_db_session):
|
|
mock_db_session.query().count.return_value = 0
|
|
with patch.object(service, "extract_secret_variables", return_value=[]):
|
|
service.add_datasource_oauth_provider("new", "t1", make_id(), "http://cb", 9999, {})
|
|
mock_db_session.add.assert_called_once()
|
|
mock_db_session.commit.assert_called_once()
|
|
|
|
def test_should_auto_rename_when_oauth_provider_name_conflicts(self, service, mock_db_session):
|
|
"""Conflict on name results in auto-incremented name, not an error."""
|
|
mock_db_session.query().count.return_value = 1 # conflict first, then auto-named
|
|
mock_db_session.query().all.return_value = []
|
|
with (
|
|
patch.object(service, "extract_secret_variables", return_value=[]),
|
|
patch.object(service, "generate_next_datasource_provider_name", return_value="new_gen"),
|
|
):
|
|
service.add_datasource_oauth_provider("conflict", "t1", make_id(), "http://cb", 9999, {})
|
|
mock_db_session.add.assert_called_once()
|
|
|
|
def test_should_auto_generate_name_when_none_provided_for_oauth(self, service, mock_db_session):
|
|
"""name=None causes auto-generation via generate_next_datasource_provider_name."""
|
|
mock_db_session.query().count.return_value = 0
|
|
mock_db_session.query().all.return_value = []
|
|
with (
|
|
patch.object(service, "extract_secret_variables", return_value=[]),
|
|
patch.object(service, "generate_next_datasource_provider_name", return_value="auto"),
|
|
):
|
|
service.add_datasource_oauth_provider(None, "t1", make_id(), "http://cb", 9999, {})
|
|
mock_db_session.add.assert_called_once()
|
|
|
|
def test_should_encrypt_secret_fields_when_adding_oauth_provider(self, service, mock_db_session):
|
|
mock_db_session.query().count.return_value = 0
|
|
with patch.object(service, "extract_secret_variables", return_value=["secret_key"]):
|
|
service.add_datasource_oauth_provider("nm", "t1", make_id(), "http://cb", 9999, {"secret_key": "value"})
|
|
self._enc.encrypt_token.assert_called()
|
|
|
|
def test_should_acquire_redis_lock_when_adding_oauth_provider(self, service, mock_db_session):
|
|
mock_db_session.query().count.return_value = 0
|
|
with patch.object(service, "extract_secret_variables", return_value=[]):
|
|
service.add_datasource_oauth_provider("nm", "t1", make_id(), "http://cb", 9999, {})
|
|
self._redis.lock.assert_called()
|
|
|
|
# -----------------------------------------------------------------------
|
|
# reauthorize_datasource_oauth_provider (lines 477-537)
|
|
# -----------------------------------------------------------------------
|
|
|
|
def test_should_raise_value_error_when_credential_id_not_found_on_reauth(self, service, mock_db_session):
|
|
mock_db_session.query().first.return_value = None
|
|
with patch.object(service, "extract_secret_variables", return_value=[]):
|
|
with pytest.raises(ValueError, match="not found"):
|
|
service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "bad-id")
|
|
|
|
def test_should_reauthorize_and_commit_when_credential_found(self, service, mock_db_session):
|
|
p = MagicMock(spec=DatasourceProvider)
|
|
mock_db_session.query().first.return_value = p
|
|
mock_db_session.query().count.return_value = 0
|
|
with patch.object(service, "extract_secret_variables", return_value=[]):
|
|
service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "oid")
|
|
mock_db_session.commit.assert_called_once()
|
|
|
|
def test_should_auto_rename_when_reauth_name_conflicts(self, service, mock_db_session):
|
|
p = MagicMock(spec=DatasourceProvider)
|
|
mock_db_session.query().first.return_value = p
|
|
mock_db_session.query().count.return_value = 1 # conflict
|
|
mock_db_session.query().all.return_value = []
|
|
with patch.object(service, "extract_secret_variables", return_value=["tok"]):
|
|
service.reauthorize_datasource_oauth_provider(
|
|
"conflict_name", "t1", make_id(), "u", 9999, {"tok": "v"}, "cred-id"
|
|
)
|
|
mock_db_session.commit.assert_called_once()
|
|
|
|
def test_should_encrypt_secret_fields_when_reauthorizing(self, service, mock_db_session):
|
|
p = MagicMock(spec=DatasourceProvider)
|
|
mock_db_session.query().first.return_value = p
|
|
mock_db_session.query().count.return_value = 0
|
|
with patch.object(service, "extract_secret_variables", return_value=["tok"]):
|
|
service.reauthorize_datasource_oauth_provider(None, "t1", make_id(), "u", 9999, {"tok": "val"}, "cred-id")
|
|
self._enc.encrypt_token.assert_called()
|
|
|
|
def test_should_acquire_redis_lock_when_reauthorizing(self, service, mock_db_session):
|
|
p = MagicMock(spec=DatasourceProvider)
|
|
mock_db_session.query().first.return_value = p
|
|
mock_db_session.query().count.return_value = 0
|
|
with patch.object(service, "extract_secret_variables", return_value=[]):
|
|
service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "oid")
|
|
self._redis.lock.assert_called()
|
|
|
|
# -----------------------------------------------------------------------
|
|
# add_datasource_api_key_provider (lines 608-675)
|
|
# -----------------------------------------------------------------------
|
|
|
|
def test_should_raise_value_error_when_api_key_name_already_exists(self, service, mock_db_session, mock_user):
|
|
"""explicit name supplied + conflict → raises ValueError immediately."""
|
|
mock_db_session.query().count.return_value = 1
|
|
with patch("services.datasource_provider_service.get_current_user", return_value=mock_user):
|
|
with pytest.raises(ValueError, match="already exists"):
|
|
service.add_datasource_api_key_provider("clash", "t1", make_id(), {"sk": "v"})
|
|
|
|
def test_should_raise_value_error_when_credentials_validation_fails(self, service, mock_db_session, mock_user):
|
|
mock_db_session.query().count.return_value = 0
|
|
with (
|
|
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
|
|
patch.object(service.provider_manager, "validate_provider_credentials", side_effect=Exception("bad cred")),
|
|
patch.object(service, "extract_secret_variables", return_value=[]),
|
|
):
|
|
with pytest.raises(ValueError, match="Failed to validate"):
|
|
service.add_datasource_api_key_provider("nm", "t1", make_id(), {"k": "v"})
|
|
|
|
def test_should_add_api_key_provider_and_commit_when_valid(self, service, mock_db_session, mock_user):
|
|
mock_db_session.query().count.return_value = 0
|
|
with (
|
|
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
|
|
patch.object(service.provider_manager, "validate_provider_credentials"),
|
|
patch.object(service, "extract_secret_variables", return_value=["sk"]),
|
|
):
|
|
service.add_datasource_api_key_provider(None, "t1", make_id(), {"sk": "v"})
|
|
mock_db_session.add.assert_called_once()
|
|
mock_db_session.commit.assert_called_once()
|
|
|
|
def test_should_acquire_redis_lock_when_adding_api_key_provider(self, service, mock_db_session, mock_user):
|
|
mock_db_session.query().count.return_value = 0
|
|
with (
|
|
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
|
|
patch.object(service.provider_manager, "validate_provider_credentials"),
|
|
patch.object(service, "extract_secret_variables", return_value=[]),
|
|
):
|
|
service.add_datasource_api_key_provider(None, "t1", make_id(), {})
|
|
self._redis.lock.assert_called()
|
|
|
|
# -----------------------------------------------------------------------
|
|
# extract_secret_variables (lines 666-699)
|
|
# -----------------------------------------------------------------------
|
|
|
|
def test_should_extract_secret_variable_names_for_api_key_schema(self, service):
|
|
schema = MagicMock()
|
|
schema.name = "my_secret"
|
|
schema.type = MagicMock()
|
|
schema.type.value = FormType.SECRET_INPUT # "secret-input"
|
|
pm = MagicMock()
|
|
pm.declaration.credentials_schema = [schema]
|
|
with patch.object(service.provider_manager, "fetch_datasource_provider", return_value=pm):
|
|
result = service.extract_secret_variables("t1", "org/plug/prov", CredentialType.API_KEY)
|
|
assert "my_secret" in result
|
|
|
|
def test_should_extract_secret_variable_names_for_oauth2_schema(self, service):
|
|
schema = MagicMock()
|
|
schema.name = "oauth_secret"
|
|
schema.type = MagicMock()
|
|
schema.type.value = FormType.SECRET_INPUT
|
|
pm = MagicMock()
|
|
pm.declaration.oauth_schema.credentials_schema = [schema]
|
|
with patch.object(service.provider_manager, "fetch_datasource_provider", return_value=pm):
|
|
result = service.extract_secret_variables("t1", "org/plug/prov", CredentialType.OAUTH2)
|
|
assert "oauth_secret" in result
|
|
|
|
def test_should_raise_value_error_when_credential_type_is_invalid(self, service):
|
|
pm = MagicMock()
|
|
with patch.object(service.provider_manager, "fetch_datasource_provider", return_value=pm):
|
|
with pytest.raises(ValueError, match="Invalid credential type"):
|
|
service.extract_secret_variables("t1", "org/plug/prov", CredentialType.UNAUTHORIZED)
|
|
|
|
# -----------------------------------------------------------------------
|
|
# list_datasource_credentials (lines 721-754)
|
|
# -----------------------------------------------------------------------
|
|
|
|
def test_should_return_empty_list_when_no_credentials_stored(self, service, mock_db_session):
|
|
mock_db_session.query().all.return_value = []
|
|
assert service.list_datasource_credentials("t1", "prov", "org/plug") == []
|
|
|
|
def test_should_return_masked_credentials_list_when_credentials_exist(self, service, mock_db_session):
|
|
p = MagicMock(spec=DatasourceProvider)
|
|
p.auth_type = "api_key"
|
|
p.encrypted_credentials = {"sk": "v"}
|
|
p.is_default = False
|
|
mock_db_session.query().all.return_value = [p]
|
|
with patch.object(service, "extract_secret_variables", return_value=["sk"]):
|
|
result = service.list_datasource_credentials("t1", "prov", "org/plug")
|
|
assert len(result) == 1
|
|
|
|
# -----------------------------------------------------------------------
|
|
# get_all_datasource_credentials (lines 808-871)
|
|
# -----------------------------------------------------------------------
|
|
|
|
def test_should_aggregate_credentials_for_non_hardcoded_plugin(self, service):
|
|
with patch("services.datasource_provider_service.PluginDatasourceManager") as mock_mgr:
|
|
ds = MagicMock()
|
|
ds.provider = "prov"
|
|
ds.plugin_id = "org/plug"
|
|
ds.declaration.identity.label.model_dump.return_value = {"en_US": "Label"}
|
|
mock_mgr.return_value.fetch_installed_datasource_providers.return_value = [ds]
|
|
cred = {"credential": {"k": "v"}, "is_default": True}
|
|
with patch.object(service, "list_datasource_credentials", return_value=[cred]):
|
|
results = service.get_all_datasource_credentials("t1")
|
|
assert len(results) == 1
|
|
|
|
def test_should_include_oauth_schema_for_hardcoded_plugin_ids(self, service, mock_db_session):
|
|
"""Lines 819-871: get_all_datasource_credentials covers hardcoded langgenius plugin IDs."""
|
|
with patch("services.datasource_provider_service.PluginDatasourceManager") as mock_mgr:
|
|
ds = MagicMock()
|
|
ds.plugin_id = "langgenius/firecrawl_datasource"
|
|
ds.provider = "firecrawl"
|
|
ds.plugin_unique_identifier = "pui"
|
|
ds.declaration.identity.icon = "icon"
|
|
ds.declaration.identity.name = "langgenius/firecrawl_datasource"
|
|
ds.declaration.identity.label.model_dump.return_value = {"en_US": "Firecrawl"}
|
|
ds.declaration.identity.description.model_dump.return_value = {"en_US": "desc"}
|
|
ds.declaration.identity.author = "langgenius"
|
|
ds.declaration.credentials_schema = []
|
|
ds.declaration.oauth_schema.client_schema = []
|
|
ds.declaration.oauth_schema.credentials_schema = []
|
|
mock_mgr.return_value.fetch_installed_datasource_providers.return_value = [ds]
|
|
with (
|
|
patch.object(service, "list_datasource_credentials", return_value=[]),
|
|
patch.object(service, "get_tenant_oauth_client", return_value=None),
|
|
patch.object(service, "is_tenant_oauth_params_enabled", return_value=False),
|
|
patch.object(service, "is_system_oauth_params_exist", return_value=False),
|
|
):
|
|
results = service.get_all_datasource_credentials("t1")
|
|
assert len(results) == 1
|
|
assert results[0]["oauth_schema"] is not None
|
|
|
|
# -----------------------------------------------------------------------
|
|
# get_real_datasource_credentials (lines 873-915)
|
|
# -----------------------------------------------------------------------
|
|
|
|
def test_should_return_empty_list_when_no_real_credentials_exist(self, service, mock_db_session):
|
|
mock_db_session.query().all.return_value = []
|
|
assert service.get_real_datasource_credentials("t1", "prov", "org/plug") == []
|
|
|
|
def test_should_return_decrypted_credential_list_when_credentials_exist(self, service, mock_db_session):
|
|
p = MagicMock(spec=DatasourceProvider)
|
|
p.auth_type = "api_key"
|
|
p.encrypted_credentials = {"sk": "v"}
|
|
mock_db_session.query().all.return_value = [p]
|
|
with patch.object(service, "extract_secret_variables", return_value=["sk"]):
|
|
result = service.get_real_datasource_credentials("t1", "prov", "org/plug")
|
|
assert len(result) == 1
|
|
|
|
# -----------------------------------------------------------------------
|
|
# update_datasource_credentials (lines 917-978)
|
|
# -----------------------------------------------------------------------
|
|
|
|
def test_should_raise_value_error_when_credential_not_found_on_update(self, service, mock_db_session, mock_user):
|
|
mock_db_session.query().first.return_value = None
|
|
with patch("services.datasource_provider_service.get_current_user", return_value=mock_user):
|
|
with pytest.raises(ValueError, match="not found"):
|
|
service.update_datasource_credentials("t1", "id", "prov", "org/plug", {}, "name")
|
|
|
|
def test_should_raise_value_error_when_new_name_already_used_on_update(self, service, mock_db_session, mock_user):
|
|
p = MagicMock(spec=DatasourceProvider)
|
|
p.name = "old_name"
|
|
p.auth_type = "api_key"
|
|
p.encrypted_credentials = {"sk": "e"}
|
|
mock_db_session.query().first.return_value = p
|
|
mock_db_session.query().count.return_value = 1
|
|
with patch("services.datasource_provider_service.get_current_user", return_value=mock_user):
|
|
with pytest.raises(ValueError, match="already exists"):
|
|
service.update_datasource_credentials("t1", "id", "prov", "org/plug", {}, "new_name")
|
|
|
|
def test_should_raise_value_error_when_credential_validation_fails_on_update(
|
|
self, service, mock_db_session, mock_user
|
|
):
|
|
p = MagicMock(spec=DatasourceProvider)
|
|
p.name = "old_name"
|
|
p.auth_type = "api_key"
|
|
p.encrypted_credentials = {"sk": "e"}
|
|
mock_db_session.query().first.return_value = p
|
|
mock_db_session.query().count.return_value = 0
|
|
with (
|
|
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
|
|
patch.object(service, "extract_secret_variables", return_value=["sk"]),
|
|
patch.object(service.provider_manager, "validate_provider_credentials", side_effect=Exception("bad")),
|
|
):
|
|
with pytest.raises(ValueError, match="Failed to validate"):
|
|
service.update_datasource_credentials("t1", "id", "prov", "org/plug", {"sk": "v"}, "name")
|
|
|
|
def test_should_encrypt_credentials_and_commit_when_update_succeeds(self, service, mock_db_session, mock_user):
|
|
"""Verifies that encrypted_credentials is reassigned with encrypted value and commit is called."""
|
|
p = MagicMock(spec=DatasourceProvider)
|
|
p.name = "old_name"
|
|
p.auth_type = "api_key"
|
|
p.encrypted_credentials = {"sk": "old_enc"}
|
|
mock_db_session.query().first.return_value = p
|
|
mock_db_session.query().count.return_value = 0
|
|
with (
|
|
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
|
|
patch.object(service, "extract_secret_variables", return_value=["sk"]),
|
|
patch.object(service.provider_manager, "validate_provider_credentials"),
|
|
):
|
|
service.update_datasource_credentials("t1", "id", "prov", "org/plug", {"sk": "new_val"}, "name")
|
|
# encrypter must have been called with the new secret value
|
|
self._enc.encrypt_token.assert_called()
|
|
# commit must be called exactly once
|
|
mock_db_session.commit.assert_called_once()
|
|
|
|
# -----------------------------------------------------------------------
|
|
# remove_datasource_credentials (lines 980-997)
|
|
# -----------------------------------------------------------------------
|
|
|
|
def test_should_delete_provider_and_commit_when_found(self, service, mock_db_session):
|
|
p = MagicMock(spec=DatasourceProvider)
|
|
mock_db_session.query().first.return_value = p
|
|
service.remove_datasource_credentials("t1", "id", "prov", "org/plug")
|
|
mock_db_session.delete.assert_called_once_with(p)
|
|
mock_db_session.commit.assert_called_once()
|
|
|
|
def test_should_do_nothing_when_credential_not_found_on_remove(self, service, mock_db_session):
|
|
"""No error raised; no delete called when record doesn't exist (lines 994 branch)."""
|
|
mock_db_session.query().first.return_value = None
|
|
service.remove_datasource_credentials("t1", "id", "prov", "org/plug")
|
|
mock_db_session.delete.assert_not_called()
|