[Bugfix] Check NVIDIA artifactory is accessible before using flashinfer cubin kernels (#21893)

This commit is contained in:
Michael Goin
2025-08-01 08:28:45 -04:00
committed by GitHub
parent fb0e0d46fc
commit f81c1bb055
4 changed files with 93 additions and 99 deletions

View File

@ -44,9 +44,9 @@ from vllm.attention.layer import Attention
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
make_tensor_with_pad)
from vllm.utils.flashinfer import use_trtllm_decode_attention
logger = init_logger(__name__)
@ -56,7 +56,6 @@ if TYPE_CHECKING:
class FlashInferBackend(AttentionBackend):
cached_sm100a_supported: Optional[bool] = None
@staticmethod
def get_name() -> str:
@ -123,47 +122,6 @@ class FlashInferBackend(AttentionBackend):
else:
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
@staticmethod
def use_trtllm_decode_attention(
batch_size: int,
max_seq_len: int,
kv_cache_dtype: str,
num_qo_heads: Optional[int],
num_kv_heads: Optional[int],
attn_head_size: Optional[int],
) -> bool:
if FlashInferBackend.cached_sm100a_supported is None:
FlashInferBackend.cached_sm100a_supported = (
current_platform.has_device_capability(100))
if not FlashInferBackend.cached_sm100a_supported:
return False
# Check if the dimensions are supported by TRTLLM decode attention
if (attn_head_size is None or num_qo_heads is None
or num_kv_heads is None or num_qo_heads // num_kv_heads > 8
or num_qo_heads % num_kv_heads != 0 or attn_head_size != 128):
return False
env_value = envs.VLLM_USE_TRTLLM_DECODE_ATTENTION
if env_value is not None:
logger.info_once("VLLM_USE_TRTLLM_DECODE_ATTENTION is set to %s",
env_value)
# Environment variable is set - respect it
# Making the conditional check for zero because
# the path is automatically enabled if the batch size condition
# is satisfied.
no_use_trtllm = (env_value == "0")
if not no_use_trtllm:
logger.info_once("Using TRTLLM decode attention.")
return not no_use_trtllm
else:
# Environment variable not set - use auto-detection
use_trtllm = (FlashInferBackend.cached_sm100a_supported
and batch_size <= 256 and max_seq_len < 131072
and kv_cache_dtype == "auto")
if use_trtllm:
logger.warning_once(
"Using TRTLLM decode attention (auto-detected).")
return use_trtllm
@dataclass
class PerLayerParameters:
@ -1156,7 +1114,7 @@ class FlashInferImpl(AttentionImpl):
assert decode_meta.decode_wrapper._sm_scale == softmax_scale
# TODO: @pavanimajety Remove this once the switch happens
# inside flashinfer.
if not FlashInferBackend.use_trtllm_decode_attention(
if not use_trtllm_decode_attention(
num_decode_tokens, attn_metadata.max_decode_seq_len,
kv_cache_dtype, attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads, attn_metadata.head_dim):

View File

@ -10,12 +10,25 @@ import contextlib
import functools
import importlib
import importlib.util
from typing import Any, Callable, NoReturn
import os
from typing import Any, Callable, NoReturn, Optional
import requests
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__)
# This is the storage path for the cubins, it can be replaced
# with a local path for testing.
# Referenced from https://github.com/flashinfer-ai/flashinfer/blob/0c9a92c3d9a7e043ab6f3f7b2273269caf6ab044/flashinfer/jit/cubin_loader.py#L35 # noqa: E501
FLASHINFER_CUBINS_REPOSITORY = os.environ.get(
"FLASHINFER_CUBINS_REPOSITORY",
"https://edge.urm.nvidia.com/artifactory/sw-kernelinferencelibrary-public-generic-local/", # noqa: E501
)
@functools.cache
def has_flashinfer() -> bool:
@ -108,6 +121,70 @@ def has_flashinfer_cutlass_fused_moe() -> bool:
return True
@functools.cache
def has_nvidia_artifactory() -> bool:
"""Return ``True`` if NVIDIA's artifactory is accessible.
This checks connectivity to the kernel inference library artifactory
which is required for downloading certain cubin kernels like TRTLLM FHMA.
"""
try:
# Use a short timeout to avoid blocking for too long
response = requests.get(FLASHINFER_CUBINS_REPOSITORY, timeout=5)
accessible = response.status_code == 200
if accessible:
logger.debug_once("NVIDIA artifactory is accessible")
else:
logger.warning_once(
"NVIDIA artifactory returned failed status code: %d",
response.status_code)
return accessible
except Exception as e:
logger.warning_once("Failed to connect to NVIDIA artifactory: %s", e)
return False
def use_trtllm_decode_attention(
num_tokens: int,
max_seq_len: int,
kv_cache_dtype: str,
num_qo_heads: Optional[int],
num_kv_heads: Optional[int],
attn_head_size: Optional[int],
) -> bool:
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins
if not (current_platform.is_device_capability(100)
and has_nvidia_artifactory()):
return False
# Check if the dimensions are supported by TRTLLM decode attention
if (attn_head_size is None or num_qo_heads is None or num_kv_heads is None
or num_qo_heads // num_kv_heads > 8
or num_qo_heads % num_kv_heads != 0 or attn_head_size != 128):
return False
env_value = envs.VLLM_USE_TRTLLM_DECODE_ATTENTION
if env_value is not None:
logger.info_once("VLLM_USE_TRTLLM_DECODE_ATTENTION is set to %s",
env_value)
# Environment variable is set - respect it
# Making the conditional check for zero because
# the path is automatically enabled if the batch size condition
# is satisfied.
no_use_trtllm = (env_value == "0")
if not no_use_trtllm:
logger.info_once("Using TRTLLM decode attention.")
return not no_use_trtllm
else:
# Environment variable not set - use auto-detection
use_trtllm = (num_tokens <= 256 and max_seq_len < 131072
and kv_cache_dtype == "auto")
if use_trtllm:
logger.warning_once(
"Using TRTLLM decode attention (auto-detected).")
return use_trtllm
__all__ = [
"has_flashinfer",
"flashinfer_trtllm_fp8_block_scale_moe",
@ -117,4 +194,6 @@ __all__ = [
"autotune",
"has_flashinfer_moe",
"has_flashinfer_cutlass_fused_moe",
"has_nvidia_artifactory",
"use_trtllm_decode_attention",
]

View File

@ -17,8 +17,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionType)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv
from vllm.utils.flashinfer import use_trtllm_decode_attention
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout,
@ -38,7 +38,6 @@ logger = init_logger(__name__)
class FlashInferBackend(AttentionBackend):
accept_output_buffer: bool = True
cached_sm100a_supported: Optional[bool] = None
@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
@ -98,48 +97,6 @@ class FlashInferBackend(AttentionBackend):
raise ValueError(f"Unknown cache layout format {cache_layout}.")
return stride_order
@staticmethod
def use_trtllm_decode_attention(
batch_size: int,
max_seq_len: int,
kv_cache_dtype: str,
num_qo_heads: int,
num_kv_heads: int,
attn_head_size: int,
) -> bool:
if FlashInferBackend.cached_sm100a_supported is None:
FlashInferBackend.cached_sm100a_supported = (
current_platform.has_device_capability(100))
if not FlashInferBackend.cached_sm100a_supported:
return False
if (num_qo_heads // num_kv_heads > 8
or num_qo_heads % num_kv_heads != 0 or attn_head_size != 128):
return False
env_value = envs.VLLM_USE_TRTLLM_DECODE_ATTENTION
if env_value is not None:
logger.info_once("VLLM_USE_TRTLLM_DECODE_ATTENTION is set to %s",
env_value)
# Environment variable is set - respect it
# Making the conditional check for zero because
# the path is automatically enabled if the batch size condition
# is satisfied.
no_use_trtllm = env_value == "0"
if not no_use_trtllm:
logger.info_once(
"VLLM_USE_TRTLLM_DECODE_ATTENTION is set to 1, "
"using TRTLLM decode attention.")
return not no_use_trtllm
else:
# Environment variable not set - use auto-detection
# Only supports attention head size of 128
use_trtllm = (FlashInferBackend.cached_sm100a_supported
and batch_size <= 256 and max_seq_len < 131072
and kv_cache_dtype == "auto")
if use_trtllm:
logger.warning_once(
"Using TRTLLM decode attention (auto-detected).")
return use_trtllm
@staticmethod
def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
@ -352,7 +309,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
if num_decodes > 0:
attn_metadata.decode_wrapper = self._get_decode_wrapper()
if not FlashInferBackend.use_trtllm_decode_attention(
if not use_trtllm_decode_attention(
num_decodes, attn_metadata.max_seq_len,
self.cache_config.cache_dtype,
attn_metadata.num_qo_heads, attn_metadata.num_kv_heads,
@ -636,7 +593,7 @@ class FlashInferImpl(AttentionImpl):
decode_query = query[:num_decode_tokens]
assert decode_query.shape[0] == num_decode_tokens
assert decode_wrapper is not None
if not FlashInferBackend.use_trtllm_decode_attention(
if not use_trtllm_decode_attention(
attn_metadata.num_decodes, attn_metadata.max_seq_len,
self.kv_cache_dtype, attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads, attn_metadata.head_dim):

View File

@ -209,6 +209,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
UnquantizedLinearMethod)
from vllm.platforms import current_platform
from vllm.utils import cdiv, round_down
from vllm.utils.flashinfer import has_nvidia_artifactory
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata,
get_per_layer_parameters, infer_global_hyperparameters,
@ -379,17 +380,16 @@ M = TypeVar("M", bound=MLACommonMetadata)
def use_flashinfer_prefill() -> bool:
if flashinfer_available and not envs.VLLM_USE_CUDNN_PREFILL:
# For blackwell default to flashinfer prefill if its available since
# its faster than FA2.
return current_platform.has_device_capability(100)
return False
# For blackwell default to flashinfer prefill if its available since
# it is faster than FA2.
return (flashinfer_available and not envs.VLLM_USE_CUDNN_PREFILL
and current_platform.is_device_capability(100))
def use_cudnn_prefill() -> bool:
if flashinfer_available and envs.VLLM_USE_CUDNN_PREFILL:
return current_platform.has_device_capability(100)
return False
return (flashinfer_available and envs.VLLM_USE_CUDNN_PREFILL
and current_platform.is_device_capability(100)
and has_nvidia_artifactory())
# Currently 394MB, this can be tuned based on GEMM sizes used.