[Attention] FlashAttn MLA (#14258)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Co-authored-by: Matthew Bonanni <mbonanni001@gmail.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@ -5,11 +5,11 @@ import os
|
||||
import sys
|
||||
import zipfile
|
||||
|
||||
# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 400 MiB
|
||||
# Note that we have 400 MiB quota, please use it wisely.
|
||||
# See https://github.com/pypi/support/issues/3792 .
|
||||
# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 450 MiB
|
||||
# Note that we have 800 MiB quota, please use it wisely.
|
||||
# See https://github.com/pypi/support/issues/6326 .
|
||||
# Please also sync the value with the one in Dockerfile.
|
||||
VLLM_MAX_SIZE_MB = int(os.environ.get("VLLM_MAX_SIZE_MB", 400))
|
||||
VLLM_MAX_SIZE_MB = int(os.environ.get("VLLM_MAX_SIZE_MB", 450))
|
||||
|
||||
|
||||
def print_top_10_largest_files(zip_file):
|
||||
|
||||
@ -38,7 +38,7 @@ else()
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn
|
||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||
GIT_TAG 57b4e68b9f9d94750b46de8f8dbd2bfcc86edd4f
|
||||
GIT_TAG ee4d25bd84e0cbc7e0b9b9685085fd5db2dcb62a
|
||||
GIT_PROGRESS TRUE
|
||||
# Don't share the vllm-flash-attn build between build types
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
|
||||
@ -237,7 +237,7 @@ RUN --mount=type=cache,target=/root/.cache/ccache \
|
||||
# Check the size of the wheel if RUN_WHEEL_CHECK is true
|
||||
COPY .buildkite/check-wheel-size.py check-wheel-size.py
|
||||
# sync the default value with .buildkite/check-wheel-size.py
|
||||
ARG VLLM_MAX_SIZE_MB=400
|
||||
ARG VLLM_MAX_SIZE_MB=450
|
||||
ENV VLLM_MAX_SIZE_MB=$VLLM_MAX_SIZE_MB
|
||||
ARG RUN_WHEEL_CHECK=true
|
||||
RUN if [ "$RUN_WHEEL_CHECK" = "true" ]; then \
|
||||
|
||||
@ -22,7 +22,7 @@ def clear_cache():
|
||||
|
||||
# Define MLA and non-MLA backends separately
|
||||
DEVICE_MLA_BACKENDS = {
|
||||
"cuda": ["TRITON_MLA", "FLASHMLA"],
|
||||
"cuda": ["TRITON_MLA", "FLASHMLA", "FLASH_ATTN_MLA", "CUTLASS_MLA"],
|
||||
"hip": ["TRITON_MLA", "ROCM_AITER_MLA"],
|
||||
"cpu": [],
|
||||
}
|
||||
@ -98,21 +98,14 @@ def test_env(
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
RocmPlatform()):
|
||||
if use_mla:
|
||||
# Validate HIP MLA backend-block_size combinations
|
||||
valid_combination = (
|
||||
(name == "TRITON_MLA" and block_size != 1)
|
||||
or (name == "ROCM_AITER_MLA" and block_size == 1))
|
||||
# ROCm MLA backend logic:
|
||||
# - TRITON_MLA: supported when block_size != 1
|
||||
# - ROCM_AITER_MLA: supported when block_size == 1
|
||||
# If backend is forced but doesn't match block_size,
|
||||
# should raise ValueError
|
||||
|
||||
if valid_combination:
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
torch.float16,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
expected = f"{name}_VLLM_V1" if use_v1 else name
|
||||
assert backend.get_name() == expected
|
||||
else:
|
||||
if name == "TRITON_MLA" and block_size == 1:
|
||||
# TRITON_MLA doesn't support block_size == 1
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
get_attn_backend(16,
|
||||
torch.float16,
|
||||
@ -122,6 +115,27 @@ def test_env(
|
||||
use_mla=use_mla)
|
||||
assert f"The selected backend, {name}" in str(
|
||||
exc_info.value)
|
||||
elif name == "ROCM_AITER_MLA" and block_size != 1:
|
||||
# ROCM_AITER_MLA only supports block_size == 1
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
get_attn_backend(16,
|
||||
torch.float16,
|
||||
torch.float16,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
assert f"The selected backend, {name}" in str(
|
||||
exc_info.value)
|
||||
else:
|
||||
# Valid backend-block_size combination
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
torch.float16,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
expected = f"{name}_VLLM_V1" if use_v1 else name
|
||||
assert backend.get_name() == expected
|
||||
else:
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
@ -136,16 +150,22 @@ def test_env(
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
CudaPlatform()):
|
||||
if use_mla:
|
||||
if name == "FLASHMLA" and block_size == 64:
|
||||
from vllm.attention.backends.flashmla import (
|
||||
is_flashmla_supported)
|
||||
# CUDA MLA backend logic:
|
||||
# - CUTLASS_MLA: only supported with block_size == 128
|
||||
# and Blackwell GPUs (SM 10.0), V1 only
|
||||
# - FLASHMLA: only supported with block_size == 64
|
||||
# - FLASH_ATTN_MLA: V1 only
|
||||
# - TRITON_MLA: fallback for other cases
|
||||
|
||||
# only on cuda platforms with specific capability.
|
||||
is_supported, _ = is_flashmla_supported()
|
||||
|
||||
if not is_supported:
|
||||
# if platform is not supported then skip this case.
|
||||
pytest.skip()
|
||||
if name == "CUTLASS_MLA":
|
||||
if not use_v1:
|
||||
# CUTLASS_MLA only supported on V1 engine
|
||||
pytest.skip(
|
||||
"CUTLASS_MLA only supported on V1 engine")
|
||||
elif block_size != 128:
|
||||
# CUTLASS_MLA only supports block_size == 128
|
||||
pytest.skip(
|
||||
"CUTLASS_MLA only supports block_size 128")
|
||||
else:
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
@ -153,9 +173,45 @@ def test_env(
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
expected = f"{name}_VLLM_V1" if use_v1 else name
|
||||
expected = "CUTLASS_MLA_VLLM_V1"
|
||||
assert backend.get_name() == expected
|
||||
elif name == "FLASHMLA":
|
||||
if block_size != 64:
|
||||
# FlashMLA only supports block_size == 64
|
||||
pytest.skip("FlashMLA only supports block_size 64")
|
||||
else:
|
||||
from vllm.attention.backends.flashmla import (
|
||||
is_flashmla_supported)
|
||||
is_supported, _ = is_flashmla_supported()
|
||||
if not is_supported:
|
||||
pytest.skip(
|
||||
"FlashMLA not supported on this platform")
|
||||
else:
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
torch.float16,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
expected = f"{name}_VLLM_V1" if use_v1 else name
|
||||
assert backend.get_name() == expected
|
||||
elif name == "FLASH_ATTN_MLA":
|
||||
if not use_v1:
|
||||
# FlashAttention MLA only supported on V1 engine
|
||||
pytest.skip(
|
||||
"FlashAttention MLA only supported on V1 engine"
|
||||
)
|
||||
else:
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
torch.float16,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
expected = "FLASH_ATTN_MLA"
|
||||
assert backend.get_name() == expected
|
||||
else:
|
||||
# TRITON_MLA or other fallback
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
torch.float16,
|
||||
|
||||
@ -70,22 +70,6 @@ BATCH_SPECS = {
|
||||
}
|
||||
|
||||
|
||||
def create_dummy_kv_cache(kv_cache_spec: FullAttentionSpec,
|
||||
device: torch.device,
|
||||
num_blocks: int = 100) -> torch.Tensor:
|
||||
"""Create a dummy KV cache tensor for testing."""
|
||||
kv_cache = torch.randn(
|
||||
2, # K and V
|
||||
num_blocks,
|
||||
kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size,
|
||||
dtype=_convert_dtype_to_torch(kv_cache_spec.dtype),
|
||||
device=device,
|
||||
)
|
||||
return kv_cache
|
||||
|
||||
|
||||
def create_and_prepopulate_kv_cache(
|
||||
k_contexts: list[torch.Tensor],
|
||||
v_contexts: list[torch.Tensor],
|
||||
|
||||
@ -15,7 +15,7 @@ from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
|
||||
BACKENDS_TO_TEST = [
|
||||
_Backend.CUTLASS_MLA, _Backend.FLASHMLA_VLLM_V1,
|
||||
_Backend.CUTLASS_MLA, _Backend.FLASHMLA_VLLM_V1, _Backend.FLASH_ATTN_MLA,
|
||||
_Backend.TRITON_MLA_VLLM_V1
|
||||
]
|
||||
|
||||
@ -69,20 +69,6 @@ BATCH_SPECS = {
|
||||
}
|
||||
|
||||
|
||||
def create_dummy_kv_cache(kv_cache_spec: FullAttentionSpec,
|
||||
device: torch.device,
|
||||
num_blocks: int = 100) -> torch.Tensor:
|
||||
"""Create a dummy KV cache tensor for testing."""
|
||||
kv_cache = torch.randn(
|
||||
num_blocks,
|
||||
kv_cache_spec.block_size,
|
||||
kv_cache_spec.head_size, # latent dimension
|
||||
dtype=_convert_dtype_to_torch(kv_cache_spec.dtype),
|
||||
device=device,
|
||||
)
|
||||
return kv_cache
|
||||
|
||||
|
||||
def create_and_prepopulate_kv_cache(
|
||||
kv_c_contexts: list[torch.Tensor],
|
||||
k_pe_contexts: list[torch.Tensor],
|
||||
@ -315,7 +301,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
|
||||
# 2. Generate data and compute SDPA reference output for MLA
|
||||
all_q_vllm, all_kv_c_vllm, all_k_pe_vllm = [], [], []
|
||||
all_sdpa_outputs = []
|
||||
all_sdpa_outputs: list[list[torch.Tensor]] = []
|
||||
kv_c_contexts, k_pe_contexts = [], []
|
||||
|
||||
# Create shared MLA weight matrices for consistency across all sequences
|
||||
@ -331,6 +317,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
device=device)
|
||||
kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1)
|
||||
|
||||
for i, backend in enumerate(BACKENDS_TO_TEST):
|
||||
all_sdpa_outputs.append([])
|
||||
|
||||
for i in range(batch_size):
|
||||
s_len = seq_lens[i]
|
||||
q_len = query_lens[i]
|
||||
@ -358,85 +347,93 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
|
||||
# Determine if this is decode (single token)
|
||||
# or prefill (multiple tokens)
|
||||
is_decode = q_len == 1
|
||||
# Determine if this is decode or prefill
|
||||
is_decode = []
|
||||
for i, backend in enumerate(BACKENDS_TO_TEST):
|
||||
builder_cls, _ = get_attention_backend(backend)
|
||||
is_decode.append(q_len <= builder_cls.reorder_batch_threshold)
|
||||
|
||||
# Split q into nope and rope components
|
||||
q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
|
||||
|
||||
if is_decode:
|
||||
# Decode path: MQA-style attention in latent space
|
||||
# Transform q_nope to latent space: q_nope @ W_UK
|
||||
# q_nope: [1, num_heads, qk_nope_head_dim]
|
||||
# W_UK: [kv_lora_rank, num_heads, qk_nope_head_dim]
|
||||
ql_nope = torch.einsum("qnh,lnh->qnl", q_nope,
|
||||
W_UK) # [1, num_heads, kv_lora_rank]
|
||||
#######################################################
|
||||
# Decode path: MQA-style attention in latent space
|
||||
# Transform q_nope to latent space: q_nope @ W_UK
|
||||
# q_nope: [1, num_heads, qk_nope_head_dim]
|
||||
# W_UK: [kv_lora_rank, num_heads, qk_nope_head_dim]
|
||||
ql_nope = torch.einsum("qnh,lnh->qnl", q_nope,
|
||||
W_UK) # [1, num_heads, kv_lora_rank]
|
||||
|
||||
# Build MQA attention inputs
|
||||
# Q: [1, num_heads, kv_lora_rank + qk_rope_head_dim]
|
||||
q_mqa = torch.cat([ql_nope, q_pe], dim=-1)
|
||||
# K: [s_len, kv_lora_rank + qk_rope_head_dim]
|
||||
# (broadcasted to all heads)
|
||||
k_mqa = torch.cat([kv_c_full, k_pe_full.squeeze(1)], dim=-1)
|
||||
k_mqa = k_mqa.unsqueeze(1).expand(-1, num_q_heads, -1)
|
||||
# V: [s_len, kv_lora_rank] (broadcasted to all heads)
|
||||
v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_q_heads, -1)
|
||||
# Build MQA attention inputs
|
||||
# Q: [1, num_heads, kv_lora_rank + qk_rope_head_dim]
|
||||
q_mqa = torch.cat([ql_nope, q_pe], dim=-1)
|
||||
# K: [s_len, kv_lora_rank + qk_rope_head_dim]
|
||||
# (broadcasted to all heads)
|
||||
k_mqa = torch.cat([kv_c_full, k_pe_full.squeeze(1)], dim=-1)
|
||||
k_mqa = k_mqa.unsqueeze(1).expand(-1, num_q_heads, -1)
|
||||
# V: [s_len, kv_lora_rank] (broadcasted to all heads)
|
||||
v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_q_heads, -1)
|
||||
|
||||
# SDPA expects (N, H, L, D)
|
||||
q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2)
|
||||
k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2)
|
||||
v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2)
|
||||
# Create custom attention mask for decode path:
|
||||
# - Query tokens can attend to all context tokens
|
||||
# - Query tokens can only attend to query tokens up to their position
|
||||
attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device)
|
||||
# Apply causal mask only to the query portion (context_len onwards)
|
||||
causal_mask = torch.tril(torch.ones(q_len, q_len, device=device))
|
||||
attn_mask[:, context_len:] = causal_mask
|
||||
|
||||
sdpa_out_i = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_sdpa_in, k_sdpa_in, v_sdpa_in, is_causal=False, scale=scale)
|
||||
sdpa_out_i = sdpa_out_i.transpose(1, 2).squeeze(
|
||||
0) # [1, num_heads, kv_lora_rank]
|
||||
# SDPA expects (N, H, L, D)
|
||||
q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2)
|
||||
k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2)
|
||||
v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2)
|
||||
|
||||
# Project back to output space: sdpa_out @ W_UV
|
||||
sdpa_out_i = torch.einsum("qnl,lnv->qnv", sdpa_out_i, W_UV)
|
||||
sdpa_out_i = sdpa_out_i.flatten(start_dim=-2)
|
||||
else:
|
||||
# Prefill path: MHA-style attention with full sequence
|
||||
# Apply kv_b_proj to the full kv_c tensor
|
||||
kv_nope_full = torch.einsum("sl,lnh->snh", kv_c_full,
|
||||
kv_b_proj_weight)
|
||||
k_nope_full, v_full = kv_nope_full.split(
|
||||
[qk_nope_head_dim, v_head_dim], dim=-1)
|
||||
sdpa_out_i_decode = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale)
|
||||
sdpa_out_i_decode = sdpa_out_i_decode.transpose(1, 2).squeeze(
|
||||
0) # [1, num_heads, kv_lora_rank]
|
||||
|
||||
# Build attention inputs for full sequence
|
||||
q_mha = torch.cat([q_nope, q_pe],
|
||||
dim=-1) # [q_len, num_heads, total_dim]
|
||||
k_pe_full_expanded = k_pe_full.expand(-1, num_q_heads, -1)
|
||||
k_full = torch.cat([k_nope_full, k_pe_full_expanded], dim=-1)
|
||||
# Project back to output space: sdpa_out @ W_UV
|
||||
sdpa_out_i_decode = torch.einsum("qnl,lnv->qnv", sdpa_out_i_decode,
|
||||
W_UV)
|
||||
sdpa_out_i_decode = sdpa_out_i_decode.flatten(start_dim=-2)
|
||||
|
||||
# Create custom attention mask:
|
||||
# - Query tokens can attend to all context tokens
|
||||
# - Query tokens can only attend to query tokens up to their pos
|
||||
attn_mask = torch.ones(q_len,
|
||||
s_len,
|
||||
dtype=torch.bool,
|
||||
device=device)
|
||||
# Apply causal mask only to the query portion (context_len onwards)
|
||||
causal_mask = torch.tril(torch.ones(q_len, q_len, device=device))
|
||||
attn_mask[:, context_len:] = causal_mask
|
||||
#######################################################
|
||||
# Prefill path: MHA-style attention with full sequence
|
||||
# Apply kv_b_proj to the full kv_c tensor
|
||||
kv_nope_full = torch.einsum("sl,lnh->snh", kv_c_full, kv_b_proj_weight)
|
||||
k_nope_full, v_full = kv_nope_full.split(
|
||||
[qk_nope_head_dim, v_head_dim], dim=-1)
|
||||
|
||||
# SDPA expects (N, H, L, D)
|
||||
q_sdpa_in = q_mha.unsqueeze(0).transpose(1, 2)
|
||||
k_sdpa_in = k_full.unsqueeze(0).transpose(1, 2)
|
||||
v_sdpa_in = v_full.unsqueeze(0).transpose(1, 2)
|
||||
# Build attention inputs for full sequence
|
||||
q_mha = torch.cat([q_nope, q_pe],
|
||||
dim=-1) # [q_len, num_heads, total_dim]
|
||||
k_pe_full_expanded = k_pe_full.expand(-1, num_q_heads, -1)
|
||||
k_full = torch.cat([k_nope_full, k_pe_full_expanded], dim=-1)
|
||||
|
||||
# Single attention call with custom mask
|
||||
sdpa_out_i = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_sdpa_in,
|
||||
k_sdpa_in,
|
||||
v_sdpa_in,
|
||||
attn_mask=attn_mask,
|
||||
scale=scale)
|
||||
sdpa_out_i = sdpa_out_i.transpose(1, 2).squeeze(0)
|
||||
sdpa_out_i = sdpa_out_i.flatten(start_dim=-2)
|
||||
# Create custom attention mask:
|
||||
# - Query tokens can attend to all context tokens
|
||||
# - Query tokens can only attend to query tokens up to their pos
|
||||
attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device)
|
||||
# Apply causal mask only to the query portion (context_len onwards)
|
||||
causal_mask = torch.tril(torch.ones(q_len, q_len, device=device))
|
||||
attn_mask[:, context_len:] = causal_mask
|
||||
|
||||
all_sdpa_outputs.append(sdpa_out_i)
|
||||
# SDPA expects (N, H, L, D)
|
||||
q_sdpa_in = q_mha.unsqueeze(0).transpose(1, 2)
|
||||
k_sdpa_in = k_full.unsqueeze(0).transpose(1, 2)
|
||||
v_sdpa_in = v_full.unsqueeze(0).transpose(1, 2)
|
||||
|
||||
# Single attention call with custom mask
|
||||
sdpa_out_i_prefill = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale)
|
||||
sdpa_out_i_prefill = sdpa_out_i_prefill.transpose(1, 2).squeeze(0)
|
||||
sdpa_out_i_prefill = sdpa_out_i_prefill.flatten(start_dim=-2)
|
||||
|
||||
for i, backend in enumerate(BACKENDS_TO_TEST):
|
||||
if is_decode[i]:
|
||||
all_sdpa_outputs[i].append(sdpa_out_i_decode)
|
||||
else:
|
||||
all_sdpa_outputs[i].append(sdpa_out_i_prefill)
|
||||
|
||||
# Inputs for vLLM MLA backends are just the new tokens
|
||||
all_q_vllm.append(q_c)
|
||||
@ -451,7 +448,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
query_vllm = torch.cat(all_q_vllm, dim=0)
|
||||
kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0)
|
||||
k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0)
|
||||
sdpa_output = torch.cat(all_sdpa_outputs, dim=0)
|
||||
sdpa_outputs = []
|
||||
for i, backend in enumerate(BACKENDS_TO_TEST):
|
||||
sdpa_outputs.append(torch.cat(all_sdpa_outputs[i], dim=0))
|
||||
|
||||
# Create mock kv_b_proj using the same weights as reference implementation
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
@ -486,7 +485,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
randomize_blocks=True)
|
||||
|
||||
# 4. Run vLLM backends and compare
|
||||
for backend_name in BACKENDS_TO_TEST:
|
||||
for i, backend_name in enumerate(BACKENDS_TO_TEST):
|
||||
backend_output = run_attention_backend(
|
||||
backend_name, kv_cache_spec, ["placeholder"], vllm_config, device,
|
||||
common_attn_metadata, query_vllm, kv_c_vllm, k_pe_vllm, kv_cache,
|
||||
@ -494,12 +493,12 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
mock_kv_b_proj)
|
||||
|
||||
# Check shape and dtype consistency
|
||||
assert backend_output.shape == sdpa_output.shape, (
|
||||
assert backend_output.shape == sdpa_outputs[i].shape, (
|
||||
f"[{backend_name}] shape {backend_output.shape} != "
|
||||
f"SDPA shape {sdpa_output.shape}")
|
||||
assert backend_output.dtype == sdpa_output.dtype, (
|
||||
f"SDPA shape {sdpa_outputs[i].shape}")
|
||||
assert backend_output.dtype == sdpa_outputs[i].dtype, (
|
||||
f"[{backend_name}] dtype {backend_output.dtype} != "
|
||||
f"SDPA dtype {sdpa_output.dtype}")
|
||||
f"SDPA dtype {sdpa_outputs[i].dtype}")
|
||||
|
||||
assert torch.isfinite(backend_output).all(), (
|
||||
f"[{backend_name}] produced non-finite values")
|
||||
@ -508,12 +507,13 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
rtol = 1e-2
|
||||
atol = 5e-1
|
||||
|
||||
max_diff = torch.max(torch.abs(backend_output - sdpa_output)).item()
|
||||
max_diff = torch.max(torch.abs(backend_output -
|
||||
sdpa_outputs[i])).item()
|
||||
max_rel_diff = torch.max(
|
||||
torch.abs(backend_output - sdpa_output) /
|
||||
torch.abs(sdpa_output)).item()
|
||||
torch.abs(backend_output - sdpa_outputs[i]) /
|
||||
torch.abs(sdpa_outputs[i])).item()
|
||||
all_close = torch.allclose(backend_output,
|
||||
sdpa_output,
|
||||
sdpa_outputs[i],
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
|
||||
@ -139,6 +139,8 @@ def get_attention_backend(backend_name: _Backend):
|
||||
"vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend",
|
||||
_Backend.FLASHMLA_VLLM_V1:
|
||||
"vllm.v1.attention.backends.mla.flashmla.FlashMLABackend",
|
||||
_Backend.FLASH_ATTN_MLA:
|
||||
"vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend",
|
||||
_Backend.TRITON_MLA_VLLM_V1:
|
||||
"vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend",
|
||||
}
|
||||
|
||||
@ -68,5 +68,18 @@ def flash_attn_supports_fp8() -> bool:
|
||||
current_platform.get_device_capability().major == 9
|
||||
|
||||
|
||||
def flash_attn_supports_mla():
|
||||
from vllm.platforms import current_platform
|
||||
if current_platform.is_cuda():
|
||||
try:
|
||||
from vllm.vllm_flash_attn.flash_attn_interface import (
|
||||
is_fa_version_supported)
|
||||
return is_fa_version_supported(3) \
|
||||
and current_platform.get_device_capability()[0] == 9
|
||||
except (ImportError, AssertionError):
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def is_flash_attn_varlen_func_available() -> bool:
|
||||
return current_platform.is_cuda() or current_platform.is_xpu()
|
||||
|
||||
@ -1488,6 +1488,8 @@ class EngineArgs:
|
||||
"TRITON_MLA",
|
||||
"CUTLASS_MLA",
|
||||
"FLASHMLA",
|
||||
"FLASHMLA_VLLM_V1",
|
||||
"FLASH_ATTN_MLA",
|
||||
"FLASHINFER",
|
||||
"FLASHINFER_VLLM_V1",
|
||||
"ROCM_AITER_MLA",
|
||||
|
||||
@ -463,6 +463,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# - "ROCM_FLASH": use ROCmFlashAttention
|
||||
# - "FLASHINFER": use flashinfer
|
||||
# - "FLASHMLA": use FlashMLA
|
||||
# - "FLASH_ATTN_MLA": use FlashAttention for MLA
|
||||
"VLLM_ATTENTION_BACKEND":
|
||||
lambda: os.getenv("VLLM_ATTENTION_BACKEND", None),
|
||||
|
||||
|
||||
@ -223,9 +223,30 @@ class CudaPlatformBase(Platform):
|
||||
if use_mla:
|
||||
# TODO(lucas): refactor to be more concise
|
||||
# we should probably consider factoring out V1 here
|
||||
if selected_backend == _Backend.CUTLASS_MLA or (
|
||||
cls.is_device_capability(100) and selected_backend is None
|
||||
and block_size == 128):
|
||||
|
||||
from vllm.attention.ops.flashmla import is_flashmla_supported
|
||||
from vllm.attention.utils.fa_utils import flash_attn_supports_mla
|
||||
|
||||
use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or (
|
||||
selected_backend is None and cls.is_device_capability(100)
|
||||
and block_size == 128)
|
||||
use_flashmla = selected_backend in [
|
||||
_Backend.FLASHMLA, _Backend.FLASHMLA_VLLM_V1
|
||||
] or (selected_backend is None and is_flashmla_supported()[0])
|
||||
use_flashattn = selected_backend == _Backend.FLASH_ATTN_MLA or (
|
||||
selected_backend is None and flash_attn_supports_mla())
|
||||
use_triton = selected_backend == _Backend.TRITON_MLA or (
|
||||
selected_backend is None)
|
||||
|
||||
def _get_version(name, import_suffix) -> str:
|
||||
if use_v1:
|
||||
logger.info_once(f"Using {name} backend on V1 engine.")
|
||||
return f"vllm.v1.attention.backends.mla.{import_suffix}"
|
||||
else:
|
||||
logger.info_once(f"Using {name} backend.")
|
||||
return f"vllm.attention.backends.{import_suffix}"
|
||||
|
||||
if use_cutlassmla:
|
||||
if use_v1:
|
||||
logger.info_once("Using Cutlass MLA backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends.mla."
|
||||
@ -233,36 +254,27 @@ class CudaPlatformBase(Platform):
|
||||
else:
|
||||
logger.warning(
|
||||
"Cutlass MLA backend is only supported on V1 engine")
|
||||
if selected_backend == _Backend.TRITON_MLA or block_size != 64:
|
||||
if use_v1:
|
||||
logger.info_once("Using Triton MLA backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends.mla."
|
||||
"triton_mla.TritonMLABackend")
|
||||
else:
|
||||
logger.info("Using Triton MLA backend.")
|
||||
return "vllm.attention.backends.triton_mla.TritonMLABackend"
|
||||
else:
|
||||
from vllm.attention.backends.flashmla import (
|
||||
is_flashmla_supported)
|
||||
if not is_flashmla_supported()[0]:
|
||||
logger.warning(
|
||||
"FlashMLA backend is not supported due to %s",
|
||||
is_flashmla_supported()[1])
|
||||
elif block_size != 64:
|
||||
if use_flashmla:
|
||||
if block_size != 64:
|
||||
logger.warning(
|
||||
"FlashMLA backend is not supported for block size %d"
|
||||
" (currently only supports block size 64).",
|
||||
block_size)
|
||||
else:
|
||||
if use_v1:
|
||||
logger.info_once(
|
||||
"Using FlashMLA backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends.mla."
|
||||
"flashmla.FlashMLABackend")
|
||||
else:
|
||||
logger.info("Using FlashMLA backend.")
|
||||
return ("vllm.attention.backends."
|
||||
"flashmla.FlashMLABackend")
|
||||
return _get_version("FlashMLA", "flashmla.FlashMLABackend")
|
||||
if use_flashattn:
|
||||
if use_v1:
|
||||
logger.info_once(
|
||||
"Using FlashAttention MLA backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends.mla."
|
||||
"flashattn_mla.FlashAttnMLABackend")
|
||||
else:
|
||||
logger.warning(
|
||||
"FlashAttention MLA backend is only supported on V1 "
|
||||
"engine.")
|
||||
if use_triton:
|
||||
return _get_version("Triton MLA",
|
||||
"triton_mla.TritonMLABackend")
|
||||
if use_v1:
|
||||
FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501
|
||||
FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
|
||||
|
||||
@ -52,9 +52,10 @@ class _Backend(enum.Enum):
|
||||
FLASHINFER_VLLM_V1 = enum.auto()
|
||||
TRITON_MLA = enum.auto() # Supported by V1
|
||||
TRITON_MLA_VLLM_V1 = enum.auto()
|
||||
FLASHMLA_VLLM_V1 = enum.auto()
|
||||
FLASHMLA = enum.auto() # Supported by V1
|
||||
CUTLASS_MLA = enum.auto()
|
||||
FLASHMLA = enum.auto() # Supported by V1
|
||||
FLASHMLA_VLLM_V1 = enum.auto()
|
||||
FLASH_ATTN_MLA = enum.auto() # Supported by V1
|
||||
PALLAS = enum.auto()
|
||||
PALLAS_VLLM_V1 = enum.auto()
|
||||
IPEX = enum.auto()
|
||||
|
||||
@ -317,7 +317,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
|
||||
split_decodes_and_prefills(common_attn_metadata)
|
||||
split_decodes_and_prefills(common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold)
|
||||
|
||||
page_size = self.page_size
|
||||
max_q_len = common_attn_metadata.max_query_len
|
||||
|
||||
@ -52,8 +52,9 @@ class LinearAttentionMetadataBuilder(
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(common_attn_metadata,
|
||||
decode_threshold=1))
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold))
|
||||
|
||||
attn_metadata = LinearAttentionMetadata(
|
||||
num_prefills=num_prefills,
|
||||
|
||||
@ -50,8 +50,9 @@ class Mamba1AttentionMetadataBuilder(
|
||||
query_start_loc.device)
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(common_attn_metadata,
|
||||
decode_threshold=1))
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold))
|
||||
|
||||
has_initial_states = None
|
||||
padded_decodes = num_decodes
|
||||
|
||||
@ -115,8 +115,9 @@ class Mamba2AttentionMetadataBuilder(
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(common_attn_metadata,
|
||||
decode_threshold=1))
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold))
|
||||
|
||||
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
|
||||
if num_prefills > 0:
|
||||
|
||||
@ -578,11 +578,13 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
prefill.prefill_main = self._fi_prefill_main
|
||||
prefill.prefill_chunks = self._fi_prefill_chunks
|
||||
|
||||
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||
seq_lens: torch.Tensor):
|
||||
def _build_decode(
|
||||
self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor,
|
||||
seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor,
|
||||
query_start_loc_device: torch.Tensor) -> MLACommonDecodeMetadata:
|
||||
return MLACommonDecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens=seq_lens_device,
|
||||
)
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
@ -618,6 +620,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||
|
||||
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
|
||||
@ -625,7 +628,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
query_seq_lens_cpu)
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
||||
split_decodes_and_prefills(common_attn_metadata)
|
||||
split_decodes_and_prefills(common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold)
|
||||
|
||||
assert num_decodes + num_prefills == num_reqs
|
||||
assert num_decode_tokens + num_prefill_tokens == num_tokens
|
||||
@ -725,7 +729,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
if num_decodes > 0:
|
||||
decode_metadata = self._build_decode(
|
||||
block_table_tensor=block_table_tensor[:num_decodes, ...],
|
||||
seq_lens=seq_lens[:num_decodes],
|
||||
seq_lens_cpu=seq_lens_cpu[:num_decodes],
|
||||
seq_lens_device=seq_lens[:num_decodes],
|
||||
query_start_loc_cpu=query_start_loc_cpu[:num_decodes + 1],
|
||||
query_start_loc_device=query_start_loc[:num_decodes + 1],
|
||||
)
|
||||
|
||||
attn_metadata = self.metadata_cls(
|
||||
|
||||
189
vllm/v1/attention/backends/mla/flashattn_mla.py
Normal file
189
vllm/v1/attention/backends/mla/flashattn_mla.py
Normal file
@ -0,0 +1,189 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionLayer, AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.utils.fa_utils import (flash_attn_supports_mla,
|
||||
get_flash_attn_version)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FlashAttnMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASH_ATTN_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["FlashAttnMLAMetadata"]:
|
||||
return FlashAttnMLAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlashAttnMLAMetadataBuilder"]:
|
||||
return FlashAttnMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashAttnMLAImpl"]:
|
||||
return FlashAttnMLAImpl
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata):
|
||||
query_start_loc: torch.Tensor
|
||||
max_query_len: int
|
||||
max_seq_len: int
|
||||
scheduler_metadata: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]):
|
||||
pass
|
||||
|
||||
|
||||
class FlashAttnMLAMetadataBuilder(
|
||||
MLACommonMetadataBuilder[FlashAttnMLAMetadata]):
|
||||
reorder_batch_threshold: ClassVar[int] = 512
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
|
||||
FlashAttnMLAMetadata)
|
||||
self.fa_aot_schedule = (get_flash_attn_version() == 3)
|
||||
|
||||
def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens,
|
||||
max_seq_len, causal):
|
||||
if self.fa_aot_schedule:
|
||||
return get_scheduler_metadata(
|
||||
batch_size=num_reqs,
|
||||
max_seqlen_q=max_query_len,
|
||||
max_seqlen_k=max_seq_len,
|
||||
num_heads_q=self.num_heads,
|
||||
num_heads_kv=1,
|
||||
headdim=self.mla_dims.qk_rope_head_dim,
|
||||
cache_seqlens=seqlens,
|
||||
qkv_dtype=self.kv_cache_spec.dtype,
|
||||
headdim_v=self.mla_dims.kv_lora_rank,
|
||||
page_size=self.page_size,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
causal=causal,
|
||||
)
|
||||
return None
|
||||
|
||||
def _build_decode(
|
||||
self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor,
|
||||
seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor,
|
||||
query_start_loc_device: torch.Tensor
|
||||
) -> FlashAttnMLADecodeMetadata:
|
||||
query_lens_cpu = (query_start_loc_cpu[1:] - query_start_loc_cpu[:-1])
|
||||
max_query_len = query_lens_cpu.max().item()
|
||||
max_seq_len = seq_lens_cpu.max().item()
|
||||
|
||||
scheduler_metadata = self._schedule_decode(
|
||||
num_reqs=seq_lens_cpu.numel(),
|
||||
cu_query_lens=query_start_loc_device,
|
||||
max_query_len=max_query_len,
|
||||
seqlens=seq_lens_device,
|
||||
max_seq_len=max_seq_len,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
return FlashAttnMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens_device,
|
||||
query_start_loc=query_start_loc_device,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=max_seq_len,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
)
|
||||
|
||||
|
||||
class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
|
||||
assert flash_attn_supports_mla(), \
|
||||
"FlashAttnMLA is not supported on this device"
|
||||
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"FlashAttnMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashAttnMLAImpl")
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"FlashAttnMLA V1 with FP8 KV cache not yet supported")
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttnMLAMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError(
|
||||
"FP8 FlashAttention MLA not yet supported")
|
||||
|
||||
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
|
||||
k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank:]
|
||||
|
||||
o = flash_attn_varlen_func(
|
||||
q=q_pe,
|
||||
k=k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
||||
v=kv_c_cache.unsqueeze(-2), # Add head dim of 1
|
||||
q_v=q_nope,
|
||||
max_seqlen_q=attn_metadata.decode.max_query_len,
|
||||
cu_seqlens_q=attn_metadata.decode.query_start_loc,
|
||||
max_seqlen_k=attn_metadata.decode.max_seq_len,
|
||||
seqused_k=attn_metadata.decode.seq_lens,
|
||||
block_table=attn_metadata.decode.block_table,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
fa_version=3, # only version 3 is supported
|
||||
scheduler_metadata=attn_metadata.decode.scheduler_metadata,
|
||||
)
|
||||
|
||||
return self._v_up_proj(o)
|
||||
@ -85,11 +85,13 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
device=self.device,
|
||||
dtype=torch.int32)
|
||||
|
||||
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||
seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
|
||||
def _build_decode(
|
||||
self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor,
|
||||
seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor,
|
||||
query_start_loc_device: torch.Tensor) -> FlashMLADecodeMetadata:
|
||||
tile_scheduler_metadata, num_splits = \
|
||||
get_mla_metadata(
|
||||
seq_lens,
|
||||
seq_lens_device,
|
||||
self.num_q_heads,
|
||||
1, # MQA for the decode path
|
||||
)
|
||||
@ -123,7 +125,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
|
||||
return FlashMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens=seq_lens_device,
|
||||
tile_scheduler_metadata=tile_scheduler_metadata,
|
||||
num_splits=num_splits,
|
||||
)
|
||||
|
||||
@ -104,12 +104,14 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||
seq_lens: torch.Tensor) -> AiterMLADecodeMetadata:
|
||||
def _build_decode(
|
||||
self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor,
|
||||
seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor,
|
||||
query_start_loc_device: torch.Tensor) -> AiterMLADecodeMetadata:
|
||||
page_size = self.kv_cache_spec.block_size
|
||||
block_table_bounds = (seq_lens + page_size - 1) // page_size
|
||||
block_table_bounds = (seq_lens_device + page_size - 1) // page_size
|
||||
device = self.device
|
||||
num_reqs = seq_lens.size(0)
|
||||
num_reqs = seq_lens_device.size(0)
|
||||
|
||||
mask = (torch.arange(block_table_tensor.size(1),
|
||||
dtype=block_table_tensor.dtype,
|
||||
@ -117,7 +119,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
< block_table_bounds.unsqueeze(1))
|
||||
paged_kv_indices = block_table_tensor[mask]
|
||||
|
||||
paged_kv_last_page_len = seq_lens % page_size
|
||||
paged_kv_last_page_len = seq_lens_device % page_size
|
||||
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
|
||||
page_size, paged_kv_last_page_len)
|
||||
|
||||
@ -156,7 +158,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
|
||||
attn_metadata = AiterMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens=seq_lens_device,
|
||||
paged_kv_indptr=paged_kv_indptr,
|
||||
paged_kv_indices=paged_kv_indices,
|
||||
paged_kv_last_page_len=paged_kv_last_page_len,
|
||||
|
||||
@ -58,8 +58,9 @@ class ShortConvAttentionMetadataBuilder(
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(common_attn_metadata,
|
||||
decode_threshold=1))
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold))
|
||||
has_initial_states = None
|
||||
if num_prefills > 0:
|
||||
#[batch,]
|
||||
@ -78,4 +79,4 @@ class ShortConvAttentionMetadataBuilder(
|
||||
has_initial_states=has_initial_states,
|
||||
state_indices_tensor=state_indices_tensor,
|
||||
)
|
||||
return attn_metadata
|
||||
return attn_metadata
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
"""Attention layer with XFormersAttention."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, ClassVar, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -197,6 +197,8 @@ class XFormersAttentionMetadata:
|
||||
class XFormersAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[XFormersAttentionMetadata]):
|
||||
|
||||
reorder_batch_threshold: ClassVar[int] = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
@ -212,9 +214,10 @@ class XFormersAttentionMetadataBuilder(
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
return reorder_batch_to_split_decodes_and_prefills(input_batch,
|
||||
scheduler_output,
|
||||
decode_threshold=1)
|
||||
return reorder_batch_to_split_decodes_and_prefills(
|
||||
input_batch,
|
||||
scheduler_output,
|
||||
decode_threshold=self.reorder_batch_threshold)
|
||||
|
||||
def build(
|
||||
self,
|
||||
@ -223,8 +226,9 @@ class XFormersAttentionMetadataBuilder(
|
||||
fast_build: bool = False,
|
||||
) -> XFormersAttentionMetadata:
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(common_attn_metadata,
|
||||
decode_threshold=1))
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold))
|
||||
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
q_start_loc = common_attn_metadata.query_start_loc
|
||||
|
||||
Reference in New Issue
Block a user