[core] set up data parallel communication (#13591)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2025-02-22 19:28:59 +08:00
committed by GitHub
parent 7f6bae561c
commit 3e472d882a
17 changed files with 416 additions and 28 deletions

View File

@ -134,7 +134,9 @@ steps:
- tests/compile/test_basic_correctness
- examples/offline_inference/rlhf.py
- examples/offline_inference/rlhf_colocate.py
- tests/examples/offline_inference/data_parallel.py
commands:
- VLLM_USE_V1=1 python3 ../examples/offline_inference/data_parallel.py
- pytest -v -s distributed/test_utils.py
- pytest -v -s compile/test_basic_correctness.py
- pytest -v -s distributed/test_pynccl.py

View File

@ -0,0 +1,76 @@
# SPDX-License-Identifier: Apache-2.0
# usage: VLLM_USE_V1=1 python examples/offline_inference/data_parallel.py
# we need to have a launcher to create multiple data parallel
# ranks. And each rank will create a vLLM instance to process its own prompts.
import os
from vllm import LLM, SamplingParams
from vllm.utils import get_open_port
def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank):
os.environ["VLLM_DP_RANK"] = str(dp_rank)
os.environ["VLLM_DP_SIZE"] = str(dp_size)
os.environ["VLLM_DP_MASTER_IP"] = dp_master_ip
os.environ["VLLM_DP_MASTER_PORT"] = str(dp_master_port)
# set devices for each dp_rank
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
str(i) for i in range(dp_rank * GPUs_per_dp_rank, (dp_rank + 1) *
GPUs_per_dp_rank))
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# with DP, each rank should process different prompts.
# usually all the DP ranks process a full dataset,
# and each rank processes a different part of the dataset.
promts_per_rank = len(prompts) // dp_size
start = dp_rank * promts_per_rank
end = start + promts_per_rank
prompts = prompts[start:end]
if len(prompts) == 0:
# if any rank has no prompts to process,
# we need to set a placeholder prompt
prompts = ["Placeholder"]
print(f"DP rank {dp_rank} needs to process {len(prompts)} prompts")
# Create a sampling params object.
# since we are doing data parallel, every rank can have different
# sampling params. here we set different max_tokens for different
# ranks for demonstration.
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
max_tokens=16 * (dp_rank + 1))
# Create an LLM.
llm = LLM(model="facebook/opt-125m", tensor_parallel_size=2, enforce_eager=True)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(
f"DP rank {dp_rank}, Prompt: {prompt!r}, "
f"Generated text: {generated_text!r}")
if __name__ == "__main__":
from multiprocessing import Process
dp_size = 2
GPUs_per_dp_rank = 2
dp_master_ip = "127.0.0.1"
dp_master_port = get_open_port()
procs = []
for i in range(dp_size):
proc = Process(target=main,
args=(dp_size, i, dp_master_ip, dp_master_port,
GPUs_per_dp_rank))
proc.start()
procs.append(proc)
for proc in procs:
proc.join()

View File

