feat: add multi app mode's server support

This commit is contained in:
Novice
2025-06-12 16:22:11 +08:00
parent 642693c79b
commit 0f668be415
7 changed files with 408 additions and 201 deletions

View File

@ -1,5 +1,6 @@
from flask_restful import Resource, reqparse
from pydantic import ValidationError
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from controllers.mcp import api
@ -59,8 +60,9 @@ class MCPAppApi(Resource):
request = ClientRequest.model_validate(args)
except ValidationError as e:
raise ValueError(f"Invalid MCP request: {str(e)}")
mcp_server_handler = MCPServerReuqestHandler(app, request, user_input_form)
return helper.compact_generate_response(mcp_server_handler.handle())
with Session(db.engine) as session:
mcp_server_handler = MCPServerReuqestHandler(app, request, user_input_form, session)
return helper.compact_generate_response(mcp_server_handler.handle())
api.add_resource(MCPAppApi, "/server/<string:server_code>/mcp")

View File

@ -108,94 +108,3 @@ def delete(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
def head(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
return make_request("HEAD", url, max_retries=max_retries, **kwargs)
def create_ssrf_proxy_mcp_http_client(
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
) -> httpx.Client:
"""Create an HTTPX client with SSRF proxy configuration for MCP connections.
Args:
headers: Optional headers to include in the client
timeout: Optional timeout configuration
Returns:
Configured httpx.Client with proxy settings
"""
if dify_config.SSRF_PROXY_ALL_URL:
return httpx.Client(
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
headers=headers or {},
timeout=timeout,
follow_redirects=True,
proxy=dify_config.SSRF_PROXY_ALL_URL,
)
elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
proxy_mounts = {
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY),
"https://": httpx.HTTPTransport(
proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY
),
}
return httpx.Client(
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
headers=headers or {},
timeout=timeout,
follow_redirects=True,
mounts=proxy_mounts,
)
else:
return httpx.Client(
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
headers=headers or {},
timeout=timeout,
follow_redirects=True,
)
def ssrf_proxy_sse_connect(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
"""Connect to SSE endpoint with SSRF proxy protection.
This function creates an SSE connection using the configured proxy settings
to prevent SSRF attacks when connecting to external endpoints.
Args:
url: The SSE endpoint URL
max_retries: Maximum number of retry attempts
**kwargs: Additional arguments passed to the SSE connection
Returns:
EventSource object for SSE streaming
"""
from httpx_sse import connect_sse
# Extract client if provided, otherwise create one
client = kwargs.pop("client", None)
if client is None:
# Create client with SSRF proxy configuration
timeout = kwargs.pop(
"timeout",
httpx.Timeout(
timeout=dify_config.SSRF_DEFAULT_TIME_OUT,
connect=dify_config.SSRF_DEFAULT_CONNECT_TIME_OUT,
read=dify_config.SSRF_DEFAULT_READ_TIME_OUT,
write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT,
),
)
headers = kwargs.pop("headers", {})
client = create_ssrf_proxy_mcp_http_client(headers=headers, timeout=timeout)
client_provided = False
else:
client_provided = True
# Extract method if provided, default to GET
method = kwargs.pop("method", "GET")
try:
return connect_sse(client, method, url, **kwargs)
except Exception as e:
# If we created the client, we need to clean it up on error
if not client_provided:
client.close()
raise

View File

@ -9,128 +9,276 @@ from urllib.parse import urljoin, urlparse
import httpx
from sseclient import SSEClient
from core.helper.ssrf_proxy import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
from core.mcp import types
from core.mcp.error import MCPAuthError, MCPConnectionError
from core.mcp.types import SessionMessage
from core.mcp.utils import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
logger = logging.getLogger(__name__)
DEFAULT_QUEUE_READ_TIMEOUT = 3
# Type aliases for better readability
ReadQueue = queue.Queue[SessionMessage | Exception | None]
WriteQueue = queue.Queue[SessionMessage | Exception | None]
StatusQueue = queue.Queue[tuple[str, str | Exception]]
def remove_request_params(url: str) -> str:
"""Remove request parameters from URL, keeping only the path."""
return urljoin(url, urlparse(url).path)
class SSETransport:
"""SSE client transport implementation."""
def __init__(
self,
url: str,
headers: dict[str, Any] | None = None,
timeout: float = 5.0,
sse_read_timeout: float = 5 * 60,
) -> None:
"""Initialize the SSE transport.
Args:
url: The SSE endpoint URL.
headers: Optional headers to include in requests.
timeout: HTTP timeout for regular operations.
sse_read_timeout: Timeout for SSE read operations.
"""
self.url = url
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
self.endpoint_url: str | None = None
def _validate_endpoint_url(self, endpoint_url: str) -> bool:
"""Validate that the endpoint URL matches the connection origin.
Args:
endpoint_url: The endpoint URL to validate.
Returns:
True if valid, False otherwise.
"""
url_parsed = urlparse(self.url)
endpoint_parsed = urlparse(endpoint_url)
return url_parsed.netloc == endpoint_parsed.netloc and url_parsed.scheme == endpoint_parsed.scheme
def _handle_endpoint_event(self, sse_data: str, status_queue: StatusQueue) -> None:
"""Handle an 'endpoint' SSE event.
Args:
sse_data: The SSE event data.
status_queue: Queue to put status updates.
"""
endpoint_url = urljoin(self.url, sse_data)
logger.info(f"Received endpoint URL: {endpoint_url}")
if not self._validate_endpoint_url(endpoint_url):
error_msg = f"Endpoint origin does not match connection origin: {endpoint_url}"
logger.error(error_msg)
status_queue.put(("error", ValueError(error_msg)))
return
status_queue.put(("ready", endpoint_url))
def _handle_message_event(self, sse_data: str, read_queue: ReadQueue) -> None:
"""Handle a 'message' SSE event.
Args:
sse_data: The SSE event data.
read_queue: Queue to put parsed messages.
"""
try:
message = types.JSONRPCMessage.model_validate_json(sse_data)
logger.debug(f"Received server message: {message}")
session_message = SessionMessage(message)
read_queue.put(session_message)
except Exception as exc:
logger.exception("Error parsing server message")
read_queue.put(exc)
def _handle_sse_event(self, sse, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
"""Handle a single SSE event.
Args:
sse: The SSE event object.
read_queue: Queue for message events.
status_queue: Queue for status events.
"""
match sse.event:
case "endpoint":
self._handle_endpoint_event(sse.data, status_queue)
case "message":
self._handle_message_event(sse.data, read_queue)
case _:
logger.warning(f"Unknown SSE event: {sse.event}")
def sse_reader(self, event_source, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
"""Read and process SSE events.
Args:
event_source: The SSE event source.
read_queue: Queue to put received messages.
status_queue: Queue to put status updates.
"""
try:
for sse in event_source.iter_sse():
self._handle_sse_event(sse, read_queue, status_queue)
except httpx.ReadError as exc:
logger.debug(f"SSE reader shutting down normally: {exc}")
except Exception as exc:
read_queue.put(exc)
finally:
read_queue.put(None)
def _send_message(self, client: httpx.Client, endpoint_url: str, message: SessionMessage) -> None:
"""Send a single message to the server.
Args:
client: HTTP client to use.
endpoint_url: The endpoint URL to send to.
message: The message to send.
"""
response = client.post(
endpoint_url,
json=message.message.model_dump(
by_alias=True,
mode="json",
exclude_none=True,
),
)
response.raise_for_status()
logger.debug(f"Client message sent successfully: {response.status_code}")
def post_writer(self, client: httpx.Client, endpoint_url: str, write_queue: WriteQueue) -> None:
"""Handle writing messages to the server.
Args:
client: HTTP client to use.
endpoint_url: The endpoint URL to send messages to.
write_queue: Queue to read messages from.
"""
try:
while True:
try:
message = write_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
if message is None:
break
if isinstance(message, Exception):
write_queue.put(message)
continue
self._send_message(client, endpoint_url, message)
except queue.Empty:
continue
except httpx.ReadError as exc:
logger.debug(f"Post writer shutting down normally: {exc}")
except Exception as exc:
logger.exception("Error writing messages")
write_queue.put(exc)
finally:
write_queue.put(None)
def _wait_for_endpoint(self, status_queue: StatusQueue) -> str:
"""Wait for the endpoint URL from the status queue.
Args:
status_queue: Queue to read status from.
Returns:
The endpoint URL.
Raises:
ValueError: If endpoint URL is not received or there's an error.
"""
try:
status, endpoint_url_or_error = status_queue.get(timeout=1)
except queue.Empty:
raise ValueError("failed to get endpoint URL")
if status != "ready":
raise ValueError("failed to get endpoint URL")
if status == "error" and isinstance(endpoint_url_or_error, Exception):
raise endpoint_url_or_error
return cast(str, endpoint_url_or_error)
def connect(
self,
executor: ThreadPoolExecutor,
client: httpx.Client,
event_source,
) -> tuple[ReadQueue, WriteQueue]:
"""Establish connection and start worker threads.
Args:
executor: Thread pool executor.
client: HTTP client.
event_source: SSE event source.
Returns:
Tuple of (read_queue, write_queue).
"""
read_queue: ReadQueue = queue.Queue()
write_queue: WriteQueue = queue.Queue()
status_queue: StatusQueue = queue.Queue()
# Start SSE reader thread
executor.submit(self.sse_reader, event_source, read_queue, status_queue)
# Wait for endpoint URL
endpoint_url = self._wait_for_endpoint(status_queue)
self.endpoint_url = endpoint_url
# Start post writer thread
executor.submit(self.post_writer, client, endpoint_url, write_queue)
return read_queue, write_queue
@contextmanager
def sse_client(
url: str,
headers: dict[str, Any] | None = None,
timeout: float = 5.0,
sse_read_timeout: float = 5 * 60,
) -> Generator[tuple[queue.Queue, queue.Queue], None, None]:
) -> Generator[tuple[ReadQueue, WriteQueue], None, None]:
"""
Client transport for SSE.
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
event before disconnecting. All other HTTP operations are controlled by `timeout`.
"""
if headers is None:
headers = {}
read_queue: queue.Queue[SessionMessage | Exception | None] = queue.Queue()
write_queue: queue.Queue[SessionMessage | Exception | None] = queue.Queue()
status_queue: queue.Queue[tuple[str, str | Exception]] = queue.Queue()
Args:
url: The SSE endpoint URL.
headers: Optional headers to include in requests.
timeout: HTTP timeout for regular operations.
sse_read_timeout: Timeout for SSE read operations.
Yields:
Tuple of (read_queue, write_queue) for message communication.
"""
transport = SSETransport(url, headers, timeout, sse_read_timeout)
read_queue: ReadQueue | None = None
write_queue: WriteQueue | None = None
with ThreadPoolExecutor() as executor:
try:
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
with create_ssrf_proxy_mcp_http_client(headers=headers) as client:
with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client:
with ssrf_proxy_sse_connect(
url, 2, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
) as event_source:
event_source.response.raise_for_status()
def sse_reader(status_queue: queue.Queue):
try:
for sse in event_source.iter_sse():
match sse.event:
case "endpoint":
endpoint_url = urljoin(url, sse.data)
logger.info(f"Received endpoint URL: {endpoint_url}")
url_parsed = urlparse(url)
endpoint_parsed = urlparse(endpoint_url)
if (
url_parsed.netloc != endpoint_parsed.netloc
or url_parsed.scheme != endpoint_parsed.scheme
):
error_msg = (
f"Endpoint origin does not match connection origin: {endpoint_url}"
)
logger.error(error_msg)
status_queue.put(("error", ValueError(error_msg)))
status_queue.put(("ready", endpoint_url))
case "message":
try:
message = types.JSONRPCMessage.model_validate_json(sse.data)
logger.debug(f"Received server message: {message}")
except Exception as exc:
logger.exception("Error parsing server message")
read_queue.put(exc)
continue
session_message = SessionMessage(message)
read_queue.put(session_message)
case _:
logger.warning(f"Unknown SSE event: {sse.event}")
except httpx.ReadError as exc:
logger.debug(f"SSE reader shutting down normally: {exc}")
except Exception as exc:
read_queue.put(exc)
finally:
read_queue.put(None)
def post_writer(endpoint_url: str):
try:
while True:
try:
message = write_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
if message is None:
break
if isinstance(message, Exception):
write_queue.put(message)
continue
response = client.post(
endpoint_url,
json=message.message.model_dump(
by_alias=True,
mode="json",
exclude_none=True,
),
)
response.raise_for_status()
logger.debug(f"Client message sent successfully: {response.status_code}")
except queue.Empty:
continue
except httpx.ReadError as exc:
logger.debug(f"SSE reader shutting down normally: {exc}")
except Exception as exc:
logger.exception("Error writing messages")
write_queue.put(exc)
finally:
write_queue.put(None)
executor.submit(sse_reader, status_queue)
try:
status, endpoint_url_or_error = status_queue.get(timeout=1)
except queue.Empty:
raise ValueError("failed to get endpoint URL")
if status != "ready":
raise ValueError("failed to get endpoint URL")
if status == "error" and isinstance(endpoint_url_or_error, Exception):
raise endpoint_url_or_error
endpoint_url = cast(str, endpoint_url_or_error)
executor.submit(post_writer, endpoint_url)
read_queue, write_queue = transport.connect(executor, client, event_source)
yield read_queue, write_queue
@ -142,8 +290,11 @@ def sse_client(
logger.exception("Error connecting to SSE endpoint")
raise exc
finally:
read_queue.put(None)
write_queue.put(None)
# Clean up queues
if read_queue:
read_queue.put(None)
if write_queue:
write_queue.put(None)
def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage) -> None:

View File

@ -18,7 +18,6 @@ from typing import Any, cast
import httpx
from httpx_sse import EventSource, ServerSentEvent
from core.helper.ssrf_proxy import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
from core.mcp.types import (
ClientMessageMetadata,
ErrorData,
@ -30,6 +29,7 @@ from core.mcp.types import (
RequestId,
SessionMessage,
)
from core.mcp.utils import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
logger = logging.getLogger(__name__)

View File

@ -2,6 +2,8 @@ import json
from collections.abc import Mapping
from typing import Any, cast
from sqlalchemy.orm import Session
from configs import dify_config
from controllers.web.passport import generate_session_id
from core.app.app_config.entities import VariableEntity, VariableEntityType
@ -9,8 +11,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.mcp import types
from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND
from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from models.model import App, AppMCPServer, EndUser
from models.model import App, AppMCPServer, AppMode, EndUser
from services.app_generate_service import AppGenerateService
"""
@ -19,12 +20,13 @@ Apply to MCP HTTP streamable server with stateless http
class MCPServerReuqestHandler:
def __init__(self, app: App, request: types.ClientRequest, user_input_form: list[VariableEntity]):
def __init__(self, app: App, request: types.ClientRequest, user_input_form: list[VariableEntity], session: Session):
self.app = app
self.request = request
if not self.app.mcp_server:
raise ValueError("MCP server not found")
self.mcp_server: AppMCPServer = self.app.mcp_server
self._session = session
self.end_user = self.retrieve_end_user()
self.user_input_form = user_input_form
@ -35,19 +37,19 @@ class MCPServerReuqestHandler:
@property
def parameter_schema(self):
parameters, required = self._convert_input_form_to_parameters(self.user_input_form)
if self.app.mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}:
return {
"type": "object",
"properties": parameters,
"required": required,
}
return {
"type": "object",
"properties": {
"query": {"type": "string", "description": "User Input/Question content"},
"inputs": {
"type": "object",
"description": "Allows the entry of various variable values defined by the App. The `inputs` parameter contains multiple key/value pairs, with each key corresponding to a specific variable and each value being the specific value for that variable. If the variable is of file type, specify an object that has the keys described in `files`.", # noqa: E501
"default": {},
"properties": parameters,
"required": required,
},
**parameters,
},
"required": ["query", "inputs"],
"required": ["query", *required],
}
@property
@ -110,9 +112,8 @@ class MCPServerReuqestHandler:
session_id=generate_session_id(),
external_user_id=self.mcp_server.id,
)
db.session.add(end_user)
db.session.commit()
self._session.add(end_user)
self._session.commit()
return types.InitializeResult(
protocolVersion=types.LATEST_PROTOCOL_VERSION,
capabilities=self.capabilities,
@ -140,14 +141,31 @@ class MCPServerReuqestHandler:
args = request.params.arguments
if not args:
raise ValueError("No arguments provided")
if self.app.mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}:
args = {"inputs": args}
else:
args = {"query": args["query"], "inputs": {k: v for k, v in args.items() if k != "query"}}
response = AppGenerateService.generate(self.app, self.end_user, args, InvokeFrom.MCP_SERVER, streaming=False)
if isinstance(response, Mapping):
return types.CallToolResult(content=[types.TextContent(text=response["answer"], type="text")])
answer = ""
if self.app.mode in {
AppMode.ADVANCED_CHAT.value,
AppMode.COMPLETION.value,
AppMode.CHAT.value,
AppMode.AGENT_CHAT.value,
}:
answer = response["answer"]
elif self.app.mode in {AppMode.WORKFLOW.value}:
answer = json.dumps(response["data"]["outputs"], ensure_ascii=False)
else:
raise ValueError("Invalid app mode")
# Not support image yet
return types.CallToolResult(content=[types.TextContent(text=answer, type="text")])
return None
def retrieve_end_user(self):
return (
db.session.query(EndUser)
self._session.query(EndUser)
.filter(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
.first()
)

View File

@ -177,6 +177,15 @@ class BaseSession(
self._receiver_future = self._executor.submit(self._receive_loop)
return self
def check_receiver_status(self) -> None:
if self._receiver_future.done():
try:
# 如果Future已完成获取结果如果有异常会在这里抛出
self._receiver_future.result()
except Exception as e:
# 重新抛出线程中的异常
raise e
def __exit__(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
) -> None:
@ -199,6 +208,7 @@ class BaseSession(
Do not use this method to emit notifications! Use send_notification()
instead.
"""
self.check_receiver_status()
request_id = self._request_id
self._request_id = request_id + 1
@ -224,6 +234,8 @@ class BaseSession(
response_or_error = response_queue.get(timeout=timeout)
break
except queue.Empty:
# 在等待响应的过程中也检查接收线程状态
self.check_receiver_status()
continue
if response_or_error is None:
@ -257,6 +269,8 @@ class BaseSession(
Emits a notification, which is a one-way message that does not expect
a response.
"""
self.check_receiver_status()
# Some transport implementations may need to set the related_request_id
# to attribute to the notifications to the request that triggered them.
jsonrpc_notification = JSONRPCNotification(
@ -353,6 +367,7 @@ class BaseSession(
continue
except Exception as e:
logging.exception("Error in message processing loop")
raise
def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None:
"""

112
api/core/mcp/utils.py Normal file
View File

@ -0,0 +1,112 @@
import httpx
from configs import dify_config
SSRF_DEFAULT_MAX_RETRIES = dify_config.SSRF_DEFAULT_MAX_RETRIES
HTTP_REQUEST_NODE_SSL_VERIFY = True # Default value for HTTP_REQUEST_NODE_SSL_VERIFY is True
try:
HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
http_request_node_ssl_verify_lower = str(HTTP_REQUEST_NODE_SSL_VERIFY).lower()
if http_request_node_ssl_verify_lower == "true":
HTTP_REQUEST_NODE_SSL_VERIFY = True
elif http_request_node_ssl_verify_lower == "false":
HTTP_REQUEST_NODE_SSL_VERIFY = False
else:
raise ValueError("Invalid value. HTTP_REQUEST_NODE_SSL_VERIFY should be 'True' or 'False'")
except NameError:
HTTP_REQUEST_NODE_SSL_VERIFY = True
BACKOFF_FACTOR = 0.5
STATUS_FORCELIST = [429, 500, 502, 503, 504]
def create_ssrf_proxy_mcp_http_client(
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
) -> httpx.Client:
"""Create an HTTPX client with SSRF proxy configuration for MCP connections.
Args:
headers: Optional headers to include in the client
timeout: Optional timeout configuration
Returns:
Configured httpx.Client with proxy settings
"""
if dify_config.SSRF_PROXY_ALL_URL:
return httpx.Client(
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
headers=headers or {},
timeout=timeout,
follow_redirects=True,
proxy=dify_config.SSRF_PROXY_ALL_URL,
)
elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
proxy_mounts = {
"http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY),
"https://": httpx.HTTPTransport(
proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY
),
}
return httpx.Client(
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
headers=headers or {},
timeout=timeout,
follow_redirects=True,
mounts=proxy_mounts,
)
else:
return httpx.Client(
verify=HTTP_REQUEST_NODE_SSL_VERIFY,
headers=headers or {},
timeout=timeout,
follow_redirects=True,
)
def ssrf_proxy_sse_connect(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
"""Connect to SSE endpoint with SSRF proxy protection.
This function creates an SSE connection using the configured proxy settings
to prevent SSRF attacks when connecting to external endpoints.
Args:
url: The SSE endpoint URL
max_retries: Maximum number of retry attempts
**kwargs: Additional arguments passed to the SSE connection
Returns:
EventSource object for SSE streaming
"""
from httpx_sse import connect_sse
# Extract client if provided, otherwise create one
client = kwargs.pop("client", None)
if client is None:
# Create client with SSRF proxy configuration
timeout = kwargs.pop(
"timeout",
httpx.Timeout(
timeout=dify_config.SSRF_DEFAULT_TIME_OUT,
connect=dify_config.SSRF_DEFAULT_CONNECT_TIME_OUT,
read=dify_config.SSRF_DEFAULT_READ_TIME_OUT,
write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT,
),
)
headers = kwargs.pop("headers", {})
client = create_ssrf_proxy_mcp_http_client(headers=headers, timeout=timeout)
client_provided = False
else:
client_provided = True
# Extract method if provided, default to GET
method = kwargs.pop("method", "GET")
try:
return connect_sse(client, method, url, **kwargs)
except Exception:
# If we created the client, we need to clean it up on error
if not client_provided:
client.close()
raise