mirror of
https://github.com/langgenius/dify.git
synced 2026-04-27 05:58:14 +08:00
refactor(mcp): clean the client service code
This commit is contained in:
@ -1,16 +1,12 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, cast
|
||||
from urllib.parse import urlparse
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from deprecated import deprecated
|
||||
from sqlalchemy import ForeignKey, String, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import encrypter
|
||||
from core.mcp.types import Tool
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
|
||||
@ -20,6 +16,9 @@ from .engine import db
|
||||
from .model import Account, App, Tenant
|
||||
from .types import StringUUID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.entities.mcp_provider import MCPProviderEntity
|
||||
|
||||
|
||||
# system level tool oauth client params (client_id, client_secret, etc.)
|
||||
class ToolOAuthSystemClient(TypeBase):
|
||||
@ -286,119 +285,34 @@ class MCPToolProvider(Base):
|
||||
def load_user(self) -> Account | None:
|
||||
return db.session.query(Account).where(Account.id == self.user_id).first()
|
||||
|
||||
@property
|
||||
def tenant(self) -> Tenant | None:
|
||||
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
|
||||
|
||||
@property
|
||||
def credentials(self) -> dict[str, Any]:
|
||||
try:
|
||||
return cast(dict[str, Any], json.loads(self.encrypted_credentials)) or {}
|
||||
return json.loads(self.encrypted_credentials)
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def mcp_tools(self) -> list[Tool]:
|
||||
return [Tool(**tool) for tool in json.loads(self.tools)]
|
||||
|
||||
@property
|
||||
def provider_icon(self) -> dict[str, str] | str:
|
||||
def headers(self) -> dict[str, Any]:
|
||||
if self.encrypted_headers is None:
|
||||
return {}
|
||||
try:
|
||||
return cast(dict[str, str], json.loads(self.icon))
|
||||
except json.JSONDecodeError:
|
||||
return file_helpers.get_signed_file_url(self.icon)
|
||||
|
||||
@property
|
||||
def decrypted_server_url(self) -> str:
|
||||
return encrypter.decrypt_token(self.tenant_id, self.server_url)
|
||||
|
||||
@property
|
||||
def decrypted_headers(self) -> dict[str, Any]:
|
||||
"""Get decrypted headers for MCP server requests."""
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.tools.utils.encryption import create_provider_encrypter
|
||||
|
||||
try:
|
||||
if not self.encrypted_headers:
|
||||
return {}
|
||||
|
||||
headers_data = json.loads(self.encrypted_headers)
|
||||
|
||||
# Create dynamic config for all headers as SECRET_INPUT
|
||||
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers_data]
|
||||
|
||||
encrypter_instance, _ = create_provider_encrypter(
|
||||
tenant_id=self.tenant_id,
|
||||
config=config,
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
result = encrypter_instance.decrypt(headers_data)
|
||||
return result
|
||||
return json.loads(self.encrypted_headers)
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def masked_headers(self) -> dict[str, Any]:
|
||||
"""Get masked headers for frontend display."""
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.tools.utils.encryption import create_provider_encrypter
|
||||
|
||||
def tool_dict(self) -> list[dict[str, Any]]:
|
||||
try:
|
||||
if not self.encrypted_headers:
|
||||
return {}
|
||||
return json.loads(self.tools) if self.tools else []
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return []
|
||||
|
||||
headers_data = json.loads(self.encrypted_headers)
|
||||
def to_entity(self) -> "MCPProviderEntity":
|
||||
"""Convert to domain entity"""
|
||||
from core.entities.mcp_provider import MCPProviderEntity
|
||||
|
||||
# Create dynamic config for all headers as SECRET_INPUT
|
||||
config = [BasicProviderConfig(type=BasicProviderConfig.Type.SECRET_INPUT, name=key) for key in headers_data]
|
||||
|
||||
encrypter_instance, _ = create_provider_encrypter(
|
||||
tenant_id=self.tenant_id,
|
||||
config=config,
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
# First decrypt, then mask
|
||||
decrypted_headers = encrypter_instance.decrypt(headers_data)
|
||||
result = encrypter_instance.mask_tool_credentials(decrypted_headers)
|
||||
return result
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def masked_server_url(self) -> str:
|
||||
def mask_url(url: str, mask_char: str = "*") -> str:
|
||||
"""
|
||||
mask the url to a simple string
|
||||
"""
|
||||
parsed = urlparse(url)
|
||||
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
||||
|
||||
if parsed.path and parsed.path != "/":
|
||||
return f"{base_url}/{mask_char * 6}"
|
||||
else:
|
||||
return base_url
|
||||
|
||||
return mask_url(self.decrypted_server_url)
|
||||
|
||||
@property
|
||||
def decrypted_credentials(self) -> dict[str, Any]:
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.tools.mcp_tool.provider import MCPToolProviderController
|
||||
from core.tools.utils.encryption import create_provider_encrypter
|
||||
|
||||
provider_controller = MCPToolProviderController.from_db(self)
|
||||
|
||||
encrypter, _ = create_provider_encrypter(
|
||||
tenant_id=self.tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
return encrypter.decrypt(self.credentials)
|
||||
return MCPProviderEntity.from_db_model(self)
|
||||
|
||||
|
||||
class ToolModelInvoke(Base):
|
||||
|
||||
Reference in New Issue
Block a user