fix: remove debugging flags

This commit is contained in:
Harry
2025-06-23 12:49:18 +08:00
parent b3a8dbe2f5
commit 7f292dc261
5 changed files with 65 additions and 22 deletions

View File

@ -118,12 +118,8 @@ class BuiltinToolManageService:
if provider is None:
raise ValueError(f"you have not added provider {provider_name}")
if not ToolProviderCredentialType.get_credential_type(provider.credential_type).is_editable():
raise ValueError(f"you cannot update oauth2 provider {provider_name} credentials")
try:
# exclude oauth2 provider
if provider.credential_type != ToolProviderCredentialType.OAUTH2.value:
if ToolProviderCredentialType.get_credential_type(provider.credential_type).is_editable():
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
if not provider_controller.need_credentials:
raise ValueError(f"provider {provider_name} does not need credentials")
@ -139,11 +135,15 @@ class BuiltinToolManageService:
credentials = BuiltinToolManageService._decrypt_and_restore_credentials(
provider_controller, tool_configuration, provider, credentials
)
# Encrypt and save the credentials
BuiltinToolManageService._encrypt_and_save_credentials(
provider_controller, tool_configuration, provider, credentials, user_id
)
else:
raise ValueError(
f"provider {provider_name} is not editable, you can only delete it and add a new one"
)
# update name if provided
if name is not None and provider.name != name:
@ -162,15 +162,60 @@ class BuiltinToolManageService:
@staticmethod
def add_builtin_tool_provider(
user_id: str, tenant_id: str, provider_name: str, credentials: dict, name: str | None = None
user_id: str, type: ToolProviderCredentialType, tenant_id: str, provider_name:str, credentials: dict, name: str | None = None
):
"""
add builtin tool provider
"""
if name is None:
name = BuiltinToolManageService.get_next_builtin_tool_provider_name(tenant_id, type)
provider = BuiltinToolProvider(
tenant_id=tenant_id,
user_id=user_id,
provider=provider_name,
credential_type=type.value,
credentials=json.dumps(credentials),
name=name,
)
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
if not provider_controller.need_credentials:
raise ValueError(f"provider {provider_name} does not need credentials")
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
# Encrypt and save the credentials
BuiltinToolManageService._encrypt_and_save_credentials(
provider_controller, tool_configuration, provider, credentials, user_id
)
db.session.add(provider)
return {"result": "success"}
@staticmethod
def get_next_builtin_tool_provider_name(tenant_id: str, type: ToolProviderCredentialType) -> str:
"""
next name = max(provider_names) + 1
"""
provider_names = db.session.query(BuiltinToolProvider).filter_by(
tenant_id=tenant_id,
credential_type=type.value,
).all()
if not provider_names:
return f"{type.value} 1"
# OAuth 1 then OAuth 2, if don't have OAuth 1, then return OAuth 1
# if dont have number, then get name and add 1
for provider_name in provider_names:
if provider_name.provider.startswith(type.value):
return f"{type.value} {int(provider_name.provider.split(' ')[1]) + 1}"
return f"{type.value} 1"
@staticmethod
def get_builtin_tool_provider_credentials(tenant_id: str, provider_name: str):
"""
@ -416,7 +461,7 @@ class BuiltinToolManageService:
def _decrypt_and_restore_credentials(provider_controller, tool_configuration, provider, credentials):
"""
Decrypt original credentials and restore masked values from the input credentials
:param provider_controller: the provider controller
:param tool_configuration: the tool configuration encrypter
:param provider: the provider object from database
@ -425,19 +470,19 @@ class BuiltinToolManageService:
"""
original_credentials = tool_configuration.decrypt(provider.credentials)
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
# check if the credential has changed, save the original credential
for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]: # type: ignore
credentials[name] = original_credentials[name] # type: ignore
return credentials
@staticmethod
def _encrypt_and_save_credentials(provider_controller, tool_configuration, provider, credentials, user_id):
"""
Validate and encrypt credentials, then save to database
:param provider_controller: the provider controller
:param tool_configuration: the tool configuration encrypter
:param provider: the provider object from database