Merge branch 'feat/mcp' into deploy/dev

This commit is contained in:
Novice
2025-07-02 15:25:17 +08:00
15 changed files with 90 additions and 79 deletions

View File

@ -1,9 +1,9 @@
import json
from enum import Enum
from enum import StrEnum
from flask_login import current_user
from flask_restful import Resource, marshal_with, reqparse
from werkzeug.exceptions import Forbidden
from werkzeug.exceptions import NotFound
from controllers.console import api
from controllers.console.app.wraps import get_app_model
@ -14,7 +14,7 @@ from libs.login import login_required
from models.model import AppMCPServer
class AppMCPServerStatus(str, Enum):
class AppMCPServerStatus(StrEnum):
ACTIVE = "active"
INACTIVE = "inactive"
@ -37,7 +37,7 @@ class AppMCPServerController(Resource):
def post(self, app_model):
# The role of the current user in the ta table must be editor, admin, or owner
if not current_user.is_editor:
raise Forbidden()
raise NotFound()
parser = reqparse.RequestParser()
parser.add_argument("description", type=str, required=True, location="json")
parser.add_argument("parameters", type=dict, required=True, location="json")
@ -62,7 +62,7 @@ class AppMCPServerController(Resource):
@marshal_with(app_server_fields)
def put(self, app_model):
if not current_user.is_editor:
raise Forbidden()
raise NotFound()
parser = reqparse.RequestParser()
parser.add_argument("id", type=str, required=True, location="json")
parser.add_argument("description", type=str, required=True, location="json")
@ -71,7 +71,7 @@ class AppMCPServerController(Resource):
args = parser.parse_args()
server = db.session.query(AppMCPServer).filter(AppMCPServer.id == args["id"]).first()
if not server:
raise Forbidden()
raise NotFound()
server.description = args["description"]
server.parameters = json.dumps(args["parameters"], ensure_ascii=False)
if args["status"]:
@ -89,10 +89,10 @@ class AppMCPServerRefreshController(Resource):
@marshal_with(app_server_fields)
def get(self, server_id):
if not current_user.is_editor:
raise Forbidden()
raise NotFound()
server = db.session.query(AppMCPServer).filter(AppMCPServer.id == server_id).first()
if not server:
raise Forbidden()
raise NotFound()
server.server_code = AppMCPServer.generate_server_code(16)
db.session.commit()
return server

View File

