Merge branch 'feat/queue-based-graph-engine' into chore/merge-graph-engine

This commit is contained in:
-LAN-
2025-09-08 14:25:10 +08:00
824 changed files with 7235 additions and 2941 deletions

View File

@ -20,7 +20,7 @@ class Tool(ABC):
The base class of a tool
"""
def __init__(self, entity: ToolEntity, runtime: ToolRuntime) -> None:
def __init__(self, entity: ToolEntity, runtime: ToolRuntime):
self.entity = entity
self.runtime = runtime

View File

@ -12,7 +12,7 @@ from core.tools.errors import ToolProviderCredentialValidationError
class ToolProviderController(ABC):
def __init__(self, entity: ToolProviderEntity) -> None:
def __init__(self, entity: ToolProviderEntity):
self.entity = entity
def get_credentials_schema(self) -> list[ProviderConfig]:
@ -41,7 +41,7 @@ class ToolProviderController(ABC):
"""
return ToolProviderType.BUILT_IN
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
def validate_credentials_format(self, credentials: dict[str, Any]):
"""
validate the format of the credentials of the provider and set the default value if needed

View File

@ -24,7 +24,7 @@ from core.tools.utils.yaml_utils import load_yaml_file
class BuiltinToolProviderController(ToolProviderController):
tools: list[BuiltinTool]
def __init__(self, **data: Any) -> None:
def __init__(self, **data: Any):
self.tools = []
# load provider yaml
@ -197,7 +197,7 @@ class BuiltinToolProviderController(ToolProviderController):
"""
return self.entity.identity.tags or []
def validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
def validate_credentials(self, user_id: str, credentials: dict[str, Any]):
"""
validate the credentials of the provider
@ -211,7 +211,7 @@ class BuiltinToolProviderController(ToolProviderController):
self._validate_credentials(user_id, credentials)
@abstractmethod
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
"""
validate the credentials of the provider

View File

@ -4,5 +4,5 @@ from core.tools.builtin_tool.provider import BuiltinToolProviderController
class AudioToolProvider(BuiltinToolProviderController):
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
pass

View File

@ -4,5 +4,5 @@ from core.tools.builtin_tool.provider import BuiltinToolProviderController
class CodeToolProvider(BuiltinToolProviderController):
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
pass

View File

@ -4,5 +4,5 @@ from core.tools.builtin_tool.provider import BuiltinToolProviderController
class WikiPediaProvider(BuiltinToolProviderController):
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
pass

View File

@ -4,7 +4,7 @@ from core.tools.builtin_tool.provider import BuiltinToolProviderController
class WebscraperProvider(BuiltinToolProviderController):
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
"""
Validate credentials
"""

View File

@ -24,7 +24,7 @@ class ApiToolProviderController(ToolProviderController):
tenant_id: str
tools: list[ApiTool] = Field(default_factory=list)
def __init__(self, entity: ToolProviderEntity, provider_id: str, tenant_id: str) -> None:
def __init__(self, entity: ToolProviderEntity, provider_id: str, tenant_id: str):
super().__init__(entity)
self.provider_id = provider_id
self.tenant_id = tenant_id

View File

@ -302,7 +302,7 @@ class ApiTool(Tool):
def _convert_body_property_any_of(
self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10
) -> Any:
):
if max_recursive <= 0:
raise Exception("Max recursion depth reached")
for option in any_of or []:
@ -337,7 +337,7 @@ class ApiTool(Tool):
# If no option succeeded, you might want to return the value as is or raise an error
return value # or raise ValueError(f"Cannot convert value '{value}' to any specified type in anyOf")
def _convert_body_property_type(self, property: dict[str, Any], value: Any) -> Any:
def _convert_body_property_type(self, property: dict[str, Any], value: Any):
try:
if "type" in property:
if property["type"] == "integer" or property["type"] == "int":

View File

@ -49,7 +49,7 @@ class ToolProviderApiEntity(BaseModel):
def convert_none_to_empty_list(cls, v):
return v if v is not None else []
def to_dict(self) -> dict:
def to_dict(self):
# -------------
# overwrite tool parameter types for temp fix
tools = jsonable_encoder(self.tools)
@ -84,7 +84,7 @@ class ToolProviderApiEntity(BaseModel):
**optional_fields,
}
def optional_field(self, key: str, value: Any) -> dict:
def optional_field(self, key: str, value: Any):
"""Return dict with key-value if value is truthy, empty dict otherwise."""
return {key: value} if value else {}

View File

@ -19,5 +19,5 @@ class I18nObject(BaseModel):
self.pt_BR = self.pt_BR or self.en_US
self.ja_JP = self.ja_JP or self.en_US
def to_dict(self) -> dict:
def to_dict(self):
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}

View File

