chore(refactor): queries in service and auth components

This commit is contained in:
Novice
2025-06-25 14:09:19 +08:00
parent 01922f2d02
commit f783ad68e4
9 changed files with 135 additions and 148 deletions

View File

@ -23,6 +23,7 @@ from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from services.tools.mcp_tools_mange_service import MCPToolManageService
from services.tools.tool_labels_service import ToolLabelsService
from services.tools.tools_manage_service import ToolCommonService
from services.tools.tools_transform_service import ToolTransformService
from services.tools.workflow_tools_manage_service import WorkflowToolManageService
@ -693,27 +694,26 @@ class ToolMCPAuthApi(Resource):
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if not provider:
raise ValueError("provider not found")
server_url = MCPToolManageService.get_mcp_provider_server_url(tenant_id, provider_id)
try:
with MCPClient(
server_url,
provider.decrypted_server_url,
provider_id,
tenant_id,
authed=False,
authorization_code=args["authorization_code"],
for_list=True,
):
MCPToolManageService.update_mcp_provider_credentials(
tenant_id=tenant_id,
provider_id=provider_id,
credentials=MCPToolManageService.get_mcp_provider_decrypted_credentials(tenant_id, provider_id),
mcp_provider=provider,
credentials=provider.decrypted_credentials,
authed=True,
)
return {"result": "success"}
except MCPAuthError:
auth_provider = OAuthClientProvider(provider_id, tenant_id)
auth_provider = OAuthClientProvider(provider_id, tenant_id, for_list=True)
return auth(auth_provider, server_url, args["authorization_code"])
return auth(auth_provider, provider.decrypted_server_url, args["authorization_code"])
class ToolMCPDetailApi(Resource):
@ -722,12 +722,8 @@ class ToolMCPDetailApi(Resource):
@account_initialization_required
def get(self, provider_id):
user = current_user
return jsonable_encoder(
MCPToolManageService.retrieve_mcp_provider(
tenant_id=user.current_tenant_id,
provider_id=provider_id,
)
)
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, user.current_tenant_id)
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider))
class ToolMCPListAllApi(Resource):

View File

@ -258,10 +258,14 @@ class BaseAgentRunner(AppRunner):
if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options] if parameter.options else []
prompt_tool.parameters["properties"][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or "",
}
prompt_tool.parameters["properties"][parameter.name] = (
{
"type": parameter_type,
"description": parameter.llm_description or "",
}
if parameter.input_schema is None
else parameter.input_schema
)
if len(enum) > 0:
prompt_tool.parameters["properties"][parameter.name]["enum"] = enum

View File

@ -98,7 +98,7 @@ def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackSta
full_state_data.code_verifier,
full_state_data.redirect_uri,
)
provider = OAuthClientProvider(full_state_data.provider_id, full_state_data.tenant_id)
provider = OAuthClientProvider(full_state_data.provider_id, full_state_data.tenant_id, for_list=True)
provider.save_tokens(tokens)
return full_state_data
@ -275,6 +275,7 @@ def auth(
server_url: str,
authorization_code: Optional[str] = None,
state_param: Optional[str] = None,
for_list: bool = False,
) -> dict[str, str]:
"""Orchestrates the full auth flow with a server using secure Redis state storage."""
metadata = discover_oauth_metadata(server_url)
@ -337,8 +338,8 @@ def auth(
metadata,
client_information,
provider.redirect_url,
provider.provider_id,
provider.tenant_id,
provider.mcp_provider.id,
provider.mcp_provider.tenant_id,
)
provider.save_code_verifier(code_verifier)

View File