@ -16,6 +16,7 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Counter, Dict,
import torch
from pydantic import BaseModel, Field, PrivateAttr
from torch.distributed import ProcessGroup, ReduceOp
from transformers import PretrainedConfig
import vllm.envs as envs
@ -1296,6 +1297,11 @@ class ParallelConfig:
pipeline_parallel_size: int = 1 # Number of pipeline parallel groups.
tensor_parallel_size: int = 1 # Number of tensor parallel groups.
data_parallel_size: int = 1 # Number of data parallel groups.
data_parallel_rank: int = 0 # Rank of the data parallel group.
# IP of the data parallel master.
data_parallel_master_ip: str = "127.0.0.1"
data_parallel_master_port: int = 29500 # Port of the data parallel master.
# Maximum number of multiple batches
# when load model sequentially. To avoid RAM OOM when using tensor
@ -1329,10 +1335,55 @@ class ParallelConfig:
worker_cls: str = "auto"
sd_worker_cls: str = "auto"
# world_size is TPxPP, it affects the number of workers we create.
world_size: int = field(init=False)
# world_size_across_dp is TPxPPxDP, it is the size of the world
# including data parallelism.
world_size_across_dp: int = field(init=False)
rank: int = 0
def get_next_dp_init_port(self) -> int:
"""
We might need to initialize process groups in multiple
processes that is related to data parallelism,
e.g. both in the worker and in the engine, which
can live in different processes. To avoid port conflicts, we
increment the port number each time we need to initialize a
new process group related to data parallelism.
"""
answer = self.data_parallel_master_port
self.data_parallel_master_port += 1
return answer
def stateless_init_dp_group(self) -> "ProcessGroup":
from vllm.distributed.utils import (
stateless_init_torch_distributed_process_group)
# use gloo since the engine process might not have cuda device
dp_group = stateless_init_torch_distributed_process_group(
self.data_parallel_master_ip,
self.get_next_dp_init_port(),
self.data_parallel_rank,
self.data_parallel_size,
backend="gloo")
return dp_group
@staticmethod
def has_unfinished_dp(dp_group: "ProcessGroup",
has_unfinished: bool) -> bool:
tensor = torch.tensor([has_unfinished],
dtype=torch.int32,
device="cpu")
# dp rank 0: has_unfinished_seqs=True
# dp rank 1: has_unfinished_seqs=False
# aggregated: has_unfinished_seqs=True
# so this is an OR operation, i.e. MAX in integers
torch.distributed.all_reduce(tensor, op=ReduceOp.MAX, group=dp_group)
aggregated_has_unfinished = bool(tensor.item())
return aggregated_has_unfinished
def compute_hash(self):
"""
Provide a hash that uniquely identifies all the configs
@ -1350,6 +1401,12 @@ class ParallelConfig:
self.world_size = self.pipeline_parallel_size * \
self.tensor_parallel_size
self.data_parallel_size = envs.VLLM_DP_SIZE
self.data_parallel_rank = envs.VLLM_DP_RANK
self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP
self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT
self.world_size_across_dp = self.world_size * self.data_parallel_size
ray_only_devices = ["tpu"]
from vllm.platforms import current_platform
if (current_platform.device_type in ray_only_devices

View File

@ -16,8 +16,8 @@ class CudaCommunicator(DeviceCommunicatorBase):
device_group: Optional[ProcessGroup] = None,
unique_name: str = ""):
super().__init__(cpu_group, device, device_group, unique_name)
if "pp" in unique_name:
# pipeline parallel does not need custom allreduce
if "tp" not in unique_name:
# only tp uses custom allreduce
use_custom_allreduce = False
else:
from vllm.distributed.parallel_state import (

View File

@ -87,6 +87,7 @@ class CustomAllreduce:
return
rank = dist.get_rank(group=self.group)
self.rank = rank
world_size = dist.get_world_size(group=self.group)
if world_size == 1:
# No need to initialize custom allreduce for single GPU case.
@ -201,8 +202,10 @@ class CustomAllreduce:
@staticmethod
def free_shared_buffer(pointers: List[int],
group: Optional[ProcessGroup] = None) -> None:
rank = dist.get_rank(group=group)
group: Optional[ProcessGroup] = None,
rank: Optional[int] = None) -> None:
if rank is None:
rank = dist.get_rank(group=group)
lib = CudaRTLibrary()
lib.cudaFree(ctypes.c_void_p(pointers[rank]))
@ -298,8 +301,8 @@ class CustomAllreduce:
if not self.disabled and self._ptr:
ops.dispose(self._ptr)
self._ptr = 0
self.free_shared_buffer(self.meta_ptrs)
self.free_shared_buffer(self.buffer_ptrs)
self.free_shared_buffer(self.meta_ptrs, rank=self.rank)
self.free_shared_buffer(self.buffer_ptrs, rank=self.rank)
def __del__(self):
self.close()

View File

@ -750,6 +750,13 @@ get_tensor_model_parallel_group = get_tp_group
_PP: Optional[GroupCoordinator] = None
_DP: Optional[GroupCoordinator] = None
def get_dp_group() -> GroupCoordinator:
assert _DP is not None, ("data parallel group is not initialized")
return _DP
def get_pp_group() -> GroupCoordinator:
assert _PP is not None, (
@ -811,6 +818,21 @@ def init_distributed_environment(
"world_size=%d rank=%d local_rank=%d "
"distributed_init_method=%s backend=%s", world_size, rank, local_rank,
distributed_init_method, backend)
from vllm.config import get_current_vllm_config
config = get_current_vllm_config()
if config is not None and config.parallel_config.data_parallel_size > 1:
parallel_config = config.parallel_config
# adjust to take into account data parallelism
# offset the rank by the data parallel rank
rank = parallel_config.data_parallel_rank * world_size + rank
# adjust the world size to take into account data parallelism
world_size = parallel_config.world_size_across_dp
ip = parallel_config.data_parallel_master_ip
port = parallel_config.get_next_dp_init_port()
distributed_init_method = f"tcp://{ip}:{port}" # noqa
logger.info(
"Adjusting world_size=%d rank=%d distributed_init_method=%s for DP",
world_size, rank, distributed_init_method)
if not torch.distributed.is_initialized():
assert distributed_init_method is not None, (
"distributed_init_method must be provided when initializing "
@ -870,20 +892,28 @@ def initialize_model_parallel(
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
backend = backend or torch.distributed.get_backend(
get_world_group().device_group)
data_parallel_size = 1
from vllm.config import get_current_vllm_config
config = get_current_vllm_config()
if config is not None:
data_parallel_size = config.parallel_config.data_parallel_size
# the layout order is: DP x PP x TP
# to get group_ranks for each dimension, transpose that dimension to the
# last dimension, then reshape to 2D, then unbind the last dimension
all_ranks = torch.arange(world_size).reshape(
data_parallel_size, pipeline_model_parallel_size,
tensor_model_parallel_size) # noqa
# Build the tensor model-parallel groups.
num_tensor_model_parallel_groups: int = (world_size //
tensor_model_parallel_size)
global _TP
assert _TP is None, ("tensor model parallel group is already initialized")
group_ranks = []
for i in range(num_tensor_model_parallel_groups):
ranks = list(
range(i * tensor_model_parallel_size,
(i + 1) * tensor_model_parallel_size))
group_ranks.append(ranks)
group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
# message queue broadcaster is only used in tensor model parallel group
_TP = init_model_parallel_group(group_ranks,
@ -893,20 +923,33 @@ def initialize_model_parallel(
group_name="tp")
# Build the pipeline model-parallel groups.
num_pipeline_model_parallel_groups: int = (world_size //
pipeline_model_parallel_size)
global _PP
assert _PP is None, (
"pipeline model parallel group is already initialized")
group_ranks = []
for i in range(num_pipeline_model_parallel_groups):
ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
group_ranks.append(ranks)
group_ranks = all_ranks.transpose(1, 2).reshape(
-1, pipeline_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
_PP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name="pp")
global _DP
assert _DP is None, ("data parallel group is already initialized")
group_ranks = all_ranks.transpose(0,
2).reshape(-1,
data_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
_DP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name="dp")
logger.info(
"rank %s in world size %s is assigned as "
"DP rank %s, PP rank %s, TP rank %s", rank, world_size,
_DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group)
def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
"""
@ -1011,6 +1054,11 @@ def destroy_model_parallel():
_PP.destroy()
_PP = None
global _DP
if _DP:
_DP.destroy()
_DP = None
def destroy_distributed_environment():
global _WORLD

