mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 17:38:04 +08:00
feat(oauth): update api
This commit is contained in:
@ -20,7 +20,6 @@ from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.nodes.tool.entities import ToolEntity
|
||||
|
||||
|
||||
from configs import dify_config
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -35,18 +34,10 @@ from core.tools.custom_tool.provider import ApiToolProviderController
|
||||
from core.tools.custom_tool.tool import ApiTool
|
||||
from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProviderTypeApiLiteral
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ToolInvokeFrom,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError
|
||||
from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter, ToolProviderType
|
||||
from core.tools.errors import ToolProviderNotFoundError
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.configuration import (
|
||||
ProviderConfigEncrypter,
|
||||
ToolParameterConfigurationManager,
|
||||
)
|
||||
from core.tools.utils.configuration import ProviderConfigEncrypter, ToolParameterConfigurationManager
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
||||
@ -64,8 +55,11 @@ class ToolManager:
|
||||
@classmethod
|
||||
def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController:
|
||||
"""
|
||||
|
||||
get the hardcoded provider
|
||||
|
||||
"""
|
||||
|
||||
if len(cls._hardcoded_providers) == 0:
|
||||
# init the builtin providers
|
||||
cls.load_hardcoded_providers_cache()
|
||||
@ -109,7 +103,12 @@ class ToolManager:
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(Lock())
|
||||
|
||||
plugin_tool_providers = contexts.plugin_tool_providers.get()
|
||||
if provider in plugin_tool_providers:
|
||||
return plugin_tool_providers[provider]
|
||||
|
||||
with contexts.plugin_tool_providers_lock.get():
|
||||
# double check
|
||||
plugin_tool_providers = contexts.plugin_tool_providers.get()
|
||||
if provider in plugin_tool_providers:
|
||||
return plugin_tool_providers[provider]
|
||||
@ -127,25 +126,7 @@ class ToolManager:
|
||||
)
|
||||
|
||||
plugin_tool_providers[provider] = controller
|
||||
|
||||
return controller
|
||||
|
||||
@classmethod
|
||||
def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None:
|
||||
"""
|
||||
get the builtin tool
|
||||
|
||||
:param provider: the name of the provider
|
||||
:param tool_name: the name of the tool
|
||||
:param tenant_id: the id of the tenant
|
||||
:return: the provider, the tool
|
||||
"""
|
||||
provider_controller = cls.get_builtin_provider(provider, tenant_id)
|
||||
tool = provider_controller.get_tool(tool_name)
|
||||
if tool is None:
|
||||
raise ToolNotFoundError(f"tool {tool_name} not found")
|
||||
|
||||
return tool
|
||||
return controller
|
||||
|
||||
@classmethod
|
||||
def get_tool_runtime(
|
||||
@ -563,6 +544,22 @@ class ToolManager:
|
||||
|
||||
return cls._builtin_tools_labels[tool_name]
|
||||
|
||||
@classmethod
|
||||
def list_default_builtin_providers(cls, tenant_id: str) -> list[BuiltinToolProvider]:
|
||||
"""
|
||||
list all the builtin providers
|
||||
"""
|
||||
# according to multi credentials, select the one with is_default=True first, then created_at oldest
|
||||
# for compatibility with old version
|
||||
sql = """
|
||||
SELECT DISTINCT ON (tenant_id, provider) id
|
||||
FROM tool_builtin_providers
|
||||
WHERE tenant_id = :tenant_id
|
||||
ORDER BY tenant_id, provider, is_default DESC, created_at DESC
|
||||
"""
|
||||
ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()]
|
||||
return db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.id.in_(ids)).all()
|
||||
|
||||
@classmethod
|
||||
def list_providers_from_api(
|
||||
cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral
|
||||
@ -577,30 +574,13 @@ class ToolManager:
|
||||
|
||||
with db.session.no_autoflush:
|
||||
if "builtin" in filters:
|
||||
|
||||
def get_builtin_providers(tenant_id):
|
||||
# according to multi credentials, select the one with is_default=True first, then created_at oldest
|
||||
# for compatibility with old version
|
||||
sql = """
|
||||
SELECT DISTINCT ON (tenant_id, provider) id
|
||||
FROM tool_builtin_providers
|
||||
WHERE tenant_id = :tenant_id
|
||||
ORDER BY tenant_id, provider, is_default DESC, created_at DESC
|
||||
"""
|
||||
ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()]
|
||||
return db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.id.in_(ids)).all()
|
||||
|
||||
builtin_providers = cls.list_builtin_providers(tenant_id)
|
||||
|
||||
# get builtin providers
|
||||
db_builtin_providers = get_builtin_providers(tenant_id)
|
||||
|
||||
# rewrite db_builtin_providers
|
||||
for db_provider in db_builtin_providers:
|
||||
db_provider.provider = str(ToolProviderID(db_provider.provider))
|
||||
|
||||
def find_db_builtin_provider(provider):
|
||||
return next((x for x in db_builtin_providers if x.provider == provider), None)
|
||||
# key: provider name, value: provider
|
||||
db_builtin_providers = {
|
||||
str(ToolProviderID(provider.provider)): provider
|
||||
for provider in cls.list_default_builtin_providers(tenant_id)
|
||||
}
|
||||
|
||||
# append builtin providers
|
||||
for provider in builtin_providers:
|
||||
@ -612,10 +592,9 @@ class ToolManager:
|
||||
name_func=lambda x: x.identity.name,
|
||||
):
|
||||
continue
|
||||
|
||||
user_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider,
|
||||
db_provider=find_db_builtin_provider(provider.entity.identity.name),
|
||||
db_provider=db_builtin_providers.get(provider.entity.identity.name),
|
||||
decrypt_credentials=False,
|
||||
)
|
||||
|
||||
@ -625,7 +604,6 @@ class ToolManager:
|
||||
result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
|
||||
|
||||
# get db api providers
|
||||
|
||||
if "api" in filters:
|
||||
db_api_providers: list[ApiToolProvider] = (
|
||||
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
|
||||
|
||||
Reference in New Issue
Block a user