@ -7,18 +7,20 @@ from core.mcp.types import (
OAuthClientMetadata,
OAuthTokens,
)
from models.tools import MCPToolProvider
from services.tools.mcp_tools_mange_service import MCPToolManageService
LATEST_PROTOCOL_VERSION = "1.0"
class OAuthClientProvider:
provider_id: str
tenant_id: str
mcp_provider: MCPToolProvider
def __init__(self, provider_id: str, tenant_id: str):
self.provider_id = provider_id
self.tenant_id = tenant_id
def __init__(self, provider_id: str, tenant_id: str, for_list: bool = False):
if for_list:
self.mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
else:
self.mcp_provider = MCPToolManageService.get_mcp_provider_by_server_identifier(provider_id, tenant_id)
@property
def redirect_url(self) -> str:
@ -39,12 +41,7 @@ class OAuthClientProvider:
def client_information(self) -> Optional[OAuthClientInformation]:
"""Loads information about this OAuth client."""
mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(self.provider_id, self.tenant_id)
if not mcp_provider:
return None
client_information = MCPToolManageService.get_mcp_provider_decrypted_credentials(
self.tenant_id, self.provider_id
).get("client_information", {})
client_information = self.mcp_provider.decrypted_credentials.get("client_information", {})
if not client_information:
return None
return OAuthClientInformation.model_validate(client_information)
@ -52,15 +49,13 @@ class OAuthClientProvider:
def save_client_information(self, client_information: OAuthClientInformationFull) -> None:
"""Saves client information after dynamic registration."""
MCPToolManageService.update_mcp_provider_credentials(
self.tenant_id, self.provider_id, {"client_information": client_information.model_dump()}
self.mcp_provider,
{"client_information": client_information.model_dump()},
)
def tokens(self) -> Optional[OAuthTokens]:
"""Loads any existing OAuth tokens for the current session."""
mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(self.provider_id, self.tenant_id)
if not mcp_provider:
return None
credentials = MCPToolManageService.get_mcp_provider_decrypted_credentials(self.tenant_id, self.provider_id)
credentials = self.mcp_provider.decrypted_credentials
if not credentials:
return None
return OAuthTokens(
@ -74,20 +69,13 @@ class OAuthClientProvider:
"""Stores new OAuth tokens for the current session."""
# update mcp provider credentials
token_dict = tokens.model_dump()
MCPToolManageService.update_mcp_provider_credentials(self.tenant_id, self.provider_id, token_dict, authed=True)
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, token_dict, authed=True)
def save_code_verifier(self, code_verifier: str) -> None:
"""Saves a PKCE code verifier for the current session."""
# update mcp provider credentials
MCPToolManageService.update_mcp_provider_credentials(
self.tenant_id, self.provider_id, {"code_verifier": code_verifier}
)
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, {"code_verifier": code_verifier})
def code_verifier(self) -> str:
"""Loads the PKCE code verifier for the current session."""
# get code verifier from mcp provider credentials
mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(self.provider_id, self.tenant_id)
if not mcp_provider:
return ""
credentials = MCPToolManageService.get_mcp_provider_decrypted_credentials(self.tenant_id, self.provider_id)
return str(credentials.get("code_verifier", ""))
return str(self.mcp_provider.decrypted_credentials.get("code_verifier", ""))

View File

@ -22,6 +22,7 @@ class MCPClient:
tenant_id: str,
authed: bool = True,
authorization_code: Optional[str] = None,
for_list: bool = False,
):
# Initialize info
self.provider_id = provider_id
@ -35,7 +36,7 @@ class MCPClient:
if authed:
from core.mcp.auth.auth_provider import OAuthClientProvider
self.provider = OAuthClientProvider(self.provider_id, self.tenant_id)
self.provider = OAuthClientProvider(self.provider_id, self.tenant_id, for_list=for_list)
self.token = self.provider.tokens()
# Initialize session and client objects

View File

@ -40,8 +40,6 @@ class MCPToolProviderController(ToolProviderController):
@classmethod
def _from_db(cls, db_provider: MCPToolProvider) -> "MCPToolProviderController":
from services.tools.mcp_tools_mange_service import MCPToolManageService
"""
from db provider
"""
@ -55,7 +53,7 @@ class MCPToolProviderController(ToolProviderController):
author=db_provider.user.name if db_provider.user else "Anonymous",
name=remote_mcp_tool.name,
label=I18nObject(en_US=remote_mcp_tool.name, zh_Hans=remote_mcp_tool.name),
provider=db_provider.id,
provider=db_provider.server_identifier,
icon=db_provider.icon,
),
parameters=ToolTransformService.convert_mcp_schema_to_parameter(remote_mcp_tool.inputSchema),
@ -84,9 +82,9 @@ class MCPToolProviderController(ToolProviderController):
credentials_schema=[],
tools=tools,
),
provider_id=db_provider.id or "",
provider_id=db_provider.server_identifier or "",
tenant_id=db_provider.tenant_id or "",
server_url=MCPToolManageService.get_mcp_provider_server_url(db_provider.tenant_id, db_provider.id),
server_url=db_provider.decrypted_server_url,
)
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:

View File

