add support for cutlass mla full cudagraphs
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -12,11 +12,22 @@ from vllm.attention.backends.abstract import (AttentionType,
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata)
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder)
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
|
||||
# enable full CUDA Graph support for decode-only capture
|
||||
full_cudagraph_supported: ClassVar[bool] = True # Decode-only
|
||||
|
||||
def can_run_in_cudagraph(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
|
||||
return common_attn_metadata.max_query_len == 1
|
||||
|
||||
|
||||
class CutlassMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
@ -27,6 +38,10 @@ class CutlassMLABackend(MLACommonBackend):
|
||||
def get_impl_cls() -> type["CutlassMLAImpl"]:
|
||||
return CutlassMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]:
|
||||
return CutlassMLAMetadataBuilder
|
||||
|
||||
|
||||
class SM100Workspace:
|
||||
|
||||
|
||||
@ -1920,21 +1920,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
elif num_scheduled_tokens in self.cudagraphs \
|
||||
and not skip_cuda_graphs:
|
||||
cudagraph_metadata = self.cudagraphs[num_scheduled_tokens]
|
||||
# if is_global_first_rank():
|
||||
# logger.info(f"UBATCH REPLAY {num_scheduled_tokens}")
|
||||
if is_global_first_rank():
|
||||
logger.info(f"UBATCH REPLAY {num_scheduled_tokens}")
|
||||
cudagraph_metadata.cudagraph.replay()
|
||||
return cudagraph_metadata.outputs
|
||||
else:
|
||||
# if is_global_first_rank():
|
||||
# logger.info(f"RUNNING NORMALLY {num_scheduled_tokens}")
|
||||
if is_global_first_rank():
|
||||
logger.info(f"RUNNING NORMALLY {num_scheduled_tokens}")
|
||||
return self._run_ubatches(ubatch_metadata, self.model)
|
||||
# run normal batch
|
||||
else:
|
||||
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||
self.model_inputs(slice(0, num_scheduled_tokens),
|
||||
scheduler_output, is_dummy_run)
|
||||
# if is_global_first_rank():
|
||||
# logger.info(f"RUNNING FULL BATCH {num_scheduled_tokens}")
|
||||
if is_global_first_rank():
|
||||
logger.info(f"RUNNING FULL BATCH {num_scheduled_tokens}")
|
||||
skip_cuda_graphs = self.parallel_config.enable_microbatching
|
||||
with set_forward_context(attn_metadata,
|
||||
vllm_config=self.vllm_config,
|
||||
|
||||
Reference in New Issue
Block a user