refactor: tool

This commit is contained in:
Yeuoly
2024-09-20 23:48:48 +08:00
parent 3c1d32e3ac
commit 91cb80f795
29 changed files with 498 additions and 906 deletions

View File

@ -8,8 +8,9 @@ import httpx
from core.helper import ssrf_proxy
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError
API_TOOL_DEFAULT_TIMEOUT = (
@ -25,7 +26,11 @@ class ApiTool(Tool):
Api tool
"""
def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool":
def __init__(self, entity: ToolEntity, api_bundle: ApiToolBundle, runtime: ToolRuntime):
super().__init__(entity, runtime)
self.api_bundle = api_bundle
def fork_tool_runtime(self, runtime: ToolRuntime):
"""
fork a new tool with meta data
@ -33,11 +38,9 @@ class ApiTool(Tool):
:return: the new tool
"""
return self.__class__(
identity=self.identity.model_copy(),
parameters=self.parameters.copy() if self.parameters else [],
description=self.description.model_copy() if self.description else None,
entity=self.entity,
api_bundle=self.api_bundle.model_copy(),
runtime=Tool.Runtime(**runtime),
runtime=runtime,
)
def validate_credentials(
@ -62,7 +65,7 @@ class ApiTool(Tool):
def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]:
if self.runtime == None:
raise ToolProviderCredentialValidationError("runtime not initialized")
headers = {}
credentials = self.runtime.credentials or {}