Compare commits
8 Commits
main
...
wentao-bat
| Author | SHA1 | Date | |
|---|---|---|---|
| 40464dbf34 | |||
| 981cc5fdbf | |||
| b53a65fa46 | |||
| 4f2a8d9d7f | |||
| b2f24cd6b7 | |||
| bc955355f8 | |||
| cf31e136ff | |||
| c2168c50cc |
11
tests/v1/determinism/conftest.py
Normal file
11
tests/v1/determinism/conftest.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def enable_batch_invariant_mode(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
"""Automatically enable batch invariant kernel overrides for all tests."""
|
||||||
|
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1")
|
||||||
|
yield
|
||||||
@ -6,66 +6,9 @@ import random
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
from utils import _extract_step_logprobs, _random_prompt, skip_unsupported
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.platforms import current_platform
|
|
||||||
|
|
||||||
skip_unsupported = pytest.mark.skipif(
|
|
||||||
not (current_platform.is_cuda() and current_platform.has_device_capability(90)),
|
|
||||||
reason="Requires CUDA and >= Hopper (SM90)",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def enable_batch_invariant_mode(monkeypatch: pytest.MonkeyPatch):
|
|
||||||
"""Automatically enable batch invariant kernel overrides for all tests."""
|
|
||||||
monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1")
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
|
|
||||||
# Generate more realistic prompts that will actually produce varied tokens
|
|
||||||
# Use a mix of common English text patterns
|
|
||||||
|
|
||||||
prompt_templates = [
|
|
||||||
# Question-answer style
|
|
||||||
"Question: What is the capital of France?\nAnswer: The capital of France is",
|
|
||||||
"Q: How does photosynthesis work?\nA: Photosynthesis is the process by which",
|
|
||||||
"User: Can you explain quantum mechanics?\nAssistant: Quantum mechanics is",
|
|
||||||
# Story/narrative style
|
|
||||||
"Once upon a time in a distant galaxy, there lived",
|
|
||||||
"The old man walked slowly down the street, remembering",
|
|
||||||
"In the year 2157, humanity finally discovered",
|
|
||||||
# Technical/code style
|
|
||||||
"To implement a binary search tree in Python, first we need to",
|
|
||||||
"The algorithm works by iterating through the array and",
|
|
||||||
"Here's how to optimize database queries using indexing:",
|
|
||||||
# Factual/informative style
|
|
||||||
"The Renaissance was a period in European history that",
|
|
||||||
"Climate change is caused by several factors including",
|
|
||||||
"The human brain contains approximately 86 billion neurons which",
|
|
||||||
# Conversational style
|
|
||||||
"I've been thinking about getting a new laptop because",
|
|
||||||
"Yesterday I went to the store and bought",
|
|
||||||
"My favorite thing about summer is definitely",
|
|
||||||
]
|
|
||||||
|
|
||||||
# Pick a random template
|
|
||||||
base_prompt = random.choice(prompt_templates)
|
|
||||||
|
|
||||||
if max_words < min_words:
|
|
||||||
max_words = min_words
|
|
||||||
target_words = random.randint(min_words, max_words)
|
|
||||||
|
|
||||||
if target_words > 50:
|
|
||||||
# For longer prompts, repeat context
|
|
||||||
padding_text = (
|
|
||||||
" This is an interesting topic that deserves more explanation. "
|
|
||||||
* (target_words // 50)
|
|
||||||
)
|
|
||||||
base_prompt = base_prompt + padding_text
|
|
||||||
|
|
||||||
return base_prompt
|
|
||||||
|
|
||||||
|
|
||||||
@skip_unsupported
|
@skip_unsupported
|
||||||
@ -204,22 +147,6 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
|
|||||||
llm_bsN.shutdown()
|
llm_bsN.shutdown()
|
||||||
|
|
||||||
|
|
||||||
def _extract_step_logprobs(request_output):
|
|
||||||
if getattr(request_output, "outputs", None):
|
|
||||||
inner = request_output.outputs[0]
|
|
||||||
if hasattr(inner, "logprobs") and inner.logprobs is not None:
|
|
||||||
t = torch.tensor(
|
|
||||||
[
|
|
||||||
inner.logprobs[i][tid].logprob
|
|
||||||
for i, tid in enumerate(inner.token_ids)
|
|
||||||
],
|
|
||||||
dtype=torch.float32,
|
|
||||||
)
|
|
||||||
return t, inner.token_ids
|
|
||||||
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
|
|
||||||
@skip_unsupported
|
@skip_unsupported
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"backend",
|
"backend",
|
||||||
208
tests/v1/determinism/test_online_batch_invariance.py
Normal file
208
tests/v1/determinism/test_online_batch_invariance.py
Normal file
@ -0,0 +1,208 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
HTTP-based batch invariance test: send requests to a running
|
||||||
|
vLLM server and compare BS=1 vs BS=N results (tokens and per-step logprobs).
|
||||||
|
|
||||||
|
Environment variables:
|
||||||
|
- VLLM_API_BASE: base URL like http://127.0.0.1:9256/v1 (default used)
|
||||||
|
- VLLM_TEST_MODEL: served model name (e.g., Qwen/Qwen3-1.7B)
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- The server must be started beforehand via `vllm serve ...`
|
||||||
|
- Example usage:
|
||||||
|
- `export VLLM_ATTENTION_BACKEND="FLASHINFER_MLA"`
|
||||||
|
- `export VLLM_BATCH_INVARIANT=1`
|
||||||
|
- `vllm serve deepseek-ai/DeepSeek-R1 -dp 8 --enable-expert-parallel --port 9256`
|
||||||
|
- `pytest tests/v1/generation/test_online_batch_invariance.py`
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
from typing import Any
|
||||||
|
from urllib.error import HTTPError, URLError
|
||||||
|
from urllib.request import Request, urlopen
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from utils import _random_prompt, skip_unsupported
|
||||||
|
|
||||||
|
|
||||||
|
def _post_json(
|
||||||
|
url: str, headers: dict[str, str], payload: dict[str, Any], timeout: float
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
import json
|
||||||
|
|
||||||
|
data = json.dumps(payload).encode("utf-8")
|
||||||
|
req = Request(url, data=data, headers=headers, method="POST")
|
||||||
|
with urlopen(req, timeout=timeout) as resp:
|
||||||
|
body = resp.read()
|
||||||
|
return json.loads(body.decode("utf-8"))
|
||||||
|
|
||||||
|
|
||||||
|
def _request_completion(
|
||||||
|
api_base: str,
|
||||||
|
model: str,
|
||||||
|
prompt: Any,
|
||||||
|
sp: dict[str, Any],
|
||||||
|
timeout: float = 60.0,
|
||||||
|
max_retries: int = 3,
|
||||||
|
retry_backoff: float = 0.5,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
url = api_base.rstrip("/") + "/completions"
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
|
||||||
|
payload: dict[str, Any] = {"model": model, "prompt": prompt}
|
||||||
|
payload.update(sp)
|
||||||
|
|
||||||
|
for attempt in range(max_retries + 1):
|
||||||
|
try:
|
||||||
|
return _post_json(url, headers, payload, timeout)
|
||||||
|
except HTTPError as e: # type: ignore[reportGeneralTypeIssues]
|
||||||
|
status = getattr(e, "code", None)
|
||||||
|
if status in (429, 500, 502, 503, 504) and attempt < max_retries:
|
||||||
|
import time as _t
|
||||||
|
|
||||||
|
_t.sleep(retry_backoff * (2**attempt))
|
||||||
|
continue
|
||||||
|
sys.stderr.write(f"HTTPError: {e}\n")
|
||||||
|
return None
|
||||||
|
except URLError as e: # type: ignore[reportGeneralTypeIssues]
|
||||||
|
if attempt < max_retries:
|
||||||
|
import time as _t
|
||||||
|
|
||||||
|
_t.sleep(retry_backoff * (2**attempt))
|
||||||
|
continue
|
||||||
|
sys.stderr.write(f"URLError: {e}\n")
|
||||||
|
return None
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
if attempt < max_retries:
|
||||||
|
import time as _t
|
||||||
|
|
||||||
|
_t.sleep(retry_backoff * (2**attempt))
|
||||||
|
continue
|
||||||
|
sys.stderr.write(f"Error: {e}\n")
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_tokens_and_logprobs(
|
||||||
|
choice: dict[str, Any],
|
||||||
|
) -> tuple[list[Any], list[float] | None]:
|
||||||
|
tokens: list[Any] = []
|
||||||
|
token_logprobs: list[float] | None = None
|
||||||
|
lp = choice.get("logprobs")
|
||||||
|
if lp and isinstance(lp, dict):
|
||||||
|
tokens = lp.get("token_ids") or lp.get("tokens") or []
|
||||||
|
token_logprobs = lp.get("token_logprobs", None)
|
||||||
|
return tokens, token_logprobs
|
||||||
|
|
||||||
|
|
||||||
|
def _server_is_up(api_base: str, timeout: float = 2.0) -> tuple[bool, str]:
|
||||||
|
url = api_base.rstrip("/") + "/models"
|
||||||
|
try:
|
||||||
|
req = Request(url, method="GET")
|
||||||
|
with urlopen(req, timeout=timeout) as resp:
|
||||||
|
_ = resp.read()
|
||||||
|
return True, "OK"
|
||||||
|
except URLError as e: # type: ignore[reportGeneralTypeIssues]
|
||||||
|
return False, f"URLError: {e}"
|
||||||
|
except Exception as e: # pragma: no cover
|
||||||
|
return False, f"Error: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
def _compare_bs1_vs_bsn_single_process(
|
||||||
|
prompts: list[str],
|
||||||
|
sp_kwargs: dict[str, Any],
|
||||||
|
api_base: str,
|
||||||
|
model_name: str,
|
||||||
|
) -> None:
|
||||||
|
# BS=1
|
||||||
|
bs1_tokens_per_prompt: list[list[Any]] = []
|
||||||
|
bs1_logprobs_per_prompt: list[list[float] | None] = []
|
||||||
|
for p in prompts:
|
||||||
|
resp = _request_completion(api_base, model_name, p, sp_kwargs)
|
||||||
|
if resp is None or not resp.get("choices"):
|
||||||
|
raise AssertionError("BS=1 empty/failed response")
|
||||||
|
choice = resp["choices"][0]
|
||||||
|
toks, lps = _extract_tokens_and_logprobs(choice)
|
||||||
|
if lps is None:
|
||||||
|
raise AssertionError(
|
||||||
|
"logprobs not returned; ensure server supports 'logprobs'"
|
||||||
|
)
|
||||||
|
bs1_tokens_per_prompt.append(list(toks))
|
||||||
|
bs1_logprobs_per_prompt.append(list(lps))
|
||||||
|
|
||||||
|
# BS=N
|
||||||
|
bsN_tokens_per_prompt: list[list[Any]] = [None] * len(prompts) # type: ignore[list-item]
|
||||||
|
bsN_logprobs_per_prompt: list[list[float] | None] = [None] * len(prompts)
|
||||||
|
resp = _request_completion(api_base, model_name, prompts, sp_kwargs)
|
||||||
|
if resp is None or not resp.get("choices"):
|
||||||
|
raise AssertionError("BS=N empty/failed batched response")
|
||||||
|
choices = resp.get("choices", [])
|
||||||
|
if len(choices) != len(prompts):
|
||||||
|
raise AssertionError(
|
||||||
|
f"BS=N choices length {len(choices)} != num prompts {len(prompts)}"
|
||||||
|
)
|
||||||
|
for idx, choice in enumerate(choices):
|
||||||
|
toks, lps = _extract_tokens_and_logprobs(choice)
|
||||||
|
if lps is None:
|
||||||
|
raise AssertionError(f"BS=N missing logprobs for prompt {idx}")
|
||||||
|
bsN_tokens_per_prompt[idx] = list(toks)
|
||||||
|
bsN_logprobs_per_prompt[idx] = list(lps)
|
||||||
|
|
||||||
|
# compare
|
||||||
|
for i, (tokens_bs1, tokens_bsN, logprobs_bs1, logprobs_bsN) in enumerate(
|
||||||
|
zip(
|
||||||
|
bs1_tokens_per_prompt,
|
||||||
|
bsN_tokens_per_prompt,
|
||||||
|
bs1_logprobs_per_prompt,
|
||||||
|
bsN_logprobs_per_prompt,
|
||||||
|
)
|
||||||
|
):
|
||||||
|
if tokens_bs1 != tokens_bsN:
|
||||||
|
raise AssertionError(
|
||||||
|
f"Prompt {i} (sampling): Different tokens sampled. "
|
||||||
|
f"BS=1 tokens: {tokens_bs1} BS=N tokens: {tokens_bsN}"
|
||||||
|
)
|
||||||
|
if logprobs_bs1 is None or logprobs_bsN is None:
|
||||||
|
raise AssertionError(f"Prompt {i}: Missing logprobs in one of the runs")
|
||||||
|
if len(logprobs_bs1) != len(logprobs_bsN):
|
||||||
|
raise AssertionError(
|
||||||
|
f"Prompt {i}: Different number of steps: "
|
||||||
|
f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bsN)} (BS=N)."
|
||||||
|
)
|
||||||
|
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bsN)):
|
||||||
|
if a != b:
|
||||||
|
diff = abs(a - b)
|
||||||
|
raise AssertionError(
|
||||||
|
f"Prompt {i} Step {t}: Bitwise mismatch "
|
||||||
|
f"(abs diff={diff:.6e}). "
|
||||||
|
f"BS=1 tokens: {tokens_bs1} BS=N tokens: {tokens_bsN}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@skip_unsupported
|
||||||
|
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN_dp_http():
|
||||||
|
random.seed(int(os.getenv("VLLM_TEST_SEED", "12345")))
|
||||||
|
api_base = os.getenv("VLLM_API_BASE", "http://127.0.0.1:9256/v1")
|
||||||
|
model_name = os.getenv("VLLM_TEST_MODEL", "deepseek-ai/DeepSeek-R1")
|
||||||
|
up, reason = _server_is_up(api_base)
|
||||||
|
if not up:
|
||||||
|
pytest.skip(f"Server not reachable at {api_base}: {reason}")
|
||||||
|
prompts_all = [_random_prompt(10, 50) for _ in range(32)]
|
||||||
|
|
||||||
|
sp_kwargs: dict[str, Any] = {
|
||||||
|
"temperature": 0.6,
|
||||||
|
"top_p": 1.0,
|
||||||
|
"max_tokens": 8,
|
||||||
|
"seed": 42,
|
||||||
|
"logprobs": 5,
|
||||||
|
}
|
||||||
|
|
||||||
|
_compare_bs1_vs_bsn_single_process(
|
||||||
|
prompts=prompts_all,
|
||||||
|
sp_kwargs=sp_kwargs,
|
||||||
|
api_base=api_base,
|
||||||
|
model_name=model_name,
|
||||||
|
)
|
||||||
@ -9,15 +9,10 @@ with the standard CUDA-based implementation to ensure numerical accuracy.
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
from utils import skip_unsupported
|
||||||
|
|
||||||
from vllm.model_executor.layers.batch_invariant import rms_norm as triton_rms_norm
|
from vllm.model_executor.layers.batch_invariant import rms_norm as triton_rms_norm
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.platforms import current_platform
|
|
||||||
|
|
||||||
skip_unsupported = pytest.mark.skipif(
|
|
||||||
not (current_platform.is_cuda() and current_platform.has_device_capability(90)),
|
|
||||||
reason="Requires CUDA and >= Hopper (SM90)",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@skip_unsupported
|
@skip_unsupported
|
||||||
74
tests/v1/determinism/utils.py
Normal file
74
tests/v1/determinism/utils.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import random
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
skip_unsupported = pytest.mark.skipif(
|
||||||
|
not (current_platform.is_cuda() and current_platform.has_device_capability(90)),
|
||||||
|
reason="Requires CUDA and >= Hopper (SM90)",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
|
||||||
|
# Generate more realistic prompts that will actually produce varied tokens
|
||||||
|
# Use a mix of common English text patterns
|
||||||
|
|
||||||
|
prompt_templates = [
|
||||||
|
# Question-answer style
|
||||||
|
"Question: What is the capital of France?\nAnswer: The capital of France is",
|
||||||
|
"Q: How does photosynthesis work?\nA: Photosynthesis is the process by which",
|
||||||
|
"User: Can you explain quantum mechanics?\nAssistant: Quantum mechanics is",
|
||||||
|
# Story/narrative style
|
||||||
|
"Once upon a time in a distant galaxy, there lived",
|
||||||
|
"The old man walked slowly down the street, remembering",
|
||||||
|
"In the year 2157, humanity finally discovered",
|
||||||
|
# Technical/code style
|
||||||
|
"To implement a binary search tree in Python, first we need to",
|
||||||
|
"The algorithm works by iterating through the array and",
|
||||||
|
"Here's how to optimize database queries using indexing:",
|
||||||
|
# Factual/informative style
|
||||||
|
"The Renaissance was a period in European history that",
|
||||||
|
"Climate change is caused by several factors including",
|
||||||
|
"The human brain contains approximately 86 billion neurons which",
|
||||||
|
# Conversational style
|
||||||
|
"I've been thinking about getting a new laptop because",
|
||||||
|
"Yesterday I went to the store and bought",
|
||||||
|
"My favorite thing about summer is definitely",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Pick a random template
|
||||||
|
base_prompt = random.choice(prompt_templates)
|
||||||
|
|
||||||
|
if max_words < min_words:
|
||||||
|
max_words = min_words
|
||||||
|
target_words = random.randint(min_words, max_words)
|
||||||
|
|
||||||
|
if target_words > 50:
|
||||||
|
# For longer prompts, repeat context
|
||||||
|
padding_text = (
|
||||||
|
" This is an interesting topic that deserves more explanation. "
|
||||||
|
* (target_words // 50)
|
||||||
|
)
|
||||||
|
base_prompt = base_prompt + padding_text
|
||||||
|
|
||||||
|
return base_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_step_logprobs(request_output):
|
||||||
|
if getattr(request_output, "outputs", None):
|
||||||
|
inner = request_output.outputs[0]
|
||||||
|
if hasattr(inner, "logprobs") and inner.logprobs is not None:
|
||||||
|
t = torch.tensor(
|
||||||
|
[
|
||||||
|
inner.logprobs[i][tid].logprob
|
||||||
|
for i, tid in enumerate(inner.token_ids)
|
||||||
|
],
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
return t, inner.token_ids
|
||||||
|
|
||||||
|
return None, None
|
||||||
@ -206,6 +206,19 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# TODO: optimize the performance here
|
||||||
|
# in batch invariance mode, force routing to DP rank 0.
|
||||||
|
data_parallel_rank = None
|
||||||
|
from vllm.model_executor.layers.batch_invariant import (
|
||||||
|
vllm_is_batch_invariant,
|
||||||
|
)
|
||||||
|
|
||||||
|
if vllm_is_batch_invariant():
|
||||||
|
config = self.engine_client.vllm_config
|
||||||
|
dp_size = config.parallel_config.data_parallel_size
|
||||||
|
if dp_size and dp_size > 1:
|
||||||
|
data_parallel_rank = 0
|
||||||
|
|
||||||
engine_request, tokenization_kwargs = await self._process_inputs(
|
engine_request, tokenization_kwargs = await self._process_inputs(
|
||||||
request_id_item,
|
request_id_item,
|
||||||
engine_prompt,
|
engine_prompt,
|
||||||
@ -213,6 +226,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
trace_headers=trace_headers,
|
trace_headers=trace_headers,
|
||||||
priority=request.priority,
|
priority=request.priority,
|
||||||
|
data_parallel_rank=data_parallel_rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
generator = self.engine_client.generate(
|
generator = self.engine_client.generate(
|
||||||
|
|||||||
@ -1158,6 +1158,7 @@ class OpenAIServing:
|
|||||||
lora_request: LoRARequest | None,
|
lora_request: LoRARequest | None,
|
||||||
trace_headers: Mapping[str, str] | None,
|
trace_headers: Mapping[str, str] | None,
|
||||||
priority: int,
|
priority: int,
|
||||||
|
data_parallel_rank: int | None = None,
|
||||||
) -> tuple[EngineCoreRequest, dict[str, Any]]:
|
) -> tuple[EngineCoreRequest, dict[str, Any]]:
|
||||||
"""Use the Processor to process inputs for AsyncLLM."""
|
"""Use the Processor to process inputs for AsyncLLM."""
|
||||||
tokenization_kwargs: dict[str, Any] = {}
|
tokenization_kwargs: dict[str, Any] = {}
|
||||||
@ -1173,6 +1174,7 @@ class OpenAIServing:
|
|||||||
tokenization_kwargs=tokenization_kwargs,
|
tokenization_kwargs=tokenization_kwargs,
|
||||||
trace_headers=trace_headers,
|
trace_headers=trace_headers,
|
||||||
priority=priority,
|
priority=priority,
|
||||||
|
data_parallel_rank=data_parallel_rank,
|
||||||
)
|
)
|
||||||
return engine_request, tokenization_kwargs
|
return engine_request, tokenization_kwargs
|
||||||
|
|
||||||
|
|||||||
@ -8,6 +8,9 @@ from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
|
|||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionLayer, AttentionType
|
from vllm.attention.backends.abstract import AttentionLayer, AttentionType
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.batch_invariant import (
|
||||||
|
vllm_is_batch_invariant,
|
||||||
|
)
|
||||||
from vllm.v1.attention.backends.mla.common import (
|
from vllm.v1.attention.backends.mla.common import (
|
||||||
MLACommonBackend,
|
MLACommonBackend,
|
||||||
MLACommonImpl,
|
MLACommonImpl,
|
||||||
@ -127,19 +130,44 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
if self.bmm2_scale is None:
|
if self.bmm2_scale is None:
|
||||||
self.bmm2_scale = layer._v_scale_float
|
self.bmm2_scale = layer._v_scale_float
|
||||||
|
|
||||||
o = trtllm_batch_decode_with_kv_cache_mla(
|
if vllm_is_batch_invariant():
|
||||||
query=q,
|
# TODO(wentao): optimize this when it is supported by Flashinfer upstream.
|
||||||
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
|
# execute per-request to eliminate batch-shape-dependent kernel paths.
|
||||||
workspace_buffer=self._workspace_buffer,
|
num = q.shape[0]
|
||||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
outs = []
|
||||||
kv_lora_rank=self.kv_lora_rank,
|
for i in range(num):
|
||||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
qi = q[i : i + 1]
|
||||||
block_tables=attn_metadata.decode.block_table,
|
bt_i = attn_metadata.decode.block_table[i : i + 1]
|
||||||
seq_lens=attn_metadata.decode.seq_lens,
|
sl_i = attn_metadata.decode.seq_lens[i : i + 1]
|
||||||
max_seq_len=attn_metadata.max_seq_len,
|
oi = trtllm_batch_decode_with_kv_cache_mla(
|
||||||
bmm1_scale=self.bmm1_scale,
|
query=qi,
|
||||||
bmm2_scale=self.bmm2_scale,
|
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
|
||||||
)
|
workspace_buffer=self._workspace_buffer,
|
||||||
|
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||||
|
kv_lora_rank=self.kv_lora_rank,
|
||||||
|
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||||
|
block_tables=bt_i,
|
||||||
|
seq_lens=sl_i,
|
||||||
|
max_seq_len=attn_metadata.max_seq_len,
|
||||||
|
bmm1_scale=self.bmm1_scale,
|
||||||
|
bmm2_scale=self.bmm2_scale,
|
||||||
|
)
|
||||||
|
outs.append(oi)
|
||||||
|
o = torch.cat(outs, dim=0)
|
||||||
|
else:
|
||||||
|
o = trtllm_batch_decode_with_kv_cache_mla(
|
||||||
|
query=q,
|
||||||
|
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
|
||||||
|
workspace_buffer=self._workspace_buffer,
|
||||||
|
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||||
|
kv_lora_rank=self.kv_lora_rank,
|
||||||
|
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||||
|
block_tables=attn_metadata.decode.block_table,
|
||||||
|
seq_lens=attn_metadata.decode.seq_lens,
|
||||||
|
max_seq_len=attn_metadata.max_seq_len,
|
||||||
|
bmm1_scale=self.bmm1_scale,
|
||||||
|
bmm2_scale=self.bmm2_scale,
|
||||||
|
)
|
||||||
|
|
||||||
# Flatten the output for consistent shape
|
# Flatten the output for consistent shape
|
||||||
o = o.view(-1, o.shape[-2], o.shape[-1])
|
o = o.view(-1, o.shape[-2], o.shape[-1])
|
||||||
|
|||||||
Reference in New Issue
Block a user