feat: upgrade streamable http client

This commit is contained in:
Novice
2025-05-27 13:14:51 +08:00
parent 1fd4839eca
commit 41bbcb9540
16 changed files with 167 additions and 155 deletions

View File

@ -59,7 +59,6 @@ def start_authorization(
metadata: Optional[OAuthMetadata],
client_information: OAuthClientInformation,
redirect_url: str,
scope: Optional[str] = None,
) -> tuple[str, str]:
"""Begins the authorization flow."""
response_type = "code"
@ -85,11 +84,9 @@ def start_authorization(
"code_challenge": code_challenge,
"code_challenge_method": code_challenge_method,
"redirect_uri": redirect_url,
"state": "/tools?provider_id=" + client_information.client_id,
}
if scope:
params["scope"] = scope
authorization_url = f"{authorization_url}?{urllib.parse.urlencode(params)}"
return authorization_url, code_verifier
@ -187,7 +184,6 @@ def auth(
provider: OAuthClientProvider,
server_url: str,
authorization_code: Optional[str] = None,
scope: Optional[str] = None,
) -> dict[str, str]:
"""Orchestrates the full auth flow with a server."""
metadata = discover_oauth_metadata(server_url)
@ -233,7 +229,6 @@ def auth(
metadata,
client_information,
provider.redirect_url,
scope or provider.client_metadata.scope,
)
provider.save_code_verifier(code_verifier)

View File

@ -1,6 +1,6 @@
from typing import Optional
from configs.app_config import DifyConfig
from configs import dify_config
from core.mcp.types import (
OAuthClientInformation,
OAuthClientInformationFull,
@ -11,8 +11,6 @@ from services.tools.mcp_tools_mange_service import MCPToolManageService
LATEST_PROTOCOL_VERSION = "1.0"
dify_config = DifyConfig()
class OAuthClientProvider:
provider_id: str
@ -25,7 +23,7 @@ class OAuthClientProvider:
@property
def redirect_url(self) -> str:
"""The URL to redirect the user agent to after authorization."""
return dify_config.CONSOLE_WEB_URL
return dify_config.CONSOLE_WEB_URL + "/tools"
@property
def client_metadata(self) -> OAuthClientMetadata:
@ -37,7 +35,6 @@ class OAuthClientProvider:
response_types=["code"],
client_name="Dify",
client_uri="https://github.com/langgenius/dify",
scope="read write",
)
def client_information(self) -> Optional[OAuthClientInformation]:
@ -91,7 +88,3 @@ class OAuthClientProvider:
if not mcp_provider:
return ""
return mcp_provider.credentials.get("code_verifier", "")
class UnauthorizedError(Exception):
pass

View File

