[Bugfix][Disaggregated] Add a check in send_kv_caches_and_hidden_states and fix the reshape of the KVCache (#14369)
Signed-off-by: Mathis Felardos <mathis@mistral.ai>
This commit is contained in:
@ -2,7 +2,7 @@
|
||||
"""
|
||||
Simple KV Cache Connector for Distributed Machine Learning Inference
|
||||
|
||||
The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache
|
||||
The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache
|
||||
producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe or
|
||||
MooncakePipe.
|
||||
|
||||
@ -159,6 +159,7 @@ class SimpleConnector(KVConnectorBase):
|
||||
input_tokens_tensor = model_input.input_tokens
|
||||
seq_lens = model_input.attn_metadata.seq_lens
|
||||
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
|
||||
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
|
||||
start_layer = model_executable.model.start_layer
|
||||
end_layer = model_executable.model.end_layer
|
||||
|
||||
@ -166,7 +167,8 @@ class SimpleConnector(KVConnectorBase):
|
||||
num_heads = int(model_config.num_key_value_heads / self.tp_size)
|
||||
hidden_size = model_config.hidden_size
|
||||
num_attention_heads = model_config.num_attention_heads
|
||||
head_size = int(hidden_size / num_attention_heads)
|
||||
head_size = getattr(model_config, "head_dim",
|
||||
int(hidden_size // num_attention_heads))
|
||||
|
||||
# query_lens contains new KV caches that are added to vLLM.
|
||||
# so we will send them to decode instance
|
||||
@ -174,6 +176,15 @@ class SimpleConnector(KVConnectorBase):
|
||||
for idx, slen in enumerate(seq_lens):
|
||||
start_pos = sum(seq_lens[:idx])
|
||||
end_pos = start_pos + slen
|
||||
|
||||
if start_pos >= num_prefill_tokens:
|
||||
# vllm/worker/model_runner.py::_prepare_model_input_tensors:
|
||||
# - input_tokens[:num_prefill_tokens] contains prefill tokens.
|
||||
# - input_tokens[num_prefill_tokens:] contains decode tokens.
|
||||
logger.warning("You have some decode requests while using "
|
||||
"SimpleConnector. Their KVCache won't be sent.")
|
||||
break
|
||||
|
||||
current_tokens = input_tokens_tensor[start_pos:end_pos]
|
||||
|
||||
keys, values = [], []
|
||||
@ -236,7 +247,7 @@ class SimpleConnector(KVConnectorBase):
|
||||
# - input_tokens[num_prefill_tokens:] contains decode tokens.
|
||||
logger.warning("You should set --enable_chunked_prefill=False "
|
||||
"and --max_num_batched_tokens "
|
||||
"should be equal to max_seq_len_to_capture")
|
||||
"should be equal to --max_seq_len_to_capture")
|
||||
bypass_model_exec = False
|
||||
assert start_pos == num_prefill_tokens
|
||||
break
|
||||
|
||||
Reference in New Issue
Block a user