Compare commits

..

24 Commits

Author SHA1 Message Date
fcec8c8827 add debug cruft
Signed-off-by: Tyler Michael Smith <tysmith@redhat.com>
2025-06-20 20:37:37 +00:00
850dafea92 update
Signed-off-by: Tyler Michael Smith <tysmith@redhat.com>
2025-06-20 19:57:07 +00:00
b4f17e12a4 tolerances
Signed-off-by: Tyler Michael Smith <tysmith@redhat.com>
2025-06-20 19:47:25 +00:00
21ffc7353a fixup
Signed-off-by: Tyler Michael Smith <tysmith@redhat.com>
2025-06-20 15:56:05 +00:00
39d5d33f8f tweaks
Signed-off-by: Tyler Michael Smith <tysmith@redhat.com>
2025-06-20 15:36:59 +00:00
7a821f0e7f precommit
Signed-off-by: Tyler Michael Smith <tysmith@redhat.com>
2025-06-20 14:41:20 +00:00
26fd8ca33c fixes
Signed-off-by: Tyler Michael Smith <tysmith@redhat.com>
2025-06-20 14:40:21 +00:00
d5f206767c Unit test
Signed-off-by: Tyler Michael Smith <tysmith@redhat.com>
2025-06-20 14:39:58 +00:00
2b5ad9f233 fixes - use-fp8-dispatch
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
2025-06-18 21:46:22 +00:00
299f829180 DeepGEMM LL optimizations
- Quantized dispatch
- Fused act-and-mul-and-quant in the right layout for DeepGEMM