@ -151,7 +151,7 @@ class ToolInvokeMessage(BaseModel):
@model_validator(mode="before")
@classmethod
def transform_variable_value(cls, values) -> Any:
def transform_variable_value(cls, values):
"""
Only basic types and lists are allowed.
"""
@ -429,7 +429,7 @@ class ToolInvokeMeta(BaseModel):
"""
return cls(time_cost=0.0, error=error, tool_config={})
def to_dict(self) -> dict:
def to_dict(self):
return {
"time_cost": self.time_cost,
"error": self.error,

View File

@ -1,5 +1,5 @@
import json
from typing import Any, Optional
from typing import Any, Optional, Self
from core.mcp.types import Tool as RemoteMCPTool
from core.tools.__base.tool_provider import ToolProviderController
@ -28,7 +28,7 @@ class MCPToolProviderController(ToolProviderController):
headers: Optional[dict[str, str]] = None,
timeout: Optional[float] = None,
sse_read_timeout: Optional[float] = None,
) -> None:
):
super().__init__(entity)
self.entity: ToolProviderEntityWithPlugin = entity
self.tenant_id = tenant_id
@ -48,7 +48,7 @@ class MCPToolProviderController(ToolProviderController):
return ToolProviderType.MCP
@classmethod
def _from_db(cls, db_provider: MCPToolProvider) -> "MCPToolProviderController":
def from_db(cls, db_provider: MCPToolProvider) -> Self:
"""
from db provider
"""
@ -99,7 +99,7 @@ class MCPToolProviderController(ToolProviderController):
sse_read_timeout=db_provider.sse_read_timeout,
)
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
"""
validate the credentials of the provider
"""

View File

@ -23,7 +23,7 @@ class MCPTool(Tool):
headers: Optional[dict[str, str]] = None,
timeout: Optional[float] = None,
sse_read_timeout: Optional[float] = None,
) -> None:
):
super().__init__(entity, runtime)
self.tenant_id = tenant_id
self.icon = icon

View File

@ -16,7 +16,7 @@ class PluginToolProviderController(BuiltinToolProviderController):
def __init__(
self, entity: ToolProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
) -> None:
):
self.entity = entity
self.tenant_id = tenant_id
self.plugin_id = plugin_id
@ -31,7 +31,7 @@ class PluginToolProviderController(BuiltinToolProviderController):
"""
return ToolProviderType.PLUGIN
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
"""
validate the credentials of the provider
"""

View File

@ -11,7 +11,7 @@ from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, Too
class PluginTool(Tool):
def __init__(
self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, plugin_unique_identifier: str
) -> None:
):
super().__init__(entity, runtime)
self.tenant_id = tenant_id
self.icon = icon

View File

@ -98,6 +98,7 @@ class ToolFileManager:
mimetype=mimetype,
name=present_filename,
size=len(file_binary),
original_url=None,
)
session.add(tool_file)
@ -131,7 +132,6 @@ class ToolFileManager:
filename = f"{unique_name}{extension}"
filepath = f"tools/{tenant_id}/{filename}"
storage.save(filepath, blob)
with Session(self._engine, expire_on_commit=False) as session:
tool_file = ToolFile(
user_id=user_id,

View File

@ -646,7 +646,7 @@ class ToolManager:
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
data=provider,
name_func=lambda x: x.identity.name,
name_func=lambda x: x.entity.identity.name,
):
continue
user_provider = ToolTransformService.builtin_provider_to_user_provider(
@ -777,12 +777,12 @@ class ToolManager:
if provider is None:
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
controller = MCPToolProviderController._from_db(provider)
controller = MCPToolProviderController.from_db(provider)
return controller
@classmethod
def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict:
def user_get_api_provider(cls, provider: str, tenant_id: str):
"""
get api provider
"""
@ -877,7 +877,7 @@ class ToolManager:
)
@classmethod
def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict:
def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str):
try:
workflow_provider: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
@ -894,7 +894,7 @@ class ToolManager:
return {"background": "#252525", "content": "\ud83d\ude01"}
@classmethod
def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict:
def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str):
try:
api_provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
@ -932,7 +932,7 @@ class ToolManager:
tenant_id: str,
provider_type: ToolProviderType,
provider_id: str,
) -> Union[str, dict]:
) -> Union[str, dict[str, Any]]:
"""
get the tool icon

View File

@ -24,7 +24,7 @@ class ToolParameterConfigurationManager:
def __init__(
self, tenant_id: str, tool_runtime: Tool, provider_name: str, provider_type: ToolProviderType, identity_id: str
) -> None:
):
self.tenant_id = tenant_id
self.tool_runtime = tool_runtime
self.provider_name = provider_name

View File

@ -20,7 +20,7 @@ from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import Datas
class DatasetRetrieverTool(Tool):
def __init__(self, entity: ToolEntity, runtime: ToolRuntime, retrieval_tool: DatasetRetrieverBaseTool) -> None:
def __init__(self, entity: ToolEntity, runtime: ToolRuntime, retrieval_tool: DatasetRetrieverBaseTool):
super().__init__(entity, runtime)
self.retrieval_tool = retrieval_tool

View File

@ -17,11 +17,11 @@ class ProviderConfigCache(Protocol):
"""Get cached provider configuration"""
...
def set(self, config: dict[str, Any]) -> None:
def set(self, config: dict[str, Any]):
"""Cache provider configuration"""
...
def delete(self) -> None:
def delete(self):
"""Delete cached provider configuration"""
...

View File

@ -242,7 +242,7 @@ class ApiBasedToolSchemaParser:
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
@staticmethod
def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None) -> dict:
def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None):
warning = warning or {}
"""
parse swagger to openapi

View File

@ -8,7 +8,7 @@ from yaml import YAMLError
logger = logging.getLogger(__name__)
def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any = {}) -> Any:
def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any = {}):
"""
Safe loading a YAML file
:param file_path: the path of the YAML file

View File

@ -220,7 +220,7 @@ class WorkflowTool(Tool):
return result, files
def _update_file_mapping(self, file_dict: dict) -> dict:
def _update_file_mapping(self, file_dict: dict):
transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method"))
if transfer_method == FileTransferMethod.TOOL_FILE:
file_dict["tool_file_id"] = file_dict.get("related_id")