[Attention] FlashAttn MLA (#14258)

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
Co-authored-by: Matthew Bonanni <mbonanni001@gmail.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Lucas Wilkinson
2025-09-04 05:47:59 -04:00
committed by GitHub
parent 2c301ee2eb
commit 402759d472
22 changed files with 480 additions and 200 deletions

View File

@ -5,11 +5,11 @@ import os
import sys
import zipfile
# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 400 MiB
# Note that we have 400 MiB quota, please use it wisely.
# See https://github.com/pypi/support/issues/3792 .
# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 450 MiB
# Note that we have 800 MiB quota, please use it wisely.
# See https://github.com/pypi/support/issues/6326 .
# Please also sync the value with the one in Dockerfile.
VLLM_MAX_SIZE_MB = int(os.environ.get("VLLM_MAX_SIZE_MB", 400))
VLLM_MAX_SIZE_MB = int(os.environ.get("VLLM_MAX_SIZE_MB", 450))
def print_top_10_largest_files(zip_file):

View File

@ -38,7 +38,7 @@ else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 57b4e68b9f9d94750b46de8f8dbd2bfcc86edd4f
GIT_TAG ee4d25bd84e0cbc7e0b9b9685085fd5db2dcb62a
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn

View File

@ -237,7 +237,7 @@ RUN --mount=type=cache,target=/root/.cache/ccache \
# Check the size of the wheel if RUN_WHEEL_CHECK is true
COPY .buildkite/check-wheel-size.py check-wheel-size.py
# sync the default value with .buildkite/check-wheel-size.py
ARG VLLM_MAX_SIZE_MB=400
ARG VLLM_MAX_SIZE_MB=450
ENV VLLM_MAX_SIZE_MB=$VLLM_MAX_SIZE_MB
ARG RUN_WHEEL_CHECK=true
RUN if [ "$RUN_WHEEL_CHECK" = "true" ]; then \

View File

@ -22,7 +22,7 @@ def clear_cache():
# Define MLA and non-MLA backends separately
DEVICE_MLA_BACKENDS = {
"cuda": ["TRITON_MLA", "FLASHMLA"],
"cuda": ["TRITON_MLA", "FLASHMLA", "FLASH_ATTN_MLA", "CUTLASS_MLA"],
"hip": ["TRITON_MLA", "ROCM_AITER_MLA"],
"cpu": [],
}
@ -98,21 +98,14 @@ def test_env(
with patch("vllm.attention.selector.current_platform",
RocmPlatform()):
if use_mla:
# Validate HIP MLA backend-block_size combinations
valid_combination = (
(name == "TRITON_MLA" and block_size != 1)
or (name == "ROCM_AITER_MLA" and block_size == 1))
# ROCm MLA backend logic:
# - TRITON_MLA: supported when block_size != 1
# - ROCM_AITER_MLA: supported when block_size == 1
# If backend is forced but doesn't match block_size,
# should raise ValueError
if valid_combination:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
expected = f"{name}_VLLM_V1" if use_v1 else name
assert backend.get_name() == expected
else:
if name == "TRITON_MLA" and block_size == 1:
# TRITON_MLA doesn't support block_size == 1
with pytest.raises(ValueError) as exc_info:
get_attn_backend(16,
torch.float16,
@ -122,6 +115,27 @@ def test_env(
use_mla=use_mla)
assert f"The selected backend, {name}" in str(
exc_info.value)
elif name == "ROCM_AITER_MLA" and block_size != 1:
# ROCM_AITER_MLA only supports block_size == 1
with pytest.raises(ValueError) as exc_info:
get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
assert f"The selected backend, {name}" in str(
exc_info.value)
else:
# Valid backend-block_size combination
backend = get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
expected = f"{name}_VLLM_V1" if use_v1 else name
assert backend.get_name() == expected
else:
backend = get_attn_backend(16,
torch.float16,
@ -136,16 +150,22 @@ def test_env(
with patch("vllm.attention.selector.current_platform",
CudaPlatform()):
if use_mla:
if name == "FLASHMLA" and block_size == 64:
from vllm.attention.backends.flashmla import (
is_flashmla_supported)
# CUDA MLA backend logic:
# - CUTLASS_MLA: only supported with block_size == 128
# and Blackwell GPUs (SM 10.0), V1 only
# - FLASHMLA: only supported with block_size == 64
# - FLASH_ATTN_MLA: V1 only
# - TRITON_MLA: fallback for other cases
# only on cuda platforms with specific capability.
is_supported, _ = is_flashmla_supported()
if not is_supported:
# if platform is not supported then skip this case.
pytest.skip()
if name == "CUTLASS_MLA":
if not use_v1:
# CUTLASS_MLA only supported on V1 engine
pytest.skip(
"CUTLASS_MLA only supported on V1 engine")
elif block_size != 128:
# CUTLASS_MLA only supports block_size == 128
pytest.skip(
"CUTLASS_MLA only supports block_size 128")
else:
backend = get_attn_backend(16,
torch.float16,
@ -153,9 +173,45 @@ def test_env(
block_size,
False,
use_mla=use_mla)
expected = f"{name}_VLLM_V1" if use_v1 else name
expected = "CUTLASS_MLA_VLLM_V1"
assert backend.get_name() == expected
elif name == "FLASHMLA":
if block_size != 64:
# FlashMLA only supports block_size == 64
pytest.skip("FlashMLA only supports block_size 64")
else:
from vllm.attention.backends.flashmla import (
is_flashmla_supported)
is_supported, _ = is_flashmla_supported()
if not is_supported:
pytest.skip(
"FlashMLA not supported on this platform")
else:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
expected = f"{name}_VLLM_V1" if use_v1 else name
assert backend.get_name() == expected
elif name == "FLASH_ATTN_MLA":
if not use_v1:
# FlashAttention MLA only supported on V1 engine
pytest.skip(
"FlashAttention MLA only supported on V1 engine"
)
else:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
expected = "FLASH_ATTN_MLA"
assert backend.get_name() == expected
else:
# TRITON_MLA or other fallback
backend = get_attn_backend(16,
torch.float16,
torch.float16,

View File

@ -70,22 +70,6 @@ BATCH_SPECS = {
}
def create_dummy_kv_cache(kv_cache_spec: FullAttentionSpec,
device: torch.device,
num_blocks: int = 100) -> torch.Tensor:
"""Create a dummy KV cache tensor for testing."""
kv_cache = torch.randn(
2, # K and V
num_blocks,
kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
dtype=_convert_dtype_to_torch(kv_cache_spec.dtype),
device=device,
)
return kv_cache
def create_and_prepopulate_kv_cache(
k_contexts: list[torch.Tensor],
v_contexts: list[torch.Tensor],

View File

@ -15,7 +15,7 @@ from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec
BACKENDS_TO_TEST = [
_Backend.CUTLASS_MLA, _Backend.FLASHMLA_VLLM_V1,
_Backend.CUTLASS_MLA, _Backend.FLASHMLA_VLLM_V1, _Backend.FLASH_ATTN_MLA,
_Backend.TRITON_MLA_VLLM_V1
]
@ -69,20 +69,6 @@ BATCH_SPECS = {
}
def create_dummy_kv_cache(kv_cache_spec: FullAttentionSpec,
device: torch.device,
num_blocks: int = 100) -> torch.Tensor:
"""Create a dummy KV cache tensor for testing."""
kv_cache = torch.randn(
num_blocks,
kv_cache_spec.block_size,
kv_cache_spec.head_size, # latent dimension
dtype=_convert_dtype_to_torch(kv_cache_spec.dtype),
device=device,
)
return kv_cache
def create_and_prepopulate_kv_cache(
kv_c_contexts: list[torch.Tensor],
k_pe_contexts: list[torch.Tensor],
@ -315,7 +301,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
# 2. Generate data and compute SDPA reference output for MLA
all_q_vllm, all_kv_c_vllm, all_k_pe_vllm = [], [], []
all_sdpa_outputs = []
all_sdpa_outputs: list[list[torch.Tensor]] = []
kv_c_contexts, k_pe_contexts = [], []
# Create shared MLA weight matrices for consistency across all sequences
@ -331,6 +317,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
device=device)
kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1)
for i, backend in enumerate(BACKENDS_TO_TEST):
all_sdpa_outputs.append([])
for i in range(batch_size):
s_len = seq_lens[i]
q_len = query_lens[i]
@ -358,85 +347,93 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
dtype=dtype,
device=device)
# Determine if this is decode (single token)
# or prefill (multiple tokens)
is_decode = q_len == 1
# Determine if this is decode or prefill
is_decode = []
for i, backend in enumerate(BACKENDS_TO_TEST):
builder_cls, _ = get_attention_backend(backend)
is_decode.append(q_len <= builder_cls.reorder_batch_threshold)
# Split q into nope and rope components
q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
if is_decode:
# Decode path: MQA-style attention in latent space
# Transform q_nope to latent space: q_nope @ W_UK
# q_nope: [1, num_heads, qk_nope_head_dim]
# W_UK: [kv_lora_rank, num_heads, qk_nope_head_dim]
ql_nope = torch.einsum("qnh,lnh->qnl", q_nope,
W_UK) # [1, num_heads, kv_lora_rank]
#######################################################
# Decode path: MQA-style attention in latent space
# Transform q_nope to latent space: q_nope @ W_UK
# q_nope: [1, num_heads, qk_nope_head_dim]
# W_UK: [kv_lora_rank, num_heads, qk_nope_head_dim]
ql_nope = torch.einsum("qnh,lnh->qnl", q_nope,
W_UK) # [1, num_heads, kv_lora_rank]
# Build MQA attention inputs
# Q: [1, num_heads, kv_lora_rank + qk_rope_head_dim]
q_mqa = torch.cat([ql_nope, q_pe], dim=-1)
# K: [s_len, kv_lora_rank + qk_rope_head_dim]
# (broadcasted to all heads)
k_mqa = torch.cat([kv_c_full, k_pe_full.squeeze(1)], dim=-1)
k_mqa = k_mqa.unsqueeze(1).expand(-1, num_q_heads, -1)
# V: [s_len, kv_lora_rank] (broadcasted to all heads)
v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_q_heads, -1)
# Build MQA attention inputs
# Q: [1, num_heads, kv_lora_rank + qk_rope_head_dim]
q_mqa = torch.cat([ql_nope, q_pe], dim=-1)
# K: [s_len, kv_lora_rank + qk_rope_head_dim]
# (broadcasted to all heads)
k_mqa = torch.cat([kv_c_full, k_pe_full.squeeze(1)], dim=-1)
k_mqa = k_mqa.unsqueeze(1).expand(-1, num_q_heads, -1)
# V: [s_len, kv_lora_rank] (broadcasted to all heads)
v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_q_heads, -1)
# SDPA expects (N, H, L, D)
q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2)
k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2)
v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2)
# Create custom attention mask for decode path:
# - Query tokens can attend to all context tokens
# - Query tokens can only attend to query tokens up to their position
attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device)
# Apply causal mask only to the query portion (context_len onwards)
causal_mask = torch.tril(torch.ones(q_len, q_len, device=device))
attn_mask[:, context_len:] = causal_mask
sdpa_out_i = torch.nn.functional.scaled_dot_product_attention(
q_sdpa_in, k_sdpa_in, v_sdpa_in, is_causal=False, scale=scale)
sdpa_out_i = sdpa_out_i.transpose(1, 2).squeeze(
0) # [1, num_heads, kv_lora_rank]
# SDPA expects (N, H, L, D)
q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2)
k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2)
v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2)
# Project back to output space: sdpa_out @ W_UV
sdpa_out_i = torch.einsum("qnl,lnv->qnv", sdpa_out_i, W_UV)
sdpa_out_i = sdpa_out_i.flatten(start_dim=-2)
else:
# Prefill path: MHA-style attention with full sequence
# Apply kv_b_proj to the full kv_c tensor
kv_nope_full = torch.einsum("sl,lnh->snh", kv_c_full,
kv_b_proj_weight)
k_nope_full, v_full = kv_nope_full.split(
[qk_nope_head_dim, v_head_dim], dim=-1)
sdpa_out_i_decode = torch.nn.functional.scaled_dot_product_attention(
q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale)
sdpa_out_i_decode = sdpa_out_i_decode.transpose(1, 2).squeeze(
0) # [1, num_heads, kv_lora_rank]
# Build attention inputs for full sequence
q_mha = torch.cat([q_nope, q_pe],
dim=-1) # [q_len, num_heads, total_dim]
k_pe_full_expanded = k_pe_full.expand(-1, num_q_heads, -1)
k_full = torch.cat([k_nope_full, k_pe_full_expanded], dim=-1)
# Project back to output space: sdpa_out @ W_UV
sdpa_out_i_decode = torch.einsum("qnl,lnv->qnv", sdpa_out_i_decode,
W_UV)
sdpa_out_i_decode = sdpa_out_i_decode.flatten(start_dim=-2)
# Create custom attention mask:
# - Query tokens can attend to all context tokens
# - Query tokens can only attend to query tokens up to their pos
attn_mask = torch.ones(q_len,
s_len,
dtype=torch.bool,
device=device)
# Apply causal mask only to the query portion (context_len onwards)
causal_mask = torch.tril(torch.ones(q_len, q_len, device=device))
attn_mask[:, context_len:] = causal_mask
#######################################################
# Prefill path: MHA-style attention with full sequence
# Apply kv_b_proj to the full kv_c tensor
kv_nope_full = torch.einsum("sl,lnh->snh", kv_c_full, kv_b_proj_weight)
k_nope_full, v_full = kv_nope_full.split(
[qk_nope_head_dim, v_head_dim], dim=-1)
# SDPA expects (N, H, L, D)
q_sdpa_in = q_mha.unsqueeze(0).transpose(1, 2)
k_sdpa_in = k_full.unsqueeze(0).transpose(1, 2)
v_sdpa_in = v_full.unsqueeze(0).transpose(1, 2)
# Build attention inputs for full sequence
q_mha = torch.cat([q_nope, q_pe],
dim=-1) # [q_len, num_heads, total_dim]
k_pe_full_expanded = k_pe_full.expand(-1, num_q_heads, -1)
k_full = torch.cat([k_nope_full, k_pe_full_expanded], dim=-1)
# Single attention call with custom mask
sdpa_out_i = torch.nn.functional.scaled_dot_product_attention(
q_sdpa_in,
k_sdpa_in,
v_sdpa_in,
attn_mask=attn_mask,
scale=scale)
sdpa_out_i = sdpa_out_i.transpose(1, 2).squeeze(0)
sdpa_out_i = sdpa_out_i.flatten(start_dim=-2)
# Create custom attention mask:
# - Query tokens can attend to all context tokens
# - Query tokens can only attend to query tokens up to their pos
attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device)
# Apply causal mask only to the query portion (context_len onwards)
causal_mask = torch.tril(torch.ones(q_len, q_len, device=device))
attn_mask[:, context_len:] = causal_mask
all_sdpa_outputs.append(sdpa_out_i)
# SDPA expects (N, H, L, D)
q_sdpa_in = q_mha.unsqueeze(0).transpose(1, 2)
k_sdpa_in = k_full.unsqueeze(0).transpose(1, 2)
v_sdpa_in = v_full.unsqueeze(0).transpose(1, 2)
# Single attention call with custom mask
sdpa_out_i_prefill = torch.nn.functional.scaled_dot_product_attention(
q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale)
sdpa_out_i_prefill = sdpa_out_i_prefill.transpose(1, 2).squeeze(0)
sdpa_out_i_prefill = sdpa_out_i_prefill.flatten(start_dim=-2)
for i, backend in enumerate(BACKENDS_TO_TEST):
if is_decode[i]:
all_sdpa_outputs[i].append(sdpa_out_i_decode)
else:
all_sdpa_outputs[i].append(sdpa_out_i_prefill)
# Inputs for vLLM MLA backends are just the new tokens
all_q_vllm.append(q_c)
@ -451,7 +448,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
query_vllm = torch.cat(all_q_vllm, dim=0)
kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0)
k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0)
sdpa_output = torch.cat(all_sdpa_outputs, dim=0)
sdpa_outputs = []
for i, backend in enumerate(BACKENDS_TO_TEST):
sdpa_outputs.append(torch.cat(all_sdpa_outputs[i], dim=0))
# Create mock kv_b_proj using the same weights as reference implementation
from vllm.model_executor.layers.linear import ColumnParallelLinear
@ -486,7 +485,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
randomize_blocks=True)
# 4. Run vLLM backends and compare
for backend_name in BACKENDS_TO_TEST:
for i, backend_name in enumerate(BACKENDS_TO_TEST):
backend_output = run_attention_backend(
backend_name, kv_cache_spec, ["placeholder"], vllm_config, device,
common_attn_metadata, query_vllm, kv_c_vllm, k_pe_vllm, kv_cache,
@ -494,12 +493,12 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
mock_kv_b_proj)
# Check shape and dtype consistency
assert backend_output.shape == sdpa_output.shape, (
assert backend_output.shape == sdpa_outputs[i].shape, (
f"[{backend_name}] shape {backend_output.shape} != "
f"SDPA shape {sdpa_output.shape}")
assert backend_output.dtype == sdpa_output.dtype, (
f"SDPA shape {sdpa_outputs[i].shape}")
assert backend_output.dtype == sdpa_outputs[i].dtype, (
f"[{backend_name}] dtype {backend_output.dtype} != "
f"SDPA dtype {sdpa_output.dtype}")
f"SDPA dtype {sdpa_outputs[i].dtype}")
assert torch.isfinite(backend_output).all(), (
f"[{backend_name}] produced non-finite values")
@ -508,12 +507,13 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
rtol = 1e-2
atol = 5e-1
max_diff = torch.max(torch.abs(backend_output - sdpa_output)).item()
max_diff = torch.max(torch.abs(backend_output -
sdpa_outputs[i])).item()
max_rel_diff = torch.max(
torch.abs(backend_output - sdpa_output) /
torch.abs(sdpa_output)).item()
torch.abs(backend_output - sdpa_outputs[i]) /
torch.abs(sdpa_outputs[i])).item()
all_close = torch.allclose(backend_output,
sdpa_output,
sdpa_outputs[i],
rtol=rtol,
atol=atol)

