[Bugfix][Disaggregated] Add a check in send_kv_caches_and_hidden_states and fix the reshape of the KVCache (#14369)

Signed-off-by: Mathis Felardos <mathis@mistral.ai>
This commit is contained in:
Mathis Felardos
2025-03-08 07:39:31 +01:00
committed by GitHub
parent ca7a2d5f28
commit 980385f8c1

View File

@ -2,7 +2,7 @@
"""
Simple KV Cache Connector for Distributed Machine Learning Inference
The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache
The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache
producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe or
MooncakePipe.
@ -159,6 +159,7 @@ class SimpleConnector(KVConnectorBase):
input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
start_layer = model_executable.model.start_layer
end_layer = model_executable.model.end_layer
@ -166,7 +167,8 @@ class SimpleConnector(KVConnectorBase):
num_heads = int(model_config.num_key_value_heads / self.tp_size)
hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads
head_size = int(hidden_size / num_attention_heads)
head_size = getattr(model_config, "head_dim",
int(hidden_size // num_attention_heads))
# query_lens contains new KV caches that are added to vLLM.
# so we will send them to decode instance
@ -174,6 +176,15 @@ class SimpleConnector(KVConnectorBase):
for idx, slen in enumerate(seq_lens):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
if start_pos >= num_prefill_tokens:
# vllm/worker/model_runner.py::_prepare_model_input_tensors:
# - input_tokens[:num_prefill_tokens] contains prefill tokens.
# - input_tokens[num_prefill_tokens:] contains decode tokens.
logger.warning("You have some decode requests while using "
"SimpleConnector. Their KVCache won't be sent.")
break
current_tokens = input_tokens_tensor[start_pos:end_pos]
keys, values = [], []
@ -236,7 +247,7 @@ class SimpleConnector(KVConnectorBase):
# - input_tokens[num_prefill_tokens:] contains decode tokens.
logger.warning("You should set --enable_chunked_prefill=False "
"and --max_num_batched_tokens "
"should be equal to max_seq_len_to_capture")
"should be equal to --max_seq_len_to_capture")
bypass_model_exec = False
assert start_pos == num_prefill_tokens
break