load balance save api also can switch custom model credential_id

This commit is contained in:
hjlarry
2025-08-19 17:36:01 +08:00
parent 416b2634ed
commit b9a6bf89ef
3 changed files with 66 additions and 5 deletions

View File

@ -1014,10 +1014,10 @@ class ProviderConfiguration(BaseModel):
session.rollback()
raise
def switch_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str) -> None:
def add_model_credential_to_model(self, model_type: ModelType, model: str, credential_id: str) -> None:
"""
Not only switch the custom model credential.
It can also add credential to a new custom model record.
if model list exist this custom model, switch the custom model credential.
if model list not exist this custom model, use the credential to add a new custom model record.
:param model_type: model type
:param model: model name
@ -1056,6 +1056,36 @@ class ProviderConfiguration(BaseModel):
session.add(provider_model_record)
session.commit()
def switch_custom_model_credential(self, model_type: ModelType, model: str, credential_id: str) -> None:
"""
switch the custom model credential.
:param model_type: model type
:param model: model name
:param credential_id: credential id
:return:
"""
with Session(db.engine) as session:
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
credential_record = session.execute(stmt).scalar_one_or_none()
if not credential_record:
raise ValueError("Credential record not found.")
provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
if not provider_model_record:
raise ValueError("The custom model record not found.")
provider_model_record.credential_id = credential_record.id
provider_model_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
session.add(provider_model_record)
session.commit()
def delete_custom_model(self, model_type: ModelType, model: str) -> None:
"""
Delete custom model.