Compare commits

...

5 Commits

Author SHA1 Message Date
f048f16ba7 fix pre-commit
Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
2025-10-24 17:45:23 -07:00
180880ddc3 Revert #26709 2025-10-24 17:33:11 -07:00
3e0a770c15 Revert "[Misc] Remove use of CUDA_VISIBLE_DEVICES for device selection (fix DP slow startup time &c) (#26709)"
This reverts commit 237cf6d32a.
2025-10-24 17:32:10 -07:00
83f478bb19 [KVConnector] Migrate the LMCache integration code to be vLLM native (#25542)
Signed-off-by: ApostaC <yihua98@uchicago.edu>
2025-10-25 00:23:53 +00:00
269c4db0a4 [Misc][DP] Guard mxfp4 implementation selection (#27484)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
2025-10-24 23:29:24 +00:00
11 changed files with 1826 additions and 121 deletions

View File

@ -75,9 +75,10 @@ class SampleRequest:
Represents a single inference request for benchmarking.
"""
prompt: str | list[str]
prompt: str | list[str] | None
prompt_len: int
expected_output_len: int
prompt_token_ids: list[int] | None = None
multi_modal_data: MultiModalDataDict | dict | list[dict] | None = None
lora_request: LoRARequest | None = None
request_id: str | None = None
@ -385,7 +386,7 @@ def gen_prompt_decode_to_target_len(
max_retry: int = 10,
add_special_tokens: bool = False,
rng: np.random.Generator | None = None,
) -> tuple[str, list[int]]:
) -> tuple[str, list[int], int]:
"""
Ensure decoded-then-encoded prompt length matches the target token length.
@ -438,9 +439,10 @@ def gen_prompt_decode_to_target_len(
# -----------------------------------------------------------------------------
class RandomDataset(BenchmarkDataset):
class RandomTokenIDDataset(BenchmarkDataset):
"""
Synthetic text-only dataset for serving/throughput benchmarks.
Synthetic token-id-only dataset for serving/throughput benchmarks.
No need to use a tokenizer with this dataset.
Strategy:
- Sample input/output token lengths per request from integer-uniform ranges
@ -448,7 +450,6 @@ class RandomDataset(BenchmarkDataset):
- Prepend a fixed random prefix of length prefix_len.
- Generate the remaining tokens as a reproducible sequence:
(offset + index + arange(input_len)) % vocab_size.
- Decode then re-encode/truncate to ensure prompt token counts match.
- Uses numpy.default_rng seeded with random_seed for reproducible sampling.
"""
@ -467,14 +468,155 @@ class RandomDataset(BenchmarkDataset):
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
request_id_prefix: str = "",
no_oversample: bool = False,
prefix_len: int = DEFAULT_PREFIX_LEN,
range_ratio: float = DEFAULT_RANGE_RATIO,
input_len: int = DEFAULT_INPUT_LEN,
output_len: int = DEFAULT_OUTPUT_LEN,
vocab_size: int = 1,
**kwargs,
) -> list[SampleRequest]:
# validate total input tokens (prefix + sampled) is at least 1.
min_sampled_input = math.floor(input_len * (1.0 - float(range_ratio)))
min_total_input = int(prefix_len) + min_sampled_input
if min_total_input < 1:
raise ValueError(
f"--random-input-len is too small: with --random-range-ratio "
f"{range_ratio}, the minimum possible total input tokens (prefix "
f"+ sampled) is {min_total_input}. Increase --random-input-len and/or "
"--random-prefix-len, or decrease --random-range-ratio so that "
"prefix_len + floor(random_input_len * (1 - range_ratio)) >= 1."
)
input_lens, output_lens, offsets = self.get_sampling_params(
num_requests, range_ratio, input_len, output_len, vocab_size
)
# Generate prefix once
prefix_token_ids = self.get_prefix(vocab_size, prefix_len)
requests = []
for i in range(num_requests):
prompt_token_ids, total_input_len = self.generate_token_id_sequence(
prefix_token_ids=prefix_token_ids,
prefix_len=prefix_len,
vocab_size=vocab_size,
input_len=int(input_lens[i]),
offset=int(offsets[i]),
index=i,
)
requests.append(
SampleRequest(
prompt=None,
prompt_token_ids=prompt_token_ids,
prompt_len=total_input_len,
expected_output_len=int(output_lens[i]),
request_id=request_id_prefix + str(i),
)
)
return requests
def get_prefix(self, vocab_size: int, prefix_len: int) -> list[int]:
"""
Get the prefix for the dataset.
"""
return (
self._rng.integers(0, vocab_size, size=prefix_len).tolist()
if prefix_len > 0
else []
)
def get_sampling_params(
self,
num_requests: int,
range_ratio: float,
input_len: int,
output_len: int,
vocab_size: int,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Get the sampling parameters for the dataset.
"""
# Enforce range_ratio < 1
if not (0.0 <= range_ratio < 1.0):
raise ValueError("range_ratio must be in [0, 1).")
# Bounds use floor for low and ceil for high
input_low = math.floor(input_len * (1 - range_ratio))
input_high = math.ceil(input_len * (1 + range_ratio))
output_low = math.floor(output_len * (1 - range_ratio))
output_high = math.ceil(output_len * (1 + range_ratio))
# Ensure the lower bound for output length is at least 1 to
# prevent sampling 0 tokens.
output_low = max(output_low, 1)
output_high = max(output_high, 1)
if input_low > input_high:
raise ValueError(
f"Invalid input sampling interval: low={input_low} > high={input_high}"
)
if output_low > output_high:
raise ValueError(
"Invalid output sampling interval: "
f"low={output_low} > high={output_high}"
)
logger.info(
"Sampling input_len from [%s, %s] and output_len from [%s, %s]",
input_low,
input_high,
output_low,
output_high,
)
input_lens = self._rng.integers(input_low, input_high + 1, size=num_requests)
output_lens = self._rng.integers(output_low, output_high + 1, size=num_requests)
offsets = self._rng.integers(0, vocab_size, size=num_requests)
return input_lens, output_lens, offsets
def generate_token_id_sequence(
self,
*,
prefix_token_ids: list[int],
prefix_len: int,
vocab_size: int,
input_len: int,
offset: int,
index: int,
) -> tuple[list[int], int]:
"""
Returns (token_sequence, total_input_len).
"""
# Build the inner sequence by sampling sequentially from the vocab
inner_seq = ((offset + index + np.arange(input_len)) % vocab_size).tolist()
token_sequence = prefix_token_ids + inner_seq
total_input_len = prefix_len + int(input_len)
return token_sequence, total_input_len
# -----------------------------------------------------------------------------
# Random Dataset Implementation (Synthetic Data)
# -----------------------------------------------------------------------------
class RandomDataset(RandomTokenIDDataset):
"""
Synthetic text-only dataset for serving/throughput benchmarks.
Additionally to RandomTokenIDDataset, we perform a decode then re-encode/truncate
to ensure prompt token counts match.
"""
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
request_id_prefix: str = "",
no_oversample: bool = False,
prefix_len: int = RandomTokenIDDataset.DEFAULT_PREFIX_LEN,
range_ratio: float = RandomTokenIDDataset.DEFAULT_RANGE_RATIO,
input_len: int = RandomTokenIDDataset.DEFAULT_INPUT_LEN,
output_len: int = RandomTokenIDDataset.DEFAULT_OUTPUT_LEN,
batchsize: int = 1,
**kwargs,
) -> list[SampleRequest]:
@ -490,17 +632,17 @@ class RandomDataset(BenchmarkDataset):
"the minimum possible total input tokens (prefix + sampled) is "
f"{min_total_input}. Increase --random-input-len and/or "
"--random-prefix-len, or decrease --random-range-ratio so that "
"prefix_len + floor(max(0, random_input_len - num_special)) "
"* (1 - range_ratio) >= 1."
"prefix_len + floor(max(0, random_input_len - num_special) "
"* (1 - range_ratio)) >= 1."
)
vocab_size = tokenizer.vocab_size
input_lens, output_lens, offsets = self.get_sampling_params(
num_requests, range_ratio, input_len, output_len, tokenizer
num_requests, range_ratio, input_len, output_len, vocab_size
)
# Generate prefix once
prefix_token_ids = self.get_prefix(tokenizer, prefix_len)
vocab_size = tokenizer.vocab_size
prefix_token_ids = self.get_prefix(vocab_size, prefix_len)
requests = []
token_mismatch_total = 0
@ -552,67 +694,6 @@ class RandomDataset(BenchmarkDataset):
return requests
def get_prefix(
self, tokenizer: PreTrainedTokenizerBase, prefix_len: int
) -> list[int]:
"""
Get the prefix for the dataset.
"""
return (
self._rng.integers(0, tokenizer.vocab_size, size=prefix_len).tolist()
if prefix_len > 0
else []
)
def get_sampling_params(
self,
num_requests: int,
range_ratio: float,
input_len: int,
output_len: int,
tokenizer: PreTrainedTokenizerBase,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Get the sampling parameters for the dataset.
"""
# Enforce range_ratio < 1
if not (0.0 <= range_ratio < 1.0):
raise ValueError("range_ratio must be in [0, 1).")
num_special_tokens = int(tokenizer.num_special_tokens_to_add())
real_input_len = max(0, int(input_len) - num_special_tokens)
# Bounds use floor for low and ceil for high
input_low = math.floor(real_input_len * (1 - range_ratio))
input_high = math.ceil(real_input_len * (1 + range_ratio))
output_low = math.floor(output_len * (1 - range_ratio))
output_high = math.ceil(output_len * (1 + range_ratio))
# Ensure the lower bound for output length is at least 1 to
# prevent sampling 0 tokens.
output_low = max(output_low, 1)
output_high = max(output_high, 1)
if input_low > input_high:
raise ValueError(
f"Invalid input sampling interval: low={input_low} > high={input_high}"
)
if output_low > output_high:
raise ValueError(
"Invalid output sampling interval: "
f"low={output_low} > high={output_high}"
)
logger.info(
"Sampling input_len from [%s, %s] and output_len from [%s, %s]",
input_low,
input_high,
output_low,
output_high,
)
input_lens = self._rng.integers(input_low, input_high + 1, size=num_requests)
output_lens = self._rng.integers(output_low, output_high + 1, size=num_requests)
offsets = self._rng.integers(0, tokenizer.vocab_size, size=num_requests)
return input_lens, output_lens, offsets
def generate_token_sequence(
self,
*,
@ -656,7 +737,7 @@ class RandomDataset(BenchmarkDataset):
# -----------------------------------------------------------------------------
# Random Dataset Implementation (Synthetic Data)
# Random Dataset Implementation (Reranking)
# -----------------------------------------------------------------------------
@ -684,9 +765,10 @@ class RandomDatasetForReranking(RandomDataset):
n_sep_tokens = int(is_reranker)
query_len_param = (input_len // 2) - n_sep_tokens if is_reranker else input_len
vocab_size = tokenizer.vocab_size
query_lens, _, query_offsets = self.get_sampling_params(
1, range_ratio, query_len_param, 0, tokenizer
1, range_ratio, query_len_param, 0, vocab_size
)
query_len = int(query_lens[0])
@ -700,9 +782,8 @@ class RandomDatasetForReranking(RandomDataset):
doc_len_param = input_len - query_len - n_sep_tokens
doc_lens, _, doc_offsets = self.get_sampling_params(
num_requests, range_ratio, doc_len_param, 0, tokenizer
num_requests, range_ratio, doc_len_param, 0, vocab_size
)
vocab_size = tokenizer.vocab_size
query_prompt, query_input_len, token_mismatch_total = (
self.generate_token_sequence(
@ -1054,9 +1135,11 @@ class RandomMultiModalDataset(RandomDataset):
"Video sampling not implemented; set its probability to 0."
)
vocab_size = tokenizer.vocab_size
# Get the sampling parameters for the dataset
input_lens, output_lens, offsets = self.get_sampling_params(
num_requests, range_ratio, input_len, output_len, tokenizer
num_requests, range_ratio, input_len, output_len, vocab_size
)
(
@ -1072,8 +1155,8 @@ class RandomMultiModalDataset(RandomDataset):
)
# Generate prefix once
prefix_token_ids = self.get_prefix(tokenizer, prefix_len)
vocab_size = tokenizer.vocab_size
prefix_token_ids = self.get_prefix(vocab_size, prefix_len)
# Add synthetic multimodal items to each request
mm_requests = []
token_mismatch_total = 0

View File

@ -23,6 +23,7 @@ from vllm.benchmarks.datasets import (
InstructCoderDataset,
PrefixRepetitionRandomDataset,
RandomDataset,
RandomTokenIDDataset,
SampleRequest,
ShareGPTDataset,
SonnetDataset,
@ -340,6 +341,10 @@ def get_requests(args, tokenizer):
"output_len": args.output_len,
}
if args.dataset_name == "random_token_id":
sample_kwargs["range_ratio"] = args.random_range_ratio
sample_kwargs["prefix_len"] = args.prefix_len
dataset_cls = RandomTokenIDDataset
if args.dataset_path is None or args.dataset_name == "random":
sample_kwargs["range_ratio"] = args.random_range_ratio
sample_kwargs["prefix_len"] = args.prefix_len
@ -691,9 +696,14 @@ def main(args: argparse.Namespace):
args.seed = 0
random.seed(args.seed)
# Sample the requests.
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code
tokenizer = (
AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code
)
if args.skip_tokenizer_init
else None
)
requests = get_requests(args, tokenizer)
is_multi_modal = any(request.multi_modal_data is not None for request in requests)
request_outputs: list[RequestOutput] | None = None

View File

@ -3,7 +3,9 @@
from typing import TYPE_CHECKING, Any
import torch
from lmcache.integration.vllm.vllm_v1_adapter import LMCacheConnectorV1Impl
from lmcache.integration.vllm.vllm_v1_adapter import (
LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl,
)
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
@ -11,6 +13,9 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorMetadata,
KVConnectorRole,
)
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration import (
vllm_v1_adapter as _adapter,
)
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
@ -26,7 +31,18 @@ logger = init_logger(__name__)
class LMCacheConnectorV1(KVConnectorBase_V1):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
self._lmcache_engine = LMCacheConnectorV1Impl(vllm_config, role, self)
assert vllm_config.kv_transfer_config is not None
use_native = vllm_config.kv_transfer_config.get_from_extra_config(
"use_native", False
)
if use_native:
logger.info("Initializing native LMCache connector")
cls = _adapter.LMCacheConnectorV1Impl
else:
logger.info("Initializing latest dev LMCache connector")
cls = LMCacheConnectorLatestImpl
self._lmcache_engine = cls(vllm_config, role, self)
# ==============================
# Worker-side methods

View File

@ -0,0 +1,2 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

View File

@ -0,0 +1,221 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Standard
import os
import threading
from typing import TYPE_CHECKING, Union
import torch
from lmcache.config import LMCacheEngineConfig as Config
from lmcache.logging import init_logger
from lmcache.v1.config import LMCacheEngineConfig as V1Config
if TYPE_CHECKING:
from vllm.config import ModelConfig
from vllm.multimodal.inputs import PlaceholderRange
from vllm.v1.core.sched.output import NewRequestData
from vllm.v1.request import Request
logger = init_logger(__name__)
ENGINE_NAME = "vllm-instance"
# Thread-safe singleton storage
_config_instance: Config | V1Config | None = None
_config_lock = threading.Lock()
def is_false(value: str) -> bool:
"""Check if the given string value is equivalent to 'false'."""
return value.lower() in ("false", "0", "no", "n", "off")
def lmcache_get_or_create_config() -> Config | V1Config:
"""Get the LMCache configuration from the environment variable
`LMCACHE_CONFIG_FILE`. If the environment variable is not set, this
function will return the default configuration.
This function is thread-safe and implements singleton pattern,
ensuring the configuration is loaded only once.
"""
global _config_instance
# Double-checked locking for thread-safe singleton
if _config_instance is None:
with _config_lock:
if _config_instance is None: # Check again within lock
if is_false(os.getenv("LMCACHE_USE_EXPERIMENTAL", "True")):
logger.warning(
"Detected LMCACHE_USE_EXPERIMENTAL is set to False. "
"Using legacy configuration is deprecated and will "
"be remove soon! Please set LMCACHE_USE_EXPERIMENTAL "
"to True."
)
LMCacheEngineConfig = Config # type: ignore[assignment]
else:
LMCacheEngineConfig = V1Config # type: ignore[assignment]
if "LMCACHE_CONFIG_FILE" not in os.environ:
logger.warning(
"No LMCache configuration file is set. Trying to read"
" configurations from the environment variables."
)
logger.warning(
"You can set the configuration file through "
"the environment variable: LMCACHE_CONFIG_FILE"
)
_config_instance = LMCacheEngineConfig.from_env()
else:
config_file = os.environ["LMCACHE_CONFIG_FILE"]
logger.info("Loading LMCache config file %s", config_file)
_config_instance = LMCacheEngineConfig.from_file(config_file)
# Update config from environment variables
_config_instance.update_config_from_env()
return _config_instance
def hex_hash_to_int16(s: str) -> int:
"""
Convert a hex hash string to a 16-bit integer.
"""
return int(s, 16) & 0xFFFF
def apply_mm_hashes_to_token_ids(
token_ids: torch.Tensor,
mm_hashes: list[str],
mm_positions: list["PlaceholderRange"],
) -> torch.Tensor:
"""
Overwrite token_ids in-place for multimodal placeholders using
efficient slice assignments.
"""
n = token_ids.size(0)
for hash_str, placeholder in zip(mm_hashes, mm_positions):
start, length = placeholder.offset, placeholder.length
if start >= n:
continue
end = min(start + length, n)
token_ids[start:end] = hex_hash_to_int16(hash_str)
return token_ids
def mla_enabled(model_config: "ModelConfig") -> bool:
return (
hasattr(model_config, "use_mla")
and isinstance(model_config.use_mla, bool)
and model_config.use_mla
)
def create_lmcache_metadata(
vllm_config=None, model_config=None, parallel_config=None, cache_config=None
):
"""
Create LMCacheEngineMetadata from vLLM configuration.
This function extracts common metadata creation logic that was duplicated
across multiple files.
Args:
vllm_config (VllmConfig): vLLM configuration object containing model,
parallel, and cache configs (alternative to
individual config parameters)
model_config (ModelConfig): Model configuration (alternative to
vllm_config)
parallel_config (ParallelConfig): Parallel configuration (alternative
to vllm_config)
cache_config (CacheConfig): Cache configuration (alternative to
vllm_config)
"""
# Third Party
# First Party
from lmcache.config import LMCacheEngineMetadata
from vllm.utils import get_kv_cache_torch_dtype
config = lmcache_get_or_create_config()
# Support both vllm_config object and individual config parameters
if vllm_config is not None:
model_cfg = vllm_config.model_config
parallel_cfg = vllm_config.parallel_config
cache_cfg = vllm_config.cache_config
else:
if model_config is None or parallel_config is None or cache_config is None:
raise ValueError(
"Either vllm_config must be provided, or all of "
"model_config, parallel_config, and cache_config must be provided."
)
model_cfg = model_config
parallel_cfg = parallel_config
cache_cfg = cache_config
# Get KV cache dtype
kv_dtype = get_kv_cache_torch_dtype(cache_cfg.cache_dtype, model_cfg.dtype)
# Check if MLA is enabled
use_mla = mla_enabled(model_cfg)
# Construct KV shape (for memory pool)
num_layer = model_cfg.get_num_layers(parallel_cfg)
chunk_size = config.chunk_size
num_kv_head = model_cfg.get_num_kv_heads(parallel_cfg)
head_size = model_cfg.get_head_size()
kv_shape = (num_layer, 1 if use_mla else 2, chunk_size, num_kv_head, head_size)
# Create metadata
metadata = LMCacheEngineMetadata(
model_cfg.model,
parallel_cfg.world_size,
parallel_cfg.rank,
"vllm",
kv_dtype,
kv_shape,
use_mla,
)
return metadata, config
def extract_mm_features(
request: Union["Request", "NewRequestData"], modify: bool = False
) -> tuple[list[str], list["PlaceholderRange"]]:
"""
Normalize multimodal information from a Request into parallel lists.
This helper reads either:
1) `request.mm_features` (objects each exposing `.identifier` and
`.mm_position`), or
2) legacy fields `request.mm_hashes` and `request.mm_positions`.
It returns two equally sized lists: the multimodal hash identifiers and
their corresponding positions. If the request contains no multimodal info,
it returns `([], [])`.
Args:
request (Request): The source object.
modify (bool):
Controls copy semantics for the legacy-path return values.
- If True and legacy fields are used, shallow-copies are returned so
the caller can mutate the lists without affecting `request`.
- If False, the original legacy sequences are returned as-is
(zero-copy); treat them as read-only.
Returns:
tuple[list[str], list[PlaceholderRange]]: (`mm_hashes`, `mm_positions`).
May be `([], [])` when no multimodal data is present.
"""
if getattr(request, "mm_features", None):
mm_hashes, mm_positions = zip(
*((f.identifier, f.mm_position) for f in request.mm_features)
)
return (list(mm_hashes), list(mm_positions))
elif getattr(request, "mm_hashes", None):
if modify:
return (
request.mm_hashes.copy(), # type: ignore
request.mm_positions.copy(), # type: ignore
)
else:
return (request.mm_hashes, request.mm_positions) # type: ignore
else:
return ([], [])

View File

@ -991,14 +991,11 @@ class NixlConnectorWorker:
# Enable different block lengths for different layers when MLA is used.
self.block_len_per_layer = list[int]()
self.slot_size_per_layer = list[int]() # HD bytes in kv terms
self.device_id = self.tp_rank
for layer_name, cache_or_caches in xfer_buffers.items():
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
for cache in cache_list:
base_addr = cache.data_ptr()
if not self.use_host_buffer and current_platform.is_cuda_alike():
self.device_id = cache.device.index
if base_addr in seen_base_addresses:
continue
@ -1026,7 +1023,7 @@ class NixlConnectorWorker:
"All kv cache tensors must have the same size"
)
caches_data.append(
(base_addr, curr_tensor_size_bytes, self.device_id, "")
(base_addr, curr_tensor_size_bytes, self.tp_rank, "")
)
logger.debug(
@ -1073,7 +1070,7 @@ class NixlConnectorWorker:
block_offset = block_id * self.block_len_per_layer[i]
addr = base_addr + block_offset
# (addr, len, device id)
blocks_data.append((addr, kv_block_len, self.device_id))
blocks_data.append((addr, kv_block_len, self.tp_rank))
if self._use_flashinfer:
# Separate and interleave K/V regions to maintain the same
@ -1084,13 +1081,12 @@ class NixlConnectorWorker:
addr = base_addr + block_offset
# Register addresses for V cache (K registered first).
v_addr = addr + kv_block_len
blocks_data.append((v_addr, kv_block_len, self.device_id))
blocks_data.append((v_addr, kv_block_len, self.tp_rank))
logger.debug(
"Created %s blocks for src engine %s and rank %s on device id %s",
"Created %s blocks for src engine %s and rank %s",
len(blocks_data),
self.engine_id,
self.tp_rank,
self.device_id,
)
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)

View File

@ -794,7 +794,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
)
else:
raise NotImplementedError(
"Incompatible Mxfp4 backend for EP batched experts format"
f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for "
"EP batched experts format"
)
else:
assert self.moe_quant_config is not None
@ -813,8 +814,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
return TrtLlmGenExperts(self.moe, self.moe_quant_config, **kwargs)
elif self.mxfp4_backend == Mxfp4Backend.MARLIN:
return MarlinExperts(self.moe_quant_config)
else:
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
return OAITritonExperts(self.moe_quant_config)
else:
raise NotImplementedError(
f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for EP"
)
def _route_and_experts(
self,

View File

@ -134,12 +134,9 @@ class CoreEngineProcManager:
data_parallel = vllm_config.parallel_config.data_parallel_size > 1
try:
for proc, local_dp_rank in zip(self.processes, local_dp_ranks):
# Adjust device control in DP for non-CUDA platforms
# For CUDA platforms, setting same device id for different DP
# processes affects NCCL init performance.
with (
set_device_control_env_var(vllm_config, local_dp_rank)
if (data_parallel and not current_platform.is_cuda_alike())
if (data_parallel)
else contextlib.nullcontext()
):
proc.start()

View File

@ -8,6 +8,7 @@ import torch.distributed as dist
from vllm.config import ParallelConfig
from vllm.distributed.parallel_state import get_dp_group
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.worker.ubatch_utils import (
UBatchSlices,
check_ubatch_thresholds,
@ -19,8 +20,7 @@ logger = init_logger(__name__)
def _get_device_and_group(parallel_config: ParallelConfig):
# Use the actual device assigned to the DP group, not just the device type
device = get_dp_group().device
device = current_platform.device_type
group = get_dp_group().device_group
# Transfering this tensor from GPU to CPU will introduce a GPU sync

View File

@ -172,27 +172,6 @@ class Worker(WorkerBase):
if self.device_config.device.type == "cuda":
# This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
if (
self.parallel_config.data_parallel_size > 1
and self.parallel_config.data_parallel_size_local > 0
and self.parallel_config.data_parallel_backend != "ray"
):
# Use local DP rank if available, otherwise use global DP rank.
dp_local_rank = self.parallel_config.data_parallel_rank_local
if dp_local_rank is None:
dp_local_rank = self.parallel_config.data_parallel_rank
tp_pp_world_size = (
self.parallel_config.pipeline_parallel_size
* self.parallel_config.tensor_parallel_size
)
# DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
self.local_rank += dp_local_rank * tp_pp_world_size
assert self.local_rank <= torch.cuda.device_count(), (
f"DP adjusted local rank {self.local_rank} is out of bounds. "
)
self.device = torch.device(f"cuda:{self.local_rank}")
current_platform.set_device(self.device)