Allocate more shared memory to attention kernel (#1154)
This commit is contained in:
@ -7,8 +7,12 @@ from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||
|
||||
from vllm import attention_ops
|
||||
from vllm.utils import get_max_shared_memory_bytes
|
||||
|
||||
MAX_SEQ_LEN = 8192
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
# This will change depending on the compute capability.
|
||||
# - 512 as a buffer
|
||||
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
|
||||
NUM_BLOCKS = 128 # Arbitrary values for testing
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
@ -135,6 +139,7 @@ def test_single_query_cached_kv_attention(
|
||||
device="cuda")
|
||||
|
||||
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
||||
context_lens[-1] = MAX_SEQ_LEN
|
||||
max_context_len = max(context_lens)
|
||||
context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
|
||||
|
||||
@ -243,6 +248,7 @@ def test_multi_query_kv_attention(
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs)
|
||||
seq_lens[-1] = MAX_SEQ_LEN
|
||||
num_tokens = sum(seq_lens)
|
||||
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
|
||||
Reference in New Issue
Block a user