mirror of
https://github.com/langgenius/dify.git
synced 2026-04-23 20:36:14 +08:00
Merge branch 'main' into feat/mcp-06-18
This commit is contained in:
@ -1,7 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from openai import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.entities.tool_entities import CredentialType, ToolInvokeFrom
|
||||
|
||||
@ -1 +1 @@
|
||||
<svg width="14" height="14" viewBox="0 0 14 14" fill="none" xmlns="http://www.w3.org/2000/svg" class="w-3.5 h-3.5" data-icon="Code" aria-hidden="true"><g id="icons/code"><path id="Vector (Stroke)" fill-rule="evenodd" clip-rule="evenodd" d="M8.32593 1.69675C8.67754 1.78466 8.89132 2.14096 8.80342 2.49257L6.47009 11.8259C6.38218 12.1775 6.02588 12.3913 5.67427 12.3034C5.32265 12.2155 5.10887 11.8592 5.19678 11.5076L7.53011 2.17424C7.61801 1.82263 7.97431 1.60885 8.32593 1.69675ZM3.96414 4.20273C4.22042 4.45901 4.22042 4.87453 3.96413 5.13081L2.45578 6.63914C2.45577 6.63915 2.45578 6.63914 2.45578 6.63914C2.25645 6.83851 2.25643 7.16168 2.45575 7.36103C2.45574 7.36103 2.45576 7.36104 2.45575 7.36103L3.96413 8.86936C4.22041 9.12564 4.22042 9.54115 3.96414 9.79744C3.70787 10.0537 3.29235 10.0537 3.03607 9.79745L1.52769 8.28913C0.815811 7.57721 0.815803 6.42302 1.52766 5.7111L3.03606 4.20272C3.29234 3.94644 3.70786 3.94644 3.96414 4.20273ZM10.0361 4.20273C10.2923 3.94644 10.7078 3.94644 10.9641 4.20272L12.4725 5.71108C13.1843 6.423 13.1844 7.57717 12.4725 8.28909L10.9641 9.79745C10.7078 10.0537 10.2923 10.0537 10.036 9.79744C9.77977 9.54115 9.77978 9.12564 10.0361 8.86936L11.5444 7.36107C11.7437 7.16172 11.7438 6.83854 11.5444 6.63917C11.5444 6.63915 11.5445 6.63918 11.5444 6.63917L10.0361 5.13081C9.77978 4.87453 9.77978 4.45901 10.0361 4.20273Z" fill="currentColor"></path></g></svg>
|
||||
<svg width="14" height="14" viewBox="0 0 14 14" fill="none" xmlns="http://www.w3.org/2000/svg" class="w-3.5 h-3.5" data-icon="Code" aria-hidden="true"><g id="icons/code"><path id="Vector (Stroke)" fill-rule="evenodd" clip-rule="evenodd" d="M8.32593 1.69675C8.67754 1.78466 8.89132 2.14096 8.80342 2.49257L6.47009 11.8259C6.38218 12.1775 6.02588 12.3913 5.67427 12.3034C5.32265 12.2155 5.10887 11.8592 5.19678 11.5076L7.53011 2.17424C7.61801 1.82263 7.97431 1.60885 8.32593 1.69675ZM3.96414 4.20273C4.22042 4.45901 4.22042 4.87453 3.96413 5.13081L2.45578 6.63914C2.45577 6.63915 2.45578 6.63914 2.45578 6.63914C2.25645 6.83851 2.25643 7.16168 2.45575 7.36103C2.45574 7.36103 2.45576 7.36104 2.45575 7.36103L3.96413 8.86936C4.22041 9.12564 4.22042 9.54115 3.96414 9.79744C3.70787 10.0537 3.29235 10.0537 3.03607 9.79745L1.52769 8.28913C0.815811 7.57721 0.815803 6.42302 1.52766 5.7111L3.03606 4.20272C3.29234 3.94644 3.70786 3.94644 3.96414 4.20273ZM10.0361 4.20273C10.2923 3.94644 10.7078 3.94644 10.9641 4.20272L12.4725 5.71108C13.1843 6.423 13.1844 7.57717 12.4725 8.28909L10.9641 9.79745C10.7078 10.0537 10.2923 10.0537 10.036 9.79744C9.77977 9.54115 9.77978 9.12564 10.0361 8.86936L11.5444 7.36107C11.7437 7.16172 11.7438 6.83854 11.5444 6.63917C11.5444 6.63915 11.5445 6.63918 11.5444 6.63917L10.0361 5.13081C9.77978 4.87453 9.77978 4.45901 10.0361 4.20273Z" fill="#2e90fa"></path></g></svg>
|
||||
|
Before Width: | Height: | Size: 1.4 KiB After Width: | Height: | Size: 1.4 KiB |
@ -396,6 +396,10 @@ class ApiTool(Tool):
|
||||
# assemble invoke message based on response type
|
||||
if parsed_response.is_json and isinstance(parsed_response.content, dict):
|
||||
yield self.create_json_message(parsed_response.content)
|
||||
|
||||
# FIXES: https://github.com/langgenius/dify/pull/23456#issuecomment-3182413088
|
||||
# We need never break the original flows
|
||||
yield self.create_text_message(response.text)
|
||||
else:
|
||||
# Convert to string if needed and create text message
|
||||
text_response = (
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
@ -16,10 +17,10 @@ class ToolApiEntity(BaseModel):
|
||||
description: I18nObject
|
||||
parameters: list[ToolParameter] | None = None
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
output_schema: dict | None = None
|
||||
output_schema: Mapping[str, object] = Field(default_factory=dict)
|
||||
|
||||
|
||||
ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow", "mcp"]]
|
||||
ToolProviderTypeApiLiteral = Literal["builtin", "api", "workflow", "mcp"] | None
|
||||
|
||||
|
||||
class ToolProviderApiEntity(BaseModel):
|
||||
@ -27,17 +28,17 @@ class ToolProviderApiEntity(BaseModel):
|
||||
author: str
|
||||
name: str # identifier
|
||||
description: I18nObject
|
||||
icon: str | dict
|
||||
icon_dark: str | dict | None = Field(default=None, description="The dark icon of the tool")
|
||||
icon: str | Mapping[str, str]
|
||||
icon_dark: str | Mapping[str, str] = ""
|
||||
label: I18nObject # label
|
||||
type: ToolProviderType
|
||||
masked_credentials: dict | None = None
|
||||
original_credentials: dict | None = None
|
||||
masked_credentials: Mapping[str, object] = Field(default_factory=dict)
|
||||
original_credentials: Mapping[str, object] = Field(default_factory=dict)
|
||||
is_team_authorization: bool = False
|
||||
allow_delete: bool = True
|
||||
plugin_id: str | None = Field(default="", description="The plugin id of the tool")
|
||||
plugin_unique_identifier: str | None = Field(default="", description="The unique identifier of the tool")
|
||||
tools: list[ToolApiEntity] = Field(default_factory=list)
|
||||
tools: list[ToolApiEntity] = Field(default_factory=list[ToolApiEntity])
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
# MCP
|
||||
server_url: str | None = Field(default="", description="The server url of the tool")
|
||||
@ -113,7 +114,7 @@ class ToolProviderCredentialApiEntity(BaseModel):
|
||||
is_default: bool = Field(
|
||||
default=False, description="Whether the credential is the default credential for the provider in the workspace"
|
||||
)
|
||||
credentials: dict = Field(description="The credentials of the provider")
|
||||
credentials: Mapping[str, object] = Field(description="The credentials of the provider", default_factory=dict)
|
||||
|
||||
|
||||
class ToolProviderCredentialInfoApiEntity(BaseModel):
|
||||
|
||||
@ -22,22 +22,23 @@ from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY
|
||||
|
||||
|
||||
class ToolLabelEnum(StrEnum):
|
||||
SEARCH = auto()
|
||||
IMAGE = auto()
|
||||
VIDEOS = auto()
|
||||
WEATHER = auto()
|
||||
FINANCE = auto()
|
||||
DESIGN = auto()
|
||||
TRAVEL = auto()
|
||||
SOCIAL = auto()
|
||||
NEWS = auto()
|
||||
MEDICAL = auto()
|
||||
PRODUCTIVITY = auto()
|
||||
EDUCATION = auto()
|
||||
BUSINESS = auto()
|
||||
ENTERTAINMENT = auto()
|
||||
UTILITIES = auto()
|
||||
OTHER = auto()
|
||||
SEARCH = "search"
|
||||
IMAGE = "image"
|
||||
VIDEOS = "videos"
|
||||
WEATHER = "weather"
|
||||
FINANCE = "finance"
|
||||
DESIGN = "design"
|
||||
TRAVEL = "travel"
|
||||
SOCIAL = "social"
|
||||
NEWS = "news"
|
||||
MEDICAL = "medical"
|
||||
PRODUCTIVITY = "productivity"
|
||||
EDUCATION = "education"
|
||||
BUSINESS = "business"
|
||||
ENTERTAINMENT = "entertainment"
|
||||
UTILITIES = "utilities"
|
||||
RAG = "rag"
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
class ToolProviderType(StrEnum):
|
||||
@ -186,7 +187,7 @@ class ToolInvokeMessage(BaseModel):
|
||||
error: str | None = Field(default=None, description="The error message")
|
||||
status: LogStatus = Field(..., description="The status of the log")
|
||||
data: Mapping[str, Any] = Field(..., description="Detailed log data")
|
||||
metadata: Mapping[str, Any] | None = Field(default=None, description="The metadata of the log")
|
||||
metadata: Mapping[str, Any] = Field(default_factory=dict, description="The metadata of the log")
|
||||
|
||||
class RetrieverResourceMessage(BaseModel):
|
||||
retriever_resources: list[RetrievalSourceMetadata] = Field(..., description="retriever resources")
|
||||
@ -362,9 +363,9 @@ class ToolDescription(BaseModel):
|
||||
|
||||
class ToolEntity(BaseModel):
|
||||
identity: ToolIdentity
|
||||
parameters: list[ToolParameter] = Field(default_factory=list)
|
||||
parameters: list[ToolParameter] = Field(default_factory=list[ToolParameter])
|
||||
description: ToolDescription | None = None
|
||||
output_schema: dict | None = None
|
||||
output_schema: Mapping[str, object] = Field(default_factory=dict)
|
||||
has_runtime_parameters: bool = Field(default=False, description="Whether the tool has runtime parameters")
|
||||
|
||||
# pydantic configs
|
||||
@ -377,21 +378,23 @@ class ToolEntity(BaseModel):
|
||||
|
||||
|
||||
class OAuthSchema(BaseModel):
|
||||
client_schema: list[ProviderConfig] = Field(default_factory=list, description="The schema of the OAuth client")
|
||||
client_schema: list[ProviderConfig] = Field(
|
||||
default_factory=list[ProviderConfig], description="The schema of the OAuth client"
|
||||
)
|
||||
credentials_schema: list[ProviderConfig] = Field(
|
||||
default_factory=list, description="The schema of the OAuth credentials"
|
||||
default_factory=list[ProviderConfig], description="The schema of the OAuth credentials"
|
||||
)
|
||||
|
||||
|
||||
class ToolProviderEntity(BaseModel):
|
||||
identity: ToolProviderIdentity
|
||||
plugin_id: str | None = None
|
||||
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
|
||||
credentials_schema: list[ProviderConfig] = Field(default_factory=list[ProviderConfig])
|
||||
oauth_schema: OAuthSchema | None = None
|
||||
|
||||
|
||||
class ToolProviderEntityWithPlugin(ToolProviderEntity):
|
||||
tools: list[ToolEntity] = Field(default_factory=list)
|
||||
tools: list[ToolEntity] = Field(default_factory=list[ToolEntity])
|
||||
|
||||
|
||||
class WorkflowToolParameterConfiguration(BaseModel):
|
||||
@ -502,9 +505,9 @@ class CredentialType(StrEnum):
|
||||
@classmethod
|
||||
def of(cls, credential_type: str) -> "CredentialType":
|
||||
type_name = credential_type.lower()
|
||||
if type_name == "api-key":
|
||||
if type_name in {"api-key", "api_key"}:
|
||||
return cls.API_KEY
|
||||
elif type_name == "oauth2":
|
||||
elif type_name in {"oauth2", "oauth"}:
|
||||
return cls.OAUTH2
|
||||
else:
|
||||
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||
|
||||
@ -49,6 +49,9 @@ ICONS = {
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.OTHER: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M8.00052 0.666748L4.00065 7.33342H12.0007L8.00052 0.666748ZM8.00052 3.25828L9.64572 6.00008H6.35553L8.00052 3.25828ZM4.50065 13.3334C3.48813 13.3334 2.66732 12.5126 2.66732 11.5001C2.66732 10.4875 3.48813 9.66675 4.50065 9.66675C5.51317 9.66675 6.33398 10.4875 6.33398 11.5001C6.33398 12.5126 5.51317 13.3334 4.50065 13.3334ZM4.50065 14.6667C6.24955 14.6667 7.66732 13.249 7.66732 11.5001C7.66732 9.75115 6.24955 8.33342 4.50065 8.33342C2.75175 8.33342 1.33398 9.75115 1.33398 11.5001C1.33398 13.249 2.75175 14.6667 4.50065 14.6667ZM10.0007 10.3334V13.0001H12.6673V10.3334H10.0007ZM8.66732 14.3334V9.00008H14.0007V14.3334H8.66732Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
ToolLabelEnum.RAG: """<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 16 16" fill="none">
|
||||
<path d="M8.00065 1.3335H9.33398V2.66683H8.00065V1.3335ZM5.33398 1.3335H6.66732V2.66683H5.33398V1.3335ZM3.99935 2.66683C3.99935 2.29864 4.29783 2.00016 4.66602 2.00016H12.3327C12.7009 2.00016 13.0007 2.29864 13.0007 2.66683V13.3335C13.0007 13.7017 12.7009 14.0002 12.3327 14.0002H4.66602C4.29783 14.0002 3.99935 13.7017 3.99935 13.3335V2.66683ZM4.66602 12.6668C4.29783 12.6668 3.99935 12.3683 3.99935 12.0002V10.6668H5.33398V12.0002C5.33398 12.3683 5.0355 12.6668 4.66602 12.6668ZM5.33398 8.66683H6.66732V10.0002H5.33398V8.66683ZM5.33398 6.66683H6.66732V8.00016H5.33398V6.66683ZM3.99935 4.66683H6.66602V6.00016H3.99935V4.66683ZM6.66602 1.3335H12.3327V2.66683H6.66602V1.3335Z" fill="#344054"/>
|
||||
</svg>""", # noqa: E501
|
||||
}
|
||||
|
||||
@ -105,6 +108,9 @@ default_tool_label_dict = {
|
||||
ToolLabelEnum.OTHER: ToolLabel(
|
||||
name="other", label=I18nObject(en_US="Other", zh_Hans="其他"), icon=ICONS[ToolLabelEnum.OTHER]
|
||||
),
|
||||
ToolLabelEnum.RAG: ToolLabel(
|
||||
name="rag", label=I18nObject(en_US="RAG", zh_Hans="RAG"), icon=ICONS[ToolLabelEnum.RAG]
|
||||
),
|
||||
}
|
||||
|
||||
default_tool_labels = list(default_tool_label_dict.values())
|
||||
|
||||
@ -152,7 +152,6 @@ class ToolEngine:
|
||||
user_id: str,
|
||||
workflow_tool_callback: DifyWorkflowCallbackHandler,
|
||||
workflow_call_depth: int,
|
||||
thread_pool_id: str | None = None,
|
||||
conversation_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
@ -166,7 +165,6 @@ class ToolEngine:
|
||||
|
||||
if isinstance(tool, WorkflowTool):
|
||||
tool.workflow_call_depth = workflow_call_depth + 1
|
||||
tool.thread_pool_id = thread_pool_id
|
||||
|
||||
if tool.runtime and tool.runtime.runtime_parameters:
|
||||
tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters}
|
||||
|
||||
@ -5,7 +5,7 @@ import time
|
||||
from collections.abc import Generator, Mapping
|
||||
from os import listdir, path
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any, Literal, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from pydantic import TypeAdapter
|
||||
@ -16,7 +16,6 @@ from yarl import URL
|
||||
import contexts
|
||||
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||
from core.plugin.entities.plugin import ToolProviderID
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
@ -63,6 +62,10 @@ from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import VariablePool
|
||||
from core.workflow.nodes.tool.entities import ToolEntity
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -117,6 +120,7 @@ class ToolManager:
|
||||
get the plugin provider
|
||||
"""
|
||||
# check if context is set
|
||||
|
||||
try:
|
||||
contexts.plugin_tool_providers.get()
|
||||
except LookupError:
|
||||
@ -172,6 +176,7 @@ class ToolManager:
|
||||
|
||||
:return: the tool
|
||||
"""
|
||||
|
||||
if provider_type == ToolProviderType.BUILT_IN:
|
||||
# check if the builtin tool need credentials
|
||||
provider_controller = cls.get_builtin_provider(provider_id, tenant_id)
|
||||
@ -213,16 +218,16 @@ class ToolManager:
|
||||
# fallback to the default provider
|
||||
if builtin_provider is None:
|
||||
# use the default provider
|
||||
builtin_provider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
(BuiltinToolProvider.provider == str(provider_id_entity))
|
||||
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
|
||||
with Session(db.engine) as session:
|
||||
builtin_provider = session.scalar(
|
||||
sa.select(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
(BuiltinToolProvider.provider == str(provider_id_entity))
|
||||
| (BuiltinToolProvider.provider == provider_id_entity.provider_name),
|
||||
)
|
||||
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||
)
|
||||
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||
.first()
|
||||
)
|
||||
if builtin_provider is None:
|
||||
raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
|
||||
else:
|
||||
@ -263,6 +268,7 @@ class ToolManager:
|
||||
# check if the credentials is expired
|
||||
if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()):
|
||||
# TODO: circular import
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
# refresh the credentials
|
||||
@ -270,6 +276,7 @@ class ToolManager:
|
||||
provider_name = tool_provider.provider_name
|
||||
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback"
|
||||
system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id)
|
||||
|
||||
oauth_handler = OAuthHandler()
|
||||
# refresh the credentials
|
||||
refreshed_credentials = oauth_handler.refresh_credentials(
|
||||
@ -358,7 +365,7 @@ class ToolManager:
|
||||
app_id: str,
|
||||
agent_tool: AgentToolEntity,
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
variable_pool: VariablePool | None = None,
|
||||
variable_pool: Optional["VariablePool"] = None,
|
||||
) -> Tool:
|
||||
"""
|
||||
get the agent tool runtime
|
||||
@ -400,7 +407,7 @@ class ToolManager:
|
||||
node_id: str,
|
||||
workflow_tool: "ToolEntity",
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
variable_pool: VariablePool | None = None,
|
||||
variable_pool: Optional["VariablePool"] = None,
|
||||
) -> Tool:
|
||||
"""
|
||||
get the workflow tool runtime
|
||||
@ -516,6 +523,7 @@ class ToolManager:
|
||||
"""
|
||||
list all the plugin providers
|
||||
"""
|
||||
|
||||
manager = PluginToolManager()
|
||||
provider_entities = manager.fetch_tool_providers(tenant_id)
|
||||
return [
|
||||
@ -879,7 +887,7 @@ class ToolManager:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str):
|
||||
def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str]:
|
||||
try:
|
||||
workflow_provider: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
@ -890,13 +898,13 @@ class ToolManager:
|
||||
if workflow_provider is None:
|
||||
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
||||
|
||||
icon: dict = json.loads(workflow_provider.icon)
|
||||
icon = json.loads(workflow_provider.icon)
|
||||
return icon
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
@classmethod
|
||||
def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str):
|
||||
def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str]:
|
||||
try:
|
||||
api_provider: ApiToolProvider | None = (
|
||||
db.session.query(ApiToolProvider)
|
||||
@ -907,13 +915,13 @@ class ToolManager:
|
||||
if api_provider is None:
|
||||
raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
|
||||
|
||||
icon: dict = json.loads(api_provider.icon)
|
||||
icon = json.loads(api_provider.icon)
|
||||
return icon
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
@classmethod
|
||||
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict[str, str] | str:
|
||||
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str] | str:
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
@ -933,7 +941,7 @@ class ToolManager:
|
||||
tenant_id: str,
|
||||
provider_type: ToolProviderType,
|
||||
provider_id: str,
|
||||
) -> Union[str, dict[str, Any]]:
|
||||
) -> str | Mapping[str, str]:
|
||||
"""
|
||||
get the tool icon
|
||||
|
||||
@ -958,11 +966,10 @@ class ToolManager:
|
||||
return cls.generate_workflow_tool_icon_url(tenant_id, provider_id)
|
||||
elif provider_type == ToolProviderType.PLUGIN:
|
||||
provider = ToolManager.get_plugin_provider(provider_id, tenant_id)
|
||||
if isinstance(provider, PluginToolProviderController):
|
||||
try:
|
||||
return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
try:
|
||||
return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
raise ValueError(f"plugin provider {provider_id} not found")
|
||||
elif provider_type == ToolProviderType.MCP:
|
||||
return cls.generate_mcp_tool_icon_url(tenant_id, provider_id)
|
||||
@ -973,7 +980,7 @@ class ToolManager:
|
||||
def _convert_tool_parameters_type(
|
||||
cls,
|
||||
parameters: list[ToolParameter],
|
||||
variable_pool: VariablePool | None,
|
||||
variable_pool: Optional["VariablePool"],
|
||||
tool_configurations: dict[str, Any],
|
||||
typ: Literal["agent", "workflow", "tool"] = "workflow",
|
||||
) -> dict[str, Any]:
|
||||
|
||||
@ -18,6 +18,10 @@ class DatasetRetrieverBaseTool(BaseModel, ABC):
|
||||
retriever_from: str
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def run(self, query: str) -> str:
|
||||
"""Use the tool."""
|
||||
return self._run(query)
|
||||
|
||||
@abstractmethod
|
||||
def _run(self, query: str) -> str:
|
||||
"""Use the tool.
|
||||
|
||||
@ -124,7 +124,7 @@ class DatasetRetrieverTool(Tool):
|
||||
yield self.create_text_message(text="please input query")
|
||||
else:
|
||||
# invoke dataset retriever tool
|
||||
result = self.retrieval_tool._run(query=query)
|
||||
result = self.retrieval_tool.run(query=query)
|
||||
yield self.create_text_message(text=result)
|
||||
|
||||
def validate_credentials(
|
||||
|
||||
@ -123,11 +123,15 @@ class ProviderConfigEncrypter:
|
||||
return data
|
||||
|
||||
|
||||
def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
|
||||
def create_provider_encrypter(
|
||||
tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache
|
||||
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
|
||||
return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache
|
||||
|
||||
|
||||
def create_tool_provider_encrypter(tenant_id: str, controller: ToolProviderController):
|
||||
def create_tool_provider_encrypter(
|
||||
tenant_id: str, controller: ToolProviderController
|
||||
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
|
||||
cache = SingletonProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
provider_type=controller.provider_type.value,
|
||||
|
||||
@ -2,6 +2,7 @@ import re
|
||||
from json import dumps as json_dumps
|
||||
from json import loads as json_loads
|
||||
from json.decoder import JSONDecodeError
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from requests import get
|
||||
@ -127,34 +128,34 @@ class ApiBasedToolSchemaParser:
|
||||
if "allOf" in prop_dict:
|
||||
del prop_dict["allOf"]
|
||||
|
||||
# parse body parameters
|
||||
if "schema" in interface["operation"]["requestBody"]["content"][content_type]:
|
||||
body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"]
|
||||
required = body_schema.get("required", [])
|
||||
properties = body_schema.get("properties", {})
|
||||
for name, property in properties.items():
|
||||
tool = ToolParameter(
|
||||
name=name,
|
||||
label=I18nObject(en_US=name, zh_Hans=name),
|
||||
human_description=I18nObject(
|
||||
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=name in required,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description=property.get("description", ""),
|
||||
default=property.get("default", None),
|
||||
placeholder=I18nObject(
|
||||
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||
),
|
||||
)
|
||||
# parse body parameters
|
||||
if "schema" in interface["operation"]["requestBody"]["content"][content_type]:
|
||||
body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"]
|
||||
required = body_schema.get("required", [])
|
||||
properties = body_schema.get("properties", {})
|
||||
for name, property in properties.items():
|
||||
tool = ToolParameter(
|
||||
name=name,
|
||||
label=I18nObject(en_US=name, zh_Hans=name),
|
||||
human_description=I18nObject(
|
||||
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=name in required,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description=property.get("description", ""),
|
||||
default=property.get("default", None),
|
||||
placeholder=I18nObject(
|
||||
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||
),
|
||||
)
|
||||
|
||||
# check if there is a type
|
||||
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property)
|
||||
if typ:
|
||||
tool.type = typ
|
||||
# check if there is a type
|
||||
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property)
|
||||
if typ:
|
||||
tool.type = typ
|
||||
|
||||
parameters.append(tool)
|
||||
parameters.append(tool)
|
||||
|
||||
# check if parameters is duplicated
|
||||
parameters_count = {}
|
||||
@ -241,7 +242,9 @@ class ApiBasedToolSchemaParser:
|
||||
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
|
||||
|
||||
@staticmethod
|
||||
def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None):
|
||||
def parse_swagger_to_openapi(
|
||||
swagger: dict, extra_info: dict | None = None, warning: dict | None = None
|
||||
) -> dict[str, Any]:
|
||||
warning = warning or {}
|
||||
"""
|
||||
parse swagger to openapi
|
||||
@ -257,7 +260,7 @@ class ApiBasedToolSchemaParser:
|
||||
if len(servers) == 0:
|
||||
raise ToolApiSchemaError("No server found in the swagger yaml.")
|
||||
|
||||
openapi = {
|
||||
converted_openapi: dict[str, Any] = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {
|
||||
"title": info.get("title", "Swagger"),
|
||||
@ -275,7 +278,7 @@ class ApiBasedToolSchemaParser:
|
||||
|
||||
# convert paths
|
||||
for path, path_item in swagger["paths"].items():
|
||||
openapi["paths"][path] = {}
|
||||
converted_openapi["paths"][path] = {}
|
||||
for method, operation in path_item.items():
|
||||
if "operationId" not in operation:
|
||||
raise ToolApiSchemaError(f"No operationId found in operation {method} {path}.")
|
||||
@ -286,7 +289,7 @@ class ApiBasedToolSchemaParser:
|
||||
if warning is not None:
|
||||
warning["missing_summary"] = f"No summary or description found in operation {method} {path}."
|
||||
|
||||
openapi["paths"][path][method] = {
|
||||
converted_openapi["paths"][path][method] = {
|
||||
"operationId": operation["operationId"],
|
||||
"summary": operation.get("summary", ""),
|
||||
"description": operation.get("description", ""),
|
||||
@ -295,13 +298,14 @@ class ApiBasedToolSchemaParser:
|
||||
}
|
||||
|
||||
if "requestBody" in operation:
|
||||
openapi["paths"][path][method]["requestBody"] = operation["requestBody"]
|
||||
converted_openapi["paths"][path][method]["requestBody"] = operation["requestBody"]
|
||||
|
||||
# convert definitions
|
||||
for name, definition in swagger["definitions"].items():
|
||||
openapi["components"]["schemas"][name] = definition
|
||||
if "definitions" in swagger:
|
||||
for name, definition in swagger["definitions"].items():
|
||||
converted_openapi["components"]["schemas"][name] = definition
|
||||
|
||||
return openapi
|
||||
return converted_openapi
|
||||
|
||||
@staticmethod
|
||||
def parse_openai_plugin_json_to_tool_bundle(
|
||||
|
||||
@ -39,14 +39,12 @@ class WorkflowTool(Tool):
|
||||
entity: ToolEntity,
|
||||
runtime: ToolRuntime,
|
||||
label: str = "Workflow",
|
||||
thread_pool_id: str | None = None,
|
||||
):
|
||||
self.workflow_app_id = workflow_app_id
|
||||
self.workflow_as_tool_id = workflow_as_tool_id
|
||||
self.version = version
|
||||
self.workflow_entities = workflow_entities
|
||||
self.workflow_call_depth = workflow_call_depth
|
||||
self.thread_pool_id = thread_pool_id
|
||||
self.label = label
|
||||
|
||||
super().__init__(entity=entity, runtime=runtime)
|
||||
@ -90,7 +88,6 @@ class WorkflowTool(Tool):
|
||||
invoke_from=self.runtime.invoke_from,
|
||||
streaming=False,
|
||||
call_depth=self.workflow_call_depth + 1,
|
||||
workflow_thread_pool_id=self.thread_pool_id,
|
||||
)
|
||||
assert isinstance(result, dict)
|
||||
data = result.get("data", {})
|
||||
|
||||
Reference in New Issue
Block a user