From 9ff4511e43bb95efefd4e28048ca257e408277fb Mon Sep 17 00:00:00 2001 From: Elfie Guo <164945471+elfiegg@users.noreply.github.com> Date: Wed, 30 Oct 2024 09:33:53 -0700 Subject: [PATCH] [Misc] Add chunked-prefill support on FlashInfer. (#9781) --- .../basic_correctness/test_chunked_prefill.py | 12 +++ vllm/attention/backends/flashinfer.py | 88 +++++++++++++------ 2 files changed, 72 insertions(+), 28 deletions(-) diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index 51aec8c873..cc5bc2aca2 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -11,6 +11,8 @@ from contextlib import nullcontext import pytest +from tests.kernels.utils import override_backend_env_variable + from ..models.utils import check_logprobs_close, check_outputs_equal from ..utils import multi_gpu_test @@ -28,6 +30,7 @@ MODELS = [ # NOTE: Increasing this in this suite will fail CI because we currently cannot # reset distributed env properly. Use a value > 1 just when you test. @pytest.mark.parametrize("tensor_parallel_size", [1]) +@pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"]) def test_models( hf_runner, vllm_runner, @@ -38,11 +41,15 @@ def test_models( chunked_prefill_token_size: int, enforce_eager: bool, tensor_parallel_size: int, + attention_backend: str, + monkeypatch, ) -> None: """ Checks exact match decode between huggingface model and vllm runner with chunked prefill. """ + override_backend_env_variable(monkeypatch, attention_backend) + max_num_seqs = chunked_prefill_token_size max_num_batched_tokens = chunked_prefill_token_size @@ -71,13 +78,18 @@ def test_models( @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"]) @pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"]) def test_models_distributed( hf_runner, vllm_runner, example_prompts, model: str, distributed_executor_backend: str, + attention_backend: str, + monkeypatch, ) -> None: + override_backend_env_variable(monkeypatch, attention_backend) + if (model == "meta-llama/Llama-2-7b-hf" and distributed_executor_backend == "ray"): # test ray adag diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index e43fb134a6..5ea101ae04 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -268,6 +268,11 @@ class FlashInferMetadata(AttentionMetadata): # Maximum sequence length among prefill batch. 0 if there are decoding # requests only. max_prefill_seq_len: int + # Number of query tokens for each request in the batch. + # Currently, we require that all requests have the same number of query + # tokens during the decoding phase. When speculavie decoding is enabled, + # decode_query_len might be greater than 1. In all other cases, it is 1. + decode_query_len: Optional[int] = 1 use_cuda_graph: bool = True @@ -335,6 +340,7 @@ class FlashInferMetadata(AttentionMetadata): assert self.paged_kv_last_page_len is not None assert self.block_table_bound is not None assert self.seq_lens_tensor is not None + self.query_start_loc = self.query_start_loc[:self.num_prefills + 1] batch_size = self.query_start_loc.shape[0] - 1 assert batch_size >= 0 # We will use flash attention for profiling to @@ -349,11 +355,13 @@ class FlashInferMetadata(AttentionMetadata): self.paged_kv_indices = self.paged_kv_indices.to(self.device) self.prefill_wrapper.end_forward() self.prefill_wrapper.begin_forward( - self.query_start_loc, self.paged_kv_indptr, - self.paged_kv_indices, self.paged_kv_last_page_len, + self.query_start_loc, + self.paged_kv_indptr[:self.num_prefills + 1], + self.paged_kv_indices, + self.paged_kv_last_page_len[:self.num_prefills], self.num_qo_heads, self.num_kv_heads, self.head_dim, self.page_size) - else: + if self.num_decode_tokens > 0: assert self.paged_kv_indices is not None assert self.paged_kv_indptr is not None assert self.paged_kv_last_page_len is not None @@ -370,9 +378,9 @@ class FlashInferMetadata(AttentionMetadata): assert self.decode_wrapper is not None self.decode_wrapper.end_forward() self.decode_wrapper.begin_forward( - self.paged_kv_indptr, + self.paged_kv_indptr[self.num_prefills:], self.paged_kv_indices, - self.paged_kv_last_page_len, + self.paged_kv_last_page_len[self.num_prefills:], self.num_qo_heads, self.num_kv_heads, self.head_dim, @@ -397,21 +405,14 @@ class FlashInferMetadata(AttentionMetadata): @property def prefill_metadata(self) -> Optional["FlashInferMetadata"]: - # Currently chunked prefill is not supported - if self.num_decode_tokens == 0: - assert self.num_prefills > 0 - return self - - return None + if self.num_prefills == 0: + return None + return self @property def decode_metadata(self) -> Optional["FlashInferMetadata"]: - # Currently chunked prefill is not supported - if self.num_prefills > 0: - assert self.num_decode_tokens == 0, ( - "Chunked prefill is not supported with flashinfer yet.") + if self.num_decode_tokens == 0: return None - return self def advance_step(self, @@ -599,11 +600,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): max_prefill_seq_len = max(self.prefill_seq_lens, default=0) num_decode_tokens = self.num_decode_tokens + decode_query_len = max(query_lens[self.num_prefills:], default=1) if use_captured_graph: self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) self.block_tables.extend([] * cuda_graph_pad_size) - num_decode_tokens = batch_size + num_decode_tokens = batch_size - self.num_prefill_tokens # The shape of graph_block_tables is # [max batch size, max context len // block size]. @@ -689,6 +691,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self.runner.kv_cache_dtype, self.runner.model_config.dtype) return FlashInferMetadata( + decode_query_len=decode_query_len, num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, num_prefill_tokens=self.num_prefill_tokens, @@ -811,12 +814,6 @@ def unified_flash_infer( key = key.view(-1, num_kv_heads, head_size) value = value.view(-1, num_kv_heads, head_size) - if attn_metadata.num_prefill_tokens > 0: - assert attn_metadata.num_decode_tokens == 0, ( - "Chunked prefill is not supported with flashinfer yet.") - if attn_metadata.num_decode_tokens > 0: - assert attn_metadata.num_prefill_tokens == 0, ( - "Chunked prefill is not supported with flashinfer yet.") if kv_cache.numel() > 0: # Use the same reshape and cache kernel as flash attention. ops.reshape_and_cache_flash( @@ -836,14 +833,33 @@ def unified_flash_infer( kv_cache_dtype) kv_cache = kv_cache.view(torch_dtype) + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ + f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa + assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ + f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa query = query.contiguous() # Flashinfer requires query to be contiguous + # Query for decode. KV is not needed because it is already cached. + # QKV for prefill. + decode_query = query[num_prefill_tokens:] + query = query[:num_prefill_tokens] + + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + prefill_output: Optional[torch.Tensor] = None + decode_output: Optional[torch.Tensor] = None if prefill_meta := attn_metadata.prefill_metadata: # We will use flash attention for prefill # when kv_cache is not provided. # This happens when vllm runs the profiling to # determine the number of blocks. if kv_cache.numel() == 0: - output = flash_attn_varlen_func( + prefill_output = flash_attn_varlen_func( q=query, k=key, v=value, @@ -859,18 +875,34 @@ def unified_flash_infer( else: assert prefill_meta is not None assert prefill_meta.prefill_wrapper is not None - output = prefill_meta.prefill_wrapper.forward( + prefill_output = prefill_meta.prefill_wrapper.forward( query, kv_cache, logits_soft_cap=logits_soft_cap, causal=True) - else: + if decode_meta := attn_metadata.decode_metadata: assert attn_metadata.decode_metadata is not None assert attn_metadata.decode_metadata.decode_wrapper is not None - output = attn_metadata.decode_metadata.decode_wrapper.forward( - query, + decode_output = attn_metadata.decode_metadata.decode_wrapper.forward( + decode_query, kv_cache, sm_scale=softmax_scale, logits_soft_cap=logits_soft_cap, k_scale=k_scale, v_scale=v_scale) + + if prefill_output is None and decode_output is not None: + # Decode only batch. + output, num_tokens = decode_output, num_decode_tokens + elif decode_output is None and prefill_output is not None: + # Prefill only batch. + output, num_tokens = prefill_output, num_prefill_tokens + else: + # Chunked prefill batch does not work with speculative decoding in + # FlashInfer backend, so the query length for decode should be 1. + assert prefill_output is not None + assert decode_output is not None + assert decode_meta is not None + assert decode_meta.decode_query_len == 1 + decode_output = decode_output.squeeze(1) + output = torch.cat([prefill_output, decode_output], dim=0) return output.view(num_tokens, hidden_size)