View File

@ -139,6 +139,8 @@ def get_attention_backend(backend_name: _Backend):
"vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend",
_Backend.FLASHMLA_VLLM_V1:
"vllm.v1.attention.backends.mla.flashmla.FlashMLABackend",
_Backend.FLASH_ATTN_MLA:
"vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend",
_Backend.TRITON_MLA_VLLM_V1:
"vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend",
}

View File

@ -68,5 +68,18 @@ def flash_attn_supports_fp8() -> bool:
current_platform.get_device_capability().major == 9
def flash_attn_supports_mla():
from vllm.platforms import current_platform
if current_platform.is_cuda():
try:
from vllm.vllm_flash_attn.flash_attn_interface import (
is_fa_version_supported)
return is_fa_version_supported(3) \
and current_platform.get_device_capability()[0] == 9
except (ImportError, AssertionError):
pass
return False
def is_flash_attn_varlen_func_available() -> bool:
return current_platform.is_cuda() or current_platform.is_xpu()

View File

@ -1488,6 +1488,8 @@ class EngineArgs:
"TRITON_MLA",
"CUTLASS_MLA",
"FLASHMLA",
"FLASHMLA_VLLM_V1",
"FLASH_ATTN_MLA",
"FLASHINFER",
"FLASHINFER_VLLM_V1",
"ROCM_AITER_MLA",