@ -1,6 +1,7 @@
import json
from datetime import datetime
from typing import Any, cast
from urllib.parse import urlparse
import sqlalchemy as sa
from deprecated import deprecated
@ -8,6 +9,7 @@ from sqlalchemy import ForeignKey, func
from sqlalchemy.orm import Mapped, mapped_column
from core.file import helpers as file_helpers
from core.helper import encrypter
from core.mcp.types import Tool
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
@ -258,6 +260,41 @@ class MCPToolProvider(Base):
except json.JSONDecodeError:
return file_helpers.get_signed_file_url(self.icon)
@property
def decrypted_server_url(self) -> str:
return cast(str, encrypter.decrypt_token(self.tenant_id, self.server_url))
@property
def masked_server_url(self) -> str:
def mask_url(url: str, mask_char: str = "*") -> str:
"""
mask the url to a simple string
"""
parsed = urlparse(url)
base_url = f"{parsed.scheme}://{parsed.netloc}"
if parsed.path and parsed.path != "/":
return f"{base_url}/{mask_char * 6}"
else:
return base_url
return mask_url(self.decrypted_server_url)
@property
def decrypted_credentials(self) -> dict:
from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.utils.configuration import ProviderConfigEncrypter
provider_controller = MCPToolProviderController._from_db(self)
tool_configuration = ProviderConfigEncrypter(
tenant_id=self.tenant_id,
config=list(provider_controller.get_credentials_schema()),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.provider_id,
)
return tool_configuration.decrypt(self.credentials, use_cache=False)
class ToolModelInvoke(Base):
"""

View File

@ -1,7 +1,6 @@
import hashlib
import json
from datetime import datetime
from urllib.parse import urlparse
from sqlalchemy import or_
from sqlalchemy.exc import IntegrityError
@ -18,18 +17,7 @@ from extensions.ext_database import db
from models.tools import MCPToolProvider
from services.tools.tools_transform_service import ToolTransformService
def mask_url(url: str, mask_char: str = "*"):
"""
mask the url to a simple string
"""
parsed = urlparse(url)
base_url = f"{parsed.scheme}://{parsed.netloc}"
if parsed.path and parsed.path != "/":
return f"{base_url}/{mask_char * 6}"
else:
return base_url
UNCHANGED_SERVER_URL_PLACEHOLDER = "[__HIDDEN__]"
class MCPToolManageService:
@ -38,15 +26,26 @@ class MCPToolManageService:
"""
@staticmethod
def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider | None:
return (
def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider:
res = (
db.session.query(MCPToolProvider)
.filter(
MCPToolProvider.id == provider_id,
MCPToolProvider.tenant_id == tenant_id,
)
.filter(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id)
.first()
)
if not res:
raise ValueError("MCP tool not found")
return res
@staticmethod
def get_mcp_provider_by_server_identifier(server_identifier: str, tenant_id: str) -> MCPToolProvider:
res = (
db.session.query(MCPToolProvider)
.filter(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier)
.first()
)
if not res:
raise ValueError("MCP tool not found")
return res
@staticmethod
def create_mcp_provider(
@ -109,11 +108,11 @@ class MCPToolManageService:
@classmethod
def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str):
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
server_url = cls.get_mcp_provider_server_url(tenant_id, provider_id)
if mcp_provider is None:
raise ValueError("MCP tool not found")
try:
with MCPClient(server_url, provider_id, tenant_id, authed=mcp_provider.authed) as mcp_client:
with MCPClient(
mcp_provider.decrypted_server_url, provider_id, tenant_id, authed=mcp_provider.authed, for_list=True
) as mcp_client:
tools = mcp_client.list_tools()
except MCPAuthError as e:
raise ValueError("Please auth the tool first")
@ -130,25 +129,17 @@ class MCPToolManageService:
type=ToolProviderType.MCP,
icon=mcp_provider.icon,
author=mcp_provider.user.name if mcp_provider.user else "Anonymous",
server_url=cls.get_masked_mcp_provider_server_url(tenant_id, provider_id),
server_url=mcp_provider.masked_server_url,
updated_at=int(mcp_provider.updated_at.timestamp()),
description=I18nObject(en_US="", zh_Hans=""),
label=I18nObject(en_US=mcp_provider.name, zh_Hans=mcp_provider.name),
plugin_unique_identifier=mcp_provider.server_identifier,
)
@classmethod
def retrieve_mcp_provider(cls, tenant_id: str, provider_id: str):
provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if provider is None:
raise ValueError("MCP tool not found")
return ToolTransformService.mcp_provider_to_user_provider(provider).to_dict()
@classmethod
def delete_mcp_tool(cls, tenant_id: str, provider_id: str):
mcp_tool = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if mcp_tool is None:
raise ValueError("MCP tool not found")
db.session.delete(mcp_tool)
db.session.commit()
@ -165,60 +156,38 @@ class MCPToolManageService:
server_identifier: str,
):
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if mcp_provider is None:
raise ValueError("MCP tool not found")
mcp_provider.name = name
mcp_provider.icon = (
json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon
)
mcp_provider.server_identifier = server_identifier
if "[__HIDDEN__]" in server_url:
db.session.commit()
return
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
mcp_provider.server_url = encrypted_server_url
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
# if the server url is changed, we need to re-auth the tool
try:
if UNCHANGED_SERVER_URL_PLACEHOLDER not in server_url:
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
mcp_provider.server_url = encrypted_server_url
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
if server_url_hash != mcp_provider.server_url_hash:
try:
with MCPClient(
server_url,
provider_id,
tenant_id,
authed=False,
) as mcp_client:
tools = mcp_client.list_tools()
mcp_provider.authed = True
mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools])
except MCPAuthError:
mcp_provider.authed = False
mcp_provider.tools = "[]"
mcp_provider.encrypted_credentials = "{}"
cls._re_auth_mcp_provider(mcp_provider, provider_id, tenant_id)
mcp_provider.server_url_hash = server_url_hash
try:
db.session.commit()
except IntegrityError as e:
db.session.rollback()
# Check if the error message contains the constraint name
if "unique_mcp_provider_name" in str(e.orig):
# Raise your custom exception
raise ValueError(f"A provider with name '{name}' already exists.")
elif "unique_mcp_provider_server_url" in str(e.orig):
# You can define another custom exception for the other constraint
raise ValueError(f"A provider for server URL '{server_url}' already exists.")
error_msg = str(e.orig)
if "unique_mcp_provider_name" in error_msg:
raise ValueError(f"MCP tool {name} already exists")
elif "unique_mcp_provider_server_url" in error_msg:
raise ValueError(f"MCP tool {server_url} already exists")
else:
# Re-raise the original exception if it's not the one you're handling
raise
@classmethod
def update_mcp_provider_credentials(cls, tenant_id: str, provider_id: str, credentials: dict, authed: bool = False):
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if mcp_provider is None:
raise ValueError("MCP tool not found")
def update_mcp_provider_credentials(cls, mcp_provider: MCPToolProvider, credentials: dict, authed: bool = False):
provider_controller = MCPToolProviderController._from_db(mcp_provider)
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
tenant_id=mcp_provider.tenant_id,
config=list(provider_controller.get_credentials_schema()),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.provider_id,
@ -229,27 +198,22 @@ class MCPToolManageService:
db.session.commit()
@classmethod
def get_mcp_provider_decrypted_credentials(cls, tenant_id: str, provider_id: str):
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if mcp_provider is None:
raise ValueError("MCP tool not found")
provider_controller = MCPToolProviderController._from_db(mcp_provider)
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=list(provider_controller.get_credentials_schema()),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.provider_id,
)
return tool_configuration.decrypt(mcp_provider.credentials, use_cache=False)
def _re_auth_mcp_provider(cls, mcp_provider: MCPToolProvider, provider_id: str, tenant_id: str):
"""re-auth mcp provider"""
try:
with MCPClient(
mcp_provider.decrypted_server_url,
provider_id,
tenant_id,
authed=False,
for_list=True,
) as mcp_client:
tools = mcp_client.list_tools()
mcp_provider.authed = True
mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools])
except MCPAuthError:
mcp_provider.authed = False
mcp_provider.tools = "[]"
@classmethod
def get_mcp_provider_server_url(cls, tenant_id: str, provider_id: str):
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
if mcp_provider is None:
raise ValueError("MCP tool not found")
return encrypter.decrypt_token(tenant_id, mcp_provider.server_url)
@classmethod
def get_masked_mcp_provider_server_url(cls, tenant_id: str, provider_id: str):
server_url = cls.get_mcp_provider_server_url(tenant_id, provider_id)
return mask_url(server_url)
# reset credentials
mcp_provider.encrypted_credentials = "{}"

View File

@ -191,8 +191,6 @@ class ToolTransformService:
@staticmethod
def mcp_provider_to_user_provider(db_provider: MCPToolProvider, for_list: bool = False) -> ToolProviderApiEntity:
from services.tools.mcp_tools_mange_service import MCPToolManageService
return ToolProviderApiEntity(
id=db_provider.server_identifier if not for_list else db_provider.id,
author=db_provider.user.name if db_provider.user else "Anonymous",
@ -200,7 +198,7 @@ class ToolTransformService:
icon=db_provider.provider_icon,
type=ToolProviderType.MCP,
is_team_authorization=db_provider.authed,
server_url=MCPToolManageService.get_masked_mcp_provider_server_url(db_provider.tenant_id, db_provider.id),
server_url=db_provider.masked_server_url,
tools=ToolTransformService.mcp_tool_to_user_tool(
db_provider, [MCPTool(**tool) for tool in json.loads(db_provider.tools)]
),