feat: inner api encrypt

This commit is contained in:
Yeuoly
2024-08-30 21:25:58 +08:00
parent 60e75dc748
commit de01ca8d55
8 changed files with 88 additions and 59 deletions

View File

@ -15,60 +15,60 @@ from core.tools.entities.tool_entities import (
from core.tools.tool.tool import Tool
class ToolConfigurationManager(BaseModel):
class ProviderConfigEncrypter(BaseModel):
tenant_id: str
config: Mapping[str, BasicProviderConfig]
provider_type: str
provider_identity: str
def _deep_copy(self, credentials: dict[str, str]) -> dict[str, str]:
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
"""
deep copy credentials
deep copy data
"""
return deepcopy(credentials)
return deepcopy(data)
def encrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]:
def encrypt(self, data: dict[str, str]) -> Mapping[str, str]:
"""
encrypt tool credentials with tenant id
return a deep copy of credentials with encrypted values
"""
credentials = self._deep_copy(credentials)
data = self._deep_copy(data)
# get fields need to be decrypted
fields = self.config
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in credentials:
encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name])
credentials[field_name] = encrypted
if field_name in data:
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name])
data[field_name] = encrypted
return credentials
return data
def mask_tool_credentials(self, credentials: dict[str, Any]) -> dict[str, Any]:
def mask_tool_credentials(self, data: dict[str, Any]) -> Mapping[str, Any]:
"""
mask tool credentials
return a deep copy of credentials with masked values
"""
credentials = self._deep_copy(credentials)
data = self._deep_copy(data)
# get fields need to be decrypted
fields = self.config
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in credentials:
if len(credentials[field_name]) > 6:
credentials[field_name] = \
credentials[field_name][:2] + \
'*' * (len(credentials[field_name]) - 4) + \
credentials[field_name][-2:]
if field_name in data:
if len(data[field_name]) > 6:
data[field_name] = \
data[field_name][:2] + \
'*' * (len(data[field_name]) - 4) + \
data[field_name][-2:]
else:
credentials[field_name] = '*' * len(credentials[field_name])
data[field_name] = '*' * len(data[field_name])
return credentials
return data
def decrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]:
def decrypt(self, data: dict[str, str]) -> Mapping[str, str]:
"""
decrypt tool credentials with tenant id
@ -82,19 +82,19 @@ class ToolConfigurationManager(BaseModel):
cached_credentials = cache.get()
if cached_credentials:
return cached_credentials
credentials = self._deep_copy(credentials)
data = self._deep_copy(data)
# get fields need to be decrypted
fields = self.config
for field_name, field in fields.items():
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
if field_name in credentials:
if field_name in data:
try:
credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name])
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
except:
pass
cache.set(credentials)
return credentials
cache.set(data)
return data
def delete_tool_credentials_cache(self):
cache = ToolProviderCredentialsCache(