Merge branch 'main' into wentao-refactor-batch-invariant-fp8-deepgemm

This commit is contained in:
Wentao Ye
2025-10-30 15:33:10 -04:00
committed by GitHub
23 changed files with 237 additions and 45 deletions

View File

@ -21,6 +21,7 @@ Join us at the [PyTorch Conference, October 22-23](https://events.linuxfoundatio
*Latest News* 🔥
- [2025/10] We hosted [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/__xb4OyOsImz-9eAVrdlcg) focused on hands-on vLLM inference optimization! Please find the meetup slides [here](https://drive.google.com/drive/folders/1KqwjsFJLfEsC8wlDugnrR61zsWHt94Q6).
- [2025/09] We hosted [vLLM Toronto Meetup](https://luma.com/e80e0ymm) focused on tackling inference at scale and speculative decoding with speakers from NVIDIA and Red Hat! Please find the meetup slides [here](https://docs.google.com/presentation/d/1IYJYmJcu9fLpID5N5RbW_vO0XLo0CGOR14IXOjB61V8/edit?usp=sharing).
- [2025/08] We hosted [vLLM Shenzhen Meetup](https://mp.weixin.qq.com/s/k8ZBO1u2_2odgiKWH_GVTQ) focusing on the ecosystem around vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1Ua2SVKVSu-wp5vou_6ElraDt2bnKhiEA).
- [2025/08] We hosted [vLLM Singapore Meetup](https://www.sginnovate.com/event/vllm-sg-meet). We shared V1 updates, disaggregated serving and MLLM speedups with speakers from Embedded LLM, AMD, WekaIO, and A*STAR. Please find the meetup slides [here](https://drive.google.com/drive/folders/1ncf3GyqLdqFaB6IeB834E5TZJPLAOiXZ?usp=sharing).

View File

@ -361,13 +361,6 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
&& uv pip install --system dist/*.whl --verbose \
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
# TODO (huydhn): Remove this once xformers is released for 2.9.0
RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
. /etc/environment
export TORCH_CUDA_ARCH_LIST='7.5 8.0+PTX 9.0a'
uv pip install --system --no-build-isolation "git+https://github.com/facebookresearch/xformers@v0.0.32.post2"
BASH
# Install FlashInfer pre-compiled kernel cache and binaries
# https://docs.flashinfer.ai/installation.html
RUN --mount=type=cache,target=/root/.cache/uv \

View File

@ -2,6 +2,7 @@
We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below:
- [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/__xb4OyOsImz-9eAVrdlcg), October 25th 2025. [[Slides]](https://drive.google.com/drive/folders/1KqwjsFJLfEsC8wlDugnrR61zsWHt94Q6)
- [vLLM Toronto Meetup](https://luma.com/e80e0ymm), September 25th 2025. [[Slides]](https://docs.google.com/presentation/d/1IYJYmJcu9fLpID5N5RbW_vO0XLo0CGOR14IXOjB61V8/edit?usp=sharing)
- [vLLM Shenzhen Meetup](https://mp.weixin.qq.com/s/k8ZBO1u2_2odgiKWH_GVTQ), August 30th 2025. [[Slides]](https://drive.google.com/drive/folders/1Ua2SVKVSu-wp5vou_6ElraDt2bnKhiEA)
- [vLLM Singapore Meetup](https://www.sginnovate.com/event/vllm-sg-meet), August 27th 2025. [[Slides]](https://drive.google.com/drive/folders/1ncf3GyqLdqFaB6IeB834E5TZJPLAOiXZ?usp=sharing)

View File

@ -9,7 +9,7 @@ torch==2.9.0
torchaudio==2.9.0
# These must be updated alongside torch
torchvision==0.24.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
# https://github.com/facebookresearch/xformers/releases/tag/v0.0.32.post1
# xformers==0.0.32.post1; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.8
# Build from https://github.com/facebookresearch/xformers/releases/tag/v0.0.32.post1
xformers==0.0.33+5d4b92a5.d20251029; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.9
# FlashInfer should be updated together with the Dockerfile
flashinfer-python==0.4.1

View File

@ -1,10 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import pytest
import torch
from vllm import LLM, SamplingParams
from vllm import LLM, AsyncEngineArgs, AsyncLLMEngine, SamplingParams
from vllm.device_allocator.cumem import CuMemAllocator
from vllm.utils.mem_constants import GiB_bytes
@ -201,3 +203,42 @@ def test_deep_sleep():
# cmp output
assert output[0].outputs[0].text == output2[0].outputs[0].text
@create_new_process_for_each_test()
def test_deep_sleep_async():
async def test():
model = "hmellor/tiny-random-LlamaForCausalLM"
free, total = torch.cuda.mem_get_info()
used_bytes_baseline = total - free # in case other process is running
engine_args = AsyncEngineArgs(
model=model,
enable_sleep_mode=True,
)
llm = AsyncLLMEngine.from_engine_args(engine_args)
prompt = "How are you?"
sampling_params = SamplingParams(temperature=0, max_tokens=10)
outputs = llm.generate(prompt, sampling_params, request_id="test_request_id1")
async for output in outputs:
pass
# Put the engine to deep sleep
await llm.sleep(level=2)
await llm.wake_up(tags=["weights"])
await llm.collective_rpc("reload_weights")
free_gpu_bytes_wake_up_w, total = torch.cuda.mem_get_info()
used_bytes = total - free_gpu_bytes_wake_up_w - used_bytes_baseline
assert used_bytes < 4 * GiB_bytes
# now allocate kv cache and cuda graph memory
await llm.wake_up(tags=["kv_cache"])
outputs2 = llm.generate(prompt, sampling_params, request_id="test_request_id2")
async for output2 in outputs2:
pass
# cmp output
assert output.outputs[0].text == output2.outputs[0].text
asyncio.run(test())

View File

@ -651,3 +651,79 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
await serving_chat.create_chat_completion(req)
engine_prompt = serving_chat._process_inputs.await_args_list[1].args[1]
assert engine_prompt.get("cache_salt") == "test_salt"
@pytest.mark.asyncio
async def test_serving_chat_data_parallel_rank_extraction():
"""Test that data_parallel_rank is properly extracted from header and
passed to engine."""
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.processor = MagicMock()
mock_engine.io_processor = MagicMock()
# Mock the generate method to return an async generator
async def mock_generate(*args, **kwargs):
# Yield a fake RequestOutput
from vllm.outputs import CompletionOutput, RequestOutput
yield RequestOutput(
request_id="test-request",
prompt="test prompt",
prompt_token_ids=[1, 2, 3],
prompt_logprobs=None,
outputs=[
CompletionOutput(
index=0,
text="test response",
token_ids=[4, 5, 6],
cumulative_logprob=0.0,
logprobs=None,
finish_reason="stop",
stop_reason=None,
)
],
finished=True,
)
mock_engine.generate = AsyncMock(side_effect=mock_generate)
serving_chat = _build_serving_chat(mock_engine)
# Test when data_parallel_rank is present in header
req = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{"role": "user", "content": "what is 1+1?"}],
)
# Mock request with X-data-parallel-rank header
mock_raw_request = MagicMock()
mock_raw_request.headers = {"X-data-parallel-rank": "2"}
mock_raw_request.state = MagicMock()
with suppress(Exception):
await serving_chat.create_chat_completion(req, mock_raw_request)
# Verify that data_parallel_rank was passed to engine.generate
assert "data_parallel_rank" in mock_engine.generate.call_args.kwargs
assert mock_engine.generate.call_args.kwargs["data_parallel_rank"] == 2
# Test when data_parallel_rank is not present (defaults to None)
req_no_dp = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{"role": "user", "content": "what is 2+2?"}],
)
# Mock request with no header
mock_raw_request_no_dp = MagicMock()
mock_raw_request_no_dp.headers = {}
mock_raw_request_no_dp.state = MagicMock()
with suppress(Exception):
await serving_chat.create_chat_completion(req_no_dp, mock_raw_request_no_dp)
# Verify that data_parallel_rank defaults to None
assert "data_parallel_rank" in mock_engine.generate.call_args.kwargs
assert mock_engine.generate.call_args.kwargs["data_parallel_rank"] is None

View File

@ -363,7 +363,7 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=num_qps_per_rank,
allow_nvlink_for_low_latency_mode=envs.VLLM_DEEPEP_LOW_LATENCY_ALLOW_NVLINK,
allow_nvlink_for_low_latency_mode=True,
allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL,
)

View File

@ -1008,11 +1008,14 @@ class NixlConnectorWorker:
# Enable different block lengths for different layers when MLA is used.
self.block_len_per_layer = list[int]()
self.slot_size_per_layer = list[int]() # HD bytes in kv terms
self.device_id = self.tp_rank
for layer_name, cache_or_caches in xfer_buffers.items():
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
for cache in cache_list:
base_addr = cache.data_ptr()
if not self.use_host_buffer and current_platform.is_cuda_alike():
self.device_id = cache.device.index
if base_addr in seen_base_addresses:
continue
@ -1040,7 +1043,7 @@ class NixlConnectorWorker:
"All kv cache tensors must have the same size"
)
caches_data.append(
(base_addr, curr_tensor_size_bytes, self.tp_rank, "")
(base_addr, curr_tensor_size_bytes, self.device_id, "")
)
logger.debug(
@ -1087,7 +1090,7 @@ class NixlConnectorWorker:
block_offset = block_id * self.block_len_per_layer[i]
addr = base_addr + block_offset
# (addr, len, device id)
blocks_data.append((addr, kv_block_len, self.tp_rank))
blocks_data.append((addr, kv_block_len, self.device_id))
if self._use_flashinfer:
# Separate and interleave K/V regions to maintain the same
@ -1098,12 +1101,13 @@ class NixlConnectorWorker:
addr = base_addr + block_offset
# Register addresses for V cache (K registered first).
v_addr = addr + kv_block_len
blocks_data.append((v_addr, kv_block_len, self.tp_rank))
blocks_data.append((v_addr, kv_block_len, self.device_id))
logger.debug(
"Created %s blocks for src engine %s and rank %s",
"Created %s blocks for src engine %s and rank %s on device id %s",
len(blocks_data),
self.engine_id,
self.tp_rank,
self.device_id,
)
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)

View File

@ -264,6 +264,9 @@ class OpenAIServingChat(OpenAIServing):
if raw_request:
raw_request.state.request_metadata = request_metadata
# Extract data_parallel_rank from header (router can inject it)
data_parallel_rank = self._get_data_parallel_rank(raw_request)
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[RequestOutput, None]] = []
try:
@ -331,6 +334,7 @@ class OpenAIServingChat(OpenAIServing):
priority=request.priority,
prompt_text=prompt_text,
tokenization_kwargs=tokenization_kwargs,
data_parallel_rank=data_parallel_rank,
)
generators.append(generator)

View File

@ -141,6 +141,9 @@ class OpenAIServingCompletion(OpenAIServing):
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
# Extract data_parallel_rank from header (router can inject it)
data_parallel_rank = self._get_data_parallel_rank(raw_request)
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[RequestOutput, None]] = []
try:
@ -224,6 +227,7 @@ class OpenAIServingCompletion(OpenAIServing):
priority=request.priority,
prompt_text=prompt_text,
tokenization_kwargs=tokenization_kwargs,
data_parallel_rank=data_parallel_rank,
)
generators.append(generator)

View File

@ -1298,6 +1298,21 @@ class OpenAIServing:
return raw_request.headers.get("X-Request-Id", default)
@staticmethod
def _get_data_parallel_rank(raw_request: Request | None) -> int | None:
"""Pulls the data parallel rank from a header, if provided"""
if raw_request is None:
return None
rank_str = raw_request.headers.get("X-data-parallel-rank")
if rank_str is None:
return None
try:
return int(rank_str)
except ValueError:
return None
@staticmethod
def _get_decoded_token(
logprob: Logprob,

View File

@ -207,7 +207,6 @@ if TYPE_CHECKING:
VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME: str = "VLLM_OBJECT_STORAGE_SHM_BUFFER"
VLLM_DEEPEP_BUFFER_SIZE_MB: int = 1024
VLLM_DEEPEP_HIGH_THROUGHPUT_FORCE_INTRA_NODE: bool = False
VLLM_DEEPEP_LOW_LATENCY_ALLOW_NVLINK: bool = True
VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL: bool = False
VLLM_DBO_COMM_SMS: int = 20
VLLM_PATTERN_MATCH_DEBUG: str | None = None
@ -1400,11 +1399,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_DEEPEP_HIGH_THROUGHPUT_FORCE_INTRA_NODE": lambda: bool(
int(os.getenv("VLLM_DEEPEP_HIGH_THROUGHPUT_FORCE_INTRA_NODE", "0"))
),
# Allow DeepEP to use nvlink for internode_ll kernel, turn this on for
# better latency on GB200 like system
"VLLM_DEEPEP_LOW_LATENCY_ALLOW_NVLINK": lambda: bool(
int(os.getenv("VLLM_DEEPEP_LOW_LATENCY_ALLOW_NVLINK", "1"))
),
# Allow DeepEP to use MNNVL (multi-node nvlink) for internode_ll kernel,
# turn this for better latency on GB200 like system
"VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL": lambda: bool(
@ -1566,7 +1560,6 @@ def compute_hash() -> str:
"VLLM_NVFP4_GEMM_BACKEND",
"VLLM_USE_FBGEMM",
"VLLM_DEEPEP_HIGH_THROUGHPUT_FORCE_INTRA_NODE",
"VLLM_DEEPEP_LOW_LATENCY_ALLOW_NVLINK",
"VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL",
]
for key in environment_variables_to_hash:

View File

@ -818,6 +818,9 @@ def get_config_file_name(
E: int, N: int, dtype: str | None, block_shape: list[int] | None = None
) -> str:
device_name = current_platform.get_device_name().replace(" ", "_")
# Set device_name to H200 if a device from the H200 family is detected
if "H200" in device_name:
device_name = "H200"
dtype_selector = "" if not dtype else f",dtype={dtype}"
block_shape_selector = (
"" if not block_shape or not all(block_shape) else f",block_shape={block_shape}"

View File

@ -357,6 +357,15 @@ class MRotaryEmbedding(RotaryEmbeddingBase):
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
def forward_cpu(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
return self.forward_native(positions, query, key, offsets)
@staticmethod
def get_next_input_positions(
mrope_position_delta: int,

View File

@ -406,7 +406,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
# easily by changing the way we layout chunks in the
# mamba2 kernels.
base_chunk_size = model_config.get_mamba_chunk_size()
base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size()
attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)
attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size)

View File

@ -97,7 +97,7 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
) -> torch.Tensor:
assert inputs_embeds is not None
# masking inputs at position 0, as not needed by MTP
inputs_embeds[positions == 0] = 0
inputs_embeds = torch.where(positions.unsqueeze(-1) == 0, 0, inputs_embeds)
inputs_embeds = self.enorm(inputs_embeds)
previous_hidden_states = self.hnorm(previous_hidden_states)

View File

@ -11,7 +11,7 @@ from typing import TYPE_CHECKING
import torch
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup
from vllm.platforms import current_platform
@ -30,13 +30,19 @@ def flashinfer_autotune_supported(vllm_config: VllmConfig) -> bool:
Record known issues with vllm + flashinfer autotune here. Return True if
and only if flashinfer autotune will run through without issues.
"""
return not (
vllm_config.parallel_config.data_parallel_size > 1
and (
envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
)
is_tp_or_dp = (vllm_config.parallel_config.data_parallel_size > 1) or (
vllm_config.parallel_config.tensor_parallel_size > 1
)
is_fi_mxfp4_backend = (
envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS
) or (
current_platform.is_cuda() and current_platform.is_device_capability(100)
) # on >=sm100, default mxfp4 backend is flashinfer
is_eager = vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
return not (is_tp_or_dp and is_fi_mxfp4_backend and is_eager)
def kernel_warmup(worker: "Worker"):

View File

@ -13,7 +13,7 @@ from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorBase_V1,
KVConnectorRole,
supports_hma,
SupportsHMA,
)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.logger import init_logger
@ -93,7 +93,11 @@ class Scheduler(SchedulerInterface):
)
connector_vllm_config = copy.copy(self.vllm_config)
connector_vllm_config.kv_cache_config = copy.copy(kv_cache_config)
# We're dynamically inserting a kv_cache_config variable into the
# connector_vllm_config. This is distinct from the cache_config
# that is already in there.
connector_vllm_config.kv_cache_config = copy.copy(kv_cache_config) # type: ignore[attr-defined]
self.connector = KVConnectorFactory.create_connector(
config=connector_vllm_config, role=KVConnectorRole.SCHEDULER
)
@ -1327,15 +1331,15 @@ class Scheduler(SchedulerInterface):
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)
if not supports_hma(self.connector):
if not isinstance(self.connector, SupportsHMA):
# NOTE(Kuntai): We should deprecate this code path after we enforce
# all connectors to support HMA.
# Hybrid memory allocator should be already turned off for this
# code path, but let's double-check here.
assert len(self.kv_cache_config.kv_cache_groups) == 1
return self.connector.request_finished(request, block_ids[0])
else:
return self.connector.request_finished(request, block_ids)
return self.connector.request_finished_all_groups(request, block_ids)
def _update_waiting_for_remote_kv(self, request: Request) -> bool:
"""

View File

@ -134,9 +134,18 @@ class CoreEngineProcManager:
data_parallel = vllm_config.parallel_config.data_parallel_size > 1
try:
for proc, local_dp_rank in zip(self.processes, local_dp_ranks):
# Adjust device control in DP for non-CUDA platforms
# as well as external and ray launchers
# For CUDA platforms, we use torch.cuda.set_device()
with (
set_device_control_env_var(vllm_config, local_dp_rank)
if (data_parallel)
if (
data_parallel
and (
not current_platform.is_cuda_alike()
or vllm_config.parallel_config.use_ray
)
)
else contextlib.nullcontext()
):
proc.start()

View File

@ -1052,6 +1052,9 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
self.gauge_lora_info.labels(**lora_info_labels).set_to_current_time()
def record_sleep_state(self, sleep: int = 0, level: int = 0):
if not envs.VLLM_SERVER_DEV_MODE:
return
awake = 1
discard_all = 0
weights_offloaded = 0

View File

@ -8,7 +8,6 @@ import torch.distributed as dist
from vllm.config import ParallelConfig
from vllm.distributed.parallel_state import get_dp_group
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.worker.ubatch_utils import (
UBatchSlices,
check_ubatch_thresholds,
@ -20,7 +19,8 @@ logger = init_logger(__name__)
def _get_device_and_group(parallel_config: ParallelConfig):
device = current_platform.device_type
# Use the actual device assigned to the DP group, not just the device type
device = get_dp_group().device
group = get_dp_group().device_group
# Transfering this tensor from GPU to CPU will introduce a GPU sync

View File

@ -2323,11 +2323,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
sampled_ids = [-1] if req_idx not in invalid_req_indices_set else None
else:
sampled_ids = valid_sampled_token_ids[req_idx]
num_sampled_ids: int = len(sampled_ids) if sampled_ids else 0
if cu_num_accepted_tokens is not None:
cu_num_accepted_tokens.append(
cu_num_accepted_tokens[-1] + num_sampled_ids
)
if not sampled_ids:
continue
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
end_idx = start_idx + len(sampled_ids)
end_idx = start_idx + num_sampled_ids
assert end_idx <= self.max_model_len, (
"Sampled token IDs exceed the max model length. "
f"Total number of tokens: {end_idx} > max_model_len: "
@ -2343,11 +2351,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_state = self.requests[req_id]
req_state.output_token_ids.extend(sampled_ids)
if cu_num_accepted_tokens is not None:
cu_num_accepted_tokens.append(
cu_num_accepted_tokens[-1] + len(sampled_ids)
)
logprobs_lists = (
logprobs_tensors.tolists(cu_num_accepted_tokens)
if not self.use_async_scheduling and logprobs_tensors is not None

View File

@ -172,6 +172,29 @@ class Worker(WorkerBase):
if self.device_config.device.type == "cuda":
# This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
if (
self.parallel_config.data_parallel_size > 1
and self.parallel_config.data_parallel_size_local > 0
and self.parallel_config.distributed_executor_backend
not in ["ray", "external_launcher"]
and self.vllm_config.parallel_config.data_parallel_backend != "ray"
):
# Use local DP rank if available, otherwise use global DP rank.
dp_local_rank = self.parallel_config.data_parallel_rank_local
if dp_local_rank is None:
dp_local_rank = self.parallel_config.data_parallel_rank
tp_pp_world_size = (
self.parallel_config.pipeline_parallel_size
* self.parallel_config.tensor_parallel_size
)
# DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
self.local_rank += dp_local_rank * tp_pp_world_size
assert self.local_rank < torch.cuda.device_count(), (
f"DP adjusted local rank {self.local_rank} is out of bounds. "
)
self.device = torch.device(f"cuda:{self.local_rank}")
current_platform.set_device(self.device)