mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 01:48:04 +08:00
feat: add custom OAuth client setup and enhance datasource provider model with avatar_url
This commit is contained in:
@ -1,19 +1,22 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from flask_login import current_user
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
||||
from core.helper import encrypter
|
||||
from core.helper.name_generator import generate_incremental_name
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.model_runtime.entities.provider_entities import FormType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.plugin.entities.plugin import DatasourceProviderID
|
||||
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||
from core.tools.entities.tool_entities import CredentialType
|
||||
from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.oauth import DatasourceProvider
|
||||
from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -26,6 +29,165 @@ class DatasourceProviderService:
|
||||
def __init__(self) -> None:
|
||||
self.provider_manager = PluginDatasourceManager()
|
||||
|
||||
def setup_oauth_custom_client_params(
|
||||
self,
|
||||
tenant_id: str,
|
||||
datasource_provider_id: DatasourceProviderID,
|
||||
client_params: dict | None,
|
||||
enabled: bool | None,
|
||||
):
|
||||
"""
|
||||
setup oauth custom client params
|
||||
"""
|
||||
if client_params is None and enabled is None:
|
||||
return
|
||||
provider_controller = PluginDatasourceManager()
|
||||
datasource_provider = provider_controller.fetch_datasource_provider(
|
||||
tenant_id=tenant_id, provider_id=str(datasource_provider_id)
|
||||
)
|
||||
if not datasource_provider.declaration.oauth_schema:
|
||||
raise ValueError("Datasource provider oauth schema not found")
|
||||
with Session(db.engine) as session:
|
||||
tenant_oauth_client_params = (
|
||||
session.query(DatasourceOauthTenantParamConfig)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not tenant_oauth_client_params:
|
||||
tenant_oauth_client_params = DatasourceOauthTenantParamConfig(
|
||||
tenant_id=tenant_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
client_params={},
|
||||
enabled=False,
|
||||
)
|
||||
session.add(tenant_oauth_client_params)
|
||||
|
||||
if client_params is not None:
|
||||
client_schema = datasource_provider.declaration.oauth_schema.client_schema
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in client_schema],
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
original_params = (
|
||||
encrypter.decrypt(tenant_oauth_client_params.client_params) if tenant_oauth_client_params else {}
|
||||
)
|
||||
new_params: dict = {
|
||||
key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
|
||||
for key, value in client_params.items()
|
||||
}
|
||||
tenant_oauth_client_params.client_params = encrypter.encrypt(new_params)
|
||||
|
||||
if enabled is not None:
|
||||
tenant_oauth_client_params.enabled = enabled
|
||||
session.commit()
|
||||
|
||||
def is_system_oauth_params_exist(self, datasource_provider_id: DatasourceProviderID) -> bool:
|
||||
"""
|
||||
check if system oauth params exist
|
||||
"""
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
return (
|
||||
session.query(DatasourceOauthParamConfig)
|
||||
.filter_by(provider=datasource_provider_id.provider_name, plugin_id=datasource_provider_id.plugin_id)
|
||||
.first()
|
||||
is not None
|
||||
)
|
||||
|
||||
def is_tenant_oauth_params_enabled(self, tenant_id: str, datasource_provider_id: DatasourceProviderID) -> bool:
|
||||
"""
|
||||
check if tenant oauth params is enabled
|
||||
"""
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
return (
|
||||
session.query(DatasourceOauthTenantParamConfig)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
enabled=True,
|
||||
)
|
||||
.count()
|
||||
> 0
|
||||
)
|
||||
|
||||
def get_tenant_oauth_client(
|
||||
self, tenant_id: str, datasource_provider_id: DatasourceProviderID
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
get tenant oauth client
|
||||
"""
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
tenant_oauth_client_params = (
|
||||
session.query(DatasourceOauthTenantParamConfig)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if tenant_oauth_client_params:
|
||||
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
|
||||
return encrypter.decrypt(tenant_oauth_client_params.client_params)
|
||||
return None
|
||||
|
||||
def get_oauth_encrypter(
|
||||
self, tenant_id: str, datasource_provider_id: DatasourceProviderID
|
||||
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
|
||||
"""
|
||||
get oauth encrypter
|
||||
"""
|
||||
datasource_provider = self.provider_manager.fetch_datasource_provider(
|
||||
tenant_id=tenant_id, provider_id=str(datasource_provider_id)
|
||||
)
|
||||
if not datasource_provider.declaration.oauth_schema:
|
||||
raise ValueError("Datasource provider oauth schema not found")
|
||||
|
||||
client_schema = datasource_provider.declaration.oauth_schema.client_schema
|
||||
return create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in client_schema],
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
def get_oauth_client(self, tenant_id: str, datasource_provider_id: DatasourceProviderID) -> dict[str, Any] | None:
|
||||
"""
|
||||
get oauth client
|
||||
"""
|
||||
provider = datasource_provider_id.provider_name
|
||||
plugin_id = datasource_provider_id.plugin_id
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
# get tenant oauth client params
|
||||
tenant_oauth_client_params = (
|
||||
session.query(DatasourceOauthTenantParamConfig)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id,
|
||||
enabled=True,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if tenant_oauth_client_params:
|
||||
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
|
||||
return encrypter.decrypt(tenant_oauth_client_params.client_params)
|
||||
|
||||
# fallback to system oauth client params
|
||||
oauth_client_params = (
|
||||
session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
|
||||
)
|
||||
if oauth_client_params:
|
||||
return oauth_client_params.system_credentials
|
||||
|
||||
raise ValueError(f"Please configure oauth client params(system/tenant) for {plugin_id}/{provider}")
|
||||
|
||||
@staticmethod
|
||||
def generate_next_datasource_provider_name(
|
||||
session: Session, tenant_id: str, provider_id: DatasourceProviderID, credential_type: CredentialType
|
||||
@ -69,24 +231,29 @@ class DatasourceProviderService:
|
||||
credential_type=credential_type,
|
||||
)
|
||||
else:
|
||||
if session.query(DatasourceProvider).filter_by(
|
||||
tenant_id=tenant_id,
|
||||
name=db_provider_name,
|
||||
provider=provider_id.provider_name,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
auth_type=credential_type.value,
|
||||
).count() > 0:
|
||||
if (
|
||||
session.query(DatasourceProvider)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
name=db_provider_name,
|
||||
provider=provider_id.provider_name,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
auth_type=credential_type.value,
|
||||
)
|
||||
.count()
|
||||
> 0
|
||||
):
|
||||
db_provider_name = generate_incremental_name(
|
||||
[
|
||||
provider.name
|
||||
for provider in session.query(DatasourceProvider).filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_id.provider_name,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
)
|
||||
],
|
||||
db_provider_name,
|
||||
)
|
||||
[
|
||||
provider.name
|
||||
for provider in session.query(DatasourceProvider).filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_id.provider_name,
|
||||
plugin_id=provider_id.plugin_id,
|
||||
)
|
||||
],
|
||||
db_provider_name,
|
||||
)
|
||||
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
tenant_id=tenant_id, provider_id=f"{provider_id}"
|
||||
@ -103,7 +270,7 @@ class DatasourceProviderService:
|
||||
plugin_id=provider_id.plugin_id,
|
||||
auth_type=credential_type.value,
|
||||
encrypted_credentials=credentials,
|
||||
avatar_url=avatar_url,
|
||||
avatar_url=avatar_url or "default",
|
||||
)
|
||||
session.add(datasource_provider)
|
||||
session.commit()
|
||||
@ -222,6 +389,7 @@ class DatasourceProviderService:
|
||||
"credential": copy_credentials,
|
||||
"type": datasource_provider.auth_type,
|
||||
"name": datasource_provider.name,
|
||||
"avatar_url": datasource_provider.avatar_url,
|
||||
"id": datasource_provider.id,
|
||||
}
|
||||
)
|
||||
@ -239,6 +407,7 @@ class DatasourceProviderService:
|
||||
datasources = manager.fetch_installed_datasource_providers(tenant_id)
|
||||
datasource_credentials = []
|
||||
for datasource in datasources:
|
||||
datasource_provider_id = DatasourceProviderID(f"{datasource.plugin_id}/{datasource.provider}")
|
||||
credentials = self.get_datasource_credentials(
|
||||
tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
|
||||
)
|
||||
@ -302,6 +471,11 @@ class DatasourceProviderService:
|
||||
}
|
||||
for credential in datasource.declaration.oauth_schema.credentials_schema or []
|
||||
],
|
||||
"oauth_custom_client_params": self.get_tenant_oauth_client(tenant_id, datasource_provider_id),
|
||||
"is_oauth_custom_client_enabled": self.is_tenant_oauth_params_enabled(
|
||||
tenant_id, datasource_provider_id
|
||||
),
|
||||
"is_system_oauth_params_exists": self.is_system_oauth_params_exist(datasource_provider_id),
|
||||
}
|
||||
if datasource.declaration.oauth_schema
|
||||
else None,
|
||||
|
||||
Reference in New Issue
Block a user