Use FlashAttention for multi_query_kv_attention (#4)
This commit is contained in:
@ -1,5 +1,6 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from flash_attn.flash_attention import FlashAttention
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@ -14,20 +15,7 @@ class OPTCacheFlowAttention(nn.Module):
|
||||
super().__init__()
|
||||
self.scale = float(scale)
|
||||
|
||||
def _masked_attention(
|
||||
self,
|
||||
query: torch.Tensor, # [num_queries, num_heads, head_size]
|
||||
key: torch.Tensor, # [num_keys, num_heads, head_size]
|
||||
value: torch.Tensor, # [num_keys, num_heads, head_size]
|
||||
attn_mask: Optional[torch.Tensor] = None, # [num_queries, num_keys]
|
||||
) -> torch.Tensor: # [num_queries, num_heads, head_size]
|
||||
query = query * self.scale
|
||||
attn = torch.einsum('qhd,khd->hqk', query, key)
|
||||
if attn_mask is not None:
|
||||
attn = attn + attn_mask
|
||||
attn = torch.softmax(attn, dim=-1)
|
||||
out = torch.einsum('hqk,khd->qhd', attn, value)
|
||||
return out
|
||||
self.flash_attn = FlashAttention(softmax_scale=self.scale)
|
||||
|
||||
def multi_query_kv_attention(
|
||||
self,
|
||||
@ -37,21 +25,31 @@ class OPTCacheFlowAttention(nn.Module):
|
||||
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
prompt_lens: List[int],
|
||||
) -> None:
|
||||
# FIXME(woosuk): Replace the following with a custom op.
|
||||
start_idx = 0
|
||||
if query.dtype == torch.float:
|
||||
raise ValueError('The float data type is not supported by '
|
||||
'FlashAttention. Use the half data type instead.')
|
||||
head_size = query.shape[2]
|
||||
if head_size > 128:
|
||||
raise ValueError('FlashAttention does not support head_size > 128.')
|
||||
|
||||
device = query.device
|
||||
prefix_sum = [0]
|
||||
for prompt_len in prompt_lens:
|
||||
out = output[start_idx:start_idx + prompt_len]
|
||||
q = query[start_idx:start_idx + prompt_len]
|
||||
k = key[start_idx:start_idx + prompt_len]
|
||||
v = value[start_idx:start_idx + prompt_len]
|
||||
prefix_sum.append(prefix_sum[-1] + prompt_len)
|
||||
prefix_sum = torch.tensor(prefix_sum, dtype=torch.int, device=device)
|
||||
max_prompt_len = max(prompt_lens)
|
||||
|
||||
attention_mask = torch.triu(
|
||||
torch.ones(q.shape[0], k.shape[0]), diagonal=1) * -1e5
|
||||
attention_mask = attention_mask.to(dtype=q.dtype, device=q.device)
|
||||
attention_out = self._masked_attention(q, k, v, attention_mask)
|
||||
out.copy_(attention_out, non_blocking=True)
|
||||
|
||||
start_idx += prompt_len
|
||||
# FIXME(woosuk): Unnecessary copy. Optimize this.
|
||||
qkv = torch.stack([query, key, value], dim=1)
|
||||
out = self.flash_attn(
|
||||
qkv,
|
||||
cu_seqlens=prefix_sum,
|
||||
max_s=max_prompt_len,
|
||||
causal=True,
|
||||
)[0]
|
||||
num_tokens = prefix_sum[-1]
|
||||
# FIXME(woosuk): Unnecessary copy. Optimize this.
|
||||
output[:num_tokens].copy_(out, non_blocking=True)
|
||||
|
||||
def single_query_cached_kv_attention(
|
||||
self,
|
||||
@ -61,6 +59,14 @@ class OPTCacheFlowAttention(nn.Module):
|
||||
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
||||
input_metadata: InputMetadata,
|
||||
) -> None:
|
||||
head_size = value_cache.shape[2]
|
||||
supported_head_sizes = [32, 64, 80, 96, 128, 160, 192, 256]
|
||||
if head_size not in supported_head_sizes:
|
||||
raise ValueError(f'head_size ({head_size}) is not supported by '
|
||||
'the single_query_cached_kv_attention kernel. '
|
||||
'Use one of the following head sizes: '
|
||||
f'{supported_head_sizes}.')
|
||||
|
||||
block_size = value_cache.shape[3]
|
||||
attention_ops.single_query_cached_kv_attention(
|
||||
output,
|
||||
@ -101,8 +107,9 @@ class OPTCacheFlowAttention(nn.Module):
|
||||
output = output.view(-1, num_heads, head_size)
|
||||
|
||||
# Compute the attention op for prompts.
|
||||
self.multi_query_kv_attention(
|
||||
output, query, key, value, input_metadata.prompt_lens)
|
||||
if input_metadata.num_prompts > 0:
|
||||
self.multi_query_kv_attention(
|
||||
output, query, key, value, input_metadata.prompt_lens)
|
||||
|
||||
# Wait until the cache op is done.
|
||||
if cache_event is not None:
|
||||
|
||||
Reference in New Issue
Block a user