mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 02:18:08 +08:00
Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine
This commit is contained in:
@ -1,8 +1,9 @@
|
||||
import contextlib
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
@ -22,6 +23,7 @@ from core.entities.provider_entities import (
|
||||
QuotaConfiguration,
|
||||
QuotaUnit,
|
||||
SystemConfiguration,
|
||||
UnaddedModelConfiguration,
|
||||
)
|
||||
from core.helper import encrypter
|
||||
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
||||
@ -537,6 +539,23 @@ class ProviderManager:
|
||||
for credential in available_credentials
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_credentials_from_provider_model(tenant_id: str, provider_name: str) -> Sequence[ProviderModelCredential]:
|
||||
"""
|
||||
Get all the credentials records from ProviderModelCredential by provider_name
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider_name: provider name
|
||||
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.tenant_id == tenant_id, ProviderModelCredential.provider_name == provider_name
|
||||
)
|
||||
|
||||
all_credentials = session.scalars(stmt).all()
|
||||
return all_credentials
|
||||
|
||||
@staticmethod
|
||||
def _init_trial_provider_records(
|
||||
tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]]
|
||||
@ -623,6 +642,44 @@ class ProviderManager:
|
||||
:param provider_model_records: provider model records
|
||||
:return:
|
||||
"""
|
||||
# Get custom provider configuration
|
||||
custom_provider_configuration = self._get_custom_provider_configuration(
|
||||
tenant_id, provider_entity, provider_records
|
||||
)
|
||||
|
||||
# Get all model credentials once
|
||||
all_model_credentials = self.get_credentials_from_provider_model(tenant_id, provider_entity.provider)
|
||||
|
||||
# Get custom models which have not been added to the model list yet
|
||||
unadded_models = self._get_can_added_models(provider_model_records, all_model_credentials)
|
||||
|
||||
# Get custom model configurations
|
||||
custom_model_configurations = self._get_custom_model_configurations(
|
||||
tenant_id, provider_entity, provider_model_records, unadded_models, all_model_credentials
|
||||
)
|
||||
|
||||
can_added_models = [
|
||||
UnaddedModelConfiguration(model=model["model"], model_type=model["model_type"]) for model in unadded_models
|
||||
]
|
||||
|
||||
return CustomConfiguration(
|
||||
provider=custom_provider_configuration,
|
||||
models=custom_model_configurations,
|
||||
can_added_models=can_added_models,
|
||||
)
|
||||
|
||||
def _get_custom_provider_configuration(
|
||||
self, tenant_id: str, provider_entity: ProviderEntity, provider_records: list[Provider]
|
||||
) -> CustomProviderConfiguration | None:
|
||||
"""Get custom provider configuration."""
|
||||
# Find custom provider record (non-system)
|
||||
custom_provider_record = next(
|
||||
(record for record in provider_records if record.provider_type != ProviderType.SYSTEM.value), None
|
||||
)
|
||||
|
||||
if not custom_provider_record:
|
||||
return None
|
||||
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self._extract_secret_variables(
|
||||
provider_entity.provider_credential_schema.credential_form_schemas
|
||||
@ -630,113 +687,98 @@ class ProviderManager:
|
||||
else []
|
||||
)
|
||||
|
||||
# Get custom provider record
|
||||
custom_provider_record = None
|
||||
for provider_record in provider_records:
|
||||
if provider_record.provider_type == ProviderType.SYSTEM.value:
|
||||
continue
|
||||
# Get and decrypt provider credentials
|
||||
provider_credentials = self._get_and_decrypt_credentials(
|
||||
tenant_id=tenant_id,
|
||||
record_id=custom_provider_record.id,
|
||||
encrypted_config=custom_provider_record.encrypted_config,
|
||||
secret_variables=provider_credential_secret_variables,
|
||||
cache_type=ProviderCredentialsCacheType.PROVIDER,
|
||||
is_provider=True,
|
||||
)
|
||||
|
||||
custom_provider_record = provider_record
|
||||
return CustomProviderConfiguration(
|
||||
credentials=provider_credentials,
|
||||
current_credential_name=custom_provider_record.credential_name,
|
||||
current_credential_id=custom_provider_record.credential_id,
|
||||
available_credentials=self.get_provider_available_credentials(
|
||||
tenant_id, custom_provider_record.provider_name
|
||||
),
|
||||
)
|
||||
|
||||
# Get custom provider credentials
|
||||
custom_provider_configuration = None
|
||||
if custom_provider_record:
|
||||
provider_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
identity_id=custom_provider_record.id,
|
||||
cache_type=ProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
def _get_can_added_models(
|
||||
self, provider_model_records: list[ProviderModel], all_model_credentials: Sequence[ProviderModelCredential]
|
||||
) -> list[dict]:
|
||||
"""Get the custom models and credentials from enterprise version which haven't add to the model list"""
|
||||
existing_model_set = {(record.model_name, record.model_type) for record in provider_model_records}
|
||||
|
||||
# Get cached provider credentials
|
||||
cached_provider_credentials = provider_credentials_cache.get()
|
||||
# Get not added custom models credentials
|
||||
not_added_custom_models_credentials = [
|
||||
credential
|
||||
for credential in all_model_credentials
|
||||
if (credential.model_name, credential.model_type) not in existing_model_set
|
||||
]
|
||||
|
||||
if not cached_provider_credentials:
|
||||
try:
|
||||
# fix origin data
|
||||
if custom_provider_record.encrypted_config is None:
|
||||
provider_credentials = {}
|
||||
elif not custom_provider_record.encrypted_config.startswith("{"):
|
||||
provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config}
|
||||
else:
|
||||
provider_credentials = json.loads(custom_provider_record.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
provider_credentials = {}
|
||||
# Group credentials by model
|
||||
model_to_credentials = defaultdict(list)
|
||||
for credential in not_added_custom_models_credentials:
|
||||
model_to_credentials[(credential.model_name, credential.model_type)].append(credential)
|
||||
|
||||
# Get decoding rsa key and cipher for decrypting credentials
|
||||
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
|
||||
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
|
||||
return [
|
||||
{
|
||||
"model": model_key[0],
|
||||
"model_type": ModelType.value_of(model_key[1]),
|
||||
"available_model_credentials": [
|
||||
CredentialConfiguration(credential_id=cred.id, credential_name=cred.credential_name)
|
||||
for cred in creds
|
||||
],
|
||||
}
|
||||
for model_key, creds in model_to_credentials.items()
|
||||
]
|
||||
|
||||
for variable in provider_credential_secret_variables:
|
||||
if variable in provider_credentials:
|
||||
with contextlib.suppress(ValueError):
|
||||
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
provider_credentials.get(variable) or "", # type: ignore
|
||||
self.decoding_rsa_key,
|
||||
self.decoding_cipher_rsa,
|
||||
)
|
||||
|
||||
# cache provider credentials
|
||||
provider_credentials_cache.set(credentials=provider_credentials)
|
||||
else:
|
||||
provider_credentials = cached_provider_credentials
|
||||
|
||||
custom_provider_configuration = CustomProviderConfiguration(
|
||||
credentials=provider_credentials,
|
||||
current_credential_name=custom_provider_record.credential_name,
|
||||
current_credential_id=custom_provider_record.credential_id,
|
||||
available_credentials=self.get_provider_available_credentials(
|
||||
tenant_id, custom_provider_record.provider_name
|
||||
),
|
||||
)
|
||||
|
||||
# Get provider model credential secret variables
|
||||
def _get_custom_model_configurations(
|
||||
self,
|
||||
tenant_id: str,
|
||||
provider_entity: ProviderEntity,
|
||||
provider_model_records: list[ProviderModel],
|
||||
can_added_models: list[dict],
|
||||
all_model_credentials: Sequence[ProviderModelCredential],
|
||||
) -> list[CustomModelConfiguration]:
|
||||
"""Get custom model configurations."""
|
||||
# Get model credential secret variables
|
||||
model_credential_secret_variables = self._extract_secret_variables(
|
||||
provider_entity.model_credential_schema.credential_form_schemas
|
||||
if provider_entity.model_credential_schema
|
||||
else []
|
||||
)
|
||||
|
||||
# Get custom provider model credentials
|
||||
# Create credentials lookup for efficient access
|
||||
credentials_map = defaultdict(list)
|
||||
for credential in all_model_credentials:
|
||||
credentials_map[(credential.model_name, credential.model_type)].append(credential)
|
||||
|
||||
custom_model_configurations = []
|
||||
|
||||
# Process existing model records
|
||||
for provider_model_record in provider_model_records:
|
||||
available_model_credentials = self.get_provider_model_available_credentials(
|
||||
tenant_id,
|
||||
provider_model_record.provider_name,
|
||||
provider_model_record.model_name,
|
||||
provider_model_record.model_type,
|
||||
# Use pre-fetched credentials instead of individual database calls
|
||||
available_model_credentials = [
|
||||
CredentialConfiguration(credential_id=cred.id, credential_name=cred.credential_name)
|
||||
for cred in credentials_map.get(
|
||||
(provider_model_record.model_name, provider_model_record.model_type), []
|
||||
)
|
||||
]
|
||||
|
||||
# Get and decrypt model credentials
|
||||
provider_model_credentials = self._get_and_decrypt_credentials(
|
||||
tenant_id=tenant_id,
|
||||
record_id=provider_model_record.id,
|
||||
encrypted_config=provider_model_record.encrypted_config,
|
||||
secret_variables=model_credential_secret_variables,
|
||||
cache_type=ProviderCredentialsCacheType.MODEL,
|
||||
is_provider=False,
|
||||
)
|
||||
|
||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL
|
||||
)
|
||||
|
||||
# Get cached provider model credentials
|
||||
cached_provider_model_credentials = provider_model_credentials_cache.get()
|
||||
|
||||
if not cached_provider_model_credentials and provider_model_record.encrypted_config:
|
||||
try:
|
||||
provider_model_credentials = json.loads(provider_model_record.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
continue
|
||||
|
||||
# Get decoding rsa key and cipher for decrypting credentials
|
||||
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
|
||||
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
|
||||
|
||||
for variable in model_credential_secret_variables:
|
||||
if variable in provider_model_credentials:
|
||||
with contextlib.suppress(ValueError):
|
||||
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
provider_model_credentials.get(variable),
|
||||
self.decoding_rsa_key,
|
||||
self.decoding_cipher_rsa,
|
||||
)
|
||||
|
||||
# cache provider model credentials
|
||||
provider_model_credentials_cache.set(credentials=provider_model_credentials)
|
||||
else:
|
||||
provider_model_credentials = cached_provider_model_credentials
|
||||
|
||||
custom_model_configurations.append(
|
||||
CustomModelConfiguration(
|
||||
model=provider_model_record.model_name,
|
||||
@ -748,7 +790,71 @@ class ProviderManager:
|
||||
)
|
||||
)
|
||||
|
||||
return CustomConfiguration(provider=custom_provider_configuration, models=custom_model_configurations)
|
||||
# Add models that can be added
|
||||
for model in can_added_models:
|
||||
custom_model_configurations.append(
|
||||
CustomModelConfiguration(
|
||||
model=model["model"],
|
||||
model_type=model["model_type"],
|
||||
credentials=None,
|
||||
current_credential_id=None,
|
||||
current_credential_name=None,
|
||||
available_model_credentials=model["available_model_credentials"],
|
||||
unadded_to_model_list=True,
|
||||
)
|
||||
)
|
||||
|
||||
return custom_model_configurations
|
||||
|
||||
def _get_and_decrypt_credentials(
|
||||
self,
|
||||
tenant_id: str,
|
||||
record_id: str,
|
||||
encrypted_config: str | None,
|
||||
secret_variables: list[str],
|
||||
cache_type: ProviderCredentialsCacheType,
|
||||
is_provider: bool = False,
|
||||
) -> dict:
|
||||
"""Get and decrypt credentials with caching."""
|
||||
credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
identity_id=record_id,
|
||||
cache_type=cache_type,
|
||||
)
|
||||
|
||||
# Try to get from cache first
|
||||
cached_credentials = credentials_cache.get()
|
||||
if cached_credentials:
|
||||
return cached_credentials
|
||||
|
||||
# Parse encrypted config
|
||||
if not encrypted_config:
|
||||
return {}
|
||||
|
||||
if is_provider and not encrypted_config.startswith("{"):
|
||||
return {"openai_api_key": encrypted_config}
|
||||
|
||||
try:
|
||||
credentials = cast(dict, json.loads(encrypted_config))
|
||||
except JSONDecodeError:
|
||||
return {}
|
||||
|
||||
# Decrypt secret variables
|
||||
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
|
||||
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
|
||||
|
||||
for variable in secret_variables:
|
||||
if variable in credentials:
|
||||
with contextlib.suppress(ValueError):
|
||||
credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
credentials.get(variable) or "",
|
||||
self.decoding_rsa_key,
|
||||
self.decoding_cipher_rsa,
|
||||
)
|
||||
|
||||
# Cache the decrypted credentials
|
||||
credentials_cache.set(credentials=credentials)
|
||||
return credentials
|
||||
|
||||
def _to_system_configuration(
|
||||
self, tenant_id: str, provider_entity: ProviderEntity, provider_records: list[Provider]
|
||||
@ -956,18 +1062,6 @@ class ProviderManager:
|
||||
load_balancing_model_config.model_name == provider_model_setting.model_name
|
||||
and load_balancing_model_config.model_type == provider_model_setting.model_type
|
||||
):
|
||||
if load_balancing_model_config.name == "__delete__":
|
||||
# to calculate current model whether has invalidate lb configs
|
||||
load_balancing_configs.append(
|
||||
ModelLoadBalancingConfiguration(
|
||||
id=load_balancing_model_config.id,
|
||||
name=load_balancing_model_config.name,
|
||||
credentials={},
|
||||
credential_source_type=load_balancing_model_config.credential_source_type,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if not load_balancing_model_config.enabled:
|
||||
continue
|
||||
|
||||
@ -1033,6 +1127,7 @@ class ProviderManager:
|
||||
model=provider_model_setting.model_name,
|
||||
model_type=ModelType.value_of(provider_model_setting.model_type),
|
||||
enabled=provider_model_setting.enabled,
|
||||
load_balancing_enabled=provider_model_setting.load_balancing_enabled,
|
||||
load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [],
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user