[BugFix] Fix vllm_flash_attn install issues (#17267)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
1
.github/CODEOWNERS
vendored
1
.github/CODEOWNERS
vendored
@ -12,6 +12,7 @@
|
||||
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth
|
||||
/vllm/model_executor/guided_decoding @mgoin @russellb
|
||||
/vllm/multimodal @DarkLight1337 @ywang96
|
||||
/vllm/vllm_flash_attn @LucasWilkinson
|
||||
CMakeLists.txt @tlrmchlsmth
|
||||
|
||||
# vLLM V1
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@ -3,8 +3,6 @@
|
||||
|
||||
# vllm-flash-attn built from source
|
||||
vllm/vllm_flash_attn/*
|
||||
!vllm/vllm_flash_attn/__init__.py
|
||||
!vllm/vllm_flash_attn/fa_utils.py
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
|
||||
26
setup.py
26
setup.py
@ -269,15 +269,17 @@ class cmake_build_ext(build_ext):
|
||||
# First, run the standard build_ext command to compile the extensions
|
||||
super().run()
|
||||
|
||||
# copy vllm/vllm_flash_attn/*.py from self.build_lib to current
|
||||
# copy vllm/vllm_flash_attn/**/*.py from self.build_lib to current
|
||||
# directory so that they can be included in the editable build
|
||||
import glob
|
||||
files = glob.glob(
|
||||
os.path.join(self.build_lib, "vllm", "vllm_flash_attn", "*.py"))
|
||||
files = glob.glob(os.path.join(self.build_lib, "vllm",
|
||||
"vllm_flash_attn", "**", "*.py"),
|
||||
recursive=True)
|
||||
for file in files:
|
||||
dst_file = os.path.join("vllm/vllm_flash_attn",
|
||||
os.path.basename(file))
|
||||
file.split("vllm/vllm_flash_attn/")[-1])
|
||||
print(f"Copying {file} to {dst_file}")
|
||||
os.makedirs(os.path.dirname(dst_file), exist_ok=True)
|
||||
self.copy_file(file, dst_file)
|
||||
|
||||
|
||||
@ -377,12 +379,22 @@ class repackage_wheel(build_ext):
|
||||
"vllm/_flashmla_C.abi3.so",
|
||||
"vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so",
|
||||
"vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so",
|
||||
"vllm/vllm_flash_attn/flash_attn_interface.py",
|
||||
"vllm/cumem_allocator.abi3.so",
|
||||
# "vllm/_version.py", # not available in nightly wheels yet
|
||||
]
|
||||
file_members = filter(lambda x: x.filename in files_to_copy,
|
||||
wheel.filelist)
|
||||
|
||||
file_members = list(
|
||||
filter(lambda x: x.filename in files_to_copy, wheel.filelist))
|
||||
|
||||
# vllm_flash_attn python code:
|
||||
# Regex from
|
||||
# `glob.translate('vllm/vllm_flash_attn/**/*.py', recursive=True)`
|
||||
import re
|
||||
compiled_regex = re.compile(
|
||||
r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py")
|
||||
file_members += list(
|
||||
filter(lambda x: compiled_regex.match(x.filename),
|
||||
wheel.filelist))
|
||||
|
||||
for file in file_members:
|
||||
print(f"Extracting and including {file.filename} "
|
||||
|
||||
@ -22,13 +22,13 @@ from vllm.attention.backends.utils import (
|
||||
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
|
||||
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
|
||||
is_all_encoder_attn_metadata_set, is_block_tables_empty)
|
||||
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
|
||||
get_flash_attn_version)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
||||
flash_attn_with_kvcache)
|
||||
from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8,
|
||||
get_flash_attn_version)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
||||
@ -689,7 +689,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
# NOTE(woosuk): FlashAttention2 does not support FP8 KV cache.
|
||||
if self.vllm_flash_attn_version < 3 or output.dtype != torch.bfloat16:
|
||||
if not flash_attn_supports_fp8() or output.dtype != torch.bfloat16:
|
||||
assert (
|
||||
layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0), (
|
||||
"key/v_scale is only supported in FlashAttention 3 with "
|
||||
|
||||
@ -205,6 +205,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
||||
compute_slot_mapping_start_idx,
|
||||
is_block_tables_empty)
|
||||
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearBase, RowParallelLinear,
|
||||
UnquantizedLinearMethod)
|
||||
@ -214,7 +215,6 @@ from vllm.multimodal import MultiModalPlaceholderMap
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down
|
||||
from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
|
||||
|
||||
if HAS_TRITON:
|
||||
from vllm.attention.ops.triton_flash_attention import triton_attention
|
||||
|
||||
@ -1377,7 +1377,7 @@ class EngineArgs:
|
||||
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
|
||||
supported = False
|
||||
if fp8_attention and will_use_fa:
|
||||
from vllm.vllm_flash_attn.fa_utils import (
|
||||
from vllm.attention.utils.fa_utils import (
|
||||
flash_attn_supports_fp8)
|
||||
supported = flash_attn_supports_fp8()
|
||||
if not supported:
|
||||
|
||||
@ -11,11 +11,11 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
|
||||
get_flash_attn_version)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv
|
||||
from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8,
|
||||
get_flash_attn_version)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
@ -197,6 +197,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
||||
MLAAttentionImpl)
|
||||
from vllm.attention.backends.utils import get_mla_dims
|
||||
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearBase, RowParallelLinear,
|
||||
@ -204,7 +205,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv, round_down
|
||||
from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
|
||||
|
||||
try:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
|
||||
@ -1,22 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import importlib.metadata
|
||||
|
||||
try:
|
||||
__version__ = importlib.metadata.version("vllm-flash-attn")
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
# in this case, vllm-flash-attn is built from installing vllm editable
|
||||
__version__ = "0.0.0.dev0"
|
||||
|
||||
from .flash_attn_interface import (fa_version_unsupported_reason,
|
||||
flash_attn_varlen_func,
|
||||
flash_attn_with_kvcache,
|
||||
get_scheduler_metadata,
|
||||
is_fa_version_supported, sparse_attn_func,
|
||||
sparse_attn_varlen_func)
|
||||
|
||||
__all__ = [
|
||||
'flash_attn_varlen_func', 'flash_attn_with_kvcache',
|
||||
'get_scheduler_metadata', 'sparse_attn_func', 'sparse_attn_varlen_func',
|
||||
'is_fa_version_supported', 'fa_version_unsupported_reason'
|
||||
]
|
||||
@ -1,245 +0,0 @@
|
||||
# ruff: ignore
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal, overload
|
||||
|
||||
import torch
|
||||
|
||||
def get_scheduler_metadata(
|
||||
batch_size: int,
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
num_heads_q: int,
|
||||
num_heads_kv: int,
|
||||
headdim: int,
|
||||
cache_seqlens: torch.Tensor,
|
||||
qkv_dtype: torch.dtype = ...,
|
||||
headdim_v: int | None = ...,
|
||||
cu_seqlens_q: torch.Tensor | None = ...,
|
||||
cu_seqlens_k_new: torch.Tensor | None = ...,
|
||||
cache_leftpad: torch.Tensor | None = ...,
|
||||
page_size: int = ...,
|
||||
max_seqlen_k_new: int = ...,
|
||||
causal: bool = ...,
|
||||
window_size: tuple[int, int] = ...,
|
||||
has_softcap: bool = ...,
|
||||
num_splits: int = ...,
|
||||
pack_gqa: Any | None = ...,
|
||||
sm_margin: int = ...,
|
||||
): ...
|
||||
@overload
|
||||
def flash_attn_varlen_func(
|
||||
q: tuple[int, int, int],
|
||||
k: tuple[int, int, int],
|
||||
v: tuple[int, int, int],
|
||||
max_seqlen_q: int,
|
||||
cu_seqlens_q: torch.Tensor | None,
|
||||
max_seqlen_k: int,
|
||||
cu_seqlens_k: torch.Tensor | None = ...,
|
||||
seqused_k: Any | None = ...,
|
||||
q_v: Any | None = ...,
|
||||
dropout_p: float = ...,
|
||||
causal: bool = ...,
|
||||
window_size: list[int] | None = ...,
|
||||
softmax_scale: float = ...,
|
||||
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
|
||||
deterministic: bool = ...,
|
||||
return_attn_probs: bool = ...,
|
||||
block_table: Any | None = ...,
|
||||
return_softmax_lse: Literal[False] = ...,
|
||||
out: Any = ...,
|
||||
# FA3 Only
|
||||
scheduler_metadata: Any | None = ...,
|
||||
q_descale: Any | None = ...,
|
||||
k_descale: Any | None = ...,
|
||||
v_descale: Any | None = ...,
|
||||
# Version selector
|
||||
fa_version: int = ...,
|
||||
) -> tuple[int, int, int]: ...
|
||||
@overload
|
||||
def flash_attn_varlen_func(
|
||||
q: tuple[int, int, int],
|
||||
k: tuple[int, int, int],
|
||||
v: tuple[int, int, int],
|
||||
max_seqlen_q: int,
|
||||
cu_seqlens_q: torch.Tensor | None,
|
||||
max_seqlen_k: int,
|
||||
cu_seqlens_k: torch.Tensor | None = ...,
|
||||
seqused_k: Any | None = ...,
|
||||
q_v: Any | None = ...,
|
||||
dropout_p: float = ...,
|
||||
causal: bool = ...,
|
||||
window_size: list[int] | None = ...,
|
||||
softmax_scale: float = ...,
|
||||
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
|
||||
deterministic: bool = ...,
|
||||
return_attn_probs: bool = ...,
|
||||
block_table: Any | None = ...,
|
||||
return_softmax_lse: Literal[True] = ...,
|
||||
out: Any = ...,
|
||||
# FA3 Only
|
||||
scheduler_metadata: Any | None = ...,
|
||||
q_descale: Any | None = ...,
|
||||
k_descale: Any | None = ...,
|
||||
v_descale: Any | None = ...,
|
||||
# Version selector
|
||||
fa_version: int = ...,
|
||||
) -> tuple[tuple[int, int, int], tuple[int, int]]: ...
|
||||
@overload
|
||||
def flash_attn_with_kvcache(
|
||||
q: tuple[int, int, int, int],
|
||||
k_cache: tuple[int, int, int, int],
|
||||
v_cache: tuple[int, int, int, int],
|
||||
k: tuple[int, int, int, int] | None = ...,
|
||||
v: tuple[int, int, int, int] | None = ...,
|
||||
rotary_cos: tuple[int, int] | None = ...,
|
||||
rotary_sin: tuple[int, int] | None = ...,
|
||||
cache_seqlens: int | torch.Tensor | None = None,
|
||||
cache_batch_idx: torch.Tensor | None = None,
|
||||
cache_leftpad: torch.Tensor | None = ...,
|
||||
block_table: torch.Tensor | None = ...,
|
||||
softmax_scale: float = ...,
|
||||
causal: bool = ...,
|
||||
window_size: tuple[int, int] = ..., # -1 means infinite context window
|
||||
softcap: float = ...,
|
||||
rotary_interleaved: bool = ...,
|
||||
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
|
||||
num_splits: int = ...,
|
||||
return_softmax_lse: Literal[False] = ...,
|
||||
*,
|
||||
out: Any = ...,
|
||||
# FA3 Only
|
||||
scheduler_metadata: Any | None = ...,
|
||||
q_descale: Any | None = ...,
|
||||
k_descale: Any | None = ...,
|
||||
v_descale: Any | None = ...,
|
||||
# Version selector
|
||||
fa_version: int = ...,
|
||||
) -> tuple[int, int, int, int]: ...
|
||||
@overload
|
||||
def flash_attn_with_kvcache(
|
||||
q: tuple[int, int, int, int],
|
||||
k_cache: tuple[int, int, int, int],
|
||||
v_cache: tuple[int, int, int, int],
|
||||
k: tuple[int, int, int, int] | None = ...,
|
||||
v: tuple[int, int, int, int] | None = ...,
|
||||
rotary_cos: tuple[int, int] | None = ...,
|
||||
rotary_sin: tuple[int, int] | None = ...,
|
||||
cache_seqlens: int | torch.Tensor | None = None,
|
||||
cache_batch_idx: torch.Tensor | None = None,
|
||||
cache_leftpad: torch.Tensor | None = ...,
|
||||
block_table: torch.Tensor | None = ...,
|
||||
softmax_scale: float = ...,
|
||||
causal: bool = ...,
|
||||
window_size: tuple[int, int] = ..., # -1 means infinite context window
|
||||
softcap: float = ...,
|
||||
rotary_interleaved: bool = ...,
|
||||
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
|
||||
num_splits: int = ...,
|
||||
return_softmax_lse: Literal[True] = ...,
|
||||
*,
|
||||
out: Any = ...,
|
||||
# FA3 Only
|
||||
scheduler_metadata: Any | None = ...,
|
||||
q_descale: Any | None = ...,
|
||||
k_descale: Any | None = ...,
|
||||
v_descale: Any | None = ...,
|
||||
# Version selector
|
||||
fa_version: int = ...,
|
||||
) -> tuple[tuple[int, int, int], tuple[int, int]]: ...
|
||||
@overload
|
||||
def sparse_attn_func(
|
||||
q: tuple[int, int, int, int],
|
||||
k: tuple[int, int, int, int],
|
||||
v: tuple[int, int, int, int],
|
||||
block_count: tuple[int, int, float],
|
||||
block_offset: tuple[int, int, float, int],
|
||||
column_count: tuple[int, int, float],
|
||||
column_index: tuple[int, int, float, int],
|
||||
dropout_p: float = ...,
|
||||
softmax_scale: float = ...,
|
||||
causal: bool = ...,
|
||||
softcap: float = ...,
|
||||
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
|
||||
deterministic: bool = ...,
|
||||
return_attn_probs: bool = ...,
|
||||
*,
|
||||
return_softmax_lse: Literal[False] = ...,
|
||||
out: Any = ...,
|
||||
) -> tuple[int, int, int]: ...
|
||||
@overload
|
||||
def sparse_attn_func(
|
||||
q: tuple[int, int, int, int],
|
||||
k: tuple[int, int, int, int],
|
||||
v: tuple[int, int, int, int],
|
||||
block_count: tuple[int, int, float],
|
||||
block_offset: tuple[int, int, float, int],
|
||||
column_count: tuple[int, int, float],
|
||||
column_index: tuple[int, int, float, int],
|
||||
dropout_p: float = ...,
|
||||
softmax_scale: float = ...,
|
||||
causal: bool = ...,
|
||||
softcap: float = ...,
|
||||
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
|
||||
deterministic: bool = ...,
|
||||
return_attn_probs: bool = ...,
|
||||
*,
|
||||
return_softmax_lse: Literal[True] = ...,
|
||||
out: Any = ...,
|
||||
) -> tuple[tuple[int, int, int], tuple[int, int]]: ...
|
||||
@overload
|
||||
def sparse_attn_varlen_func(
|
||||
q: tuple[int, int, int],
|
||||
k: tuple[int, int, int],
|
||||
v: tuple[int, int, int],
|
||||
block_count: tuple[int, int, float],
|
||||
block_offset: tuple[int, int, float, int],
|
||||
column_count: tuple[int, int, float],
|
||||
column_index: tuple[int, int, float, int],
|
||||
cu_seqlens_q: torch.Tensor | None,
|
||||
cu_seqlens_k: torch.Tensor | None,
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
dropout_p: float = ...,
|
||||
softmax_scale: float = ...,
|
||||
causal: bool = ...,
|
||||
softcap: float = ...,
|
||||
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
|
||||
deterministic: bool = ...,
|
||||
return_attn_probs: bool = ...,
|
||||
*,
|
||||
return_softmax_lse: Literal[False] = ...,
|
||||
out: Any = ...,
|
||||
) -> tuple[int, int, int]: ...
|
||||
@overload
|
||||
def sparse_attn_varlen_func(
|
||||
q: tuple[int, int, int],
|
||||
k: tuple[int, int, int],
|
||||
v: tuple[int, int, int],
|
||||
block_count: tuple[int, int, float],
|
||||
block_offset: tuple[int, int, float, int],
|
||||
column_count: tuple[int, int, float],
|
||||
column_index: tuple[int, int, float, int],
|
||||
cu_seqlens_q: torch.Tensor | None,
|
||||
cu_seqlens_k: torch.Tensor | None,
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
dropout_p: float = ...,
|
||||
softmax_scale: float = ...,
|
||||
causal: bool = ...,
|
||||
softcap: float = ...,
|
||||
alibi_slopes: tuple[int] | tuple[int, int] | None = ...,
|
||||
deterministic: bool = ...,
|
||||
return_attn_probs: bool = ...,
|
||||
*,
|
||||
return_softmax_lse: Literal[True] = ...,
|
||||
out: Any = ...,
|
||||
) -> tuple[tuple[int, int, int], tuple[int, int]]: ...
|
||||
def is_fa_version_supported(
|
||||
fa_version: int, device: torch.device | None = None
|
||||
) -> bool: ...
|
||||
def fa_version_unsupported_reason(
|
||||
fa_version: int, device: torch.device | None = None
|
||||
) -> str | None: ...
|
||||
Reference in New Issue
Block a user