View File

@ -11,7 +11,11 @@ from collections import deque
from typing import Any, Deque, Dict, Optional, Sequence, Tuple
import torch
from torch.distributed import TCPStore
from torch.distributed import ProcessGroup, TCPStore
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
_get_default_timeout,
is_nccl_available)
from torch.distributed.rendezvous import rendezvous
import vllm.envs as envs
from vllm.logger import init_logger
@ -227,3 +231,88 @@ class StatelessProcessGroup:
world_size=world_size,
store=store,
data_expiration_seconds=data_expiration_seconds)
def stateless_init_torch_distributed_process_group(
host: str, port: int, rank: int, world_size: int,
backend: str) -> ProcessGroup:
"""
A replacement for `torch.distributed.init_process_group` that does not
pollute the global state. The created ProcessGroup object can be used for
some operations such as `allreduce`, because it does not depend on the
global rank. However, some operations such as `broadcast` cannot be used
because it depends on the global rank.
# TODO: ask for help from PyTorch team if we need the `broadcast` operation.
This function is useful when we are not sure about the total number of
processes in the process group. For example, we may have process
1, 2, ..., 8 who want to communicate, and process 9 might be the same
process as process 1, or it might be a different process; process 10
might be the same process as process 5, or it might be a different process.
In this case, how can we reliably form a communication channel within
process 9 and 10, without affecting the communication channel within
process 1, 2, ..., 8?
One possible solution is to figure out if process 9 and 10 are the same
as process 1 and 5 beforehand, and then form a communication channel
based on the information, adjusting the ranks and world_size etc. However,
figuring out the information is not always easy, and it will interfere
with the main communication channel.
Our solution is to always form a communication channel with process 1, 2,
..., 8, and then use this function to form another communication channel
with process 9 and 10. This way, regardless of whether process 9 and 10
are the same as process 1 and 5, the main communication channel is
always formed with process 1, 2, ..., 8, and the additional communication
channel is formed with process 9 and 10.
"""
init_method = f"tcp://{host}:{port}"
backend = Backend(backend) # it is basically string
timeout = _get_default_timeout(backend)
store, rank, world_size = next(
rendezvous(init_method, rank, world_size, timeout=timeout))
store.set_timeout(timeout)
group_rank = rank
group_size = world_size
# Use a PrefixStore to avoid accidental overrides of keys used by
# different systems (e.g. RPC) in case the store is multi-tenant.
prefix_store = PrefixStore(init_method, store)
pg_options = ProcessGroup.Options(backend=backend, timeout=timeout)
pg: ProcessGroup = ProcessGroup(
prefix_store,
group_rank,
group_size,
pg_options,
)
if backend == "gloo":
from torch.distributed.distributed_c10d import ProcessGroupGloo
backend_class = ProcessGroupGloo(prefix_store,
group_rank,
group_size,
timeout=timeout)
backend_type = ProcessGroup.BackendType.GLOO
device = torch.device("cpu")
elif backend == "nccl":
assert is_nccl_available()
from torch.distributed.distributed_c10d import ProcessGroupNCCL
backend_options = ProcessGroupNCCL.Options()
backend_options._timeout = timeout
backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size,
backend_options)
backend_type = ProcessGroup.BackendType.NCCL
device = torch.device("cuda")
backend_class._set_sequence_number_for_group()
pg._register_backend(device, backend_type, backend_class)
return pg

