mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 17:38:04 +08:00
Merge remote-tracking branch 'upstream/main' into feat/rag-2
This commit is contained in:
@ -4,6 +4,7 @@ from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from httpx import get
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
@ -443,9 +444,7 @@ class ApiToolManageService:
|
||||
list api tools
|
||||
"""
|
||||
# get all api providers
|
||||
db_providers: list[ApiToolProvider] = (
|
||||
db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all() or []
|
||||
)
|
||||
db_providers = db.session.scalars(select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)).all()
|
||||
|
||||
result: list[ToolProviderApiEntity] = []
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import exists, select
|
||||
from sqlalchemy.orm import Session
|
||||
@ -223,8 +223,8 @@ class BuiltinToolManageService:
|
||||
"""
|
||||
add builtin tool provider
|
||||
"""
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine) as session:
|
||||
try:
|
||||
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
|
||||
with redis_client.lock(lock, timeout=20):
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
@ -285,9 +285,9 @@ class BuiltinToolManageService:
|
||||
|
||||
session.add(db_provider)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
raise ValueError(str(e))
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
raise ValueError(str(e))
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
@ -582,7 +582,7 @@ class BuiltinToolManageService:
|
||||
return BuiltinToolProviderSort.sort(result)
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]:
|
||||
def get_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None:
|
||||
"""
|
||||
This method is used to fetch the builtin provider from the database
|
||||
1.if the default provider exists, return the default provider
|
||||
@ -643,8 +643,8 @@ class BuiltinToolManageService:
|
||||
def save_custom_oauth_client_params(
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
client_params: Optional[dict] = None,
|
||||
enable_oauth_custom_client: Optional[bool] = None,
|
||||
client_params: dict | None = None,
|
||||
enable_oauth_custom_client: bool | None = None,
|
||||
):
|
||||
"""
|
||||
setup oauth custom client
|
||||
|
||||
@ -259,11 +259,30 @@ class MCPToolManageService:
|
||||
if sse_read_timeout is not None:
|
||||
mcp_provider.sse_read_timeout = sse_read_timeout
|
||||
if headers is not None:
|
||||
# Encrypt headers
|
||||
# Merge masked headers from frontend with existing real values
|
||||
if headers:
|
||||
encrypted_headers_dict = MCPToolManageService._encrypt_headers(headers, tenant_id)
|
||||
# existing decrypted and masked headers
|
||||
existing_decrypted = mcp_provider.decrypted_headers
|
||||
existing_masked = mcp_provider.masked_headers
|
||||
|
||||
# Build final headers: if value equals masked existing, keep original decrypted value
|
||||
final_headers: dict[str, str] = {}
|
||||
for key, incoming_value in headers.items():
|
||||
if (
|
||||
key in existing_masked
|
||||
and key in existing_decrypted
|
||||
and isinstance(incoming_value, str)
|
||||
and incoming_value == existing_masked.get(key)
|
||||
):
|
||||
# unchanged, use original decrypted value
|
||||
final_headers[key] = str(existing_decrypted[key])
|
||||
else:
|
||||
final_headers[key] = incoming_value
|
||||
|
||||
encrypted_headers_dict = MCPToolManageService._encrypt_headers(final_headers, tenant_id)
|
||||
mcp_provider.encrypted_headers = json.dumps(encrypted_headers_dict)
|
||||
else:
|
||||
# Explicitly clear headers if empty dict passed
|
||||
mcp_provider.encrypted_headers = None
|
||||
db.session.commit()
|
||||
except IntegrityError as e:
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
|
||||
from core.tools.tool_manager import ToolManager
|
||||
@ -10,7 +9,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ToolCommonService:
|
||||
@staticmethod
|
||||
def list_tool_providers(user_id: str, tenant_id: str, typ: Optional[ToolProviderTypeApiLiteral] = None):
|
||||
def list_tool_providers(user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral | None = None):
|
||||
"""
|
||||
list tool providers
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional, Union, cast
|
||||
from typing import Any, Union, cast
|
||||
|
||||
from yarl import URL
|
||||
|
||||
@ -107,7 +107,7 @@ class ToolTransformService:
|
||||
def builtin_provider_to_user_provider(
|
||||
cls,
|
||||
provider_controller: BuiltinToolProviderController | PluginToolProviderController,
|
||||
db_provider: Optional[BuiltinToolProvider],
|
||||
db_provider: BuiltinToolProvider | None,
|
||||
decrypt_credentials: bool = True,
|
||||
) -> ToolProviderApiEntity:
|
||||
"""
|
||||
|
||||
@ -3,7 +3,7 @@ from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import or_, select
|
||||
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
@ -186,7 +186,9 @@ class WorkflowToolManageService:
|
||||
:param tenant_id: the tenant id
|
||||
:return: the list of tools
|
||||
"""
|
||||
db_tools = db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all()
|
||||
db_tools = db.session.scalars(
|
||||
select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id)
|
||||
).all()
|
||||
|
||||
tools: list[WorkflowToolProviderController] = []
|
||||
for provider in db_tools:
|
||||
|
||||
Reference in New Issue
Block a user