mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 08:58:09 +08:00
add multi model credentials
This commit is contained in:
@ -11,6 +11,7 @@ from configs import dify_config
|
||||
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle
|
||||
from core.entities.provider_entities import (
|
||||
CredentialConfiguration,
|
||||
CustomConfiguration,
|
||||
CustomModelConfiguration,
|
||||
CustomProviderConfiguration,
|
||||
@ -39,7 +40,9 @@ from extensions.ext_redis import redis_client
|
||||
from models.provider import (
|
||||
LoadBalancingModelConfig,
|
||||
Provider,
|
||||
ProviderCredential,
|
||||
ProviderModel,
|
||||
ProviderModelCredential,
|
||||
ProviderModelSetting,
|
||||
ProviderType,
|
||||
TenantDefaultModel,
|
||||
@ -487,6 +490,61 @@ class ProviderManager:
|
||||
|
||||
return provider_name_to_provider_load_balancing_model_configs_dict
|
||||
|
||||
@staticmethod
|
||||
def get_provider_available_credentials(tenant_id: str, provider_name: str) -> list[CredentialConfiguration]:
|
||||
"""
|
||||
Get provider all credentials.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider_name: provider name
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = (
|
||||
select(ProviderCredential)
|
||||
.where(ProviderCredential.tenant_id == tenant_id, ProviderCredential.provider_name == provider_name)
|
||||
.order_by(ProviderCredential.created_at.desc())
|
||||
)
|
||||
|
||||
available_credentials = session.scalars(stmt).all()
|
||||
|
||||
return [
|
||||
CredentialConfiguration(credential_id=credential.id, credential_name=credential.credential_name)
|
||||
for credential in available_credentials
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_provider_model_available_credentials(
|
||||
tenant_id: str, provider_name: str, model_name: str, model_type: str
|
||||
) -> list[CredentialConfiguration]:
|
||||
"""
|
||||
Get provider custom model all credentials.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider_name: provider name
|
||||
:param model_name: model name
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = (
|
||||
select(ProviderModelCredential)
|
||||
.where(
|
||||
ProviderModelCredential.tenant_id == tenant_id,
|
||||
ProviderModelCredential.provider_name == provider_name,
|
||||
ProviderModelCredential.model_name == model_name,
|
||||
ProviderModelCredential.model_type == model_type,
|
||||
)
|
||||
.order_by(ProviderModelCredential.created_at.desc())
|
||||
)
|
||||
|
||||
available_credentials = session.scalars(stmt).all()
|
||||
|
||||
return [
|
||||
CredentialConfiguration(credential_id=credential.id, credential_name=credential.credential_name)
|
||||
for credential in available_credentials
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _init_trial_provider_records(
|
||||
tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]]
|
||||
@ -589,9 +647,6 @@ class ProviderManager:
|
||||
if provider_record.provider_type == ProviderType.SYSTEM.value:
|
||||
continue
|
||||
|
||||
if not provider_record.encrypted_config:
|
||||
continue
|
||||
|
||||
custom_provider_record = provider_record
|
||||
|
||||
# Get custom provider credentials
|
||||
@ -610,8 +665,8 @@ class ProviderManager:
|
||||
try:
|
||||
# fix origin data
|
||||
if custom_provider_record.encrypted_config is None:
|
||||
raise ValueError("No credentials found")
|
||||
if not custom_provider_record.encrypted_config.startswith("{"):
|
||||
provider_credentials = {}
|
||||
elif not custom_provider_record.encrypted_config.startswith("{"):
|
||||
provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config}
|
||||
else:
|
||||
provider_credentials = json.loads(custom_provider_record.encrypted_config)
|
||||
@ -638,7 +693,14 @@ class ProviderManager:
|
||||
else:
|
||||
provider_credentials = cached_provider_credentials
|
||||
|
||||
custom_provider_configuration = CustomProviderConfiguration(credentials=provider_credentials)
|
||||
custom_provider_configuration = CustomProviderConfiguration(
|
||||
credentials=provider_credentials,
|
||||
current_credential_name=custom_provider_record.credential_name,
|
||||
current_credential_id=custom_provider_record.credential_id,
|
||||
available_credentials=self.get_provider_available_credentials(
|
||||
tenant_id, custom_provider_record.provider_name
|
||||
),
|
||||
)
|
||||
|
||||
# Get provider model credential secret variables
|
||||
model_credential_secret_variables = self._extract_secret_variables(
|
||||
@ -650,8 +712,12 @@ class ProviderManager:
|
||||
# Get custom provider model credentials
|
||||
custom_model_configurations = []
|
||||
for provider_model_record in provider_model_records:
|
||||
if not provider_model_record.encrypted_config:
|
||||
continue
|
||||
available_model_credentials = self.get_provider_model_available_credentials(
|
||||
tenant_id,
|
||||
provider_model_record.provider_name,
|
||||
provider_model_record.model_name,
|
||||
provider_model_record.model_type,
|
||||
)
|
||||
|
||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL
|
||||
@ -691,6 +757,8 @@ class ProviderManager:
|
||||
model=provider_model_record.model_name,
|
||||
model_type=ModelType.value_of(provider_model_record.model_type),
|
||||
credentials=provider_model_credentials,
|
||||
current_credential_id=provider_model_record.credential_id,
|
||||
available_model_credentials=available_model_credentials,
|
||||
)
|
||||
)
|
||||
|
||||
@ -894,6 +962,7 @@ class ProviderManager:
|
||||
if not provider_model_settings:
|
||||
return model_settings
|
||||
|
||||
has_invalid_load_balancing_configs = False
|
||||
for provider_model_setting in provider_model_settings:
|
||||
load_balancing_configs = []
|
||||
if provider_model_setting.load_balancing_enabled and load_balancing_model_configs:
|
||||
@ -902,6 +971,10 @@ class ProviderManager:
|
||||
load_balancing_model_config.model_name == provider_model_setting.model_name
|
||||
and load_balancing_model_config.model_type == provider_model_setting.model_type
|
||||
):
|
||||
if load_balancing_model_config.name == "__delete__":
|
||||
has_invalid_load_balancing_configs = True
|
||||
continue
|
||||
|
||||
if not load_balancing_model_config.enabled:
|
||||
continue
|
||||
|
||||
@ -967,6 +1040,7 @@ class ProviderManager:
|
||||
model_type=ModelType.value_of(provider_model_setting.model_type),
|
||||
enabled=provider_model_setting.enabled,
|
||||
load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [],
|
||||
has_invalid_load_balancing_configs=has_invalid_load_balancing_configs,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user