Refactor system architecture (#109)

This commit is contained in:
Woosuk Kwon
2023-05-20 13:06:59 -07:00
committed by GitHub
parent 7297fa6f7c
commit c3442c1f6f
24 changed files with 1017 additions and 1034 deletions

View File

@ -0,0 +1,74 @@
import argparse
from typing import Tuple
from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
from cacheflow.server.llm_server import LLMServer
from cacheflow.server.ray_utils import initialize_cluster
_GiB = 1 << 30
def add_server_arguments(parser: argparse.ArgumentParser):
"""Shared arguments for CacheFlow servers."""
# Model arguments
parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name')
parser.add_argument('--download-dir', type=str, default=None,
help='directory to download and load the weights, '
'default to the default cache dir of huggingface')
parser.add_argument('--use-np-weights', action='store_true',
help='save a numpy copy of model weights for faster loading')
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
# TODO(woosuk): Support FP32.
parser.add_argument('--dtype', type=str, default='default', choices=['default', 'half', 'bfloat16'],
help=('data type for model weights and activations. '
'The "default" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.'))
# Parallel arguments
parser.add_argument('--use-ray', action='store_true', help='use Ray for distributed serving, will be automatically set when using more than 1 GPU')
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas')
# KV cache arguments
parser.add_argument('--block-size', type=int, default=16, choices=[1, 2, 4, 8, 16, 32, 64, 128, 256], help='token block size')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--swap-space', type=int, default=4, help='CPU swap space size (GiB) per GPU')
parser.add_argument('--gpu-memory-utilization', type=float, default=0.95, help='the percentage of GPU memory to be used for the model executor')
parser.add_argument('--max-num-batched-tokens', type=int, default=2560, help='maximum number of batched tokens per iteration')
parser.add_argument('--max-num-seqs', type=int, default=256, help='maximum number of sequences per iteration')
parser.add_argument('--disable-log-stats', action='store_true', help='disable logging statistics')
return parser
def create_server_configs_from_args(
args: argparse.Namespace,
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
# Post-process the parsed arguments.
args.swap_space = args.swap_space * _GiB
args.max_num_seqs = min(args.max_num_seqs, args.max_num_batched_tokens)
# Initialize the configs.
model_config = ModelConfig(
args.model, args.download_dir, args.use_np_weights,
args.use_dummy_weights, args.dtype, args.seed)
cache_config = CacheConfig(args.block_size, args.gpu_memory_utilization,
args.swap_space)
parallel_config = ParallelConfig(args.pipeline_parallel_size,
args.tensor_parallel_size, args.use_ray)
scheduler_config = SchedulerConfig(args.max_num_batched_tokens,
args.max_num_seqs)
return model_config, cache_config, parallel_config, scheduler_config
def initialize_server_from_args(args: argparse.Namespace) -> LLMServer:
server_configs = create_server_configs_from_args(args)
parallel_config = server_configs[2]
# Initialize the cluster.
distributed_init_method, devices = initialize_cluster(parallel_config)
# Create the LLM server.
server = LLMServer(*server_configs, distributed_init_method, devices,
log_stats=not args.disable_log_stats)
return server

View File

@ -0,0 +1,198 @@
import time
from typing import Any, List, Optional
try:
import ray
except ImportError:
ray = None
from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
from cacheflow.core.scheduler import Scheduler
from cacheflow.logger import init_logger
from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.tokenizer_utils import get_tokenizer
from cacheflow.sequence import Sequence, SequenceGroup
from cacheflow.utils import Counter
from cacheflow.worker.worker import Worker
logger = init_logger(__name__)
class LLMServer:
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
distributed_init_method: str,
stage_devices: List[List[Any]],
log_stats: bool = True,
) -> None:
logger.info(
"Initializing an LLM server with config: "
f"model={model_config.model!r}, "
f"dtype={model_config.dtype}, "
f"use_dummy_weights={model_config.use_dummy_weights}, "
f"download_dir={model_config.download_dir!r}, "
f"use_np_weights={model_config.use_np_weights}, "
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
f"seed={model_config.seed})"
)
# TODO(woosuk): Print more configs in debug mode.
self.model_config = model_config
self.cache_config = cache_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.log_stats = log_stats
self._verify_args()
self.tokenizer = get_tokenizer(model_config.model)
self.seq_counter = Counter()
# Create the parallel GPU workers.
self.workers: List[Worker] = []
assert len(stage_devices) == 1, "Only support one stage for now."
for rank, node_resource, _ in stage_devices[0]:
worker_cls = Worker
if self.parallel_config.use_ray:
worker_cls = ray.remote(
num_cpus=0,
num_gpus=1,
resources={node_resource: 1e-5},
)(worker_cls).remote
worker = worker_cls(
model_config,
parallel_config,
scheduler_config,
rank,
distributed_init_method,
)
self.workers.append(worker)
# Profile the memory usage and initialize the cache.
self._init_cache()
# Create the scheduler.
self.scheduler = Scheduler(scheduler_config, cache_config, log_stats)
def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config)
def _init_cache(self) -> None:
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks = self._run_workers(
"profile_num_available_blocks",
get_all_outputs=True,
block_size=self.cache_config.block_size,
gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
cpu_swap_space=self.cache_config.swap_space,
)
# Since we use a shared centralized controller, we take the minimum
# number of blocks across all workers to make sure all the memory
# operators can be applied to all workers.
num_gpu_blocks = min(b[0] for b in num_blocks)
num_cpu_blocks = min(b[1] for b in num_blocks)
# FIXME(woosuk): Change to debug log.
logger.info(f'# GPU blocks: {num_gpu_blocks}, '
f'# CPU blocks: {num_cpu_blocks}')
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
# Initialize the cache.
self._run_workers("init_cache_engine", cache_config=self.cache_config)
def add_request(
self,
request_id: str,
prompt: str,
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
) -> None:
if arrival_time is None:
arrival_time = time.time()
if prompt_token_ids is None:
prompt_token_ids = self.tokenizer.encode(prompt)
# Create the sequences.
block_size = self.cache_config.block_size
seqs: List[Sequence] = []
for _ in range(sampling_params.n):
seq_id = next(self.seq_counter)
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
seqs.append(seq)
# FIXME(woosuk)
# Add the EOS token to the stop token list.
sampling_params.stop_token_ids.add(self.tokenizer.eos_token_id)
# Create the sequence group.
seq_group = SequenceGroup(request_id, seqs, sampling_params,
arrival_time)
# Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group)
def has_unfinished_requests(self) -> bool:
return self.scheduler.has_unfinished_seqs()
def step(self) -> List[RequestOutput]:
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
if (not seq_group_metadata_list) and scheduler_outputs.is_empty():
# Nothing to do.
return []
# Execute the model.
output = self._run_workers(
"execute_model",
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy,
)
# Update the scheduler.
updated_seq_groups = self.scheduler.update(output)
# Create the outputs.
request_outputs: List[RequestOutput] = []
for seq_group in updated_seq_groups:
# TODO(woosuk): Batch-decode the outputs for speedup.
request_output = RequestOutput.from_seq_group(seq_group,
self.tokenizer)
request_outputs.append(request_output)
return request_outputs
def _run_workers(
self,
method: str,
get_all_outputs: bool = False,
*args,
**kwargs,
) -> Any:
all_outputs = []
for worker in self.workers:
executor = getattr(worker, method)
if self.parallel_config.use_ray:
executor = executor.remote
output = executor(*args, **kwargs)
all_outputs.append(output)
if self.parallel_config.use_ray:
all_outputs = ray.get(all_outputs)
if get_all_outputs:
return all_outputs
# Make sure all workers have the same results.
output = all_outputs[0]
for other_output in all_outputs[1:]:
assert output == other_output
return output

