diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index b23a8f0a5e..3353aaf760 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os -from typing import Optional +from typing import ClassVar, Optional import torch @@ -12,11 +12,22 @@ from vllm.attention.backends.abstract import (AttentionType, from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import (MLACommonBackend, MLACommonImpl, - MLACommonMetadata) + MLACommonMetadata, + MLACommonMetadataBuilder) +from vllm.v1.attention.backends.utils import CommonAttentionMetadata logger = init_logger(__name__) +class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]): + # enable full CUDA Graph support for decode-only capture + full_cudagraph_supported: ClassVar[bool] = True # Decode-only + + def can_run_in_cudagraph( + self, common_attn_metadata: CommonAttentionMetadata) -> bool: + return common_attn_metadata.max_query_len == 1 + + class CutlassMLABackend(MLACommonBackend): @staticmethod @@ -27,6 +38,10 @@ class CutlassMLABackend(MLACommonBackend): def get_impl_cls() -> type["CutlassMLAImpl"]: return CutlassMLAImpl + @staticmethod + def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]: + return CutlassMLAMetadataBuilder + class SM100Workspace: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a821d4e8c2..f0f4942677 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1920,21 +1920,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): elif num_scheduled_tokens in self.cudagraphs \ and not skip_cuda_graphs: cudagraph_metadata = self.cudagraphs[num_scheduled_tokens] - # if is_global_first_rank(): - # logger.info(f"UBATCH REPLAY {num_scheduled_tokens}") + if is_global_first_rank(): + logger.info(f"UBATCH REPLAY {num_scheduled_tokens}") cudagraph_metadata.cudagraph.replay() return cudagraph_metadata.outputs else: - # if is_global_first_rank(): - # logger.info(f"RUNNING NORMALLY {num_scheduled_tokens}") + if is_global_first_rank(): + logger.info(f"RUNNING NORMALLY {num_scheduled_tokens}") return self._run_ubatches(ubatch_metadata, self.model) # run normal batch else: input_ids, positions, inputs_embeds, intermediate_tensors = \ self.model_inputs(slice(0, num_scheduled_tokens), scheduler_output, is_dummy_run) - # if is_global_first_rank(): - # logger.info(f"RUNNING FULL BATCH {num_scheduled_tokens}") + if is_global_first_rank(): + logger.info(f"RUNNING FULL BATCH {num_scheduled_tokens}") skip_cuda_graphs = self.parallel_config.enable_microbatching with set_forward_context(attn_metadata, vllm_config=self.vllm_config,