create proxy sockets in the proxy function for thread safety
Signed-off-by: clark <panf2333@gmail.com>
This commit is contained in:
@ -7,7 +7,7 @@ import aiohttp
|
||||
# test connect completions we assume prefill and decode are on the same node
|
||||
# 1. node:vllm serve facebook/opt-125m --port 7001 --zmq-server-port 7010 \
|
||||
# --chat-template ~/vllm/examples/template_chatglm2.jinja
|
||||
# 2. vllm connect --prefill-addr nodeIp:7010 --decode-addr nodeIp:7010
|
||||
# 2. vllm connect --prefill-addr 127.0.0.1:7010 --decode-addr 127.0.0.1:7010
|
||||
# 3. python test_request.py
|
||||
async def test_connect_completions(session):
|
||||
try:
|
||||
@ -68,11 +68,12 @@ def is_json(data):
|
||||
return False
|
||||
|
||||
def extract_data(responseText):
|
||||
reply = ""
|
||||
if responseText == "":
|
||||
return ""
|
||||
return reply
|
||||
if is_json(responseText):
|
||||
return responseText
|
||||
reply = ""
|
||||
|
||||
for data in responseText.split("\n\n"):
|
||||
if data.startswith('data: '):
|
||||
content = data[6:]
|
||||
|
||||
@ -78,6 +78,20 @@ async def serve_http(app: FastAPI,
|
||||
return server.shutdown()
|
||||
|
||||
|
||||
def proxy(clients_addr: str, workers_addr: str,
|
||||
ctx: zmq.asyncio.Context) -> None:
|
||||
in_socket = ctx.socket(zmq.ROUTER)
|
||||
in_socket.bind(clients_addr)
|
||||
out_socket = ctx.socket(zmq.DEALER)
|
||||
out_socket.bind(workers_addr)
|
||||
try:
|
||||
zmq.proxy(in_socket, out_socket)
|
||||
except zmq.ContextTerminated:
|
||||
print("proxy terminated")
|
||||
in_socket.close()
|
||||
out_socket.close()
|
||||
|
||||
|
||||
async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None:
|
||||
"""Server routine"""
|
||||
logger.info("zmq Server start arg: %s, zmq_server_port: %d", arg,
|
||||
@ -85,24 +99,15 @@ async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None:
|
||||
workers_addr = "inproc://workers"
|
||||
clients_addr = f"ipc://127.0.0.1:{zmq_server_port}"
|
||||
# Prepare our context and sockets
|
||||
context = zmq.asyncio.Context()
|
||||
|
||||
# Socket to talk to clients
|
||||
clients = context.socket(zmq.ROUTER)
|
||||
clients.bind(clients_addr)
|
||||
logger.info("ZMQ Server ROUTER started at %s", clients_addr)
|
||||
# Socket to talk to workers
|
||||
workers = context.socket(zmq.DEALER)
|
||||
workers.bind(workers_addr)
|
||||
logger.info("ZMQ Worker DEALER started at %s", workers_addr)
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(worker_routine(workers_addr, app, context, i))
|
||||
for i in range(5)
|
||||
]
|
||||
proxy_task = asyncio.to_thread(zmq.proxy, clients, workers)
|
||||
|
||||
context = zmq.asyncio.Context.instance()
|
||||
try:
|
||||
tasks = [
|
||||
asyncio.create_task(worker_routine(workers_addr, app, context, i))
|
||||
for i in range(5)
|
||||
]
|
||||
logger.info("zmq tasks: %s", tasks)
|
||||
proxy_task = asyncio.to_thread(proxy, clients_addr, workers_addr,
|
||||
context)
|
||||
await asyncio.gather(*tasks, proxy_task)
|
||||
except KeyboardInterrupt:
|
||||
print("ZMQ Server interrupted")
|
||||
@ -110,8 +115,6 @@ async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None:
|
||||
print("ZMQError:", e)
|
||||
finally:
|
||||
# We never get here but clean up anyhow
|
||||
clients.close()
|
||||
workers.close()
|
||||
context.destroy(linger=0)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user