mirror of
https://github.com/langgenius/dify.git
synced 2026-02-22 19:15:47 +08:00
Merge remote-tracking branch 'origin/main' into feat/trigger
This commit is contained in:
@ -4,6 +4,7 @@ from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from httpx import get
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
@ -443,9 +444,7 @@ class ApiToolManageService:
|
||||
list api tools
|
||||
"""
|
||||
# get all api providers
|
||||
db_providers: list[ApiToolProvider] = (
|
||||
db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all() or []
|
||||
)
|
||||
db_providers = db.session.scalars(select(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id)).all()
|
||||
|
||||
result: list[ToolProviderApiEntity] = []
|
||||
|
||||
|
||||
@ -1,18 +1,19 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Mapping
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import exists, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
||||
from core.helper.name_generator import generate_incremental_name
|
||||
from core.helper.position_helper import is_filtered
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
|
||||
from core.plugin.entities.plugin import ToolProviderID
|
||||
|
||||
# from core.plugin.entities.plugin import ToolProviderID
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
||||
@ -30,6 +31,7 @@ from core.tools.utils.encryption import create_provider_encrypter
|
||||
from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.provider_ids import ToolProviderID
|
||||
from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient
|
||||
from services.plugin.plugin_service import PluginService
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
@ -222,8 +224,8 @@ class BuiltinToolManageService:
|
||||
"""
|
||||
add builtin tool provider
|
||||
"""
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine) as session:
|
||||
try:
|
||||
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
|
||||
with redis_client.lock(lock, timeout=20):
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
@ -282,9 +284,9 @@ class BuiltinToolManageService:
|
||||
|
||||
session.add(db_provider)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
raise ValueError(str(e))
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
raise ValueError(str(e))
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
@ -308,42 +310,20 @@ class BuiltinToolManageService:
|
||||
def generate_builtin_tool_provider_name(
|
||||
session: Session, tenant_id: str, provider: str, credential_type: CredentialType
|
||||
) -> str:
|
||||
try:
|
||||
db_providers = (
|
||||
session.query(BuiltinToolProvider)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
credential_type=credential_type.value,
|
||||
)
|
||||
.order_by(BuiltinToolProvider.created_at.desc())
|
||||
.all()
|
||||
db_providers = (
|
||||
session.query(BuiltinToolProvider)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
credential_type=credential_type.value,
|
||||
)
|
||||
|
||||
# Get the default name pattern
|
||||
default_pattern = f"{credential_type.get_name()}"
|
||||
|
||||
# Find all names that match the default pattern: "{default_pattern} {number}"
|
||||
pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$"
|
||||
numbers = []
|
||||
|
||||
for db_provider in db_providers:
|
||||
if db_provider.name:
|
||||
match = re.match(pattern, db_provider.name.strip())
|
||||
if match:
|
||||
numbers.append(int(match.group(1)))
|
||||
|
||||
# If no default pattern names found, start with 1
|
||||
if not numbers:
|
||||
return f"{default_pattern} 1"
|
||||
|
||||
# Find the next number
|
||||
max_number = max(numbers)
|
||||
return f"{default_pattern} {max_number + 1}"
|
||||
except Exception as e:
|
||||
logger.warning("Error generating next provider name for %s: %s", provider, str(e))
|
||||
# fallback
|
||||
return f"{credential_type.get_name()} 1"
|
||||
.order_by(BuiltinToolProvider.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
return generate_incremental_name(
|
||||
[provider.name for provider in db_providers],
|
||||
f"{credential_type.get_name()}",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_tool_provider_credentials(
|
||||
@ -570,7 +550,7 @@ class BuiltinToolManageService:
|
||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
|
||||
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
|
||||
data=provider_controller,
|
||||
name_func=lambda x: x.identity.name,
|
||||
name_func=lambda x: x.entity.identity.name,
|
||||
):
|
||||
continue
|
||||
|
||||
@ -601,7 +581,7 @@ class BuiltinToolManageService:
|
||||
return BuiltinToolProviderSort.sort(result)
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]:
|
||||
def get_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None:
|
||||
"""
|
||||
This method is used to fetch the builtin provider from the database
|
||||
1.if the default provider exists, return the default provider
|
||||
@ -662,8 +642,8 @@ class BuiltinToolManageService:
|
||||
def save_custom_oauth_client_params(
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
client_params: Optional[dict] = None,
|
||||
enable_oauth_custom_client: Optional[bool] = None,
|
||||
client_params: dict | None = None,
|
||||
enable_oauth_custom_client: bool | None = None,
|
||||
):
|
||||
"""
|
||||
setup oauth custom client
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import hashlib
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
@ -27,6 +27,36 @@ class MCPToolManageService:
|
||||
Service class for managing mcp tools.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _encrypt_headers(headers: dict[str, str], tenant_id: str) -> dict[str, str]:
|
||||
"""
|
||||
Encrypt headers using ProviderConfigEncrypter with all headers as SECRET_INPUT.
|
||||
|
||||
Args:
|
||||
headers: Dictionary of headers to encrypt
|
||||
tenant_id: Tenant ID for encryption
|
||||
|
||||
Returns:
|
||||
Dictionary with all headers encrypted
|
||||
"""
|
||||
if not headers:
|
||||
return {}
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.tools.utils.encryption import create_provider_encrypter
|
||||
|
||||
# Create dynamic config for all headers as SECRET_INPUT
|
||||
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers]
|
||||
|
||||
encrypter_instance, _ = create_provider_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=config,
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
return cast(dict[str, str], encrypter_instance.encrypt(headers))
|
||||
|
||||
@staticmethod
|
||||
def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider:
|
||||
res = (
|
||||
@ -61,6 +91,7 @@ class MCPToolManageService:
|
||||
server_identifier: str,
|
||||
timeout: float,
|
||||
sse_read_timeout: float,
|
||||
headers: dict[str, str] | None = None,
|
||||
) -> ToolProviderApiEntity:
|
||||
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
|
||||
existing_provider = (
|
||||
@ -83,6 +114,12 @@ class MCPToolManageService:
|
||||
if existing_provider.server_identifier == server_identifier:
|
||||
raise ValueError(f"MCP tool {server_identifier} already exists")
|
||||
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
|
||||
# Encrypt headers
|
||||
encrypted_headers = None
|
||||
if headers:
|
||||
encrypted_headers_dict = MCPToolManageService._encrypt_headers(headers, tenant_id)
|
||||
encrypted_headers = json.dumps(encrypted_headers_dict)
|
||||
|
||||
mcp_tool = MCPToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
@ -95,6 +132,7 @@ class MCPToolManageService:
|
||||
server_identifier=server_identifier,
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
encrypted_headers=encrypted_headers,
|
||||
)
|
||||
db.session.add(mcp_tool)
|
||||
db.session.commit()
|
||||
@ -118,9 +156,21 @@ class MCPToolManageService:
|
||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
server_url = mcp_provider.decrypted_server_url
|
||||
authed = mcp_provider.authed
|
||||
headers = mcp_provider.decrypted_headers
|
||||
timeout = mcp_provider.timeout
|
||||
sse_read_timeout = mcp_provider.sse_read_timeout
|
||||
|
||||
try:
|
||||
with MCPClient(server_url, provider_id, tenant_id, authed=authed, for_list=True) as mcp_client:
|
||||
with MCPClient(
|
||||
server_url,
|
||||
provider_id,
|
||||
tenant_id,
|
||||
authed=authed,
|
||||
for_list=True,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
) as mcp_client:
|
||||
tools = mcp_client.list_tools()
|
||||
except MCPAuthError:
|
||||
raise ValueError("Please auth the tool first")
|
||||
@ -172,6 +222,7 @@ class MCPToolManageService:
|
||||
server_identifier: str,
|
||||
timeout: float | None = None,
|
||||
sse_read_timeout: float | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
):
|
||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
|
||||
@ -207,6 +258,32 @@ class MCPToolManageService:
|
||||
mcp_provider.timeout = timeout
|
||||
if sse_read_timeout is not None:
|
||||
mcp_provider.sse_read_timeout = sse_read_timeout
|
||||
if headers is not None:
|
||||
# Merge masked headers from frontend with existing real values
|
||||
if headers:
|
||||
# existing decrypted and masked headers
|
||||
existing_decrypted = mcp_provider.decrypted_headers
|
||||
existing_masked = mcp_provider.masked_headers
|
||||
|
||||
# Build final headers: if value equals masked existing, keep original decrypted value
|
||||
final_headers: dict[str, str] = {}
|
||||
for key, incoming_value in headers.items():
|
||||
if (
|
||||
key in existing_masked
|
||||
and key in existing_decrypted
|
||||
and isinstance(incoming_value, str)
|
||||
and incoming_value == existing_masked.get(key)
|
||||
):
|
||||
# unchanged, use original decrypted value
|
||||
final_headers[key] = str(existing_decrypted[key])
|
||||
else:
|
||||
final_headers[key] = incoming_value
|
||||
|
||||
encrypted_headers_dict = MCPToolManageService._encrypt_headers(final_headers, tenant_id)
|
||||
mcp_provider.encrypted_headers = json.dumps(encrypted_headers_dict)
|
||||
else:
|
||||
# Explicitly clear headers if empty dict passed
|
||||
mcp_provider.encrypted_headers = None
|
||||
db.session.commit()
|
||||
except IntegrityError as e:
|
||||
db.session.rollback()
|
||||
@ -226,10 +303,10 @@ class MCPToolManageService:
|
||||
def update_mcp_provider_credentials(
|
||||
cls, mcp_provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False
|
||||
):
|
||||
provider_controller = MCPToolProviderController._from_db(mcp_provider)
|
||||
provider_controller = MCPToolProviderController.from_db(mcp_provider)
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=mcp_provider.tenant_id,
|
||||
config=list(provider_controller.get_credentials_schema()),
|
||||
config=list(provider_controller.get_credentials_schema()), # ty: ignore [invalid-argument-type]
|
||||
provider_config_cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
credentials = tool_configuration.encrypt(credentials)
|
||||
@ -242,6 +319,12 @@ class MCPToolManageService:
|
||||
|
||||
@classmethod
|
||||
def _re_connect_mcp_provider(cls, server_url: str, provider_id: str, tenant_id: str):
|
||||
# Get the existing provider to access headers and timeout settings
|
||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
headers = mcp_provider.decrypted_headers
|
||||
timeout = mcp_provider.timeout
|
||||
sse_read_timeout = mcp_provider.sse_read_timeout
|
||||
|
||||
try:
|
||||
with MCPClient(
|
||||
server_url,
|
||||
@ -249,6 +332,9 @@ class MCPToolManageService:
|
||||
tenant_id,
|
||||
authed=False,
|
||||
for_list=True,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
) as mcp_client:
|
||||
tools = mcp_client.list_tools()
|
||||
return {
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
|
||||
from core.tools.tool_manager import ToolManager
|
||||
@ -10,7 +9,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ToolCommonService:
|
||||
@staticmethod
|
||||
def list_tool_providers(user_id: str, tenant_id: str, typ: Optional[ToolProviderTypeApiLiteral] = None):
|
||||
def list_tool_providers(user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral | None = None):
|
||||
"""
|
||||
list tool providers
|
||||
|
||||
|
||||
@ -1,13 +1,14 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional, Union, cast
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Union
|
||||
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||
from core.mcp.types import Tool as MCPTool
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.plugin.entities.plugin_daemon import CredentialType, PluginDatasourceProviderEntity
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
@ -32,7 +33,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ToolTransformService:
|
||||
@classmethod
|
||||
def get_tool_provider_icon_url(cls, provider_type: str, provider_name: str, icon: str | dict) -> Union[str, dict]:
|
||||
def get_tool_provider_icon_url(
|
||||
cls, provider_type: str, provider_name: str, icon: str | Mapping[str, str]
|
||||
) -> str | Mapping[str, str]:
|
||||
"""
|
||||
get tool provider icon url
|
||||
"""
|
||||
@ -45,7 +48,7 @@ class ToolTransformService:
|
||||
elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}:
|
||||
try:
|
||||
if isinstance(icon, str):
|
||||
return cast(dict, json.loads(icon))
|
||||
return json.loads(icon)
|
||||
return icon
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
@ -54,7 +57,7 @@ class ToolTransformService:
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity]):
|
||||
def repack_provider(tenant_id: str, provider: Union[dict, ToolProviderApiEntity, PluginDatasourceProviderEntity]):
|
||||
"""
|
||||
repack provider
|
||||
|
||||
@ -68,7 +71,9 @@ class ToolTransformService:
|
||||
elif isinstance(provider, ToolProviderApiEntity):
|
||||
if provider.plugin_id:
|
||||
if isinstance(provider.icon, str):
|
||||
provider.icon = PluginService.get_plugin_icon_url(tenant_id=tenant_id, filename=provider.icon)
|
||||
provider.icon = PluginService.get_plugin_icon_url(
|
||||
tenant_id=tenant_id, filename=provider.icon
|
||||
)
|
||||
if isinstance(provider.icon_dark, str) and provider.icon_dark:
|
||||
provider.icon_dark = PluginService.get_plugin_icon_url(
|
||||
tenant_id=tenant_id, filename=provider.icon_dark
|
||||
@ -81,12 +86,18 @@ class ToolTransformService:
|
||||
provider.icon_dark = ToolTransformService.get_tool_provider_icon_url(
|
||||
provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon_dark
|
||||
)
|
||||
elif isinstance(provider, PluginDatasourceProviderEntity):
|
||||
if provider.plugin_id:
|
||||
if isinstance(provider.declaration.identity.icon, str):
|
||||
provider.declaration.identity.icon = ToolTransformService.get_plugin_icon_url(
|
||||
tenant_id=tenant_id, filename=provider.declaration.identity.icon
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def builtin_provider_to_user_provider(
|
||||
cls,
|
||||
provider_controller: BuiltinToolProviderController | PluginToolProviderController,
|
||||
db_provider: Optional[BuiltinToolProvider],
|
||||
db_provider: BuiltinToolProvider | None,
|
||||
decrypt_credentials: bool = True,
|
||||
) -> ToolProviderApiEntity:
|
||||
"""
|
||||
@ -98,7 +109,7 @@ class ToolTransformService:
|
||||
name=provider_controller.entity.identity.name,
|
||||
description=provider_controller.entity.identity.description,
|
||||
icon=provider_controller.entity.identity.icon,
|
||||
icon_dark=provider_controller.entity.identity.icon_dark,
|
||||
icon_dark=provider_controller.entity.identity.icon_dark or "",
|
||||
label=provider_controller.entity.identity.label,
|
||||
type=ToolProviderType.BUILT_IN,
|
||||
masked_credentials={},
|
||||
@ -120,9 +131,10 @@ class ToolTransformService:
|
||||
)
|
||||
}
|
||||
|
||||
for name, value in schema.items():
|
||||
if result.masked_credentials:
|
||||
result.masked_credentials[name] = ""
|
||||
masked_creds = {}
|
||||
for name in schema:
|
||||
masked_creds[name] = ""
|
||||
result.masked_credentials = masked_creds
|
||||
|
||||
# check if the provider need credentials
|
||||
if not provider_controller.need_credentials:
|
||||
@ -200,7 +212,7 @@ class ToolTransformService:
|
||||
name=provider_controller.entity.identity.name,
|
||||
description=provider_controller.entity.identity.description,
|
||||
icon=provider_controller.entity.identity.icon,
|
||||
icon_dark=provider_controller.entity.identity.icon_dark,
|
||||
icon_dark=provider_controller.entity.identity.icon_dark or "",
|
||||
label=provider_controller.entity.identity.label,
|
||||
type=ToolProviderType.WORKFLOW,
|
||||
masked_credentials={},
|
||||
@ -229,6 +241,10 @@ class ToolTransformService:
|
||||
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
|
||||
description=I18nObject(en_US="", zh_Hans=""),
|
||||
server_identifier=db_provider.server_identifier,
|
||||
timeout=db_provider.timeout,
|
||||
sse_read_timeout=db_provider.sse_read_timeout,
|
||||
masked_headers=db_provider.masked_headers,
|
||||
original_headers=db_provider.decrypted_headers,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -239,7 +255,7 @@ class ToolTransformService:
|
||||
author=user.name if user else "Anonymous",
|
||||
name=tool.name,
|
||||
label=I18nObject(en_US=tool.name, zh_Hans=tool.name),
|
||||
description=I18nObject(en_US=tool.description, zh_Hans=tool.description),
|
||||
description=I18nObject(en_US=tool.description or "", zh_Hans=tool.description or ""),
|
||||
parameters=ToolTransformService.convert_mcp_schema_to_parameter(tool.inputSchema),
|
||||
labels=[],
|
||||
)
|
||||
@ -309,7 +325,7 @@ class ToolTransformService:
|
||||
|
||||
@staticmethod
|
||||
def convert_tool_entity_to_api_entity(
|
||||
tool: Union[ApiToolBundle, WorkflowTool, Tool],
|
||||
tool: ApiToolBundle | WorkflowTool | Tool,
|
||||
tenant_id: str,
|
||||
labels: list[str] | None = None,
|
||||
) -> ToolApiEntity:
|
||||
@ -363,7 +379,7 @@ class ToolTransformService:
|
||||
parameters=merged_parameters,
|
||||
labels=labels or [],
|
||||
)
|
||||
elif isinstance(tool, ApiToolBundle):
|
||||
else:
|
||||
return ToolApiEntity(
|
||||
author=tool.author,
|
||||
name=tool.operation_id or "",
|
||||
@ -372,9 +388,6 @@ class ToolTransformService:
|
||||
parameters=tool.parameters,
|
||||
labels=labels or [],
|
||||
)
|
||||
else:
|
||||
# Handle WorkflowTool case
|
||||
raise ValueError(f"Unsupported tool type: {type(tool)}")
|
||||
|
||||
@staticmethod
|
||||
def convert_builtin_provider_to_credential_entity(
|
||||
|
||||
@ -3,7 +3,7 @@ from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import or_, select
|
||||
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
@ -37,7 +37,7 @@ class WorkflowToolManageService:
|
||||
parameters: list[Mapping[str, Any]],
|
||||
privacy_policy: str = "",
|
||||
labels: list[str] | None = None,
|
||||
) -> dict:
|
||||
):
|
||||
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
|
||||
|
||||
# check if the name is unique
|
||||
@ -103,7 +103,7 @@ class WorkflowToolManageService:
|
||||
parameters: list[Mapping[str, Any]],
|
||||
privacy_policy: str = "",
|
||||
labels: list[str] | None = None,
|
||||
) -> dict:
|
||||
):
|
||||
"""
|
||||
Update a workflow tool.
|
||||
:param user_id: the user id
|
||||
@ -186,7 +186,9 @@ class WorkflowToolManageService:
|
||||
:param tenant_id: the tenant id
|
||||
:return: the list of tools
|
||||
"""
|
||||
db_tools = db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all()
|
||||
db_tools = db.session.scalars(
|
||||
select(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id)
|
||||
).all()
|
||||
|
||||
tools: list[WorkflowToolProviderController] = []
|
||||
for provider in db_tools:
|
||||
@ -217,7 +219,7 @@ class WorkflowToolManageService:
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def delete_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict:
|
||||
def delete_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str):
|
||||
"""
|
||||
Delete a workflow tool.
|
||||
:param user_id: the user id
|
||||
@ -233,7 +235,7 @@ class WorkflowToolManageService:
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict:
|
||||
def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str):
|
||||
"""
|
||||
Get a workflow tool.
|
||||
:param user_id: the user id
|
||||
@ -249,7 +251,7 @@ class WorkflowToolManageService:
|
||||
return cls._get_workflow_tool(tenant_id, db_tool)
|
||||
|
||||
@classmethod
|
||||
def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str) -> dict:
|
||||
def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str):
|
||||
"""
|
||||
Get a workflow tool.
|
||||
:param user_id: the user id
|
||||
@ -265,7 +267,7 @@ class WorkflowToolManageService:
|
||||
return cls._get_workflow_tool(tenant_id, db_tool)
|
||||
|
||||
@classmethod
|
||||
def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None) -> dict:
|
||||
def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None):
|
||||
"""
|
||||
Get a workflow tool.
|
||||
:db_tool: the database tool
|
||||
|
||||
Reference in New Issue
Block a user