View File

@ -90,6 +90,10 @@ if TYPE_CHECKING:
VLLM_RAY_BUNDLE_INDICES: str = ""
VLLM_CUDART_SO_PATH: Optional[str] = None
VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH: bool = True
VLLM_DP_RANK: int = 0
VLLM_DP_SIZE: int = 1
VLLM_DP_MASTER_IP: str = ""
VLLM_DP_MASTER_PORT: int = 0
def get_default_cache_root():
@ -593,6 +597,22 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH":
lambda: os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() in
("1", "true"),
# Rank of the process in the data parallel setting
"VLLM_DP_RANK":
lambda: int(os.getenv("VLLM_DP_RANK", "0")),
# World size of the data parallel setting
"VLLM_DP_SIZE":
lambda: int(os.getenv("VLLM_DP_SIZE", "1")),
# IP address of the master node in the data parallel setting
"VLLM_DP_MASTER_IP":
lambda: os.getenv("VLLM_DP_MASTER_IP", "127.0.0.1"),
# Port of the master node in the data parallel setting
"VLLM_DP_MASTER_PORT":
lambda: int(os.getenv("VLLM_DP_MASTER_PORT", "0")),
}
# end-env-vars-definition

View File

@ -4,9 +4,10 @@ import time
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import torch
import torch.distributed as dist
import vllm.envs as envs
from vllm.config import VllmConfig
@ -32,6 +33,8 @@ class ForwardContext:
attn_metadata: "AttentionMetadata" # set dynamically for each forward pass
# TODO: remove after making all virtual_engines share the same kv cache
virtual_engine: int # set dynamically for each forward pass
num_tokens_across_dp: Optional[
List[int]] = None # set dynamically for each forward pass
_forward_context: Optional[ForwardContext] = None
@ -48,7 +51,8 @@ def get_forward_context() -> ForwardContext:
@contextmanager
def set_forward_context(attn_metadata: Any,
vllm_config: VllmConfig,
virtual_engine: int = 0):
virtual_engine: int = 0,
num_tokens: int = 0):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
Here we can inject common logic for every model forward pass.
@ -57,12 +61,36 @@ def set_forward_context(attn_metadata: Any,
need_to_track_batchsize = track_batchsize and attn_metadata is not None
if need_to_track_batchsize:
forward_start_time = time.perf_counter()
num_tokens_across_dp = None
if vllm_config.parallel_config.data_parallel_size > 1:
dp_size = vllm_config.parallel_config.data_parallel_size
dp_rank = vllm_config.parallel_config.data_parallel_rank
if attn_metadata is not None:
if hasattr(attn_metadata, "num_prefill_tokens"):
# for v0 attention backends
batchsize = attn_metadata.num_prefill_tokens + \
attn_metadata.num_decode_tokens
else:
# for v1 attention backends
batchsize = attn_metadata.num_input_tokens
else:
batchsize = num_tokens
num_tokens_across_dp = [0] * dp_size
num_tokens_across_dp[dp_rank] = batchsize
num_tokens_tensor = torch.tensor(num_tokens_across_dp,
device="cpu",
dtype=torch.int32)
from vllm.distributed.parallel_state import get_dp_group
dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group)
num_tokens_across_dp = num_tokens_tensor.tolist()
global _forward_context
prev_context = _forward_context
_forward_context = ForwardContext(
attn_layers=vllm_config.compilation_config.static_forward_context,
virtual_engine=virtual_engine,
attn_metadata=attn_metadata)
attn_metadata=attn_metadata,
num_tokens_across_dp=num_tokens_across_dp)
try:
yield
finally:

