mirror of
https://github.com/langgenius/dify.git
synced 2026-04-26 21:55:58 +08:00
chore(refactor): queries in service and auth components
This commit is contained in:
@ -23,6 +23,7 @@ from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
from services.tools.mcp_tools_mange_service import MCPToolManageService
|
||||
from services.tools.tool_labels_service import ToolLabelsService
|
||||
from services.tools.tools_manage_service import ToolCommonService
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
from services.tools.workflow_tools_manage_service import WorkflowToolManageService
|
||||
|
||||
|
||||
@ -693,27 +694,26 @@ class ToolMCPAuthApi(Resource):
|
||||
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
if not provider:
|
||||
raise ValueError("provider not found")
|
||||
server_url = MCPToolManageService.get_mcp_provider_server_url(tenant_id, provider_id)
|
||||
try:
|
||||
with MCPClient(
|
||||
server_url,
|
||||
provider.decrypted_server_url,
|
||||
provider_id,
|
||||
tenant_id,
|
||||
authed=False,
|
||||
authorization_code=args["authorization_code"],
|
||||
for_list=True,
|
||||
):
|
||||
MCPToolManageService.update_mcp_provider_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=provider_id,
|
||||
credentials=MCPToolManageService.get_mcp_provider_decrypted_credentials(tenant_id, provider_id),
|
||||
mcp_provider=provider,
|
||||
credentials=provider.decrypted_credentials,
|
||||
authed=True,
|
||||
)
|
||||
return {"result": "success"}
|
||||
|
||||
except MCPAuthError:
|
||||
auth_provider = OAuthClientProvider(provider_id, tenant_id)
|
||||
auth_provider = OAuthClientProvider(provider_id, tenant_id, for_list=True)
|
||||
|
||||
return auth(auth_provider, server_url, args["authorization_code"])
|
||||
return auth(auth_provider, provider.decrypted_server_url, args["authorization_code"])
|
||||
|
||||
|
||||
class ToolMCPDetailApi(Resource):
|
||||
@ -722,12 +722,8 @@ class ToolMCPDetailApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, provider_id):
|
||||
user = current_user
|
||||
return jsonable_encoder(
|
||||
MCPToolManageService.retrieve_mcp_provider(
|
||||
tenant_id=user.current_tenant_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
)
|
||||
provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, user.current_tenant_id)
|
||||
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider))
|
||||
|
||||
|
||||
class ToolMCPListAllApi(Resource):
|
||||
|
||||
@ -258,10 +258,14 @@ class BaseAgentRunner(AppRunner):
|
||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
enum = [option.value for option in parameter.options] if parameter.options else []
|
||||
|
||||
prompt_tool.parameters["properties"][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or "",
|
||||
}
|
||||
prompt_tool.parameters["properties"][parameter.name] = (
|
||||
{
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or "",
|
||||
}
|
||||
if parameter.input_schema is None
|
||||
else parameter.input_schema
|
||||
)
|
||||
|
||||
if len(enum) > 0:
|
||||
prompt_tool.parameters["properties"][parameter.name]["enum"] = enum
|
||||
|
||||
@ -98,7 +98,7 @@ def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackSta
|
||||
full_state_data.code_verifier,
|
||||
full_state_data.redirect_uri,
|
||||
)
|
||||
provider = OAuthClientProvider(full_state_data.provider_id, full_state_data.tenant_id)
|
||||
provider = OAuthClientProvider(full_state_data.provider_id, full_state_data.tenant_id, for_list=True)
|
||||
provider.save_tokens(tokens)
|
||||
return full_state_data
|
||||
|
||||
@ -275,6 +275,7 @@ def auth(
|
||||
server_url: str,
|
||||
authorization_code: Optional[str] = None,
|
||||
state_param: Optional[str] = None,
|
||||
for_list: bool = False,
|
||||
) -> dict[str, str]:
|
||||
"""Orchestrates the full auth flow with a server using secure Redis state storage."""
|
||||
metadata = discover_oauth_metadata(server_url)
|
||||
@ -337,8 +338,8 @@ def auth(
|
||||
metadata,
|
||||
client_information,
|
||||
provider.redirect_url,
|
||||
provider.provider_id,
|
||||
provider.tenant_id,
|
||||
provider.mcp_provider.id,
|
||||
provider.mcp_provider.tenant_id,
|
||||
)
|
||||
|
||||
provider.save_code_verifier(code_verifier)
|
||||
|
||||
@ -7,18 +7,20 @@ from core.mcp.types import (
|
||||
OAuthClientMetadata,
|
||||
OAuthTokens,
|
||||
)
|
||||
from models.tools import MCPToolProvider
|
||||
from services.tools.mcp_tools_mange_service import MCPToolManageService
|
||||
|
||||
LATEST_PROTOCOL_VERSION = "1.0"
|
||||
|
||||
|
||||
class OAuthClientProvider:
|
||||
provider_id: str
|
||||
tenant_id: str
|
||||
mcp_provider: MCPToolProvider
|
||||
|
||||
def __init__(self, provider_id: str, tenant_id: str):
|
||||
self.provider_id = provider_id
|
||||
self.tenant_id = tenant_id
|
||||
def __init__(self, provider_id: str, tenant_id: str, for_list: bool = False):
|
||||
if for_list:
|
||||
self.mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
else:
|
||||
self.mcp_provider = MCPToolManageService.get_mcp_provider_by_server_identifier(provider_id, tenant_id)
|
||||
|
||||
@property
|
||||
def redirect_url(self) -> str:
|
||||
@ -39,12 +41,7 @@ class OAuthClientProvider:
|
||||
|
||||
def client_information(self) -> Optional[OAuthClientInformation]:
|
||||
"""Loads information about this OAuth client."""
|
||||
mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(self.provider_id, self.tenant_id)
|
||||
if not mcp_provider:
|
||||
return None
|
||||
client_information = MCPToolManageService.get_mcp_provider_decrypted_credentials(
|
||||
self.tenant_id, self.provider_id
|
||||
).get("client_information", {})
|
||||
client_information = self.mcp_provider.decrypted_credentials.get("client_information", {})
|
||||
if not client_information:
|
||||
return None
|
||||
return OAuthClientInformation.model_validate(client_information)
|
||||
@ -52,15 +49,13 @@ class OAuthClientProvider:
|
||||
def save_client_information(self, client_information: OAuthClientInformationFull) -> None:
|
||||
"""Saves client information after dynamic registration."""
|
||||
MCPToolManageService.update_mcp_provider_credentials(
|
||||
self.tenant_id, self.provider_id, {"client_information": client_information.model_dump()}
|
||||
self.mcp_provider,
|
||||
{"client_information": client_information.model_dump()},
|
||||
)
|
||||
|
||||
def tokens(self) -> Optional[OAuthTokens]:
|
||||
"""Loads any existing OAuth tokens for the current session."""
|
||||
mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(self.provider_id, self.tenant_id)
|
||||
if not mcp_provider:
|
||||
return None
|
||||
credentials = MCPToolManageService.get_mcp_provider_decrypted_credentials(self.tenant_id, self.provider_id)
|
||||
credentials = self.mcp_provider.decrypted_credentials
|
||||
if not credentials:
|
||||
return None
|
||||
return OAuthTokens(
|
||||
@ -74,20 +69,13 @@ class OAuthClientProvider:
|
||||
"""Stores new OAuth tokens for the current session."""
|
||||
# update mcp provider credentials
|
||||
token_dict = tokens.model_dump()
|
||||
MCPToolManageService.update_mcp_provider_credentials(self.tenant_id, self.provider_id, token_dict, authed=True)
|
||||
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, token_dict, authed=True)
|
||||
|
||||
def save_code_verifier(self, code_verifier: str) -> None:
|
||||
"""Saves a PKCE code verifier for the current session."""
|
||||
# update mcp provider credentials
|
||||
MCPToolManageService.update_mcp_provider_credentials(
|
||||
self.tenant_id, self.provider_id, {"code_verifier": code_verifier}
|
||||
)
|
||||
MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, {"code_verifier": code_verifier})
|
||||
|
||||
def code_verifier(self) -> str:
|
||||
"""Loads the PKCE code verifier for the current session."""
|
||||
# get code verifier from mcp provider credentials
|
||||
mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(self.provider_id, self.tenant_id)
|
||||
if not mcp_provider:
|
||||
return ""
|
||||
credentials = MCPToolManageService.get_mcp_provider_decrypted_credentials(self.tenant_id, self.provider_id)
|
||||
return str(credentials.get("code_verifier", ""))
|
||||
return str(self.mcp_provider.decrypted_credentials.get("code_verifier", ""))
|
||||
|
||||
@ -22,6 +22,7 @@ class MCPClient:
|
||||
tenant_id: str,
|
||||
authed: bool = True,
|
||||
authorization_code: Optional[str] = None,
|
||||
for_list: bool = False,
|
||||
):
|
||||
# Initialize info
|
||||
self.provider_id = provider_id
|
||||
@ -35,7 +36,7 @@ class MCPClient:
|
||||
if authed:
|
||||
from core.mcp.auth.auth_provider import OAuthClientProvider
|
||||
|
||||
self.provider = OAuthClientProvider(self.provider_id, self.tenant_id)
|
||||
self.provider = OAuthClientProvider(self.provider_id, self.tenant_id, for_list=for_list)
|
||||
self.token = self.provider.tokens()
|
||||
|
||||
# Initialize session and client objects
|
||||
|
||||
@ -40,8 +40,6 @@ class MCPToolProviderController(ToolProviderController):
|
||||
|
||||
@classmethod
|
||||
def _from_db(cls, db_provider: MCPToolProvider) -> "MCPToolProviderController":
|
||||
from services.tools.mcp_tools_mange_service import MCPToolManageService
|
||||
|
||||
"""
|
||||
from db provider
|
||||
"""
|
||||
@ -55,7 +53,7 @@ class MCPToolProviderController(ToolProviderController):
|
||||
author=db_provider.user.name if db_provider.user else "Anonymous",
|
||||
name=remote_mcp_tool.name,
|
||||
label=I18nObject(en_US=remote_mcp_tool.name, zh_Hans=remote_mcp_tool.name),
|
||||
provider=db_provider.id,
|
||||
provider=db_provider.server_identifier,
|
||||
icon=db_provider.icon,
|
||||
),
|
||||
parameters=ToolTransformService.convert_mcp_schema_to_parameter(remote_mcp_tool.inputSchema),
|
||||
@ -84,9 +82,9 @@ class MCPToolProviderController(ToolProviderController):
|
||||
credentials_schema=[],
|
||||
tools=tools,
|
||||
),
|
||||
provider_id=db_provider.id or "",
|
||||
provider_id=db_provider.server_identifier or "",
|
||||
tenant_id=db_provider.tenant_id or "",
|
||||
server_url=MCPToolManageService.get_mcp_provider_server_url(db_provider.tenant_id, db_provider.id),
|
||||
server_url=db_provider.decrypted_server_url,
|
||||
)
|
||||
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import sqlalchemy as sa
|
||||
from deprecated import deprecated
|
||||
@ -8,6 +9,7 @@ from sqlalchemy import ForeignKey, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import encrypter
|
||||
from core.mcp.types import Tool
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
@ -258,6 +260,41 @@ class MCPToolProvider(Base):
|
||||
except json.JSONDecodeError:
|
||||
return file_helpers.get_signed_file_url(self.icon)
|
||||
|
||||
@property
|
||||
def decrypted_server_url(self) -> str:
|
||||
return cast(str, encrypter.decrypt_token(self.tenant_id, self.server_url))
|
||||
|
||||
@property
|
||||
def masked_server_url(self) -> str:
|
||||
def mask_url(url: str, mask_char: str = "*") -> str:
|
||||
"""
|
||||
mask the url to a simple string
|
||||
"""
|
||||
parsed = urlparse(url)
|
||||
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
||||
|
||||
if parsed.path and parsed.path != "/":
|
||||
return f"{base_url}/{mask_char * 6}"
|
||||
else:
|
||||
return base_url
|
||||
|
||||
return mask_url(self.decrypted_server_url)
|
||||
|
||||
@property
|
||||
def decrypted_credentials(self) -> dict:
|
||||
from core.tools.mcp_tool.provider import MCPToolProviderController
|
||||
from core.tools.utils.configuration import ProviderConfigEncrypter
|
||||
|
||||
provider_controller = MCPToolProviderController._from_db(self)
|
||||
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=self.tenant_id,
|
||||
config=list(provider_controller.get_credentials_schema()),
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.provider_id,
|
||||
)
|
||||
return tool_configuration.decrypt(self.credentials, use_cache=False)
|
||||
|
||||
|
||||
class ToolModelInvoke(Base):
|
||||
"""
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import hashlib
|
||||
import json
|
||||
from datetime import datetime
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
@ -18,18 +17,7 @@ from extensions.ext_database import db
|
||||
from models.tools import MCPToolProvider
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
|
||||
def mask_url(url: str, mask_char: str = "*"):
|
||||
"""
|
||||
mask the url to a simple string
|
||||
"""
|
||||
parsed = urlparse(url)
|
||||
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
||||
|
||||
if parsed.path and parsed.path != "/":
|
||||
return f"{base_url}/{mask_char * 6}"
|
||||
else:
|
||||
return base_url
|
||||
UNCHANGED_SERVER_URL_PLACEHOLDER = "[__HIDDEN__]"
|
||||
|
||||
|
||||
class MCPToolManageService:
|
||||
@ -38,15 +26,26 @@ class MCPToolManageService:
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider | None:
|
||||
return (
|
||||
def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider:
|
||||
res = (
|
||||
db.session.query(MCPToolProvider)
|
||||
.filter(
|
||||
MCPToolProvider.id == provider_id,
|
||||
MCPToolProvider.tenant_id == tenant_id,
|
||||
)
|
||||
.filter(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id)
|
||||
.first()
|
||||
)
|
||||
if not res:
|
||||
raise ValueError("MCP tool not found")
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def get_mcp_provider_by_server_identifier(server_identifier: str, tenant_id: str) -> MCPToolProvider:
|
||||
res = (
|
||||
db.session.query(MCPToolProvider)
|
||||
.filter(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier)
|
||||
.first()
|
||||
)
|
||||
if not res:
|
||||
raise ValueError("MCP tool not found")
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def create_mcp_provider(
|
||||
@ -109,11 +108,11 @@ class MCPToolManageService:
|
||||
@classmethod
|
||||
def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str):
|
||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
server_url = cls.get_mcp_provider_server_url(tenant_id, provider_id)
|
||||
if mcp_provider is None:
|
||||
raise ValueError("MCP tool not found")
|
||||
|
||||
try:
|
||||
with MCPClient(server_url, provider_id, tenant_id, authed=mcp_provider.authed) as mcp_client:
|
||||
with MCPClient(
|
||||
mcp_provider.decrypted_server_url, provider_id, tenant_id, authed=mcp_provider.authed, for_list=True
|
||||
) as mcp_client:
|
||||
tools = mcp_client.list_tools()
|
||||
except MCPAuthError as e:
|
||||
raise ValueError("Please auth the tool first")
|
||||
@ -130,25 +129,17 @@ class MCPToolManageService:
|
||||
type=ToolProviderType.MCP,
|
||||
icon=mcp_provider.icon,
|
||||
author=mcp_provider.user.name if mcp_provider.user else "Anonymous",
|
||||
server_url=cls.get_masked_mcp_provider_server_url(tenant_id, provider_id),
|
||||
server_url=mcp_provider.masked_server_url,
|
||||
updated_at=int(mcp_provider.updated_at.timestamp()),
|
||||
description=I18nObject(en_US="", zh_Hans=""),
|
||||
label=I18nObject(en_US=mcp_provider.name, zh_Hans=mcp_provider.name),
|
||||
plugin_unique_identifier=mcp_provider.server_identifier,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def retrieve_mcp_provider(cls, tenant_id: str, provider_id: str):
|
||||
provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
if provider is None:
|
||||
raise ValueError("MCP tool not found")
|
||||
return ToolTransformService.mcp_provider_to_user_provider(provider).to_dict()
|
||||
|
||||
@classmethod
|
||||
def delete_mcp_tool(cls, tenant_id: str, provider_id: str):
|
||||
mcp_tool = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
if mcp_tool is None:
|
||||
raise ValueError("MCP tool not found")
|
||||
|
||||
db.session.delete(mcp_tool)
|
||||
db.session.commit()
|
||||
|
||||
@ -165,60 +156,38 @@ class MCPToolManageService:
|
||||
server_identifier: str,
|
||||
):
|
||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
if mcp_provider is None:
|
||||
raise ValueError("MCP tool not found")
|
||||
|
||||
mcp_provider.name = name
|
||||
mcp_provider.icon = (
|
||||
json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon
|
||||
)
|
||||
mcp_provider.server_identifier = server_identifier
|
||||
|
||||
if "[__HIDDEN__]" in server_url:
|
||||
db.session.commit()
|
||||
return
|
||||
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
|
||||
mcp_provider.server_url = encrypted_server_url
|
||||
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
|
||||
# if the server url is changed, we need to re-auth the tool
|
||||
try:
|
||||
if UNCHANGED_SERVER_URL_PLACEHOLDER not in server_url:
|
||||
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
|
||||
mcp_provider.server_url = encrypted_server_url
|
||||
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
|
||||
|
||||
if server_url_hash != mcp_provider.server_url_hash:
|
||||
try:
|
||||
with MCPClient(
|
||||
server_url,
|
||||
provider_id,
|
||||
tenant_id,
|
||||
authed=False,
|
||||
) as mcp_client:
|
||||
tools = mcp_client.list_tools()
|
||||
mcp_provider.authed = True
|
||||
mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools])
|
||||
except MCPAuthError:
|
||||
mcp_provider.authed = False
|
||||
mcp_provider.tools = "[]"
|
||||
mcp_provider.encrypted_credentials = "{}"
|
||||
cls._re_auth_mcp_provider(mcp_provider, provider_id, tenant_id)
|
||||
mcp_provider.server_url_hash = server_url_hash
|
||||
try:
|
||||
db.session.commit()
|
||||
except IntegrityError as e:
|
||||
db.session.rollback()
|
||||
# Check if the error message contains the constraint name
|
||||
if "unique_mcp_provider_name" in str(e.orig):
|
||||
# Raise your custom exception
|
||||
raise ValueError(f"A provider with name '{name}' already exists.")
|
||||
elif "unique_mcp_provider_server_url" in str(e.orig):
|
||||
# You can define another custom exception for the other constraint
|
||||
raise ValueError(f"A provider for server URL '{server_url}' already exists.")
|
||||
error_msg = str(e.orig)
|
||||
if "unique_mcp_provider_name" in error_msg:
|
||||
raise ValueError(f"MCP tool {name} already exists")
|
||||
elif "unique_mcp_provider_server_url" in error_msg:
|
||||
raise ValueError(f"MCP tool {server_url} already exists")
|
||||
else:
|
||||
# Re-raise the original exception if it's not the one you're handling
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def update_mcp_provider_credentials(cls, tenant_id: str, provider_id: str, credentials: dict, authed: bool = False):
|
||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
if mcp_provider is None:
|
||||
raise ValueError("MCP tool not found")
|
||||
def update_mcp_provider_credentials(cls, mcp_provider: MCPToolProvider, credentials: dict, authed: bool = False):
|
||||
provider_controller = MCPToolProviderController._from_db(mcp_provider)
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
tenant_id=mcp_provider.tenant_id,
|
||||
config=list(provider_controller.get_credentials_schema()),
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.provider_id,
|
||||
@ -229,27 +198,22 @@ class MCPToolManageService:
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def get_mcp_provider_decrypted_credentials(cls, tenant_id: str, provider_id: str):
|
||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
if mcp_provider is None:
|
||||
raise ValueError("MCP tool not found")
|
||||
provider_controller = MCPToolProviderController._from_db(mcp_provider)
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=list(provider_controller.get_credentials_schema()),
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.provider_id,
|
||||
)
|
||||
return tool_configuration.decrypt(mcp_provider.credentials, use_cache=False)
|
||||
def _re_auth_mcp_provider(cls, mcp_provider: MCPToolProvider, provider_id: str, tenant_id: str):
|
||||
"""re-auth mcp provider"""
|
||||
try:
|
||||
with MCPClient(
|
||||
mcp_provider.decrypted_server_url,
|
||||
provider_id,
|
||||
tenant_id,
|
||||
authed=False,
|
||||
for_list=True,
|
||||
) as mcp_client:
|
||||
tools = mcp_client.list_tools()
|
||||
mcp_provider.authed = True
|
||||
mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools])
|
||||
except MCPAuthError:
|
||||
mcp_provider.authed = False
|
||||
mcp_provider.tools = "[]"
|
||||
|
||||
@classmethod
|
||||
def get_mcp_provider_server_url(cls, tenant_id: str, provider_id: str):
|
||||
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
|
||||
if mcp_provider is None:
|
||||
raise ValueError("MCP tool not found")
|
||||
return encrypter.decrypt_token(tenant_id, mcp_provider.server_url)
|
||||
|
||||
@classmethod
|
||||
def get_masked_mcp_provider_server_url(cls, tenant_id: str, provider_id: str):
|
||||
server_url = cls.get_mcp_provider_server_url(tenant_id, provider_id)
|
||||
return mask_url(server_url)
|
||||
# reset credentials
|
||||
mcp_provider.encrypted_credentials = "{}"
|
||||
|
||||
@ -191,8 +191,6 @@ class ToolTransformService:
|
||||
|
||||
@staticmethod
|
||||
def mcp_provider_to_user_provider(db_provider: MCPToolProvider, for_list: bool = False) -> ToolProviderApiEntity:
|
||||
from services.tools.mcp_tools_mange_service import MCPToolManageService
|
||||
|
||||
return ToolProviderApiEntity(
|
||||
id=db_provider.server_identifier if not for_list else db_provider.id,
|
||||
author=db_provider.user.name if db_provider.user else "Anonymous",
|
||||
@ -200,7 +198,7 @@ class ToolTransformService:
|
||||
icon=db_provider.provider_icon,
|
||||
type=ToolProviderType.MCP,
|
||||
is_team_authorization=db_provider.authed,
|
||||
server_url=MCPToolManageService.get_masked_mcp_provider_server_url(db_provider.tenant_id, db_provider.id),
|
||||
server_url=db_provider.masked_server_url,
|
||||
tools=ToolTransformService.mcp_tool_to_user_tool(
|
||||
db_provider, [MCPTool(**tool) for tool in json.loads(db_provider.tools)]
|
||||
),
|
||||
|
||||
Reference in New Issue
Block a user