[Kernel][Hardware][Amd]Custom paged attention kernel for rocm (#8310)
This commit is contained in:
@ -3,8 +3,6 @@ from typing import List, Optional, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops
|
||||
@ -12,6 +10,10 @@ from vllm.utils import get_max_shared_memory_bytes, is_hip
|
||||
|
||||
from .allclose_default import get_default_atol, get_default_rtol
|
||||
|
||||
if not is_hip():
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
# This will change depending on the compute capability.
|
||||
# - 512 as a buffer
|
||||
@ -328,6 +330,165 @@ def ref_multi_query_kv_attention(
|
||||
return torch.cat(ref_outputs, dim=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("version", ["rocm"])
|
||||
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", [64, 128]) # only test 64 128
|
||||
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["auto"])
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.skipif(not is_hip(), reason="only for rocm")
|
||||
def test_paged_attention_rocm(
|
||||
kv_cache_factory,
|
||||
version: str,
|
||||
num_seqs: int,
|
||||
num_heads: Tuple[int, int],
|
||||
head_size: int,
|
||||
use_alibi: bool,
|
||||
block_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str,
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
num_query_heads, num_kv_heads = num_heads
|
||||
query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
|
||||
query.uniform_(-scale, scale)
|
||||
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||
alibi_slopes = None
|
||||
if use_alibi:
|
||||
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
|
||||
|
||||
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
||||
context_lens[-1] = MAX_SEQ_LEN
|
||||
#context_lens = [8192 for _ in range(num_seqs)]
|
||||
max_context_len = max(context_lens)
|
||||
context_lens = torch.tensor(context_lens, dtype=torch.int)
|
||||
#print('>>> ctx lens', context_lens)
|
||||
|
||||
# Create the block tables.
|
||||
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
||||
block_tables = []
|
||||
for _ in range(num_seqs):
|
||||
block_table = [
|
||||
random.randint(0, NUM_BLOCKS - 1)
|
||||
for _ in range(max_num_blocks_per_seq)
|
||||
]
|
||||
block_tables.append(block_table)
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int)
|
||||
|
||||
# Create the KV caches.
|
||||
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
|
||||
num_kv_heads, head_size,
|
||||
kv_cache_dtype, dtype, seed,
|
||||
device)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
# TODO(charlifu) enable fp8 kv cache
|
||||
# Using default kv_scale
|
||||
# kv_scale = 1.0
|
||||
|
||||
# Call the paged attention kernel.
|
||||
output = torch.empty_like(query)
|
||||
PARTITION_SIZE_ROCM = 256
|
||||
num_partitions = ((max_context_len + PARTITION_SIZE_ROCM - 1) //
|
||||
PARTITION_SIZE_ROCM)
|
||||
assert PARTITION_SIZE_ROCM % block_size == 0
|
||||
num_seqs, num_heads, head_size = output.shape
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_heads, num_partitions, head_size),
|
||||
dtype=output.dtype,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(num_seqs, num_heads, num_partitions),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
if version == "rocm":
|
||||
ops.paged_attention_rocm(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
block_size,
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
)
|
||||
else:
|
||||
raise AssertionError(f"Unknown version: {version}")
|
||||
|
||||
# Run the reference implementation.
|
||||
if kv_cache_dtype == "fp8":
|
||||
# Convert cache data back to dtype.
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
|
||||
block_size, x)
|
||||
dequantized_key_cache = torch.empty(size=key_cache_shape,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
ops.convert_fp8(key_cache, dequantized_key_cache)
|
||||
key_cache = dequantized_key_cache
|
||||
|
||||
value_cache_shape = value_cache.shape
|
||||
dequantized_value_cache = torch.empty(size=value_cache_shape,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
ops.convert_fp8(value_cache, dequantized_value_cache)
|
||||
value_cache = dequantized_value_cache
|
||||
|
||||
ref_output = torch.empty_like(query)
|
||||
ref_single_query_cached_kv_attention(
|
||||
ref_output,
|
||||
query,
|
||||
num_queries_per_kv,
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_tables,
|
||||
context_lens,
|
||||
scale,
|
||||
alibi_slopes,
|
||||
)
|
||||
|
||||
# NOTE(woosuk): Due to the kernel-level differences in the two
|
||||
# implementations, there is a small numerical difference in the two
|
||||
# outputs. Thus, we use a relaxed tolerance for the test.
|
||||
atol = get_default_atol(output) if is_hip() else 1e-3
|
||||
rtol = get_default_rtol(output) if is_hip() else 1e-5
|
||||
|
||||
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
|
||||
# so we use a relaxed tolerance for the test.
|
||||
atol, rtol = 1e-4, 1e-5
|
||||
if dtype == torch.bfloat16:
|
||||
atol, rtol = 2e-4, 1e-5
|
||||
if use_alibi:
|
||||
if dtype == torch.half:
|
||||
atol, rtol = 5e-4, 1e-5
|
||||
if dtype == torch.bfloat16:
|
||||
atol, rtol = 1e-3, 1e-5
|
||||
if kv_cache_dtype == "fp8":
|
||||
atol, rtol = 1e-2, 1e-5
|
||||
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
# TODO(woosuk): Add tests for USE_ALIBI=True.
|
||||
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@ -335,6 +496,7 @@ def ref_multi_query_kv_attention(
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.skipif(is_hip(), reason="skip for rocm")
|
||||
@torch.inference_mode()
|
||||
def test_multi_query_kv_attention(
|
||||
num_seqs: int,
|
||||
|
||||
Reference in New Issue
Block a user