View File

@ -463,6 +463,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# - "ROCM_FLASH": use ROCmFlashAttention
# - "FLASHINFER": use flashinfer
# - "FLASHMLA": use FlashMLA
# - "FLASH_ATTN_MLA": use FlashAttention for MLA
"VLLM_ATTENTION_BACKEND":
lambda: os.getenv("VLLM_ATTENTION_BACKEND", None),

View File

@ -223,9 +223,30 @@ class CudaPlatformBase(Platform):
if use_mla:
# TODO(lucas): refactor to be more concise
# we should probably consider factoring out V1 here
if selected_backend == _Backend.CUTLASS_MLA or (
cls.is_device_capability(100) and selected_backend is None
and block_size == 128):
from vllm.attention.ops.flashmla import is_flashmla_supported
from vllm.attention.utils.fa_utils import flash_attn_supports_mla
use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or (
selected_backend is None and cls.is_device_capability(100)
and block_size == 128)
use_flashmla = selected_backend in [
_Backend.FLASHMLA, _Backend.FLASHMLA_VLLM_V1
] or (selected_backend is None and is_flashmla_supported()[0])
use_flashattn = selected_backend == _Backend.FLASH_ATTN_MLA or (
selected_backend is None and flash_attn_supports_mla())
use_triton = selected_backend == _Backend.TRITON_MLA or (
selected_backend is None)
def _get_version(name, import_suffix) -> str:
if use_v1:
logger.info_once(f"Using {name} backend on V1 engine.")
return f"vllm.v1.attention.backends.mla.{import_suffix}"
else:
logger.info_once(f"Using {name} backend.")
return f"vllm.attention.backends.{import_suffix}"
if use_cutlassmla:
if use_v1:
logger.info_once("Using Cutlass MLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
@ -233,36 +254,27 @@ class CudaPlatformBase(Platform):
else:
logger.warning(
"Cutlass MLA backend is only supported on V1 engine")
if selected_backend == _Backend.TRITON_MLA or block_size != 64:
if use_v1:
logger.info_once("Using Triton MLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"triton_mla.TritonMLABackend")
else:
logger.info("Using Triton MLA backend.")
return "vllm.attention.backends.triton_mla.TritonMLABackend"
else:
from vllm.attention.backends.flashmla import (
is_flashmla_supported)
if not is_flashmla_supported()[0]:
logger.warning(
"FlashMLA backend is not supported due to %s",
is_flashmla_supported()[1])
elif block_size != 64:
if use_flashmla:
if block_size != 64:
logger.warning(
"FlashMLA backend is not supported for block size %d"
" (currently only supports block size 64).",
block_size)
else:
if use_v1:
logger.info_once(
"Using FlashMLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"flashmla.FlashMLABackend")
else:
logger.info("Using FlashMLA backend.")
return ("vllm.attention.backends."
"flashmla.FlashMLABackend")
return _get_version("FlashMLA", "flashmla.FlashMLABackend")
if use_flashattn:
if use_v1:
logger.info_once(
"Using FlashAttention MLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"flashattn_mla.FlashAttnMLABackend")
else:
logger.warning(
"FlashAttention MLA backend is only supported on V1 "
"engine.")
if use_triton:
return _get_version("Triton MLA",
"triton_mla.TritonMLABackend")
if use_v1:
FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501
FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501

View File

@ -52,9 +52,10 @@ class _Backend(enum.Enum):
FLASHINFER_VLLM_V1 = enum.auto()
TRITON_MLA = enum.auto() # Supported by V1
TRITON_MLA_VLLM_V1 = enum.auto()
FLASHMLA_VLLM_V1 = enum.auto()
FLASHMLA = enum.auto() # Supported by V1
CUTLASS_MLA = enum.auto()
FLASHMLA = enum.auto() # Supported by V1
FLASHMLA_VLLM_V1 = enum.auto()
FLASH_ATTN_MLA = enum.auto() # Supported by V1
PALLAS = enum.auto()
PALLAS_VLLM_V1 = enum.auto()
IPEX = enum.auto()

View File

@ -317,7 +317,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
split_decodes_and_prefills(common_attn_metadata)
split_decodes_and_prefills(common_attn_metadata,
decode_threshold=self.reorder_batch_threshold)
page_size = self.page_size
max_q_len = common_attn_metadata.max_query_len

View File

@ -52,8 +52,9 @@ class LinearAttentionMetadataBuilder(
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(common_attn_metadata,
decode_threshold=1))
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold))
attn_metadata = LinearAttentionMetadata(
num_prefills=num_prefills,

View File

@ -50,8 +50,9 @@ class Mamba1AttentionMetadataBuilder(
query_start_loc.device)
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(common_attn_metadata,
decode_threshold=1))
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold))
has_initial_states = None
padded_decodes = num_decodes

