Compare commits

..

7 Commits

Author SHA1 Message Date
f0945e311d stash
Signed-off-by: Robert Shaw <robshaw@redhat.com>
2025-07-24 00:33:37 +00:00
4ec76caafa updated
Signed-off-by: Robert Shaw <robshaw@redhat.com>
2025-07-23 20:02:41 +00:00
1588294a88 updated
Signed-off-by: Robert Shaw <robshaw@redhat.com>
2025-07-23 18:58:49 +00:00
e82e9afeb7 updated
Signed-off-by: Robert Shaw <robshaw@redhat.com>
2025-07-23 18:43:20 +00:00
10abfaf309 Merge branch 'fix-connector-agg' into debug-logging 2025-07-23 18:20:39 +00:00
9ff1a2b537 [BugFix] Fix KVConnector TP worker aggregation
Signed-off-by: Nick Hill <nhill@redhat.com>
2025-07-23 18:29:06 +01:00
0abe10e4a7 updated
Signed-off-by: Robert Shaw <robshaw@redhat.com>
2025-07-23 15:21:46 +00:00
14 changed files with 66 additions and 495 deletions

View File

@ -166,7 +166,6 @@ steps:
- tests/v1/test_async_llm_dp.py
- tests/v1/test_external_lb_dp.py
- tests/v1/test_internal_lb_dp.py
- tests/v1/test_hybrid_lb_dp.py
- tests/v1/engine/test_engine_core_client.py
commands:
# test with tp=2 and external_dp=2
@ -179,7 +178,6 @@ steps:
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/test_internal_lb_dp.py
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/test_hybrid_lb_dp.py
- pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp
- pytest -v -s distributed/test_utils.py
- pytest -v -s compile/test_basic_correctness.py

View File

