[Bugfix] Check NVIDIA artifactory is accessible before using flashinfer cubin kernels (#21893)
This commit is contained in:
@ -44,9 +44,9 @@ from vllm.attention.layer import Attention
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
|
||||
make_tensor_with_pad)
|
||||
from vllm.utils.flashinfer import use_trtllm_decode_attention
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -56,7 +56,6 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class FlashInferBackend(AttentionBackend):
|
||||
cached_sm100a_supported: Optional[bool] = None
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
@ -123,47 +122,6 @@ class FlashInferBackend(AttentionBackend):
|
||||
else:
|
||||
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
|
||||
|
||||
@staticmethod
|
||||
def use_trtllm_decode_attention(
|
||||
batch_size: int,
|
||||
max_seq_len: int,
|
||||
kv_cache_dtype: str,
|
||||
num_qo_heads: Optional[int],
|
||||
num_kv_heads: Optional[int],
|
||||
attn_head_size: Optional[int],
|
||||
) -> bool:
|
||||
if FlashInferBackend.cached_sm100a_supported is None:
|
||||
FlashInferBackend.cached_sm100a_supported = (
|
||||
current_platform.has_device_capability(100))
|
||||
if not FlashInferBackend.cached_sm100a_supported:
|
||||
return False
|
||||
# Check if the dimensions are supported by TRTLLM decode attention
|
||||
if (attn_head_size is None or num_qo_heads is None
|
||||
or num_kv_heads is None or num_qo_heads // num_kv_heads > 8
|
||||
or num_qo_heads % num_kv_heads != 0 or attn_head_size != 128):
|
||||
return False
|
||||
env_value = envs.VLLM_USE_TRTLLM_DECODE_ATTENTION
|
||||
if env_value is not None:
|
||||
logger.info_once("VLLM_USE_TRTLLM_DECODE_ATTENTION is set to %s",
|
||||
env_value)
|
||||
# Environment variable is set - respect it
|
||||
# Making the conditional check for zero because
|
||||
# the path is automatically enabled if the batch size condition
|
||||
# is satisfied.
|
||||
no_use_trtllm = (env_value == "0")
|
||||
if not no_use_trtllm:
|
||||
logger.info_once("Using TRTLLM decode attention.")
|
||||
return not no_use_trtllm
|
||||
else:
|
||||
# Environment variable not set - use auto-detection
|
||||
use_trtllm = (FlashInferBackend.cached_sm100a_supported
|
||||
and batch_size <= 256 and max_seq_len < 131072
|
||||
and kv_cache_dtype == "auto")
|
||||
if use_trtllm:
|
||||
logger.warning_once(
|
||||
"Using TRTLLM decode attention (auto-detected).")
|
||||
return use_trtllm
|
||||
|
||||
|
||||
@dataclass
|
||||
class PerLayerParameters:
|
||||
@ -1156,7 +1114,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
assert decode_meta.decode_wrapper._sm_scale == softmax_scale
|
||||
# TODO: @pavanimajety Remove this once the switch happens
|
||||
# inside flashinfer.
|
||||
if not FlashInferBackend.use_trtllm_decode_attention(
|
||||
if not use_trtllm_decode_attention(
|
||||
num_decode_tokens, attn_metadata.max_decode_seq_len,
|
||||
kv_cache_dtype, attn_metadata.num_qo_heads,
|
||||
attn_metadata.num_kv_heads, attn_metadata.head_dim):
|
||||
|
||||
@ -10,12 +10,25 @@ import contextlib
|
||||
import functools
|
||||
import importlib
|
||||
import importlib.util
|
||||
from typing import Any, Callable, NoReturn
|
||||
import os
|
||||
from typing import Any, Callable, NoReturn, Optional
|
||||
|
||||
import requests
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# This is the storage path for the cubins, it can be replaced
|
||||
# with a local path for testing.
|
||||
# Referenced from https://github.com/flashinfer-ai/flashinfer/blob/0c9a92c3d9a7e043ab6f3f7b2273269caf6ab044/flashinfer/jit/cubin_loader.py#L35 # noqa: E501
|
||||
FLASHINFER_CUBINS_REPOSITORY = os.environ.get(
|
||||
"FLASHINFER_CUBINS_REPOSITORY",
|
||||
"https://edge.urm.nvidia.com/artifactory/sw-kernelinferencelibrary-public-generic-local/", # noqa: E501
|
||||
)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_flashinfer() -> bool:
|
||||
@ -108,6 +121,70 @@ def has_flashinfer_cutlass_fused_moe() -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_nvidia_artifactory() -> bool:
|
||||
"""Return ``True`` if NVIDIA's artifactory is accessible.
|
||||
|
||||
This checks connectivity to the kernel inference library artifactory
|
||||
which is required for downloading certain cubin kernels like TRTLLM FHMA.
|
||||
"""
|
||||
try:
|
||||
# Use a short timeout to avoid blocking for too long
|
||||
response = requests.get(FLASHINFER_CUBINS_REPOSITORY, timeout=5)
|
||||
accessible = response.status_code == 200
|
||||
if accessible:
|
||||
logger.debug_once("NVIDIA artifactory is accessible")
|
||||
else:
|
||||
logger.warning_once(
|
||||
"NVIDIA artifactory returned failed status code: %d",
|
||||
response.status_code)
|
||||
return accessible
|
||||
except Exception as e:
|
||||
logger.warning_once("Failed to connect to NVIDIA artifactory: %s", e)
|
||||
return False
|
||||
|
||||
|
||||
def use_trtllm_decode_attention(
|
||||
num_tokens: int,
|
||||
max_seq_len: int,
|
||||
kv_cache_dtype: str,
|
||||
num_qo_heads: Optional[int],
|
||||
num_kv_heads: Optional[int],
|
||||
attn_head_size: Optional[int],
|
||||
) -> bool:
|
||||
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins
|
||||
if not (current_platform.is_device_capability(100)
|
||||
and has_nvidia_artifactory()):
|
||||
return False
|
||||
|
||||
# Check if the dimensions are supported by TRTLLM decode attention
|
||||
if (attn_head_size is None or num_qo_heads is None or num_kv_heads is None
|
||||
or num_qo_heads // num_kv_heads > 8
|
||||
or num_qo_heads % num_kv_heads != 0 or attn_head_size != 128):
|
||||
return False
|
||||
|
||||
env_value = envs.VLLM_USE_TRTLLM_DECODE_ATTENTION
|
||||
if env_value is not None:
|
||||
logger.info_once("VLLM_USE_TRTLLM_DECODE_ATTENTION is set to %s",
|
||||
env_value)
|
||||
# Environment variable is set - respect it
|
||||
# Making the conditional check for zero because
|
||||
# the path is automatically enabled if the batch size condition
|
||||
# is satisfied.
|
||||
no_use_trtllm = (env_value == "0")
|
||||
if not no_use_trtllm:
|
||||
logger.info_once("Using TRTLLM decode attention.")
|
||||
return not no_use_trtllm
|
||||
else:
|
||||
# Environment variable not set - use auto-detection
|
||||
use_trtllm = (num_tokens <= 256 and max_seq_len < 131072
|
||||
and kv_cache_dtype == "auto")
|
||||
if use_trtllm:
|
||||
logger.warning_once(
|
||||
"Using TRTLLM decode attention (auto-detected).")
|
||||
return use_trtllm
|
||||
|
||||
|
||||
__all__ = [
|
||||
"has_flashinfer",
|
||||
"flashinfer_trtllm_fp8_block_scale_moe",
|
||||
@ -117,4 +194,6 @@ __all__ = [
|
||||
"autotune",
|
||||
"has_flashinfer_moe",
|
||||
"has_flashinfer_cutlass_fused_moe",
|
||||
"has_nvidia_artifactory",
|
||||
"use_trtllm_decode_attention",
|
||||
]
|
||||
|
||||
@ -17,8 +17,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionType)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv
|
||||
from vllm.utils.flashinfer import use_trtllm_decode_attention
|
||||
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout,
|
||||
@ -38,7 +38,6 @@ logger = init_logger(__name__)
|
||||
class FlashInferBackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
cached_sm100a_supported: Optional[bool] = None
|
||||
|
||||
@classmethod
|
||||
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||
@ -98,48 +97,6 @@ class FlashInferBackend(AttentionBackend):
|
||||
raise ValueError(f"Unknown cache layout format {cache_layout}.")
|
||||
return stride_order
|
||||
|
||||
@staticmethod
|
||||
def use_trtllm_decode_attention(
|
||||
batch_size: int,
|
||||
max_seq_len: int,
|
||||
kv_cache_dtype: str,
|
||||
num_qo_heads: int,
|
||||
num_kv_heads: int,
|
||||
attn_head_size: int,
|
||||
) -> bool:
|
||||
if FlashInferBackend.cached_sm100a_supported is None:
|
||||
FlashInferBackend.cached_sm100a_supported = (
|
||||
current_platform.has_device_capability(100))
|
||||
if not FlashInferBackend.cached_sm100a_supported:
|
||||
return False
|
||||
if (num_qo_heads // num_kv_heads > 8
|
||||
or num_qo_heads % num_kv_heads != 0 or attn_head_size != 128):
|
||||
return False
|
||||
env_value = envs.VLLM_USE_TRTLLM_DECODE_ATTENTION
|
||||
if env_value is not None:
|
||||
logger.info_once("VLLM_USE_TRTLLM_DECODE_ATTENTION is set to %s",
|
||||
env_value)
|
||||
# Environment variable is set - respect it
|
||||
# Making the conditional check for zero because
|
||||
# the path is automatically enabled if the batch size condition
|
||||
# is satisfied.
|
||||
no_use_trtllm = env_value == "0"
|
||||
if not no_use_trtllm:
|
||||
logger.info_once(
|
||||
"VLLM_USE_TRTLLM_DECODE_ATTENTION is set to 1, "
|
||||
"using TRTLLM decode attention.")
|
||||
return not no_use_trtllm
|
||||
else:
|
||||
# Environment variable not set - use auto-detection
|
||||
# Only supports attention head size of 128
|
||||
use_trtllm = (FlashInferBackend.cached_sm100a_supported
|
||||
and batch_size <= 256 and max_seq_len < 131072
|
||||
and kv_cache_dtype == "auto")
|
||||
if use_trtllm:
|
||||
logger.warning_once(
|
||||
"Using TRTLLM decode attention (auto-detected).")
|
||||
return use_trtllm
|
||||
|
||||
@staticmethod
|
||||
def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
|
||||
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
|
||||
@ -352,7 +309,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
|
||||
if num_decodes > 0:
|
||||
attn_metadata.decode_wrapper = self._get_decode_wrapper()
|
||||
if not FlashInferBackend.use_trtllm_decode_attention(
|
||||
if not use_trtllm_decode_attention(
|
||||
num_decodes, attn_metadata.max_seq_len,
|
||||
self.cache_config.cache_dtype,
|
||||
attn_metadata.num_qo_heads, attn_metadata.num_kv_heads,
|
||||
@ -636,7 +593,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
decode_query = query[:num_decode_tokens]
|
||||
assert decode_query.shape[0] == num_decode_tokens
|
||||
assert decode_wrapper is not None
|
||||
if not FlashInferBackend.use_trtllm_decode_attention(
|
||||
if not use_trtllm_decode_attention(
|
||||
attn_metadata.num_decodes, attn_metadata.max_seq_len,
|
||||
self.kv_cache_dtype, attn_metadata.num_qo_heads,
|
||||
attn_metadata.num_kv_heads, attn_metadata.head_dim):
|
||||
|
||||
@ -209,6 +209,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv, round_down
|
||||
from vllm.utils.flashinfer import has_nvidia_artifactory
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||
get_per_layer_parameters, infer_global_hyperparameters,
|
||||
@ -379,17 +380,16 @@ M = TypeVar("M", bound=MLACommonMetadata)
|
||||
|
||||
|
||||
def use_flashinfer_prefill() -> bool:
|
||||
if flashinfer_available and not envs.VLLM_USE_CUDNN_PREFILL:
|
||||
# For blackwell default to flashinfer prefill if its available since
|
||||
# its faster than FA2.
|
||||
return current_platform.has_device_capability(100)
|
||||
return False
|
||||
# For blackwell default to flashinfer prefill if its available since
|
||||
# it is faster than FA2.
|
||||
return (flashinfer_available and not envs.VLLM_USE_CUDNN_PREFILL
|
||||
and current_platform.is_device_capability(100))
|
||||
|
||||
|
||||
def use_cudnn_prefill() -> bool:
|
||||
if flashinfer_available and envs.VLLM_USE_CUDNN_PREFILL:
|
||||
return current_platform.has_device_capability(100)
|
||||
return False
|
||||
return (flashinfer_available and envs.VLLM_USE_CUDNN_PREFILL
|
||||
and current_platform.is_device_capability(100)
|
||||
and has_nvidia_artifactory())
|
||||
|
||||
|
||||
# Currently 394MB, this can be tuned based on GEMM sizes used.
|
||||
|
||||
Reference in New Issue
Block a user