View File

@ -115,8 +115,9 @@ class Mamba2AttentionMetadataBuilder(
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(common_attn_metadata,
decode_threshold=1))
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold))
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
if num_prefills > 0:

View File

@ -578,11 +578,13 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
prefill.prefill_main = self._fi_prefill_main
prefill.prefill_chunks = self._fi_prefill_chunks
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor):
def _build_decode(
self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor) -> MLACommonDecodeMetadata:
return MLACommonDecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens,
seq_lens=seq_lens_device,
)
def build_for_cudagraph_capture(
@ -618,6 +620,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
seq_lens = common_attn_metadata.seq_lens
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
@ -625,7 +628,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
query_seq_lens_cpu)
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(common_attn_metadata)
split_decodes_and_prefills(common_attn_metadata,
decode_threshold=self.reorder_batch_threshold)
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_tokens
@ -725,7 +729,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
if num_decodes > 0:
decode_metadata = self._build_decode(
block_table_tensor=block_table_tensor[:num_decodes, ...],
seq_lens=seq_lens[:num_decodes],
seq_lens_cpu=seq_lens_cpu[:num_decodes],
seq_lens_device=seq_lens[:num_decodes],
query_start_loc_cpu=query_start_loc_cpu[:num_decodes + 1],
query_start_loc_device=query_start_loc[:num_decodes + 1],
)
attn_metadata = self.metadata_cls(

View File

@ -0,0 +1,189 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar, Optional
import torch
from vllm.attention.backends.abstract import (AttentionLayer, AttentionType,
is_quantized_kv_cache)
from vllm.attention.utils.fa_utils import (flash_attn_supports_mla,
get_flash_attn_version)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonDecodeMetadata,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata
logger = init_logger(__name__)
class FlashAttnMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
return "FLASH_ATTN_MLA"
@staticmethod
def get_metadata_cls() -> type["FlashAttnMLAMetadata"]:
return FlashAttnMLAMetadata
@staticmethod
def get_builder_cls() -> type["FlashAttnMLAMetadataBuilder"]:
return FlashAttnMLAMetadataBuilder
@staticmethod
def get_impl_cls() -> type["FlashAttnMLAImpl"]:
return FlashAttnMLAImpl
@dataclass
class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata):
query_start_loc: torch.Tensor
max_query_len: int
max_seq_len: int
scheduler_metadata: Optional[torch.Tensor] = None
@dataclass
class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]):
pass
class FlashAttnMLAMetadataBuilder(
MLACommonMetadataBuilder[FlashAttnMLAMetadata]):
reorder_batch_threshold: ClassVar[int] = 512
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
FlashAttnMLAMetadata)
self.fa_aot_schedule = (get_flash_attn_version() == 3)
def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens,
max_seq_len, causal):
if self.fa_aot_schedule:
return get_scheduler_metadata(
batch_size=num_reqs,
max_seqlen_q=max_query_len,
max_seqlen_k=max_seq_len,
num_heads_q=self.num_heads,
num_heads_kv=1,
headdim=self.mla_dims.qk_rope_head_dim,
cache_seqlens=seqlens,
qkv_dtype=self.kv_cache_spec.dtype,
headdim_v=self.mla_dims.kv_lora_rank,
page_size=self.page_size,
cu_seqlens_q=cu_query_lens,
causal=causal,
)
return None
def _build_decode(
self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor
) -> FlashAttnMLADecodeMetadata:
query_lens_cpu = (query_start_loc_cpu[1:] - query_start_loc_cpu[:-1])
max_query_len = query_lens_cpu.max().item()
max_seq_len = seq_lens_cpu.max().item()
scheduler_metadata = self._schedule_decode(
num_reqs=seq_lens_cpu.numel(),
cu_query_lens=query_start_loc_device,
max_query_len=max_query_len,
seqlens=seq_lens_device,
max_seq_len=max_seq_len,
causal=True,
)
return FlashAttnMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens_device,
query_start_loc=query_start_loc_device,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
scheduler_metadata=scheduler_metadata,
)
class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **mla_args)
assert flash_attn_supports_mla(), \
"FlashAttnMLA is not supported on this device"
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features):
raise NotImplementedError(
"FlashAttnMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, logits_soft_cap")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashAttnMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FlashAttnMLA V1 with FP8 KV cache not yet supported")
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashAttnMLAMetadata,
layer: AttentionLayer,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError(
"FP8 FlashAttention MLA not yet supported")
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank:]
o = flash_attn_varlen_func(
q=q_pe,
k=k_pe_cache.unsqueeze(-2), # Add head dim of 1
v=kv_c_cache.unsqueeze(-2), # Add head dim of 1
q_v=q_nope,
max_seqlen_q=attn_metadata.decode.max_query_len,
cu_seqlens_q=attn_metadata.decode.query_start_loc,
max_seqlen_k=attn_metadata.decode.max_seq_len,
seqused_k=attn_metadata.decode.seq_lens,
block_table=attn_metadata.decode.block_table,
softmax_scale=self.scale,
causal=True,
fa_version=3, # only version 3 is supported
scheduler_metadata=attn_metadata.decode.scheduler_metadata,
)
return self._v_up_proj(o)

