@ -9,7 +9,7 @@ import pytest
|
||||
def enable_batch_invariant_mode():
|
||||
"""Automatically enable batch invariant kernel overrides for all tests."""
|
||||
old_value = os.environ.get("VLLM_BATCH_INVARIANT")
|
||||
os.environ["VLLM_BATCH_INVARIANT"] = "0"
|
||||
os.environ["VLLM_BATCH_INVARIANT"] = "1"
|
||||
yield
|
||||
# restore original value after test
|
||||
if old_value is None:
|
||||
|
||||
@ -9,16 +9,17 @@ Environment variables:
|
||||
- VLLM_TEST_MODEL: served model name (e.g., Qwen/Qwen3-1.7B)
|
||||
|
||||
Note:
|
||||
- The server must be started beforehand via `vllm serve ...` and should
|
||||
honor the same sampling params for determinism.
|
||||
- For BS=N, we dispatch requests concurrently to encourage server-side
|
||||
dynamic batching.
|
||||
- The server must be started beforehand via `vllm serve ...`
|
||||
- Example usage:
|
||||
- `export VLLM_ATTENTION_BACKEND="FLASHINFER_MLA"`
|
||||
- `export VLLM_BATCH_INVARIANT=1`
|
||||
- `vllm serve deepseek-ai/DeepSeek-R1 -dp 8 --enable-expert-parallel --port 9256`
|
||||
- `pytest tests/v1/generation/test_online_batch_invariance.py`
|
||||
"""
|
||||
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Any
|
||||
from urllib.error import HTTPError, URLError
|
||||
from urllib.request import Request, urlopen
|
||||
@ -41,7 +42,7 @@ def _post_json(
|
||||
def _request_completion(
|
||||
api_base: str,
|
||||
model: str,
|
||||
prompt: str,
|
||||
prompt: Any,
|
||||
sp: dict[str, Any],
|
||||
timeout: float = 60.0,
|
||||
max_retries: int = 3,
|
||||
@ -118,27 +119,25 @@ def _compare_bs1_vs_bsn_single_process(
|
||||
bs1_tokens_per_prompt.append(list(toks))
|
||||
bs1_logprobs_per_prompt.append(list(lps))
|
||||
|
||||
# BS=N: dispatch concurrently to encourage server batching
|
||||
# BS=N
|
||||
bsN_tokens_per_prompt: list[list[Any]] = [None] * len(prompts) # type: ignore[list-item]
|
||||
bsN_logprobs_per_prompt: list[list[float] | None] = [None] * len(prompts)
|
||||
with ThreadPoolExecutor(max_workers=min(64, max(1, len(prompts)))) as ex:
|
||||
futures = {
|
||||
ex.submit(_request_completion, api_base, model_name, p, sp_kwargs): i
|
||||
for i, p in enumerate(prompts)
|
||||
}
|
||||
for fut in as_completed(futures):
|
||||
idx = futures[fut]
|
||||
resp = fut.result()
|
||||
if resp is None or not resp.get("choices"):
|
||||
raise AssertionError(f"BS=N empty/failed response for prompt {idx}")
|
||||
choice = resp["choices"][0]
|
||||
toks, lps = _extract_tokens_and_logprobs(choice)
|
||||
if lps is None:
|
||||
raise AssertionError(f"BS=N missing logprobs for prompt {idx}")
|
||||
bsN_tokens_per_prompt[idx] = list(toks)
|
||||
bsN_logprobs_per_prompt[idx] = list(lps)
|
||||
resp = _request_completion(api_base, model_name, prompts, sp_kwargs)
|
||||
if resp is None or not resp.get("choices"):
|
||||
raise AssertionError("BS=N empty/failed batched response")
|
||||
choices = resp.get("choices", [])
|
||||
if len(choices) != len(prompts):
|
||||
raise AssertionError(
|
||||
f"BS=N choices length {len(choices)} != num prompts {len(prompts)}"
|
||||
)
|
||||
for idx, choice in enumerate(choices):
|
||||
toks, lps = _extract_tokens_and_logprobs(choice)
|
||||
if lps is None:
|
||||
raise AssertionError(f"BS=N missing logprobs for prompt {idx}")
|
||||
bsN_tokens_per_prompt[idx] = list(toks)
|
||||
bsN_logprobs_per_prompt[idx] = list(lps)
|
||||
|
||||
# Compare
|
||||
# compare
|
||||
for i, (tokens_bs1, tokens_bsN, logprobs_bs1, logprobs_bsN) in enumerate(
|
||||
zip(
|
||||
bs1_tokens_per_prompt,
|
||||
@ -172,11 +171,9 @@ def _compare_bs1_vs_bsn_single_process(
|
||||
@skip_unsupported
|
||||
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN_dp_http():
|
||||
random.seed(int(os.getenv("VLLM_TEST_SEED", "12345")))
|
||||
|
||||
api_base = os.getenv("VLLM_API_BASE", "http://127.0.0.1:9256/v1")
|
||||
model_name = os.getenv("VLLM_TEST_MODEL", "deepseek-ai/DeepSeek-V2-lite")
|
||||
num_prompts = int(os.getenv("VLLM_TEST_NUM_PROMPTS", "32"))
|
||||
prompts_all = [_random_prompt(10, 50) for _ in range(num_prompts)]
|
||||
model_name = os.getenv("VLLM_TEST_MODEL", "deepseek-ai/DeepSeek-R1")
|
||||
prompts_all = [_random_prompt(10, 50) for _ in range(32)]
|
||||
|
||||
sp_kwargs: dict[str, Any] = {
|
||||
"temperature": 0.6,
|
||||
@ -206,6 +206,19 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
lora_request=lora_request,
|
||||
)
|
||||
else:
|
||||
# TODO: optimize the performance here
|
||||
# in batch invariance mode, force routing to DP rank 0.
|
||||
data_parallel_rank = None
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
|
||||
if vllm_is_batch_invariant():
|
||||
config = self.engine_client.vllm_config
|
||||
dp_size = config.parallel_config.data_parallel_size
|
||||
if dp_size and dp_size > 1:
|
||||
data_parallel_rank = 0
|
||||
|
||||
engine_request, tokenization_kwargs = await self._process_inputs(
|
||||
request_id_item,
|
||||
engine_prompt,
|
||||
@ -213,6 +226,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
)
|
||||
|
||||
generator = self.engine_client.generate(
|
||||
|
||||
@ -1158,6 +1158,7 @@ class OpenAIServing:
|
||||
lora_request: LoRARequest | None,
|
||||
trace_headers: Mapping[str, str] | None,
|
||||
priority: int,
|
||||
data_parallel_rank: int | None = None,
|
||||
) -> tuple[EngineCoreRequest, dict[str, Any]]:
|
||||
"""Use the Processor to process inputs for AsyncLLM."""
|
||||
tokenization_kwargs: dict[str, Any] = {}
|
||||
@ -1173,6 +1174,7 @@ class OpenAIServing:
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
trace_headers=trace_headers,
|
||||
priority=priority,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
)
|
||||
return engine_request, tokenization_kwargs
|
||||
|
||||
|
||||
@ -8,6 +8,9 @@ from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionLayer, AttentionType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
@ -127,19 +130,43 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
if self.bmm2_scale is None:
|
||||
self.bmm2_scale = layer._v_scale_float
|
||||
|
||||
o = trtllm_batch_decode_with_kv_cache_mla(
|
||||
query=q,
|
||||
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
|
||||
workspace_buffer=self._workspace_buffer,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
block_tables=attn_metadata.decode.block_table,
|
||||
seq_lens=attn_metadata.decode.seq_lens,
|
||||
max_seq_len=attn_metadata.max_seq_len,
|
||||
bmm1_scale=self.bmm1_scale,
|
||||
bmm2_scale=self.bmm2_scale,
|
||||
)
|
||||
if vllm_is_batch_invariant():
|
||||
# execute per-request to eliminate batch-shape-dependent kernel paths.
|
||||
num = q.shape[0]
|
||||
outs = []
|
||||
for i in range(num):
|
||||
qi = q[i : i + 1]
|
||||
bt_i = attn_metadata.decode.block_table[i : i + 1]
|
||||
sl_i = attn_metadata.decode.seq_lens[i : i + 1]
|
||||
oi = trtllm_batch_decode_with_kv_cache_mla(
|
||||
query=qi,
|
||||
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
|
||||
workspace_buffer=self._workspace_buffer,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
block_tables=bt_i,
|
||||
seq_lens=sl_i,
|
||||
max_seq_len=attn_metadata.max_seq_len,
|
||||
bmm1_scale=self.bmm1_scale,
|
||||
bmm2_scale=self.bmm2_scale,
|
||||
)
|
||||
outs.append(oi)
|
||||
o = torch.cat(outs, dim=0)
|
||||
else:
|
||||
o = trtllm_batch_decode_with_kv_cache_mla(
|
||||
query=q,
|
||||
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
|
||||
workspace_buffer=self._workspace_buffer,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
block_tables=attn_metadata.decode.block_table,
|
||||
seq_lens=attn_metadata.decode.seq_lens,
|
||||
max_seq_len=attn_metadata.max_seq_len,
|
||||
bmm1_scale=self.bmm1_scale,
|
||||
bmm2_scale=self.bmm2_scale,
|
||||
)
|
||||
|
||||
# Flatten the output for consistent shape
|
||||
o = o.view(-1, o.shape[-2], o.shape[-1])
|
||||
|
||||
Reference in New Issue
Block a user