mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 09:28:04 +08:00
refactor: tool
This commit is contained in:
@ -7,6 +7,8 @@ from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ToolProviderEntity,
|
||||
ToolProviderIdentity,
|
||||
ToolProviderType,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
@ -18,6 +20,11 @@ class ApiToolProviderController(ToolProviderController):
|
||||
tenant_id: str
|
||||
tools: list[ApiTool] = Field(default_factory=list)
|
||||
|
||||
def __init__(self, entity: ToolProviderEntity, provider_id: str, tenant_id: str) -> None:
|
||||
super().__init__(entity)
|
||||
self.provider_id = provider_id
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
@staticmethod
|
||||
def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController":
|
||||
credentials_schema = {
|
||||
@ -64,25 +71,23 @@ class ApiToolProviderController(ToolProviderController):
|
||||
}
|
||||
elif auth_type == ApiProviderAuthType.NONE:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"invalid auth type {auth_type}")
|
||||
|
||||
user = db_provider.user
|
||||
user_name = user.name if user else ""
|
||||
|
||||
return ApiToolProviderController(
|
||||
**{
|
||||
"identity": {
|
||||
"author": user_name,
|
||||
"name": db_provider.name,
|
||||
"label": {"en_US": db_provider.name, "zh_Hans": db_provider.name},
|
||||
"description": {"en_US": db_provider.description, "zh_Hans": db_provider.description},
|
||||
"icon": db_provider.icon,
|
||||
},
|
||||
"credentials_schema": credentials_schema,
|
||||
"provider_id": db_provider.id or "",
|
||||
"tenant_id": db_provider.tenant_id or "",
|
||||
},
|
||||
entity=ToolProviderEntity(
|
||||
identity=ToolProviderIdentity(
|
||||
author=user_name,
|
||||
name=db_provider.name,
|
||||
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
|
||||
description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description),
|
||||
icon=db_provider.icon,
|
||||
),
|
||||
credentials_schema=credentials_schema,
|
||||
),
|
||||
provider_id=db_provider.id or "",
|
||||
tenant_id=db_provider.tenant_id or "",
|
||||
)
|
||||
|
||||
@property
|
||||
@ -103,7 +108,7 @@ class ApiToolProviderController(ToolProviderController):
|
||||
"author": tool_bundle.author,
|
||||
"name": tool_bundle.operation_id,
|
||||
"label": {"en_US": tool_bundle.operation_id, "zh_Hans": tool_bundle.operation_id},
|
||||
"icon": self.identity.icon,
|
||||
"icon": self.entity.identity.icon,
|
||||
"provider": self.provider_id,
|
||||
},
|
||||
"description": {
|
||||
@ -141,7 +146,7 @@ class ApiToolProviderController(ToolProviderController):
|
||||
# get tenant api providers
|
||||
db_providers: list[ApiToolProvider] = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.identity.name)
|
||||
.filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name)
|
||||
.all()
|
||||
)
|
||||
|
||||
@ -149,7 +154,6 @@ class ApiToolProviderController(ToolProviderController):
|
||||
for db_provider in db_providers:
|
||||
for tool in db_provider.tools:
|
||||
assistant_tool = self._parse_tool_bundle(tool)
|
||||
assistant_tool.is_team_authorization = True
|
||||
tools.append(assistant_tool)
|
||||
|
||||
self.tools = tools
|
||||
@ -166,7 +170,7 @@ class ApiToolProviderController(ToolProviderController):
|
||||
self.get_tools(self.tenant_id)
|
||||
|
||||
for tool in self.tools:
|
||||
if tool.identity.name == tool_name:
|
||||
if tool.entity.identity.name == tool_name:
|
||||
return tool
|
||||
|
||||
raise ValueError(f"tool {tool_name} not found")
|
||||
|
||||
Reference in New Issue
Block a user