@ -565,8 +565,8 @@ def test_engine_core_proc_instantiation_cuda_empty(
from vllm.v1.engine.utils import EngineZmqAddresses
def mock_startup_handshake(self, handshake_socket, local_client,
headless, parallel_config):
def mock_startup_handshake(self, handshake_socket, on_head_node,
parallel_config):
return EngineZmqAddresses(inputs=["tcp://127.0.0.1:5555"],
outputs=["tcp://127.0.0.1:5556"],
coordinator_input=None,

View File

@ -1,352 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import os
import threading
import time
from contextlib import AsyncExitStack
import openai # use the official client for correctness check
import pytest
import pytest_asyncio
from tests.utils import RemoteOpenAIServer
from tests.v1.test_utils import check_request_balancing
from vllm.platforms import Platform
MODEL_NAME = "ibm-research/PowerMoE-3b"
# Number of data parallel ranks for hybrid LB testing (4 total)
DP_SIZE = int(os.getenv("DP_SIZE", "4"))
# Default tensor parallel size to use
TP_SIZE = int(os.getenv("TP_SIZE", "1"))
# Number of nodes (2 nodes, each with 2 DP ranks)
NUM_NODES = 2
DP_SIZE_LOCAL = DP_SIZE // NUM_NODES # 2 ranks per node
class HybridLBServerManager:
"""Manages hybrid data parallel vLLM server instances where each node
runs a single logical API server that balances requests only to the
DP engines running on that same node."""
def __init__(self,
model_name: str,
dp_size: int,
api_server_count: int,
base_server_args: list,
dp_size_local: int = DP_SIZE_LOCAL,
tp_size: int = TP_SIZE):
self.model_name = model_name
self.dp_size = dp_size
self.dp_size_local = dp_size_local
self.tp_size = tp_size
self.api_server_count = api_server_count
self.base_server_args = base_server_args
self.servers: list[tuple[RemoteOpenAIServer, list[str]]] = []
self.server_threads: list[threading.Thread] = []
self.num_nodes = dp_size // dp_size_local
def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]:
"""Start all server instances for hybrid LB mode."""
for node_id in range(self.num_nodes):
# Create server args for this specific node
server_args = self.base_server_args.copy()
# Calculate start rank for this node
start_rank = node_id * self.dp_size_local
# Add hybrid LB specific arguments
server_args.extend([
"--data-parallel-size",
str(self.dp_size),
"--data-parallel-size-local",
str(self.dp_size_local),
"--data-parallel-start-rank",
str(start_rank),
"--data-parallel-hybrid-lb", # Enable hybrid LB mode
"--tensor-parallel-size",
str(self.tp_size),
"--port",
str(8000 + node_id), # Different port for each node
"--api-server-count",
str(self.api_server_count),
"--data-parallel-address",
"127.0.0.1",
"--data-parallel-rpc-port",
"13345",
])
# Use a thread to start each server to allow parallel initialization
def start_server(node: int, sargs: list[str]):
try:
# Calculate GPU devices for this node
gpus_per_node = self.dp_size_local * self.tp_size
gpu_start = node * gpus_per_node
gpu_end = gpu_start + gpus_per_node
# Start the server
server = RemoteOpenAIServer(
self.model_name,
sargs,
auto_port=False,
env_dict={
"CUDA_VISIBLE_DEVICES":
",".join(
str(Platform.device_id_to_physical_device_id(
i)) for i in range(gpu_start, gpu_end))
})
server.__enter__()
print(f"Hybrid LB node {node} started successfully with "
f"{self.dp_size_local} local DP ranks and "
f"{self.api_server_count} API servers")
self.servers.append((server, sargs))
except Exception as e:
print(f"Failed to start hybrid LB node {node}: {e}")
raise
thread = threading.Thread(target=start_server,
args=(node_id, server_args))
thread.start()
self.server_threads.append(thread)
# Wait for all servers to start
for thread in self.server_threads:
thread.join()
# Give servers additional time to fully initialize and coordinate
time.sleep(3)
if len(self.servers) != self.num_nodes:
raise Exception("Servers failed to start")
return self.servers
def __exit__(self, exc_type, exc_val, exc_tb):
"""Stop all server instances."""
while self.servers:
try:
self.servers.pop()[0].__exit__(exc_type, exc_val, exc_tb)
except Exception as e:
print(f"Error stopping server: {e}")
@pytest.fixture(scope="module")
def default_server_args():
return [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"2048",
"--max-num-seqs",
"128",
"--enforce-eager",
]
@pytest.fixture(scope="module", params=[1]) # Only 1 API server for now
def servers(request, default_server_args):
api_server_count = request.param
with HybridLBServerManager(MODEL_NAME, DP_SIZE, api_server_count,
default_server_args, DP_SIZE_LOCAL,
TP_SIZE) as server_list:
yield server_list
@pytest_asyncio.fixture
async def clients(servers: list[tuple[RemoteOpenAIServer, list[str]]]):
# Create a client for each node (each node has its own API endpoint)
async with AsyncExitStack() as stack:
yield [
await stack.enter_async_context(server.get_async_client())
for server, _ in servers
]
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_hybrid_lb_completion(clients: list[openai.AsyncOpenAI],
servers: list[tuple[RemoteOpenAIServer,
list[str]]],
model_name: str) -> None:
async def make_request(client: openai.AsyncOpenAI):
completion = await client.completions.create(
model=model_name,
prompt="Hello, my name is",
max_tokens=10,
temperature=1.0)
assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 1
choice = completion.choices[0]
# The exact number of tokens can vary slightly with temperature=1.0,
# so we check for a reasonable minimum length.
assert len(choice.text) >= 1
# Finish reason might not always be 'length' if the model finishes early
# or due to other reasons, especially with high temperature.
# So, we'll accept 'length' or 'stop'.
assert choice.finish_reason in ("length", "stop")
# Token counts can also vary, so we check they are positive.
assert completion.usage.completion_tokens > 0
assert completion.usage.prompt_tokens > 0
assert completion.usage.total_tokens > 0
return completion
# Test single request to each node
for i, client in enumerate(clients):
result = await make_request(client)
assert result is not None
print(
f"Hybrid LB node {i} handled single completion request successfully"
)
await asyncio.sleep(0.5)
# Send requests to all nodes - each should balance within its local DP ranks
num_requests_per_node = 25 # Total 50 requests across 2 nodes
all_tasks = []
for i, client in enumerate(clients):
tasks = [make_request(client) for _ in range(num_requests_per_node)]
all_tasks.extend(tasks)
results = await asyncio.gather(*all_tasks)
assert len(results) == num_requests_per_node * len(clients)
assert all(completion is not None for completion in results)
await asyncio.sleep(0.5)
# Second burst of requests
all_tasks = []
for i, client in enumerate(clients):
tasks = [make_request(client) for _ in range(num_requests_per_node)]
all_tasks.extend(tasks)
results = await asyncio.gather(*all_tasks)
assert len(results) == num_requests_per_node * len(clients)
assert all(completion is not None for completion in results)
_, server_args = servers[0]
api_server_count = (
server_args.count('--api-server-count')
and server_args[server_args.index('--api-server-count') + 1] or 1)
print(
f"Successfully completed hybrid LB test with {len(clients)} nodes "
f"({DP_SIZE_LOCAL} DP ranks each, API server count: {api_server_count})"
)
# Check request balancing within each node
for i, (server, _) in enumerate(servers):
print(f"Checking request balancing for node {i}")
check_request_balancing(server, DP_SIZE_LOCAL)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_hybrid_lb_completion_streaming(clients: list[
openai.AsyncOpenAI], servers: list[tuple[RemoteOpenAIServer, list[str]]],
model_name: str) -> None:
prompt = "What is an LLM?"
async def make_streaming_request(client: openai.AsyncOpenAI):
# Perform a non-streaming request to get the expected full output
single_completion = await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
)
single_output = single_completion.choices[0].text
# Perform the streaming request
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True)
chunks: list[str] = []
finish_reason_count = 0
last_chunk = None
async for chunk in stream:
chunks.append(chunk.choices[0].text)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
last_chunk = chunk # Keep track of the last chunk
# finish reason should only return in the last block for OpenAI API
assert finish_reason_count == 1, (
"Finish reason should appear exactly once.")
assert last_chunk is not None, (
"Stream should have yielded at least one chunk.")
assert last_chunk.choices[
0].finish_reason == "length", "Finish reason should be 'length'."
# Check that the combined text matches the non-streamed version.
assert "".join(
chunks
) == single_output, "Streamed output should match non-streamed output."
return True # Indicate success for this request
# Test single request to each node
for i, client in enumerate(clients):
result = await make_streaming_request(client)
assert result is not None
print(
f"Hybrid LB node {i} handled single streaming request successfully"
)
await asyncio.sleep(0.5)
# Send streaming requests to all nodes
num_requests_per_node = 25 # Total 50 requests across 2 nodes
all_tasks = []
for i, client in enumerate(clients):
tasks = [
make_streaming_request(client)
for _ in range(num_requests_per_node)
]
all_tasks.extend(tasks)
results = await asyncio.gather(*all_tasks)
assert len(results) == num_requests_per_node * len(clients)
assert all(results), "Not all streaming requests completed successfully."
await asyncio.sleep(0.5)
# Second burst of streaming requests
all_tasks = []
for i, client in enumerate(clients):
tasks = [
make_streaming_request(client)
for _ in range(num_requests_per_node)
]
all_tasks.extend(tasks)
results = await asyncio.gather(*all_tasks)
assert len(results) == num_requests_per_node * len(clients)
assert all(results), "Not all streaming requests completed successfully."
_, server_args = servers[0]
api_server_count = (
server_args.count('--api-server-count')
and server_args[server_args.index('--api-server-count') + 1] or 1)
print(f"Successfully completed hybrid LB streaming test with "
f"{len(clients)} nodes ({DP_SIZE_LOCAL} DP ranks each, "
f"API server count: {api_server_count})")
# Check request balancing within each node
for i, (server, _) in enumerate(servers):
print(f"Checking streaming request balancing for node {i}")
check_request_balancing(server, DP_SIZE_LOCAL)

