Compare commits

...

8 Commits

Author SHA1 Message Date
40464dbf34 rename
Signed-off-by: yewentao256 <zhyanwentao@126.com>
2025-10-29 15:19:17 -07:00
981cc5fdbf Merge branch 'main' into wentao-batch-invariance-dp 2025-10-29 15:18:33 -07:00
b53a65fa46 update using skip if server is not up
Signed-off-by: yewentao256 <zhyanwentao@126.com>
2025-10-28 14:58:07 -07:00
4f2a8d9d7f Merge branch 'main' into wentao-batch-invariance-dp
Signed-off-by: yewentao256 <zhyanwentao@126.com>
2025-10-28 14:00:34 -07:00
b2f24cd6b7 add todo
Signed-off-by: yewentao256 <zhyanwentao@126.com>
2025-10-27 09:52:17 -07:00
bc955355f8 Merge branch 'main' into wentao-batch-invariance-dp 2025-10-27 09:32:27 -07:00
cf31e136ff flashmla + dp support
Signed-off-by: yewentao256 <zhyanwentao@126.com>
2025-10-23 08:12:49 -07:00
c2168c50cc batch invariance dp
Signed-off-by: yewentao256 <zhyanwentao@126.com>
2025-10-22 13:40:16 -07:00
8 changed files with 352 additions and 93 deletions

View File

@ -0,0 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
@pytest.fixture(autouse=True)
def enable_batch_invariant_mode(monkeypatch: pytest.MonkeyPatch):
"""Automatically enable batch invariant kernel overrides for all tests."""
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1")
yield

View File

