mirror of
https://github.com/langgenius/dify.git
synced 2026-04-25 21:26:15 +08:00
feat(oauth): refactor tool provider methods and enhance credential handling
This commit is contained in:
@ -7,7 +7,13 @@ from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolProviderEntity, ToolProviderType
|
||||
from core.tools.entities.tool_entities import (
|
||||
OAuthSchema,
|
||||
ToolEntity,
|
||||
ToolProviderCredentialType,
|
||||
ToolProviderEntity,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
|
||||
from core.tools.errors import (
|
||||
ToolProviderNotFoundError,
|
||||
@ -39,10 +45,18 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
credential_dict = provider_yaml.get("credentials_for_provider", {}).get(credential, {})
|
||||
credentials_schema.append(credential_dict)
|
||||
|
||||
oauth_schema = None
|
||||
if provider_yaml.get("oauth_schema", None) is not None:
|
||||
oauth_schema = OAuthSchema(
|
||||
client_schema=provider_yaml.get("oauth_schema", {}).get("client_schema", []),
|
||||
credentials_schema=provider_yaml.get("oauth_schema", {}).get("credentials_schema", []),
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
entity=ToolProviderEntity(
|
||||
identity=provider_yaml["identity"],
|
||||
credentials_schema=credentials_schema,
|
||||
oauth_schema=oauth_schema,
|
||||
),
|
||||
)
|
||||
|
||||
@ -91,16 +105,20 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
"""
|
||||
return self.tools
|
||||
|
||||
def get_credentials_schema(self) -> list[ProviderConfig]:
|
||||
def get_credentials_schema(
|
||||
self, credential_type: ToolProviderCredentialType = ToolProviderCredentialType.API_KEY
|
||||
) -> list[ProviderConfig]:
|
||||
"""
|
||||
returns the credentials schema of the provider
|
||||
|
||||
:return: the credentials schema
|
||||
"""
|
||||
if not self.entity.credentials_schema:
|
||||
return []
|
||||
|
||||
return self.entity.credentials_schema.copy()
|
||||
if credential_type == ToolProviderCredentialType.OAUTH2:
|
||||
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
|
||||
elif credential_type == ToolProviderCredentialType.API_KEY:
|
||||
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
|
||||
else:
|
||||
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||
|
||||
def get_tools(self) -> list[BuiltinTool]:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user