mirror of
https://github.com/langgenius/dify.git
synced 2026-04-21 19:27:40 +08:00
feat(oauth): refactor tool provider methods and enhance credential handling
This commit is contained in:
@ -7,7 +7,13 @@ from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolProviderEntity, ToolProviderType
|
||||
from core.tools.entities.tool_entities import (
|
||||
OAuthSchema,
|
||||
ToolEntity,
|
||||
ToolProviderCredentialType,
|
||||
ToolProviderEntity,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
|
||||
from core.tools.errors import (
|
||||
ToolProviderNotFoundError,
|
||||
@ -39,10 +45,18 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
credential_dict = provider_yaml.get("credentials_for_provider", {}).get(credential, {})
|
||||
credentials_schema.append(credential_dict)
|
||||
|
||||
oauth_schema = None
|
||||
if provider_yaml.get("oauth_schema", None) is not None:
|
||||
oauth_schema = OAuthSchema(
|
||||
client_schema=provider_yaml.get("oauth_schema", {}).get("client_schema", []),
|
||||
credentials_schema=provider_yaml.get("oauth_schema", {}).get("credentials_schema", []),
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
entity=ToolProviderEntity(
|
||||
identity=provider_yaml["identity"],
|
||||
credentials_schema=credentials_schema,
|
||||
oauth_schema=oauth_schema,
|
||||
),
|
||||
)
|
||||
|
||||
@ -91,16 +105,20 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
"""
|
||||
return self.tools
|
||||
|
||||
def get_credentials_schema(self) -> list[ProviderConfig]:
|
||||
def get_credentials_schema(
|
||||
self, credential_type: ToolProviderCredentialType = ToolProviderCredentialType.API_KEY
|
||||
) -> list[ProviderConfig]:
|
||||
"""
|
||||
returns the credentials schema of the provider
|
||||
|
||||
:return: the credentials schema
|
||||
"""
|
||||
if not self.entity.credentials_schema:
|
||||
return []
|
||||
|
||||
return self.entity.credentials_schema.copy()
|
||||
if credential_type == ToolProviderCredentialType.OAUTH2:
|
||||
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
|
||||
elif credential_type == ToolProviderCredentialType.API_KEY:
|
||||
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
|
||||
else:
|
||||
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||
|
||||
def get_tools(self) -> list[BuiltinTool]:
|
||||
"""
|
||||
|
||||
@ -344,10 +344,18 @@ class ToolEntity(BaseModel):
|
||||
return v or []
|
||||
|
||||
|
||||
class OAuthSchema(BaseModel):
|
||||
client_schema: list[ProviderConfig] = Field(default_factory=list, description="The schema of the OAuth client")
|
||||
credentials_schema: list[ProviderConfig] = Field(
|
||||
default_factory=list, description="The schema of the OAuth credentials"
|
||||
)
|
||||
|
||||
|
||||
class ToolProviderEntity(BaseModel):
|
||||
identity: ToolProviderIdentity
|
||||
plugin_id: Optional[str] = None
|
||||
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
|
||||
oauth_schema: Optional[OAuthSchema] = None
|
||||
|
||||
|
||||
class ToolProviderEntityWithPlugin(ToolProviderEntity):
|
||||
@ -437,7 +445,7 @@ class ToolSelector(BaseModel):
|
||||
|
||||
|
||||
class ToolProviderCredentialType(enum.StrEnum):
|
||||
API_KEY = "api_key"
|
||||
API_KEY = "api-key"
|
||||
OAUTH2 = "oauth2"
|
||||
|
||||
def get_name(self):
|
||||
@ -446,7 +454,7 @@ class ToolProviderCredentialType(enum.StrEnum):
|
||||
elif self == ToolProviderCredentialType.OAUTH2:
|
||||
return "AUTH"
|
||||
else:
|
||||
return self.value.replace("_", " ").upper()
|
||||
return self.value.replace("-", " ").upper()
|
||||
|
||||
def is_editable(self):
|
||||
return self == ToolProviderCredentialType.API_KEY
|
||||
@ -461,7 +469,7 @@ class ToolProviderCredentialType(enum.StrEnum):
|
||||
@classmethod
|
||||
def of(cls, credential_type: str) -> "ToolProviderCredentialType":
|
||||
type_name = credential_type.lower()
|
||||
if type_name == "api_key":
|
||||
if type_name == "api-key":
|
||||
return cls.API_KEY
|
||||
elif type_name == "oauth2":
|
||||
return cls.OAUTH2
|
||||
|
||||
@ -34,7 +34,13 @@ from core.tools.custom_tool.provider import ApiToolProviderController
|
||||
from core.tools.custom_tool.tool import ApiTool
|
||||
from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProviderTypeApiLiteral
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter, ToolProviderType
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ToolInvokeFrom,
|
||||
ToolParameter,
|
||||
ToolProviderCredentialType,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.errors import ToolProviderNotFoundError
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.configuration import ProviderConfigEncrypter, ToolParameterConfigurationManager
|
||||
@ -202,7 +208,12 @@ class ToolManager:
|
||||
credentials = builtin_provider.credentials
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
config=[
|
||||
x.to_basic_provider_config()
|
||||
for x in provider_controller.get_credentials_schema(
|
||||
ToolProviderCredentialType.of(builtin_provider.credential_type)
|
||||
)
|
||||
],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user