Compare commits

..

1 Commits

Author SHA1 Message Date
521545d52e fix: migrate model type enum construction 2026-05-23 18:56:48 +08:00
4 changed files with 23 additions and 34 deletions

View File

@ -795,7 +795,7 @@ class ProviderManager:
return [
{
"model": model_key[0],
"model_type": ModelType.value_of(model_key[1]),
"model_type": ModelType(model_key[1]),
"available_model_credentials": [
CredentialConfiguration(credential_id=cred.id, credential_name=cred.credential_name)
for cred in creds

View File

@ -1,6 +1,6 @@
from typing import Any, Union
from typing import Union
from pydantic import BaseModel, field_validator
from pydantic import BaseModel
from core.rag.entities import RerankingModelConfig, WeightedScoreConfig
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
@ -101,14 +101,3 @@ class KnowledgeIndexNodeData(BaseNodeData):
index_chunk_variable_selector: list[str]
indexing_technique: str | None = None
summary_index_setting: SummaryIndexSettingDict | None = None
@field_validator("summary_index_setting", mode="before")
@classmethod
def normalize_summary_index_setting(cls, v: Any) -> Any:
"""Treat dicts with enable=None (or missing enable) as None (#36233)."""
if v is None:
return None
if isinstance(v, dict):
if v.get("enable") is None:
return None
return v

View File

@ -66,7 +66,7 @@ class ModelLoadBalancingService:
raise ValueError(f"Provider {provider} does not exist.")
# Enable model load balancing
provider_configuration.enable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type))
provider_configuration.enable_model_load_balancing(model=model, model_type=ModelType(model_type))
def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str):
"""
@ -87,7 +87,7 @@ class ModelLoadBalancingService:
raise ValueError(f"Provider {provider} does not exist.")
# disable model load balancing
provider_configuration.disable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type))
provider_configuration.disable_model_load_balancing(model=model, model_type=ModelType(model_type))
def get_load_balancing_configs(
self, tenant_id: str, provider: str, model: str, model_type: str, config_from: str = ""
@ -109,7 +109,7 @@ class ModelLoadBalancingService:
raise ValueError(f"Provider {provider} does not exist.")
# Convert model type to ModelType
model_type_enum = ModelType.value_of(model_type)
model_type_enum = ModelType(model_type)
# Get provider model setting
provider_model_setting = provider_configuration.get_provider_model_setting(
@ -250,7 +250,7 @@ class ModelLoadBalancingService:
raise ValueError(f"Provider {provider} does not exist.")
# Convert model type to ModelType
model_type_enum = ModelType.value_of(model_type)
model_type_enum = ModelType(model_type)
# Get load balancing configurations
load_balancing_model_config = db.session.scalar(
@ -338,7 +338,7 @@ class ModelLoadBalancingService:
raise ValueError(f"Provider {provider} does not exist.")
# Convert model type to ModelType
model_type_enum = ModelType.value_of(model_type)
model_type_enum = ModelType(model_type)
if not isinstance(configs, list):
raise ValueError("Invalid load balancing configs")
@ -524,7 +524,7 @@ class ModelLoadBalancingService:
raise ValueError(f"Provider {provider} does not exist.")
# Convert model type to ModelType
model_type_enum = ModelType.value_of(model_type)
model_type_enum = ModelType(model_type)
load_balancing_model_config = None
if config_id:

View File

@ -67,7 +67,7 @@ class ModelProviderService:
provider_responses = []
for provider_configuration in provider_configurations.values():
if model_type:
model_type_entity = ModelType.value_of(model_type)
model_type_entity = ModelType(model_type)
if model_type_entity not in provider_configuration.provider.supported_model_types:
continue
@ -269,7 +269,7 @@ class ModelProviderService:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
return provider_configuration.get_custom_model_credential(
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
model_type=ModelType(model_type), model=model, credential_id=credential_id
)
def validate_model_credentials(
@ -287,7 +287,7 @@ class ModelProviderService:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.validate_custom_model_credentials(
model_type=ModelType.value_of(model_type), model=model, credentials=credentials
model_type=ModelType(model_type), model=model, credentials=credentials
)
def create_model_credential(
@ -312,7 +312,7 @@ class ModelProviderService:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.create_custom_model_credential(
model_type=ModelType.value_of(model_type),
model_type=ModelType(model_type),
model=model,
credentials=credentials,
credential_name=credential_name,
@ -342,7 +342,7 @@ class ModelProviderService:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.update_custom_model_credential(
model_type=ModelType.value_of(model_type),
model_type=ModelType(model_type),
model=model,
credentials=credentials,
credential_id=credential_id,
@ -362,7 +362,7 @@ class ModelProviderService:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.delete_custom_model_credential(
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
model_type=ModelType(model_type), model=model, credential_id=credential_id
)
def switch_active_custom_model_credential(
@ -380,7 +380,7 @@ class ModelProviderService:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.switch_custom_model_credential(
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
model_type=ModelType(model_type), model=model, credential_id=credential_id
)
def add_model_credential_to_model_list(
@ -398,7 +398,7 @@ class ModelProviderService:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.add_model_credential_to_model(
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
model_type=ModelType(model_type), model=model, credential_id=credential_id
)
def remove_model(self, tenant_id: str, provider: str, model_type: str, model: str):
@ -412,7 +412,7 @@ class ModelProviderService:
:return:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.delete_custom_model(model_type=ModelType.value_of(model_type), model=model)
provider_configuration.delete_custom_model(model_type=ModelType(model_type), model=model)
def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[ProviderWithModelsResponse]:
"""
@ -426,7 +426,7 @@ class ModelProviderService:
provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id)
# Get provider available models
models = provider_configurations.get_models(model_type=ModelType.value_of(model_type), only_active=True)
models = provider_configurations.get_models(model_type=ModelType(model_type), only_active=True)
# Group models by provider
provider_models: dict[str, list[ModelWithProviderEntity]] = {}
@ -505,7 +505,7 @@ class ModelProviderService:
:param model_type: model type
:return:
"""
model_type_enum = ModelType.value_of(model_type)
model_type_enum = ModelType(model_type)
try:
result = self._get_provider_manager(tenant_id).get_default_model(
@ -540,7 +540,7 @@ class ModelProviderService:
:param model: model name
:return:
"""
model_type_enum = ModelType.value_of(model_type)
model_type_enum = ModelType(model_type)
self._get_provider_manager(tenant_id).update_default_model_record(
tenant_id=tenant_id, model_type=model_type_enum, provider=provider, model=model
)
@ -590,7 +590,7 @@ class ModelProviderService:
:return:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.enable_model(model=model, model_type=ModelType.value_of(model_type))
provider_configuration.enable_model(model=model, model_type=ModelType(model_type))
def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str):
"""
@ -603,4 +603,4 @@ class ModelProviderService:
:return:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.disable_model(model=model, model_type=ModelType.value_of(model_type))
provider_configuration.disable_model(model=model, model_type=ModelType(model_type))