Merge remote-tracking branch 'origin/main' into feat/trigger

This commit is contained in:
yessenia
2025-09-25 17:14:24 +08:00
3013 changed files with 148826 additions and 44294 deletions

View File

@ -1,8 +1,8 @@
from enum import Enum
from enum import StrEnum, auto
class PlanningStrategy(Enum):
ROUTER = "router"
REACT_ROUTER = "react_router"
REACT = "react"
FUNCTION_CALL = "function_call"
class PlanningStrategy(StrEnum):
ROUTER = auto()
REACT_ROUTER = auto()
REACT = auto()
FUNCTION_CALL = auto()

View File

@ -1,10 +1,10 @@
from enum import Enum
from enum import StrEnum, auto
class EmbeddingInputType(Enum):
class EmbeddingInputType(StrEnum):
"""
Enum for embedding input type.
"""
DOCUMENT = "document"
QUERY = "query"
DOCUMENT = auto()
QUERY = auto()

View File

@ -1,11 +1,9 @@
from typing import Optional
from pydantic import BaseModel
class PreviewDetail(BaseModel):
content: str
child_chunks: Optional[list[str]] = None
child_chunks: list[str] | None = None
class QAPreviewDetail(BaseModel):
@ -16,4 +14,28 @@ class QAPreviewDetail(BaseModel):
class IndexingEstimate(BaseModel):
total_segments: int
preview: list[PreviewDetail]
qa_preview: Optional[list[QAPreviewDetail]] = None
qa_preview: list[QAPreviewDetail] | None = None
class PipelineDataset(BaseModel):
id: str
name: str
description: str
chunk_structure: str
class PipelineDocument(BaseModel):
id: str
position: int
data_source_type: str
data_source_info: dict | None = None
name: str
indexing_status: str
error: str | None = None
enabled: bool
class PipelineGenerateResponse(BaseModel):
batch: str
dataset: PipelineDataset
documents: list[PipelineDocument]

View File

@ -1,6 +1,5 @@
from collections.abc import Sequence
from enum import Enum
from typing import Optional
from enum import StrEnum, auto
from pydantic import BaseModel, ConfigDict
@ -9,16 +8,16 @@ from core.model_runtime.entities.model_entities import ModelType, ProviderModel
from core.model_runtime.entities.provider_entities import ProviderEntity
class ModelStatus(Enum):
class ModelStatus(StrEnum):
"""
Enum class for model status.
"""
ACTIVE = "active"
ACTIVE = auto()
NO_CONFIGURE = "no-configure"
QUOTA_EXCEEDED = "quota-exceeded"
NO_PERMISSION = "no-permission"
DISABLED = "disabled"
DISABLED = auto()
CREDENTIAL_REMOVED = "credential-removed"
@ -29,11 +28,11 @@ class SimpleModelProviderEntity(BaseModel):
provider: str
label: I18nObject
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
icon_small: I18nObject | None = None
icon_large: I18nObject | None = None
supported_model_types: list[ModelType]
def __init__(self, provider_entity: ProviderEntity) -> None:
def __init__(self, provider_entity: ProviderEntity):
"""
Init simple provider.
@ -57,7 +56,7 @@ class ProviderModelWithStatusEntity(ProviderModel):
load_balancing_enabled: bool = False
has_invalid_load_balancing_configs: bool = False
def raise_for_status(self) -> None:
def raise_for_status(self):
"""
Check model status and raise ValueError if not active.
@ -92,8 +91,8 @@ class DefaultModelProviderEntity(BaseModel):
provider: str
label: I18nObject
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
icon_small: I18nObject | None = None
icon_large: I18nObject | None = None
supported_model_types: Sequence[ModelType] = []

View File

