[BugFix] Fix vllm_flash_attn install issues (#17267)

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
Co-authored-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
Lucas Wilkinson
2025-04-27 20:27:56 -04:00
committed by GitHub
parent 20e489eaa1
commit d8bccde686
11 changed files with 28 additions and 284 deletions

1
.github/CODEOWNERS vendored
View File

@ -12,6 +12,7 @@
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth
/vllm/model_executor/guided_decoding @mgoin @russellb
/vllm/multimodal @DarkLight1337 @ywang96
/vllm/vllm_flash_attn @LucasWilkinson
CMakeLists.txt @tlrmchlsmth
# vLLM V1

2
.gitignore vendored
View File

@ -3,8 +3,6 @@
# vllm-flash-attn built from source
vllm/vllm_flash_attn/*
!vllm/vllm_flash_attn/__init__.py
!vllm/vllm_flash_attn/fa_utils.py
# Byte-compiled / optimized / DLL files
__pycache__/

View File

@ -269,15 +269,17 @@ class cmake_build_ext(build_ext):
# First, run the standard build_ext command to compile the extensions
super().run()
# copy vllm/vllm_flash_attn/*.py from self.build_lib to current
# copy vllm/vllm_flash_attn/**/*.py from self.build_lib to current
# directory so that they can be included in the editable build
import glob
files = glob.glob(
os.path.join(self.build_lib, "vllm", "vllm_flash_attn", "*.py"))
files = glob.glob(os.path.join(self.build_lib, "vllm",
"vllm_flash_attn", "**", "*.py"),
recursive=True)
for file in files:
dst_file = os.path.join("vllm/vllm_flash_attn",
os.path.basename(file))
file.split("vllm/vllm_flash_attn/")[-1])
print(f"Copying {file} to {dst_file}")
os.makedirs(os.path.dirname(dst_file), exist_ok=True)
self.copy_file(file, dst_file)
@ -377,12 +379,22 @@ class repackage_wheel(build_ext):
"vllm/_flashmla_C.abi3.so",
"vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so",
"vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so",
"vllm/vllm_flash_attn/flash_attn_interface.py",
"vllm/cumem_allocator.abi3.so",
# "vllm/_version.py", # not available in nightly wheels yet
]
file_members = filter(lambda x: x.filename in files_to_copy,
wheel.filelist)
file_members = list(
filter(lambda x: x.filename in files_to_copy, wheel.filelist))
# vllm_flash_attn python code:
# Regex from
# `glob.translate('vllm/vllm_flash_attn/**/*.py', recursive=True)`
import re
compiled_regex = re.compile(
r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py")
file_members += list(
filter(lambda x: compiled_regex.match(x.filename),
wheel.filelist))
for file in file_members:
print(f"Extracting and including {file.filename} "

View File

@ -22,13 +22,13 @@ from vllm.attention.backends.utils import (
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
is_all_encoder_attn_metadata_set, is_block_tables_empty)
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version)
from vllm.logger import init_logger
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache)
from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version)
if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
@ -689,7 +689,7 @@ class FlashAttentionImpl(AttentionImpl):
assert output is not None, "Output tensor must be provided."
# NOTE(woosuk): FlashAttention2 does not support FP8 KV cache.
if self.vllm_flash_attn_version < 3 or output.dtype != torch.bfloat16:
if not flash_attn_supports_fp8() or output.dtype != torch.bfloat16:
assert (
layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0), (
"key/v_scale is only supported in FlashAttention 3 with "

View File

@ -205,6 +205,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear,
UnquantizedLinearMethod)
@ -214,7 +215,6 @@ from vllm.multimodal import MultiModalPlaceholderMap
from vllm.platforms import current_platform
from vllm.triton_utils import HAS_TRITON
from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down
from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
if HAS_TRITON:
from vllm.attention.ops.triton_flash_attention import triton_attention

View File

@ -1377,7 +1377,7 @@ class EngineArgs:
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
supported = False
if fp8_attention and will_use_fa:
from vllm.vllm_flash_attn.fa_utils import (
from vllm.attention.utils.fa_utils import (
flash_attn_supports_fp8)
supported = flash_attn_supports_fp8()
if not supported:

View File

@ -11,11 +11,11 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version)
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv
from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version)
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput

View File

@ -197,6 +197,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
MLAAttentionImpl)
from vllm.attention.backends.utils import get_mla_dims
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear,
@ -204,7 +205,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform
from vllm.utils import cdiv, round_down
from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
try:
from vllm.vllm_flash_attn import flash_attn_varlen_func

