feat: add mcp server

This commit is contained in:
Novice
2025-05-22 18:16:46 +08:00
parent c1a58ac160
commit bdb4395319
11 changed files with 377 additions and 4 deletions

View File

@ -56,6 +56,7 @@ from .app import (
conversation,
conversation_variables,
generator,
mcp_server,
message,
model_config,
ops_trace,

View File

@ -0,0 +1,83 @@
import json
from enum import Enum
from flask_login import current_user
from flask_restful import Resource, marshal_with, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from fields.app_fields import app_server_fields
from libs.login import login_required
from models.model import AppMCPServer
class AppMCPServerStatus(str, Enum):
ACTIVE = "active"
INACTIVE = "inactive"
class AppMCPServerController(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model
@marshal_with(app_server_fields)
def get(self, app_model):
server = db.session.query(AppMCPServer).filter(AppMCPServer.app_id == app_model.id).first()
return server
@setup_required
@login_required
@account_initialization_required
@get_app_model
@marshal_with(app_server_fields)
def post(self, app_model):
# The role of the current user in the ta table must be editor, admin, or owner
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("description", type=str, required=True, location="json")
parser.add_argument("parameters", type=dict, required=True, location="json")
args = parser.parse_args()
server = AppMCPServer(
name=app_model.name,
description=args["description"],
parameters=json.dumps(args["parameters"], ensure_ascii=False),
status=AppMCPServerStatus.ACTIVE,
app_id=app_model.id,
tenant_id=current_user.current_tenant_id,
server_code=AppMCPServer.generate_server_code(16),
)
db.session.add(server)
db.session.commit()
return server
@setup_required
@login_required
@account_initialization_required
@get_app_model
@marshal_with(app_server_fields)
def put(self, app_model):
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("id", type=str, required=True, location="json")
parser.add_argument("description", type=str, required=True, location="json")
parser.add_argument("parameters", type=dict, required=True, location="json")
parser.add_argument("status", type=str, required=True, location="json")
args = parser.parse_args()
server = db.session.query(AppMCPServer).filter(AppMCPServer.id == args["id"]).first()
if not server:
raise Forbidden()
server.description = args["description"]
server.parameters = json.dumps(args["parameters"], ensure_ascii=False)
server.status = AppMCPServerStatus(args["status"])
db.session.commit()
return server
api.add_resource(AppMCPServerController, "/apps/<uuid:app_id>/server")

View File

@ -1,6 +1,7 @@
import logging
from flask_restful import reqparse
from flask_restful import Resource, reqparse
from pydantic import ValidationError
from werkzeug.exceptions import InternalServerError, NotFound
import services
@ -24,10 +25,13 @@ from core.errors.error import (
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.mcp.server.handler import MCPServerReuqestHandler
from core.mcp.types import ClientRequest
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from libs import helper
from libs.helper import uuid_value
from models.model import AppMode
from models.model import App, AppMCPServer, AppMode
from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError
@ -149,7 +153,38 @@ class ChatStopApi(WebApiResource):
return {"result": "success"}, 200
class ChatMCPApi(Resource):
def post(self, server_code):
def int_or_str(value):
if isinstance(value, int):
return value
elif isinstance(value, str):
return int(value)
else:
raise ValueError("Invalid id")
parser = reqparse.RequestParser()
parser.add_argument("jsonrpc", type=str, required=True, location="json")
parser.add_argument("method", type=str, required=True, location="json")
parser.add_argument("params", type=dict, required=True, location="json")
parser.add_argument("id", type=int_or_str, required=True, location="json")
args = parser.parse_args()
server = db.session.query(AppMCPServer).filter(AppMCPServer.server_code == server_code).first()
if not server:
raise NotFound("Server Not Found")
app = db.session.query(App).filter(App.id == server.app_id).first()
if not app:
raise NotFound("App Not Found")
try:
request = ClientRequest.model_validate(args)
except ValidationError as e:
raise ValueError(f"Invalid MCP request: {str(e)}")
mcp_server_handler = MCPServerReuqestHandler(app, request)
return helper.compact_generate_response(mcp_server_handler.handle())
api.add_resource(CompletionApi, "/completion-messages")
api.add_resource(CompletionStopApi, "/completion-messages/<string:task_id>/stop")
api.add_resource(ChatApi, "/chat-messages")
api.add_resource(ChatMCPApi, "/server/<string:server_code>/mcp")
api.add_resource(ChatStopApi, "/chat-messages/<string:task_id>/stop")

View File

@ -21,6 +21,7 @@ class InvokeFrom(Enum):
WEB_APP = "web-app"
EXPLORE = "explore"
DEBUGGER = "debugger"
MCP_SERVER = "mcp-server"
@classmethod
def value_of(cls, value: str):
@ -49,6 +50,8 @@ class InvokeFrom(Enum):
return "explore_app"
elif self == InvokeFrom.SERVICE_API:
return "api"
elif self == InvokeFrom.MCP_SERVER:
return "mcp_server"
return "dev"

View File

@ -0,0 +1,154 @@
import json
from collections.abc import Mapping
from typing import cast
from configs.app_config import DifyConfig
from controllers.web.passport import generate_session_id
from core.app.entities.app_invoke_entities import InvokeFrom
from core.mcp import types
from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND
from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from models.model import App, EndUser
from services.app_generate_service import AppGenerateService
"""
Apply to MCP HTTP streamable server with stateless http
"""
dify_config = DifyConfig()
class MCPServerReuqestHandler:
def __init__(self, app: App, request: types.ClientRequest):
self.app = app
self.request = request
if not self.app.mcp_server:
raise ValueError("MCP server not found")
self.mcp_server = self.app.mcp_server
self.end_user = self.retrieve_end_user()
@property
def request_type(self):
return type(self.request.root)
@property
def parameter_schema(self):
return {
"type": "object",
"properties": {
"query": {"type": "string", "description": "User Input/Question content"},
"inputs": {
"type": "object",
"description": "Allows the entry of various variable values defined by the App. The `inputs` parameter contains multiple key/value pairs, with each key corresponding to a specific variable and each value being the specific value for that variable. If the variable is of file type, specify an object that has the keys described in `files`.", # noqa: E501
"default": {},
# TODO: add input parameters
},
},
"required": ["query"],
}
@property
def output_parameters(self):
return self.app.output_schema
@property
def capabilities(self):
return types.ServerCapabilities(
tools=types.ToolsCapability(listChanged=False),
)
def response(self, response: types.Result):
json_response = types.JSONRPCResponse(
jsonrpc="2.0",
id=(self.request.root.model_extra or {}).get("id", 1),
result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
)
json_data = json.dumps(jsonable_encoder(json_response))
sse_content = f"event: message\ndata: {json_data}\n\n".encode()
yield sse_content
def error_response(self, code: int, message: str, data=None):
error_data = types.ErrorData(code=code, message=message, data=data)
json_response = types.JSONRPCError(
jsonrpc="2.0",
id=(self.request.root.model_extra or {}).get("id", 1),
error=error_data,
)
json_data = json.dumps(jsonable_encoder(json_response))
sse_content = f"event: message\ndata: {json_data}\n\n".encode()
yield sse_content
def handle(self):
handle_map = {
types.InitializeRequest: self.initialize,
types.ListToolsRequest: self.list_tools,
types.CallToolRequest: self.invoke_tool,
}
try:
if self.request_type in handle_map:
return self.response(handle_map[self.request_type]())
else:
return self.error_response(METHOD_NOT_FOUND, f"Method not found: {self.request_type}")
except ValueError as e:
return self.error_response(INVALID_PARAMS, str(e))
except Exception as e:
return self.error_response(INTERNAL_ERROR, f"Internal server error: {str(e)}")
def initialize(self):
request = cast(types.InitializeRequest, self.request.root)
client_info = request.params.clientInfo
clinet_name = f"{client_info.name}@{client_info.version}"
if not self.end_user:
end_user = EndUser(
tenant_id=self.app.tenant_id,
app_id=self.app.id,
type="mcp",
name=clinet_name,
session_id=generate_session_id(),
external_user_id=self.mcp_server.id,
)
db.session.add(end_user)
db.session.commit()
return types.InitializeResult(
protocolVersion=types.LATEST_PROTOCOL_VERSION,
capabilities=self.capabilities,
serverInfo=types.Implementation(name="Dify", version=dify_config.CURRENT_VERSION),
instructions=self.mcp_server.description,
)
def list_tools(self):
if not self.end_user:
raise ValueError("User not found")
return types.ListToolsResult(
tools=[
types.Tool(
name=self.mcp_server.name,
description=self.mcp_server.description,
inputSchema=self.parameter_schema,
)
],
)
def invoke_tool(self):
if not self.end_user:
raise ValueError("User not found")
request = cast(types.CallToolRequest, self.request.root)
args = request.params.arguments
if not args:
raise ValueError("No arguments provided")
response = AppGenerateService.generate(self.app, self.end_user, args, InvokeFrom.MCP_SERVER, streaming=False)
if isinstance(response, Mapping):
return types.CallToolResult(content=[types.TextContent(text=response["answer"], type="text")])
return None
def retrieve_end_user(self):
return (
db.session.query(EndUser)
.filter(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
.first()
)

View File

@ -3,11 +3,13 @@ from typing import Any, Protocol
from pydantic import AnyUrl, TypeAdapter
from configs.app_config import DifyConfig
from core.mcp import types
from core.mcp.entities import SUPPORTED_PROTOCOL_VERSIONS, RequestContext
from core.mcp.session.base_session import BaseSession, RequestResponder
DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0")
dify_config = DifyConfig()
DEFAULT_CLIENT_INFO = types.Implementation(name="Dify", version=dify_config.CURRENT_VERSION)
class SamplingFnT(Protocol):

View File

@ -213,3 +213,14 @@ app_import_fields = {
app_import_check_dependencies_fields = {
"leaked_dependencies": fields.List(fields.Nested(leaked_dependency_fields)),
}
app_server_fields = {
"id": fields.String,
"name": fields.String,
"server_code": fields.String,
"description": fields.String,
"status": fields.String,
"parameters": fields.Raw,
"created_at": TimestampField,
"updated_at": TimestampField,
}

View File

@ -0,0 +1,41 @@
"""add app mcp server
Revision ID: ca4c4abcc347
Revises: 1e67f2654a08
Create Date: 2025-05-22 16:23:44.206102
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'ca4c4abcc347'
down_revision = '1e67f2654a08'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('app_mcp_servers',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('app_id', models.types.StringUUID(), nullable=False),
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('description', sa.String(length=255), nullable=False),
sa.Column('server_code', sa.String(length=255), nullable=False),
sa.Column('status', sa.String(length=255), server_default=sa.text("'normal'::character varying"), nullable=False),
sa.Column('parameters', sa.Text(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='app_mcp_server_pkey')
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('app_mcp_servers')
# ### end Alembic commands ###

View File

@ -294,6 +294,10 @@ class App(Base):
return tags or []
@property
def mcp_server(self):
return db.session.query(AppMCPServer).filter(AppMCPServer.app_id == self.id).first()
class AppModelConfig(Base):
__tablename__ = "app_model_configs"
@ -1433,6 +1437,31 @@ class EndUser(Base, UserMixin):
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class AppMCPServer(Base):
__tablename__ = "app_mcp_servers"
__table_args__ = (db.PrimaryKeyConstraint("id", name="app_mcp_server_pkey"),)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
app_id = db.Column(StringUUID, nullable=False)
name = db.Column(db.String(255), nullable=False)
description = db.Column(db.String(255), nullable=False)
server_code = db.Column(db.String(255), nullable=False)
status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying"))
parameters = db.Column(db.Text, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@staticmethod
def generate_server_code(n):
while True:
result = generate_string(n)
while db.session.query(AppMCPServer).filter(AppMCPServer.server_code == result).count() > 0:
result = generate_string(n)
return result
class Site(Base):
__tablename__ = "sites"
__table_args__ = (

View File

@ -217,7 +217,7 @@ class MCPToolProvider(Base):
# who created this tool
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# encrypted credentials
encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=False)
encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=True)
# authed
authed: Mapped[bool] = mapped_column(db.Boolean, nullable=False, default=False)
# tools

View File

@ -14,6 +14,7 @@ from models.model import (
ApiToken,
AppAnnotationHitHistory,
AppAnnotationSetting,
AppMCPServer,
AppModelConfig,
Conversation,
EndUser,
@ -42,6 +43,7 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str):
# Delete related data
_delete_app_model_configs(tenant_id, app_id)
_delete_app_site(tenant_id, app_id)
_delete_app_mcp_servers(tenant_id, app_id)
_delete_app_api_tokens(tenant_id, app_id)
_delete_installed_apps(tenant_id, app_id)
_delete_recommended_apps(tenant_id, app_id)
@ -90,6 +92,18 @@ def _delete_app_site(tenant_id: str, app_id: str):
_delete_records("""select id from sites where app_id=:app_id limit 1000""", {"app_id": app_id}, del_site, "site")
def _delete_app_mcp_servers(tenant_id: str, app_id: str):
def del_mcp_server(mcp_server_id: str):
db.session.query(AppMCPServer).filter(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False)
_delete_records(
"""select id from app_mcp_servers where app_id=:app_id limit 1000""",
{"app_id": app_id},
del_mcp_server,
"app mcp server",
)
def _delete_app_api_tokens(tenant_id: str, app_id: str):
def del_api_token(api_token_id: str):
db.session.query(ApiToken).filter(ApiToken.id == api_token_id).delete(synchronize_session=False)