Compare commits
9 Commits
copilot/fi
...
split_kv_c
| Author | SHA1 | Date | |
|---|---|---|---|
| 6e1e31a66a | |||
| 50e80db4ef | |||
| d3d6afb355 | |||
| 808fa43d76 | |||
| 4ac510f484 | |||
| 7fb2a5be28 | |||
| 6c036615dc | |||
| 2fc24e94f9 | |||
| 2c3c1bd07a |
@ -217,16 +217,14 @@ steps:
|
||||
num_gpus: 2
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/metrics
|
||||
- tests/v1/tracing
|
||||
commands:
|
||||
- pytest -v -s metrics
|
||||
- "pip install \
|
||||
'opentelemetry-sdk>=1.26.0' \
|
||||
'opentelemetry-api>=1.26.0' \
|
||||
'opentelemetry-exporter-otlp>=1.26.0' \
|
||||
'opentelemetry-semantic-conventions-ai>=0.4.1'"
|
||||
- pytest -v -s tracing
|
||||
- pytest -v -s v1/tracing
|
||||
|
||||
##### fast check tests #####
|
||||
##### 1 GPU test #####
|
||||
|
||||
2
.github/CODEOWNERS
vendored
2
.github/CODEOWNERS
vendored
@ -37,8 +37,8 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||
/vllm/v1/attention/backends/triton_attn.py @tdoublep
|
||||
/vllm/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat @heheda12345 @ApostaC
|
||||
/vllm/v1/kv_cache_interface.py @heheda12345
|
||||
/vllm/v1/worker/kv_cache_initializer_mixin.py @heheda12345
|
||||
/vllm/v1/offloading @ApostaC
|
||||
|
||||
# Test ownership
|
||||
/.buildkite/lm-eval-harness @mgoin @simon-mo
|
||||
/tests/async_engine @njhill @robertgshaw2-redhat @simon-mo
|
||||
|
||||
@ -26,23 +26,10 @@ logger = init_logger("test_pipeline_parallel")
|
||||
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch):
|
||||
"""
|
||||
For PP, we fall back to V0 by default. This means
|
||||
that the TP baseline runs with V1 while the PP engine
|
||||
runs with V0. This gives divergent results with dummy
|
||||
weights. Once we enable V1 by default for PP, we can
|
||||
remove this.
|
||||
"""
|
||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||
|
||||
|
||||
class ParallelSetup(NamedTuple):
|
||||
tp_size: int
|
||||
pp_size: int
|
||||
eager_mode: bool
|
||||
chunked_prefill: bool
|
||||
|
||||
|
||||
class PPTestOptions(NamedTuple):
|
||||
@ -53,23 +40,10 @@ class PPTestOptions(NamedTuple):
|
||||
@dataclass
|
||||
class PPTestSettings:
|
||||
parallel_setups: list[ParallelSetup]
|
||||
# NOTE: the length of distributed_backends and
|
||||
# vllm_major_versions should be the same, and they
|
||||
# are first zipped together to iterate over all
|
||||
# test settings.
|
||||
distributed_backends: list[str]
|
||||
# vllm major version: "0" for V0, "1" for V1
|
||||
vllm_major_versions: list[str]
|
||||
runner: RunnerOption
|
||||
test_options: PPTestOptions
|
||||
|
||||
def __post_init__(self):
|
||||
if len(self.distributed_backends) != len(self.vllm_major_versions):
|
||||
raise ValueError(
|
||||
f"Length mismatch: distributed_backends "
|
||||
f"({len(self.distributed_backends)}) != "
|
||||
f"vllm_major_versions ({len(self.vllm_major_versions)})")
|
||||
|
||||
@staticmethod
|
||||
def detailed(
|
||||
*,
|
||||
@ -83,27 +57,21 @@ class PPTestSettings:
|
||||
parallel_setups=[
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
pp_size=pp_base,
|
||||
eager_mode=False,
|
||||
chunked_prefill=False),
|
||||
eager_mode=False),
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
pp_size=2 * pp_base,
|
||||
eager_mode=False,
|
||||
chunked_prefill=True),
|
||||
eager_mode=False),
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
pp_size=2 * pp_base,
|
||||
eager_mode=True,
|
||||
chunked_prefill=False),
|
||||
eager_mode=True),
|
||||
ParallelSetup(tp_size=2 * tp_base,
|
||||
pp_size=pp_base,
|
||||
eager_mode=False,
|
||||
chunked_prefill=True),
|
||||
eager_mode=False),
|
||||
ParallelSetup(tp_size=2 * tp_base,
|
||||
pp_size=pp_base,
|
||||
eager_mode=True,
|
||||
chunked_prefill=False),
|
||||
eager_mode=True),
|
||||
],
|
||||
distributed_backends=["mp", "mp", "ray", "ray"],
|
||||
vllm_major_versions=["0", "1", "0", "1"],
|
||||
distributed_backends=["mp", "ray"],
|
||||
runner=runner,
|
||||
test_options=PPTestOptions(multi_node_only=multi_node_only,
|
||||
load_format=load_format),
|
||||
@ -118,17 +86,14 @@ class PPTestSettings:
|
||||
multi_node_only: bool = False,
|
||||
load_format: Optional[str] = None,
|
||||
):
|
||||
vllm_major_versions = ["1"] if runner == "pooling" else ["0"]
|
||||
|
||||
return PPTestSettings(
|
||||
parallel_setups=[
|
||||
ParallelSetup(tp_size=tp_base,
|
||||
pp_size=pp_base,
|
||||
eager_mode=True,
|
||||
chunked_prefill=False),
|
||||
eager_mode=True),
|
||||
],
|
||||
distributed_backends=["mp"],
|
||||
vllm_major_versions=vllm_major_versions,
|
||||
runner=runner,
|
||||
test_options=PPTestOptions(multi_node_only=multi_node_only,
|
||||
load_format=load_format),
|
||||
@ -138,10 +103,8 @@ class PPTestSettings:
|
||||
opts = self.test_options
|
||||
|
||||
for parallel_setup in self.parallel_setups:
|
||||
for backend, vllm_major_version in zip(self.distributed_backends,
|
||||
self.vllm_major_versions):
|
||||
yield (model_id, parallel_setup, backend, vllm_major_version,
|
||||
self.runner, opts)
|
||||
for backend in self.distributed_backends:
|
||||
yield (model_id, parallel_setup, backend, self.runner, opts)
|
||||
|
||||
|
||||
# NOTE: You can adjust tp_base and/or pp_base locally to fit the model in GPU
|
||||
@ -269,7 +232,6 @@ def _compare_tp(
|
||||
model_id: str,
|
||||
parallel_setup: ParallelSetup,
|
||||
distributed_backend: str,
|
||||
vllm_major_version: str,
|
||||
runner: RunnerOption,
|
||||
test_options: PPTestOptions,
|
||||
num_gpus_available: int,
|
||||
@ -281,7 +243,6 @@ def _compare_tp(
|
||||
tp_size,
|
||||
pp_size,
|
||||
eager_mode,
|
||||
chunked_prefill,
|
||||
) = parallel_setup
|
||||
|
||||
multi_node_only, load_format = test_options
|
||||
@ -334,8 +295,6 @@ def _compare_tp(
|
||||
"--max-num-seqs",
|
||||
"8",
|
||||
]
|
||||
if chunked_prefill:
|
||||
common_args.append("--enable-chunked-prefill")
|
||||
if eager_mode:
|
||||
common_args.append("--enforce-eager")
|
||||
if runner != "auto":
|
||||
@ -353,14 +312,10 @@ def _compare_tp(
|
||||
if max_num_seqs:
|
||||
common_args.extend(["--max-num-seqs", f"{max_num_seqs}"])
|
||||
|
||||
specific_case = tp_size == 2 and pp_size == 2 and chunked_prefill
|
||||
testing_ray_compiled_graph = False
|
||||
if distributed_backend == "ray" and (vllm_major_version == "1"
|
||||
or specific_case):
|
||||
if distributed_backend == "ray":
|
||||
# For V1, test Ray Compiled Graph for all the tests
|
||||
# For V0, test Ray Compiled Graph for a subset of the tests
|
||||
pp_env = {
|
||||
"VLLM_USE_V1": vllm_major_version,
|
||||
"VLLM_USE_V1": "1",
|
||||
"VLLM_USE_RAY_COMPILED_DAG": "1",
|
||||
"VLLM_USE_RAY_SPMD_WORKER": "1",
|
||||
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": "1",
|
||||
@ -368,17 +323,15 @@ def _compare_tp(
|
||||
# Temporary. Currently when zeromq + SPMD is used, it does not properly
|
||||
# terminate because of a Ray Compiled Graph issue.
|
||||
common_args.append("--disable-frontend-multiprocessing")
|
||||
testing_ray_compiled_graph = True
|
||||
elif distributed_backend == "mp":
|
||||
# Both V0/V1 of multiprocessing executor support PP
|
||||
pp_env = {
|
||||
"VLLM_USE_V1": vllm_major_version,
|
||||
"VLLM_USE_V1": "1",
|
||||
}
|
||||
else:
|
||||
pp_env = None
|
||||
|
||||
tp_env = {
|
||||
"VLLM_USE_V1": vllm_major_version,
|
||||
"VLLM_USE_V1": "1",
|
||||
}
|
||||
|
||||
pp_args = [
|
||||
@ -404,25 +357,17 @@ def _compare_tp(
|
||||
"mp",
|
||||
]
|
||||
|
||||
try:
|
||||
compare_two_settings(model_id,
|
||||
pp_args,
|
||||
tp_args,
|
||||
pp_env,
|
||||
tp_env,
|
||||
method=method)
|
||||
except Exception:
|
||||
if testing_ray_compiled_graph and vllm_major_version == "0":
|
||||
# Ray Compiled Graph tests are flaky for V0,
|
||||
# so we don't want to fail the test
|
||||
logger.exception("Ray Compiled Graph tests failed")
|
||||
else:
|
||||
raise
|
||||
compare_two_settings(model_id,
|
||||
pp_args,
|
||||
tp_args,
|
||||
pp_env,
|
||||
tp_env,
|
||||
method=method)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_id", "parallel_setup", "distributed_backend", "vllm_major_version",
|
||||
"runner", "test_options"),
|
||||
("model_id", "parallel_setup", "distributed_backend", "runner",
|
||||
"test_options"),
|
||||
[
|
||||
params for model_id, settings in TEXT_GENERATION_MODELS.items()
|
||||
for params in settings.iter_params(model_id) if model_id in TEST_MODELS
|
||||
@ -433,15 +378,14 @@ def test_tp_language_generation(
|
||||
model_id: str,
|
||||
parallel_setup: ParallelSetup,
|
||||
distributed_backend: str,
|
||||
vllm_major_version: str,
|
||||
runner: RunnerOption,
|
||||
test_options: PPTestOptions,
|
||||
num_gpus_available,
|
||||
):
|
||||
pytest.skip("Skipping the test until V1 passes it.")
|
||||
_compare_tp(model_id,
|
||||
parallel_setup,
|
||||
distributed_backend,
|
||||
vllm_major_version,
|
||||
runner,
|
||||
test_options,
|
||||
num_gpus_available,
|
||||
@ -450,8 +394,8 @@ def test_tp_language_generation(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_id", "parallel_setup", "distributed_backend", "vllm_major_version",
|
||||
"runner", "test_options"),
|
||||
("model_id", "parallel_setup", "distributed_backend", "runner",
|
||||
"test_options"),
|
||||
[
|
||||
params for model_id, settings in EMBEDDING_MODELS.items()
|
||||
for params in settings.iter_params(model_id) if model_id in TEST_MODELS
|
||||
@ -462,15 +406,14 @@ def test_tp_language_embedding(
|
||||
model_id: str,
|
||||
parallel_setup: ParallelSetup,
|
||||
distributed_backend: str,
|
||||
vllm_major_version: str,
|
||||
runner: RunnerOption,
|
||||
test_options: PPTestOptions,
|
||||
num_gpus_available,
|
||||
):
|
||||
pytest.skip("Skipping the test until V1 passes it.")
|
||||
_compare_tp(model_id,
|
||||
parallel_setup,
|
||||
distributed_backend,
|
||||
vllm_major_version,
|
||||
runner,
|
||||
test_options,
|
||||
num_gpus_available,
|
||||
@ -479,8 +422,8 @@ def test_tp_language_embedding(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_id", "parallel_setup", "distributed_backend", "vllm_major_version",
|
||||
"runner", "test_options"),
|
||||
("model_id", "parallel_setup", "distributed_backend", "runner",
|
||||
"test_options"),
|
||||
[
|
||||
params for model_id, settings in MULTIMODAL_MODELS.items()
|
||||
for params in settings.iter_params(model_id) if model_id in TEST_MODELS
|
||||
@ -491,15 +434,14 @@ def test_tp_multimodal_generation(
|
||||
model_id: str,
|
||||
parallel_setup: ParallelSetup,
|
||||
distributed_backend: str,
|
||||
vllm_major_version: str,
|
||||
runner: RunnerOption,
|
||||
test_options: PPTestOptions,
|
||||
num_gpus_available,
|
||||
):
|
||||
pytest.skip("Skipping the test until V1 passes it.")
|
||||
_compare_tp(model_id,
|
||||
parallel_setup,
|
||||
distributed_backend,
|
||||
vllm_major_version,
|
||||
runner,
|
||||
test_options,
|
||||
num_gpus_available,
|
||||
|
||||
@ -1,12 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch):
|
||||
"""
|
||||
Since this module is V0 only, set VLLM_USE_V1=0 for
|
||||
all tests in the module.
|
||||
"""
|
||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||
@ -1,37 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["distilbert/distilgpt2"])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
def test_computed_prefix_blocks(model: str, block_size: int):
|
||||
# This test checks if we are able to run the engine to completion
|
||||
# without triggering asserts.
|
||||
# We are in a scenario where all blocks from the second request's prompt
|
||||
# are full and already computed when the second request arrives.
|
||||
prompt = (
|
||||
"You are a helpful assistant. How do I build a car from cardboard and "
|
||||
"paper clips? Is there an easy to follow video tutorial available "
|
||||
"online for free?")
|
||||
prompt2 = (
|
||||
" Please recommend to me some resources where I can learn not only to "
|
||||
"handle technical difficulties of building a car, but also "
|
||||
"decoration.")
|
||||
|
||||
engine_args = EngineArgs(model=model,
|
||||
block_size=block_size,
|
||||
enable_prefix_caching=True)
|
||||
|
||||
engine = LLMEngine.from_engine_args(engine_args)
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
engine.add_request("0", prompt + prompt2, sampling_params)
|
||||
engine.step()
|
||||
engine.add_request("1", prompt, sampling_params)
|
||||
engine.step()
|
||||
@ -1,111 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.executor.uniproc_executor import UniProcExecutor
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
|
||||
class Mock:
|
||||
...
|
||||
|
||||
|
||||
class CustomUniExecutor(UniProcExecutor):
|
||||
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable],
|
||||
timeout: Optional[float] = None,
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict] = None) -> list[Any]:
|
||||
# Drop marker to show that this was run
|
||||
with open(".marker", "w"):
|
||||
...
|
||||
return super().collective_rpc(method, timeout, args, kwargs)
|
||||
|
||||
|
||||
CustomUniExecutorAsync = CustomUniExecutor
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["distilbert/distilgpt2"])
|
||||
def test_custom_executor_type_checking(model):
|
||||
with pytest.raises(ValueError):
|
||||
engine_args = EngineArgs(model=model,
|
||||
distributed_executor_backend=Mock)
|
||||
LLMEngine.from_engine_args(engine_args)
|
||||
with pytest.raises(ValueError):
|
||||
engine_args = AsyncEngineArgs(model=model,
|
||||
distributed_executor_backend=Mock)
|
||||
AsyncLLMEngine.from_engine_args(engine_args)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["distilbert/distilgpt2"])
|
||||
def test_custom_executor(model, tmp_path):
|
||||
cwd = os.path.abspath(".")
|
||||
os.chdir(tmp_path)
|
||||
try:
|
||||
assert not os.path.exists(".marker")
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model,
|
||||
distributed_executor_backend=CustomUniExecutor,
|
||||
enforce_eager=True, # reduce test time
|
||||
)
|
||||
engine = LLMEngine.from_engine_args(engine_args)
|
||||
sampling_params = SamplingParams(max_tokens=1)
|
||||
|
||||
engine.add_request("0", "foo", sampling_params)
|
||||
engine.step()
|
||||
|
||||
assert os.path.exists(".marker")
|
||||
finally:
|
||||
os.chdir(cwd)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["distilbert/distilgpt2"])
|
||||
def test_custom_executor_async(model, tmp_path):
|
||||
cwd = os.path.abspath(".")
|
||||
os.chdir(tmp_path)
|
||||
try:
|
||||
assert not os.path.exists(".marker")
|
||||
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=model,
|
||||
distributed_executor_backend=CustomUniExecutorAsync,
|
||||
enforce_eager=True, # reduce test time
|
||||
)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
sampling_params = SamplingParams(max_tokens=1)
|
||||
|
||||
async def t():
|
||||
stream = await engine.add_request("0", "foo", sampling_params)
|
||||
async for x in stream:
|
||||
...
|
||||
|
||||
asyncio.run(t())
|
||||
|
||||
assert os.path.exists(".marker")
|
||||
finally:
|
||||
os.chdir(cwd)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["distilbert/distilgpt2"])
|
||||
def test_respect_ray(model):
|
||||
# even for TP=1 and PP=1,
|
||||
# if users specify ray, we should use ray.
|
||||
# users might do this if they want to manage the
|
||||
# resources using ray.
|
||||
engine_args = EngineArgs(
|
||||
model=model,
|
||||
distributed_executor_backend="ray",
|
||||
enforce_eager=True, # reduce test time
|
||||
)
|
||||
engine = LLMEngine.from_engine_args(engine_args)
|
||||
assert engine.model_executor.uses_ray
|
||||
@ -1,179 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from time import sleep
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
|
||||
ResultHandler, WorkerMonitor)
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
|
||||
class DummyWorkerWrapper(WorkerWrapperBase):
|
||||
"""Dummy version of vllm.worker.worker.Worker"""
|
||||
|
||||
def worker_method(self, worker_input: Any) -> tuple[int, Any]:
|
||||
sleep(0.05)
|
||||
|
||||
if isinstance(worker_input, Exception):
|
||||
# simulate error case
|
||||
raise worker_input
|
||||
|
||||
return self.rpc_rank, input
|
||||
|
||||
|
||||
def _start_workers() -> tuple[list[ProcessWorkerWrapper], WorkerMonitor]:
|
||||
result_handler = ResultHandler()
|
||||
vllm_config = VllmConfig()
|
||||
workers = [
|
||||
ProcessWorkerWrapper(result_handler, DummyWorkerWrapper, vllm_config,
|
||||
rank) for rank in range(8)
|
||||
]
|
||||
|
||||
worker_monitor = WorkerMonitor(workers, result_handler)
|
||||
assert not worker_monitor.is_alive()
|
||||
|
||||
result_handler.start()
|
||||
worker_monitor.start()
|
||||
assert worker_monitor.is_alive()
|
||||
|
||||
return workers, worker_monitor
|
||||
|
||||
|
||||
def test_local_workers() -> None:
|
||||
"""Test workers with sync task submission"""
|
||||
|
||||
workers, worker_monitor = _start_workers()
|
||||
|
||||
def execute_workers(worker_input: str) -> None:
|
||||
worker_outputs = [
|
||||
worker.execute_method("worker_method", worker_input)
|
||||
for worker in workers
|
||||
]
|
||||
|
||||
for rank, output in enumerate(worker_outputs):
|
||||
assert output.get() == (rank, input)
|
||||
|
||||
executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
# Test concurrent submission from different threads
|
||||
futures = [
|
||||
executor.submit(partial(execute_workers, f"thread {thread_num}"))
|
||||
for thread_num in range(4)
|
||||
]
|
||||
|
||||
for future in futures:
|
||||
future.result()
|
||||
|
||||
# Test error case
|
||||
exception = ValueError("fake error")
|
||||
result = workers[0].execute_method("worker_method", exception)
|
||||
try:
|
||||
result.get()
|
||||
pytest.fail("task should have failed")
|
||||
except Exception as e:
|
||||
assert isinstance(e, ValueError)
|
||||
assert str(e) == "fake error"
|
||||
|
||||
# Test cleanup when a worker fails
|
||||
assert worker_monitor.is_alive()
|
||||
workers[3].process.kill()
|
||||
|
||||
# Other workers should get shut down here
|
||||
worker_monitor.join(20)
|
||||
|
||||
# Ensure everything is stopped
|
||||
assert not worker_monitor.is_alive()
|
||||
assert all(not worker.process.is_alive() for worker in workers)
|
||||
|
||||
# Further attempts to submit tasks should fail
|
||||
try:
|
||||
_result = workers[0].execute_method("worker_method", "test")
|
||||
pytest.fail("task should fail once workers have been shut down")
|
||||
except Exception as e:
|
||||
assert isinstance(e, ChildProcessError)
|
||||
|
||||
|
||||
def test_local_workers_clean_shutdown() -> None:
|
||||
"""Test clean shutdown"""
|
||||
|
||||
workers, worker_monitor = _start_workers()
|
||||
|
||||
assert worker_monitor.is_alive()
|
||||
assert all(worker.process.is_alive() for worker in workers)
|
||||
|
||||
# Clean shutdown
|
||||
worker_monitor.close()
|
||||
|
||||
worker_monitor.join(20)
|
||||
|
||||
# Ensure everything is stopped
|
||||
assert not worker_monitor.is_alive()
|
||||
assert all(not worker.process.is_alive() for worker in workers)
|
||||
|
||||
# Further attempts to submit tasks should fail
|
||||
try:
|
||||
_result = workers[0].execute_method("worker_method", "test")
|
||||
pytest.fail("task should fail once workers have been shut down")
|
||||
except Exception as e:
|
||||
assert isinstance(e, ChildProcessError)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_workers_async() -> None:
|
||||
"""Test local workers with async task submission"""
|
||||
|
||||
workers, worker_monitor = _start_workers()
|
||||
|
||||
async def execute_workers(worker_input: str) -> None:
|
||||
worker_coros = [
|
||||
worker.execute_method_async("worker_method", worker_input)
|
||||
for worker in workers
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*worker_coros)
|
||||
for rank, result in enumerate(results):
|
||||
assert result == (rank, input)
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(execute_workers(f"task {task_num}"))
|
||||
for task_num in range(4)
|
||||
]
|
||||
|
||||
for task in tasks:
|
||||
await task
|
||||
|
||||
# Test error case
|
||||
exception = ValueError("fake error")
|
||||
try:
|
||||
_result = await workers[0].execute_method_async(
|
||||
"worker_method", exception)
|
||||
pytest.fail("task should have failed")
|
||||
except Exception as e:
|
||||
assert isinstance(e, ValueError)
|
||||
assert str(e) == "fake error"
|
||||
|
||||
# Test cleanup when a worker fails
|
||||
assert worker_monitor.is_alive()
|
||||
workers[3].process.kill()
|
||||
|
||||
# Other workers should get shut down here
|
||||
worker_monitor.join(20)
|
||||
|
||||
# Ensure everything is stopped
|
||||
assert not worker_monitor.is_alive()
|
||||
assert all(not worker.process.is_alive() for worker in workers)
|
||||
|
||||
# Further attempts to submit tasks should fail
|
||||
try:
|
||||
_result = await workers[0].execute_method_async(
|
||||
"worker_method", "test")
|
||||
pytest.fail("task should fail once workers have been shut down")
|
||||
except Exception as e:
|
||||
assert isinstance(e, ChildProcessError)
|
||||
@ -1,58 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from contextlib import nullcontext
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["distilbert/distilgpt2"])
|
||||
def test_skip_tokenizer_initialization(model: str):
|
||||
# This test checks if the flag skip_tokenizer_init skips the initialization
|
||||
# of tokenizer and detokenizer. The generated output is expected to contain
|
||||
# token ids.
|
||||
llm = LLM(
|
||||
model=model,
|
||||
skip_tokenizer_init=True,
|
||||
enforce_eager=True,
|
||||
)
|
||||
sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True)
|
||||
|
||||
with pytest.raises(ValueError, match="cannot pass text prompts when"):
|
||||
llm.generate("abc", sampling_params)
|
||||
|
||||
outputs = llm.generate({"prompt_token_ids": [1, 2, 3]},
|
||||
sampling_params=sampling_params)
|
||||
assert len(outputs) > 0
|
||||
completions = outputs[0].outputs
|
||||
assert len(completions) > 0
|
||||
assert completions[0].text == ""
|
||||
assert completions[0].token_ids
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["distilbert/distilgpt2"])
|
||||
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
|
||||
def test_enable_prompt_embeds(hf_runner, model: str,
|
||||
enable_prompt_embeds: bool):
|
||||
prompt = "abc"
|
||||
|
||||
with hf_runner(model) as hf_model:
|
||||
token_ids = hf_model.tokenizer(prompt, return_tensors="pt").input_ids
|
||||
token_ids = token_ids.to(hf_model.model.device)
|
||||
|
||||
embed_layer = hf_model.model.get_input_embeddings()
|
||||
prompt_embeds = embed_layer(token_ids).squeeze(0)
|
||||
|
||||
ctx = (nullcontext() if enable_prompt_embeds else pytest.raises(
|
||||
ValueError, match="set `--enable-prompt-embeds`"))
|
||||
|
||||
llm = LLM(
|
||||
model=model,
|
||||
enable_prompt_embeds=enable_prompt_embeds,
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
with ctx:
|
||||
llm.generate({"prompt_embeds": prompt_embeds})
|
||||
@ -25,6 +25,7 @@ def test_context_length_too_short(vllm_runner, image_assets, model):
|
||||
model,
|
||||
max_model_len=128, # LLaVA has a feature size of 576
|
||||
enforce_eager=True,
|
||||
load_format="dummy",
|
||||
)
|
||||
|
||||
with vllm_model:
|
||||
|
||||
@ -1,225 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.reasoning import ReasoningParser
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import Sequence, SequenceStatus
|
||||
|
||||
REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
|
||||
|
||||
|
||||
class MockReasoningParser(ReasoningParser):
|
||||
"""Mock reasoning parser for testing purposes."""
|
||||
|
||||
def __init__(self,
|
||||
tokenizer: AutoTokenizer,
|
||||
reasoning_active: bool = False):
|
||||
super().__init__(tokenizer)
|
||||
self.reasoning_active = reasoning_active
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
return not self.reasoning_active
|
||||
|
||||
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
||||
return input_ids
|
||||
|
||||
|
||||
class MockSequence(Sequence):
|
||||
"""Mock sequence for testing purposes."""
|
||||
|
||||
def __init__(self, token_ids, output_text="test_output", eos_token_id=0):
|
||||
self.token_ids = token_ids
|
||||
self.output_text = output_text
|
||||
self.eos_token_id = eos_token_id
|
||||
self.status = SequenceStatus.RUNNING
|
||||
self.stop_reason = None
|
||||
|
||||
def get_token_ids(self):
|
||||
return self.token_ids
|
||||
|
||||
def get_last_token_id(self):
|
||||
return self.token_ids[-1] if self.token_ids else None
|
||||
|
||||
def get_len(self):
|
||||
return len(self.token_ids)
|
||||
|
||||
def get_output_len(self):
|
||||
return len(self.token_ids) - 1 # Simulating prompt + outputs
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def deepseek_r1_qwen_tokenizer():
|
||||
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def stop_checker():
|
||||
return StopChecker(max_model_len=10)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def stop_checker_with_reasoner():
|
||||
reasoner = MockReasoningParser(deepseek_r1_qwen_tokenizer)
|
||||
return StopChecker(max_model_len=10, reasoner=reasoner)
|
||||
|
||||
|
||||
def test_eos_token_stopping(stop_checker):
|
||||
"""Test sequence stopping when EOS token is encountered."""
|
||||
seq = MockSequence(token_ids=[1, 2, 0], eos_token_id=0)
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
stop_checker.maybe_stop_sequence(seq,
|
||||
new_char_count=1,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
assert seq.status == SequenceStatus.FINISHED_STOPPED
|
||||
|
||||
|
||||
def test_ignore_eos(stop_checker):
|
||||
"""Test sequence continuing when EOS token is ignored."""
|
||||
seq = MockSequence(token_ids=[1, 2, 0], eos_token_id=0)
|
||||
sampling_params = SamplingParams(ignore_eos=True)
|
||||
|
||||
stop_checker.maybe_stop_sequence(seq,
|
||||
new_char_count=1,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
assert seq.status == SequenceStatus.RUNNING
|
||||
|
||||
|
||||
def test_min_tokens(stop_checker):
|
||||
"""Test min_tokens prevents early stopping."""
|
||||
seq = MockSequence(token_ids=[1, 2, 0], eos_token_id=0)
|
||||
sampling_params = SamplingParams(min_tokens=3)
|
||||
|
||||
stop_checker.maybe_stop_sequence(seq,
|
||||
new_char_count=1,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
assert seq.status == SequenceStatus.RUNNING
|
||||
|
||||
|
||||
def test_stop_token_ids(stop_checker):
|
||||
"""Test sequence stopping with custom stop token IDs."""
|
||||
seq = MockSequence(token_ids=[1, 2, 3], eos_token_id=0)
|
||||
sampling_params = SamplingParams(stop_token_ids=[3])
|
||||
|
||||
stop_checker.maybe_stop_sequence(seq,
|
||||
new_char_count=1,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
assert seq.status == SequenceStatus.FINISHED_STOPPED
|
||||
assert seq.stop_reason == 3
|
||||
|
||||
|
||||
def test_stop_strings(stop_checker):
|
||||
"""Test sequence stopping with stop strings."""
|
||||
seq = MockSequence(token_ids=[1, 2, 3],
|
||||
output_text="test output with STOP",
|
||||
eos_token_id=0)
|
||||
sampling_params = SamplingParams(stop=["STOP"])
|
||||
|
||||
stop_checker.maybe_stop_sequence(seq,
|
||||
new_char_count=1,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
assert seq.status == SequenceStatus.FINISHED_STOPPED
|
||||
assert seq.stop_reason == "STOP"
|
||||
assert "STOP" not in seq.output_text # Default behavior removes stop string
|
||||
|
||||
|
||||
def test_include_stop_str_in_output(stop_checker):
|
||||
"""Test keeping stop strings in output."""
|
||||
seq = MockSequence(token_ids=[1, 2, 3],
|
||||
output_text="test output with STOP",
|
||||
eos_token_id=0)
|
||||
sampling_params = SamplingParams(stop=["STOP"],
|
||||
include_stop_str_in_output=True)
|
||||
|
||||
stop_checker.maybe_stop_sequence(seq,
|
||||
new_char_count=5,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
assert seq.status == SequenceStatus.FINISHED_STOPPED
|
||||
assert "STOP" in seq.output_text
|
||||
|
||||
|
||||
def test_max_tokens(stop_checker):
|
||||
"""Test sequence stopping at max_tokens."""
|
||||
seq = MockSequence(token_ids=[1, 2, 3], eos_token_id=0)
|
||||
sampling_params = SamplingParams(max_tokens=2)
|
||||
|
||||
stop_checker.maybe_stop_sequence(seq,
|
||||
new_char_count=1,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
assert seq.status == SequenceStatus.FINISHED_LENGTH_CAPPED
|
||||
|
||||
|
||||
def test_max_model_len(stop_checker):
|
||||
"""Test sequence stopping at max_model_len."""
|
||||
seq = MockSequence(token_ids=list(range(11)),
|
||||
eos_token_id=0) # 11 tokens, max is 10
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
stop_checker.maybe_stop_sequence(seq,
|
||||
new_char_count=1,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
assert seq.status == SequenceStatus.FINISHED_LENGTH_CAPPED
|
||||
|
||||
|
||||
def test_reasoning_skip_stops(stop_checker_with_reasoner):
|
||||
"""Test that stop tokens and strings are ignored during reasoning."""
|
||||
# Set reasoning_active to True to simulate being in reasoning mode
|
||||
stop_checker_with_reasoner.reasoner.reasoning_active = True
|
||||
|
||||
# Test with stop token
|
||||
seq = MockSequence(token_ids=[1, 2, 3], eos_token_id=0)
|
||||
sampling_params = SamplingParams(stop_token_ids=[3])
|
||||
|
||||
stop_checker_with_reasoner.maybe_stop_sequence(
|
||||
seq, new_char_count=1, sampling_params=sampling_params)
|
||||
assert seq.status == SequenceStatus.RUNNING
|
||||
|
||||
# Test with stop string
|
||||
seq = MockSequence(token_ids=[1, 2, 3], output_text="test STOP")
|
||||
sampling_params = SamplingParams(stop=["STOP"])
|
||||
|
||||
stop_checker_with_reasoner.maybe_stop_sequence(
|
||||
seq, new_char_count=4, sampling_params=sampling_params)
|
||||
assert seq.status == SequenceStatus.RUNNING
|
||||
|
||||
# But EOS token still stops the sequence
|
||||
seq = MockSequence(token_ids=[1, 2, 0], eos_token_id=0)
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
stop_checker_with_reasoner.maybe_stop_sequence(
|
||||
seq, new_char_count=1, sampling_params=sampling_params)
|
||||
assert seq.status == SequenceStatus.FINISHED_STOPPED
|
||||
|
||||
|
||||
def test_reasoning_end_enables_stops(stop_checker_with_reasoner):
|
||||
"""Test that stop tokens work after reasoning ends."""
|
||||
# Set reasoning_active to False to simulate being out of reasoning mode
|
||||
stop_checker_with_reasoner.reasoner.reasoning_active = False
|
||||
|
||||
# Test with stop token
|
||||
seq = MockSequence(token_ids=[1, 2, 3], eos_token_id=0)
|
||||
sampling_params = SamplingParams(stop_token_ids=[3])
|
||||
|
||||
stop_checker_with_reasoner.maybe_stop_sequence(
|
||||
seq, new_char_count=1, sampling_params=sampling_params)
|
||||
assert seq.status == SequenceStatus.FINISHED_STOPPED
|
||||
|
||||
# Test with stop string
|
||||
seq = MockSequence(token_ids=[1, 2, 3], output_text="test STOP")
|
||||
sampling_params = SamplingParams(stop=["STOP"])
|
||||
|
||||
stop_checker_with_reasoner.maybe_stop_sequence(
|
||||
seq, new_char_count=4, sampling_params=sampling_params)
|
||||
assert seq.status == SequenceStatus.FINISHED_STOPPED
|
||||
@ -1,268 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import ray
|
||||
from prometheus_client import REGISTRY
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import EngineArgs, LLMEngine
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.metrics import RayPrometheusStatLogger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch):
|
||||
"""
|
||||
This module tests V0 internals, so set VLLM_USE_V1=0.
|
||||
"""
|
||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||
|
||||
|
||||
MODELS = [
|
||||
"distilbert/distilgpt2",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
def test_metric_counter_prompt_tokens(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
with vllm_runner(model,
|
||||
dtype=dtype,
|
||||
disable_log_stats=False,
|
||||
gpu_memory_utilization=0.4) as vllm_model:
|
||||
tokenizer = vllm_model.llm.get_tokenizer()
|
||||
prompt_token_counts = [
|
||||
len(tokenizer.encode(p)) for p in example_prompts
|
||||
]
|
||||
# This test needs at least 2 prompts in a batch of different lengths to
|
||||
# verify their token count is correct despite padding.
|
||||
assert len(example_prompts) > 1, "at least 2 prompts are required"
|
||||
assert prompt_token_counts[0] != prompt_token_counts[1], (
|
||||
"prompts of different lengths are required")
|
||||
vllm_prompt_token_count = sum(prompt_token_counts)
|
||||
|
||||
_ = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
stat_logger = vllm_model.llm.llm_engine.stat_loggers['prometheus']
|
||||
metric_count = stat_logger.metrics.counter_prompt_tokens.labels(
|
||||
**stat_logger.labels)._value.get()
|
||||
|
||||
assert vllm_prompt_token_count == metric_count, (
|
||||
f"prompt token count: {vllm_prompt_token_count!r}\n"
|
||||
f"metric: {metric_count!r}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
def test_metric_counter_generation_tokens(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
with vllm_runner(model,
|
||||
dtype=dtype,
|
||||
disable_log_stats=False,
|
||||
gpu_memory_utilization=0.4) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
tokenizer = vllm_model.llm.get_tokenizer()
|
||||
stat_logger = vllm_model.llm.llm_engine.stat_loggers['prometheus']
|
||||
metric_count = stat_logger.metrics.counter_generation_tokens.labels(
|
||||
**stat_logger.labels)._value.get()
|
||||
vllm_generation_count = 0
|
||||
for i in range(len(example_prompts)):
|
||||
vllm_output_ids, vllm_output_str = vllm_outputs[i]
|
||||
prompt_ids = tokenizer.encode(example_prompts[i])
|
||||
# vllm_output_ids contains both prompt tokens and generation tokens.
|
||||
# We're interested only in the count of the generation tokens.
|
||||
vllm_generation_count += len(vllm_output_ids) - len(prompt_ids)
|
||||
|
||||
assert vllm_generation_count == metric_count, (
|
||||
f"generation token count: {vllm_generation_count!r}\n"
|
||||
f"metric: {metric_count!r}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize(
|
||||
"served_model_name",
|
||||
[None, [], ["ModelName0"], ["ModelName0", "ModelName1", "ModelName2"]])
|
||||
def test_metric_set_tag_model_name(vllm_runner, model: str, dtype: str,
|
||||
served_model_name: list[str]) -> None:
|
||||
with vllm_runner(model,
|
||||
dtype=dtype,
|
||||
disable_log_stats=False,
|
||||
gpu_memory_utilization=0.3,
|
||||
served_model_name=served_model_name) as vllm_model:
|
||||
stat_logger = vllm_model.llm.llm_engine.stat_loggers['prometheus']
|
||||
metrics_tag_content = stat_logger.labels["model_name"]
|
||||
|
||||
if envs.VLLM_CI_USE_S3:
|
||||
model = f"{MODEL_WEIGHTS_S3_BUCKET}/{model}"
|
||||
if served_model_name is None or served_model_name == []:
|
||||
assert metrics_tag_content == model, (
|
||||
f"Metrics tag model_name is wrong! expect: {model!r}\n"
|
||||
f"actual: {metrics_tag_content!r}")
|
||||
else:
|
||||
assert metrics_tag_content == served_model_name[0], (
|
||||
f"Metrics tag model_name is wrong! expect: "
|
||||
f"{served_model_name[0]!r}\n"
|
||||
f"actual: {metrics_tag_content!r}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [4])
|
||||
@pytest.mark.parametrize("disable_log_stats", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_engine_log_metrics_regression(
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
disable_log_stats: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Regression test ensuring async engine generates metrics
|
||||
when disable_log_stats=False
|
||||
(see: https://github.com/vllm-project/vllm/pull/4150#pullrequestreview-2008176678)
|
||||
"""
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=model,
|
||||
dtype=dtype,
|
||||
disable_log_stats=disable_log_stats,
|
||||
)
|
||||
async_engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
for i, prompt in enumerate(example_prompts):
|
||||
results = async_engine.generate(
|
||||
prompt,
|
||||
SamplingParams(max_tokens=max_tokens),
|
||||
f"request-id-{i}",
|
||||
)
|
||||
# Exhaust the async iterator to make the async engine work
|
||||
async for _ in results:
|
||||
pass
|
||||
|
||||
assert_metrics(model, async_engine.engine, disable_log_stats,
|
||||
len(example_prompts))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [4])
|
||||
@pytest.mark.parametrize("disable_log_stats", [True, False])
|
||||
def test_engine_log_metrics_regression(
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
disable_log_stats: bool,
|
||||
) -> None:
|
||||
engine_args = EngineArgs(
|
||||
model=model,
|
||||
dtype=dtype,
|
||||
disable_log_stats=disable_log_stats,
|
||||
)
|
||||
engine = LLMEngine.from_engine_args(engine_args)
|
||||
for i, prompt in enumerate(example_prompts):
|
||||
engine.add_request(
|
||||
f"request-id-{i}",
|
||||
prompt,
|
||||
SamplingParams(max_tokens=max_tokens),
|
||||
)
|
||||
while engine.has_unfinished_requests():
|
||||
engine.step()
|
||||
|
||||
if envs.VLLM_CI_USE_S3:
|
||||
model = f"{MODEL_WEIGHTS_S3_BUCKET}/{model}"
|
||||
assert_metrics(model, engine, disable_log_stats, len(example_prompts))
|
||||
|
||||
|
||||
def assert_metrics(model: str, engine: LLMEngine, disable_log_stats: bool,
|
||||
num_requests: int) -> None:
|
||||
if disable_log_stats:
|
||||
with pytest.raises(AttributeError):
|
||||
_ = engine.stat_loggers
|
||||
else:
|
||||
assert (engine.stat_loggers
|
||||
is not None), "engine.stat_loggers should be set"
|
||||
# Ensure the count bucket of request-level histogram metrics matches
|
||||
# the number of requests as a simple sanity check to ensure metrics are
|
||||
# generated
|
||||
labels = {'model_name': model}
|
||||
request_histogram_metrics = [
|
||||
"vllm:e2e_request_latency_seconds",
|
||||
"vllm:request_prompt_tokens",
|
||||
"vllm:request_generation_tokens",
|
||||
"vllm:request_params_n",
|
||||
"vllm:request_params_max_tokens",
|
||||
]
|
||||
for metric_name in request_histogram_metrics:
|
||||
metric_value = REGISTRY.get_sample_value(f"{metric_name}_count",
|
||||
labels)
|
||||
assert (
|
||||
metric_value == num_requests), "Metrics should be collected"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [16])
|
||||
def test_engine_log_metrics_ray(
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
# This test is quite weak - it only checks that we can use
|
||||
# RayPrometheusStatLogger without exceptions.
|
||||
# Checking whether the metrics are actually emitted is unfortunately
|
||||
# non-trivial.
|
||||
|
||||
# We have to run in a Ray task for Ray metrics to be emitted correctly
|
||||
@ray.remote(num_gpus=1)
|
||||
def _inner():
|
||||
|
||||
class _RayPrometheusStatLogger(RayPrometheusStatLogger):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._i = 0
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def log(self, *args, **kwargs):
|
||||
self._i += 1
|
||||
return super().log(*args, **kwargs)
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model,
|
||||
dtype=dtype,
|
||||
disable_log_stats=False,
|
||||
)
|
||||
engine = LLMEngine.from_engine_args(engine_args)
|
||||
logger = _RayPrometheusStatLogger(
|
||||
local_interval=0.5,
|
||||
labels=dict(model_name=engine.model_config.served_model_name),
|
||||
vllm_config=engine.vllm_config)
|
||||
engine.add_logger("ray", logger)
|
||||
for i, prompt in enumerate(example_prompts):
|
||||
engine.add_request(
|
||||
f"request-id-{i}",
|
||||
prompt,
|
||||
SamplingParams(max_tokens=max_tokens),
|
||||
)
|
||||
while engine.has_unfinished_requests():
|
||||
engine.step()
|
||||
assert logger._i > 0, ".log must be called at least once"
|
||||
|
||||
ray.get(_inner.remote())
|
||||
@ -1,98 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import random
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
|
||||
class MockLogitsProcessor(LogitsProcessor):
|
||||
|
||||
def __init__(self, vocab_size: int, scale: float,
|
||||
fake_logits: torch.Tensor):
|
||||
super().__init__(vocab_size=vocab_size, scale=scale)
|
||||
self.fake_logits = fake_logits.clone()
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
with patch(
|
||||
"vllm.model_executor.layers.logits_processor._prune_hidden_states",
|
||||
lambda x, y: x
|
||||
), patch(
|
||||
"vllm.model_executor.layers.logits_processor.LogitsProcessor._get_logits",
|
||||
lambda *args, **kwargs: self.fake_logits):
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
|
||||
def _prepare_test(
|
||||
batch_size: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor]:
|
||||
vocab_size = 32000
|
||||
input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
|
||||
fake_logits = torch.full((batch_size, vocab_size),
|
||||
1e-2,
|
||||
dtype=input_tensor.dtype)
|
||||
logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits)
|
||||
return input_tensor, fake_logits, logits_processor
|
||||
|
||||
|
||||
RANDOM_SEEDS = list(range(128))
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_logits_processors(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, fake_logits, logits_processor = _prepare_test(batch_size)
|
||||
|
||||
# This sample logits processor gives infinite score to the i-th token,
|
||||
# where i is the length of the input sequence.
|
||||
# We therefore expect the output token sequence to be [0, 1, 2, ...]
|
||||
def pick_ith(token_ids, logits):
|
||||
logits[len(token_ids)] = float("inf")
|
||||
return logits
|
||||
|
||||
seq_group_metadata_list = []
|
||||
seq_lens = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
|
||||
sampling_params=SamplingParams(temperature=0,
|
||||
logits_processors=[pick_ith]),
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
seq_lens,
|
||||
query_lens=seq_lens,
|
||||
device=device,
|
||||
pin_memory=is_pin_memory_available())
|
||||
logits_processor_output = logits_processor(
|
||||
lm_head=None,
|
||||
hidden_states=input_tensor,
|
||||
sampling_metadata=sampling_metadata)
|
||||
|
||||
assert torch.isinf(logits_processor_output[:, 0]).all()
|
||||
|
||||
fake_logits *= logits_processor.scale
|
||||
torch.testing.assert_close(logits_processor_output[:, 1],
|
||||
fake_logits[:, 1],
|
||||
rtol=1e-4,
|
||||
atol=0.0)
|
||||
@ -1,92 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Test hashing of cache blocks.
|
||||
|
||||
Run `pytest tests/test_cache_block_hashing.py`.
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.inputs import token_inputs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import Sequence
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
# Make two prefixes with different first blocks.
|
||||
prefix_start = [("You are an expert"), ("You are a")]
|
||||
prefix_common = (
|
||||
" school principal, skilled in effectively managing "
|
||||
"faculty and staff. Draft 10-15 questions for a potential first grade "
|
||||
"Head Teacher for my K-12, all-girls', independent school that emphasizes "
|
||||
"community, joyful discovery, and life-long learning. The candidate is "
|
||||
"coming in for a first-round panel interview for a 8th grade Math "
|
||||
"teaching role. They have 5 years of previous teaching experience "
|
||||
"as an assistant teacher at a co-ed, public school with experience "
|
||||
"in middle school math teaching. Based on this, fulfill "
|
||||
"the following: ")
|
||||
prefixes = [start + prefix_common for start in prefix_start]
|
||||
|
||||
# Sample prompts.
|
||||
sample_prompts = [
|
||||
"Hello, my name is", "The president of the United States is",
|
||||
"The capital of France is", "The future of AI is"
|
||||
]
|
||||
|
||||
|
||||
# Helper function.
|
||||
def flatten_2d(li):
|
||||
return [lss for ls in li for lss in ls]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.parametrize("max_num_seqs", [256])
|
||||
@pytest.mark.parametrize("concurrent_lora_int_ids",
|
||||
[[None], [1], [None, 1], [None, 1, 2], [1, 2]])
|
||||
def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
|
||||
concurrent_lora_int_ids: list[Optional[int]]):
|
||||
|
||||
tokenizer = get_tokenizer("facebook/opt-125m")
|
||||
|
||||
hashes: list[list[list[int]]] = []
|
||||
|
||||
for prefix in prefixes:
|
||||
for lora_int_id in concurrent_lora_int_ids:
|
||||
lora_request = None
|
||||
|
||||
if lora_int_id is not None:
|
||||
lora_request = LoRARequest(
|
||||
f"example_lora_{lora_int_id}",
|
||||
lora_int_id,
|
||||
f"example/path/to/lora_{lora_int_id}",
|
||||
)
|
||||
|
||||
hashes.append([])
|
||||
prompts = [prefix + prompt for prompt in sample_prompts]
|
||||
for seq_id, prompt in enumerate(prompts):
|
||||
hashes[-1].append([])
|
||||
prompt_token_ids = tokenizer.encode(prompt)
|
||||
seq = Sequence(seq_id,
|
||||
inputs=token_inputs(prompt_token_ids,
|
||||
prompt=prompt),
|
||||
block_size=block_size,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
lora_request=lora_request)
|
||||
|
||||
num_blocks = len(prompt_token_ids) // block_size
|
||||
for idx in range(num_blocks):
|
||||
hashes[-1][-1].append(seq.hash_of_block(idx))
|
||||
|
||||
# Check that hashes made with two prefixes with different first blocks are
|
||||
# different everywhere.
|
||||
for hash0, hash1 in zip(flatten_2d(hashes[0]), flatten_2d(hashes[1])):
|
||||
assert (hash0 != hash1)
|
||||
|
||||
# Check that hashes of different prompts made with the same prefix are the
|
||||
# same until the hashes that contain the prompt.
|
||||
for hash_pref in hashes:
|
||||
same_hashes = [tuple(h[:-1]) for h in hash_pref]
|
||||
different_hashes = [h[-1] for h in hash_pref]
|
||||
assert (len(set(same_hashes)) == 1)
|
||||
assert (len(set(different_hashes)) == len(different_hashes))
|
||||
@ -1,237 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa
|
||||
# type: ignore
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from collections.abc import Iterable
|
||||
from concurrent import futures
|
||||
from typing import Callable, Generator, Literal
|
||||
|
||||
import grpc
|
||||
import pytest
|
||||
from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import (
|
||||
ExportTraceServiceResponse)
|
||||
from opentelemetry.proto.collector.trace.v1.trace_service_pb2_grpc import (
|
||||
TraceServiceServicer, add_TraceServiceServicer_to_server)
|
||||
from opentelemetry.proto.common.v1.common_pb2 import AnyValue, KeyValue
|
||||
from opentelemetry.sdk.environment_variables import (
|
||||
OTEL_EXPORTER_OTLP_TRACES_INSECURE)
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.tracing import SpanAttributes
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
Since this module is V0 only, set VLLM_USE_V1=0 for
|
||||
all tests in the module.
|
||||
"""
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv('VLLM_USE_V1', '0')
|
||||
yield
|
||||
|
||||
|
||||
FAKE_TRACE_SERVER_ADDRESS = "localhost:4317"
|
||||
|
||||
FieldName = Literal['bool_value', 'string_value', 'int_value', 'double_value',
|
||||
'array_value']
|
||||
|
||||
|
||||
def decode_value(value: AnyValue):
|
||||
field_decoders: dict[FieldName, Callable] = {
|
||||
"bool_value": (lambda v: v.bool_value),
|
||||
"string_value": (lambda v: v.string_value),
|
||||
"int_value": (lambda v: v.int_value),
|
||||
"double_value": (lambda v: v.double_value),
|
||||
"array_value":
|
||||
(lambda v: [decode_value(item) for item in v.array_value.values]),
|
||||
}
|
||||
for field, decoder in field_decoders.items():
|
||||
if value.HasField(field):
|
||||
return decoder(value)
|
||||
raise ValueError(f"Couldn't decode value: {value}")
|
||||
|
||||
|
||||
def decode_attributes(attributes: Iterable[KeyValue]):
|
||||
return {kv.key: decode_value(kv.value) for kv in attributes}
|
||||
|
||||
|
||||
class FakeTraceService(TraceServiceServicer):
|
||||
|
||||
def __init__(self):
|
||||
self.request = None
|
||||
self.evt = threading.Event()
|
||||
|
||||
def Export(self, request, context):
|
||||
self.request = request
|
||||
self.evt.set()
|
||||
return ExportTraceServiceResponse()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trace_service() -> Generator[FakeTraceService, None, None]:
|
||||
"""Fixture to set up a fake gRPC trace service"""
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
|
||||
service = FakeTraceService()
|
||||
add_TraceServiceServicer_to_server(service, server)
|
||||
server.add_insecure_port(FAKE_TRACE_SERVER_ADDRESS)
|
||||
server.start()
|
||||
|
||||
yield service
|
||||
|
||||
server.stop(None)
|
||||
|
||||
|
||||
def test_traces(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
trace_service: FakeTraceService,
|
||||
):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv(OTEL_EXPORTER_OTLP_TRACES_INSECURE, "true")
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.01,
|
||||
top_p=0.1,
|
||||
max_tokens=256,
|
||||
)
|
||||
model = "facebook/opt-125m"
|
||||
llm = LLM(
|
||||
model=model,
|
||||
otlp_traces_endpoint=FAKE_TRACE_SERVER_ADDRESS,
|
||||
)
|
||||
prompts = ["This is a short prompt"]
|
||||
outputs = llm.generate(prompts, sampling_params=sampling_params)
|
||||
|
||||
timeout = 5
|
||||
if not trace_service.evt.wait(timeout):
|
||||
raise TimeoutError(
|
||||
f"The fake trace service didn't receive a trace within "
|
||||
f"the {timeout} seconds timeout")
|
||||
|
||||
request = trace_service.request
|
||||
assert len(request.resource_spans) == 1, (
|
||||
f"Expected 1 resource span, "
|
||||
f"but got {len(request.resource_spans)}")
|
||||
assert len(request.resource_spans[0].scope_spans) == 1, (
|
||||
f"Expected 1 scope span, "
|
||||
f"but got {len(request.resource_spans[0].scope_spans)}")
|
||||
assert len(request.resource_spans[0].scope_spans[0].spans) == 1, (
|
||||
f"Expected 1 span, "
|
||||
f"but got {len(request.resource_spans[0].scope_spans[0].spans)}")
|
||||
|
||||
attributes = decode_attributes(
|
||||
request.resource_spans[0].scope_spans[0].spans[0].attributes)
|
||||
assert attributes.get(SpanAttributes.GEN_AI_RESPONSE_MODEL) == model
|
||||
assert attributes.get(
|
||||
SpanAttributes.GEN_AI_REQUEST_ID) == outputs[0].request_id
|
||||
assert attributes.get(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE
|
||||
) == sampling_params.temperature
|
||||
assert attributes.get(
|
||||
SpanAttributes.GEN_AI_REQUEST_TOP_P) == sampling_params.top_p
|
||||
assert attributes.get(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS
|
||||
) == sampling_params.max_tokens
|
||||
assert attributes.get(
|
||||
SpanAttributes.GEN_AI_REQUEST_N) == sampling_params.n
|
||||
assert attributes.get(
|
||||
SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS) == len(
|
||||
outputs[0].prompt_token_ids)
|
||||
completion_tokens = sum(len(o.token_ids) for o in outputs[0].outputs)
|
||||
assert attributes.get(
|
||||
SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS) == completion_tokens
|
||||
metrics = outputs[0].metrics
|
||||
assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE
|
||||
) == metrics.time_in_queue
|
||||
ttft = metrics.first_token_time - metrics.arrival_time
|
||||
assert attributes.get(
|
||||
SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN) == ttft
|
||||
e2e_time = metrics.finished_time - metrics.arrival_time
|
||||
assert attributes.get(SpanAttributes.GEN_AI_LATENCY_E2E) == e2e_time
|
||||
assert metrics.scheduler_time > 0
|
||||
assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_SCHEDULER
|
||||
) == metrics.scheduler_time
|
||||
# Model forward and model execute should be none, since detailed traces is
|
||||
# not enabled.
|
||||
assert metrics.model_forward_time is None
|
||||
assert metrics.model_execute_time is None
|
||||
|
||||
|
||||
def test_traces_with_detailed_steps(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
trace_service: FakeTraceService,
|
||||
):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv(OTEL_EXPORTER_OTLP_TRACES_INSECURE, "true")
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.01,
|
||||
top_p=0.1,
|
||||
max_tokens=256,
|
||||
)
|
||||
model = "facebook/opt-125m"
|
||||
llm = LLM(
|
||||
model=model,
|
||||
otlp_traces_endpoint=FAKE_TRACE_SERVER_ADDRESS,
|
||||
collect_detailed_traces=["all"],
|
||||
)
|
||||
prompts = ["This is a short prompt"]
|
||||
outputs = llm.generate(prompts, sampling_params=sampling_params)
|
||||
|
||||
timeout = 5
|
||||
if not trace_service.evt.wait(timeout):
|
||||
raise TimeoutError(
|
||||
f"The fake trace service didn't receive a trace within "
|
||||
f"the {timeout} seconds timeout")
|
||||
|
||||
request = trace_service.request
|
||||
assert len(request.resource_spans) == 1, (
|
||||
f"Expected 1 resource span, "
|
||||
f"but got {len(request.resource_spans)}")
|
||||
assert len(request.resource_spans[0].scope_spans) == 1, (
|
||||
f"Expected 1 scope span, "
|
||||
f"but got {len(request.resource_spans[0].scope_spans)}")
|
||||
assert len(request.resource_spans[0].scope_spans[0].spans) == 1, (
|
||||
f"Expected 1 span, "
|
||||
f"but got {len(request.resource_spans[0].scope_spans[0].spans)}")
|
||||
|
||||
attributes = decode_attributes(
|
||||
request.resource_spans[0].scope_spans[0].spans[0].attributes)
|
||||
assert attributes.get(SpanAttributes.GEN_AI_RESPONSE_MODEL) == model
|
||||
assert attributes.get(
|
||||
SpanAttributes.GEN_AI_REQUEST_ID) == outputs[0].request_id
|
||||
assert attributes.get(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE
|
||||
) == sampling_params.temperature
|
||||
assert attributes.get(
|
||||
SpanAttributes.GEN_AI_REQUEST_TOP_P) == sampling_params.top_p
|
||||
assert attributes.get(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS
|
||||
) == sampling_params.max_tokens
|
||||
assert attributes.get(
|
||||
SpanAttributes.GEN_AI_REQUEST_N) == sampling_params.n
|
||||
assert attributes.get(
|
||||
SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS) == len(
|
||||
outputs[0].prompt_token_ids)
|
||||
completion_tokens = sum(len(o.token_ids) for o in outputs[0].outputs)
|
||||
assert attributes.get(
|
||||
SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS) == completion_tokens
|
||||
metrics = outputs[0].metrics
|
||||
assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE
|
||||
) == metrics.time_in_queue
|
||||
ttft = metrics.first_token_time - metrics.arrival_time
|
||||
assert attributes.get(
|
||||
SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN) == ttft
|
||||
e2e_time = metrics.finished_time - metrics.arrival_time
|
||||
assert attributes.get(SpanAttributes.GEN_AI_LATENCY_E2E) == e2e_time
|
||||
assert metrics.scheduler_time > 0
|
||||
assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_SCHEDULER
|
||||
) == metrics.scheduler_time
|
||||
assert metrics.model_forward_time > 0
|
||||
assert attributes.get(
|
||||
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD
|
||||
) == pytest.approx(metrics.model_forward_time / 1000)
|
||||
assert metrics.model_execute_time > 0
|
||||
assert attributes.get(
|
||||
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE
|
||||
) == metrics.model_execute_time
|
||||
assert metrics.model_forward_time < 1000 * metrics.model_execute_time
|
||||
@ -135,7 +135,7 @@ if TYPE_CHECKING:
|
||||
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
|
||||
VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None
|
||||
VLLM_TPU_USING_PATHWAYS: bool = False
|
||||
VLLM_USE_DEEP_GEMM: bool = False
|
||||
VLLM_USE_DEEP_GEMM: bool = True
|
||||
VLLM_USE_DEEP_GEMM_E8M0: bool = True
|
||||
VLLM_USE_DEEP_GEMM_E8M0_HOPPER: bool = False
|
||||
VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False
|
||||
@ -1044,7 +1044,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
|
||||
# Allow use of DeepGemm kernels for fused moe ops.
|
||||
"VLLM_USE_DEEP_GEMM":
|
||||
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),
|
||||
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "1"))),
|
||||
|
||||
# Whether to use E8M0 scaling when DeepGEMM is used on Blackwell GPUs.
|
||||
"VLLM_USE_DEEP_GEMM_E8M0":
|
||||
|
||||
@ -7,7 +7,6 @@ import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
|
||||
import numpy as np
|
||||
@ -18,18 +17,14 @@ from tqdm import tqdm
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.cuda_graph import CUDAGraphWrapper
|
||||
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
|
||||
from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig,
|
||||
get_layers_from_vllm_config, update_config)
|
||||
from vllm.distributed.eplb.eplb_state import EplbState
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group)
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
|
||||
from vllm.distributed.kv_transfer import has_kv_transfer_group
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_pp_group, get_tp_group, graph_capture, is_global_first_rank,
|
||||
prepare_communication_buffer_for_model)
|
||||
@ -37,7 +32,6 @@ from vllm.forward_context import (BatchDescriptor, DPMetadata,
|
||||
set_forward_context)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
|
||||
from vllm.model_executor.models.interfaces import (is_mixture_of_experts,
|
||||
@ -54,7 +48,7 @@ from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size,
|
||||
GiB_bytes, LazyLoader, check_use_alibi,
|
||||
is_pin_memory_available, round_up, supports_dynamo)
|
||||
from vllm.v1.attention.backends.flash_attn import AttentionMetadata
|
||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
||||
@ -70,8 +64,7 @@ from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||||
CrossAttentionSpec,
|
||||
EncoderOnlyAttentionSpec,
|
||||
FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, KVCacheSpec,
|
||||
MambaSpec, SlidingWindowSpec)
|
||||
KVCacheSpec, SlidingWindowSpec)
|
||||
# yapf: enable
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
||||
DraftTokenIds, LogprobsLists, LogprobsTensors,
|
||||
@ -88,6 +81,7 @@ from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
|
||||
from vllm.v1.worker.kv_cache_initializer_mixin import KVCacheInitializerMixin
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
||||
KVConnectorModelRunnerMixin)
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
@ -95,10 +89,8 @@ from vllm.v1.worker.ubatch_splitting import get_dp_padding_ubatch, ubatch_split
|
||||
from vllm.v1.worker.ubatch_utils import UBatchSlice, UBatchSlices
|
||||
from vllm.v1.worker.utils import is_residual_scattered_for_sp
|
||||
|
||||
from .utils import (AttentionGroup, MultiModalBudget,
|
||||
add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache,
|
||||
gather_mm_placeholders, sanity_check_mm_encoder_outputs,
|
||||
scatter_mm_placeholders)
|
||||
from .utils import (AttentionGroup, MultiModalBudget, gather_mm_placeholders,
|
||||
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import xgrammar as xgr
|
||||
@ -163,7 +155,8 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
|
||||
return output
|
||||
|
||||
|
||||
class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
class GPUModelRunner(KVCacheInitializerMixin, LoRAModelRunnerMixin,
|
||||
KVConnectorModelRunnerMixin):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -255,7 +248,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.kv_caches: list[torch.Tensor] = []
|
||||
# indexes: [kv_cache_group_id][attn_group]
|
||||
self.attn_groups: list[list[AttentionGroup]] = []
|
||||
# self.kv_cache_config: KVCacheConfig
|
||||
# a fake value to satisfy the type checker
|
||||
self.kv_cache_config: KVCacheConfig = cast(KVCacheConfig, None)
|
||||
|
||||
# mm_hash -> encoder_output
|
||||
self.encoder_cache: dict[str, torch.Tensor] = {}
|
||||
@ -3529,418 +3523,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
else:
|
||||
self.reorder_batch_threshold = reorder_batch_threshold_i
|
||||
|
||||
def may_reinitialize_input_batch(self,
|
||||
kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Re-initialize the input batch if the block sizes are different from
|
||||
`[self.cache_config.block_size]`. This usually happens when there
|
||||
are multiple KV cache groups.
|
||||
|
||||
Args:
|
||||
kv_cache_config: The KV cache configuration.
|
||||
"""
|
||||
block_sizes = [
|
||||
kv_cache_group.kv_cache_spec.block_size
|
||||
for kv_cache_group in kv_cache_config.kv_cache_groups
|
||||
]
|
||||
if block_sizes != [self.cache_config.block_size]:
|
||||
assert self.cache_config.cpu_offload_gb == 0, (
|
||||
"Cannot re-initialize the input batch when CPU weight "
|
||||
"offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501
|
||||
"for more details.")
|
||||
self.input_batch = InputBatch(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=max(self.max_model_len, self.max_encoder_len),
|
||||
max_num_batched_tokens=self.max_num_tokens,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
vocab_size=self.model_config.get_vocab_size(),
|
||||
block_sizes=block_sizes,
|
||||
is_spec_decode=bool(self.vllm_config.speculative_config),
|
||||
logitsprocs=self.input_batch.logitsprocs,
|
||||
is_pooling_model=self.is_pooling_model,
|
||||
num_speculative_tokens=(
|
||||
self.vllm_config.speculative_config.num_speculative_tokens
|
||||
if self.vllm_config.speculative_config else 0),
|
||||
)
|
||||
|
||||
def _allocate_kv_cache_tensors(
|
||||
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Initializes the KV cache buffer with the correct size. The buffer needs
|
||||
to be reshaped to the desired shape before being used by the models.
|
||||
|
||||
Args:
|
||||
kv_cache_config: The KV cache config
|
||||
Returns:
|
||||
dict[str, torch.Tensor]: A map between layer names to their
|
||||
corresponding memory buffer for KV cache.
|
||||
"""
|
||||
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
|
||||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
||||
tensor = torch.zeros(kv_cache_tensor.size,
|
||||
dtype=torch.int8,
|
||||
device=self.device)
|
||||
for layer_name in kv_cache_tensor.shared_by:
|
||||
kv_cache_raw_tensors[layer_name] = tensor
|
||||
|
||||
layer_names = set()
|
||||
for group in kv_cache_config.kv_cache_groups:
|
||||
for layer_name in group.layer_names:
|
||||
if layer_name in self.runner_only_attn_layers:
|
||||
continue
|
||||
layer_names.add(layer_name)
|
||||
assert layer_names == set(kv_cache_raw_tensors.keys(
|
||||
)), "Some layers are not correctly initialized"
|
||||
return kv_cache_raw_tensors
|
||||
|
||||
def _attn_group_iterator(self) -> Iterator[AttentionGroup]:
|
||||
return itertools.chain.from_iterable(self.attn_groups)
|
||||
|
||||
def _kv_cache_spec_attn_group_iterator(
|
||||
self) -> Iterator[tuple[KVCacheSpec, AttentionGroup]]:
|
||||
if not self.kv_cache_config.kv_cache_groups:
|
||||
return
|
||||
for kv_cache_spec_id, attn_groups in enumerate(self.attn_groups):
|
||||
for attn_group in attn_groups:
|
||||
yield self.kv_cache_config.kv_cache_groups[
|
||||
kv_cache_spec_id].kv_cache_spec, attn_group
|
||||
|
||||
def _reshape_kv_cache_tensors(
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
kv_cache_raw_tensors: dict[str, torch.Tensor],
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Reshape the KV cache tensors to the desired shape and dtype.
|
||||
|
||||
Args:
|
||||
kv_cache_config: The KV cache config
|
||||
kv_cache_raw_tensors: The KV cache buffer of each layer, with
|
||||
correct size but uninitialized shape.
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: A map between layer names to their
|
||||
corresponding memory buffer for KV cache.
|
||||
"""
|
||||
kv_caches: dict[str, torch.Tensor] = {}
|
||||
has_attn, has_mamba = False, False
|
||||
for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator():
|
||||
attn_backend = group.backend
|
||||
for layer_name in group.layer_names:
|
||||
if layer_name in self.runner_only_attn_layers:
|
||||
continue
|
||||
raw_tensor = kv_cache_raw_tensors[layer_name]
|
||||
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
|
||||
num_blocks = (raw_tensor.numel() //
|
||||
kv_cache_spec.page_size_bytes)
|
||||
if isinstance(kv_cache_spec, AttentionSpec):
|
||||
has_attn = True
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
dtype = kv_cache_spec.dtype
|
||||
try:
|
||||
kv_cache_stride_order = \
|
||||
attn_backend.get_kv_cache_stride_order()
|
||||
assert len(kv_cache_stride_order) == len(
|
||||
kv_cache_shape)
|
||||
except (AttributeError, NotImplementedError):
|
||||
kv_cache_stride_order = tuple(
|
||||
range(len(kv_cache_shape)))
|
||||
# The allocation respects the backend-defined stride order
|
||||
# to ensure the semantic remains consistent for each
|
||||
# backend. We first obtain the generic kv cache shape and
|
||||
# then permute it according to the stride order which could
|
||||
# result in a non-contiguous tensor.
|
||||
kv_cache_shape = tuple(kv_cache_shape[i]
|
||||
for i in kv_cache_stride_order)
|
||||
# Maintain original KV shape view.
|
||||
inv_order = [
|
||||
kv_cache_stride_order.index(i)
|
||||
for i in range(len(kv_cache_stride_order))
|
||||
]
|
||||
kv_caches[layer_name] = kv_cache_raw_tensors[
|
||||
layer_name].view(dtype).view(kv_cache_shape).permute(
|
||||
*inv_order)
|
||||
elif isinstance(kv_cache_spec, MambaSpec):
|
||||
has_mamba = True
|
||||
raw_tensor = kv_cache_raw_tensors[layer_name]
|
||||
state_tensors = []
|
||||
storage_offset_bytes = 0
|
||||
for (shape, dtype) in zip(kv_cache_spec.shapes,
|
||||
kv_cache_spec.dtypes):
|
||||
dtype_size = get_dtype_size(dtype)
|
||||
num_element_per_page = (
|
||||
kv_cache_spec.page_size_bytes // dtype_size)
|
||||
target_shape = (num_blocks, *shape)
|
||||
stride = torch.empty(target_shape).stride()
|
||||
target_stride = (num_element_per_page, *stride[1:])
|
||||
assert storage_offset_bytes % dtype_size == 0
|
||||
tensor = torch.as_strided(
|
||||
raw_tensor.view(dtype),
|
||||
size=target_shape,
|
||||
stride=target_stride,
|
||||
storage_offset=storage_offset_bytes // dtype_size,
|
||||
)
|
||||
state_tensors.append(tensor)
|
||||
storage_offset_bytes += stride[0] * dtype_size
|
||||
|
||||
kv_caches[layer_name] = state_tensors
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if has_attn and has_mamba:
|
||||
self._update_hybrid_attention_mamba_layout(kv_caches)
|
||||
|
||||
return kv_caches
|
||||
|
||||
def _update_hybrid_attention_mamba_layout(
|
||||
self, kv_caches: dict[str, torch.Tensor]) -> None:
|
||||
"""
|
||||
Update the layout of attention layers from (2, num_blocks, ...) to
|
||||
(num_blocks, 2, ...).
|
||||
|
||||
Args:
|
||||
kv_caches: The KV cache buffer of each layer.
|
||||
"""
|
||||
|
||||
for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator():
|
||||
for layer_name in group.layer_names:
|
||||
kv_cache = kv_caches[layer_name]
|
||||
if (isinstance(kv_cache_spec, AttentionSpec)
|
||||
and kv_cache.shape[0] == 2):
|
||||
assert kv_cache.shape[1] != 2, \
|
||||
"Fail to determine whether the layout is " \
|
||||
"(2, num_blocks, ...) or (num_blocks, 2, ...) for " \
|
||||
f"a tensor of shape {kv_cache.shape}"
|
||||
hidden_size = kv_cache.shape[2:].numel()
|
||||
kv_cache.as_strided_(size=kv_cache.shape,
|
||||
stride=(hidden_size, 2 * hidden_size,
|
||||
*kv_cache.stride()[2:]))
|
||||
|
||||
def initialize_kv_cache_tensors(
|
||||
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Initialize the memory buffer for KV cache.
|
||||
|
||||
Args:
|
||||
kv_cache_config: The KV cache config
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: A map between layer names to their
|
||||
corresponding memory buffer for KV cache.
|
||||
"""
|
||||
# Initialize the memory buffer for KV cache
|
||||
kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)
|
||||
# Change the memory buffer to the desired shape
|
||||
kv_caches = self._reshape_kv_cache_tensors(kv_cache_config,
|
||||
kv_cache_raw_tensors)
|
||||
|
||||
# Set up cross-layer KV cache sharing
|
||||
for layer_name, target_layer_name in self.shared_kv_cache_layers.items(
|
||||
):
|
||||
logger.debug("%s reuses KV cache of %s", layer_name,
|
||||
target_layer_name)
|
||||
kv_caches[layer_name] = kv_caches[target_layer_name]
|
||||
|
||||
bind_kv_cache(kv_caches,
|
||||
self.compilation_config.static_forward_context,
|
||||
self.kv_caches)
|
||||
return kv_caches
|
||||
|
||||
def maybe_add_kv_sharing_layers_to_kv_cache_groups(
|
||||
self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Add layers that re-use KV cache to KV cache group of its target layer.
|
||||
Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()`
|
||||
"""
|
||||
if not self.shared_kv_cache_layers:
|
||||
# No cross-layer KV sharing, return
|
||||
return
|
||||
|
||||
add_kv_sharing_layers_to_kv_cache_groups(
|
||||
self.shared_kv_cache_layers,
|
||||
kv_cache_config.kv_cache_groups,
|
||||
self.runner_only_attn_layers,
|
||||
)
|
||||
|
||||
if self.cache_config.kv_sharing_fast_prefill:
|
||||
# In You Only Cache Once (https://arxiv.org/abs/2405.05254) or other
|
||||
# similar KV sharing setups, only the layers that generate KV caches
|
||||
# are involved in the prefill phase, enabling prefill to early exit.
|
||||
attn_layers = get_layers_from_vllm_config(self.vllm_config,
|
||||
Attention)
|
||||
for layer_name in reversed(attn_layers):
|
||||
if layer_name in self.shared_kv_cache_layers:
|
||||
self.kv_sharing_fast_prefill_eligible_layers.add(
|
||||
layer_name)
|
||||
else:
|
||||
break
|
||||
|
||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Initialize KV cache based on `kv_cache_config`.
|
||||
Args:
|
||||
kv_cache_config: Configuration for the KV cache, including the KV
|
||||
cache size of each layer
|
||||
"""
|
||||
kv_cache_config = deepcopy(kv_cache_config)
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.may_reinitialize_input_batch(kv_cache_config)
|
||||
self.may_add_encoder_only_layers_to_kv_cache_config()
|
||||
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
|
||||
self.initialize_attn_backend(kv_cache_config)
|
||||
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
|
||||
|
||||
if self.speculative_config and self.speculative_config.use_eagle():
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
# validate all draft model layers belong to the same kv cache
|
||||
# group
|
||||
self.drafter.validate_same_kv_cache_group(kv_cache_config)
|
||||
|
||||
if has_kv_transfer_group():
|
||||
get_kv_transfer_group().register_kv_caches(kv_caches)
|
||||
if self.device.type == 'xpu':
|
||||
get_kv_transfer_group().set_host_xfer_buffer_ops(
|
||||
copy_kv_blocks)
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
layer_names = self.attn_groups[0][0].layer_names
|
||||
layers = get_layers_from_vllm_config(self.vllm_config,
|
||||
AttentionLayerBase,
|
||||
layer_names)
|
||||
for layer in layers.values():
|
||||
assert layer.impl.need_to_return_lse_for_decode, (
|
||||
"DCP requires attention impls to return"
|
||||
" the softmax lse for decode, but the impl "
|
||||
f"{layer.impl.__class__.__name__} "
|
||||
"does not return the softmax lse for decode.")
|
||||
|
||||
def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
|
||||
"""
|
||||
Add encoder-only layers to the KV cache config.
|
||||
"""
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
use_mla = self.vllm_config.model_config.use_mla
|
||||
encoder_only_attn_specs: dict[AttentionSpec,
|
||||
list[str]] = defaultdict(list)
|
||||
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
||||
for layer_name, attn_module in attn_layers.items():
|
||||
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
|
||||
attn_spec: AttentionSpec = EncoderOnlyAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
use_mla=use_mla)
|
||||
encoder_only_attn_specs[attn_spec].append(layer_name)
|
||||
self.runner_only_attn_layers.add(layer_name)
|
||||
if len(encoder_only_attn_specs) > 0:
|
||||
assert len(
|
||||
encoder_only_attn_specs
|
||||
) == 1, "Only support one encoder-only attention spec now"
|
||||
spec, layer_names = encoder_only_attn_specs.popitem()
|
||||
self.kv_cache_config.kv_cache_groups.append(
|
||||
KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec))
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
"""
|
||||
Generates the KVCacheSpec by parsing the kv cache format from each
|
||||
Attention module in the static forward context.
|
||||
Returns:
|
||||
KVCacheSpec: A dictionary mapping layer names to their KV cache
|
||||
format. Layers that do not need KV cache are not included.
|
||||
"""
|
||||
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
use_mla = self.vllm_config.model_config.use_mla
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
||||
for layer_name, attn_module in attn_layers.items():
|
||||
if (kv_tgt_layer :=
|
||||
attn_module.kv_sharing_target_layer_name) is not None:
|
||||
# The layer doesn't need its own KV cache and will use that of
|
||||
# the target layer. We skip creating a KVCacheSpec for it, so
|
||||
# that KV cache management logic will act as this layer does
|
||||
# not exist, and doesn't allocate KV cache for the layer. This
|
||||
# enables the memory saving of cross-layer kv sharing, allowing
|
||||
# a given amount of memory to accommodate longer context lengths
|
||||
# or enable more requests to be processed simultaneously.
|
||||
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
|
||||
continue
|
||||
|
||||
# TODO(lucas): move the attention specs into the model layers like
|
||||
# the attention backends
|
||||
if attn_module.attn_type == AttentionType.DECODER:
|
||||
if attn_module.sliding_window is not None:
|
||||
kv_cache_spec[layer_name] = SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
sliding_window=attn_module.sliding_window,
|
||||
use_mla=use_mla)
|
||||
elif self.attention_chunk_size is not None \
|
||||
and isinstance(attn_module, ChunkedLocalAttention):
|
||||
kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
attention_chunk_size=self.attention_chunk_size,
|
||||
use_mla=use_mla)
|
||||
else:
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
use_mla=use_mla)
|
||||
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
||||
kv_cache_spec[layer_name] = CrossAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
use_mla=use_mla)
|
||||
elif attn_module.attn_type in (AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY):
|
||||
# encoder-only attention does not need KV cache.
|
||||
continue
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown attention type: {attn_module.attn_type}")
|
||||
|
||||
mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase)
|
||||
if len(mamba_layers) > 0:
|
||||
if (self.vllm_config.speculative_config is not None
|
||||
and self.vllm_config.model_config.hf_config.model_type
|
||||
not in ["qwen3_next"]):
|
||||
raise NotImplementedError(
|
||||
"Mamba with speculative decoding is not supported yet.")
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
raise NotImplementedError(
|
||||
"Prefix caching is not supported for Mamba yet.")
|
||||
max_model_len = self.vllm_config.model_config.max_model_len
|
||||
|
||||
page_size_padded = (
|
||||
self.vllm_config.cache_config.mamba_page_size_padded)
|
||||
|
||||
# Set block_size to max_model_len, so that mamba model will always
|
||||
# have only one block in the KV cache.
|
||||
for layer_name, mamba_module in mamba_layers.items():
|
||||
kv_cache_spec[layer_name] = MambaSpec(
|
||||
shapes=mamba_module.get_state_shape(),
|
||||
dtypes=mamba_module.get_state_dtype(),
|
||||
block_size=max_model_len,
|
||||
page_size_padded=page_size_padded,
|
||||
mamba_type=mamba_module.mamba_type,
|
||||
num_speculative_blocks=(
|
||||
self.speculative_config.num_speculative_tokens
|
||||
if self.speculative_config else 0),
|
||||
)
|
||||
|
||||
return kv_cache_spec
|
||||
|
||||
def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]:
|
||||
# This is a short term mitigation for issue mentioned in
|
||||
# https://github.com/vllm-project/vllm/issues/22754.
|
||||
|
||||
484
vllm/v1/worker/kv_cache_initializer_mixin.py
Normal file
484
vllm/v1/worker/kv_cache_initializer_mixin.py
Normal file
@ -0,0 +1,484 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator
|
||||
from copy import deepcopy
|
||||
from typing import Any, Protocol, cast
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
|
||||
from vllm.config import get_layers_from_vllm_config
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group)
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.utils import get_dtype_size
|
||||
# yapf: disable
|
||||
from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||||
ChunkedLocalAttentionSpec,
|
||||
CrossAttentionSpec,
|
||||
EncoderOnlyAttentionSpec,
|
||||
FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, KVCacheSpec,
|
||||
MambaSpec, SlidingWindowSpec)
|
||||
# yapf: enable
|
||||
from vllm.v1.spec_decode.eagle import EagleProposer
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
from .utils import (AttentionGroup, add_kv_sharing_layers_to_kv_cache_groups,
|
||||
bind_kv_cache)
|
||||
|
||||
|
||||
class _KVCacheInitializerSelf(Protocol):
|
||||
cache_config: Any
|
||||
max_num_reqs: int
|
||||
max_model_len: int
|
||||
max_encoder_len: int
|
||||
max_num_tokens: int
|
||||
device: Any
|
||||
pin_memory: bool
|
||||
model_config: Any
|
||||
vllm_config: Any
|
||||
input_batch: InputBatch
|
||||
is_pooling_model: bool
|
||||
shared_kv_cache_layers: dict[str, str]
|
||||
kv_sharing_fast_prefill_eligible_layers: set[str]
|
||||
attention_chunk_size: int
|
||||
runner_only_attn_layers: set[str]
|
||||
kv_cache_dtype: torch.dtype
|
||||
kv_cache_config: KVCacheConfig
|
||||
compilation_config: Any
|
||||
kv_caches: Any
|
||||
speculative_config: Any
|
||||
drafter: Any
|
||||
dcp_world_size: int
|
||||
attn_groups: list[list[AttentionGroup]]
|
||||
|
||||
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
...
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# Defined as a mixin for GPUModelRunner
|
||||
class KVCacheInitializerMixin:
|
||||
|
||||
def _runner(self) -> _KVCacheInitializerSelf:
|
||||
return cast(_KVCacheInitializerSelf, self)
|
||||
|
||||
def may_reinitialize_input_batch(self,
|
||||
kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Re-initialize the input batch if the block sizes are different from
|
||||
`[self.cache_config.block_size]`. This usually happens when there
|
||||
are multiple KV cache groups.
|
||||
|
||||
Args:
|
||||
kv_cache_config: The KV cache configuration.
|
||||
"""
|
||||
runner = self._runner()
|
||||
block_sizes = [
|
||||
kv_cache_group.kv_cache_spec.block_size
|
||||
for kv_cache_group in kv_cache_config.kv_cache_groups
|
||||
]
|
||||
if block_sizes != [runner.cache_config.block_size]:
|
||||
assert runner.cache_config.cpu_offload_gb == 0, (
|
||||
"Cannot re-initialize the input batch when CPU weight "
|
||||
"offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501
|
||||
"for more details.")
|
||||
runner.input_batch = InputBatch(
|
||||
max_num_reqs=runner.max_num_reqs,
|
||||
max_model_len=max(runner.max_model_len,
|
||||
runner.max_encoder_len),
|
||||
max_num_batched_tokens=runner.max_num_tokens,
|
||||
device=runner.device,
|
||||
pin_memory=runner.pin_memory,
|
||||
vocab_size=runner.model_config.get_vocab_size(),
|
||||
block_sizes=block_sizes,
|
||||
is_spec_decode=bool(runner.vllm_config.speculative_config),
|
||||
logitsprocs=runner.input_batch.logitsprocs,
|
||||
is_pooling_model=runner.is_pooling_model,
|
||||
num_speculative_tokens=(runner.vllm_config.speculative_config.
|
||||
num_speculative_tokens if
|
||||
runner.vllm_config.speculative_config
|
||||
else 0),
|
||||
)
|
||||
|
||||
def _allocate_kv_cache_tensors(
|
||||
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Initializes the KV cache buffer with the correct size. The buffer needs
|
||||
to be reshaped to the desired shape before being used by the models.
|
||||
|
||||
Args:
|
||||
kv_cache_config: The KV cache config
|
||||
Returns:
|
||||
dict[str, torch.Tensor]: A map between layer names to their
|
||||
corresponding memory buffer for KV cache.
|
||||
"""
|
||||
runner = self._runner()
|
||||
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
|
||||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
||||
tensor = torch.zeros(kv_cache_tensor.size,
|
||||
dtype=torch.int8,
|
||||
device=runner.device)
|
||||
for layer_name in kv_cache_tensor.shared_by:
|
||||
kv_cache_raw_tensors[layer_name] = tensor
|
||||
|
||||
layer_names = set()
|
||||
for group in kv_cache_config.kv_cache_groups:
|
||||
for layer_name in group.layer_names:
|
||||
if layer_name in runner.runner_only_attn_layers:
|
||||
continue
|
||||
layer_names.add(layer_name)
|
||||
assert layer_names == set(kv_cache_raw_tensors.keys(
|
||||
)), "Some layers are not correctly initialized"
|
||||
return kv_cache_raw_tensors
|
||||
|
||||
def _kv_cache_spec_attn_group_iterator(
|
||||
self) -> Iterator[tuple[KVCacheSpec, AttentionGroup]]:
|
||||
runner = self._runner()
|
||||
if not runner.kv_cache_config.kv_cache_groups:
|
||||
return
|
||||
for kv_cache_spec_id, attn_groups in enumerate(runner.attn_groups):
|
||||
for attn_group in attn_groups:
|
||||
yield runner.kv_cache_config.kv_cache_groups[
|
||||
kv_cache_spec_id].kv_cache_spec, attn_group
|
||||
|
||||
def _reshape_kv_cache_tensors(
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
kv_cache_raw_tensors: dict[str, torch.Tensor],
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Reshape the KV cache tensors to the desired shape and dtype.
|
||||
|
||||
Args:
|
||||
kv_cache_config: The KV cache config
|
||||
kv_cache_raw_tensors: The KV cache buffer of each layer, with
|
||||
correct size but uninitialized shape.
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: A map between layer names to their
|
||||
corresponding memory buffer for KV cache.
|
||||
"""
|
||||
runner = self._runner()
|
||||
kv_caches: dict[str, torch.Tensor] = {}
|
||||
has_attn, has_mamba = False, False
|
||||
for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator():
|
||||
attn_backend = group.backend
|
||||
for layer_name in group.layer_names:
|
||||
if layer_name in runner.runner_only_attn_layers:
|
||||
continue
|
||||
raw_tensor = kv_cache_raw_tensors[layer_name]
|
||||
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
|
||||
num_blocks = (raw_tensor.numel() //
|
||||
kv_cache_spec.page_size_bytes)
|
||||
if isinstance(kv_cache_spec, AttentionSpec):
|
||||
has_attn = True
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
dtype = kv_cache_spec.dtype
|
||||
try:
|
||||
kv_cache_stride_order = \
|
||||
attn_backend.get_kv_cache_stride_order()
|
||||
assert len(kv_cache_stride_order) == len(
|
||||
kv_cache_shape)
|
||||
except (AttributeError, NotImplementedError):
|
||||
kv_cache_stride_order = tuple(
|
||||
range(len(kv_cache_shape)))
|
||||
kv_cache_shape = tuple(kv_cache_shape[i]
|
||||
for i in kv_cache_stride_order)
|
||||
inv_order = [
|
||||
kv_cache_stride_order.index(i)
|
||||
for i in range(len(kv_cache_stride_order))
|
||||
]
|
||||
kv_caches[layer_name] = kv_cache_raw_tensors[
|
||||
layer_name].view(dtype).view(kv_cache_shape).permute(
|
||||
*inv_order)
|
||||
elif isinstance(kv_cache_spec, MambaSpec):
|
||||
has_mamba = True
|
||||
raw_tensor = kv_cache_raw_tensors[layer_name]
|
||||
state_tensors = []
|
||||
storage_offset_bytes = 0
|
||||
for (shape, dtype) in zip(kv_cache_spec.shapes,
|
||||
kv_cache_spec.dtypes):
|
||||
dtype_size = get_dtype_size(dtype)
|
||||
num_element_per_page = (
|
||||
kv_cache_spec.page_size_bytes // dtype_size)
|
||||
target_shape = (num_blocks, *shape)
|
||||
stride = torch.empty(target_shape).stride()
|
||||
target_stride = (num_element_per_page, *stride[1:])
|
||||
assert storage_offset_bytes % dtype_size == 0
|
||||
tensor = torch.as_strided(
|
||||
raw_tensor.view(dtype),
|
||||
size=target_shape,
|
||||
stride=target_stride,
|
||||
storage_offset=storage_offset_bytes // dtype_size,
|
||||
)
|
||||
state_tensors.append(tensor)
|
||||
storage_offset_bytes += stride[0] * dtype_size
|
||||
|
||||
kv_caches[layer_name] = state_tensors
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if has_attn and has_mamba:
|
||||
self._update_hybrid_attention_mamba_layout(kv_caches)
|
||||
|
||||
return kv_caches
|
||||
|
||||
def _update_hybrid_attention_mamba_layout(
|
||||
self, kv_caches: dict[str, torch.Tensor]) -> None:
|
||||
"""
|
||||
Update the layout of attention layers from (2, num_blocks, ...) to
|
||||
(num_blocks, 2, ...).
|
||||
|
||||
Args:
|
||||
kv_caches: The KV cache buffer of each layer.
|
||||
"""
|
||||
|
||||
for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator():
|
||||
for layer_name in group.layer_names:
|
||||
kv_cache = kv_caches[layer_name]
|
||||
if (isinstance(kv_cache_spec, AttentionSpec)
|
||||
and kv_cache.shape[0] == 2):
|
||||
assert kv_cache.shape[1] != 2, \
|
||||
"Fail to determine whether the layout is " \
|
||||
"(2, num_blocks, ...) or (num_blocks, 2, ...) for " \
|
||||
f"a tensor of shape {kv_cache.shape}"
|
||||
hidden_size = kv_cache.shape[2:].numel()
|
||||
kv_cache.as_strided_(size=kv_cache.shape,
|
||||
stride=(hidden_size, 2 * hidden_size,
|
||||
*kv_cache.stride()[2:]))
|
||||
|
||||
def initialize_kv_cache_tensors(
|
||||
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Initialize the memory buffer for KV cache.
|
||||
|
||||
Args:
|
||||
kv_cache_config: The KV cache config
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: A map between layer names to their
|
||||
corresponding memory buffer for KV cache.
|
||||
"""
|
||||
runner = self._runner()
|
||||
kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)
|
||||
kv_caches = self._reshape_kv_cache_tensors(kv_cache_config,
|
||||
kv_cache_raw_tensors)
|
||||
|
||||
for layer_name, target_layer_name in (
|
||||
runner.shared_kv_cache_layers.items()):
|
||||
logger.debug("%s reuses KV cache of %s", layer_name,
|
||||
target_layer_name)
|
||||
kv_caches[layer_name] = kv_caches[target_layer_name]
|
||||
|
||||
bind_kv_cache(kv_caches,
|
||||
runner.compilation_config.static_forward_context,
|
||||
runner.kv_caches)
|
||||
return kv_caches
|
||||
|
||||
def maybe_add_kv_sharing_layers_to_kv_cache_groups(
|
||||
self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Add layers that re-use KV cache to KV cache group of its target layer.
|
||||
Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()`
|
||||
"""
|
||||
runner = self._runner()
|
||||
if not runner.shared_kv_cache_layers:
|
||||
return
|
||||
|
||||
add_kv_sharing_layers_to_kv_cache_groups(
|
||||
runner.shared_kv_cache_layers,
|
||||
kv_cache_config.kv_cache_groups,
|
||||
runner.runner_only_attn_layers,
|
||||
)
|
||||
|
||||
if runner.cache_config.kv_sharing_fast_prefill:
|
||||
attn_layers = get_layers_from_vllm_config(runner.vllm_config,
|
||||
Attention)
|
||||
for layer_name in reversed(attn_layers):
|
||||
if layer_name in runner.shared_kv_cache_layers:
|
||||
runner.kv_sharing_fast_prefill_eligible_layers.add(
|
||||
layer_name)
|
||||
else:
|
||||
break
|
||||
|
||||
def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
|
||||
"""
|
||||
Add encoder-only layers to the KV cache config.
|
||||
"""
|
||||
runner = self._runner()
|
||||
block_size = runner.vllm_config.cache_config.block_size
|
||||
use_mla = runner.vllm_config.model_config.use_mla
|
||||
encoder_only_attn_specs: dict[AttentionSpec,
|
||||
list[str]] = defaultdict(list)
|
||||
attn_layers = get_layers_from_vllm_config(runner.vllm_config,
|
||||
Attention)
|
||||
for layer_name, attn_module in attn_layers.items():
|
||||
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
|
||||
attn_spec: AttentionSpec = EncoderOnlyAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=runner.kv_cache_dtype,
|
||||
use_mla=use_mla)
|
||||
encoder_only_attn_specs[attn_spec].append(layer_name)
|
||||
runner.runner_only_attn_layers.add(layer_name)
|
||||
if len(encoder_only_attn_specs) > 0:
|
||||
assert len(
|
||||
encoder_only_attn_specs
|
||||
) == 1, "Only support one encoder-only attention spec now"
|
||||
spec, layer_names = encoder_only_attn_specs.popitem()
|
||||
runner.kv_cache_config.kv_cache_groups.append(
|
||||
KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec))
|
||||
|
||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Initialize KV cache based on `kv_cache_config`.
|
||||
|
||||
Args:
|
||||
kv_cache_config: Configuration for the KV cache, including the KV
|
||||
cache size of each layer
|
||||
"""
|
||||
runner = self._runner()
|
||||
kv_cache_config = deepcopy(kv_cache_config)
|
||||
runner.kv_cache_config = kv_cache_config
|
||||
self.may_reinitialize_input_batch(kv_cache_config)
|
||||
self.may_add_encoder_only_layers_to_kv_cache_config()
|
||||
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
|
||||
runner.initialize_attn_backend(kv_cache_config)
|
||||
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
|
||||
|
||||
if runner.speculative_config and runner.speculative_config.use_eagle():
|
||||
assert isinstance(runner.drafter, EagleProposer)
|
||||
runner.drafter.validate_same_kv_cache_group(kv_cache_config)
|
||||
|
||||
if has_kv_transfer_group():
|
||||
get_kv_transfer_group().register_kv_caches(kv_caches)
|
||||
if runner.device.type == 'xpu':
|
||||
get_kv_transfer_group().set_host_xfer_buffer_ops(
|
||||
copy_kv_blocks)
|
||||
|
||||
if runner.dcp_world_size > 1:
|
||||
layer_names = runner.attn_groups[0][0].layer_names
|
||||
layers = get_layers_from_vllm_config(
|
||||
runner.vllm_config,
|
||||
AttentionLayerBase, # type: ignore[type-abstract]
|
||||
layer_names,
|
||||
)
|
||||
for layer in layers.values():
|
||||
layer_impl = cast(Any, layer).impl
|
||||
assert layer_impl.need_to_return_lse_for_decode, (
|
||||
"DCP requires attention impls to return"
|
||||
" the softmax lse for decode, but the impl "
|
||||
f"{layer_impl.__class__.__name__} "
|
||||
"does not return the softmax lse for decode.")
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
"""
|
||||
Generates the KVCacheSpec by parsing the kv cache format from each
|
||||
Attention module in the static forward context.
|
||||
Returns:
|
||||
KVCacheSpec: A dictionary mapping layer names to their KV cache
|
||||
format. Layers that do not need KV cache are not included.
|
||||
"""
|
||||
runner = self._runner()
|
||||
block_size = runner.vllm_config.cache_config.block_size
|
||||
use_mla = runner.vllm_config.model_config.use_mla
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
attn_layers = get_layers_from_vllm_config(runner.vllm_config,
|
||||
Attention)
|
||||
for layer_name, attn_module in attn_layers.items():
|
||||
if (kv_tgt_layer :=
|
||||
attn_module.kv_sharing_target_layer_name) is not None:
|
||||
# The layer doesn't need its own KV cache and will use that of
|
||||
# the target layer. We skip creating a KVCacheSpec for it, so
|
||||
# that KV cache management logic will act as this layer does
|
||||
# not exist, and doesn't allocate KV cache for the layer. This
|
||||
# enables the memory saving of cross-layer kv sharing, allowing
|
||||
# a given amount of memory to accommodate longer context lengths
|
||||
# or enable more requests to be processed simultaneously.
|
||||
runner.shared_kv_cache_layers[layer_name] = kv_tgt_layer
|
||||
continue
|
||||
|
||||
# TODO(lucas): move the attention specs into the model layers like
|
||||
# the attention backends
|
||||
if attn_module.attn_type == AttentionType.DECODER:
|
||||
if attn_module.sliding_window is not None:
|
||||
kv_cache_spec[layer_name] = SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=runner.kv_cache_dtype,
|
||||
sliding_window=attn_module.sliding_window,
|
||||
use_mla=use_mla)
|
||||
elif runner.attention_chunk_size is not None \
|
||||
and isinstance(attn_module, ChunkedLocalAttention):
|
||||
kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=runner.kv_cache_dtype,
|
||||
attention_chunk_size=runner.attention_chunk_size,
|
||||
use_mla=use_mla)
|
||||
else:
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=runner.kv_cache_dtype,
|
||||
use_mla=use_mla)
|
||||
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
||||
kv_cache_spec[layer_name] = CrossAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=runner.kv_cache_dtype,
|
||||
use_mla=use_mla)
|
||||
elif attn_module.attn_type in (AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY):
|
||||
# encoder-only attention does not need KV cache.
|
||||
continue
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown attention type: {attn_module.attn_type}")
|
||||
|
||||
mamba_layers = get_layers_from_vllm_config(runner.vllm_config,
|
||||
MambaBase)
|
||||
if len(mamba_layers) > 0:
|
||||
if (runner.vllm_config.speculative_config is not None
|
||||
and runner.vllm_config.model_config.hf_config.model_type
|
||||
not in ["qwen3_next"]):
|
||||
raise NotImplementedError(
|
||||
"Mamba with speculative decoding is not supported yet.")
|
||||
if runner.vllm_config.cache_config.enable_prefix_caching:
|
||||
raise NotImplementedError(
|
||||
"Prefix caching is not supported for Mamba yet.")
|
||||
max_model_len = runner.vllm_config.model_config.max_model_len
|
||||
|
||||
page_size_padded = (
|
||||
runner.vllm_config.cache_config.mamba_page_size_padded)
|
||||
|
||||
# Set block_size to max_model_len, so that mamba model will always
|
||||
# have only one block in the KV cache.
|
||||
for layer_name, mamba_module in mamba_layers.items():
|
||||
kv_cache_spec[layer_name] = MambaSpec(
|
||||
shapes=mamba_module.get_state_shape(),
|
||||
dtypes=mamba_module.get_state_dtype(),
|
||||
block_size=max_model_len,
|
||||
page_size_padded=page_size_padded,
|
||||
mamba_type=mamba_module.mamba_type,
|
||||
num_speculative_blocks=(
|
||||
runner.speculative_config.num_speculative_tokens
|
||||
if runner.speculative_config else 0),
|
||||
)
|
||||
|
||||
return kv_cache_spec
|
||||
Reference in New Issue
Block a user