fix: move remote credential validation outside DB session to prevent … (#35350)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
zyssyz123
2026-04-17 15:42:29 +08:00
committed by GitHub
parent eaddd4a132
commit a74e12809b
2 changed files with 156 additions and 177 deletions

View File

@ -318,34 +318,28 @@ class ProviderConfiguration(BaseModel):
else [],
)
def validate_provider_credentials(
self, credentials: dict[str, Any], credential_id: str = "", session: Session | None = None
):
def validate_provider_credentials(self, credentials: dict[str, Any], credential_id: str = ""):
"""
Validate custom credentials.
:param credentials: provider credentials
:param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate
:param session: optional database session
:return:
"""
provider_credential_secret_variables = self.extract_secret_variables(
self.provider.provider_credential_schema.credential_form_schemas
if self.provider.provider_credential_schema
else []
)
def _validate(s: Session):
# Get provider credential secret variables
provider_credential_secret_variables = self.extract_secret_variables(
self.provider.provider_credential_schema.credential_form_schemas
if self.provider.provider_credential_schema
else []
)
if credential_id:
if credential_id:
with Session(db.engine) as session:
try:
stmt = select(ProviderCredential).where(
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name.in_(self._get_provider_names()),
ProviderCredential.id == credential_id,
)
credential_record = s.execute(stmt).scalar_one_or_none()
# fix origin data
credential_record = session.execute(stmt).scalar_one_or_none()
if credential_record and credential_record.encrypted_config:
if not credential_record.encrypted_config.startswith("{"):
original_credentials = {"openai_api_key": credential_record.encrypted_config}
@ -356,31 +350,23 @@ class ProviderConfiguration(BaseModel):
except JSONDecodeError:
original_credentials = {}
# encrypt credentials
for key, value in credentials.items():
if key in provider_credential_secret_variables:
# if send [__HIDDEN__] in secret input, it will be same as original value
if value == HIDDEN_VALUE and key in original_credentials:
credentials[key] = encrypter.decrypt_token(
tenant_id=self.tenant_id, token=original_credentials[key]
)
model_provider_factory = self.get_model_provider_factory()
validated_credentials = model_provider_factory.provider_credentials_validate(
provider=self.provider.provider, credentials=credentials
)
for key, value in validated_credentials.items():
for key, value in credentials.items():
if key in provider_credential_secret_variables:
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
if value == HIDDEN_VALUE and key in original_credentials:
credentials[key] = encrypter.decrypt_token(
tenant_id=self.tenant_id, token=original_credentials[key]
)
return validated_credentials
model_provider_factory = self.get_model_provider_factory()
validated_credentials = model_provider_factory.provider_credentials_validate(
provider=self.provider.provider, credentials=credentials
)
if session:
return _validate(session)
else:
with Session(db.engine) as new_session:
return _validate(new_session)
for key, value in validated_credentials.items():
if key in provider_credential_secret_variables:
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
return validated_credentials
def _generate_provider_credential_name(self, session) -> str:
"""
@ -457,14 +443,16 @@ class ProviderConfiguration(BaseModel):
:param credential_name: credential name
:return:
"""
with Session(db.engine) as session:
with Session(db.engine) as pre_session:
if credential_name:
if self._check_provider_credential_name_exists(credential_name=credential_name, session=session):
if self._check_provider_credential_name_exists(credential_name=credential_name, session=pre_session):
raise ValueError(f"Credential with name '{credential_name}' already exists.")
else:
credential_name = self._generate_provider_credential_name(session)
credential_name = self._generate_provider_credential_name(pre_session)
credentials = self.validate_provider_credentials(credentials=credentials, session=session)
credentials = self.validate_provider_credentials(credentials=credentials)
with Session(db.engine) as session:
provider_record = self._get_provider_record(session)
try:
new_record = ProviderCredential(
@ -477,7 +465,6 @@ class ProviderConfiguration(BaseModel):
session.flush()
if not provider_record:
# If provider record does not exist, create it
provider_record = Provider(
tenant_id=self.tenant_id,
provider_name=self.provider.provider,
@ -530,15 +517,15 @@ class ProviderConfiguration(BaseModel):
:param credential_name: credential name
:return:
"""
with Session(db.engine) as session:
with Session(db.engine) as pre_session:
if credential_name and self._check_provider_credential_name_exists(
credential_name=credential_name, session=session, exclude_id=credential_id
credential_name=credential_name, session=pre_session, exclude_id=credential_id
):
raise ValueError(f"Credential with name '{credential_name}' already exists.")
credentials = self.validate_provider_credentials(
credentials=credentials, credential_id=credential_id, session=session
)
credentials = self.validate_provider_credentials(credentials=credentials, credential_id=credential_id)
with Session(db.engine) as session:
provider_record = self._get_provider_record(session)
stmt = select(ProviderCredential).where(
ProviderCredential.id == credential_id,
@ -546,12 +533,10 @@ class ProviderConfiguration(BaseModel):
ProviderCredential.provider_name.in_(self._get_provider_names()),
)
# Get the credential record to update
credential_record = session.execute(stmt).scalar_one_or_none()
if not credential_record:
raise ValueError("Credential record not found.")
try:
# Update credential
credential_record.encrypted_config = json.dumps(credentials)
credential_record.updated_at = naive_utc_now()
if credential_name:
@ -879,7 +864,6 @@ class ProviderConfiguration(BaseModel):
model: str,
credentials: dict[str, Any],
credential_id: str = "",
session: Session | None = None,
):
"""
Validate custom model credentials.
@ -890,16 +874,14 @@ class ProviderConfiguration(BaseModel):
:param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate
:return:
"""
provider_credential_secret_variables = self.extract_secret_variables(
self.provider.model_credential_schema.credential_form_schemas
if self.provider.model_credential_schema
else []
)
def _validate(s: Session):
# Get provider credential secret variables
provider_credential_secret_variables = self.extract_secret_variables(
self.provider.model_credential_schema.credential_form_schemas
if self.provider.model_credential_schema
else []
)
if credential_id:
if credential_id:
with Session(db.engine) as session:
try:
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id,
@ -908,7 +890,7 @@ class ProviderConfiguration(BaseModel):
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type,
)
credential_record = s.execute(stmt).scalar_one_or_none()
credential_record = session.execute(stmt).scalar_one_or_none()
original_credentials = (
json.loads(credential_record.encrypted_config)
if credential_record and credential_record.encrypted_config
@ -917,31 +899,23 @@ class ProviderConfiguration(BaseModel):
except JSONDecodeError:
original_credentials = {}
# decrypt credentials
for key, value in credentials.items():
if key in provider_credential_secret_variables:
# if send [__HIDDEN__] in secret input, it will be same as original value
if value == HIDDEN_VALUE and key in original_credentials:
credentials[key] = encrypter.decrypt_token(
tenant_id=self.tenant_id, token=original_credentials[key]
)
model_provider_factory = self.get_model_provider_factory()
validated_credentials = model_provider_factory.model_credentials_validate(
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
)
for key, value in validated_credentials.items():
for key, value in credentials.items():
if key in provider_credential_secret_variables:
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
if value == HIDDEN_VALUE and key in original_credentials:
credentials[key] = encrypter.decrypt_token(
tenant_id=self.tenant_id, token=original_credentials[key]
)
return validated_credentials
model_provider_factory = self.get_model_provider_factory()
validated_credentials = model_provider_factory.model_credentials_validate(
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
)
if session:
return _validate(session)
else:
with Session(db.engine) as new_session:
return _validate(new_session)
for key, value in validated_credentials.items():
if key in provider_credential_secret_variables:
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
return validated_credentials
def create_custom_model_credential(
self, model_type: ModelType, model: str, credentials: dict[str, Any], credential_name: str | None
@ -954,20 +928,22 @@ class ProviderConfiguration(BaseModel):
:param credentials: model credentials dict
:return:
"""
with Session(db.engine) as session:
with Session(db.engine) as pre_session:
if credential_name:
if self._check_custom_model_credential_name_exists(
model=model, model_type=model_type, credential_name=credential_name, session=session
model=model, model_type=model_type, credential_name=credential_name, session=pre_session
):
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
else:
credential_name = self._generate_custom_model_credential_name(
model=model, model_type=model_type, session=session
model=model, model_type=model_type, session=pre_session
)
# validate custom model config
credentials = self.validate_custom_model_credentials(
model_type=model_type, model=model, credentials=credentials, session=session
)
credentials = self.validate_custom_model_credentials(
model_type=model_type, model=model, credentials=credentials
)
with Session(db.engine) as session:
provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
try:
@ -982,7 +958,6 @@ class ProviderConfiguration(BaseModel):
session.add(credential)
session.flush()
# save provider model
if not provider_model_record:
provider_model_record = ProviderModel(
tenant_id=self.tenant_id,
@ -1024,23 +999,24 @@ class ProviderConfiguration(BaseModel):
:param credential_id: credential id
:return:
"""
with Session(db.engine) as session:
with Session(db.engine) as pre_session:
if credential_name and self._check_custom_model_credential_name_exists(
model=model,
model_type=model_type,
credential_name=credential_name,
session=session,
session=pre_session,
exclude_id=credential_id,
):
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
# validate custom model config
credentials = self.validate_custom_model_credentials(
model_type=model_type,
model=model,
credentials=credentials,
credential_id=credential_id,
session=session,
)
credentials = self.validate_custom_model_credentials(
model_type=model_type,
model=model,
credentials=credentials,
credential_id=credential_id,
)
with Session(db.engine) as session:
provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
stmt = select(ProviderModelCredential).where(
@ -1055,7 +1031,6 @@ class ProviderConfiguration(BaseModel):
raise ValueError("Credential record not found.")
try:
# Update credential
credential_record.encrypted_config = json.dumps(credentials)
credential_record.updated_at = naive_utc_now()
if credential_name: