mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 17:38:04 +08:00
feat: add multi model credentials (#24451)
Co-authored-by: zxhlyh <jasonapring2015@outlook.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -19,6 +19,7 @@ class ModelStatus(Enum):
|
||||
QUOTA_EXCEEDED = "quota-exceeded"
|
||||
NO_PERMISSION = "no-permission"
|
||||
DISABLED = "disabled"
|
||||
CREDENTIAL_REMOVED = "credential-removed"
|
||||
|
||||
|
||||
class SimpleModelProviderEntity(BaseModel):
|
||||
@ -54,6 +55,7 @@ class ProviderModelWithStatusEntity(ProviderModel):
|
||||
|
||||
status: ModelStatus
|
||||
load_balancing_enabled: bool = False
|
||||
has_invalid_load_balancing_configs: bool = False
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
"""
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -69,6 +69,15 @@ class QuotaConfiguration(BaseModel):
|
||||
restrict_models: list[RestrictModel] = []
|
||||
|
||||
|
||||
class CredentialConfiguration(BaseModel):
|
||||
"""
|
||||
Model class for credential configuration.
|
||||
"""
|
||||
|
||||
credential_id: str
|
||||
credential_name: str
|
||||
|
||||
|
||||
class SystemConfiguration(BaseModel):
|
||||
"""
|
||||
Model class for provider system configuration.
|
||||
@ -86,6 +95,9 @@ class CustomProviderConfiguration(BaseModel):
|
||||
"""
|
||||
|
||||
credentials: dict
|
||||
current_credential_id: Optional[str] = None
|
||||
current_credential_name: Optional[str] = None
|
||||
available_credentials: list[CredentialConfiguration] = []
|
||||
|
||||
|
||||
class CustomModelConfiguration(BaseModel):
|
||||
@ -95,7 +107,10 @@ class CustomModelConfiguration(BaseModel):
|
||||
|
||||
model: str
|
||||
model_type: ModelType
|
||||
credentials: dict
|
||||
credentials: dict | None
|
||||
current_credential_id: Optional[str] = None
|
||||
current_credential_name: Optional[str] = None
|
||||
available_model_credentials: list[CredentialConfiguration] = []
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
@ -118,6 +133,7 @@ class ModelLoadBalancingConfiguration(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
credentials: dict
|
||||
credential_source_type: str | None = None
|
||||
|
||||
|
||||
class ModelSettings(BaseModel):
|
||||
|
||||
@ -201,7 +201,7 @@ class ModelProviderFactory:
|
||||
return filtered_credentials
|
||||
|
||||
def get_model_schema(
|
||||
self, *, provider: str, model_type: ModelType, model: str, credentials: dict
|
||||
self, *, provider: str, model_type: ModelType, model: str, credentials: dict | None
|
||||
) -> AIModelEntity | None:
|
||||
"""
|
||||
Get model schema
|
||||
|
||||
@ -12,6 +12,7 @@ from configs import dify_config
|
||||
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle
|
||||
from core.entities.provider_entities import (
|
||||
CredentialConfiguration,
|
||||
CustomConfiguration,
|
||||
CustomModelConfiguration,
|
||||
CustomProviderConfiguration,
|
||||
@ -40,7 +41,9 @@ from extensions.ext_redis import redis_client
|
||||
from models.provider import (
|
||||
LoadBalancingModelConfig,
|
||||
Provider,
|
||||
ProviderCredential,
|
||||
ProviderModel,
|
||||
ProviderModelCredential,
|
||||
ProviderModelSetting,
|
||||
ProviderType,
|
||||
TenantDefaultModel,
|
||||
@ -488,6 +491,61 @@ class ProviderManager:
|
||||
|
||||
return provider_name_to_provider_load_balancing_model_configs_dict
|
||||
|
||||
@staticmethod
|
||||
def get_provider_available_credentials(tenant_id: str, provider_name: str) -> list[CredentialConfiguration]:
|
||||
"""
|
||||
Get provider all credentials.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider_name: provider name
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = (
|
||||
select(ProviderCredential)
|
||||
.where(ProviderCredential.tenant_id == tenant_id, ProviderCredential.provider_name == provider_name)
|
||||
.order_by(ProviderCredential.created_at.desc())
|
||||
)
|
||||
|
||||
available_credentials = session.scalars(stmt).all()
|
||||
|
||||
return [
|
||||
CredentialConfiguration(credential_id=credential.id, credential_name=credential.credential_name)
|
||||
for credential in available_credentials
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_provider_model_available_credentials(
|
||||
tenant_id: str, provider_name: str, model_name: str, model_type: str
|
||||
) -> list[CredentialConfiguration]:
|
||||
"""
|
||||
Get provider custom model all credentials.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider_name: provider name
|
||||
:param model_name: model name
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = (
|
||||
select(ProviderModelCredential)
|
||||
.where(
|
||||
ProviderModelCredential.tenant_id == tenant_id,
|
||||
ProviderModelCredential.provider_name == provider_name,
|
||||
ProviderModelCredential.model_name == model_name,
|
||||
ProviderModelCredential.model_type == model_type,
|
||||
)
|
||||
.order_by(ProviderModelCredential.created_at.desc())
|
||||
)
|
||||
|
||||
available_credentials = session.scalars(stmt).all()
|
||||
|
||||
return [
|
||||
CredentialConfiguration(credential_id=credential.id, credential_name=credential.credential_name)
|
||||
for credential in available_credentials
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _init_trial_provider_records(
|
||||
tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]]
|
||||
@ -590,9 +648,6 @@ class ProviderManager:
|
||||
if provider_record.provider_type == ProviderType.SYSTEM.value:
|
||||
continue
|
||||
|
||||
if not provider_record.encrypted_config:
|
||||
continue
|
||||
|
||||
custom_provider_record = provider_record
|
||||
|
||||
# Get custom provider credentials
|
||||
@ -611,8 +666,8 @@ class ProviderManager:
|
||||
try:
|
||||
# fix origin data
|
||||
if custom_provider_record.encrypted_config is None:
|
||||
raise ValueError("No credentials found")
|
||||
if not custom_provider_record.encrypted_config.startswith("{"):
|
||||
provider_credentials = {}
|
||||
elif not custom_provider_record.encrypted_config.startswith("{"):
|
||||
provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config}
|
||||
else:
|
||||
provider_credentials = json.loads(custom_provider_record.encrypted_config)
|
||||
@ -637,7 +692,14 @@ class ProviderManager:
|
||||
else:
|
||||
provider_credentials = cached_provider_credentials
|
||||
|
||||
custom_provider_configuration = CustomProviderConfiguration(credentials=provider_credentials)
|
||||
custom_provider_configuration = CustomProviderConfiguration(
|
||||
credentials=provider_credentials,
|
||||
current_credential_name=custom_provider_record.credential_name,
|
||||
current_credential_id=custom_provider_record.credential_id,
|
||||
available_credentials=self.get_provider_available_credentials(
|
||||
tenant_id, custom_provider_record.provider_name
|
||||
),
|
||||
)
|
||||
|
||||
# Get provider model credential secret variables
|
||||
model_credential_secret_variables = self._extract_secret_variables(
|
||||
@ -649,8 +711,12 @@ class ProviderManager:
|
||||
# Get custom provider model credentials
|
||||
custom_model_configurations = []
|
||||
for provider_model_record in provider_model_records:
|
||||
if not provider_model_record.encrypted_config:
|
||||
continue
|
||||
available_model_credentials = self.get_provider_model_available_credentials(
|
||||
tenant_id,
|
||||
provider_model_record.provider_name,
|
||||
provider_model_record.model_name,
|
||||
provider_model_record.model_type,
|
||||
)
|
||||
|
||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL
|
||||
@ -659,7 +725,7 @@ class ProviderManager:
|
||||
# Get cached provider model credentials
|
||||
cached_provider_model_credentials = provider_model_credentials_cache.get()
|
||||
|
||||
if not cached_provider_model_credentials:
|
||||
if not cached_provider_model_credentials and provider_model_record.encrypted_config:
|
||||
try:
|
||||
provider_model_credentials = json.loads(provider_model_record.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
@ -688,6 +754,9 @@ class ProviderManager:
|
||||
model=provider_model_record.model_name,
|
||||
model_type=ModelType.value_of(provider_model_record.model_type),
|
||||
credentials=provider_model_credentials,
|
||||
current_credential_id=provider_model_record.credential_id,
|
||||
current_credential_name=provider_model_record.credential_name,
|
||||
available_model_credentials=available_model_credentials,
|
||||
)
|
||||
)
|
||||
|
||||
@ -899,6 +968,18 @@ class ProviderManager:
|
||||
load_balancing_model_config.model_name == provider_model_setting.model_name
|
||||
and load_balancing_model_config.model_type == provider_model_setting.model_type
|
||||
):
|
||||
if load_balancing_model_config.name == "__delete__":
|
||||
# to calculate current model whether has invalidate lb configs
|
||||
load_balancing_configs.append(
|
||||
ModelLoadBalancingConfiguration(
|
||||
id=load_balancing_model_config.id,
|
||||
name=load_balancing_model_config.name,
|
||||
credentials={},
|
||||
credential_source_type=load_balancing_model_config.credential_source_type,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if not load_balancing_model_config.enabled:
|
||||
continue
|
||||
|
||||
@ -955,6 +1036,7 @@ class ProviderManager:
|
||||
id=load_balancing_model_config.id,
|
||||
name=load_balancing_model_config.name,
|
||||
credentials=provider_model_credentials,
|
||||
credential_source_type=load_balancing_model_config.credential_source_type,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user