Compare commits

..

17 Commits

Author SHA1 Message Date
7557a67655 precommit
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
2025-10-29 20:26:12 +00:00
1af476b0e9 Merge branch 'main' into copilot/disable-batched-triton-kernel 2025-10-29 20:18:03 +00:00
8c3b1c7c62 ditch the unit test honestly
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
2025-10-29 20:17:46 +00:00
d4aa144343 [BugFix] Fix handling of resumed reqs in SharedStorageConnector (#27719)
Signed-off-by: Nick Hill <nhill@redhat.com>
2025-10-29 20:16:52 +00:00
fcb1d570bb [Bug] Fix DeepEP low latency assert self.batched_router_logits.size(-1) == full_router_logits.size(-1) Bug (#27682)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
2025-10-29 14:50:39 -04:00
accb8fab07 [KVConnector] Add metrics to Prometheus-Grafana dashboard (#26811)
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
Co-authored-by: Mark McLoughlin <markmc@redhat.com>
2025-10-29 18:44:49 +00:00
5b0448104f [Bug] Raise error explicitly if using incompatible backend (#27424)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
2025-10-29 13:29:20 -04:00
f7a6682872 [CI/Build] Test torchrun with 8 cards (#27548)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
2025-10-29 10:26:06 -07:00
a9fe0793f2 use_aot_compile should respect VLLM_DISABLE_COMPILE_CACHE (#27698)
Signed-off-by: Boyuan Feng <boyuan@meta.com>
2025-10-29 17:08:54 +00:00
7568a282b9 [FIXBUG] Qwen3VL hallucinations without Contiguous on Torch.SDPA (#27744)
Signed-off-by: JartX <sagformas@epdcenter.es>
Co-authored-by: Lukas Geiger <lukas.geiger94@gmail.com>
2025-10-29 16:55:35 +00:00
1da3309ace [Core] Exposing engine sleep & wake_up state as prometheus metrics (#24176)
Signed-off-by: Braulio Dumba <Braulio.Dumba@ibm.com>
2025-10-29 09:32:01 -07:00
5522fb274b [Chore] Optimize P2PNCCLEngine http_address (#27488)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
2025-10-30 00:05:09 +08:00
0f95a1c3f2 [CI] Fix flaky test_two_responses_with_same_prev_id test (#27745)
Signed-off-by: NickLucche <nlucches@redhat.com>
2025-10-29 15:10:35 +00:00
ded24e3e54 [ROCm][Platform] Add MI308X device id in _ROCM_DEVICE_ID_NAME_MAP (#27623)
Signed-off-by: Xiake Sun <xiake.sun@amd.com>
2025-10-29 14:44:03 +00:00
c72d44ba4a Add test for batched triton fallback behavior
Co-authored-by: tlrmchlsmth <1236979+tlrmchlsmth@users.noreply.github.com>
2025-10-16 03:46:02 +00:00
c292032b44 Add env var to control batched triton kernel fallback
- Add VLLM_ALLOW_BATCHED_TRITON_FALLBACK environment variable
- Modify BatchedTritonOrDeepGemmExperts to crash when deepgemm is unavailable unless debug env is set

Co-authored-by: tlrmchlsmth <1236979+tlrmchlsmth@users.noreply.github.com>
2025-10-16 03:42:58 +00:00
b286fba2bb Initial plan 2025-10-16 03:37:04 +00:00
28 changed files with 779 additions and 434 deletions

View File

@ -205,6 +205,24 @@ steps:
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
- popd
- label: Distributed Tests (8 GPUs) # 4min
timeout_in_minutes: 10
gpu: h100
num_gpus: 8
working_dir: "/vllm-workspace/tests"
source_file_dependencies:
- examples/offline_inference/torchrun_dp_example.py
- vllm/config/parallel.py
- vllm/distributed/
- vllm/v1/engine/llm_engine.py
- vllm/v1/executor/uniproc_executor.py
- vllm/v1/worker/gpu_worker.py
commands:
# https://github.com/NVIDIA/nccl/issues/1838
- export NCCL_CUMEM_HOST_ENABLE=0
# test with torchrun tp=2 and dp=4 with ep
- torchrun --nproc-per-node=8 ../examples/offline_inference/torchrun_dp_example.py --tp-size=2 --pp-size=1 --dp-size=4 --enable-ep
- label: EPLB Algorithm Test # 5min
timeout_in_minutes: 15
working_dir: "/vllm-workspace/tests"
@ -401,7 +419,7 @@ steps:
--ignore=lora/test_deepseekv2_tp.py \
--ignore=lora/test_gptoss.py \
--ignore=lora/test_qwen3moe_tp.py
parallelism: 4
- label: PyTorch Compilation Unit Tests # 15min
@ -1126,7 +1144,7 @@ steps:
- tests/weight_loading
commands:
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt
- label: NixlConnector PD accuracy tests (Distributed) # 30min
timeout_in_minutes: 30
working_dir: "/vllm-workspace/tests"

View File

@ -9,10 +9,76 @@ To run this example:
```bash
$ torchrun --nproc-per-node=2 examples/offline_inference/torchrun_dp_example.py
```
With custom parallelism settings:
```bash
$ torchrun --nproc-per-node=8 examples/offline_inference/torchrun_dp_example.py \
--tp-size=2 --pp-size=1 --dp-size=4 --enable-ep
```
"""
import argparse
from vllm import LLM, SamplingParams
def parse_args():
parser = argparse.ArgumentParser(
description="Data-parallel inference with torchrun"
)
parser.add_argument(
"--tp-size",
type=int,
default=1,
help="Tensor parallel size (default: 1)",
)
parser.add_argument(
"--pp-size",
type=int,
default=1,
help="Pipeline parallel size (default: 1)",
)
parser.add_argument(
"--dp-size",
type=int,
default=2,
help="Data parallel size (default: 2)",
)
parser.add_argument(
"--enable-ep",
action="store_true",
help="Enable expert parallel (default: False)",
)
parser.add_argument(
"--model",
type=str,
default="microsoft/Phi-mini-MoE-instruct",
help="Model name or path (default: microsoft/Phi-mini-MoE-instruct)",
)
parser.add_argument(
"--max-model-len",
type=int,
default=4096,
help="Maximum model length (default: 4096)",
)
parser.add_argument(
"--gpu-memory-utilization",
type=float,
default=0.6,
help="GPU memory utilization (default: 0.6)",
)
parser.add_argument(
"--seed",
type=int,
default=1,
help="Random seed (default: 1)",
)
return parser.parse_args()
args = parse_args()
# Create prompts, the same across all ranks
prompts = [
"Hello, my name is",
@ -30,15 +96,15 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# all ranks have the same random seed, so that sampling can be
# deterministic across ranks.
llm = LLM(
model="microsoft/Phi-mini-MoE-instruct",
tensor_parallel_size=1,
data_parallel_size=2,
pipeline_parallel_size=1,
enable_expert_parallel=False,
model=args.model,
tensor_parallel_size=args.tp_size,
data_parallel_size=args.dp_size,
pipeline_parallel_size=args.pp_size,
enable_expert_parallel=args.enable_ep,
distributed_executor_backend="external_launcher",
max_model_len=4096,
gpu_memory_utilization=0.6,
seed=1,
max_model_len=args.max_model_len,
gpu_memory_utilization=args.gpu_memory_utilization,
seed=args.seed,
)
dp_rank = llm.llm_engine.vllm_config.parallel_config.data_parallel_rank

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import requests
from prometheus_client.parser import text_string_to_metric_families
from ...utils import RemoteOpenAIServer
@ -31,12 +32,28 @@ def test_sleep_mode():
assert response.status_code == 200
assert response.json().get("is_sleeping") is True
# check sleep metrics
response = requests.get(remote_server.url_for("metrics"))
assert response.status_code == 200
awake, weights_offloaded, discard_all = _get_sleep_metrics_from_api(response)
assert awake == 0
assert weights_offloaded == 1
assert discard_all == 0
response = requests.post(remote_server.url_for("wake_up"))
assert response.status_code == 200
response = requests.get(remote_server.url_for("is_sleeping"))
assert response.status_code == 200
assert response.json().get("is_sleeping") is False
# check sleep metrics
response = requests.get(remote_server.url_for("metrics"))
assert response.status_code == 200
awake, weights_offloaded, discard_all = _get_sleep_metrics_from_api(response)
assert awake == 1
assert weights_offloaded == 0
assert discard_all == 0
# test wake up with tags
response = requests.post(remote_server.url_for("sleep"), params={"level": "1"})
assert response.status_code == 200
@ -59,3 +76,35 @@ def test_sleep_mode():
response = requests.get(remote_server.url_for("is_sleeping"))
assert response.status_code == 200
assert response.json().get("is_sleeping") is False
# check sleep metrics
response = requests.get(remote_server.url_for("metrics"))
assert response.status_code == 200
awake, weights_offloaded, discard_all = _get_sleep_metrics_from_api(response)
assert awake == 1
assert weights_offloaded == 0
assert discard_all == 0
def _get_sleep_metrics_from_api(response: requests.Response):
"""Return (awake, weights_offloaded, discard_all)"""
awake, weights_offloaded, discard_all = None, None, None
for family in text_string_to_metric_families(response.text):
if family.name == "vllm:engine_sleep_state":
for sample in family.samples:
if sample.name == "vllm:engine_sleep_state":
for label_name, label_value in sample.labels.items():
if label_value == "awake":
awake = sample.value
elif label_value == "weights_offloaded":
weights_offloaded = sample.value
elif label_value == "discard_all":
discard_all = sample.value
assert awake is not None
assert weights_offloaded is not None
assert discard_all is not None
return awake, weights_offloaded, discard_all

View File

@ -1,11 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
@pytest.fixture(autouse=True)
def enable_batch_invariant_mode(monkeypatch: pytest.MonkeyPatch):
"""Automatically enable batch invariant kernel overrides for all tests."""
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1")
yield

View File

@ -1,208 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
HTTP-based batch invariance test: send requests to a running
vLLM server and compare BS=1 vs BS=N results (tokens and per-step logprobs).
Environment variables:
- VLLM_API_BASE: base URL like http://127.0.0.1:9256/v1 (default used)
- VLLM_TEST_MODEL: served model name (e.g., Qwen/Qwen3-1.7B)
Note:
- The server must be started beforehand via `vllm serve ...`
- Example usage:
- `export VLLM_ATTENTION_BACKEND="FLASHINFER_MLA"`
- `export VLLM_BATCH_INVARIANT=1`
- `vllm serve deepseek-ai/DeepSeek-R1 -dp 8 --enable-expert-parallel --port 9256`
- `pytest tests/v1/generation/test_online_batch_invariance.py`
"""
import os
import random
import sys
from typing import Any
from urllib.error import HTTPError, URLError
from urllib.request import Request, urlopen
import pytest
from utils import _random_prompt, skip_unsupported
def _post_json(
url: str, headers: dict[str, str], payload: dict[str, Any], timeout: float
) -> dict[str, Any]:
import json
data = json.dumps(payload).encode("utf-8")
req = Request(url, data=data, headers=headers, method="POST")
with urlopen(req, timeout=timeout) as resp:
body = resp.read()
return json.loads(body.decode("utf-8"))
def _request_completion(
api_base: str,
model: str,
prompt: Any,
sp: dict[str, Any],
timeout: float = 60.0,
max_retries: int = 3,
retry_backoff: float = 0.5,
) -> dict[str, Any] | None:
url = api_base.rstrip("/") + "/completions"
headers = {"Content-Type": "application/json"}
payload: dict[str, Any] = {"model": model, "prompt": prompt}
payload.update(sp)
for attempt in range(max_retries + 1):
try:
return _post_json(url, headers, payload, timeout)
except HTTPError as e: # type: ignore[reportGeneralTypeIssues]
status = getattr(e, "code", None)
if status in (429, 500, 502, 503, 504) and attempt < max_retries:
import time as _t
_t.sleep(retry_backoff * (2**attempt))
continue
sys.stderr.write(f"HTTPError: {e}\n")
return None
except URLError as e: # type: ignore[reportGeneralTypeIssues]
if attempt < max_retries:
import time as _t
_t.sleep(retry_backoff * (2**attempt))
continue
sys.stderr.write(f"URLError: {e}\n")
return None
except Exception as e: # pragma: no cover
if attempt < max_retries:
import time as _t
_t.sleep(retry_backoff * (2**attempt))
continue
sys.stderr.write(f"Error: {e}\n")
return None
return None
def _extract_tokens_and_logprobs(
choice: dict[str, Any],
) -> tuple[list[Any], list[float] | None]:
tokens: list[Any] = []
token_logprobs: list[float] | None = None
lp = choice.get("logprobs")
if lp and isinstance(lp, dict):
tokens = lp.get("token_ids") or lp.get("tokens") or []
token_logprobs = lp.get("token_logprobs", None)
return tokens, token_logprobs
def _server_is_up(api_base: str, timeout: float = 2.0) -> tuple[bool, str]:
url = api_base.rstrip("/") + "/models"
try:
req = Request(url, method="GET")
with urlopen(req, timeout=timeout) as resp:
_ = resp.read()
return True, "OK"
except URLError as e: # type: ignore[reportGeneralTypeIssues]
return False, f"URLError: {e}"
except Exception as e: # pragma: no cover
return False, f"Error: {e}"
def _compare_bs1_vs_bsn_single_process(
prompts: list[str],
sp_kwargs: dict[str, Any],
api_base: str,
model_name: str,
) -> None:
# BS=1
bs1_tokens_per_prompt: list[list[Any]] = []
bs1_logprobs_per_prompt: list[list[float] | None] = []
for p in prompts:
resp = _request_completion(api_base, model_name, p, sp_kwargs)
if resp is None or not resp.get("choices"):
raise AssertionError("BS=1 empty/failed response")
choice = resp["choices"][0]
toks, lps = _extract_tokens_and_logprobs(choice)
if lps is None:
raise AssertionError(
"logprobs not returned; ensure server supports 'logprobs'"
)
bs1_tokens_per_prompt.append(list(toks))
bs1_logprobs_per_prompt.append(list(lps))
# BS=N
bsN_tokens_per_prompt: list[list[Any]] = [None] * len(prompts) # type: ignore[list-item]
bsN_logprobs_per_prompt: list[list[float] | None] = [None] * len(prompts)
resp = _request_completion(api_base, model_name, prompts, sp_kwargs)
if resp is None or not resp.get("choices"):
raise AssertionError("BS=N empty/failed batched response")
choices = resp.get("choices", [])
if len(choices) != len(prompts):
raise AssertionError(
f"BS=N choices length {len(choices)} != num prompts {len(prompts)}"
)
for idx, choice in enumerate(choices):
toks, lps = _extract_tokens_and_logprobs(choice)
if lps is None:
raise AssertionError(f"BS=N missing logprobs for prompt {idx}")
bsN_tokens_per_prompt[idx] = list(toks)
bsN_logprobs_per_prompt[idx] = list(lps)
# compare
for i, (tokens_bs1, tokens_bsN, logprobs_bs1, logprobs_bsN) in enumerate(
zip(
bs1_tokens_per_prompt,
bsN_tokens_per_prompt,
bs1_logprobs_per_prompt,
bsN_logprobs_per_prompt,
)
):
if tokens_bs1 != tokens_bsN:
raise AssertionError(
f"Prompt {i} (sampling): Different tokens sampled. "
f"BS=1 tokens: {tokens_bs1} BS=N tokens: {tokens_bsN}"
)
if logprobs_bs1 is None or logprobs_bsN is None:
raise AssertionError(f"Prompt {i}: Missing logprobs in one of the runs")
if len(logprobs_bs1) != len(logprobs_bsN):
raise AssertionError(
f"Prompt {i}: Different number of steps: "
f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bsN)} (BS=N)."
)
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)):
if a != b:
diff = abs(a - b)
raise AssertionError(
f"Prompt {i} Step {t}: Bitwise mismatch "
f"(abs diff={diff:.6e}). "
f"BS=1 tokens: {tokens_bs1} BS=N tokens: {tokens_bsN}"
)
@skip_unsupported
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN_dp_http():
random.seed(int(os.getenv("VLLM_TEST_SEED", "12345")))
api_base = os.getenv("VLLM_API_BASE", "http://127.0.0.1:9256/v1")
model_name = os.getenv("VLLM_TEST_MODEL", "deepseek-ai/DeepSeek-R1")
up, reason = _server_is_up(api_base)
if not up:
pytest.skip(f"Server not reachable at {api_base}: {reason}")
prompts_all = [_random_prompt(10, 50) for _ in range(32)]
sp_kwargs: dict[str, Any] = {
"temperature": 0.6,
"top_p": 1.0,
"max_tokens": 8,
"seed": 42,
"logprobs": 5,
}
_compare_bs1_vs_bsn_single_process(
prompts=prompts_all,
sp_kwargs=sp_kwargs,
api_base=api_base,
model_name=model_name,
)

View File

@ -1,74 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
import pytest
import torch
from vllm.platforms import current_platform
skip_unsupported = pytest.mark.skipif(
not (current_platform.is_cuda() and current_platform.has_device_capability(90)),
reason="Requires CUDA and >= Hopper (SM90)",
)
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
# Generate more realistic prompts that will actually produce varied tokens
# Use a mix of common English text patterns
prompt_templates = [
# Question-answer style
"Question: What is the capital of France?\nAnswer: The capital of France is",
"Q: How does photosynthesis work?\nA: Photosynthesis is the process by which",
"User: Can you explain quantum mechanics?\nAssistant: Quantum mechanics is",
# Story/narrative style
"Once upon a time in a distant galaxy, there lived",
"The old man walked slowly down the street, remembering",
"In the year 2157, humanity finally discovered",
# Technical/code style
"To implement a binary search tree in Python, first we need to",
"The algorithm works by iterating through the array and",
"Here's how to optimize database queries using indexing:",
# Factual/informative style
"The Renaissance was a period in European history that",
"Climate change is caused by several factors including",
"The human brain contains approximately 86 billion neurons which",
# Conversational style
"I've been thinking about getting a new laptop because",
"Yesterday I went to the store and bought",
"My favorite thing about summer is definitely",
]
# Pick a random template
base_prompt = random.choice(prompt_templates)
if max_words < min_words:
max_words = min_words
target_words = random.randint(min_words, max_words)
if target_words > 50:
# For longer prompts, repeat context
padding_text = (
" This is an interesting topic that deserves more explanation. "
* (target_words // 50)
)
base_prompt = base_prompt + padding_text
return base_prompt
def _extract_step_logprobs(request_output):
if getattr(request_output, "outputs", None):
inner = request_output.outputs[0]
if hasattr(inner, "logprobs") and inner.logprobs is not None:
t = torch.tensor(
[
inner.logprobs[i][tid].logprob
for i, tid in enumerate(inner.token_ids)
],
dtype=torch.float32,
)
return t, inner.token_ids
return None, None

View File

@ -6,7 +6,7 @@ import pytest_asyncio
from tests.utils import RemoteOpenAIServer
# Use a small reasoning model to test the responses API.
MODEL_NAME = "Qwen/Qwen3-0.6B"
MODEL_NAME = "Qwen/Qwen3-1.7B"
@pytest.fixture(scope="module")

View File

@ -6,9 +6,66 @@ import random
import pytest
import torch
from utils import _extract_step_logprobs, _random_prompt, skip_unsupported
from vllm import LLM, SamplingParams
from vllm.platforms import current_platform
skip_unsupported = pytest.mark.skipif(
not (current_platform.is_cuda() and current_platform.has_device_capability(90)),
reason="Requires CUDA and >= Hopper (SM90)",
)
@pytest.fixture(autouse=True)
def enable_batch_invariant_mode(monkeypatch: pytest.MonkeyPatch):
"""Automatically enable batch invariant kernel overrides for all tests."""
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1")
yield
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
# Generate more realistic prompts that will actually produce varied tokens
# Use a mix of common English text patterns
prompt_templates = [
# Question-answer style
"Question: What is the capital of France?\nAnswer: The capital of France is",
"Q: How does photosynthesis work?\nA: Photosynthesis is the process by which",
"User: Can you explain quantum mechanics?\nAssistant: Quantum mechanics is",
# Story/narrative style
"Once upon a time in a distant galaxy, there lived",
"The old man walked slowly down the street, remembering",
"In the year 2157, humanity finally discovered",
# Technical/code style
"To implement a binary search tree in Python, first we need to",
"The algorithm works by iterating through the array and",
"Here's how to optimize database queries using indexing:",
# Factual/informative style
"The Renaissance was a period in European history that",
"Climate change is caused by several factors including",
"The human brain contains approximately 86 billion neurons which",
# Conversational style
"I've been thinking about getting a new laptop because",
"Yesterday I went to the store and bought",
"My favorite thing about summer is definitely",
]
# Pick a random template
base_prompt = random.choice(prompt_templates)
if max_words < min_words:
max_words = min_words
target_words = random.randint(min_words, max_words)
if target_words > 50:
# For longer prompts, repeat context
padding_text = (
" This is an interesting topic that deserves more explanation. "
* (target_words // 50)
)
base_prompt = base_prompt + padding_text
return base_prompt
@skip_unsupported
@ -147,6 +204,22 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
llm_bsN.shutdown()
def _extract_step_logprobs(request_output):
if getattr(request_output, "outputs", None):
inner = request_output.outputs[0]
if hasattr(inner, "logprobs") and inner.logprobs is not None:
t = torch.tensor(
[
inner.logprobs[i][tid].logprob
for i, tid in enumerate(inner.token_ids)
],
dtype=torch.float32,
)
return t, inner.token_ids
return None, None
@skip_unsupported
@pytest.mark.parametrize(
"backend",

View File

@ -9,10 +9,15 @@ with the standard CUDA-based implementation to ensure numerical accuracy.
import pytest
import torch
from utils import skip_unsupported
from vllm.model_executor.layers.batch_invariant import rms_norm as triton_rms_norm
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.platforms import current_platform
skip_unsupported = pytest.mark.skipif(
not (current_platform.is_cuda() and current_platform.has_device_capability(90)),
reason="Requires CUDA and >= Hopper (SM90)",
)
@skip_unsupported

View File

@ -50,7 +50,12 @@ if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorPromMetrics,
KVConnectorStats,
PromMetric,
PromMetricT,
)
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
@ -471,3 +476,18 @@ class KVConnectorBase_V1(ABC):
which can implement custom aggregation logic on the data dict.
"""
return None
@classmethod
def build_prom_metrics(
cls,
vllm_config: "VllmConfig",
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[str]],
) -> Optional["KVConnectorPromMetrics"]:
"""
Create a KVConnectorPromMetrics subclass which should register
per-connector Prometheus metrics and implement observe() to
expose connector transfer stats via Prometheus.
"""
return None

View File

@ -1,13 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from typing import Any
from typing import Any, TypeAlias, TypeVar
from vllm.config.kv_transfer import KVTransferConfig
from prometheus_client import Counter, Gauge, Histogram
from vllm.config import KVTransferConfig, VllmConfig
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_transfer_state import has_kv_transfer_group
from vllm.logger import init_logger
PromMetric: TypeAlias = Gauge | Counter | Histogram
PromMetricT = TypeVar("PromMetricT", bound=PromMetric)
logger = init_logger(__name__)
@ -102,3 +107,83 @@ class KVConnectorLogging:
# Reset metrics for next interval
self.reset()
class KVConnectorPromMetrics:
"""
A base class for per-connector Prometheus metric registration
and recording.
"""
def __init__(
self,
vllm_config: VllmConfig,
metric_types: dict[type[PromMetric], type[PromMetricT]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[str]],
):
self._kv_transfer_config = vllm_config.kv_transfer_config
self._gauge_cls = metric_types[Gauge]
self._counter_cls = metric_types[Counter]
self._histogram_cls = metric_types[Histogram]
self._labelnames = labelnames
self._per_engine_labelvalues = per_engine_labelvalues
def make_per_engine(self, metric: PromMetric) -> PromMetric:
"""
Create a per-engine child of a prometheus_client.Metric with
the appropriate labels set. The parent metric must be created
using the labelnames list.
"""
return {
idx: metric.labels(*labelvalues)
for idx, labelvalues in self._per_engine_labelvalues.items()
}
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
"""
Record the supplied transfer statistics to Prometheus metrics. These
statistics are engine-specific, and should be recorded to a metric
with the appropriate 'engine' label. These metric instances can be
created using the make_per_engine() helper method.
"""
raise NotImplementedError
class KVConnectorPrometheus:
"""
Support for registering per-connector Prometheus metrics, and
recording transfer statistics to those metrics. Uses
KVConnectorBase.build_prom_metrics().
"""
_gauge_cls = Gauge
_counter_cls = Counter
_histogram_cls = Histogram
def __init__(
self,
vllm_config: VllmConfig,
labelnames: list[str],
per_engine_labelvalues: dict[int, list[str]],
):
self.prom_metrics: KVConnectorPromMetrics | None = None
kv_transfer_config = vllm_config.kv_transfer_config
if kv_transfer_config and kv_transfer_config.kv_connector:
connector_cls = KVConnectorFactory.get_connector_class(kv_transfer_config)
metric_types = {
Gauge: self._gauge_cls,
Counter: self._counter_cls,
Histogram: self._histogram_cls,
}
self.prom_metrics = connector_cls.build_prom_metrics(
vllm_config,
metric_types,
labelnames,
per_engine_labelvalues,
)
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
if self.prom_metrics is None:
return
self.prom_metrics.observe(transfer_stats_data, engine_idx)

View File

@ -9,13 +9,19 @@ import torch
from vllm.config import VllmConfig
from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorPromMetrics,
KVConnectorStats,
PromMetric,
PromMetricT,
)
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput
@ -72,6 +78,27 @@ class MultiKVConnectorStats(KVConnectorStats):
self.data[connector_id] = stats
class MultiKVConnectorPromMetrics(KVConnectorPromMetrics):
def __init__(
self,
vllm_config: "VllmConfig",
metric_types: dict[type[PromMetric], type[PromMetricT]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[str]],
prom_metrics: dict[str, KVConnectorPromMetrics],
):
super().__init__(vllm_config, metric_types, labelnames, per_engine_labelvalues)
self._prom_metrics = prom_metrics
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
for connector_id, stats_data in transfer_stats_data.items():
assert connector_id in self._prom_metrics, (
f"{connector_id} is not contained in the list of registered connectors "
f"with Prometheus metrics support: {self._prom_metrics.keys()}"
)
self._prom_metrics[connector_id].observe(stats_data["data"], engine_idx)
class MultiConnector(KVConnectorBase_V1):
"""
A wrapper for using multiple KVConnectors at the same time.
@ -84,19 +111,13 @@ class MultiConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
self._connectors: list[KVConnectorBase_V1] = []
self._ktc_kv_transfer_config = []
ktcs = self._kv_transfer_config.kv_connector_extra_config.get("connectors")
assert ktcs is not None
for ktc in ktcs:
temp_config = copy.copy(vllm_config)
engine_id = ktc.get("engine_id", self._kv_transfer_config.engine_id)
temp_config.kv_transfer_config = KVTransferConfig(
**ktc, engine_id=engine_id
)
self._connectors.append(
KVConnectorFactory.create_connector(temp_config, role)
)
for connector_cls, temp_config in self._get_connector_classes_and_configs(
vllm_config
):
self._connectors.append(connector_cls(temp_config, role))
self._ktc_kv_transfer_config.append(temp_config.kv_transfer_config)
# A mapping from request id to the index of the connector chosen to
@ -109,6 +130,32 @@ class MultiConnector(KVConnectorBase_V1):
# Propagated from scheduler to worker side via the connector metadata.
self._extra_async_saves: dict[str, int] = {}
@classmethod
def _get_connector_classes_and_configs(
cls, vllm_config: "VllmConfig"
) -> list[tuple[type[KVConnectorBaseType], "VllmConfig"]]:
assert vllm_config.kv_transfer_config is not None
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"connectors"
)
assert ktcs is not None
ret: list[tuple[type[KVConnectorBaseType], VllmConfig]] = []
for ktc in ktcs:
temp_config = copy.copy(vllm_config)
engine_id = ktc.get("engine_id", vllm_config.kv_transfer_config.engine_id)
temp_config.kv_transfer_config = KVTransferConfig(
**ktc, engine_id=engine_id
)
ret.append(
(
KVConnectorFactory.get_connector_class(
temp_config.kv_transfer_config
),
temp_config,
)
)
return ret
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
for c in self._connectors:
c.register_kv_caches(kv_caches)
@ -295,18 +342,12 @@ class MultiConnector(KVConnectorBase_V1):
None if the connector does not require a specific layout.
"""
assert vllm_config.kv_transfer_config is not None
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"connectors"
)
assert ktcs is not None
layouts: set[str] = set()
temp_vllm_config = copy.copy(vllm_config)
for ktc in ktcs:
kv_transfer_config = KVTransferConfig(**ktc)
temp_vllm_config.kv_transfer_config = kv_transfer_config
connector_cls = KVConnectorFactory.get_connector_class(kv_transfer_config)
for connector_cls, temp_config in cls._get_connector_classes_and_configs(
vllm_config
):
required_kvcache_layout = connector_cls.get_required_kvcache_layout(
temp_vllm_config
temp_config
)
if required_kvcache_layout is not None:
layouts.add(required_kvcache_layout)
@ -372,3 +413,28 @@ class MultiConnector(KVConnectorBase_V1):
stats_by_connector = MultiKVConnectorStats()
stats_by_connector[c.__class__.__name__] = stats
return stats_by_connector
@classmethod
def build_prom_metrics(
cls,
vllm_config: "VllmConfig",
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[str]],
) -> KVConnectorPromMetrics:
prom_metrics: dict[str, KVConnectorPromMetrics] = {}
for connector_cls, temp_config in cls._get_connector_classes_and_configs(
vllm_config
):
connector_prom = connector_cls.build_prom_metrics(
temp_config, metric_types, labelnames, per_engine_labelvalues
)
if connector_prom is not None:
prom_metrics[connector_cls.__name__] = connector_prom
return MultiKVConnectorPromMetrics(
vllm_config,
metric_types,
labelnames,
per_engine_labelvalues,
prom_metrics,
)

View File

@ -30,7 +30,12 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorMetadata,
KVConnectorRole,
)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorPromMetrics,
KVConnectorStats,
PromMetric,
PromMetricT,
)
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
@ -254,6 +259,18 @@ class NixlConnector(KVConnectorBase_V1):
else NixlKVConnectorStats()
)
@classmethod
def build_prom_metrics(
cls,
vllm_config: VllmConfig,
metric_types: dict[type[PromMetric], type[PromMetricT]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[str]],
) -> KVConnectorPromMetrics:
return NixlPromMetrics(
vllm_config, metric_types, labelnames, per_engine_labelvalues
)
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
assert self.connector_worker is not None
assert isinstance(self._connector_metadata, NixlConnectorMetadata)
@ -1960,3 +1977,125 @@ class NixlKVConnectorStats(KVConnectorStats):
@property
def num_successful_transfers(self) -> int:
return len(self.data["transfer_duration"])
class NixlPromMetrics(KVConnectorPromMetrics):
def __init__(
self,
vllm_config: VllmConfig,
metric_types: dict[type[PromMetric], type[PromMetricT]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[str]],
):
super().__init__(vllm_config, metric_types, labelnames, per_engine_labelvalues)
buckets = [
0.001,
0.005,
0.01,
0.025,
0.05,
0.075,
0.1,
0.2,
0.3,
0.5,
0.75,
1.0,
5.0,
]
nixl_histogram_xfer_time = self._histogram_cls(
name="vllm:nixl_xfer_time_seconds",
documentation="Histogram of transfer duration for NIXL KV Cache transfers.",
buckets=buckets[1:],
labelnames=labelnames,
)
self.nixl_histogram_xfer_time = self.make_per_engine(nixl_histogram_xfer_time)
nixl_histogram_post_time = self._histogram_cls(
name="vllm:nixl_post_time_seconds",
documentation="Histogram of transfer post time for NIXL KV"
" Cache transfers.",
buckets=buckets,
labelnames=labelnames,
)
self.nixl_histogram_post_time = self.make_per_engine(nixl_histogram_post_time)
# uniform 2kb to 16gb range
buckets = [2 ** (10 + i) for i in range(1, 25, 2)]
nixl_histogram_bytes_transferred = self._histogram_cls(
name="vllm:nixl_bytes_transferred",
documentation="Histogram of bytes transferred per NIXL KV Cache transfers.",
buckets=buckets,
labelnames=labelnames,
)
self.nixl_histogram_bytes_transferred = self.make_per_engine(
nixl_histogram_bytes_transferred
)
buckets = [
10,
20,
30,
50,
75,
100,
200,
400,
1000,
2000,
4000,
10000,
20000,
50000,
]
nixl_histogram_num_descriptors = self._histogram_cls(
name="vllm:nixl_num_descriptors",
documentation="Histogram of number of descriptors per NIXL"
" KV Cache transfers.",
buckets=buckets,
labelnames=labelnames,
)
self.nixl_histogram_num_descriptors = self.make_per_engine(
nixl_histogram_num_descriptors
)
counter_nixl_num_failed_transfers = self._counter_cls(
name="vllm:nixl_num_failed_transfers",
documentation="Number of failed NIXL KV Cache transfers.",
labelnames=labelnames,
)
self.counter_nixl_num_failed_transfers = self.make_per_engine(
counter_nixl_num_failed_transfers
)
counter_nixl_num_failed_notifications = self._counter_cls(
name="vllm:nixl_num_failed_notifications",
documentation="Number of failed NIXL KV Cache notifications.",
labelnames=labelnames,
)
self.counter_nixl_num_failed_notifications = self.make_per_engine(
counter_nixl_num_failed_notifications
)
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
for prom_obj, list_item_key in zip(
[
self.nixl_histogram_xfer_time,
self.nixl_histogram_post_time,
self.nixl_histogram_bytes_transferred,
self.nixl_histogram_num_descriptors,
],
[
"transfer_duration",
"post_duration",
"bytes_transferred",
"num_descriptors",
],
):
for list_item in transfer_stats_data[list_item_key]:
prom_obj[engine_idx].observe(list_item)
for counter_obj, counter_item_key in zip(
[
self.counter_nixl_num_failed_transfers,
self.counter_nixl_num_failed_notifications,
],
["num_failed_transfers", "num_failed_notifications"],
):
for list_item in transfer_stats_data[counter_item_key]:
counter_obj[engine_idx].inc(list_item)

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import logging
import os
import threading
@ -96,19 +97,30 @@ class P2pNcclEngine:
# Each card corresponds to a ZMQ address.
self.zmq_address = f"{self._hostname}:{self._port}"
# The `http_port` must be consistent with the port of OpenAI.
self.http_address = (
f"{self._hostname}:{self.config.kv_connector_extra_config['http_port']}"
)
# If `proxy_ip` or `proxy_port` is `""`,
# then the ping thread will not be enabled.
proxy_ip = self.config.get_from_extra_config("proxy_ip", "")
proxy_port = self.config.get_from_extra_config("proxy_port", "")
if proxy_ip == "" or proxy_port == "":
self.proxy_address = ""
self.http_address = ""
else:
self.proxy_address = proxy_ip + ":" + proxy_port
# the `http_port` must be consistent with the port of OpenAI.
http_port = self.config.get_from_extra_config("http_port", None)
if http_port is None:
example_cfg = {
"kv_connector": "P2pNcclConnector",
"kv_connector_extra_config": {"http_port": 8000},
}
example = (
f"--port=8000 --kv-transfer-config='{json.dumps(example_cfg)}'"
)
raise ValueError(
"kv_connector_extra_config.http_port is required. "
f"Example: {example}"
)
self.http_address = f"{self._hostname}:{http_port}"
self.context = zmq.Context()
self.router_socket = self.context.socket(zmq.ROUTER)

View File

@ -336,36 +336,34 @@ class SharedStorageConnector(KVConnectorBase_V1):
cached_reqs = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(cached_reqs.req_ids):
resumed_from_preemption = cached_reqs.resumed_from_preemption[i]
if not resumed_from_preemption or req_id not in self._requests_need_load:
continue
num_computed_tokens = cached_reqs.num_computed_tokens[i]
num_new_tokens = scheduler_output.num_scheduled_tokens[req_id]
new_block_ids = cached_reqs.new_block_ids[i]
resumed_from_preemption = cached_reqs.resumed_from_preemption[i]
# NOTE(rob): here we rely on the resumed requests being
# the first N requests in the list scheduled_cache_reqs.
if not resumed_from_preemption:
break
if req_id in self._requests_need_load:
# NOTE(rob): cached_req_data does not have the full
# list of token ids (only new tokens). So we look it
# up in the actual request object.
request = self._requests_need_load[req_id]
total_tokens = num_computed_tokens + num_new_tokens
token_ids = request.all_token_ids[:total_tokens]
# NOTE(rob): cached_req_data does not have the full
# list of token ids (only new tokens). So we look it
# up in the actual request object.
request = self._requests_need_load[req_id]
total_tokens = num_computed_tokens + num_new_tokens
token_ids = request.all_token_ids[:total_tokens]
# NOTE(rob): For resumed req, new_block_ids is all
# of the block_ids for the request.
assert new_block_ids is not None
block_ids = new_block_ids[0]
# NOTE(rob): For resumed req, new_block_ids is all
# of the block_ids for the request.
assert new_block_ids is not None
block_ids = new_block_ids[0]
meta.add_request(
token_ids=token_ids,
block_ids=block_ids,
block_size=self._block_size,
is_store=False,
mm_hashes=[f.identifier for f in request.mm_features],
)
total_need_load += 1
meta.add_request(
token_ids=token_ids,
block_ids=block_ids,
block_size=self._block_size,
is_store=False,
mm_hashes=[f.identifier for f in request.mm_features],
)
total_need_load += 1
assert total_need_load == len(self._requests_need_load)
self._requests_need_load.clear()

View File

@ -206,19 +206,6 @@ class OpenAIServingCompletion(OpenAIServing):
lora_request=lora_request,
)
else:
# TODO: optimize the performance here
# in batch invariance mode, force routing to DP rank 0.
data_parallel_rank = None
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
if vllm_is_batch_invariant():
config = self.engine_client.vllm_config
dp_size = config.parallel_config.data_parallel_size
if dp_size and dp_size > 1:
data_parallel_rank = 0
engine_request, tokenization_kwargs = await self._process_inputs(
request_id_item,
engine_prompt,
@ -226,7 +213,6 @@ class OpenAIServingCompletion(OpenAIServing):
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
data_parallel_rank=data_parallel_rank,
)
generator = self.engine_client.generate(

View File

@ -1158,7 +1158,6 @@ class OpenAIServing:
lora_request: LoRARequest | None,
trace_headers: Mapping[str, str] | None,
priority: int,
data_parallel_rank: int | None = None,
) -> tuple[EngineCoreRequest, dict[str, Any]]:
"""Use the Processor to process inputs for AsyncLLM."""
tokenization_kwargs: dict[str, Any] = {}
@ -1174,7 +1173,6 @@ class OpenAIServing:
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=priority,
data_parallel_rank=data_parallel_rank,
)
return engine_request, tokenization_kwargs

View File

@ -155,6 +155,7 @@ if TYPE_CHECKING:
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency"] = "throughput"
VLLM_ALLOW_BATCHED_TRITON_FALLBACK: bool = False
VLLM_XGRAMMAR_CACHE_MB: int = 0
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
@ -247,10 +248,19 @@ def maybe_convert_bool(value: str | None) -> bool | None:
return bool(int(value))
def disable_compile_cache() -> bool:
return bool(int(os.getenv("VLLM_DISABLE_COMPILE_CACHE", "0")))
def use_aot_compile() -> bool:
from vllm.utils.torch_utils import is_torch_equal_or_newer
default_value = "1" if is_torch_equal_or_newer("2.10.0.dev") else "0"
default_value = (
"1"
if is_torch_equal_or_newer("2.10.0.dev") and not disable_compile_cache()
else "0"
)
return os.environ.get("VLLM_USE_AOT_COMPILE", default_value) == "1"
@ -963,9 +973,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_LOG_BATCHSIZE_INTERVAL": lambda: float(
os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")
),
"VLLM_DISABLE_COMPILE_CACHE": lambda: bool(
int(os.getenv("VLLM_DISABLE_COMPILE_CACHE", "0"))
),
"VLLM_DISABLE_COMPILE_CACHE": disable_compile_cache,
# If set, vllm will run in development mode, which will enable
# some additional endpoints for developing and debugging,
# e.g. `/reset_prefix_cache`
@ -1138,6 +1146,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16": lambda: bool(
int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "0"))
),
# If set to 1, allow fallback to batched triton kernel when deepgemm
# is unavailable. By default (0), the system will crash if deepgemm
# is expected but not available.
"VLLM_ALLOW_BATCHED_TRITON_FALLBACK": lambda: bool(
int(os.getenv("VLLM_ALLOW_BATCHED_TRITON_FALLBACK", "0"))
),
# Control the cache sized used by the xgrammar compiler. The default
# of 512 MB should be enough for roughly 1000 JSON schemas.
# It can be changed with this variable if needed for some reason.

View File

@ -3,6 +3,7 @@
import torch
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts,
@ -22,11 +23,8 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
):
super().__init__(quant_config)
self.batched_triton_experts = BatchedTritonExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=num_dispatchers,
quant_config=self.quant_config,
)
# Store the original request for deep gemm
deep_gemm_requested = allow_deep_gemm
self.allow_deep_gemm = (
allow_deep_gemm
@ -44,6 +42,31 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
else None
)
# If deep gemm was requested but is not available (either due to
# unsupported configuration or missing dependencies), check if
# we should allow fallback to batched triton kernel
if (
deep_gemm_requested
and self.batched_deep_gemm_experts is None
and not envs.VLLM_ALLOW_BATCHED_TRITON_FALLBACK
):
raise RuntimeError(
"DeepGemm was requested but is not available. "
"The batched triton kernel fallback is disabled by default. "
"Set VLLM_ALLOW_BATCHED_TRITON_FALLBACK=1 to enable the fallback "
"for debugging purposes."
)
self.batched_triton_experts = (
BatchedTritonExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=num_dispatchers,
quant_config=self.quant_config,
)
if self.batched_deep_gemm_experts is None
else None
)
assert (
self.batched_deep_gemm_experts is not None
or self.batched_triton_experts is not None

View File

@ -1135,6 +1135,7 @@ class FusedMoE(CustomOp):
)
self.global_num_experts = num_experts + num_redundant_experts
self.logical_num_experts = num_experts
self.zero_expert_num = zero_expert_num
self.zero_expert_type = zero_expert_type
@ -1998,13 +1999,12 @@ class FusedMoE(CustomOp):
moe = self.moe_config
# Note here we use `num_experts` which is logical expert count
if self.vllm_config.parallel_config.enable_dbo:
states_shape = (2, moe.max_num_tokens, self.hidden_size)
logits_shape = (2, moe.max_num_tokens, moe.num_experts)
logits_shape = (2, moe.max_num_tokens, self.logical_num_experts)
else:
states_shape = (moe.max_num_tokens, self.hidden_size)
logits_shape = (moe.max_num_tokens, moe.num_experts)
logits_shape = (moe.max_num_tokens, self.logical_num_experts)
self.batched_hidden_states = torch.zeros(
states_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()

View File

@ -428,6 +428,14 @@ class Qwen2_5_VisionAttention(nn.Module):
)
elif self.attn_backend == _Backend.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
from vllm.platforms import current_platform
# Never remove the next contiguous logic
# Without it, hallucinations occur with the backend
if current_platform.is_rocm():
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
outputs = []
for i in range(1, len(cu_seqlens)):
start_idx = cu_seqlens[i - 1]

View File

@ -261,6 +261,21 @@ class CudaPlatformBase(Platform):
from vllm.attention.backends.registry import _Backend
if use_mla:
# explicitly reject non-MLA backends when MLA is enabled to avoid
# silently selecting an incompatible backend (e.g., FLASHINFER).
if selected_backend in {
_Backend.FLASHINFER,
_Backend.FLASH_ATTN,
_Backend.TRITON_ATTN,
_Backend.TREE_ATTN,
_Backend.XFORMERS,
}:
raise ValueError(
f"Attention backend {selected_backend} incompatible with MLA. "
"Please use one of the MLA backends: FLASHINFER_MLA, CUTLASS_MLA, "
"FLASHMLA, FLASH_ATTN_MLA, or TRITON_MLA. Alternatively, set "
"VLLM_MLA_DISABLE=1 to disable MLA for this model."
)
if not use_v1:
raise RuntimeError(
"MLA attention backends require the V1 engine. "

View File

@ -72,6 +72,7 @@ _ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = {
"0x74a0": "AMD_Instinct_MI300A",
"0x74a1": "AMD_Instinct_MI300X",
"0x74b5": "AMD_Instinct_MI300X", # MI300X VF
"0x74a2": "AMD_Instinct_MI308X",
"0x74a5": "AMD_Instinct_MI325X",
"0x74b9": "AMD_Instinct_MI325X", # MI325X VF
"0x74a9": "AMD_Instinct_MI300X_HF",

View File

@ -8,9 +8,6 @@ from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
from vllm.attention.backends.abstract import AttentionLayer, AttentionType
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
MLACommonImpl,
@ -130,44 +127,19 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
if self.bmm2_scale is None:
self.bmm2_scale = layer._v_scale_float
if vllm_is_batch_invariant():
# TODO(wentao): optimize this when it is supported by Flashinfer upstream.
# execute per-request to eliminate batch-shape-dependent kernel paths.
num = q.shape[0]
outs = []
for i in range(num):
qi = q[i : i + 1]
bt_i = attn_metadata.decode.block_table[i : i + 1]
sl_i = attn_metadata.decode.seq_lens[i : i + 1]
oi = trtllm_batch_decode_with_kv_cache_mla(
query=qi,
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
workspace_buffer=self._workspace_buffer,
qk_nope_head_dim=self.qk_nope_head_dim,
kv_lora_rank=self.kv_lora_rank,
qk_rope_head_dim=self.qk_rope_head_dim,
block_tables=bt_i,
seq_lens=sl_i,
max_seq_len=attn_metadata.max_seq_len,
bmm1_scale=self.bmm1_scale,
bmm2_scale=self.bmm2_scale,
)
outs.append(oi)
o = torch.cat(outs, dim=0)
else:
o = trtllm_batch_decode_with_kv_cache_mla(
query=q,
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
workspace_buffer=self._workspace_buffer,
qk_nope_head_dim=self.qk_nope_head_dim,
kv_lora_rank=self.kv_lora_rank,
qk_rope_head_dim=self.qk_rope_head_dim,
block_tables=attn_metadata.decode.block_table,
seq_lens=attn_metadata.decode.seq_lens,
max_seq_len=attn_metadata.max_seq_len,
bmm1_scale=self.bmm1_scale,
bmm2_scale=self.bmm2_scale,
)
o = trtllm_batch_decode_with_kv_cache_mla(
query=q,
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
workspace_buffer=self._workspace_buffer,
qk_nope_head_dim=self.qk_nope_head_dim,
kv_lora_rank=self.kv_lora_rank,
qk_rope_head_dim=self.qk_rope_head_dim,
block_tables=attn_metadata.decode.block_table,
seq_lens=attn_metadata.decode.seq_lens,
max_seq_len=attn_metadata.max_seq_len,
bmm1_scale=self.bmm1_scale,
bmm2_scale=self.bmm2_scale,
)
# Flatten the output for consistent shape
o = o.view(-1, o.shape[-2], o.shape[-1])

View File

@ -689,9 +689,15 @@ class AsyncLLM(EngineClient):
await self.reset_prefix_cache()
await self.engine_core.sleep_async(level)
if self.logger_manager is not None:
self.logger_manager.record_sleep_state(1, level)
async def wake_up(self, tags: list[str] | None = None) -> None:
await self.engine_core.wake_up_async(tags)
if self.logger_manager is not None:
self.logger_manager.record_sleep_state(0, 0)
async def is_sleeping(self) -> bool:
return await self.engine_core.is_sleeping_async()

View File

@ -332,9 +332,15 @@ class LLMEngine:
def sleep(self, level: int = 1):
self.engine_core.sleep(level)
if self.logger_manager is not None:
self.logger_manager.record_sleep_state(1, level)
def wake_up(self, tags: list[str] | None = None):
self.engine_core.wake_up(tags)
if self.logger_manager is not None:
self.logger_manager.record_sleep_state(0, 0)
def is_sleeping(self) -> bool:
return self.engine_core.is_sleeping()

View File

@ -9,8 +9,12 @@ from typing import TypeAlias
from prometheus_client import Counter, Gauge, Histogram
import vllm.envs as envs
from vllm.config import SupportsMetricsInfo, VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorLogging
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorLogging,
KVConnectorPrometheus,
)
from vllm.logger import init_logger
from vllm.plugins import load_plugins_by_group
from vllm.v1.engine import FinishReason
@ -56,6 +60,9 @@ class StatLoggerBase(ABC):
def log(self): # noqa
pass
def record_sleep_state(self, is_awake: int, level: int): # noqa
pass
def load_stat_logger_plugin_factories() -> list[StatLoggerFactory]:
factories: list[StatLoggerFactory] = []
@ -335,6 +342,7 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
_counter_cls = Counter
_histogram_cls = Histogram
_spec_decoding_cls = SpecDecodingProm
_kv_connector_cls = KVConnectorPrometheus
def __init__(
self, vllm_config: VllmConfig, engine_indexes: list[int] | None = None
@ -354,12 +362,15 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
model_name = vllm_config.model_config.served_model_name
max_model_len = vllm_config.model_config.max_model_len
spec_decode_labelvalues: dict[int, list[str]] = {
per_engine_labelvalues: dict[int, list[str]] = {
idx: [model_name, str(idx)] for idx in engine_indexes
}
self.spec_decoding_prom = self._spec_decoding_cls(
vllm_config.speculative_config, labelnames, spec_decode_labelvalues
vllm_config.speculative_config, labelnames, per_engine_labelvalues
)
self.kv_connector_prom = self._kv_connector_cls(
vllm_config, labelnames, per_engine_labelvalues
)
#
@ -384,8 +395,33 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
self.gauge_scheduler_waiting = make_per_engine(
gauge_scheduler_waiting, engine_indexes, model_name
)
if envs.VLLM_SERVER_DEV_MODE:
gauge_engine_sleep_state = self._gauge_cls(
name="vllm:engine_sleep_state",
documentation=(
"Engine sleep state; awake = 0 means engine is sleeping; "
"awake = 1 means engine is awake; "
"weights_offloaded = 1 means sleep level 1; "
"discard_all = 1 means sleep level 2."
),
labelnames=labelnames + ["sleep_state"],
multiprocess_mode="mostrecent",
)
self.gauge_engine_sleep_state = {}
sleep_state = ["awake", "weights_offloaded", "discard_all"]
for s in sleep_state:
self.gauge_engine_sleep_state[s] = {
idx: gauge_engine_sleep_state.labels(
engine=idx, model_name=model_name, sleep_state=s
)
for idx in engine_indexes
}
# Setting default values
self.record_sleep_state()
#
# GPU cache
#
# Deprecated in 0.9.2 - Renamed as vllm:kv_cache_usage_perc
@ -933,6 +969,11 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
scheduler_stats.spec_decoding_stats, engine_idx
)
if scheduler_stats.kv_connector_stats is not None:
self.kv_connector_prom.observe(
scheduler_stats.kv_connector_stats, engine_idx
)
if mm_cache_stats is not None:
self.counter_mm_cache_queries[engine_idx].inc(mm_cache_stats.queries)
self.counter_mm_cache_hits[engine_idx].inc(mm_cache_stats.hits)
@ -1010,6 +1051,25 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
}
self.gauge_lora_info.labels(**lora_info_labels).set_to_current_time()
def record_sleep_state(self, sleep: int = 0, level: int = 0):
awake = 1
discard_all = 0
weights_offloaded = 0
if sleep == 1:
awake = 0
if level == 1:
weights_offloaded = 1
elif level == 2:
discard_all = 1
for engine_idx in self.engine_indexes:
self.gauge_engine_sleep_state["discard_all"][engine_idx].set(discard_all)
self.gauge_engine_sleep_state["weights_offloaded"][engine_idx].set(
weights_offloaded
)
self.gauge_engine_sleep_state["awake"][engine_idx].set(awake)
def log_engine_initialized(self):
self.log_metrics_info("cache_config", self.vllm_config.cache_config)
@ -1131,6 +1191,10 @@ class StatLoggerManager:
engine_idx=engine_idx,
)
def record_sleep_state(self, sleep: int = 0, level: int = 0):
for logger in self.stat_loggers:
logger.record_sleep_state(sleep, level)
def log(self):
for logger in self.stat_loggers:
logger.log()

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorPrometheus
from vllm.v1.metrics.loggers import PrometheusStatLogger
from vllm.v1.spec_decode.metrics import SpecDecodingProm
@ -141,6 +142,18 @@ class RaySpecDecodingProm(SpecDecodingProm):
_counter_cls = RayCounterWrapper
class RayKVConnectorPrometheus(KVConnectorPrometheus):
"""
RayKVConnectorPrometheus is used by RayMetrics to log Ray
metrics. Provides the same metrics as KV connectors but
uses Ray's util.metrics library.
"""
_gauge_cls = RayGaugeWrapper
_counter_cls = RayCounterWrapper
_histogram_cls = RayHistogramWrapper
class RayPrometheusStatLogger(PrometheusStatLogger):
"""RayPrometheusStatLogger uses Ray metrics instead."""
@ -148,6 +161,7 @@ class RayPrometheusStatLogger(PrometheusStatLogger):
_counter_cls = RayCounterWrapper
_histogram_cls = RayHistogramWrapper
_spec_decoding_cls = RaySpecDecodingProm
_kv_connector_cls = RayKVConnectorPrometheus
@staticmethod
def _unregister_vllm_metrics():