View File

@ -501,6 +501,24 @@ def get_open_zmq_ipc_path() -> str:
def get_open_port() -> int:
"""
Get an open port for the vLLM process to listen on.
An edge case to handle, is when we run data parallel,
we need to avoid ports that are potentially used by
the data parallel master process.
Right now we reserve 10 ports for the data parallel master
process. Currently it uses 2 ports.
"""
if "VLLM_DP_MASTER_PORT" in os.environ:
dp_port = envs.VLLM_DP_MASTER_PORT
while True:
port = _get_open_port()
if port >= dp_port and port < dp_port + 10:
continue
return port
return _get_open_port()
def _get_open_port() -> int:
port = envs.VLLM_PORT
if port is not None:
while True:

View File

@ -219,6 +219,9 @@ class EngineCore:
def wake_up(self):
self.model_executor.wake_up()
def execute_dummy_batch(self):
self.model_executor.collective_rpc("execute_dummy_batch")
def add_lora(self, lora_request: LoRARequest) -> None:
self.model_executor.add_lora(lora_request)

View File

@ -87,6 +87,12 @@ class EngineCoreClient(ABC):
def wake_up(self) -> None:
raise NotImplementedError
def execute_dummy_batch(self) -> None:
raise NotImplementedError
async def execute_dummy_batch_async(self) -> None:
raise NotImplementedError
def abort_requests(self, request_ids: List[str]) -> None:
raise NotImplementedError
@ -156,6 +162,9 @@ class InprocClient(EngineCoreClient):
def wake_up(self) -> None:
self.engine_core.wake_up()
def execute_dummy_batch(self) -> None:
self.engine_core.execute_dummy_batch()
def add_lora(self, lora_request: LoRARequest) -> None:
self.engine_core.add_lora(lora_request)
@ -331,6 +340,8 @@ class SyncMPClient(MPClient):
def wake_up(self) -> None:
self._call_utility("wake_up")
def execute_dummy_batch(self) -> None:
self._call_utility("execute_dummy_batch")
class AsyncMPClient(MPClient):
"""Asyncio-compatible client for multi-proc EngineCore."""
@ -414,5 +425,8 @@ class AsyncMPClient(MPClient):
async def wake_up_async(self) -> None:
await self._call_utility_async("wake_up")
async def execute_dummy_batch_async(self) -> None:
await self._call_utility_async("execute_dummy_batch")
async def add_lora_async(self, lora_request: LoRARequest) -> None:
await self._call_utility_async("add_lora", lora_request)