@ -6,66 +6,9 @@ import random
import pytest
import torch
from utils import _extract_step_logprobs, _random_prompt, skip_unsupported
from vllm import LLM, SamplingParams
from vllm.platforms import current_platform
skip_unsupported = pytest.mark.skipif(
not (current_platform.is_cuda() and current_platform.has_device_capability(90)),
reason="Requires CUDA and >= Hopper (SM90)",
)
@pytest.fixture(autouse=True)
def enable_batch_invariant_mode(monkeypatch: pytest.MonkeyPatch):
"""Automatically enable batch invariant kernel overrides for all tests."""
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1")
yield
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
# Generate more realistic prompts that will actually produce varied tokens
# Use a mix of common English text patterns
prompt_templates = [
# Question-answer style
"Question: What is the capital of France?\nAnswer: The capital of France is",
"Q: How does photosynthesis work?\nA: Photosynthesis is the process by which",
"User: Can you explain quantum mechanics?\nAssistant: Quantum mechanics is",
# Story/narrative style
"Once upon a time in a distant galaxy, there lived",
"The old man walked slowly down the street, remembering",
"In the year 2157, humanity finally discovered",
# Technical/code style
"To implement a binary search tree in Python, first we need to",
"The algorithm works by iterating through the array and",
"Here's how to optimize database queries using indexing:",
# Factual/informative style
"The Renaissance was a period in European history that",
"Climate change is caused by several factors including",
"The human brain contains approximately 86 billion neurons which",
# Conversational style
"I've been thinking about getting a new laptop because",
"Yesterday I went to the store and bought",
"My favorite thing about summer is definitely",
]
# Pick a random template
base_prompt = random.choice(prompt_templates)
if max_words < min_words:
max_words = min_words
target_words = random.randint(min_words, max_words)
if target_words > 50:
# For longer prompts, repeat context
padding_text = (
" This is an interesting topic that deserves more explanation. "
* (target_words // 50)
)
base_prompt = base_prompt + padding_text
return base_prompt
@skip_unsupported
@ -204,22 +147,6 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
llm_bsN.shutdown()
def _extract_step_logprobs(request_output):
if getattr(request_output, "outputs", None):
inner = request_output.outputs[0]
if hasattr(inner, "logprobs") and inner.logprobs is not None:
t = torch.tensor(
[
inner.logprobs[i][tid].logprob
for i, tid in enumerate(inner.token_ids)
],
dtype=torch.float32,
)
return t, inner.token_ids
return None, None
@skip_unsupported
@pytest.mark.parametrize(
"backend",

View File

@ -0,0 +1,208 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
HTTP-based batch invariance test: send requests to a running
vLLM server and compare BS=1 vs BS=N results (tokens and per-step logprobs).
Environment variables:
- VLLM_API_BASE: base URL like http://127.0.0.1:9256/v1 (default used)
- VLLM_TEST_MODEL: served model name (e.g., Qwen/Qwen3-1.7B)
Note:
- The server must be started beforehand via `vllm serve ...`
- Example usage:
- `export VLLM_ATTENTION_BACKEND="FLASHINFER_MLA"`
- `export VLLM_BATCH_INVARIANT=1`
- `vllm serve deepseek-ai/DeepSeek-R1 -dp 8 --enable-expert-parallel --port 9256`
- `pytest tests/v1/generation/test_online_batch_invariance.py`
"""
import os
import random
import sys
from typing import Any
from urllib.error import HTTPError, URLError
from urllib.request import Request, urlopen
import pytest
from utils import _random_prompt, skip_unsupported
def _post_json(
url: str, headers: dict[str, str], payload: dict[str, Any], timeout: float
) -> dict[str, Any]:
import json
data = json.dumps(payload).encode("utf-8")
req = Request(url, data=data, headers=headers, method="POST")
with urlopen(req, timeout=timeout) as resp:
body = resp.read()
return json.loads(body.decode("utf-8"))
def _request_completion(
api_base: str,
model: str,
prompt: Any,
sp: dict[str, Any],
timeout: float = 60.0,
max_retries: int = 3,
retry_backoff: float = 0.5,
) -> dict[str, Any] | None:
url = api_base.rstrip("/") + "/completions"
headers = {"Content-Type": "application/json"}
payload: dict[str, Any] = {"model": model, "prompt": prompt}
payload.update(sp)
for attempt in range(max_retries + 1):
try:
return _post_json(url, headers, payload, timeout)
except HTTPError as e: # type: ignore[reportGeneralTypeIssues]
status = getattr(e, "code", None)
if status in (429, 500, 502, 503, 504) and attempt < max_retries:
import time as _t
_t.sleep(retry_backoff * (2**attempt))
continue
sys.stderr.write(f"HTTPError: {e}\n")
return None
except URLError as e: # type: ignore[reportGeneralTypeIssues]
if attempt < max_retries:
import time as _t
_t.sleep(retry_backoff * (2**attempt))
continue
sys.stderr.write(f"URLError: {e}\n")
return None
except Exception as e: # pragma: no cover
if attempt < max_retries:
import time as _t
_t.sleep(retry_backoff * (2**attempt))
continue
sys.stderr.write(f"Error: {e}\n")
return None
return None
def _extract_tokens_and_logprobs(
choice: dict[str, Any],
) -> tuple[list[Any], list[float] | None]:
tokens: list[Any] = []
token_logprobs: list[float] | None = None
lp = choice.get("logprobs")
if lp and isinstance(lp, dict):
tokens = lp.get("token_ids") or lp.get("tokens") or []
token_logprobs = lp.get("token_logprobs", None)
return tokens, token_logprobs
def _server_is_up(api_base: str, timeout: float = 2.0) -> tuple[bool, str]:
url = api_base.rstrip("/") + "/models"
try:
req = Request(url, method="GET")
with urlopen(req, timeout=timeout) as resp:
_ = resp.read()
return True, "OK"
except URLError as e: # type: ignore[reportGeneralTypeIssues]
return False, f"URLError: {e}"
except Exception as e: # pragma: no cover
return False, f"Error: {e}"
def _compare_bs1_vs_bsn_single_process(
prompts: list[str],
sp_kwargs: dict[str, Any],
api_base: str,
model_name: str,
) -> None:
# BS=1
bs1_tokens_per_prompt: list[list[Any]] = []
bs1_logprobs_per_prompt: list[list[float] | None] = []
for p in prompts:
resp = _request_completion(api_base, model_name, p, sp_kwargs)
if resp is None or not resp.get("choices"):
raise AssertionError("BS=1 empty/failed response")
choice = resp["choices"][0]
toks, lps = _extract_tokens_and_logprobs(choice)
if lps is None:
raise AssertionError(
"logprobs not returned; ensure server supports 'logprobs'"
)
bs1_tokens_per_prompt.append(list(toks))
bs1_logprobs_per_prompt.append(list(lps))
# BS=N
bsN_tokens_per_prompt: list[list[Any]] = [None] * len(prompts) # type: ignore[list-item]
bsN_logprobs_per_prompt: list[list[float] | None] = [None] * len(prompts)
resp = _request_completion(api_base, model_name, prompts, sp_kwargs)
if resp is None or not resp.get("choices"):
raise AssertionError("BS=N empty/failed batched response")
choices = resp.get("choices", [])
if len(choices) != len(prompts):
raise AssertionError(
f"BS=N choices length {len(choices)} != num prompts {len(prompts)}"
)
for idx, choice in enumerate(choices):
toks, lps = _extract_tokens_and_logprobs(choice)
if lps is None:
raise AssertionError(f"BS=N missing logprobs for prompt {idx}")
bsN_tokens_per_prompt[idx] = list(toks)
bsN_logprobs_per_prompt[idx] = list(lps)
# compare
for i, (tokens_bs1, tokens_bsN, logprobs_bs1, logprobs_bsN) in enumerate(
zip(
bs1_tokens_per_prompt,
bsN_tokens_per_prompt,
bs1_logprobs_per_prompt,
bsN_logprobs_per_prompt,
)
):
if tokens_bs1 != tokens_bsN:
raise AssertionError(
f"Prompt {i} (sampling): Different tokens sampled. "
f"BS=1 tokens: {tokens_bs1} BS=N tokens: {tokens_bsN}"
)
if logprobs_bs1 is None or logprobs_bsN is None:
raise AssertionError(f"Prompt {i}: Missing logprobs in one of the runs")
if len(logprobs_bs1) != len(logprobs_bsN):
raise AssertionError(
f"Prompt {i}: Different number of steps: "
f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bsN)} (BS=N)."
)
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)):
if a != b:
diff = abs(a - b)
raise AssertionError(
f"Prompt {i} Step {t}: Bitwise mismatch "
f"(abs diff={diff:.6e}). "
f"BS=1 tokens: {tokens_bs1} BS=N tokens: {tokens_bsN}"
)
@skip_unsupported
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN_dp_http():
random.seed(int(os.getenv("VLLM_TEST_SEED", "12345")))
api_base = os.getenv("VLLM_API_BASE", "http://127.0.0.1:9256/v1")
model_name = os.getenv("VLLM_TEST_MODEL", "deepseek-ai/DeepSeek-R1")
up, reason = _server_is_up(api_base)
if not up:
pytest.skip(f"Server not reachable at {api_base}: {reason}")
prompts_all = [_random_prompt(10, 50) for _ in range(32)]
sp_kwargs: dict[str, Any] = {
"temperature": 0.6,
"top_p": 1.0,
"max_tokens": 8,
"seed": 42,
"logprobs": 5,
}
_compare_bs1_vs_bsn_single_process(
prompts=prompts_all,
sp_kwargs=sp_kwargs,
api_base=api_base,
model_name=model_name,
)

View File

@ -9,15 +9,10 @@ with the standard CUDA-based implementation to ensure numerical accuracy.
import pytest
import torch
from utils import skip_unsupported
from vllm.model_executor.layers.batch_invariant import rms_norm as triton_rms_norm
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.platforms import current_platform
skip_unsupported = pytest.mark.skipif(
not (current_platform.is_cuda() and current_platform.has_device_capability(90)),
reason="Requires CUDA and >= Hopper (SM90)",
)
@skip_unsupported

View File

@ -0,0 +1,74 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
import pytest
import torch
from vllm.platforms import current_platform
skip_unsupported = pytest.mark.skipif(
not (current_platform.is_cuda() and current_platform.has_device_capability(90)),
reason="Requires CUDA and >= Hopper (SM90)",
)
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
# Generate more realistic prompts that will actually produce varied tokens
# Use a mix of common English text patterns
prompt_templates = [
# Question-answer style
"Question: What is the capital of France?\nAnswer: The capital of France is",
"Q: How does photosynthesis work?\nA: Photosynthesis is the process by which",
"User: Can you explain quantum mechanics?\nAssistant: Quantum mechanics is",
# Story/narrative style
"Once upon a time in a distant galaxy, there lived",
"The old man walked slowly down the street, remembering",
"In the year 2157, humanity finally discovered",
# Technical/code style
"To implement a binary search tree in Python, first we need to",
"The algorithm works by iterating through the array and",
"Here's how to optimize database queries using indexing:",
# Factual/informative style
"The Renaissance was a period in European history that",
"Climate change is caused by several factors including",
"The human brain contains approximately 86 billion neurons which",
# Conversational style
"I've been thinking about getting a new laptop because",
"Yesterday I went to the store and bought",
"My favorite thing about summer is definitely",
]
# Pick a random template
base_prompt = random.choice(prompt_templates)
if max_words < min_words:
max_words = min_words
target_words = random.randint(min_words, max_words)
if target_words > 50:
# For longer prompts, repeat context
padding_text = (
" This is an interesting topic that deserves more explanation. "
* (target_words // 50)
)
base_prompt = base_prompt + padding_text
return base_prompt
def _extract_step_logprobs(request_output):
if getattr(request_output, "outputs", None):
inner = request_output.outputs[0]
if hasattr(inner, "logprobs") and inner.logprobs is not None:
t = torch.tensor(
[
inner.logprobs[i][tid].logprob
for i, tid in enumerate(inner.token_ids)
],
dtype=torch.float32,
)
return t, inner.token_ids
return None, None

View File

@ -206,6 +206,19 @@ class OpenAIServingCompletion(OpenAIServing):
lora_request=lora_request,
)
else:
# TODO: optimize the performance here
# in batch invariance mode, force routing to DP rank 0.
data_parallel_rank = None
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
if vllm_is_batch_invariant():
config = self.engine_client.vllm_config
dp_size = config.parallel_config.data_parallel_size
if dp_size and dp_size > 1:
data_parallel_rank = 0
engine_request, tokenization_kwargs = await self._process_inputs(
request_id_item,
engine_prompt,
@ -213,6 +226,7 @@ class OpenAIServingCompletion(OpenAIServing):
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
data_parallel_rank=data_parallel_rank,
)
generator = self.engine_client.generate(

View File

@ -1158,6 +1158,7 @@ class OpenAIServing:
lora_request: LoRARequest | None,
trace_headers: Mapping[str, str] | None,
priority: int,
data_parallel_rank: int | None = None,
) -> tuple[EngineCoreRequest, dict[str, Any]]:
"""Use the Processor to process inputs for AsyncLLM."""
tokenization_kwargs: dict[str, Any] = {}
@ -1173,6 +1174,7 @@ class OpenAIServing:
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=priority,
data_parallel_rank=data_parallel_rank,
)
return engine_request, tokenization_kwargs

View File

@ -8,6 +8,9 @@ from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
from vllm.attention.backends.abstract import AttentionLayer, AttentionType
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
MLACommonImpl,
@ -127,19 +130,44 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
if self.bmm2_scale is None:
self.bmm2_scale = layer._v_scale_float
o = trtllm_batch_decode_with_kv_cache_mla(
query=q,
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
workspace_buffer=self._workspace_buffer,
qk_nope_head_dim=self.qk_nope_head_dim,
kv_lora_rank=self.kv_lora_rank,
qk_rope_head_dim=self.qk_rope_head_dim,
block_tables=attn_metadata.decode.block_table,
seq_lens=attn_metadata.decode.seq_lens,
max_seq_len=attn_metadata.max_seq_len,
bmm1_scale=self.bmm1_scale,
bmm2_scale=self.bmm2_scale,
)
if vllm_is_batch_invariant():
# TODO(wentao): optimize this when it is supported by Flashinfer upstream.
# execute per-request to eliminate batch-shape-dependent kernel paths.
num = q.shape[0]
outs = []
for i in range(num):
qi = q[i : i + 1]
bt_i = attn_metadata.decode.block_table[i : i + 1]
sl_i = attn_metadata.decode.seq_lens[i : i + 1]
oi = trtllm_batch_decode_with_kv_cache_mla(
query=qi,
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
workspace_buffer=self._workspace_buffer,
qk_nope_head_dim=self.qk_nope_head_dim,
kv_lora_rank=self.kv_lora_rank,
qk_rope_head_dim=self.qk_rope_head_dim,
block_tables=bt_i,
seq_lens=sl_i,
max_seq_len=attn_metadata.max_seq_len,
bmm1_scale=self.bmm1_scale,
bmm2_scale=self.bmm2_scale,
)
outs.append(oi)
o = torch.cat(outs, dim=0)
else:
o = trtllm_batch_decode_with_kv_cache_mla(
query=q,
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
workspace_buffer=self._workspace_buffer,
qk_nope_head_dim=self.qk_nope_head_dim,
kv_lora_rank=self.kv_lora_rank,
qk_rope_head_dim=self.qk_rope_head_dim,
block_tables=attn_metadata.decode.block_table,
seq_lens=attn_metadata.decode.seq_lens,
max_seq_len=attn_metadata.max_seq_len,
bmm1_scale=self.bmm1_scale,
bmm2_scale=self.bmm2_scale,
)
# Flatten the output for consistent shape
o = o.view(-1, o.shape[-2], o.shape[-1])