From 978598f06ef16c0dfb3ea2f89613ee9210efcf12 Mon Sep 17 00:00:00 2001 From: Novice Date: Tue, 8 Jul 2025 15:26:25 +0800 Subject: [PATCH] refactor: streamable http app server --- api/controllers/mcp/mcp.py | 57 +++++++++++++------ api/core/app/entities/app_invoke_entities.py | 3 - .../server/{handler.py => streamable_http.py} | 25 ++++---- api/core/mcp/session/base_session.py | 2 +- api/core/mcp/session/client_session.py | 2 +- api/core/mcp/utils.py | 17 ++++++ 6 files changed, 71 insertions(+), 35 deletions(-) rename api/core/mcp/server/{handler.py => streamable_http.py} (92%) diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py index 22b313dbfc..3286ddb2be 100644 --- a/api/controllers/mcp/mcp.py +++ b/api/controllers/mcp/mcp.py @@ -1,15 +1,12 @@ from flask_restful import Resource, reqparse from pydantic import ValidationError -from werkzeug.exceptions import NotFound from controllers.console.app.mcp_server import AppMCPServerStatus from controllers.mcp import api -from controllers.web.error import ( - AppUnavailableError, -) from core.app.app_config.entities import VariableEntity -from core.mcp.server.handler import MCPServerRequestHandler +from core.mcp.server.streamable_http import MCPServerStreamableHTTPRequestHandler from core.mcp.types import ClientNotification, ClientRequest +from core.mcp.utils import create_mcp_error_response from extensions.ext_database import db from libs import helper from models.model import App, AppMCPServer, AppMode @@ -30,33 +27,59 @@ class MCPAppApi(Resource): parser.add_argument("id", type=int_or_str, required=False, location="json") args = parser.parse_args() + request_id = args.get("id") + server = db.session.query(AppMCPServer).filter(AppMCPServer.server_code == server_code).first() if not server: - raise NotFound("Server Not Found") + return helper.compact_generate_response(create_mcp_error_response(request_id, -32001, "Server Not Found")) + if server.status != AppMCPServerStatus.ACTIVE: - raise NotFound("Server is not active") + return helper.compact_generate_response( + create_mcp_error_response(request_id, -32001, "Server is not active") + ) + app = db.session.query(App).filter(App.id == server.app_id).first() if not app: - raise NotFound("App Not Found") + return helper.compact_generate_response(create_mcp_error_response(request_id, -32001, "App Not Found")) + if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: workflow = app.workflow if workflow is None: - raise AppUnavailableError() + return helper.compact_generate_response( + create_mcp_error_response(request_id, -32001, "App is unavailable") + ) - features_dict = workflow.features_dict user_input_form = workflow.user_input_form(to_old_structure=True) else: app_model_config = app.app_model_config if app_model_config is None: - raise AppUnavailableError() + return helper.compact_generate_response( + create_mcp_error_response(request_id, -32001, "App is unavailable") + ) features_dict = app_model_config.to_dict() - user_input_form = features_dict.get("user_input_form", []) + converted_user_input_form: list[VariableEntity] = [] try: - user_input_form = [VariableEntity.model_validate(list(item.values())[0]) for item in user_input_form] + for item in user_input_form: + variable_type = item.get("type", "") or list(item.keys())[0] + variable = item[variable_type] + converted_user_input_form.append( + VariableEntity( + type=variable_type, + variable=variable.get("variable"), + description=variable.get("description") or "", + label=variable.get("label"), + required=variable.get("required", False), + max_length=variable.get("max_length"), + options=variable.get("options") or [], + ) + ) except ValidationError as e: - raise ValueError(f"Invalid user_input_form: {str(e)}") + return helper.compact_generate_response( + create_mcp_error_response(request_id, -32602, f"Invalid user_input_form: {str(e)}") + ) + try: request: ClientRequest | ClientNotification = ClientRequest.model_validate(args) except ValidationError as e: @@ -64,9 +87,11 @@ class MCPAppApi(Resource): notification = ClientNotification.model_validate(args) request = notification except ValidationError as e: - raise ValueError(f"Invalid MCP request: {str(e)}") + return helper.compact_generate_response( + create_mcp_error_response(request_id, -32602, f"Invalid MCP request: {str(e)}") + ) - mcp_server_handler = MCPServerRequestHandler(app, request, user_input_form) + mcp_server_handler = MCPServerStreamableHTTPRequestHandler(app, request, converted_user_input_form) response = mcp_server_handler.handle() return helper.compact_generate_response(response) diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 138b5680a4..65ed267959 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -36,7 +36,6 @@ class InvokeFrom(Enum): # DEBUGGER indicates that this invocation is from # the workflow (or chatflow) edit page. DEBUGGER = "debugger" - MCP_SERVER = "mcp-server" @classmethod def value_of(cls, value: str): @@ -65,8 +64,6 @@ class InvokeFrom(Enum): return "explore_app" elif self == InvokeFrom.SERVICE_API: return "api" - elif self == InvokeFrom.MCP_SERVER: - return "mcp_server" return "dev" diff --git a/api/core/mcp/server/handler.py b/api/core/mcp/server/streamable_http.py similarity index 92% rename from api/core/mcp/server/handler.py rename to api/core/mcp/server/streamable_http.py index 86d4c31aa5..158557271f 100644 --- a/api/core/mcp/server/handler.py +++ b/api/core/mcp/server/streamable_http.py @@ -9,6 +9,7 @@ from core.app.app_config.entities import VariableEntity, VariableEntityType from core.app.entities.app_invoke_entities import InvokeFrom from core.mcp import types from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND +from core.mcp.utils import create_mcp_error_response from core.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from models.model import App, AppMCPServer, AppMode, EndUser @@ -20,7 +21,7 @@ Apply to MCP HTTP streamable server with stateless http logger = logging.getLogger(__name__) -class MCPServerRequestHandler: +class MCPServerStreamableHTTPRequestHandler: def __init__( self, app: App, request: types.ClientRequest | types.ClientNotification, user_input_form: list[VariableEntity] ): @@ -78,17 +79,8 @@ class MCPServerRequestHandler: yield sse_content def error_response(self, code: int, message: str, data=None): - error_data = types.ErrorData(code=code, message=message, data=data) - json_response = types.JSONRPCError( - jsonrpc="2.0", - id=(self.request.root.model_extra or {}).get("id", 1) or 1, - error=error_data, - ) - json_data = json.dumps(jsonable_encoder(json_response)) - - sse_content = f"event: message\ndata: {json_data}\n\n".encode() - - yield sse_content + request_id = (self.request.root.model_extra or {}).get("id", 1) or 1 + return create_mcp_error_response(request_id, code, message, data) def handle(self): handle_map = { @@ -158,7 +150,7 @@ class MCPServerRequestHandler: args = {"inputs": args} else: args = {"query": args["query"], "inputs": {k: v for k, v in args.items() if k != "query"}} - response = AppGenerateService.generate(self.app, self.end_user, args, InvokeFrom.MCP_SERVER, streaming=False) + response = AppGenerateService.generate(self.app, self.end_user, args, InvokeFrom.SERVICE_API, streaming=False) if isinstance(response, Mapping): answer = "" if self.app.mode in { @@ -196,7 +188,12 @@ class MCPServerRequestHandler: continue if item.required: required.append(item.variable) - description = self.mcp_server.parameters_dict[item.variable] + # if the workflow republished, the parameters not changed + # we should not raise error here + try: + description = self.mcp_server.parameters_dict[item.variable] + except KeyError: + description = "" parameters[item.variable]["description"] = description if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH): parameters[item.variable]["type"] = "string" diff --git a/api/core/mcp/session/base_session.py b/api/core/mcp/session/base_session.py index a85eee5219..ac344ec395 100644 --- a/api/core/mcp/session/base_session.py +++ b/api/core/mcp/session/base_session.py @@ -359,7 +359,7 @@ class BaseSession( if response_queue is not None: response_queue.put(message.message.root) else: - self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}")) + self._handle_incoming(RuntimeError(f"Server Error: {message}")) except queue.Empty: continue except Exception as e: diff --git a/api/core/mcp/session/client_session.py b/api/core/mcp/session/client_session.py index 518920c60b..7660db45c5 100644 --- a/api/core/mcp/session/client_session.py +++ b/api/core/mcp/session/client_session.py @@ -41,7 +41,7 @@ def _default_message_handler( message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: if isinstance(message, Exception): - raise message + raise ValueError(str(message)) elif isinstance(message, (types.ServerNotification | RequestResponder)): pass diff --git a/api/core/mcp/utils.py b/api/core/mcp/utils.py index a8a603b3f2..b177f34f9b 100644 --- a/api/core/mcp/utils.py +++ b/api/core/mcp/utils.py @@ -1,6 +1,10 @@ +import json + import httpx from configs import dify_config +from core.mcp.types import ErrorData, JSONRPCError +from core.model_runtime.utils.encoders import jsonable_encoder SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES @@ -98,3 +102,16 @@ def ssrf_proxy_sse_connect(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): if not client_provided: client.close() raise + + +def create_mcp_error_response(request_id: int | str | None, code: int, message: str, data=None): + """Create MCP error response""" + error_data = ErrorData(code=code, message=message, data=data) + json_response = JSONRPCError( + jsonrpc="2.0", + id=request_id or 1, + error=error_data, + ) + json_data = json.dumps(jsonable_encoder(json_response)) + sse_content = f"event: message\ndata: {json_data}\n\n".encode() + yield sse_content