add support for cutlass mla full cudagraphs

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore
2025-08-13 00:14:40 -04:00
parent 6d76bd034a
commit 090f485aa1
2 changed files with 23 additions and 8 deletions

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from typing import Optional
from typing import ClassVar, Optional
import torch
@ -12,11 +12,22 @@ from vllm.attention.backends.abstract import (AttentionType,
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
MLACommonMetadata)
MLACommonMetadata,
MLACommonMetadataBuilder)
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
logger = init_logger(__name__)
class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
# enable full CUDA Graph support for decode-only capture
full_cudagraph_supported: ClassVar[bool] = True # Decode-only
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
return common_attn_metadata.max_query_len == 1
class CutlassMLABackend(MLACommonBackend):
@staticmethod
@ -27,6 +38,10 @@ class CutlassMLABackend(MLACommonBackend):
def get_impl_cls() -> type["CutlassMLAImpl"]:
return CutlassMLAImpl
@staticmethod
def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]:
return CutlassMLAMetadataBuilder
class SM100Workspace:

View File

@ -1920,21 +1920,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
elif num_scheduled_tokens in self.cudagraphs \
and not skip_cuda_graphs:
cudagraph_metadata = self.cudagraphs[num_scheduled_tokens]
# if is_global_first_rank():
# logger.info(f"UBATCH REPLAY {num_scheduled_tokens}")
if is_global_first_rank():
logger.info(f"UBATCH REPLAY {num_scheduled_tokens}")
cudagraph_metadata.cudagraph.replay()
return cudagraph_metadata.outputs
else:
# if is_global_first_rank():
# logger.info(f"RUNNING NORMALLY {num_scheduled_tokens}")
if is_global_first_rank():
logger.info(f"RUNNING NORMALLY {num_scheduled_tokens}")
return self._run_ubatches(ubatch_metadata, self.model)
# run normal batch
else:
input_ids, positions, inputs_embeds, intermediate_tensors = \
self.model_inputs(slice(0, num_scheduled_tokens),
scheduler_output, is_dummy_run)
# if is_global_first_rank():
# logger.info(f"RUNNING FULL BATCH {num_scheduled_tokens}")
if is_global_first_rank():
logger.info(f"RUNNING FULL BATCH {num_scheduled_tokens}")
skip_cuda_graphs = self.parallel_config.enable_microbatching
with set_forward_context(attn_metadata,
vllm_config=self.vllm_config,