|
|
|
|
@ -1,9 +1,9 @@
|
|
|
|
|
import json
|
|
|
|
|
import logging
|
|
|
|
|
import re
|
|
|
|
|
from collections import defaultdict
|
|
|
|
|
from collections.abc import Iterator, Sequence
|
|
|
|
|
from json import JSONDecodeError
|
|
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
|
|
from pydantic import BaseModel, ConfigDict, Field
|
|
|
|
|
from sqlalchemy import func, select
|
|
|
|
|
@ -28,7 +28,6 @@ from core.model_runtime.entities.provider_entities import (
|
|
|
|
|
)
|
|
|
|
|
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
|
|
|
|
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
|
|
|
|
from core.plugin.entities.plugin import ModelProviderID
|
|
|
|
|
from extensions.ext_database import db
|
|
|
|
|
from libs.datetime_utils import naive_utc_now
|
|
|
|
|
from models.provider import (
|
|
|
|
|
@ -41,6 +40,8 @@ from models.provider import (
|
|
|
|
|
ProviderType,
|
|
|
|
|
TenantPreferredModelProvider,
|
|
|
|
|
)
|
|
|
|
|
from models.provider_ids import ModelProviderID
|
|
|
|
|
from services.enterprise.plugin_manager_service import PluginCredentialType
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
@ -90,7 +91,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
):
|
|
|
|
|
self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
|
|
|
|
|
|
|
|
|
|
def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
|
|
|
|
|
def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None:
|
|
|
|
|
"""
|
|
|
|
|
Get current credentials.
|
|
|
|
|
|
|
|
|
|
@ -128,18 +129,42 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
return copy_credentials
|
|
|
|
|
else:
|
|
|
|
|
credentials = None
|
|
|
|
|
current_credential_id = None
|
|
|
|
|
|
|
|
|
|
if self.custom_configuration.models:
|
|
|
|
|
for model_configuration in self.custom_configuration.models:
|
|
|
|
|
if model_configuration.model_type == model_type and model_configuration.model == model:
|
|
|
|
|
credentials = model_configuration.credentials
|
|
|
|
|
current_credential_id = model_configuration.current_credential_id
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
if not credentials and self.custom_configuration.provider:
|
|
|
|
|
credentials = self.custom_configuration.provider.credentials
|
|
|
|
|
current_credential_id = self.custom_configuration.provider.current_credential_id
|
|
|
|
|
|
|
|
|
|
if current_credential_id:
|
|
|
|
|
from core.helper.credential_utils import check_credential_policy_compliance
|
|
|
|
|
|
|
|
|
|
check_credential_policy_compliance(
|
|
|
|
|
credential_id=current_credential_id,
|
|
|
|
|
provider=self.provider.provider,
|
|
|
|
|
credential_type=PluginCredentialType.MODEL,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
# no current credential id, check all available credentials
|
|
|
|
|
if self.custom_configuration.provider:
|
|
|
|
|
for credential_configuration in self.custom_configuration.provider.available_credentials:
|
|
|
|
|
from core.helper.credential_utils import check_credential_policy_compliance
|
|
|
|
|
|
|
|
|
|
check_credential_policy_compliance(
|
|
|
|
|
credential_id=credential_configuration.credential_id,
|
|
|
|
|
provider=self.provider.provider,
|
|
|
|
|
credential_type=PluginCredentialType.MODEL,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return credentials
|
|
|
|
|
|
|
|
|
|
def get_system_configuration_status(self) -> Optional[SystemConfigurationStatus]:
|
|
|
|
|
def get_system_configuration_status(self) -> SystemConfigurationStatus | None:
|
|
|
|
|
"""
|
|
|
|
|
Get system configuration status.
|
|
|
|
|
:return:
|
|
|
|
|
@ -180,16 +205,10 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
"""
|
|
|
|
|
Get custom provider record.
|
|
|
|
|
"""
|
|
|
|
|
# get provider
|
|
|
|
|
model_provider_id = ModelProviderID(self.provider.provider)
|
|
|
|
|
provider_names = [self.provider.provider]
|
|
|
|
|
if model_provider_id.is_langgenius():
|
|
|
|
|
provider_names.append(model_provider_id.provider_name)
|
|
|
|
|
|
|
|
|
|
stmt = select(Provider).where(
|
|
|
|
|
Provider.tenant_id == self.tenant_id,
|
|
|
|
|
Provider.provider_type == ProviderType.CUSTOM.value,
|
|
|
|
|
Provider.provider_name.in_(provider_names),
|
|
|
|
|
Provider.provider_name.in_(self._get_provider_names()),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return session.execute(stmt).scalar_one_or_none()
|
|
|
|
|
@ -251,7 +270,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
"""
|
|
|
|
|
stmt = select(ProviderCredential.id).where(
|
|
|
|
|
ProviderCredential.tenant_id == self.tenant_id,
|
|
|
|
|
ProviderCredential.provider_name == self.provider.provider,
|
|
|
|
|
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
|
|
|
|
ProviderCredential.credential_name == credential_name,
|
|
|
|
|
)
|
|
|
|
|
if exclude_id:
|
|
|
|
|
@ -265,7 +284,6 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
:param credential_id: if provided, return the specified credential
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if credential_id:
|
|
|
|
|
return self._get_specific_provider_credential(credential_id)
|
|
|
|
|
|
|
|
|
|
@ -279,9 +297,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
else [],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def validate_provider_credentials(
|
|
|
|
|
self, credentials: dict, credential_id: str = "", session: Session | None = None
|
|
|
|
|
) -> dict:
|
|
|
|
|
def validate_provider_credentials(self, credentials: dict, credential_id: str = "", session: Session | None = None):
|
|
|
|
|
"""
|
|
|
|
|
Validate custom credentials.
|
|
|
|
|
:param credentials: provider credentials
|
|
|
|
|
@ -290,7 +306,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def _validate(s: Session) -> dict:
|
|
|
|
|
def _validate(s: Session):
|
|
|
|
|
# Get provider credential secret variables
|
|
|
|
|
provider_credential_secret_variables = self.extract_secret_variables(
|
|
|
|
|
self.provider.provider_credential_schema.credential_form_schemas
|
|
|
|
|
@ -302,7 +318,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
try:
|
|
|
|
|
stmt = select(ProviderCredential).where(
|
|
|
|
|
ProviderCredential.tenant_id == self.tenant_id,
|
|
|
|
|
ProviderCredential.provider_name == self.provider.provider,
|
|
|
|
|
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
|
|
|
|
ProviderCredential.id == credential_id,
|
|
|
|
|
)
|
|
|
|
|
credential_record = s.execute(stmt).scalar_one_or_none()
|
|
|
|
|
@ -343,7 +359,75 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
with Session(db.engine) as new_session:
|
|
|
|
|
return _validate(new_session)
|
|
|
|
|
|
|
|
|
|
def create_provider_credential(self, credentials: dict, credential_name: str) -> None:
|
|
|
|
|
def _generate_provider_credential_name(self, session) -> str:
|
|
|
|
|
"""
|
|
|
|
|
Generate a unique credential name for provider.
|
|
|
|
|
:return: credential name
|
|
|
|
|
"""
|
|
|
|
|
return self._generate_next_api_key_name(
|
|
|
|
|
session=session,
|
|
|
|
|
query_factory=lambda: select(ProviderCredential).where(
|
|
|
|
|
ProviderCredential.tenant_id == self.tenant_id,
|
|
|
|
|
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _generate_custom_model_credential_name(self, model: str, model_type: ModelType, session) -> str:
|
|
|
|
|
"""
|
|
|
|
|
Generate a unique credential name for custom model.
|
|
|
|
|
:return: credential name
|
|
|
|
|
"""
|
|
|
|
|
return self._generate_next_api_key_name(
|
|
|
|
|
session=session,
|
|
|
|
|
query_factory=lambda: select(ProviderModelCredential).where(
|
|
|
|
|
ProviderModelCredential.tenant_id == self.tenant_id,
|
|
|
|
|
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
|
|
|
|
ProviderModelCredential.model_name == model,
|
|
|
|
|
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _generate_next_api_key_name(self, session, query_factory) -> str:
|
|
|
|
|
"""
|
|
|
|
|
Generate next available API KEY name by finding the highest numbered suffix.
|
|
|
|
|
:param session: database session
|
|
|
|
|
:param query_factory: function that returns the SQLAlchemy query
|
|
|
|
|
:return: next available API KEY name
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
stmt = query_factory()
|
|
|
|
|
credential_records = session.execute(stmt).scalars().all()
|
|
|
|
|
|
|
|
|
|
if not credential_records:
|
|
|
|
|
return "API KEY 1"
|
|
|
|
|
|
|
|
|
|
# Extract numbers from API KEY pattern using list comprehension
|
|
|
|
|
pattern = re.compile(r"^API KEY\s+(\d+)$")
|
|
|
|
|
numbers = [
|
|
|
|
|
int(match.group(1))
|
|
|
|
|
for cr in credential_records
|
|
|
|
|
if cr.credential_name and (match := pattern.match(cr.credential_name.strip()))
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# Return next sequential number
|
|
|
|
|
next_number = max(numbers, default=0) + 1
|
|
|
|
|
return f"API KEY {next_number}"
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.warning("Error generating next credential name: %s", str(e))
|
|
|
|
|
return "API KEY 1"
|
|
|
|
|
|
|
|
|
|
def _get_provider_names(self):
|
|
|
|
|
"""
|
|
|
|
|
The provider name might be stored in the database as either `openai` or `langgenius/openai/openai`.
|
|
|
|
|
"""
|
|
|
|
|
model_provider_id = ModelProviderID(self.provider.provider)
|
|
|
|
|
provider_names = [self.provider.provider]
|
|
|
|
|
if model_provider_id.is_langgenius():
|
|
|
|
|
provider_names.append(model_provider_id.provider_name)
|
|
|
|
|
return provider_names
|
|
|
|
|
|
|
|
|
|
def create_provider_credential(self, credentials: dict, credential_name: str | None):
|
|
|
|
|
"""
|
|
|
|
|
Add custom provider credentials.
|
|
|
|
|
:param credentials: provider credentials
|
|
|
|
|
@ -351,8 +435,11 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
with Session(db.engine) as session:
|
|
|
|
|
if self._check_provider_credential_name_exists(credential_name=credential_name, session=session):
|
|
|
|
|
raise ValueError(f"Credential with name '{credential_name}' already exists.")
|
|
|
|
|
if credential_name:
|
|
|
|
|
if self._check_provider_credential_name_exists(credential_name=credential_name, session=session):
|
|
|
|
|
raise ValueError(f"Credential with name '{credential_name}' already exists.")
|
|
|
|
|
else:
|
|
|
|
|
credential_name = self._generate_provider_credential_name(session)
|
|
|
|
|
|
|
|
|
|
credentials = self.validate_provider_credentials(credentials=credentials, session=session)
|
|
|
|
|
provider_record = self._get_provider_record(session)
|
|
|
|
|
@ -395,8 +482,8 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
self,
|
|
|
|
|
credentials: dict,
|
|
|
|
|
credential_id: str,
|
|
|
|
|
credential_name: str,
|
|
|
|
|
) -> None:
|
|
|
|
|
credential_name: str | None,
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
update a saved provider credential (by credential_id).
|
|
|
|
|
|
|
|
|
|
@ -406,7 +493,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
with Session(db.engine) as session:
|
|
|
|
|
if self._check_provider_credential_name_exists(
|
|
|
|
|
if credential_name and self._check_provider_credential_name_exists(
|
|
|
|
|
credential_name=credential_name, session=session, exclude_id=credential_id
|
|
|
|
|
):
|
|
|
|
|
raise ValueError(f"Credential with name '{credential_name}' already exists.")
|
|
|
|
|
@ -418,7 +505,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
stmt = select(ProviderCredential).where(
|
|
|
|
|
ProviderCredential.id == credential_id,
|
|
|
|
|
ProviderCredential.tenant_id == self.tenant_id,
|
|
|
|
|
ProviderCredential.provider_name == self.provider.provider,
|
|
|
|
|
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Get the credential record to update
|
|
|
|
|
@ -428,9 +515,9 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
try:
|
|
|
|
|
# Update credential
|
|
|
|
|
credential_record.encrypted_config = json.dumps(credentials)
|
|
|
|
|
credential_record.credential_name = credential_name
|
|
|
|
|
credential_record.updated_at = naive_utc_now()
|
|
|
|
|
|
|
|
|
|
if credential_name:
|
|
|
|
|
credential_record.credential_name = credential_name
|
|
|
|
|
session.commit()
|
|
|
|
|
|
|
|
|
|
if provider_record and provider_record.credential_id == credential_id:
|
|
|
|
|
@ -457,7 +544,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
credential_record: ProviderCredential | ProviderModelCredential,
|
|
|
|
|
credential_source: str,
|
|
|
|
|
session: Session,
|
|
|
|
|
) -> None:
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
Update load balancing configurations that reference the given credential_id.
|
|
|
|
|
|
|
|
|
|
@ -471,7 +558,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
# Find all load balancing configs that use this credential_id
|
|
|
|
|
stmt = select(LoadBalancingModelConfig).where(
|
|
|
|
|
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
|
|
|
|
LoadBalancingModelConfig.provider_name == self.provider.provider,
|
|
|
|
|
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
|
|
|
|
|
LoadBalancingModelConfig.credential_id == credential_id,
|
|
|
|
|
LoadBalancingModelConfig.credential_source_type == credential_source,
|
|
|
|
|
)
|
|
|
|
|
@ -497,7 +584,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
|
|
|
|
|
session.commit()
|
|
|
|
|
|
|
|
|
|
def delete_provider_credential(self, credential_id: str) -> None:
|
|
|
|
|
def delete_provider_credential(self, credential_id: str):
|
|
|
|
|
"""
|
|
|
|
|
Delete a saved provider credential (by credential_id).
|
|
|
|
|
|
|
|
|
|
@ -508,7 +595,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
stmt = select(ProviderCredential).where(
|
|
|
|
|
ProviderCredential.id == credential_id,
|
|
|
|
|
ProviderCredential.tenant_id == self.tenant_id,
|
|
|
|
|
ProviderCredential.provider_name == self.provider.provider,
|
|
|
|
|
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Get the credential record to update
|
|
|
|
|
@ -519,7 +606,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
# Check if this credential is used in load balancing configs
|
|
|
|
|
lb_stmt = select(LoadBalancingModelConfig).where(
|
|
|
|
|
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
|
|
|
|
LoadBalancingModelConfig.provider_name == self.provider.provider,
|
|
|
|
|
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
|
|
|
|
|
LoadBalancingModelConfig.credential_id == credential_id,
|
|
|
|
|
LoadBalancingModelConfig.credential_source_type == "provider",
|
|
|
|
|
)
|
|
|
|
|
@ -532,13 +619,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
|
|
|
|
|
)
|
|
|
|
|
lb_credentials_cache.delete()
|
|
|
|
|
|
|
|
|
|
lb_config.credential_id = None
|
|
|
|
|
lb_config.encrypted_config = None
|
|
|
|
|
lb_config.enabled = False
|
|
|
|
|
lb_config.name = "__delete__"
|
|
|
|
|
lb_config.updated_at = naive_utc_now()
|
|
|
|
|
session.add(lb_config)
|
|
|
|
|
session.delete(lb_config)
|
|
|
|
|
|
|
|
|
|
# Check if this is the currently active credential
|
|
|
|
|
provider_record = self._get_provider_record(session)
|
|
|
|
|
@ -547,7 +628,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
# if this is the last credential, we need to delete the provider record
|
|
|
|
|
count_stmt = select(func.count(ProviderCredential.id)).where(
|
|
|
|
|
ProviderCredential.tenant_id == self.tenant_id,
|
|
|
|
|
ProviderCredential.provider_name == self.provider.provider,
|
|
|
|
|
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
|
|
|
|
)
|
|
|
|
|
available_credentials_count = session.execute(count_stmt).scalar() or 0
|
|
|
|
|
session.delete(credential_record)
|
|
|
|
|
@ -580,7 +661,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
session.rollback()
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
def switch_active_provider_credential(self, credential_id: str) -> None:
|
|
|
|
|
def switch_active_provider_credential(self, credential_id: str):
|
|
|
|
|
"""
|
|
|
|
|
Switch active provider credential (copy the selected one into current active snapshot).
|
|
|
|
|
|
|
|
|
|
@ -591,7 +672,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
stmt = select(ProviderCredential).where(
|
|
|
|
|
ProviderCredential.id == credential_id,
|
|
|
|
|
ProviderCredential.tenant_id == self.tenant_id,
|
|
|
|
|
ProviderCredential.provider_name == self.provider.provider,
|
|
|
|
|
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
|
|
|
|
)
|
|
|
|
|
credential_record = session.execute(stmt).scalar_one_or_none()
|
|
|
|
|
if not credential_record:
|
|
|
|
|
@ -627,6 +708,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
Get custom model credentials.
|
|
|
|
|
"""
|
|
|
|
|
# get provider model
|
|
|
|
|
|
|
|
|
|
model_provider_id = ModelProviderID(self.provider.provider)
|
|
|
|
|
provider_names = [self.provider.provider]
|
|
|
|
|
if model_provider_id.is_langgenius():
|
|
|
|
|
@ -659,7 +741,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
stmt = select(ProviderModelCredential).where(
|
|
|
|
|
ProviderModelCredential.id == credential_id,
|
|
|
|
|
ProviderModelCredential.tenant_id == self.tenant_id,
|
|
|
|
|
ProviderModelCredential.provider_name == self.provider.provider,
|
|
|
|
|
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
|
|
|
|
ProviderModelCredential.model_name == model,
|
|
|
|
|
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
|
|
|
|
)
|
|
|
|
|
@ -684,6 +766,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
|
|
|
|
|
current_credential_id = credential_record.id
|
|
|
|
|
current_credential_name = credential_record.credential_name
|
|
|
|
|
|
|
|
|
|
credentials = self.obfuscated_credentials(
|
|
|
|
|
credentials=credentials,
|
|
|
|
|
credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
|
|
|
|
|
@ -705,7 +788,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
"""
|
|
|
|
|
stmt = select(ProviderModelCredential).where(
|
|
|
|
|
ProviderModelCredential.tenant_id == self.tenant_id,
|
|
|
|
|
ProviderModelCredential.provider_name == self.provider.provider,
|
|
|
|
|
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
|
|
|
|
ProviderModelCredential.model_name == model,
|
|
|
|
|
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
|
|
|
|
ProviderModelCredential.credential_name == credential_name,
|
|
|
|
|
@ -714,9 +797,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
stmt = stmt.where(ProviderModelCredential.id != exclude_id)
|
|
|
|
|
return session.execute(stmt).scalar_one_or_none() is not None
|
|
|
|
|
|
|
|
|
|
def get_custom_model_credential(
|
|
|
|
|
self, model_type: ModelType, model: str, credential_id: str | None
|
|
|
|
|
) -> Optional[dict]:
|
|
|
|
|
def get_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str | None) -> dict | None:
|
|
|
|
|
"""
|
|
|
|
|
Get custom model credentials.
|
|
|
|
|
|
|
|
|
|
@ -738,6 +819,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
):
|
|
|
|
|
current_credential_id = model_configuration.current_credential_id
|
|
|
|
|
current_credential_name = model_configuration.current_credential_name
|
|
|
|
|
|
|
|
|
|
credentials = self.obfuscated_credentials(
|
|
|
|
|
credentials=model_configuration.credentials,
|
|
|
|
|
credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
|
|
|
|
|
@ -758,7 +840,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
credentials: dict,
|
|
|
|
|
credential_id: str = "",
|
|
|
|
|
session: Session | None = None,
|
|
|
|
|
) -> dict:
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
Validate custom model credentials.
|
|
|
|
|
|
|
|
|
|
@ -769,7 +851,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def _validate(s: Session) -> dict:
|
|
|
|
|
def _validate(s: Session):
|
|
|
|
|
# Get provider credential secret variables
|
|
|
|
|
provider_credential_secret_variables = self.extract_secret_variables(
|
|
|
|
|
self.provider.model_credential_schema.credential_form_schemas
|
|
|
|
|
@ -782,7 +864,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
stmt = select(ProviderModelCredential).where(
|
|
|
|
|
ProviderModelCredential.id == credential_id,
|
|
|
|
|
ProviderModelCredential.tenant_id == self.tenant_id,
|
|
|
|
|
ProviderModelCredential.provider_name == self.provider.provider,
|
|
|
|
|
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
|
|
|
|
ProviderModelCredential.model_name == model,
|
|
|
|
|
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
|
|
|
|
)
|
|
|
|
|
@ -822,7 +904,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
return _validate(new_session)
|
|
|
|
|
|
|
|
|
|
def create_custom_model_credential(
|
|
|
|
|
self, model_type: ModelType, model: str, credentials: dict, credential_name: str
|
|
|
|
|
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None
|
|
|
|
|
) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Create a custom model credential.
|
|
|
|
|
@ -833,10 +915,15 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
with Session(db.engine) as session:
|
|
|
|
|
if self._check_custom_model_credential_name_exists(
|
|
|
|
|
model=model, model_type=model_type, credential_name=credential_name, session=session
|
|
|
|
|
):
|
|
|
|
|
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
|
|
|
|
|
if credential_name:
|
|
|
|
|
if self._check_custom_model_credential_name_exists(
|
|
|
|
|
model=model, model_type=model_type, credential_name=credential_name, session=session
|
|
|
|
|
):
|
|
|
|
|
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
|
|
|
|
|
else:
|
|
|
|
|
credential_name = self._generate_custom_model_credential_name(
|
|
|
|
|
model=model, model_type=model_type, session=session
|
|
|
|
|
)
|
|
|
|
|
# validate custom model config
|
|
|
|
|
credentials = self.validate_custom_model_credentials(
|
|
|
|
|
model_type=model_type, model=model, credentials=credentials, session=session
|
|
|
|
|
@ -880,7 +967,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
def update_custom_model_credential(
|
|
|
|
|
self, model_type: ModelType, model: str, credentials: dict, credential_name: str, credential_id: str
|
|
|
|
|
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None, credential_id: str
|
|
|
|
|
) -> None:
|
|
|
|
|
"""
|
|
|
|
|
Update a custom model credential.
|
|
|
|
|
@ -893,7 +980,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
with Session(db.engine) as session:
|
|
|
|
|
if self._check_custom_model_credential_name_exists(
|
|
|
|
|
if credential_name and self._check_custom_model_credential_name_exists(
|
|
|
|
|
model=model,
|
|
|
|
|
model_type=model_type,
|
|
|
|
|
credential_name=credential_name,
|
|
|
|
|
@ -914,7 +1001,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
stmt = select(ProviderModelCredential).where(
|
|
|
|
|
ProviderModelCredential.id == credential_id,
|
|
|
|
|
ProviderModelCredential.tenant_id == self.tenant_id,
|
|
|
|
|
ProviderModelCredential.provider_name == self.provider.provider,
|
|
|
|
|
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
|
|
|
|
ProviderModelCredential.model_name == model,
|
|
|
|
|
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
|
|
|
|
)
|
|
|
|
|
@ -925,8 +1012,9 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
try:
|
|
|
|
|
# Update credential
|
|
|
|
|
credential_record.encrypted_config = json.dumps(credentials)
|
|
|
|
|
credential_record.credential_name = credential_name
|
|
|
|
|
credential_record.updated_at = naive_utc_now()
|
|
|
|
|
if credential_name:
|
|
|
|
|
credential_record.credential_name = credential_name
|
|
|
|
|
session.commit()
|
|
|
|
|
|
|
|
|
|
if provider_model_record and provider_model_record.credential_id == credential_id:
|
|
|
|
|
@ -947,7 +1035,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
session.rollback()
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
def delete_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str) -> None:
|
|
|
|
|
def delete_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str):
|
|
|
|
|
"""
|
|
|
|
|
Delete a saved provider credential (by credential_id).
|
|
|
|
|
|
|
|
|
|
@ -958,7 +1046,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
stmt = select(ProviderModelCredential).where(
|
|
|
|
|
ProviderModelCredential.id == credential_id,
|
|
|
|
|
ProviderModelCredential.tenant_id == self.tenant_id,
|
|
|
|
|
ProviderModelCredential.provider_name == self.provider.provider,
|
|
|
|
|
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
|
|
|
|
ProviderModelCredential.model_name == model,
|
|
|
|
|
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
|
|
|
|
)
|
|
|
|
|
@ -968,7 +1056,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
|
|
|
|
|
lb_stmt = select(LoadBalancingModelConfig).where(
|
|
|
|
|
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
|
|
|
|
LoadBalancingModelConfig.provider_name == self.provider.provider,
|
|
|
|
|
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
|
|
|
|
|
LoadBalancingModelConfig.credential_id == credential_id,
|
|
|
|
|
LoadBalancingModelConfig.credential_source_type == "custom_model",
|
|
|
|
|
)
|
|
|
|
|
@ -982,12 +1070,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
|
|
|
|
|
)
|
|
|
|
|
lb_credentials_cache.delete()
|
|
|
|
|
lb_config.credential_id = None
|
|
|
|
|
lb_config.encrypted_config = None
|
|
|
|
|
lb_config.enabled = False
|
|
|
|
|
lb_config.name = "__delete__"
|
|
|
|
|
lb_config.updated_at = naive_utc_now()
|
|
|
|
|
session.add(lb_config)
|
|
|
|
|
session.delete(lb_config)
|
|
|
|
|
|
|
|
|
|
# Check if this is the currently active credential
|
|
|
|
|
provider_model_record = self._get_custom_model_record(model_type, model, session=session)
|
|
|
|
|
@ -996,7 +1079,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
# if this is the last credential, we need to delete the custom model record
|
|
|
|
|
count_stmt = select(func.count(ProviderModelCredential.id)).where(
|
|
|
|
|
ProviderModelCredential.tenant_id == self.tenant_id,
|
|
|
|
|
ProviderModelCredential.provider_name == self.provider.provider,
|
|
|
|
|
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
|
|
|
|
ProviderModelCredential.model_name == model,
|
|
|
|
|
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
|
|
|
|
)
|
|
|
|
|
@ -1022,7 +1105,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
session.rollback()
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
def add_model_credential_to_model(self, model_type: ModelType, model: str, credential_id: str) -> None:
|
|
|
|
|
def add_model_credential_to_model(self, model_type: ModelType, model: str, credential_id: str):
|
|
|
|
|
"""
|
|
|
|
|
if model list exist this custom model, switch the custom model credential.
|
|
|
|
|
if model list not exist this custom model, use the credential to add a new custom model record.
|
|
|
|
|
@ -1036,7 +1119,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
stmt = select(ProviderModelCredential).where(
|
|
|
|
|
ProviderModelCredential.id == credential_id,
|
|
|
|
|
ProviderModelCredential.tenant_id == self.tenant_id,
|
|
|
|
|
ProviderModelCredential.provider_name == self.provider.provider,
|
|
|
|
|
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
|
|
|
|
ProviderModelCredential.model_name == model,
|
|
|
|
|
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
|
|
|
|
)
|
|
|
|
|
@ -1054,6 +1137,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
provider_name=self.provider.provider,
|
|
|
|
|
model_name=model,
|
|
|
|
|
model_type=model_type.to_origin_model_type(),
|
|
|
|
|
is_valid=True,
|
|
|
|
|
credential_id=credential_id,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
@ -1064,7 +1148,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
session.add(provider_model_record)
|
|
|
|
|
session.commit()
|
|
|
|
|
|
|
|
|
|
def switch_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str) -> None:
|
|
|
|
|
def switch_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str):
|
|
|
|
|
"""
|
|
|
|
|
switch the custom model credential.
|
|
|
|
|
|
|
|
|
|
@ -1077,7 +1161,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
stmt = select(ProviderModelCredential).where(
|
|
|
|
|
ProviderModelCredential.id == credential_id,
|
|
|
|
|
ProviderModelCredential.tenant_id == self.tenant_id,
|
|
|
|
|
ProviderModelCredential.provider_name == self.provider.provider,
|
|
|
|
|
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
|
|
|
|
ProviderModelCredential.model_name == model,
|
|
|
|
|
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
|
|
|
|
)
|
|
|
|
|
@ -1094,7 +1178,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
session.add(provider_model_record)
|
|
|
|
|
session.commit()
|
|
|
|
|
|
|
|
|
|
def delete_custom_model(self, model_type: ModelType, model: str) -> None:
|
|
|
|
|
def delete_custom_model(self, model_type: ModelType, model: str):
|
|
|
|
|
"""
|
|
|
|
|
Delete custom model.
|
|
|
|
|
:param model_type: model type
|
|
|
|
|
@ -1124,14 +1208,9 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
"""
|
|
|
|
|
Get provider model setting.
|
|
|
|
|
"""
|
|
|
|
|
model_provider_id = ModelProviderID(self.provider.provider)
|
|
|
|
|
provider_names = [self.provider.provider]
|
|
|
|
|
if model_provider_id.is_langgenius():
|
|
|
|
|
provider_names.append(model_provider_id.provider_name)
|
|
|
|
|
|
|
|
|
|
stmt = select(ProviderModelSetting).where(
|
|
|
|
|
ProviderModelSetting.tenant_id == self.tenant_id,
|
|
|
|
|
ProviderModelSetting.provider_name.in_(provider_names),
|
|
|
|
|
ProviderModelSetting.provider_name.in_(self._get_provider_names()),
|
|
|
|
|
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
|
|
|
|
ProviderModelSetting.model_name == model,
|
|
|
|
|
)
|
|
|
|
|
@ -1190,7 +1269,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
|
|
|
|
|
return model_setting
|
|
|
|
|
|
|
|
|
|
def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optional[ProviderModelSetting]:
|
|
|
|
|
def get_provider_model_setting(self, model_type: ModelType, model: str) -> ProviderModelSetting | None:
|
|
|
|
|
"""
|
|
|
|
|
Get provider model setting.
|
|
|
|
|
:param model_type: model type
|
|
|
|
|
@ -1207,6 +1286,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
:param model: model name
|
|
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
model_provider_id = ModelProviderID(self.provider.provider)
|
|
|
|
|
provider_names = [self.provider.provider]
|
|
|
|
|
if model_provider_id.is_langgenius():
|
|
|
|
|
@ -1289,7 +1369,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def switch_preferred_provider_type(self, provider_type: ProviderType, session: Session | None = None) -> None:
|
|
|
|
|
def switch_preferred_provider_type(self, provider_type: ProviderType, session: Session | None = None):
|
|
|
|
|
"""
|
|
|
|
|
Switch preferred provider type.
|
|
|
|
|
:param provider_type:
|
|
|
|
|
@ -1301,16 +1381,10 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
def _switch(s: Session) -> None:
|
|
|
|
|
# get preferred provider
|
|
|
|
|
model_provider_id = ModelProviderID(self.provider.provider)
|
|
|
|
|
provider_names = [self.provider.provider]
|
|
|
|
|
if model_provider_id.is_langgenius():
|
|
|
|
|
provider_names.append(model_provider_id.provider_name)
|
|
|
|
|
|
|
|
|
|
def _switch(s: Session):
|
|
|
|
|
stmt = select(TenantPreferredModelProvider).where(
|
|
|
|
|
TenantPreferredModelProvider.tenant_id == self.tenant_id,
|
|
|
|
|
TenantPreferredModelProvider.provider_name.in_(provider_names),
|
|
|
|
|
TenantPreferredModelProvider.provider_name.in_(self._get_provider_names()),
|
|
|
|
|
)
|
|
|
|
|
preferred_model_provider = s.execute(stmt).scalars().first()
|
|
|
|
|
|
|
|
|
|
@ -1340,12 +1414,12 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
"""
|
|
|
|
|
secret_input_form_variables = []
|
|
|
|
|
for credential_form_schema in credential_form_schemas:
|
|
|
|
|
if credential_form_schema.type == FormType.SECRET_INPUT:
|
|
|
|
|
if credential_form_schema.type.value == FormType.SECRET_INPUT.value:
|
|
|
|
|
secret_input_form_variables.append(credential_form_schema.variable)
|
|
|
|
|
|
|
|
|
|
return secret_input_form_variables
|
|
|
|
|
|
|
|
|
|
def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
|
|
|
|
|
def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]):
|
|
|
|
|
"""
|
|
|
|
|
Obfuscated credentials.
|
|
|
|
|
|
|
|
|
|
@ -1366,7 +1440,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
|
|
|
|
|
def get_provider_model(
|
|
|
|
|
self, model_type: ModelType, model: str, only_active: bool = False
|
|
|
|
|
) -> Optional[ModelWithProviderEntity]:
|
|
|
|
|
) -> ModelWithProviderEntity | None:
|
|
|
|
|
"""
|
|
|
|
|
Get provider model.
|
|
|
|
|
:param model_type: model type
|
|
|
|
|
@ -1383,7 +1457,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
def get_provider_models(
|
|
|
|
|
self, model_type: Optional[ModelType] = None, only_active: bool = False, model: Optional[str] = None
|
|
|
|
|
self, model_type: ModelType | None = None, only_active: bool = False, model: str | None = None
|
|
|
|
|
) -> list[ModelWithProviderEntity]:
|
|
|
|
|
"""
|
|
|
|
|
Get provider models.
|
|
|
|
|
@ -1567,7 +1641,7 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
model_types: Sequence[ModelType],
|
|
|
|
|
provider_schema: ProviderEntity,
|
|
|
|
|
model_setting_map: dict[ModelType, dict[str, ModelSettings]],
|
|
|
|
|
model: Optional[str] = None,
|
|
|
|
|
model: str | None = None,
|
|
|
|
|
) -> list[ModelWithProviderEntity]:
|
|
|
|
|
"""
|
|
|
|
|
Get custom provider models.
|
|
|
|
|
@ -1605,11 +1679,9 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
if config.credential_source_type != "custom_model"
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
if len(provider_model_lb_configs) > 1:
|
|
|
|
|
load_balancing_enabled = True
|
|
|
|
|
|
|
|
|
|
if any(config.name == "__delete__" for config in provider_model_lb_configs):
|
|
|
|
|
has_invalid_load_balancing_configs = True
|
|
|
|
|
load_balancing_enabled = model_setting.load_balancing_enabled
|
|
|
|
|
# when the user enable load_balancing but available configs are less than 2 display warning
|
|
|
|
|
has_invalid_load_balancing_configs = load_balancing_enabled and len(provider_model_lb_configs) < 2
|
|
|
|
|
|
|
|
|
|
provider_models.append(
|
|
|
|
|
ModelWithProviderEntity(
|
|
|
|
|
@ -1631,6 +1703,8 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
for model_configuration in self.custom_configuration.models:
|
|
|
|
|
if model_configuration.model_type not in model_types:
|
|
|
|
|
continue
|
|
|
|
|
if model_configuration.unadded_to_model_list:
|
|
|
|
|
continue
|
|
|
|
|
if model and model != model_configuration.model:
|
|
|
|
|
continue
|
|
|
|
|
try:
|
|
|
|
|
@ -1663,11 +1737,9 @@ class ProviderConfiguration(BaseModel):
|
|
|
|
|
if config.credential_source_type != "provider"
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
if len(custom_model_lb_configs) > 1:
|
|
|
|
|
load_balancing_enabled = True
|
|
|
|
|
|
|
|
|
|
if any(config.name == "__delete__" for config in custom_model_lb_configs):
|
|
|
|
|
has_invalid_load_balancing_configs = True
|
|
|
|
|
load_balancing_enabled = model_setting.load_balancing_enabled
|
|
|
|
|
# when the user enable load_balancing but available configs are less than 2 display warning
|
|
|
|
|
has_invalid_load_balancing_configs = load_balancing_enabled and len(custom_model_lb_configs) < 2
|
|
|
|
|
|
|
|
|
|
if len(model_configuration.available_model_credentials) > 0 and not model_configuration.credentials:
|
|
|
|
|
status = ModelStatus.CREDENTIAL_REMOVED
|
|
|
|
|
@ -1703,7 +1775,7 @@ class ProviderConfigurations(BaseModel):
|
|
|
|
|
super().__init__(tenant_id=tenant_id)
|
|
|
|
|
|
|
|
|
|
def get_models(
|
|
|
|
|
self, provider: Optional[str] = None, model_type: Optional[ModelType] = None, only_active: bool = False
|
|
|
|
|
self, provider: str | None = None, model_type: ModelType | None = None, only_active: bool = False
|
|
|
|
|
) -> list[ModelWithProviderEntity]:
|
|
|
|
|
"""
|
|
|
|
|
Get available models.
|
|
|
|
|
@ -1760,8 +1832,14 @@ class ProviderConfigurations(BaseModel):
|
|
|
|
|
def __setitem__(self, key, value):
|
|
|
|
|
self.configurations[key] = value
|
|
|
|
|
|
|
|
|
|
def __contains__(self, key):
|
|
|
|
|
if "/" not in key:
|
|
|
|
|
key = str(ModelProviderID(key))
|
|
|
|
|
return key in self.configurations
|
|
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
|
return iter(self.configurations)
|
|
|
|
|
# Return an iterator of (key, value) tuples to match BaseModel's __iter__
|
|
|
|
|
yield from self.configurations.items()
|
|
|
|
|
|
|
|
|
|
def values(self) -> Iterator[ProviderConfiguration]:
|
|
|
|
|
return iter(self.configurations.values())
|
|
|
|
|
|