Use runtime profiling to replace manual memory analyzers (#81)

This commit is contained in:
Zhuohan Li
2023-05-19 11:35:44 -06:00
committed by GitHub
parent 825d8892b5
commit f756799b84
14 changed files with 211 additions and 478 deletions

View File

@ -6,15 +6,14 @@ try:
import ray
except ImportError:
ray = None
import numpy as np
import torch
from cacheflow.core.scheduler import Scheduler
from cacheflow.frontend.simple_frontend import SimpleFrontend
from cacheflow.logger import init_logger
from cacheflow.model_executor import get_memory_analyzer
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import SequenceGroup
from cacheflow.utils import get_gpu_memory, get_cpu_memory
from cacheflow.worker.controller import Controller, DeviceID
logger = init_logger(__name__)
@ -34,14 +33,13 @@ class Server:
dtype: str,
seed: int,
swap_space: int,
gpu_memory_utilization: float,
max_num_batched_tokens: int,
max_num_sequences: int,
num_nodes: int,
num_devices_per_node: int,
distributed_init_method: str,
all_stage_devices: List[List[DeviceID]],
gpu_memory: int,
cpu_memory: int,
use_ray: bool,
log_stats: bool,
):
@ -63,19 +61,6 @@ class Server:
assert self.world_size == 1, (
"Only support single GPU without Ray.")
self.memory_analyzer = get_memory_analyzer(
model_name=model,
block_size=block_size,
dtype=dtype,
gpu_memory=gpu_memory,
cpu_memory=cpu_memory,
tensor_parallel_size=tensor_parallel_size,
)
self.num_gpu_blocks = self.memory_analyzer.get_max_num_gpu_blocks(
max_num_batched_tokens=max_num_batched_tokens)
self.num_cpu_blocks = self.memory_analyzer.get_max_num_cpu_blocks(
swap_space_gib=swap_space)
# Create a controller for each pipeline stage.
self.controllers: List[Controller] = []
for i in range(pipeline_parallel_size):
@ -87,19 +72,35 @@ class Server:
tensor_parallel_size=tensor_parallel_size,
distributed_init_method=distributed_init_method,
model_name=model,
block_size=block_size,
num_gpu_blocks=self.num_gpu_blocks,
num_cpu_blocks=self.num_cpu_blocks,
dtype=dtype,
seed=seed,
cache_dir=cache_dir,
use_dummy_weights=use_dummy_weights,
use_np_cache=use_np_cache,
max_num_batched_tokens=max_num_batched_tokens,
max_num_sequences=max_num_sequences,
use_ray=use_ray,
)
self.controllers.append(controller)
# Initialize cache engine.
all_worker_num_available_blocks = []
for controller in self.controllers:
all_worker_num_available_blocks.extend(
controller.get_num_available_blocks(
block_size, swap_space, gpu_memory_utilization)
)
# Since we use a shared centralized controller, we take the minimum
# number of blocks across all workers to make sure all the memory
# operators can be applied to all workers.
self.num_gpu_blocks = np.min([b[0] for b in all_worker_num_available_blocks])
self.num_cpu_blocks = np.min([b[1] for b in all_worker_num_available_blocks])
logger.info(f'# GPU blocks: {self.num_gpu_blocks}, '
f'# CPU blocks: {self.num_cpu_blocks}')
for controller in self.controllers:
controller.init_cache_engine(block_size, self.num_gpu_blocks,
self.num_cpu_blocks)
# Create a scheduler.
self.scheduler = Scheduler(
controllers=self.controllers,
@ -214,7 +215,11 @@ def initialize_cluster(
all_stage_devices)
_GiB = 1 << 30
def add_server_arguments(parser: argparse.ArgumentParser):
"""Shared arguments for CacheFlow servers."""
# Model arguments
parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name')
parser.add_argument('--cache-dir', type=str, default=None,
@ -238,6 +243,7 @@ def add_server_arguments(parser: argparse.ArgumentParser):
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU')
parser.add_argument('--gpu-memory-utilization', type=float, default=0.95, help='the percentage of GPU memory to be used for the model executor')
parser.add_argument('--max-num-batched-tokens', type=int, default=2560, help='maximum number of batched tokens per iteration')
parser.add_argument('--max-num-sequences', type=int, default=256, help='maximum number of sequences per iteration')
parser.add_argument('--log-stats', action='store_true', help='log system statistics')
@ -245,8 +251,11 @@ def add_server_arguments(parser: argparse.ArgumentParser):
def process_server_arguments(args: argparse.Namespace):
"""Post process the parsed arguments."""
if args.pipeline_parallel_size * args.tensor_parallel_size > 1:
args.use_ray = True
args.swap_space = args.swap_space * _GiB
args.max_num_sequences = min(args.max_num_sequences, args.max_num_batched_tokens)
return args
@ -274,14 +283,13 @@ def init_local_server_and_frontend_with_arguments(args: argparse.Namespace):
dtype=args.dtype,
seed=args.seed,
swap_space=args.swap_space,
gpu_memory_utilization=args.gpu_memory_utilization,
max_num_batched_tokens=args.max_num_batched_tokens,
max_num_sequences=args.max_num_sequences,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,
all_stage_devices=all_stage_devices,
gpu_memory=get_gpu_memory(),
cpu_memory=get_cpu_memory(),
use_ray=args.use_ray,
log_stats=args.log_stats,
)