Merge branch 'fix/chore-fix' into dev/plugin-deploy

This commit is contained in:
Yeuoly
2024-12-02 21:08:53 +08:00
24 changed files with 217 additions and 55 deletions

View File

@ -74,6 +74,7 @@ class BuiltinToolProviderController(ToolProviderController):
tool["identity"]["provider"] = provider
tools.append(
assistant_tool_class(
provider=provider,
entity=ToolEntity(**tool),
runtime=ToolRuntime(tenant_id=""),
)

View File

@ -1,6 +1,7 @@
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.utils.model_invocation_utils import ModelInvocationUtils
@ -19,6 +20,25 @@ class BuiltinTool(Tool):
:param meta: the meta data of a tool call processing
"""
provider: str
def __init__(self, provider: str, **kwargs):
super().__init__(**kwargs)
self.provider = provider
def fork_tool_runtime(self, runtime: ToolRuntime) -> "BuiltinTool":
"""
fork a new tool with meta data
:param meta: the meta data of a tool call processing, tenant_id is required
:return: the new tool
"""
return self.__class__(
entity=self.entity.model_copy(),
runtime=runtime,
provider=self.provider,
)
def invoke_model(self, user_id: str, prompt_messages: list[PromptMessage], stop: list[str]) -> LLMResult:
"""
invoke model

View File

@ -109,6 +109,7 @@ class ApiToolProviderController(ToolProviderController):
"""
return ApiTool(
api_bundle=tool_bundle,
provider_id=self.provider_id,
entity=ToolEntity(
identity=ToolIdentity(
author=tool_bundle.author,

View File

@ -22,14 +22,16 @@ API_TOOL_DEFAULT_TIMEOUT = (
class ApiTool(Tool):
api_bundle: ApiToolBundle
provider_id: str
"""
Api tool
"""
def __init__(self, entity: ToolEntity, api_bundle: ApiToolBundle, runtime: ToolRuntime):
def __init__(self, entity: ToolEntity, api_bundle: ApiToolBundle, runtime: ToolRuntime, provider_id: str):
super().__init__(entity, runtime)
self.api_bundle = api_bundle
self.provider_id = provider_id
def fork_tool_runtime(self, runtime: ToolRuntime):
"""
@ -42,6 +44,7 @@ class ApiTool(Tool):
entity=self.entity,
api_bundle=self.api_bundle.model_copy(),
runtime=runtime,
provider_id=self.provider_id,
)
def validate_credentials(

View File

@ -34,6 +34,7 @@ class ToolProviderApiEntity(BaseModel):
is_team_authorization: bool = False
allow_delete: bool = True
plugin_id: Optional[str] = Field(default="", description="The plugin id of the tool")
plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the tool")
tools: list[ToolApiEntity] = Field(default_factory=list)
labels: list[str] = Field(default_factory=list)
@ -58,6 +59,7 @@ class ToolProviderApiEntity(BaseModel):
"author": self.author,
"name": self.name,
"plugin_id": self.plugin_id,
"plugin_unique_identifier": self.plugin_unique_identifier,
"description": self.description.to_dict(),
"icon": self.icon,
"label": self.label.to_dict(),

View File

@ -12,11 +12,15 @@ class PluginToolProviderController(BuiltinToolProviderController):
entity: ToolProviderEntityWithPlugin
tenant_id: str
plugin_id: str
plugin_unique_identifier: str
def __init__(self, entity: ToolProviderEntityWithPlugin, plugin_id: str, tenant_id: str) -> None:
def __init__(
self, entity: ToolProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
) -> None:
self.entity = entity
self.tenant_id = tenant_id
self.plugin_id = plugin_id
self.plugin_unique_identifier = plugin_unique_identifier
@property
def provider_type(self) -> ToolProviderType:
@ -53,6 +57,8 @@ class PluginToolProviderController(BuiltinToolProviderController):
entity=tool_entity,
runtime=ToolRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id,
icon=self.entity.identity.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)
def get_tools(self) -> list[PluginTool]:
@ -64,6 +70,8 @@ class PluginToolProviderController(BuiltinToolProviderController):
entity=tool_entity,
runtime=ToolRuntime(tenant_id=self.tenant_id),
tenant_id=self.tenant_id,
icon=self.entity.identity.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)
for tool_entity in self.entity.tools
]

