refactor: tool

This commit is contained in:
Yeuoly
2024-09-20 23:48:48 +08:00
parent 3c1d32e3ac
commit 91cb80f795
29 changed files with 498 additions and 906 deletions

View File

@ -7,6 +7,8 @@ from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ToolProviderEntity,
ToolProviderIdentity,
ToolProviderType,
)
from extensions.ext_database import db
@ -18,6 +20,11 @@ class ApiToolProviderController(ToolProviderController):
tenant_id: str
tools: list[ApiTool] = Field(default_factory=list)
def __init__(self, entity: ToolProviderEntity, provider_id: str, tenant_id: str) -> None:
super().__init__(entity)
self.provider_id = provider_id
self.tenant_id = tenant_id
@staticmethod
def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> "ApiToolProviderController":
credentials_schema = {
@ -64,25 +71,23 @@ class ApiToolProviderController(ToolProviderController):
}
elif auth_type == ApiProviderAuthType.NONE:
pass
else:
raise ValueError(f"invalid auth type {auth_type}")
user = db_provider.user
user_name = user.name if user else ""
return ApiToolProviderController(
**{
"identity": {
"author": user_name,
"name": db_provider.name,
"label": {"en_US": db_provider.name, "zh_Hans": db_provider.name},
"description": {"en_US": db_provider.description, "zh_Hans": db_provider.description},
"icon": db_provider.icon,
},
"credentials_schema": credentials_schema,
"provider_id": db_provider.id or "",
"tenant_id": db_provider.tenant_id or "",
},
entity=ToolProviderEntity(
identity=ToolProviderIdentity(
author=user_name,
name=db_provider.name,
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
description=I18nObject(en_US=db_provider.description, zh_Hans=db_provider.description),
icon=db_provider.icon,
),
credentials_schema=credentials_schema,
),
provider_id=db_provider.id or "",
tenant_id=db_provider.tenant_id or "",
)
@property
@ -103,7 +108,7 @@ class ApiToolProviderController(ToolProviderController):
"author": tool_bundle.author,
"name": tool_bundle.operation_id,
"label": {"en_US": tool_bundle.operation_id, "zh_Hans": tool_bundle.operation_id},
"icon": self.identity.icon,
"icon": self.entity.identity.icon,
"provider": self.provider_id,
},
"description": {
@ -141,7 +146,7 @@ class ApiToolProviderController(ToolProviderController):
# get tenant api providers
db_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider)
.filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.identity.name)
.filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name)
.all()
)
@ -149,7 +154,6 @@ class ApiToolProviderController(ToolProviderController):
for db_provider in db_providers:
for tool in db_provider.tools:
assistant_tool = self._parse_tool_bundle(tool)
assistant_tool.is_team_authorization = True
tools.append(assistant_tool)
self.tools = tools
@ -166,7 +170,7 @@ class ApiToolProviderController(ToolProviderController):
self.get_tools(self.tenant_id)
for tool in self.tools:
if tool.identity.name == tool_name:
if tool.entity.identity.name == tool_name:
return tool
raise ValueError(f"tool {tool_name} not found")