View File

@ -85,11 +85,13 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
device=self.device,
dtype=torch.int32)
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
def _build_decode(
self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor) -> FlashMLADecodeMetadata:
tile_scheduler_metadata, num_splits = \
get_mla_metadata(
seq_lens,
seq_lens_device,
self.num_q_heads,
1, # MQA for the decode path
)
@ -123,7 +125,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
return FlashMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens,
seq_lens=seq_lens_device,
tile_scheduler_metadata=tile_scheduler_metadata,
num_splits=num_splits,
)

View File

@ -104,12 +104,14 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
dtype=torch.int32,
device=device)
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor) -> AiterMLADecodeMetadata:
def _build_decode(
self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor) -> AiterMLADecodeMetadata:
page_size = self.kv_cache_spec.block_size
block_table_bounds = (seq_lens + page_size - 1) // page_size
block_table_bounds = (seq_lens_device + page_size - 1) // page_size
device = self.device
num_reqs = seq_lens.size(0)
num_reqs = seq_lens_device.size(0)
mask = (torch.arange(block_table_tensor.size(1),
dtype=block_table_tensor.dtype,
@ -117,7 +119,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
< block_table_bounds.unsqueeze(1))
paged_kv_indices = block_table_tensor[mask]
paged_kv_last_page_len = seq_lens % page_size
paged_kv_last_page_len = seq_lens_device % page_size
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
page_size, paged_kv_last_page_len)
@ -156,7 +158,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
attn_metadata = AiterMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens,
seq_lens=seq_lens_device,
paged_kv_indptr=paged_kv_indptr,
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_kv_last_page_len,