View File

@ -1,22 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
import importlib.metadata
try:
__version__ = importlib.metadata.version("vllm-flash-attn")
except importlib.metadata.PackageNotFoundError:
# in this case, vllm-flash-attn is built from installing vllm editable
__version__ = "0.0.0.dev0"
from .flash_attn_interface import (fa_version_unsupported_reason,
flash_attn_varlen_func,
flash_attn_with_kvcache,
get_scheduler_metadata,
is_fa_version_supported, sparse_attn_func,
sparse_attn_varlen_func)
__all__ = [
'flash_attn_varlen_func', 'flash_attn_with_kvcache',
'get_scheduler_metadata', 'sparse_attn_func', 'sparse_attn_varlen_func',
'is_fa_version_supported', 'fa_version_unsupported_reason'
]

View File

@ -1,245 +0,0 @@
# ruff: ignore
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from typing import Any, Literal, overload
import torch
def get_scheduler_metadata(
batch_size: int,
max_seqlen_q: int,
max_seqlen_k: int,
num_heads_q: int,
num_heads_kv: int,
headdim: int,
cache_seqlens: torch.Tensor,
qkv_dtype: torch.dtype = ...,
headdim_v: int | None = ...,
cu_seqlens_q: torch.Tensor | None = ...,
cu_seqlens_k_new: torch.Tensor | None = ...,
cache_leftpad: torch.Tensor | None = ...,
page_size: int = ...,
max_seqlen_k_new: int = ...,
causal: bool = ...,
window_size: tuple[int, int] = ...,
has_softcap: bool = ...,
num_splits: int = ...,
pack_gqa: Any | None = ...,
sm_margin: int = ...,
): ...
@overload
def flash_attn_varlen_func(
q: tuple[int, int, int],
k: tuple[int, int, int],
v: tuple[int, int, int],
max_seqlen_q: int,
cu_seqlens_q: torch.Tensor | None,
max_seqlen_k: int,
cu_seqlens_k: torch.Tensor | None = ...,
seqused_k: Any | None = ...,
q_v: Any | None = ...,
dropout_p: float = ...,
causal: bool = ...,
window_size: list[int] | None = ...,
softmax_scale: float = ...,
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
deterministic: bool = ...,
return_attn_probs: bool = ...,
block_table: Any | None = ...,
return_softmax_lse: Literal[False] = ...,
out: Any = ...,
# FA3 Only
scheduler_metadata: Any | None = ...,
q_descale: Any | None = ...,
k_descale: Any | None = ...,
v_descale: Any | None = ...,
# Version selector
fa_version: int = ...,
) -> tuple[int, int, int]: ...
@overload
def flash_attn_varlen_func(
q: tuple[int, int, int],
k: tuple[int, int, int],
v: tuple[int, int, int],
max_seqlen_q: int,
cu_seqlens_q: torch.Tensor | None,
max_seqlen_k: int,
cu_seqlens_k: torch.Tensor | None = ...,
seqused_k: Any | None = ...,
q_v: Any | None = ...,
dropout_p: float = ...,
causal: bool = ...,
window_size: list[int] | None = ...,
softmax_scale: float = ...,
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
deterministic: bool = ...,
return_attn_probs: bool = ...,
block_table: Any | None = ...,
return_softmax_lse: Literal[True] = ...,
out: Any = ...,
# FA3 Only
scheduler_metadata: Any | None = ...,
q_descale: Any | None = ...,
k_descale: Any | None = ...,
v_descale: Any | None = ...,
# Version selector
fa_version: int = ...,
) -> tuple[tuple[int, int, int], tuple[int, int]]: ...
@overload
def flash_attn_with_kvcache(
q: tuple[int, int, int, int],
k_cache: tuple[int, int, int, int],
v_cache: tuple[int, int, int, int],
k: tuple[int, int, int, int] | None = ...,
v: tuple[int, int, int, int] | None = ...,
rotary_cos: tuple[int, int] | None = ...,
rotary_sin: tuple[int, int] | None = ...,
cache_seqlens: int | torch.Tensor | None = None,
cache_batch_idx: torch.Tensor | None = None,
cache_leftpad: torch.Tensor | None = ...,
block_table: torch.Tensor | None = ...,
softmax_scale: float = ...,
causal: bool = ...,
window_size: tuple[int, int] = ..., # -1 means infinite context window
softcap: float = ...,
rotary_interleaved: bool = ...,
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
num_splits: int = ...,
return_softmax_lse: Literal[False] = ...,
*,
out: Any = ...,
# FA3 Only
scheduler_metadata: Any | None = ...,
q_descale: Any | None = ...,
k_descale: Any | None = ...,
v_descale: Any | None = ...,
# Version selector
fa_version: int = ...,
) -> tuple[int, int, int, int]: ...
@overload
def flash_attn_with_kvcache(
q: tuple[int, int, int, int],
k_cache: tuple[int, int, int, int],
v_cache: tuple[int, int, int, int],
k: tuple[int, int, int, int] | None = ...,
v: tuple[int, int, int, int] | None = ...,
rotary_cos: tuple[int, int] | None = ...,
rotary_sin: tuple[int, int] | None = ...,
cache_seqlens: int | torch.Tensor | None = None,
cache_batch_idx: torch.Tensor | None = None,
cache_leftpad: torch.Tensor | None = ...,
block_table: torch.Tensor | None = ...,
softmax_scale: float = ...,
causal: bool = ...,
window_size: tuple[int, int] = ..., # -1 means infinite context window
softcap: float = ...,
rotary_interleaved: bool = ...,
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
num_splits: int = ...,
return_softmax_lse: Literal[True] = ...,
*,
out: Any = ...,
# FA3 Only
scheduler_metadata: Any | None = ...,
q_descale: Any | None = ...,
k_descale: Any | None = ...,
v_descale: Any | None = ...,
# Version selector
fa_version: int = ...,
) -> tuple[tuple[int, int, int], tuple[int, int]]: ...
@overload
def sparse_attn_func(
q: tuple[int, int, int, int],
k: tuple[int, int, int, int],
v: tuple[int, int, int, int],
block_count: tuple[int, int, float],
block_offset: tuple[int, int, float, int],
column_count: tuple[int, int, float],
column_index: tuple[int, int, float, int],
dropout_p: float = ...,
softmax_scale: float = ...,
causal: bool = ...,
softcap: float = ...,
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
deterministic: bool = ...,
return_attn_probs: bool = ...,
*,
return_softmax_lse: Literal[False] = ...,
out: Any = ...,
) -> tuple[int, int, int]: ...
@overload
def sparse_attn_func(
q: tuple[int, int, int, int],
k: tuple[int, int, int, int],
v: tuple[int, int, int, int],
block_count: tuple[int, int, float],
block_offset: tuple[int, int, float, int],
column_count: tuple[int, int, float],
column_index: tuple[int, int, float, int],
dropout_p: float = ...,
softmax_scale: float = ...,
causal: bool = ...,
softcap: float = ...,
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
deterministic: bool = ...,
return_attn_probs: bool = ...,
*,
return_softmax_lse: Literal[True] = ...,
out: Any = ...,
) -> tuple[tuple[int, int, int], tuple[int, int]]: ...
@overload
def sparse_attn_varlen_func(
q: tuple[int, int, int],
k: tuple[int, int, int],
v: tuple[int, int, int],
block_count: tuple[int, int, float],
block_offset: tuple[int, int, float, int],
column_count: tuple[int, int, float],
column_index: tuple[int, int, float, int],
cu_seqlens_q: torch.Tensor | None,
cu_seqlens_k: torch.Tensor | None,
max_seqlen_q: int,
max_seqlen_k: int,
dropout_p: float = ...,
softmax_scale: float = ...,
causal: bool = ...,
softcap: float = ...,
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
deterministic: bool = ...,
return_attn_probs: bool = ...,
*,
return_softmax_lse: Literal[False] = ...,
out: Any = ...,
) -> tuple[int, int, int]: ...
@overload
def sparse_attn_varlen_func(
q: tuple[int, int, int],
k: tuple[int, int, int],
v: tuple[int, int, int],
block_count: tuple[int, int, float],
block_offset: tuple[int, int, float, int],
column_count: tuple[int, int, float],
column_index: tuple[int, int, float, int],
cu_seqlens_q: torch.Tensor | None,
cu_seqlens_k: torch.Tensor | None,
max_seqlen_q: int,
max_seqlen_k: int,
dropout_p: float = ...,
softmax_scale: float = ...,
causal: bool = ...,
softcap: float = ...,
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
deterministic: bool = ...,
return_attn_probs: bool = ...,
*,
return_softmax_lse: Literal[True] = ...,
out: Any = ...,
) -> tuple[tuple[int, int, int], tuple[int, int]]: ...
def is_fa_version_supported(
fa_version: int, device: torch.device | None = None
) -> bool: ...
def fa_version_unsupported_reason(
fa_version: int, device: torch.device | None = None
) -> str | None: ...