diff --git a/benchmarks/disagg_benchmarks/zmq/test_request.py b/benchmarks/disagg_benchmarks/zmq/test_request.py index 1e3ce27bd8..5aa66ebaf7 100644 --- a/benchmarks/disagg_benchmarks/zmq/test_request.py +++ b/benchmarks/disagg_benchmarks/zmq/test_request.py @@ -7,7 +7,7 @@ import aiohttp # test connect completions we assume prefill and decode are on the same node # 1. node:vllm serve facebook/opt-125m --port 7001 --zmq-server-port 7010 \ # --chat-template ~/vllm/examples/template_chatglm2.jinja -# 2. vllm connect --prefill-addr nodeIp:7010 --decode-addr nodeIp:7010 +# 2. vllm connect --prefill-addr 127.0.0.1:7010 --decode-addr 127.0.0.1:7010 # 3. python test_request.py async def test_connect_completions(session): try: @@ -68,11 +68,12 @@ def is_json(data): return False def extract_data(responseText): + reply = "" if responseText == "": - return "" + return reply if is_json(responseText): return responseText - reply = "" + for data in responseText.split("\n\n"): if data.startswith('data: '): content = data[6:] diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index e9aa540d6a..35fbd517c8 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -78,6 +78,20 @@ async def serve_http(app: FastAPI, return server.shutdown() +def proxy(clients_addr: str, workers_addr: str, + ctx: zmq.asyncio.Context) -> None: + in_socket = ctx.socket(zmq.ROUTER) + in_socket.bind(clients_addr) + out_socket = ctx.socket(zmq.DEALER) + out_socket.bind(workers_addr) + try: + zmq.proxy(in_socket, out_socket) + except zmq.ContextTerminated: + print("proxy terminated") + in_socket.close() + out_socket.close() + + async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None: """Server routine""" logger.info("zmq Server start arg: %s, zmq_server_port: %d", arg, @@ -85,24 +99,15 @@ async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None: workers_addr = "inproc://workers" clients_addr = f"ipc://127.0.0.1:{zmq_server_port}" # Prepare our context and sockets - context = zmq.asyncio.Context() - - # Socket to talk to clients - clients = context.socket(zmq.ROUTER) - clients.bind(clients_addr) - logger.info("ZMQ Server ROUTER started at %s", clients_addr) - # Socket to talk to workers - workers = context.socket(zmq.DEALER) - workers.bind(workers_addr) - logger.info("ZMQ Worker DEALER started at %s", workers_addr) - - tasks = [ - asyncio.create_task(worker_routine(workers_addr, app, context, i)) - for i in range(5) - ] - proxy_task = asyncio.to_thread(zmq.proxy, clients, workers) - + context = zmq.asyncio.Context.instance() try: + tasks = [ + asyncio.create_task(worker_routine(workers_addr, app, context, i)) + for i in range(5) + ] + logger.info("zmq tasks: %s", tasks) + proxy_task = asyncio.to_thread(proxy, clients_addr, workers_addr, + context) await asyncio.gather(*tasks, proxy_task) except KeyboardInterrupt: print("ZMQ Server interrupted") @@ -110,8 +115,6 @@ async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None: print("ZMQError:", e) finally: # We never get here but clean up anyhow - clients.close() - workers.close() context.destroy(linger=0)