View File

@ -1906,16 +1906,8 @@ class ParallelConfig:
"""Backend to use for data parallel, either "mp" or "ray"."""
data_parallel_external_lb: bool = False
"""Whether to use "external" DP LB mode. Applies only to online serving
and when data_parallel_size > 0. This is useful for a "one-pod-per-rank"
wide-EP setup in Kuberentes. Set implicitly when --data-parallel-rank
is provided explicitly to vllm serve."""
data_parallel_hybrid_lb: bool = False
"""Whether to use "hybrid" DP LB mode. Applies only to online serving
and when data_parallel_size > 0. Enables running an AsyncLLM
and API server on a "per-node" basis where vLLM load balances
between local data parallel ranks, but an external LB balances
between vLLM nodes/replicas. Set explicitly in conjunction with
--data-parallel-start-rank."""
and when data_parallel_size > 0. Set implicitly when
data_parallel_rank is provided explicitly to vllm serve."""
enable_expert_parallel: bool = False
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
enable_eplb: bool = False

View File

@ -893,6 +893,7 @@ class NixlConnectorWorker:
return done_req_ids
def start_load_kv(self, metadata: NixlConnectorMetadata):
start = time.perf_counter()
"""
Start loading by triggering non-blocking nixl_xfer.
We check for these trnxs to complete in each step().
@ -921,6 +922,11 @@ class NixlConnectorWorker:
# Add to requests that are waiting to be read and track expiration.
self._reqs_to_send.update(metadata.reqs_to_send)
end = time.perf_counter()
if self.tp_rank == 0:
logger.info(
f"===== {len(metadata.reqs_to_recv)}: start_load_kv time: {end-start: 0.5f}s"
)
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
logger.debug(
@ -1019,10 +1025,15 @@ class NixlConnectorWorker:
remote_xfer_side_handle,
remote_block_descs_ids,
notif_msg=notif_id,
skip_desc_merge=True,
)
# Begin async xfer.
start = time.perf_counter()
self.nixl_wrapper.transfer(handle)
end = time.perf_counter()
if self.tp_rank == 0:
logger.info(f"TRANSFER TIME: {end-start :0.4f}s")
# Use handle to check completion in future step().
# TODO (NickLucche) surface xfer elapsed time

