Refactor system architecture (#109)
This commit is contained in:
@ -4,6 +4,7 @@ from typing import Dict, List, Tuple
|
||||
import torch
|
||||
|
||||
from cacheflow import cache_ops
|
||||
from cacheflow.config import CacheConfig, ModelConfig, ParallelConfig
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
@ -18,27 +19,22 @@ class CacheEngine:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
worker_id: int,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
num_gpu_blocks: int,
|
||||
num_cpu_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
cache_config: CacheConfig,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
) -> None:
|
||||
if head_size % 16 != 0:
|
||||
raise ValueError(
|
||||
f'head_size ({head_size}) must be a multiple of 16.')
|
||||
self.cache_config = cache_config
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
|
||||
self.worker_id = worker_id
|
||||
self.num_layers = num_layers
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.block_size = block_size
|
||||
self.num_gpu_blocks = num_gpu_blocks
|
||||
self.num_cpu_blocks = num_cpu_blocks
|
||||
self.dtype = dtype
|
||||
self.head_size = model_config.get_head_size()
|
||||
self.num_layers = model_config.get_num_layers(parallel_config)
|
||||
self.num_heads = model_config.get_num_heads(parallel_config)
|
||||
self.dtype = model_config.dtype
|
||||
|
||||
self.block_size = cache_config.block_size
|
||||
self.num_gpu_blocks = cache_config.num_gpu_blocks
|
||||
self.num_cpu_blocks = cache_config.num_cpu_blocks
|
||||
|
||||
# Initialize the cache.
|
||||
self.gpu_cache = self.allocate_gpu_cache()
|
||||
@ -48,7 +44,7 @@ class CacheEngine:
|
||||
self.cache_stream = torch.cuda.Stream()
|
||||
assert self.cache_stream != torch.cuda.current_stream()
|
||||
# Initialize the events for stream synchronization.
|
||||
self.events = [torch.cuda.Event() for _ in range(num_layers)]
|
||||
self.events = [torch.cuda.Event() for _ in range(self.num_layers)]
|
||||
|
||||
def get_key_block_shape(self) -> Tuple[int, int, int, int]:
|
||||
element_size = torch.tensor([], dtype=self.dtype).element_size()
|
||||
@ -133,3 +129,23 @@ class CacheEngine:
|
||||
value_caches = [value_cache for _, value_cache in self.gpu_cache]
|
||||
# NOTE(woosuk): This operation implicitly synchronizes the CPU and GPU.
|
||||
cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts)
|
||||
|
||||
@staticmethod
|
||||
def get_cache_block_size(
|
||||
block_size: int,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
) -> int:
|
||||
head_size = model_config.get_head_size()
|
||||
num_heads = model_config.get_num_heads(parallel_config)
|
||||
num_layers = model_config.get_num_layers(parallel_config)
|
||||
|
||||
key_cache_block = block_size * num_heads * head_size
|
||||
value_cache_block = key_cache_block
|
||||
total = num_layers * (key_cache_block + value_cache_block)
|
||||
dtype_size = _get_dtype_size(model_config.dtype)
|
||||
return dtype_size * total
|
||||
|
||||
|
||||
def _get_dtype_size(dtype: torch.dtype) -> int:
|
||||
return torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
@ -1,130 +0,0 @@
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
try:
|
||||
import ray
|
||||
except ImportError:
|
||||
ray = None
|
||||
|
||||
from cacheflow.core.scheduler import Scheduler
|
||||
from cacheflow.worker.worker import Worker
|
||||
|
||||
|
||||
DeviceID = Tuple[int, str, int] # rank, node resource (node IP), device id
|
||||
|
||||
|
||||
class Controller:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stage_id: int,
|
||||
stage_devices: List[DeviceID],
|
||||
world_size: int,
|
||||
tensor_parallel_size: int,
|
||||
pipeline_parallel_size: int,
|
||||
distributed_init_method: str,
|
||||
model_name: str,
|
||||
dtype: str,
|
||||
seed: int,
|
||||
cache_dir: Optional[str],
|
||||
use_dummy_weights: bool,
|
||||
use_np_cache: bool,
|
||||
max_num_batched_tokens: int,
|
||||
max_num_sequences: int,
|
||||
use_ray: bool,
|
||||
) -> None:
|
||||
self.stage_id = stage_id
|
||||
self.stage_devices = stage_devices
|
||||
self.model_name = model_name
|
||||
self.use_ray = use_ray
|
||||
|
||||
# Which pipeline stage is this node assigned to?
|
||||
self.is_first_stage = stage_id == 0
|
||||
self.is_last_stage = False
|
||||
|
||||
self.workers: List[Worker] = []
|
||||
for rank, node_resource, device_id in stage_devices:
|
||||
if self.use_ray:
|
||||
worker_cls = ray.remote(num_cpus=0,
|
||||
num_gpus=1,
|
||||
resources={node_resource: 1e-5})(Worker).remote
|
||||
else:
|
||||
worker_cls = Worker
|
||||
worker = worker_cls(
|
||||
model_name=model_name,
|
||||
dtype=dtype,
|
||||
seed=seed,
|
||||
distributed_init_method=distributed_init_method,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
pipeline_parallel_size=pipeline_parallel_size,
|
||||
cache_dir=cache_dir,
|
||||
use_dummy_weights=use_dummy_weights,
|
||||
use_np_cache=use_np_cache,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_num_sequences=max_num_sequences,
|
||||
)
|
||||
self.workers.append(worker)
|
||||
|
||||
def get_num_available_blocks(self, block_size: int, cpu_swap_space: int,
|
||||
gpu_memory_utilization: float) -> List[Tuple[int, int]]:
|
||||
all_worker_results = []
|
||||
for worker in self.workers:
|
||||
executor = worker.get_num_available_blocks
|
||||
if self.use_ray:
|
||||
executor = executor.remote
|
||||
|
||||
result = executor(
|
||||
block_size,
|
||||
cpu_swap_space,
|
||||
gpu_memory_utilization,
|
||||
)
|
||||
all_worker_results.append(result)
|
||||
if self.use_ray:
|
||||
all_worker_results = ray.get(all_worker_results)
|
||||
return all_worker_results
|
||||
|
||||
def init_cache_engine(self, block_size: int, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int):
|
||||
all_worker_futures = []
|
||||
for worker in self.workers:
|
||||
executor = worker.init_cache_engine
|
||||
if self.use_ray:
|
||||
executor = executor.remote
|
||||
future = executor(
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
num_cpu_blocks,
|
||||
)
|
||||
all_worker_futures.append(future)
|
||||
if self.use_ray:
|
||||
ray.get(all_worker_futures)
|
||||
|
||||
def set_next(
|
||||
self,
|
||||
next_node: Union['Controller', 'Scheduler'],
|
||||
) -> None:
|
||||
self.next_node = next_node
|
||||
self.is_last_stage = isinstance(next_node, Scheduler)
|
||||
|
||||
def execute_stage(self, *args, **kwargs) -> None:
|
||||
all_outputs = []
|
||||
for worker in self.workers:
|
||||
executor = (worker.execute_stage.remote
|
||||
if self.use_ray else worker.execute_stage)
|
||||
output = executor(*args, **kwargs)
|
||||
all_outputs.append(output)
|
||||
|
||||
if self.use_ray:
|
||||
all_outputs = ray.get(all_outputs)
|
||||
|
||||
# Make sure all workers have the same results.
|
||||
output = all_outputs[0]
|
||||
for other_output in all_outputs[1:]:
|
||||
assert output == other_output
|
||||
|
||||
if self.is_last_stage:
|
||||
self.next_node.post_step(output)
|
||||
else:
|
||||
# TODO: Support pipeline parallelism.
|
||||
assert False
|
||||
@ -1,14 +1,13 @@
|
||||
"""A GPU worker class."""
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from cacheflow.model_executor import (get_model, get_cache_block_size,
|
||||
InputMetadata, set_random_seed)
|
||||
from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
from cacheflow.model_executor import get_model, InputMetadata, set_random_seed
|
||||
from cacheflow.model_executor.parallel_utils.parallel_state import (
|
||||
initialize_model_parallel,
|
||||
initialize_all_reduce_launcher,
|
||||
get_tensor_model_parallel_world_size)
|
||||
initialize_model_parallel, initialize_all_reduce_launcher)
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.sequence import (SequenceData, SequenceGroupMetadata,
|
||||
SequenceOutputs)
|
||||
@ -26,59 +25,46 @@ class Worker:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
dtype: str,
|
||||
seed: int,
|
||||
distributed_init_method: str,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
cache_dir: Optional[str],
|
||||
use_dummy_weights: bool,
|
||||
use_np_cache: bool,
|
||||
max_num_batched_tokens: int,
|
||||
max_num_sequences: int,
|
||||
tensor_parallel_size: int = 1,
|
||||
pipeline_parallel_size: int = 1,
|
||||
distributed_init_method: str,
|
||||
) -> None:
|
||||
self.init_distributed_environment(distributed_init_method,
|
||||
rank,
|
||||
world_size,
|
||||
tensor_parallel_size,
|
||||
pipeline_parallel_size)
|
||||
self.worker_id = rank
|
||||
self.seed = seed
|
||||
set_random_seed(self.seed)
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
|
||||
# Initialize the distributed environment.
|
||||
_init_distributed_environment(parallel_config, rank,
|
||||
distributed_init_method)
|
||||
|
||||
# Initialize the model.
|
||||
self.model, self.dtype = get_model(
|
||||
model_name, dtype=dtype, cache_dir=cache_dir,
|
||||
use_dummy_weights=use_dummy_weights, use_np_cache=use_np_cache)
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
set_random_seed(self.model_config.seed)
|
||||
self.model = get_model(model_config)
|
||||
initialize_all_reduce_launcher(
|
||||
self.max_num_batched_tokens, self.model.config.hidden_size, self.dtype)
|
||||
self.max_num_sequences = max_num_sequences
|
||||
self.num_layers = self.model.config.num_hidden_layers
|
||||
assert self.model.config.num_attention_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = self.model.config.num_attention_heads // tensor_model_parallel_world_size
|
||||
self.head_size = self.model.config.hidden_size // (self.num_heads * tensor_model_parallel_world_size)
|
||||
self.scheduler_config.max_num_batched_tokens,
|
||||
self.model_config.get_hidden_size(),
|
||||
self.model_config.dtype,
|
||||
)
|
||||
|
||||
# We reset the seed after initializing the model to ensure that
|
||||
# the random state is not affected by the model initialization.
|
||||
set_random_seed(seed)
|
||||
|
||||
# Uninitialized cache engine. Will be initialized with
|
||||
# Uninitialized cache engine. Will be initialized by
|
||||
# self.init_cache_engine().
|
||||
self.cache_config = None
|
||||
self.block_size = None
|
||||
self.cache_engine = None
|
||||
self.cache_events = None
|
||||
self.gpu_cache = None
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_num_available_blocks(
|
||||
self, block_size: int, cpu_swap_space: int,
|
||||
gpu_memory_utilization: float) -> Tuple[int, int]:
|
||||
def profile_num_available_blocks(
|
||||
self,
|
||||
block_size: int,
|
||||
gpu_memory_utilization: float,
|
||||
cpu_swap_space: int,
|
||||
) -> Tuple[int, int]:
|
||||
# Profile the memory usage of the model and get the maximum number of
|
||||
# cache blocks that can be allocated with the remaining free memory.
|
||||
torch.cuda.empty_cache()
|
||||
@ -90,14 +76,15 @@ class Worker:
|
||||
# Enable top-k sampling to reflect the accurate memory usage.
|
||||
sampling_params = SamplingParams(top_p=0.99,
|
||||
top_k=self.model.config.vocab_size - 1)
|
||||
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
max_num_seqs = self.scheduler_config.max_num_seqs
|
||||
seqs = []
|
||||
for group_id in range(self.max_num_sequences):
|
||||
seq_len = (self.max_num_batched_tokens // self.max_num_sequences +
|
||||
(group_id < self.max_num_batched_tokens %
|
||||
self.max_num_sequences))
|
||||
for group_id in range(max_num_seqs):
|
||||
seq_len = (max_num_batched_tokens // max_num_seqs +
|
||||
(group_id < max_num_batched_tokens % max_num_seqs))
|
||||
seq_data = SequenceData([0] * seq_len)
|
||||
seq = SequenceGroupMetadata(
|
||||
group_id=group_id,
|
||||
request_id=str(group_id),
|
||||
is_prompt=True,
|
||||
seq_data={group_id: seq_data},
|
||||
sampling_params=sampling_params,
|
||||
@ -105,13 +92,14 @@ class Worker:
|
||||
)
|
||||
seqs.append(seq)
|
||||
|
||||
input_tokens, input_positions, input_metadata = self.prepare_inputs(seqs)
|
||||
input_tokens, input_positions, input_metadata = self._prepare_inputs(seqs)
|
||||
|
||||
# Execute the model.
|
||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||
self.model(
|
||||
input_ids=input_tokens,
|
||||
positions=input_positions,
|
||||
kv_caches=[(None, None)] * self.num_layers,
|
||||
kv_caches=[(None, None)] * num_layers,
|
||||
input_metadata=input_metadata,
|
||||
cache_events=None,
|
||||
)
|
||||
@ -121,53 +109,27 @@ class Worker:
|
||||
torch.cuda.synchronize()
|
||||
peak_memory = torch.cuda.max_memory_allocated()
|
||||
total_gpu_memory = get_gpu_memory()
|
||||
cache_block_size = get_cache_block_size(block_size, self.num_heads,
|
||||
self.head_size, self.num_layers,
|
||||
self.dtype)
|
||||
cache_block_size = CacheEngine.get_cache_block_size(
|
||||
block_size, self.model_config, self.parallel_config)
|
||||
num_gpu_blocks = int((total_gpu_memory * gpu_memory_utilization
|
||||
- peak_memory) // cache_block_size)
|
||||
num_cpu_blocks = int(cpu_swap_space // cache_block_size)
|
||||
torch.cuda.empty_cache()
|
||||
# Reset the seed to ensure that the model output is not affected by
|
||||
# the profiling.
|
||||
set_random_seed(self.seed)
|
||||
|
||||
# Reset the seed to ensure that the random state is not affected by
|
||||
# the model initialization and profiling.
|
||||
set_random_seed(self.model_config.seed)
|
||||
return num_gpu_blocks, num_cpu_blocks
|
||||
|
||||
def init_cache_engine(self, block_size: int, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int):
|
||||
self.block_size = block_size
|
||||
def init_cache_engine(self, cache_config: CacheConfig) -> None:
|
||||
self.cache_config = cache_config
|
||||
self.block_size = cache_config.block_size
|
||||
self.cache_engine = CacheEngine(
|
||||
worker_id=self.worker_id,
|
||||
num_layers=self.num_layers,
|
||||
num_heads=self.num_heads,
|
||||
head_size=self.head_size,
|
||||
block_size=self.block_size,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=num_cpu_blocks,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.cache_config, self.model_config, self.parallel_config)
|
||||
self.cache_events = self.cache_engine.events
|
||||
self.gpu_cache = self.cache_engine.gpu_cache
|
||||
|
||||
def init_distributed_environment(self,
|
||||
distributed_init_method: str,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
tensor_parallel_size: int = 1,
|
||||
pipeline_parallel_size: int = 1) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
torch.distributed.init_process_group(
|
||||
backend='nccl',
|
||||
init_method=distributed_init_method,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
# A small all_reduce for warmup.
|
||||
torch.distributed.all_reduce(torch.zeros(1).cuda())
|
||||
initialize_model_parallel(tensor_parallel_size,
|
||||
pipeline_parallel_size)
|
||||
|
||||
def prepare_inputs(
|
||||
def _prepare_inputs(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.LongTensor, torch.LongTensor, InputMetadata]:
|
||||
@ -284,7 +246,7 @@ class Worker:
|
||||
return tokens_tensor, positions_tensor, input_metadata
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_stage(
|
||||
def execute_model(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
@ -316,7 +278,7 @@ class Worker:
|
||||
return {}
|
||||
|
||||
# Prepare input tensors.
|
||||
input_tokens, input_positions, input_metadata = self.prepare_inputs(
|
||||
input_tokens, input_positions, input_metadata = self._prepare_inputs(
|
||||
seq_group_metadata_list)
|
||||
|
||||
# Execute the model.
|
||||
@ -330,6 +292,24 @@ class Worker:
|
||||
return output
|
||||
|
||||
|
||||
def _init_distributed_environment(
|
||||
parallel_config: ParallelConfig,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
torch.distributed.init_process_group(
|
||||
backend="nccl",
|
||||
world_size=parallel_config.world_size,
|
||||
rank=rank,
|
||||
init_method=distributed_init_method,
|
||||
)
|
||||
# A small all_reduce for warmup.
|
||||
torch.distributed.all_reduce(torch.zeros(1).cuda())
|
||||
initialize_model_parallel(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
|
||||
|
||||
def _pad_to_alignment(x: List[int], multiple_of: int) -> List[int]:
|
||||
return x + [0] * ((-len(x)) % multiple_of)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user