Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine

This commit is contained in:
-LAN-
2025-09-03 13:53:43 +08:00
42 changed files with 1565 additions and 750 deletions

View File

@ -1,8 +1,9 @@
import contextlib
import json
from collections import defaultdict
from collections.abc import Sequence
from json import JSONDecodeError
from typing import Any, Optional
from typing import Any, Optional, cast
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
@ -22,6 +23,7 @@ from core.entities.provider_entities import (
QuotaConfiguration,
QuotaUnit,
SystemConfiguration,
UnaddedModelConfiguration,
)
from core.helper import encrypter
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
@ -537,6 +539,23 @@ class ProviderManager:
for credential in available_credentials
]
@staticmethod
def get_credentials_from_provider_model(tenant_id: str, provider_name: str) -> Sequence[ProviderModelCredential]:
"""
Get all the credentials records from ProviderModelCredential by provider_name
:param tenant_id: workspace id
:param provider_name: provider name
"""
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.tenant_id == tenant_id, ProviderModelCredential.provider_name == provider_name
)
all_credentials = session.scalars(stmt).all()
return all_credentials
@staticmethod
def _init_trial_provider_records(
tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]]
@ -623,6 +642,44 @@ class ProviderManager:
:param provider_model_records: provider model records
:return:
"""
# Get custom provider configuration
custom_provider_configuration = self._get_custom_provider_configuration(
tenant_id, provider_entity, provider_records
)
# Get all model credentials once
all_model_credentials = self.get_credentials_from_provider_model(tenant_id, provider_entity.provider)
# Get custom models which have not been added to the model list yet
unadded_models = self._get_can_added_models(provider_model_records, all_model_credentials)
# Get custom model configurations
custom_model_configurations = self._get_custom_model_configurations(
tenant_id, provider_entity, provider_model_records, unadded_models, all_model_credentials
)
can_added_models = [
UnaddedModelConfiguration(model=model["model"], model_type=model["model_type"]) for model in unadded_models
]
return CustomConfiguration(
provider=custom_provider_configuration,
models=custom_model_configurations,
can_added_models=can_added_models,
)
def _get_custom_provider_configuration(
self, tenant_id: str, provider_entity: ProviderEntity, provider_records: list[Provider]
) -> CustomProviderConfiguration | None:
"""Get custom provider configuration."""
# Find custom provider record (non-system)
custom_provider_record = next(
(record for record in provider_records if record.provider_type != ProviderType.SYSTEM.value), None
)
if not custom_provider_record:
return None
# Get provider credential secret variables
provider_credential_secret_variables = self._extract_secret_variables(
provider_entity.provider_credential_schema.credential_form_schemas
@ -630,113 +687,98 @@ class ProviderManager:
else []
)
# Get custom provider record
custom_provider_record = None
for provider_record in provider_records:
if provider_record.provider_type == ProviderType.SYSTEM.value:
continue
# Get and decrypt provider credentials
provider_credentials = self._get_and_decrypt_credentials(
tenant_id=tenant_id,
record_id=custom_provider_record.id,
encrypted_config=custom_provider_record.encrypted_config,
secret_variables=provider_credential_secret_variables,
cache_type=ProviderCredentialsCacheType.PROVIDER,
is_provider=True,
)
custom_provider_record = provider_record
return CustomProviderConfiguration(
credentials=provider_credentials,
current_credential_name=custom_provider_record.credential_name,
current_credential_id=custom_provider_record.credential_id,
available_credentials=self.get_provider_available_credentials(
tenant_id, custom_provider_record.provider_name
),
)
# Get custom provider credentials
custom_provider_configuration = None
if custom_provider_record:
provider_credentials_cache = ProviderCredentialsCache(
tenant_id=tenant_id,
identity_id=custom_provider_record.id,
cache_type=ProviderCredentialsCacheType.PROVIDER,
)
def _get_can_added_models(
self, provider_model_records: list[ProviderModel], all_model_credentials: Sequence[ProviderModelCredential]
) -> list[dict]:
"""Get the custom models and credentials from enterprise version which haven't add to the model list"""
existing_model_set = {(record.model_name, record.model_type) for record in provider_model_records}
# Get cached provider credentials
cached_provider_credentials = provider_credentials_cache.get()
# Get not added custom models credentials
not_added_custom_models_credentials = [
credential
for credential in all_model_credentials
if (credential.model_name, credential.model_type) not in existing_model_set
]
if not cached_provider_credentials:
try:
# fix origin data
if custom_provider_record.encrypted_config is None:
provider_credentials = {}
elif not custom_provider_record.encrypted_config.startswith("{"):
provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config}
else:
provider_credentials = json.loads(custom_provider_record.encrypted_config)
except JSONDecodeError:
provider_credentials = {}
# Group credentials by model
model_to_credentials = defaultdict(list)
for credential in not_added_custom_models_credentials:
model_to_credentials[(credential.model_name, credential.model_type)].append(credential)
# Get decoding rsa key and cipher for decrypting credentials
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
return [
{
"model": model_key[0],
"model_type": ModelType.value_of(model_key[1]),
"available_model_credentials": [
CredentialConfiguration(credential_id=cred.id, credential_name=cred.credential_name)
for cred in creds
],
}
for model_key, creds in model_to_credentials.items()
]
for variable in provider_credential_secret_variables:
if variable in provider_credentials:
with contextlib.suppress(ValueError):
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_credentials.get(variable) or "", # type: ignore
self.decoding_rsa_key,
self.decoding_cipher_rsa,
)
# cache provider credentials
provider_credentials_cache.set(credentials=provider_credentials)
else:
provider_credentials = cached_provider_credentials
custom_provider_configuration = CustomProviderConfiguration(
credentials=provider_credentials,
current_credential_name=custom_provider_record.credential_name,
current_credential_id=custom_provider_record.credential_id,
available_credentials=self.get_provider_available_credentials(
tenant_id, custom_provider_record.provider_name
),
)
# Get provider model credential secret variables
def _get_custom_model_configurations(
self,
tenant_id: str,
provider_entity: ProviderEntity,
provider_model_records: list[ProviderModel],
can_added_models: list[dict],
all_model_credentials: Sequence[ProviderModelCredential],
) -> list[CustomModelConfiguration]:
"""Get custom model configurations."""
# Get model credential secret variables
model_credential_secret_variables = self._extract_secret_variables(
provider_entity.model_credential_schema.credential_form_schemas
if provider_entity.model_credential_schema
else []
)
# Get custom provider model credentials
# Create credentials lookup for efficient access
credentials_map = defaultdict(list)
for credential in all_model_credentials:
credentials_map[(credential.model_name, credential.model_type)].append(credential)
custom_model_configurations = []
# Process existing model records
for provider_model_record in provider_model_records:
available_model_credentials = self.get_provider_model_available_credentials(
tenant_id,
provider_model_record.provider_name,
provider_model_record.model_name,
provider_model_record.model_type,
# Use pre-fetched credentials instead of individual database calls
available_model_credentials = [
CredentialConfiguration(credential_id=cred.id, credential_name=cred.credential_name)
for cred in credentials_map.get(
(provider_model_record.model_name, provider_model_record.model_type), []
)
]
# Get and decrypt model credentials
provider_model_credentials = self._get_and_decrypt_credentials(
tenant_id=tenant_id,
record_id=provider_model_record.id,
encrypted_config=provider_model_record.encrypted_config,
secret_variables=model_credential_secret_variables,
cache_type=ProviderCredentialsCacheType.MODEL,
is_provider=False,
)
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL
)
# Get cached provider model credentials
cached_provider_model_credentials = provider_model_credentials_cache.get()
if not cached_provider_model_credentials and provider_model_record.encrypted_config:
try:
provider_model_credentials = json.loads(provider_model_record.encrypted_config)
except JSONDecodeError:
continue
# Get decoding rsa key and cipher for decrypting credentials
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
for variable in model_credential_secret_variables:
if variable in provider_model_credentials:
with contextlib.suppress(ValueError):
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_model_credentials.get(variable),
self.decoding_rsa_key,
self.decoding_cipher_rsa,
)
# cache provider model credentials
provider_model_credentials_cache.set(credentials=provider_model_credentials)
else:
provider_model_credentials = cached_provider_model_credentials
custom_model_configurations.append(
CustomModelConfiguration(
model=provider_model_record.model_name,
@ -748,7 +790,71 @@ class ProviderManager:
)
)
return CustomConfiguration(provider=custom_provider_configuration, models=custom_model_configurations)
# Add models that can be added
for model in can_added_models:
custom_model_configurations.append(
CustomModelConfiguration(
model=model["model"],
model_type=model["model_type"],
credentials=None,
current_credential_id=None,
current_credential_name=None,
available_model_credentials=model["available_model_credentials"],
unadded_to_model_list=True,
)
)
return custom_model_configurations
def _get_and_decrypt_credentials(
self,
tenant_id: str,
record_id: str,
encrypted_config: str | None,
secret_variables: list[str],
cache_type: ProviderCredentialsCacheType,
is_provider: bool = False,
) -> dict:
"""Get and decrypt credentials with caching."""
credentials_cache = ProviderCredentialsCache(
tenant_id=tenant_id,
identity_id=record_id,
cache_type=cache_type,
)
# Try to get from cache first
cached_credentials = credentials_cache.get()
if cached_credentials:
return cached_credentials
# Parse encrypted config
if not encrypted_config:
return {}
if is_provider and not encrypted_config.startswith("{"):
return {"openai_api_key": encrypted_config}
try:
credentials = cast(dict, json.loads(encrypted_config))
except JSONDecodeError:
return {}
# Decrypt secret variables
if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None:
self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
for variable in secret_variables:
if variable in credentials:
with contextlib.suppress(ValueError):
credentials[variable] = encrypter.decrypt_token_with_decoding(
credentials.get(variable) or "",
self.decoding_rsa_key,
self.decoding_cipher_rsa,
)
# Cache the decrypted credentials
credentials_cache.set(credentials=credentials)
return credentials
def _to_system_configuration(
self, tenant_id: str, provider_entity: ProviderEntity, provider_records: list[Provider]
@ -956,18 +1062,6 @@ class ProviderManager:
load_balancing_model_config.model_name == provider_model_setting.model_name
and load_balancing_model_config.model_type == provider_model_setting.model_type
):
if load_balancing_model_config.name == "__delete__":
# to calculate current model whether has invalidate lb configs
load_balancing_configs.append(
ModelLoadBalancingConfiguration(
id=load_balancing_model_config.id,
name=load_balancing_model_config.name,
credentials={},
credential_source_type=load_balancing_model_config.credential_source_type,
)
)
continue
if not load_balancing_model_config.enabled:
continue
@ -1033,6 +1127,7 @@ class ProviderManager:
model=provider_model_setting.model_name,
model_type=ModelType.value_of(provider_model_setting.model_type),
enabled=provider_model_setting.enabled,
load_balancing_enabled=provider_model_setting.load_balancing_enabled,
load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [],
)
)