refactor: streamable http app server

This commit is contained in:
Novice
2025-07-08 15:26:25 +08:00
parent 1a67cb77cc
commit 978598f06e
6 changed files with 71 additions and 35 deletions

View File

@ -1,15 +1,12 @@
from flask_restful import Resource, reqparse
from pydantic import ValidationError
from werkzeug.exceptions import NotFound
from controllers.console.app.mcp_server import AppMCPServerStatus
from controllers.mcp import api
from controllers.web.error import (
AppUnavailableError,
)
from core.app.app_config.entities import VariableEntity
from core.mcp.server.handler import MCPServerRequestHandler
from core.mcp.server.streamable_http import MCPServerStreamableHTTPRequestHandler
from core.mcp.types import ClientNotification, ClientRequest
from core.mcp.utils import create_mcp_error_response
from extensions.ext_database import db
from libs import helper
from models.model import App, AppMCPServer, AppMode
@ -30,33 +27,59 @@ class MCPAppApi(Resource):
parser.add_argument("id", type=int_or_str, required=False, location="json")
args = parser.parse_args()
request_id = args.get("id")
server = db.session.query(AppMCPServer).filter(AppMCPServer.server_code == server_code).first()
if not server:
raise NotFound("Server Not Found")
return helper.compact_generate_response(create_mcp_error_response(request_id, -32001, "Server Not Found"))
if server.status != AppMCPServerStatus.ACTIVE:
raise NotFound("Server is not active")
return helper.compact_generate_response(
create_mcp_error_response(request_id, -32001, "Server is not active")
)
app = db.session.query(App).filter(App.id == server.app_id).first()
if not app:
raise NotFound("App Not Found")
return helper.compact_generate_response(create_mcp_error_response(request_id, -32001, "App Not Found"))
if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
workflow = app.workflow
if workflow is None:
raise AppUnavailableError()
return helper.compact_generate_response(
create_mcp_error_response(request_id, -32001, "App is unavailable")
)
features_dict = workflow.features_dict
user_input_form = workflow.user_input_form(to_old_structure=True)
else:
app_model_config = app.app_model_config
if app_model_config is None:
raise AppUnavailableError()
return helper.compact_generate_response(
create_mcp_error_response(request_id, -32001, "App is unavailable")
)
features_dict = app_model_config.to_dict()
user_input_form = features_dict.get("user_input_form", [])
converted_user_input_form: list[VariableEntity] = []
try:
user_input_form = [VariableEntity.model_validate(list(item.values())[0]) for item in user_input_form]
for item in user_input_form:
variable_type = item.get("type", "") or list(item.keys())[0]
variable = item[variable_type]
converted_user_input_form.append(
VariableEntity(
type=variable_type,
variable=variable.get("variable"),
description=variable.get("description") or "",
label=variable.get("label"),
required=variable.get("required", False),
max_length=variable.get("max_length"),
options=variable.get("options") or [],
)
)
except ValidationError as e:
raise ValueError(f"Invalid user_input_form: {str(e)}")
return helper.compact_generate_response(
create_mcp_error_response(request_id, -32602, f"Invalid user_input_form: {str(e)}")
)
try:
request: ClientRequest | ClientNotification = ClientRequest.model_validate(args)
except ValidationError as e:
@ -64,9 +87,11 @@ class MCPAppApi(Resource):
notification = ClientNotification.model_validate(args)
request = notification
except ValidationError as e:
raise ValueError(f"Invalid MCP request: {str(e)}")
return helper.compact_generate_response(
create_mcp_error_response(request_id, -32602, f"Invalid MCP request: {str(e)}")
)
mcp_server_handler = MCPServerRequestHandler(app, request, user_input_form)
mcp_server_handler = MCPServerStreamableHTTPRequestHandler(app, request, converted_user_input_form)
response = mcp_server_handler.handle()
return helper.compact_generate_response(response)

View File

