Add docstrings to some modules and classes (#100)
This commit is contained in:
@ -1,3 +1,4 @@
|
||||
"""Custom activation functions."""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@ -5,6 +6,10 @@ from cacheflow import activation_ops
|
||||
|
||||
|
||||
class SiluAndMul(nn.Module):
|
||||
"""An activation function for SwiGLU.
|
||||
|
||||
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[1] // 2.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
"""Multi-head attention."""
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@ -11,6 +12,32 @@ from cacheflow.model_executor.input_metadata import InputMetadata
|
||||
|
||||
|
||||
class GPTCacheFlowAttention(nn.Module):
|
||||
"""GPT-style multi-head attention.
|
||||
|
||||
This class takes flattened 1D query, key, and value tensors as input. The
|
||||
input 1D tensors can be split into three parts: the prompt tokens, the
|
||||
generation tokens, and the paddings.
|
||||
|
||||
|<------------------------------------- num_valid_tokens ------------------------------------->|
|
||||
|<--------------- num_prompt_tokens -------------->|<------- num_generation_tokens (M) ------->|
|
||||
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|<--generation_0-->|...|<--generation_M-1-->|<--padding-->|
|
||||
|
||||
The prompts might have different lengths, while the generation tokens always
|
||||
have length 1. The paddings are appended to make the input length a multiple
|
||||
of 8, which is desirable for Tensor Cores.
|
||||
|
||||
The class does the following:
|
||||
1. Perform multi_query_kv_attention for the prompts. This operation does
|
||||
not use the KV cache.
|
||||
2. Wait for the cache operations (e.g., swap, copy) to finish. The cache
|
||||
operations are issued by the cache engine before executing the forward
|
||||
pass of the model, and they are executed asynchronously.
|
||||
3. Reshape and store the input key and value tensors in the KV cache.
|
||||
4. Perform single_query_cached_kv_attention for the generation tokens.
|
||||
This operation reads the previous key and value tensors from the KV
|
||||
cache.
|
||||
5. Output a flattened 1D tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, scale: float) -> None:
|
||||
super().__init__()
|
||||
@ -157,7 +184,7 @@ class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
|
||||
torch_dtype = torch.get_default_dtype()
|
||||
cache = cache.to(torch_dtype)
|
||||
# Embedding size: [max_position, rotary_dim]
|
||||
self.register_buffer('cos_sin_cache', cache, persistent=False)
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
"""Custom normalization layers."""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@ -5,6 +6,11 @@ from cacheflow import layernorm_ops
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
"""Root mean square normalization.
|
||||
|
||||
Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
|
||||
Refer to https://arxiv.org/abs/1910.07467
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
"""A layer that samples the next tokens from the model's outputs."""
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
@ -12,6 +13,19 @@ from cacheflow.sequence import SequenceOutputs
|
||||
|
||||
|
||||
class Sampler(nn.Module):
|
||||
"""Samples the next tokens from the model's outputs.
|
||||
|
||||
This layer does the following:
|
||||
1. Discard the hidden states that are not used for sampling (i.e., all
|
||||
tokens except the final one in each prompt).
|
||||
2. Compute the logits for the next tokens.
|
||||
3. Apply presence and frequency penalties.
|
||||
4. Apply temperature scaling.
|
||||
5. Apply top-p and top-k truncation.
|
||||
6. Sample the next tokens.
|
||||
Here, each sequence group within the batch can have different sampling
|
||||
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
|
||||
"""
|
||||
|
||||
def __init__(self, vocab_size: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
"""Utilities for selecting and loading models."""
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import AutoConfig
|
||||
from transformers import PretrainedConfig
|
||||
from transformers import AutoConfig, PretrainedConfig
|
||||
|
||||
from cacheflow.model_executor.memory_analyzer import (
|
||||
CacheFlowMemoryAnalyzer, GPT2MemoryAnalyzer, GPTNeoXMemoryAnalyzer,
|
||||
|
||||
@ -15,7 +15,11 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""1D GPT-2 model compatible with HuggingFace weights."""
|
||||
"""Inference-only GPT-2 model compatible with HuggingFace weights.
|
||||
|
||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||
InputMetadata to extract the original 2D shape of the input.
|
||||
"""
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@ -14,7 +14,11 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""1D GPT-NeoX model compatible with HuggingFace weights."""
|
||||
"""Inference-only GPT-NeoX model compatible with HuggingFace weights.
|
||||
|
||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||
InputMetadata to extract the original 2D shape of the input.
|
||||
"""
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@ -79,6 +83,7 @@ class GPTNeoXAttention(nn.Module):
|
||||
|
||||
|
||||
class GPTNeoXMLP(nn.Module):
|
||||
|
||||
def __init__(self, config: GPTNeoXConfig):
|
||||
super().__init__()
|
||||
self.dense_h_to_4h = ColumnParallelLinear(config.hidden_size,
|
||||
|
||||
@ -19,7 +19,11 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""1D LLaMA model compatible with HuggingFace weights."""
|
||||
"""Inference-only LLaMA model compatible with HuggingFace weights.
|
||||
|
||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||
InputMetadata to extract the original 2D shape of the input.
|
||||
"""
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@ -14,7 +14,11 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""1D OPT model compatible with HuggingFace weights."""
|
||||
"""Inference-only OPT model compatible with HuggingFace weights.
|
||||
|
||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||
InputMetadata to extract the original 2D shape of the input.
|
||||
"""
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
"""Utils for model executor."""
|
||||
import random
|
||||
from typing import Union
|
||||
|
||||
@ -9,11 +10,11 @@ from cacheflow.model_executor.parallel_utils.tensor_parallel import model_parall
|
||||
|
||||
|
||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
'half': torch.half,
|
||||
'float': torch.float,
|
||||
'float16': torch.float16,
|
||||
'float32': torch.float32,
|
||||
'bfloat16': torch.bfloat16,
|
||||
"half": torch.half,
|
||||
"float": torch.float,
|
||||
"float16": torch.float16,
|
||||
"float32": torch.float32,
|
||||
"bfloat16": torch.bfloat16,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
"""Utilities for downloading and initializing model weights."""
|
||||
import filelock
|
||||
import glob
|
||||
import json
|
||||
@ -106,5 +107,12 @@ def initialize_dummy_weights(
|
||||
low: float = -1e-3,
|
||||
high: float = 1e-3,
|
||||
) -> None:
|
||||
"""Initialize model weights with random values.
|
||||
|
||||
The model weights must be randomly initialized for accurate performance
|
||||
measurements. Additionally, the model weights should not cause NaNs in the
|
||||
forward pass. We empirically found that initializing the weights with
|
||||
values between -1e-3 and 1e-3 works well for most models.
|
||||
"""
|
||||
for param in model.state_dict().values():
|
||||
param.data.uniform_(low, high)
|
||||
|
||||
Reference in New Issue
Block a user