Compare commits
17 Commits
wentao-bat
...
copilot/di
| Author | SHA1 | Date | |
|---|---|---|---|
| 7557a67655 | |||
| 1af476b0e9 | |||
| 8c3b1c7c62 | |||
| d4aa144343 | |||
| fcb1d570bb | |||
| accb8fab07 | |||
| 5b0448104f | |||
| f7a6682872 | |||
| a9fe0793f2 | |||
| 7568a282b9 | |||
| 1da3309ace | |||
| 5522fb274b | |||
| 0f95a1c3f2 | |||
| ded24e3e54 | |||
| c72d44ba4a | |||
| c292032b44 | |||
| b286fba2bb |
@ -205,6 +205,24 @@ steps:
|
||||
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
|
||||
- popd
|
||||
|
||||
- label: Distributed Tests (8 GPUs) # 4min
|
||||
timeout_in_minutes: 10
|
||||
gpu: h100
|
||||
num_gpus: 8
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
source_file_dependencies:
|
||||
- examples/offline_inference/torchrun_dp_example.py
|
||||
- vllm/config/parallel.py
|
||||
- vllm/distributed/
|
||||
- vllm/v1/engine/llm_engine.py
|
||||
- vllm/v1/executor/uniproc_executor.py
|
||||
- vllm/v1/worker/gpu_worker.py
|
||||
commands:
|
||||
# https://github.com/NVIDIA/nccl/issues/1838
|
||||
- export NCCL_CUMEM_HOST_ENABLE=0
|
||||
# test with torchrun tp=2 and dp=4 with ep
|
||||
- torchrun --nproc-per-node=8 ../examples/offline_inference/torchrun_dp_example.py --tp-size=2 --pp-size=1 --dp-size=4 --enable-ep
|
||||
|
||||
- label: EPLB Algorithm Test # 5min
|
||||
timeout_in_minutes: 15
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
@ -401,7 +419,7 @@ steps:
|
||||
--ignore=lora/test_deepseekv2_tp.py \
|
||||
--ignore=lora/test_gptoss.py \
|
||||
--ignore=lora/test_qwen3moe_tp.py
|
||||
|
||||
|
||||
parallelism: 4
|
||||
|
||||
- label: PyTorch Compilation Unit Tests # 15min
|
||||
@ -1126,7 +1144,7 @@ steps:
|
||||
- tests/weight_loading
|
||||
commands:
|
||||
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt
|
||||
|
||||
|
||||
- label: NixlConnector PD accuracy tests (Distributed) # 30min
|
||||
timeout_in_minutes: 30
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
|
||||
@ -9,10 +9,76 @@ To run this example:
|
||||
```bash
|
||||
$ torchrun --nproc-per-node=2 examples/offline_inference/torchrun_dp_example.py
|
||||
```
|
||||
|
||||
With custom parallelism settings:
|
||||
```bash
|
||||
$ torchrun --nproc-per-node=8 examples/offline_inference/torchrun_dp_example.py \
|
||||
--tp-size=2 --pp-size=1 --dp-size=4 --enable-ep
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Data-parallel inference with torchrun"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Tensor parallel size (default: 1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pp-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Pipeline parallel size (default: 1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dp-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Data parallel size (default: 2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-ep",
|
||||
action="store_true",
|
||||
help="Enable expert parallel (default: False)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="microsoft/Phi-mini-MoE-instruct",
|
||||
help="Model name or path (default: microsoft/Phi-mini-MoE-instruct)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-model-len",
|
||||
type=int,
|
||||
default=4096,
|
||||
help="Maximum model length (default: 4096)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu-memory-utilization",
|
||||
type=float,
|
||||
default=0.6,
|
||||
help="GPU memory utilization (default: 0.6)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Random seed (default: 1)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
args = parse_args()
|
||||
|
||||
|
||||
# Create prompts, the same across all ranks
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
@ -30,15 +96,15 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
# all ranks have the same random seed, so that sampling can be
|
||||
# deterministic across ranks.
|
||||
llm = LLM(
|
||||
model="microsoft/Phi-mini-MoE-instruct",
|
||||
tensor_parallel_size=1,
|
||||
data_parallel_size=2,
|
||||
pipeline_parallel_size=1,
|
||||
enable_expert_parallel=False,
|
||||
model=args.model,
|
||||
tensor_parallel_size=args.tp_size,
|
||||
data_parallel_size=args.dp_size,
|
||||
pipeline_parallel_size=args.pp_size,
|
||||
enable_expert_parallel=args.enable_ep,
|
||||
distributed_executor_backend="external_launcher",
|
||||
max_model_len=4096,
|
||||
gpu_memory_utilization=0.6,
|
||||
seed=1,
|
||||
max_model_len=args.max_model_len,
|
||||
gpu_memory_utilization=args.gpu_memory_utilization,
|
||||
seed=args.seed,
|
||||
)
|
||||
|
||||
dp_rank = llm.llm_engine.vllm_config.parallel_config.data_parallel_rank
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import requests
|
||||
from prometheus_client.parser import text_string_to_metric_families
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
@ -31,12 +32,28 @@ def test_sleep_mode():
|
||||
assert response.status_code == 200
|
||||
assert response.json().get("is_sleeping") is True
|
||||
|
||||
# check sleep metrics
|
||||
response = requests.get(remote_server.url_for("metrics"))
|
||||
assert response.status_code == 200
|
||||
awake, weights_offloaded, discard_all = _get_sleep_metrics_from_api(response)
|
||||
assert awake == 0
|
||||
assert weights_offloaded == 1
|
||||
assert discard_all == 0
|
||||
|
||||
response = requests.post(remote_server.url_for("wake_up"))
|
||||
assert response.status_code == 200
|
||||
response = requests.get(remote_server.url_for("is_sleeping"))
|
||||
assert response.status_code == 200
|
||||
assert response.json().get("is_sleeping") is False
|
||||
|
||||
# check sleep metrics
|
||||
response = requests.get(remote_server.url_for("metrics"))
|
||||
assert response.status_code == 200
|
||||
awake, weights_offloaded, discard_all = _get_sleep_metrics_from_api(response)
|
||||
assert awake == 1
|
||||
assert weights_offloaded == 0
|
||||
assert discard_all == 0
|
||||
|
||||
# test wake up with tags
|
||||
response = requests.post(remote_server.url_for("sleep"), params={"level": "1"})
|
||||
assert response.status_code == 200
|
||||
@ -59,3 +76,35 @@ def test_sleep_mode():
|
||||
response = requests.get(remote_server.url_for("is_sleeping"))
|
||||
assert response.status_code == 200
|
||||
assert response.json().get("is_sleeping") is False
|
||||
|
||||
# check sleep metrics
|
||||
response = requests.get(remote_server.url_for("metrics"))
|
||||
assert response.status_code == 200
|
||||
awake, weights_offloaded, discard_all = _get_sleep_metrics_from_api(response)
|
||||
assert awake == 1
|
||||
assert weights_offloaded == 0
|
||||
assert discard_all == 0
|
||||
|
||||
|
||||
def _get_sleep_metrics_from_api(response: requests.Response):
|
||||
"""Return (awake, weights_offloaded, discard_all)"""
|
||||
|
||||
awake, weights_offloaded, discard_all = None, None, None
|
||||
|
||||
for family in text_string_to_metric_families(response.text):
|
||||
if family.name == "vllm:engine_sleep_state":
|
||||
for sample in family.samples:
|
||||
if sample.name == "vllm:engine_sleep_state":
|
||||
for label_name, label_value in sample.labels.items():
|
||||
if label_value == "awake":
|
||||
awake = sample.value
|
||||
elif label_value == "weights_offloaded":
|
||||
weights_offloaded = sample.value
|
||||
elif label_value == "discard_all":
|
||||
discard_all = sample.value
|
||||
|
||||
assert awake is not None
|
||||
assert weights_offloaded is not None
|
||||
assert discard_all is not None
|
||||
|
||||
return awake, weights_offloaded, discard_all
|
||||
|
||||
@ -1,11 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def enable_batch_invariant_mode(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Automatically enable batch invariant kernel overrides for all tests."""
|
||||
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1")
|
||||
yield
|
||||
@ -1,208 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
HTTP-based batch invariance test: send requests to a running
|
||||
vLLM server and compare BS=1 vs BS=N results (tokens and per-step logprobs).
|
||||
|
||||
Environment variables:
|
||||
- VLLM_API_BASE: base URL like http://127.0.0.1:9256/v1 (default used)
|
||||
- VLLM_TEST_MODEL: served model name (e.g., Qwen/Qwen3-1.7B)
|
||||
|
||||
Note:
|
||||
- The server must be started beforehand via `vllm serve ...`
|
||||
- Example usage:
|
||||
- `export VLLM_ATTENTION_BACKEND="FLASHINFER_MLA"`
|
||||
- `export VLLM_BATCH_INVARIANT=1`
|
||||
- `vllm serve deepseek-ai/DeepSeek-R1 -dp 8 --enable-expert-parallel --port 9256`
|
||||
- `pytest tests/v1/generation/test_online_batch_invariance.py`
|
||||
"""
|
||||
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from typing import Any
|
||||
from urllib.error import HTTPError, URLError
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
import pytest
|
||||
from utils import _random_prompt, skip_unsupported
|
||||
|
||||
|
||||
def _post_json(
|
||||
url: str, headers: dict[str, str], payload: dict[str, Any], timeout: float
|
||||
) -> dict[str, Any]:
|
||||
import json
|
||||
|
||||
data = json.dumps(payload).encode("utf-8")
|
||||
req = Request(url, data=data, headers=headers, method="POST")
|
||||
with urlopen(req, timeout=timeout) as resp:
|
||||
body = resp.read()
|
||||
return json.loads(body.decode("utf-8"))
|
||||
|
||||
|
||||
def _request_completion(
|
||||
api_base: str,
|
||||
model: str,
|
||||
prompt: Any,
|
||||
sp: dict[str, Any],
|
||||
timeout: float = 60.0,
|
||||
max_retries: int = 3,
|
||||
retry_backoff: float = 0.5,
|
||||
) -> dict[str, Any] | None:
|
||||
url = api_base.rstrip("/") + "/completions"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
payload: dict[str, Any] = {"model": model, "prompt": prompt}
|
||||
payload.update(sp)
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
return _post_json(url, headers, payload, timeout)
|
||||
except HTTPError as e: # type: ignore[reportGeneralTypeIssues]
|
||||
status = getattr(e, "code", None)
|
||||
if status in (429, 500, 502, 503, 504) and attempt < max_retries:
|
||||
import time as _t
|
||||
|
||||
_t.sleep(retry_backoff * (2**attempt))
|
||||
continue
|
||||
sys.stderr.write(f"HTTPError: {e}\n")
|
||||
return None
|
||||
except URLError as e: # type: ignore[reportGeneralTypeIssues]
|
||||
if attempt < max_retries:
|
||||
import time as _t
|
||||
|
||||
_t.sleep(retry_backoff * (2**attempt))
|
||||
continue
|
||||
sys.stderr.write(f"URLError: {e}\n")
|
||||
return None
|
||||
except Exception as e: # pragma: no cover
|
||||
if attempt < max_retries:
|
||||
import time as _t
|
||||
|
||||
_t.sleep(retry_backoff * (2**attempt))
|
||||
continue
|
||||
sys.stderr.write(f"Error: {e}\n")
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def _extract_tokens_and_logprobs(
|
||||
choice: dict[str, Any],
|
||||
) -> tuple[list[Any], list[float] | None]:
|
||||
tokens: list[Any] = []
|
||||
token_logprobs: list[float] | None = None
|
||||
lp = choice.get("logprobs")
|
||||
if lp and isinstance(lp, dict):
|
||||
tokens = lp.get("token_ids") or lp.get("tokens") or []
|
||||
token_logprobs = lp.get("token_logprobs", None)
|
||||
return tokens, token_logprobs
|
||||
|
||||
|
||||
def _server_is_up(api_base: str, timeout: float = 2.0) -> tuple[bool, str]:
|
||||
url = api_base.rstrip("/") + "/models"
|
||||
try:
|
||||
req = Request(url, method="GET")
|
||||
with urlopen(req, timeout=timeout) as resp:
|
||||
_ = resp.read()
|
||||
return True, "OK"
|
||||
except URLError as e: # type: ignore[reportGeneralTypeIssues]
|
||||
return False, f"URLError: {e}"
|
||||
except Exception as e: # pragma: no cover
|
||||
return False, f"Error: {e}"
|
||||
|
||||
|
||||
def _compare_bs1_vs_bsn_single_process(
|
||||
prompts: list[str],
|
||||
sp_kwargs: dict[str, Any],
|
||||
api_base: str,
|
||||
model_name: str,
|
||||
) -> None:
|
||||
# BS=1
|
||||
bs1_tokens_per_prompt: list[list[Any]] = []
|
||||
bs1_logprobs_per_prompt: list[list[float] | None] = []
|
||||
for p in prompts:
|
||||
resp = _request_completion(api_base, model_name, p, sp_kwargs)
|
||||
if resp is None or not resp.get("choices"):
|
||||
raise AssertionError("BS=1 empty/failed response")
|
||||
choice = resp["choices"][0]
|
||||
toks, lps = _extract_tokens_and_logprobs(choice)
|
||||
if lps is None:
|
||||
raise AssertionError(
|
||||
"logprobs not returned; ensure server supports 'logprobs'"
|
||||
)
|
||||
bs1_tokens_per_prompt.append(list(toks))
|
||||
bs1_logprobs_per_prompt.append(list(lps))
|
||||
|
||||
# BS=N
|
||||
bsN_tokens_per_prompt: list[list[Any]] = [None] * len(prompts) # type: ignore[list-item]
|
||||
bsN_logprobs_per_prompt: list[list[float] | None] = [None] * len(prompts)
|
||||
resp = _request_completion(api_base, model_name, prompts, sp_kwargs)
|
||||
if resp is None or not resp.get("choices"):
|
||||
raise AssertionError("BS=N empty/failed batched response")
|
||||
choices = resp.get("choices", [])
|
||||
if len(choices) != len(prompts):
|
||||
raise AssertionError(
|
||||
f"BS=N choices length {len(choices)} != num prompts {len(prompts)}"
|
||||
)
|
||||
for idx, choice in enumerate(choices):
|
||||
toks, lps = _extract_tokens_and_logprobs(choice)
|
||||
if lps is None:
|
||||
raise AssertionError(f"BS=N missing logprobs for prompt {idx}")
|
||||
bsN_tokens_per_prompt[idx] = list(toks)
|
||||
bsN_logprobs_per_prompt[idx] = list(lps)
|
||||
|
||||
# compare
|
||||
for i, (tokens_bs1, tokens_bsN, logprobs_bs1, logprobs_bsN) in enumerate(
|
||||
zip(
|
||||
bs1_tokens_per_prompt,
|
||||
bsN_tokens_per_prompt,
|
||||
bs1_logprobs_per_prompt,
|
||||
bsN_logprobs_per_prompt,
|
||||
)
|
||||
):
|
||||
if tokens_bs1 != tokens_bsN:
|
||||
raise AssertionError(
|
||||
f"Prompt {i} (sampling): Different tokens sampled. "
|
||||
f"BS=1 tokens: {tokens_bs1} BS=N tokens: {tokens_bsN}"
|
||||
)
|
||||
if logprobs_bs1 is None or logprobs_bsN is None:
|
||||
raise AssertionError(f"Prompt {i}: Missing logprobs in one of the runs")
|
||||
if len(logprobs_bs1) != len(logprobs_bsN):
|
||||
raise AssertionError(
|
||||
f"Prompt {i}: Different number of steps: "
|
||||
f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bsN)} (BS=N)."
|
||||
)
|
||||
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)):
|
||||
if a != b:
|
||||
diff = abs(a - b)
|
||||
raise AssertionError(
|
||||
f"Prompt {i} Step {t}: Bitwise mismatch "
|
||||
f"(abs diff={diff:.6e}). "
|
||||
f"BS=1 tokens: {tokens_bs1} BS=N tokens: {tokens_bsN}"
|
||||
)
|
||||
|
||||
|
||||
@skip_unsupported
|
||||
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN_dp_http():
|
||||
random.seed(int(os.getenv("VLLM_TEST_SEED", "12345")))
|
||||
api_base = os.getenv("VLLM_API_BASE", "http://127.0.0.1:9256/v1")
|
||||
model_name = os.getenv("VLLM_TEST_MODEL", "deepseek-ai/DeepSeek-R1")
|
||||
up, reason = _server_is_up(api_base)
|
||||
if not up:
|
||||
pytest.skip(f"Server not reachable at {api_base}: {reason}")
|
||||
prompts_all = [_random_prompt(10, 50) for _ in range(32)]
|
||||
|
||||
sp_kwargs: dict[str, Any] = {
|
||||
"temperature": 0.6,
|
||||
"top_p": 1.0,
|
||||
"max_tokens": 8,
|
||||
"seed": 42,
|
||||
"logprobs": 5,
|
||||
}
|
||||
|
||||
_compare_bs1_vs_bsn_single_process(
|
||||
prompts=prompts_all,
|
||||
sp_kwargs=sp_kwargs,
|
||||
api_base=api_base,
|
||||
model_name=model_name,
|
||||
)
|
||||
@ -1,74 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
skip_unsupported = pytest.mark.skipif(
|
||||
not (current_platform.is_cuda() and current_platform.has_device_capability(90)),
|
||||
reason="Requires CUDA and >= Hopper (SM90)",
|
||||
)
|
||||
|
||||
|
||||
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
|
||||
# Generate more realistic prompts that will actually produce varied tokens
|
||||
# Use a mix of common English text patterns
|
||||
|
||||
prompt_templates = [
|
||||
# Question-answer style
|
||||
"Question: What is the capital of France?\nAnswer: The capital of France is",
|
||||
"Q: How does photosynthesis work?\nA: Photosynthesis is the process by which",
|
||||
"User: Can you explain quantum mechanics?\nAssistant: Quantum mechanics is",
|
||||
# Story/narrative style
|
||||
"Once upon a time in a distant galaxy, there lived",
|
||||
"The old man walked slowly down the street, remembering",
|
||||
"In the year 2157, humanity finally discovered",
|
||||
# Technical/code style
|
||||
"To implement a binary search tree in Python, first we need to",
|
||||
"The algorithm works by iterating through the array and",
|
||||
"Here's how to optimize database queries using indexing:",
|
||||
# Factual/informative style
|
||||
"The Renaissance was a period in European history that",
|
||||
"Climate change is caused by several factors including",
|
||||
"The human brain contains approximately 86 billion neurons which",
|
||||
# Conversational style
|
||||
"I've been thinking about getting a new laptop because",
|
||||
"Yesterday I went to the store and bought",
|
||||
"My favorite thing about summer is definitely",
|
||||
]
|
||||
|
||||
# Pick a random template
|
||||
base_prompt = random.choice(prompt_templates)
|
||||
|
||||
if max_words < min_words:
|
||||
max_words = min_words
|
||||
target_words = random.randint(min_words, max_words)
|
||||
|
||||
if target_words > 50:
|
||||
# For longer prompts, repeat context
|
||||
padding_text = (
|
||||
" This is an interesting topic that deserves more explanation. "
|
||||
* (target_words // 50)
|
||||
)
|
||||
base_prompt = base_prompt + padding_text
|
||||
|
||||
return base_prompt
|
||||
|
||||
|
||||
def _extract_step_logprobs(request_output):
|
||||
if getattr(request_output, "outputs", None):
|
||||
inner = request_output.outputs[0]
|
||||
if hasattr(inner, "logprobs") and inner.logprobs is not None:
|
||||
t = torch.tensor(
|
||||
[
|
||||
inner.logprobs[i][tid].logprob
|
||||
for i, tid in enumerate(inner.token_ids)
|
||||
],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
return t, inner.token_ids
|
||||
|
||||
return None, None
|
||||
@ -6,7 +6,7 @@ import pytest_asyncio
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
# Use a small reasoning model to test the responses API.
|
||||
MODEL_NAME = "Qwen/Qwen3-0.6B"
|
||||
MODEL_NAME = "Qwen/Qwen3-1.7B"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
||||
@ -6,9 +6,66 @@ import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from utils import _extract_step_logprobs, _random_prompt, skip_unsupported
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
skip_unsupported = pytest.mark.skipif(
|
||||
not (current_platform.is_cuda() and current_platform.has_device_capability(90)),
|
||||
reason="Requires CUDA and >= Hopper (SM90)",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def enable_batch_invariant_mode(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Automatically enable batch invariant kernel overrides for all tests."""
|
||||
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1")
|
||||
yield
|
||||
|
||||
|
||||
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
|
||||
# Generate more realistic prompts that will actually produce varied tokens
|
||||
# Use a mix of common English text patterns
|
||||
|
||||
prompt_templates = [
|
||||
# Question-answer style
|
||||
"Question: What is the capital of France?\nAnswer: The capital of France is",
|
||||
"Q: How does photosynthesis work?\nA: Photosynthesis is the process by which",
|
||||
"User: Can you explain quantum mechanics?\nAssistant: Quantum mechanics is",
|
||||
# Story/narrative style
|
||||
"Once upon a time in a distant galaxy, there lived",
|
||||
"The old man walked slowly down the street, remembering",
|
||||
"In the year 2157, humanity finally discovered",
|
||||
# Technical/code style
|
||||
"To implement a binary search tree in Python, first we need to",
|
||||
"The algorithm works by iterating through the array and",
|
||||
"Here's how to optimize database queries using indexing:",
|
||||
# Factual/informative style
|
||||
"The Renaissance was a period in European history that",
|
||||
"Climate change is caused by several factors including",
|
||||
"The human brain contains approximately 86 billion neurons which",
|
||||
# Conversational style
|
||||
"I've been thinking about getting a new laptop because",
|
||||
"Yesterday I went to the store and bought",
|
||||
"My favorite thing about summer is definitely",
|
||||
]
|
||||
|
||||
# Pick a random template
|
||||
base_prompt = random.choice(prompt_templates)
|
||||
|
||||
if max_words < min_words:
|
||||
max_words = min_words
|
||||
target_words = random.randint(min_words, max_words)
|
||||
|
||||
if target_words > 50:
|
||||
# For longer prompts, repeat context
|
||||
padding_text = (
|
||||
" This is an interesting topic that deserves more explanation. "
|
||||
* (target_words // 50)
|
||||
)
|
||||
base_prompt = base_prompt + padding_text
|
||||
|
||||
return base_prompt
|
||||
|
||||
|
||||
@skip_unsupported
|
||||
@ -147,6 +204,22 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
|
||||
llm_bsN.shutdown()
|
||||
|
||||
|
||||
def _extract_step_logprobs(request_output):
|
||||
if getattr(request_output, "outputs", None):
|
||||
inner = request_output.outputs[0]
|
||||
if hasattr(inner, "logprobs") and inner.logprobs is not None:
|
||||
t = torch.tensor(
|
||||
[
|
||||
inner.logprobs[i][tid].logprob
|
||||
for i, tid in enumerate(inner.token_ids)
|
||||
],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
return t, inner.token_ids
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
@skip_unsupported
|
||||
@pytest.mark.parametrize(
|
||||
"backend",
|
||||
@ -9,10 +9,15 @@ with the standard CUDA-based implementation to ensure numerical accuracy.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from utils import skip_unsupported
|
||||
|
||||
from vllm.model_executor.layers.batch_invariant import rms_norm as triton_rms_norm
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
skip_unsupported = pytest.mark.skipif(
|
||||
not (current_platform.is_cuda() and current_platform.has_device_capability(90)),
|
||||
reason="Requires CUDA and >= Hopper (SM90)",
|
||||
)
|
||||
|
||||
|
||||
@skip_unsupported
|
||||
@ -50,7 +50,12 @@ if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorPromMetrics,
|
||||
KVConnectorStats,
|
||||
PromMetric,
|
||||
PromMetricT,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.request import Request
|
||||
@ -471,3 +476,18 @@ class KVConnectorBase_V1(ABC):
|
||||
which can implement custom aggregation logic on the data dict.
|
||||
"""
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def build_prom_metrics(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
|
||||
labelnames: list[str],
|
||||
per_engine_labelvalues: dict[int, list[str]],
|
||||
) -> Optional["KVConnectorPromMetrics"]:
|
||||
"""
|
||||
Create a KVConnectorPromMetrics subclass which should register
|
||||
per-connector Prometheus metrics and implement observe() to
|
||||
expose connector transfer stats via Prometheus.
|
||||
"""
|
||||
return None
|
||||
|
||||
@ -1,13 +1,18 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
from typing import Any, TypeAlias, TypeVar
|
||||
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from prometheus_client import Counter, Gauge, Histogram
|
||||
|
||||
from vllm.config import KVTransferConfig, VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||
from vllm.distributed.kv_transfer.kv_transfer_state import has_kv_transfer_group
|
||||
from vllm.logger import init_logger
|
||||
|
||||
PromMetric: TypeAlias = Gauge | Counter | Histogram
|
||||
PromMetricT = TypeVar("PromMetricT", bound=PromMetric)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -102,3 +107,83 @@ class KVConnectorLogging:
|
||||
|
||||
# Reset metrics for next interval
|
||||
self.reset()
|
||||
|
||||
|
||||
class KVConnectorPromMetrics:
|
||||
"""
|
||||
A base class for per-connector Prometheus metric registration
|
||||
and recording.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
metric_types: dict[type[PromMetric], type[PromMetricT]],
|
||||
labelnames: list[str],
|
||||
per_engine_labelvalues: dict[int, list[str]],
|
||||
):
|
||||
self._kv_transfer_config = vllm_config.kv_transfer_config
|
||||
self._gauge_cls = metric_types[Gauge]
|
||||
self._counter_cls = metric_types[Counter]
|
||||
self._histogram_cls = metric_types[Histogram]
|
||||
self._labelnames = labelnames
|
||||
self._per_engine_labelvalues = per_engine_labelvalues
|
||||
|
||||
def make_per_engine(self, metric: PromMetric) -> PromMetric:
|
||||
"""
|
||||
Create a per-engine child of a prometheus_client.Metric with
|
||||
the appropriate labels set. The parent metric must be created
|
||||
using the labelnames list.
|
||||
"""
|
||||
return {
|
||||
idx: metric.labels(*labelvalues)
|
||||
for idx, labelvalues in self._per_engine_labelvalues.items()
|
||||
}
|
||||
|
||||
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
|
||||
"""
|
||||
Record the supplied transfer statistics to Prometheus metrics. These
|
||||
statistics are engine-specific, and should be recorded to a metric
|
||||
with the appropriate 'engine' label. These metric instances can be
|
||||
created using the make_per_engine() helper method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class KVConnectorPrometheus:
|
||||
"""
|
||||
Support for registering per-connector Prometheus metrics, and
|
||||
recording transfer statistics to those metrics. Uses
|
||||
KVConnectorBase.build_prom_metrics().
|
||||
"""
|
||||
|
||||
_gauge_cls = Gauge
|
||||
_counter_cls = Counter
|
||||
_histogram_cls = Histogram
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
labelnames: list[str],
|
||||
per_engine_labelvalues: dict[int, list[str]],
|
||||
):
|
||||
self.prom_metrics: KVConnectorPromMetrics | None = None
|
||||
kv_transfer_config = vllm_config.kv_transfer_config
|
||||
if kv_transfer_config and kv_transfer_config.kv_connector:
|
||||
connector_cls = KVConnectorFactory.get_connector_class(kv_transfer_config)
|
||||
metric_types = {
|
||||
Gauge: self._gauge_cls,
|
||||
Counter: self._counter_cls,
|
||||
Histogram: self._histogram_cls,
|
||||
}
|
||||
self.prom_metrics = connector_cls.build_prom_metrics(
|
||||
vllm_config,
|
||||
metric_types,
|
||||
labelnames,
|
||||
per_engine_labelvalues,
|
||||
)
|
||||
|
||||
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
|
||||
if self.prom_metrics is None:
|
||||
return
|
||||
self.prom_metrics.observe(transfer_stats_data, engine_idx)
|
||||
|
||||
@ -9,13 +9,19 @@ import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorPromMetrics,
|
||||
KVConnectorStats,
|
||||
PromMetric,
|
||||
PromMetricT,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
@ -72,6 +78,27 @@ class MultiKVConnectorStats(KVConnectorStats):
|
||||
self.data[connector_id] = stats
|
||||
|
||||
|
||||
class MultiKVConnectorPromMetrics(KVConnectorPromMetrics):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
metric_types: dict[type[PromMetric], type[PromMetricT]],
|
||||
labelnames: list[str],
|
||||
per_engine_labelvalues: dict[int, list[str]],
|
||||
prom_metrics: dict[str, KVConnectorPromMetrics],
|
||||
):
|
||||
super().__init__(vllm_config, metric_types, labelnames, per_engine_labelvalues)
|
||||
self._prom_metrics = prom_metrics
|
||||
|
||||
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
|
||||
for connector_id, stats_data in transfer_stats_data.items():
|
||||
assert connector_id in self._prom_metrics, (
|
||||
f"{connector_id} is not contained in the list of registered connectors "
|
||||
f"with Prometheus metrics support: {self._prom_metrics.keys()}"
|
||||
)
|
||||
self._prom_metrics[connector_id].observe(stats_data["data"], engine_idx)
|
||||
|
||||
|
||||
class MultiConnector(KVConnectorBase_V1):
|
||||
"""
|
||||
A wrapper for using multiple KVConnectors at the same time.
|
||||
@ -84,19 +111,13 @@ class MultiConnector(KVConnectorBase_V1):
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
||||
super().__init__(vllm_config=vllm_config, role=role)
|
||||
|
||||
self._connectors: list[KVConnectorBase_V1] = []
|
||||
self._ktc_kv_transfer_config = []
|
||||
ktcs = self._kv_transfer_config.kv_connector_extra_config.get("connectors")
|
||||
assert ktcs is not None
|
||||
for ktc in ktcs:
|
||||
temp_config = copy.copy(vllm_config)
|
||||
engine_id = ktc.get("engine_id", self._kv_transfer_config.engine_id)
|
||||
temp_config.kv_transfer_config = KVTransferConfig(
|
||||
**ktc, engine_id=engine_id
|
||||
)
|
||||
self._connectors.append(
|
||||
KVConnectorFactory.create_connector(temp_config, role)
|
||||
)
|
||||
for connector_cls, temp_config in self._get_connector_classes_and_configs(
|
||||
vllm_config
|
||||
):
|
||||
self._connectors.append(connector_cls(temp_config, role))
|
||||
self._ktc_kv_transfer_config.append(temp_config.kv_transfer_config)
|
||||
|
||||
# A mapping from request id to the index of the connector chosen to
|
||||
@ -109,6 +130,32 @@ class MultiConnector(KVConnectorBase_V1):
|
||||
# Propagated from scheduler to worker side via the connector metadata.
|
||||
self._extra_async_saves: dict[str, int] = {}
|
||||
|
||||
@classmethod
|
||||
def _get_connector_classes_and_configs(
|
||||
cls, vllm_config: "VllmConfig"
|
||||
) -> list[tuple[type[KVConnectorBaseType], "VllmConfig"]]:
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"connectors"
|
||||
)
|
||||
assert ktcs is not None
|
||||
ret: list[tuple[type[KVConnectorBaseType], VllmConfig]] = []
|
||||
for ktc in ktcs:
|
||||
temp_config = copy.copy(vllm_config)
|
||||
engine_id = ktc.get("engine_id", vllm_config.kv_transfer_config.engine_id)
|
||||
temp_config.kv_transfer_config = KVTransferConfig(
|
||||
**ktc, engine_id=engine_id
|
||||
)
|
||||
ret.append(
|
||||
(
|
||||
KVConnectorFactory.get_connector_class(
|
||||
temp_config.kv_transfer_config
|
||||
),
|
||||
temp_config,
|
||||
)
|
||||
)
|
||||
return ret
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
for c in self._connectors:
|
||||
c.register_kv_caches(kv_caches)
|
||||
@ -295,18 +342,12 @@ class MultiConnector(KVConnectorBase_V1):
|
||||
None if the connector does not require a specific layout.
|
||||
"""
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"connectors"
|
||||
)
|
||||
assert ktcs is not None
|
||||
layouts: set[str] = set()
|
||||
temp_vllm_config = copy.copy(vllm_config)
|
||||
for ktc in ktcs:
|
||||
kv_transfer_config = KVTransferConfig(**ktc)
|
||||
temp_vllm_config.kv_transfer_config = kv_transfer_config
|
||||
connector_cls = KVConnectorFactory.get_connector_class(kv_transfer_config)
|
||||
for connector_cls, temp_config in cls._get_connector_classes_and_configs(
|
||||
vllm_config
|
||||
):
|
||||
required_kvcache_layout = connector_cls.get_required_kvcache_layout(
|
||||
temp_vllm_config
|
||||
temp_config
|
||||
)
|
||||
if required_kvcache_layout is not None:
|
||||
layouts.add(required_kvcache_layout)
|
||||
@ -372,3 +413,28 @@ class MultiConnector(KVConnectorBase_V1):
|
||||
stats_by_connector = MultiKVConnectorStats()
|
||||
stats_by_connector[c.__class__.__name__] = stats
|
||||
return stats_by_connector
|
||||
|
||||
@classmethod
|
||||
def build_prom_metrics(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
|
||||
labelnames: list[str],
|
||||
per_engine_labelvalues: dict[int, list[str]],
|
||||
) -> KVConnectorPromMetrics:
|
||||
prom_metrics: dict[str, KVConnectorPromMetrics] = {}
|
||||
for connector_cls, temp_config in cls._get_connector_classes_and_configs(
|
||||
vllm_config
|
||||
):
|
||||
connector_prom = connector_cls.build_prom_metrics(
|
||||
temp_config, metric_types, labelnames, per_engine_labelvalues
|
||||
)
|
||||
if connector_prom is not None:
|
||||
prom_metrics[connector_cls.__name__] = connector_prom
|
||||
return MultiKVConnectorPromMetrics(
|
||||
vllm_config,
|
||||
metric_types,
|
||||
labelnames,
|
||||
per_engine_labelvalues,
|
||||
prom_metrics,
|
||||
)
|
||||
|
||||
@ -30,7 +30,12 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorPromMetrics,
|
||||
KVConnectorStats,
|
||||
PromMetric,
|
||||
PromMetricT,
|
||||
)
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
@ -254,6 +259,18 @@ class NixlConnector(KVConnectorBase_V1):
|
||||
else NixlKVConnectorStats()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build_prom_metrics(
|
||||
cls,
|
||||
vllm_config: VllmConfig,
|
||||
metric_types: dict[type[PromMetric], type[PromMetricT]],
|
||||
labelnames: list[str],
|
||||
per_engine_labelvalues: dict[int, list[str]],
|
||||
) -> KVConnectorPromMetrics:
|
||||
return NixlPromMetrics(
|
||||
vllm_config, metric_types, labelnames, per_engine_labelvalues
|
||||
)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
|
||||
assert self.connector_worker is not None
|
||||
assert isinstance(self._connector_metadata, NixlConnectorMetadata)
|
||||
@ -1960,3 +1977,125 @@ class NixlKVConnectorStats(KVConnectorStats):
|
||||
@property
|
||||
def num_successful_transfers(self) -> int:
|
||||
return len(self.data["transfer_duration"])
|
||||
|
||||
|
||||
class NixlPromMetrics(KVConnectorPromMetrics):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
metric_types: dict[type[PromMetric], type[PromMetricT]],
|
||||
labelnames: list[str],
|
||||
per_engine_labelvalues: dict[int, list[str]],
|
||||
):
|
||||
super().__init__(vllm_config, metric_types, labelnames, per_engine_labelvalues)
|
||||
|
||||
buckets = [
|
||||
0.001,
|
||||
0.005,
|
||||
0.01,
|
||||
0.025,
|
||||
0.05,
|
||||
0.075,
|
||||
0.1,
|
||||
0.2,
|
||||
0.3,
|
||||
0.5,
|
||||
0.75,
|
||||
1.0,
|
||||
5.0,
|
||||
]
|
||||
nixl_histogram_xfer_time = self._histogram_cls(
|
||||
name="vllm:nixl_xfer_time_seconds",
|
||||
documentation="Histogram of transfer duration for NIXL KV Cache transfers.",
|
||||
buckets=buckets[1:],
|
||||
labelnames=labelnames,
|
||||
)
|
||||
self.nixl_histogram_xfer_time = self.make_per_engine(nixl_histogram_xfer_time)
|
||||
nixl_histogram_post_time = self._histogram_cls(
|
||||
name="vllm:nixl_post_time_seconds",
|
||||
documentation="Histogram of transfer post time for NIXL KV"
|
||||
" Cache transfers.",
|
||||
buckets=buckets,
|
||||
labelnames=labelnames,
|
||||
)
|
||||
self.nixl_histogram_post_time = self.make_per_engine(nixl_histogram_post_time)
|
||||
# uniform 2kb to 16gb range
|
||||
buckets = [2 ** (10 + i) for i in range(1, 25, 2)]
|
||||
nixl_histogram_bytes_transferred = self._histogram_cls(
|
||||
name="vllm:nixl_bytes_transferred",
|
||||
documentation="Histogram of bytes transferred per NIXL KV Cache transfers.",
|
||||
buckets=buckets,
|
||||
labelnames=labelnames,
|
||||
)
|
||||
self.nixl_histogram_bytes_transferred = self.make_per_engine(
|
||||
nixl_histogram_bytes_transferred
|
||||
)
|
||||
buckets = [
|
||||
10,
|
||||
20,
|
||||
30,
|
||||
50,
|
||||
75,
|
||||
100,
|
||||
200,
|
||||
400,
|
||||
1000,
|
||||
2000,
|
||||
4000,
|
||||
10000,
|
||||
20000,
|
||||
50000,
|
||||
]
|
||||
nixl_histogram_num_descriptors = self._histogram_cls(
|
||||
name="vllm:nixl_num_descriptors",
|
||||
documentation="Histogram of number of descriptors per NIXL"
|
||||
" KV Cache transfers.",
|
||||
buckets=buckets,
|
||||
labelnames=labelnames,
|
||||
)
|
||||
self.nixl_histogram_num_descriptors = self.make_per_engine(
|
||||
nixl_histogram_num_descriptors
|
||||
)
|
||||
counter_nixl_num_failed_transfers = self._counter_cls(
|
||||
name="vllm:nixl_num_failed_transfers",
|
||||
documentation="Number of failed NIXL KV Cache transfers.",
|
||||
labelnames=labelnames,
|
||||
)
|
||||
self.counter_nixl_num_failed_transfers = self.make_per_engine(
|
||||
counter_nixl_num_failed_transfers
|
||||
)
|
||||
counter_nixl_num_failed_notifications = self._counter_cls(
|
||||
name="vllm:nixl_num_failed_notifications",
|
||||
documentation="Number of failed NIXL KV Cache notifications.",
|
||||
labelnames=labelnames,
|
||||
)
|
||||
self.counter_nixl_num_failed_notifications = self.make_per_engine(
|
||||
counter_nixl_num_failed_notifications
|
||||
)
|
||||
|
||||
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
|
||||
for prom_obj, list_item_key in zip(
|
||||
[
|
||||
self.nixl_histogram_xfer_time,
|
||||
self.nixl_histogram_post_time,
|
||||
self.nixl_histogram_bytes_transferred,
|
||||
self.nixl_histogram_num_descriptors,
|
||||
],
|
||||
[
|
||||
"transfer_duration",
|
||||
"post_duration",
|
||||
"bytes_transferred",
|
||||
"num_descriptors",
|
||||
],
|
||||
):
|
||||
for list_item in transfer_stats_data[list_item_key]:
|
||||
prom_obj[engine_idx].observe(list_item)
|
||||
for counter_obj, counter_item_key in zip(
|
||||
[
|
||||
self.counter_nixl_num_failed_transfers,
|
||||
self.counter_nixl_num_failed_notifications,
|
||||
],
|
||||
["num_failed_transfers", "num_failed_notifications"],
|
||||
):
|
||||
for list_item in transfer_stats_data[counter_item_key]:
|
||||
counter_obj[engine_idx].inc(list_item)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
@ -96,19 +97,30 @@ class P2pNcclEngine:
|
||||
# Each card corresponds to a ZMQ address.
|
||||
self.zmq_address = f"{self._hostname}:{self._port}"
|
||||
|
||||
# The `http_port` must be consistent with the port of OpenAI.
|
||||
self.http_address = (
|
||||
f"{self._hostname}:{self.config.kv_connector_extra_config['http_port']}"
|
||||
)
|
||||
|
||||
# If `proxy_ip` or `proxy_port` is `""`,
|
||||
# then the ping thread will not be enabled.
|
||||
proxy_ip = self.config.get_from_extra_config("proxy_ip", "")
|
||||
proxy_port = self.config.get_from_extra_config("proxy_port", "")
|
||||
if proxy_ip == "" or proxy_port == "":
|
||||
self.proxy_address = ""
|
||||
self.http_address = ""
|
||||
else:
|
||||
self.proxy_address = proxy_ip + ":" + proxy_port
|
||||
# the `http_port` must be consistent with the port of OpenAI.
|
||||
http_port = self.config.get_from_extra_config("http_port", None)
|
||||
if http_port is None:
|
||||
example_cfg = {
|
||||
"kv_connector": "P2pNcclConnector",
|
||||
"kv_connector_extra_config": {"http_port": 8000},
|
||||
}
|
||||
example = (
|
||||
f"--port=8000 --kv-transfer-config='{json.dumps(example_cfg)}'"
|
||||
)
|
||||
raise ValueError(
|
||||
"kv_connector_extra_config.http_port is required. "
|
||||
f"Example: {example}"
|
||||
)
|
||||
self.http_address = f"{self._hostname}:{http_port}"
|
||||
|
||||
self.context = zmq.Context()
|
||||
self.router_socket = self.context.socket(zmq.ROUTER)
|
||||
|
||||
@ -336,36 +336,34 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
||||
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
for i, req_id in enumerate(cached_reqs.req_ids):
|
||||
resumed_from_preemption = cached_reqs.resumed_from_preemption[i]
|
||||
if not resumed_from_preemption or req_id not in self._requests_need_load:
|
||||
continue
|
||||
|
||||
num_computed_tokens = cached_reqs.num_computed_tokens[i]
|
||||
num_new_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
new_block_ids = cached_reqs.new_block_ids[i]
|
||||
resumed_from_preemption = cached_reqs.resumed_from_preemption[i]
|
||||
|
||||
# NOTE(rob): here we rely on the resumed requests being
|
||||
# the first N requests in the list scheduled_cache_reqs.
|
||||
if not resumed_from_preemption:
|
||||
break
|
||||
if req_id in self._requests_need_load:
|
||||
# NOTE(rob): cached_req_data does not have the full
|
||||
# list of token ids (only new tokens). So we look it
|
||||
# up in the actual request object.
|
||||
request = self._requests_need_load[req_id]
|
||||
total_tokens = num_computed_tokens + num_new_tokens
|
||||
token_ids = request.all_token_ids[:total_tokens]
|
||||
# NOTE(rob): cached_req_data does not have the full
|
||||
# list of token ids (only new tokens). So we look it
|
||||
# up in the actual request object.
|
||||
request = self._requests_need_load[req_id]
|
||||
total_tokens = num_computed_tokens + num_new_tokens
|
||||
token_ids = request.all_token_ids[:total_tokens]
|
||||
|
||||
# NOTE(rob): For resumed req, new_block_ids is all
|
||||
# of the block_ids for the request.
|
||||
assert new_block_ids is not None
|
||||
block_ids = new_block_ids[0]
|
||||
# NOTE(rob): For resumed req, new_block_ids is all
|
||||
# of the block_ids for the request.
|
||||
assert new_block_ids is not None
|
||||
block_ids = new_block_ids[0]
|
||||
|
||||
meta.add_request(
|
||||
token_ids=token_ids,
|
||||
block_ids=block_ids,
|
||||
block_size=self._block_size,
|
||||
is_store=False,
|
||||
mm_hashes=[f.identifier for f in request.mm_features],
|
||||
)
|
||||
total_need_load += 1
|
||||
meta.add_request(
|
||||
token_ids=token_ids,
|
||||
block_ids=block_ids,
|
||||
block_size=self._block_size,
|
||||
is_store=False,
|
||||
mm_hashes=[f.identifier for f in request.mm_features],
|
||||
)
|
||||
total_need_load += 1
|
||||
|
||||
assert total_need_load == len(self._requests_need_load)
|
||||
self._requests_need_load.clear()
|
||||
|
||||
@ -206,19 +206,6 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
lora_request=lora_request,
|
||||
)
|
||||
else:
|
||||
# TODO: optimize the performance here
|
||||
# in batch invariance mode, force routing to DP rank 0.
|
||||
data_parallel_rank = None
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
|
||||
if vllm_is_batch_invariant():
|
||||
config = self.engine_client.vllm_config
|
||||
dp_size = config.parallel_config.data_parallel_size
|
||||
if dp_size and dp_size > 1:
|
||||
data_parallel_rank = 0
|
||||
|
||||
engine_request, tokenization_kwargs = await self._process_inputs(
|
||||
request_id_item,
|
||||
engine_prompt,
|
||||
@ -226,7 +213,6 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
)
|
||||
|
||||
generator = self.engine_client.generate(
|
||||
|
||||
@ -1158,7 +1158,6 @@ class OpenAIServing:
|
||||
lora_request: LoRARequest | None,
|
||||
trace_headers: Mapping[str, str] | None,
|
||||
priority: int,
|
||||
data_parallel_rank: int | None = None,
|
||||
) -> tuple[EngineCoreRequest, dict[str, Any]]:
|
||||
"""Use the Processor to process inputs for AsyncLLM."""
|
||||
tokenization_kwargs: dict[str, Any] = {}
|
||||
@ -1174,7 +1173,6 @@ class OpenAIServing:
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
trace_headers=trace_headers,
|
||||
priority=priority,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
)
|
||||
return engine_request, tokenization_kwargs
|
||||
|
||||
|
||||
22
vllm/envs.py
22
vllm/envs.py
@ -155,6 +155,7 @@ if TYPE_CHECKING:
|
||||
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
|
||||
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
|
||||
VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency"] = "throughput"
|
||||
VLLM_ALLOW_BATCHED_TRITON_FALLBACK: bool = False
|
||||
VLLM_XGRAMMAR_CACHE_MB: int = 0
|
||||
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
|
||||
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
|
||||
@ -247,10 +248,19 @@ def maybe_convert_bool(value: str | None) -> bool | None:
|
||||
return bool(int(value))
|
||||
|
||||
|
||||
def disable_compile_cache() -> bool:
|
||||
return bool(int(os.getenv("VLLM_DISABLE_COMPILE_CACHE", "0")))
|
||||
|
||||
|
||||
def use_aot_compile() -> bool:
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
default_value = "1" if is_torch_equal_or_newer("2.10.0.dev") else "0"
|
||||
default_value = (
|
||||
"1"
|
||||
if is_torch_equal_or_newer("2.10.0.dev") and not disable_compile_cache()
|
||||
else "0"
|
||||
)
|
||||
|
||||
return os.environ.get("VLLM_USE_AOT_COMPILE", default_value) == "1"
|
||||
|
||||
|
||||
@ -963,9 +973,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_LOG_BATCHSIZE_INTERVAL": lambda: float(
|
||||
os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")
|
||||
),
|
||||
"VLLM_DISABLE_COMPILE_CACHE": lambda: bool(
|
||||
int(os.getenv("VLLM_DISABLE_COMPILE_CACHE", "0"))
|
||||
),
|
||||
"VLLM_DISABLE_COMPILE_CACHE": disable_compile_cache,
|
||||
# If set, vllm will run in development mode, which will enable
|
||||
# some additional endpoints for developing and debugging,
|
||||
# e.g. `/reset_prefix_cache`
|
||||
@ -1138,6 +1146,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16": lambda: bool(
|
||||
int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "0"))
|
||||
),
|
||||
# If set to 1, allow fallback to batched triton kernel when deepgemm
|
||||
# is unavailable. By default (0), the system will crash if deepgemm
|
||||
# is expected but not available.
|
||||
"VLLM_ALLOW_BATCHED_TRITON_FALLBACK": lambda: bool(
|
||||
int(os.getenv("VLLM_ALLOW_BATCHED_TRITON_FALLBACK", "0"))
|
||||
),
|
||||
# Control the cache sized used by the xgrammar compiler. The default
|
||||
# of 512 MB should be enough for roughly 1000 JSON schemas.
|
||||
# It can be changed with this variable if needed for some reason.
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts,
|
||||
@ -22,11 +23,8 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
):
|
||||
super().__init__(quant_config)
|
||||
|
||||
self.batched_triton_experts = BatchedTritonExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=num_dispatchers,
|
||||
quant_config=self.quant_config,
|
||||
)
|
||||
# Store the original request for deep gemm
|
||||
deep_gemm_requested = allow_deep_gemm
|
||||
|
||||
self.allow_deep_gemm = (
|
||||
allow_deep_gemm
|
||||
@ -44,6 +42,31 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
else None
|
||||
)
|
||||
|
||||
# If deep gemm was requested but is not available (either due to
|
||||
# unsupported configuration or missing dependencies), check if
|
||||
# we should allow fallback to batched triton kernel
|
||||
if (
|
||||
deep_gemm_requested
|
||||
and self.batched_deep_gemm_experts is None
|
||||
and not envs.VLLM_ALLOW_BATCHED_TRITON_FALLBACK
|
||||
):
|
||||
raise RuntimeError(
|
||||
"DeepGemm was requested but is not available. "
|
||||
"The batched triton kernel fallback is disabled by default. "
|
||||
"Set VLLM_ALLOW_BATCHED_TRITON_FALLBACK=1 to enable the fallback "
|
||||
"for debugging purposes."
|
||||
)
|
||||
|
||||
self.batched_triton_experts = (
|
||||
BatchedTritonExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=num_dispatchers,
|
||||
quant_config=self.quant_config,
|
||||
)
|
||||
if self.batched_deep_gemm_experts is None
|
||||
else None
|
||||
)
|
||||
|
||||
assert (
|
||||
self.batched_deep_gemm_experts is not None
|
||||
or self.batched_triton_experts is not None
|
||||
|
||||
@ -1135,6 +1135,7 @@ class FusedMoE(CustomOp):
|
||||
)
|
||||
|
||||
self.global_num_experts = num_experts + num_redundant_experts
|
||||
self.logical_num_experts = num_experts
|
||||
self.zero_expert_num = zero_expert_num
|
||||
self.zero_expert_type = zero_expert_type
|
||||
|
||||
@ -1998,13 +1999,12 @@ class FusedMoE(CustomOp):
|
||||
|
||||
moe = self.moe_config
|
||||
|
||||
# Note here we use `num_experts` which is logical expert count
|
||||
if self.vllm_config.parallel_config.enable_dbo:
|
||||
states_shape = (2, moe.max_num_tokens, self.hidden_size)
|
||||
logits_shape = (2, moe.max_num_tokens, moe.num_experts)
|
||||
logits_shape = (2, moe.max_num_tokens, self.logical_num_experts)
|
||||
else:
|
||||
states_shape = (moe.max_num_tokens, self.hidden_size)
|
||||
logits_shape = (moe.max_num_tokens, moe.num_experts)
|
||||
logits_shape = (moe.max_num_tokens, self.logical_num_experts)
|
||||
|
||||
self.batched_hidden_states = torch.zeros(
|
||||
states_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
|
||||
|
||||
@ -428,6 +428,14 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
)
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
# Execute attention entry by entry for speed & less VRAM.
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# Never remove the next contiguous logic
|
||||
# Without it, hallucinations occur with the backend
|
||||
if current_platform.is_rocm():
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
outputs = []
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
start_idx = cu_seqlens[i - 1]
|
||||
|
||||
@ -261,6 +261,21 @@ class CudaPlatformBase(Platform):
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
|
||||
if use_mla:
|
||||
# explicitly reject non-MLA backends when MLA is enabled to avoid
|
||||
# silently selecting an incompatible backend (e.g., FLASHINFER).
|
||||
if selected_backend in {
|
||||
_Backend.FLASHINFER,
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.TRITON_ATTN,
|
||||
_Backend.TREE_ATTN,
|
||||
_Backend.XFORMERS,
|
||||
}:
|
||||
raise ValueError(
|
||||
f"Attention backend {selected_backend} incompatible with MLA. "
|
||||
"Please use one of the MLA backends: FLASHINFER_MLA, CUTLASS_MLA, "
|
||||
"FLASHMLA, FLASH_ATTN_MLA, or TRITON_MLA. Alternatively, set "
|
||||
"VLLM_MLA_DISABLE=1 to disable MLA for this model."
|
||||
)
|
||||
if not use_v1:
|
||||
raise RuntimeError(
|
||||
"MLA attention backends require the V1 engine. "
|
||||
|
||||
@ -72,6 +72,7 @@ _ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = {
|
||||
"0x74a0": "AMD_Instinct_MI300A",
|
||||
"0x74a1": "AMD_Instinct_MI300X",
|
||||
"0x74b5": "AMD_Instinct_MI300X", # MI300X VF
|
||||
"0x74a2": "AMD_Instinct_MI308X",
|
||||
"0x74a5": "AMD_Instinct_MI325X",
|
||||
"0x74b9": "AMD_Instinct_MI325X", # MI325X VF
|
||||
"0x74a9": "AMD_Instinct_MI300X_HF",
|
||||
|
||||
@ -8,9 +8,6 @@ from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionLayer, AttentionType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
@ -130,44 +127,19 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
if self.bmm2_scale is None:
|
||||
self.bmm2_scale = layer._v_scale_float
|
||||
|
||||
if vllm_is_batch_invariant():
|
||||
# TODO(wentao): optimize this when it is supported by Flashinfer upstream.
|
||||
# execute per-request to eliminate batch-shape-dependent kernel paths.
|
||||
num = q.shape[0]
|
||||
outs = []
|
||||
for i in range(num):
|
||||
qi = q[i : i + 1]
|
||||
bt_i = attn_metadata.decode.block_table[i : i + 1]
|
||||
sl_i = attn_metadata.decode.seq_lens[i : i + 1]
|
||||
oi = trtllm_batch_decode_with_kv_cache_mla(
|
||||
query=qi,
|
||||
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
|
||||
workspace_buffer=self._workspace_buffer,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
block_tables=bt_i,
|
||||
seq_lens=sl_i,
|
||||
max_seq_len=attn_metadata.max_seq_len,
|
||||
bmm1_scale=self.bmm1_scale,
|
||||
bmm2_scale=self.bmm2_scale,
|
||||
)
|
||||
outs.append(oi)
|
||||
o = torch.cat(outs, dim=0)
|
||||
else:
|
||||
o = trtllm_batch_decode_with_kv_cache_mla(
|
||||
query=q,
|
||||
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
|
||||
workspace_buffer=self._workspace_buffer,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
block_tables=attn_metadata.decode.block_table,
|
||||
seq_lens=attn_metadata.decode.seq_lens,
|
||||
max_seq_len=attn_metadata.max_seq_len,
|
||||
bmm1_scale=self.bmm1_scale,
|
||||
bmm2_scale=self.bmm2_scale,
|
||||
)
|
||||
o = trtllm_batch_decode_with_kv_cache_mla(
|
||||
query=q,
|
||||
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
|
||||
workspace_buffer=self._workspace_buffer,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
block_tables=attn_metadata.decode.block_table,
|
||||
seq_lens=attn_metadata.decode.seq_lens,
|
||||
max_seq_len=attn_metadata.max_seq_len,
|
||||
bmm1_scale=self.bmm1_scale,
|
||||
bmm2_scale=self.bmm2_scale,
|
||||
)
|
||||
|
||||
# Flatten the output for consistent shape
|
||||
o = o.view(-1, o.shape[-2], o.shape[-1])
|
||||
|
||||
@ -689,9 +689,15 @@ class AsyncLLM(EngineClient):
|
||||
await self.reset_prefix_cache()
|
||||
await self.engine_core.sleep_async(level)
|
||||
|
||||
if self.logger_manager is not None:
|
||||
self.logger_manager.record_sleep_state(1, level)
|
||||
|
||||
async def wake_up(self, tags: list[str] | None = None) -> None:
|
||||
await self.engine_core.wake_up_async(tags)
|
||||
|
||||
if self.logger_manager is not None:
|
||||
self.logger_manager.record_sleep_state(0, 0)
|
||||
|
||||
async def is_sleeping(self) -> bool:
|
||||
return await self.engine_core.is_sleeping_async()
|
||||
|
||||
|
||||
@ -332,9 +332,15 @@ class LLMEngine:
|
||||
def sleep(self, level: int = 1):
|
||||
self.engine_core.sleep(level)
|
||||
|
||||
if self.logger_manager is not None:
|
||||
self.logger_manager.record_sleep_state(1, level)
|
||||
|
||||
def wake_up(self, tags: list[str] | None = None):
|
||||
self.engine_core.wake_up(tags)
|
||||
|
||||
if self.logger_manager is not None:
|
||||
self.logger_manager.record_sleep_state(0, 0)
|
||||
|
||||
def is_sleeping(self) -> bool:
|
||||
return self.engine_core.is_sleeping()
|
||||
|
||||
|
||||
@ -9,8 +9,12 @@ from typing import TypeAlias
|
||||
|
||||
from prometheus_client import Counter, Gauge, Histogram
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import SupportsMetricsInfo, VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorLogging
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorLogging,
|
||||
KVConnectorPrometheus,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.plugins import load_plugins_by_group
|
||||
from vllm.v1.engine import FinishReason
|
||||
@ -56,6 +60,9 @@ class StatLoggerBase(ABC):
|
||||
def log(self): # noqa
|
||||
pass
|
||||
|
||||
def record_sleep_state(self, is_awake: int, level: int): # noqa
|
||||
pass
|
||||
|
||||
|
||||
def load_stat_logger_plugin_factories() -> list[StatLoggerFactory]:
|
||||
factories: list[StatLoggerFactory] = []
|
||||
@ -335,6 +342,7 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
|
||||
_counter_cls = Counter
|
||||
_histogram_cls = Histogram
|
||||
_spec_decoding_cls = SpecDecodingProm
|
||||
_kv_connector_cls = KVConnectorPrometheus
|
||||
|
||||
def __init__(
|
||||
self, vllm_config: VllmConfig, engine_indexes: list[int] | None = None
|
||||
@ -354,12 +362,15 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
|
||||
model_name = vllm_config.model_config.served_model_name
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
|
||||
spec_decode_labelvalues: dict[int, list[str]] = {
|
||||
per_engine_labelvalues: dict[int, list[str]] = {
|
||||
idx: [model_name, str(idx)] for idx in engine_indexes
|
||||
}
|
||||
|
||||
self.spec_decoding_prom = self._spec_decoding_cls(
|
||||
vllm_config.speculative_config, labelnames, spec_decode_labelvalues
|
||||
vllm_config.speculative_config, labelnames, per_engine_labelvalues
|
||||
)
|
||||
self.kv_connector_prom = self._kv_connector_cls(
|
||||
vllm_config, labelnames, per_engine_labelvalues
|
||||
)
|
||||
|
||||
#
|
||||
@ -384,8 +395,33 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
|
||||
self.gauge_scheduler_waiting = make_per_engine(
|
||||
gauge_scheduler_waiting, engine_indexes, model_name
|
||||
)
|
||||
if envs.VLLM_SERVER_DEV_MODE:
|
||||
gauge_engine_sleep_state = self._gauge_cls(
|
||||
name="vllm:engine_sleep_state",
|
||||
documentation=(
|
||||
"Engine sleep state; awake = 0 means engine is sleeping; "
|
||||
"awake = 1 means engine is awake; "
|
||||
"weights_offloaded = 1 means sleep level 1; "
|
||||
"discard_all = 1 means sleep level 2."
|
||||
),
|
||||
labelnames=labelnames + ["sleep_state"],
|
||||
multiprocess_mode="mostrecent",
|
||||
)
|
||||
|
||||
self.gauge_engine_sleep_state = {}
|
||||
sleep_state = ["awake", "weights_offloaded", "discard_all"]
|
||||
|
||||
for s in sleep_state:
|
||||
self.gauge_engine_sleep_state[s] = {
|
||||
idx: gauge_engine_sleep_state.labels(
|
||||
engine=idx, model_name=model_name, sleep_state=s
|
||||
)
|
||||
for idx in engine_indexes
|
||||
}
|
||||
|
||||
# Setting default values
|
||||
self.record_sleep_state()
|
||||
|
||||
#
|
||||
# GPU cache
|
||||
#
|
||||
# Deprecated in 0.9.2 - Renamed as vllm:kv_cache_usage_perc
|
||||
@ -933,6 +969,11 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
|
||||
scheduler_stats.spec_decoding_stats, engine_idx
|
||||
)
|
||||
|
||||
if scheduler_stats.kv_connector_stats is not None:
|
||||
self.kv_connector_prom.observe(
|
||||
scheduler_stats.kv_connector_stats, engine_idx
|
||||
)
|
||||
|
||||
if mm_cache_stats is not None:
|
||||
self.counter_mm_cache_queries[engine_idx].inc(mm_cache_stats.queries)
|
||||
self.counter_mm_cache_hits[engine_idx].inc(mm_cache_stats.hits)
|
||||
@ -1010,6 +1051,25 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
|
||||
}
|
||||
self.gauge_lora_info.labels(**lora_info_labels).set_to_current_time()
|
||||
|
||||
def record_sleep_state(self, sleep: int = 0, level: int = 0):
|
||||
awake = 1
|
||||
discard_all = 0
|
||||
weights_offloaded = 0
|
||||
|
||||
if sleep == 1:
|
||||
awake = 0
|
||||
if level == 1:
|
||||
weights_offloaded = 1
|
||||
elif level == 2:
|
||||
discard_all = 1
|
||||
|
||||
for engine_idx in self.engine_indexes:
|
||||
self.gauge_engine_sleep_state["discard_all"][engine_idx].set(discard_all)
|
||||
self.gauge_engine_sleep_state["weights_offloaded"][engine_idx].set(
|
||||
weights_offloaded
|
||||
)
|
||||
self.gauge_engine_sleep_state["awake"][engine_idx].set(awake)
|
||||
|
||||
def log_engine_initialized(self):
|
||||
self.log_metrics_info("cache_config", self.vllm_config.cache_config)
|
||||
|
||||
@ -1131,6 +1191,10 @@ class StatLoggerManager:
|
||||
engine_idx=engine_idx,
|
||||
)
|
||||
|
||||
def record_sleep_state(self, sleep: int = 0, level: int = 0):
|
||||
for logger in self.stat_loggers:
|
||||
logger.record_sleep_state(sleep, level)
|
||||
|
||||
def log(self):
|
||||
for logger in self.stat_loggers:
|
||||
logger.log()
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import time
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorPrometheus
|
||||
from vllm.v1.metrics.loggers import PrometheusStatLogger
|
||||
from vllm.v1.spec_decode.metrics import SpecDecodingProm
|
||||
|
||||
@ -141,6 +142,18 @@ class RaySpecDecodingProm(SpecDecodingProm):
|
||||
_counter_cls = RayCounterWrapper
|
||||
|
||||
|
||||
class RayKVConnectorPrometheus(KVConnectorPrometheus):
|
||||
"""
|
||||
RayKVConnectorPrometheus is used by RayMetrics to log Ray
|
||||
metrics. Provides the same metrics as KV connectors but
|
||||
uses Ray's util.metrics library.
|
||||
"""
|
||||
|
||||
_gauge_cls = RayGaugeWrapper
|
||||
_counter_cls = RayCounterWrapper
|
||||
_histogram_cls = RayHistogramWrapper
|
||||
|
||||
|
||||
class RayPrometheusStatLogger(PrometheusStatLogger):
|
||||
"""RayPrometheusStatLogger uses Ray metrics instead."""
|
||||
|
||||
@ -148,6 +161,7 @@ class RayPrometheusStatLogger(PrometheusStatLogger):
|
||||
_counter_cls = RayCounterWrapper
|
||||
_histogram_cls = RayHistogramWrapper
|
||||
_spec_decoding_cls = RaySpecDecodingProm
|
||||
_kv_connector_cls = RayKVConnectorPrometheus
|
||||
|
||||
@staticmethod
|
||||
def _unregister_vllm_metrics():
|
||||
|
||||
Reference in New Issue
Block a user