View File

@ -11,11 +11,17 @@ from models.model import File
class PluginTool(Tool):
tenant_id: str
icon: str
plugin_unique_identifier: str
runtime_parameters: Optional[list[ToolParameter]]
def __init__(self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str) -> None:
def __init__(
self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, plugin_unique_identifier: str
) -> None:
super().__init__(entity, runtime)
self.tenant_id = tenant_id
self.icon = icon
self.plugin_unique_identifier = plugin_unique_identifier
self.runtime_parameters = None
def tool_provider_type(self) -> ToolProviderType:
@ -64,6 +70,8 @@ class PluginTool(Tool):
entity=self.entity,
runtime=runtime,
tenant_id=self.tenant_id,
icon=self.icon,
plugin_unique_identifier=self.plugin_unique_identifier,
)
def get_runtime_parameters(self) -> list[ToolParameter]:

View File

@ -6,6 +6,9 @@ from os import listdir, path
from threading import Lock
from typing import TYPE_CHECKING, Any, Union, cast
from yarl import URL
import contexts
from core.plugin.entities.plugin import GenericProviderID
from core.plugin.manager.tool import PluginToolManager
from core.tools.__base.tool_runtime import ToolRuntime
@ -97,16 +100,26 @@ class ToolManager:
"""
get the plugin provider
"""
manager = PluginToolManager()
provider_entity = manager.fetch_tool_provider(tenant_id, provider)
if not provider_entity:
raise ToolProviderNotFoundError(f"plugin provider {provider} not found")
with contexts.plugin_tool_providers_lock.get():
plugin_tool_providers = contexts.plugin_tool_providers.get()
if provider in plugin_tool_providers:
return plugin_tool_providers[provider]
return PluginToolProviderController(
entity=provider_entity.declaration,
plugin_id=provider_entity.plugin_id,
tenant_id=tenant_id,
)
manager = PluginToolManager()
provider_entity = manager.fetch_tool_provider(tenant_id, provider)
if not provider_entity:
raise ToolProviderNotFoundError(f"plugin provider {provider} not found")
controller = PluginToolProviderController(
entity=provider_entity.declaration,
plugin_id=provider_entity.plugin_id,
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
tenant_id=tenant_id,
)
plugin_tool_providers[provider] = controller
return controller
@classmethod
def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None:
@ -132,7 +145,7 @@ class ToolManager:
tenant_id: str,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
) -> Union[BuiltinTool, ApiTool, WorkflowTool]:
) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool]:
"""
get the tool runtime
@ -260,6 +273,8 @@ class ToolManager:
)
elif provider_type == ToolProviderType.APP:
raise NotImplementedError("app provider not implemented")
elif provider_type == ToolProviderType.PLUGIN:
return cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name)
else:
raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
@ -477,6 +492,7 @@ class ToolManager:
PluginToolProviderController(
entity=provider.declaration,
plugin_id=provider.plugin_id,
plugin_unique_identifier=provider.plugin_unique_identifier,
tenant_id=tenant_id,
)
for provider in provider_entities
@ -758,7 +774,66 @@ class ToolManager:
)
@classmethod
def get_tool_icon(cls, tenant_id: str, provider_type: ToolProviderType, provider_id: str) -> Union[str, dict]:
def generate_builtin_tool_icon_url(cls, provider_id: str) -> str:
return (
dify_config.CONSOLE_API_URL
+ "/console/api/workspaces/current/tool-provider/builtin/"
+ provider_id
+ "/icon"
)
@classmethod
def generate_plugin_tool_icon_url(cls, tenant_id: str, filename: str) -> str:
return str(
URL(dify_config.CONSOLE_API_URL)
/ "console"
/ "api"
/ "workspaces"
/ "current"
/ "plugin"
/ "icon"
% {"tenant_id": tenant_id, "filename": filename}
)
@classmethod
def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict:
try:
workflow_provider: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
.first()
)
if workflow_provider is None:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
return json.loads(workflow_provider.icon)
except:
return {"background": "#252525", "content": "\ud83d\ude01"}
@classmethod
def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict:
try:
api_provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
.filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
.first()
)
if api_provider is None:
raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
return json.loads(api_provider.icon)
except:
return {"background": "#252525", "content": "\ud83d\ude01"}
@classmethod
def get_tool_icon(
cls,
tenant_id: str,
provider_type: ToolProviderType,
provider_id: str,
) -> Union[str, dict]:
"""
get the tool icon
@ -770,36 +845,25 @@ class ToolManager:
provider_type = provider_type
provider_id = provider_id
if provider_type == ToolProviderType.BUILT_IN:
return (
dify_config.CONSOLE_API_URL
+ "/console/api/workspaces/current/tool-provider/builtin/"
+ provider_id
+ "/icon"
)
provider = ToolManager.get_builtin_provider(provider_id, tenant_id)
if isinstance(provider, PluginToolProviderController):
try:
return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
except:
return {"background": "#252525", "content": "\ud83d\ude01"}
return cls.generate_builtin_tool_icon_url(provider_id)
elif provider_type == ToolProviderType.API:
try:
api_provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
.filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
.first()
)
if not api_provider:
raise ValueError("api tool not found")
return json.loads(api_provider.icon)
except:
return {"background": "#252525", "content": "\ud83d\ude01"}
return cls.generate_api_tool_icon_url(tenant_id, provider_id)
elif provider_type == ToolProviderType.WORKFLOW:
workflow_provider: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
.first()
)
if workflow_provider is None:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
return json.loads(workflow_provider.icon)
return cls.generate_workflow_tool_icon_url(tenant_id, provider_id)
elif provider_type == ToolProviderType.PLUGIN:
provider = ToolManager.get_builtin_provider(provider_id, tenant_id)
if isinstance(provider, PluginToolProviderController):
try:
return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
except:
return {"background": "#252525", "content": "\ud83d\ude01"}
raise ValueError(f"plugin provider {provider_id} not found")
else:
raise ValueError(f"provider type {provider_type} not found")

