mirror of
https://github.com/langgenius/dify.git
synced 2026-05-24 10:57:52 +08:00
Compare commits
1 Commits
main
...
fix/model-
| Author | SHA1 | Date | |
|---|---|---|---|
| 521545d52e |
@ -795,7 +795,7 @@ class ProviderManager:
|
||||
return [
|
||||
{
|
||||
"model": model_key[0],
|
||||
"model_type": ModelType.value_of(model_key[1]),
|
||||
"model_type": ModelType(model_key[1]),
|
||||
"available_model_credentials": [
|
||||
CredentialConfiguration(credential_id=cred.id, credential_name=cred.credential_name)
|
||||
for cred in creds
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from typing import Any, Union
|
||||
from typing import Union
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.rag.entities import RerankingModelConfig, WeightedScoreConfig
|
||||
from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict
|
||||
@ -101,14 +101,3 @@ class KnowledgeIndexNodeData(BaseNodeData):
|
||||
index_chunk_variable_selector: list[str]
|
||||
indexing_technique: str | None = None
|
||||
summary_index_setting: SummaryIndexSettingDict | None = None
|
||||
|
||||
@field_validator("summary_index_setting", mode="before")
|
||||
@classmethod
|
||||
def normalize_summary_index_setting(cls, v: Any) -> Any:
|
||||
"""Treat dicts with enable=None (or missing enable) as None (#36233)."""
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, dict):
|
||||
if v.get("enable") is None:
|
||||
return None
|
||||
return v
|
||||
|
||||
@ -66,7 +66,7 @@ class ModelLoadBalancingService:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Enable model load balancing
|
||||
provider_configuration.enable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type))
|
||||
provider_configuration.enable_model_load_balancing(model=model, model_type=ModelType(model_type))
|
||||
|
||||
def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str):
|
||||
"""
|
||||
@ -87,7 +87,7 @@ class ModelLoadBalancingService:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# disable model load balancing
|
||||
provider_configuration.disable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type))
|
||||
provider_configuration.disable_model_load_balancing(model=model, model_type=ModelType(model_type))
|
||||
|
||||
def get_load_balancing_configs(
|
||||
self, tenant_id: str, provider: str, model: str, model_type: str, config_from: str = ""
|
||||
@ -109,7 +109,7 @@ class ModelLoadBalancingService:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Convert model type to ModelType
|
||||
model_type_enum = ModelType.value_of(model_type)
|
||||
model_type_enum = ModelType(model_type)
|
||||
|
||||
# Get provider model setting
|
||||
provider_model_setting = provider_configuration.get_provider_model_setting(
|
||||
@ -250,7 +250,7 @@ class ModelLoadBalancingService:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Convert model type to ModelType
|
||||
model_type_enum = ModelType.value_of(model_type)
|
||||
model_type_enum = ModelType(model_type)
|
||||
|
||||
# Get load balancing configurations
|
||||
load_balancing_model_config = db.session.scalar(
|
||||
@ -338,7 +338,7 @@ class ModelLoadBalancingService:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Convert model type to ModelType
|
||||
model_type_enum = ModelType.value_of(model_type)
|
||||
model_type_enum = ModelType(model_type)
|
||||
|
||||
if not isinstance(configs, list):
|
||||
raise ValueError("Invalid load balancing configs")
|
||||
@ -524,7 +524,7 @@ class ModelLoadBalancingService:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Convert model type to ModelType
|
||||
model_type_enum = ModelType.value_of(model_type)
|
||||
model_type_enum = ModelType(model_type)
|
||||
|
||||
load_balancing_model_config = None
|
||||
if config_id:
|
||||
|
||||
@ -67,7 +67,7 @@ class ModelProviderService:
|
||||
provider_responses = []
|
||||
for provider_configuration in provider_configurations.values():
|
||||
if model_type:
|
||||
model_type_entity = ModelType.value_of(model_type)
|
||||
model_type_entity = ModelType(model_type)
|
||||
if model_type_entity not in provider_configuration.provider.supported_model_types:
|
||||
continue
|
||||
|
||||
@ -269,7 +269,7 @@ class ModelProviderService:
|
||||
"""
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
return provider_configuration.get_custom_model_credential(
|
||||
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
|
||||
model_type=ModelType(model_type), model=model, credential_id=credential_id
|
||||
)
|
||||
|
||||
def validate_model_credentials(
|
||||
@ -287,7 +287,7 @@ class ModelProviderService:
|
||||
"""
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.validate_custom_model_credentials(
|
||||
model_type=ModelType.value_of(model_type), model=model, credentials=credentials
|
||||
model_type=ModelType(model_type), model=model, credentials=credentials
|
||||
)
|
||||
|
||||
def create_model_credential(
|
||||
@ -312,7 +312,7 @@ class ModelProviderService:
|
||||
"""
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.create_custom_model_credential(
|
||||
model_type=ModelType.value_of(model_type),
|
||||
model_type=ModelType(model_type),
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
credential_name=credential_name,
|
||||
@ -342,7 +342,7 @@ class ModelProviderService:
|
||||
"""
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.update_custom_model_credential(
|
||||
model_type=ModelType.value_of(model_type),
|
||||
model_type=ModelType(model_type),
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
credential_id=credential_id,
|
||||
@ -362,7 +362,7 @@ class ModelProviderService:
|
||||
"""
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.delete_custom_model_credential(
|
||||
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
|
||||
model_type=ModelType(model_type), model=model, credential_id=credential_id
|
||||
)
|
||||
|
||||
def switch_active_custom_model_credential(
|
||||
@ -380,7 +380,7 @@ class ModelProviderService:
|
||||
"""
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.switch_custom_model_credential(
|
||||
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
|
||||
model_type=ModelType(model_type), model=model, credential_id=credential_id
|
||||
)
|
||||
|
||||
def add_model_credential_to_model_list(
|
||||
@ -398,7 +398,7 @@ class ModelProviderService:
|
||||
"""
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.add_model_credential_to_model(
|
||||
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
|
||||
model_type=ModelType(model_type), model=model, credential_id=credential_id
|
||||
)
|
||||
|
||||
def remove_model(self, tenant_id: str, provider: str, model_type: str, model: str):
|
||||
@ -412,7 +412,7 @@ class ModelProviderService:
|
||||
:return:
|
||||
"""
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.delete_custom_model(model_type=ModelType.value_of(model_type), model=model)
|
||||
provider_configuration.delete_custom_model(model_type=ModelType(model_type), model=model)
|
||||
|
||||
def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[ProviderWithModelsResponse]:
|
||||
"""
|
||||
@ -426,7 +426,7 @@ class ModelProviderService:
|
||||
provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id)
|
||||
|
||||
# Get provider available models
|
||||
models = provider_configurations.get_models(model_type=ModelType.value_of(model_type), only_active=True)
|
||||
models = provider_configurations.get_models(model_type=ModelType(model_type), only_active=True)
|
||||
|
||||
# Group models by provider
|
||||
provider_models: dict[str, list[ModelWithProviderEntity]] = {}
|
||||
@ -505,7 +505,7 @@ class ModelProviderService:
|
||||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
model_type_enum = ModelType.value_of(model_type)
|
||||
model_type_enum = ModelType(model_type)
|
||||
|
||||
try:
|
||||
result = self._get_provider_manager(tenant_id).get_default_model(
|
||||
@ -540,7 +540,7 @@ class ModelProviderService:
|
||||
:param model: model name
|
||||
:return:
|
||||
"""
|
||||
model_type_enum = ModelType.value_of(model_type)
|
||||
model_type_enum = ModelType(model_type)
|
||||
self._get_provider_manager(tenant_id).update_default_model_record(
|
||||
tenant_id=tenant_id, model_type=model_type_enum, provider=provider, model=model
|
||||
)
|
||||
@ -590,7 +590,7 @@ class ModelProviderService:
|
||||
:return:
|
||||
"""
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.enable_model(model=model, model_type=ModelType.value_of(model_type))
|
||||
provider_configuration.enable_model(model=model, model_type=ModelType(model_type))
|
||||
|
||||
def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str):
|
||||
"""
|
||||
@ -603,4 +603,4 @@ class ModelProviderService:
|
||||
:return:
|
||||
"""
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.disable_model(model=model, model_type=ModelType.value_of(model_type))
|
||||
provider_configuration.disable_model(model=model, model_type=ModelType(model_type))
|
||||
|
||||
Reference in New Issue
Block a user