Compare commits
1 Commits
debug
...
woosuk/cle
| Author | SHA1 | Date | |
|---|---|---|---|
| 0dba2a36a9 |
@ -39,6 +39,7 @@ from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
|||||||
split_decodes_and_prefills)
|
split_decodes_and_prefills)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
from vllm.v1.worker.utils import CpuGpuBuffer
|
||||||
|
|
||||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
||||||
|
|
||||||
@ -215,34 +216,20 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
self.global_hyperparameters = infer_global_hyperparameters(
|
self.global_hyperparameters = infer_global_hyperparameters(
|
||||||
get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl))
|
get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl))
|
||||||
|
|
||||||
# Preparing persistent buffers (device-side)
|
# Preparing persistent buffers
|
||||||
self.paged_kv_indptr = torch.zeros(max_num_reqs + 1,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=self.device)
|
|
||||||
self.paged_kv_indices = torch.zeros(
|
|
||||||
max_num_pages, # max num pages possible
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=self.device)
|
|
||||||
self.paged_kv_last_page_len = torch.zeros(max_num_reqs,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=self.device)
|
|
||||||
# host-side buffer
|
|
||||||
pin_memory = is_pin_memory_available()
|
pin_memory = is_pin_memory_available()
|
||||||
self.paged_kv_indptr_cpu = torch.zeros(max_num_reqs + 1,
|
self.paged_kv_indptr = CpuGpuBuffer(max_num_reqs + 1,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device="cpu",
|
device=self.device,
|
||||||
pin_memory=pin_memory)
|
pin_memory=pin_memory)
|
||||||
self.paged_kv_indptr_np = self.paged_kv_indptr_cpu.numpy()
|
self.paged_kv_indices = CpuGpuBuffer(max_num_pages,
|
||||||
self.paged_kv_indices_cpu = torch.zeros(max_num_pages,
|
dtype=torch.int32,
|
||||||
dtype=torch.int32,
|
device=self.device,
|
||||||
device="cpu",
|
pin_memory=pin_memory)
|
||||||
pin_memory=pin_memory)
|
self.paged_kv_last_page_len = CpuGpuBuffer(max_num_reqs,
|
||||||
self.paged_kv_last_page_len_cpu = torch.zeros(max_num_reqs,
|
dtype=torch.int32,
|
||||||
dtype=torch.int32,
|
device=self.device,
|
||||||
device="cpu",
|
pin_memory=pin_memory)
|
||||||
pin_memory=pin_memory)
|
|
||||||
self.paged_kv_last_page_len_np = (
|
|
||||||
self.paged_kv_last_page_len_cpu.numpy())
|
|
||||||
|
|
||||||
def _get_workspace_buffer(self):
|
def _get_workspace_buffer(self):
|
||||||
if self._workspace_buffer is None:
|
if self._workspace_buffer is None:
|
||||||
@ -269,10 +256,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
|
|
||||||
if decode_wrapper is None:
|
if decode_wrapper is None:
|
||||||
if use_cudagraph:
|
if use_cudagraph:
|
||||||
paged_kv_indptr = self.paged_kv_indptr[:batch_size + 1]
|
paged_kv_indptr = self.paged_kv_indptr.gpu[:batch_size + 1]
|
||||||
paged_kv_indices = self.paged_kv_indices
|
paged_kv_indices = self.paged_kv_indices.gpu
|
||||||
paged_kv_last_page_len = self.paged_kv_last_page_len[:
|
paged_kv_last_page_len = (
|
||||||
batch_size]
|
self.paged_kv_last_page_len.gpu[:batch_size])
|
||||||
else:
|
else:
|
||||||
paged_kv_indptr = None
|
paged_kv_indptr = None
|
||||||
paged_kv_indices = None
|
paged_kv_indices = None
|
||||||
@ -355,15 +342,13 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
np.cumsum(
|
np.cumsum(
|
||||||
num_blocks_np,
|
num_blocks_np,
|
||||||
dtype=np.int32,
|
dtype=np.int32,
|
||||||
out=self.paged_kv_indptr_np[1:num_reqs + 1],
|
out=self.paged_kv_indptr.np[1:num_reqs + 1],
|
||||||
)
|
)
|
||||||
paged_kv_indptr = self.paged_kv_indptr[:num_reqs + 1]
|
paged_kv_indptr = self.paged_kv_indptr.copy_to_gpu(num_reqs + 1)
|
||||||
paged_kv_indptr.copy_(self.paged_kv_indptr_cpu[:num_reqs + 1],
|
|
||||||
non_blocking=True)
|
|
||||||
|
|
||||||
# write self.paged_kv_indices inplace
|
# write self.paged_kv_indices inplace
|
||||||
num_actual_pages = num_blocks_np.sum().item()
|
num_actual_pages = num_blocks_np.sum().item()
|
||||||
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
|
paged_kv_indices = self.paged_kv_indices.gpu[:num_actual_pages]
|
||||||
_copy_page_indices_kernel[(num_reqs, )](
|
_copy_page_indices_kernel[(num_reqs, )](
|
||||||
paged_kv_indices,
|
paged_kv_indices,
|
||||||
block_table_tensor,
|
block_table_tensor,
|
||||||
@ -374,7 +359,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
|
|
||||||
# write self.paged_kv_last_page_len_cpu inplace
|
# write self.paged_kv_last_page_len_cpu inplace
|
||||||
paged_kv_last_page_len_np = seq_lens_np % page_size
|
paged_kv_last_page_len_np = seq_lens_np % page_size
|
||||||
self.paged_kv_last_page_len_np[:num_reqs] = np.where(
|
self.paged_kv_last_page_len.np[:num_reqs] = np.where(
|
||||||
paged_kv_last_page_len_np == 0,
|
paged_kv_last_page_len_np == 0,
|
||||||
page_size,
|
page_size,
|
||||||
paged_kv_last_page_len_np,
|
paged_kv_last_page_len_np,
|
||||||
@ -418,8 +403,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu
|
qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu
|
||||||
paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[:1 + num_reqs]
|
paged_kv_indptr_cpu = self.paged_kv_indptr.cpu[:1 + num_reqs]
|
||||||
paged_kv_last_page_len_cpu = self.paged_kv_last_page_len_cpu[:num_reqs]
|
paged_kv_last_page_len_cpu = self.paged_kv_last_page_len.cpu[:num_reqs]
|
||||||
|
|
||||||
if attn_metadata.use_cascade:
|
if attn_metadata.use_cascade:
|
||||||
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
|
attn_metadata.cascade_wrapper = self._get_cascade_wrapper()
|
||||||
@ -495,14 +480,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
# Carefully fulfill the padding region with reasonable value
|
# Carefully fulfill the padding region with reasonable value
|
||||||
# on cpu.
|
# on cpu.
|
||||||
# Make sure paged_kv_indptr_cpu is not decreasing
|
# Make sure paged_kv_indptr_cpu is not decreasing
|
||||||
self.paged_kv_indptr_cpu[1 + num_decodes:1 +
|
self.paged_kv_indptr.np[1 + num_decodes:1 +
|
||||||
num_input_tokens].fill_(
|
num_input_tokens].fill(
|
||||||
paged_kv_indptr_cpu[-1])
|
paged_kv_indptr_cpu[-1])
|
||||||
# Fill the remaining paged_kv_last_page_len_cpu with 1.
|
# Fill the remaining paged_kv_last_page_len_cpu with 1.
|
||||||
# This is because flashinfer treats 0 as a full page
|
# This is because flashinfer treats 0 as a full page
|
||||||
# instead of empty.
|
# instead of empty.
|
||||||
self.paged_kv_last_page_len_cpu[
|
self.paged_kv_last_page_len.np[
|
||||||
num_decodes:num_input_tokens].fill_(1)
|
num_decodes:num_input_tokens].fill(1)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
num_input_tokens = num_decodes
|
num_input_tokens = num_decodes
|
||||||
@ -515,9 +500,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
# in atten_metadata when using cudagraph.
|
# in atten_metadata when using cudagraph.
|
||||||
fast_plan_decode(
|
fast_plan_decode(
|
||||||
attn_metadata.decode_wrapper,
|
attn_metadata.decode_wrapper,
|
||||||
self.paged_kv_indptr_cpu[:num_input_tokens + 1],
|
self.paged_kv_indptr.cpu[:num_input_tokens + 1],
|
||||||
paged_kv_indices,
|
paged_kv_indices,
|
||||||
self.paged_kv_last_page_len_cpu[:num_input_tokens],
|
self.paged_kv_last_page_len.cpu[:num_input_tokens],
|
||||||
seq_lens_cpu[:num_input_tokens],
|
seq_lens_cpu[:num_input_tokens],
|
||||||
self.num_qo_heads,
|
self.num_qo_heads,
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
|
|||||||
Reference in New Issue
Block a user