Compare commits
2 Commits
khluu/nccl
...
copilot/fi
| Author | SHA1 | Date | |
|---|---|---|---|
| fb0089c536 | |||
| b8d7f55dbd |
@ -459,6 +459,13 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> FlashInferMetadata:
|
||||
# For full cudagraph capture, ensure decode-only mode
|
||||
if common_prefix_len == 0 and common_attn_metadata.max_query_len == 1:
|
||||
# This is likely a cudagraph capture scenario
|
||||
assert common_attn_metadata.num_reqs == common_attn_metadata.num_actual_tokens, \
|
||||
"FlashInfer only supports decode-only full CUDAGraph capture. " \
|
||||
"Make sure all cudagraph capture sizes <= max_num_seq."
|
||||
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
|
||||
@ -577,22 +584,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata):
|
||||
"""
|
||||
This method builds the metadata for full cudagraph capture.
|
||||
Currently, only decode is supported for full cudagraphs with FlashInfer.
|
||||
"""
|
||||
m = common_attn_metadata
|
||||
|
||||
assert m.num_reqs == m.num_actual_tokens, \
|
||||
"FlashInfer only supports decode-only full CUDAGraph capture. " \
|
||||
"Make sure all cudagraph capture sizes <= max_num_seq."
|
||||
|
||||
m.max_query_len = 1 # decode-only
|
||||
|
||||
return self.build(0, m)
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype:
|
||||
# TODO: The cascade wrapper currently does not support setting
|
||||
|
||||
@ -43,6 +43,13 @@ class Mamba1AttentionMetadataBuilder(
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> Mamba1AttentionMetadata:
|
||||
# For full cudagraph capture, ensure decode-only mode
|
||||
if common_prefix_len == 0 and common_attn_metadata.max_query_len == 1:
|
||||
# This is likely a cudagraph capture scenario
|
||||
assert common_attn_metadata.num_reqs == common_attn_metadata.num_actual_tokens, \
|
||||
"Mamba only supports decode-only full CUDAGraph capture. " \
|
||||
"Make sure all cudagraph capture sizes <= max_num_seq."
|
||||
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||
|
||||
@ -101,6 +101,13 @@ class Mamba2AttentionMetadataBuilder(
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> Mamba2AttentionMetadata:
|
||||
# For full cudagraph capture, ensure decode-only mode
|
||||
if common_prefix_len == 0 and common_attn_metadata.max_query_len == 1:
|
||||
# This is likely a cudagraph capture scenario
|
||||
assert common_attn_metadata.num_reqs == common_attn_metadata.num_actual_tokens, \
|
||||
"Mamba only supports decode-only full CUDAGraph capture. " \
|
||||
"Make sure all cudagraph capture sizes <= max_num_seq."
|
||||
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
|
||||
@ -38,18 +38,3 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
device=device,
|
||||
)
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> M:
|
||||
"""
|
||||
This method builds the metadata for full cudagraph capture.
|
||||
Currently, only decode is supported for full cudagraphs with Mamba.
|
||||
"""
|
||||
m = common_attn_metadata
|
||||
|
||||
assert m.num_reqs == m.num_actual_tokens, \
|
||||
"Mamba only supports decode-only full CUDAGraph capture. " \
|
||||
"Make sure all cudagraph capture sizes <= max_num_seq."
|
||||
|
||||
m.max_query_len = 1 # decode-only
|
||||
|
||||
return self.build(0, m)
|
||||
@ -561,25 +561,17 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
seq_lens=seq_lens,
|
||||
)
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> M:
|
||||
"""
|
||||
This method builds the metadata for full cudagraph capture.
|
||||
Currently, only decode is supported for full cudagraphs with MLA.
|
||||
"""
|
||||
m = common_attn_metadata
|
||||
assert m.num_reqs == m.num_actual_tokens, \
|
||||
"MLA only supports decode-only full CUDAGraph capture. " \
|
||||
"Make sure all cudagraph capture sizes <= max_num_seq."
|
||||
|
||||
assert m.max_query_len == 1 # decode-only
|
||||
|
||||
return self.build(0, m)
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> M:
|
||||
# For full cudagraph capture, ensure decode-only mode
|
||||
if common_prefix_len == 0 and common_attn_metadata.max_query_len == 1:
|
||||
# This is likely a cudagraph capture scenario
|
||||
assert common_attn_metadata.num_reqs == common_attn_metadata.num_actual_tokens, \
|
||||
"MLA only supports decode-only full CUDAGraph capture. " \
|
||||
"Make sure all cudagraph capture sizes <= max_num_seq."
|
||||
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
|
||||
@ -254,19 +254,19 @@ class AiterFlashAttentionMetadataBuilder(
|
||||
self.aot_sliding_window: Optional[tuple[int, int]] = None
|
||||
self.total_tokens: int = 0
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata):
|
||||
self.total_tokens = self.model_config.max_model_len \
|
||||
* self.vllm_config.scheduler_config.max_num_partial_prefills
|
||||
res = self.build(common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata)
|
||||
self.total_tokens = 0
|
||||
return res
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> 'AiterFlashAttentionMetadata':
|
||||
# Handle total_tokens for cudagraph capture scenarios
|
||||
is_cudagraph_capture = (common_prefix_len == 0 and
|
||||
common_attn_metadata.max_query_len == 1)
|
||||
if is_cudagraph_capture:
|
||||
original_total_tokens = self.total_tokens
|
||||
self.total_tokens = self.model_config.max_model_len \
|
||||
* self.vllm_config.scheduler_config.max_num_partial_prefills
|
||||
else:
|
||||
original_total_tokens = None
|
||||
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
@ -310,6 +310,11 @@ class AiterFlashAttentionMetadataBuilder(
|
||||
common_prefix_len=common_prefix_len,
|
||||
total_tokens=self.total_tokens,
|
||||
)
|
||||
|
||||
# Restore total_tokens value if this was a cudagraph capture
|
||||
if is_cudagraph_capture:
|
||||
self.total_tokens = original_total_tokens
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
|
||||
@ -73,16 +73,6 @@ class TritonAttentionMetadataBuilder(
|
||||
vllm_config.parallel_config)
|
||||
self.headdim = model_config.get_head_size()
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata
|
||||
) -> TritonAttentionMetadata:
|
||||
attn_metadata = self.build(0, common_attn_metadata)
|
||||
# When doing full graph capture, setting seq_lens to
|
||||
# max_model_len will cause graph capture to be extremely
|
||||
# slow, so here we set it to 1.
|
||||
attn_metadata.seq_lens.fill_(1)
|
||||
return attn_metadata
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
@ -129,6 +119,14 @@ class TritonAttentionMetadataBuilder(
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
prefix_scheduler_metadata=prefix_scheduler_metadata,
|
||||
)
|
||||
|
||||
# Handle cudagraph capture optimizations
|
||||
if common_prefix_len == 0 and common_attn_metadata.max_query_len == 1:
|
||||
# When doing full graph capture, setting seq_lens to
|
||||
# max_model_len will cause graph capture to be extremely
|
||||
# slow, so here we set it to 1.
|
||||
attn_metadata.seq_lens.fill_(1)
|
||||
|
||||
return attn_metadata
|
||||
|
||||
|
||||
|
||||
@ -205,16 +205,6 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> M:
|
||||
"""
|
||||
Build attention metadata for CUDA graph capture. Uses build by default.
|
||||
Subclasses that override this method should call self.build or
|
||||
super().build_for_cudagraph_capture.
|
||||
"""
|
||||
return self.build(common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata)
|
||||
|
||||
def build_for_drafting(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
|
||||
@ -2306,7 +2306,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
for attn_group in self.attn_groups[kv_cache_group_id]:
|
||||
attn_metadata_i = attn_group.metadata_builder\
|
||||
.build_for_cudagraph_capture(common_attn_metadata)
|
||||
.build(common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata)
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
|
||||
|
||||
Reference in New Issue
Block a user