feat(oauth): refactor tool provider methods and enhance credential handling

This commit is contained in:
Harry
2025-06-27 13:17:09 +08:00
parent 8a954c0b19
commit daec82bd44
9 changed files with 309 additions and 170 deletions

View File

@ -7,7 +7,13 @@ from core.helper.module_import_helper import load_single_subclass_from_source
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.tool_entities import ToolEntity, ToolProviderEntity, ToolProviderType
from core.tools.entities.tool_entities import (
OAuthSchema,
ToolEntity,
ToolProviderCredentialType,
ToolProviderEntity,
ToolProviderType,
)
from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
from core.tools.errors import (
ToolProviderNotFoundError,
@ -39,10 +45,18 @@ class BuiltinToolProviderController(ToolProviderController):
credential_dict = provider_yaml.get("credentials_for_provider", {}).get(credential, {})
credentials_schema.append(credential_dict)
oauth_schema = None
if provider_yaml.get("oauth_schema", None) is not None:
oauth_schema = OAuthSchema(
client_schema=provider_yaml.get("oauth_schema", {}).get("client_schema", []),
credentials_schema=provider_yaml.get("oauth_schema", {}).get("credentials_schema", []),
)
super().__init__(
entity=ToolProviderEntity(
identity=provider_yaml["identity"],
credentials_schema=credentials_schema,
oauth_schema=oauth_schema,
),
)
@ -91,16 +105,20 @@ class BuiltinToolProviderController(ToolProviderController):
"""
return self.tools
def get_credentials_schema(self) -> list[ProviderConfig]:
def get_credentials_schema(
self, credential_type: ToolProviderCredentialType = ToolProviderCredentialType.API_KEY
) -> list[ProviderConfig]:
"""
returns the credentials schema of the provider
:return: the credentials schema
"""
if not self.entity.credentials_schema:
return []
return self.entity.credentials_schema.copy()
if credential_type == ToolProviderCredentialType.OAUTH2:
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
elif credential_type == ToolProviderCredentialType.API_KEY:
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
else:
raise ValueError(f"Invalid credential type: {credential_type}")
def get_tools(self) -> list[BuiltinTool]:
"""