@ -1,6 +1,6 @@
import io
from urllib.parse import urlparse
import validators
from flask import redirect, send_file
from flask_login import current_user
from flask_restful import Resource, reqparse
@ -27,6 +27,17 @@ from services.tools.tools_transform_service import ToolTransformService
from services.tools.workflow_tools_manage_service import WorkflowToolManageService
def is_valid_url(url: str) -> bool:
if not url:
return False
try:
parsed = urlparse(url)
return all([parsed.scheme, parsed.netloc]) and parsed.scheme in ["http", "https"]
except Exception:
return False
class ToolProviderListApi(Resource):
@setup_required
@login_required
@ -634,7 +645,7 @@ class ToolProviderMCPApi(Resource):
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
user = current_user
if not validators.url(args["server_url"]):
if not is_valid_url(args["server_url"]):
raise ValueError("Server URL is not valid.")
return jsonable_encoder(
MCPToolManageService.create_mcp_provider(
@ -662,7 +673,7 @@ class ToolProviderMCPApi(Resource):
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
if not validators.url(args["server_url"]):
if not is_valid_url(args["server_url"]):
if "[__HIDDEN__]" in args["server_url"]:
pass
else:

View File

@ -8,7 +8,7 @@ from controllers.web.error import (
AppUnavailableError,
)
from core.app.app_config.entities import VariableEntity
from core.mcp.server.handler import MCPServerReuqestHandler
from core.mcp.server.handler import MCPServerRequestHandler
from core.mcp.types import ClientNotification, ClientRequest
from extensions.ext_database import db
from libs import helper
@ -66,7 +66,7 @@ class MCPAppApi(Resource):
except ValidationError as e:
raise ValueError(f"Invalid MCP request: {str(e)}")
mcp_server_handler = MCPServerReuqestHandler(app, request, user_input_form)
mcp_server_handler = MCPServerRequestHandler(app, request, user_input_form)
response = mcp_server_handler.handle()
return helper.compact_generate_response(response)

View File

@ -8,7 +8,7 @@ from typing import Optional
from urllib.parse import urljoin
import requests
from pydantic import BaseModel
from pydantic import BaseModel, ValidationError
from core.mcp.auth.auth_provider import OAuthClientProvider
from core.mcp.types import (
@ -60,7 +60,7 @@ def _create_secure_redis_state(state_data: OAuthCallbackState) -> str:
def _retrieve_redis_state(state_key: str) -> OAuthCallbackState:
"""Retrieve and decode OAuth state data from Redis using the state key."""
"""Retrieve and decode OAuth state data from Redis using the state key, then delete it."""
redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}"
# Get state data from Redis
@ -69,27 +69,23 @@ def _retrieve_redis_state(state_key: str) -> OAuthCallbackState:
if not state_data:
raise ValueError("State parameter has expired or does not exist")
# Delete the state data from Redis immediately after retrieval to prevent reuse
redis_client.delete(redis_key)
try:
# Parse and validate the state data
if isinstance(state_data, bytes):
state_data = state_data.decode("utf-8")
oauth_state = OAuthCallbackState.model_validate_json(state_data)
return oauth_state
except Exception as e:
except ValidationError as e:
raise ValueError(f"Invalid state parameter: {str(e)}")
def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackState:
"""Handle the callback from the OAuth provider."""
# Retrieve state data from Redis
# Retrieve state data from Redis (state is automatically deleted after retrieval)
full_state_data = _retrieve_redis_state(state_key)
# Clean up the state data from Redis after successful retrieval
redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}"
redis_client.delete(redis_key)
tokens = exchange_authorization(
full_state_data.server_url,
full_state_data.metadata,

View File

@ -3,7 +3,7 @@ import queue
from collections.abc import Generator
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from typing import Any, cast
from typing import Any, TypeAlias, final
from urllib.parse import urljoin, urlparse
import httpx
@ -18,10 +18,23 @@ logger = logging.getLogger(__name__)
DEFAULT_QUEUE_READ_TIMEOUT = 3
@final
class _StatusReady:
def __init__(self, endpoint_url: str):
self._endpoint_url = endpoint_url
@final
class _StatusError:
def __init__(self, exc: Exception):
self._exc = exc
# Type aliases for better readability
ReadQueue = queue.Queue[SessionMessage | Exception | None]
WriteQueue = queue.Queue[SessionMessage | Exception | None]
StatusQueue = queue.Queue[tuple[str, str | Exception]]
ReadQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
WriteQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
StatusQueue: TypeAlias = queue.Queue[_StatusReady | _StatusError]
def remove_request_params(url: str) -> str:
@ -80,10 +93,10 @@ class SSETransport:
if not self._validate_endpoint_url(endpoint_url):
error_msg = f"Endpoint origin does not match connection origin: {endpoint_url}"
logger.error(error_msg)
status_queue.put(("error", ValueError(error_msg)))
status_queue.put(_StatusError(ValueError(error_msg)))
return
status_queue.put(("ready", endpoint_url))
status_queue.put(_StatusReady(endpoint_url))
def _handle_message_event(self, sse_data: str, read_queue: ReadQueue) -> None:
"""Handle a 'message' SSE event.
@ -197,18 +210,17 @@ class SSETransport:
ValueError: If endpoint URL is not received or there's an error.
"""
try:
status, endpoint_url_or_error = status_queue.get(timeout=1)
status = status_queue.get(timeout=1)
except queue.Empty:
raise ValueError("failed to get endpoint URL")
if status != "ready":
if isinstance(status, _StatusReady):
return status._endpoint_url
elif isinstance(status, _StatusError):
raise status._exc
else:
raise ValueError("failed to get endpoint URL")
if status == "error" and isinstance(endpoint_url_or_error, Exception):
raise endpoint_url_or_error
return cast(str, endpoint_url_or_error)
def connect(
self,
executor: ThreadPoolExecutor,
@ -284,9 +296,9 @@ def sse_client(
if exc.response.status_code == 401:
raise MCPAuthError()
raise MCPConnectionError()
except Exception as exc:
except Exception:
logger.exception("Error connecting to SSE endpoint")
raise exc
raise
finally:
# Clean up queues
if read_queue:

View File

@ -94,14 +94,15 @@ class MCPClient:
if self._streams_context is None:
raise MCPConnectionError("Failed to create connection context")
# Use exit_stack to manage context managers properly
if method_name == "mcp":
read_stream, write_stream, _ = self._streams_context.__enter__()
read_stream, write_stream, _ = self.exit_stack.enter_context(self._streams_context)
streams = (read_stream, write_stream)
else: # sse_client
streams = self._streams_context.__enter__()
streams = self.exit_stack.enter_context(self._streams_context)
self._session_context = ClientSession(*streams)
self._session = self._session_context.__enter__()
self._session = self.exit_stack.enter_context(self._session_context)
session = cast(ClientSession, self._session)
session.initialize()
return
@ -138,14 +139,12 @@ class MCPClient:
def cleanup(self):
"""Clean up resources"""
try:
if self._session_context:
self._session_context.__exit__(None, None, None)
if self._streams_context:
self._streams_context.__exit__(None, None, None)
self._session = None
self._initialized = False
# ExitStack will handle proper cleanup of all managed context managers
self.exit_stack.close()
self._session = None
self._session_context = None
self._streams_context = None
self._initialized = False
except Exception as e:
logging.exception("Error during cleanup")
raise ValueError(f"Error during cleanup: {e}")

View File

@ -18,15 +18,16 @@ Apply to MCP HTTP streamable server with stateless http
"""
class MCPServerReuqestHandler:
class MCPServerRequestHandler:
def __init__(
self, app: App, request: types.ClientRequest | types.ClientNotification, user_input_form: list[VariableEntity]
):
self.app = app
self.request = request
if not self.app.mcp_server:
mcp_server = db.session.query(AppMCPServer).filter(AppMCPServer.app_id == self.app.id).first()
if not mcp_server:
raise ValueError("MCP server not found")
self.mcp_server: AppMCPServer = self.app.mcp_server
self.mcp_server: AppMCPServer = mcp_server
self.end_user = self.retrieve_end_user()
self.user_input_form = user_input_form

View File

@ -4,20 +4,8 @@ from configs import dify_config
SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
HTTP_REQUEST_NODE_SSL_VERIFY = True # Default value for HTTP_REQUEST_NODE_SSL_VERIFY is True
try:
HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
http_request_node_ssl_verify_lower = str(HTTP_REQUEST_NODE_SSL_VERIFY).lower()
if http_request_node_ssl_verify_lower == "true":
HTTP_REQUEST_NODE_SSL_VERIFY = True
elif http_request_node_ssl_verify_lower == "false":
HTTP_REQUEST_NODE_SSL_VERIFY = False
else:
raise ValueError("Invalid value. HTTP_REQUEST_NODE_SSL_VERIFY should be 'True' or 'False'")
except NameError:
HTTP_REQUEST_NODE_SSL_VERIFY = True
HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
BACKOFF_FACTOR = 0.5
STATUS_FORCELIST = [429, 500, 502, 503, 504]

View File

@ -46,11 +46,11 @@ class MCPToolProviderController(ToolProviderController):
tools = []
tools_data = json.loads(db_provider.tools)
remote_mcp_tools = [RemoteMCPTool(**tool) for tool in tools_data]
user = db_provider.load_user()
tools = [
ToolEntity(
identity=ToolIdentity(
author=db_provider.user.name if db_provider.user else "Anonymous",
author=user.name if user else "Anonymous",
name=remote_mcp_tool.name,
label=I18nObject(en_US=remote_mcp_tool.name, zh_Hans=remote_mcp_tool.name),
provider=db_provider.server_identifier,
@ -72,7 +72,7 @@ class MCPToolProviderController(ToolProviderController):
return cls(
entity=ToolProviderEntityWithPlugin(
identity=ToolProviderIdentity(
author=db_provider.user.name if db_provider.user else "Anonymous",
author=user.name if user else "Anonymous",
name=db_provider.name,
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
description=I18nObject(en_US="", zh_Hans=""),

View File

@ -1,4 +1,5 @@
import base64
import json
from collections.abc import Generator
from typing import Any, Optional
@ -49,6 +50,11 @@ class MCPTool(Tool):
for content in result.content:
if isinstance(content, TextContent):
yield self.create_text_message(content.text)
try:
yield self.create_json_message(json.loads(content.text))
except json.JSONDecodeError:
pass
elif isinstance(content, ImageContent):
yield self.create_blob_message(
blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType}

View File

@ -32,7 +32,7 @@ def upgrade():
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='app_mcp_server_pkey'),
sa.UniqueConstraint('tenant_id', 'app_id', name='unique_app_mcp_server_tenant_app_id'),
sa.UniqueConstraint('tenant_id', 'server_code', name='unique_app_mcp_server_tenant_server_code')
sa.UniqueConstraint('server_code', name='unique_app_mcp_server_server_code')
)
op.create_table('tool_mcp_providers',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),

View File

@ -294,10 +294,6 @@ class App(Base):
return tags or []
@property
def mcp_server(self):
return db.session.query(AppMCPServer).filter(AppMCPServer.app_id == self.id).first()
@property
def author_name(self):
if self.created_by:
@ -1465,7 +1461,7 @@ class AppMCPServer(Base):
__table_args__ = (
db.PrimaryKeyConstraint("id", name="app_mcp_server_pkey"),
db.UniqueConstraint("tenant_id", "app_id", name="unique_app_mcp_server_tenant_app_id"),
db.UniqueConstraint("tenant_id", "server_code", name="unique_app_mcp_server_tenant_server_code"),
db.UniqueConstraint("server_code", name="unique_app_mcp_server_server_code"),
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)

View File

@ -234,8 +234,7 @@ class MCPToolProvider(Base):
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
)
@property
def user(self) -> Account | None:
def load_user(self) -> Account | None:
return db.session.query(Account).filter(Account.id == self.user_id).first()
@property

View File

@ -125,13 +125,14 @@ class MCPToolManageService:
mcp_provider.authed = True
mcp_provider.updated_at = datetime.now()
db.session.commit()
user = mcp_provider.load_user()
return ToolProviderApiEntity(
id=mcp_provider.id,
name=mcp_provider.name,
tools=ToolTransformService.mcp_tool_to_user_tool(mcp_provider, tools),
type=ToolProviderType.MCP,
icon=mcp_provider.icon,
author=mcp_provider.user.name if mcp_provider.user else "Anonymous",
author=user.name if user else "Anonymous",
server_url=mcp_provider.masked_server_url,
updated_at=int(mcp_provider.updated_at.timestamp()),
description=I18nObject(en_US="", zh_Hans=""),

View File

@ -191,9 +191,10 @@ class ToolTransformService:
@staticmethod
def mcp_provider_to_user_provider(db_provider: MCPToolProvider, for_list: bool = False) -> ToolProviderApiEntity:
user = db_provider.load_user()
return ToolProviderApiEntity(
id=db_provider.server_identifier if not for_list else db_provider.id,
author=db_provider.user.name if db_provider.user else "Anonymous",
author=user.name if user else "Anonymous",
name=db_provider.name,
icon=db_provider.provider_icon,
type=ToolProviderType.MCP,
@ -210,9 +211,10 @@ class ToolTransformService:
@staticmethod
def mcp_tool_to_user_tool(mcp_provider: MCPToolProvider, tools: list[MCPTool]) -> list[ToolApiEntity]:
user = mcp_provider.load_user()
return [
ToolApiEntity(
author=mcp_provider.user.name if mcp_provider.user else "Anonymous",
author=user.name if user else "Anonymous",
name=tool.name,
label=I18nObject(en_US=tool.name, zh_Hans=tool.name),
description=I18nObject(en_US=tool.description, zh_Hans=tool.description),