View File

@ -295,11 +295,9 @@ class EngineArgs:
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
data_parallel_size: int = ParallelConfig.data_parallel_size
data_parallel_rank: Optional[int] = None
data_parallel_start_rank: Optional[int] = None
data_parallel_size_local: Optional[int] = None
data_parallel_address: Optional[str] = None
data_parallel_rpc_port: Optional[int] = None
data_parallel_hybrid_lb: bool = False
data_parallel_backend: str = ParallelConfig.data_parallel_backend
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
enable_eplb: bool = ParallelConfig.enable_eplb
@ -609,11 +607,6 @@ class EngineArgs:
type=int,
help='Data parallel rank of this instance. '
'When set, enables external load balancer mode.')
parallel_group.add_argument('--data-parallel-start-rank',
'-dpr',
type=int,
help='Starting data parallel rank '
'for secondary nodes.')
parallel_group.add_argument('--data-parallel-size-local',
'-dpl',
type=int,
@ -635,9 +628,6 @@ class EngineArgs:
default='mp',
help='Backend for data parallel, either '
'"mp" or "ray".')
parallel_group.add_argument(
"--data-parallel-hybrid-lb",
**parallel_kwargs["data_parallel_hybrid_lb"])
parallel_group.add_argument(
"--enable-expert-parallel",
**parallel_kwargs["enable_expert_parallel"])
@ -996,7 +986,6 @@ class EngineArgs:
def create_engine_config(
self,
usage_context: Optional[UsageContext] = None,
headless: bool = False,
) -> VllmConfig:
"""
Create the VllmConfig.
@ -1085,41 +1074,15 @@ class EngineArgs:
# but we should not do this here.
placement_group = ray.util.get_current_placement_group()
assert not headless or not self.data_parallel_hybrid_lb, (
"data_parallel_hybrid_lb is not applicable in "
"headless mode")
data_parallel_external_lb = self.data_parallel_rank is not None
# Local DP rank = 1, use pure-external LB.
if data_parallel_external_lb:
assert self.data_parallel_size_local in (1, None), (
"data_parallel_size_local must be 1 when data_parallel_rank "
"is set")
data_parallel_size_local = 1
# Use full external lb if we have local_size of 1.
self.data_parallel_hybrid_lb = False
elif self.data_parallel_size_local is not None:
data_parallel_size_local = self.data_parallel_size_local
if self.data_parallel_start_rank and not headless:
# Infer hybrid LB mode.
self.data_parallel_hybrid_lb = True
if self.data_parallel_hybrid_lb and data_parallel_size_local == 1:
# Use full external lb if we have local_size of 1.
data_parallel_external_lb = True
self.data_parallel_hybrid_lb = False
if data_parallel_size_local == self.data_parallel_size:
# Disable hybrid LB mode if set for a single node
self.data_parallel_hybrid_lb = False
self.data_parallel_rank = self.data_parallel_start_rank or 0
else:
assert not self.data_parallel_hybrid_lb, (
"data_parallel_size_local must be set to use "
"data_parallel_hybrid_lb.")
# Local DP size defaults to global DP size if not set.
data_parallel_size_local = self.data_parallel_size
@ -1176,7 +1139,6 @@ class EngineArgs:
data_parallel_master_ip=data_parallel_address,
data_parallel_rpc_port=data_parallel_rpc_port,
data_parallel_backend=self.data_parallel_backend,
data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
enable_expert_parallel=self.enable_expert_parallel,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.num_redundant_experts,

View File

@ -45,6 +45,11 @@ class ServeSubcommand(CLISubcommand):
if args.headless or args.api_server_count < 1:
run_headless(args)
else:
if args.data_parallel_start_rank:
raise ValueError(
"data_parallel_start_rank is only applicable "
"in headless mode. "
"Add --headless flag to enable headless mode.")
if args.api_server_count > 1:
run_multi_api_server(args)
else:
@ -81,14 +86,13 @@ def run_headless(args: argparse.Namespace):
# Create the EngineConfig.
engine_args = vllm.AsyncEngineArgs.from_cli_args(args)
usage_context = UsageContext.OPENAI_API_SERVER
vllm_config = engine_args.create_engine_config(usage_context=usage_context,
headless=True)
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
if not envs.VLLM_USE_V1:
raise ValueError("Headless mode is only supported for V1")
if engine_args.data_parallel_hybrid_lb:
raise ValueError("data_parallel_hybrid_lb is not applicable in "
if engine_args.data_parallel_rank is not None:
raise ValueError("data_parallel_rank is not applicable in "
"headless mode")
parallel_config = vllm_config.parallel_config
@ -118,7 +122,7 @@ def run_headless(args: argparse.Namespace):
engine_manager = CoreEngineProcManager(
target_fn=EngineCoreProc.run_engine_core,
local_engine_count=local_engine_count,
start_index=vllm_config.parallel_config.data_parallel_rank,
start_index=args.data_parallel_start_rank,
local_start_index=0,
vllm_config=vllm_config,
local_client=False,
@ -165,11 +169,6 @@ def run_multi_api_server(args: argparse.Namespace):
" api_server_count > 1")
model_config.disable_mm_preprocessor_cache = True
if vllm_config.parallel_config.data_parallel_hybrid_lb:
raise NotImplementedError(
"Hybrid load balancing with --api-server-count > 0"
"is not yet supported.")
executor_class = Executor.get_class(vllm_config)
log_stats = not engine_args.disable_log_stats

View File

@ -253,6 +253,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=False,
help="Run in headless mode. See multi-node data parallel "
"documentation for more details.")
parser.add_argument(
"--data-parallel-start-rank",
"-dpr",
type=int,
default=0,
help="Starting data parallel rank for secondary nodes. "
"Requires --headless.")
parser.add_argument("--api-server-count",
"-asc",
type=int,

View File

@ -125,7 +125,7 @@ class AsyncLLM(EngineClient):
if self.log_stats:
self.logger_manager = StatLoggerManager(
vllm_config=vllm_config,
engine_idxs=self.engine_core.engine_ranks_managed,
engine_idxs=self.engine_core.engine_ranks,
custom_stat_loggers=stat_loggers,
)
self.logger_manager.log_engine_initialized()

View File

@ -61,12 +61,11 @@ class DPCoordinator:
host = parallel_config.data_parallel_master_ip
external_lb = parallel_config.data_parallel_external_lb
hybrid_lb = parallel_config.data_parallel_hybrid_lb
# Assume coordinator is colocated with front-end procs when not in
# either external or hybrid DP LB mode.
# external DP LB mode.
front_publish_address = get_engine_client_zmq_addr(
local_only=not external_lb and not hybrid_lb, host=host)
local_only=not external_lb, host=host)
local_only_eng = dp_size == parallel_config.data_parallel_size_local
back_publish_address = get_engine_client_zmq_addr(local_only_eng, host)

View File

@ -467,14 +467,13 @@ class EngineCoreProc(EngineCore):
For DP>1 with internal loadbalancing this is with the shared front-end
process which may reside on a different node.
For DP>1 with external or hybrid loadbalancing, two handshakes are
performed:
For DP>1 with external loadbalancing, two handshakes are performed:
- With the rank 0 front-end process which retrieves the
DP Coordinator ZMQ addresses and DP process group address.
- With the colocated front-end process which retrieves the
client input/output socket addresses.
with the exception of the rank 0 and colocated engines themselves which
don't require the second handshake.
with the exception of the rank 0 engine itself which doesn't require
the second handshake.
Here, "front-end" process can mean the process containing the engine
core client (which is the API server process in the case the API
@ -483,18 +482,15 @@ class EngineCoreProc(EngineCore):
"""
input_ctx = zmq.Context()
is_local = local_client and client_handshake_address is None
headless = not local_client
handshake = self._perform_handshake(input_ctx, handshake_address,
identity, is_local, headless,
vllm_config,
identity, is_local, vllm_config,
vllm_config.parallel_config)
if client_handshake_address is None:
with handshake as addresses:
yield addresses
else:
assert local_client
local_handshake = self._perform_handshake(
input_ctx, client_handshake_address, identity, True, False,
input_ctx, client_handshake_address, identity, local_client,
vllm_config)
with handshake as addresses, local_handshake as client_addresses:
addresses.inputs = client_addresses.inputs
@ -511,7 +507,6 @@ class EngineCoreProc(EngineCore):
handshake_address: str,
identity: bytes,
local_client: bool,
headless: bool,
vllm_config: VllmConfig,
parallel_config_to_update: Optional[ParallelConfig] = None,
) -> Generator[EngineZmqAddresses, None, None]:
@ -523,7 +518,6 @@ class EngineCoreProc(EngineCore):
bind=False) as handshake_socket:
# Register engine with front-end.
addresses = self.startup_handshake(handshake_socket, local_client,
headless,
parallel_config_to_update)
yield addresses
@ -537,7 +531,6 @@ class EngineCoreProc(EngineCore):
msgspec.msgpack.encode({
"status": "READY",
"local": local_client,
"headless": headless,
"num_gpu_blocks": num_gpu_blocks,
"dp_stats_address": dp_stats_address,
}))
@ -546,7 +539,6 @@ class EngineCoreProc(EngineCore):
def startup_handshake(
handshake_socket: zmq.Socket,
local_client: bool,
headless: bool,
parallel_config: Optional[ParallelConfig] = None,
) -> EngineZmqAddresses:
@ -555,7 +547,6 @@ class EngineCoreProc(EngineCore):
msgspec.msgpack.encode({
"status": "HELLO",
"local": local_client,
"headless": headless,
}))
# Receive initialization message.

