[P/D][NixlConnector] Support tp_size > num_kv_heads deployments (#19691)
Signed-off-by: NickLucche <nlucches@redhat.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@ -22,6 +22,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
|
||||
get_tp_group)
|
||||
from vllm.distributed.utils import divide
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.utils import make_zmq_path, make_zmq_socket, round_down
|
||||
@ -679,11 +680,15 @@ class NixlConnectorWorker:
|
||||
|
||||
# Number of D TP workers reading from a single P TP worker. This is
|
||||
# 1 when P and D `--tensor-parallel-size` match.
|
||||
assert self._tp_size[self.engine_id] % self._tp_size[engine_id] == 0, (
|
||||
"Local TP size must be divisible by remote TP size.")
|
||||
tp_ratio = self._tp_size[self.engine_id] // self._tp_size[engine_id]
|
||||
tp_ratio = divide(self._tp_size[self.engine_id],
|
||||
self._tp_size[engine_id])
|
||||
assert tp_ratio > 0, "Decode TP cannot be smaller than prefill TP"
|
||||
if self.use_mla:
|
||||
|
||||
# Handle tp_size>num_kv_heads: replicate KV cache.
|
||||
total_num_kv_heads = self.model_config.get_total_num_kv_heads()
|
||||
is_kv_replicated = self._tp_size[engine_id] // total_num_kv_heads >= 1
|
||||
|
||||
if self.use_mla or is_kv_replicated:
|
||||
# With MLA the only difference is in the number of blocks.
|
||||
remote_block_size = nixl_agent_meta.block_len // (
|
||||
self.slot_size_bytes)
|
||||
@ -720,7 +725,7 @@ class NixlConnectorWorker:
|
||||
self.kv_caches_base_addr[
|
||||
engine_id] = nixl_agent_meta.kv_caches_base_addr
|
||||
rank_offset = self.tp_rank % tp_ratio * self.block_len \
|
||||
if not self.use_mla else 0
|
||||
if not (self.use_mla or is_kv_replicated) else 0
|
||||
# Register all remote blocks, but only the corresponding kv heads.
|
||||
for base_addr in nixl_agent_meta.kv_caches_base_addr:
|
||||
for block_id in range(nixl_agent_meta.num_blocks):
|
||||
|
||||
Reference in New Issue
Block a user