feat: add multi model credentials (#24451)

Co-authored-by: zxhlyh <jasonapring2015@outlook.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
非法操作
2025-08-25 16:12:29 +08:00
committed by GitHub
parent b08bfa203a
commit 6010d5f24c
65 changed files with 5202 additions and 1814 deletions

View File

@ -19,6 +19,7 @@ class ModelStatus(Enum):
QUOTA_EXCEEDED = "quota-exceeded"
NO_PERMISSION = "no-permission"
DISABLED = "disabled"
CREDENTIAL_REMOVED = "credential-removed"
class SimpleModelProviderEntity(BaseModel):
@ -54,6 +55,7 @@ class ProviderModelWithStatusEntity(ProviderModel):
status: ModelStatus
load_balancing_enabled: bool = False
has_invalid_load_balancing_configs: bool = False
def raise_for_status(self) -> None:
"""

File diff suppressed because it is too large Load Diff

View File

@ -69,6 +69,15 @@ class QuotaConfiguration(BaseModel):
restrict_models: list[RestrictModel] = []
class CredentialConfiguration(BaseModel):
"""
Model class for credential configuration.
"""
credential_id: str
credential_name: str
class SystemConfiguration(BaseModel):
"""
Model class for provider system configuration.
@ -86,6 +95,9 @@ class CustomProviderConfiguration(BaseModel):
"""
credentials: dict
current_credential_id: Optional[str] = None
current_credential_name: Optional[str] = None
available_credentials: list[CredentialConfiguration] = []
class CustomModelConfiguration(BaseModel):
@ -95,7 +107,10 @@ class CustomModelConfiguration(BaseModel):
model: str
model_type: ModelType
credentials: dict
credentials: dict | None
current_credential_id: Optional[str] = None
current_credential_name: Optional[str] = None
available_model_credentials: list[CredentialConfiguration] = []
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
@ -118,6 +133,7 @@ class ModelLoadBalancingConfiguration(BaseModel):
id: str
name: str
credentials: dict
credential_source_type: str | None = None
class ModelSettings(BaseModel):

View File

@ -201,7 +201,7 @@ class ModelProviderFactory:
return filtered_credentials
def get_model_schema(
self, *, provider: str, model_type: ModelType, model: str, credentials: dict
self, *, provider: str, model_type: ModelType, model: str, credentials: dict | None
) -> AIModelEntity | None:
"""
Get model schema

View File

@ -12,6 +12,7 @@ from configs import dify_config
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle
from core.entities.provider_entities import (
CredentialConfiguration,
CustomConfiguration,
CustomModelConfiguration,
CustomProviderConfiguration,
@ -40,7 +41,9 @@ from extensions.ext_redis import redis_client
from models.provider import (
LoadBalancingModelConfig,
Provider,
ProviderCredential,
ProviderModel,
ProviderModelCredential,
ProviderModelSetting,
ProviderType,
TenantDefaultModel,
@ -488,6 +491,61 @@ class ProviderManager:
return provider_name_to_provider_load_balancing_model_configs_dict
@staticmethod
def get_provider_available_credentials(tenant_id: str, provider_name: str) -> list[CredentialConfiguration]:
"""
Get provider all credentials.
:param tenant_id: workspace id
:param provider_name: provider name
:return:
"""
with Session(db.engine, expire_on_commit=False) as session:
stmt = (
select(ProviderCredential)
.where(ProviderCredential.tenant_id == tenant_id, ProviderCredential.provider_name == provider_name)
.order_by(ProviderCredential.created_at.desc())
)
available_credentials = session.scalars(stmt).all()
return [
CredentialConfiguration(credential_id=credential.id, credential_name=credential.credential_name)
for credential in available_credentials
]
@staticmethod
def get_provider_model_available_credentials(
tenant_id: str, provider_name: str, model_name: str, model_type: str
) -> list[CredentialConfiguration]:
"""
Get provider custom model all credentials.
:param tenant_id: workspace id
:param provider_name: provider name
:param model_name: model name
:param model_type: model type
:return:
"""
with Session(db.engine, expire_on_commit=False) as session:
stmt = (
select(ProviderModelCredential)
.where(
ProviderModelCredential.tenant_id == tenant_id,
ProviderModelCredential.provider_name == provider_name,
ProviderModelCredential.model_name == model_name,
ProviderModelCredential.model_type == model_type,
)
.order_by(ProviderModelCredential.created_at.desc())
)
available_credentials = session.scalars(stmt).all()
return [
CredentialConfiguration(credential_id=credential.id, credential_name=credential.credential_name)
for credential in available_credentials
]
@staticmethod
def _init_trial_provider_records(
tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]]
@ -590,9 +648,6 @@ class ProviderManager:
if provider_record.provider_type == ProviderType.SYSTEM.value:
continue
if not provider_record.encrypted_config:
continue
custom_provider_record = provider_record
# Get custom provider credentials
@ -611,8 +666,8 @@ class ProviderManager:
try:
# fix origin data
if custom_provider_record.encrypted_config is None:
raise ValueError("No credentials found")
if not custom_provider_record.encrypted_config.startswith("{"):
provider_credentials = {}
elif not custom_provider_record.encrypted_config.startswith("{"):
provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config}
else:
provider_credentials = json.loads(custom_provider_record.encrypted_config)
@ -637,7 +692,14 @@ class ProviderManager:
else:
provider_credentials = cached_provider_credentials
custom_provider_configuration = CustomProviderConfiguration(credentials=provider_credentials)
custom_provider_configuration = CustomProviderConfiguration(
credentials=provider_credentials,
current_credential_name=custom_provider_record.credential_name,
current_credential_id=custom_provider_record.credential_id,
available_credentials=self.get_provider_available_credentials(
tenant_id, custom_provider_record.provider_name
),
)
# Get provider model credential secret variables
model_credential_secret_variables = self._extract_secret_variables(
@ -649,8 +711,12 @@ class ProviderManager:
# Get custom provider model credentials
custom_model_configurations = []
for provider_model_record in provider_model_records:
if not provider_model_record.encrypted_config:
continue
available_model_credentials = self.get_provider_model_available_credentials(
tenant_id,
provider_model_record.provider_name,
provider_model_record.model_name,
provider_model_record.model_type,
)
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL
@ -659,7 +725,7 @@ class ProviderManager:
# Get cached provider model credentials
cached_provider_model_credentials = provider_model_credentials_cache.get()
if not cached_provider_model_credentials:
if not cached_provider_model_credentials and provider_model_record.encrypted_config:
try:
provider_model_credentials = json.loads(provider_model_record.encrypted_config)
except JSONDecodeError:
@ -688,6 +754,9 @@ class ProviderManager:
model=provider_model_record.model_name,
model_type=ModelType.value_of(provider_model_record.model_type),
credentials=provider_model_credentials,
current_credential_id=provider_model_record.credential_id,
current_credential_name=provider_model_record.credential_name,
available_model_credentials=available_model_credentials,
)
)
@ -899,6 +968,18 @@ class ProviderManager:
load_balancing_model_config.model_name == provider_model_setting.model_name
and load_balancing_model_config.model_type == provider_model_setting.model_type
):
if load_balancing_model_config.name == "__delete__":
# to calculate current model whether has invalidate lb configs
load_balancing_configs.append(
ModelLoadBalancingConfiguration(
id=load_balancing_model_config.id,
name=load_balancing_model_config.name,
credentials={},
credential_source_type=load_balancing_model_config.credential_source_type,
)
)
continue
if not load_balancing_model_config.enabled:
continue
@ -955,6 +1036,7 @@ class ProviderManager:
id=load_balancing_model_config.id,
name=load_balancing_model_config.name,
credentials=provider_model_credentials,
credential_source_type=load_balancing_model_config.credential_source_type,
)
)