Merge remote-tracking branch 'upstream/main' into feat/rag-2

This commit is contained in:
QuantumGhost
2025-09-16 14:59:35 +08:00
791 changed files with 24372 additions and 7085 deletions

View File

@ -4,6 +4,7 @@ from collections.abc import Mapping
from typing import Any, cast
from httpx import get
from sqlalchemy import select
from core.entities.provider_entities import ProviderConfig
from core.model_runtime.utils.encoders import jsonable_encoder
@ -443,9 +444,7 @@ class ApiToolManageService:
list api tools
"""
# get all api providers
db_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all() or []
)
db_providers = db.session.scalars(select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)).all()
result: list[ToolProviderApiEntity] = []

View File

@ -2,7 +2,7 @@ import json
import logging
from collections.abc import Mapping
from pathlib import Path
from typing import Any, Optional
from typing import Any
from sqlalchemy import exists, select
from sqlalchemy.orm import Session
@ -223,8 +223,8 @@ class BuiltinToolManageService:
"""
add builtin tool provider
"""
try:
with Session(db.engine) as session:
with Session(db.engine) as session:
try:
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
with redis_client.lock(lock, timeout=20):
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
@ -285,9 +285,9 @@ class BuiltinToolManageService:
session.add(db_provider)
session.commit()
except Exception as e:
session.rollback()
raise ValueError(str(e))
except Exception as e:
session.rollback()
raise ValueError(str(e))
return {"result": "success"}
@staticmethod
@ -582,7 +582,7 @@ class BuiltinToolManageService:
return BuiltinToolProviderSort.sort(result)
@staticmethod
def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]:
def get_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None:
"""
This method is used to fetch the builtin provider from the database
1.if the default provider exists, return the default provider
@ -643,8 +643,8 @@ class BuiltinToolManageService:
def save_custom_oauth_client_params(
tenant_id: str,
provider: str,
client_params: Optional[dict] = None,
enable_oauth_custom_client: Optional[bool] = None,
client_params: dict | None = None,
enable_oauth_custom_client: bool | None = None,
):
"""
setup oauth custom client

View File

@ -259,11 +259,30 @@ class MCPToolManageService:
if sse_read_timeout is not None:
mcp_provider.sse_read_timeout = sse_read_timeout
if headers is not None:
# Encrypt headers
# Merge masked headers from frontend with existing real values
if headers:
encrypted_headers_dict = MCPToolManageService._encrypt_headers(headers, tenant_id)
# existing decrypted and masked headers
existing_decrypted = mcp_provider.decrypted_headers
existing_masked = mcp_provider.masked_headers
# Build final headers: if value equals masked existing, keep original decrypted value
final_headers: dict[str, str] = {}
for key, incoming_value in headers.items():
if (
key in existing_masked
and key in existing_decrypted
and isinstance(incoming_value, str)
and incoming_value == existing_masked.get(key)
):
# unchanged, use original decrypted value
final_headers[key] = str(existing_decrypted[key])
else:
final_headers[key] = incoming_value
encrypted_headers_dict = MCPToolManageService._encrypt_headers(final_headers, tenant_id)
mcp_provider.encrypted_headers = json.dumps(encrypted_headers_dict)
else:
# Explicitly clear headers if empty dict passed
mcp_provider.encrypted_headers = None
db.session.commit()
except IntegrityError as e:

View File

@ -1,5 +1,4 @@
import logging
from typing import Optional
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
from core.tools.tool_manager import ToolManager
@ -10,7 +9,7 @@ logger = logging.getLogger(__name__)
class ToolCommonService:
@staticmethod
def list_tool_providers(user_id: str, tenant_id: str, typ: Optional[ToolProviderTypeApiLiteral] = None):
def list_tool_providers(user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral | None = None):
"""
list tool providers

View File

@ -1,6 +1,6 @@
import json
import logging
from typing import Any, Optional, Union, cast
from typing import Any, Union, cast
from yarl import URL
@ -107,7 +107,7 @@ class ToolTransformService:
def builtin_provider_to_user_provider(
cls,
provider_controller: BuiltinToolProviderController | PluginToolProviderController,
db_provider: Optional[BuiltinToolProvider],
db_provider: BuiltinToolProvider | None,
decrypt_credentials: bool = True,
) -> ToolProviderApiEntity:
"""

View File

@ -3,7 +3,7 @@ from collections.abc import Mapping
from datetime import datetime
from typing import Any
from sqlalchemy import or_
from sqlalchemy import or_, select
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool_provider import ToolProviderController
@ -186,7 +186,9 @@ class WorkflowToolManageService:
:param tenant_id: the tenant id
:return: the list of tools
"""
db_tools = db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all()
db_tools = db.session.scalars(
select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id)
).all()
tools: list[WorkflowToolProviderController] = []
for provider in db_tools: