[core] set up data parallel communication (#13591)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@ -134,7 +134,9 @@ steps:
|
||||
- tests/compile/test_basic_correctness
|
||||
- examples/offline_inference/rlhf.py
|
||||
- examples/offline_inference/rlhf_colocate.py
|
||||
- tests/examples/offline_inference/data_parallel.py
|
||||
commands:
|
||||
- VLLM_USE_V1=1 python3 ../examples/offline_inference/data_parallel.py
|
||||
- pytest -v -s distributed/test_utils.py
|
||||
- pytest -v -s compile/test_basic_correctness.py
|
||||
- pytest -v -s distributed/test_pynccl.py
|
||||
|
||||
76
examples/offline_inference/data_parallel.py
Normal file
76
examples/offline_inference/data_parallel.py
Normal file
@ -0,0 +1,76 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# usage: VLLM_USE_V1=1 python examples/offline_inference/data_parallel.py
|
||||
# we need to have a launcher to create multiple data parallel
|
||||
# ranks. And each rank will create a vLLM instance to process its own prompts.
|
||||
import os
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.utils import get_open_port
|
||||
|
||||
|
||||
def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank):
|
||||
os.environ["VLLM_DP_RANK"] = str(dp_rank)
|
||||
os.environ["VLLM_DP_SIZE"] = str(dp_size)
|
||||
os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip
|
||||
os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port)
|
||||
# set devices for each dp_rank
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
|
||||
str(i) for i in range(dp_rank * GPUs_per_dp_rank, (dp_rank + 1) *
|
||||
GPUs_per_dp_rank))
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
# with DP, each rank should process different prompts.
|
||||
# usually all the DP ranks process a full dataset,
|
||||
# and each rank processes a different part of the dataset.
|
||||
promts_per_rank = len(prompts) // dp_size
|
||||
start = dp_rank * promts_per_rank
|
||||
end = start + promts_per_rank
|
||||
prompts = prompts[start:end]
|
||||
if len(prompts) == 0:
|
||||
# if any rank has no prompts to process,
|
||||
# we need to set a placeholder prompt
|
||||
prompts = ["Placeholder"]
|
||||
print(f"DP rank {dp_rank} needs to process {len(prompts)} prompts")
|
||||
|
||||
# Create a sampling params object.
|
||||
# since we are doing data parallel, every rank can have different
|
||||
# sampling params. here we set different max_tokens for different
|
||||
# ranks for demonstration.
|
||||
sampling_params = SamplingParams(temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=16 * (dp_rank + 1))
|
||||
|
||||
# Create an LLM.
|
||||
llm = LLM(model="facebook/opt-125m", tensor_parallel_size=2, enforce_eager=True)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(
|
||||
f"DP rank {dp_rank}, Prompt: {prompt!r}, "
|
||||
f"Generated text: {generated_text!r}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from multiprocessing import Process
|
||||
dp_size = 2
|
||||
GPUs_per_dp_rank = 2
|
||||
dp_master_ip = "127.0.0.1"
|
||||
dp_master_port = get_open_port()
|
||||
procs = []
|
||||
for i in range(dp_size):
|
||||
proc = Process(target=main,
|
||||
args=(dp_size, i, dp_master_ip, dp_master_port,
|
||||
GPUs_per_dp_rank))
|
||||
proc.start()
|
||||
procs.append(proc)
|
||||
for proc in procs:
|
||||
proc.join()
|
||||
@ -16,6 +16,7 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Counter, Dict,
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
from torch.distributed import ProcessGroup, ReduceOp
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
import vllm.envs as envs
|
||||
@ -1296,6 +1297,11 @@ class ParallelConfig:
|
||||
|
||||
pipeline_parallel_size: int = 1 # Number of pipeline parallel groups.
|
||||
tensor_parallel_size: int = 1 # Number of tensor parallel groups.
|
||||
data_parallel_size: int = 1 # Number of data parallel groups.
|
||||
data_parallel_rank: int = 0 # Rank of the data parallel group.
|
||||
# IP of the data parallel master.
|
||||
data_parallel_master_ip: str = "127.0.0.1"
|
||||
data_parallel_master_port: int = 29500 # Port of the data parallel master.
|
||||
|
||||
# Maximum number of multiple batches
|
||||
# when load model sequentially. To avoid RAM OOM when using tensor
|
||||
@ -1329,10 +1335,55 @@ class ParallelConfig:
|
||||
worker_cls: str = "auto"
|
||||
sd_worker_cls: str = "auto"
|
||||
|
||||
# world_size is TPxPP, it affects the number of workers we create.
|
||||
world_size: int = field(init=False)
|
||||
# world_size_across_dp is TPxPPxDP, it is the size of the world
|
||||
# including data parallelism.
|
||||
world_size_across_dp: int = field(init=False)
|
||||
|
||||
rank: int = 0
|
||||
|
||||
def get_next_dp_init_port(self) -> int:
|
||||
"""
|
||||
We might need to initialize process groups in multiple
|
||||
processes that is related to data parallelism,
|
||||
e.g. both in the worker and in the engine, which
|
||||
can live in different processes. To avoid port conflicts, we
|
||||
increment the port number each time we need to initialize a
|
||||
new process group related to data parallelism.
|
||||
"""
|
||||
answer = self.data_parallel_master_port
|
||||
self.data_parallel_master_port += 1
|
||||
return answer
|
||||
|
||||
def stateless_init_dp_group(self) -> "ProcessGroup":
|
||||
from vllm.distributed.utils import (
|
||||
stateless_init_torch_distributed_process_group)
|
||||
|
||||
# use gloo since the engine process might not have cuda device
|
||||
dp_group = stateless_init_torch_distributed_process_group(
|
||||
self.data_parallel_master_ip,
|
||||
self.get_next_dp_init_port(),
|
||||
self.data_parallel_rank,
|
||||
self.data_parallel_size,
|
||||
backend="gloo")
|
||||
|
||||
return dp_group
|
||||
|
||||
@staticmethod
|
||||
def has_unfinished_dp(dp_group: "ProcessGroup",
|
||||
has_unfinished: bool) -> bool:
|
||||
tensor = torch.tensor([has_unfinished],
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
# dp rank 0: has_unfinished_seqs=True
|
||||
# dp rank 1: has_unfinished_seqs=False
|
||||
# aggregated: has_unfinished_seqs=True
|
||||
# so this is an OR operation, i.e. MAX in integers
|
||||
torch.distributed.all_reduce(tensor, op=ReduceOp.MAX, group=dp_group)
|
||||
aggregated_has_unfinished = bool(tensor.item())
|
||||
return aggregated_has_unfinished
|
||||
|
||||
def compute_hash(self):
|
||||
"""
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
@ -1350,6 +1401,12 @@ class ParallelConfig:
|
||||
self.world_size = self.pipeline_parallel_size * \
|
||||
self.tensor_parallel_size
|
||||
|
||||
self.data_parallel_size = envs.VLLM_DP_SIZE
|
||||
self.data_parallel_rank = envs.VLLM_DP_RANK
|
||||
self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
|
||||
self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT
|
||||
self.world_size_across_dp = self.world_size * self.data_parallel_size
|
||||
|
||||
ray_only_devices = ["tpu"]
|
||||
from vllm.platforms import current_platform
|
||||
if (current_platform.device_type in ray_only_devices
|
||||
|
||||
@ -16,8 +16,8 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
device_group: Optional[ProcessGroup] = None,
|
||||
unique_name: str = ""):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
if "pp" in unique_name:
|
||||
# pipeline parallel does not need custom allreduce
|
||||
if "tp" not in unique_name:
|
||||
# only tp uses custom allreduce
|
||||
use_custom_allreduce = False
|
||||
else:
|
||||
from vllm.distributed.parallel_state import (
|
||||
|
||||
@ -87,6 +87,7 @@ class CustomAllreduce:
|
||||
return
|
||||
|
||||
rank = dist.get_rank(group=self.group)
|
||||
self.rank = rank
|
||||
world_size = dist.get_world_size(group=self.group)
|
||||
if world_size == 1:
|
||||
# No need to initialize custom allreduce for single GPU case.
|
||||
@ -201,8 +202,10 @@ class CustomAllreduce:
|
||||
|
||||
@staticmethod
|
||||
def free_shared_buffer(pointers: List[int],
|
||||
group: Optional[ProcessGroup] = None) -> None:
|
||||
rank = dist.get_rank(group=group)
|
||||
group: Optional[ProcessGroup] = None,
|
||||
rank: Optional[int] = None) -> None:
|
||||
if rank is None:
|
||||
rank = dist.get_rank(group=group)
|
||||
lib = CudaRTLibrary()
|
||||
lib.cudaFree(ctypes.c_void_p(pointers[rank]))
|
||||
|
||||
@ -298,8 +301,8 @@ class CustomAllreduce:
|
||||
if not self.disabled and self._ptr:
|
||||
ops.dispose(self._ptr)
|
||||
self._ptr = 0
|
||||
self.free_shared_buffer(self.meta_ptrs)
|
||||
self.free_shared_buffer(self.buffer_ptrs)
|
||||
self.free_shared_buffer(self.meta_ptrs, rank=self.rank)
|
||||
self.free_shared_buffer(self.buffer_ptrs, rank=self.rank)
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
@ -750,6 +750,13 @@ get_tensor_model_parallel_group = get_tp_group
|
||||
|
||||
_PP: Optional[GroupCoordinator] = None
|
||||
|
||||
_DP: Optional[GroupCoordinator] = None
|
||||
|
||||
|
||||
def get_dp_group() -> GroupCoordinator:
|
||||
assert _DP is not None, ("data parallel group is not initialized")
|
||||
return _DP
|
||||
|
||||
|
||||
def get_pp_group() -> GroupCoordinator:
|
||||
assert _PP is not None, (
|
||||
@ -811,6 +818,21 @@ def init_distributed_environment(
|
||||
"world_size=%d rank=%d local_rank=%d "
|
||||
"distributed_init_method=%s backend=%s", world_size, rank, local_rank,
|
||||
distributed_init_method, backend)
|
||||
from vllm.config import get_current_vllm_config
|
||||
config = get_current_vllm_config()
|
||||
if config is not None and config.parallel_config.data_parallel_size > 1:
|
||||
parallel_config = config.parallel_config
|
||||
# adjust to take into account data parallelism
|
||||
# offset the rank by the data parallel rank
|
||||
rank = parallel_config.data_parallel_rank * world_size + rank
|
||||
# adjust the world size to take into account data parallelism
|
||||
world_size = parallel_config.world_size_across_dp
|
||||
ip = parallel_config.data_parallel_master_ip
|
||||
port = parallel_config.get_next_dp_init_port()
|
||||
distributed_init_method = f"tcp://{ip}:{port}" # noqa
|
||||
logger.info(
|
||||
"Adjusting world_size=%d rank=%d distributed_init_method=%s for DP",
|
||||
world_size, rank, distributed_init_method)
|
||||
if not torch.distributed.is_initialized():
|
||||
assert distributed_init_method is not None, (
|
||||
"distributed_init_method must be provided when initializing "
|
||||
@ -870,20 +892,28 @@ def initialize_model_parallel(
|
||||
# Get world size and rank. Ensure some consistencies.
|
||||
assert torch.distributed.is_initialized()
|
||||
world_size: int = torch.distributed.get_world_size()
|
||||
rank = torch.distributed.get_rank()
|
||||
backend = backend or torch.distributed.get_backend(
|
||||
get_world_group().device_group)
|
||||
|
||||
data_parallel_size = 1
|
||||
from vllm.config import get_current_vllm_config
|
||||
config = get_current_vllm_config()
|
||||
if config is not None:
|
||||
data_parallel_size = config.parallel_config.data_parallel_size
|
||||
|
||||
# the layout order is: DP x PP x TP
|
||||
# to get group_ranks for each dimension, transpose that dimension to the
|
||||
# last dimension, then reshape to 2D, then unbind the last dimension
|
||||
all_ranks = torch.arange(world_size).reshape(
|
||||
data_parallel_size, pipeline_model_parallel_size,
|
||||
tensor_model_parallel_size) # noqa
|
||||
|
||||
# Build the tensor model-parallel groups.
|
||||
num_tensor_model_parallel_groups: int = (world_size //
|
||||
tensor_model_parallel_size)
|
||||
global _TP
|
||||
assert _TP is None, ("tensor model parallel group is already initialized")
|
||||
group_ranks = []
|
||||
for i in range(num_tensor_model_parallel_groups):
|
||||
ranks = list(
|
||||
range(i * tensor_model_parallel_size,
|
||||
(i + 1) * tensor_model_parallel_size))
|
||||
group_ranks.append(ranks)
|
||||
group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
|
||||
# message queue broadcaster is only used in tensor model parallel group
|
||||
_TP = init_model_parallel_group(group_ranks,
|
||||
@ -893,20 +923,33 @@ def initialize_model_parallel(
|
||||
group_name="tp")
|
||||
|
||||
# Build the pipeline model-parallel groups.
|
||||
num_pipeline_model_parallel_groups: int = (world_size //
|
||||
pipeline_model_parallel_size)
|
||||
global _PP
|
||||
assert _PP is None, (
|
||||
"pipeline model parallel group is already initialized")
|
||||
group_ranks = []
|
||||
for i in range(num_pipeline_model_parallel_groups):
|
||||
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
|
||||
group_ranks.append(ranks)
|
||||
group_ranks = all_ranks.transpose(1, 2).reshape(
|
||||
-1, pipeline_model_parallel_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
_PP = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="pp")
|
||||
|
||||
global _DP
|
||||
assert _DP is None, ("data parallel group is already initialized")
|
||||
group_ranks = all_ranks.transpose(0,
|
||||
2).reshape(-1,
|
||||
data_parallel_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
_DP = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name="dp")
|
||||
|
||||
logger.info(
|
||||
"rank %s in world size %s is assigned as "
|
||||
"DP rank %s, PP rank %s, TP rank %s", rank, world_size,
|
||||
_DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group)
|
||||
|
||||
|
||||
def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
|
||||
"""
|
||||
@ -1011,6 +1054,11 @@ def destroy_model_parallel():
|
||||
_PP.destroy()
|
||||
_PP = None
|
||||
|
||||
global _DP
|
||||
if _DP:
|
||||
_DP.destroy()
|
||||
_DP = None
|
||||
|
||||
|
||||
def destroy_distributed_environment():
|
||||
global _WORLD
|
||||
|
||||
@ -11,7 +11,11 @@ from collections import deque
|
||||
from typing import Any, Deque, Dict, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from torch.distributed import TCPStore
|
||||
from torch.distributed import ProcessGroup, TCPStore
|
||||
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
|
||||
_get_default_timeout,
|
||||
is_nccl_available)
|
||||
from torch.distributed.rendezvous import rendezvous
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
@ -227,3 +231,88 @@ class StatelessProcessGroup:
|
||||
world_size=world_size,
|
||||
store=store,
|
||||
data_expiration_seconds=data_expiration_seconds)
|
||||
|
||||
|
||||
def stateless_init_torch_distributed_process_group(
|
||||
host: str, port: int, rank: int, world_size: int,
|
||||
backend: str) -> ProcessGroup:
|
||||
"""
|
||||
A replacement for `torch.distributed.init_process_group` that does not
|
||||
pollute the global state. The created ProcessGroup object can be used for
|
||||
some operations such as `allreduce`, because it does not depend on the
|
||||
global rank. However, some operations such as `broadcast` cannot be used
|
||||
because it depends on the global rank.
|
||||
|
||||
# TODO: ask for help from PyTorch team if we need the `broadcast` operation.
|
||||
|
||||
This function is useful when we are not sure about the total number of
|
||||
processes in the process group. For example, we may have process
|
||||
1, 2, ..., 8 who want to communicate, and process 9 might be the same
|
||||
process as process 1, or it might be a different process; process 10
|
||||
might be the same process as process 5, or it might be a different process.
|
||||
In this case, how can we reliably form a communication channel within
|
||||
process 9 and 10, without affecting the communication channel within
|
||||
process 1, 2, ..., 8?
|
||||
|
||||
One possible solution is to figure out if process 9 and 10 are the same
|
||||
as process 1 and 5 beforehand, and then form a communication channel
|
||||
based on the information, adjusting the ranks and world_size etc. However,
|
||||
figuring out the information is not always easy, and it will interfere
|
||||
with the main communication channel.
|
||||
|
||||
Our solution is to always form a communication channel with process 1, 2,
|
||||
..., 8, and then use this function to form another communication channel
|
||||
with process 9 and 10. This way, regardless of whether process 9 and 10
|
||||
are the same as process 1 and 5, the main communication channel is
|
||||
always formed with process 1, 2, ..., 8, and the additional communication
|
||||
channel is formed with process 9 and 10.
|
||||
"""
|
||||
init_method = f"tcp://{host}:{port}"
|
||||
backend = Backend(backend) # it is basically string
|
||||
timeout = _get_default_timeout(backend)
|
||||
|
||||
store, rank, world_size = next(
|
||||
rendezvous(init_method, rank, world_size, timeout=timeout))
|
||||
store.set_timeout(timeout)
|
||||
|
||||
group_rank = rank
|
||||
group_size = world_size
|
||||
|
||||
# Use a PrefixStore to avoid accidental overrides of keys used by
|
||||
# different systems (e.g. RPC) in case the store is multi-tenant.
|
||||
prefix_store = PrefixStore(init_method, store)
|
||||
|
||||
pg_options = ProcessGroup.Options(backend=backend, timeout=timeout)
|
||||
|
||||
pg: ProcessGroup = ProcessGroup(
|
||||
prefix_store,
|
||||
group_rank,
|
||||
group_size,
|
||||
pg_options,
|
||||
)
|
||||
|
||||
if backend == "gloo":
|
||||
from torch.distributed.distributed_c10d import ProcessGroupGloo
|
||||
backend_class = ProcessGroupGloo(prefix_store,
|
||||
group_rank,
|
||||
group_size,
|
||||
timeout=timeout)
|
||||
backend_type = ProcessGroup.BackendType.GLOO
|
||||
device = torch.device("cpu")
|
||||
elif backend == "nccl":
|
||||
assert is_nccl_available()
|
||||
from torch.distributed.distributed_c10d import ProcessGroupNCCL
|
||||
|
||||
backend_options = ProcessGroupNCCL.Options()
|
||||
backend_options._timeout = timeout
|
||||
|
||||
backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
|
||||
backend_options)
|
||||
backend_type = ProcessGroup.BackendType.NCCL
|
||||
device = torch.device("cuda")
|
||||
|
||||
backend_class._set_sequence_number_for_group()
|
||||
|
||||
pg._register_backend(device, backend_type, backend_class)
|
||||
|
||||
return pg
|
||||
|
||||
20
vllm/envs.py
20
vllm/envs.py
@ -90,6 +90,10 @@ if TYPE_CHECKING:
|
||||
VLLM_RAY_BUNDLE_INDICES: str = ""
|
||||
VLLM_CUDART_SO_PATH: Optional[str] = None
|
||||
VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH: bool = True
|
||||
VLLM_DP_RANK: int = 0
|
||||
VLLM_DP_SIZE: int = 1
|
||||
VLLM_DP_MASTER_IP: str = ""
|
||||
VLLM_DP_MASTER_PORT: int = 0
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -593,6 +597,22 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
"VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH":
|
||||
lambda: os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() in
|
||||
("1", "true"),
|
||||
|
||||
# Rank of the process in the data parallel setting
|
||||
"VLLM_DP_RANK":
|
||||
lambda: int(os.getenv("VLLM_DP_RANK", "0")),
|
||||
|
||||
# World size of the data parallel setting
|
||||
"VLLM_DP_SIZE":
|
||||
lambda: int(os.getenv("VLLM_DP_SIZE", "1")),
|
||||
|
||||
# IP address of the master node in the data parallel setting
|
||||
"VLLM_DP_MASTER_IP":
|
||||
lambda: os.getenv("VLLM_DP_MASTER_IP", "127.0.0.1"),
|
||||
|
||||
# Port of the master node in the data parallel setting
|
||||
"VLLM_DP_MASTER_PORT":
|
||||
lambda: int(os.getenv("VLLM_DP_MASTER_PORT", "0")),
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
||||
@ -4,9 +4,10 @@ import time
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
@ -32,6 +33,8 @@ class ForwardContext:
|
||||
attn_metadata: "AttentionMetadata" # set dynamically for each forward pass
|
||||
# TODO: remove after making all virtual_engines share the same kv cache
|
||||
virtual_engine: int # set dynamically for each forward pass
|
||||
num_tokens_across_dp: Optional[
|
||||
List[int]] = None # set dynamically for each forward pass
|
||||
|
||||
|
||||
_forward_context: Optional[ForwardContext] = None
|
||||
@ -48,7 +51,8 @@ def get_forward_context() -> ForwardContext:
|
||||
@contextmanager
|
||||
def set_forward_context(attn_metadata: Any,
|
||||
vllm_config: VllmConfig,
|
||||
virtual_engine: int = 0):
|
||||
virtual_engine: int = 0,
|
||||
num_tokens: int = 0):
|
||||
"""A context manager that stores the current forward context,
|
||||
can be attention metadata, etc.
|
||||
Here we can inject common logic for every model forward pass.
|
||||
@ -57,12 +61,36 @@ def set_forward_context(attn_metadata: Any,
|
||||
need_to_track_batchsize = track_batchsize and attn_metadata is not None
|
||||
if need_to_track_batchsize:
|
||||
forward_start_time = time.perf_counter()
|
||||
num_tokens_across_dp = None
|
||||
if vllm_config.parallel_config.data_parallel_size > 1:
|
||||
dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
if attn_metadata is not None:
|
||||
if hasattr(attn_metadata, "num_prefill_tokens"):
|
||||
# for v0 attention backends
|
||||
batchsize = attn_metadata.num_prefill_tokens + \
|
||||
attn_metadata.num_decode_tokens
|
||||
else:
|
||||
# for v1 attention backends
|
||||
batchsize = attn_metadata.num_input_tokens
|
||||
else:
|
||||
batchsize = num_tokens
|
||||
num_tokens_across_dp = [0] * dp_size
|
||||
num_tokens_across_dp[dp_rank] = batchsize
|
||||
num_tokens_tensor = torch.tensor(num_tokens_across_dp,
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
|
||||
num_tokens_across_dp = num_tokens_tensor.tolist()
|
||||
|
||||
global _forward_context
|
||||
prev_context = _forward_context
|
||||
_forward_context = ForwardContext(
|
||||
attn_layers=vllm_config.compilation_config.static_forward_context,
|
||||
virtual_engine=virtual_engine,
|
||||
attn_metadata=attn_metadata)
|
||||
attn_metadata=attn_metadata,
|
||||
num_tokens_across_dp=num_tokens_across_dp)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
|
||||
@ -501,6 +501,24 @@ def get_open_zmq_ipc_path() -> str:
|
||||
|
||||
|
||||
def get_open_port() -> int:
|
||||
"""
|
||||
Get an open port for the vLLM process to listen on.
|
||||
An edge case to handle, is when we run data parallel,
|
||||
we need to avoid ports that are potentially used by
|
||||
the data parallel master process.
|
||||
Right now we reserve 10 ports for the data parallel master
|
||||
process. Currently it uses 2 ports.
|
||||
"""
|
||||
if "VLLM_DP_MASTER_PORT" in os.environ:
|
||||
dp_port = envs.VLLM_DP_MASTER_PORT
|
||||
while True:
|
||||
port = _get_open_port()
|
||||
if port >= dp_port and port < dp_port + 10:
|
||||
continue
|
||||
return port
|
||||
return _get_open_port()
|
||||
|
||||
def _get_open_port() -> int:
|
||||
port = envs.VLLM_PORT
|
||||
if port is not None:
|
||||
while True:
|
||||
|
||||
@ -219,6 +219,9 @@ class EngineCore:
|
||||
def wake_up(self):
|
||||
self.model_executor.wake_up()
|
||||
|
||||
def execute_dummy_batch(self):
|
||||
self.model_executor.collective_rpc("execute_dummy_batch")
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> None:
|
||||
self.model_executor.add_lora(lora_request)
|
||||
|
||||
|
||||
@ -87,6 +87,12 @@ class EngineCoreClient(ABC):
|
||||
def wake_up(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def execute_dummy_batch(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def execute_dummy_batch_async(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def abort_requests(self, request_ids: List[str]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -156,6 +162,9 @@ class InprocClient(EngineCoreClient):
|
||||
def wake_up(self) -> None:
|
||||
self.engine_core.wake_up()
|
||||
|
||||
def execute_dummy_batch(self) -> None:
|
||||
self.engine_core.execute_dummy_batch()
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> None:
|
||||
self.engine_core.add_lora(lora_request)
|
||||
|
||||
@ -331,6 +340,8 @@ class SyncMPClient(MPClient):
|
||||
def wake_up(self) -> None:
|
||||
self._call_utility("wake_up")
|
||||
|
||||
def execute_dummy_batch(self) -> None:
|
||||
self._call_utility("execute_dummy_batch")
|
||||
|
||||
class AsyncMPClient(MPClient):
|
||||
"""Asyncio-compatible client for multi-proc EngineCore."""
|
||||
@ -414,5 +425,8 @@ class AsyncMPClient(MPClient):
|
||||
async def wake_up_async(self) -> None:
|
||||
await self._call_utility_async("wake_up")
|
||||
|
||||
async def execute_dummy_batch_async(self) -> None:
|
||||
await self._call_utility_async("execute_dummy_batch")
|
||||
|
||||
async def add_lora_async(self, lora_request: LoRARequest) -> None:
|
||||
await self._call_utility_async("add_lora", lora_request)
|
||||
|
||||
@ -4,7 +4,7 @@ from typing import Dict, List, Mapping, Optional, Type, Union
|
||||
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.metrics_types import StatLoggerBase
|
||||
from vllm.envs import VLLM_ENABLE_V1_MULTIPROCESSING
|
||||
@ -47,6 +47,13 @@ class LLMEngine:
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
|
||||
# important: init dp group before init the engine_core
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.dp_enabled = self.parallel_config.data_parallel_size > 1 # noqa
|
||||
self.should_execute_dummy_batch = False
|
||||
if self.dp_enabled:
|
||||
self.dp_group = self.parallel_config.stateless_init_dp_group()
|
||||
|
||||
# Tokenizer (+ ensure liveness if running in another process).
|
||||
self.tokenizer = init_tokenizer_from_configs(
|
||||
model_config=vllm_config.model_config,
|
||||
@ -106,7 +113,17 @@ class LLMEngine:
|
||||
return self.output_processor.get_num_unfinished_requests()
|
||||
|
||||
def has_unfinished_requests(self) -> bool:
|
||||
return self.output_processor.has_unfinished_requests()
|
||||
has_unfinished = self.output_processor.has_unfinished_requests()
|
||||
if not self.dp_enabled:
|
||||
return has_unfinished
|
||||
return self.has_unfinished_requests_dp(has_unfinished)
|
||||
|
||||
def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool:
|
||||
aggregated_has_unfinished = ParallelConfig.has_unfinished_dp(
|
||||
self.dp_group, has_unfinished)
|
||||
if not has_unfinished and aggregated_has_unfinished:
|
||||
self.should_execute_dummy_batch = True
|
||||
return aggregated_has_unfinished
|
||||
|
||||
@classmethod
|
||||
def validate_outputs(cls, outputs, output_type):
|
||||
@ -145,6 +162,11 @@ class LLMEngine:
|
||||
|
||||
def step(self) -> List[RequestOutput]:
|
||||
|
||||
if self.should_execute_dummy_batch:
|
||||
self.should_execute_dummy_batch = False
|
||||
self.engine_core.execute_dummy_batch()
|
||||
return []
|
||||
|
||||
# 1) Get EngineCoreOutput from the EngineCore.
|
||||
outputs = self.engine_core.get_output()
|
||||
|
||||
|
||||
@ -239,7 +239,7 @@ class WorkerProc:
|
||||
ready_socket.send_string(WorkerProc.READY_STR)
|
||||
ready_socket.send(payload)
|
||||
|
||||
self.worker.init_device()
|
||||
wrapper.init_device()
|
||||
self.worker.load_model()
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -1167,7 +1167,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
for k, v in self.intermediate_tensors.items()
|
||||
})
|
||||
|
||||
with set_forward_context(None, self.vllm_config):
|
||||
with set_forward_context(None, self.vllm_config, num_tokens=num_tokens):
|
||||
hidden_states = model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
|
||||
@ -235,6 +235,9 @@ class Worker(WorkerBase):
|
||||
else:
|
||||
self.profiler.stop()
|
||||
|
||||
def execute_dummy_batch(self) -> None:
|
||||
self.model_runner._dummy_run(1)
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.model_runner.add_lora(lora_request)
|
||||
|
||||
|
||||
@ -567,6 +567,11 @@ class WorkerWrapperBase:
|
||||
self.worker = worker_class(**kwargs)
|
||||
assert self.worker is not None
|
||||
|
||||
def init_device(self):
|
||||
with set_current_vllm_config(self.vllm_config):
|
||||
# To make vLLM config available during device initialization
|
||||
self.worker.init_device() # type: ignore
|
||||
|
||||
def execute_method(self, method: Union[str, bytes], *args, **kwargs):
|
||||
try:
|
||||
target = self if self.worker is None else self.worker
|
||||
|
||||
Reference in New Issue
Block a user