Use runtime profiling to replace manual memory analyzers (#81)

This commit is contained in:
Zhuohan Li
2023-05-19 11:35:44 -06:00
committed by GitHub
parent 825d8892b5
commit f756799b84
14 changed files with 211 additions and 478 deletions

View File

@ -5,9 +5,6 @@ import torch
import torch.nn as nn
from transformers import AutoConfig, PretrainedConfig
from cacheflow.model_executor.memory_analyzer import (
CacheFlowMemoryAnalyzer, GPT2MemoryAnalyzer, GPTNeoXMemoryAnalyzer,
LlamaMemoryAnalyzer, OPTMemoryAnalyzer)
from cacheflow.model_executor.models import (
GPT2LMHeadModel, GPTNeoXForCausalLM, LlamaForCausalLM, OPTForCausalLM)
from cacheflow.model_executor.utils import get_torch_dtype
@ -22,14 +19,6 @@ _MODEL_REGISTRY = {
"OPTForCausalLM": OPTForCausalLM,
}
_MEMORY_ANALYZERS = {
"GPT2LMHeadModel": GPT2MemoryAnalyzer,
"GPTNeoXForCausalLM": GPTNeoXMemoryAnalyzer,
"LlamaForCausalLM": LlamaMemoryAnalyzer,
"OPTForCausalLM": OPTMemoryAnalyzer,
}
def _get_model_architecture(config: PretrainedConfig) -> nn.Module:
architectures = getattr(config, "architectures", [])
for arch in architectures:
@ -41,17 +30,6 @@ def _get_model_architecture(config: PretrainedConfig) -> nn.Module:
)
def _get_memory_analyzer(config: PretrainedConfig) -> CacheFlowMemoryAnalyzer:
architectures = getattr(config, "architectures", [])
for arch in architectures:
if arch in _MEMORY_ANALYZERS:
return _MEMORY_ANALYZERS[arch]
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {list(_MEMORY_ANALYZERS.keys())}"
)
def _get_dtype(config: PretrainedConfig, dtype: str) -> torch.dtype:
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None.
@ -100,18 +78,3 @@ def get_model(
model = model.cuda()
return model.eval(), torch_dtype
def get_memory_analyzer(
model_name: str,
block_size: int,
dtype: str,
gpu_memory: int,
cpu_memory: int,
tensor_parallel_size: int = 1,
) -> CacheFlowMemoryAnalyzer:
config = AutoConfig.from_pretrained(model_name)
torch_dtype = _get_dtype(config, dtype)
memory_analyzer = _get_memory_analyzer(config)
return memory_analyzer(
model_name, block_size, torch_dtype, gpu_memory, cpu_memory,
tensor_parallel_size)