Compare commits
5 Commits
revert-267
...
zhuohan/re
| Author | SHA1 | Date | |
|---|---|---|---|
| f048f16ba7 | |||
| 180880ddc3 | |||
| 3e0a770c15 | |||
| 83f478bb19 | |||
| 269c4db0a4 |
@ -75,9 +75,10 @@ class SampleRequest:
|
||||
Represents a single inference request for benchmarking.
|
||||
"""
|
||||
|
||||
prompt: str | list[str]
|
||||
prompt: str | list[str] | None
|
||||
prompt_len: int
|
||||
expected_output_len: int
|
||||
prompt_token_ids: list[int] | None = None
|
||||
multi_modal_data: MultiModalDataDict | dict | list[dict] | None = None
|
||||
lora_request: LoRARequest | None = None
|
||||
request_id: str | None = None
|
||||
@ -385,7 +386,7 @@ def gen_prompt_decode_to_target_len(
|
||||
max_retry: int = 10,
|
||||
add_special_tokens: bool = False,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> tuple[str, list[int]]:
|
||||
) -> tuple[str, list[int], int]:
|
||||
"""
|
||||
Ensure decoded-then-encoded prompt length matches the target token length.
|
||||
|
||||
@ -438,9 +439,10 @@ def gen_prompt_decode_to_target_len(
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class RandomDataset(BenchmarkDataset):
|
||||
class RandomTokenIDDataset(BenchmarkDataset):
|
||||
"""
|
||||
Synthetic text-only dataset for serving/throughput benchmarks.
|
||||
Synthetic token-id-only dataset for serving/throughput benchmarks.
|
||||
No need to use a tokenizer with this dataset.
|
||||
|
||||
Strategy:
|
||||
- Sample input/output token lengths per request from integer-uniform ranges
|
||||
@ -448,7 +450,6 @@ class RandomDataset(BenchmarkDataset):
|
||||
- Prepend a fixed random prefix of length prefix_len.
|
||||
- Generate the remaining tokens as a reproducible sequence:
|
||||
(offset + index + arange(input_len)) % vocab_size.
|
||||
- Decode then re-encode/truncate to ensure prompt token counts match.
|
||||
- Uses numpy.default_rng seeded with random_seed for reproducible sampling.
|
||||
"""
|
||||
|
||||
@ -467,14 +468,155 @@ class RandomDataset(BenchmarkDataset):
|
||||
|
||||
def sample(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
request_id_prefix: str = "",
|
||||
no_oversample: bool = False,
|
||||
prefix_len: int = DEFAULT_PREFIX_LEN,
|
||||
range_ratio: float = DEFAULT_RANGE_RATIO,
|
||||
input_len: int = DEFAULT_INPUT_LEN,
|
||||
output_len: int = DEFAULT_OUTPUT_LEN,
|
||||
vocab_size: int = 1,
|
||||
**kwargs,
|
||||
) -> list[SampleRequest]:
|
||||
# validate total input tokens (prefix + sampled) is at least 1.
|
||||
min_sampled_input = math.floor(input_len * (1.0 - float(range_ratio)))
|
||||
min_total_input = int(prefix_len) + min_sampled_input
|
||||
if min_total_input < 1:
|
||||
raise ValueError(
|
||||
f"--random-input-len is too small: with --random-range-ratio "
|
||||
f"{range_ratio}, the minimum possible total input tokens (prefix "
|
||||
f"+ sampled) is {min_total_input}. Increase --random-input-len and/or "
|
||||
"--random-prefix-len, or decrease --random-range-ratio so that "
|
||||
"prefix_len + floor(random_input_len * (1 - range_ratio)) >= 1."
|
||||
)
|
||||
|
||||
input_lens, output_lens, offsets = self.get_sampling_params(
|
||||
num_requests, range_ratio, input_len, output_len, vocab_size
|
||||
)
|
||||
|
||||
# Generate prefix once
|
||||
prefix_token_ids = self.get_prefix(vocab_size, prefix_len)
|
||||
|
||||
requests = []
|
||||
for i in range(num_requests):
|
||||
prompt_token_ids, total_input_len = self.generate_token_id_sequence(
|
||||
prefix_token_ids=prefix_token_ids,
|
||||
prefix_len=prefix_len,
|
||||
vocab_size=vocab_size,
|
||||
input_len=int(input_lens[i]),
|
||||
offset=int(offsets[i]),
|
||||
index=i,
|
||||
)
|
||||
requests.append(
|
||||
SampleRequest(
|
||||
prompt=None,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt_len=total_input_len,
|
||||
expected_output_len=int(output_lens[i]),
|
||||
request_id=request_id_prefix + str(i),
|
||||
)
|
||||
)
|
||||
|
||||
return requests
|
||||
|
||||
def get_prefix(self, vocab_size: int, prefix_len: int) -> list[int]:
|
||||
"""
|
||||
Get the prefix for the dataset.
|
||||
"""
|
||||
return (
|
||||
self._rng.integers(0, vocab_size, size=prefix_len).tolist()
|
||||
if prefix_len > 0
|
||||
else []
|
||||
)
|
||||
|
||||
def get_sampling_params(
|
||||
self,
|
||||
num_requests: int,
|
||||
range_ratio: float,
|
||||
input_len: int,
|
||||
output_len: int,
|
||||
vocab_size: int,
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Get the sampling parameters for the dataset.
|
||||
"""
|
||||
# Enforce range_ratio < 1
|
||||
if not (0.0 <= range_ratio < 1.0):
|
||||
raise ValueError("range_ratio must be in [0, 1).")
|
||||
# Bounds use floor for low and ceil for high
|
||||
input_low = math.floor(input_len * (1 - range_ratio))
|
||||
input_high = math.ceil(input_len * (1 + range_ratio))
|
||||
output_low = math.floor(output_len * (1 - range_ratio))
|
||||
output_high = math.ceil(output_len * (1 + range_ratio))
|
||||
# Ensure the lower bound for output length is at least 1 to
|
||||
# prevent sampling 0 tokens.
|
||||
output_low = max(output_low, 1)
|
||||
output_high = max(output_high, 1)
|
||||
|
||||
if input_low > input_high:
|
||||
raise ValueError(
|
||||
f"Invalid input sampling interval: low={input_low} > high={input_high}"
|
||||
)
|
||||
if output_low > output_high:
|
||||
raise ValueError(
|
||||
"Invalid output sampling interval: "
|
||||
f"low={output_low} > high={output_high}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Sampling input_len from [%s, %s] and output_len from [%s, %s]",
|
||||
input_low,
|
||||
input_high,
|
||||
output_low,
|
||||
output_high,
|
||||
)
|
||||
|
||||
input_lens = self._rng.integers(input_low, input_high + 1, size=num_requests)
|
||||
output_lens = self._rng.integers(output_low, output_high + 1, size=num_requests)
|
||||
offsets = self._rng.integers(0, vocab_size, size=num_requests)
|
||||
return input_lens, output_lens, offsets
|
||||
|
||||
def generate_token_id_sequence(
|
||||
self,
|
||||
*,
|
||||
prefix_token_ids: list[int],
|
||||
prefix_len: int,
|
||||
vocab_size: int,
|
||||
input_len: int,
|
||||
offset: int,
|
||||
index: int,
|
||||
) -> tuple[list[int], int]:
|
||||
"""
|
||||
Returns (token_sequence, total_input_len).
|
||||
"""
|
||||
# Build the inner sequence by sampling sequentially from the vocab
|
||||
inner_seq = ((offset + index + np.arange(input_len)) % vocab_size).tolist()
|
||||
token_sequence = prefix_token_ids + inner_seq
|
||||
total_input_len = prefix_len + int(input_len)
|
||||
return token_sequence, total_input_len
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Random Dataset Implementation (Synthetic Data)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class RandomDataset(RandomTokenIDDataset):
|
||||
"""
|
||||
Synthetic text-only dataset for serving/throughput benchmarks.
|
||||
Additionally to RandomTokenIDDataset, we perform a decode then re-encode/truncate
|
||||
to ensure prompt token counts match.
|
||||
"""
|
||||
|
||||
def sample(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
request_id_prefix: str = "",
|
||||
no_oversample: bool = False,
|
||||
prefix_len: int = RandomTokenIDDataset.DEFAULT_PREFIX_LEN,
|
||||
range_ratio: float = RandomTokenIDDataset.DEFAULT_RANGE_RATIO,
|
||||
input_len: int = RandomTokenIDDataset.DEFAULT_INPUT_LEN,
|
||||
output_len: int = RandomTokenIDDataset.DEFAULT_OUTPUT_LEN,
|
||||
batchsize: int = 1,
|
||||
**kwargs,
|
||||
) -> list[SampleRequest]:
|
||||
@ -490,17 +632,17 @@ class RandomDataset(BenchmarkDataset):
|
||||
"the minimum possible total input tokens (prefix + sampled) is "
|
||||
f"{min_total_input}. Increase --random-input-len and/or "
|
||||
"--random-prefix-len, or decrease --random-range-ratio so that "
|
||||
"prefix_len + floor(max(0, random_input_len - num_special)) "
|
||||
"* (1 - range_ratio) >= 1."
|
||||
"prefix_len + floor(max(0, random_input_len - num_special) "
|
||||
"* (1 - range_ratio)) >= 1."
|
||||
)
|
||||
|
||||
vocab_size = tokenizer.vocab_size
|
||||
input_lens, output_lens, offsets = self.get_sampling_params(
|
||||
num_requests, range_ratio, input_len, output_len, tokenizer
|
||||
num_requests, range_ratio, input_len, output_len, vocab_size
|
||||
)
|
||||
|
||||
# Generate prefix once
|
||||
prefix_token_ids = self.get_prefix(tokenizer, prefix_len)
|
||||
vocab_size = tokenizer.vocab_size
|
||||
prefix_token_ids = self.get_prefix(vocab_size, prefix_len)
|
||||
|
||||
requests = []
|
||||
token_mismatch_total = 0
|
||||
@ -552,67 +694,6 @@ class RandomDataset(BenchmarkDataset):
|
||||
|
||||
return requests
|
||||
|
||||
def get_prefix(
|
||||
self, tokenizer: PreTrainedTokenizerBase, prefix_len: int
|
||||
) -> list[int]:
|
||||
"""
|
||||
Get the prefix for the dataset.
|
||||
"""
|
||||
return (
|
||||
self._rng.integers(0, tokenizer.vocab_size, size=prefix_len).tolist()
|
||||
if prefix_len > 0
|
||||
else []
|
||||
)
|
||||
|
||||
def get_sampling_params(
|
||||
self,
|
||||
num_requests: int,
|
||||
range_ratio: float,
|
||||
input_len: int,
|
||||
output_len: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Get the sampling parameters for the dataset.
|
||||
"""
|
||||
# Enforce range_ratio < 1
|
||||
if not (0.0 <= range_ratio < 1.0):
|
||||
raise ValueError("range_ratio must be in [0, 1).")
|
||||
num_special_tokens = int(tokenizer.num_special_tokens_to_add())
|
||||
real_input_len = max(0, int(input_len) - num_special_tokens)
|
||||
# Bounds use floor for low and ceil for high
|
||||
input_low = math.floor(real_input_len * (1 - range_ratio))
|
||||
input_high = math.ceil(real_input_len * (1 + range_ratio))
|
||||
output_low = math.floor(output_len * (1 - range_ratio))
|
||||
output_high = math.ceil(output_len * (1 + range_ratio))
|
||||
# Ensure the lower bound for output length is at least 1 to
|
||||
# prevent sampling 0 tokens.
|
||||
output_low = max(output_low, 1)
|
||||
output_high = max(output_high, 1)
|
||||
|
||||
if input_low > input_high:
|
||||
raise ValueError(
|
||||
f"Invalid input sampling interval: low={input_low} > high={input_high}"
|
||||
)
|
||||
if output_low > output_high:
|
||||
raise ValueError(
|
||||
"Invalid output sampling interval: "
|
||||
f"low={output_low} > high={output_high}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Sampling input_len from [%s, %s] and output_len from [%s, %s]",
|
||||
input_low,
|
||||
input_high,
|
||||
output_low,
|
||||
output_high,
|
||||
)
|
||||
|
||||
input_lens = self._rng.integers(input_low, input_high + 1, size=num_requests)
|
||||
output_lens = self._rng.integers(output_low, output_high + 1, size=num_requests)
|
||||
offsets = self._rng.integers(0, tokenizer.vocab_size, size=num_requests)
|
||||
return input_lens, output_lens, offsets
|
||||
|
||||
def generate_token_sequence(
|
||||
self,
|
||||
*,
|
||||
@ -656,7 +737,7 @@ class RandomDataset(BenchmarkDataset):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Random Dataset Implementation (Synthetic Data)
|
||||
# Random Dataset Implementation (Reranking)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@ -684,9 +765,10 @@ class RandomDatasetForReranking(RandomDataset):
|
||||
n_sep_tokens = int(is_reranker)
|
||||
|
||||
query_len_param = (input_len // 2) - n_sep_tokens if is_reranker else input_len
|
||||
vocab_size = tokenizer.vocab_size
|
||||
|
||||
query_lens, _, query_offsets = self.get_sampling_params(
|
||||
1, range_ratio, query_len_param, 0, tokenizer
|
||||
1, range_ratio, query_len_param, 0, vocab_size
|
||||
)
|
||||
|
||||
query_len = int(query_lens[0])
|
||||
@ -700,9 +782,8 @@ class RandomDatasetForReranking(RandomDataset):
|
||||
doc_len_param = input_len - query_len - n_sep_tokens
|
||||
|
||||
doc_lens, _, doc_offsets = self.get_sampling_params(
|
||||
num_requests, range_ratio, doc_len_param, 0, tokenizer
|
||||
num_requests, range_ratio, doc_len_param, 0, vocab_size
|
||||
)
|
||||
vocab_size = tokenizer.vocab_size
|
||||
|
||||
query_prompt, query_input_len, token_mismatch_total = (
|
||||
self.generate_token_sequence(
|
||||
@ -1054,9 +1135,11 @@ class RandomMultiModalDataset(RandomDataset):
|
||||
"Video sampling not implemented; set its probability to 0."
|
||||
)
|
||||
|
||||
vocab_size = tokenizer.vocab_size
|
||||
|
||||
# Get the sampling parameters for the dataset
|
||||
input_lens, output_lens, offsets = self.get_sampling_params(
|
||||
num_requests, range_ratio, input_len, output_len, tokenizer
|
||||
num_requests, range_ratio, input_len, output_len, vocab_size
|
||||
)
|
||||
|
||||
(
|
||||
@ -1072,8 +1155,8 @@ class RandomMultiModalDataset(RandomDataset):
|
||||
)
|
||||
|
||||
# Generate prefix once
|
||||
prefix_token_ids = self.get_prefix(tokenizer, prefix_len)
|
||||
vocab_size = tokenizer.vocab_size
|
||||
prefix_token_ids = self.get_prefix(vocab_size, prefix_len)
|
||||
|
||||
# Add synthetic multimodal items to each request
|
||||
mm_requests = []
|
||||
token_mismatch_total = 0
|
||||
|
||||
@ -23,6 +23,7 @@ from vllm.benchmarks.datasets import (
|
||||
InstructCoderDataset,
|
||||
PrefixRepetitionRandomDataset,
|
||||
RandomDataset,
|
||||
RandomTokenIDDataset,
|
||||
SampleRequest,
|
||||
ShareGPTDataset,
|
||||
SonnetDataset,
|
||||
@ -340,6 +341,10 @@ def get_requests(args, tokenizer):
|
||||
"output_len": args.output_len,
|
||||
}
|
||||
|
||||
if args.dataset_name == "random_token_id":
|
||||
sample_kwargs["range_ratio"] = args.random_range_ratio
|
||||
sample_kwargs["prefix_len"] = args.prefix_len
|
||||
dataset_cls = RandomTokenIDDataset
|
||||
if args.dataset_path is None or args.dataset_name == "random":
|
||||
sample_kwargs["range_ratio"] = args.random_range_ratio
|
||||
sample_kwargs["prefix_len"] = args.prefix_len
|
||||
@ -691,9 +696,14 @@ def main(args: argparse.Namespace):
|
||||
args.seed = 0
|
||||
random.seed(args.seed)
|
||||
# Sample the requests.
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer, trust_remote_code=args.trust_remote_code
|
||||
tokenizer = (
|
||||
AutoTokenizer.from_pretrained(
|
||||
args.tokenizer, trust_remote_code=args.trust_remote_code
|
||||
)
|
||||
if args.skip_tokenizer_init
|
||||
else None
|
||||
)
|
||||
|
||||
requests = get_requests(args, tokenizer)
|
||||
is_multi_modal = any(request.multi_modal_data is not None for request in requests)
|
||||
request_outputs: list[RequestOutput] | None = None
|
||||
|
||||
@ -3,7 +3,9 @@
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from lmcache.integration.vllm.vllm_v1_adapter import LMCacheConnectorV1Impl
|
||||
from lmcache.integration.vllm.vllm_v1_adapter import (
|
||||
LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl,
|
||||
)
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
@ -11,6 +13,9 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration import (
|
||||
vllm_v1_adapter as _adapter,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
@ -26,7 +31,18 @@ logger = init_logger(__name__)
|
||||
class LMCacheConnectorV1(KVConnectorBase_V1):
|
||||
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
||||
super().__init__(vllm_config=vllm_config, role=role)
|
||||
self._lmcache_engine = LMCacheConnectorV1Impl(vllm_config, role, self)
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
use_native = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"use_native", False
|
||||
)
|
||||
if use_native:
|
||||
logger.info("Initializing native LMCache connector")
|
||||
cls = _adapter.LMCacheConnectorV1Impl
|
||||
else:
|
||||
logger.info("Initializing latest dev LMCache connector")
|
||||
cls = LMCacheConnectorLatestImpl
|
||||
|
||||
self._lmcache_engine = cls(vllm_config, role, self)
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
|
||||
@ -0,0 +1,2 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
@ -0,0 +1,221 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Standard
|
||||
import os
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from lmcache.config import LMCacheEngineConfig as Config
|
||||
from lmcache.logging import init_logger
|
||||
from lmcache.v1.config import LMCacheEngineConfig as V1Config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
from vllm.v1.core.sched.output import NewRequestData
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
ENGINE_NAME = "vllm-instance"
|
||||
|
||||
# Thread-safe singleton storage
|
||||
_config_instance: Config | V1Config | None = None
|
||||
_config_lock = threading.Lock()
|
||||
|
||||
|
||||
def is_false(value: str) -> bool:
|
||||
"""Check if the given string value is equivalent to 'false'."""
|
||||
return value.lower() in ("false", "0", "no", "n", "off")
|
||||
|
||||
|
||||
def lmcache_get_or_create_config() -> Config | V1Config:
|
||||
"""Get the LMCache configuration from the environment variable
|
||||
`LMCACHE_CONFIG_FILE`. If the environment variable is not set, this
|
||||
function will return the default configuration.
|
||||
|
||||
This function is thread-safe and implements singleton pattern,
|
||||
ensuring the configuration is loaded only once.
|
||||
"""
|
||||
global _config_instance
|
||||
|
||||
# Double-checked locking for thread-safe singleton
|
||||
if _config_instance is None:
|
||||
with _config_lock:
|
||||
if _config_instance is None: # Check again within lock
|
||||
if is_false(os.getenv("LMCACHE_USE_EXPERIMENTAL", "True")):
|
||||
logger.warning(
|
||||
"Detected LMCACHE_USE_EXPERIMENTAL is set to False. "
|
||||
"Using legacy configuration is deprecated and will "
|
||||
"be remove soon! Please set LMCACHE_USE_EXPERIMENTAL "
|
||||
"to True."
|
||||
)
|
||||
LMCacheEngineConfig = Config # type: ignore[assignment]
|
||||
else:
|
||||
LMCacheEngineConfig = V1Config # type: ignore[assignment]
|
||||
|
||||
if "LMCACHE_CONFIG_FILE" not in os.environ:
|
||||
logger.warning(
|
||||
"No LMCache configuration file is set. Trying to read"
|
||||
" configurations from the environment variables."
|
||||
)
|
||||
logger.warning(
|
||||
"You can set the configuration file through "
|
||||
"the environment variable: LMCACHE_CONFIG_FILE"
|
||||
)
|
||||
_config_instance = LMCacheEngineConfig.from_env()
|
||||
else:
|
||||
config_file = os.environ["LMCACHE_CONFIG_FILE"]
|
||||
logger.info("Loading LMCache config file %s", config_file)
|
||||
_config_instance = LMCacheEngineConfig.from_file(config_file)
|
||||
# Update config from environment variables
|
||||
_config_instance.update_config_from_env()
|
||||
return _config_instance
|
||||
|
||||
|
||||
def hex_hash_to_int16(s: str) -> int:
|
||||
"""
|
||||
Convert a hex hash string to a 16-bit integer.
|
||||
"""
|
||||
return int(s, 16) & 0xFFFF
|
||||
|
||||
|
||||
def apply_mm_hashes_to_token_ids(
|
||||
token_ids: torch.Tensor,
|
||||
mm_hashes: list[str],
|
||||
mm_positions: list["PlaceholderRange"],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Overwrite token_ids in-place for multimodal placeholders using
|
||||
efficient slice assignments.
|
||||
"""
|
||||
n = token_ids.size(0)
|
||||
for hash_str, placeholder in zip(mm_hashes, mm_positions):
|
||||
start, length = placeholder.offset, placeholder.length
|
||||
if start >= n:
|
||||
continue
|
||||
end = min(start + length, n)
|
||||
token_ids[start:end] = hex_hash_to_int16(hash_str)
|
||||
return token_ids
|
||||
|
||||
|
||||
def mla_enabled(model_config: "ModelConfig") -> bool:
|
||||
return (
|
||||
hasattr(model_config, "use_mla")
|
||||
and isinstance(model_config.use_mla, bool)
|
||||
and model_config.use_mla
|
||||
)
|
||||
|
||||
|
||||
def create_lmcache_metadata(
|
||||
vllm_config=None, model_config=None, parallel_config=None, cache_config=None
|
||||
):
|
||||
"""
|
||||
Create LMCacheEngineMetadata from vLLM configuration.
|
||||
|
||||
This function extracts common metadata creation logic that was duplicated
|
||||
across multiple files.
|
||||
|
||||
Args:
|
||||
vllm_config (VllmConfig): vLLM configuration object containing model,
|
||||
parallel, and cache configs (alternative to
|
||||
individual config parameters)
|
||||
model_config (ModelConfig): Model configuration (alternative to
|
||||
vllm_config)
|
||||
parallel_config (ParallelConfig): Parallel configuration (alternative
|
||||
to vllm_config)
|
||||
cache_config (CacheConfig): Cache configuration (alternative to
|
||||
vllm_config)
|
||||
"""
|
||||
# Third Party
|
||||
# First Party
|
||||
from lmcache.config import LMCacheEngineMetadata
|
||||
|
||||
from vllm.utils import get_kv_cache_torch_dtype
|
||||
|
||||
config = lmcache_get_or_create_config()
|
||||
# Support both vllm_config object and individual config parameters
|
||||
if vllm_config is not None:
|
||||
model_cfg = vllm_config.model_config
|
||||
parallel_cfg = vllm_config.parallel_config
|
||||
cache_cfg = vllm_config.cache_config
|
||||
else:
|
||||
if model_config is None or parallel_config is None or cache_config is None:
|
||||
raise ValueError(
|
||||
"Either vllm_config must be provided, or all of "
|
||||
"model_config, parallel_config, and cache_config must be provided."
|
||||
)
|
||||
model_cfg = model_config
|
||||
parallel_cfg = parallel_config
|
||||
cache_cfg = cache_config
|
||||
|
||||
# Get KV cache dtype
|
||||
kv_dtype = get_kv_cache_torch_dtype(cache_cfg.cache_dtype, model_cfg.dtype)
|
||||
|
||||
# Check if MLA is enabled
|
||||
use_mla = mla_enabled(model_cfg)
|
||||
|
||||
# Construct KV shape (for memory pool)
|
||||
num_layer = model_cfg.get_num_layers(parallel_cfg)
|
||||
chunk_size = config.chunk_size
|
||||
num_kv_head = model_cfg.get_num_kv_heads(parallel_cfg)
|
||||
head_size = model_cfg.get_head_size()
|
||||
kv_shape = (num_layer, 1 if use_mla else 2, chunk_size, num_kv_head, head_size)
|
||||
|
||||
# Create metadata
|
||||
metadata = LMCacheEngineMetadata(
|
||||
model_cfg.model,
|
||||
parallel_cfg.world_size,
|
||||
parallel_cfg.rank,
|
||||
"vllm",
|
||||
kv_dtype,
|
||||
kv_shape,
|
||||
use_mla,
|
||||
)
|
||||
|
||||
return metadata, config
|
||||
|
||||
|
||||
def extract_mm_features(
|
||||
request: Union["Request", "NewRequestData"], modify: bool = False
|
||||
) -> tuple[list[str], list["PlaceholderRange"]]:
|
||||
"""
|
||||
Normalize multimodal information from a Request into parallel lists.
|
||||
|
||||
This helper reads either:
|
||||
1) `request.mm_features` (objects each exposing `.identifier` and
|
||||
`.mm_position`), or
|
||||
2) legacy fields `request.mm_hashes` and `request.mm_positions`.
|
||||
|
||||
It returns two equally sized lists: the multimodal hash identifiers and
|
||||
their corresponding positions. If the request contains no multimodal info,
|
||||
it returns `([], [])`.
|
||||
|
||||
Args:
|
||||
request (Request): The source object.
|
||||
modify (bool):
|
||||
Controls copy semantics for the legacy-path return values.
|
||||
- If True and legacy fields are used, shallow-copies are returned so
|
||||
the caller can mutate the lists without affecting `request`.
|
||||
- If False, the original legacy sequences are returned as-is
|
||||
(zero-copy); treat them as read-only.
|
||||
|
||||
Returns:
|
||||
tuple[list[str], list[PlaceholderRange]]: (`mm_hashes`, `mm_positions`).
|
||||
May be `([], [])` when no multimodal data is present.
|
||||
"""
|
||||
if getattr(request, "mm_features", None):
|
||||
mm_hashes, mm_positions = zip(
|
||||
*((f.identifier, f.mm_position) for f in request.mm_features)
|
||||
)
|
||||
return (list(mm_hashes), list(mm_positions))
|
||||
elif getattr(request, "mm_hashes", None):
|
||||
if modify:
|
||||
return (
|
||||
request.mm_hashes.copy(), # type: ignore
|
||||
request.mm_positions.copy(), # type: ignore
|
||||
)
|
||||
else:
|
||||
return (request.mm_hashes, request.mm_positions) # type: ignore
|
||||
else:
|
||||
return ([], [])
|
||||
File diff suppressed because it is too large
Load Diff
@ -991,14 +991,11 @@ class NixlConnectorWorker:
|
||||
# Enable different block lengths for different layers when MLA is used.
|
||||
self.block_len_per_layer = list[int]()
|
||||
self.slot_size_per_layer = list[int]() # HD bytes in kv terms
|
||||
self.device_id = self.tp_rank
|
||||
for layer_name, cache_or_caches in xfer_buffers.items():
|
||||
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
|
||||
|
||||
for cache in cache_list:
|
||||
base_addr = cache.data_ptr()
|
||||
if not self.use_host_buffer and current_platform.is_cuda_alike():
|
||||
self.device_id = cache.device.index
|
||||
if base_addr in seen_base_addresses:
|
||||
continue
|
||||
|
||||
@ -1026,7 +1023,7 @@ class NixlConnectorWorker:
|
||||
"All kv cache tensors must have the same size"
|
||||
)
|
||||
caches_data.append(
|
||||
(base_addr, curr_tensor_size_bytes, self.device_id, "")
|
||||
(base_addr, curr_tensor_size_bytes, self.tp_rank, "")
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
@ -1073,7 +1070,7 @@ class NixlConnectorWorker:
|
||||
block_offset = block_id * self.block_len_per_layer[i]
|
||||
addr = base_addr + block_offset
|
||||
# (addr, len, device id)
|
||||
blocks_data.append((addr, kv_block_len, self.device_id))
|
||||
blocks_data.append((addr, kv_block_len, self.tp_rank))
|
||||
|
||||
if self._use_flashinfer:
|
||||
# Separate and interleave K/V regions to maintain the same
|
||||
@ -1084,13 +1081,12 @@ class NixlConnectorWorker:
|
||||
addr = base_addr + block_offset
|
||||
# Register addresses for V cache (K registered first).
|
||||
v_addr = addr + kv_block_len
|
||||
blocks_data.append((v_addr, kv_block_len, self.device_id))
|
||||
blocks_data.append((v_addr, kv_block_len, self.tp_rank))
|
||||
logger.debug(
|
||||
"Created %s blocks for src engine %s and rank %s on device id %s",
|
||||
"Created %s blocks for src engine %s and rank %s",
|
||||
len(blocks_data),
|
||||
self.engine_id,
|
||||
self.tp_rank,
|
||||
self.device_id,
|
||||
)
|
||||
|
||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
|
||||
|
||||
@ -794,7 +794,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Incompatible Mxfp4 backend for EP batched experts format"
|
||||
f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for "
|
||||
"EP batched experts format"
|
||||
)
|
||||
else:
|
||||
assert self.moe_quant_config is not None
|
||||
@ -813,8 +814,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
return TrtLlmGenExperts(self.moe, self.moe_quant_config, **kwargs)
|
||||
elif self.mxfp4_backend == Mxfp4Backend.MARLIN:
|
||||
return MarlinExperts(self.moe_quant_config)
|
||||
else:
|
||||
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
|
||||
return OAITritonExperts(self.moe_quant_config)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for EP"
|
||||
)
|
||||
|
||||
def _route_and_experts(
|
||||
self,
|
||||
|
||||
@ -134,12 +134,9 @@ class CoreEngineProcManager:
|
||||
data_parallel = vllm_config.parallel_config.data_parallel_size > 1
|
||||
try:
|
||||
for proc, local_dp_rank in zip(self.processes, local_dp_ranks):
|
||||
# Adjust device control in DP for non-CUDA platforms
|
||||
# For CUDA platforms, setting same device id for different DP
|
||||
# processes affects NCCL init performance.
|
||||
with (
|
||||
set_device_control_env_var(vllm_config, local_dp_rank)
|
||||
if (data_parallel and not current_platform.is_cuda_alike())
|
||||
if (data_parallel)
|
||||
else contextlib.nullcontext()
|
||||
):
|
||||
proc.start()
|
||||
|
||||
@ -8,6 +8,7 @@ import torch.distributed as dist
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.worker.ubatch_utils import (
|
||||
UBatchSlices,
|
||||
check_ubatch_thresholds,
|
||||
@ -19,8 +20,7 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _get_device_and_group(parallel_config: ParallelConfig):
|
||||
# Use the actual device assigned to the DP group, not just the device type
|
||||
device = get_dp_group().device
|
||||
device = current_platform.device_type
|
||||
group = get_dp_group().device_group
|
||||
|
||||
# Transfering this tensor from GPU to CPU will introduce a GPU sync
|
||||
|
||||
@ -172,27 +172,6 @@ class Worker(WorkerBase):
|
||||
if self.device_config.device.type == "cuda":
|
||||
# This env var set by Ray causes exceptions with graph building.
|
||||
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
|
||||
if (
|
||||
self.parallel_config.data_parallel_size > 1
|
||||
and self.parallel_config.data_parallel_size_local > 0
|
||||
and self.parallel_config.data_parallel_backend != "ray"
|
||||
):
|
||||
# Use local DP rank if available, otherwise use global DP rank.
|
||||
dp_local_rank = self.parallel_config.data_parallel_rank_local
|
||||
if dp_local_rank is None:
|
||||
dp_local_rank = self.parallel_config.data_parallel_rank
|
||||
|
||||
tp_pp_world_size = (
|
||||
self.parallel_config.pipeline_parallel_size
|
||||
* self.parallel_config.tensor_parallel_size
|
||||
)
|
||||
|
||||
# DP_LOCAL_RANK * TP_PP_WORLD_SIZE + TP_LOCAL_RANK
|
||||
self.local_rank += dp_local_rank * tp_pp_world_size
|
||||
assert self.local_rank <= torch.cuda.device_count(), (
|
||||
f"DP adjusted local rank {self.local_rank} is out of bounds. "
|
||||
)
|
||||
|
||||
self.device = torch.device(f"cuda:{self.local_rank}")
|
||||
current_platform.set_device(self.device)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user