create proxy sockets in the proxy function for thread safety

Signed-off-by: clark <panf2333@gmail.com>
This commit is contained in:
clark
2025-01-11 23:10:15 +08:00
parent 7fbf70db57
commit ee6607332e
2 changed files with 26 additions and 22 deletions

View File

@ -7,7 +7,7 @@ import aiohttp
# test connect completions we assume prefill and decode are on the same node
# 1. node:vllm serve facebook/opt-125m --port 7001 --zmq-server-port 7010 \
# --chat-template ~/vllm/examples/template_chatglm2.jinja
# 2. vllm connect --prefill-addr nodeIp:7010 --decode-addr nodeIp:7010
# 2. vllm connect --prefill-addr 127.0.0.1:7010 --decode-addr 127.0.0.1:7010
# 3. python test_request.py
async def test_connect_completions(session):
try:
@ -68,11 +68,12 @@ def is_json(data):
return False
def extract_data(responseText):
reply = ""
if responseText == "":
return ""
return reply
if is_json(responseText):
return responseText
reply = ""
for data in responseText.split("\n\n"):
if data.startswith('data: '):
content = data[6:]

View File

@ -78,6 +78,20 @@ async def serve_http(app: FastAPI,
return server.shutdown()
def proxy(clients_addr: str, workers_addr: str,
ctx: zmq.asyncio.Context) -> None:
in_socket = ctx.socket(zmq.ROUTER)
in_socket.bind(clients_addr)
out_socket = ctx.socket(zmq.DEALER)
out_socket.bind(workers_addr)
try:
zmq.proxy(in_socket, out_socket)
except zmq.ContextTerminated:
print("proxy terminated")
in_socket.close()
out_socket.close()
async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None:
"""Server routine"""
logger.info("zmq Server start arg: %s, zmq_server_port: %d", arg,
@ -85,24 +99,15 @@ async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None:
workers_addr = "inproc://workers"
clients_addr = f"ipc://127.0.0.1:{zmq_server_port}"
# Prepare our context and sockets
context = zmq.asyncio.Context()
# Socket to talk to clients
clients = context.socket(zmq.ROUTER)
clients.bind(clients_addr)
logger.info("ZMQ Server ROUTER started at %s", clients_addr)
# Socket to talk to workers
workers = context.socket(zmq.DEALER)
workers.bind(workers_addr)
logger.info("ZMQ Worker DEALER started at %s", workers_addr)
tasks = [
asyncio.create_task(worker_routine(workers_addr, app, context, i))
for i in range(5)
]
proxy_task = asyncio.to_thread(zmq.proxy, clients, workers)
context = zmq.asyncio.Context.instance()
try:
tasks = [
asyncio.create_task(worker_routine(workers_addr, app, context, i))
for i in range(5)
]
logger.info("zmq tasks: %s", tasks)
proxy_task = asyncio.to_thread(proxy, clients_addr, workers_addr,
context)
await asyncio.gather(*tasks, proxy_task)
except KeyboardInterrupt:
print("ZMQ Server interrupted")
@ -110,8 +115,6 @@ async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None:
print("ZMQError:", e)
finally:
# We never get here but clean up anyhow
clients.close()
workers.close()
context.destroy(linger=0)