Compare commits

...

3 Commits

Author SHA1 Message Date
f048f16ba7 fix pre-commit
Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
2025-10-24 17:45:23 -07:00
180880ddc3 Revert #26709 2025-10-24 17:33:11 -07:00
3e0a770c15 Revert "[Misc] Remove use of CUDA_VISIBLE_DEVICES for device selection (fix DP slow startup time &c) (#26709)"
This reverts commit 237cf6d32a.
2025-10-24 17:32:10 -07:00
6 changed files with 182 additions and 117 deletions

View File

@ -75,9 +75,10 @@ class SampleRequest:
Represents a single inference request for benchmarking. Represents a single inference request for benchmarking.
""" """
prompt: str | list[str] prompt: str | list[str] | None
prompt_len: int prompt_len: int
expected_output_len: int expected_output_len: int
prompt_token_ids: list[int] | None = None
multi_modal_data: MultiModalDataDict | dict | list[dict] | None = None multi_modal_data: MultiModalDataDict | dict | list[dict] | None = None
lora_request: LoRARequest | None = None lora_request: LoRARequest | None = None
request_id: str | None = None request_id: str | None = None
@ -385,7 +386,7 @@ def gen_prompt_decode_to_target_len(
max_retry: int = 10, max_retry: int = 10,
add_special_tokens: bool = False, add_special_tokens: bool = False,
rng: np.random.Generator | None = None, rng: np.random.Generator | None = None,
) -> tuple[str, list[int]]: ) -> tuple[str, list[int], int]:
""" """
Ensure decoded-then-encoded prompt length matches the target token length. Ensure decoded-then-encoded prompt length matches the target token length.
@ -438,9 +439,10 @@ def gen_prompt_decode_to_target_len(
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class RandomDataset(BenchmarkDataset): class RandomTokenIDDataset(BenchmarkDataset):
""" """
Synthetic text-only dataset for serving/throughput benchmarks. Synthetic token-id-only dataset for serving/throughput benchmarks.
No need to use a tokenizer with this dataset.
Strategy: Strategy:
- Sample input/output token lengths per request from integer-uniform ranges - Sample input/output token lengths per request from integer-uniform ranges
@ -448,7 +450,6 @@ class RandomDataset(BenchmarkDataset):
- Prepend a fixed random prefix of length prefix_len. - Prepend a fixed random prefix of length prefix_len.
- Generate the remaining tokens as a reproducible sequence: - Generate the remaining tokens as a reproducible sequence:
(offset + index + arange(input_len)) % vocab_size. (offset + index + arange(input_len)) % vocab_size.
- Decode then re-encode/truncate to ensure prompt token counts match.
- Uses numpy.default_rng seeded with random_seed for reproducible sampling. - Uses numpy.default_rng seeded with random_seed for reproducible sampling.
""" """
@ -467,14 +468,155 @@ class RandomDataset(BenchmarkDataset):
def sample( def sample(
self, self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int, num_requests: int,
request_id_prefix: str = "", request_id_prefix: str = "",
no_oversample: bool = False,
prefix_len: int = DEFAULT_PREFIX_LEN, prefix_len: int = DEFAULT_PREFIX_LEN,
range_ratio: float = DEFAULT_RANGE_RATIO, range_ratio: float = DEFAULT_RANGE_RATIO,
input_len: int = DEFAULT_INPUT_LEN, input_len: int = DEFAULT_INPUT_LEN,
output_len: int = DEFAULT_OUTPUT_LEN, output_len: int = DEFAULT_OUTPUT_LEN,
vocab_size: int = 1,
**kwargs,
) -> list[SampleRequest]:
# validate total input tokens (prefix + sampled) is at least 1.
min_sampled_input = math.floor(input_len * (1.0 - float(range_ratio)))
min_total_input = int(prefix_len) + min_sampled_input
if min_total_input < 1:
raise ValueError(
f"--random-input-len is too small: with --random-range-ratio "
f"{range_ratio}, the minimum possible total input tokens (prefix "
f"+ sampled) is {min_total_input}. Increase --random-input-len and/or "
"--random-prefix-len, or decrease --random-range-ratio so that "
"prefix_len + floor(random_input_len * (1 - range_ratio)) >= 1."
)
input_lens, output_lens, offsets = self.get_sampling_params(
num_requests, range_ratio, input_len, output_len, vocab_size
)
# Generate prefix once
prefix_token_ids = self.get_prefix(vocab_size, prefix_len)
requests = []
for i in range(num_requests):
prompt_token_ids, total_input_len = self.generate_token_id_sequence(
prefix_token_ids=prefix_token_ids,
prefix_len=prefix_len,
vocab_size=vocab_size,
input_len=int(input_lens[i]),
offset=int(offsets[i]),
index=i,
)
requests.append(
SampleRequest(
prompt=None,
prompt_token_ids=prompt_token_ids,
prompt_len=total_input_len,
expected_output_len=int(output_lens[i]),
request_id=request_id_prefix + str(i),
)
)
return requests
def get_prefix(self, vocab_size: int, prefix_len: int) -> list[int]:
"""
Get the prefix for the dataset.
"""
return (
self._rng.integers(0, vocab_size, size=prefix_len).tolist()
if prefix_len > 0
else []
)
def get_sampling_params(
self,
num_requests: int,
range_ratio: float,
input_len: int,
output_len: int,
vocab_size: int,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Get the sampling parameters for the dataset.
"""
# Enforce range_ratio < 1
if not (0.0 <= range_ratio < 1.0):
raise ValueError("range_ratio must be in [0, 1).")
# Bounds use floor for low and ceil for high
input_low = math.floor(input_len * (1 - range_ratio))
input_high = math.ceil(input_len * (1 + range_ratio))
output_low = math.floor(output_len * (1 - range_ratio))
output_high = math.ceil(output_len * (1 + range_ratio))
# Ensure the lower bound for output length is at least 1 to
# prevent sampling 0 tokens.
output_low = max(output_low, 1)
output_high = max(output_high, 1)
if input_low > input_high:
raise ValueError(
f"Invalid input sampling interval: low={input_low} > high={input_high}"
)
if output_low > output_high:
raise ValueError(
"Invalid output sampling interval: "
f"low={output_low} > high={output_high}"
)
logger.info(
"Sampling input_len from [%s, %s] and output_len from [%s, %s]",
input_low,
input_high,
output_low,
output_high,
)
input_lens = self._rng.integers(input_low, input_high + 1, size=num_requests)
output_lens = self._rng.integers(output_low, output_high + 1, size=num_requests)
offsets = self._rng.integers(0, vocab_size, size=num_requests)
return input_lens, output_lens, offsets
def generate_token_id_sequence(
self,
*,
prefix_token_ids: list[int],
prefix_len: int,
vocab_size: int,
input_len: int,
offset: int,
index: int,
) -> tuple[list[int], int]:
"""
Returns (token_sequence, total_input_len).
"""
# Build the inner sequence by sampling sequentially from the vocab
inner_seq = ((offset + index + np.arange(input_len)) % vocab_size).tolist()
token_sequence = prefix_token_ids + inner_seq
total_input_len = prefix_len + int(input_len)
return token_sequence, total_input_len
# -----------------------------------------------------------------------------
# Random Dataset Implementation (Synthetic Data)
# -----------------------------------------------------------------------------
class RandomDataset(RandomTokenIDDataset):
"""
Synthetic text-only dataset for serving/throughput benchmarks.
Additionally to RandomTokenIDDataset, we perform a decode then re-encode/truncate
to ensure prompt token counts match.
"""
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
request_id_prefix: str = "",
no_oversample: bool = False,
prefix_len: int = RandomTokenIDDataset.DEFAULT_PREFIX_LEN,
range_ratio: float = RandomTokenIDDataset.DEFAULT_RANGE_RATIO,
input_len: int = RandomTokenIDDataset.DEFAULT_INPUT_LEN,
output_len: int = RandomTokenIDDataset.DEFAULT_OUTPUT_LEN,
batchsize: int = 1, batchsize: int = 1,
**kwargs, **kwargs,
) -> list[SampleRequest]: ) -> list[SampleRequest]:
@ -490,17 +632,17 @@ class RandomDataset(BenchmarkDataset):
"the minimum possible total input tokens (prefix + sampled) is " "the minimum possible total input tokens (prefix + sampled) is "
f"{min_total_input}. Increase --random-input-len and/or " f"{min_total_input}. Increase --random-input-len and/or "
"--random-prefix-len, or decrease --random-range-ratio so that " "--random-prefix-len, or decrease --random-range-ratio so that "
"prefix_len + floor(max(0, random_input_len - num_special)) " "prefix_len + floor(max(0, random_input_len - num_special) "
"* (1 - range_ratio) >= 1." "* (1 - range_ratio)) >= 1."
) )
vocab_size = tokenizer.vocab_size
input_lens, output_lens, offsets = self.get_sampling_params( input_lens, output_lens, offsets = self.get_sampling_params(
num_requests, range_ratio, input_len, output_len, tokenizer num_requests, range_ratio, input_len, output_len, vocab_size
) )
# Generate prefix once # Generate prefix once
prefix_token_ids = self.get_prefix(tokenizer, prefix_len) prefix_token_ids = self.get_prefix(vocab_size, prefix_len)
vocab_size = tokenizer.vocab_size
requests = [] requests = []
token_mismatch_total = 0 token_mismatch_total = 0
@ -552,67 +694,6 @@ class RandomDataset(BenchmarkDataset):
return requests return requests
def get_prefix(
self, tokenizer: PreTrainedTokenizerBase, prefix_len: int
) -> list[int]:
"""
Get the prefix for the dataset.
"""
return (
self._rng.integers(0, tokenizer.vocab_size, size=prefix_len).tolist()
if prefix_len > 0
else []
)
def get_sampling_params(
self,
num_requests: int,
range_ratio: float,
input_len: int,
output_len: int,
tokenizer: PreTrainedTokenizerBase,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Get the sampling parameters for the dataset.
"""
# Enforce range_ratio < 1
if not (0.0 <= range_ratio < 1.0):
raise ValueError("range_ratio must be in [0, 1).")
num_special_tokens = int(tokenizer.num_special_tokens_to_add())
real_input_len = max(0, int(input_len) - num_special_tokens)
# Bounds use floor for low and ceil for high
input_low = math.floor(real_input_len * (1 - range_ratio))
input_high = math.ceil(real_input_len * (1 + range_ratio))
output_low = math.floor(output_len * (1 - range_ratio))
output_high = math.ceil(output_len * (1 + range_ratio))
# Ensure the lower bound for output length is at least 1 to
# prevent sampling 0 tokens.
output_low = max(output_low, 1)
output_high = max(output_high, 1)
if input_low > input_high:
raise ValueError(
f"Invalid input sampling interval: low={input_low} > high={input_high}"
)
if output_low > output_high:
raise ValueError(
"Invalid output sampling interval: "
f"low={output_low} > high={output_high}"
)
logger.info(
"Sampling input_len from [%s, %s] and output_len from [%s, %s]",
input_low,
input_high,
output_low,
output_high,
)
input_lens = self._rng.integers(input_low, input_high + 1, size=num_requests)
output_lens = self._rng.integers(output_low, output_high + 1, size=num_requests)
offsets = self._rng.integers(0, tokenizer.vocab_size, size=num_requests)
return input_lens, output_lens, offsets
def generate_token_sequence( def generate_token_sequence(
self, self,
*, *,
@ -656,7 +737,7 @@ class RandomDataset(BenchmarkDataset):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Random Dataset Implementation (Synthetic Data) # Random Dataset Implementation (Reranking)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ -684,9 +765,10 @@ class RandomDatasetForReranking(RandomDataset):
n_sep_tokens = int(is_reranker) n_sep_tokens = int(is_reranker)
query_len_param = (input_len // 2) - n_sep_tokens if is_reranker else input_len query_len_param = (input_len // 2) - n_sep_tokens if is_reranker else input_len
vocab_size = tokenizer.vocab_size
query_lens, _, query_offsets = self.get_sampling_params( query_lens, _, query_offsets = self.get_sampling_params(
1, range_ratio, query_len_param, 0, tokenizer 1, range_ratio, query_len_param, 0, vocab_size
) )
query_len = int(query_lens[0]) query_len = int(query_lens[0])
@ -700,9 +782,8 @@ class RandomDatasetForReranking(RandomDataset):
doc_len_param = input_len - query_len - n_sep_tokens doc_len_param = input_len - query_len - n_sep_tokens
doc_lens, _, doc_offsets = self.get_sampling_params( doc_lens, _, doc_offsets = self.get_sampling_params(
num_requests, range_ratio, doc_len_param, 0, tokenizer num_requests, range_ratio, doc_len_param, 0, vocab_size
) )
vocab_size = tokenizer.vocab_size
query_prompt, query_input_len, token_mismatch_total = ( query_prompt, query_input_len, token_mismatch_total = (
self.generate_token_sequence( self.generate_token_sequence(
@ -1054,9 +1135,11 @@ class RandomMultiModalDataset(RandomDataset):
"Video sampling not implemented; set its probability to 0." "Video sampling not implemented; set its probability to 0."
) )
vocab_size = tokenizer.vocab_size
# Get the sampling parameters for the dataset # Get the sampling parameters for the dataset
input_lens, output_lens, offsets = self.get_sampling_params( input_lens, output_lens, offsets = self.get_sampling_params(
num_requests, range_ratio, input_len, output_len, tokenizer num_requests, range_ratio, input_len, output_len, vocab_size
) )
( (
@ -1072,8 +1155,8 @@ class RandomMultiModalDataset(RandomDataset):
) )
# Generate prefix once # Generate prefix once
prefix_token_ids = self.get_prefix(tokenizer, prefix_len) prefix_token_ids = self.get_prefix(vocab_size, prefix_len)
vocab_size = tokenizer.vocab_size
# Add synthetic multimodal items to each request # Add synthetic multimodal items to each request
mm_requests = [] mm_requests = []
token_mismatch_total = 0 token_mismatch_total = 0

View File

@ -23,6 +23,7 @@ from vllm.benchmarks.datasets import (
InstructCoderDataset, InstructCoderDataset,
PrefixRepetitionRandomDataset, PrefixRepetitionRandomDataset,
RandomDataset, RandomDataset,
RandomTokenIDDataset,
SampleRequest, SampleRequest,
ShareGPTDataset, ShareGPTDataset,
SonnetDataset, SonnetDataset,
@ -340,6 +341,10 @@ def get_requests(args, tokenizer):
"output_len": args.output_len, "output_len": args.output_len,
} }
if args.dataset_name == "random_token_id":
sample_kwargs["range_ratio"] = args.random_range_ratio
sample_kwargs["prefix_len"] = args.prefix_len
dataset_cls = RandomTokenIDDataset
if args.dataset_path is None or args.dataset_name == "random": if args.dataset_path is None or args.dataset_name == "random":
sample_kwargs["range_ratio"] = args.random_range_ratio sample_kwargs["range_ratio"] = args.random_range_ratio
sample_kwargs["prefix_len"] = args.prefix_len sample_kwargs["prefix_len"] = args.prefix_len
@ -691,9 +696,14 @@ def main(args: argparse.Namespace):
args.seed = 0 args.seed = 0
random.seed(args.seed) random.seed(args.seed)
# Sample the requests. # Sample the requests.
tokenizer = AutoTokenizer.from_pretrained( tokenizer = (
args.tokenizer, trust_remote_code=args.trust_remote_code AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code
)
if args.skip_tokenizer_init
else None
) )
requests = get_requests(args, tokenizer) requests = get_requests(args, tokenizer)
is_multi_modal = any(request.multi_modal_data is not None for request in requests) is_multi_modal = any(request.multi_modal_data is not None for request in requests)
request_outputs: list[RequestOutput] | None = None request_outputs: list[RequestOutput] | None = None

View File

@ -991,14 +991,11 @@ class NixlConnectorWorker:
# Enable different block lengths for different layers when MLA is used. # Enable different block lengths for different layers when MLA is used.
self.block_len_per_layer = list[int]() self.block_len_per_layer = list[int]()
self.slot_size_per_layer = list[int]() # HD bytes in kv terms self.slot_size_per_layer = list[int]() # HD bytes in kv terms
self.device_id = self.tp_rank
for layer_name, cache_or_caches in xfer_buffers.items(): for layer_name, cache_or_caches in xfer_buffers.items():
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches] cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
for cache in cache_list: for cache in cache_list:
base_addr = cache.data_ptr() base_addr = cache.data_ptr()
if not self.use_host_buffer and current_platform.is_cuda_alike():
self.device_id = cache.device.index
if base_addr in seen_base_addresses: if base_addr in seen_base_addresses:
continue continue
@ -1026,7 +1023,7 @@ class NixlConnectorWorker:
"All kv cache tensors must have the same size" "All kv cache tensors must have the same size"
) )
caches_data.append( caches_data.append(
(base_addr, curr_tensor_size_bytes, self.device_id, "") (base_addr, curr_tensor_size_bytes, self.tp_rank, "")
) )
logger.debug( logger.debug(
@ -1073,7 +1070,7 @@ class NixlConnectorWorker:
block_offset = block_id * self.block_len_per_layer[i] block_offset = block_id * self.block_len_per_layer[i]
addr = base_addr + block_offset addr = base_addr + block_offset
# (addr, len, device id) # (addr, len, device id)
blocks_data.append((addr, kv_block_len, self.device_id)) blocks_data.append((addr, kv_block_len, self.tp_rank))
if self._use_flashinfer: if self._use_flashinfer:
# Separate and interleave K/V regions to maintain the same # Separate and interleave K/V regions to maintain the same
@ -1084,13 +1081,12 @@ class NixlConnectorWorker:
addr = base_addr + block_offset addr = base_addr + block_offset
# Register addresses for V cache (K registered first). # Register addresses for V cache (K registered first).
v_addr = addr + kv_block_len v_addr = addr + kv_block_len
blocks_data.append((v_addr, kv_block_len, self.device_id)) blocks_data.append((v_addr, kv_block_len, self.tp_rank))
logger.debug( logger.debug(
"Created %s blocks for src engine %s and rank %s on device id %s", "Created %s blocks for src engine %s and rank %s",
len(blocks_data), len(blocks_data),
self.engine_id, self.engine_id,
self.tp_rank, self.tp_rank,
self.device_id,
) )
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)

View File

@ -134,12 +134,9 @@ class CoreEngineProcManager:
data_parallel = vllm_config.parallel_config.data_parallel_size > 1 data_parallel = vllm_config.parallel_config.data_parallel_size > 1
try: try:
for proc, local_dp_rank in zip(self.processes, local_dp_ranks): for proc, local_dp_rank in zip(self.processes, local_dp_ranks):
# Adjust device control in DP for non-CUDA platforms
# For CUDA platforms, setting same device id for different DP
# processes affects NCCL init performance.
with ( with (
set_device_control_env_var(vllm_config, local_dp_rank) set_device_control_env_var(vllm_config, local_dp_rank)
if (data_parallel and not current_platform.is_cuda_alike()) if (data_parallel)
else contextlib.nullcontext() else contextlib.nullcontext()
): ):
proc.start() proc.start()

View File

@ -8,6 +8,7 @@ import torch.distributed as dist
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
from vllm.distributed.parallel_state import get_dp_group from vllm.distributed.parallel_state import get_dp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.worker.ubatch_utils import ( from vllm.v1.worker.ubatch_utils import (
UBatchSlices, UBatchSlices,
check_ubatch_thresholds, check_ubatch_thresholds,
@ -19,8 +20,7 @@ logger = init_logger(__name__)
def _get_device_and_group(parallel_config: ParallelConfig): def _get_device_and_group(parallel_config: ParallelConfig):
# Use the actual device assigned to the DP group, not just the device type device = current_platform.device_type
device = get_dp_group().device
group = get_dp_group().device_group group = get_dp_group().device_group
# Transfering this tensor from GPU to CPU will introduce a GPU sync # Transfering this tensor from GPU to CPU will introduce a GPU sync

View File

@ -172,27 +172,6 @@ class Worker(WorkerBase):
if self.device_config.device.type == "cuda": if self.device_config.device.type == "cuda":
# This env var set by Ray causes exceptions with graph building. # This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
if (
self.parallel_config.data_parallel_size > 1
and self.parallel_config.data_parallel_size_local > 0
and self.parallel_config.data_parallel_backend != "ray"
):
# Use local DP rank if available, otherwise use global DP rank.
dp_local_rank = self.parallel_config.data_parallel_rank_local
if dp_local_rank is None:
dp_local_rank = self.parallel_config.data_parallel_rank
tp_pp_world_size = (
self.parallel_config.pipeline_parallel_size
* self.parallel_config.tensor_parallel_size
)
# DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
self.local_rank += dp_local_rank * tp_pp_world_size
assert self.local_rank <= torch.cuda.device_count(), (
f"DP adjusted local rank {self.local_rank} is out of bounds. "
)
self.device = torch.device(f"cuda:{self.local_rank}") self.device = torch.device(f"cuda:{self.local_rank}")
current_platform.set_device(self.device) current_platform.set_device(self.device)