View File

@ -0,0 +1,90 @@
import random
from typing import List, Optional, Tuple
try:
import ray
except ImportError:
ray = None
from cacheflow.config import ParallelConfig
DeviceID = Tuple[int, str, int] # rank, node resource (node IP), device id
def initialize_cluster(
parallel_config: ParallelConfig,
address: Optional[str] = None,
) -> Tuple[str, List[List[DeviceID]]]:
if not parallel_config.use_ray:
# Initialize cluster locally.
port = random.randint(10000, 20000)
# We need to setup the distributed init method to make sure
# the distributed megatron code (e.g., get world size) works correctly.
distributed_init_method = f"tcp://localhost:{port}"
all_stage_devices = [[(0, None, 0)]]
return distributed_init_method, all_stage_devices
if ray is None:
raise ImportError(
"Ray is not installed. Please install Ray to use distributed "
"serving.")
# Connect to a ray cluster.
ray.init(address=address)
# Assume we have a uniform cluster that each node has the same number of
# GPUs for now.
valid_node_resources = []
num_devices_per_node = None
for node in ray.nodes():
if (not node['Alive']) or node['Resources']['GPU'] <= 0:
continue
if num_devices_per_node is None:
num_devices_per_node = node['Resources']['GPU']
else:
assert num_devices_per_node == node['Resources']['GPU'], (
"The number of GPUs per node is not uniform.")
for key in node['Resources']:
if key.startswith('node:'):
valid_node_resources.append(key)
# Verify the parallel config.
num_nodes = len(valid_node_resources)
if parallel_config.world_size > num_nodes * num_devices_per_node:
raise ValueError(
"The number of required GPUs exceeds the total number of "
"available GPUs.")
if parallel_config.tensor_parallel_size >= num_devices_per_node:
if parallel_config.tensor_parallel_size % num_devices_per_node != 0:
raise ValueError(
"The number of tensor parallelism is not divisible by the "
"number of GPUs per node.")
else:
if num_devices_per_node % parallel_config.tensor_parallel_size != 0:
raise ValueError(
"The number of GPUs per node is not divisible by the number "
"of tensor parallelism.")
# Assign GPUs to pipeline stages.
rank = 0
current_node_id = 0
current_device_id = 0
distributed_init_method = None
all_stage_devices = []
for _ in range(parallel_config.pipeline_parallel_size):
stage_devices = []
for _ in range(parallel_config.tensor_parallel_size):
node_resource = valid_node_resources[current_node_id]
stage_devices.append((rank, node_resource, current_device_id))
if distributed_init_method is None:
ip = node_resource.split("node:")[-1]
port = random.randint(10000, 20000)
distributed_init_method = f"tcp://{ip}:{port}"
rank += 1
current_device_id += 1
if current_device_id >= num_devices_per_node:
current_node_id += 1
current_device_id = 0
all_stage_devices.append(stage_devices)
return distributed_init_method, all_stage_devices

View File

@ -0,0 +1,21 @@
from typing import Union
from transformers import (AutoConfig, AutoTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast)
_MODEL_TYPES_WITH_SLOW_TOKENIZER = [
# LLaMA fast tokenizer has a bug related to protobuf.
# See https://github.com/WoosukKwon/cacheflow/issues/80#issue-1698550554
"llama",
]
def get_tokenizer(
model_name: str,
*args,
**kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
config = AutoConfig.from_pretrained(model_name)
if config.model_type in _MODEL_TYPES_WITH_SLOW_TOKENIZER:
kwargs["use_fast"] = False
return AutoTokenizer.from_pretrained(model_name, *args, **kwargs)