feat: add custom OAuth client setup and enhance datasource provider model with avatar_url

This commit is contained in:
Harry
2025-07-21 12:35:07 +08:00
parent 7364d051d2
commit e97f03c130
7 changed files with 294 additions and 33 deletions

View File

@ -1,19 +1,22 @@
import logging
from typing import Any
from flask_login import current_user
from sqlalchemy.orm import Session
from constants import HIDDEN_VALUE
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from core.helper import encrypter
from core.helper.name_generator import generate_incremental_name
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.model_runtime.entities.provider_entities import FormType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.plugin.entities.plugin import DatasourceProviderID
from core.plugin.impl.datasource import PluginDatasourceManager
from core.tools.entities.tool_entities import CredentialType
from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.oauth import DatasourceProvider
from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider
logger = logging.getLogger(__name__)
@ -26,6 +29,165 @@ class DatasourceProviderService:
def __init__(self) -> None:
self.provider_manager = PluginDatasourceManager()
def setup_oauth_custom_client_params(
self,
tenant_id: str,
datasource_provider_id: DatasourceProviderID,
client_params: dict | None,
enabled: bool | None,
):
"""
setup oauth custom client params
"""
if client_params is None and enabled is None:
return
provider_controller = PluginDatasourceManager()
datasource_provider = provider_controller.fetch_datasource_provider(
tenant_id=tenant_id, provider_id=str(datasource_provider_id)
)
if not datasource_provider.declaration.oauth_schema:
raise ValueError("Datasource provider oauth schema not found")
with Session(db.engine) as session:
tenant_oauth_client_params = (
session.query(DatasourceOauthTenantParamConfig)
.filter_by(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
)
.first()
)
if not tenant_oauth_client_params:
tenant_oauth_client_params = DatasourceOauthTenantParamConfig(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
client_params={},
enabled=False,
)
session.add(tenant_oauth_client_params)
if client_params is not None:
client_schema = datasource_provider.declaration.oauth_schema.client_schema
encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in client_schema],
cache=NoOpProviderCredentialCache(),
)
original_params = (
encrypter.decrypt(tenant_oauth_client_params.client_params) if tenant_oauth_client_params else {}
)
new_params: dict = {
key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
for key, value in client_params.items()
}
tenant_oauth_client_params.client_params = encrypter.encrypt(new_params)
if enabled is not None:
tenant_oauth_client_params.enabled = enabled
session.commit()
def is_system_oauth_params_exist(self, datasource_provider_id: DatasourceProviderID) -> bool:
"""
check if system oauth params exist
"""
with Session(db.engine).no_autoflush as session:
return (
session.query(DatasourceOauthParamConfig)
.filter_by(provider=datasource_provider_id.provider_name, plugin_id=datasource_provider_id.plugin_id)
.first()
is not None
)
def is_tenant_oauth_params_enabled(self, tenant_id: str, datasource_provider_id: DatasourceProviderID) -> bool:
"""
check if tenant oauth params is enabled
"""
with Session(db.engine).no_autoflush as session:
return (
session.query(DatasourceOauthTenantParamConfig)
.filter_by(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
enabled=True,
)
.count()
> 0
)
def get_tenant_oauth_client(
self, tenant_id: str, datasource_provider_id: DatasourceProviderID
) -> dict[str, Any] | None:
"""
get tenant oauth client
"""
with Session(db.engine).no_autoflush as session:
tenant_oauth_client_params = (
session.query(DatasourceOauthTenantParamConfig)
.filter_by(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
)
.first()
)
if tenant_oauth_client_params:
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
return encrypter.decrypt(tenant_oauth_client_params.client_params)
return None
def get_oauth_encrypter(
self, tenant_id: str, datasource_provider_id: DatasourceProviderID
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
"""
get oauth encrypter
"""
datasource_provider = self.provider_manager.fetch_datasource_provider(
tenant_id=tenant_id, provider_id=str(datasource_provider_id)
)
if not datasource_provider.declaration.oauth_schema:
raise ValueError("Datasource provider oauth schema not found")
client_schema = datasource_provider.declaration.oauth_schema.client_schema
return create_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in client_schema],
cache=NoOpProviderCredentialCache(),
)
def get_oauth_client(self, tenant_id: str, datasource_provider_id: DatasourceProviderID) -> dict[str, Any] | None:
"""
get oauth client
"""
provider = datasource_provider_id.provider_name
plugin_id = datasource_provider_id.plugin_id
with Session(db.engine).no_autoflush as session:
# get tenant oauth client params
tenant_oauth_client_params = (
session.query(DatasourceOauthTenantParamConfig)
.filter_by(
tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id,
enabled=True,
)
.first()
)
if tenant_oauth_client_params:
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
return encrypter.decrypt(tenant_oauth_client_params.client_params)
# fallback to system oauth client params
oauth_client_params = (
session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
)
if oauth_client_params:
return oauth_client_params.system_credentials
raise ValueError(f"Please configure oauth client params(system/tenant) for {plugin_id}/{provider}")
@staticmethod
def generate_next_datasource_provider_name(
session: Session, tenant_id: str, provider_id: DatasourceProviderID, credential_type: CredentialType
@ -69,24 +231,29 @@ class DatasourceProviderService:
credential_type=credential_type,
)
else:
if session.query(DatasourceProvider).filter_by(
tenant_id=tenant_id,
name=db_provider_name,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
auth_type=credential_type.value,
).count() > 0:
if (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
name=db_provider_name,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
auth_type=credential_type.value,
)
.count()
> 0
):
db_provider_name = generate_incremental_name(
[
provider.name
for provider in session.query(DatasourceProvider).filter_by(
tenant_id=tenant_id,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
)
],
db_provider_name,
)
[
provider.name
for provider in session.query(DatasourceProvider).filter_by(
tenant_id=tenant_id,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
)
],
db_provider_name,
)
provider_credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id, provider_id=f"{provider_id}"
@ -103,7 +270,7 @@ class DatasourceProviderService:
plugin_id=provider_id.plugin_id,
auth_type=credential_type.value,
encrypted_credentials=credentials,
avatar_url=avatar_url,
avatar_url=avatar_url or "default",
)
session.add(datasource_provider)
session.commit()
@ -222,6 +389,7 @@ class DatasourceProviderService:
"credential": copy_credentials,
"type": datasource_provider.auth_type,
"name": datasource_provider.name,
"avatar_url": datasource_provider.avatar_url,
"id": datasource_provider.id,
}
)
@ -239,6 +407,7 @@ class DatasourceProviderService:
datasources = manager.fetch_installed_datasource_providers(tenant_id)
datasource_credentials = []
for datasource in datasources:
datasource_provider_id = DatasourceProviderID(f"{datasource.plugin_id}/{datasource.provider}")
credentials = self.get_datasource_credentials(
tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
)
@ -302,6 +471,11 @@ class DatasourceProviderService:
}
for credential in datasource.declaration.oauth_schema.credentials_schema or []
],
"oauth_custom_client_params": self.get_tenant_oauth_client(tenant_id, datasource_provider_id),
"is_oauth_custom_client_enabled": self.is_tenant_oauth_params_enabled(
tenant_id, datasource_provider_id
),
"is_system_oauth_params_exists": self.is_system_oauth_params_exist(datasource_provider_id),
}
if datasource.declaration.oauth_schema
else None,