[P/D][NixlConnector] Support tp_size > num_kv_heads deployments (#19691)

Signed-off-by: NickLucche <nlucches@redhat.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nicolò Lucchesi
2025-06-23 07:41:50 +02:00
committed by GitHub
parent f17aec0d63
commit 2ebff5b77c

View File

@ -22,6 +22,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
get_tp_group)
from vllm.distributed.utils import divide
from vllm.logger import init_logger
from vllm.platforms import _Backend
from vllm.utils import make_zmq_path, make_zmq_socket, round_down
@ -679,11 +680,15 @@ class NixlConnectorWorker:
# Number of D TP workers reading from a single P TP worker. This is
# 1 when P and D `--tensor-parallel-size` match.
assert self._tp_size[self.engine_id] % self._tp_size[engine_id] == 0, (
"Local TP size must be divisible by remote TP size.")
tp_ratio = self._tp_size[self.engine_id] // self._tp_size[engine_id]
tp_ratio = divide(self._tp_size[self.engine_id],
self._tp_size[engine_id])
assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP"
if self.use_mla:
# Handle tp_size>num_kv_heads: replicate KV cache.
total_num_kv_heads = self.model_config.get_total_num_kv_heads()
is_kv_replicated = self._tp_size[engine_id] // total_num_kv_heads >= 1
if self.use_mla or is_kv_replicated:
# With MLA the only difference is in the number of blocks.
remote_block_size = nixl_agent_meta.block_len // (
self.slot_size_bytes)
@ -720,7 +725,7 @@ class NixlConnectorWorker:
self.kv_caches_base_addr[
engine_id] = nixl_agent_meta.kv_caches_base_addr
rank_offset = self.tp_rank % tp_ratio * self.block_len \
if not self.use_mla else 0
if not (self.use_mla or is_kv_replicated) else 0
# Register all remote blocks, but only the corresponding kv heads.
for base_addr in nixl_agent_meta.kv_caches_base_addr:
for block_id in range(nixl_agent_meta.num_blocks):