Use runtime profiling to replace manual memory analyzers (#81)
This commit is contained in:
@ -11,6 +11,8 @@ from cacheflow import pos_encoding_ops
|
||||
from cacheflow.model_executor.input_metadata import InputMetadata
|
||||
|
||||
|
||||
_SUPPORTED_HEAD_SIZES = [32, 64, 80, 96, 128, 160, 192, 256]
|
||||
|
||||
class GPTCacheFlowAttention(nn.Module):
|
||||
"""GPT-style multi-head attention.
|
||||
|
||||
@ -39,11 +41,19 @@ class GPTCacheFlowAttention(nn.Module):
|
||||
5. Output a flattened 1D tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, scale: float) -> None:
|
||||
def __init__(self, num_heads: int, head_size: int, scale: float) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.attn_op = xops.fmha.cutlass.FwOp()
|
||||
|
||||
if self.head_size not in _SUPPORTED_HEAD_SIZES:
|
||||
raise ValueError(f'head_size ({self.head_size}) is not supported by '
|
||||
'the single_query_cached_kv_attention kernel. '
|
||||
'Use one of the following head sizes: '
|
||||
f'{_SUPPORTED_HEAD_SIZES}.')
|
||||
|
||||
def multi_query_kv_attention(
|
||||
self,
|
||||
output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
@ -74,14 +84,6 @@ class GPTCacheFlowAttention(nn.Module):
|
||||
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
||||
input_metadata: InputMetadata,
|
||||
) -> None:
|
||||
head_size = value_cache.shape[2]
|
||||
supported_head_sizes = [32, 64, 80, 96, 128, 160, 192, 256]
|
||||
if head_size not in supported_head_sizes:
|
||||
raise ValueError(f'head_size ({head_size}) is not supported by '
|
||||
'the single_query_cached_kv_attention kernel. '
|
||||
'Use one of the following head sizes: '
|
||||
f'{supported_head_sizes}.')
|
||||
|
||||
block_size = value_cache.shape[3]
|
||||
attention_ops.single_query_cached_kv_attention(
|
||||
output,
|
||||
@ -100,8 +102,8 @@ class GPTCacheFlowAttention(nn.Module):
|
||||
query: torch.Tensor, # [num_tokens, num_heads * head_size]
|
||||
key: torch.Tensor, # [num_tokens, num_heads * head_size]
|
||||
value: torch.Tensor, # [num_tokens, num_heads * head_size]
|
||||
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
|
||||
key_cache: Optional[torch.Tensor], # [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
value_cache: Optional[torch.Tensor], # [num_blocks, num_heads, head_size, block_size]
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
|
||||
@ -109,11 +111,9 @@ class GPTCacheFlowAttention(nn.Module):
|
||||
# tensor of shape [num_tokens, 3 * num_heads * head_size].
|
||||
|
||||
# Reshape the query, key, and value tensors.
|
||||
num_heads = value_cache.shape[1]
|
||||
head_size = value_cache.shape[2]
|
||||
query = query.view(-1, num_heads, head_size)
|
||||
key = key.view(-1, num_heads, head_size)
|
||||
value = value.view(-1, num_heads, head_size)
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_heads, self.head_size)
|
||||
value = value.view(-1, self.num_heads, self.head_size)
|
||||
|
||||
# Pre-allocate the output tensor.
|
||||
output = torch.empty_like(query)
|
||||
@ -134,8 +134,11 @@ class GPTCacheFlowAttention(nn.Module):
|
||||
cache_event.wait()
|
||||
|
||||
# Reshape the keys and values and store them in the cache.
|
||||
# When key_cache and value_cache are not provided, the new key
|
||||
# and value vectors will not be cached.
|
||||
num_valid_tokens = input_metadata.num_valid_tokens
|
||||
if num_valid_tokens > 0:
|
||||
if (num_valid_tokens > 0 and key_cache is not None
|
||||
and value_cache is not None):
|
||||
# The stride is 3 because the key and value are sliced from qkv.
|
||||
cache_ops.reshape_and_cache(
|
||||
key[:num_valid_tokens],
|
||||
@ -146,6 +149,10 @@ class GPTCacheFlowAttention(nn.Module):
|
||||
)
|
||||
|
||||
if input_metadata.num_generation_tokens > 0:
|
||||
assert key_cache is not None and value_cache is not None, (
|
||||
"key_cache and value_cache must be provided when "
|
||||
"generating tokens."
|
||||
)
|
||||
# Compute the attention op for generation tokens.
|
||||
self.single_query_cached_kv_attention(
|
||||
output[num_prompt_tokens:num_valid_tokens],
|
||||
@ -156,7 +163,7 @@ class GPTCacheFlowAttention(nn.Module):
|
||||
|
||||
# Reshape the output tensor.
|
||||
# NOTE(woosuk): The output tensor may include paddings.
|
||||
return output.view(-1, num_heads * head_size)
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
|
||||
|
||||
class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
|
||||
@ -164,12 +171,14 @@ class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
rotary_dim: int,
|
||||
max_position: int = 8192,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
super().__init__(scale)
|
||||
super().__init__(num_heads, head_size, scale)
|
||||
|
||||
# Create the cos and sin cache.
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim))
|
||||
@ -199,12 +208,11 @@ class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
|
||||
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
|
||||
# Apply rotary embedding to the query and key before passing them
|
||||
# to the attention op.
|
||||
head_size = value_cache.shape[2]
|
||||
pos_encoding_ops.rotary_embedding_neox(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
head_size,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
)
|
||||
return super().forward(
|
||||
|
||||
@ -74,7 +74,7 @@ class Sampler(nn.Module):
|
||||
# Apply top-p and top-k truncation.
|
||||
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
|
||||
assert len(top_ps) == len(top_ks) == probs.shape[0]
|
||||
if any(p < 1.0 for p in top_ps) or any(k != -1 for k in top_ks):
|
||||
if any(p < 1.0 for p in top_ps) or any(k != self.vocab_size for k in top_ks):
|
||||
probs = _apply_top_p_top_k(probs, top_ps, top_ks)
|
||||
|
||||
# Sample the next tokens.
|
||||
|
||||
Reference in New Issue
Block a user