mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 10:28:10 +08:00
fix: move remote credential validation outside DB session to prevent … (#35350)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -318,34 +318,28 @@ class ProviderConfiguration(BaseModel):
|
||||
else [],
|
||||
)
|
||||
|
||||
def validate_provider_credentials(
|
||||
self, credentials: dict[str, Any], credential_id: str = "", session: Session | None = None
|
||||
):
|
||||
def validate_provider_credentials(self, credentials: dict[str, Any], credential_id: str = ""):
|
||||
"""
|
||||
Validate custom credentials.
|
||||
:param credentials: provider credentials
|
||||
:param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate
|
||||
:param session: optional database session
|
||||
:return:
|
||||
"""
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
self.provider.provider_credential_schema.credential_form_schemas
|
||||
if self.provider.provider_credential_schema
|
||||
else []
|
||||
)
|
||||
|
||||
def _validate(s: Session):
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
self.provider.provider_credential_schema.credential_form_schemas
|
||||
if self.provider.provider_credential_schema
|
||||
else []
|
||||
)
|
||||
|
||||
if credential_id:
|
||||
if credential_id:
|
||||
with Session(db.engine) as session:
|
||||
try:
|
||||
stmt = select(ProviderCredential).where(
|
||||
ProviderCredential.tenant_id == self.tenant_id,
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
ProviderCredential.id == credential_id,
|
||||
)
|
||||
credential_record = s.execute(stmt).scalar_one_or_none()
|
||||
# fix origin data
|
||||
credential_record = session.execute(stmt).scalar_one_or_none()
|
||||
if credential_record and credential_record.encrypted_config:
|
||||
if not credential_record.encrypted_config.startswith("{"):
|
||||
original_credentials = {"openai_api_key": credential_record.encrypted_config}
|
||||
@ -356,31 +350,23 @@ class ProviderConfiguration(BaseModel):
|
||||
except JSONDecodeError:
|
||||
original_credentials = {}
|
||||
|
||||
# encrypt credentials
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
# if send [__HIDDEN__] in secret input, it will be same as original value
|
||||
if value == HIDDEN_VALUE and key in original_credentials:
|
||||
credentials[key] = encrypter.decrypt_token(
|
||||
tenant_id=self.tenant_id, token=original_credentials[key]
|
||||
)
|
||||
|
||||
model_provider_factory = self.get_model_provider_factory()
|
||||
validated_credentials = model_provider_factory.provider_credentials_validate(
|
||||
provider=self.provider.provider, credentials=credentials
|
||||
)
|
||||
|
||||
for key, value in validated_credentials.items():
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
|
||||
if value == HIDDEN_VALUE and key in original_credentials:
|
||||
credentials[key] = encrypter.decrypt_token(
|
||||
tenant_id=self.tenant_id, token=original_credentials[key]
|
||||
)
|
||||
|
||||
return validated_credentials
|
||||
model_provider_factory = self.get_model_provider_factory()
|
||||
validated_credentials = model_provider_factory.provider_credentials_validate(
|
||||
provider=self.provider.provider, credentials=credentials
|
||||
)
|
||||
|
||||
if session:
|
||||
return _validate(session)
|
||||
else:
|
||||
with Session(db.engine) as new_session:
|
||||
return _validate(new_session)
|
||||
for key, value in validated_credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
|
||||
|
||||
return validated_credentials
|
||||
|
||||
def _generate_provider_credential_name(self, session) -> str:
|
||||
"""
|
||||
@ -457,14 +443,16 @@ class ProviderConfiguration(BaseModel):
|
||||
:param credential_name: credential name
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine) as pre_session:
|
||||
if credential_name:
|
||||
if self._check_provider_credential_name_exists(credential_name=credential_name, session=session):
|
||||
if self._check_provider_credential_name_exists(credential_name=credential_name, session=pre_session):
|
||||
raise ValueError(f"Credential with name '{credential_name}' already exists.")
|
||||
else:
|
||||
credential_name = self._generate_provider_credential_name(session)
|
||||
credential_name = self._generate_provider_credential_name(pre_session)
|
||||
|
||||
credentials = self.validate_provider_credentials(credentials=credentials, session=session)
|
||||
credentials = self.validate_provider_credentials(credentials=credentials)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
provider_record = self._get_provider_record(session)
|
||||
try:
|
||||
new_record = ProviderCredential(
|
||||
@ -477,7 +465,6 @@ class ProviderConfiguration(BaseModel):
|
||||
session.flush()
|
||||
|
||||
if not provider_record:
|
||||
# If provider record does not exist, create it
|
||||
provider_record = Provider(
|
||||
tenant_id=self.tenant_id,
|
||||
provider_name=self.provider.provider,
|
||||
@ -530,15 +517,15 @@ class ProviderConfiguration(BaseModel):
|
||||
:param credential_name: credential name
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine) as pre_session:
|
||||
if credential_name and self._check_provider_credential_name_exists(
|
||||
credential_name=credential_name, session=session, exclude_id=credential_id
|
||||
credential_name=credential_name, session=pre_session, exclude_id=credential_id
|
||||
):
|
||||
raise ValueError(f"Credential with name '{credential_name}' already exists.")
|
||||
|
||||
credentials = self.validate_provider_credentials(
|
||||
credentials=credentials, credential_id=credential_id, session=session
|
||||
)
|
||||
credentials = self.validate_provider_credentials(credentials=credentials, credential_id=credential_id)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
provider_record = self._get_provider_record(session)
|
||||
stmt = select(ProviderCredential).where(
|
||||
ProviderCredential.id == credential_id,
|
||||
@ -546,12 +533,10 @@ class ProviderConfiguration(BaseModel):
|
||||
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||
)
|
||||
|
||||
# Get the credential record to update
|
||||
credential_record = session.execute(stmt).scalar_one_or_none()
|
||||
if not credential_record:
|
||||
raise ValueError("Credential record not found.")
|
||||
try:
|
||||
# Update credential
|
||||
credential_record.encrypted_config = json.dumps(credentials)
|
||||
credential_record.updated_at = naive_utc_now()
|
||||
if credential_name:
|
||||
@ -879,7 +864,6 @@ class ProviderConfiguration(BaseModel):
|
||||
model: str,
|
||||
credentials: dict[str, Any],
|
||||
credential_id: str = "",
|
||||
session: Session | None = None,
|
||||
):
|
||||
"""
|
||||
Validate custom model credentials.
|
||||
@ -890,16 +874,14 @@ class ProviderConfiguration(BaseModel):
|
||||
:param credential_id: (Optional)If provided, can use existing credential's hidden api key to validate
|
||||
:return:
|
||||
"""
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
self.provider.model_credential_schema.credential_form_schemas
|
||||
if self.provider.model_credential_schema
|
||||
else []
|
||||
)
|
||||
|
||||
def _validate(s: Session):
|
||||
# Get provider credential secret variables
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
self.provider.model_credential_schema.credential_form_schemas
|
||||
if self.provider.model_credential_schema
|
||||
else []
|
||||
)
|
||||
|
||||
if credential_id:
|
||||
if credential_id:
|
||||
with Session(db.engine) as session:
|
||||
try:
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
ProviderModelCredential.id == credential_id,
|
||||
@ -908,7 +890,7 @@ class ProviderConfiguration(BaseModel):
|
||||
ProviderModelCredential.model_name == model,
|
||||
ProviderModelCredential.model_type == model_type,
|
||||
)
|
||||
credential_record = s.execute(stmt).scalar_one_or_none()
|
||||
credential_record = session.execute(stmt).scalar_one_or_none()
|
||||
original_credentials = (
|
||||
json.loads(credential_record.encrypted_config)
|
||||
if credential_record and credential_record.encrypted_config
|
||||
@ -917,31 +899,23 @@ class ProviderConfiguration(BaseModel):
|
||||
except JSONDecodeError:
|
||||
original_credentials = {}
|
||||
|
||||
# decrypt credentials
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
# if send [__HIDDEN__] in secret input, it will be same as original value
|
||||
if value == HIDDEN_VALUE and key in original_credentials:
|
||||
credentials[key] = encrypter.decrypt_token(
|
||||
tenant_id=self.tenant_id, token=original_credentials[key]
|
||||
)
|
||||
|
||||
model_provider_factory = self.get_model_provider_factory()
|
||||
validated_credentials = model_provider_factory.model_credentials_validate(
|
||||
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
|
||||
)
|
||||
|
||||
for key, value in validated_credentials.items():
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
|
||||
if value == HIDDEN_VALUE and key in original_credentials:
|
||||
credentials[key] = encrypter.decrypt_token(
|
||||
tenant_id=self.tenant_id, token=original_credentials[key]
|
||||
)
|
||||
|
||||
return validated_credentials
|
||||
model_provider_factory = self.get_model_provider_factory()
|
||||
validated_credentials = model_provider_factory.model_credentials_validate(
|
||||
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
|
||||
)
|
||||
|
||||
if session:
|
||||
return _validate(session)
|
||||
else:
|
||||
with Session(db.engine) as new_session:
|
||||
return _validate(new_session)
|
||||
for key, value in validated_credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
validated_credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
|
||||
|
||||
return validated_credentials
|
||||
|
||||
def create_custom_model_credential(
|
||||
self, model_type: ModelType, model: str, credentials: dict[str, Any], credential_name: str | None
|
||||
@ -954,20 +928,22 @@ class ProviderConfiguration(BaseModel):
|
||||
:param credentials: model credentials dict
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine) as pre_session:
|
||||
if credential_name:
|
||||
if self._check_custom_model_credential_name_exists(
|
||||
model=model, model_type=model_type, credential_name=credential_name, session=session
|
||||
model=model, model_type=model_type, credential_name=credential_name, session=pre_session
|
||||
):
|
||||
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
|
||||
else:
|
||||
credential_name = self._generate_custom_model_credential_name(
|
||||
model=model, model_type=model_type, session=session
|
||||
model=model, model_type=model_type, session=pre_session
|
||||
)
|
||||
# validate custom model config
|
||||
credentials = self.validate_custom_model_credentials(
|
||||
model_type=model_type, model=model, credentials=credentials, session=session
|
||||
)
|
||||
|
||||
credentials = self.validate_custom_model_credentials(
|
||||
model_type=model_type, model=model, credentials=credentials
|
||||
)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
|
||||
|
||||
try:
|
||||
@ -982,7 +958,6 @@ class ProviderConfiguration(BaseModel):
|
||||
session.add(credential)
|
||||
session.flush()
|
||||
|
||||
# save provider model
|
||||
if not provider_model_record:
|
||||
provider_model_record = ProviderModel(
|
||||
tenant_id=self.tenant_id,
|
||||
@ -1024,23 +999,24 @@ class ProviderConfiguration(BaseModel):
|
||||
:param credential_id: credential id
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine) as pre_session:
|
||||
if credential_name and self._check_custom_model_credential_name_exists(
|
||||
model=model,
|
||||
model_type=model_type,
|
||||
credential_name=credential_name,
|
||||
session=session,
|
||||
session=pre_session,
|
||||
exclude_id=credential_id,
|
||||
):
|
||||
raise ValueError(f"Model credential with name '{credential_name}' already exists for {model}.")
|
||||
# validate custom model config
|
||||
credentials = self.validate_custom_model_credentials(
|
||||
model_type=model_type,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
credential_id=credential_id,
|
||||
session=session,
|
||||
)
|
||||
|
||||
credentials = self.validate_custom_model_credentials(
|
||||
model_type=model_type,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
provider_model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
|
||||
|
||||
stmt = select(ProviderModelCredential).where(
|
||||
@ -1055,7 +1031,6 @@ class ProviderConfiguration(BaseModel):
|
||||
raise ValueError("Credential record not found.")
|
||||
|
||||
try:
|
||||
# Update credential
|
||||
credential_record.encrypted_config = json.dumps(credentials)
|
||||
credential_record.updated_at = naive_utc_now()
|
||||
if credential_name:
|
||||
|
||||
Reference in New Issue
Block a user