flashmla + dp support

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
yewentao256
2025-10-23 08:12:49 -07:00
parent c2168c50cc
commit cf31e136ff
5 changed files with 82 additions and 42 deletions

View File

@ -9,7 +9,7 @@ import pytest
def enable_batch_invariant_mode():
"""Automatically enable batch invariant kernel overrides for all tests."""
old_value = os.environ.get("VLLM_BATCH_INVARIANT")
os.environ["VLLM_BATCH_INVARIANT"] = "0"
os.environ["VLLM_BATCH_INVARIANT"] = "1"
yield
# restore original value after test
if old_value is None:

View File

@ -9,16 +9,17 @@ Environment variables:
- VLLM_TEST_MODEL: served model name (e.g., Qwen/Qwen3-1.7B)
Note:
- The server must be started beforehand via `vllm serve ...` and should
honor the same sampling params for determinism.
- For BS=N, we dispatch requests concurrently to encourage server-side
dynamic batching.
- The server must be started beforehand via `vllm serve ...`
- Example usage:
- `export VLLM_ATTENTION_BACKEND="FLASHINFER_MLA"`
- `export VLLM_BATCH_INVARIANT=1`
- `vllm serve deepseek-ai/DeepSeek-R1 -dp 8 --enable-expert-parallel --port 9256`
- `pytest tests/v1/generation/test_online_batch_invariance.py`
"""
import os
import random
import sys
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any
from urllib.error import HTTPError, URLError
from urllib.request import Request, urlopen
@ -41,7 +42,7 @@ def _post_json(
def _request_completion(
api_base: str,
model: str,
prompt: str,
prompt: Any,
sp: dict[str, Any],
timeout: float = 60.0,
max_retries: int = 3,
@ -118,27 +119,25 @@ def _compare_bs1_vs_bsn_single_process(
bs1_tokens_per_prompt.append(list(toks))
bs1_logprobs_per_prompt.append(list(lps))
# BS=N: dispatch concurrently to encourage server batching
# BS=N
bsN_tokens_per_prompt: list[list[Any]] = [None] * len(prompts) # type: ignore[list-item]
bsN_logprobs_per_prompt: list[list[float] | None] = [None] * len(prompts)
with ThreadPoolExecutor(max_workers=min(64, max(1, len(prompts)))) as ex:
futures = {
ex.submit(_request_completion, api_base, model_name, p, sp_kwargs): i
for i, p in enumerate(prompts)
}
for fut in as_completed(futures):
idx = futures[fut]
resp = fut.result()
if resp is None or not resp.get("choices"):
raise AssertionError(f"BS=N empty/failed response for prompt {idx}")
choice = resp["choices"][0]
toks, lps = _extract_tokens_and_logprobs(choice)
if lps is None:
raise AssertionError(f"BS=N missing logprobs for prompt {idx}")
bsN_tokens_per_prompt[idx] = list(toks)
bsN_logprobs_per_prompt[idx] = list(lps)
resp = _request_completion(api_base, model_name, prompts, sp_kwargs)
if resp is None or not resp.get("choices"):
raise AssertionError("BS=N empty/failed batched response")
choices = resp.get("choices", [])
if len(choices) != len(prompts):
raise AssertionError(
f"BS=N choices length {len(choices)} != num prompts {len(prompts)}"
)
for idx, choice in enumerate(choices):
toks, lps = _extract_tokens_and_logprobs(choice)
if lps is None:
raise AssertionError(f"BS=N missing logprobs for prompt {idx}")
bsN_tokens_per_prompt[idx] = list(toks)
bsN_logprobs_per_prompt[idx] = list(lps)
# Compare
# compare
for i, (tokens_bs1, tokens_bsN, logprobs_bs1, logprobs_bsN) in enumerate(
zip(
bs1_tokens_per_prompt,
@ -172,11 +171,9 @@ def _compare_bs1_vs_bsn_single_process(
@skip_unsupported
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN_dp_http():
random.seed(int(os.getenv("VLLM_TEST_SEED", "12345")))
api_base = os.getenv("VLLM_API_BASE", "http://127.0.0.1:9256/v1")
model_name = os.getenv("VLLM_TEST_MODEL", "deepseek-ai/DeepSeek-V2-lite")
num_prompts = int(os.getenv("VLLM_TEST_NUM_PROMPTS", "32"))
prompts_all = [_random_prompt(10, 50) for _ in range(num_prompts)]
model_name = os.getenv("VLLM_TEST_MODEL", "deepseek-ai/DeepSeek-R1")
prompts_all = [_random_prompt(10, 50) for _ in range(32)]
sp_kwargs: dict[str, Any] = {
"temperature": 0.6,

View File

@ -206,6 +206,19 @@ class OpenAIServingCompletion(OpenAIServing):
lora_request=lora_request,
)
else:
# TODO: optimize the performance here
# in batch invariance mode, force routing to DP rank 0.
data_parallel_rank = None
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
if vllm_is_batch_invariant():
config = self.engine_client.vllm_config
dp_size = config.parallel_config.data_parallel_size
if dp_size and dp_size > 1:
data_parallel_rank = 0
engine_request, tokenization_kwargs = await self._process_inputs(
request_id_item,
engine_prompt,
@ -213,6 +226,7 @@ class OpenAIServingCompletion(OpenAIServing):
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
data_parallel_rank=data_parallel_rank,
)
generator = self.engine_client.generate(

View File

@ -1158,6 +1158,7 @@ class OpenAIServing:
lora_request: LoRARequest | None,
trace_headers: Mapping[str, str] | None,
priority: int,
data_parallel_rank: int | None = None,
) -> tuple[EngineCoreRequest, dict[str, Any]]:
"""Use the Processor to process inputs for AsyncLLM."""
tokenization_kwargs: dict[str, Any] = {}
@ -1173,6 +1174,7 @@ class OpenAIServing:
tokenization_kwargs=tokenization_kwargs,
trace_headers=trace_headers,
priority=priority,
data_parallel_rank=data_parallel_rank,
)
return engine_request, tokenization_kwargs

View File

@ -8,6 +8,9 @@ from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
from vllm.attention.backends.abstract import AttentionLayer, AttentionType
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
MLACommonImpl,
@ -127,19 +130,43 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
if self.bmm2_scale is None:
self.bmm2_scale = layer._v_scale_float
o = trtllm_batch_decode_with_kv_cache_mla(
query=q,
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
workspace_buffer=self._workspace_buffer,
qk_nope_head_dim=self.qk_nope_head_dim,
kv_lora_rank=self.kv_lora_rank,
qk_rope_head_dim=self.qk_rope_head_dim,
block_tables=attn_metadata.decode.block_table,
seq_lens=attn_metadata.decode.seq_lens,
max_seq_len=attn_metadata.max_seq_len,
bmm1_scale=self.bmm1_scale,
bmm2_scale=self.bmm2_scale,
)
if vllm_is_batch_invariant():
# execute per-request to eliminate batch-shape-dependent kernel paths.
num = q.shape[0]
outs = []
for i in range(num):
qi = q[i : i + 1]
bt_i = attn_metadata.decode.block_table[i : i + 1]
sl_i = attn_metadata.decode.seq_lens[i : i + 1]
oi = trtllm_batch_decode_with_kv_cache_mla(
query=qi,
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
workspace_buffer=self._workspace_buffer,
qk_nope_head_dim=self.qk_nope_head_dim,
kv_lora_rank=self.kv_lora_rank,
qk_rope_head_dim=self.qk_rope_head_dim,
block_tables=bt_i,
seq_lens=sl_i,
max_seq_len=attn_metadata.max_seq_len,
bmm1_scale=self.bmm1_scale,
bmm2_scale=self.bmm2_scale,
)
outs.append(oi)
o = torch.cat(outs, dim=0)
else:
o = trtllm_batch_decode_with_kv_cache_mla(
query=q,
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
workspace_buffer=self._workspace_buffer,
qk_nope_head_dim=self.qk_nope_head_dim,
kv_lora_rank=self.kv_lora_rank,
qk_rope_head_dim=self.qk_rope_head_dim,
block_tables=attn_metadata.decode.block_table,
seq_lens=attn_metadata.decode.seq_lens,
max_seq_len=attn_metadata.max_seq_len,
bmm1_scale=self.bmm1_scale,
bmm2_scale=self.bmm2_scale,
)
# Flatten the output for consistent shape
o = o.view(-1, o.shape[-2], o.shape[-1])