Signed-off-by: Tyler Michael Smith <tysmith@redhat.com>
2025-06-18 20:09:51 +00:00
104a984e6a Merge remote-tracking branch 'nm/varun/deepep-fp8-dispatch' into ll_deepgemm_opt 2025-06-18 19:35:02 +00:00
12575cfa7a [Bugfix] fix RAY_CGRAPH_get_timeout is not set successfully (#19725)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
2025-06-18 10:26:16 -07:00
8b6e1d639c [Hardware][AMD] integrate aiter chunked prefill into vllm (#18596)
Signed-off-by: fsx950223 <fsx950223@outlook.com>
Signed-off-by: charlifu <charlifu@amd.com>
Co-authored-by: fsx950223 <fsx950223@outlook.com>
Co-authored-by: charlifu <charlifu@amd.com>
2025-06-18 08:46:51 -07:00
8de2fd39fc deep_ep + use_fp8_dispatch
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
2025-06-18 07:32:15 -07:00
735a9de71f [Qwen] Add tagging rule for Qwen related PRs (#19799)
Signed-off-by: Lu Fang <lufang@fb.com>
2025-06-18 14:26:43 +00:00
257ab95439 [Platform] Allow platform use V1 Engine by default (#19792)
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
2025-06-18 13:03:36 +00:00
cca91a7a10 [doc] fix the incorrect label (#19787)
Signed-off-by: reidliu41 <reid201711@gmail.com>
Co-authored-by: reidliu41 <reid201711@gmail.com>
2025-06-18 10:30:58 +00:00
f04d604567 [Minor] Zero-initialize attn output buffer (#19784)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-06-18 06:59:27 +00:00
19a53b2783 [V1] Decouple GPU and TPU InputBatch (#19778)
Signed-off-by: Andrew Feldman <afeldman@redhat.com>
2025-06-18 06:38:13 +00:00
eccdc8318c [V1][P/D] An native implementation of xPyD based on P2P NCCL (#18242)
Signed-off-by: Abatom <abzhonghua@gmail.com>
2025-06-18 06:32:36 +00:00
5f52a84685 [V1] Add API docs for EncoderCacheManager (#19294)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
2025-06-18 13:37:01 +08:00
d4629dc43f [Misc] Add __str__ for RequestStatus (#19780)
Signed-off-by: Linkun Chen <github@lkchen.net>
2025-06-18 03:03:01 +00:00
6e9cc73f67 [MISC] correct DeviceConfig device field static type analysis (#19699)
Signed-off-by: Andy Xie <andy.xning@gmail.com>
2025-06-17 17:21:50 -07:00
c53711bd63 [MISC] correct copy_blocks src_to_dists param type (#19696)
Signed-off-by: Andy Xie <andy.xning@gmail.com>
2025-06-17 17:21:06 -07:00
37 changed files with 3469 additions and 597 deletions

15
.github/mergify.yml vendored
View File

@ -65,6 +65,21 @@ pull_request_rules:
add:
- multi-modality
- name: label-qwen
description: Automatically apply qwen label
conditions:
- or:
- files~=^examples/.*qwen.*\.py
- files~=^tests/.*qwen.*\.py
- files~=^vllm/model_executor/models/.*qwen.*\.py
- files~=^vllm/reasoning/.*qwen.*\.py
- title~=(?i)Qwen
- body~=(?i)Qwen
actions:
label:
add:
- qwen
- name: label-rocm
description: Automatically apply rocm label
conditions:

View File

@ -1,387 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
import argparse
import asyncio
import logging
import random
import time
from dataclasses import dataclass
from typing import Optional
import aiohttp # Import aiohttp
import numpy as np
from tqdm import tqdm
from backend_request_func import RequestFuncInput, RequestFuncOutput
from benchmark_dataset import RandomDataset, SampleRequest
try:
from vllm.transformers_utils.tokenizer import get_tokenizer
except ImportError:
from backend_request_func import get_tokenizer
logger = logging.getLogger(__name__)
@dataclass
class BenchmarkMetrics:
completed: int
total_input: int
total_output: int
mean_ttft_ms: float
median_ttft_ms: float
std_ttft_ms: float
percentiles_ttft_ms: list[tuple[float, float]]
mean_itl_ms: float
median_itl_ms: float
std_itl_ms: float
percentiles_itl_ms: list[tuple[float, float]]
mean_e2el_ms: float
median_e2el_ms: float
std_e2el_ms: float
percentiles_e2el_ms: list[tuple[float, float]]
async def reset_cache(reset_url: str):
"""Sends a POST request to reset the prefix cache."""
logger.debug("Resetting prefix cache at %s", reset_url)
try:
async with (
aiohttp.ClientSession() as session,
session.post(reset_url) as response,
):
response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
logger.debug("Prefix cache reset successful: %s", response.status)
except aiohttp.ClientConnectorError as e:
logger.error("Failed to connect to cache reset endpoint %s: %s}", reset_url, e)
except aiohttp.ClientResponseError as e:
logger.error(
"Cache reset request failed with status %s: %s", e.status, e.message
)
except Exception as e:
logger.error("An unexpected error occurred during cache reset: %s", e)
async def sequential_benchmark(
backend: str,
api_url: str,
model_id: str,
tokenizer,
input_requests: list[SampleRequest],
request_func,
selected_percentiles: list[float],
cache_reset_url: Optional[str] = None,
):
"""
Benchmark that processes requests sequentially, waiting for each to complete
before starting the next one. Resets prefix cache between requests.
"""
outputs = []
pbar = tqdm(total=len(input_requests))
# Small request to force a forward pass.
# Used for resetting the prefix cache.
# dummy_req_input = RequestFuncInput(
# model=model_id,
# prompt="0",
# api_url=api_url,
# prompt_len=1,
# output_len=2,
# )
# print("Starting initial single prompt test run...")
# test_output = await request_func(request_func_input=dummy_req_input)
# if not test_output.success:
# raise ValueError(
# "Initial test run failed - Please check your configuration. "
# "Error: %s", test_output.error)
# else:
# print("Initial test run completed. Starting sequential benchmark...")
benchmark_start_time = time.perf_counter()
# Process requests sequentially
for request in input_requests:
prompt, prompt_len, output_len = (
request.prompt,
request.prompt_len,
request.expected_output_len,
)
logger.info("Sending request with len %s", request.prompt_len)
logger.debug('Request str: "%s"', request.prompt[:50])
request_start_time = time.perf_counter()
# print(f"{prompt=}")
request_func_input = RequestFuncInput(
model=model_id,
prompt=prompt,
api_url=api_url,
prompt_len=prompt_len,
output_len=output_len,
)
output = await request_func(request_func_input=request_func_input)
request_end_time = time.perf_counter()
# Add timing information
if output.success and not hasattr(output, "latency"):
output.latency = request_end_time - request_start_time
logger.info("Finished request with latency %.4f s", output.latency)
outputs.append(output)
pbar.update(1)
# Reset prefix cache if configured, except after the very last request
if cache_reset_url and False:
await request_func(request_func_input=dummy_req_input)
await reset_cache(cache_reset_url)
pbar.close()
benchmark_duration = time.perf_counter() - benchmark_start_time
# Calculate metrics
metrics = calculate_metrics(
input_requests=input_requests,
outputs=outputs,
dur_s=benchmark_duration,
tokenizer=tokenizer,
selected_percentiles=selected_percentiles,
)
print_results(metrics, benchmark_duration)
result = {
"duration": benchmark_duration,
"completed": metrics.completed,
"total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output,
"input_lens": [request.prompt_len for request in input_requests],
"output_lens": [
output.output_tokens if output.success else 0 for output in outputs
],
"ttfts": [output.ttft for output in outputs if output.success],
"itls": [output.itl for output in outputs if output.success],
"generated_texts": [
output.generated_text for output in outputs if output.success
],
"errors": [output.error for output in outputs if not output.success],
}
# Add summary statistics
for stat_name in ["ttft", "itl", "e2el"]:
for metric_name in ["mean", "median", "std"]:
result[f"{metric_name}_{stat_name}_ms"] = getattr(
metrics, f"{metric_name}_{stat_name}_ms"
)
for p, value in getattr(metrics, f"percentiles_{stat_name}_ms"):
p_word = str(int(p)) if int(p) == p else str(p)
result[f"p{p_word}_{stat_name}_ms"] = value
return result
def calculate_metrics(
input_requests: list[SampleRequest],
outputs: list[RequestFuncOutput],
dur_s: float,
tokenizer,
selected_percentiles: list[float],
) -> BenchmarkMetrics:
"""Calculate benchmark metrics from results."""
total_input = 0
completed = 0
total_output = 0
ttfts = []
itls = []
e2els = []
for i, output in enumerate(outputs):
if output.success:
output_len = output.output_tokens
if not output_len:
# Use tokenizer to count output tokens if not provided
output_len = len(
tokenizer(output.generated_text, add_special_tokens=False).input_ids
)
total_output += output_len
total_input += input_requests[i].prompt_len
if hasattr(output, "ttft") and output.ttft is not None:
ttfts.append(output.ttft)
if hasattr(output, "itl") and output.itl:
# Ensure itl is a list of floats
if isinstance(output.itl, list):
itls.extend(output.itl)
else:
logger.warning(
"Expected list for ITL but got %s. Appending as is.",
type(output.itl),
)
itls.append(output.itl)
if hasattr(output, "latency") and output.latency is not None:
e2els.append(output.latency)
completed += 1
return BenchmarkMetrics(
completed=completed,
total_input=total_input,
total_output=total_output,
mean_ttft_ms=np.mean(ttfts or [0]) * 1000,
median_ttft_ms=np.median(ttfts or [0]) * 1000,
std_ttft_ms=np.std(ttfts or [0]) * 1000,
percentiles_ttft_ms=[
(p, np.percentile(ttfts or [0], p) * 1000) for p in selected_percentiles
],
mean_itl_ms=np.mean(itls or [0]) * 1000,
median_itl_ms=np.median(itls or [0]) * 1000,
std_itl_ms=np.std(itls or [0]) * 1000,
percentiles_itl_ms=[
(p, np.percentile(itls or [0], p) * 1000) for p in selected_percentiles
],
mean_e2el_ms=np.mean(e2els or [0]) * 1000,
median_e2el_ms=np.median(e2els or [0]) * 1000,
std_e2el_ms=np.std(e2els or [0]) * 1000,
percentiles_e2el_ms=[
(p, np.percentile(e2els or [0], p) * 1000) for p in selected_percentiles
],
)
def print_results(metrics: BenchmarkMetrics, benchmark_duration: float):
"""Print benchmark results in a formatted way."""
print("{s:{c}^{n}}".format(s=" Sequential Benchmark Result ", n=60, c="="))
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
def print_metric_stats(metric_name, header):
print("{s:{c}^{n}}".format(s=header, n=60, c="-"))
print(
"{:<40} {:<10.2f}".format(
f"Mean {metric_name} (ms):",
getattr(metrics, f"mean_{metric_name.lower()}_ms"),
)
)
print(
"{:<40} {:<10.2f}".format(
f"Median {metric_name} (ms):",
getattr(metrics, f"median_{metric_name.lower()}_ms"),
)
)
for p, value in getattr(metrics, f"percentiles_{metric_name.lower()}_ms"):
p_word = str(int(p)) if int(p) == p else str(p)
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value))
print_metric_stats("TTFT", "Time to First Token")
print_metric_stats("ITL", "Inter-token Latency")
print_metric_stats("E2EL", "End-to-end Latency")
print("=" * 60)
async def main_async(args):
# Import needed functions based on your setup
from backend_request_func import ASYNC_REQUEST_FUNCS
backend = args.backend
model_id = args.model
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
# Set up API URL
if args.base_url is not None:
api_url = f"{args.base_url}{args.endpoint}"
else:
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
# Set up Cache Reset URL
cache_reset_url = f"http://{args.host}:{args.port}/reset_prefix_cache"
logger.info("Prefix cache reset configured at: %s", cache_reset_url)
# Get tokenizer
tokenizer = get_tokenizer(tokenizer_id, trust_remote_code=args.trust_remote_code)
# Get request function
if backend in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS[backend]
else:
raise ValueError(f"Unknown backend: {backend}")
input_requests = RandomDataset().sample(
tokenizer=tokenizer,
num_requests=args.num_requests,
prefix_len=0,
input_len=args.input_len,
output_len=args.output_len,
range_ratio=0.0,
)
# Run benchmark
result = await sequential_benchmark(
backend=backend,
api_url=api_url,
model_id=model_id,
tokenizer=tokenizer,
input_requests=input_requests,
request_func=request_func,
selected_percentiles=[50, 90, 95, 99],
cache_reset_url=cache_reset_url,
)
return result
def main(args):
print(args)
random.seed(args.seed)
np.random.seed(args.seed)
asyncio.run(main_async(args))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Sequential benchmark for LLM serving")
parser.add_argument(
"--backend", type=str, default="vllm", help="Backend to use for requests"
)
parser.add_argument(
"--base-url",
type=str,
default=None,
help="Server base URL (overrides --host and --port)",
)
parser.add_argument("--host", type=str, default="127.0.0.1")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument(
"--endpoint", type=str, default="/v1/completions", help="API endpoint"
)
parser.add_argument("--model", type=str, required=True, help="Name of the model")
parser.add_argument(
"--tokenizer", type=str, help="Name of the tokenizer (defaults to model name)"
)
parser.add_argument(
"--num-requests", type=int, default=100, help="Number of requests to process"
)
parser.add_argument(
"--input-len", type=int, default=128, help="Input len for generated prompts"
)
parser.add_argument(
"--output-len", type=int, default=None, help="Override output len for requests"
)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument(
"--trust-remote-code",
action="store_true",
help="Trust remote code from HuggingFace",
)
args = parser.parse_args()
main(args)

View File

@ -0,0 +1,337 @@
An implementation of xPyD with dynamic scaling based on point-to-point communication, partly inspired by Dynamo.
# Detailed Design
## Overall Process
As shown in Figure 1, the overall process of this **PD disaggregation** solution is described through a request flow:
1. The client sends an HTTP request to the Proxy/Router's `/v1/completions` interface.
2. The Proxy/Router selects a **1P1D (1 Prefill instance + 1 Decode instance)** through either through round-robin or random selection, generates a `request_id` (rules to be introduced later), modifies the `max_tokens` in the HTTP request message to **1**, and then forwards the request to the **P instance**.
3. Immediately afterward, the Proxy/Router forwards the **original HTTP request** to the **D instance**.
4. The **P instance** performs **Prefill** and then **actively sends the generated KV cache** to the D instance (using **PUT_ASYNC** mode). The D instance's `zmq_addr` can be resolved through the `request_id`.
5. The **D instance** has a **dedicated thread** for receiving the KV cache (to avoid blocking the main process). The received KV cache is saved into the **GPU memory buffer**, the size of which is determined by the vLLM startup parameter `kv_buffer_size`. When the GPU buffer is full, the KV cache is stored in the **local Tensor memory pool**.
6. During the **Decode**, the D instance's main process retrieves the KV cache (transmitted by the P instance) from either the **GPU buffer** or the **memory pool**, thereby **skipping Prefill**.
7. After completing **Decode**, the D instance returns the result to the **Proxy/Router**, which then forwards it to the **client**.
![image1](https://github.com/user-attachments/assets/fb01bde6-755b-49f7-ad45-48a94b1e10a7)
## Proxy/Router (Demo)
A simple HTTP service acts as the entry point for client requests and starts a background thread to listen for P/D instances reporting their HTTP IP and PORT, as well as ZMQ IP and PORT. It maintains a dictionary of `http_addr -> zmq_addr`. The `http_addr` is the IP:PORT for the vLLM instance's request, while the `zmq_addr` is the address for KV cache handshake and metadata reception.
The Proxy/Router is responsible for selecting 1P1D based on the characteristics of the client request, such as the prompt, and generating a corresponding `request_id`, for example:
```
cmpl-___prefill_addr_10.0.1.2:21001___decode_addr_10.0.1.3:22001_93923d63113b4b338973f24d19d4bf11-0
```
Currently, to quickly verify whether xPyD can work, a round-robin selection of 1P1D is used. In the future, it is planned to use a trie combined with the load status of instances to select appropriate P and D.
Each P/D instance periodically sends a heartbeat packet to the Proxy/Router (currently every 3 seconds) to register (i.e., report `http_addr -> zmq_addr`) and keep the connection alive. If an instance crashes and fails to send a ping for a certain period of time, the Proxy/Router will remove the timed-out instance (this feature has not yet been developed).
## KV Cache Transfer Methods
There are three methods for KVcache transfer: PUT, GET, and PUT_ASYNC. These methods can be specified using the `--kv-transfer-config` and `kv_connector_extra_config` parameters, specifically through the `send_type` field. Both PUT and PUT_ASYNC involve the P instance actively sending KVcache to the D instance. The difference is that PUT is a synchronous transfer method that blocks the main process, while PUT_ASYNC is an asynchronous transfer method. PUT_ASYNC uses a dedicated thread for sending KVcache, which means it does not block the main process. In contrast, the GET method involves the P instance saving the KVcache to the memory buffer after computing the prefill. The D instance then actively retrieves the computed KVcache from the P instance once it has allocated space for the KVcache.
Experimental results have shown that the performance of these methods, from highest to lowest, is as follows: PUT_ASYNC → GET → PUT.
## P2P Communication via ZMQ & NCCL
As long as the address of the counterpart is known, point-to-point KV cache transfer (using NCCL) can be performed, without being constrained by rank and world size. To support dynamic scaling (expansion and contraction) of instances with PD disaggregation. This means that adding or removing P/D instances does not require a full system restart.
Each P/D instance only needs to create a single `P2pNcclEngine` instance. This instance maintains a ZMQ Server, which runs a dedicated thread to listen on the `zmq_addr` address and receive control flow requests from other instances. These requests include requests to establish an NCCL connection and requests to send KVcache metadata (such as tensor shapes and data types). However, it does not actually transmit the KVcache data itself.
When a P instance and a D instance transmit KVcache for the first time, they need to establish a ZMQ connection and an NCCL group. For subsequent KVcache transmissions, this ZMQ connection and NCCL group are reused. The NCCL group consists of only two ranks, meaning the world size is equal to 2. This design is intended to support dynamic scaling, which means that adding or removing P/D instances does not require a full system restart. As long as the address of the counterpart is known, point-to-point KVcache transmission can be performed, without being restricted by rank or world size.
## NCCL Group Topology
Currently, only symmetric TP (Tensor Parallelism) methods are supported for KVcache transmission. Asymmetric TP and PP (Pipeline Parallelism) methods will be supported in the future. Figure 2 illustrates the 1P2D setup, where each instance has a TP (Tensor Parallelism) degree of 2. There are a total of 7 NCCL groups: three vLLM instances each have one NCCL group with TP=2. Additionally, the 0th GPU card of the P instance establishes an NCCL group with the 0th GPU card of each D instance. Similarly, the 1st GPU card of the P instance establishes an NCCL group with the 1st GPU card of each D instance.
![image2](https://github.com/user-attachments/assets/837e61d6-365e-4cbf-8640-6dd7ab295b36)
Each NCCL group occupies a certain amount of GPU memory buffer for communication, the size of which is primarily influenced by the `NCCL_MAX_NCHANNELS` environment variable. When `NCCL_MAX_NCHANNELS=16`, an NCCL group typically occupies 100MB, while when `NCCL_MAX_NCHANNELS=8`, it usually takes up 52MB. For large-scale xPyD configurations—such as DeepSeek's 96P144D—this implementation is currently not feasible. Moving forward, we are considering using RDMA for point-to-point communication and are also keeping an eye on UCCL.
## GPU Memory Buffer and Tensor Memory Pool
The trade-off in the size of the memory buffer is as follows: For P instances, the memory buffer is not required in PUT and PUT_ASYNC modes, but it is necessary in GET mode. For D instances, a memory buffer is needed in all three modes. The memory buffer for D instances should not be too large. Similarly, for P instances in GET mode, the memory buffer should also not be too large. The memory buffer of D instances is used to temporarily store KVcache sent by P instances. If it is too large, it will reduce the KVcache space available for normal inference by D instances, thereby decreasing the inference batch size and ultimately leading to a reduction in output throughput. The size of the memory buffer is configured by the parameter `kv_buffer_size`, measured in bytes, and is typically set to 5%10% of the memory size.
If the `--max-num-seqs` parameter for P instances is set to a large value, due to the large batch size, P instances will generate a large amount of KVcache simultaneously. This may exceed the capacity of the memory buffer of D instances, resulting in KVcache loss. Once KVcache is lost, D instances need to recompute Prefill, which is equivalent to performing Prefill twice. Consequently, the time-to-first-token (TTFT) will significantly increase, leading to degraded performance.
To address the above issues, I have designed and developed a local Tensor memory pool for storing KVcache, inspired by the buddy system used in Linux memory modules. Since the memory is sufficiently large, typically in the TB range on servers, there is no need to consider prefix caching or using block-based designs to reuse memory, thereby saving space. When the memory buffer is insufficient, KVcache can be directly stored in the Tensor memory pool, and D instances can subsequently retrieve KVcache from it. The read and write speed is that of PCIe, with PCIe 4.0 having a speed of approximately 21 GB/s, which is usually faster than the Prefill speed. Otherwise, solutions like Mooncake and lmcache would not be necessary. The Tensor memory pool acts as a flood diversion area, typically unused except during sudden traffic surges. In the worst-case scenario, my solution performs no worse than the normal situation with a Cache store.
# Install vLLM
```shell
# Enter the home directory or your working directory.
cd /home
# Download the installation package, and I will update the commit-id in time. You can directly copy the command.
wget https://vllm-wheels.s3.us-west-2.amazonaws.com/9112b443a042d8d815880b8780633882ad32b183/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl
# Download the code repository.
git clone -b xpyd-v1 https://github.com/Abatom/vllm.git
cd vllm
# Set the installation package path.
export VLLM_PRECOMPILED_WHEEL_LOCATION=/home/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl
# installation
pip install -e . -v
```
# Run xPyD
## Instructions
- The following examples are run on an A800 (80GB) device, using the Meta-Llama-3.1-8B-Instruct model.
- Pay attention to the setting of the `kv_buffer_size` (in bytes). The empirical value is 10% of the GPU memory size. This is related to the kvcache size. If it is too small, the GPU memory buffer for temporarily storing the received kvcache will overflow, causing the kvcache to be stored in the tensor memory pool, which increases latency. If it is too large, the kvcache available for inference will be reduced, leading to a smaller batch size and decreased throughput.
- For Prefill instances, when using non-GET mode, the `kv_buffer_size` can be set to 1, as Prefill currently does not need to receive kvcache. However, when using GET mode, a larger `kv_buffer_size` is required because it needs to store the kvcache sent to the D instance.
- You may need to modify the `kv_buffer_size` and `port` in the following commands (if there is a conflict).
- `PUT_ASYNC` offers the best performance and should be prioritized.
- The `--port` must be consistent with the `http_port` in the `--kv-transfer-config`.
- The `disagg_prefill_proxy_xpyd.py` script will use port 10001 (for receiving client requests) and port 30001 (for receiving service discovery from P and D instances).
- The node running the proxy must have `quart` installed.
- Supports multiple nodes; you just need to modify the `proxy_ip` and `proxy_port` in `--kv-transfer-config`.
- In the following examples, it is assumed that **the proxy's IP is 10.0.1.1**.
## Run 1P3D
### Proxy (e.g. 10.0.1.1)
```shell
cd {your vllm directory}/examples/online_serving/disagg_xpyd/
python3 disagg_prefill_proxy_xpyd.py &
```
### Prefill1 (e.g. 10.0.1.2 or 10.0.1.1)
```shell
VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=0 vllm serve {your model directory} \
--host 0.0.0.0 \
--port 20005 \
--tensor-parallel-size 1 \
--seed 1024 \
--served-model-name base_model \
--dtype float16 \
--max-model-len 10000 \
--max-num-batched-tokens 10000 \
--max-num-seqs 256 \
--trust-remote-code \
--gpu-memory-utilization 0.9 \
--disable-log-request \
--kv-transfer-config \
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_buffer_size":"1e1","kv_port":"21001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20005","send_type":"PUT_ASYNC","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 &
```
### Decode1 (e.g. 10.0.1.3 or 10.0.1.1)
```shell
VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=1 vllm serve {your model directory} \
--host 0.0.0.0 \
--port 20009 \
--tensor-parallel-size 1 \
--seed 1024 \
--served-model-name base_model \
--dtype float16 \
--max-model-len 10000 \
--max-num-batched-tokens 10000 \
--max-num-seqs 256 \
--trust-remote-code \
--gpu-memory-utilization 0.7 \
--disable-log-request \
--kv-transfer-config \
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_buffer_size":"8e9","kv_port":"22001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20009","send_type":"PUT_ASYNC","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 &
```
### Decode2 (e.g. 10.0.1.4 or 10.0.1.1)
```shell
VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=2 vllm serve {your model directory} \
--host 0.0.0.0 \
--port 20003 \
--tensor-parallel-size 1 \
--seed 1024 \
--served-model-name base_model \
--dtype float16 \
--max-model-len 10000 \
--max-num-batched-tokens 10000 \
--max-num-seqs 256 \
--trust-remote-code \
--gpu-memory-utilization 0.7 \
--disable-log-request \
--kv-transfer-config \
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_buffer_size":"8e9","kv_port":"23001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20003","send_type":"PUT_ASYNC","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 &
```
### Decode3 (e.g. 10.0.1.5 or 10.0.1.1)
```shell
VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=3 vllm serve {your model directory} \
--host 0.0.0.0 \
--port 20008 \
--tensor-parallel-size 1 \
--seed 1024 \
--served-model-name base_model \
--dtype float16 \
--max-model-len 10000 \
--max-num-batched-tokens 10000 \
--max-num-seqs 256 \
--trust-remote-code \
--gpu-memory-utilization 0.7 \
--disable-log-request \
--kv-transfer-config \
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_buffer_size":"8e9","kv_port":"24001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20008","send_type":"PUT_ASYNC","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 &
```
## Run 3P1D
### Proxy (e.g. 10.0.1.1)
```shell
cd {your vllm directory}/examples/online_serving/disagg_xpyd/
python3 disagg_prefill_proxy_xpyd.py &
```
### Prefill1 (e.g. 10.0.1.2 or 10.0.1.1)
```shell
VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=0 vllm serve {your model directory} \
--host 0.0.0.0 \
--port 20005 \
--tensor-parallel-size 1 \
--seed 1024 \
--served-model-name base_model \
--dtype float16 \
--max-model-len 10000 \
--max-num-batched-tokens 10000 \
--max-num-seqs 256 \
--trust-remote-code \
--gpu-memory-utilization 0.9 \
--disable-log-request \
--kv-transfer-config \
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_buffer_size":"1e1","kv_port":"21001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20005","send_type":"PUT_ASYNC","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 &
```
### Prefill2 (e.g. 10.0.1.3 or 10.0.1.1)
```shell
VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=1 vllm serve {your model directory} \
--host 0.0.0.0 \
--port 20009 \
--tensor-parallel-size 1 \
--seed 1024 \
--served-model-name base_model \
--dtype float16 \
--max-model-len 10000 \
--max-num-batched-tokens 10000 \
--max-num-seqs 256 \
--trust-remote-code \
--gpu-memory-utilization 0.9 \
--disable-log-request \
--kv-transfer-config \
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_buffer_size":"1e1","kv_port":"22001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20009","send_type":"PUT_ASYNC","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 &
```
### Prefill3 (e.g. 10.0.1.4 or 10.0.1.1)
```shell
VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=2 vllm serve {your model directory} \
--host 0.0.0.0 \
--port 20003 \
--tensor-parallel-size 1 \
--seed 1024 \
--served-model-name base_model \
--dtype float16 \
--max-model-len 10000 \
--max-num-batched-tokens 10000 \
--max-num-seqs 256 \
--trust-remote-code \
--gpu-memory-utilization 0.9 \
--disable-log-request \
--kv-transfer-config \
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_buffer_size":"1e1","kv_port":"23001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20003","send_type":"PUT_ASYNC","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 &
```
### Decode1 (e.g. 10.0.1.5 or 10.0.1.1)
```shell
VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=3 vllm serve {your model directory} \
--host 0.0.0.0 \
--port 20008 \
--tensor-parallel-size 1 \
--seed 1024 \
--served-model-name base_model \
--dtype float16 \
--max-model-len 10000 \
--max-num-batched-tokens 10000 \
--max-num-seqs 256 \
--trust-remote-code \
--gpu-memory-utilization 0.7 \
--disable-log-request \
--kv-transfer-config \
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_buffer_size":"8e9","kv_port":"24001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20008","send_type":"PUT_ASYNC","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 &
```
# Single request
```shell
curl -X POST -s http://10.0.1.1:10001/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "base_model",
"prompt": "San Francisco is a",
"max_tokens": 10,
"temperature": 0
}'
```
# Benchmark
```shell
python3 benchmark_serving.py \
--backend vllm \
--model base_model \
--tokenizer meta-llama/Llama-3.1-8B-Instruct \
--dataset-name "random" \
--host 10.0.1.1 \
--port 10001 \
--random-input-len 1024 \
--random-output-len 1024 \
--ignore-eos \
--burstiness 100 \
--percentile-metrics "ttft,tpot,itl,e2el" \
--metric-percentiles "90,95,99" \
--seed $(date +%s) \
--trust-remote-code \
--request-rate 3 \
--num-prompts 1000
```
# Shut down
```shell
pgrep python | xargs kill -9 && pkill -f python
```
# Test data
## **Scenario 1**: 1K input & 1K output tokens, E2E P99 latency ~20s
- **1P5D (6×A800) vs vLLM (1×A800)**:
- Throughput ↑7.2% (1085 → 6979/6)
- ITL (P99) ↓81.3% (120ms → 22.9ms)
- TTFT (P99) ↑26.8% (175ms → 222ms)
- TPOT: No change
- **1P6D (7×A800) vs vLLM (1×A800)**:
- Throughput ↑9.6% (1085 → 8329/7)
- ITL (P99) ↓81.0% (120ms → 22.7ms)
- TTFT (P99) ↑210% (175ms →543ms)
- TPOT: No change
## **Scenario 2**: 1K input & 200 output tokens, E2E P99 latency ~4s
- **1P1D (2×A800) vs vLLM (1×A800)**:
- Throughput ↑37.4% (537 → 1476/2)
- ITL (P99) ↓81.8% (127ms → 23.1ms)
- TTFT (P99) ↑41.8% (160ms → 227ms)
- TPOT: No change
![testdata](https://github.com/user-attachments/assets/f791bfc7-9f3d-4e5c-9171-a42f9f4da627)

View File

@ -42,7 +42,7 @@ vLLM is a Python library that supports the following GPU variants. Select your G
=== "NVIDIA CUDA"
--8<-- "docs/getting_started/installation/gpu/cuda.inc.md:create-a-new-python-environment"
--8<-- "docs/getting_started/installation/gpu/cuda.inc.md:set-up-using-python"
=== "AMD ROCm"

View File

@ -10,8 +10,6 @@ vLLM contains pre-compiled C++ and CUDA (12.8) binaries.
# --8<-- [end:requirements]
# --8<-- [start:set-up-using-python]
### Create a new Python environment
!!! note
PyTorch installed via `conda` will statically link `NCCL` library, which can cause issues when vLLM tries to use `NCCL`. See <gh-issue:8420> for more details.

View File

@ -0,0 +1,154 @@
# SPDX-License-Identifier: Apache-2.0
import os
import socket
import threading
import uuid
import aiohttp
import msgpack
import zmq
from quart import Quart, make_response, request
count = 0
prefill_instances: dict[str, str] = {} # http_address: zmq_address
decode_instances: dict[str, str] = {} # http_address: zmq_address
prefill_cv = threading.Condition()
decode_cv = threading.Condition()
def _listen_for_register(poller, router_socket):
while True:
socks = dict(poller.poll())
if router_socket in socks:
remote_address, message = router_socket.recv_multipart()
# data: {"type": "P", "http_address": "ip:port",
# "zmq_address": "ip:port"}
data = msgpack.loads(message)
if data["type"] == "P":
global prefill_instances
global prefill_cv
with prefill_cv:
prefill_instances[data["http_address"]] = data["zmq_address"]
elif data["type"] == "D":
global decode_instances
global decode_cv
with decode_cv:
decode_instances[data["http_address"]] = data["zmq_address"]
else:
print(
"Unexpected, Received message from %s, data: %s",
remote_address,
data,
)
def start_service_discovery(hostname, port):
if not hostname:
hostname = socket.gethostname()
if port == 0:
raise ValueError("Port cannot be 0")
context = zmq.Context()
router_socket = context.socket(zmq.ROUTER)
router_socket.bind(f"tcp://{hostname}:{port}")
poller = zmq.Poller()
poller.register(router_socket, zmq.POLLIN)
_listener_thread = threading.Thread(
target=_listen_for_register, args=[poller, router_socket], daemon=True
)
_listener_thread.start()
return _listener_thread
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
app = Quart(__name__)
def random_uuid() -> str:
return str(uuid.uuid4().hex)
async def forward_request(url, data, request_id):
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id,
}
async with session.post(url=url, json=data, headers=headers) as response:
if response.status == 200:
if True:
async for chunk_bytes in response.content.iter_chunked(1024):
yield chunk_bytes
else:
content = await response.read()
yield content
@app.route("/v1/completions", methods=["POST"])
async def handle_request():
try:
original_request_data = await request.get_json()
prefill_request = original_request_data.copy()
# change max_tokens = 1 to let it only do prefill
prefill_request["max_tokens"] = 1
global count
global prefill_instances
global prefill_cv
with prefill_cv:
prefill_list = list(prefill_instances.items())
prefill_addr, prefill_zmq_addr = prefill_list[count % len(prefill_list)]
global decode_instances
global decode_cv
with decode_cv:
decode_list = list(decode_instances.items())
decode_addr, decode_zmq_addr = decode_list[count % len(decode_list)]
print(
f"handle_request count: {count}, [HTTP:{prefill_addr}, "
f"ZMQ:{prefill_zmq_addr}] 👉 [HTTP:{decode_addr}, "
f"ZMQ:{decode_zmq_addr}]"
)
count += 1
request_id = (
f"___prefill_addr_{prefill_zmq_addr}___decode_addr_"
f"{decode_zmq_addr}_{random_uuid()}"
)
# finish prefill
async for _ in forward_request(
f"http://{prefill_addr}/v1/completions", prefill_request, request_id
):
continue
# return decode
generator = forward_request(
f"http://{decode_addr}/v1/completions", original_request_data, request_id
)
response = await make_response(generator)
response.timeout = None
return response
except Exception as e:
import sys
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server")
print(e)
print("".join(traceback.format_exception(*exc_info)))
if __name__ == "__main__":
t = start_service_discovery("0.0.0.0", 30001)
app.run(host="0.0.0.0", port=10001)
t.join()

View File

@ -0,0 +1,83 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
silu_mul_fp8_quant_deep_gemm)
from vllm.platforms import current_platform
# (E, T, H, group_size, seed)
CASES = [
(1, 1, 128, 64, 0),
(1, 4, 128, 128, 0),
(2, 4, 256, 128, 0),
(32, 64, 256, 128, 0),
(17, 31, 768, 128, 0),
]
@pytest.mark.parametrize("E,T,H,group_size,seed", CASES)
@torch.inference_mode()
def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed):
current_platform.seed_everything(seed)
# Input tensor of shape (E, T, 2*H)
y = torch.randn((E, T, 2 * H), dtype=torch.float32, device="cuda")
tokens_per_expert = torch.randint(
low=0,
high=T,
size=(E, ),
dtype=torch.int32,
device="cuda",
)
# Run the Triton kernel
y_q, y_s = silu_mul_fp8_quant_deep_gemm(y,
tokens_per_expert,
group_size=group_size,
eps=1e-10)
# Reference implementation
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max = fp8_info.max
fp8_min = fp8_info.min
eps = 1e-10
# Compute silu activation and elementwise multiplication
y1 = y[..., :H]
y2 = y[..., H:]
silu_x = y1 * torch.sigmoid(y1)
merged = silu_x * y2
# Compute reference scales and quantized output, skipping padded tokens
for e in range(E):
nt = tokens_per_expert[e].item()
ref_s = torch.empty((T, H // group_size),
dtype=torch.float32,
device="cuda")
ref_q = torch.empty((T, H), dtype=torch.float8_e4m3fn, device="cuda")
for t in range(nt):
data = merged[e, t]
data_grp = data.view(H // group_size, group_size)
amax = data_grp.abs().amax(dim=1).clamp(min=eps)
scale = amax / fp8_max
scaled = data / scale.repeat_interleave(group_size)
clamped = scaled.clamp(fp8_min, fp8_max)
q = clamped.to(torch.float8_e4m3fn)
ref_s[t] = scale
ref_q[t] = q
y_se = y_s[e]
y_qe = y_q[e]
torch.testing.assert_close(y_se[:nt], ref_s[:nt])
torch.testing.assert_close(
y_qe[:nt].to(torch.float32),
ref_q[:nt].to(torch.float32),
atol=2,
rtol=2e-1,
)

15
tests/v1/test_request.py Normal file
View File

@ -0,0 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
from vllm.v1.request import RequestStatus
def test_request_status_fmt_str():
"""Test that the string representation of RequestStatus is correct."""
assert f"{RequestStatus.WAITING}" == "WAITING"
assert f"{RequestStatus.WAITING_FOR_FSM}" == "WAITING_FOR_FSM"
assert f"{RequestStatus.WAITING_FOR_REMOTE_KVS}" == "WAITING_FOR_REMOTE_KVS"
assert f"{RequestStatus.RUNNING}" == "RUNNING"
assert f"{RequestStatus.PREEMPTED}" == "PREEMPTED"
assert f"{RequestStatus.FINISHED_STOPPED}" == "FINISHED_STOPPED"
assert f"{RequestStatus.FINISHED_LENGTH_CAPPED}" == "FINISHED_LENGTH_CAPPED"
assert f"{RequestStatus.FINISHED_ABORTED}" == "FINISHED_ABORTED"
assert f"{RequestStatus.FINISHED_IGNORED}" == "FINISHED_IGNORED"

View File

@ -1,79 +0,0 @@
# Needed for the proxy server
vllm-directory := "/home/rshaw/vllm/"
# MODEL := "Qwen/Qwen3-0.6B"
MODEL := "meta-llama/Llama-3.1-8B-Instruct"
PROXY_PORT := "8192"
PREFILL_PORT := "8100"
DECODE_PORT := "8200"
prefill:
VLLM_NIXL_SIDE_CHANNEL_PORT=5557 \
CUDA_VISIBLE_DEVICES=0,7 \
vllm serve {{MODEL}} \
--port {{PREFILL_PORT}} \
--tensor-parallel-size 2 \
--enforce-eager \
--disable-log-requests \
--block-size 128 \
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
decode:
VLLM_NIXL_SIDE_CHANNEL_PORT=5567 \
CUDA_VISIBLE_DEVICES=4,5 \
vllm serve {{MODEL}} \
--port {{DECODE_PORT}} \
--tensor-parallel-size 2 \
--enforce-eager \
--disable-log-requests \
--block-size 128 \
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
proxy:
python "{{vllm-directory}}tests/v1/kv_connector/nixl_integration/toy_proxy_server.py" \
--port {{PROXY_PORT}} \
--prefiller-port {{PREFILL_PORT}} \
--decoder-port {{DECODE_PORT}}
send_request:
curl -X POST http://localhost:{{PROXY_PORT}}/v1/completions \
-H "Content-Type: application/json" \
-d '{ \
"model": "{{MODEL}}", \
"prompt": "Red Hat is the best open source company by far across Linux, K8s, and AI, and vLLM has the greatest community in open source AI software infrastructure. I love vLLM because", \
"max_tokens": 150, \
"temperature": 0.7 \
}'
benchmark NUM_PROMPTS:
python {{vllm-directory}}/benchmarks/benchmark_serving.py \
--port {{PROXY_PORT}} \
--model {{MODEL}} \
--dataset-name random \
--random-input-len 30000 \
--random-output-len 10 \
--num-prompts {{NUM_PROMPTS}} \
--seed $(date +%s) \
benchmark_one INPUT_LEN:
python {{vllm-directory}}benchmarks/benchmark_one_concurrent_req.py \
--port {{PROXY_PORT}} \
--model {{MODEL}} \
--input-len {{INPUT_LEN}} \
--output-len 1 \
--num-requests 10 \
--seed $(date +%s)
benchmark_one_no_pd INPUT_LEN:
python {{vllm-directory}}benchmarks/benchmark_one_concurrent_req.py \
--port {{DECODE_PORT}} \
--model {{MODEL}} \
--input-len {{INPUT_LEN}} \
--output-len 1 \
--num-requests 10 \
--seed $(date +%s)
eval:
lm_eval --model local-completions --tasks gsm8k \
--model_args model={{MODEL}},base_url=http://127.0.0.1:{{PROXY_PORT}}/v1/completions,num_concurrent=100,max_retries=3,tokenized_requests=False \
--limit 1000

View File

@ -209,7 +209,7 @@ class Attention(nn.Module):
if self.use_output:
output_shape = (output_shape
if output_shape is not None else query.shape)
output = torch.empty(output_shape,
output = torch.zeros(output_shape,
dtype=query.dtype,
device=query.device)
hidden_size = output_shape[-1]

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Dict, List, Optional, Tuple
from typing import List, Optional, Tuple
try:
import intel_extension_for_pytorch.llm.modules as ipex_modules
@ -120,7 +120,7 @@ class _PagedAttention:
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
src_to_dists: torch.Tensor,
*args,
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]

View File

@ -2285,7 +2285,7 @@ Device = Literal["auto", "cuda", "neuron", "cpu", "tpu", "xpu", "hpu"]
class DeviceConfig:
"""Configuration for the device to use for vLLM execution."""
device: SkipValidation[Union[Device, torch.device]] = "auto"
device: SkipValidation[Optional[Union[Device, torch.device]]] = "auto"
"""Device type for vLLM execution.
This parameter is deprecated and will be
removed in a future release.
@ -2327,7 +2327,10 @@ class DeviceConfig:
"to turn on verbose logging to help debug the issue.")
else:
# Device type is assigned explicitly
self.device_type = self.device
if isinstance(self.device, str):
self.device_type = self.device
elif isinstance(self.device, torch.device):
self.device_type = self.device.type
# Some device types require processing inputs on CPU
if self.device_type in ["neuron"]:

View File

@ -138,9 +138,29 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
super().__init__(cpu_group)
self.handle_cache = Cache()
# This is the DeepEP default. Stick to it till we can establish
# reasonable defaults based on profiling.
self.num_sms = 20
# Use all SMs for all2all communication
# This will need to be adjusted for dual-batch overlap
device = self.dp_group.device
props = torch.cuda.get_device_properties(device)
self.num_sms = props.multi_processor_count
print(f"Setting num sms to {self.num_sms}")
print(f"Setting num sms to {self.num_sms}")
print(f"Setting num sms to {self.num_sms}")
print(f"Setting num sms to {self.num_sms}")
print(f"Setting num sms to {self.num_sms}")
print(f"Setting num sms to {self.num_sms}")
print(f"Setting num sms to {self.num_sms}")
print(f"Setting num sms to {self.num_sms}")
print(f"Setting num sms to {self.num_sms}")
print(f"Setting num sms to {self.num_sms}")
print(f"Setting num sms to {self.num_sms}")
print(f"Setting num sms to {self.num_sms}")
print(f"Setting num sms to {self.num_sms}")
print(f"Setting num sms to {self.num_sms}")
print(f"Setting num sms to {self.num_sms}")
print(f"Setting num sms to {self.num_sms}")
print(f"Setting num sms to {self.num_sms}")
print(f"Setting num sms to {self.num_sms}")
def get_handle(self, kwargs):
raise NotImplementedError

View File

@ -272,6 +272,14 @@ class NCCLLibrary:
ctypes.byref(unique_id)))
return unique_id
def unique_id_from_bytes(self, data: bytes) -> ncclUniqueId:
if len(data) != 128:
raise ValueError(
f"Expected 128 bytes for ncclUniqueId, got {len(data)} bytes")
unique_id = ncclUniqueId()
ctypes.memmove(ctypes.addressof(unique_id.internal), data, 128)
return unique_id
def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId,
rank: int) -> ncclComm_t:
comm = ncclComm_t()

View File

@ -112,6 +112,11 @@ KVConnectorFactory.register_connector(
"vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector",
"SharedStorageConnector")
KVConnectorFactory.register_connector(
"P2pNcclConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_connector",
"P2pNcclConnector")
KVConnectorFactory.register_connector(
"LMCacheConnectorV1",
"vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector",

View File

@ -37,9 +37,6 @@ if TYPE_CHECKING:
Transfer = tuple[int, float] # (xfer_handle, start_time)
GET_META_MSG = b"get_meta_msg"
import os
LOG_XFER_TIME = os.getenv("VLLM_LOG_XFER_TIME", "0") == "1"
logger = init_logger(__name__)
# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
@ -332,17 +329,7 @@ class NixlConnectorWorker:
self.block_size = vllm_config.cache_config.block_size
# Agent.
import os
num_workers = 32
# setting num_workers on the prefiller causes no notifs to be recved???
# this is a hack to make sure we set num workers on the prefiller to 1.
if os.getenv("VLLM_IS_PREFILL", "0") == "1":
num_workers = None
print(f"NUM_WORKERS: {num_workers=}")
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()),
None,
num_workers=None,
num_shared_workers=num_workers)
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
self._remote_agents: dict[str, dict[int, str]] = defaultdict(dict)
@ -384,8 +371,8 @@ class NixlConnectorWorker:
self._registered_descs: list[Any] = []
# In progress transfers.
# [req_id -> list[handles], agent_name, notif_id]
self._recving_transfers: dict[str, tuple[list[int], str, str]] = {}
# [req_id -> list[handle]]
self._recving_transfers = defaultdict[str, list[Transfer]](list)
# Complete transfer tracker. Used by the rank 0 to track finished
# transactions on ranks 1 to N-1.
@ -767,11 +754,8 @@ class NixlConnectorWorker:
to Rank 0 once their transaction is done + Rank 0 returns
finished sets to Scheduler only once all ranks are done.
"""
start = time.perf_counter()
done_sending = self._get_new_notifs()
done_recving = self._pop_done_transfers(self._recving_transfers)
if len(done_sending) > 0 or len(done_recving) > 0:
logger.debug(
"Rank %s, get_finished: %s requests done sending "
@ -812,10 +796,6 @@ class NixlConnectorWorker:
if self._done_sending_count[req_id] == self.world_size:
del self._done_sending_count[req_id]
all_done_sending.add(req_id)
end = time.perf_counter()
if LOG_XFER_TIME:
logger.info("========== .get_finished(): %s ==========", end - start)
return all_done_sending, all_done_recving
@ -825,10 +805,6 @@ class NixlConnectorWorker:
self.tp_group.send_object(finished_req_ids, dst=0)
# Unused as only Rank 0 results are sent to scheduler.
end = time.perf_counter()
if LOG_XFER_TIME:
logger.info("========== .get_finished(): %s ==========", end - start)
return done_sending, done_recving
def _get_new_notifs(self) -> set[str]:
@ -850,8 +826,7 @@ class NixlConnectorWorker:
return notified_req_ids
def _pop_done_transfers(
self, transfers: dict[str, tuple[list[int], str,
str]]) -> set[str]:
self, transfers: dict[str, list[tuple[int, float]]]) -> set[str]:
"""
Pop completed xfers by checking for DONE state.
Args:
@ -859,30 +834,19 @@ class NixlConnectorWorker:
Returns:
set of req_ids that have all done xfers
"""
done_req_ids: set[str, float] = set()
for req_id, (handles, agent_name, notif_id, start_time) in list(transfers.items()):
new_handles = []
for handle in handles:
done_req_ids: set[str] = set()
for req_id, handles in list(transfers.items()):
for handle, xfer_stime in handles:
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
if xfer_state == "DONE":
self.nixl_wrapper.release_xfer_handle(handle)
done_req_ids.add(req_id)
del transfers[req_id]
elif xfer_state == "PROC":
new_handles.append(handle)
continue
else:
raise RuntimeError("Transfer failed with state %s",
xfer_state)
# Done.
if len(new_handles) == 0:
self.nixl_wrapper.send_notif(agent_name, notif_id)
del transfers[req_id]
done_req_ids.add(req_id)
if LOG_XFER_TIME:
logger.info("========== transmission time: %s ==========", time.perf_counter() - start_time)
else:
transfers[req_id] = (new_handles, agent_name, notif_id, start_time)
return done_req_ids
def start_load_kv(self, metadata: NixlConnectorMetadata):
@ -994,35 +958,22 @@ class NixlConnectorWorker:
assert len(local_block_descs_ids) == len(remote_block_descs_ids)
# Prepare transfer with Nixl.
CHUNK_SIZE = 500
handles = []
# NOTE: this is a hack to make make_prepped_xfer into threads so that
# different workers are allocated for each chuck. Without this change,
# nixl was allocating the same worker (0) for all the chunks and the
# overall launch time was >300 ms.
for i in range(0, len(local_block_descs_ids), CHUNK_SIZE):
handle = self.nixl_wrapper.make_prepped_xfer(
"READ",
local_xfer_side_handle,
local_block_descs_ids[i:i + CHUNK_SIZE],
remote_xfer_side_handle,
remote_block_descs_ids[i:i + CHUNK_SIZE],
skip_desc_merge=True,
)
handles.append(handle)
handle = self.nixl_wrapper.make_prepped_xfer(
"READ",
local_xfer_side_handle,
local_block_descs_ids,
remote_xfer_side_handle,
remote_block_descs_ids,
notif_msg=notif_id,
)
# Begin async xfer.
start = time.perf_counter()
self.nixl_wrapper.transfer_batched(handles)
end = time.perf_counter()
if LOG_XFER_TIME:
logger.info("========== .transfer_batched(): %s ==========", end - start)
self.nixl_wrapper.transfer(handle)
# Keep track of ongoing transfers.
remote_rank = self.tp_rank // tp_ratio
agent_name = self._remote_agents[dst_engine_id][remote_rank]
assert request_id not in self._recving_transfers
self._recving_transfers[request_id] = (handles, agent_name, notif_id, time.perf_counter())
# Use handle to check completion in future step().
# TODO (NickLucche) surface xfer elapsed time
self._recving_transfers[request_id].append(
(handle, time.perf_counter()))
def _get_block_descs_ids(self,
engine_id: str,

View File

@ -0,0 +1,481 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
import regex as re
import torch
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine import (
P2pNcclEngine)
from vllm.distributed.parallel_state import get_world_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
logger = init_logger(__name__)
@dataclass
class ReqMeta:
# Request Id
request_id: str
# Request tokens
token_ids: torch.Tensor
# Slot mappings, should have the same length as token_ids
slot_mapping: torch.Tensor
@staticmethod
def make_meta(request_id: str, token_ids: list[int], block_ids: list[int],
block_size: int) -> "ReqMeta":
valid_num_tokens = len(token_ids)
token_ids_tensor = torch.tensor(token_ids)
block_ids_tensor = torch.tensor(block_ids)
num_blocks = block_ids_tensor.shape[0]
block_offsets = torch.arange(0, block_size)
slot_mapping = block_offsets.reshape((1, block_size)) + \
block_ids_tensor.reshape((num_blocks, 1)) * block_size
slot_mapping = slot_mapping.flatten()[:valid_num_tokens]
return ReqMeta(
request_id=request_id,
token_ids=token_ids_tensor,
slot_mapping=slot_mapping,
)
@dataclass
class P2pNcclConnectorMetadata(KVConnectorMetadata):
requests: list[ReqMeta]
def __init__(self):
self.requests = []
def add_request(
self,
request_id: str,
token_ids: list[int],
block_ids: list[int],
block_size: int,
) -> None:
self.requests.append(
ReqMeta.make_meta(request_id, token_ids, block_ids, block_size))
class P2pNcclConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
self._block_size = vllm_config.cache_config.block_size
self._requests_need_load: dict[str, Any] = {}
self.config = vllm_config.kv_transfer_config
self.is_producer = self.config.is_kv_producer
self.chunked_prefill: dict[str, Any] = {}
self._rank = get_world_group().rank \
if role == KVConnectorRole.WORKER else 0
self._local_rank = get_world_group().local_rank \
if role == KVConnectorRole.WORKER else 0
self.p2p_nccl_engine = P2pNcclEngine(
local_rank=self._local_rank,
config=self.config,
hostname="",
port_offset=self._rank,
) if role == KVConnectorRole.WORKER else None
# ==============================
# Worker-side methods
# ==============================
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
"""Start loading the KV cache from the connector buffer to vLLM's
paged KV buffer.
Args:
forward_context (ForwardContext): the forward context.
**kwargs: additional arguments for the load operation
Note:
The number of elements in kv_caches and layer_names should be
the same.
"""
# Only consumer/decode loads KV Cache
if self.is_producer:
return
assert self.p2p_nccl_engine is not None
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return
def inject_kv_into_layer(
dst_kv_cache_layer: torch.Tensor,
src_kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
request_id: str,
) -> None:
"""Inject the KV cache into the layer.
Args:
dst_kv_cache_layer (torch.Tensor): the destination KV cache
layer. In shape [2, num_pages, page_size, xxx] if not
using MLA, [num_pages, page_size, xxx] otherwise.
src_kv_cache (torch.Tensor): the source KV cache. In shape
[2, num_tokens, xxx] if not using MLA, [num_tokens, xxx]
otherwise.
slot_mapping (torch.Tensor): the slot mapping. In shape
[num_tokens].
request_id (str): request id for log
"""
dst_kv_cache_layer_shape = dst_kv_cache_layer.shape
if isinstance(attn_metadata, MLACommonMetadata):
num_pages = dst_kv_cache_layer_shape[0]
page_size = dst_kv_cache_layer_shape[1]
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
num_pages * page_size, -1)
self.check_tensors_except_dim(dst_kv_cache_layer, src_kv_cache,
0)
num_token = src_kv_cache.shape[0]
if len(slot_mapping) == num_token:
dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
else:
dst_kv_cache_layer[slot_mapping[:num_token],
...] = src_kv_cache
logger.warning(
"🚧src_kv_cache does not match, num_slot:%d, "
"num_token:%d, request_id:%s", len(slot_mapping),
num_token, request_id)
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
else:
num_pages = dst_kv_cache_layer_shape[1]
page_size = dst_kv_cache_layer_shape[2]
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
2, num_pages * page_size, -1)
self.check_tensors_except_dim(dst_kv_cache_layer, src_kv_cache,
1)
num_token = src_kv_cache.shape[1]
if len(slot_mapping) == num_token:
dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache
else:
dst_kv_cache_layer[:, slot_mapping[:num_token],
...] = src_kv_cache
logger.warning(
"🚧src_kv_cache does not match, num_slot:%d, "
"num_token:%d, request_id:%s", len(slot_mapping),
num_token, request_id)
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
# Get the metadata
metadata: KVConnectorMetadata = \
self._get_connector_metadata()
assert isinstance(metadata, P2pNcclConnectorMetadata)
if metadata is None:
return
# Load the KV for each request each layer
for request in metadata.requests:
for layer_name in forward_context.no_compile_layers:
attn_layer = forward_context.no_compile_layers[layer_name]
kv_cache_layer = attn_layer.kv_cache[ \
forward_context.virtual_engine]
kv_cache = self.p2p_nccl_engine.recv_tensor(
request.request_id + "#" + layer_name)
if kv_cache is None:
logger.warning("🚧src_kv_cache is None, %s",
request.request_id)
continue
inject_kv_into_layer(kv_cache_layer, kv_cache,
request.slot_mapping, request.request_id)
def wait_for_layer_load(self, layer_name: str) -> None:
"""Blocking until the KV for a specific layer is loaded into vLLM's
paged buffer.
This interface will be useful for layer-by-layer pipelining.
Args:
layer_name: the name of that layer
"""
return
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", **kwargs) -> None:
"""Start saving the KV cache of the layer from vLLM's paged buffer
to the connector.
Args:
layer_name (str): the name of the layer.
kv_layer (torch.Tensor): the paged KV buffer of the current
layer in vLLM.
attn_metadata (AttentionMetadata): the attention metadata.
**kwargs: additional arguments for the save operation.
"""
# Only producer/prefill saves KV Cache
if not self.is_producer:
return
assert self.p2p_nccl_engine is not None
def extract_kv_from_layer(
layer: torch.Tensor,
slot_mapping: torch.Tensor,
) -> torch.Tensor:
"""Extract the KV cache from the layer.
Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
"""
if isinstance(attn_metadata, MLACommonMetadata):
num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping,
...]
num_pages, page_size = layer.shape[1], layer.shape[2]
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping,
...]
connector_metadata = self._get_connector_metadata()
assert isinstance(connector_metadata, P2pNcclConnectorMetadata)
for request in connector_metadata.requests:
request_id = request.request_id
ip, port = self.parse_request_id(request_id, True)
remote_address = ip + ":" + str(port + self._rank)
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
def wait_for_save(self):
if self.is_producer:
assert self.p2p_nccl_engine is not None
self.p2p_nccl_engine.wait_for_sent()
def get_finished(
self, finished_req_ids: set[str],
**kwargs) -> tuple[Optional[set[str]], Optional[set[str]]]:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens.
Returns:
ids of requests that have finished asynchronous transfer,
tuple of (sending/saving ids, recving/loading ids).
The finished saves/sends req ids must belong to a set provided in a
call to this method (this call or a prior one).
"""
assert self.p2p_nccl_engine is not None
forward_context: ForwardContext = get_forward_context()
return self.p2p_nccl_engine.get_finished(finished_req_ids,
forward_context)
# ==============================
# Scheduler-side methods
# ==============================
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int, bool]:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
"""
if self.is_producer:
return 0, False
num_external_tokens = (len(request.prompt_token_ids) - 1 -
num_computed_tokens)
if num_external_tokens < 0:
num_external_tokens = 0
return num_external_tokens, False
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
"""
Update KVConnector state after block allocation.
"""
if not self.is_producer and num_external_tokens > 0:
self._requests_need_load[request.request_id] = (
request, blocks.get_block_ids()[0])
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
"""Build the connector metadata for this step.
This function should NOT modify any fields in the scheduler_output.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
meta = P2pNcclConnectorMetadata()
for new_req in scheduler_output.scheduled_new_reqs:
if self.is_producer:
num_scheduled_tokens = (
scheduler_output.num_scheduled_tokens)[new_req.req_id]
num_tokens = num_scheduled_tokens + new_req.num_computed_tokens
# the request's prompt is chunked prefill
if num_tokens < len(new_req.prompt_token_ids):
# 'CachedRequestData' has no attribute 'prompt_token_ids'
self.chunked_prefill[new_req.req_id] = (
new_req.block_ids[0], new_req.prompt_token_ids)
continue
# the request's prompt is not chunked prefill
meta.add_request(request_id=new_req.req_id,
token_ids=new_req.prompt_token_ids,
block_ids=new_req.block_ids[0],
block_size=self._block_size)
continue
if new_req.req_id in self._requests_need_load:
meta.add_request(request_id=new_req.req_id,
token_ids=new_req.prompt_token_ids,
block_ids=new_req.block_ids[0],
block_size=self._block_size)
self._requests_need_load.pop(new_req.req_id)
for cached_req in scheduler_output.scheduled_cached_reqs:
if self.is_producer:
num_scheduled_tokens = (
scheduler_output.num_scheduled_tokens)[cached_req.req_id]
num_tokens = (num_scheduled_tokens +
cached_req.num_computed_tokens)
assert cached_req.req_id in self.chunked_prefill
block_ids = cached_req.new_block_ids[0]
if not cached_req.resumed_from_preemption:
block_ids = (self.chunked_prefill[cached_req.req_id][0] +
block_ids)
prompt_token_ids = self.chunked_prefill[cached_req.req_id][1]
# the request's prompt is chunked prefill again
if num_tokens < len(prompt_token_ids):
self.chunked_prefill[cached_req.req_id] = (
block_ids, prompt_token_ids)
continue
# the request's prompt is all prefilled finally
meta.add_request(request_id=cached_req.req_id,
token_ids=prompt_token_ids,
block_ids=block_ids,
block_size=self._block_size)
self.chunked_prefill.pop(cached_req.req_id, None)
continue
# NOTE(rob): here we rely on the resumed requests being
# the first N requests in the list scheduled_cache_reqs.
if not cached_req.resumed_from_preemption:
break
if cached_req.req_id in self._requests_need_load:
request, _ = self._requests_need_load.pop(cached_req.req_id)
total_tokens = cached_req.num_computed_tokens + 1
token_ids = request.all_token_ids[:total_tokens]
# NOTE(rob): For resumed req, new_block_ids is all
# of the block_ids for the request.
block_ids = cached_req.new_block_ids[0]
meta.add_request(request_id=cached_req.req_id,
token_ids=token_ids,
block_ids=block_ids,
block_size=self._block_size)
# Requests loaded asynchronously are not in the scheduler_output.
# for request_id in self._requests_need_load:
# request, block_ids = self._requests_need_load[request_id]
# meta.add_request(request_id=request.request_id,
# token_ids=request.prompt_token_ids,
# block_ids=block_ids,
# block_size=self._block_size)
self._requests_need_load.clear()
return meta
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
"""
Called when a request has finished, before its blocks are freed.
Returns:
True if the request is being saved/sent asynchronously and blocks
should not be freed until the request_id is returned from
get_finished().
Optional KVTransferParams to be included in the request outputs
returned by the engine.
"""
self.chunked_prefill.pop(request.request_id, None)
return False, None
# ==============================
# Static methods
# ==============================
@staticmethod
def parse_request_id(request_id: str, is_prefill=True) -> tuple[str, int]:
# Regular expression to match the string hostname and integer port
if is_prefill:
pattern = r"___decode_addr_(.*):(\d+)"
else:
pattern = r"___prefill_addr_(.*):(\d+)___"
# Use re.search to find the pattern in the request_id
match = re.search(pattern, request_id)
if match:
# Extract the ranks
ip = match.group(1)
port = int(match.group(2))
return ip, port
raise ValueError(
f"Request id {request_id} does not contain hostname and port")
@staticmethod
def check_tensors_except_dim(tensor1, tensor2, dim):
shape1 = tensor1.size()
shape2 = tensor2.size()
if len(shape1) != len(shape2) or not all(
s1 == s2
for i, (s1, s2) in enumerate(zip(shape1, shape2)) if i != dim):
raise NotImplementedError(
"Currently, only symmetric TP is supported. Asymmetric TP, PP,"
"and others will be supported in future PRs.")

View File

@ -0,0 +1,531 @@
# SPDX-License-Identifier: Apache-2.0
import logging
import os
import threading
import time
import typing
from collections import deque
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Optional
import msgpack
import torch
import zmq
from vllm.config import KVTransferConfig
from vllm.distributed.device_communicators.pynccl_wrapper import (
NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum)
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501
TensorMemoryPool)
from vllm.utils import current_stream, get_ip
if TYPE_CHECKING:
from vllm.forward_context import ForwardContext
logger = logging.getLogger(__name__)
DEFAULT_MEM_POOL_SIZE_GB = 32
@contextmanager
def set_p2p_nccl_context(num_channels: str):
original_values: dict[str, Any] = {}
env_vars = [
'NCCL_MAX_NCHANNELS',
'NCCL_MIN_NCHANNELS',
'NCCL_CUMEM_ENABLE',
'NCCL_BUFFSIZE',
'NCCL_PROTO', # LL,LL128,SIMPLE
'NCCL_ALGO', # RING,TREE
]
for var in env_vars:
original_values[var] = os.environ.get(var)
logger.info("set_p2p_nccl_context, original_values: %s", original_values)
try:
os.environ['NCCL_MAX_NCHANNELS'] = num_channels
os.environ['NCCL_MIN_NCHANNELS'] = num_channels
os.environ['NCCL_CUMEM_ENABLE'] = '1'
yield
finally:
for var in env_vars:
if original_values[var] is not None:
os.environ[var] = original_values[var]
else:
os.environ.pop(var, None)
class P2pNcclEngine:
def __init__(self,
local_rank: int,
config: KVTransferConfig,
hostname: str = "",
port_offset: int = 0,
library_path: Optional[str] = None) -> None:
self.config = config
self.rank = port_offset
self.local_rank = local_rank
self.device = torch.device(f"cuda:{self.local_rank}")
self.nccl = NCCLLibrary(library_path)
if not hostname:
hostname = get_ip()
port = int(self.config.kv_port) + port_offset
if port == 0:
raise ValueError("Port cannot be 0")
self._hostname = hostname
self._port = port
# Each card corresponds to a ZMQ address.
self.zmq_address = f"{self._hostname}:{self._port}"
# The `http_port` must be consistent with the port of OpenAI.
self.http_address = (
f"{self._hostname}:"
f"{self.config.kv_connector_extra_config['http_port']}")
# If `proxy_ip` or `proxy_port` is `""`,
# then the ping thread will not be enabled.
proxy_ip = self.config.get_from_extra_config("proxy_ip", "")
proxy_port = self.config.get_from_extra_config("proxy_port", "")
if proxy_ip == "" or proxy_port == "":
self.proxy_address = ""
else:
self.proxy_address = proxy_ip + ":" + proxy_port
self.context = zmq.Context()
self.router_socket = self.context.socket(zmq.ROUTER)
self.router_socket.bind(f"tcp://{self.zmq_address}")
self.poller = zmq.Poller()
self.poller.register(self.router_socket, zmq.POLLIN)
self.send_store_cv = threading.Condition()
self.send_queue_cv = threading.Condition()
self.recv_store_cv = threading.Condition()
self.send_stream = torch.cuda.Stream()
self.recv_stream = torch.cuda.Stream()
mem_pool_size_gb = self.config.get_from_extra_config(
"mem_pool_size_gb", DEFAULT_MEM_POOL_SIZE_GB)
self.pool = TensorMemoryPool(max_block_size=int(mem_pool_size_gb) *
1024**3) # GB
# The sending type includes tree mutually exclusive options:
# PUT, GET, PUT_ASYNC.
self.send_type = self.config.get_from_extra_config("send_type", "PUT")
if self.send_type == "GET":
# tensor_id: torch.Tensor
self.send_store: dict[str, torch.Tensor] = {}
else:
# PUT or PUT_ASYNC
# tensor_id: torch.Tensor
self.send_queue: deque[list[Any]] = deque()
self.send_request_id_to_tensor_ids: dict[str, set[str]] = {}
if self.send_type == "PUT_ASYNC":
self._send_thread = threading.Thread(target=self._send_async,
daemon=True)
self._send_thread.start()
# tensor_id: torch.Tensor/(addr, dtype, shape)
self.recv_store: dict[str, Any] = {}
self.recv_request_id_to_tensor_ids: dict[str, set[str]] = {}
self.socks: dict[str, Any] = {} # remote_address: client socket
self.comms: dict[str, Any] = {} # remote_address: (ncclComm_t, rank)
self.buffer_size = 0
self.buffer_size_threshold = float(self.config.kv_buffer_size)
self.nccl_num_channels = self.config.get_from_extra_config(
"nccl_num_channels", "8")
self._listener_thread = threading.Thread(
target=self._listen_for_requests, daemon=True)
self._listener_thread.start()
self._ping_thread = None
if port_offset == 0 and self.proxy_address != "":
self._ping_thread = threading.Thread(target=self._ping,
daemon=True)
self._ping_thread.start()
logger.info(
"💯P2pNcclEngine init, rank:%d, local_rank:%d, http_address:%s, "
"zmq_address:%s, proxy_address:%s, send_type:%s, buffer_size_"
"threshold:%.2f, nccl_num_channels:%s", self.rank, self.local_rank,
self.http_address, self.zmq_address, self.proxy_address,
self.send_type, self.buffer_size_threshold, self.nccl_num_channels)
def _create_connect(self, remote_address: typing.Optional[str] = None):
assert remote_address is not None
if remote_address not in self.socks:
sock = self.context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
sock.connect(f"tcp://{remote_address}")
self.socks[remote_address] = sock
if remote_address in self.comms:
logger.info("👋comm exists, remote_address:%s, comms:%s",
remote_address, self.comms)
return sock, self.comms[remote_address]
unique_id = self.nccl.ncclGetUniqueId()
data = {"cmd": "NEW", "unique_id": bytes(unique_id.internal)}
sock.send(msgpack.dumps(data))
with torch.cuda.device(self.device):
rank = 0
with set_p2p_nccl_context(self.nccl_num_channels):
comm: ncclComm_t = self.nccl.ncclCommInitRank(
2, unique_id, rank)
self.comms[remote_address] = (comm, rank)
logger.info("🤝ncclCommInitRank Success, %s👉%s, MyRank: %s",
self.zmq_address, remote_address, rank)
return self.socks[remote_address], self.comms[remote_address]
def send_tensor(
self,
tensor_id: str,
tensor: torch.Tensor,
remote_address: typing.Optional[str] = None,
) -> bool:
if remote_address is None:
with self.recv_store_cv:
self.recv_store[tensor_id] = tensor
self.recv_store_cv.notify()
return True
else:
if self.send_type == "PUT":
return self._send_sync(tensor_id, tensor, remote_address)
elif self.send_type == "PUT_ASYNC":
with self.send_queue_cv:
self.send_queue.append([tensor_id, remote_address, tensor])
self.send_queue_cv.notify()
else: # GET
with self.send_store_cv:
tensor_size = tensor.element_size() * tensor.numel()
while (self.buffer_size + tensor_size
> self.buffer_size_threshold):
oldest_tenser_id = next(iter(self.send_store))
oldest_tenser = self.send_store.pop(oldest_tenser_id)
oldest_tenser_size = oldest_tenser.element_size(
) * oldest_tenser.numel()
self.buffer_size -= oldest_tenser_size
logger.info(
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
" buffer_size:%d, oldest_tenser_size:%d, rank:%d",
remote_address, tensor_id, tensor_size,
self.buffer_size, oldest_tenser_size, self.rank)
self.send_store[tensor_id] = tensor
self.buffer_size += tensor_size
logger.debug(
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)",
remote_address, tensor_id, tensor_size, tensor.shape,
self.rank, self.buffer_size,
self.buffer_size / self.buffer_size_threshold * 100)
return True
def recv_tensor(
self,
tensor_id: str,
remote_address: typing.Optional[str] = None,
) -> torch.Tensor:
if self.send_type == "PUT" or self.send_type == "PUT_ASYNC":
start_time = time.time()
with self.recv_store_cv:
while tensor_id not in self.recv_store:
self.recv_store_cv.wait()
tensor = self.recv_store[tensor_id]
if tensor is not None:
if isinstance(tensor, tuple):
addr, dtype, shape = tensor
tensor = self.pool.load_tensor(addr, dtype, shape,
self.device)
else:
self.buffer_size -= (tensor.element_size() *
tensor.numel())
else:
duration = time.time() - start_time
logger.warning(
"🔴[PUT]Recv From %s, tensor_id:%s, duration:%.3fms, "
"rank:%d", remote_address, tensor_id, duration * 1000,
self.rank)
return tensor
# GET
if remote_address is None:
return None
if remote_address not in self.socks:
self._create_connect(remote_address)
sock = self.socks[remote_address]
comm, rank = self.comms[remote_address]
data = {"cmd": "GET", "tensor_id": tensor_id}
sock.send(msgpack.dumps(data))
message = sock.recv()
data = msgpack.loads(message)
if data["ret"] != 0:
logger.warning("🔴[GET]Recv From %s, tensor_id: %s, ret: %d",
remote_address, tensor_id, data["ret"])
return None
tensor = torch.empty(data["shape"],
dtype=getattr(torch, data["dtype"]),
device=self.device)
self._recv(comm, tensor, rank ^ 1, self.recv_stream)
return tensor
def _listen_for_requests(self):
while True:
socks = dict(self.poller.poll())
if self.router_socket in socks:
remote_address, message = self.router_socket.recv_multipart()
data = msgpack.loads(message)
if data["cmd"] == "NEW":
unique_id = self.nccl.unique_id_from_bytes(
bytes(data["unique_id"]))
with torch.cuda.device(self.device):
rank = 1
with set_p2p_nccl_context(self.nccl_num_channels):
comm: ncclComm_t = self.nccl.ncclCommInitRank(
2, unique_id, rank)
self.comms[remote_address.decode()] = (comm, rank)
logger.info(
"🤝ncclCommInitRank Success, %s👈%s, MyRank:%s",
self.zmq_address, remote_address.decode(), rank)
elif data["cmd"] == "PUT":
tensor_id = data["tensor_id"]
try:
tensor = torch.empty(data["shape"],
dtype=getattr(
torch, data["dtype"]),
device=self.device)
self.router_socket.send_multipart(
[remote_address, b"0"])
comm, rank = self.comms[remote_address.decode()]
self._recv(comm, tensor, rank ^ 1, self.recv_stream)
tensor_size = tensor.element_size() * tensor.numel()
if (self.buffer_size + tensor_size
> self.buffer_size_threshold):
# Store Tensor in memory pool
addr = self.pool.store_tensor(tensor)
tensor = (addr, tensor.dtype, tensor.shape)
logger.warning(
"🔴[PUT]Recv Tensor, Out Of Threshold, "
"%s👈%s, data:%s, addr:%d", self.zmq_address,
remote_address.decode(), data, addr)
else:
self.buffer_size += tensor_size
except torch.cuda.OutOfMemoryError:
self.router_socket.send_multipart(
[remote_address, b"1"])
tensor = None
logger.warning(
"🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, "
"data:%s", self.zmq_address,
remote_address.decode(), data)
with self.recv_store_cv:
self.recv_store[tensor_id] = tensor
self._have_received_tensor_id(tensor_id)
self.recv_store_cv.notify()
elif data["cmd"] == "GET":
tensor_id = data["tensor_id"]
with self.send_store_cv:
tensor = self.send_store.pop(tensor_id, None)
if tensor is not None:
data = {
"ret": 0,
"shape": tensor.shape,
"dtype":
str(tensor.dtype).replace("torch.", "")
}
# LRU
self.send_store[tensor_id] = tensor
self._have_sent_tensor_id(tensor_id)
else:
data = {"ret": 1}
self.router_socket.send_multipart(
[remote_address, msgpack.dumps(data)])
if data["ret"] == 0:
comm, rank = self.comms[remote_address.decode()]
self._send(comm, tensor.to(self.device), rank ^ 1,
self.send_stream)
else:
logger.warning(
"🚧Unexpected, Received message from %s, data:%s",
remote_address, data)
def _have_sent_tensor_id(self, tensor_id: str):
request_id = tensor_id.split('#')[0]
if request_id not in self.send_request_id_to_tensor_ids:
self.send_request_id_to_tensor_ids[request_id] = set()
self.send_request_id_to_tensor_ids[request_id].add(tensor_id)
def _have_received_tensor_id(self, tensor_id: str):
request_id = tensor_id.split('#')[0]
if request_id not in self.recv_request_id_to_tensor_ids:
self.recv_request_id_to_tensor_ids[request_id] = set()
self.recv_request_id_to_tensor_ids[request_id].add(tensor_id)
def _send_async(self):
while True:
with self.send_queue_cv:
while not self.send_queue:
self.send_queue_cv.wait()
tensor_id, remote_address, tensor = self.send_queue.popleft()
if not self.send_queue:
self.send_queue_cv.notify()
self._send_sync(tensor_id, tensor, remote_address)
def wait_for_sent(self):
if self.send_type == "PUT_ASYNC":
start_time = time.time()
with self.send_queue_cv:
while self.send_queue:
self.send_queue_cv.wait()
duration = time.time() - start_time
logger.debug(
"🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue"
" to be empty, rank:%d", duration * 1000, self.rank)
def _send_sync(
self,
tensor_id: str,
tensor: torch.Tensor,
remote_address: typing.Optional[str] = None,
) -> bool:
if remote_address is None:
return False
if remote_address not in self.socks:
self._create_connect(remote_address)
sock = self.socks[remote_address]
comm, rank = self.comms[remote_address]
data = {
"cmd": "PUT",
"tensor_id": tensor_id,
"shape": tensor.shape,
"dtype": str(tensor.dtype).replace("torch.", "")
}
sock.send(msgpack.dumps(data))
response = sock.recv()
if response != b"0":
logger.error(
"🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, "
"MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s",
self.zmq_address, remote_address, rank, data, tensor.shape,
tensor.element_size() * tensor.numel() / 1024**3,
response.decode())
return False
self._send(comm, tensor.to(self.device), rank ^ 1, self.send_stream)
if self.send_type == "PUT_ASYNC":
self._have_sent_tensor_id(tensor_id)
return True
def get_finished(
self, finished_req_ids: set[str], forward_context: "ForwardContext"
) -> tuple[Optional[set[str]], Optional[set[str]]]:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens.
Returns:
ids of requests that have finished asynchronous transfer,
tuple of (sending/saving ids, recving/loading ids).
The finished saves/sends req ids must belong to a set provided in a
call to this method (this call or a prior one).
"""
# Clear the buffer upon request completion.
for request_id in finished_req_ids:
for layer_name in forward_context.no_compile_layers:
tensor_id = request_id + "#" + layer_name
if tensor_id in self.recv_store:
with self.recv_store_cv:
tensor = self.recv_store.pop(tensor_id, None)
self.send_request_id_to_tensor_ids.pop(
request_id, None)
self.recv_request_id_to_tensor_ids.pop(
request_id, None)
addr = 0
if isinstance(tensor, tuple):
addr, _, _ = tensor
self.pool.free(addr)
# TODO:Retrieve requests that have already sent the KV cache.
finished_sending: set[str] = set()
# TODO:Retrieve requests that have already received the KV cache.
finished_recving: set[str] = set()
return finished_sending or None, finished_recving or None
def _ping(self):
sock = self.context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
logger.debug("ping start, zmq_address:%s", self.zmq_address)
sock.connect(f"tcp://{self.proxy_address}")
data = {
"type": "P" if self.config.is_kv_producer else "D",
"http_address": self.http_address,
"zmq_address": self.zmq_address
}
while True:
sock.send(msgpack.dumps(data))
time.sleep(3)
def _send(self, comm, tensor: torch.Tensor, dst: int, stream=None):
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = current_stream()
with torch.cuda.stream(stream):
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), dst,
comm, cudaStream_t(stream.cuda_stream))
stream.synchronize()
def _recv(self, comm, tensor: torch.Tensor, src: int, stream=None):
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = current_stream()
with torch.cuda.stream(stream):
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), src,
comm, cudaStream_t(stream.cuda_stream))
stream.synchronize()
def close(self) -> None:
self._listener_thread.join()
if self.send_type == "PUT_ASYNC":
self._send_thread.join()
if self._ping_thread is not None:
self._ping_thread.join()

View File

@ -0,0 +1,264 @@
# SPDX-License-Identifier: Apache-2.0
import atexit
import ctypes
import math
from dataclasses import dataclass
import torch
from vllm.logger import init_logger
logger = init_logger(__name__)
@dataclass
class MemoryBlock:
size: int
addr: int
"""A memory pool for managing pinned host memory allocations for tensors.
This class implements a buddy allocation system to efficiently manage pinned
host memory for tensor storage. It supports allocation, deallocation, and
tensor storage/retrieval operations.
Key Features:
- Uses power-of-two block sizes for efficient buddy allocation
- Supports splitting and merging of memory blocks
- Provides methods to store CUDA tensors in pinned host memory
- Allows loading tensors from pinned memory back to device
- Automatically cleans up memory on destruction
Attributes:
max_block_size (int): Maximum block size (rounded to nearest power of two)
min_block_size (int): Minimum block size (rounded to nearest power of two)
free_lists (dict): Dictionary of free memory blocks by size
allocated_blocks (dict): Dictionary of currently allocated blocks
base_tensor (torch.Tensor): Base pinned memory tensor
base_address (int): Base memory address of the pinned memory region
Example:
>>> pool = TensorMemoryPool(max_block_size=1024*1024)
>>> tensor = torch.randn(100, device='cuda')
>>> addr = pool.store_tensor(tensor)
>>> loaded_tensor = pool.load_tensor(addr, tensor.dtype,
... tensor.shape, 'cuda')
>>> pool.free(addr)
"""
class TensorMemoryPool:
"""Initializes the memory pool with given size constraints.
Args:
max_block_size (int): Maximum size of memory blocks to manage
min_block_size (int, optional): Minimum size of memory blocks
to manage. Defaults to 512.
Raises:
ValueError: If block sizes are invalid or max_block_size is less
than min_block_size
"""
def __init__(self, max_block_size: int, min_block_size: int = 512):
if max_block_size <= 0 or min_block_size <= 0:
raise ValueError("Block sizes must be positive")
if max_block_size < min_block_size:
raise ValueError(
"Max block size must be greater than min block size")
self.max_block_size = self._round_to_power_of_two(max_block_size)
self.min_block_size = self._round_to_power_of_two(min_block_size)
self.free_lists: dict[int, dict[int, MemoryBlock]] = {}
self.allocated_blocks: dict[int, MemoryBlock] = {}
self._initialize_free_lists()
self._allocate_pinned_memory()
atexit.register(self.cleanup)
def _round_to_power_of_two(self, size: int) -> int:
return 1 << (size - 1).bit_length()
def _initialize_free_lists(self):
size = self.max_block_size
while size >= self.min_block_size:
self.free_lists[size] = {}
size //= 2
def _allocate_pinned_memory(self):
self.base_tensor = torch.empty(self.max_block_size // 4,
dtype=torch.float32,
pin_memory=True)
self.base_address = self.base_tensor.data_ptr()
initial_block = MemoryBlock(size=self.max_block_size,
addr=self.base_address)
self.free_lists[self.max_block_size][
initial_block.addr] = initial_block
logger.debug("TensorMemoryPool, base_address:", self.base_address,
self.base_address % self.max_block_size)
def allocate(self, size: int) -> int:
"""Allocates a memory block of at least the requested size.
Args:
size (int): Minimum size of memory to allocate
Returns:
int: Address of the allocated memory block
Raises:
ValueError: If size is invalid or insufficient memory is available
"""
if size <= 0:
raise ValueError("Allocation size must be positive")
required_size = self._round_to_power_of_two(
max(size, self.min_block_size))
if required_size > self.max_block_size:
raise ValueError("Requested size exceeds maximum block size")
current_size = required_size
while current_size <= self.max_block_size:
if self.free_lists[current_size]:
_, block = self.free_lists[current_size].popitem()
self._split_block(block, required_size)
self.allocated_blocks[block.addr] = block
return block.addr
current_size *= 2
raise ValueError("Insufficient memory")
def _split_block(self, block: MemoryBlock, required_size: int):
while (block.size > required_size
and block.size // 2 >= self.min_block_size):
buddy_size = block.size // 2
buddy_addr = block.addr + buddy_size
buddy = MemoryBlock(size=buddy_size, addr=buddy_addr)
block.size = buddy_size
self.free_lists[buddy_size][buddy.addr] = buddy
def free(self, addr: int):
"""Frees an allocated memory block.
Args:
addr (int): Address of the block to free
Raises:
ValueError: If address is invalid or not allocated
"""
if addr not in self.allocated_blocks:
raise ValueError("Invalid address to free")
block = self.allocated_blocks.pop(addr)
self._merge_buddies(block)
def _merge_buddies(self, block: MemoryBlock):
MAX_MERGE_DEPTH = 30
depth = 0
while depth < MAX_MERGE_DEPTH:
buddy_offset = block.size if (block.addr - self.base_address) % (
2 * block.size) == 0 else -block.size
buddy_addr = block.addr + buddy_offset
buddy = self.free_lists[block.size].get(buddy_addr)
if buddy:
del self.free_lists[buddy.size][buddy.addr]
merged_addr = min(block.addr, buddy.addr)
merged_size = block.size * 2
block = MemoryBlock(size=merged_size, addr=merged_addr)
depth += 1
else:
break
self.free_lists[block.size][block.addr] = block
def store_tensor(self, tensor: torch.Tensor) -> int:
"""Stores a CUDA tensor in pinned host memory.
Args:
tensor (torch.Tensor): CUDA tensor to store
Returns:
int: Address where the tensor is stored
Raises:
ValueError: If tensor is not on CUDA or allocation fails
"""
if not tensor.is_cuda:
raise ValueError("Only CUDA tensors can be stored")
size = tensor.element_size() * tensor.numel()
addr = self.allocate(size)
block = self.allocated_blocks[addr]
if block.size < size:
self.free(addr)
raise ValueError(
f"Allocated block size {block.size} is smaller than "
f"required size {size}")
try:
buffer = (ctypes.c_byte * block.size).from_address(block.addr)
cpu_tensor = torch.frombuffer(buffer,
dtype=tensor.dtype,
count=tensor.numel()).reshape(
tensor.shape)
except ValueError as err:
self.free(addr)
raise ValueError(f"Failed to create tensor view: {err}") from err
cpu_tensor.copy_(tensor)
return addr
def load_tensor(self, addr: int, dtype: torch.dtype,
shape: tuple[int, ...], device) -> torch.Tensor:
"""Loads a tensor from pinned host memory to the specified device.
Args:
addr (int): Address where tensor is stored
dtype (torch.dtype): Data type of the tensor
shape (tuple[int, ...]): Shape of the tensor
device: Target device for the loaded tensor
Returns:
torch.Tensor: The loaded tensor on the specified device
Raises:
ValueError: If address is invalid or sizes don't match
"""
if addr not in self.allocated_blocks:
raise ValueError("Invalid address to load")
block = self.allocated_blocks[addr]
num_elements = math.prod(shape)
dtype_size = torch.tensor([], dtype=dtype).element_size()
required_size = num_elements * dtype_size
if required_size > block.size:
raise ValueError("Requested tensor size exceeds block size")
buffer = (ctypes.c_byte * block.size).from_address(block.addr)
cpu_tensor = torch.frombuffer(buffer, dtype=dtype,
count=num_elements).reshape(shape)
cuda_tensor = torch.empty(shape, dtype=dtype, device=device)
cuda_tensor.copy_(cpu_tensor)
return cuda_tensor
def cleanup(self):
"""Cleans up all memory resources and resets the pool state."""
self.free_lists.clear()
self.allocated_blocks.clear()
if hasattr(self, 'base_tensor'):
del self.base_tensor
def __del__(self):
self.cleanup()

View File

@ -1018,7 +1018,8 @@ class EngineArgs:
from vllm.platforms import current_platform
current_platform.pre_register_and_update()
device_config = DeviceConfig(device=current_platform.device_type)
device_config = DeviceConfig(
device=cast(Device, current_platform.device_type))
model_config = self.create_model_config()
# * If VLLM_USE_V1 is unset, we enable V1 for "supported features"
@ -1302,7 +1303,7 @@ class EngineArgs:
# Skip this check if we are running on a non-GPU platform,
# or if the device capability is not available
# (e.g. in a Ray actor without GPUs).
from vllm.platforms import CpuArchEnum, current_platform
from vllm.platforms import current_platform
if (current_platform.is_cuda()
and current_platform.get_device_capability()
and current_platform.get_device_capability().major < 8):
@ -1444,14 +1445,10 @@ class EngineArgs:
_raise_or_fallback(feature_name=name, recommend_to_remove=False)
return False
# Non-[CUDA, TPU, x86 CPU] may be supported on V1,
# but off by default for now.
v0_hardware = not any(
(current_platform.is_cuda_alike(), current_platform.is_tpu(),
(current_platform.is_cpu()
and current_platform.get_cpu_architecture() == CpuArchEnum.X86)))
if v0_hardware and _warn_or_fallback( # noqa: SIM103
current_platform.device_name):
# The platform may be supported on V1, but off by default for now.
if not current_platform.default_v1( # noqa: SIM103
model_config=model_config) and _warn_or_fallback(
current_platform.device_name):
return False
#############################################################

View File

@ -87,6 +87,7 @@ if TYPE_CHECKING:
VLLM_ROCM_USE_AITER_MOE: bool = True
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
VLLM_ROCM_USE_AITER_MLA: bool = True
VLLM_ROCM_USE_AITER_MHA: bool = True
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ROCM_MOE_PADDING: bool = True
@ -653,6 +654,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ROCM_USE_AITER_MLA":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MLA", "True").lower() in
("true", "1")),
# Whether to use aiter mha ops.
# By default is enabled.
"VLLM_ROCM_USE_AITER_MHA":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in
("true", "1")),
# use rocm skinny gemms
"VLLM_ROCM_USE_SKINNY_GEMM":
lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in

View File

@ -557,8 +557,17 @@ class RayDistributedExecutor(DistributedExecutorBase):
def _compiled_ray_dag(self, enable_asyncio: bool):
assert self.parallel_config.use_ray
self._check_ray_cgraph_installation()
# Enlarge the default value of "RAY_CGRAPH_get_timeout" to 300 seconds
# (it is 10 seconds by default). This is a Ray environment variable to
# control the timeout of getting result from a compiled graph execution,
# i.e., the distributed execution that includes model forward runs and
# intermediate tensor communications, in the case of vllm.
# Note: we should set this env var before importing
# ray.dag, otherwise it will not take effect.
os.environ.setdefault("RAY_CGRAPH_get_timeout", "300") # noqa: SIM112
from ray.dag import InputNode, MultiOutputNode
logger.info("RAY_CGRAPH_get_timeout is set to %s",
os.environ["RAY_CGRAPH_get_timeout"]) # noqa: SIM112
logger.info("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE = %s",
envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE)
logger.info("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = %s",
@ -570,15 +579,6 @@ class RayDistributedExecutor(DistributedExecutorBase):
"Invalid value for VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: "
f"{channel_type}. Valid values are: 'auto', 'nccl', or 'shm'.")
# Enlarge the default value of "RAY_CGRAPH_get_timeout" to 300 seconds
# (it is 10 seconds by default). This is a Ray environment variable to
# control the timeout of getting result from a compiled graph execution,
# i.e., the distributed execution that includes model forward runs and
# intermediate tensor communications, in the case of vllm.
os.environ.setdefault("RAY_CGRAPH_get_timeout", "300") # noqa: SIM112
logger.info("RAY_CGRAPH_get_timeout is set to %s",
os.environ["RAY_CGRAPH_get_timeout"]) # noqa: SIM112
with InputNode() as input_data:
# Example DAG: PP=2, TP=4
#

View File

@ -6,14 +6,179 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, per_token_group_quant_fp8)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.triton_utils import tl, triton
logger = init_logger(__name__)
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
@triton.jit
def _silu_mul_fp8_quant_deep_gemm(
# Pointers ------------------------------------------------------------
input_ptr, # 16-bit activations (E, T, 2*H)
y_q_ptr, # fp8 quantized activations (E, T, H)
y_s_ptr, # 16-bit scales (E, T, G)
counts_ptr, # int32 num tokens per expert (E)
# Sizes ---------------------------------------------------------------
H: tl.constexpr, # hidden dimension (per output)
GROUP_SIZE: tl.constexpr, # elements per group (usually 128)
# Strides for input (elements) ---------------------------------------
stride_i_e,
stride_i_t,
stride_i_h,
# Strides for y_q (elements) -----------------------------------------
stride_yq_e,
stride_yq_t,
stride_yq_h,
# Strides for y_s (elements) -----------------------------------------
stride_ys_e,
stride_ys_t,
stride_ys_g,
# Stride for counts (elements)
stride_counts_e,
# Numeric params ------------------------------------------------------
eps: tl.constexpr,
fp8_min: tl.constexpr,
fp8_max: tl.constexpr,
# Meta ---------------------------------------------------------------
BLOCK: tl.constexpr,
):
G = H // GROUP_SIZE
# map program id -> (e, g)
pid = tl.program_id(0)
e = pid // G
g = pid % G
e = e.to(tl.int64)
g = g.to(tl.int64)
# number of valid tokens for this expert
n_tokens = tl.load(counts_ptr + e * stride_counts_e).to(tl.int64)
cols = tl.arange(0, BLOCK)
cols = cols.to(tl.int64)
mask_h = cols < BLOCK
t = tl.zeros([], tl.int64)
while t < n_tokens:
base_i_offset = (e * stride_i_e + t * stride_i_t +
g * GROUP_SIZE * stride_i_h)
base_yq_offset = (e * stride_yq_e + t * stride_yq_t +
g * GROUP_SIZE * stride_yq_h)
base_ys_offset = e * stride_ys_e + t * stride_ys_t + g * stride_ys_g
mask = mask_h
x = tl.load(input_ptr + base_i_offset + cols * stride_i_h,
mask=mask,
other=0.0).to(tl.float32)
y2 = tl.load(input_ptr + base_i_offset + H * stride_i_h +
cols * stride_i_h,
mask=mask,
other=0.0).to(tl.float32)
x = x * (1.0 / (1.0 + tl.exp(-x)))
y = x * y2
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + base_yq_offset + cols * stride_yq_h, y_q, mask=mask)
tl.store(y_s_ptr + base_ys_offset, y_s)
t += 1
def silu_mul_fp8_quant_deep_gemm(
y: torch.Tensor, # (E, T, 2*H) float32
tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert
group_size: int = 128,
eps: float = 1e-10,
):
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
y has shape (E, T, 2*H). The first half of the last dimension is
silu-activated, multiplied by the second half, then quantized into FP8.
Returns `(y_q, y_s)` where
* `y_q` is the FP8 tensor of shape `(E, T, H)`, same layout as `y[..., :H]`.
* `y_s` has shape `(E, T, H // group_size)` and strides `(T*G, 1, T)`
"""
assert y.ndim == 3, "y must be (E, T, 2*H)"
E, T, H2 = y.shape
assert H2 % 2 == 0, "last dim of y must be even (2*H)"
H = H2 // 2
G = H // group_size
assert H % group_size == 0, "H must be divisible by group_size"
assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E, \
"tokens_per_expert must be shape (E,)"
tokens_per_expert = tokens_per_expert.to(device=y.device,
dtype=torch.int32)
# allocate outputs
fp8_dtype = torch.float8_e4m3fn
y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device)
# strides (elements)
stride_i_e, stride_i_t, stride_i_h = y.stride()
stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride()
# desired scale strides (elements): (T*G, 1, T)
stride_ys_e = T * G
stride_ys_t = 1
stride_ys_g = T
y_s = torch.empty_strided((E, T, G),
(stride_ys_e, stride_ys_t, stride_ys_g),
dtype=torch.float32,
device=y.device)
stride_cnt_e = tokens_per_expert.stride()[0]
# static grid over experts and H-groups.
# A loop inside the kernel handles the token dim
grid = (E * G, )
f_info = torch.finfo(fp8_dtype)
fp8_max = f_info.max
fp8_min = f_info.min
_silu_mul_fp8_quant_deep_gemm[grid](
y,
y_q,
y_s,
tokens_per_expert,
H,
group_size,
stride_i_e,
stride_i_t,
stride_i_h,
stride_yq_e,
stride_yq_t,
stride_yq_h,
stride_ys_e,
stride_ys_t,
stride_ys_g,
stride_cnt_e,
eps,
fp8_min,
fp8_max,
BLOCK=group_size,
num_warps=4,
)
return y_q, y_s
class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# The Deep Gemm kernels only support block size of 128
@ -96,7 +261,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
hidden_states, w1, w2, topk_ids)
workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))
workspace2 = _resize_cache(workspace2, (E, max_num_tokens, N // 2))
# (from deepgemm docs) : A value hint (which is a value on CPU)
# for the M expectation of each batch, correctly setting this value
@ -109,19 +273,9 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
masked_m=expert_num_tokens,
expected_m=expected_m)
# TODO (varun) [Optimization]: Use a batched version of activation.
# Similarly for the quant below.
self.activation(activation, workspace2, workspace1.view(-1, N))
w2_hidden_size = workspace2.size(-1)
workspace2 = workspace2.view(-1, w2_hidden_size)
a2q_scale: Optional[torch.Tensor] = None
a2q, a2q_scale = per_token_group_quant_fp8(workspace2,
self.block_shape[1],
column_major_scales=False)
a2q = a2q.view(E, max_num_tokens, -1)
a2q_scale = a2q_scale.view(E, max_num_tokens, -1)
assert expert_num_tokens is not None
a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1,
expert_num_tokens)
dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a2q, a2q_scale),
(w2, w2_scale),

View File

@ -45,7 +45,8 @@ if current_platform.is_cuda_alike():
from .pplx_prepare_finalize import PplxPrepareAndFinalize
if has_deepep:
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
from .deepep_ll_prepare_finalize import DeepEPLLPrepareAndFinalize
from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SIZE,
DeepEPLLPrepareAndFinalize)
else:
fused_experts = None # type: ignore
FusedMoEPermuteExpertsUnpermute = None # type: ignore
@ -377,6 +378,13 @@ class FusedMoEMethodBase(QuantizeMethodBase):
all2all_manager.world_size)
handle = all2all_manager.get_handle(all_to_all_args)
# Note : We may want to use FP8 dispatch even otherwise just to
# reduce datamovement
assert act_quant_block_size is not None
use_fp8_dispatch = (quant_dtype == current_platform.fp8_dtype()
and act_quant_block_size[1]
== DEEPEP_QUANT_BLOCK_SIZE)
# Note (varun): Whether to use FP8 dispatch or not needs some
# profiling. Turning it off for now.
prepare_finalize = DeepEPLLPrepareAndFinalize(
@ -386,7 +394,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
max_tokens_per_rank=moe.max_num_tokens,
quant_dtype=quant_dtype,
block_shape=act_quant_block_size,
use_fp8_dispatch=False,
use_fp8_dispatch=use_fp8_dispatch,
)
self.topk_indices_dtype = None

View File

@ -75,7 +75,7 @@ enable_hf_transfer()
class DisabledTqdm(tqdm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, disable=False)
super().__init__(*args, **kwargs, disable=True)
def get_lock(model_name_or_path: Union[str, Path],

View File

@ -269,3 +269,11 @@ class CpuPlatform(Platform):
model configuration.
"""
return True
@classmethod
def default_v1(cls, model_config) -> bool:
"""Returns whether the current platform can use v1 by default for the
supplied model configuration.
"""
return cls.supports_v1(
model_config) and cls.get_cpu_architecture() == CpuArchEnum.X86

View File

@ -479,6 +479,13 @@ class Platform:
"""
return False
@classmethod
def default_v1(cls, model_config: ModelConfig) -> bool:
"""
Returns whether the current platform supports v1 by default.
"""
return cls.supports_v1(model_config)
@classmethod
def use_custom_allreduce(cls) -> bool:
"""

View File

@ -215,9 +215,15 @@ class RocmPlatform(Platform):
selected_backend = _Backend.ROCM_FLASH
if envs.VLLM_USE_V1:
logger.info("Using Triton Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."
"triton_attn.TritonAttentionBackend")
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA \
and on_gfx9():
logger.info("Using Flash Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."
"rocm_aiter_fa.AiterFlashAttentionBackend")
else:
logger.info("Using Triton Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."
"triton_attn.TritonAttentionBackend")
if selected_backend == _Backend.ROCM_FLASH:
if not cls.has_device_capability(90):
# not Instinct series GPUs.

View File

@ -0,0 +1,585 @@
# SPDX-License-Identifier: Apache-2.0
"""Attention layer with AiterFlashAttention."""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
import torch
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import (
make_local_attention_virtual_batches)
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
if current_platform.is_rocm():
import aiter
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op
@triton.jit
def _vllm_layout_trans_kernel(
k_buffer_ptr,
v_buffer_ptr,
k_values_ptr,
v_values_ptr,
b_query_lens_loc,
b_seq_lens_loc,
block_table,
block_table_stride_0,
E_DIM: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
block_idx = tl.program_id(1)
batch_token_indexes = tl.load(b_seq_lens_loc + batch_idx +
tl.arange(0, 2))
batch_token_start, batch_token_end = tl.split(batch_token_indexes)
seq_len = batch_token_end - batch_token_start
batch_query_indexes = tl.load(b_query_lens_loc + batch_idx +
tl.arange(0, 2))
batch_query_start, batch_query_end = tl.split(batch_query_indexes)
query_len = batch_query_end - batch_query_start
if query_len <= 1:
return
if block_idx * BLOCK_SIZE < seq_len:
block_mask = (block_idx * BLOCK_SIZE +
tl.arange(0, BLOCK_SIZE)[:, None]) < seq_len
kv_idx = tl.load(block_table + batch_idx * block_table_stride_0 +
block_idx)
kv_buffer_off = kv_idx * BLOCK_SIZE * E_DIM + tl.arange(
0, BLOCK_SIZE)[:, None] * E_DIM + tl.arange(0, E_DIM)[None, :]
k_vals = tl.load(k_buffer_ptr + kv_buffer_off,
mask=block_mask,
other=0.0)
v_vals = tl.load(v_buffer_ptr + kv_buffer_off,
mask=block_mask,
other=0.0)
kv_values_off = batch_token_start * E_DIM + \
block_idx * BLOCK_SIZE * E_DIM + \
tl.arange(0, BLOCK_SIZE)[:, None] * E_DIM + \
tl.arange(0, E_DIM)[None, :]
tl.store(k_values_ptr + kv_values_off, k_vals, mask=block_mask)
tl.store(v_values_ptr + kv_values_off, v_vals, mask=block_mask)
def vllm_layout_trans(b_query_lens_loc, b_seq_lens_loc, block_table,
k_buffer, v_buffer, max_seq_len, total_tokens):
H_KV = v_buffer.shape[2]
D = v_buffer.shape[3]
BLOCK_SIZE = v_buffer.shape[1]
dtype = k_buffer.dtype
k_values = torch.empty((total_tokens, H_KV, D),
dtype=dtype,
device="cuda")
v_values = torch.empty((total_tokens, H_KV, D),
dtype=dtype,
device="cuda")
grid = (block_table.shape[0],
(max_seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE)
_vllm_layout_trans_kernel[grid](k_buffer,
v_buffer,
k_values,
v_values,
b_query_lens_loc,
b_seq_lens_loc,
block_table,
block_table.stride(0),
E_DIM=H_KV * D,
BLOCK_SIZE=BLOCK_SIZE)
return k_values, v_values
def flash_attn_varlen_func_impl(
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
out: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
total_tokens: int,
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: float,
window_size: Optional[list[int]], # -1 means infinite context window
alibi_slopes: Optional[list[float]],
block_table: torch.Tensor,
) -> torch.Tensor:
k, v = vllm_layout_trans(cu_seqlens_q, cu_seqlens_k, block_table,
k_cache, v_cache, max_seqlen_k, total_tokens)
output = aiter.flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
min_seqlen_q=1,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=True,
alibi_slopes=alibi_slopes,
window_size=window_size,
out=out,
)
return output
def flash_attn_varlen_func_fake(
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
out: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
total_tokens: int,
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: float,
window_size: Optional[list[int]], # -1 means infinite context window
alibi_slopes: Optional[list[float]],
block_table: torch.Tensor,
) -> torch.Tensor:
return torch.empty(q.shape[0],
q.shape[1],
v_cache.shape[-2],
dtype=torch.float8_e4m3fnuz,
device="cuda")
direct_register_custom_op("flash_attn_varlen_func",
flash_attn_varlen_func_impl, ["out"],
flash_attn_varlen_func_fake,
dispatch_key=current_platform.dispatch_key)
logger = init_logger(__name__)
class AiterFlashAttentionMetadataBuilder:
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
block_table: BlockTable):
model_config = runner.model_config
self.runner = runner
self.num_heads_q = model_config.get_num_attention_heads(
runner.parallel_config)
self.num_heads_kv = model_config.get_num_kv_heads(
runner.parallel_config)
self.headdim = model_config.get_head_size()
self.block_size = kv_cache_spec.block_size
self.kv_cache_spec = kv_cache_spec
self.block_table = block_table
# Sliding window size to be used with the AOT scheduler will be
# populated on first build() call.
self.aot_sliding_window: Optional[tuple[int, int]] = None
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
return False
def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
total_tokens = int(self.runner.seq_lens_np[:num_reqs].sum())
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
block_table = self.block_table
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
block_table.slot_mapping[:num_actual_tokens].copy_(
block_table.slot_mapping_cpu[:num_actual_tokens],
non_blocking=True)
# Fill unused with -1. Needed for reshape_and_cache in full cuda graph
# mode.
block_table.slot_mapping[num_actual_tokens:].fill_(-1)
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
cu_seq_lens = torch.zeros(seq_lens.shape[0] + 1,
dtype=torch.int32,
device="cuda")
torch.cumsum(seq_lens,
dim=0,
dtype=cu_seq_lens.dtype,
out=cu_seq_lens[1:])
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
max_seq_len, causal):
return None
# for local attention
local_attn_metadata = None
if self.runner.attention_chunk_size is not None:
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
virt_block_table_tensor = make_local_attention_virtual_batches(
self.runner.attention_chunk_size,
self.runner.query_start_loc_np[:num_reqs + 1],
self.runner.seq_lens_np[:num_reqs],
block_table_tensor,
self.block_size,
)
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
self.runner.device, non_blocking=True)
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
self.runner.device, non_blocking=True)
local_max_query_len = seqlens_q_local_np.max()
local_max_seq_len = virt_k_seqlens_np.max()
local_scheduler_metadata = schedule(
batch_size=local_query_start_loc.shape[0] - 1,
cu_query_lens=local_query_start_loc,
max_query_len=local_max_query_len,
seqlens=local_seqused_k,
max_seq_len=local_max_seq_len,
causal=True)
local_attn_metadata = \
AiterFlashAttentionMetadata.LocalAttentionMetadata(
local_query_start_loc=local_query_start_loc,
local_seqused_k=local_seqused_k,
local_block_table=virt_block_table_tensor,
local_max_query_len=local_max_query_len,
local_max_seq_len=local_max_seq_len,
local_scheduler_metadata=local_scheduler_metadata,
)
use_cascade = common_prefix_len > 0
cu_prefix_query_lens = None
prefix_kv_lens = None
suffix_kv_lens = None
attn_metadata = AiterFlashAttentionMetadata(
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
cu_seq_lens=cu_seq_lens,
total_tokens=total_tokens,
block_table=block_table_tensor,
slot_mapping=slot_mapping,
use_cascade=use_cascade,
common_prefix_len=common_prefix_len,
cu_prefix_query_lens=cu_prefix_query_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
local_attn_metadata=local_attn_metadata,
)
return attn_metadata
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
# Full CUDA Graph always supported (FA2 support checked separately)
return True
def use_cascade_attention(self, *args, **kwargs) -> bool:
return False
class AiterFlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_supported_head_sizes() -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
def get_name() -> str:
return "FLASH_ATTN_VLLM_V1"
@staticmethod
def get_impl_cls() -> type["AiterFlashAttentionImpl"]:
return AiterFlashAttentionImpl
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return AiterFlashAttentionMetadata
@staticmethod
def get_builder_cls() -> type["AiterFlashAttentionMetadataBuilder"]:
return AiterFlashAttentionMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
@dataclass
class AiterFlashAttentionMetadata:
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
cu_seq_lens: torch.Tensor
total_tokens: int
block_table: torch.Tensor
slot_mapping: torch.Tensor
# For cascade attention.
use_cascade: bool
common_prefix_len: int
cu_prefix_query_lens: Optional[torch.Tensor]
prefix_kv_lens: Optional[torch.Tensor]
suffix_kv_lens: Optional[torch.Tensor]
# for local attention
@dataclass
class LocalAttentionMetadata:
local_query_start_loc: torch.Tensor
local_seqused_k: torch.Tensor
local_block_table: torch.Tensor
local_max_query_len: int
local_max_seq_len: int
local_scheduler_metadata: Optional[torch.Tensor]
local_attn_metadata: Optional[LocalAttentionMetadata] = None
class AiterFlashAttentionImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
use_irope: bool = False,
) -> None:
if blocksparse_params is not None:
raise ValueError(
"AiterFlashAttention does not support block-sparse attention.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
if sliding_window is None:
self.sliding_window = [-1, -1]
else:
self.sliding_window = [sliding_window - 1, 0]
self.kv_cache_dtype = kv_cache_dtype
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0.
self.logits_soft_cap = logits_soft_cap
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
support_head_sizes = \
AiterFlashAttentionBackend.get_supported_head_sizes()
if head_size not in support_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by "
"AiterFlashAttention. "
f"Supported head sizes are: {support_head_sizes}. "
"Set VLLM_USE_V1=0 to use another attention backend.")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashAttentionImpl")
self.use_irope = use_irope
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"AiterFlashAttention does not support fp8 kv-cache on this "
"device.")
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AiterFlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with AiterFlashAttention.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
NOTE: FP8 quantization, flash-attn expect the size of
{q,k,v}_descale to be (num_sequences, num_kv_heads).
We use torch's .expand() to avoid duplicating values
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashAttentionImpl")
if attn_metadata is None:
# Profiling run.
return output
# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
# in this method. For example, `view` and `slice` (or `[:n]`) operations
# are surprisingly slow even in the case they do not invoke any GPU ops.
# Minimize the PyTorch ops in this method as much as possible.
# Whenever making a change in this method, please benchmark the
# performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata.num_actual_tokens
# Reshape the input keys and values and store them in the cache.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens] and
# value[:num_actual_tokens] because the reshape_and_cache_flash op uses
# the slot_mapping's shape to determine the number of actual tokens.
key_cache, value_cache = kv_cache.unbind(0)
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(torch.float8_e4m3fnuz)
value_cache = value_cache.view(torch.float8_e4m3fnuz)
num_tokens, num_heads, head_size = query.shape
query, _ = ops.scaled_fp8_quant(
query.reshape(
(num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale)
query = query.reshape((num_tokens, num_heads, head_size))
# Compute attention and update output up to `num_actual_tokens`.
use_local_attn = \
(self.use_irope and attn_metadata.local_attn_metadata is not None)
if not attn_metadata.use_cascade or use_local_attn:
if use_local_attn:
assert attn_metadata.local_attn_metadata is not None
local_metadata = attn_metadata.local_attn_metadata
cu_seqlens_q = local_metadata.local_query_start_loc
seqused_k = local_metadata.local_seqused_k
max_seqlen_q = local_metadata.local_max_query_len
max_seqlen_k = local_metadata.local_max_seq_len
block_table = local_metadata.local_block_table
else:
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
if max_seqlen_q > 1:
cu_seq_lens = attn_metadata.cu_seq_lens
total_tokens = attn_metadata.total_tokens
torch.ops.vllm.flash_attn_varlen_func(
query[:num_actual_tokens],
key_cache,
value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
total_tokens=total_tokens,
softmax_scale=self.scale,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
cu_seqlens_k=cu_seq_lens,
)
_, num_heads, head_size = query.shape
_PARTITION_SIZE_ROCM = 256
num_seqs = seqused_k.shape[0]
nbyes_per_qo_elem = torch.finfo(output.dtype).bits // 8
max_num_partitions = (max_seqlen_k + _PARTITION_SIZE_ROCM -
1) // _PARTITION_SIZE_ROCM
workspace_buffer = torch.empty(
(num_seqs * num_heads * max_num_partitions * head_size) *
nbyes_per_qo_elem + 2 *
(num_seqs * num_heads * max_num_partitions) * 4,
dtype=torch.uint8,
device=output.device,
)
aiter.paged_attention_v1(
output[:num_actual_tokens],
workspace_buffer,
query[:num_actual_tokens],
key_cache,
value_cache,
self.scale,
block_table,
cu_seqlens_q,
seqused_k,
max_seqlen_k,
self.alibi_slopes,
self.kv_cache_dtype,
"NHD",
self.logits_soft_cap,
layer._k_scale,
layer._v_scale,
None,
_PARTITION_SIZE_ROCM,
)
return output
else:
raise NotImplementedError(
"Cascade attention is not implemented for ROCM AITER")

View File

@ -14,6 +14,39 @@ logger = init_logger(__name__)
class EncoderCacheManager:
"""Manages caching of encoder outputs for multimodal models in vLLM V1.
The EncoderCacheManager handles the lifecycle of multimodal encoder outputs
(such as vision embeddings from images) during request processing. It
provides memory-aware caching to avoid recomputing encoder outputs when the
same multimodal inputs appear in different stages of request processing.
This manager is particularly important for:
- Vision-language models (e.g., LLaVA) where image encoder outputs are
cached
- Any multimodal model where encoder computation is expensive and
cacheable
The cache operates at the granularity of individual multimodal input items
within requests, allowing for fine-grained memory management and enabling
chunked processing of multimodal inputs.
Note that no caching is shared between requests at this time. If the same
input is used across multiple requests, it will be reprocessed for each
request.
Args:
cache_size: Limit the size of the cache, measured by the number of
tokens from the input sequence.
Attributes:
cache_size: Total cache capacity in encoder tokens
num_free_slots: Current available cache capacity in encoder tokens
cached: Mapping from request_id to set of cached input_ids for that
request
freed: List of (request_id, input_id) pairs that were recently freed.
This is cleared after every call to get_freed_ids().
"""
def __init__(self, cache_size: int):
self.cache_size = cache_size
@ -24,14 +57,48 @@ class EncoderCacheManager:
self.freed: list[tuple[str, int]] = []
def has_cache(self, request: Request, input_id: int) -> bool:
"""Check if encoder output for a specific multimodal input is cached.
Args:
request: The request containing the multimodal input
input_id: Index of the multimodal input within the request
Returns:
True if the encoder output for this input is already cached
"""
req_id = request.request_id
return req_id in self.cached and input_id in self.cached[req_id]
def can_allocate(self, request: Request, input_id: int) -> bool:
"""Check if there's sufficient cache space for a multimodal input.
Args:
request: The request containing the multimodal input
input_id: Index of the multimodal input within the request
Returns:
True if there's enough free cache space to store the encoder output
for this multimodal input
"""
num_tokens = request.get_num_encoder_tokens(input_id)
return num_tokens <= self.num_free_slots
def allocate(self, request: Request, input_id: int) -> None:
"""Allocate cache space for a multimodal input's encoder output.
This method reserves cache space for storing the encoder output of
the specified multimodal input. The actual encoder output storage
happens in the model runner, but this method ensures the cache
manager tracks the allocation.
Args:
request: The request containing the multimodal input
input_id: Index of the multimodal input within the request
Note:
This method assumes can_allocate() returned True for the same
request and input_id. It will reduce available cache space.
"""
req_id = request.request_id
if req_id not in self.cached:
self.cached[req_id] = set()
@ -39,10 +106,30 @@ class EncoderCacheManager:
self.num_free_slots -= request.get_num_encoder_tokens(input_id)
def get_cached_input_ids(self, request: Request) -> set[int]:
"""Get all cached multimodal input IDs for a request.
Args:
request: The request to query
Returns:
Set of input_ids that have cached encoder outputs for this request.
Returns empty set if no inputs are cached for this request.
"""
return self.cached.get(request.request_id, set())
def free_encoder_input(self, request: Request, input_id: int) -> None:
"""Free a single encoder input id for the request."""
"""Free cache space for a single multimodal input's encoder output.
This method is called when:
- The encoder output has been fully consumed by the decoder and is
no longer needed (e.g., in vision-language models after image
tokens are processed)
- A request is being cancelled or aborted
Args:
request: The request containing the multimodal input
input_id: Index of the multimodal input to free from cache
"""
req_id = request.request_id
if req_id not in self.cached:
return
@ -54,12 +141,29 @@ class EncoderCacheManager:
self.freed.append((req_id, input_id))
def free(self, request: Request) -> None:
"""Free all cached input ids for the request."""
"""Free all cached encoder outputs for a request.
This method is typically called when a request is finished, cancelled,
or aborted, and all its encoder outputs should be freed from cache.
Args:
request: The request whose encoder outputs should be freed
"""
input_ids = self.get_cached_input_ids(request).copy()
for input_id in input_ids:
self.free_encoder_input(request, input_id)
def get_freed_ids(self) -> list[tuple[str, int]]:
"""Get and clear the list of recently freed encoder cache entries.
This method returns all encoder cache entries that were freed since
the last call to this method. It's used by the scheduler to notify
workers about which encoder outputs can be removed from their caches.
Returns:
List of (request_id, input_id) tuples that were freed since the
last call. The internal freed list is cleared after this call.
"""
freed = self.freed
self.freed = []
return freed

View File

@ -171,6 +171,9 @@ class RequestStatus(enum.IntEnum):
FINISHED_ABORTED = enum.auto()
FINISHED_IGNORED = enum.auto()
def __str__(self):
return self.name
@staticmethod
def is_finished(status: "RequestStatus") -> bool:
return status > RequestStatus.PREEMPTED

View File

@ -5,7 +5,7 @@ from typing import Optional
import torch
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.tpu_input_batch import InputBatch
DEFAULT_SAMPLING_PARAMS = dict(
temperature=-1.0,

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Datastructures defining an input batch
# Datastructures defining a GPU input batch
from dataclasses import dataclass
from typing import Optional, cast
@ -453,6 +453,11 @@ class InputBatch:
self.block_table.swap_row(i1, i2)
def condense(self, empty_req_indices: list[int]) -> None:
"""Move non-empty requests down into lower, empty indices.
Args:
empty_req_indices: empty batch indices, sorted descending.
"""
num_reqs = self.num_reqs
if num_reqs == 0:
# The batched states are empty.

View File

@ -5,6 +5,7 @@ Define LoRA functionality mixin for model runners.
"""
from contextlib import contextmanager
from typing import Union
import numpy as np
import torch.nn as nn
@ -15,7 +16,10 @@ from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor.models import supports_lora, supports_multimodal
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_input_batch import InputBatch as GPUInputBatch
from vllm.v1.worker.tpu_input_batch import InputBatch as TPUInputBatch
InputBatch = Union[TPUInputBatch, GPUInputBatch]
logger = init_logger(__name__)

View File

@ -0,0 +1,584 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Datastructures defining a TPU input batch
from typing import Optional, cast
import numpy as np
import torch
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingType
from vllm.utils import swap_dict_values
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.worker.block_table import MultiGroupBlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState
_SAMPLING_EPS = 1e-5
class InputBatch:
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
device: torch.device,
pin_memory: bool,
vocab_size: int,
block_sizes: list[int], # The block_size of each kv cache group
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_batched_tokens = max_num_batched_tokens
self.device = device
self.pin_memory = pin_memory
self.vocab_size = vocab_size
self._req_ids: list[Optional[str]] = []
self.req_id_to_index: dict[str, int] = {}
# TODO(woosuk): This buffer could be too large if max_model_len is big.
# Find a way to reduce the CPU memory usage.
# This buffer is not directly transferred to the GPU, so it does not
# need to be pinned.
self.token_ids_cpu_tensor = torch.zeros(
(max_num_reqs, max_model_len),
device="cpu",
dtype=torch.int32,
pin_memory=False,
)
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_computed_tokens_cpu_tensor = torch.zeros(
(max_num_reqs, ),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.num_computed_tokens_cpu = \
self.num_computed_tokens_cpu_tensor.numpy()
# Block table.
self.block_table = MultiGroupBlockTable(
max_num_reqs=max_num_reqs,
max_model_len=max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
pin_memory=pin_memory,
device=device,
block_sizes=block_sizes,
)
# Sampling-related.
self.temperature = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device=device)
self.temperature_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device="cpu",
pin_memory=pin_memory)
self.temperature_cpu = self.temperature_cpu_tensor.numpy()
self.greedy_reqs: set[str] = set()
self.random_reqs: set[str] = set()
self.top_p = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device=device)
self.top_p_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device="cpu",
pin_memory=pin_memory)
self.top_p_cpu = self.top_p_cpu_tensor.numpy()
self.top_p_reqs: set[str] = set()
self.top_k = torch.empty((max_num_reqs, ),
dtype=torch.int32,
device=device)
self.top_k_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.int32,
device="cpu",
pin_memory=pin_memory)
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
self.top_k_reqs: set[str] = set()
self.min_p = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device=device)
self.min_p_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device="cpu",
pin_memory=pin_memory)
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
self.min_p_reqs: set[str] = set()
# Frequency penalty related data structures
self.frequency_penalties = torch.empty((max_num_reqs, ),
dtype=torch.float,
device=device)
self.frequency_penalties_cpu_tensor = torch.empty(
(max_num_reqs, ),
dtype=torch.float,
device="cpu",
pin_memory=pin_memory)
self.frequency_penalties_cpu = \
self.frequency_penalties_cpu_tensor.numpy()
self.frequency_penalties_reqs: set[str] = set()
# Presence penalty related data structures
self.presence_penalties = torch.empty((max_num_reqs, ),
dtype=torch.float,
device=device)
self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.float,
device="cpu",
pin_memory=pin_memory)
self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy(
)
self.presence_penalties_reqs: set[str] = set()
# Repetition penalty related data structures
self.repetition_penalties = torch.empty((max_num_reqs, ),
dtype=torch.float,
device=device)
self.repetition_penalties_cpu_tensor = torch.empty(
(max_num_reqs, ),
dtype=torch.float,
device="cpu",
pin_memory=pin_memory)
self.repetition_penalties_cpu = \
self.repetition_penalties_cpu_tensor.numpy()
self.repetition_penalties_reqs: set[str] = set()
# req_index -> (min_tokens, stop_token_ids)
self.min_tokens: dict[int, tuple[int, set[int]]] = {}
# lora related
self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
dtype=np.int32)
self.lora_id_to_request_ids: dict[int, set[str]] = {}
self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
# req_index -> generator
# NOTE(woosuk): The indices of the requests that do not have their own
# generator should not be included in the dictionary.
self.generators: dict[int, torch.Generator] = {}
self.num_logprobs: dict[str, int] = {}
# NOTE(rob): num_prompt_logprobs only includes reqs
# that are currently in the prefill phase.
self.num_prompt_logprobs: dict[str, int] = {}
# To accumulate prompt logprobs tensor chunks across prefill steps.
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
self.logit_bias: list[Optional[dict[int,
float]]] = [None] * max_num_reqs
self.has_allowed_token_ids: set[str] = set()
# NOTE(lufang): In the mask tensor, if the corresponding token allowed,
# the value is False. Since we use masked_fill_ to set -inf.
self.allowed_token_ids_mask: Optional[torch.Tensor] = None
self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None
# req_index -> bad_words_token_ids
self.bad_words_token_ids: dict[int, list[list[int]]] = {}
self.req_output_token_ids: list[Optional[list[int]]] = []
@property
def req_ids(self) -> list[str]:
# None elements should only be present transiently
# while performing state updates to the batch.
return cast(list[str], self._req_ids)
def add_request(
self,
request: "CachedRequestState",
req_index: Optional[int] = None,
) -> None:
if req_index is None:
req_index = self.num_reqs
assert req_index < self.max_num_reqs
req_id = request.req_id
if req_index == len(self._req_ids):
self._req_ids.append(req_id)
self.req_output_token_ids.append(request.output_token_ids)
else:
self._req_ids[req_index] = req_id
self.req_output_token_ids[req_index] = request.output_token_ids
self.req_id_to_index[req_id] = req_index
# Copy the prompt token ids and output token ids.
num_prompt_tokens = len(request.prompt_token_ids)
self.num_prompt_tokens[req_index] = num_prompt_tokens
self.token_ids_cpu[
req_index, :num_prompt_tokens] = request.prompt_token_ids
start_idx = num_prompt_tokens
end_idx = start_idx + len(request.output_token_ids)
self.token_ids_cpu[req_index,
start_idx:end_idx] = request.output_token_ids
# Number of token ids in token_ids_cpu.
# NOTE(woosuk): This may include spec decode tokens.
self.num_tokens[req_index] = request.num_tokens
# Number of tokens without spec decode tokens.
self.num_tokens_no_spec[req_index] = request.num_tokens
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
self.block_table.add_row(request.block_ids, req_index)
sampling_params = request.sampling_params
if sampling_params.sampling_type == SamplingType.GREEDY:
# Avoid later division by zero.
self.temperature_cpu[req_index] = -1.0
self.greedy_reqs.add(req_id)
else:
self.temperature_cpu[req_index] = sampling_params.temperature
self.random_reqs.add(req_id)
self.top_p_cpu[req_index] = sampling_params.top_p
if sampling_params.top_p < 1:
self.top_p_reqs.add(req_id)
top_k = sampling_params.top_k
if 0 < top_k < self.vocab_size:
self.top_k_reqs.add(req_id)
else:
top_k = self.vocab_size
self.top_k_cpu[req_index] = top_k
self.min_p_cpu[req_index] = sampling_params.min_p
self.frequency_penalties_cpu[
req_index] = sampling_params.frequency_penalty
if sampling_params.min_p > _SAMPLING_EPS:
self.min_p_reqs.add(req_id)
if sampling_params.frequency_penalty != 0.0:
self.frequency_penalties_reqs.add(req_id)
self.presence_penalties_cpu[
req_index] = sampling_params.presence_penalty
if sampling_params.presence_penalty != 0.0:
self.presence_penalties_reqs.add(req_id)
self.repetition_penalties_cpu[
req_index] = sampling_params.repetition_penalty
if sampling_params.repetition_penalty != 1.0:
self.repetition_penalties_reqs.add(req_id)
if sampling_params.min_tokens:
self.min_tokens[req_index] = (sampling_params.min_tokens,
sampling_params.all_stop_token_ids)
# NOTE(woosuk): self.generators should not include the requests that
# do not have their own generator.
if request.generator is not None:
self.generators[req_index] = request.generator
if sampling_params.logprobs is not None:
self.num_logprobs[req_id] = sampling_params.logprobs
if sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs
if sampling_params.logit_bias is not None:
self.logit_bias[req_index] = sampling_params.logit_bias
if sampling_params.allowed_token_ids:
self.has_allowed_token_ids.add(req_id)
if self.allowed_token_ids_mask_cpu_tensor is None:
# Lazy allocation for this tensor, which can be large.
# False means we don't fill with -inf.
self.allowed_token_ids_mask = torch.zeros(self.max_num_reqs,
self.vocab_size,
dtype=torch.bool,
device=self.device)
self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
self.max_num_reqs,
self.vocab_size,
dtype=torch.bool,
device="cpu")
self.allowed_token_ids_mask_cpu_tensor[req_index] = True
# False means we don't fill with -inf.
self.allowed_token_ids_mask_cpu_tensor[req_index][
sampling_params.allowed_token_ids] = False
if sampling_params.bad_words_token_ids:
self.bad_words_token_ids[
req_index] = sampling_params.bad_words_token_ids
# Add request lora ID
if request.lora_request:
lora_id = request.lora_request.lora_int_id
if lora_id not in self.lora_id_to_request_ids:
self.lora_id_to_request_ids[lora_id] = set()
self.request_lora_mapping[req_index] = lora_id
self.lora_id_to_request_ids[lora_id].add(request.req_id)
self.lora_id_to_lora_request[lora_id] = request.lora_request
else:
# No LoRA
self.request_lora_mapping[req_index] = 0
def remove_request(self, req_id: str) -> Optional[int]:
"""This method must always be followed by a call to condense()."""
req_index = self.req_id_to_index.pop(req_id, None)
if req_index is None:
return None
self._req_ids[req_index] = None
self.req_output_token_ids[req_index] = None
self.greedy_reqs.discard(req_id)
self.random_reqs.discard(req_id)
self.top_p_reqs.discard(req_id)
self.top_k_reqs.discard(req_id)
self.min_p_reqs.discard(req_id)
self.min_tokens.pop(req_index, None)
self.frequency_penalties_reqs.discard(req_id)
self.presence_penalties_reqs.discard(req_id)
self.repetition_penalties_reqs.discard(req_id)
self.generators.pop(req_index, None)
self.num_logprobs.pop(req_id, None)
self.num_prompt_logprobs.pop(req_id, None)
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
# LoRA
lora_id = self.request_lora_mapping[req_index]
if lora_id != 0:
self.lora_id_to_request_ids[lora_id].discard(req_id)
if len(self.lora_id_to_request_ids[lora_id]) == 0:
self.lora_id_to_request_ids.pop(lora_id)
self.lora_id_to_lora_request.pop(lora_id)
self.request_lora_mapping[req_index] = 0
self.logit_bias[req_index] = None
self.has_allowed_token_ids.discard(req_id)
if self.allowed_token_ids_mask_cpu_tensor is not None:
# False means we don't fill with -inf.
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
self.bad_words_token_ids.pop(req_index, None)
return req_index
def swap_states(self, i1: int, i2: int) -> None:
old_id_i1 = self._req_ids[i1]
old_id_i2 = self._req_ids[i2]
self._req_ids[i1], self._req_ids[i2] =\
self._req_ids[i2], self._req_ids[i1] # noqa
self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\
self.req_output_token_ids[i2], self.req_output_token_ids[i1]
assert old_id_i1 is not None and old_id_i2 is not None
self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\
self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
self.num_tokens[i1], self.num_tokens[i2] =\
self.num_tokens[i2], self.num_tokens[i1]
self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\
self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1]
self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\
self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
self.temperature_cpu[i1], self.temperature_cpu[i2] =\
self.temperature_cpu[i2], self.temperature_cpu[i1]
self.top_p_cpu[i1], self.top_p_cpu[i2] =\
self.top_p_cpu[i2], self.top_p_cpu[i1]
self.top_k_cpu[i1], self.top_k_cpu[i2] =\
self.top_k_cpu[i2], self.top_k_cpu[i1]
self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\
self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
self.min_p_cpu[i1], self.min_p_cpu[i2] =\
self.min_p_cpu[i2], self.min_p_cpu[i1]
# NOTE: the following is unsafe
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
# self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
# instead, we need to temporiarily copy the data for one of the indices
# TODO(lucas): optimize this by only copying valid indices
tmp = self.token_ids_cpu[i1, ...].copy()
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
self.token_ids_cpu[i2, ...] = tmp
swap_dict_values(self.generators, i1, i2)
swap_dict_values(self.min_tokens, i1, i2)
swap_dict_values(self.bad_words_token_ids, i1, i2)
self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
self.logit_bias[i1], self.logit_bias[i2] =\
self.logit_bias[i2], self.logit_bias[i1]
if self.allowed_token_ids_mask_cpu_tensor is not None:
self.allowed_token_ids_mask_cpu_tensor[i1], \
self.allowed_token_ids_mask_cpu_tensor[i2] =\
self.allowed_token_ids_mask_cpu_tensor[i2], \
self.allowed_token_ids_mask_cpu_tensor[i1]
self.block_table.swap_row(i1, i2)
def condense(self, empty_req_indices: list[int]) -> None:
"""Move non-empty requests down into lower, empty indices.
Args:
empty_req_indices: empty batch indices, sorted descending.
"""
num_reqs = self.num_reqs
if num_reqs == 0:
# The batched states are empty.
self._req_ids.clear()
self.req_output_token_ids.clear()
return
# NOTE(woosuk): This function assumes that the empty_req_indices
# is sorted in descending order.
last_req_index = num_reqs + len(empty_req_indices) - 1
while empty_req_indices:
# Find the largest non-empty index.
while last_req_index in empty_req_indices:
last_req_index -= 1
# Find the smallest empty index.
empty_index = empty_req_indices.pop()
if empty_index >= last_req_index:
break
# Swap the states.
req_id = self._req_ids[last_req_index]
output_token_ids = self.req_output_token_ids[last_req_index]
assert req_id is not None
self._req_ids[empty_index] = req_id
self._req_ids[last_req_index] = None
self.req_output_token_ids[empty_index] = output_token_ids
self.req_output_token_ids[last_req_index] = None
self.req_id_to_index[req_id] = empty_index
num_tokens = self.num_tokens[last_req_index]
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
last_req_index, :num_tokens]
self.num_tokens[empty_index] = num_tokens
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
last_req_index]
self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[
last_req_index]
self.num_computed_tokens_cpu[
empty_index] = self.num_computed_tokens_cpu[last_req_index]
self.block_table.move_row(last_req_index, empty_index)
self.temperature_cpu[empty_index] = self.temperature_cpu[
last_req_index]
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
self.frequency_penalties_cpu[
empty_index] = self.frequency_penalties_cpu[last_req_index]
self.presence_penalties_cpu[
empty_index] = self.presence_penalties_cpu[last_req_index]
self.repetition_penalties_cpu[
empty_index] = self.repetition_penalties_cpu[last_req_index]
self.min_p_cpu[empty_index] = self.min_p_cpu[last_req_index]
generator = self.generators.pop(last_req_index, None)
if generator is not None:
self.generators[empty_index] = generator
min_token = self.min_tokens.pop(last_req_index, None)
if min_token is not None:
self.min_tokens[empty_index] = min_token
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
last_req_index]
self.logit_bias[empty_index] = self.logit_bias[last_req_index]
if self.allowed_token_ids_mask_cpu_tensor is not None:
self.allowed_token_ids_mask_cpu_tensor[
empty_index] = self.allowed_token_ids_mask_cpu_tensor[
last_req_index]
bad_words_token_ids = self.bad_words_token_ids.pop(
last_req_index, None)
if bad_words_token_ids is not None:
self.bad_words_token_ids[empty_index] = bad_words_token_ids
# Decrement last_req_index since it is now empty.
last_req_index -= 1
# Trim lists to the batch size.
del self._req_ids[self.num_reqs:]
del self.req_output_token_ids[self.num_reqs:]
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
prompt_token_ids_cpu_tensor = torch.empty(
(self.num_reqs, max_prompt_len),
device="cpu",
dtype=torch.int64,
pin_memory=self.pin_memory,
)
prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
prompt_token_ids[:] = self.token_ids_cpu[:self.
num_reqs, :max_prompt_len]
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
for i in range(self.num_reqs):
prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
return prompt_token_ids_cpu_tensor.to(device=self.device,
non_blocking=True)
def make_lora_inputs(
self, num_scheduled_tokens: np.ndarray
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
"""
Given the num_scheduled_tokens for each request in the batch, return
datastructures used to activate the current LoRAs.
Returns:
1. prompt_lora_mapping: A tuple of size self.num_reqs where,
prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
where, token_lora_mapping[i] is the LoRA id to use for ith token.
3. lora_requests: Set of relevant LoRA requests.
"""
req_lora_mapping = self.request_lora_mapping[:self.num_reqs]
prompt_lora_mapping = tuple(req_lora_mapping)
token_lora_mapping = tuple(
req_lora_mapping.repeat(num_scheduled_tokens))
active_lora_requests: set[LoRARequest] = set(
self.lora_id_to_lora_request.values())
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
@property
def num_reqs(self) -> int:
return len(self.req_id_to_index)
@property
def all_greedy(self) -> bool:
return len(self.random_reqs) == 0
@property
def all_random(self) -> bool:
return len(self.greedy_reqs) == 0
@property
def no_top_p(self) -> bool:
return len(self.top_p_reqs) == 0
@property
def no_top_k(self) -> bool:
return len(self.top_k_reqs) == 0
@property
def no_min_p(self) -> bool:
return len(self.min_p_reqs) == 0
@property
def no_penalties(self) -> bool:
return (len(self.presence_penalties_reqs) == 0
and len(self.frequency_penalties_reqs) == 0
and len(self.repetition_penalties_reqs) == 0)
@property
def max_num_logprobs(self) -> Optional[int]:
return max(self.num_logprobs.values()) if self.num_logprobs else None
@property
def no_prompt_logprob(self) -> bool:
return not self.num_prompt_logprobs
@property
def no_allowed_token_ids(self) -> bool:
return len(self.has_allowed_token_ids) == 0

View File

@ -42,8 +42,8 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch
from .utils import (initialize_kv_cache_for_kv_sharing,
sanity_check_mm_encoder_outputs)