View File

@ -148,6 +148,7 @@ class WorkflowToolProviderController(ToolProviderController):
raise ValueError("variable not found")
return WorkflowTool(
workflow_as_tool_id=db_provider.id,
entity=ToolEntity(
identity=ToolIdentity(
author=user.name if user else "",

View File

@ -21,6 +21,7 @@ class WorkflowTool(Tool):
workflow_entities: dict[str, Any]
workflow_call_depth: int
thread_pool_id: Optional[str] = None
workflow_as_tool_id: str
label: str
@ -31,6 +32,7 @@ class WorkflowTool(Tool):
def __init__(
self,
workflow_app_id: str,
workflow_as_tool_id: str,
version: str,
workflow_entities: dict[str, Any],
workflow_call_depth: int,
@ -40,6 +42,7 @@ class WorkflowTool(Tool):
thread_pool_id: Optional[str] = None,
):
self.workflow_app_id = workflow_app_id
self.workflow_as_tool_id = workflow_as_tool_id
self.version = version
self.workflow_entities = workflow_entities
self.workflow_call_depth = workflow_call_depth
@ -134,6 +137,7 @@ class WorkflowTool(Tool):
entity=self.entity.model_copy(),
runtime=runtime,
workflow_app_id=self.workflow_app_id,
workflow_as_tool_id=self.workflow_as_tool_id,
workflow_entities=self.workflow_entities,
workflow_call_depth=self.workflow_call_depth,
version=self.version,