@ -1,20 +1,20 @@
from enum import StrEnum
from enum import StrEnum, auto
class CommonParameterType(StrEnum):
SECRET_INPUT = "secret-input"
TEXT_INPUT = "text-input"
SELECT = "select"
STRING = "string"
NUMBER = "number"
FILE = "file"
FILES = "files"
SELECT = auto()
STRING = auto()
NUMBER = auto()
FILE = auto()
FILES = auto()
SYSTEM_FILES = "system-files"
BOOLEAN = "boolean"
BOOLEAN = auto()
APP_SELECTOR = "app-selector"
MODEL_SELECTOR = "model-selector"
TOOLS_SELECTOR = "array[tools]"
ANY = "any"
ANY = auto()
# Dynamic select parameter
# Once you are not sure about the available options until authorization is done
@ -23,29 +23,29 @@ class CommonParameterType(StrEnum):
# TOOL_SELECTOR = "tool-selector"
# MCP object and array type parameters
ARRAY = "array"
OBJECT = "object"
ARRAY = auto()
OBJECT = auto()
class AppSelectorScope(StrEnum):
ALL = "all"
CHAT = "chat"
WORKFLOW = "workflow"
COMPLETION = "completion"
ALL = auto()
CHAT = auto()
WORKFLOW = auto()
COMPLETION = auto()
class ModelSelectorScope(StrEnum):
LLM = "llm"
LLM = auto()
TEXT_EMBEDDING = "text-embedding"
RERANK = "rerank"
TTS = "tts"
SPEECH2TEXT = "speech2text"
MODERATION = "moderation"
VISION = "vision"
RERANK = auto()
TTS = auto()
SPEECH2TEXT = auto()
MODERATION = auto()
VISION = auto()
class ToolSelectorScope(StrEnum):
ALL = "all"
CUSTOM = "custom"
BUILTIN = "builtin"
WORKFLOW = "workflow"
ALL = auto()
CUSTOM = auto()
BUILTIN = auto()
WORKFLOW = auto()

View File

