feat(oauth): refactor tool provider methods and enhance credential handling

This commit is contained in:
Harry
2025-06-27 13:17:09 +08:00
parent 8a954c0b19
commit daec82bd44
9 changed files with 309 additions and 170 deletions

View File

@ -7,7 +7,13 @@ from core.helper.module_import_helper import load_single_subclass_from_source
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.tool_entities import ToolEntity, ToolProviderEntity, ToolProviderType
from core.tools.entities.tool_entities import (
OAuthSchema,
ToolEntity,
ToolProviderCredentialType,
ToolProviderEntity,
ToolProviderType,
)
from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
from core.tools.errors import (
ToolProviderNotFoundError,
@ -39,10 +45,18 @@ class BuiltinToolProviderController(ToolProviderController):
credential_dict = provider_yaml.get("credentials_for_provider", {}).get(credential, {})
credentials_schema.append(credential_dict)
oauth_schema = None
if provider_yaml.get("oauth_schema", None) is not None:
oauth_schema = OAuthSchema(
client_schema=provider_yaml.get("oauth_schema", {}).get("client_schema", []),
credentials_schema=provider_yaml.get("oauth_schema", {}).get("credentials_schema", []),
)
super().__init__(
entity=ToolProviderEntity(
identity=provider_yaml["identity"],
credentials_schema=credentials_schema,
oauth_schema=oauth_schema,
),
)
@ -91,16 +105,20 @@ class BuiltinToolProviderController(ToolProviderController):
"""
return self.tools
def get_credentials_schema(self) -> list[ProviderConfig]:
def get_credentials_schema(
self, credential_type: ToolProviderCredentialType = ToolProviderCredentialType.API_KEY
) -> list[ProviderConfig]:
"""
returns the credentials schema of the provider
:return: the credentials schema
"""
if not self.entity.credentials_schema:
return []
return self.entity.credentials_schema.copy()
if credential_type == ToolProviderCredentialType.OAUTH2:
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
elif credential_type == ToolProviderCredentialType.API_KEY:
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
else:
raise ValueError(f"Invalid credential type: {credential_type}")
def get_tools(self) -> list[BuiltinTool]:
"""

View File

@ -344,10 +344,18 @@ class ToolEntity(BaseModel):
return v or []
class OAuthSchema(BaseModel):
client_schema: list[ProviderConfig] = Field(default_factory=list, description="The schema of the OAuth client")
credentials_schema: list[ProviderConfig] = Field(
default_factory=list, description="The schema of the OAuth credentials"
)
class ToolProviderEntity(BaseModel):
identity: ToolProviderIdentity
plugin_id: Optional[str] = None
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
oauth_schema: Optional[OAuthSchema] = None
class ToolProviderEntityWithPlugin(ToolProviderEntity):
@ -437,7 +445,7 @@ class ToolSelector(BaseModel):
class ToolProviderCredentialType(enum.StrEnum):
API_KEY = "api_key"
API_KEY = "api-key"
OAUTH2 = "oauth2"
def get_name(self):
@ -446,7 +454,7 @@ class ToolProviderCredentialType(enum.StrEnum):
elif self == ToolProviderCredentialType.OAUTH2:
return "AUTH"
else:
return self.value.replace("_", " ").upper()
return self.value.replace("-", " ").upper()
def is_editable(self):
return self == ToolProviderCredentialType.API_KEY
@ -461,7 +469,7 @@ class ToolProviderCredentialType(enum.StrEnum):
@classmethod
def of(cls, credential_type: str) -> "ToolProviderCredentialType":
type_name = credential_type.lower()
if type_name == "api_key":
if type_name == "api-key":
return cls.API_KEY
elif type_name == "oauth2":
return cls.OAUTH2

View File

@ -34,7 +34,13 @@ from core.tools.custom_tool.provider import ApiToolProviderController
from core.tools.custom_tool.tool import ApiTool
from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProviderTypeApiLiteral
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter, ToolProviderType
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ToolInvokeFrom,
ToolParameter,
ToolProviderCredentialType,
ToolProviderType,
)
from core.tools.errors import ToolProviderNotFoundError
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import ProviderConfigEncrypter, ToolParameterConfigurationManager
@ -202,7 +208,12 @@ class ToolManager:
credentials = builtin_provider.credentials
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
config=[
x.to_basic_provider_config()
for x in provider_controller.get_credentials_schema(
ToolProviderCredentialType.of(builtin_provider.credential_type)
)
],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)