refactor: using DeclarativeBase as parent class of models, refactored tools

This commit is contained in:
Yeuoly
2024-09-29 17:00:58 +08:00
parent c8bc3892b3
commit e9e5c8806a
17 changed files with 225 additions and 120 deletions

View File

@ -14,9 +14,9 @@ class PluginToolManager(BasePluginManager):
provider follows format: plugin_id/provider_name
"""
if "/" in provider:
parts = provider.split("/", 1)
if len(parts) == 2:
return parts[0], parts[1]
parts = provider.split("/", -1)
if len(parts) >= 2:
return "/".join(parts[:-1]), parts[-1]
raise ValueError(f"invalid provider format: {provider}")
raise ValueError(f"invalid provider format: {provider}")
@ -46,6 +46,10 @@ class PluginToolManager(BasePluginManager):
for provider in response:
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
# override the provider name for each tool to plugin_id/provider_name
for tool in provider.declaration.tools:
tool.identity.provider = provider.declaration.identity.name
return response
def fetch_tool_provider(self, tenant_id: str, provider: str) -> PluginToolProviderEntity:
@ -54,15 +58,26 @@ class PluginToolManager(BasePluginManager):
"""
plugin_id, provider_name = self._split_provider(provider)
def transformer(json_response: dict[str, Any]) -> dict:
for tool in json_response.get("data", {}).get("declaration", {}).get("tools", []):
tool["identity"]["provider"] = provider_name
return json_response
response = self._request_with_plugin_daemon_response(
"GET",
f"plugin/{tenant_id}/management/tool",
PluginToolProviderEntity,
params={"provider": provider_name, "plugin_id": plugin_id},
transformer=transformer,
)
response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}"
# override the provider name for each tool to plugin_id/provider_name
for tool in response.declaration.tools:
tool.identity.provider = response.declaration.identity.name
return response
def invoke(

View File

@ -11,12 +11,10 @@ from core.tools.plugin_tool.tool import PluginTool
class PluginToolProviderController(BuiltinToolProviderController):
entity: ToolProviderEntityWithPlugin
tenant_id: str
plugin_id: str
def __init__(self, entity: ToolProviderEntityWithPlugin, tenant_id: str, plugin_id: str) -> None:
def __init__(self, entity: ToolProviderEntityWithPlugin, tenant_id: str) -> None:
self.entity = entity
self.tenant_id = tenant_id
self.plugin_id = plugin_id
@property
def provider_type(self) -> ToolProviderType:
@ -35,7 +33,6 @@ class PluginToolProviderController(BuiltinToolProviderController):
if not manager.validate_provider_credentials(
tenant_id=self.tenant_id,
user_id=user_id,
plugin_id=self.plugin_id,
provider=self.entity.identity.name,
credentials=credentials,
):
@ -54,7 +51,6 @@ class PluginToolProviderController(BuiltinToolProviderController):
entity=tool_entity,
runtime=ToolRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id,
plugin_id=self.plugin_id,
)
def get_tools(self) -> list[PluginTool]:
@ -66,7 +62,6 @@ class PluginToolProviderController(BuiltinToolProviderController):
entity=tool_entity,
runtime=ToolRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id,
plugin_id=self.plugin_id,
)
for tool_entity in self.entity.tools
]

View File

@ -9,12 +9,10 @@ from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, Too
class PluginTool(Tool):
tenant_id: str
plugin_id: str
def __init__(self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, plugin_id: str) -> None:
def __init__(self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str) -> None:
super().__init__(entity, runtime)
self.tenant_id = tenant_id
self.plugin_id = plugin_id
@property
def tool_provider_type(self) -> ToolProviderType:
@ -25,7 +23,6 @@ class PluginTool(Tool):
return manager.invoke(
tenant_id=self.tenant_id,
user_id=user_id,
plugin_id=self.plugin_id,
tool_provider=self.entity.identity.provider,
tool_name=self.entity.identity.name,
credentials=self.runtime.credentials,
@ -37,5 +34,4 @@ class PluginTool(Tool):
entity=self.entity,
runtime=runtime,
tenant_id=self.tenant_id,
plugin_id=self.plugin_id,
)

View File

@ -86,7 +86,6 @@ class ToolManager:
return PluginToolProviderController(
entity=provider_entity.declaration,
tenant_id=tenant_id,
plugin_id=provider_entity.plugin_id,
)
@classmethod
@ -158,12 +157,11 @@ class ToolManager:
# decrypt the credentials
credentials = builtin_provider.credentials
controller = cls.get_builtin_provider(provider_id, tenant_id)
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=controller.get_credentials_schema(),
provider_type=controller.provider_type.value,
provider_identity=controller.entity.identity.name,
config=provider_controller.get_credentials_schema(),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
decrypted_credentials = tool_configuration.decrypt(credentials)
@ -400,7 +398,6 @@ class ToolManager:
PluginToolProviderController(
entity=provider.declaration,
tenant_id=tenant_id,
plugin_id=provider.plugin_id,
)
for provider in provider_entities
]
@ -525,7 +522,7 @@ class ToolManager:
)
if isinstance(provider, PluginToolProviderController):
result_providers[f"plugin_provider.{user_provider.name}.{provider.plugin_id}"] = user_provider
result_providers[f"plugin_provider.{user_provider.name}"] = user_provider
else:
result_providers[f"builtin_provider.{user_provider.name}"] = user_provider