From 2ebff5b77c049b3e620d5f79f02acbcbbc09bade Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Lucchesi?= Date: Mon, 23 Jun 2025 07:41:50 +0200 Subject: [PATCH] [P/D][NixlConnector] Support `tp_size > num_kv_heads` deployments (#19691) Signed-off-by: NickLucche Co-authored-by: Nick Hill --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index bdab4850d4..94f757e007 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -22,6 +22,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tp_group) +from vllm.distributed.utils import divide from vllm.logger import init_logger from vllm.platforms import _Backend from vllm.utils import make_zmq_path, make_zmq_socket, round_down @@ -679,11 +680,15 @@ class NixlConnectorWorker: # Number of D TP workers reading from a single P TP worker. This is # 1 when P and D `--tensor-parallel-size` match. - assert self._tp_size[self.engine_id] % self._tp_size[engine_id] == 0, ( - "Local TP size must be divisible by remote TP size.") - tp_ratio = self._tp_size[self.engine_id] // self._tp_size[engine_id] + tp_ratio = divide(self._tp_size[self.engine_id], + self._tp_size[engine_id]) assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP" - if self.use_mla: + + # Handle tp_size>num_kv_heads: replicate KV cache. + total_num_kv_heads = self.model_config.get_total_num_kv_heads() + is_kv_replicated = self._tp_size[engine_id] // total_num_kv_heads >= 1 + + if self.use_mla or is_kv_replicated: # With MLA the only difference is in the number of blocks. remote_block_size = nixl_agent_meta.block_len // ( self.slot_size_bytes) @@ -720,7 +725,7 @@ class NixlConnectorWorker: self.kv_caches_base_addr[ engine_id] = nixl_agent_meta.kv_caches_base_addr rank_offset = self.tp_rank % tp_ratio * self.block_len \ - if not self.use_mla else 0 + if not (self.use_mla or is_kv_replicated) else 0 # Register all remote blocks, but only the corresponding kv heads. for base_addr in nixl_agent_meta.kv_caches_base_addr: for block_id in range(nixl_agent_meta.num_blocks):