mirror of
https://github.com/langgenius/dify.git
synced 2026-04-29 23:18:05 +08:00
Merge branch 'feat/queue-based-graph-engine' into chore/merge-graph-engine
This commit is contained in:
@ -20,7 +20,7 @@ class Tool(ABC):
|
||||
The base class of a tool
|
||||
"""
|
||||
|
||||
def __init__(self, entity: ToolEntity, runtime: ToolRuntime) -> None:
|
||||
def __init__(self, entity: ToolEntity, runtime: ToolRuntime):
|
||||
self.entity = entity
|
||||
self.runtime = runtime
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@ from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
|
||||
class ToolProviderController(ABC):
|
||||
def __init__(self, entity: ToolProviderEntity) -> None:
|
||||
def __init__(self, entity: ToolProviderEntity):
|
||||
self.entity = entity
|
||||
|
||||
def get_credentials_schema(self) -> list[ProviderConfig]:
|
||||
@ -41,7 +41,7 @@ class ToolProviderController(ABC):
|
||||
"""
|
||||
return ToolProviderType.BUILT_IN
|
||||
|
||||
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
|
||||
def validate_credentials_format(self, credentials: dict[str, Any]):
|
||||
"""
|
||||
validate the format of the credentials of the provider and set the default value if needed
|
||||
|
||||
|
||||
@ -24,7 +24,7 @@ from core.tools.utils.yaml_utils import load_yaml_file
|
||||
class BuiltinToolProviderController(ToolProviderController):
|
||||
tools: list[BuiltinTool]
|
||||
|
||||
def __init__(self, **data: Any) -> None:
|
||||
def __init__(self, **data: Any):
|
||||
self.tools = []
|
||||
|
||||
# load provider yaml
|
||||
@ -197,7 +197,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
"""
|
||||
return self.entity.identity.tags or []
|
||||
|
||||
def validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
def validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
|
||||
@ -211,7 +211,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
self._validate_credentials(user_id, credentials)
|
||||
|
||||
@abstractmethod
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
|
||||
|
||||
@ -4,5 +4,5 @@ from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class AudioToolProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
pass
|
||||
|
||||
@ -4,5 +4,5 @@ from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class CodeToolProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
pass
|
||||
|
||||
@ -4,5 +4,5 @@ from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class WikiPediaProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
pass
|
||||
|
||||
@ -4,7 +4,7 @@ from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class WebscraperProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
"""
|
||||
Validate credentials
|
||||
"""
|
||||
|
||||
@ -24,7 +24,7 @@ class ApiToolProviderController(ToolProviderController):
|
||||
tenant_id: str
|
||||
tools: list[ApiTool] = Field(default_factory=list)
|
||||
|
||||
def __init__(self, entity: ToolProviderEntity, provider_id: str, tenant_id: str) -> None:
|
||||
def __init__(self, entity: ToolProviderEntity, provider_id: str, tenant_id: str):
|
||||
super().__init__(entity)
|
||||
self.provider_id = provider_id
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
@ -302,7 +302,7 @@ class ApiTool(Tool):
|
||||
|
||||
def _convert_body_property_any_of(
|
||||
self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10
|
||||
) -> Any:
|
||||
):
|
||||
if max_recursive <= 0:
|
||||
raise Exception("Max recursion depth reached")
|
||||
for option in any_of or []:
|
||||
@ -337,7 +337,7 @@ class ApiTool(Tool):
|
||||
# If no option succeeded, you might want to return the value as is or raise an error
|
||||
return value # or raise ValueError(f"Cannot convert value '{value}' to any specified type in anyOf")
|
||||
|
||||
def _convert_body_property_type(self, property: dict[str, Any], value: Any) -> Any:
|
||||
def _convert_body_property_type(self, property: dict[str, Any], value: Any):
|
||||
try:
|
||||
if "type" in property:
|
||||
if property["type"] == "integer" or property["type"] == "int":
|
||||
|
||||
@ -49,7 +49,7 @@ class ToolProviderApiEntity(BaseModel):
|
||||
def convert_none_to_empty_list(cls, v):
|
||||
return v if v is not None else []
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
def to_dict(self):
|
||||
# -------------
|
||||
# overwrite tool parameter types for temp fix
|
||||
tools = jsonable_encoder(self.tools)
|
||||
@ -84,7 +84,7 @@ class ToolProviderApiEntity(BaseModel):
|
||||
**optional_fields,
|
||||
}
|
||||
|
||||
def optional_field(self, key: str, value: Any) -> dict:
|
||||
def optional_field(self, key: str, value: Any):
|
||||
"""Return dict with key-value if value is truthy, empty dict otherwise."""
|
||||
return {key: value} if value else {}
|
||||
|
||||
|
||||
@ -19,5 +19,5 @@ class I18nObject(BaseModel):
|
||||
self.pt_BR = self.pt_BR or self.en_US
|
||||
self.ja_JP = self.ja_JP or self.en_US
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
def to_dict(self):
|
||||
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}
|
||||
|
||||
@ -151,7 +151,7 @@ class ToolInvokeMessage(BaseModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def transform_variable_value(cls, values) -> Any:
|
||||
def transform_variable_value(cls, values):
|
||||
"""
|
||||
Only basic types and lists are allowed.
|
||||
"""
|
||||
@ -429,7 +429,7 @@ class ToolInvokeMeta(BaseModel):
|
||||
"""
|
||||
return cls(time_cost=0.0, error=error, tool_config={})
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
def to_dict(self):
|
||||
return {
|
||||
"time_cost": self.time_cost,
|
||||
"error": self.error,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Self
|
||||
|
||||
from core.mcp.types import Tool as RemoteMCPTool
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
@ -28,7 +28,7 @@ class MCPToolProviderController(ToolProviderController):
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
sse_read_timeout: Optional[float] = None,
|
||||
) -> None:
|
||||
):
|
||||
super().__init__(entity)
|
||||
self.entity: ToolProviderEntityWithPlugin = entity
|
||||
self.tenant_id = tenant_id
|
||||
@ -48,7 +48,7 @@ class MCPToolProviderController(ToolProviderController):
|
||||
return ToolProviderType.MCP
|
||||
|
||||
@classmethod
|
||||
def _from_db(cls, db_provider: MCPToolProvider) -> "MCPToolProviderController":
|
||||
def from_db(cls, db_provider: MCPToolProvider) -> Self:
|
||||
"""
|
||||
from db provider
|
||||
"""
|
||||
@ -99,7 +99,7 @@ class MCPToolProviderController(ToolProviderController):
|
||||
sse_read_timeout=db_provider.sse_read_timeout,
|
||||
)
|
||||
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
"""
|
||||
|
||||
@ -23,7 +23,7 @@ class MCPTool(Tool):
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
sse_read_timeout: Optional[float] = None,
|
||||
) -> None:
|
||||
):
|
||||
super().__init__(entity, runtime)
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
|
||||
@ -16,7 +16,7 @@ class PluginToolProviderController(BuiltinToolProviderController):
|
||||
|
||||
def __init__(
|
||||
self, entity: ToolProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||
) -> None:
|
||||
):
|
||||
self.entity = entity
|
||||
self.tenant_id = tenant_id
|
||||
self.plugin_id = plugin_id
|
||||
@ -31,7 +31,7 @@ class PluginToolProviderController(BuiltinToolProviderController):
|
||||
"""
|
||||
return ToolProviderType.PLUGIN
|
||||
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
"""
|
||||
|
||||
@ -11,7 +11,7 @@ from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, Too
|
||||
class PluginTool(Tool):
|
||||
def __init__(
|
||||
self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, plugin_unique_identifier: str
|
||||
) -> None:
|
||||
):
|
||||
super().__init__(entity, runtime)
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
|
||||
@ -98,6 +98,7 @@ class ToolFileManager:
|
||||
mimetype=mimetype,
|
||||
name=present_filename,
|
||||
size=len(file_binary),
|
||||
original_url=None,
|
||||
)
|
||||
|
||||
session.add(tool_file)
|
||||
@ -131,7 +132,6 @@ class ToolFileManager:
|
||||
filename = f"{unique_name}{extension}"
|
||||
filepath = f"tools/{tenant_id}/{filename}"
|
||||
storage.save(filepath, blob)
|
||||
|
||||
with Session(self._engine, expire_on_commit=False) as session:
|
||||
tool_file = ToolFile(
|
||||
user_id=user_id,
|
||||
|
||||
@ -646,7 +646,7 @@ class ToolManager:
|
||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
|
||||
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
|
||||
data=provider,
|
||||
name_func=lambda x: x.identity.name,
|
||||
name_func=lambda x: x.entity.identity.name,
|
||||
):
|
||||
continue
|
||||
user_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||
@ -777,12 +777,12 @@ class ToolManager:
|
||||
if provider is None:
|
||||
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
|
||||
|
||||
controller = MCPToolProviderController._from_db(provider)
|
||||
controller = MCPToolProviderController.from_db(provider)
|
||||
|
||||
return controller
|
||||
|
||||
@classmethod
|
||||
def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict:
|
||||
def user_get_api_provider(cls, provider: str, tenant_id: str):
|
||||
"""
|
||||
get api provider
|
||||
"""
|
||||
@ -877,7 +877,7 @@ class ToolManager:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict:
|
||||
def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str):
|
||||
try:
|
||||
workflow_provider: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
@ -894,7 +894,7 @@ class ToolManager:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
@classmethod
|
||||
def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict:
|
||||
def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str):
|
||||
try:
|
||||
api_provider: ApiToolProvider | None = (
|
||||
db.session.query(ApiToolProvider)
|
||||
@ -932,7 +932,7 @@ class ToolManager:
|
||||
tenant_id: str,
|
||||
provider_type: ToolProviderType,
|
||||
provider_id: str,
|
||||
) -> Union[str, dict]:
|
||||
) -> Union[str, dict[str, Any]]:
|
||||
"""
|
||||
get the tool icon
|
||||
|
||||
|
||||
@ -24,7 +24,7 @@ class ToolParameterConfigurationManager:
|
||||
|
||||
def __init__(
|
||||
self, tenant_id: str, tool_runtime: Tool, provider_name: str, provider_type: ToolProviderType, identity_id: str
|
||||
) -> None:
|
||||
):
|
||||
self.tenant_id = tenant_id
|
||||
self.tool_runtime = tool_runtime
|
||||
self.provider_name = provider_name
|
||||
|
||||
@ -20,7 +20,7 @@ from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import Datas
|
||||
|
||||
|
||||
class DatasetRetrieverTool(Tool):
|
||||
def __init__(self, entity: ToolEntity, runtime: ToolRuntime, retrieval_tool: DatasetRetrieverBaseTool) -> None:
|
||||
def __init__(self, entity: ToolEntity, runtime: ToolRuntime, retrieval_tool: DatasetRetrieverBaseTool):
|
||||
super().__init__(entity, runtime)
|
||||
self.retrieval_tool = retrieval_tool
|
||||
|
||||
|
||||
@ -17,11 +17,11 @@ class ProviderConfigCache(Protocol):
|
||||
"""Get cached provider configuration"""
|
||||
...
|
||||
|
||||
def set(self, config: dict[str, Any]) -> None:
|
||||
def set(self, config: dict[str, Any]):
|
||||
"""Cache provider configuration"""
|
||||
...
|
||||
|
||||
def delete(self) -> None:
|
||||
def delete(self):
|
||||
"""Delete cached provider configuration"""
|
||||
...
|
||||
|
||||
|
||||
@ -242,7 +242,7 @@ class ApiBasedToolSchemaParser:
|
||||
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
|
||||
|
||||
@staticmethod
|
||||
def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None) -> dict:
|
||||
def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None):
|
||||
warning = warning or {}
|
||||
"""
|
||||
parse swagger to openapi
|
||||
|
||||
@ -8,7 +8,7 @@ from yaml import YAMLError
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any = {}) -> Any:
|
||||
def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any = {}):
|
||||
"""
|
||||
Safe loading a YAML file
|
||||
:param file_path: the path of the YAML file
|
||||
|
||||
@ -220,7 +220,7 @@ class WorkflowTool(Tool):
|
||||
|
||||
return result, files
|
||||
|
||||
def _update_file_mapping(self, file_dict: dict) -> dict:
|
||||
def _update_file_mapping(self, file_dict: dict):
|
||||
transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method"))
|
||||
if transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
file_dict["tool_file_id"] = file_dict.get("related_id")
|
||||
|
||||
Reference in New Issue
Block a user