Merge remote-tracking branch 'origin/main' into feat/trigger

This commit is contained in:
yessenia
2025-09-25 17:14:24 +08:00
3013 changed files with 148826 additions and 44294 deletions

View File

@ -4,6 +4,7 @@ from collections.abc import Mapping
from typing import Any, cast
from httpx import get
from sqlalchemy import select
from core.entities.provider_entities import ProviderConfig
from core.model_runtime.utils.encoders import jsonable_encoder
@ -443,9 +444,7 @@ class ApiToolManageService:
list api tools
"""
# get all api providers
db_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all() or []
)
db_providers = db.session.scalars(select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)).all()
result: list[ToolProviderApiEntity] = []

View File

@ -1,18 +1,19 @@
import json
import logging
import re
from collections.abc import Mapping
from pathlib import Path
from typing import Any, Optional
from typing import Any
from sqlalchemy import exists, select
from sqlalchemy.orm import Session
from configs import dify_config
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from core.helper.name_generator import generate_incremental_name
from core.helper.position_helper import is_filtered
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
from core.plugin.entities.plugin import ToolProviderID
# from core.plugin.entities.plugin import ToolProviderID
from core.plugin.entities.plugin_daemon import CredentialType
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
@ -30,6 +31,7 @@ from core.tools.utils.encryption import create_provider_encrypter
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.provider_ids import ToolProviderID
from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient
from services.plugin.plugin_service import PluginService
from services.tools.tools_transform_service import ToolTransformService
@ -222,8 +224,8 @@ class BuiltinToolManageService:
"""
add builtin tool provider
"""
try:
with Session(db.engine) as session:
with Session(db.engine) as session:
try:
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
with redis_client.lock(lock, timeout=20):
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
@ -282,9 +284,9 @@ class BuiltinToolManageService:
session.add(db_provider)
session.commit()
except Exception as e:
session.rollback()
raise ValueError(str(e))
except Exception as e:
session.rollback()
raise ValueError(str(e))
return {"result": "success"}
@staticmethod
@ -308,42 +310,20 @@ class BuiltinToolManageService:
def generate_builtin_tool_provider_name(
session: Session, tenant_id: str, provider: str, credential_type: CredentialType
) -> str:
try:
db_providers = (
session.query(BuiltinToolProvider)
.filter_by(
tenant_id=tenant_id,
provider=provider,
credential_type=credential_type.value,
)
.order_by(BuiltinToolProvider.created_at.desc())
.all()
db_providers = (
session.query(BuiltinToolProvider)
.filter_by(
tenant_id=tenant_id,
provider=provider,
credential_type=credential_type.value,
)
# Get the default name pattern
default_pattern = f"{credential_type.get_name()}"
# Find all names that match the default pattern: "{default_pattern} {number}"
pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$"
numbers = []
for db_provider in db_providers:
if db_provider.name:
match = re.match(pattern, db_provider.name.strip())
if match:
numbers.append(int(match.group(1)))
# If no default pattern names found, start with 1
if not numbers:
return f"{default_pattern} 1"
# Find the next number
max_number = max(numbers)
return f"{default_pattern} {max_number + 1}"
except Exception as e:
logger.warning("Error generating next provider name for %s: %s", provider, str(e))
# fallback
return f"{credential_type.get_name()} 1"
.order_by(BuiltinToolProvider.created_at.desc())
.all()
)
return generate_incremental_name(
[provider.name for provider in db_providers],
f"{credential_type.get_name()}",
)
@staticmethod
def get_builtin_tool_provider_credentials(
@ -570,7 +550,7 @@ class BuiltinToolManageService:
include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
data=provider_controller,
name_func=lambda x: x.identity.name,
name_func=lambda x: x.entity.identity.name,
):
continue
@ -601,7 +581,7 @@ class BuiltinToolManageService:
return BuiltinToolProviderSort.sort(result)
@staticmethod
def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]:
def get_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None:
"""
This method is used to fetch the builtin provider from the database
1.if the default provider exists, return the default provider
@ -662,8 +642,8 @@ class BuiltinToolManageService:
def save_custom_oauth_client_params(
tenant_id: str,
provider: str,
client_params: Optional[dict] = None,
enable_oauth_custom_client: Optional[bool] = None,
client_params: dict | None = None,
enable_oauth_custom_client: bool | None = None,
):
"""
setup oauth custom client

View File

@ -1,7 +1,7 @@
import hashlib
import json
from datetime import datetime
from typing import Any
from typing import Any, cast
from sqlalchemy import or_
from sqlalchemy.exc import IntegrityError
@ -27,6 +27,36 @@ class MCPToolManageService:
Service class for managing mcp tools.
"""
@staticmethod
def _encrypt_headers(headers: dict[str, str], tenant_id: str) -> dict[str, str]:
"""
Encrypt headers using ProviderConfigEncrypter with all headers as SECRET_INPUT.
Args:
headers: Dictionary of headers to encrypt
tenant_id: Tenant ID for encryption
Returns:
Dictionary with all headers encrypted
"""
if not headers:
return {}
from core.entities.provider_entities import BasicProviderConfig
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.tools.utils.encryption import create_provider_encrypter
# Create dynamic config for all headers as SECRET_INPUT
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers]
encrypter_instance, _ = create_provider_encrypter(
tenant_id=tenant_id,
config=config,
cache=NoOpProviderCredentialCache(),
)
return cast(dict[str, str], encrypter_instance.encrypt(headers))
@staticmethod
def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider:
res = (
@ -61,6 +91,7 @@ class MCPToolManageService:
server_identifier: str,
timeout: float,
sse_read_timeout: float,
headers: dict[str, str] | None = None,
) -> ToolProviderApiEntity:
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
existing_provider = (
@ -83,6 +114,12 @@ class MCPToolManageService:
if existing_provider.server_identifier == server_identifier:
raise ValueError(f"MCP tool {server_identifier} already exists")
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
# Encrypt headers
encrypted_headers = None
if headers:
encrypted_headers_dict = MCPToolManageService._encrypt_headers(headers, tenant_id)
encrypted_headers = json.dumps(encrypted_headers_dict)
mcp_tool = MCPToolProvider(
tenant_id=tenant_id,
name=name,
@ -95,6 +132,7 @@ class MCPToolManageService:
server_identifier=server_identifier,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
encrypted_headers=encrypted_headers,
)
db.session.add(mcp_tool)
db.session.commit()
@ -118,9 +156,21 @@ class MCPToolManageService:
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
server_url = mcp_provider.decrypted_server_url
authed = mcp_provider.authed
headers = mcp_provider.decrypted_headers
timeout = mcp_provider.timeout
sse_read_timeout = mcp_provider.sse_read_timeout
try:
with MCPClient(server_url, provider_id, tenant_id, authed=authed, for_list=True) as mcp_client:
with MCPClient(
server_url,
provider_id,
tenant_id,
authed=authed,
for_list=True,
headers=headers,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
) as mcp_client:
tools = mcp_client.list_tools()
except MCPAuthError:
raise ValueError("Please auth the tool first")
@ -172,6 +222,7 @@ class MCPToolManageService:
server_identifier: str,
timeout: float | None = None,
sse_read_timeout: float | None = None,
headers: dict[str, str] | None = None,
):
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
@ -207,6 +258,32 @@ class MCPToolManageService:
mcp_provider.timeout = timeout
if sse_read_timeout is not None:
mcp_provider.sse_read_timeout = sse_read_timeout
if headers is not None:
# Merge masked headers from frontend with existing real values
if headers:
# existing decrypted and masked headers
existing_decrypted = mcp_provider.decrypted_headers
existing_masked = mcp_provider.masked_headers
# Build final headers: if value equals masked existing, keep original decrypted value
final_headers: dict[str, str] = {}
for key, incoming_value in headers.items():
if (
key in existing_masked
and key in existing_decrypted
and isinstance(incoming_value, str)
and incoming_value == existing_masked.get(key)
):
# unchanged, use original decrypted value
final_headers[key] = str(existing_decrypted[key])
else:
final_headers[key] = incoming_value
encrypted_headers_dict = MCPToolManageService._encrypt_headers(final_headers, tenant_id)
mcp_provider.encrypted_headers = json.dumps(encrypted_headers_dict)
else:
# Explicitly clear headers if empty dict passed
mcp_provider.encrypted_headers = None
db.session.commit()
except IntegrityError as e:
db.session.rollback()
@ -226,10 +303,10 @@ class MCPToolManageService:
def update_mcp_provider_credentials(
cls, mcp_provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False
):
provider_controller = MCPToolProviderController._from_db(mcp_provider)
provider_controller = MCPToolProviderController.from_db(mcp_provider)
tool_configuration = ProviderConfigEncrypter(
tenant_id=mcp_provider.tenant_id,
config=list(provider_controller.get_credentials_schema()),
config=list(provider_controller.get_credentials_schema()), # ty: ignore [invalid-argument-type]
provider_config_cache=NoOpProviderCredentialCache(),
)
credentials = tool_configuration.encrypt(credentials)
@ -242,6 +319,12 @@ class MCPToolManageService:
@classmethod
def _re_connect_mcp_provider(cls, server_url: str, provider_id: str, tenant_id: str):
# Get the existing provider to access headers and timeout settings
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
headers = mcp_provider.decrypted_headers
timeout = mcp_provider.timeout
sse_read_timeout = mcp_provider.sse_read_timeout
try:
with MCPClient(
server_url,
@ -249,6 +332,9 @@ class MCPToolManageService:
tenant_id,
authed=False,
for_list=True,
headers=headers,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
) as mcp_client:
tools = mcp_client.list_tools()
return {

View File

@ -1,5 +1,4 @@
import logging
from typing import Optional
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
from core.tools.tool_manager import ToolManager
@ -10,7 +9,7 @@ logger = logging.getLogger(__name__)
class ToolCommonService:
@staticmethod
def list_tool_providers(user_id: str, tenant_id: str, typ: Optional[ToolProviderTypeApiLiteral] = None):
def list_tool_providers(user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral | None = None):
"""
list tool providers

View File

@ -1,13 +1,14 @@
import json
import logging
from typing import Any, Optional, Union, cast
from collections.abc import Mapping
from typing import Any, Union
from yarl import URL
from configs import dify_config
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.mcp.types import Tool as MCPTool
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.entities.plugin_daemon import CredentialType, PluginDatasourceProviderEntity
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.provider import BuiltinToolProviderController
@ -32,7 +33,9 @@ logger = logging.getLogger(__name__)
class ToolTransformService:
@classmethod
def get_tool_provider_icon_url(cls, provider_type: str, provider_name: str, icon: str | dict) -> Union[str, dict]:
def get_tool_provider_icon_url(
cls, provider_type: str, provider_name: str, icon: str | Mapping[str, str]
) -> str | Mapping[str, str]:
"""
get tool provider icon url
"""
@ -45,7 +48,7 @@ class ToolTransformService:
elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}:
try:
if isinstance(icon, str):
return cast(dict, json.loads(icon))
return json.loads(icon)
return icon
except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"}
@ -54,7 +57,7 @@ class ToolTransformService:
return ""
@staticmethod
def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity]):
def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity, PluginDatasourceProviderEntity]):
"""
repack provider
@ -68,7 +71,9 @@ class ToolTransformService:
elif isinstance(provider, ToolProviderApiEntity):
if provider.plugin_id:
if isinstance(provider.icon, str):
provider.icon = PluginService.get_plugin_icon_url(tenant_id=tenant_id, filename=provider.icon)
provider.icon = PluginService.get_plugin_icon_url(
tenant_id=tenant_id, filename=provider.icon
)
if isinstance(provider.icon_dark, str) and provider.icon_dark:
provider.icon_dark = PluginService.get_plugin_icon_url(
tenant_id=tenant_id, filename=provider.icon_dark
@ -81,12 +86,18 @@ class ToolTransformService:
provider.icon_dark = ToolTransformService.get_tool_provider_icon_url(
provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon_dark
)
elif isinstance(provider, PluginDatasourceProviderEntity):
if provider.plugin_id:
if isinstance(provider.declaration.identity.icon, str):
provider.declaration.identity.icon = ToolTransformService.get_plugin_icon_url(
tenant_id=tenant_id, filename=provider.declaration.identity.icon
)
@classmethod
def builtin_provider_to_user_provider(
cls,
provider_controller: BuiltinToolProviderController | PluginToolProviderController,
db_provider: Optional[BuiltinToolProvider],
db_provider: BuiltinToolProvider | None,
decrypt_credentials: bool = True,
) -> ToolProviderApiEntity:
"""
@ -98,7 +109,7 @@ class ToolTransformService:
name=provider_controller.entity.identity.name,
description=provider_controller.entity.identity.description,
icon=provider_controller.entity.identity.icon,
icon_dark=provider_controller.entity.identity.icon_dark,
icon_dark=provider_controller.entity.identity.icon_dark or "",
label=provider_controller.entity.identity.label,
type=ToolProviderType.BUILT_IN,
masked_credentials={},
@ -120,9 +131,10 @@ class ToolTransformService:
)
}
for name, value in schema.items():
if result.masked_credentials:
result.masked_credentials[name] = ""
masked_creds = {}
for name in schema:
masked_creds[name] = ""
result.masked_credentials = masked_creds
# check if the provider need credentials
if not provider_controller.need_credentials:
@ -200,7 +212,7 @@ class ToolTransformService:
name=provider_controller.entity.identity.name,
description=provider_controller.entity.identity.description,
icon=provider_controller.entity.identity.icon,
icon_dark=provider_controller.entity.identity.icon_dark,
icon_dark=provider_controller.entity.identity.icon_dark or "",
label=provider_controller.entity.identity.label,
type=ToolProviderType.WORKFLOW,
masked_credentials={},
@ -229,6 +241,10 @@ class ToolTransformService:
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
description=I18nObject(en_US="", zh_Hans=""),
server_identifier=db_provider.server_identifier,
timeout=db_provider.timeout,
sse_read_timeout=db_provider.sse_read_timeout,
masked_headers=db_provider.masked_headers,
original_headers=db_provider.decrypted_headers,
)
@staticmethod
@ -239,7 +255,7 @@ class ToolTransformService:
author=user.name if user else "Anonymous",
name=tool.name,
label=I18nObject(en_US=tool.name, zh_Hans=tool.name),
description=I18nObject(en_US=tool.description, zh_Hans=tool.description),
description=I18nObject(en_US=tool.description or "", zh_Hans=tool.description or ""),
parameters=ToolTransformService.convert_mcp_schema_to_parameter(tool.inputSchema),
labels=[],
)
@ -309,7 +325,7 @@ class ToolTransformService:
@staticmethod
def convert_tool_entity_to_api_entity(
tool: Union[ApiToolBundle, WorkflowTool, Tool],
tool: ApiToolBundle | WorkflowTool | Tool,
tenant_id: str,
labels: list[str] | None = None,
) -> ToolApiEntity:
@ -363,7 +379,7 @@ class ToolTransformService:
parameters=merged_parameters,
labels=labels or [],
)
elif isinstance(tool, ApiToolBundle):
else:
return ToolApiEntity(
author=tool.author,
name=tool.operation_id or "",
@ -372,9 +388,6 @@ class ToolTransformService:
parameters=tool.parameters,
labels=labels or [],
)
else:
# Handle WorkflowTool case
raise ValueError(f"Unsupported tool type: {type(tool)}")
@staticmethod
def convert_builtin_provider_to_credential_entity(

View File

@ -3,7 +3,7 @@ from collections.abc import Mapping
from datetime import datetime
from typing import Any
from sqlalchemy import or_
from sqlalchemy import or_, select
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool_provider import ToolProviderController
@ -37,7 +37,7 @@ class WorkflowToolManageService:
parameters: list[Mapping[str, Any]],
privacy_policy: str = "",
labels: list[str] | None = None,
) -> dict:
):
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
# check if the name is unique
@ -103,7 +103,7 @@ class WorkflowToolManageService:
parameters: list[Mapping[str, Any]],
privacy_policy: str = "",
labels: list[str] | None = None,
) -> dict:
):
"""
Update a workflow tool.
:param user_id: the user id
@ -186,7 +186,9 @@ class WorkflowToolManageService:
:param tenant_id: the tenant id
:return: the list of tools
"""
db_tools = db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all()
db_tools = db.session.scalars(
select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id)
).all()
tools: list[WorkflowToolProviderController] = []
for provider in db_tools:
@ -217,7 +219,7 @@ class WorkflowToolManageService:
return result
@classmethod
def delete_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict:
def delete_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str):
"""
Delete a workflow tool.
:param user_id: the user id
@ -233,7 +235,7 @@ class WorkflowToolManageService:
return {"result": "success"}
@classmethod
def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict:
def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str):
"""
Get a workflow tool.
:param user_id: the user id
@ -249,7 +251,7 @@ class WorkflowToolManageService:
return cls._get_workflow_tool(tenant_id, db_tool)
@classmethod
def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str) -> dict:
def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str):
"""
Get a workflow tool.
:param user_id: the user id
@ -265,7 +267,7 @@ class WorkflowToolManageService:
return cls._get_workflow_tool(tenant_id, db_tool)
@classmethod
def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None) -> dict:
def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None):
"""
Get a workflow tool.
:db_tool: the database tool