Refactor attention kernels (#53)

This commit is contained in:
Woosuk Kwon
2023-05-03 13:40:13 -07:00
committed by GitHub
parent 27f1410d06
commit 436e523bf1
14 changed files with 1253 additions and 2569 deletions

View File

@ -271,78 +271,6 @@ def test_multi_query_kv_attention(
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
def test_multi_query_cached_kv_attention(
num_queries: int,
num_heads: int,
head_size: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
) -> None:
query_lens = random.sample(range(1, MAX_SEQ_LEN), num_queries)
cu_query_lens = [0]
for query_len in query_lens:
cu_query_lens.append(cu_query_lens[-1] + query_len)
num_total_tokens = cu_query_lens[-1]
qkv = torch.randn(
num_total_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
query, _, _ = qkv.unbind(dim=1)
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_block_shape = (num_heads, head_size // x, block_size, x)
key_cache = torch.randn(
size=(num_blocks, *key_block_shape), dtype=dtype, device='cuda')
value_block_shape = (num_heads, head_size, block_size)
value_cache = torch.randn(
size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda')
cu_query_lens = torch.tensor(cu_query_lens, dtype=torch.int, device='cuda')
context_lens = [
query_len + random.randint(0, MAX_SEQ_LEN - query_len)
for query_len in query_lens
]
max_context_len = max(context_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda')
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
block_tables = []
for _ in range(num_queries):
block_table = [
random.randint(0, num_blocks - 1)
for _ in range(max_num_blocks_per_seq)
]
block_tables.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
scale = float(1.0 / (head_size ** 0.5))
output = torch.empty(
num_total_tokens, num_heads, head_size, dtype=dtype, device='cuda')
attention_ops.multi_query_cached_kv_attention(
cu_query_lens,
output,
query,
key_cache,
value_cache,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
)
ref_output = ref_multi_query_cached_kv_attention(
cu_query_lens,
query,
key_cache,
value_cache,
block_tables,
context_lens,
dtype,
)
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
@torch.inference_mode()
def test_attention(seed: int) -> None:
# NOTE(woosuk): Even when the seed is fixed, there is a chance that
@ -364,24 +292,6 @@ def test_attention(seed: int) -> None:
dtype=dtype,
)
# NOTE(siyuan): Same as above. Re-run the test if it fails. Also
# note that the test is also more likely to fail due to the much
# larger amount of tokens in the input may increase the variance.
for dtype in [torch.half, torch.float]:
for block_size in [8, 16, 32]:
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
print(f'Testing multi_query_cached_kv_attention with '
f'dtype={dtype}, block_size={block_size}, '
f'head_size={head_size}')
test_multi_query_cached_kv_attention(
num_queries=11,
num_heads=3,
head_size=head_size,
block_size=block_size,
num_blocks=1024,
dtype=dtype,
)
# NOTE(woosuk): FlashAttention does not support FP32.
for dtype in [torch.half]:
# NOTE(woosuk): FlashAttention does not support head_size > 128.