Compare commits

...

2 Commits

9 changed files with 52 additions and 76 deletions

View File

@ -459,6 +459,13 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> FlashInferMetadata:
# For full cudagraph capture, ensure decode-only mode
if common_prefix_len == 0 and common_attn_metadata.max_query_len == 1:
# This is likely a cudagraph capture scenario
assert common_attn_metadata.num_reqs == common_attn_metadata.num_actual_tokens, \
"FlashInfer only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
@ -577,22 +584,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
return attn_metadata
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata):
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with FlashInfer.
"""
m = common_attn_metadata
assert m.num_reqs == m.num_actual_tokens, \
"FlashInfer only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
m.max_query_len = 1 # decode-only
return self.build(0, m)
def use_cascade_attention(self, *args, **kwargs) -> bool:
if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype:
# TODO: The cascade wrapper currently does not support setting

View File

@ -43,6 +43,13 @@ class Mamba1AttentionMetadataBuilder(
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> Mamba1AttentionMetadata:
# For full cudagraph capture, ensure decode-only mode
if common_prefix_len == 0 and common_attn_metadata.max_query_len == 1:
# This is likely a cudagraph capture scenario
assert common_attn_metadata.num_reqs == common_attn_metadata.num_actual_tokens, \
"Mamba only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
query_start_loc = common_attn_metadata.query_start_loc
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]

View File

@ -101,6 +101,13 @@ class Mamba2AttentionMetadataBuilder(
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> Mamba2AttentionMetadata:
# For full cudagraph capture, ensure decode-only mode
if common_prefix_len == 0 and common_attn_metadata.max_query_len == 1:
# This is likely a cudagraph capture scenario
assert common_attn_metadata.num_reqs == common_attn_metadata.num_actual_tokens, \
"Mamba only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
num_reqs = common_attn_metadata.num_reqs
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens

View File

@ -38,18 +38,3 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
device=device,
)
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata) -> M:
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with Mamba.
"""
m = common_attn_metadata
assert m.num_reqs == m.num_actual_tokens, \
"Mamba only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
m.max_query_len = 1 # decode-only
return self.build(0, m)

View File

@ -561,25 +561,17 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
seq_lens=seq_lens,
)
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata) -> M:
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with MLA.
"""
m = common_attn_metadata
assert m.num_reqs == m.num_actual_tokens, \
"MLA only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
assert m.max_query_len == 1 # decode-only
return self.build(0, m)
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> M:
# For full cudagraph capture, ensure decode-only mode
if common_prefix_len == 0 and common_attn_metadata.max_query_len == 1:
# This is likely a cudagraph capture scenario
assert common_attn_metadata.num_reqs == common_attn_metadata.num_actual_tokens, \
"MLA only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len

View File

@ -254,19 +254,19 @@ class AiterFlashAttentionMetadataBuilder(
self.aot_sliding_window: Optional[tuple[int, int]] = None
self.total_tokens: int = 0
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata):
self.total_tokens = self.model_config.max_model_len \
* self.vllm_config.scheduler_config.max_num_partial_prefills
res = self.build(common_prefix_len=0,
common_attn_metadata=common_attn_metadata)
self.total_tokens = 0
return res
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> 'AiterFlashAttentionMetadata':
# Handle total_tokens for cudagraph capture scenarios
is_cudagraph_capture = (common_prefix_len == 0 and
common_attn_metadata.max_query_len == 1)
if is_cudagraph_capture:
original_total_tokens = self.total_tokens
self.total_tokens = self.model_config.max_model_len \
* self.vllm_config.scheduler_config.max_num_partial_prefills
else:
original_total_tokens = None
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
@ -310,6 +310,11 @@ class AiterFlashAttentionMetadataBuilder(
common_prefix_len=common_prefix_len,
total_tokens=self.total_tokens,
)
# Restore total_tokens value if this was a cudagraph capture
if is_cudagraph_capture:
self.total_tokens = original_total_tokens
return attn_metadata
def use_cascade_attention(self, *args, **kwargs) -> bool:

View File

@ -73,16 +73,6 @@ class TritonAttentionMetadataBuilder(
vllm_config.parallel_config)
self.headdim = model_config.get_head_size()
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
) -> TritonAttentionMetadata:
attn_metadata = self.build(0, common_attn_metadata)
# When doing full graph capture, setting seq_lens to
# max_model_len will cause graph capture to be extremely
# slow, so here we set it to 1.
attn_metadata.seq_lens.fill_(1)
return attn_metadata
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
@ -129,6 +119,14 @@ class TritonAttentionMetadataBuilder(
suffix_kv_lens=suffix_kv_lens,
prefix_scheduler_metadata=prefix_scheduler_metadata,
)
# Handle cudagraph capture optimizations
if common_prefix_len == 0 and common_attn_metadata.max_query_len == 1:
# When doing full graph capture, setting seq_lens to
# max_model_len will cause graph capture to be extremely
# slow, so here we set it to 1.
attn_metadata.seq_lens.fill_(1)
return attn_metadata

View File

@ -205,16 +205,6 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
"""
raise NotImplementedError
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata) -> M:
"""
Build attention metadata for CUDA graph capture. Uses build by default.
Subclasses that override this method should call self.build or
super().build_for_cudagraph_capture.
"""
return self.build(common_prefix_len=0,
common_attn_metadata=common_attn_metadata)
def build_for_drafting(
self,
common_attn_metadata: CommonAttentionMetadata,

View File

@ -2306,7 +2306,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for attn_group in self.attn_groups[kv_cache_group_id]:
attn_metadata_i = attn_group.metadata_builder\
.build_for_cudagraph_capture(common_attn_metadata)
.build(common_prefix_len=0,
common_attn_metadata=common_attn_metadata)
for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i