View File

@ -58,8 +58,9 @@ class ShortConvAttentionMetadataBuilder(
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(common_attn_metadata,
decode_threshold=1))
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold))
has_initial_states = None
if num_prefills > 0:
#[batch,]
@ -78,4 +79,4 @@ class ShortConvAttentionMetadataBuilder(
has_initial_states=has_initial_states,
state_indices_tensor=state_indices_tensor,
)
return attn_metadata
return attn_metadata

View File

@ -3,7 +3,7 @@
"""Attention layer with XFormersAttention."""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, ClassVar, Optional
import torch
@ -197,6 +197,8 @@ class XFormersAttentionMetadata:
class XFormersAttentionMetadataBuilder(
AttentionMetadataBuilder[XFormersAttentionMetadata]):
reorder_batch_threshold: ClassVar[int] = 1
def __init__(
self,
kv_cache_spec: AttentionSpec,
@ -212,9 +214,10 @@ class XFormersAttentionMetadataBuilder(
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
return reorder_batch_to_split_decodes_and_prefills(input_batch,
scheduler_output,
decode_threshold=1)
return reorder_batch_to_split_decodes_and_prefills(
input_batch,
scheduler_output,
decode_threshold=self.reorder_batch_threshold)
def build(
self,
@ -223,8 +226,9 @@ class XFormersAttentionMetadataBuilder(
fast_build: bool = False,
) -> XFormersAttentionMetadata:
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(common_attn_metadata,
decode_threshold=1))
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold))
num_actual_tokens = common_attn_metadata.num_actual_tokens
q_start_loc = common_attn_metadata.query_start_loc