View File

@ -429,23 +429,18 @@ class MPClient(EngineCoreClient):
parallel_config = vllm_config.parallel_config
dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_rank
dp_local_size = parallel_config.data_parallel_size_local
offline_mode = parallel_config.data_parallel_rank_local is not None
# Client manages local+remote EngineCores in pure internal LB case.
# Client manages local EngineCores in hybrid and external LB case.
local_engines_only = (parallel_config.data_parallel_hybrid_lb
or parallel_config.data_parallel_external_lb)
external_dp_lb = parallel_config.data_parallel_external_lb
num_ranks = dp_local_size if local_engines_only else dp_size
self.engine_ranks_managed = [dp_rank] if offline_mode else list(
range(dp_rank, dp_rank + num_ranks))
offline_mode = parallel_config.data_parallel_rank_local is not None
self.engine_ranks = ([dp_rank] if
(offline_mode or external_dp_lb) else list(
range(dp_size)))
assert parallel_config.data_parallel_size_local <= len(
self.engine_ranks_managed)
self.engine_ranks)
# ZMQ identity of each engine that this client will talk to.
self.core_engines: list[EngineIdentity] = [
rank.to_bytes(2, "little")
for rank in self.engine_ranks_managed
index.to_bytes(2, "little") for index in self.engine_ranks
]
# Wait for ready messages from each engine on the input socket.
@ -900,12 +895,6 @@ class DPAsyncMPClient(AsyncMPClient):
return
assert self.stats_update_address is not None
assert len(self.engine_ranks_managed) > 0
# NOTE: running and waiting counts are all global from
# the Coordinator include all global EngineCores. This
# slice includes just the cores managed by this client.
count_slice = slice(self.engine_ranks_managed[0],
self.engine_ranks_managed[-1] + 1)
async def run_engine_stats_update_task():
with make_zmq_socket(self.ctx, self.stats_update_address,
@ -970,8 +959,7 @@ class DPAsyncMPClient(AsyncMPClient):
counts, wave, running = msgspec.msgpack.decode(buf)
self.current_wave = wave
self.engines_running = running
self.lb_engines = counts[count_slice]
logger.info(f"{counts=} | {count_slice=}")
self.lb_engines = counts
resources.stats_update_task = asyncio.create_task(
run_engine_stats_update_task())

View File

@ -544,8 +544,7 @@ def launch_core_engines(
local_start_index = parallel_config.data_parallel_rank_local
dp_rank = parallel_config.data_parallel_rank
host = parallel_config.data_parallel_master_ip
local_engines_only = (parallel_config.data_parallel_hybrid_lb
or parallel_config.data_parallel_external_lb)
external_dp_lb = parallel_config.data_parallel_external_lb
# In offline mode there is an LLM instance per DP rank and
# one core engine per LLM, see
@ -554,8 +553,8 @@ def launch_core_engines(
# client_local_only = True for cases where this front-end
# sends requests only to colocated engines.
client_local_only = (offline_mode or local_engines_only
or (local_engine_count == dp_size))
client_local_only = offline_mode or external_dp_lb or (local_engine_count
== dp_size)
# Set up input and output addresses.
addresses = EngineZmqAddresses(
@ -599,27 +598,14 @@ def launch_core_engines(
yield engine_actor_manager, coordinator, addresses
return
if offline_mode:
if offline_mode or (external_dp_lb and dp_rank > 0):
assert local_engine_count == 1
engines_to_handshake = [CoreEngine(index=dp_rank, local=True)]
elif dp_rank == 0:
# Rank 0 holds Coordinator, so it handshakes with all Cores
# in both external dplb and internal dplb mode.
# Note this also covers the case where we have zero local engines
# and rank 0 is headless.
else:
engines_to_handshake = [
CoreEngine(index=i, local=(i < local_engine_count))
for i in range(dp_size)
]
else:
# Rank > 0 handshakes with just the local cores it is managing.
assert local_engines_only, (
"Attempting to launch core_engines from dp_rank > 0, but "
"found internal DPLB, which is incompatible.")
engines_to_handshake = [
CoreEngine(index=i, local=True)
for i in range(dp_rank, dp_rank + local_engine_count)
]
# Whether the started engines will handshake only with co-located
# front-end processes. In external_dp_lb mode, ranks > 0 handshake with
@ -630,7 +616,7 @@ def launch_core_engines(
handshake_address = get_engine_client_zmq_addr(
handshake_local_only, host, parallel_config.data_parallel_rpc_port)
if local_engines_only and dp_rank > 0:
if external_dp_lb and dp_rank > 0:
assert not handshake_local_only
local_handshake_address = get_open_zmq_ipc_path()
client_handshake_address = local_handshake_address
@ -645,6 +631,8 @@ def launch_core_engines(
# Start local engines.
if local_engine_count:
# In server mode, start_index and local_start_index will
# both be 0.
local_engine_manager = CoreEngineProcManager(
EngineCoreProc.run_engine_core,
vllm_config=vllm_config,
@ -690,9 +678,6 @@ def wait_for_engine_startup(
poller = zmq.Poller()
poller.register(handshake_socket, zmq.POLLIN)
remote_should_be_headless = not parallel_config.data_parallel_hybrid_lb \
and not parallel_config.data_parallel_external_lb
if proc_manager is not None:
for sentinel in proc_manager.sentinels():
poller.register(sentinel, zmq.POLLIN)
@ -728,24 +713,13 @@ def wait_for_engine_startup(
raise RuntimeError(f"Message from engine with unexpected data "
f"parallel rank: {eng_index}")
msg = msgspec.msgpack.decode(ready_msg_bytes)
status, local, headless = msg["status"], msg["local"], msg["headless"]
status, local = msg["status"], msg["local"]
if local != engine.local:
raise RuntimeError(f"{status} message from "
f"{'local' if local else 'remote'} "
f"engine {eng_index}, expected it to be "
f"{'local' if engine.local else 'remote'}")
# Remote engines must be headless iff we aren't in hybrid dp lb mode.
if not local and headless != remote_should_be_headless:
if headless:
raise RuntimeError(f"Remote engine {eng_index} must not use "
f"--headless in external or hybrid dp lb "
f"mode")
else:
raise RuntimeError(f"Remote engine {eng_index} must use "
f"--headless unless in external or hybrid "
f"dp lb mode")
if status == "HELLO" and engine.state == CoreEngineState.NEW:
# Send init message with DP config info.

View File

@ -15,7 +15,8 @@ from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment,
set_custom_all_reduce)
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
has_kv_transfer_group)
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@ -333,19 +334,20 @@ class Worker(WorkerBase):
assert isinstance(output, IntermediateTensors)
get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group())
if not has_kv_transfer_group():
return None
# In case of PP with kv transfer, we need to pass through the
# finished_sending and finished_recving buffers.
empty_output = EMPTY_MODEL_RUNNER_OUTPUT
new_output = EMPTY_MODEL_RUNNER_OUTPUT
if output.finished_sending or output.finished_recving:
empty_output = copy.copy(empty_output)
empty_output.finished_sending = output.finished_sending
empty_output.finished_recving = output.finished_recving
output = empty_output
new_output = copy.copy(new_output)
new_output.finished_sending = output.finished_sending
new_output.finished_recving = output.finished_recving
output = new_output
assert isinstance(output, ModelRunnerOutput)
# return output only from the driver worker
return output if self.is_driver_worker else None
return output
def profile(self, is_start: bool = True):
if self.profiler is None: