mirror of
https://github.com/langgenius/dify.git
synced 2026-04-29 06:58:05 +08:00
feat: add mcp server
This commit is contained in:
@ -56,6 +56,7 @@ from .app import (
|
||||
conversation,
|
||||
conversation_variables,
|
||||
generator,
|
||||
mcp_server,
|
||||
message,
|
||||
model_config,
|
||||
ops_trace,
|
||||
|
||||
83
api/controllers/console/app/mcp_server.py
Normal file
83
api/controllers/console/app/mcp_server.py
Normal file
@ -0,0 +1,83 @@
|
||||
import json
|
||||
from enum import Enum
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from extensions.ext_database import db
|
||||
from fields.app_fields import app_server_fields
|
||||
from libs.login import login_required
|
||||
from models.model import AppMCPServer
|
||||
|
||||
|
||||
class AppMCPServerStatus(str, Enum):
|
||||
ACTIVE = "active"
|
||||
INACTIVE = "inactive"
|
||||
|
||||
|
||||
class AppMCPServerController(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_server_fields)
|
||||
def get(self, app_model):
|
||||
server = db.session.query(AppMCPServer).filter(AppMCPServer.app_id == app_model.id).first()
|
||||
return server
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_server_fields)
|
||||
def post(self, app_model):
|
||||
# The role of the current user in the ta table must be editor, admin, or owner
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("description", type=str, required=True, location="json")
|
||||
parser.add_argument("parameters", type=dict, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
server = AppMCPServer(
|
||||
name=app_model.name,
|
||||
description=args["description"],
|
||||
parameters=json.dumps(args["parameters"], ensure_ascii=False),
|
||||
status=AppMCPServerStatus.ACTIVE,
|
||||
app_id=app_model.id,
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
server_code=AppMCPServer.generate_server_code(16),
|
||||
)
|
||||
db.session.add(server)
|
||||
db.session.commit()
|
||||
|
||||
return server
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model
|
||||
@marshal_with(app_server_fields)
|
||||
def put(self, app_model):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("id", type=str, required=True, location="json")
|
||||
parser.add_argument("description", type=str, required=True, location="json")
|
||||
parser.add_argument("parameters", type=dict, required=True, location="json")
|
||||
parser.add_argument("status", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
server = db.session.query(AppMCPServer).filter(AppMCPServer.id == args["id"]).first()
|
||||
if not server:
|
||||
raise Forbidden()
|
||||
server.description = args["description"]
|
||||
server.parameters = json.dumps(args["parameters"], ensure_ascii=False)
|
||||
server.status = AppMCPServerStatus(args["status"])
|
||||
db.session.commit()
|
||||
return server
|
||||
|
||||
|
||||
api.add_resource(AppMCPServerController, "/apps/<uuid:app_id>/server")
|
||||
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
|
||||
from flask_restful import reqparse
|
||||
from flask_restful import Resource, reqparse
|
||||
from pydantic import ValidationError
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
@ -24,10 +25,13 @@ from core.errors.error import (
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from core.mcp.server.handler import MCPServerReuqestHandler
|
||||
from core.mcp.types import ClientRequest
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
from models.model import AppMode
|
||||
from models.model import App, AppMCPServer, AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
@ -149,7 +153,38 @@ class ChatStopApi(WebApiResource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
class ChatMCPApi(Resource):
|
||||
def post(self, server_code):
|
||||
def int_or_str(value):
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
elif isinstance(value, str):
|
||||
return int(value)
|
||||
else:
|
||||
raise ValueError("Invalid id")
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("jsonrpc", type=str, required=True, location="json")
|
||||
parser.add_argument("method", type=str, required=True, location="json")
|
||||
parser.add_argument("params", type=dict, required=True, location="json")
|
||||
parser.add_argument("id", type=int_or_str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
server = db.session.query(AppMCPServer).filter(AppMCPServer.server_code == server_code).first()
|
||||
if not server:
|
||||
raise NotFound("Server Not Found")
|
||||
app = db.session.query(App).filter(App.id == server.app_id).first()
|
||||
if not app:
|
||||
raise NotFound("App Not Found")
|
||||
try:
|
||||
request = ClientRequest.model_validate(args)
|
||||
except ValidationError as e:
|
||||
raise ValueError(f"Invalid MCP request: {str(e)}")
|
||||
mcp_server_handler = MCPServerReuqestHandler(app, request)
|
||||
return helper.compact_generate_response(mcp_server_handler.handle())
|
||||
|
||||
|
||||
api.add_resource(CompletionApi, "/completion-messages")
|
||||
api.add_resource(CompletionStopApi, "/completion-messages/<string:task_id>/stop")
|
||||
api.add_resource(ChatApi, "/chat-messages")
|
||||
api.add_resource(ChatMCPApi, "/server/<string:server_code>/mcp")
|
||||
api.add_resource(ChatStopApi, "/chat-messages/<string:task_id>/stop")
|
||||
|
||||
@ -21,6 +21,7 @@ class InvokeFrom(Enum):
|
||||
WEB_APP = "web-app"
|
||||
EXPLORE = "explore"
|
||||
DEBUGGER = "debugger"
|
||||
MCP_SERVER = "mcp-server"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str):
|
||||
@ -49,6 +50,8 @@ class InvokeFrom(Enum):
|
||||
return "explore_app"
|
||||
elif self == InvokeFrom.SERVICE_API:
|
||||
return "api"
|
||||
elif self == InvokeFrom.MCP_SERVER:
|
||||
return "mcp_server"
|
||||
|
||||
return "dev"
|
||||
|
||||
|
||||
154
api/core/mcp/server/handler.py
Normal file
154
api/core/mcp/server/handler.py
Normal file
@ -0,0 +1,154 @@
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from typing import cast
|
||||
|
||||
from configs.app_config import DifyConfig
|
||||
from controllers.web.passport import generate_session_id
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.mcp import types
|
||||
from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
||||
"""
|
||||
Apply to MCP HTTP streamable server with stateless http
|
||||
"""
|
||||
dify_config = DifyConfig()
|
||||
|
||||
|
||||
class MCPServerReuqestHandler:
|
||||
def __init__(self, app: App, request: types.ClientRequest):
|
||||
self.app = app
|
||||
self.request = request
|
||||
if not self.app.mcp_server:
|
||||
raise ValueError("MCP server not found")
|
||||
self.mcp_server = self.app.mcp_server
|
||||
self.end_user = self.retrieve_end_user()
|
||||
|
||||
@property
|
||||
def request_type(self):
|
||||
return type(self.request.root)
|
||||
|
||||
@property
|
||||
def parameter_schema(self):
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "User Input/Question content"},
|
||||
"inputs": {
|
||||
"type": "object",
|
||||
"description": "Allows the entry of various variable values defined by the App. The `inputs` parameter contains multiple key/value pairs, with each key corresponding to a specific variable and each value being the specific value for that variable. If the variable is of file type, specify an object that has the keys described in `files`.", # noqa: E501
|
||||
"default": {},
|
||||
# TODO: add input parameters
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
@property
|
||||
def output_parameters(self):
|
||||
return self.app.output_schema
|
||||
|
||||
@property
|
||||
def capabilities(self):
|
||||
return types.ServerCapabilities(
|
||||
tools=types.ToolsCapability(listChanged=False),
|
||||
)
|
||||
|
||||
def response(self, response: types.Result):
|
||||
json_response = types.JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=(self.request.root.model_extra or {}).get("id", 1),
|
||||
result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
json_data = json.dumps(jsonable_encoder(json_response))
|
||||
|
||||
sse_content = f"event: message\ndata: {json_data}\n\n".encode()
|
||||
|
||||
yield sse_content
|
||||
|
||||
def error_response(self, code: int, message: str, data=None):
|
||||
error_data = types.ErrorData(code=code, message=message, data=data)
|
||||
json_response = types.JSONRPCError(
|
||||
jsonrpc="2.0",
|
||||
id=(self.request.root.model_extra or {}).get("id", 1),
|
||||
error=error_data,
|
||||
)
|
||||
json_data = json.dumps(jsonable_encoder(json_response))
|
||||
|
||||
sse_content = f"event: message\ndata: {json_data}\n\n".encode()
|
||||
|
||||
yield sse_content
|
||||
|
||||
def handle(self):
|
||||
handle_map = {
|
||||
types.InitializeRequest: self.initialize,
|
||||
types.ListToolsRequest: self.list_tools,
|
||||
types.CallToolRequest: self.invoke_tool,
|
||||
}
|
||||
try:
|
||||
if self.request_type in handle_map:
|
||||
return self.response(handle_map[self.request_type]())
|
||||
else:
|
||||
return self.error_response(METHOD_NOT_FOUND, f"Method not found: {self.request_type}")
|
||||
except ValueError as e:
|
||||
return self.error_response(INVALID_PARAMS, str(e))
|
||||
except Exception as e:
|
||||
return self.error_response(INTERNAL_ERROR, f"Internal server error: {str(e)}")
|
||||
|
||||
def initialize(self):
|
||||
request = cast(types.InitializeRequest, self.request.root)
|
||||
client_info = request.params.clientInfo
|
||||
clinet_name = f"{client_info.name}@{client_info.version}"
|
||||
if not self.end_user:
|
||||
end_user = EndUser(
|
||||
tenant_id=self.app.tenant_id,
|
||||
app_id=self.app.id,
|
||||
type="mcp",
|
||||
name=clinet_name,
|
||||
session_id=generate_session_id(),
|
||||
external_user_id=self.mcp_server.id,
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
|
||||
return types.InitializeResult(
|
||||
protocolVersion=types.LATEST_PROTOCOL_VERSION,
|
||||
capabilities=self.capabilities,
|
||||
serverInfo=types.Implementation(name="Dify", version=dify_config.CURRENT_VERSION),
|
||||
instructions=self.mcp_server.description,
|
||||
)
|
||||
|
||||
def list_tools(self):
|
||||
if not self.end_user:
|
||||
raise ValueError("User not found")
|
||||
return types.ListToolsResult(
|
||||
tools=[
|
||||
types.Tool(
|
||||
name=self.mcp_server.name,
|
||||
description=self.mcp_server.description,
|
||||
inputSchema=self.parameter_schema,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def invoke_tool(self):
|
||||
if not self.end_user:
|
||||
raise ValueError("User not found")
|
||||
request = cast(types.CallToolRequest, self.request.root)
|
||||
args = request.params.arguments
|
||||
if not args:
|
||||
raise ValueError("No arguments provided")
|
||||
response = AppGenerateService.generate(self.app, self.end_user, args, InvokeFrom.MCP_SERVER, streaming=False)
|
||||
if isinstance(response, Mapping):
|
||||
return types.CallToolResult(content=[types.TextContent(text=response["answer"], type="text")])
|
||||
return None
|
||||
|
||||
def retrieve_end_user(self):
|
||||
return (
|
||||
db.session.query(EndUser)
|
||||
.filter(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
|
||||
.first()
|
||||
)
|
||||
@ -3,11 +3,13 @@ from typing import Any, Protocol
|
||||
|
||||
from pydantic import AnyUrl, TypeAdapter
|
||||
|
||||
from configs.app_config import DifyConfig
|
||||
from core.mcp import types
|
||||
from core.mcp.entities import SUPPORTED_PROTOCOL_VERSIONS, RequestContext
|
||||
from core.mcp.session.base_session import BaseSession, RequestResponder
|
||||
|
||||
DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0")
|
||||
dify_config = DifyConfig()
|
||||
DEFAULT_CLIENT_INFO = types.Implementation(name="Dify", version=dify_config.CURRENT_VERSION)
|
||||
|
||||
|
||||
class SamplingFnT(Protocol):
|
||||
|
||||
@ -213,3 +213,14 @@ app_import_fields = {
|
||||
app_import_check_dependencies_fields = {
|
||||
"leaked_dependencies": fields.List(fields.Nested(leaked_dependency_fields)),
|
||||
}
|
||||
|
||||
app_server_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"server_code": fields.String,
|
||||
"description": fields.String,
|
||||
"status": fields.String,
|
||||
"parameters": fields.Raw,
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
}
|
||||
|
||||
@ -0,0 +1,41 @@
|
||||
"""add app mcp server
|
||||
|
||||
Revision ID: ca4c4abcc347
|
||||
Revises: 1e67f2654a08
|
||||
Create Date: 2025-05-22 16:23:44.206102
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'ca4c4abcc347'
|
||||
down_revision = '1e67f2654a08'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('app_mcp_servers',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('description', sa.String(length=255), nullable=False),
|
||||
sa.Column('server_code', sa.String(length=255), nullable=False),
|
||||
sa.Column('status', sa.String(length=255), server_default=sa.text("'normal'::character varying"), nullable=False),
|
||||
sa.Column('parameters', sa.Text(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='app_mcp_server_pkey')
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('app_mcp_servers')
|
||||
# ### end Alembic commands ###
|
||||
@ -294,6 +294,10 @@ class App(Base):
|
||||
|
||||
return tags or []
|
||||
|
||||
@property
|
||||
def mcp_server(self):
|
||||
return db.session.query(AppMCPServer).filter(AppMCPServer.app_id == self.id).first()
|
||||
|
||||
|
||||
class AppModelConfig(Base):
|
||||
__tablename__ = "app_model_configs"
|
||||
@ -1433,6 +1437,31 @@ class EndUser(Base, UserMixin):
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class AppMCPServer(Base):
|
||||
__tablename__ = "app_mcp_servers"
|
||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="app_mcp_server_pkey"),)
|
||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
app_id = db.Column(StringUUID, nullable=False)
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
description = db.Column(db.String(255), nullable=False)
|
||||
server_code = db.Column(db.String(255), nullable=False)
|
||||
status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying"))
|
||||
parameters = db.Column(db.Text, nullable=False)
|
||||
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@staticmethod
|
||||
def generate_server_code(n):
|
||||
while True:
|
||||
result = generate_string(n)
|
||||
while db.session.query(AppMCPServer).filter(AppMCPServer.server_code == result).count() > 0:
|
||||
result = generate_string(n)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class Site(Base):
|
||||
__tablename__ = "sites"
|
||||
__table_args__ = (
|
||||
|
||||
@ -217,7 +217,7 @@ class MCPToolProvider(Base):
|
||||
# who created this tool
|
||||
user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# encrypted credentials
|
||||
encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=False)
|
||||
encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=True)
|
||||
# authed
|
||||
authed: Mapped[bool] = mapped_column(db.Boolean, nullable=False, default=False)
|
||||
# tools
|
||||
|
||||
@ -14,6 +14,7 @@ from models.model import (
|
||||
ApiToken,
|
||||
AppAnnotationHitHistory,
|
||||
AppAnnotationSetting,
|
||||
AppMCPServer,
|
||||
AppModelConfig,
|
||||
Conversation,
|
||||
EndUser,
|
||||
@ -42,6 +43,7 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str):
|
||||
# Delete related data
|
||||
_delete_app_model_configs(tenant_id, app_id)
|
||||
_delete_app_site(tenant_id, app_id)
|
||||
_delete_app_mcp_servers(tenant_id, app_id)
|
||||
_delete_app_api_tokens(tenant_id, app_id)
|
||||
_delete_installed_apps(tenant_id, app_id)
|
||||
_delete_recommended_apps(tenant_id, app_id)
|
||||
@ -90,6 +92,18 @@ def _delete_app_site(tenant_id: str, app_id: str):
|
||||
_delete_records("""select id from sites where app_id=:app_id limit 1000""", {"app_id": app_id}, del_site, "site")
|
||||
|
||||
|
||||
def _delete_app_mcp_servers(tenant_id: str, app_id: str):
|
||||
def del_mcp_server(mcp_server_id: str):
|
||||
db.session.query(AppMCPServer).filter(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False)
|
||||
|
||||
_delete_records(
|
||||
"""select id from app_mcp_servers where app_id=:app_id limit 1000""",
|
||||
{"app_id": app_id},
|
||||
del_mcp_server,
|
||||
"app mcp server",
|
||||
)
|
||||
|
||||
|
||||
def _delete_app_api_tokens(tenant_id: str, app_id: str):
|
||||
def del_api_token(api_token_id: str):
|
||||
db.session.query(ApiToken).filter(ApiToken.id == api_token_id).delete(synchronize_session=False)
|
||||
|
||||
Reference in New Issue
Block a user