View File

@ -4,7 +4,7 @@ from typing import Dict, List, Mapping, Optional, Type, Union
from typing_extensions import TypeVar
from vllm.config import VllmConfig
from vllm.config import ParallelConfig, VllmConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics_types import StatLoggerBase
from vllm.envs import VLLM_ENABLE_V1_MULTIPROCESSING
@ -47,6 +47,13 @@ class LLMEngine:
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
# important: init dp group before init the engine_core
self.parallel_config = vllm_config.parallel_config
self.dp_enabled = self.parallel_config.data_parallel_size > 1 # noqa
self.should_execute_dummy_batch = False
if self.dp_enabled:
self.dp_group = self.parallel_config.stateless_init_dp_group()
# Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
@ -106,7 +113,17 @@ class LLMEngine:
return self.output_processor.get_num_unfinished_requests()
def has_unfinished_requests(self) -> bool:
return self.output_processor.has_unfinished_requests()
has_unfinished = self.output_processor.has_unfinished_requests()
if not self.dp_enabled:
return has_unfinished
return self.has_unfinished_requests_dp(has_unfinished)
def has_unfinished_requests_dp(self, has_unfinished: bool) -> bool:
aggregated_has_unfinished = ParallelConfig.has_unfinished_dp(
self.dp_group, has_unfinished)
if not has_unfinished and aggregated_has_unfinished:
self.should_execute_dummy_batch = True
return aggregated_has_unfinished
@classmethod
def validate_outputs(cls, outputs, output_type):
@ -145,6 +162,11 @@ class LLMEngine:
def step(self) -> List[RequestOutput]:
if self.should_execute_dummy_batch:
self.should_execute_dummy_batch = False
self.engine_core.execute_dummy_batch()
return []
# 1) Get EngineCoreOutput from the EngineCore.
outputs = self.engine_core.get_output()

View File

@ -239,7 +239,7 @@ class WorkerProc:
ready_socket.send_string(WorkerProc.READY_STR)
ready_socket.send(payload)
self.worker.init_device()
wrapper.init_device()
self.worker.load_model()
@staticmethod

View File

@ -1167,7 +1167,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for k, v in self.intermediate_tensors.items()
})
with set_forward_context(None, self.vllm_config):
with set_forward_context(None, self.vllm_config, num_tokens=num_tokens):
hidden_states = model(
input_ids=input_ids,
positions=positions,

View File

@ -235,6 +235,9 @@ class Worker(WorkerBase):
else:
self.profiler.stop()
def execute_dummy_batch(self) -> None:
self.model_runner._dummy_run(1)
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request)

View File

@ -567,6 +567,11 @@ class WorkerWrapperBase:
self.worker = worker_class(**kwargs)
assert self.worker is not None
def init_device(self):
with set_current_vllm_config(self.vllm_config):
# To make vLLM config available during device initialization
self.worker.init_device() # type: ignore
def execute_method(self, method: Union[str, bytes], *args, **kwargs):
try:
target = self if self.worker is None else self.worker