@ -1,9 +1,9 @@
import json
import logging
import re
from collections import defaultdict
from collections.abc import Iterator, Sequence
from json import JSONDecodeError
from typing import Optional
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import func, select
@ -28,7 +28,6 @@ from core.model_runtime.entities.provider_entities import (
)
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.plugin.entities.plugin import ModelProviderID
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.provider import (
@ -41,6 +40,8 @@ from models.provider import (
ProviderType,
TenantPreferredModelProvider,
)
from models.provider_ids import ModelProviderID
from services.enterprise.plugin_manager_service import PluginCredentialType
logger = logging.getLogger(__name__)
@ -90,7 +91,7 @@ class ProviderConfiguration(BaseModel):
):
self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None:
"""
Get current credentials.
@ -128,18 +129,42 @@ class ProviderConfiguration(BaseModel):
return copy_credentials
else:
credentials = None
current_credential_id = None
if self.custom_configuration.models:
for model_configuration in self.custom_configuration.models:
if model_configuration.model_type == model_type and model_configuration.model == model:
credentials = model_configuration.credentials
current_credential_id = model_configuration.current_credential_id
break
if not credentials and self.custom_configuration.provider:
credentials = self.custom_configuration.provider.credentials
current_credential_id = self.custom_configuration.provider.current_credential_id
if current_credential_id:
from core.helper.credential_utils import check_credential_policy_compliance
check_credential_policy_compliance(
credential_id=current_credential_id,
provider=self.provider.provider,
credential_type=PluginCredentialType.MODEL,
)
else:
# no current credential id, check all available credentials
if self.custom_configuration.provider:
for credential_configuration in self.custom_configuration.provider.available_credentials:
from core.helper.credential_utils import check_credential_policy_compliance
check_credential_policy_compliance(
credential_id=credential_configuration.credential_id,
provider=self.provider.provider,
credential_type=PluginCredentialType.MODEL,
)
return credentials
def get_system_configuration_status(self) -> Optional[SystemConfigurationStatus]:
def get_system_configuration_status(self) -> SystemConfigurationStatus | None:
"""
Get system configuration status.
:return:
@ -180,16 +205,10 @@ class ProviderConfiguration(BaseModel):
"""
Get custom provider record.
"""
# get provider
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)
stmt = select(Provider).where(
Provider.tenant_id == self.tenant_id,
Provider.provider_type == ProviderType.CUSTOM.value,
Provider.provider_name.in_(provider_names),
Provider.provider_name.in_(self._get_provider_names()),
)
return session.execute(stmt).scalar_one_or_none()
@ -251,7 +270,7 @@ class ProviderConfiguration(BaseModel):
"""
stmt = select(ProviderCredential.id).where(
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider,
ProviderCredential.provider_name.in_(self._get_provider_names()),
ProviderCredential.credential_name == credential_name,
)
if exclude_id:
@ -265,7 +284,6 @@ class ProviderConfiguration(BaseModel):
:param credential_id: if provided, return the specified credential
:return:
"""
if credential_id:
return self._get_specific_provider_credential(credential_id)
@ -279,9 +297,7 @@ class ProviderConfiguration(BaseModel):
else [],
)
def validate_provider_credentials(
self, credentials: dict, credential_id: str = "", session: Session | None = None
) -> dict:
def validate_provider_credentials(self, credentials: dict, credential_id: str = "", session: Session | None = None):
"""
Validate custom credentials.
:param credentials: provider credentials
@ -290,7 +306,7 @@ class ProviderConfiguration(BaseModel):
:return:
"""
def _validate(s: Session) -> dict:
def _validate(s: Session):
# Get provider credential secret variables
provider_credential_secret_variables = self.extract_secret_variables(
self.provider.provider_credential_schema.credential_form_schemas
@ -302,7 +318,7 @@ class ProviderConfiguration(BaseModel):
try:
stmt = select(ProviderCredential).where(
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider,
ProviderCredential.provider_name.in_(self._get_provider_names()),
ProviderCredential.id == credential_id,
)
credential_record = s.execute(stmt).scalar_one_or_none()
@ -343,7 +359,75 @@ class ProviderConfiguration(BaseModel):
with Session(db.engine) as new_session:
return _validate(new_session)
def create_provider_credential(self, credentials: dict, credential_name: str) -> None:
def _generate_provider_credential_name(self, session) -> str:
"""
Generate a unique credential name for provider.
:return: credential name
"""
return self._generate_next_api_key_name(
session=session,
query_factory=lambda: select(ProviderCredential).where(
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name.in_(self._get_provider_names()),
),
)
def _generate_custom_model_credential_name(self, model: str, model_type: ModelType, session) -> str:
"""
Generate a unique credential name for custom model.
:return: credential name
"""
return self._generate_next_api_key_name(
session=session,
query_factory=lambda: select(ProviderModelCredential).where(
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
),
)
def _generate_next_api_key_name(self, session, query_factory) -> str:
"""
Generate next available API KEY name by finding the highest numbered suffix.
:param session: database session
:param query_factory: function that returns the SQLAlchemy query
:return: next available API KEY name
"""
try:
stmt = query_factory()
credential_records = session.execute(stmt).scalars().all()
if not credential_records:
return "API KEY 1"
# Extract numbers from API KEY pattern using list comprehension
pattern = re.compile(r"^API KEY\s+(\d+)$")
numbers = [
int(match.group(1))
for cr in credential_records
if cr.credential_name and (match := pattern.match(cr.credential_name.strip()))
]
# Return next sequential number
next_number = max(numbers, default=0) + 1
return f"API KEY {next_number}"
except Exception as e:
logger.warning("Error generating next credential name: %s", str(e))
return "API KEY 1"
def _get_provider_names(self):
"""
The provider name might be stored in the database as either `openai` or `langgenius/openai/openai`.
"""
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)
return provider_names
def create_provider_credential(self, credentials: dict, credential_name: str | None):
"""
Add custom provider credentials.
:param credentials: provider credentials
@ -351,8 +435,11 @@ class ProviderConfiguration(BaseModel):
:return:
"""
with Session(db.engine) as session:
if self._check_provider_credential_name_exists(credential_name=credential_name, session=session):
raise ValueError(f"Credential with name '{credential_name}' already exists.")
if credential_name:
if self._check_provider_credential_name_exists(credential_name=credential_name, session=session):
raise ValueError(f"Credential with name '{credential_name}' already exists.")
else:
credential_name = self._generate_provider_credential_name(session)
credentials = self.validate_provider_credentials(credentials=credentials, session=session)
provider_record = self._get_provider_record(session)
@ -395,8 +482,8 @@ class ProviderConfiguration(BaseModel):
self,
credentials: dict,
credential_id: str,
credential_name: str,
) -> None:
credential_name: str | None,
):
"""
update a saved provider credential (by credential_id).
@ -406,7 +493,7 @@ class ProviderConfiguration(BaseModel):
:return:
"""
with Session(db.engine) as session:
if self._check_provider_credential_name_exists(
if credential_name and self._check_provider_credential_name_exists(
credential_name=credential_name, session=session, exclude_id=credential_id
):
raise ValueError(f"Credential with name '{credential_name}' already exists.")
@ -418,7 +505,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderCredential).where(
ProviderCredential.id == credential_id,
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider,
ProviderCredential.provider_name.in_(self._get_provider_names()),
)
# Get the credential record to update
@ -428,9 +515,9 @@ class ProviderConfiguration(BaseModel):
try:
# Update credential
credential_record.encrypted_config = json.dumps(credentials)
credential_record.credential_name = credential_name
credential_record.updated_at = naive_utc_now()
if credential_name:
credential_record.credential_name = credential_name
session.commit()
if provider_record and provider_record.credential_id == credential_id:
@ -457,7 +544,7 @@ class ProviderConfiguration(BaseModel):
credential_record: ProviderCredential | ProviderModelCredential,
credential_source: str,
session: Session,
) -> None:
):
"""
Update load balancing configurations that reference the given credential_id.
@ -471,7 +558,7 @@ class ProviderConfiguration(BaseModel):
# Find all load balancing configs that use this credential_id
stmt = select(LoadBalancingModelConfig).where(
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name == self.provider.provider,
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
LoadBalancingModelConfig.credential_id == credential_id,
LoadBalancingModelConfig.credential_source_type == credential_source,
)
@ -497,7 +584,7 @@ class ProviderConfiguration(BaseModel):
session.commit()
def delete_provider_credential(self, credential_id: str) -> None:
def delete_provider_credential(self, credential_id: str):
"""
Delete a saved provider credential (by credential_id).
@ -508,7 +595,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderCredential).where(
ProviderCredential.id == credential_id,
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider,
ProviderCredential.provider_name.in_(self._get_provider_names()),
)
# Get the credential record to update
@ -519,7 +606,7 @@ class ProviderConfiguration(BaseModel):
# Check if this credential is used in load balancing configs
lb_stmt = select(LoadBalancingModelConfig).where(
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name == self.provider.provider,
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
LoadBalancingModelConfig.credential_id == credential_id,
LoadBalancingModelConfig.credential_source_type == "provider",
)
@ -532,13 +619,7 @@ class ProviderConfiguration(BaseModel):
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
)
lb_credentials_cache.delete()
lb_config.credential_id = None
lb_config.encrypted_config = None
lb_config.enabled = False
lb_config.name = "__delete__"
lb_config.updated_at = naive_utc_now()
session.add(lb_config)
session.delete(lb_config)
# Check if this is the currently active credential
provider_record = self._get_provider_record(session)
@ -547,7 +628,7 @@ class ProviderConfiguration(BaseModel):
# if this is the last credential, we need to delete the provider record
count_stmt = select(func.count(ProviderCredential.id)).where(
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider,
ProviderCredential.provider_name.in_(self._get_provider_names()),
)
available_credentials_count = session.execute(count_stmt).scalar() or 0
session.delete(credential_record)
@ -580,7 +661,7 @@ class ProviderConfiguration(BaseModel):
session.rollback()
raise
def switch_active_provider_credential(self, credential_id: str) -> None:
def switch_active_provider_credential(self, credential_id: str):
"""
Switch active provider credential (copy the selected one into current active snapshot).
@ -591,7 +672,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderCredential).where(
ProviderCredential.id == credential_id,
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider,
ProviderCredential.provider_name.in_(self._get_provider_names()),
)
credential_record = session.execute(stmt).scalar_one_or_none()
if not credential_record:
@ -627,6 +708,7 @@ class ProviderConfiguration(BaseModel):
Get custom model credentials.
"""
# get provider model
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
@ -659,7 +741,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
@ -684,6 +766,7 @@ class ProviderConfiguration(BaseModel):
current_credential_id = credential_record.id
current_credential_name = credential_record.credential_name
credentials = self.obfuscated_credentials(
credentials=credentials,
credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
@ -705,7 +788,7 @@ class ProviderConfiguration(BaseModel):
"""
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
ProviderModelCredential.credential_name == credential_name,
@ -714,9 +797,7 @@ class ProviderConfiguration(BaseModel):
stmt = stmt.where(ProviderModelCredential.id != exclude_id)
return session.execute(stmt).scalar_one_or_none() is not None
def get_custom_model_credential(
self, model_type: ModelType, model: str, credential_id: str | None
) -> Optional[dict]:
def get_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str | None) -> dict | None:
"""
Get custom model credentials.
@ -738,6 +819,7 @@ class ProviderConfiguration(BaseModel):
):
current_credential_id = model_configuration.current_credential_id
current_credential_name = model_configuration.current_credential_name
credentials = self.obfuscated_credentials(
credentials=model_configuration.credentials,
credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
@ -758,7 +840,7 @@ class ProviderConfiguration(BaseModel):
credentials: dict,
credential_id: str = "",
session: Session | None = None,
) -> dict:
):
"""
Validate custom model credentials.
@ -769,7 +851,7 @@ class ProviderConfiguration(BaseModel):
:return:
"""
def _validate(s: Session) -> dict:
def _validate(s: Session):
# Get provider credential secret variables
provider_credential_secret_variables = self.extract_secret_variables(
self.provider.model_credential_schema.credential_form_schemas
@ -782,7 +864,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
@ -822,7 +904,7 @@ class ProviderConfiguration(BaseModel):
return _validate(new_session)
def create_custom_model_credential(
self, model_type: ModelType, model: str, credentials: dict, credential_name: str
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None
) -> None:
"""
Create a custom model credential.
@ -833,10 +915,15 @@ class ProviderConfiguration(BaseModel):
:return:
"""
with Session(db.engine) as session:
if self._check_custom_model_credential_name_exists(
model=model, model_type=model_type, credential_name=credential_name, session=session
):
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
if credential_name:
if self._check_custom_model_credential_name_exists(
model=model, model_type=model_type, credential_name=credential_name, session=session
):
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
else:
credential_name = self._generate_custom_model_credential_name(
model=model, model_type=model_type, session=session
)
# validate custom model config
credentials = self.validate_custom_model_credentials(
model_type=model_type, model=model, credentials=credentials, session=session
@ -880,7 +967,7 @@ class ProviderConfiguration(BaseModel):
raise
def update_custom_model_credential(
self, model_type: ModelType, model: str, credentials: dict, credential_name: str, credential_id: str
self, model_type: ModelType, model: str, credentials: dict, credential_name: str | None, credential_id: str
) -> None:
"""
Update a custom model credential.
@ -893,7 +980,7 @@ class ProviderConfiguration(BaseModel):
:return:
"""
with Session(db.engine) as session:
if self._check_custom_model_credential_name_exists(
if credential_name and self._check_custom_model_credential_name_exists(
model=model,
model_type=model_type,
credential_name=credential_name,
@ -914,7 +1001,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
@ -925,8 +1012,9 @@ class ProviderConfiguration(BaseModel):
try:
# Update credential
credential_record.encrypted_config = json.dumps(credentials)
credential_record.credential_name = credential_name
credential_record.updated_at = naive_utc_now()
if credential_name:
credential_record.credential_name = credential_name
session.commit()
if provider_model_record and provider_model_record.credential_id == credential_id:
@ -947,7 +1035,7 @@ class ProviderConfiguration(BaseModel):
session.rollback()
raise
def delete_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str) -> None:
def delete_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str):
"""
Delete a saved provider credential (by credential_id).
@ -958,7 +1046,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
@ -968,7 +1056,7 @@ class ProviderConfiguration(BaseModel):
lb_stmt = select(LoadBalancingModelConfig).where(
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name == self.provider.provider,
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
LoadBalancingModelConfig.credential_id == credential_id,
LoadBalancingModelConfig.credential_source_type == "custom_model",
)
@ -982,12 +1070,7 @@ class ProviderConfiguration(BaseModel):
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL,
)
lb_credentials_cache.delete()
lb_config.credential_id = None
lb_config.encrypted_config = None
lb_config.enabled = False
lb_config.name = "__delete__"
lb_config.updated_at = naive_utc_now()
session.add(lb_config)
session.delete(lb_config)
# Check if this is the currently active credential
provider_model_record = self._get_custom_model_record(model_type, model, session=session)
@ -996,7 +1079,7 @@ class ProviderConfiguration(BaseModel):
# if this is the last credential, we need to delete the custom model record
count_stmt = select(func.count(ProviderModelCredential.id)).where(
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
@ -1022,7 +1105,7 @@ class ProviderConfiguration(BaseModel):
session.rollback()
raise
def add_model_credential_to_model(self, model_type: ModelType, model: str, credential_id: str) -> None:
def add_model_credential_to_model(self, model_type: ModelType, model: str, credential_id: str):
"""
if model list exist this custom model, switch the custom model credential.
if model list not exist this custom model, use the credential to add a new custom model record.
@ -1036,7 +1119,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
@ -1054,6 +1137,7 @@ class ProviderConfiguration(BaseModel):
provider_name=self.provider.provider,
model_name=model,
model_type=model_type.to_origin_model_type(),
is_valid=True,
credential_id=credential_id,
)
else:
@ -1064,7 +1148,7 @@ class ProviderConfiguration(BaseModel):
session.add(provider_model_record)
session.commit()
def switch_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str) -> None:
def switch_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str):
"""
switch the custom model credential.
@ -1077,7 +1161,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
@ -1094,7 +1178,7 @@ class ProviderConfiguration(BaseModel):
session.add(provider_model_record)
session.commit()
def delete_custom_model(self, model_type: ModelType, model: str) -> None:
def delete_custom_model(self, model_type: ModelType, model: str):
"""
Delete custom model.
:param model_type: model type
@ -1124,14 +1208,9 @@ class ProviderConfiguration(BaseModel):
"""
Get provider model setting.
"""
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)
stmt = select(ProviderModelSetting).where(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name.in_(provider_names),
ProviderModelSetting.provider_name.in_(self._get_provider_names()),
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model,
)
@ -1190,7 +1269,7 @@ class ProviderConfiguration(BaseModel):
return model_setting
def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optional[ProviderModelSetting]:
def get_provider_model_setting(self, model_type: ModelType, model: str) -> ProviderModelSetting | None:
"""
Get provider model setting.
:param model_type: model type
@ -1207,6 +1286,7 @@ class ProviderConfiguration(BaseModel):
:param model: model name
:return:
"""
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
@ -1289,7 +1369,7 @@ class ProviderConfiguration(BaseModel):
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
)
def switch_preferred_provider_type(self, provider_type: ProviderType, session: Session | None = None) -> None:
def switch_preferred_provider_type(self, provider_type: ProviderType, session: Session | None = None):
"""
Switch preferred provider type.
:param provider_type:
@ -1301,16 +1381,10 @@ class ProviderConfiguration(BaseModel):
if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
return
def _switch(s: Session) -> None:
# get preferred provider
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)
def _switch(s: Session):
stmt = select(TenantPreferredModelProvider).where(
TenantPreferredModelProvider.tenant_id == self.tenant_id,
TenantPreferredModelProvider.provider_name.in_(provider_names),
TenantPreferredModelProvider.provider_name.in_(self._get_provider_names()),
)
preferred_model_provider = s.execute(stmt).scalars().first()
@ -1340,12 +1414,12 @@ class ProviderConfiguration(BaseModel):
"""
secret_input_form_variables = []
for credential_form_schema in credential_form_schemas:
if credential_form_schema.type == FormType.SECRET_INPUT:
if credential_form_schema.type.value == FormType.SECRET_INPUT.value:
secret_input_form_variables.append(credential_form_schema.variable)
return secret_input_form_variables
def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]):
"""
Obfuscated credentials.
@ -1366,7 +1440,7 @@ class ProviderConfiguration(BaseModel):
def get_provider_model(
self, model_type: ModelType, model: str, only_active: bool = False
) -> Optional[ModelWithProviderEntity]:
) -> ModelWithProviderEntity | None:
"""
Get provider model.
:param model_type: model type
@ -1383,7 +1457,7 @@ class ProviderConfiguration(BaseModel):
return None
def get_provider_models(
self, model_type: Optional[ModelType] = None, only_active: bool = False, model: Optional[str] = None
self, model_type: ModelType | None = None, only_active: bool = False, model: str | None = None
) -> list[ModelWithProviderEntity]:
"""
Get provider models.
@ -1567,7 +1641,7 @@ class ProviderConfiguration(BaseModel):
model_types: Sequence[ModelType],
provider_schema: ProviderEntity,
model_setting_map: dict[ModelType, dict[str, ModelSettings]],
model: Optional[str] = None,
model: str | None = None,
) -> list[ModelWithProviderEntity]:
"""
Get custom provider models.
@ -1605,11 +1679,9 @@ class ProviderConfiguration(BaseModel):
if config.credential_source_type != "custom_model"
]
if len(provider_model_lb_configs) > 1:
load_balancing_enabled = True
if any(config.name == "__delete__" for config in provider_model_lb_configs):
has_invalid_load_balancing_configs = True
load_balancing_enabled = model_setting.load_balancing_enabled
# when the user enable load_balancing but available configs are less than 2 display warning
has_invalid_load_balancing_configs = load_balancing_enabled and len(provider_model_lb_configs) < 2
provider_models.append(
ModelWithProviderEntity(
@ -1631,6 +1703,8 @@ class ProviderConfiguration(BaseModel):
for model_configuration in self.custom_configuration.models:
if model_configuration.model_type not in model_types:
continue
if model_configuration.unadded_to_model_list:
continue
if model and model != model_configuration.model:
continue
try:
@ -1663,11 +1737,9 @@ class ProviderConfiguration(BaseModel):
if config.credential_source_type != "provider"
]
if len(custom_model_lb_configs) > 1:
load_balancing_enabled = True
if any(config.name == "__delete__" for config in custom_model_lb_configs):
has_invalid_load_balancing_configs = True
load_balancing_enabled = model_setting.load_balancing_enabled
# when the user enable load_balancing but available configs are less than 2 display warning
has_invalid_load_balancing_configs = load_balancing_enabled and len(custom_model_lb_configs) < 2
if len(model_configuration.available_model_credentials) > 0 and not model_configuration.credentials:
status = ModelStatus.CREDENTIAL_REMOVED
@ -1703,7 +1775,7 @@ class ProviderConfigurations(BaseModel):
super().__init__(tenant_id=tenant_id)
def get_models(
self, provider: Optional[str] = None, model_type: Optional[ModelType] = None, only_active: bool = False
self, provider: str | None = None, model_type: ModelType | None = None, only_active: bool = False
) -> list[ModelWithProviderEntity]:
"""
Get available models.
@ -1760,8 +1832,14 @@ class ProviderConfigurations(BaseModel):
def __setitem__(self, key, value):
self.configurations[key] = value
def __contains__(self, key):
if "/" not in key:
key = str(ModelProviderID(key))
return key in self.configurations
def __iter__(self):
return iter(self.configurations)
# Return an iterator of (key, value) tuples to match BaseModel's __iter__
yield from self.configurations.items()
def values(self) -> Iterator[ProviderConfiguration]:
return iter(self.configurations.values())

View File

@ -1,5 +1,5 @@
from enum import Enum
from typing import Optional, Union
from enum import StrEnum, auto
from typing import Union
from pydantic import BaseModel, ConfigDict, Field
@ -13,14 +13,14 @@ from core.model_runtime.entities.model_entities import ModelType
from core.tools.entities.common_entities import I18nObject
class ProviderQuotaType(Enum):
PAID = "paid"
class ProviderQuotaType(StrEnum):
PAID = auto()
"""hosted paid quota"""
FREE = "free"
FREE = auto()
"""third-party free quota"""
TRIAL = "trial"
TRIAL = auto()
"""hosted trial quota"""
@staticmethod
@ -31,25 +31,25 @@ class ProviderQuotaType(Enum):
raise ValueError(f"No matching enum found for value '{value}'")
class QuotaUnit(Enum):
TIMES = "times"
TOKENS = "tokens"
CREDITS = "credits"
class QuotaUnit(StrEnum):
TIMES = auto()
TOKENS = auto()
CREDITS = auto()
class SystemConfigurationStatus(Enum):
class SystemConfigurationStatus(StrEnum):
"""
Enum class for system configuration status.
"""
ACTIVE = "active"
ACTIVE = auto()
QUOTA_EXCEEDED = "quota-exceeded"
UNSUPPORTED = "unsupported"
UNSUPPORTED = auto()
class RestrictModel(BaseModel):
model: str
base_model_name: Optional[str] = None
base_model_name: str | None = None
model_type: ModelType
# pydantic configs
@ -111,11 +111,21 @@ class CustomModelConfiguration(BaseModel):
current_credential_id: Optional[str] = None
current_credential_name: Optional[str] = None
available_model_credentials: list[CredentialConfiguration] = []
unadded_to_model_list: bool | None = False
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
class UnaddedModelConfiguration(BaseModel):
"""
Model class for provider unadded model configuration.
"""
model: str
model_type: ModelType
class CustomConfiguration(BaseModel):
"""
Model class for provider custom configuration.
@ -123,6 +133,7 @@ class CustomConfiguration(BaseModel):
provider: Optional[CustomProviderConfiguration] = None
models: list[CustomModelConfiguration] = []
can_added_models: list[UnaddedModelConfiguration] = []
class ModelLoadBalancingConfiguration(BaseModel):
@ -134,6 +145,7 @@ class ModelLoadBalancingConfiguration(BaseModel):
name: str
credentials: dict
credential_source_type: str | None = None
credential_id: str | None = None
class ModelSettings(BaseModel):
@ -144,6 +156,7 @@ class ModelSettings(BaseModel):
model: str
model_type: ModelType
enabled: bool = True
load_balancing_enabled: bool = False
load_balancing_configs: list[ModelLoadBalancingConfiguration] = []
# pydantic configs
@ -155,14 +168,14 @@ class BasicProviderConfig(BaseModel):
Base model class for common provider settings like credentials
"""
class Type(Enum):
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
TEXT_INPUT = CommonParameterType.TEXT_INPUT.value
SELECT = CommonParameterType.SELECT.value
BOOLEAN = CommonParameterType.BOOLEAN.value
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value
class Type(StrEnum):
SECRET_INPUT = CommonParameterType.SECRET_INPUT
TEXT_INPUT = CommonParameterType.TEXT_INPUT
SELECT = CommonParameterType.SELECT
BOOLEAN = CommonParameterType.BOOLEAN
APP_SELECTOR = CommonParameterType.APP_SELECTOR
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR
@classmethod
def value_of(cls, value: str) -> "ProviderConfig.Type":
@ -192,13 +205,13 @@ class ProviderConfig(BasicProviderConfig):
scope: AppSelectorScope | ModelSelectorScope | ToolSelectorScope | None = None
required: bool = False
default: Optional[Union[int, str, float, bool, list]] = None
options: Optional[list[Option]] = None
default: Union[int, str, float, bool] | None = None
options: list[Option] | None = None
multiple: bool | None = False
label: Optional[I18nObject] = None
help: Optional[I18nObject] = None
url: Optional[str] = None
placeholder: Optional[I18nObject] = None
label: I18nObject | None = None
help: I18nObject | None = None
url: str | None = None
placeholder: I18nObject | None = None
def to_basic_provider_config(self) -> BasicProviderConfig:
return BasicProviderConfig(type=self.type, name=self.name)