mirror of
https://github.com/langgenius/dify.git
synced 2026-05-02 00:18:03 +08:00
feat: add multi app mode's server support
This commit is contained in:
@ -1,5 +1,6 @@
|
||||
from flask_restful import Resource, reqparse
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.mcp import api
|
||||
@ -59,8 +60,9 @@ class MCPAppApi(Resource):
|
||||
request = ClientRequest.model_validate(args)
|
||||
except ValidationError as e:
|
||||
raise ValueError(f"Invalid MCP request: {str(e)}")
|
||||
mcp_server_handler = MCPServerReuqestHandler(app, request, user_input_form)
|
||||
return helper.compact_generate_response(mcp_server_handler.handle())
|
||||
with Session(db.engine) as session:
|
||||
mcp_server_handler = MCPServerReuqestHandler(app, request, user_input_form, session)
|
||||
return helper.compact_generate_response(mcp_server_handler.handle())
|
||||
|
||||
|
||||
api.add_resource(MCPAppApi, "/server/<string:server_code>/mcp")
|
||||
|
||||
@ -108,94 +108,3 @@ def delete(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
|
||||
def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
return make_request("HEAD", url, max_retries=max_retries, **kwargs)
|
||||
|
||||
|
||||
def create_ssrf_proxy_mcp_http_client(
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: httpx.Timeout | None = None,
|
||||
) -> httpx.Client:
|
||||
"""Create an HTTPX client with SSRF proxy configuration for MCP connections.
|
||||
|
||||
Args:
|
||||
headers: Optional headers to include in the client
|
||||
timeout: Optional timeout configuration
|
||||
|
||||
Returns:
|
||||
Configured httpx.Client with proxy settings
|
||||
"""
|
||||
if dify_config.SSRF_PROXY_ALL_URL:
|
||||
return httpx.Client(
|
||||
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
|
||||
headers=headers or {},
|
||||
timeout=timeout,
|
||||
follow_redirects=True,
|
||||
proxy=dify_config.SSRF_PROXY_ALL_URL,
|
||||
)
|
||||
elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
|
||||
proxy_mounts = {
|
||||
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY),
|
||||
"https://": httpx.HTTPTransport(
|
||||
proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY
|
||||
),
|
||||
}
|
||||
return httpx.Client(
|
||||
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
|
||||
headers=headers or {},
|
||||
timeout=timeout,
|
||||
follow_redirects=True,
|
||||
mounts=proxy_mounts,
|
||||
)
|
||||
else:
|
||||
return httpx.Client(
|
||||
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
|
||||
headers=headers or {},
|
||||
timeout=timeout,
|
||||
follow_redirects=True,
|
||||
)
|
||||
|
||||
|
||||
def ssrf_proxy_sse_connect(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
"""Connect to SSE endpoint with SSRF proxy protection.
|
||||
|
||||
This function creates an SSE connection using the configured proxy settings
|
||||
to prevent SSRF attacks when connecting to external endpoints.
|
||||
|
||||
Args:
|
||||
url: The SSE endpoint URL
|
||||
max_retries: Maximum number of retry attempts
|
||||
**kwargs: Additional arguments passed to the SSE connection
|
||||
|
||||
Returns:
|
||||
EventSource object for SSE streaming
|
||||
"""
|
||||
from httpx_sse import connect_sse
|
||||
|
||||
# Extract client if provided, otherwise create one
|
||||
client = kwargs.pop("client", None)
|
||||
if client is None:
|
||||
# Create client with SSRF proxy configuration
|
||||
timeout = kwargs.pop(
|
||||
"timeout",
|
||||
httpx.Timeout(
|
||||
timeout=dify_config.SSRF_DEFAULT_TIME_OUT,
|
||||
connect=dify_config.SSRF_DEFAULT_CONNECT_TIME_OUT,
|
||||
read=dify_config.SSRF_DEFAULT_READ_TIME_OUT,
|
||||
write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT,
|
||||
),
|
||||
)
|
||||
headers = kwargs.pop("headers", {})
|
||||
client = create_ssrf_proxy_mcp_http_client(headers=headers, timeout=timeout)
|
||||
client_provided = False
|
||||
else:
|
||||
client_provided = True
|
||||
|
||||
# Extract method if provided, default to GET
|
||||
method = kwargs.pop("method", "GET")
|
||||
|
||||
try:
|
||||
return connect_sse(client, method, url, **kwargs)
|
||||
except Exception as e:
|
||||
# If we created the client, we need to clean it up on error
|
||||
if not client_provided:
|
||||
client.close()
|
||||
raise
|
||||
|
||||
@ -9,128 +9,276 @@ from urllib.parse import urljoin, urlparse
|
||||
import httpx
|
||||
from sseclient import SSEClient
|
||||
|
||||
from core.helper.ssrf_proxy import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
|
||||
from core.mcp import types
|
||||
from core.mcp.error import MCPAuthError, MCPConnectionError
|
||||
from core.mcp.types import SessionMessage
|
||||
from core.mcp.utils import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_QUEUE_READ_TIMEOUT = 3
|
||||
|
||||
# Type aliases for better readability
|
||||
ReadQueue = queue.Queue[SessionMessage | Exception | None]
|
||||
WriteQueue = queue.Queue[SessionMessage | Exception | None]
|
||||
StatusQueue = queue.Queue[tuple[str, str | Exception]]
|
||||
|
||||
|
||||
def remove_request_params(url: str) -> str:
|
||||
"""Remove request parameters from URL, keeping only the path."""
|
||||
return urljoin(url, urlparse(url).path)
|
||||
|
||||
|
||||
class SSETransport:
|
||||
"""SSE client transport implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: float = 5.0,
|
||||
sse_read_timeout: float = 5 * 60,
|
||||
) -> None:
|
||||
"""Initialize the SSE transport.
|
||||
|
||||
Args:
|
||||
url: The SSE endpoint URL.
|
||||
headers: Optional headers to include in requests.
|
||||
timeout: HTTP timeout for regular operations.
|
||||
sse_read_timeout: Timeout for SSE read operations.
|
||||
"""
|
||||
self.url = url
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout
|
||||
self.sse_read_timeout = sse_read_timeout
|
||||
self.endpoint_url: str | None = None
|
||||
|
||||
def _validate_endpoint_url(self, endpoint_url: str) -> bool:
|
||||
"""Validate that the endpoint URL matches the connection origin.
|
||||
|
||||
Args:
|
||||
endpoint_url: The endpoint URL to validate.
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise.
|
||||
"""
|
||||
url_parsed = urlparse(self.url)
|
||||
endpoint_parsed = urlparse(endpoint_url)
|
||||
|
||||
return url_parsed.netloc == endpoint_parsed.netloc and url_parsed.scheme == endpoint_parsed.scheme
|
||||
|
||||
def _handle_endpoint_event(self, sse_data: str, status_queue: StatusQueue) -> None:
|
||||
"""Handle an 'endpoint' SSE event.
|
||||
|
||||
Args:
|
||||
sse_data: The SSE event data.
|
||||
status_queue: Queue to put status updates.
|
||||
"""
|
||||
endpoint_url = urljoin(self.url, sse_data)
|
||||
logger.info(f"Received endpoint URL: {endpoint_url}")
|
||||
|
||||
if not self._validate_endpoint_url(endpoint_url):
|
||||
error_msg = f"Endpoint origin does not match connection origin: {endpoint_url}"
|
||||
logger.error(error_msg)
|
||||
status_queue.put(("error", ValueError(error_msg)))
|
||||
return
|
||||
|
||||
status_queue.put(("ready", endpoint_url))
|
||||
|
||||
def _handle_message_event(self, sse_data: str, read_queue: ReadQueue) -> None:
|
||||
"""Handle a 'message' SSE event.
|
||||
|
||||
Args:
|
||||
sse_data: The SSE event data.
|
||||
read_queue: Queue to put parsed messages.
|
||||
"""
|
||||
try:
|
||||
message = types.JSONRPCMessage.model_validate_json(sse_data)
|
||||
logger.debug(f"Received server message: {message}")
|
||||
session_message = SessionMessage(message)
|
||||
read_queue.put(session_message)
|
||||
except Exception as exc:
|
||||
logger.exception("Error parsing server message")
|
||||
read_queue.put(exc)
|
||||
|
||||
def _handle_sse_event(self, sse, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
|
||||
"""Handle a single SSE event.
|
||||
|
||||
Args:
|
||||
sse: The SSE event object.
|
||||
read_queue: Queue for message events.
|
||||
status_queue: Queue for status events.
|
||||
"""
|
||||
match sse.event:
|
||||
case "endpoint":
|
||||
self._handle_endpoint_event(sse.data, status_queue)
|
||||
case "message":
|
||||
self._handle_message_event(sse.data, read_queue)
|
||||
case _:
|
||||
logger.warning(f"Unknown SSE event: {sse.event}")
|
||||
|
||||
def sse_reader(self, event_source, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
|
||||
"""Read and process SSE events.
|
||||
|
||||
Args:
|
||||
event_source: The SSE event source.
|
||||
read_queue: Queue to put received messages.
|
||||
status_queue: Queue to put status updates.
|
||||
"""
|
||||
try:
|
||||
for sse in event_source.iter_sse():
|
||||
self._handle_sse_event(sse, read_queue, status_queue)
|
||||
except httpx.ReadError as exc:
|
||||
logger.debug(f"SSE reader shutting down normally: {exc}")
|
||||
except Exception as exc:
|
||||
read_queue.put(exc)
|
||||
finally:
|
||||
read_queue.put(None)
|
||||
|
||||
def _send_message(self, client: httpx.Client, endpoint_url: str, message: SessionMessage) -> None:
|
||||
"""Send a single message to the server.
|
||||
|
||||
Args:
|
||||
client: HTTP client to use.
|
||||
endpoint_url: The endpoint URL to send to.
|
||||
message: The message to send.
|
||||
"""
|
||||
response = client.post(
|
||||
endpoint_url,
|
||||
json=message.message.model_dump(
|
||||
by_alias=True,
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
),
|
||||
)
|
||||
response.raise_for_status()
|
||||
logger.debug(f"Client message sent successfully: {response.status_code}")
|
||||
|
||||
def post_writer(self, client: httpx.Client, endpoint_url: str, write_queue: WriteQueue) -> None:
|
||||
"""Handle writing messages to the server.
|
||||
|
||||
Args:
|
||||
client: HTTP client to use.
|
||||
endpoint_url: The endpoint URL to send messages to.
|
||||
write_queue: Queue to read messages from.
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
message = write_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
|
||||
if message is None:
|
||||
break
|
||||
if isinstance(message, Exception):
|
||||
write_queue.put(message)
|
||||
continue
|
||||
|
||||
self._send_message(client, endpoint_url, message)
|
||||
|
||||
except queue.Empty:
|
||||
continue
|
||||
except httpx.ReadError as exc:
|
||||
logger.debug(f"Post writer shutting down normally: {exc}")
|
||||
except Exception as exc:
|
||||
logger.exception("Error writing messages")
|
||||
write_queue.put(exc)
|
||||
finally:
|
||||
write_queue.put(None)
|
||||
|
||||
def _wait_for_endpoint(self, status_queue: StatusQueue) -> str:
|
||||
"""Wait for the endpoint URL from the status queue.
|
||||
|
||||
Args:
|
||||
status_queue: Queue to read status from.
|
||||
|
||||
Returns:
|
||||
The endpoint URL.
|
||||
|
||||
Raises:
|
||||
ValueError: If endpoint URL is not received or there's an error.
|
||||
"""
|
||||
try:
|
||||
status, endpoint_url_or_error = status_queue.get(timeout=1)
|
||||
except queue.Empty:
|
||||
raise ValueError("failed to get endpoint URL")
|
||||
|
||||
if status != "ready":
|
||||
raise ValueError("failed to get endpoint URL")
|
||||
|
||||
if status == "error" and isinstance(endpoint_url_or_error, Exception):
|
||||
raise endpoint_url_or_error
|
||||
|
||||
return cast(str, endpoint_url_or_error)
|
||||
|
||||
def connect(
|
||||
self,
|
||||
executor: ThreadPoolExecutor,
|
||||
client: httpx.Client,
|
||||
event_source,
|
||||
) -> tuple[ReadQueue, WriteQueue]:
|
||||
"""Establish connection and start worker threads.
|
||||
|
||||
Args:
|
||||
executor: Thread pool executor.
|
||||
client: HTTP client.
|
||||
event_source: SSE event source.
|
||||
|
||||
Returns:
|
||||
Tuple of (read_queue, write_queue).
|
||||
"""
|
||||
read_queue: ReadQueue = queue.Queue()
|
||||
write_queue: WriteQueue = queue.Queue()
|
||||
status_queue: StatusQueue = queue.Queue()
|
||||
|
||||
# Start SSE reader thread
|
||||
executor.submit(self.sse_reader, event_source, read_queue, status_queue)
|
||||
|
||||
# Wait for endpoint URL
|
||||
endpoint_url = self._wait_for_endpoint(status_queue)
|
||||
self.endpoint_url = endpoint_url
|
||||
|
||||
# Start post writer thread
|
||||
executor.submit(self.post_writer, client, endpoint_url, write_queue)
|
||||
|
||||
return read_queue, write_queue
|
||||
|
||||
|
||||
@contextmanager
|
||||
def sse_client(
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: float = 5.0,
|
||||
sse_read_timeout: float = 5 * 60,
|
||||
) -> Generator[tuple[queue.Queue, queue.Queue], None, None]:
|
||||
) -> Generator[tuple[ReadQueue, WriteQueue], None, None]:
|
||||
"""
|
||||
Client transport for SSE.
|
||||
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
|
||||
event before disconnecting. All other HTTP operations are controlled by `timeout`.
|
||||
"""
|
||||
if headers is None:
|
||||
headers = {}
|
||||
|
||||
read_queue: queue.Queue[SessionMessage | Exception | None] = queue.Queue()
|
||||
write_queue: queue.Queue[SessionMessage | Exception | None] = queue.Queue()
|
||||
status_queue: queue.Queue[tuple[str, str | Exception]] = queue.Queue()
|
||||
Args:
|
||||
url: The SSE endpoint URL.
|
||||
headers: Optional headers to include in requests.
|
||||
timeout: HTTP timeout for regular operations.
|
||||
sse_read_timeout: Timeout for SSE read operations.
|
||||
|
||||
Yields:
|
||||
Tuple of (read_queue, write_queue) for message communication.
|
||||
"""
|
||||
transport = SSETransport(url, headers, timeout, sse_read_timeout)
|
||||
|
||||
read_queue: ReadQueue | None = None
|
||||
write_queue: WriteQueue | None = None
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
try:
|
||||
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
|
||||
with create_ssrf_proxy_mcp_http_client(headers=headers) as client:
|
||||
|
||||
with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client:
|
||||
with ssrf_proxy_sse_connect(
|
||||
url, 2, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
|
||||
def sse_reader(status_queue: queue.Queue):
|
||||
try:
|
||||
for sse in event_source.iter_sse():
|
||||
match sse.event:
|
||||
case "endpoint":
|
||||
endpoint_url = urljoin(url, sse.data)
|
||||
logger.info(f"Received endpoint URL: {endpoint_url}")
|
||||
url_parsed = urlparse(url)
|
||||
endpoint_parsed = urlparse(endpoint_url)
|
||||
|
||||
if (
|
||||
url_parsed.netloc != endpoint_parsed.netloc
|
||||
or url_parsed.scheme != endpoint_parsed.scheme
|
||||
):
|
||||
error_msg = (
|
||||
f"Endpoint origin does not match connection origin: {endpoint_url}"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
status_queue.put(("error", ValueError(error_msg)))
|
||||
status_queue.put(("ready", endpoint_url))
|
||||
case "message":
|
||||
try:
|
||||
message = types.JSONRPCMessage.model_validate_json(sse.data)
|
||||
logger.debug(f"Received server message: {message}")
|
||||
except Exception as exc:
|
||||
logger.exception("Error parsing server message")
|
||||
read_queue.put(exc)
|
||||
continue
|
||||
session_message = SessionMessage(message)
|
||||
read_queue.put(session_message)
|
||||
case _:
|
||||
logger.warning(f"Unknown SSE event: {sse.event}")
|
||||
except httpx.ReadError as exc:
|
||||
logger.debug(f"SSE reader shutting down normally: {exc}")
|
||||
except Exception as exc:
|
||||
read_queue.put(exc)
|
||||
finally:
|
||||
read_queue.put(None)
|
||||
|
||||
def post_writer(endpoint_url: str):
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
message = write_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
|
||||
if message is None:
|
||||
break
|
||||
if isinstance(message, Exception):
|
||||
write_queue.put(message)
|
||||
continue
|
||||
response = client.post(
|
||||
endpoint_url,
|
||||
json=message.message.model_dump(
|
||||
by_alias=True,
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
),
|
||||
)
|
||||
response.raise_for_status()
|
||||
logger.debug(f"Client message sent successfully: {response.status_code}")
|
||||
except queue.Empty:
|
||||
continue
|
||||
except httpx.ReadError as exc:
|
||||
logger.debug(f"SSE reader shutting down normally: {exc}")
|
||||
except Exception as exc:
|
||||
logger.exception("Error writing messages")
|
||||
write_queue.put(exc)
|
||||
finally:
|
||||
write_queue.put(None)
|
||||
|
||||
executor.submit(sse_reader, status_queue)
|
||||
try:
|
||||
status, endpoint_url_or_error = status_queue.get(timeout=1)
|
||||
except queue.Empty:
|
||||
raise ValueError("failed to get endpoint URL")
|
||||
if status != "ready":
|
||||
raise ValueError("failed to get endpoint URL")
|
||||
if status == "error" and isinstance(endpoint_url_or_error, Exception):
|
||||
raise endpoint_url_or_error
|
||||
endpoint_url = cast(str, endpoint_url_or_error)
|
||||
executor.submit(post_writer, endpoint_url)
|
||||
read_queue, write_queue = transport.connect(executor, client, event_source)
|
||||
|
||||
yield read_queue, write_queue
|
||||
|
||||
@ -142,8 +290,11 @@ def sse_client(
|
||||
logger.exception("Error connecting to SSE endpoint")
|
||||
raise exc
|
||||
finally:
|
||||
read_queue.put(None)
|
||||
write_queue.put(None)
|
||||
# Clean up queues
|
||||
if read_queue:
|
||||
read_queue.put(None)
|
||||
if write_queue:
|
||||
write_queue.put(None)
|
||||
|
||||
|
||||
def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage) -> None:
|
||||
|
||||
@ -18,7 +18,6 @@ from typing import Any, cast
|
||||
import httpx
|
||||
from httpx_sse import EventSource, ServerSentEvent
|
||||
|
||||
from core.helper.ssrf_proxy import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
|
||||
from core.mcp.types import (
|
||||
ClientMessageMetadata,
|
||||
ErrorData,
|
||||
@ -30,6 +29,7 @@ from core.mcp.types import (
|
||||
RequestId,
|
||||
SessionMessage,
|
||||
)
|
||||
from core.mcp.utils import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -2,6 +2,8 @@ import json
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.web.passport import generate_session_id
|
||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||
@ -9,8 +11,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.mcp import types
|
||||
from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, AppMCPServer, EndUser
|
||||
from models.model import App, AppMCPServer, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
||||
"""
|
||||
@ -19,12 +20,13 @@ Apply to MCP HTTP streamable server with stateless http
|
||||
|
||||
|
||||
class MCPServerReuqestHandler:
|
||||
def __init__(self, app: App, request: types.ClientRequest, user_input_form: list[VariableEntity]):
|
||||
def __init__(self, app: App, request: types.ClientRequest, user_input_form: list[VariableEntity], session: Session):
|
||||
self.app = app
|
||||
self.request = request
|
||||
if not self.app.mcp_server:
|
||||
raise ValueError("MCP server not found")
|
||||
self.mcp_server: AppMCPServer = self.app.mcp_server
|
||||
self._session = session
|
||||
self.end_user = self.retrieve_end_user()
|
||||
self.user_input_form = user_input_form
|
||||
|
||||
@ -35,19 +37,19 @@ class MCPServerReuqestHandler:
|
||||
@property
|
||||
def parameter_schema(self):
|
||||
parameters, required = self._convert_input_form_to_parameters(self.user_input_form)
|
||||
if self.app.mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": parameters,
|
||||
"required": required,
|
||||
}
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "User Input/Question content"},
|
||||
"inputs": {
|
||||
"type": "object",
|
||||
"description": "Allows the entry of various variable values defined by the App. The `inputs` parameter contains multiple key/value pairs, with each key corresponding to a specific variable and each value being the specific value for that variable. If the variable is of file type, specify an object that has the keys described in `files`.", # noqa: E501
|
||||
"default": {},
|
||||
"properties": parameters,
|
||||
"required": required,
|
||||
},
|
||||
**parameters,
|
||||
},
|
||||
"required": ["query", "inputs"],
|
||||
"required": ["query", *required],
|
||||
}
|
||||
|
||||
@property
|
||||
@ -110,9 +112,8 @@ class MCPServerReuqestHandler:
|
||||
session_id=generate_session_id(),
|
||||
external_user_id=self.mcp_server.id,
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
|
||||
self._session.add(end_user)
|
||||
self._session.commit()
|
||||
return types.InitializeResult(
|
||||
protocolVersion=types.LATEST_PROTOCOL_VERSION,
|
||||
capabilities=self.capabilities,
|
||||
@ -140,14 +141,31 @@ class MCPServerReuqestHandler:
|
||||
args = request.params.arguments
|
||||
if not args:
|
||||
raise ValueError("No arguments provided")
|
||||
if self.app.mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}:
|
||||
args = {"inputs": args}
|
||||
else:
|
||||
args = {"query": args["query"], "inputs": {k: v for k, v in args.items() if k != "query"}}
|
||||
response = AppGenerateService.generate(self.app, self.end_user, args, InvokeFrom.MCP_SERVER, streaming=False)
|
||||
if isinstance(response, Mapping):
|
||||
return types.CallToolResult(content=[types.TextContent(text=response["answer"], type="text")])
|
||||
answer = ""
|
||||
if self.app.mode in {
|
||||
AppMode.ADVANCED_CHAT.value,
|
||||
AppMode.COMPLETION.value,
|
||||
AppMode.CHAT.value,
|
||||
AppMode.AGENT_CHAT.value,
|
||||
}:
|
||||
answer = response["answer"]
|
||||
elif self.app.mode in {AppMode.WORKFLOW.value}:
|
||||
answer = json.dumps(response["data"]["outputs"], ensure_ascii=False)
|
||||
else:
|
||||
raise ValueError("Invalid app mode")
|
||||
# Not support image yet
|
||||
return types.CallToolResult(content=[types.TextContent(text=answer, type="text")])
|
||||
return None
|
||||
|
||||
def retrieve_end_user(self):
|
||||
return (
|
||||
db.session.query(EndUser)
|
||||
self._session.query(EndUser)
|
||||
.filter(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
|
||||
.first()
|
||||
)
|
||||
|
||||
@ -177,6 +177,15 @@ class BaseSession(
|
||||
self._receiver_future = self._executor.submit(self._receive_loop)
|
||||
return self
|
||||
|
||||
def check_receiver_status(self) -> None:
|
||||
if self._receiver_future.done():
|
||||
try:
|
||||
# 如果Future已完成,获取结果(如果有异常会在这里抛出)
|
||||
self._receiver_future.result()
|
||||
except Exception as e:
|
||||
# 重新抛出线程中的异常
|
||||
raise e
|
||||
|
||||
def __exit__(
|
||||
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
|
||||
) -> None:
|
||||
@ -199,6 +208,7 @@ class BaseSession(
|
||||
Do not use this method to emit notifications! Use send_notification()
|
||||
instead.
|
||||
"""
|
||||
self.check_receiver_status()
|
||||
|
||||
request_id = self._request_id
|
||||
self._request_id = request_id + 1
|
||||
@ -224,6 +234,8 @@ class BaseSession(
|
||||
response_or_error = response_queue.get(timeout=timeout)
|
||||
break
|
||||
except queue.Empty:
|
||||
# 在等待响应的过程中也检查接收线程状态
|
||||
self.check_receiver_status()
|
||||
continue
|
||||
|
||||
if response_or_error is None:
|
||||
@ -257,6 +269,8 @@ class BaseSession(
|
||||
Emits a notification, which is a one-way message that does not expect
|
||||
a response.
|
||||
"""
|
||||
self.check_receiver_status()
|
||||
|
||||
# Some transport implementations may need to set the related_request_id
|
||||
# to attribute to the notifications to the request that triggered them.
|
||||
jsonrpc_notification = JSONRPCNotification(
|
||||
@ -353,6 +367,7 @@ class BaseSession(
|
||||
continue
|
||||
except Exception as e:
|
||||
logging.exception("Error in message processing loop")
|
||||
raise
|
||||
|
||||
def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None:
|
||||
"""
|
||||
|
||||
112
api/core/mcp/utils.py
Normal file
112
api/core/mcp/utils.py
Normal file
@ -0,0 +1,112 @@
|
||||
import httpx
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
|
||||
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY = True # Default value for HTTP_REQUEST_NODE_SSL_VERIFY is True
|
||||
try:
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
|
||||
http_request_node_ssl_verify_lower = str(HTTP_REQUEST_NODE_SSL_VERIFY).lower()
|
||||
if http_request_node_ssl_verify_lower == "true":
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY = True
|
||||
elif http_request_node_ssl_verify_lower == "false":
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY = False
|
||||
else:
|
||||
raise ValueError("Invalid value. HTTP_REQUEST_NODE_SSL_VERIFY should be 'True' or 'False'")
|
||||
except NameError:
|
||||
HTTP_REQUEST_NODE_SSL_VERIFY = True
|
||||
|
||||
BACKOFF_FACTOR = 0.5
|
||||
STATUS_FORCELIST = [429, 500, 502, 503, 504]
|
||||
|
||||
|
||||
def create_ssrf_proxy_mcp_http_client(
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: httpx.Timeout | None = None,
|
||||
) -> httpx.Client:
|
||||
"""Create an HTTPX client with SSRF proxy configuration for MCP connections.
|
||||
|
||||
Args:
|
||||
headers: Optional headers to include in the client
|
||||
timeout: Optional timeout configuration
|
||||
|
||||
Returns:
|
||||
Configured httpx.Client with proxy settings
|
||||
"""
|
||||
if dify_config.SSRF_PROXY_ALL_URL:
|
||||
return httpx.Client(
|
||||
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
|
||||
headers=headers or {},
|
||||
timeout=timeout,
|
||||
follow_redirects=True,
|
||||
proxy=dify_config.SSRF_PROXY_ALL_URL,
|
||||
)
|
||||
elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
|
||||
proxy_mounts = {
|
||||
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY),
|
||||
"https://": httpx.HTTPTransport(
|
||||
proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY
|
||||
),
|
||||
}
|
||||
return httpx.Client(
|
||||
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
|
||||
headers=headers or {},
|
||||
timeout=timeout,
|
||||
follow_redirects=True,
|
||||
mounts=proxy_mounts,
|
||||
)
|
||||
else:
|
||||
return httpx.Client(
|
||||
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
|
||||
headers=headers or {},
|
||||
timeout=timeout,
|
||||
follow_redirects=True,
|
||||
)
|
||||
|
||||
|
||||
def ssrf_proxy_sse_connect(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||
"""Connect to SSE endpoint with SSRF proxy protection.
|
||||
|
||||
This function creates an SSE connection using the configured proxy settings
|
||||
to prevent SSRF attacks when connecting to external endpoints.
|
||||
|
||||
Args:
|
||||
url: The SSE endpoint URL
|
||||
max_retries: Maximum number of retry attempts
|
||||
**kwargs: Additional arguments passed to the SSE connection
|
||||
|
||||
Returns:
|
||||
EventSource object for SSE streaming
|
||||
"""
|
||||
from httpx_sse import connect_sse
|
||||
|
||||
# Extract client if provided, otherwise create one
|
||||
client = kwargs.pop("client", None)
|
||||
if client is None:
|
||||
# Create client with SSRF proxy configuration
|
||||
timeout = kwargs.pop(
|
||||
"timeout",
|
||||
httpx.Timeout(
|
||||
timeout=dify_config.SSRF_DEFAULT_TIME_OUT,
|
||||
connect=dify_config.SSRF_DEFAULT_CONNECT_TIME_OUT,
|
||||
read=dify_config.SSRF_DEFAULT_READ_TIME_OUT,
|
||||
write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT,
|
||||
),
|
||||
)
|
||||
headers = kwargs.pop("headers", {})
|
||||
client = create_ssrf_proxy_mcp_http_client(headers=headers, timeout=timeout)
|
||||
client_provided = False
|
||||
else:
|
||||
client_provided = True
|
||||
|
||||
# Extract method if provided, default to GET
|
||||
method = kwargs.pop("method", "GET")
|
||||
|
||||
try:
|
||||
return connect_sse(client, method, url, **kwargs)
|
||||
except Exception:
|
||||
# If we created the client, we need to clean it up on error
|
||||
if not client_provided:
|
||||
client.close()
|
||||
raise
|
||||
Reference in New Issue
Block a user