mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 18:08:07 +08:00
Merge branch 'feat/mcp' into deploy/dev
This commit is contained in:
@ -1,9 +1,9 @@
|
||||
import json
|
||||
from enum import Enum
|
||||
from enum import StrEnum
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
@ -14,7 +14,7 @@ from libs.login import login_required
|
||||
from models.model import AppMCPServer
|
||||
|
||||
|
||||
class AppMCPServerStatus(str, Enum):
|
||||
class AppMCPServerStatus(StrEnum):
|
||||
ACTIVE = "active"
|
||||
INACTIVE = "inactive"
|
||||
|
||||
@ -37,7 +37,7 @@ class AppMCPServerController(Resource):
|
||||
def post(self, app_model):
|
||||
# The role of the current user in the ta table must be editor, admin, or owner
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
raise NotFound()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("description", type=str, required=True, location="json")
|
||||
parser.add_argument("parameters", type=dict, required=True, location="json")
|
||||
@ -62,7 +62,7 @@ class AppMCPServerController(Resource):
|
||||
@marshal_with(app_server_fields)
|
||||
def put(self, app_model):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
raise NotFound()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("id", type=str, required=True, location="json")
|
||||
parser.add_argument("description", type=str, required=True, location="json")
|
||||
@ -71,7 +71,7 @@ class AppMCPServerController(Resource):
|
||||
args = parser.parse_args()
|
||||
server = db.session.query(AppMCPServer).filter(AppMCPServer.id == args["id"]).first()
|
||||
if not server:
|
||||
raise Forbidden()
|
||||
raise NotFound()
|
||||
server.description = args["description"]
|
||||
server.parameters = json.dumps(args["parameters"], ensure_ascii=False)
|
||||
if args["status"]:
|
||||
@ -89,10 +89,10 @@ class AppMCPServerRefreshController(Resource):
|
||||
@marshal_with(app_server_fields)
|
||||
def get(self, server_id):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
raise NotFound()
|
||||
server = db.session.query(AppMCPServer).filter(AppMCPServer.id == server_id).first()
|
||||
if not server:
|
||||
raise Forbidden()
|
||||
raise NotFound()
|
||||
server.server_code = AppMCPServer.generate_server_code(16)
|
||||
db.session.commit()
|
||||
return server
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import io
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import validators
|
||||
from flask import redirect, send_file
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
@ -27,6 +27,17 @@ from services.tools.tools_transform_service import ToolTransformService
|
||||
from services.tools.workflow_tools_manage_service import WorkflowToolManageService
|
||||
|
||||
|
||||
def is_valid_url(url: str) -> bool:
|
||||
if not url:
|
||||
return False
|
||||
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
return all([parsed.scheme, parsed.netloc]) and parsed.scheme in ["http", "https"]
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
class ToolProviderListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@ -634,7 +645,7 @@ class ToolProviderMCPApi(Resource):
|
||||
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
user = current_user
|
||||
if not validators.url(args["server_url"]):
|
||||
if not is_valid_url(args["server_url"]):
|
||||
raise ValueError("Server URL is not valid.")
|
||||
return jsonable_encoder(
|
||||
MCPToolManageService.create_mcp_provider(
|
||||
@ -662,7 +673,7 @@ class ToolProviderMCPApi(Resource):
|
||||
parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
if not validators.url(args["server_url"]):
|
||||
if not is_valid_url(args["server_url"]):
|
||||
if "[__HIDDEN__]" in args["server_url"]:
|
||||
pass
|
||||
else:
|
||||
|
||||
@ -8,7 +8,7 @@ from controllers.web.error import (
|
||||
AppUnavailableError,
|
||||
)
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.mcp.server.handler import MCPServerReuqestHandler
|
||||
from core.mcp.server.handler import MCPServerRequestHandler
|
||||
from core.mcp.types import ClientNotification, ClientRequest
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
@ -66,7 +66,7 @@ class MCPAppApi(Resource):
|
||||
except ValidationError as e:
|
||||
raise ValueError(f"Invalid MCP request: {str(e)}")
|
||||
|
||||
mcp_server_handler = MCPServerReuqestHandler(app, request, user_input_form)
|
||||
mcp_server_handler = MCPServerRequestHandler(app, request, user_input_form)
|
||||
response = mcp_server_handler.handle()
|
||||
return helper.compact_generate_response(response)
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ from typing import Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from core.mcp.auth.auth_provider import OAuthClientProvider
|
||||
from core.mcp.types import (
|
||||
@ -60,7 +60,7 @@ def _create_secure_redis_state(state_data: OAuthCallbackState) -> str:
|
||||
|
||||
|
||||
def _retrieve_redis_state(state_key: str) -> OAuthCallbackState:
|
||||
"""Retrieve and decode OAuth state data from Redis using the state key."""
|
||||
"""Retrieve and decode OAuth state data from Redis using the state key, then delete it."""
|
||||
redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}"
|
||||
|
||||
# Get state data from Redis
|
||||
@ -69,27 +69,23 @@ def _retrieve_redis_state(state_key: str) -> OAuthCallbackState:
|
||||
if not state_data:
|
||||
raise ValueError("State parameter has expired or does not exist")
|
||||
|
||||
# Delete the state data from Redis immediately after retrieval to prevent reuse
|
||||
redis_client.delete(redis_key)
|
||||
|
||||
try:
|
||||
# Parse and validate the state data
|
||||
if isinstance(state_data, bytes):
|
||||
state_data = state_data.decode("utf-8")
|
||||
|
||||
oauth_state = OAuthCallbackState.model_validate_json(state_data)
|
||||
|
||||
return oauth_state
|
||||
except Exception as e:
|
||||
except ValidationError as e:
|
||||
raise ValueError(f"Invalid state parameter: {str(e)}")
|
||||
|
||||
|
||||
def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackState:
|
||||
"""Handle the callback from the OAuth provider."""
|
||||
# Retrieve state data from Redis
|
||||
# Retrieve state data from Redis (state is automatically deleted after retrieval)
|
||||
full_state_data = _retrieve_redis_state(state_key)
|
||||
|
||||
# Clean up the state data from Redis after successful retrieval
|
||||
redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}"
|
||||
redis_client.delete(redis_key)
|
||||
|
||||
tokens = exchange_authorization(
|
||||
full_state_data.server_url,
|
||||
full_state_data.metadata,
|
||||
|
||||
@ -3,7 +3,7 @@ import queue
|
||||
from collections.abc import Generator
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, cast
|
||||
from typing import Any, TypeAlias, final
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import httpx
|
||||
@ -18,10 +18,23 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_QUEUE_READ_TIMEOUT = 3
|
||||
|
||||
|
||||
@final
|
||||
class _StatusReady:
|
||||
def __init__(self, endpoint_url: str):
|
||||
self._endpoint_url = endpoint_url
|
||||
|
||||
|
||||
@final
|
||||
class _StatusError:
|
||||
def __init__(self, exc: Exception):
|
||||
self._exc = exc
|
||||
|
||||
|
||||
# Type aliases for better readability
|
||||
ReadQueue = queue.Queue[SessionMessage | Exception | None]
|
||||
WriteQueue = queue.Queue[SessionMessage | Exception | None]
|
||||
StatusQueue = queue.Queue[tuple[str, str | Exception]]
|
||||
ReadQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
|
||||
WriteQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
|
||||
StatusQueue: TypeAlias = queue.Queue[_StatusReady | _StatusError]
|
||||
|
||||
|
||||
def remove_request_params(url: str) -> str:
|
||||
@ -80,10 +93,10 @@ class SSETransport:
|
||||
if not self._validate_endpoint_url(endpoint_url):
|
||||
error_msg = f"Endpoint origin does not match connection origin: {endpoint_url}"
|
||||
logger.error(error_msg)
|
||||
status_queue.put(("error", ValueError(error_msg)))
|
||||
status_queue.put(_StatusError(ValueError(error_msg)))
|
||||
return
|
||||
|
||||
status_queue.put(("ready", endpoint_url))
|
||||
status_queue.put(_StatusReady(endpoint_url))
|
||||
|
||||
def _handle_message_event(self, sse_data: str, read_queue: ReadQueue) -> None:
|
||||
"""Handle a 'message' SSE event.
|
||||
@ -197,18 +210,17 @@ class SSETransport:
|
||||
ValueError: If endpoint URL is not received or there's an error.
|
||||
"""
|
||||
try:
|
||||
status, endpoint_url_or_error = status_queue.get(timeout=1)
|
||||
status = status_queue.get(timeout=1)
|
||||
except queue.Empty:
|
||||
raise ValueError("failed to get endpoint URL")
|
||||
|
||||
if status != "ready":
|
||||
if isinstance(status, _StatusReady):
|
||||
return status._endpoint_url
|
||||
elif isinstance(status, _StatusError):
|
||||
raise status._exc
|
||||
else:
|
||||
raise ValueError("failed to get endpoint URL")
|
||||
|
||||
if status == "error" and isinstance(endpoint_url_or_error, Exception):
|
||||
raise endpoint_url_or_error
|
||||
|
||||
return cast(str, endpoint_url_or_error)
|
||||
|
||||
def connect(
|
||||
self,
|
||||
executor: ThreadPoolExecutor,
|
||||
@ -284,9 +296,9 @@ def sse_client(
|
||||
if exc.response.status_code == 401:
|
||||
raise MCPAuthError()
|
||||
raise MCPConnectionError()
|
||||
except Exception as exc:
|
||||
except Exception:
|
||||
logger.exception("Error connecting to SSE endpoint")
|
||||
raise exc
|
||||
raise
|
||||
finally:
|
||||
# Clean up queues
|
||||
if read_queue:
|
||||
|
||||
@ -94,14 +94,15 @@ class MCPClient:
|
||||
if self._streams_context is None:
|
||||
raise MCPConnectionError("Failed to create connection context")
|
||||
|
||||
# Use exit_stack to manage context managers properly
|
||||
if method_name == "mcp":
|
||||
read_stream, write_stream, _ = self._streams_context.__enter__()
|
||||
read_stream, write_stream, _ = self.exit_stack.enter_context(self._streams_context)
|
||||
streams = (read_stream, write_stream)
|
||||
else: # sse_client
|
||||
streams = self._streams_context.__enter__()
|
||||
streams = self.exit_stack.enter_context(self._streams_context)
|
||||
|
||||
self._session_context = ClientSession(*streams)
|
||||
self._session = self._session_context.__enter__()
|
||||
self._session = self.exit_stack.enter_context(self._session_context)
|
||||
session = cast(ClientSession, self._session)
|
||||
session.initialize()
|
||||
return
|
||||
@ -138,14 +139,12 @@ class MCPClient:
|
||||
def cleanup(self):
|
||||
"""Clean up resources"""
|
||||
try:
|
||||
if self._session_context:
|
||||
self._session_context.__exit__(None, None, None)
|
||||
|
||||
if self._streams_context:
|
||||
self._streams_context.__exit__(None, None, None)
|
||||
self._session = None
|
||||
self._initialized = False
|
||||
# ExitStack will handle proper cleanup of all managed context managers
|
||||
self.exit_stack.close()
|
||||
self._session = None
|
||||
self._session_context = None
|
||||
self._streams_context = None
|
||||
self._initialized = False
|
||||
except Exception as e:
|
||||
logging.exception("Error during cleanup")
|
||||
raise ValueError(f"Error during cleanup: {e}")
|
||||
|
||||
@ -18,15 +18,16 @@ Apply to MCP HTTP streamable server with stateless http
|
||||
"""
|
||||
|
||||
|
||||
class MCPServerReuqestHandler:
|
||||
class MCPServerRequestHandler:
|
||||
def __init__(
|
||||
self, app: App, request: types.ClientRequest | types.ClientNotification, user_input_form: list[VariableEntity]
|
||||
):
|
||||
self.app = app
|
||||
self.request = request
|
||||
if not self.app.mcp_server:
|
||||
mcp_server = db.session.query(AppMCPServer).filter(AppMCPServer.app_id == self.app.id).first()
|
||||
if not mcp_server:
|
||||
raise ValueError("MCP server not found")
|
||||
self.mcp_server: AppMCPServer = self.app.mcp_server
|
||||
self.mcp_server: AppMCPServer = mcp_server
|
||||
self.end_user = self.retrieve_end_user()
|
||||
self.user_input_form = user_input_form
|
||||
|
||||
|
||||
@ -4,20 +4,8 @@ from configs import dify_config
|
||||
|
||||
SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
|
||||
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY = True # Default value for HTTP_REQUEST_NODE_SSL_VERIFY is True
|
||||
try:
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
|
||||
http_request_node_ssl_verify_lower = str(HTTP_REQUEST_NODE_SSL_VERIFY).lower()
|
||||
if http_request_node_ssl_verify_lower == "true":
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY = True
|
||||
elif http_request_node_ssl_verify_lower == "false":
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY = False
|
||||
else:
|
||||
raise ValueError("Invalid value. HTTP_REQUEST_NODE_SSL_VERIFY should be 'True' or 'False'")
|
||||
except NameError:
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY = True
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
|
||||
|
||||
BACKOFF_FACTOR = 0.5
|
||||
STATUS_FORCELIST = [429, 500, 502, 503, 504]
|
||||
|
||||
|
||||
|
||||
@ -46,11 +46,11 @@ class MCPToolProviderController(ToolProviderController):
|
||||
tools = []
|
||||
tools_data = json.loads(db_provider.tools)
|
||||
remote_mcp_tools = [RemoteMCPTool(**tool) for tool in tools_data]
|
||||
|
||||
user = db_provider.load_user()
|
||||
tools = [
|
||||
ToolEntity(
|
||||
identity=ToolIdentity(
|
||||
author=db_provider.user.name if db_provider.user else "Anonymous",
|
||||
author=user.name if user else "Anonymous",
|
||||
name=remote_mcp_tool.name,
|
||||
label=I18nObject(en_US=remote_mcp_tool.name, zh_Hans=remote_mcp_tool.name),
|
||||
provider=db_provider.server_identifier,
|
||||
@ -72,7 +72,7 @@ class MCPToolProviderController(ToolProviderController):
|
||||
return cls(
|
||||
entity=ToolProviderEntityWithPlugin(
|
||||
identity=ToolProviderIdentity(
|
||||
author=db_provider.user.name if db_provider.user else "Anonymous",
|
||||
author=user.name if user else "Anonymous",
|
||||
name=db_provider.name,
|
||||
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
|
||||
description=I18nObject(en_US="", zh_Hans=""),
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import base64
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional
|
||||
|
||||
@ -49,6 +50,11 @@ class MCPTool(Tool):
|
||||
for content in result.content:
|
||||
if isinstance(content, TextContent):
|
||||
yield self.create_text_message(content.text)
|
||||
try:
|
||||
yield self.create_json_message(json.loads(content.text))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
elif isinstance(content, ImageContent):
|
||||
yield self.create_blob_message(
|
||||
blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType}
|
||||
|
||||
@ -32,7 +32,7 @@ def upgrade():
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='app_mcp_server_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', 'app_id', name='unique_app_mcp_server_tenant_app_id'),
|
||||
sa.UniqueConstraint('tenant_id', 'server_code', name='unique_app_mcp_server_tenant_server_code')
|
||||
sa.UniqueConstraint('server_code', name='unique_app_mcp_server_server_code')
|
||||
)
|
||||
op.create_table('tool_mcp_providers',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
|
||||
@ -294,10 +294,6 @@ class App(Base):
|
||||
|
||||
return tags or []
|
||||
|
||||
@property
|
||||
def mcp_server(self):
|
||||
return db.session.query(AppMCPServer).filter(AppMCPServer.app_id == self.id).first()
|
||||
|
||||
@property
|
||||
def author_name(self):
|
||||
if self.created_by:
|
||||
@ -1465,7 +1461,7 @@ class AppMCPServer(Base):
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="app_mcp_server_pkey"),
|
||||
db.UniqueConstraint("tenant_id", "app_id", name="unique_app_mcp_server_tenant_app_id"),
|
||||
db.UniqueConstraint("tenant_id", "server_code", name="unique_app_mcp_server_tenant_server_code"),
|
||||
db.UniqueConstraint("server_code", name="unique_app_mcp_server_server_code"),
|
||||
)
|
||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
|
||||
@ -234,8 +234,7 @@ class MCPToolProvider(Base):
|
||||
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
|
||||
)
|
||||
|
||||
@property
|
||||
def user(self) -> Account | None:
|
||||
def load_user(self) -> Account | None:
|
||||
return db.session.query(Account).filter(Account.id == self.user_id).first()
|
||||
|
||||
@property
|
||||
|
||||
@ -125,13 +125,14 @@ class MCPToolManageService:
|
||||
mcp_provider.authed = True
|
||||
mcp_provider.updated_at = datetime.now()
|
||||
db.session.commit()
|
||||
user = mcp_provider.load_user()
|
||||
return ToolProviderApiEntity(
|
||||
id=mcp_provider.id,
|
||||
name=mcp_provider.name,
|
||||
tools=ToolTransformService.mcp_tool_to_user_tool(mcp_provider, tools),
|
||||
type=ToolProviderType.MCP,
|
||||
icon=mcp_provider.icon,
|
||||
author=mcp_provider.user.name if mcp_provider.user else "Anonymous",
|
||||
author=user.name if user else "Anonymous",
|
||||
server_url=mcp_provider.masked_server_url,
|
||||
updated_at=int(mcp_provider.updated_at.timestamp()),
|
||||
description=I18nObject(en_US="", zh_Hans=""),
|
||||
|
||||
@ -191,9 +191,10 @@ class ToolTransformService:
|
||||
|
||||
@staticmethod
|
||||
def mcp_provider_to_user_provider(db_provider: MCPToolProvider, for_list: bool = False) -> ToolProviderApiEntity:
|
||||
user = db_provider.load_user()
|
||||
return ToolProviderApiEntity(
|
||||
id=db_provider.server_identifier if not for_list else db_provider.id,
|
||||
author=db_provider.user.name if db_provider.user else "Anonymous",
|
||||
author=user.name if user else "Anonymous",
|
||||
name=db_provider.name,
|
||||
icon=db_provider.provider_icon,
|
||||
type=ToolProviderType.MCP,
|
||||
@ -210,9 +211,10 @@ class ToolTransformService:
|
||||
|
||||
@staticmethod
|
||||
def mcp_tool_to_user_tool(mcp_provider: MCPToolProvider, tools: list[MCPTool]) -> list[ToolApiEntity]:
|
||||
user = mcp_provider.load_user()
|
||||
return [
|
||||
ToolApiEntity(
|
||||
author=mcp_provider.user.name if mcp_provider.user else "Anonymous",
|
||||
author=user.name if user else "Anonymous",
|
||||
name=tool.name,
|
||||
label=I18nObject(en_US=tool.name, zh_Hans=tool.name),
|
||||
description=I18nObject(en_US=tool.description, zh_Hans=tool.description),
|
||||
|
||||
Reference in New Issue
Block a user