@ -36,7 +36,6 @@ class InvokeFrom(Enum):
# DEBUGGER indicates that this invocation is from
# the workflow (or chatflow) edit page.
DEBUGGER = "debugger"
MCP_SERVER = "mcp-server"
@classmethod
def value_of(cls, value: str):
@ -65,8 +64,6 @@ class InvokeFrom(Enum):
return "explore_app"
elif self == InvokeFrom.SERVICE_API:
return "api"
elif self == InvokeFrom.MCP_SERVER:
return "mcp_server"
return "dev"

View File

@ -9,6 +9,7 @@ from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.entities.app_invoke_entities import InvokeFrom
from core.mcp import types
from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND
from core.mcp.utils import create_mcp_error_response
from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from models.model import App, AppMCPServer, AppMode, EndUser
@ -20,7 +21,7 @@ Apply to MCP HTTP streamable server with stateless http
logger = logging.getLogger(__name__)
class MCPServerRequestHandler:
class MCPServerStreamableHTTPRequestHandler:
def __init__(
self, app: App, request: types.ClientRequest | types.ClientNotification, user_input_form: list[VariableEntity]
):
@ -78,17 +79,8 @@ class MCPServerRequestHandler:
yield sse_content
def error_response(self, code: int, message: str, data=None):
error_data = types.ErrorData(code=code, message=message, data=data)
json_response = types.JSONRPCError(
jsonrpc="2.0",
id=(self.request.root.model_extra or {}).get("id", 1) or 1,
error=error_data,
)
json_data = json.dumps(jsonable_encoder(json_response))
sse_content = f"event: message\ndata: {json_data}\n\n".encode()
yield sse_content
request_id = (self.request.root.model_extra or {}).get("id", 1) or 1
return create_mcp_error_response(request_id, code, message, data)
def handle(self):
handle_map = {
@ -158,7 +150,7 @@ class MCPServerRequestHandler:
args = {"inputs": args}
else:
args = {"query": args["query"], "inputs": {k: v for k, v in args.items() if k != "query"}}
response = AppGenerateService.generate(self.app, self.end_user, args, InvokeFrom.MCP_SERVER, streaming=False)
response = AppGenerateService.generate(self.app, self.end_user, args, InvokeFrom.SERVICE_API, streaming=False)
if isinstance(response, Mapping):
answer = ""
if self.app.mode in {
@ -196,7 +188,12 @@ class MCPServerRequestHandler:
continue
if item.required:
required.append(item.variable)
description = self.mcp_server.parameters_dict[item.variable]
# if the workflow republished, the parameters not changed
# we should not raise error here
try:
description = self.mcp_server.parameters_dict[item.variable]
except KeyError:
description = ""
parameters[item.variable]["description"] = description
if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
parameters[item.variable]["type"] = "string"

View File

@ -359,7 +359,7 @@ class BaseSession(
if response_queue is not None:
response_queue.put(message.message.root)
else:
self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}"))
self._handle_incoming(RuntimeError(f"Server Error: {message}"))
except queue.Empty:
continue
except Exception as e:

View File

@ -41,7 +41,7 @@ def _default_message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
if isinstance(message, Exception):
raise message
raise ValueError(str(message))
elif isinstance(message, (types.ServerNotification | RequestResponder)):
pass

View File

@ -1,6 +1,10 @@
import json
import httpx
from configs import dify_config
from core.mcp.types import ErrorData, JSONRPCError
from core.model_runtime.utils.encoders import jsonable_encoder
SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
@ -98,3 +102,16 @@ def ssrf_proxy_sse_connect(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
if not client_provided:
client.close()
raise
def create_mcp_error_response(request_id: int | str | None, code: int, message: str, data=None):
"""Create MCP error response"""
error_data = ErrorData(code=code, message=message, data=data)
json_response = JSONRPCError(
jsonrpc="2.0",
id=request_id or 1,
error=error_data,
)
json_data = json.dumps(jsonable_encoder(json_response))
sse_content = f"event: message\ndata: {json_data}\n\n".encode()
yield sse_content