Optimize data movement (#20)
This commit is contained in:
30
tests/kernels/activation.py
Normal file
30
tests/kernels/activation.py
Normal file
@ -0,0 +1,30 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from cacheflow import activation_ops
|
||||
|
||||
|
||||
def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor:
|
||||
x1, x2 = x.chunk(chunks=2, dim=1)
|
||||
return F.silu(x1) * x2
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_silu_and_mul(
|
||||
num_tokens: int,
|
||||
d: int,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device='cuda')
|
||||
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
|
||||
activation_ops.silu_and_mul(out, x)
|
||||
ref_out = ref_silu_and_mul(x)
|
||||
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
for dtype in [torch.half, torch.float]:
|
||||
for num_tokens in [7, 83, 2048]:
|
||||
for d in [512, 4096, 13824]:
|
||||
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')
|
||||
test_silu_and_mul(num_tokens, d, dtype)
|
||||
@ -1,7 +1,7 @@
|
||||
import random
|
||||
from typing import List, Optional
|
||||
|
||||
from flash_attn.flash_attention import FlashAttention
|
||||
from flash_attn.flash_attn_interface import _flash_attn_forward
|
||||
import torch
|
||||
|
||||
from cacheflow import attention_ops
|
||||
@ -105,8 +105,9 @@ def test_single_query_cached_kv_attention(
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
query = torch.randn(
|
||||
num_tokens, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
qkv = torch.randn(
|
||||
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
query, _, _ = qkv.unbind(dim=1)
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_block_shape = (num_heads, head_size // x, block_size, x)
|
||||
key_cache = torch.randn(
|
||||
@ -115,6 +116,11 @@ def test_single_query_cached_kv_attention(
|
||||
value_cache = torch.randn(
|
||||
size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda')
|
||||
|
||||
# Adjust the range of the values to reduce precision errors.
|
||||
query = query / (head_size ** 0.5)
|
||||
key_cache = key_cache / (head_size ** 0.5)
|
||||
value_cache = value_cache / (head_size ** 0.5)
|
||||
|
||||
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)]
|
||||
max_context_len = max(context_lens)
|
||||
context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda')
|
||||
@ -130,7 +136,8 @@ def test_single_query_cached_kv_attention(
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
|
||||
|
||||
scale = float(1.0 / (head_size ** 0.5))
|
||||
output = torch.empty_like(query)
|
||||
output = torch.empty(
|
||||
num_tokens, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
attention_ops.single_query_cached_kv_attention(
|
||||
output,
|
||||
query,
|
||||
@ -175,19 +182,28 @@ def test_multi_query_kv_attention(
|
||||
cu_seq_lens = torch.tensor(cu_seq_lens, dtype=torch.int, device='cuda')
|
||||
|
||||
scale = float(1.0 / (head_size ** 0.5))
|
||||
query = torch.randn(
|
||||
num_tokens, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
key = torch.rand_like(query)
|
||||
value = torch.rand_like(query)
|
||||
qkv = torch.randn(
|
||||
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
# Adjust the range of the values to reduce precision errors.
|
||||
qkv = qkv / (head_size ** 0.5)
|
||||
|
||||
qkv = torch.stack([query, key, value], dim=1)
|
||||
flash_attn = FlashAttention(softmax_scale=scale)
|
||||
output = flash_attn(
|
||||
qkv,
|
||||
cu_seqlens=cu_seq_lens,
|
||||
max_s=max_seq_len,
|
||||
query, key, value = qkv.unbind(dim=1)
|
||||
output = torch.empty(
|
||||
num_tokens, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
_flash_attn_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
cu_seq_lens,
|
||||
cu_seq_lens,
|
||||
max_seq_len,
|
||||
max_seq_len,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=scale,
|
||||
causal=True,
|
||||
)[0]
|
||||
return_softmax=False,
|
||||
)
|
||||
|
||||
cu_seq_lens = cu_seq_lens.cpu().tolist()
|
||||
ref_output = ref_multi_query_kv_attention(
|
||||
|
||||
@ -17,10 +17,10 @@ def test_reshape_and_cache(
|
||||
slot_mapping = random.sample(range(num_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
|
||||
|
||||
kv_shape = (num_tokens, num_heads, head_size)
|
||||
key = torch.randn(size=kv_shape, dtype=dtype, device='cuda')
|
||||
value = torch.randn(size=kv_shape, dtype=dtype, device='cuda')
|
||||
|
||||
qkv = torch.randn(
|
||||
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
|
||||
_, key, value = qkv.unbind(dim=1)
|
||||
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
|
||||
key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')
|
||||
@ -35,7 +35,7 @@ def test_reshape_and_cache(
|
||||
|
||||
for i in range(num_tokens):
|
||||
reshaped_key = key.reshape(num_tokens, num_heads, head_size // x, x)
|
||||
block_idx = slot_mapping[i] // block_size
|
||||
block_idx = torch.div(slot_mapping[i], block_size, rounding_mode='floor')
|
||||
block_offset = slot_mapping[i] % block_size
|
||||
cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
|
||||
cloned_value_cache[block_idx, :, :, block_offset] = value[i]
|
||||
|
||||
@ -85,15 +85,13 @@ def test_rotary_embedding_neox(
|
||||
cos_sin_cache = torch.cat((cos, sin), dim=-1)
|
||||
cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda')
|
||||
|
||||
# Run the kernel.
|
||||
out_query = torch.empty_like(query)
|
||||
out_key = torch.empty_like(key)
|
||||
# Run the kernel. The kernel is in-place, so we need to clone the inputs.
|
||||
out_query = query.clone()
|
||||
out_key = key.clone()
|
||||
pos_encoding_ops.rotary_embedding_neox(
|
||||
positions,
|
||||
out_query,
|
||||
out_key,
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
cos_sin_cache,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user