@ -1,6 +1,5 @@
import logging
import queue
import threading
from collections.abc import Generator
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
@ -42,7 +41,7 @@ def sse_client(
read_queue = queue.Queue()
write_queue = queue.Queue()
status_queue = queue.Queue()
cancel_event = threading.Event()
with ThreadPoolExecutor() as executor:
try:
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
@ -51,54 +50,49 @@ def sse_client(
url, 2, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
) as event_source:
event_source.response.raise_for_status()
logger.debug("SSE connection established")
def sse_reader(status_queue: queue.Queue):
try:
while not cancel_event.is_set():
for sse in event_source.iter_sse():
if cancel_event.is_set():
break
match sse.event:
case "endpoint":
endpoint_url = urljoin(url, sse.data)
logger.info(f"Received endpoint URL: {endpoint_url}")
url_parsed = urlparse(url)
endpoint_parsed = urlparse(endpoint_url)
if (
url_parsed.netloc != endpoint_parsed.netloc
or url_parsed.scheme != endpoint_parsed.scheme
):
error_msg = (
f"Endpoint origin does not match connection origin: {endpoint_url}"
)
logger.error(error_msg)
raise ValueError(error_msg)
status_queue.put(("ready", endpoint_url))
case "message":
try:
message = types.JSONRPCMessage.model_validate_json(sse.data)
logger.debug(f"Received server message: {message}")
except Exception as exc:
logger.exception("Error parsing server message")
read_queue.put(exc)
continue
session_message = SessionMessage(message)
read_queue.put(session_message)
case _:
logger.warning(f"Unknown SSE event: {sse.event}")
for sse in event_source.iter_sse():
match sse.event:
case "endpoint":
endpoint_url = urljoin(url, sse.data)
logger.info(f"Received endpoint URL: {endpoint_url}")
url_parsed = urlparse(url)
endpoint_parsed = urlparse(endpoint_url)
if (
url_parsed.netloc != endpoint_parsed.netloc
or url_parsed.scheme != endpoint_parsed.scheme
):
error_msg = (
f"Endpoint origin does not match connection origin: {endpoint_url}"
)
logger.error(error_msg)
status_queue.put(("error", ValueError(error_msg)))
status_queue.put(("ready", endpoint_url))
case "message":
try:
message = types.JSONRPCMessage.model_validate_json(sse.data)
logger.debug(f"Received server message: {message}")
except Exception as exc:
logger.exception("Error parsing server message")
read_queue.put(exc)
continue
session_message = SessionMessage(message)
read_queue.put(session_message)
case _:
logger.warning(f"Unknown SSE event: {sse.event}")
except httpx.ReadError as exc:
logger.debug(f"SSE reader shutting down normally: {exc}")
except Exception as exc:
if not cancel_event.is_set():
logger.exception("Error reading SSE messages")
read_queue.put(exc)
read_queue.put(exc)
finally:
read_queue.put(None)
def post_writer(endpoint_url: str):
try:
while not cancel_event.is_set():
while True:
try:
message = write_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
if message is None:
@ -113,14 +107,13 @@ def sse_client(
)
response.raise_for_status()
logger.debug(f"Client message sent successfully: {response.status_code}")
if cancel_event.is_set():
break
except queue.Empty:
if cancel_event.is_set():
break
continue
except Exception:
except httpx.ReadError as exc:
logger.debug(f"SSE reader shutting down normally: {exc}")
except Exception as exc:
logger.exception("Error writing messages")
write_queue.put(exc)
finally:
write_queue.put(None)
@ -131,11 +124,12 @@ def sse_client(
raise ValueError("failed to get endpoint URL")
if status != "ready":
raise ValueError("failed to get endpoint URL")
if status == "error":
raise endpoint_url
executor.submit(post_writer, endpoint_url)
try:
yield read_queue, write_queue
finally:
cancel_event.set()
yield read_queue, write_queue
except httpx.HTTPStatusError as exc:
if exc.response.status_code == 401:
raise MCPAuthError()

View File

@ -8,7 +8,6 @@ and session management.
import logging
import queue
import threading
from collections.abc import Callable, Generator
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
@ -106,11 +105,6 @@ class StreamableHTTPTransport:
CONTENT_TYPE: JSON,
**self.headers,
}
self.stop_event = threading.Event()
def stop(self):
"""Signal to stop all operations."""
self.stop_event.set()
def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]:
"""Update headers with session ID if available."""
@ -170,6 +164,9 @@ class StreamableHTTPTransport:
# Put exception in queue that goes to client
server_to_client_queue.put(exc)
return False
elif sse.event == "ping":
logger.debug("Received ping event")
return False
else:
logger.warning(f"Unknown SSE event: {sse.event}")
return False
@ -198,8 +195,6 @@ class StreamableHTTPTransport:
logger.debug("GET SSE connection established")
for sse in event_source.iter_sse():
if self.stop_event.is_set():
break
self._handle_sse_event(sse, server_to_client_queue)
except Exception as exc:
@ -230,8 +225,6 @@ class StreamableHTTPTransport:
logger.debug("Resumption GET SSE connection established")
for sse in event_source.iter_sse():
if self.stop_event.is_set():
break
is_complete = self._handle_sse_event(
sse,
ctx.server_to_client_queue,
@ -300,13 +293,13 @@ class StreamableHTTPTransport:
try:
event_source = EventSource(response)
for sse in event_source.iter_sse():
if self.stop_event.is_set():
break
self._handle_sse_event(
is_complete = self._handle_sse_event(
sse,
ctx.server_to_client_queue,
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
)
if is_complete:
break
except Exception as e:
ctx.server_to_client_queue.put(e)
@ -346,7 +339,7 @@ class StreamableHTTPTransport:
This method processes messages from the client_to_server_queue and sends them to the server.
Responses are written to the server_to_client_queue.
"""
while not self.stop_event.is_set():
while True:
try:
# Read message from client queue with timeout to check stop_event periodically
session_message = client_to_server_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
@ -382,10 +375,8 @@ class StreamableHTTPTransport:
else:
self._handle_post_request(ctx)
except queue.Empty:
# Timeout - continue loop to check stop_event
continue
except Exception as exc:
# Send exception to client
server_to_client_queue.put(exc)
def terminate_session(self, client: httpx.Client) -> None:
@ -478,9 +469,6 @@ def streamablehttp_client(
# Signal threads to stop
client_to_server_queue.put(None)
finally:
# Clean up
transport.stop()
# Clear any remaining items and add None sentinel to unblock any waiting threads
try:
while not client_to_server_queue.empty():

View File

@ -21,7 +21,6 @@ class MCPClient:
tenant_id: str,
authed: bool = True,
authorization_code: Optional[str] = None,
scope: Optional[str] = None,
):
# Initialize info
self.provider_id = provider_id
@ -32,7 +31,6 @@ class MCPClient:
# Authentication info
self.authed = authed
self.authorization_code = authorization_code
self.scope = scope
if authed:
from core.mcp.auth.auth_provider import OAuthClientProvider
@ -49,7 +47,7 @@ class MCPClient:
self._initialized = False
def __enter__(self):
self._initialize(first_try=True)
self._initialize()
self._initialized = True
return self
@ -58,7 +56,6 @@ class MCPClient:
def _initialize(
self,
first_try: bool = True,
):
"""Initialize the client with fallback to SSE if streamable connection fails"""
connection_methods = {"mcp": streamablehttp_client, "sse": sse_client}
@ -71,9 +68,9 @@ class MCPClient:
self.connect_server(client_factory, method_name)
except KeyError:
try:
self.connect_server(streamablehttp_client, "sse")
self.connect_server(sse_client, "sse")
except MCPConnectionError:
self.connect_server(sse_client, "mcp")
self.connect_server(streamablehttp_client, "mcp")
def connect_server(self, client_factory: Callable, method_name: str, first_try: bool = True):
from core.mcp.auth.auth_flow import auth
@ -100,8 +97,8 @@ class MCPClient:
except MCPAuthError:
if not self.authed:
raise
auth(self.provider, self.server_url, self.authorization_code, self.scope)
auth(self.provider, self.server_url, self.authorization_code)
self.token = self.provider.tokens()
if first_try:
return self.connect_server(client_factory, method_name, first_try=False)
@ -134,5 +131,6 @@ class MCPClient:
self._session = None
self._initialized = False
self.exit_stack.close()
except Exception:
except Exception as e:
logging.exception("Error during cleanup")
raise ValueError(f"Error during cleanup: {e}")

View File

@ -2,30 +2,31 @@ import json
from collections.abc import Mapping
from typing import cast
from configs.app_config import DifyConfig
from configs import dify_config
from controllers.web.passport import generate_session_id
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.entities.app_invoke_entities import InvokeFrom
from core.mcp import types
from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND
from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from models.model import App, EndUser
from models.model import App, AppMCPServer, EndUser
from services.app_generate_service import AppGenerateService
"""
Apply to MCP HTTP streamable server with stateless http
"""
dify_config = DifyConfig()
class MCPServerReuqestHandler:
def __init__(self, app: App, request: types.ClientRequest):
def __init__(self, app: App, request: types.ClientRequest, user_input_form: list[VariableEntity]):
self.app = app
self.request = request
if not self.app.mcp_server:
self.mcp_server: AppMCPServer = self.app.mcp_server
if not self.mcp_server:
raise ValueError("MCP server not found")
self.mcp_server = self.app.mcp_server
self.end_user = self.retrieve_end_user()
self.user_input_form = user_input_form
@property
def request_type(self):
@ -33,6 +34,7 @@ class MCPServerReuqestHandler:
@property
def parameter_schema(self):
parameters, required = self._convert_input_form_to_parameters(self.user_input_form)
return {
"type": "object",
"properties": {
@ -41,10 +43,11 @@ class MCPServerReuqestHandler:
"type": "object",
"description": "Allows the entry of various variable values defined by the App. The `inputs` parameter contains multiple key/value pairs, with each key corresponding to a specific variable and each value being the specific value for that variable. If the variable is of file type, specify an object that has the keys described in `files`.", # noqa: E501
"default": {},
# TODO: add input parameters
"properties": parameters,
"required": required,
},
},
"required": ["query"],
"required": "query",
}
@property
@ -152,3 +155,25 @@ class MCPServerReuqestHandler:
.filter(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
.first()
)
def _convert_input_form_to_parameters(self, user_input_form: list[VariableEntity]):
parameters = {}
required = []
for item in user_input_form:
if item.type in (
VariableEntityType.FILE,
VariableEntityType.FILE_LIST,
VariableEntityType.EXTERNAL_DATA_TOOL,
):
continue
if item.required:
required.append(item.variable)
parameters[item.variable]["description"] = self.mcp_server.parameters_dict[item.label]["description"]
if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
parameters[item.variable]["type"] = "string"
elif item.type == VariableEntityType.SELECT:
parameters[item.variable]["type"] = "string"
parameters[item.variable]["enum"] = item.options
elif item.type == VariableEntityType.NUMBER:
parameters[item.variable]["type"] = "number"
return parameters, required

View File

@ -1,7 +1,5 @@
import logging
import queue
import threading
import time
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from contextlib import ExitStack
@ -80,13 +78,10 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
self._completed = False
self._on_complete = on_complete
self._entered = False # Track if we're in a context manager
self._cancel_event = threading.Event()
def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]":
"""Enter the context manager, enabling request cancellation tracking."""
self._entered = True
self._cancel_event = threading.Event()
self._cancel_event.clear()
return self
def __exit__(
@ -101,9 +96,6 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
self._on_complete(self)
finally:
self._entered = False
if not self._cancel_event:
raise RuntimeError("No active cancel scope")
self._cancel_event.set()
def respond(self, response: SendResultT | ErrorData) -> None:
"""Send a response for this request.
@ -117,17 +109,15 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
raise RuntimeError("RequestResponder must be used as a context manager")
assert not self._completed, "Request already responded to"
if not self.cancelled:
self._completed = True
self._completed = True
self._session._send_response(request_id=self.request_id, response=response)
self._session._send_response(request_id=self.request_id, response=response)
def cancel(self) -> None:
"""Cancel this request and mark it as completed."""
if not self._entered:
raise RuntimeError("RequestResponder must be used as a context manager")
self._cancel_event.set()
self._completed = True # Mark as completed so it's removed from in_flight
# Send an error response to indicate cancellation
self._session._send_response(
@ -135,14 +125,6 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
response=ErrorData(code=0, message="Request cancelled", data=None),
)
@property
def in_flight(self) -> bool:
return not self._completed and not self.cancelled
@property
def cancelled(self) -> bool:
return self._cancel_event.is_set()
class BaseSession(
Generic[
@ -184,11 +166,9 @@ class BaseSession(
self._in_flight = {}
self._exit_stack = ExitStack()
self._futures = []
self._request_id_lock = threading.Lock()
def __enter__(self) -> Self:
self._executor = ThreadPoolExecutor()
self._stop_event = threading.Event()
self._receiver_future = self._executor.submit(self._receive_loop)
return self
@ -196,21 +176,8 @@ class BaseSession(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
) -> None:
self._exit_stack.close()
self._stop_event.set()
self._wait_for_futures(timeout=5)
def _wait_for_futures(self, timeout=None):
end_time = time.time() + timeout if timeout else None
for future in list(self._futures):
try:
remaining = end_time - time.time() if end_time else None
if remaining is not None and remaining <= 0:
break
future.result(timeout=remaining)
except Exception as e:
logging.exception(f"Error waiting for task: {e}")
self._read_stream.put(None)
self._write_stream.put(None)
def send_request(
self,
@ -247,7 +214,7 @@ class BaseSession(
timeout = request_read_timeout_seconds.total_seconds()
elif self._session_read_timeout_seconds is not None:
timeout = self._session_read_timeout_seconds.total_seconds()
while not self._stop_event.is_set():
while True:
try:
response_or_error = response_queue.get(timeout=timeout)
break
@ -316,7 +283,7 @@ class BaseSession(
Main message processing loop.
In a real synchronous implementation, this would likely run in a separate thread.
"""
while not self._stop_event.is_set():
while True:
try:
# Attempt to receive a message (this would be blocking in a synchronous context)
message = self._read_stream.get(timeout=DEFAULT_RESPONSE_READ_TIMEOUT)
@ -378,12 +345,9 @@ class BaseSession(
else:
self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}"))
except queue.Empty:
if self._stop_event.is_set():
break
continue
except Exception as e:
logging.exception("Error in message processing loop")
self._stop_event.set()
def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None:
"""

View File

@ -3,12 +3,11 @@ from typing import Any, Protocol
from pydantic import AnyUrl, TypeAdapter
from configs.app_config import DifyConfig
from configs import dify_config
from core.mcp import types
from core.mcp.entities import SUPPORTED_PROTOCOL_VERSIONS, RequestContext
from core.mcp.session.base_session import BaseSession, RequestResponder
dify_config = DifyConfig()
DEFAULT_CLIENT_INFO = types.Implementation(name="Dify", version=dify_config.CURRENT_VERSION)

View File

@ -1177,7 +1177,6 @@ class SessionMessage:
class OAuthClientMetadata(BaseModel):
client_name: str
redirect_uris: list[str]
scope: str
grant_types: Optional[list[str]] = None
response_types: Optional[list[str]] = None
token_endpoint_auth_method: Optional[str] = None