Compare commits
24 Commits
nixl-debug
...
deepep_twe
| Author | SHA1 | Date | |
|---|---|---|---|
| fcec8c8827 | |||
| 850dafea92 | |||
| b4f17e12a4 | |||
| 21ffc7353a | |||
| 39d5d33f8f | |||
| 7a821f0e7f | |||
| 26fd8ca33c | |||
| d5f206767c | |||
| 2b5ad9f233 | |||
| 299f829180 | |||
| 104a984e6a | |||
| 12575cfa7a | |||
| 8b6e1d639c | |||
| 8de2fd39fc | |||
| 735a9de71f | |||
| 257ab95439 | |||
| cca91a7a10 | |||
| f04d604567 | |||
| 19a53b2783 | |||
| eccdc8318c | |||
| 5f52a84685 | |||
| d4629dc43f | |||
| 6e9cc73f67 | |||
| c53711bd63 |
15
.github/mergify.yml
vendored
15
.github/mergify.yml
vendored
@ -65,6 +65,21 @@ pull_request_rules:
|
||||
add:
|
||||
- multi-modality
|
||||
|
||||
- name: label-qwen
|
||||
description: Automatically apply qwen label
|
||||
conditions:
|
||||
- or:
|
||||
- files~=^examples/.*qwen.*\.py
|
||||
- files~=^tests/.*qwen.*\.py
|
||||
- files~=^vllm/model_executor/models/.*qwen.*\.py
|
||||
- files~=^vllm/reasoning/.*qwen.*\.py
|
||||
- title~=(?i)Qwen
|
||||
- body~=(?i)Qwen
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
- qwen
|
||||
|
||||
- name: label-rocm
|
||||
description: Automatically apply rocm label
|
||||
conditions:
|
||||
|
||||
@ -1,387 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp # Import aiohttp
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from backend_request_func import RequestFuncInput, RequestFuncOutput
|
||||
from benchmark_dataset import RandomDataset, SampleRequest
|
||||
|
||||
try:
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
except ImportError:
|
||||
from backend_request_func import get_tokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkMetrics:
|
||||
completed: int
|
||||
total_input: int
|
||||
total_output: int
|
||||
mean_ttft_ms: float
|
||||
median_ttft_ms: float
|
||||
std_ttft_ms: float
|
||||
percentiles_ttft_ms: list[tuple[float, float]]
|
||||
mean_itl_ms: float
|
||||
median_itl_ms: float
|
||||
std_itl_ms: float
|
||||
percentiles_itl_ms: list[tuple[float, float]]
|
||||
mean_e2el_ms: float
|
||||
median_e2el_ms: float
|
||||
std_e2el_ms: float
|
||||
percentiles_e2el_ms: list[tuple[float, float]]
|
||||
|
||||
|
||||
async def reset_cache(reset_url: str):
|
||||
"""Sends a POST request to reset the prefix cache."""
|
||||
logger.debug("Resetting prefix cache at %s", reset_url)
|
||||
try:
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.post(reset_url) as response,
|
||||
):
|
||||
response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
|
||||
logger.debug("Prefix cache reset successful: %s", response.status)
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
logger.error("Failed to connect to cache reset endpoint %s: %s}", reset_url, e)
|
||||
except aiohttp.ClientResponseError as e:
|
||||
logger.error(
|
||||
"Cache reset request failed with status %s: %s", e.status, e.message
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("An unexpected error occurred during cache reset: %s", e)
|
||||
|
||||
|
||||
async def sequential_benchmark(
|
||||
backend: str,
|
||||
api_url: str,
|
||||
model_id: str,
|
||||
tokenizer,
|
||||
input_requests: list[SampleRequest],
|
||||
request_func,
|
||||
selected_percentiles: list[float],
|
||||
cache_reset_url: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Benchmark that processes requests sequentially, waiting for each to complete
|
||||
before starting the next one. Resets prefix cache between requests.
|
||||
"""
|
||||
outputs = []
|
||||
|
||||
pbar = tqdm(total=len(input_requests))
|
||||
|
||||
# Small request to force a forward pass.
|
||||
# Used for resetting the prefix cache.
|
||||
# dummy_req_input = RequestFuncInput(
|
||||
# model=model_id,
|
||||
# prompt="0",
|
||||
# api_url=api_url,
|
||||
# prompt_len=1,
|
||||
# output_len=2,
|
||||
# )
|
||||
|
||||
# print("Starting initial single prompt test run...")
|
||||
# test_output = await request_func(request_func_input=dummy_req_input)
|
||||
# if not test_output.success:
|
||||
# raise ValueError(
|
||||
# "Initial test run failed - Please check your configuration. "
|
||||
# "Error: %s", test_output.error)
|
||||
# else:
|
||||
# print("Initial test run completed. Starting sequential benchmark...")
|
||||
|
||||
benchmark_start_time = time.perf_counter()
|
||||
|
||||
# Process requests sequentially
|
||||
for request in input_requests:
|
||||
prompt, prompt_len, output_len = (
|
||||
request.prompt,
|
||||
request.prompt_len,
|
||||
request.expected_output_len,
|
||||
)
|
||||
|
||||
logger.info("Sending request with len %s", request.prompt_len)
|
||||
logger.debug('Request str: "%s"', request.prompt[:50])
|
||||
request_start_time = time.perf_counter()
|
||||
|
||||
# print(f"{prompt=}")
|
||||
request_func_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
prompt=prompt,
|
||||
api_url=api_url,
|
||||
prompt_len=prompt_len,
|
||||
output_len=output_len,
|
||||
)
|
||||
|
||||
output = await request_func(request_func_input=request_func_input)
|
||||
|
||||
request_end_time = time.perf_counter()
|
||||
# Add timing information
|
||||
if output.success and not hasattr(output, "latency"):
|
||||
output.latency = request_end_time - request_start_time
|
||||
logger.info("Finished request with latency %.4f s", output.latency)
|
||||
|
||||
outputs.append(output)
|
||||
pbar.update(1)
|
||||
|
||||
# Reset prefix cache if configured, except after the very last request
|
||||
if cache_reset_url and False:
|
||||
await request_func(request_func_input=dummy_req_input)
|
||||
await reset_cache(cache_reset_url)
|
||||
|
||||
pbar.close()
|
||||
|
||||
benchmark_duration = time.perf_counter() - benchmark_start_time
|
||||
|
||||
# Calculate metrics
|
||||
metrics = calculate_metrics(
|
||||
input_requests=input_requests,
|
||||
outputs=outputs,
|
||||
dur_s=benchmark_duration,
|
||||
tokenizer=tokenizer,
|
||||
selected_percentiles=selected_percentiles,
|
||||
)
|
||||
|
||||
print_results(metrics, benchmark_duration)
|
||||
|
||||
result = {
|
||||
"duration": benchmark_duration,
|
||||
"completed": metrics.completed,
|
||||
"total_input_tokens": metrics.total_input,
|
||||
"total_output_tokens": metrics.total_output,
|
||||
"input_lens": [request.prompt_len for request in input_requests],
|
||||
"output_lens": [
|
||||
output.output_tokens if output.success else 0 for output in outputs
|
||||
],
|
||||
"ttfts": [output.ttft for output in outputs if output.success],
|
||||
"itls": [output.itl for output in outputs if output.success],
|
||||
"generated_texts": [
|
||||
output.generated_text for output in outputs if output.success
|
||||
],
|
||||
"errors": [output.error for output in outputs if not output.success],
|
||||
}
|
||||
|
||||
# Add summary statistics
|
||||
for stat_name in ["ttft", "itl", "e2el"]:
|
||||
for metric_name in ["mean", "median", "std"]:
|
||||
result[f"{metric_name}_{stat_name}_ms"] = getattr(
|
||||
metrics, f"{metric_name}_{stat_name}_ms"
|
||||
)
|
||||
|
||||
for p, value in getattr(metrics, f"percentiles_{stat_name}_ms"):
|
||||
p_word = str(int(p)) if int(p) == p else str(p)
|
||||
result[f"p{p_word}_{stat_name}_ms"] = value
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def calculate_metrics(
|
||||
input_requests: list[SampleRequest],
|
||||
outputs: list[RequestFuncOutput],
|
||||
dur_s: float,
|
||||
tokenizer,
|
||||
selected_percentiles: list[float],
|
||||
) -> BenchmarkMetrics:
|
||||
"""Calculate benchmark metrics from results."""
|
||||
total_input = 0
|
||||
completed = 0
|
||||
total_output = 0
|
||||
ttfts = []
|
||||
itls = []
|
||||
e2els = []
|
||||
|
||||
for i, output in enumerate(outputs):
|
||||
if output.success:
|
||||
output_len = output.output_tokens
|
||||
|
||||
if not output_len:
|
||||
# Use tokenizer to count output tokens if not provided
|
||||
output_len = len(
|
||||
tokenizer(output.generated_text, add_special_tokens=False).input_ids
|
||||
)
|
||||
|
||||
total_output += output_len
|
||||
total_input += input_requests[i].prompt_len
|
||||
|
||||
if hasattr(output, "ttft") and output.ttft is not None:
|
||||
ttfts.append(output.ttft)
|
||||
|
||||
if hasattr(output, "itl") and output.itl:
|
||||
# Ensure itl is a list of floats
|
||||
if isinstance(output.itl, list):
|
||||
itls.extend(output.itl)
|
||||
else:
|
||||
logger.warning(
|
||||
"Expected list for ITL but got %s. Appending as is.",
|
||||
type(output.itl),
|
||||
)
|
||||
itls.append(output.itl)
|
||||
|
||||
if hasattr(output, "latency") and output.latency is not None:
|
||||
e2els.append(output.latency)
|
||||
|
||||
completed += 1
|
||||
|
||||
return BenchmarkMetrics(
|
||||
completed=completed,
|
||||
total_input=total_input,
|
||||
total_output=total_output,
|
||||
mean_ttft_ms=np.mean(ttfts or [0]) * 1000,
|
||||
median_ttft_ms=np.median(ttfts or [0]) * 1000,
|
||||
std_ttft_ms=np.std(ttfts or [0]) * 1000,
|
||||
percentiles_ttft_ms=[
|
||||
(p, np.percentile(ttfts or [0], p) * 1000) for p in selected_percentiles
|
||||
],
|
||||
mean_itl_ms=np.mean(itls or [0]) * 1000,
|
||||
median_itl_ms=np.median(itls or [0]) * 1000,
|
||||
std_itl_ms=np.std(itls or [0]) * 1000,
|
||||
percentiles_itl_ms=[
|
||||
(p, np.percentile(itls or [0], p) * 1000) for p in selected_percentiles
|
||||
],
|
||||
mean_e2el_ms=np.mean(e2els or [0]) * 1000,
|
||||
median_e2el_ms=np.median(e2els or [0]) * 1000,
|
||||
std_e2el_ms=np.std(e2els or [0]) * 1000,
|
||||
percentiles_e2el_ms=[
|
||||
(p, np.percentile(e2els or [0], p) * 1000) for p in selected_percentiles
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def print_results(metrics: BenchmarkMetrics, benchmark_duration: float):
|
||||
"""Print benchmark results in a formatted way."""
|
||||
print("{s:{c}^{n}}".format(s=" Sequential Benchmark Result ", n=60, c="="))
|
||||
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
|
||||
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
|
||||
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
|
||||
print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
|
||||
|
||||
def print_metric_stats(metric_name, header):
|
||||
print("{s:{c}^{n}}".format(s=header, n=60, c="-"))
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
f"Mean {metric_name} (ms):",
|
||||
getattr(metrics, f"mean_{metric_name.lower()}_ms"),
|
||||
)
|
||||
)
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
f"Median {metric_name} (ms):",
|
||||
getattr(metrics, f"median_{metric_name.lower()}_ms"),
|
||||
)
|
||||
)
|
||||
|
||||
for p, value in getattr(metrics, f"percentiles_{metric_name.lower()}_ms"):
|
||||
p_word = str(int(p)) if int(p) == p else str(p)
|
||||
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value))
|
||||
|
||||
print_metric_stats("TTFT", "Time to First Token")
|
||||
print_metric_stats("ITL", "Inter-token Latency")
|
||||
print_metric_stats("E2EL", "End-to-end Latency")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
async def main_async(args):
|
||||
# Import needed functions based on your setup
|
||||
from backend_request_func import ASYNC_REQUEST_FUNCS
|
||||
|
||||
backend = args.backend
|
||||
model_id = args.model
|
||||
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
|
||||
|
||||
# Set up API URL
|
||||
if args.base_url is not None:
|
||||
api_url = f"{args.base_url}{args.endpoint}"
|
||||
else:
|
||||
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
|
||||
|
||||
# Set up Cache Reset URL
|
||||
cache_reset_url = f"http://{args.host}:{args.port}/reset_prefix_cache"
|
||||
logger.info("Prefix cache reset configured at: %s", cache_reset_url)
|
||||
|
||||
# Get tokenizer
|
||||
tokenizer = get_tokenizer(tokenizer_id, trust_remote_code=args.trust_remote_code)
|
||||
|
||||
# Get request function
|
||||
if backend in ASYNC_REQUEST_FUNCS:
|
||||
request_func = ASYNC_REQUEST_FUNCS[backend]
|
||||
else:
|
||||
raise ValueError(f"Unknown backend: {backend}")
|
||||
|
||||
input_requests = RandomDataset().sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=args.num_requests,
|
||||
prefix_len=0,
|
||||
input_len=args.input_len,
|
||||
output_len=args.output_len,
|
||||
range_ratio=0.0,
|
||||
)
|
||||
|
||||
# Run benchmark
|
||||
result = await sequential_benchmark(
|
||||
backend=backend,
|
||||
api_url=api_url,
|
||||
model_id=model_id,
|
||||
tokenizer=tokenizer,
|
||||
input_requests=input_requests,
|
||||
request_func=request_func,
|
||||
selected_percentiles=[50, 90, 95, 99],
|
||||
cache_reset_url=cache_reset_url,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main(args):
|
||||
print(args)
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
asyncio.run(main_async(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Sequential benchmark for LLM serving")
|
||||
parser.add_argument(
|
||||
"--backend", type=str, default="vllm", help="Backend to use for requests"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base-url",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Server base URL (overrides --host and --port)",
|
||||
)
|
||||
parser.add_argument("--host", type=str, default="127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument(
|
||||
"--endpoint", type=str, default="/v1/completions", help="API endpoint"
|
||||
)
|
||||
parser.add_argument("--model", type=str, required=True, help="Name of the model")
|
||||
parser.add_argument(
|
||||
"--tokenizer", type=str, help="Name of the tokenizer (defaults to model name)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-requests", type=int, default=100, help="Number of requests to process"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-len", type=int, default=128, help="Input len for generated prompts"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-len", type=int, default=None, help="Override output len for requests"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
parser.add_argument(
|
||||
"--trust-remote-code",
|
||||
action="store_true",
|
||||
help="Trust remote code from HuggingFace",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
337
docs/design/v1/p2p_nccl_connector.md
Normal file
337
docs/design/v1/p2p_nccl_connector.md
Normal file
@ -0,0 +1,337 @@
|
||||
An implementation of xPyD with dynamic scaling based on point-to-point communication, partly inspired by Dynamo.
|
||||
|
||||
# Detailed Design
|
||||
|
||||
## Overall Process
|
||||
As shown in Figure 1, the overall process of this **PD disaggregation** solution is described through a request flow:
|
||||
|
||||
1. The client sends an HTTP request to the Proxy/Router's `/v1/completions` interface.
|
||||
2. The Proxy/Router selects a **1P1D (1 Prefill instance + 1 Decode instance)** through either through round-robin or random selection, generates a `request_id` (rules to be introduced later), modifies the `max_tokens` in the HTTP request message to **1**, and then forwards the request to the **P instance**.
|
||||
3. Immediately afterward, the Proxy/Router forwards the **original HTTP request** to the **D instance**.
|
||||
4. The **P instance** performs **Prefill** and then **actively sends the generated KV cache** to the D instance (using **PUT_ASYNC** mode). The D instance's `zmq_addr` can be resolved through the `request_id`.
|
||||
5. The **D instance** has a **dedicated thread** for receiving the KV cache (to avoid blocking the main process). The received KV cache is saved into the **GPU memory buffer**, the size of which is determined by the vLLM startup parameter `kv_buffer_size`. When the GPU buffer is full, the KV cache is stored in the **local Tensor memory pool**.
|
||||
6. During the **Decode**, the D instance's main process retrieves the KV cache (transmitted by the P instance) from either the **GPU buffer** or the **memory pool**, thereby **skipping Prefill**.
|
||||
7. After completing **Decode**, the D instance returns the result to the **Proxy/Router**, which then forwards it to the **client**.
|
||||
|
||||

|
||||
|
||||
## Proxy/Router (Demo)
|
||||
|
||||
A simple HTTP service acts as the entry point for client requests and starts a background thread to listen for P/D instances reporting their HTTP IP and PORT, as well as ZMQ IP and PORT. It maintains a dictionary of `http_addr -> zmq_addr`. The `http_addr` is the IP:PORT for the vLLM instance's request, while the `zmq_addr` is the address for KV cache handshake and metadata reception.
|
||||
|
||||
The Proxy/Router is responsible for selecting 1P1D based on the characteristics of the client request, such as the prompt, and generating a corresponding `request_id`, for example:
|
||||
|
||||
```
|
||||
cmpl-___prefill_addr_10.0.1.2:21001___decode_addr_10.0.1.3:22001_93923d63113b4b338973f24d19d4bf11-0
|
||||
```
|
||||
|
||||
Currently, to quickly verify whether xPyD can work, a round-robin selection of 1P1D is used. In the future, it is planned to use a trie combined with the load status of instances to select appropriate P and D.
|
||||
|
||||
Each P/D instance periodically sends a heartbeat packet to the Proxy/Router (currently every 3 seconds) to register (i.e., report `http_addr -> zmq_addr`) and keep the connection alive. If an instance crashes and fails to send a ping for a certain period of time, the Proxy/Router will remove the timed-out instance (this feature has not yet been developed).
|
||||
|
||||
## KV Cache Transfer Methods
|
||||
|
||||
There are three methods for KVcache transfer: PUT, GET, and PUT_ASYNC. These methods can be specified using the `--kv-transfer-config` and `kv_connector_extra_config` parameters, specifically through the `send_type` field. Both PUT and PUT_ASYNC involve the P instance actively sending KVcache to the D instance. The difference is that PUT is a synchronous transfer method that blocks the main process, while PUT_ASYNC is an asynchronous transfer method. PUT_ASYNC uses a dedicated thread for sending KVcache, which means it does not block the main process. In contrast, the GET method involves the P instance saving the KVcache to the memory buffer after computing the prefill. The D instance then actively retrieves the computed KVcache from the P instance once it has allocated space for the KVcache.
|
||||
|
||||
Experimental results have shown that the performance of these methods, from highest to lowest, is as follows: PUT_ASYNC → GET → PUT.
|
||||
|
||||
## P2P Communication via ZMQ & NCCL
|
||||
|
||||
As long as the address of the counterpart is known, point-to-point KV cache transfer (using NCCL) can be performed, without being constrained by rank and world size. To support dynamic scaling (expansion and contraction) of instances with PD disaggregation. This means that adding or removing P/D instances does not require a full system restart.
|
||||
|
||||
Each P/D instance only needs to create a single `P2pNcclEngine` instance. This instance maintains a ZMQ Server, which runs a dedicated thread to listen on the `zmq_addr` address and receive control flow requests from other instances. These requests include requests to establish an NCCL connection and requests to send KVcache metadata (such as tensor shapes and data types). However, it does not actually transmit the KVcache data itself.
|
||||
|
||||
When a P instance and a D instance transmit KVcache for the first time, they need to establish a ZMQ connection and an NCCL group. For subsequent KVcache transmissions, this ZMQ connection and NCCL group are reused. The NCCL group consists of only two ranks, meaning the world size is equal to 2. This design is intended to support dynamic scaling, which means that adding or removing P/D instances does not require a full system restart. As long as the address of the counterpart is known, point-to-point KVcache transmission can be performed, without being restricted by rank or world size.
|
||||
|
||||
## NCCL Group Topology
|
||||
|
||||
Currently, only symmetric TP (Tensor Parallelism) methods are supported for KVcache transmission. Asymmetric TP and PP (Pipeline Parallelism) methods will be supported in the future. Figure 2 illustrates the 1P2D setup, where each instance has a TP (Tensor Parallelism) degree of 2. There are a total of 7 NCCL groups: three vLLM instances each have one NCCL group with TP=2. Additionally, the 0th GPU card of the P instance establishes an NCCL group with the 0th GPU card of each D instance. Similarly, the 1st GPU card of the P instance establishes an NCCL group with the 1st GPU card of each D instance.
|
||||
|
||||

|
||||
|
||||
Each NCCL group occupies a certain amount of GPU memory buffer for communication, the size of which is primarily influenced by the `NCCL_MAX_NCHANNELS` environment variable. When `NCCL_MAX_NCHANNELS=16`, an NCCL group typically occupies 100MB, while when `NCCL_MAX_NCHANNELS=8`, it usually takes up 52MB. For large-scale xPyD configurations—such as DeepSeek's 96P144D—this implementation is currently not feasible. Moving forward, we are considering using RDMA for point-to-point communication and are also keeping an eye on UCCL.
|
||||
|
||||
## GPU Memory Buffer and Tensor Memory Pool
|
||||
|
||||
The trade-off in the size of the memory buffer is as follows: For P instances, the memory buffer is not required in PUT and PUT_ASYNC modes, but it is necessary in GET mode. For D instances, a memory buffer is needed in all three modes. The memory buffer for D instances should not be too large. Similarly, for P instances in GET mode, the memory buffer should also not be too large. The memory buffer of D instances is used to temporarily store KVcache sent by P instances. If it is too large, it will reduce the KVcache space available for normal inference by D instances, thereby decreasing the inference batch size and ultimately leading to a reduction in output throughput. The size of the memory buffer is configured by the parameter `kv_buffer_size`, measured in bytes, and is typically set to 5%~10% of the memory size.
|
||||
|
||||
If the `--max-num-seqs` parameter for P instances is set to a large value, due to the large batch size, P instances will generate a large amount of KVcache simultaneously. This may exceed the capacity of the memory buffer of D instances, resulting in KVcache loss. Once KVcache is lost, D instances need to recompute Prefill, which is equivalent to performing Prefill twice. Consequently, the time-to-first-token (TTFT) will significantly increase, leading to degraded performance.
|
||||
|
||||
To address the above issues, I have designed and developed a local Tensor memory pool for storing KVcache, inspired by the buddy system used in Linux memory modules. Since the memory is sufficiently large, typically in the TB range on servers, there is no need to consider prefix caching or using block-based designs to reuse memory, thereby saving space. When the memory buffer is insufficient, KVcache can be directly stored in the Tensor memory pool, and D instances can subsequently retrieve KVcache from it. The read and write speed is that of PCIe, with PCIe 4.0 having a speed of approximately 21 GB/s, which is usually faster than the Prefill speed. Otherwise, solutions like Mooncake and lmcache would not be necessary. The Tensor memory pool acts as a flood diversion area, typically unused except during sudden traffic surges. In the worst-case scenario, my solution performs no worse than the normal situation with a Cache store.
|
||||
|
||||
# Install vLLM
|
||||
|
||||
```shell
|
||||
# Enter the home directory or your working directory.
|
||||
cd /home
|
||||
|
||||
# Download the installation package, and I will update the commit-id in time. You can directly copy the command.
|
||||
wget https://vllm-wheels.s3.us-west-2.amazonaws.com/9112b443a042d8d815880b8780633882ad32b183/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl
|
||||
|
||||
# Download the code repository.
|
||||
git clone -b xpyd-v1 https://github.com/Abatom/vllm.git
|
||||
cd vllm
|
||||
|
||||
# Set the installation package path.
|
||||
export VLLM_PRECOMPILED_WHEEL_LOCATION=/home/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl
|
||||
|
||||
# installation
|
||||
pip install -e . -v
|
||||
```
|
||||
|
||||
# Run xPyD
|
||||
|
||||
## Instructions
|
||||
- The following examples are run on an A800 (80GB) device, using the Meta-Llama-3.1-8B-Instruct model.
|
||||
- Pay attention to the setting of the `kv_buffer_size` (in bytes). The empirical value is 10% of the GPU memory size. This is related to the kvcache size. If it is too small, the GPU memory buffer for temporarily storing the received kvcache will overflow, causing the kvcache to be stored in the tensor memory pool, which increases latency. If it is too large, the kvcache available for inference will be reduced, leading to a smaller batch size and decreased throughput.
|
||||
- For Prefill instances, when using non-GET mode, the `kv_buffer_size` can be set to 1, as Prefill currently does not need to receive kvcache. However, when using GET mode, a larger `kv_buffer_size` is required because it needs to store the kvcache sent to the D instance.
|
||||
- You may need to modify the `kv_buffer_size` and `port` in the following commands (if there is a conflict).
|
||||
- `PUT_ASYNC` offers the best performance and should be prioritized.
|
||||
- The `--port` must be consistent with the `http_port` in the `--kv-transfer-config`.
|
||||
- The `disagg_prefill_proxy_xpyd.py` script will use port 10001 (for receiving client requests) and port 30001 (for receiving service discovery from P and D instances).
|
||||
- The node running the proxy must have `quart` installed.
|
||||
- Supports multiple nodes; you just need to modify the `proxy_ip` and `proxy_port` in `--kv-transfer-config`.
|
||||
- In the following examples, it is assumed that **the proxy's IP is 10.0.1.1**.
|
||||
|
||||
## Run 1P3D
|
||||
|
||||
### Proxy (e.g. 10.0.1.1)
|
||||
|
||||
```shell
|
||||
cd {your vllm directory}/examples/online_serving/disagg_xpyd/
|
||||
python3 disagg_prefill_proxy_xpyd.py &
|
||||
```
|
||||
|
||||
### Prefill1 (e.g. 10.0.1.2 or 10.0.1.1)
|
||||
|
||||
```shell
|
||||
VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=0 vllm serve {your model directory} \
|
||||
--host 0.0.0.0 \
|
||||
--port 20005 \
|
||||
--tensor-parallel-size 1 \
|
||||
--seed 1024 \
|
||||
--served-model-name base_model \
|
||||
--dtype float16 \
|
||||
--max-model-len 10000 \
|
||||
--max-num-batched-tokens 10000 \
|
||||
--max-num-seqs 256 \
|
||||
--trust-remote-code \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--disable-log-request \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_buffer_size":"1e1","kv_port":"21001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20005","send_type":"PUT_ASYNC","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 &
|
||||
```
|
||||
|
||||
### Decode1 (e.g. 10.0.1.3 or 10.0.1.1)
|
||||
|
||||
```shell
|
||||
VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=1 vllm serve {your model directory} \
|
||||
--host 0.0.0.0 \
|
||||
--port 20009 \
|
||||
--tensor-parallel-size 1 \
|
||||
--seed 1024 \
|
||||
--served-model-name base_model \
|
||||
--dtype float16 \
|
||||
--max-model-len 10000 \
|
||||
--max-num-batched-tokens 10000 \
|
||||
--max-num-seqs 256 \
|
||||
--trust-remote-code \
|
||||
--gpu-memory-utilization 0.7 \
|
||||
--disable-log-request \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_buffer_size":"8e9","kv_port":"22001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20009","send_type":"PUT_ASYNC","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 &
|
||||
```
|
||||
|
||||
### Decode2 (e.g. 10.0.1.4 or 10.0.1.1)
|
||||
|
||||
```shell
|
||||
VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=2 vllm serve {your model directory} \
|
||||
--host 0.0.0.0 \
|
||||
--port 20003 \
|
||||
--tensor-parallel-size 1 \
|
||||
--seed 1024 \
|
||||
--served-model-name base_model \
|
||||
--dtype float16 \
|
||||
--max-model-len 10000 \
|
||||
--max-num-batched-tokens 10000 \
|
||||
--max-num-seqs 256 \
|
||||
--trust-remote-code \
|
||||
--gpu-memory-utilization 0.7 \
|
||||
--disable-log-request \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_buffer_size":"8e9","kv_port":"23001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20003","send_type":"PUT_ASYNC","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 &
|
||||
```
|
||||
|
||||
### Decode3 (e.g. 10.0.1.5 or 10.0.1.1)
|
||||
|
||||
```shell
|
||||
VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=3 vllm serve {your model directory} \
|
||||
--host 0.0.0.0 \
|
||||
--port 20008 \
|
||||
--tensor-parallel-size 1 \
|
||||
--seed 1024 \
|
||||
--served-model-name base_model \
|
||||
--dtype float16 \
|
||||
--max-model-len 10000 \
|
||||
--max-num-batched-tokens 10000 \
|
||||
--max-num-seqs 256 \
|
||||
--trust-remote-code \
|
||||
--gpu-memory-utilization 0.7 \
|
||||
--disable-log-request \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_buffer_size":"8e9","kv_port":"24001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20008","send_type":"PUT_ASYNC","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 &
|
||||
```
|
||||
|
||||
## Run 3P1D
|
||||
|
||||
### Proxy (e.g. 10.0.1.1)
|
||||
|
||||
```shell
|
||||
cd {your vllm directory}/examples/online_serving/disagg_xpyd/
|
||||
python3 disagg_prefill_proxy_xpyd.py &
|
||||
```
|
||||
|
||||
### Prefill1 (e.g. 10.0.1.2 or 10.0.1.1)
|
||||
|
||||
```shell
|
||||
VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=0 vllm serve {your model directory} \
|
||||
--host 0.0.0.0 \
|
||||
--port 20005 \
|
||||
--tensor-parallel-size 1 \
|
||||
--seed 1024 \
|
||||
--served-model-name base_model \
|
||||
--dtype float16 \
|
||||
--max-model-len 10000 \
|
||||
--max-num-batched-tokens 10000 \
|
||||
--max-num-seqs 256 \
|
||||
--trust-remote-code \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--disable-log-request \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_buffer_size":"1e1","kv_port":"21001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20005","send_type":"PUT_ASYNC","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 &
|
||||
```
|
||||
|
||||
### Prefill2 (e.g. 10.0.1.3 or 10.0.1.1)
|
||||
|
||||
```shell
|
||||
VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=1 vllm serve {your model directory} \
|
||||
--host 0.0.0.0 \
|
||||
--port 20009 \
|
||||
--tensor-parallel-size 1 \
|
||||
--seed 1024 \
|
||||
--served-model-name base_model \
|
||||
--dtype float16 \
|
||||
--max-model-len 10000 \
|
||||
--max-num-batched-tokens 10000 \
|
||||
--max-num-seqs 256 \
|
||||
--trust-remote-code \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--disable-log-request \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_buffer_size":"1e1","kv_port":"22001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20009","send_type":"PUT_ASYNC","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 &
|
||||
```
|
||||
|
||||
### Prefill3 (e.g. 10.0.1.4 or 10.0.1.1)
|
||||
|
||||
```shell
|
||||
VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=2 vllm serve {your model directory} \
|
||||
--host 0.0.0.0 \
|
||||
--port 20003 \
|
||||
--tensor-parallel-size 1 \
|
||||
--seed 1024 \
|
||||
--served-model-name base_model \
|
||||
--dtype float16 \
|
||||
--max-model-len 10000 \
|
||||
--max-num-batched-tokens 10000 \
|
||||
--max-num-seqs 256 \
|
||||
--trust-remote-code \
|
||||
--gpu-memory-utilization 0.9 \
|
||||
--disable-log-request \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_buffer_size":"1e1","kv_port":"23001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20003","send_type":"PUT_ASYNC","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 &
|
||||
```
|
||||
|
||||
### Decode1 (e.g. 10.0.1.5 or 10.0.1.1)
|
||||
|
||||
```shell
|
||||
VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=3 vllm serve {your model directory} \
|
||||
--host 0.0.0.0 \
|
||||
--port 20008 \
|
||||
--tensor-parallel-size 1 \
|
||||
--seed 1024 \
|
||||
--served-model-name base_model \
|
||||
--dtype float16 \
|
||||
--max-model-len 10000 \
|
||||
--max-num-batched-tokens 10000 \
|
||||
--max-num-seqs 256 \
|
||||
--trust-remote-code \
|
||||
--gpu-memory-utilization 0.7 \
|
||||
--disable-log-request \
|
||||
--kv-transfer-config \
|
||||
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_buffer_size":"8e9","kv_port":"24001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20008","send_type":"PUT_ASYNC","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 &
|
||||
```
|
||||
|
||||
# Single request
|
||||
|
||||
```shell
|
||||
curl -X POST -s http://10.0.1.1:10001/v1/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "base_model",
|
||||
"prompt": "San Francisco is a",
|
||||
"max_tokens": 10,
|
||||
"temperature": 0
|
||||
}'
|
||||
```
|
||||
|
||||
# Benchmark
|
||||
|
||||
```shell
|
||||
python3 benchmark_serving.py \
|
||||
--backend vllm \
|
||||
--model base_model \
|
||||
--tokenizer meta-llama/Llama-3.1-8B-Instruct \
|
||||
--dataset-name "random" \
|
||||
--host 10.0.1.1 \
|
||||
--port 10001 \
|
||||
--random-input-len 1024 \
|
||||
--random-output-len 1024 \
|
||||
--ignore-eos \
|
||||
--burstiness 100 \
|
||||
--percentile-metrics "ttft,tpot,itl,e2el" \
|
||||
--metric-percentiles "90,95,99" \
|
||||
--seed $(date +%s) \
|
||||
--trust-remote-code \
|
||||
--request-rate 3 \
|
||||
--num-prompts 1000
|
||||
```
|
||||
|
||||
# Shut down
|
||||
|
||||
```shell
|
||||
pgrep python | xargs kill -9 && pkill -f python
|
||||
```
|
||||
|
||||
# Test data
|
||||
|
||||
## **Scenario 1**: 1K input & 1K output tokens, E2E P99 latency ~20s
|
||||
- **1P5D (6×A800) vs vLLM (1×A800)**:
|
||||
- Throughput ↑7.2% (1085 → 6979/6)
|
||||
- ITL (P99) ↓81.3% (120ms → 22.9ms)
|
||||
- TTFT (P99) ↑26.8% (175ms → 222ms)
|
||||
- TPOT: No change
|
||||
|
||||
- **1P6D (7×A800) vs vLLM (1×A800)**:
|
||||
- Throughput ↑9.6% (1085 → 8329/7)
|
||||
- ITL (P99) ↓81.0% (120ms → 22.7ms)
|
||||
- TTFT (P99) ↑210% (175ms →543ms)
|
||||
- TPOT: No change
|
||||
|
||||
## **Scenario 2**: 1K input & 200 output tokens, E2E P99 latency ~4s
|
||||
- **1P1D (2×A800) vs vLLM (1×A800)**:
|
||||
- Throughput ↑37.4% (537 → 1476/2)
|
||||
- ITL (P99) ↓81.8% (127ms → 23.1ms)
|
||||
- TTFT (P99) ↑41.8% (160ms → 227ms)
|
||||
- TPOT: No change
|
||||
|
||||

|
||||
@ -42,7 +42,7 @@ vLLM is a Python library that supports the following GPU variants. Select your G
|
||||
|
||||
=== "NVIDIA CUDA"
|
||||
|
||||
--8<-- "docs/getting_started/installation/gpu/cuda.inc.md:create-a-new-python-environment"
|
||||
--8<-- "docs/getting_started/installation/gpu/cuda.inc.md:set-up-using-python"
|
||||
|
||||
=== "AMD ROCm"
|
||||
|
||||
|
||||
@ -10,8 +10,6 @@ vLLM contains pre-compiled C++ and CUDA (12.8) binaries.
|
||||
# --8<-- [end:requirements]
|
||||
# --8<-- [start:set-up-using-python]
|
||||
|
||||
### Create a new Python environment
|
||||
|
||||
!!! note
|
||||
PyTorch installed via `conda` will statically link `NCCL` library, which can cause issues when vLLM tries to use `NCCL`. See <gh-issue:8420> for more details.
|
||||
|
||||
|
||||
154
examples/online_serving/disagg_xpyd/disagg_prefill_proxy_xpyd.py
Normal file
154
examples/online_serving/disagg_xpyd/disagg_prefill_proxy_xpyd.py
Normal file
@ -0,0 +1,154 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import os
|
||||
import socket
|
||||
import threading
|
||||
import uuid
|
||||
|
||||
import aiohttp
|
||||
import msgpack
|
||||
import zmq
|
||||
from quart import Quart, make_response, request
|
||||
|
||||
count = 0
|
||||
prefill_instances: dict[str, str] = {} # http_address: zmq_address
|
||||
decode_instances: dict[str, str] = {} # http_address: zmq_address
|
||||
|
||||
prefill_cv = threading.Condition()
|
||||
decode_cv = threading.Condition()
|
||||
|
||||
|
||||
def _listen_for_register(poller, router_socket):
|
||||
while True:
|
||||
socks = dict(poller.poll())
|
||||
if router_socket in socks:
|
||||
remote_address, message = router_socket.recv_multipart()
|
||||
# data: {"type": "P", "http_address": "ip:port",
|
||||
# "zmq_address": "ip:port"}
|
||||
data = msgpack.loads(message)
|
||||
if data["type"] == "P":
|
||||
global prefill_instances
|
||||
global prefill_cv
|
||||
with prefill_cv:
|
||||
prefill_instances[data["http_address"]] = data["zmq_address"]
|
||||
elif data["type"] == "D":
|
||||
global decode_instances
|
||||
global decode_cv
|
||||
with decode_cv:
|
||||
decode_instances[data["http_address"]] = data["zmq_address"]
|
||||
else:
|
||||
print(
|
||||
"Unexpected, Received message from %s, data: %s",
|
||||
remote_address,
|
||||
data,
|
||||
)
|
||||
|
||||
|
||||
def start_service_discovery(hostname, port):
|
||||
if not hostname:
|
||||
hostname = socket.gethostname()
|
||||
if port == 0:
|
||||
raise ValueError("Port cannot be 0")
|
||||
|
||||
context = zmq.Context()
|
||||
router_socket = context.socket(zmq.ROUTER)
|
||||
router_socket.bind(f"tcp://{hostname}:{port}")
|
||||
|
||||
poller = zmq.Poller()
|
||||
poller.register(router_socket, zmq.POLLIN)
|
||||
|
||||
_listener_thread = threading.Thread(
|
||||
target=_listen_for_register, args=[poller, router_socket], daemon=True
|
||||
)
|
||||
_listener_thread.start()
|
||||
return _listener_thread
|
||||
|
||||
|
||||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
||||
|
||||
app = Quart(__name__)
|
||||
|
||||
|
||||
def random_uuid() -> str:
|
||||
return str(uuid.uuid4().hex)
|
||||
|
||||
|
||||
async def forward_request(url, data, request_id):
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
"X-Request-Id": request_id,
|
||||
}
|
||||
async with session.post(url=url, json=data, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
if True:
|
||||
async for chunk_bytes in response.content.iter_chunked(1024):
|
||||
yield chunk_bytes
|
||||
else:
|
||||
content = await response.read()
|
||||
yield content
|
||||
|
||||
|
||||
@app.route("/v1/completions", methods=["POST"])
|
||||
async def handle_request():
|
||||
try:
|
||||
original_request_data = await request.get_json()
|
||||
|
||||
prefill_request = original_request_data.copy()
|
||||
# change max_tokens = 1 to let it only do prefill
|
||||
prefill_request["max_tokens"] = 1
|
||||
|
||||
global count
|
||||
global prefill_instances
|
||||
global prefill_cv
|
||||
with prefill_cv:
|
||||
prefill_list = list(prefill_instances.items())
|
||||
prefill_addr, prefill_zmq_addr = prefill_list[count % len(prefill_list)]
|
||||
|
||||
global decode_instances
|
||||
global decode_cv
|
||||
with decode_cv:
|
||||
decode_list = list(decode_instances.items())
|
||||
decode_addr, decode_zmq_addr = decode_list[count % len(decode_list)]
|
||||
|
||||
print(
|
||||
f"handle_request count: {count}, [HTTP:{prefill_addr}, "
|
||||
f"ZMQ:{prefill_zmq_addr}] 👉 [HTTP:{decode_addr}, "
|
||||
f"ZMQ:{decode_zmq_addr}]"
|
||||
)
|
||||
count += 1
|
||||
|
||||
request_id = (
|
||||
f"___prefill_addr_{prefill_zmq_addr}___decode_addr_"
|
||||
f"{decode_zmq_addr}_{random_uuid()}"
|
||||
)
|
||||
|
||||
# finish prefill
|
||||
async for _ in forward_request(
|
||||
f"http://{prefill_addr}/v1/completions", prefill_request, request_id
|
||||
):
|
||||
continue
|
||||
|
||||
# return decode
|
||||
generator = forward_request(
|
||||
f"http://{decode_addr}/v1/completions", original_request_data, request_id
|
||||
)
|
||||
response = await make_response(generator)
|
||||
response.timeout = None
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
exc_info = sys.exc_info()
|
||||
print("Error occurred in disagg prefill proxy server")
|
||||
print(e)
|
||||
print("".join(traceback.format_exception(*exc_info)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
t = start_service_discovery("0.0.0.0", 30001)
|
||||
app.run(host="0.0.0.0", port=10001)
|
||||
t.join()
|
||||
83
tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py
Normal file
83
tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py
Normal file
@ -0,0 +1,83 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
silu_mul_fp8_quant_deep_gemm)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# (E, T, H, group_size, seed)
|
||||
CASES = [
|
||||
(1, 1, 128, 64, 0),
|
||||
(1, 4, 128, 128, 0),
|
||||
(2, 4, 256, 128, 0),
|
||||
(32, 64, 256, 128, 0),
|
||||
(17, 31, 768, 128, 0),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("E,T,H,group_size,seed", CASES)
|
||||
@torch.inference_mode()
|
||||
def test_silu_mul_fp8_quant_deep_gemm(E, T, H, group_size, seed):
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
# Input tensor of shape (E, T, 2*H)
|
||||
y = torch.randn((E, T, 2 * H), dtype=torch.float32, device="cuda")
|
||||
tokens_per_expert = torch.randint(
|
||||
low=0,
|
||||
high=T,
|
||||
size=(E, ),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
# Run the Triton kernel
|
||||
y_q, y_s = silu_mul_fp8_quant_deep_gemm(y,
|
||||
tokens_per_expert,
|
||||
group_size=group_size,
|
||||
eps=1e-10)
|
||||
|
||||
# Reference implementation
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max = fp8_info.max
|
||||
fp8_min = fp8_info.min
|
||||
eps = 1e-10
|
||||
|
||||
# Compute silu activation and elementwise multiplication
|
||||
y1 = y[..., :H]
|
||||
y2 = y[..., H:]
|
||||
silu_x = y1 * torch.sigmoid(y1)
|
||||
merged = silu_x * y2
|
||||
|
||||
# Compute reference scales and quantized output, skipping padded tokens
|
||||
for e in range(E):
|
||||
nt = tokens_per_expert[e].item()
|
||||
ref_s = torch.empty((T, H // group_size),
|
||||
dtype=torch.float32,
|
||||
device="cuda")
|
||||
ref_q = torch.empty((T, H), dtype=torch.float8_e4m3fn, device="cuda")
|
||||
for t in range(nt):
|
||||
data = merged[e, t]
|
||||
data_grp = data.view(H // group_size, group_size)
|
||||
amax = data_grp.abs().amax(dim=1).clamp(min=eps)
|
||||
scale = amax / fp8_max
|
||||
|
||||
scaled = data / scale.repeat_interleave(group_size)
|
||||
clamped = scaled.clamp(fp8_min, fp8_max)
|
||||
q = clamped.to(torch.float8_e4m3fn)
|
||||
|
||||
ref_s[t] = scale
|
||||
ref_q[t] = q
|
||||
|
||||
y_se = y_s[e]
|
||||
y_qe = y_q[e]
|
||||
|
||||
torch.testing.assert_close(y_se[:nt], ref_s[:nt])
|
||||
torch.testing.assert_close(
|
||||
y_qe[:nt].to(torch.float32),
|
||||
ref_q[:nt].to(torch.float32),
|
||||
atol=2,
|
||||
rtol=2e-1,
|
||||
)
|
||||
15
tests/v1/test_request.py
Normal file
15
tests/v1/test_request.py
Normal file
@ -0,0 +1,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from vllm.v1.request import RequestStatus
|
||||
|
||||
|
||||
def test_request_status_fmt_str():
|
||||
"""Test that the string representation of RequestStatus is correct."""
|
||||
assert f"{RequestStatus.WAITING}" == "WAITING"
|
||||
assert f"{RequestStatus.WAITING_FOR_FSM}" == "WAITING_FOR_FSM"
|
||||
assert f"{RequestStatus.WAITING_FOR_REMOTE_KVS}" == "WAITING_FOR_REMOTE_KVS"
|
||||
assert f"{RequestStatus.RUNNING}" == "RUNNING"
|
||||
assert f"{RequestStatus.PREEMPTED}" == "PREEMPTED"
|
||||
assert f"{RequestStatus.FINISHED_STOPPED}" == "FINISHED_STOPPED"
|
||||
assert f"{RequestStatus.FINISHED_LENGTH_CAPPED}" == "FINISHED_LENGTH_CAPPED"
|
||||
assert f"{RequestStatus.FINISHED_ABORTED}" == "FINISHED_ABORTED"
|
||||
assert f"{RequestStatus.FINISHED_IGNORED}" == "FINISHED_IGNORED"
|
||||
@ -1,79 +0,0 @@
|
||||
# Needed for the proxy server
|
||||
vllm-directory := "/home/rshaw/vllm/"
|
||||
|
||||
# MODEL := "Qwen/Qwen3-0.6B"
|
||||
MODEL := "meta-llama/Llama-3.1-8B-Instruct"
|
||||
PROXY_PORT := "8192"
|
||||
PREFILL_PORT := "8100"
|
||||
DECODE_PORT := "8200"
|
||||
|
||||
prefill:
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=5557 \
|
||||
CUDA_VISIBLE_DEVICES=0,7 \
|
||||
vllm serve {{MODEL}} \
|
||||
--port {{PREFILL_PORT}} \
|
||||
--tensor-parallel-size 2 \
|
||||
--enforce-eager \
|
||||
--disable-log-requests \
|
||||
--block-size 128 \
|
||||
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
|
||||
|
||||
decode:
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=5567 \
|
||||
CUDA_VISIBLE_DEVICES=4,5 \
|
||||
vllm serve {{MODEL}} \
|
||||
--port {{DECODE_PORT}} \
|
||||
--tensor-parallel-size 2 \
|
||||
--enforce-eager \
|
||||
--disable-log-requests \
|
||||
--block-size 128 \
|
||||
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
|
||||
|
||||
proxy:
|
||||
python "{{vllm-directory}}tests/v1/kv_connector/nixl_integration/toy_proxy_server.py" \
|
||||
--port {{PROXY_PORT}} \
|
||||
--prefiller-port {{PREFILL_PORT}} \
|
||||
--decoder-port {{DECODE_PORT}}
|
||||
|
||||
send_request:
|
||||
curl -X POST http://localhost:{{PROXY_PORT}}/v1/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{ \
|
||||
"model": "{{MODEL}}", \
|
||||
"prompt": "Red Hat is the best open source company by far across Linux, K8s, and AI, and vLLM has the greatest community in open source AI software infrastructure. I love vLLM because", \
|
||||
"max_tokens": 150, \
|
||||
"temperature": 0.7 \
|
||||
}'
|
||||
|
||||
benchmark NUM_PROMPTS:
|
||||
python {{vllm-directory}}/benchmarks/benchmark_serving.py \
|
||||
--port {{PROXY_PORT}} \
|
||||
--model {{MODEL}} \
|
||||
--dataset-name random \
|
||||
--random-input-len 30000 \
|
||||
--random-output-len 10 \
|
||||
--num-prompts {{NUM_PROMPTS}} \
|
||||
--seed $(date +%s) \
|
||||
|
||||
benchmark_one INPUT_LEN:
|
||||
python {{vllm-directory}}benchmarks/benchmark_one_concurrent_req.py \
|
||||
--port {{PROXY_PORT}} \
|
||||
--model {{MODEL}} \
|
||||
--input-len {{INPUT_LEN}} \
|
||||
--output-len 1 \
|
||||
--num-requests 10 \
|
||||
--seed $(date +%s)
|
||||
|
||||
benchmark_one_no_pd INPUT_LEN:
|
||||
python {{vllm-directory}}benchmarks/benchmark_one_concurrent_req.py \
|
||||
--port {{DECODE_PORT}} \
|
||||
--model {{MODEL}} \
|
||||
--input-len {{INPUT_LEN}} \
|
||||
--output-len 1 \
|
||||
--num-requests 10 \
|
||||
--seed $(date +%s)
|
||||
|
||||
eval:
|
||||
lm_eval --model local-completions --tasks gsm8k \
|
||||
--model_args model={{MODEL}},base_url=http://127.0.0.1:{{PROXY_PORT}}/v1/completions,num_concurrent=100,max_retries=3,tokenized_requests=False \
|
||||
--limit 1000
|
||||
@ -209,7 +209,7 @@ class Attention(nn.Module):
|
||||
if self.use_output:
|
||||
output_shape = (output_shape
|
||||
if output_shape is not None else query.shape)
|
||||
output = torch.empty(output_shape,
|
||||
output = torch.zeros(output_shape,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
hidden_size = output_shape[-1]
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch.llm.modules as ipex_modules
|
||||
@ -120,7 +120,7 @@ class _PagedAttention:
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: Dict[int, List[int]],
|
||||
src_to_dists: torch.Tensor,
|
||||
*args,
|
||||
) -> None:
|
||||
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
||||
|
||||
@ -2285,7 +2285,7 @@ Device = Literal["auto", "cuda", "neuron", "cpu", "tpu", "xpu", "hpu"]
|
||||
class DeviceConfig:
|
||||
"""Configuration for the device to use for vLLM execution."""
|
||||
|
||||
device: SkipValidation[Union[Device, torch.device]] = "auto"
|
||||
device: SkipValidation[Optional[Union[Device, torch.device]]] = "auto"
|
||||
"""Device type for vLLM execution.
|
||||
This parameter is deprecated and will be
|
||||
removed in a future release.
|
||||
@ -2327,7 +2327,10 @@ class DeviceConfig:
|
||||
"to turn on verbose logging to help debug the issue.")
|
||||
else:
|
||||
# Device type is assigned explicitly
|
||||
self.device_type = self.device
|
||||
if isinstance(self.device, str):
|
||||
self.device_type = self.device
|
||||
elif isinstance(self.device, torch.device):
|
||||
self.device_type = self.device.type
|
||||
|
||||
# Some device types require processing inputs on CPU
|
||||
if self.device_type in ["neuron"]:
|
||||
|
||||
@ -138,9 +138,29 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
|
||||
super().__init__(cpu_group)
|
||||
self.handle_cache = Cache()
|
||||
|
||||
# This is the DeepEP default. Stick to it till we can establish
|
||||
# reasonable defaults based on profiling.
|
||||
self.num_sms = 20
|
||||
# Use all SMs for all2all communication
|
||||
# This will need to be adjusted for dual-batch overlap
|
||||
device = self.dp_group.device
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
self.num_sms = props.multi_processor_count
|
||||
print(f"Setting num sms to {self.num_sms}")
|
||||
print(f"Setting num sms to {self.num_sms}")
|
||||
print(f"Setting num sms to {self.num_sms}")
|
||||
print(f"Setting num sms to {self.num_sms}")
|
||||
print(f"Setting num sms to {self.num_sms}")
|
||||
print(f"Setting num sms to {self.num_sms}")
|
||||
print(f"Setting num sms to {self.num_sms}")
|
||||
print(f"Setting num sms to {self.num_sms}")
|
||||
print(f"Setting num sms to {self.num_sms}")
|
||||
print(f"Setting num sms to {self.num_sms}")
|
||||
print(f"Setting num sms to {self.num_sms}")
|
||||
print(f"Setting num sms to {self.num_sms}")
|
||||
print(f"Setting num sms to {self.num_sms}")
|
||||
print(f"Setting num sms to {self.num_sms}")
|
||||
print(f"Setting num sms to {self.num_sms}")
|
||||
print(f"Setting num sms to {self.num_sms}")
|
||||
print(f"Setting num sms to {self.num_sms}")
|
||||
print(f"Setting num sms to {self.num_sms}")
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@ -272,6 +272,14 @@ class NCCLLibrary:
|
||||
ctypes.byref(unique_id)))
|
||||
return unique_id
|
||||
|
||||
def unique_id_from_bytes(self, data: bytes) -> ncclUniqueId:
|
||||
if len(data) != 128:
|
||||
raise ValueError(
|
||||
f"Expected 128 bytes for ncclUniqueId, got {len(data)} bytes")
|
||||
unique_id = ncclUniqueId()
|
||||
ctypes.memmove(ctypes.addressof(unique_id.internal), data, 128)
|
||||
return unique_id
|
||||
|
||||
def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId,
|
||||
rank: int) -> ncclComm_t:
|
||||
comm = ncclComm_t()
|
||||
|
||||
@ -112,6 +112,11 @@ KVConnectorFactory.register_connector(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector",
|
||||
"SharedStorageConnector")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"P2pNcclConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_connector",
|
||||
"P2pNcclConnector")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"LMCacheConnectorV1",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector",
|
||||
|
||||
@ -37,9 +37,6 @@ if TYPE_CHECKING:
|
||||
Transfer = tuple[int, float] # (xfer_handle, start_time)
|
||||
GET_META_MSG = b"get_meta_msg"
|
||||
|
||||
import os
|
||||
LOG_XFER_TIME = os.getenv("VLLM_LOG_XFER_TIME", "0") == "1"
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
|
||||
@ -332,17 +329,7 @@ class NixlConnectorWorker:
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
|
||||
# Agent.
|
||||
import os
|
||||
num_workers = 32
|
||||
# setting num_workers on the prefiller causes no notifs to be recved???
|
||||
# this is a hack to make sure we set num workers on the prefiller to 1.
|
||||
if os.getenv("VLLM_IS_PREFILL", "0") == "1":
|
||||
num_workers = None
|
||||
print(f"NUM_WORKERS: {num_workers=}")
|
||||
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()),
|
||||
None,
|
||||
num_workers=None,
|
||||
num_shared_workers=num_workers)
|
||||
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
|
||||
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
|
||||
self._remote_agents: dict[str, dict[int, str]] = defaultdict(dict)
|
||||
|
||||
@ -384,8 +371,8 @@ class NixlConnectorWorker:
|
||||
self._registered_descs: list[Any] = []
|
||||
|
||||
# In progress transfers.
|
||||
# [req_id -> list[handles], agent_name, notif_id]
|
||||
self._recving_transfers: dict[str, tuple[list[int], str, str]] = {}
|
||||
# [req_id -> list[handle]]
|
||||
self._recving_transfers = defaultdict[str, list[Transfer]](list)
|
||||
|
||||
# Complete transfer tracker. Used by the rank 0 to track finished
|
||||
# transactions on ranks 1 to N-1.
|
||||
@ -767,11 +754,8 @@ class NixlConnectorWorker:
|
||||
to Rank 0 once their transaction is done + Rank 0 returns
|
||||
finished sets to Scheduler only once all ranks are done.
|
||||
"""
|
||||
|
||||
start = time.perf_counter()
|
||||
done_sending = self._get_new_notifs()
|
||||
done_recving = self._pop_done_transfers(self._recving_transfers)
|
||||
|
||||
if len(done_sending) > 0 or len(done_recving) > 0:
|
||||
logger.debug(
|
||||
"Rank %s, get_finished: %s requests done sending "
|
||||
@ -812,10 +796,6 @@ class NixlConnectorWorker:
|
||||
if self._done_sending_count[req_id] == self.world_size:
|
||||
del self._done_sending_count[req_id]
|
||||
all_done_sending.add(req_id)
|
||||
|
||||
end = time.perf_counter()
|
||||
if LOG_XFER_TIME:
|
||||
logger.info("========== .get_finished(): %s ==========", end - start)
|
||||
|
||||
return all_done_sending, all_done_recving
|
||||
|
||||
@ -825,10 +805,6 @@ class NixlConnectorWorker:
|
||||
self.tp_group.send_object(finished_req_ids, dst=0)
|
||||
|
||||
# Unused as only Rank 0 results are sent to scheduler.
|
||||
end = time.perf_counter()
|
||||
if LOG_XFER_TIME:
|
||||
logger.info("========== .get_finished(): %s ==========", end - start)
|
||||
|
||||
return done_sending, done_recving
|
||||
|
||||
def _get_new_notifs(self) -> set[str]:
|
||||
@ -850,8 +826,7 @@ class NixlConnectorWorker:
|
||||
return notified_req_ids
|
||||
|
||||
def _pop_done_transfers(
|
||||
self, transfers: dict[str, tuple[list[int], str,
|
||||
str]]) -> set[str]:
|
||||
self, transfers: dict[str, list[tuple[int, float]]]) -> set[str]:
|
||||
"""
|
||||
Pop completed xfers by checking for DONE state.
|
||||
Args:
|
||||
@ -859,30 +834,19 @@ class NixlConnectorWorker:
|
||||
Returns:
|
||||
set of req_ids that have all done xfers
|
||||
"""
|
||||
done_req_ids: set[str, float] = set()
|
||||
for req_id, (handles, agent_name, notif_id, start_time) in list(transfers.items()):
|
||||
new_handles = []
|
||||
for handle in handles:
|
||||
done_req_ids: set[str] = set()
|
||||
for req_id, handles in list(transfers.items()):
|
||||
for handle, xfer_stime in handles:
|
||||
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
|
||||
if xfer_state == "DONE":
|
||||
self.nixl_wrapper.release_xfer_handle(handle)
|
||||
done_req_ids.add(req_id)
|
||||
del transfers[req_id]
|
||||
elif xfer_state == "PROC":
|
||||
new_handles.append(handle)
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError("Transfer failed with state %s",
|
||||
xfer_state)
|
||||
|
||||
# Done.
|
||||
if len(new_handles) == 0:
|
||||
self.nixl_wrapper.send_notif(agent_name, notif_id)
|
||||
del transfers[req_id]
|
||||
done_req_ids.add(req_id)
|
||||
if LOG_XFER_TIME:
|
||||
logger.info("========== transmission time: %s ==========", time.perf_counter() - start_time)
|
||||
|
||||
else:
|
||||
transfers[req_id] = (new_handles, agent_name, notif_id, start_time)
|
||||
|
||||
return done_req_ids
|
||||
|
||||
def start_load_kv(self, metadata: NixlConnectorMetadata):
|
||||
@ -994,35 +958,22 @@ class NixlConnectorWorker:
|
||||
assert len(local_block_descs_ids) == len(remote_block_descs_ids)
|
||||
|
||||
# Prepare transfer with Nixl.
|
||||
CHUNK_SIZE = 500
|
||||
handles = []
|
||||
# NOTE: this is a hack to make make_prepped_xfer into threads so that
|
||||
# different workers are allocated for each chuck. Without this change,
|
||||
# nixl was allocating the same worker (0) for all the chunks and the
|
||||
# overall launch time was >300 ms.
|
||||
for i in range(0, len(local_block_descs_ids), CHUNK_SIZE):
|
||||
handle = self.nixl_wrapper.make_prepped_xfer(
|
||||
"READ",
|
||||
local_xfer_side_handle,
|
||||
local_block_descs_ids[i:i + CHUNK_SIZE],
|
||||
remote_xfer_side_handle,
|
||||
remote_block_descs_ids[i:i + CHUNK_SIZE],
|
||||
skip_desc_merge=True,
|
||||
)
|
||||
handles.append(handle)
|
||||
handle = self.nixl_wrapper.make_prepped_xfer(
|
||||
"READ",
|
||||
local_xfer_side_handle,
|
||||
local_block_descs_ids,
|
||||
remote_xfer_side_handle,
|
||||
remote_block_descs_ids,
|
||||
notif_msg=notif_id,
|
||||
)
|
||||
|
||||
# Begin async xfer.
|
||||
start = time.perf_counter()
|
||||
self.nixl_wrapper.transfer_batched(handles)
|
||||
end = time.perf_counter()
|
||||
if LOG_XFER_TIME:
|
||||
logger.info("========== .transfer_batched(): %s ==========", end - start)
|
||||
self.nixl_wrapper.transfer(handle)
|
||||
|
||||
# Keep track of ongoing transfers.
|
||||
remote_rank = self.tp_rank // tp_ratio
|
||||
agent_name = self._remote_agents[dst_engine_id][remote_rank]
|
||||
assert request_id not in self._recving_transfers
|
||||
self._recving_transfers[request_id] = (handles, agent_name, notif_id, time.perf_counter())
|
||||
# Use handle to check completion in future step().
|
||||
# TODO (NickLucche) surface xfer elapsed time
|
||||
self._recving_transfers[request_id].append(
|
||||
(handle, time.perf_counter()))
|
||||
|
||||
def _get_block_descs_ids(self,
|
||||
engine_id: str,
|
||||
|
||||
@ -0,0 +1,481 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine import (
|
||||
P2pNcclEngine)
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
# Request Id
|
||||
request_id: str
|
||||
# Request tokens
|
||||
token_ids: torch.Tensor
|
||||
# Slot mappings, should have the same length as token_ids
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
@staticmethod
|
||||
def make_meta(request_id: str, token_ids: list[int], block_ids: list[int],
|
||||
block_size: int) -> "ReqMeta":
|
||||
valid_num_tokens = len(token_ids)
|
||||
token_ids_tensor = torch.tensor(token_ids)
|
||||
block_ids_tensor = torch.tensor(block_ids)
|
||||
num_blocks = block_ids_tensor.shape[0]
|
||||
block_offsets = torch.arange(0, block_size)
|
||||
slot_mapping = block_offsets.reshape((1, block_size)) + \
|
||||
block_ids_tensor.reshape((num_blocks, 1)) * block_size
|
||||
slot_mapping = slot_mapping.flatten()[:valid_num_tokens]
|
||||
|
||||
return ReqMeta(
|
||||
request_id=request_id,
|
||||
token_ids=token_ids_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class P2pNcclConnectorMetadata(KVConnectorMetadata):
|
||||
requests: list[ReqMeta]
|
||||
|
||||
def __init__(self):
|
||||
self.requests = []
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
token_ids: list[int],
|
||||
block_ids: list[int],
|
||||
block_size: int,
|
||||
) -> None:
|
||||
self.requests.append(
|
||||
ReqMeta.make_meta(request_id, token_ids, block_ids, block_size))
|
||||
|
||||
|
||||
class P2pNcclConnector(KVConnectorBase_V1):
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
||||
super().__init__(vllm_config=vllm_config, role=role)
|
||||
self._block_size = vllm_config.cache_config.block_size
|
||||
self._requests_need_load: dict[str, Any] = {}
|
||||
self.config = vllm_config.kv_transfer_config
|
||||
self.is_producer = self.config.is_kv_producer
|
||||
self.chunked_prefill: dict[str, Any] = {}
|
||||
|
||||
self._rank = get_world_group().rank \
|
||||
if role == KVConnectorRole.WORKER else 0
|
||||
self._local_rank = get_world_group().local_rank \
|
||||
if role == KVConnectorRole.WORKER else 0
|
||||
|
||||
self.p2p_nccl_engine = P2pNcclEngine(
|
||||
local_rank=self._local_rank,
|
||||
config=self.config,
|
||||
hostname="",
|
||||
port_offset=self._rank,
|
||||
) if role == KVConnectorRole.WORKER else None
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
**kwargs) -> None:
|
||||
"""Start loading the KV cache from the connector buffer to vLLM's
|
||||
paged KV buffer.
|
||||
|
||||
Args:
|
||||
forward_context (ForwardContext): the forward context.
|
||||
**kwargs: additional arguments for the load operation
|
||||
|
||||
Note:
|
||||
The number of elements in kv_caches and layer_names should be
|
||||
the same.
|
||||
"""
|
||||
|
||||
# Only consumer/decode loads KV Cache
|
||||
if self.is_producer:
|
||||
return
|
||||
|
||||
assert self.p2p_nccl_engine is not None
|
||||
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if attn_metadata is None:
|
||||
return
|
||||
|
||||
def inject_kv_into_layer(
|
||||
dst_kv_cache_layer: torch.Tensor,
|
||||
src_kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
request_id: str,
|
||||
) -> None:
|
||||
"""Inject the KV cache into the layer.
|
||||
|
||||
Args:
|
||||
dst_kv_cache_layer (torch.Tensor): the destination KV cache
|
||||
layer. In shape [2, num_pages, page_size, xxx] if not
|
||||
using MLA, [num_pages, page_size, xxx] otherwise.
|
||||
src_kv_cache (torch.Tensor): the source KV cache. In shape
|
||||
[2, num_tokens, xxx] if not using MLA, [num_tokens, xxx]
|
||||
otherwise.
|
||||
slot_mapping (torch.Tensor): the slot mapping. In shape
|
||||
[num_tokens].
|
||||
request_id (str): request id for log
|
||||
"""
|
||||
dst_kv_cache_layer_shape = dst_kv_cache_layer.shape
|
||||
if isinstance(attn_metadata, MLACommonMetadata):
|
||||
num_pages = dst_kv_cache_layer_shape[0]
|
||||
page_size = dst_kv_cache_layer_shape[1]
|
||||
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
|
||||
num_pages * page_size, -1)
|
||||
self.check_tensors_except_dim(dst_kv_cache_layer, src_kv_cache,
|
||||
0)
|
||||
num_token = src_kv_cache.shape[0]
|
||||
if len(slot_mapping) == num_token:
|
||||
dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
|
||||
else:
|
||||
dst_kv_cache_layer[slot_mapping[:num_token],
|
||||
...] = src_kv_cache
|
||||
logger.warning(
|
||||
"🚧src_kv_cache does not match, num_slot:%d, "
|
||||
"num_token:%d, request_id:%s", len(slot_mapping),
|
||||
num_token, request_id)
|
||||
|
||||
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
|
||||
else:
|
||||
num_pages = dst_kv_cache_layer_shape[1]
|
||||
page_size = dst_kv_cache_layer_shape[2]
|
||||
dst_kv_cache_layer = dst_kv_cache_layer.reshape(
|
||||
2, num_pages * page_size, -1)
|
||||
self.check_tensors_except_dim(dst_kv_cache_layer, src_kv_cache,
|
||||
1)
|
||||
num_token = src_kv_cache.shape[1]
|
||||
if len(slot_mapping) == num_token:
|
||||
dst_kv_cache_layer[:, slot_mapping, ...] = src_kv_cache
|
||||
else:
|
||||
dst_kv_cache_layer[:, slot_mapping[:num_token],
|
||||
...] = src_kv_cache
|
||||
logger.warning(
|
||||
"🚧src_kv_cache does not match, num_slot:%d, "
|
||||
"num_token:%d, request_id:%s", len(slot_mapping),
|
||||
num_token, request_id)
|
||||
|
||||
dst_kv_cache_layer.reshape(dst_kv_cache_layer_shape)
|
||||
|
||||
# Get the metadata
|
||||
metadata: KVConnectorMetadata = \
|
||||
self._get_connector_metadata()
|
||||
assert isinstance(metadata, P2pNcclConnectorMetadata)
|
||||
|
||||
if metadata is None:
|
||||
return
|
||||
|
||||
# Load the KV for each request each layer
|
||||
for request in metadata.requests:
|
||||
for layer_name in forward_context.no_compile_layers:
|
||||
attn_layer = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache_layer = attn_layer.kv_cache[ \
|
||||
forward_context.virtual_engine]
|
||||
|
||||
kv_cache = self.p2p_nccl_engine.recv_tensor(
|
||||
request.request_id + "#" + layer_name)
|
||||
|
||||
if kv_cache is None:
|
||||
logger.warning("🚧src_kv_cache is None, %s",
|
||||
request.request_id)
|
||||
continue
|
||||
|
||||
inject_kv_into_layer(kv_cache_layer, kv_cache,
|
||||
request.slot_mapping, request.request_id)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""Blocking until the KV for a specific layer is loaded into vLLM's
|
||||
paged buffer.
|
||||
|
||||
This interface will be useful for layer-by-layer pipelining.
|
||||
|
||||
Args:
|
||||
layer_name: the name of that layer
|
||||
"""
|
||||
return
|
||||
|
||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
||||
"""Start saving the KV cache of the layer from vLLM's paged buffer
|
||||
to the connector.
|
||||
|
||||
Args:
|
||||
layer_name (str): the name of the layer.
|
||||
kv_layer (torch.Tensor): the paged KV buffer of the current
|
||||
layer in vLLM.
|
||||
attn_metadata (AttentionMetadata): the attention metadata.
|
||||
**kwargs: additional arguments for the save operation.
|
||||
"""
|
||||
|
||||
# Only producer/prefill saves KV Cache
|
||||
if not self.is_producer:
|
||||
return
|
||||
|
||||
assert self.p2p_nccl_engine is not None
|
||||
|
||||
def extract_kv_from_layer(
|
||||
layer: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Extract the KV cache from the layer.
|
||||
|
||||
Assume the shape of the layer is (2, num_pages, page_size, xxx)
|
||||
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
|
||||
"""
|
||||
if isinstance(attn_metadata, MLACommonMetadata):
|
||||
num_pages, page_size = layer.shape[0], layer.shape[1]
|
||||
return layer.reshape(num_pages * page_size, -1)[slot_mapping,
|
||||
...]
|
||||
num_pages, page_size = layer.shape[1], layer.shape[2]
|
||||
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping,
|
||||
...]
|
||||
|
||||
connector_metadata = self._get_connector_metadata()
|
||||
assert isinstance(connector_metadata, P2pNcclConnectorMetadata)
|
||||
for request in connector_metadata.requests:
|
||||
request_id = request.request_id
|
||||
ip, port = self.parse_request_id(request_id, True)
|
||||
remote_address = ip + ":" + str(port + self._rank)
|
||||
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping)
|
||||
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
|
||||
kv_cache, remote_address)
|
||||
|
||||
def wait_for_save(self):
|
||||
if self.is_producer:
|
||||
assert self.p2p_nccl_engine is not None
|
||||
self.p2p_nccl_engine.wait_for_sent()
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str],
|
||||
**kwargs) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||
"""
|
||||
Notifies worker-side connector ids of requests that have
|
||||
finished generating tokens.
|
||||
|
||||
Returns:
|
||||
ids of requests that have finished asynchronous transfer,
|
||||
tuple of (sending/saving ids, recving/loading ids).
|
||||
The finished saves/sends req ids must belong to a set provided in a
|
||||
call to this method (this call or a prior one).
|
||||
"""
|
||||
|
||||
assert self.p2p_nccl_engine is not None
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
return self.p2p_nccl_engine.get_finished(finished_req_ids,
|
||||
forward_context)
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded from the
|
||||
external KV cache beyond the num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
"""
|
||||
if self.is_producer:
|
||||
return 0, False
|
||||
|
||||
num_external_tokens = (len(request.prompt_token_ids) - 1 -
|
||||
num_computed_tokens)
|
||||
|
||||
if num_external_tokens < 0:
|
||||
num_external_tokens = 0
|
||||
|
||||
return num_external_tokens, False
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
"""
|
||||
Update KVConnector state after block allocation.
|
||||
"""
|
||||
if not self.is_producer and num_external_tokens > 0:
|
||||
self._requests_need_load[request.request_id] = (
|
||||
request, blocks.get_block_ids()[0])
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> KVConnectorMetadata:
|
||||
"""Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify any fields in the scheduler_output.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
|
||||
meta = P2pNcclConnectorMetadata()
|
||||
|
||||
for new_req in scheduler_output.scheduled_new_reqs:
|
||||
if self.is_producer:
|
||||
num_scheduled_tokens = (
|
||||
scheduler_output.num_scheduled_tokens)[new_req.req_id]
|
||||
num_tokens = num_scheduled_tokens + new_req.num_computed_tokens
|
||||
# the request's prompt is chunked prefill
|
||||
if num_tokens < len(new_req.prompt_token_ids):
|
||||
# 'CachedRequestData' has no attribute 'prompt_token_ids'
|
||||
self.chunked_prefill[new_req.req_id] = (
|
||||
new_req.block_ids[0], new_req.prompt_token_ids)
|
||||
continue
|
||||
# the request's prompt is not chunked prefill
|
||||
meta.add_request(request_id=new_req.req_id,
|
||||
token_ids=new_req.prompt_token_ids,
|
||||
block_ids=new_req.block_ids[0],
|
||||
block_size=self._block_size)
|
||||
continue
|
||||
if new_req.req_id in self._requests_need_load:
|
||||
meta.add_request(request_id=new_req.req_id,
|
||||
token_ids=new_req.prompt_token_ids,
|
||||
block_ids=new_req.block_ids[0],
|
||||
block_size=self._block_size)
|
||||
self._requests_need_load.pop(new_req.req_id)
|
||||
|
||||
for cached_req in scheduler_output.scheduled_cached_reqs:
|
||||
if self.is_producer:
|
||||
num_scheduled_tokens = (
|
||||
scheduler_output.num_scheduled_tokens)[cached_req.req_id]
|
||||
num_tokens = (num_scheduled_tokens +
|
||||
cached_req.num_computed_tokens)
|
||||
assert cached_req.req_id in self.chunked_prefill
|
||||
block_ids = cached_req.new_block_ids[0]
|
||||
if not cached_req.resumed_from_preemption:
|
||||
block_ids = (self.chunked_prefill[cached_req.req_id][0] +
|
||||
block_ids)
|
||||
prompt_token_ids = self.chunked_prefill[cached_req.req_id][1]
|
||||
# the request's prompt is chunked prefill again
|
||||
if num_tokens < len(prompt_token_ids):
|
||||
self.chunked_prefill[cached_req.req_id] = (
|
||||
block_ids, prompt_token_ids)
|
||||
continue
|
||||
# the request's prompt is all prefilled finally
|
||||
meta.add_request(request_id=cached_req.req_id,
|
||||
token_ids=prompt_token_ids,
|
||||
block_ids=block_ids,
|
||||
block_size=self._block_size)
|
||||
self.chunked_prefill.pop(cached_req.req_id, None)
|
||||
continue
|
||||
|
||||
# NOTE(rob): here we rely on the resumed requests being
|
||||
# the first N requests in the list scheduled_cache_reqs.
|
||||
if not cached_req.resumed_from_preemption:
|
||||
break
|
||||
if cached_req.req_id in self._requests_need_load:
|
||||
request, _ = self._requests_need_load.pop(cached_req.req_id)
|
||||
total_tokens = cached_req.num_computed_tokens + 1
|
||||
token_ids = request.all_token_ids[:total_tokens]
|
||||
|
||||
# NOTE(rob): For resumed req, new_block_ids is all
|
||||
# of the block_ids for the request.
|
||||
block_ids = cached_req.new_block_ids[0]
|
||||
|
||||
meta.add_request(request_id=cached_req.req_id,
|
||||
token_ids=token_ids,
|
||||
block_ids=block_ids,
|
||||
block_size=self._block_size)
|
||||
|
||||
# Requests loaded asynchronously are not in the scheduler_output.
|
||||
# for request_id in self._requests_need_load:
|
||||
# request, block_ids = self._requests_need_load[request_id]
|
||||
# meta.add_request(request_id=request.request_id,
|
||||
# token_ids=request.prompt_token_ids,
|
||||
# block_ids=block_ids,
|
||||
# block_size=self._block_size)
|
||||
|
||||
self._requests_need_load.clear()
|
||||
return meta
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
"""
|
||||
Called when a request has finished, before its blocks are freed.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and blocks
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
Optional KVTransferParams to be included in the request outputs
|
||||
returned by the engine.
|
||||
"""
|
||||
|
||||
self.chunked_prefill.pop(request.request_id, None)
|
||||
|
||||
return False, None
|
||||
|
||||
# ==============================
|
||||
# Static methods
|
||||
# ==============================
|
||||
|
||||
@staticmethod
|
||||
def parse_request_id(request_id: str, is_prefill=True) -> tuple[str, int]:
|
||||
# Regular expression to match the string hostname and integer port
|
||||
if is_prefill:
|
||||
pattern = r"___decode_addr_(.*):(\d+)"
|
||||
else:
|
||||
pattern = r"___prefill_addr_(.*):(\d+)___"
|
||||
|
||||
# Use re.search to find the pattern in the request_id
|
||||
match = re.search(pattern, request_id)
|
||||
if match:
|
||||
# Extract the ranks
|
||||
ip = match.group(1)
|
||||
port = int(match.group(2))
|
||||
|
||||
return ip, port
|
||||
raise ValueError(
|
||||
f"Request id {request_id} does not contain hostname and port")
|
||||
|
||||
@staticmethod
|
||||
def check_tensors_except_dim(tensor1, tensor2, dim):
|
||||
shape1 = tensor1.size()
|
||||
shape2 = tensor2.size()
|
||||
|
||||
if len(shape1) != len(shape2) or not all(
|
||||
s1 == s2
|
||||
for i, (s1, s2) in enumerate(zip(shape1, shape2)) if i != dim):
|
||||
raise NotImplementedError(
|
||||
"Currently, only symmetric TP is supported. Asymmetric TP, PP,"
|
||||
"and others will be supported in future PRs.")
|
||||
@ -0,0 +1,531 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import typing
|
||||
from collections import deque
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import msgpack
|
||||
import torch
|
||||
import zmq
|
||||
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.distributed.device_communicators.pynccl_wrapper import (
|
||||
NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501
|
||||
TensorMemoryPool)
|
||||
from vllm.utils import current_stream, get_ip
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.forward_context import ForwardContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MEM_POOL_SIZE_GB = 32
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_p2p_nccl_context(num_channels: str):
|
||||
original_values: dict[str, Any] = {}
|
||||
env_vars = [
|
||||
'NCCL_MAX_NCHANNELS',
|
||||
'NCCL_MIN_NCHANNELS',
|
||||
'NCCL_CUMEM_ENABLE',
|
||||
'NCCL_BUFFSIZE',
|
||||
'NCCL_PROTO', # LL,LL128,SIMPLE
|
||||
'NCCL_ALGO', # RING,TREE
|
||||
]
|
||||
|
||||
for var in env_vars:
|
||||
original_values[var] = os.environ.get(var)
|
||||
|
||||
logger.info("set_p2p_nccl_context, original_values: %s", original_values)
|
||||
|
||||
try:
|
||||
os.environ['NCCL_MAX_NCHANNELS'] = num_channels
|
||||
os.environ['NCCL_MIN_NCHANNELS'] = num_channels
|
||||
os.environ['NCCL_CUMEM_ENABLE'] = '1'
|
||||
yield
|
||||
finally:
|
||||
for var in env_vars:
|
||||
if original_values[var] is not None:
|
||||
os.environ[var] = original_values[var]
|
||||
else:
|
||||
os.environ.pop(var, None)
|
||||
|
||||
|
||||
class P2pNcclEngine:
|
||||
|
||||
def __init__(self,
|
||||
local_rank: int,
|
||||
config: KVTransferConfig,
|
||||
hostname: str = "",
|
||||
port_offset: int = 0,
|
||||
library_path: Optional[str] = None) -> None:
|
||||
self.config = config
|
||||
self.rank = port_offset
|
||||
self.local_rank = local_rank
|
||||
self.device = torch.device(f"cuda:{self.local_rank}")
|
||||
self.nccl = NCCLLibrary(library_path)
|
||||
|
||||
if not hostname:
|
||||
hostname = get_ip()
|
||||
port = int(self.config.kv_port) + port_offset
|
||||
if port == 0:
|
||||
raise ValueError("Port cannot be 0")
|
||||
self._hostname = hostname
|
||||
self._port = port
|
||||
|
||||
# Each card corresponds to a ZMQ address.
|
||||
self.zmq_address = f"{self._hostname}:{self._port}"
|
||||
|
||||
# The `http_port` must be consistent with the port of OpenAI.
|
||||
self.http_address = (
|
||||
f"{self._hostname}:"
|
||||
f"{self.config.kv_connector_extra_config['http_port']}")
|
||||
|
||||
# If `proxy_ip` or `proxy_port` is `""`,
|
||||
# then the ping thread will not be enabled.
|
||||
proxy_ip = self.config.get_from_extra_config("proxy_ip", "")
|
||||
proxy_port = self.config.get_from_extra_config("proxy_port", "")
|
||||
if proxy_ip == "" or proxy_port == "":
|
||||
self.proxy_address = ""
|
||||
else:
|
||||
self.proxy_address = proxy_ip + ":" + proxy_port
|
||||
|
||||
self.context = zmq.Context()
|
||||
self.router_socket = self.context.socket(zmq.ROUTER)
|
||||
self.router_socket.bind(f"tcp://{self.zmq_address}")
|
||||
|
||||
self.poller = zmq.Poller()
|
||||
self.poller.register(self.router_socket, zmq.POLLIN)
|
||||
|
||||
self.send_store_cv = threading.Condition()
|
||||
self.send_queue_cv = threading.Condition()
|
||||
self.recv_store_cv = threading.Condition()
|
||||
|
||||
self.send_stream = torch.cuda.Stream()
|
||||
self.recv_stream = torch.cuda.Stream()
|
||||
|
||||
mem_pool_size_gb = self.config.get_from_extra_config(
|
||||
"mem_pool_size_gb", DEFAULT_MEM_POOL_SIZE_GB)
|
||||
self.pool = TensorMemoryPool(max_block_size=int(mem_pool_size_gb) *
|
||||
1024**3) # GB
|
||||
|
||||
# The sending type includes tree mutually exclusive options:
|
||||
# PUT, GET, PUT_ASYNC.
|
||||
self.send_type = self.config.get_from_extra_config("send_type", "PUT")
|
||||
if self.send_type == "GET":
|
||||
# tensor_id: torch.Tensor
|
||||
self.send_store: dict[str, torch.Tensor] = {}
|
||||
else:
|
||||
# PUT or PUT_ASYNC
|
||||
# tensor_id: torch.Tensor
|
||||
self.send_queue: deque[list[Any]] = deque()
|
||||
self.send_request_id_to_tensor_ids: dict[str, set[str]] = {}
|
||||
if self.send_type == "PUT_ASYNC":
|
||||
self._send_thread = threading.Thread(target=self._send_async,
|
||||
daemon=True)
|
||||
self._send_thread.start()
|
||||
|
||||
# tensor_id: torch.Tensor/(addr, dtype, shape)
|
||||
self.recv_store: dict[str, Any] = {}
|
||||
self.recv_request_id_to_tensor_ids: dict[str, set[str]] = {}
|
||||
self.socks: dict[str, Any] = {} # remote_address: client socket
|
||||
self.comms: dict[str, Any] = {} # remote_address: (ncclComm_t, rank)
|
||||
|
||||
self.buffer_size = 0
|
||||
self.buffer_size_threshold = float(self.config.kv_buffer_size)
|
||||
|
||||
self.nccl_num_channels = self.config.get_from_extra_config(
|
||||
"nccl_num_channels", "8")
|
||||
|
||||
self._listener_thread = threading.Thread(
|
||||
target=self._listen_for_requests, daemon=True)
|
||||
self._listener_thread.start()
|
||||
|
||||
self._ping_thread = None
|
||||
if port_offset == 0 and self.proxy_address != "":
|
||||
self._ping_thread = threading.Thread(target=self._ping,
|
||||
daemon=True)
|
||||
self._ping_thread.start()
|
||||
|
||||
logger.info(
|
||||
"💯P2pNcclEngine init, rank:%d, local_rank:%d, http_address:%s, "
|
||||
"zmq_address:%s, proxy_address:%s, send_type:%s, buffer_size_"
|
||||
"threshold:%.2f, nccl_num_channels:%s", self.rank, self.local_rank,
|
||||
self.http_address, self.zmq_address, self.proxy_address,
|
||||
self.send_type, self.buffer_size_threshold, self.nccl_num_channels)
|
||||
|
||||
def _create_connect(self, remote_address: typing.Optional[str] = None):
|
||||
assert remote_address is not None
|
||||
if remote_address not in self.socks:
|
||||
sock = self.context.socket(zmq.DEALER)
|
||||
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
|
||||
sock.connect(f"tcp://{remote_address}")
|
||||
self.socks[remote_address] = sock
|
||||
if remote_address in self.comms:
|
||||
logger.info("👋comm exists, remote_address:%s, comms:%s",
|
||||
remote_address, self.comms)
|
||||
return sock, self.comms[remote_address]
|
||||
|
||||
unique_id = self.nccl.ncclGetUniqueId()
|
||||
data = {"cmd": "NEW", "unique_id": bytes(unique_id.internal)}
|
||||
sock.send(msgpack.dumps(data))
|
||||
|
||||
with torch.cuda.device(self.device):
|
||||
rank = 0
|
||||
with set_p2p_nccl_context(self.nccl_num_channels):
|
||||
comm: ncclComm_t = self.nccl.ncclCommInitRank(
|
||||
2, unique_id, rank)
|
||||
self.comms[remote_address] = (comm, rank)
|
||||
logger.info("🤝ncclCommInitRank Success, %s👉%s, MyRank: %s",
|
||||
self.zmq_address, remote_address, rank)
|
||||
|
||||
return self.socks[remote_address], self.comms[remote_address]
|
||||
|
||||
def send_tensor(
|
||||
self,
|
||||
tensor_id: str,
|
||||
tensor: torch.Tensor,
|
||||
remote_address: typing.Optional[str] = None,
|
||||
) -> bool:
|
||||
if remote_address is None:
|
||||
with self.recv_store_cv:
|
||||
self.recv_store[tensor_id] = tensor
|
||||
self.recv_store_cv.notify()
|
||||
return True
|
||||
else:
|
||||
if self.send_type == "PUT":
|
||||
return self._send_sync(tensor_id, tensor, remote_address)
|
||||
elif self.send_type == "PUT_ASYNC":
|
||||
with self.send_queue_cv:
|
||||
self.send_queue.append([tensor_id, remote_address, tensor])
|
||||
self.send_queue_cv.notify()
|
||||
else: # GET
|
||||
with self.send_store_cv:
|
||||
tensor_size = tensor.element_size() * tensor.numel()
|
||||
while (self.buffer_size + tensor_size
|
||||
> self.buffer_size_threshold):
|
||||
oldest_tenser_id = next(iter(self.send_store))
|
||||
oldest_tenser = self.send_store.pop(oldest_tenser_id)
|
||||
oldest_tenser_size = oldest_tenser.element_size(
|
||||
) * oldest_tenser.numel()
|
||||
self.buffer_size -= oldest_tenser_size
|
||||
logger.info(
|
||||
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
|
||||
" buffer_size:%d, oldest_tenser_size:%d, rank:%d",
|
||||
remote_address, tensor_id, tensor_size,
|
||||
self.buffer_size, oldest_tenser_size, self.rank)
|
||||
|
||||
self.send_store[tensor_id] = tensor
|
||||
self.buffer_size += tensor_size
|
||||
logger.debug(
|
||||
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
|
||||
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)",
|
||||
remote_address, tensor_id, tensor_size, tensor.shape,
|
||||
self.rank, self.buffer_size,
|
||||
self.buffer_size / self.buffer_size_threshold * 100)
|
||||
|
||||
return True
|
||||
|
||||
def recv_tensor(
|
||||
self,
|
||||
tensor_id: str,
|
||||
remote_address: typing.Optional[str] = None,
|
||||
) -> torch.Tensor:
|
||||
if self.send_type == "PUT" or self.send_type == "PUT_ASYNC":
|
||||
start_time = time.time()
|
||||
with self.recv_store_cv:
|
||||
while tensor_id not in self.recv_store:
|
||||
self.recv_store_cv.wait()
|
||||
tensor = self.recv_store[tensor_id]
|
||||
|
||||
if tensor is not None:
|
||||
if isinstance(tensor, tuple):
|
||||
addr, dtype, shape = tensor
|
||||
tensor = self.pool.load_tensor(addr, dtype, shape,
|
||||
self.device)
|
||||
else:
|
||||
self.buffer_size -= (tensor.element_size() *
|
||||
tensor.numel())
|
||||
else:
|
||||
duration = time.time() - start_time
|
||||
logger.warning(
|
||||
"🔴[PUT]Recv From %s, tensor_id:%s, duration:%.3fms, "
|
||||
"rank:%d", remote_address, tensor_id, duration * 1000,
|
||||
self.rank)
|
||||
return tensor
|
||||
|
||||
# GET
|
||||
if remote_address is None:
|
||||
return None
|
||||
|
||||
if remote_address not in self.socks:
|
||||
self._create_connect(remote_address)
|
||||
|
||||
sock = self.socks[remote_address]
|
||||
comm, rank = self.comms[remote_address]
|
||||
|
||||
data = {"cmd": "GET", "tensor_id": tensor_id}
|
||||
sock.send(msgpack.dumps(data))
|
||||
|
||||
message = sock.recv()
|
||||
data = msgpack.loads(message)
|
||||
if data["ret"] != 0:
|
||||
logger.warning("🔴[GET]Recv From %s, tensor_id: %s, ret: %d",
|
||||
remote_address, tensor_id, data["ret"])
|
||||
return None
|
||||
|
||||
tensor = torch.empty(data["shape"],
|
||||
dtype=getattr(torch, data["dtype"]),
|
||||
device=self.device)
|
||||
|
||||
self._recv(comm, tensor, rank ^ 1, self.recv_stream)
|
||||
|
||||
return tensor
|
||||
|
||||
def _listen_for_requests(self):
|
||||
while True:
|
||||
socks = dict(self.poller.poll())
|
||||
if self.router_socket in socks:
|
||||
remote_address, message = self.router_socket.recv_multipart()
|
||||
data = msgpack.loads(message)
|
||||
if data["cmd"] == "NEW":
|
||||
unique_id = self.nccl.unique_id_from_bytes(
|
||||
bytes(data["unique_id"]))
|
||||
with torch.cuda.device(self.device):
|
||||
rank = 1
|
||||
with set_p2p_nccl_context(self.nccl_num_channels):
|
||||
comm: ncclComm_t = self.nccl.ncclCommInitRank(
|
||||
2, unique_id, rank)
|
||||
self.comms[remote_address.decode()] = (comm, rank)
|
||||
logger.info(
|
||||
"🤝ncclCommInitRank Success, %s👈%s, MyRank:%s",
|
||||
self.zmq_address, remote_address.decode(), rank)
|
||||
elif data["cmd"] == "PUT":
|
||||
tensor_id = data["tensor_id"]
|
||||
try:
|
||||
tensor = torch.empty(data["shape"],
|
||||
dtype=getattr(
|
||||
torch, data["dtype"]),
|
||||
device=self.device)
|
||||
self.router_socket.send_multipart(
|
||||
[remote_address, b"0"])
|
||||
comm, rank = self.comms[remote_address.decode()]
|
||||
self._recv(comm, tensor, rank ^ 1, self.recv_stream)
|
||||
tensor_size = tensor.element_size() * tensor.numel()
|
||||
if (self.buffer_size + tensor_size
|
||||
> self.buffer_size_threshold):
|
||||
# Store Tensor in memory pool
|
||||
addr = self.pool.store_tensor(tensor)
|
||||
tensor = (addr, tensor.dtype, tensor.shape)
|
||||
logger.warning(
|
||||
"🔴[PUT]Recv Tensor, Out Of Threshold, "
|
||||
"%s👈%s, data:%s, addr:%d", self.zmq_address,
|
||||
remote_address.decode(), data, addr)
|
||||
else:
|
||||
self.buffer_size += tensor_size
|
||||
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
self.router_socket.send_multipart(
|
||||
[remote_address, b"1"])
|
||||
tensor = None
|
||||
logger.warning(
|
||||
"🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, "
|
||||
"data:%s", self.zmq_address,
|
||||
remote_address.decode(), data)
|
||||
|
||||
with self.recv_store_cv:
|
||||
self.recv_store[tensor_id] = tensor
|
||||
self._have_received_tensor_id(tensor_id)
|
||||
self.recv_store_cv.notify()
|
||||
|
||||
elif data["cmd"] == "GET":
|
||||
tensor_id = data["tensor_id"]
|
||||
with self.send_store_cv:
|
||||
tensor = self.send_store.pop(tensor_id, None)
|
||||
if tensor is not None:
|
||||
data = {
|
||||
"ret": 0,
|
||||
"shape": tensor.shape,
|
||||
"dtype":
|
||||
str(tensor.dtype).replace("torch.", "")
|
||||
}
|
||||
# LRU
|
||||
self.send_store[tensor_id] = tensor
|
||||
self._have_sent_tensor_id(tensor_id)
|
||||
else:
|
||||
data = {"ret": 1}
|
||||
|
||||
self.router_socket.send_multipart(
|
||||
[remote_address, msgpack.dumps(data)])
|
||||
|
||||
if data["ret"] == 0:
|
||||
comm, rank = self.comms[remote_address.decode()]
|
||||
self._send(comm, tensor.to(self.device), rank ^ 1,
|
||||
self.send_stream)
|
||||
else:
|
||||
logger.warning(
|
||||
"🚧Unexpected, Received message from %s, data:%s",
|
||||
remote_address, data)
|
||||
|
||||
def _have_sent_tensor_id(self, tensor_id: str):
|
||||
request_id = tensor_id.split('#')[0]
|
||||
if request_id not in self.send_request_id_to_tensor_ids:
|
||||
self.send_request_id_to_tensor_ids[request_id] = set()
|
||||
self.send_request_id_to_tensor_ids[request_id].add(tensor_id)
|
||||
|
||||
def _have_received_tensor_id(self, tensor_id: str):
|
||||
request_id = tensor_id.split('#')[0]
|
||||
if request_id not in self.recv_request_id_to_tensor_ids:
|
||||
self.recv_request_id_to_tensor_ids[request_id] = set()
|
||||
self.recv_request_id_to_tensor_ids[request_id].add(tensor_id)
|
||||
|
||||
def _send_async(self):
|
||||
while True:
|
||||
with self.send_queue_cv:
|
||||
while not self.send_queue:
|
||||
self.send_queue_cv.wait()
|
||||
tensor_id, remote_address, tensor = self.send_queue.popleft()
|
||||
if not self.send_queue:
|
||||
self.send_queue_cv.notify()
|
||||
self._send_sync(tensor_id, tensor, remote_address)
|
||||
|
||||
def wait_for_sent(self):
|
||||
if self.send_type == "PUT_ASYNC":
|
||||
start_time = time.time()
|
||||
with self.send_queue_cv:
|
||||
while self.send_queue:
|
||||
self.send_queue_cv.wait()
|
||||
duration = time.time() - start_time
|
||||
logger.debug(
|
||||
"🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue"
|
||||
" to be empty, rank:%d", duration * 1000, self.rank)
|
||||
|
||||
def _send_sync(
|
||||
self,
|
||||
tensor_id: str,
|
||||
tensor: torch.Tensor,
|
||||
remote_address: typing.Optional[str] = None,
|
||||
) -> bool:
|
||||
if remote_address is None:
|
||||
return False
|
||||
if remote_address not in self.socks:
|
||||
self._create_connect(remote_address)
|
||||
|
||||
sock = self.socks[remote_address]
|
||||
comm, rank = self.comms[remote_address]
|
||||
data = {
|
||||
"cmd": "PUT",
|
||||
"tensor_id": tensor_id,
|
||||
"shape": tensor.shape,
|
||||
"dtype": str(tensor.dtype).replace("torch.", "")
|
||||
}
|
||||
sock.send(msgpack.dumps(data))
|
||||
|
||||
response = sock.recv()
|
||||
if response != b"0":
|
||||
logger.error(
|
||||
"🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, "
|
||||
"MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s",
|
||||
self.zmq_address, remote_address, rank, data, tensor.shape,
|
||||
tensor.element_size() * tensor.numel() / 1024**3,
|
||||
response.decode())
|
||||
return False
|
||||
|
||||
self._send(comm, tensor.to(self.device), rank ^ 1, self.send_stream)
|
||||
|
||||
if self.send_type == "PUT_ASYNC":
|
||||
self._have_sent_tensor_id(tensor_id)
|
||||
|
||||
return True
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str], forward_context: "ForwardContext"
|
||||
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||
"""
|
||||
Notifies worker-side connector ids of requests that have
|
||||
finished generating tokens.
|
||||
|
||||
Returns:
|
||||
ids of requests that have finished asynchronous transfer,
|
||||
tuple of (sending/saving ids, recving/loading ids).
|
||||
The finished saves/sends req ids must belong to a set provided in a
|
||||
call to this method (this call or a prior one).
|
||||
"""
|
||||
|
||||
# Clear the buffer upon request completion.
|
||||
for request_id in finished_req_ids:
|
||||
for layer_name in forward_context.no_compile_layers:
|
||||
tensor_id = request_id + "#" + layer_name
|
||||
if tensor_id in self.recv_store:
|
||||
with self.recv_store_cv:
|
||||
tensor = self.recv_store.pop(tensor_id, None)
|
||||
self.send_request_id_to_tensor_ids.pop(
|
||||
request_id, None)
|
||||
self.recv_request_id_to_tensor_ids.pop(
|
||||
request_id, None)
|
||||
addr = 0
|
||||
if isinstance(tensor, tuple):
|
||||
addr, _, _ = tensor
|
||||
self.pool.free(addr)
|
||||
|
||||
# TODO:Retrieve requests that have already sent the KV cache.
|
||||
finished_sending: set[str] = set()
|
||||
|
||||
# TODO:Retrieve requests that have already received the KV cache.
|
||||
finished_recving: set[str] = set()
|
||||
|
||||
return finished_sending or None, finished_recving or None
|
||||
|
||||
def _ping(self):
|
||||
sock = self.context.socket(zmq.DEALER)
|
||||
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
|
||||
logger.debug("ping start, zmq_address:%s", self.zmq_address)
|
||||
sock.connect(f"tcp://{self.proxy_address}")
|
||||
data = {
|
||||
"type": "P" if self.config.is_kv_producer else "D",
|
||||
"http_address": self.http_address,
|
||||
"zmq_address": self.zmq_address
|
||||
}
|
||||
while True:
|
||||
sock.send(msgpack.dumps(data))
|
||||
time.sleep(3)
|
||||
|
||||
def _send(self, comm, tensor: torch.Tensor, dst: int, stream=None):
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype), dst,
|
||||
comm, cudaStream_t(stream.cuda_stream))
|
||||
stream.synchronize()
|
||||
|
||||
def _recv(self, comm, tensor: torch.Tensor, src: int, stream=None):
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
if stream is None:
|
||||
stream = current_stream()
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(tensor.dtype), src,
|
||||
comm, cudaStream_t(stream.cuda_stream))
|
||||
stream.synchronize()
|
||||
|
||||
def close(self) -> None:
|
||||
self._listener_thread.join()
|
||||
if self.send_type == "PUT_ASYNC":
|
||||
self._send_thread.join()
|
||||
if self._ping_thread is not None:
|
||||
self._ping_thread.join()
|
||||
@ -0,0 +1,264 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import atexit
|
||||
import ctypes
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryBlock:
|
||||
size: int
|
||||
addr: int
|
||||
|
||||
|
||||
"""A memory pool for managing pinned host memory allocations for tensors.
|
||||
|
||||
This class implements a buddy allocation system to efficiently manage pinned
|
||||
host memory for tensor storage. It supports allocation, deallocation, and
|
||||
tensor storage/retrieval operations.
|
||||
|
||||
Key Features:
|
||||
- Uses power-of-two block sizes for efficient buddy allocation
|
||||
- Supports splitting and merging of memory blocks
|
||||
- Provides methods to store CUDA tensors in pinned host memory
|
||||
- Allows loading tensors from pinned memory back to device
|
||||
- Automatically cleans up memory on destruction
|
||||
|
||||
Attributes:
|
||||
max_block_size (int): Maximum block size (rounded to nearest power of two)
|
||||
min_block_size (int): Minimum block size (rounded to nearest power of two)
|
||||
free_lists (dict): Dictionary of free memory blocks by size
|
||||
allocated_blocks (dict): Dictionary of currently allocated blocks
|
||||
base_tensor (torch.Tensor): Base pinned memory tensor
|
||||
base_address (int): Base memory address of the pinned memory region
|
||||
|
||||
Example:
|
||||
>>> pool = TensorMemoryPool(max_block_size=1024*1024)
|
||||
>>> tensor = torch.randn(100, device='cuda')
|
||||
>>> addr = pool.store_tensor(tensor)
|
||||
>>> loaded_tensor = pool.load_tensor(addr, tensor.dtype,
|
||||
... tensor.shape, 'cuda')
|
||||
>>> pool.free(addr)
|
||||
"""
|
||||
|
||||
|
||||
class TensorMemoryPool:
|
||||
"""Initializes the memory pool with given size constraints.
|
||||
|
||||
Args:
|
||||
max_block_size (int): Maximum size of memory blocks to manage
|
||||
min_block_size (int, optional): Minimum size of memory blocks
|
||||
to manage. Defaults to 512.
|
||||
|
||||
Raises:
|
||||
ValueError: If block sizes are invalid or max_block_size is less
|
||||
than min_block_size
|
||||
"""
|
||||
|
||||
def __init__(self, max_block_size: int, min_block_size: int = 512):
|
||||
if max_block_size <= 0 or min_block_size <= 0:
|
||||
raise ValueError("Block sizes must be positive")
|
||||
if max_block_size < min_block_size:
|
||||
raise ValueError(
|
||||
"Max block size must be greater than min block size")
|
||||
|
||||
self.max_block_size = self._round_to_power_of_two(max_block_size)
|
||||
self.min_block_size = self._round_to_power_of_two(min_block_size)
|
||||
|
||||
self.free_lists: dict[int, dict[int, MemoryBlock]] = {}
|
||||
self.allocated_blocks: dict[int, MemoryBlock] = {}
|
||||
|
||||
self._initialize_free_lists()
|
||||
self._allocate_pinned_memory()
|
||||
|
||||
atexit.register(self.cleanup)
|
||||
|
||||
def _round_to_power_of_two(self, size: int) -> int:
|
||||
return 1 << (size - 1).bit_length()
|
||||
|
||||
def _initialize_free_lists(self):
|
||||
size = self.max_block_size
|
||||
while size >= self.min_block_size:
|
||||
self.free_lists[size] = {}
|
||||
size //= 2
|
||||
|
||||
def _allocate_pinned_memory(self):
|
||||
self.base_tensor = torch.empty(self.max_block_size // 4,
|
||||
dtype=torch.float32,
|
||||
pin_memory=True)
|
||||
self.base_address = self.base_tensor.data_ptr()
|
||||
initial_block = MemoryBlock(size=self.max_block_size,
|
||||
addr=self.base_address)
|
||||
self.free_lists[self.max_block_size][
|
||||
initial_block.addr] = initial_block
|
||||
logger.debug("TensorMemoryPool, base_address:", self.base_address,
|
||||
self.base_address % self.max_block_size)
|
||||
|
||||
def allocate(self, size: int) -> int:
|
||||
"""Allocates a memory block of at least the requested size.
|
||||
|
||||
Args:
|
||||
size (int): Minimum size of memory to allocate
|
||||
|
||||
Returns:
|
||||
int: Address of the allocated memory block
|
||||
|
||||
Raises:
|
||||
ValueError: If size is invalid or insufficient memory is available
|
||||
"""
|
||||
if size <= 0:
|
||||
raise ValueError("Allocation size must be positive")
|
||||
|
||||
required_size = self._round_to_power_of_two(
|
||||
max(size, self.min_block_size))
|
||||
if required_size > self.max_block_size:
|
||||
raise ValueError("Requested size exceeds maximum block size")
|
||||
|
||||
current_size = required_size
|
||||
while current_size <= self.max_block_size:
|
||||
if self.free_lists[current_size]:
|
||||
_, block = self.free_lists[current_size].popitem()
|
||||
self._split_block(block, required_size)
|
||||
self.allocated_blocks[block.addr] = block
|
||||
return block.addr
|
||||
current_size *= 2
|
||||
|
||||
raise ValueError("Insufficient memory")
|
||||
|
||||
def _split_block(self, block: MemoryBlock, required_size: int):
|
||||
while (block.size > required_size
|
||||
and block.size // 2 >= self.min_block_size):
|
||||
buddy_size = block.size // 2
|
||||
buddy_addr = block.addr + buddy_size
|
||||
|
||||
buddy = MemoryBlock(size=buddy_size, addr=buddy_addr)
|
||||
block.size = buddy_size
|
||||
|
||||
self.free_lists[buddy_size][buddy.addr] = buddy
|
||||
|
||||
def free(self, addr: int):
|
||||
"""Frees an allocated memory block.
|
||||
|
||||
Args:
|
||||
addr (int): Address of the block to free
|
||||
|
||||
Raises:
|
||||
ValueError: If address is invalid or not allocated
|
||||
"""
|
||||
if addr not in self.allocated_blocks:
|
||||
raise ValueError("Invalid address to free")
|
||||
|
||||
block = self.allocated_blocks.pop(addr)
|
||||
self._merge_buddies(block)
|
||||
|
||||
def _merge_buddies(self, block: MemoryBlock):
|
||||
MAX_MERGE_DEPTH = 30
|
||||
depth = 0
|
||||
|
||||
while depth < MAX_MERGE_DEPTH:
|
||||
buddy_offset = block.size if (block.addr - self.base_address) % (
|
||||
2 * block.size) == 0 else -block.size
|
||||
buddy_addr = block.addr + buddy_offset
|
||||
buddy = self.free_lists[block.size].get(buddy_addr)
|
||||
if buddy:
|
||||
del self.free_lists[buddy.size][buddy.addr]
|
||||
merged_addr = min(block.addr, buddy.addr)
|
||||
merged_size = block.size * 2
|
||||
block = MemoryBlock(size=merged_size, addr=merged_addr)
|
||||
depth += 1
|
||||
else:
|
||||
break
|
||||
self.free_lists[block.size][block.addr] = block
|
||||
|
||||
def store_tensor(self, tensor: torch.Tensor) -> int:
|
||||
"""Stores a CUDA tensor in pinned host memory.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): CUDA tensor to store
|
||||
|
||||
Returns:
|
||||
int: Address where the tensor is stored
|
||||
|
||||
Raises:
|
||||
ValueError: If tensor is not on CUDA or allocation fails
|
||||
"""
|
||||
if not tensor.is_cuda:
|
||||
raise ValueError("Only CUDA tensors can be stored")
|
||||
|
||||
size = tensor.element_size() * tensor.numel()
|
||||
addr = self.allocate(size)
|
||||
block = self.allocated_blocks[addr]
|
||||
|
||||
if block.size < size:
|
||||
self.free(addr)
|
||||
raise ValueError(
|
||||
f"Allocated block size {block.size} is smaller than "
|
||||
f"required size {size}")
|
||||
|
||||
try:
|
||||
buffer = (ctypes.c_byte * block.size).from_address(block.addr)
|
||||
cpu_tensor = torch.frombuffer(buffer,
|
||||
dtype=tensor.dtype,
|
||||
count=tensor.numel()).reshape(
|
||||
tensor.shape)
|
||||
except ValueError as err:
|
||||
self.free(addr)
|
||||
raise ValueError(f"Failed to create tensor view: {err}") from err
|
||||
|
||||
cpu_tensor.copy_(tensor)
|
||||
|
||||
return addr
|
||||
|
||||
def load_tensor(self, addr: int, dtype: torch.dtype,
|
||||
shape: tuple[int, ...], device) -> torch.Tensor:
|
||||
"""Loads a tensor from pinned host memory to the specified device.
|
||||
|
||||
Args:
|
||||
addr (int): Address where tensor is stored
|
||||
dtype (torch.dtype): Data type of the tensor
|
||||
shape (tuple[int, ...]): Shape of the tensor
|
||||
device: Target device for the loaded tensor
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The loaded tensor on the specified device
|
||||
|
||||
Raises:
|
||||
ValueError: If address is invalid or sizes don't match
|
||||
"""
|
||||
if addr not in self.allocated_blocks:
|
||||
raise ValueError("Invalid address to load")
|
||||
|
||||
block = self.allocated_blocks[addr]
|
||||
num_elements = math.prod(shape)
|
||||
dtype_size = torch.tensor([], dtype=dtype).element_size()
|
||||
required_size = num_elements * dtype_size
|
||||
|
||||
if required_size > block.size:
|
||||
raise ValueError("Requested tensor size exceeds block size")
|
||||
|
||||
buffer = (ctypes.c_byte * block.size).from_address(block.addr)
|
||||
cpu_tensor = torch.frombuffer(buffer, dtype=dtype,
|
||||
count=num_elements).reshape(shape)
|
||||
|
||||
cuda_tensor = torch.empty(shape, dtype=dtype, device=device)
|
||||
|
||||
cuda_tensor.copy_(cpu_tensor)
|
||||
|
||||
return cuda_tensor
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleans up all memory resources and resets the pool state."""
|
||||
self.free_lists.clear()
|
||||
self.allocated_blocks.clear()
|
||||
if hasattr(self, 'base_tensor'):
|
||||
del self.base_tensor
|
||||
|
||||
def __del__(self):
|
||||
self.cleanup()
|
||||
@ -1018,7 +1018,8 @@ class EngineArgs:
|
||||
from vllm.platforms import current_platform
|
||||
current_platform.pre_register_and_update()
|
||||
|
||||
device_config = DeviceConfig(device=current_platform.device_type)
|
||||
device_config = DeviceConfig(
|
||||
device=cast(Device, current_platform.device_type))
|
||||
model_config = self.create_model_config()
|
||||
|
||||
# * If VLLM_USE_V1 is unset, we enable V1 for "supported features"
|
||||
@ -1302,7 +1303,7 @@ class EngineArgs:
|
||||
# Skip this check if we are running on a non-GPU platform,
|
||||
# or if the device capability is not available
|
||||
# (e.g. in a Ray actor without GPUs).
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.platforms import current_platform
|
||||
if (current_platform.is_cuda()
|
||||
and current_platform.get_device_capability()
|
||||
and current_platform.get_device_capability().major < 8):
|
||||
@ -1444,14 +1445,10 @@ class EngineArgs:
|
||||
_raise_or_fallback(feature_name=name, recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
# Non-[CUDA, TPU, x86 CPU] may be supported on V1,
|
||||
# but off by default for now.
|
||||
v0_hardware = not any(
|
||||
(current_platform.is_cuda_alike(), current_platform.is_tpu(),
|
||||
(current_platform.is_cpu()
|
||||
and current_platform.get_cpu_architecture() == CpuArchEnum.X86)))
|
||||
if v0_hardware and _warn_or_fallback( # noqa: SIM103
|
||||
current_platform.device_name):
|
||||
# The platform may be supported on V1, but off by default for now.
|
||||
if not current_platform.default_v1( # noqa: SIM103
|
||||
model_config=model_config) and _warn_or_fallback(
|
||||
current_platform.device_name):
|
||||
return False
|
||||
#############################################################
|
||||
|
||||
|
||||
@ -87,6 +87,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ROCM_USE_AITER_MOE: bool = True
|
||||
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
|
||||
VLLM_ROCM_USE_AITER_MLA: bool = True
|
||||
VLLM_ROCM_USE_AITER_MHA: bool = True
|
||||
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
|
||||
VLLM_ROCM_FP8_PADDING: bool = True
|
||||
VLLM_ROCM_MOE_PADDING: bool = True
|
||||
@ -653,6 +654,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_ROCM_USE_AITER_MLA":
|
||||
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MLA", "True").lower() in
|
||||
("true", "1")),
|
||||
|
||||
# Whether to use aiter mha ops.
|
||||
# By default is enabled.
|
||||
"VLLM_ROCM_USE_AITER_MHA":
|
||||
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in
|
||||
("true", "1")),
|
||||
|
||||
# use rocm skinny gemms
|
||||
"VLLM_ROCM_USE_SKINNY_GEMM":
|
||||
lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in
|
||||
|
||||
@ -557,8 +557,17 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
||||
def _compiled_ray_dag(self, enable_asyncio: bool):
|
||||
assert self.parallel_config.use_ray
|
||||
self._check_ray_cgraph_installation()
|
||||
# Enlarge the default value of "RAY_CGRAPH_get_timeout" to 300 seconds
|
||||
# (it is 10 seconds by default). This is a Ray environment variable to
|
||||
# control the timeout of getting result from a compiled graph execution,
|
||||
# i.e., the distributed execution that includes model forward runs and
|
||||
# intermediate tensor communications, in the case of vllm.
|
||||
# Note: we should set this env var before importing
|
||||
# ray.dag, otherwise it will not take effect.
|
||||
os.environ.setdefault("RAY_CGRAPH_get_timeout", "300") # noqa: SIM112
|
||||
from ray.dag import InputNode, MultiOutputNode
|
||||
|
||||
logger.info("RAY_CGRAPH_get_timeout is set to %s",
|
||||
os.environ["RAY_CGRAPH_get_timeout"]) # noqa: SIM112
|
||||
logger.info("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE = %s",
|
||||
envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE)
|
||||
logger.info("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = %s",
|
||||
@ -570,15 +579,6 @@ class RayDistributedExecutor(DistributedExecutorBase):
|
||||
"Invalid value for VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: "
|
||||
f"{channel_type}. Valid values are: 'auto', 'nccl', or 'shm'.")
|
||||
|
||||
# Enlarge the default value of "RAY_CGRAPH_get_timeout" to 300 seconds
|
||||
# (it is 10 seconds by default). This is a Ray environment variable to
|
||||
# control the timeout of getting result from a compiled graph execution,
|
||||
# i.e., the distributed execution that includes model forward runs and
|
||||
# intermediate tensor communications, in the case of vllm.
|
||||
os.environ.setdefault("RAY_CGRAPH_get_timeout", "300") # noqa: SIM112
|
||||
logger.info("RAY_CGRAPH_get_timeout is set to %s",
|
||||
os.environ["RAY_CGRAPH_get_timeout"]) # noqa: SIM112
|
||||
|
||||
with InputNode() as input_data:
|
||||
# Example DAG: PP=2, TP=4
|
||||
#
|
||||
|
||||
@ -6,14 +6,179 @@ import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_resize_cache, per_token_group_quant_fp8)
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _silu_mul_fp8_quant_deep_gemm(
|
||||
# Pointers ------------------------------------------------------------
|
||||
input_ptr, # 16-bit activations (E, T, 2*H)
|
||||
y_q_ptr, # fp8 quantized activations (E, T, H)
|
||||
y_s_ptr, # 16-bit scales (E, T, G)
|
||||
counts_ptr, # int32 num tokens per expert (E)
|
||||
|
||||
# Sizes ---------------------------------------------------------------
|
||||
H: tl.constexpr, # hidden dimension (per output)
|
||||
GROUP_SIZE: tl.constexpr, # elements per group (usually 128)
|
||||
|
||||
# Strides for input (elements) ---------------------------------------
|
||||
stride_i_e,
|
||||
stride_i_t,
|
||||
stride_i_h,
|
||||
|
||||
# Strides for y_q (elements) -----------------------------------------
|
||||
stride_yq_e,
|
||||
stride_yq_t,
|
||||
stride_yq_h,
|
||||
|
||||
# Strides for y_s (elements) -----------------------------------------
|
||||
stride_ys_e,
|
||||
stride_ys_t,
|
||||
stride_ys_g,
|
||||
|
||||
# Stride for counts (elements)
|
||||
stride_counts_e,
|
||||
|
||||
# Numeric params ------------------------------------------------------
|
||||
eps: tl.constexpr,
|
||||
fp8_min: tl.constexpr,
|
||||
fp8_max: tl.constexpr,
|
||||
|
||||
# Meta ---------------------------------------------------------------
|
||||
BLOCK: tl.constexpr,
|
||||
):
|
||||
G = H // GROUP_SIZE
|
||||
|
||||
# map program id -> (e, g)
|
||||
pid = tl.program_id(0)
|
||||
e = pid // G
|
||||
g = pid % G
|
||||
|
||||
e = e.to(tl.int64)
|
||||
g = g.to(tl.int64)
|
||||
|
||||
# number of valid tokens for this expert
|
||||
n_tokens = tl.load(counts_ptr + e * stride_counts_e).to(tl.int64)
|
||||
|
||||
cols = tl.arange(0, BLOCK)
|
||||
cols = cols.to(tl.int64)
|
||||
mask_h = cols < BLOCK
|
||||
|
||||
t = tl.zeros([], tl.int64)
|
||||
while t < n_tokens:
|
||||
base_i_offset = (e * stride_i_e + t * stride_i_t +
|
||||
g * GROUP_SIZE * stride_i_h)
|
||||
base_yq_offset = (e * stride_yq_e + t * stride_yq_t +
|
||||
g * GROUP_SIZE * stride_yq_h)
|
||||
base_ys_offset = e * stride_ys_e + t * stride_ys_t + g * stride_ys_g
|
||||
|
||||
mask = mask_h
|
||||
x = tl.load(input_ptr + base_i_offset + cols * stride_i_h,
|
||||
mask=mask,
|
||||
other=0.0).to(tl.float32)
|
||||
y2 = tl.load(input_ptr + base_i_offset + H * stride_i_h +
|
||||
cols * stride_i_h,
|
||||
mask=mask,
|
||||
other=0.0).to(tl.float32)
|
||||
|
||||
x = x * (1.0 / (1.0 + tl.exp(-x)))
|
||||
y = x * y2
|
||||
|
||||
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
|
||||
y_s = _absmax / fp8_max
|
||||
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
||||
|
||||
tl.store(y_q_ptr + base_yq_offset + cols * stride_yq_h, y_q, mask=mask)
|
||||
tl.store(y_s_ptr + base_ys_offset, y_s)
|
||||
|
||||
t += 1
|
||||
|
||||
|
||||
def silu_mul_fp8_quant_deep_gemm(
|
||||
y: torch.Tensor, # (E, T, 2*H) float32
|
||||
tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert
|
||||
group_size: int = 128,
|
||||
eps: float = 1e-10,
|
||||
):
|
||||
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
|
||||
|
||||
y has shape (E, T, 2*H). The first half of the last dimension is
|
||||
silu-activated, multiplied by the second half, then quantized into FP8.
|
||||
|
||||
Returns `(y_q, y_s)` where
|
||||
* `y_q` is the FP8 tensor of shape `(E, T, H)`, same layout as `y[..., :H]`.
|
||||
* `y_s` has shape `(E, T, H // group_size)` and strides `(T*G, 1, T)`
|
||||
"""
|
||||
assert y.ndim == 3, "y must be (E, T, 2*H)"
|
||||
E, T, H2 = y.shape
|
||||
assert H2 % 2 == 0, "last dim of y must be even (2*H)"
|
||||
H = H2 // 2
|
||||
G = H // group_size
|
||||
assert H % group_size == 0, "H must be divisible by group_size"
|
||||
assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E, \
|
||||
"tokens_per_expert must be shape (E,)"
|
||||
tokens_per_expert = tokens_per_expert.to(device=y.device,
|
||||
dtype=torch.int32)
|
||||
|
||||
# allocate outputs
|
||||
fp8_dtype = torch.float8_e4m3fn
|
||||
y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device)
|
||||
|
||||
# strides (elements)
|
||||
stride_i_e, stride_i_t, stride_i_h = y.stride()
|
||||
stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride()
|
||||
|
||||
# desired scale strides (elements): (T*G, 1, T)
|
||||
stride_ys_e = T * G
|
||||
stride_ys_t = 1
|
||||
stride_ys_g = T
|
||||
y_s = torch.empty_strided((E, T, G),
|
||||
(stride_ys_e, stride_ys_t, stride_ys_g),
|
||||
dtype=torch.float32,
|
||||
device=y.device)
|
||||
|
||||
stride_cnt_e = tokens_per_expert.stride()[0]
|
||||
|
||||
# static grid over experts and H-groups.
|
||||
# A loop inside the kernel handles the token dim
|
||||
grid = (E * G, )
|
||||
|
||||
f_info = torch.finfo(fp8_dtype)
|
||||
fp8_max = f_info.max
|
||||
fp8_min = f_info.min
|
||||
|
||||
_silu_mul_fp8_quant_deep_gemm[grid](
|
||||
y,
|
||||
y_q,
|
||||
y_s,
|
||||
tokens_per_expert,
|
||||
H,
|
||||
group_size,
|
||||
stride_i_e,
|
||||
stride_i_t,
|
||||
stride_i_h,
|
||||
stride_yq_e,
|
||||
stride_yq_t,
|
||||
stride_yq_h,
|
||||
stride_ys_e,
|
||||
stride_ys_t,
|
||||
stride_ys_g,
|
||||
stride_cnt_e,
|
||||
eps,
|
||||
fp8_min,
|
||||
fp8_max,
|
||||
BLOCK=group_size,
|
||||
num_warps=4,
|
||||
)
|
||||
|
||||
return y_q, y_s
|
||||
|
||||
|
||||
class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
# The Deep Gemm kernels only support block size of 128
|
||||
@ -96,7 +261,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
hidden_states, w1, w2, topk_ids)
|
||||
|
||||
workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))
|
||||
workspace2 = _resize_cache(workspace2, (E, max_num_tokens, N // 2))
|
||||
|
||||
# (from deepgemm docs) : A value hint (which is a value on CPU)
|
||||
# for the M expectation of each batch, correctly setting this value
|
||||
@ -109,19 +273,9 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
masked_m=expert_num_tokens,
|
||||
expected_m=expected_m)
|
||||
|
||||
# TODO (varun) [Optimization]: Use a batched version of activation.
|
||||
# Similarly for the quant below.
|
||||
self.activation(activation, workspace2, workspace1.view(-1, N))
|
||||
|
||||
w2_hidden_size = workspace2.size(-1)
|
||||
workspace2 = workspace2.view(-1, w2_hidden_size)
|
||||
|
||||
a2q_scale: Optional[torch.Tensor] = None
|
||||
a2q, a2q_scale = per_token_group_quant_fp8(workspace2,
|
||||
self.block_shape[1],
|
||||
column_major_scales=False)
|
||||
a2q = a2q.view(E, max_num_tokens, -1)
|
||||
a2q_scale = a2q_scale.view(E, max_num_tokens, -1)
|
||||
assert expert_num_tokens is not None
|
||||
a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1,
|
||||
expert_num_tokens)
|
||||
|
||||
dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a2q, a2q_scale),
|
||||
(w2, w2_scale),
|
||||
|
||||
@ -45,7 +45,8 @@ if current_platform.is_cuda_alike():
|
||||
from .pplx_prepare_finalize import PplxPrepareAndFinalize
|
||||
if has_deepep:
|
||||
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
|
||||
from .deepep_ll_prepare_finalize import DeepEPLLPrepareAndFinalize
|
||||
from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SIZE,
|
||||
DeepEPLLPrepareAndFinalize)
|
||||
else:
|
||||
fused_experts = None # type: ignore
|
||||
FusedMoEPermuteExpertsUnpermute = None # type: ignore
|
||||
@ -377,6 +378,13 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
all2all_manager.world_size)
|
||||
handle = all2all_manager.get_handle(all_to_all_args)
|
||||
|
||||
# Note : We may want to use FP8 dispatch even otherwise just to
|
||||
# reduce datamovement
|
||||
assert act_quant_block_size is not None
|
||||
use_fp8_dispatch = (quant_dtype == current_platform.fp8_dtype()
|
||||
and act_quant_block_size[1]
|
||||
== DEEPEP_QUANT_BLOCK_SIZE)
|
||||
|
||||
# Note (varun): Whether to use FP8 dispatch or not needs some
|
||||
# profiling. Turning it off for now.
|
||||
prepare_finalize = DeepEPLLPrepareAndFinalize(
|
||||
@ -386,7 +394,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
max_tokens_per_rank=moe.max_num_tokens,
|
||||
quant_dtype=quant_dtype,
|
||||
block_shape=act_quant_block_size,
|
||||
use_fp8_dispatch=False,
|
||||
use_fp8_dispatch=use_fp8_dispatch,
|
||||
)
|
||||
|
||||
self.topk_indices_dtype = None
|
||||
|
||||
@ -75,7 +75,7 @@ enable_hf_transfer()
|
||||
class DisabledTqdm(tqdm):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs, disable=False)
|
||||
super().__init__(*args, **kwargs, disable=True)
|
||||
|
||||
|
||||
def get_lock(model_name_or_path: Union[str, Path],
|
||||
|
||||
@ -269,3 +269,11 @@ class CpuPlatform(Platform):
|
||||
model configuration.
|
||||
"""
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def default_v1(cls, model_config) -> bool:
|
||||
"""Returns whether the current platform can use v1 by default for the
|
||||
supplied model configuration.
|
||||
"""
|
||||
return cls.supports_v1(
|
||||
model_config) and cls.get_cpu_architecture() == CpuArchEnum.X86
|
||||
|
||||
@ -479,6 +479,13 @@ class Platform:
|
||||
"""
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def default_v1(cls, model_config: ModelConfig) -> bool:
|
||||
"""
|
||||
Returns whether the current platform supports v1 by default.
|
||||
"""
|
||||
return cls.supports_v1(model_config)
|
||||
|
||||
@classmethod
|
||||
def use_custom_allreduce(cls) -> bool:
|
||||
"""
|
||||
|
||||
@ -215,9 +215,15 @@ class RocmPlatform(Platform):
|
||||
selected_backend = _Backend.ROCM_FLASH
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
logger.info("Using Triton Attention backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends."
|
||||
"triton_attn.TritonAttentionBackend")
|
||||
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA \
|
||||
and on_gfx9():
|
||||
logger.info("Using Flash Attention backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends."
|
||||
"rocm_aiter_fa.AiterFlashAttentionBackend")
|
||||
else:
|
||||
logger.info("Using Triton Attention backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends."
|
||||
"triton_attn.TritonAttentionBackend")
|
||||
if selected_backend == _Backend.ROCM_FLASH:
|
||||
if not cls.has_device_capability(90):
|
||||
# not Instinct series GPUs.
|
||||
|
||||
585
vllm/v1/attention/backends/rocm_aiter_fa.py
Normal file
585
vllm/v1/attention/backends/rocm_aiter_fa.py
Normal file
@ -0,0 +1,585 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Attention layer with AiterFlashAttention."""
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.flash_attn import (
|
||||
make_local_attention_virtual_batches)
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
if current_platform.is_rocm():
|
||||
import aiter
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
@triton.jit
|
||||
def _vllm_layout_trans_kernel(
|
||||
k_buffer_ptr,
|
||||
v_buffer_ptr,
|
||||
k_values_ptr,
|
||||
v_values_ptr,
|
||||
b_query_lens_loc,
|
||||
b_seq_lens_loc,
|
||||
block_table,
|
||||
block_table_stride_0,
|
||||
E_DIM: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
block_idx = tl.program_id(1)
|
||||
batch_token_indexes = tl.load(b_seq_lens_loc + batch_idx +
|
||||
tl.arange(0, 2))
|
||||
batch_token_start, batch_token_end = tl.split(batch_token_indexes)
|
||||
seq_len = batch_token_end - batch_token_start
|
||||
|
||||
batch_query_indexes = tl.load(b_query_lens_loc + batch_idx +
|
||||
tl.arange(0, 2))
|
||||
batch_query_start, batch_query_end = tl.split(batch_query_indexes)
|
||||
query_len = batch_query_end - batch_query_start
|
||||
if query_len <= 1:
|
||||
return
|
||||
if block_idx * BLOCK_SIZE < seq_len:
|
||||
block_mask = (block_idx * BLOCK_SIZE +
|
||||
tl.arange(0, BLOCK_SIZE)[:, None]) < seq_len
|
||||
|
||||
kv_idx = tl.load(block_table + batch_idx * block_table_stride_0 +
|
||||
block_idx)
|
||||
|
||||
kv_buffer_off = kv_idx * BLOCK_SIZE * E_DIM + tl.arange(
|
||||
0, BLOCK_SIZE)[:, None] * E_DIM + tl.arange(0, E_DIM)[None, :]
|
||||
k_vals = tl.load(k_buffer_ptr + kv_buffer_off,
|
||||
mask=block_mask,
|
||||
other=0.0)
|
||||
v_vals = tl.load(v_buffer_ptr + kv_buffer_off,
|
||||
mask=block_mask,
|
||||
other=0.0)
|
||||
|
||||
kv_values_off = batch_token_start * E_DIM + \
|
||||
block_idx * BLOCK_SIZE * E_DIM + \
|
||||
tl.arange(0, BLOCK_SIZE)[:, None] * E_DIM + \
|
||||
tl.arange(0, E_DIM)[None, :]
|
||||
tl.store(k_values_ptr + kv_values_off, k_vals, mask=block_mask)
|
||||
tl.store(v_values_ptr + kv_values_off, v_vals, mask=block_mask)
|
||||
|
||||
def vllm_layout_trans(b_query_lens_loc, b_seq_lens_loc, block_table,
|
||||
k_buffer, v_buffer, max_seq_len, total_tokens):
|
||||
H_KV = v_buffer.shape[2]
|
||||
D = v_buffer.shape[3]
|
||||
BLOCK_SIZE = v_buffer.shape[1]
|
||||
dtype = k_buffer.dtype
|
||||
k_values = torch.empty((total_tokens, H_KV, D),
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
v_values = torch.empty((total_tokens, H_KV, D),
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
|
||||
grid = (block_table.shape[0],
|
||||
(max_seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE)
|
||||
|
||||
_vllm_layout_trans_kernel[grid](k_buffer,
|
||||
v_buffer,
|
||||
k_values,
|
||||
v_values,
|
||||
b_query_lens_loc,
|
||||
b_seq_lens_loc,
|
||||
block_table,
|
||||
block_table.stride(0),
|
||||
E_DIM=H_KV * D,
|
||||
BLOCK_SIZE=BLOCK_SIZE)
|
||||
|
||||
return k_values, v_values
|
||||
|
||||
def flash_attn_varlen_func_impl(
|
||||
q: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
cu_seqlens_q: torch.Tensor,
|
||||
cu_seqlens_k: torch.Tensor,
|
||||
total_tokens: int,
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
softmax_scale: float,
|
||||
window_size: Optional[list[int]], # -1 means infinite context window
|
||||
alibi_slopes: Optional[list[float]],
|
||||
block_table: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
k, v = vllm_layout_trans(cu_seqlens_q, cu_seqlens_k, block_table,
|
||||
k_cache, v_cache, max_seqlen_k, total_tokens)
|
||||
output = aiter.flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
min_seqlen_q=1,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
alibi_slopes=alibi_slopes,
|
||||
window_size=window_size,
|
||||
out=out,
|
||||
)
|
||||
return output
|
||||
|
||||
def flash_attn_varlen_func_fake(
|
||||
q: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
cu_seqlens_q: torch.Tensor,
|
||||
cu_seqlens_k: torch.Tensor,
|
||||
total_tokens: int,
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
softmax_scale: float,
|
||||
window_size: Optional[list[int]], # -1 means infinite context window
|
||||
alibi_slopes: Optional[list[float]],
|
||||
block_table: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(q.shape[0],
|
||||
q.shape[1],
|
||||
v_cache.shape[-2],
|
||||
dtype=torch.float8_e4m3fnuz,
|
||||
device="cuda")
|
||||
|
||||
direct_register_custom_op("flash_attn_varlen_func",
|
||||
flash_attn_varlen_func_impl, ["out"],
|
||||
flash_attn_varlen_func_fake,
|
||||
dispatch_key=current_platform.dispatch_key)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AiterFlashAttentionMetadataBuilder:
|
||||
|
||||
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable):
|
||||
model_config = runner.model_config
|
||||
|
||||
self.runner = runner
|
||||
self.num_heads_q = model_config.get_num_attention_heads(
|
||||
runner.parallel_config)
|
||||
self.num_heads_kv = model_config.get_num_kv_heads(
|
||||
runner.parallel_config)
|
||||
self.headdim = model_config.get_head_size()
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.block_table = block_table
|
||||
|
||||
# Sliding window size to be used with the AOT scheduler will be
|
||||
# populated on first build() call.
|
||||
self.aot_sliding_window: Optional[tuple[int, int]] = None
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
return False
|
||||
|
||||
def build(self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata):
|
||||
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
|
||||
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
|
||||
total_tokens = int(self.runner.seq_lens_np[:num_reqs].sum())
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
block_table = self.block_table
|
||||
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
|
||||
|
||||
block_table.slot_mapping[:num_actual_tokens].copy_(
|
||||
block_table.slot_mapping_cpu[:num_actual_tokens],
|
||||
non_blocking=True)
|
||||
# Fill unused with -1. Needed for reshape_and_cache in full cuda graph
|
||||
# mode.
|
||||
block_table.slot_mapping[num_actual_tokens:].fill_(-1)
|
||||
|
||||
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
|
||||
|
||||
cu_seq_lens = torch.zeros(seq_lens.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
torch.cumsum(seq_lens,
|
||||
dim=0,
|
||||
dtype=cu_seq_lens.dtype,
|
||||
out=cu_seq_lens[1:])
|
||||
|
||||
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
|
||||
max_seq_len, causal):
|
||||
return None
|
||||
|
||||
# for local attention
|
||||
local_attn_metadata = None
|
||||
if self.runner.attention_chunk_size is not None:
|
||||
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
|
||||
virt_block_table_tensor = make_local_attention_virtual_batches(
|
||||
self.runner.attention_chunk_size,
|
||||
self.runner.query_start_loc_np[:num_reqs + 1],
|
||||
self.runner.seq_lens_np[:num_reqs],
|
||||
block_table_tensor,
|
||||
self.block_size,
|
||||
)
|
||||
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
|
||||
self.runner.device, non_blocking=True)
|
||||
local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to(
|
||||
self.runner.device, non_blocking=True)
|
||||
local_max_query_len = seqlens_q_local_np.max()
|
||||
local_max_seq_len = virt_k_seqlens_np.max()
|
||||
local_scheduler_metadata = schedule(
|
||||
batch_size=local_query_start_loc.shape[0] - 1,
|
||||
cu_query_lens=local_query_start_loc,
|
||||
max_query_len=local_max_query_len,
|
||||
seqlens=local_seqused_k,
|
||||
max_seq_len=local_max_seq_len,
|
||||
causal=True)
|
||||
|
||||
local_attn_metadata = \
|
||||
AiterFlashAttentionMetadata.LocalAttentionMetadata(
|
||||
local_query_start_loc=local_query_start_loc,
|
||||
local_seqused_k=local_seqused_k,
|
||||
local_block_table=virt_block_table_tensor,
|
||||
local_max_query_len=local_max_query_len,
|
||||
local_max_seq_len=local_max_seq_len,
|
||||
local_scheduler_metadata=local_scheduler_metadata,
|
||||
)
|
||||
|
||||
use_cascade = common_prefix_len > 0
|
||||
|
||||
cu_prefix_query_lens = None
|
||||
prefix_kv_lens = None
|
||||
suffix_kv_lens = None
|
||||
|
||||
attn_metadata = AiterFlashAttentionMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=query_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
cu_seq_lens=cu_seq_lens,
|
||||
total_tokens=total_tokens,
|
||||
block_table=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
use_cascade=use_cascade,
|
||||
common_prefix_len=common_prefix_len,
|
||||
cu_prefix_query_lens=cu_prefix_query_lens,
|
||||
prefix_kv_lens=prefix_kv_lens,
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
local_attn_metadata=local_attn_metadata,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
def can_run_in_cudagraph(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
|
||||
# Full CUDA Graph always supported (FA2 support checked separately)
|
||||
return True
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class AiterFlashAttentionBackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> list[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASH_ATTN_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["AiterFlashAttentionImpl"]:
|
||||
return AiterFlashAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
||||
return AiterFlashAttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["AiterFlashAttentionMetadataBuilder"]:
|
||||
return AiterFlashAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AiterFlashAttentionMetadata:
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
cu_seq_lens: torch.Tensor
|
||||
total_tokens: int
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
# For cascade attention.
|
||||
use_cascade: bool
|
||||
common_prefix_len: int
|
||||
cu_prefix_query_lens: Optional[torch.Tensor]
|
||||
prefix_kv_lens: Optional[torch.Tensor]
|
||||
suffix_kv_lens: Optional[torch.Tensor]
|
||||
|
||||
# for local attention
|
||||
@dataclass
|
||||
class LocalAttentionMetadata:
|
||||
local_query_start_loc: torch.Tensor
|
||||
local_seqused_k: torch.Tensor
|
||||
local_block_table: torch.Tensor
|
||||
local_max_query_len: int
|
||||
local_max_seq_len: int
|
||||
local_scheduler_metadata: Optional[torch.Tensor]
|
||||
|
||||
local_attn_metadata: Optional[LocalAttentionMetadata] = None
|
||||
|
||||
|
||||
class AiterFlashAttentionImpl(AttentionImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"AiterFlashAttention does not support block-sparse attention.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
if sliding_window is None:
|
||||
self.sliding_window = [-1, -1]
|
||||
else:
|
||||
self.sliding_window = [sliding_window - 1, 0]
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
if logits_soft_cap is None:
|
||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||
logits_soft_cap = 0.
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
support_head_sizes = \
|
||||
AiterFlashAttentionBackend.get_supported_head_sizes()
|
||||
if head_size not in support_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by "
|
||||
"AiterFlashAttention. "
|
||||
f"Supported head sizes are: {support_head_sizes}. "
|
||||
"Set VLLM_USE_V1=0 to use another attention backend.")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashAttentionImpl")
|
||||
self.use_irope = use_irope
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"AiterFlashAttention does not support fp8 kv-cache on this "
|
||||
"device.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AiterFlashAttentionMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with AiterFlashAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
NOTE: FP8 quantization, flash-attn expect the size of
|
||||
{q,k,v}_descale to be (num_sequences, num_kv_heads).
|
||||
We use torch's .expand() to avoid duplicating values
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for FlashAttentionImpl")
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
|
||||
# IMPORTANT!
|
||||
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
|
||||
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
|
||||
# in this method. For example, `view` and `slice` (or `[:n]`) operations
|
||||
# are surprisingly slow even in the case they do not invoke any GPU ops.
|
||||
# Minimize the PyTorch ops in this method as much as possible.
|
||||
# Whenever making a change in this method, please benchmark the
|
||||
# performance to make sure it does not introduce any overhead.
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
||||
# not padded. However, we don't need to do key[:num_actual_tokens] and
|
||||
# value[:num_actual_tokens] because the reshape_and_cache_flash op uses
|
||||
# the slot_mapping's shape to determine the number of actual tokens.
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(torch.float8_e4m3fnuz)
|
||||
value_cache = value_cache.view(torch.float8_e4m3fnuz)
|
||||
num_tokens, num_heads, head_size = query.shape
|
||||
query, _ = ops.scaled_fp8_quant(
|
||||
query.reshape(
|
||||
(num_tokens, num_heads * head_size)).contiguous(),
|
||||
layer._q_scale)
|
||||
query = query.reshape((num_tokens, num_heads, head_size))
|
||||
|
||||
# Compute attention and update output up to `num_actual_tokens`.
|
||||
use_local_attn = \
|
||||
(self.use_irope and attn_metadata.local_attn_metadata is not None)
|
||||
|
||||
if not attn_metadata.use_cascade or use_local_attn:
|
||||
if use_local_attn:
|
||||
assert attn_metadata.local_attn_metadata is not None
|
||||
local_metadata = attn_metadata.local_attn_metadata
|
||||
cu_seqlens_q = local_metadata.local_query_start_loc
|
||||
seqused_k = local_metadata.local_seqused_k
|
||||
max_seqlen_q = local_metadata.local_max_query_len
|
||||
max_seqlen_k = local_metadata.local_max_seq_len
|
||||
block_table = local_metadata.local_block_table
|
||||
else:
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
seqused_k = attn_metadata.seq_lens
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
max_seqlen_k = attn_metadata.max_seq_len
|
||||
block_table = attn_metadata.block_table
|
||||
|
||||
if max_seqlen_q > 1:
|
||||
cu_seq_lens = attn_metadata.cu_seq_lens
|
||||
total_tokens = attn_metadata.total_tokens
|
||||
torch.ops.vllm.flash_attn_varlen_func(
|
||||
query[:num_actual_tokens],
|
||||
key_cache,
|
||||
value_cache,
|
||||
out=output[:num_actual_tokens],
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
total_tokens=total_tokens,
|
||||
softmax_scale=self.scale,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
block_table=block_table,
|
||||
cu_seqlens_k=cu_seq_lens,
|
||||
)
|
||||
|
||||
_, num_heads, head_size = query.shape
|
||||
_PARTITION_SIZE_ROCM = 256
|
||||
num_seqs = seqused_k.shape[0]
|
||||
nbyes_per_qo_elem = torch.finfo(output.dtype).bits // 8
|
||||
max_num_partitions = (max_seqlen_k + _PARTITION_SIZE_ROCM -
|
||||
1) // _PARTITION_SIZE_ROCM
|
||||
|
||||
workspace_buffer = torch.empty(
|
||||
(num_seqs * num_heads * max_num_partitions * head_size) *
|
||||
nbyes_per_qo_elem + 2 *
|
||||
(num_seqs * num_heads * max_num_partitions) * 4,
|
||||
dtype=torch.uint8,
|
||||
device=output.device,
|
||||
)
|
||||
|
||||
aiter.paged_attention_v1(
|
||||
output[:num_actual_tokens],
|
||||
workspace_buffer,
|
||||
query[:num_actual_tokens],
|
||||
key_cache,
|
||||
value_cache,
|
||||
self.scale,
|
||||
block_table,
|
||||
cu_seqlens_q,
|
||||
seqused_k,
|
||||
max_seqlen_k,
|
||||
self.alibi_slopes,
|
||||
self.kv_cache_dtype,
|
||||
"NHD",
|
||||
self.logits_soft_cap,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
None,
|
||||
_PARTITION_SIZE_ROCM,
|
||||
)
|
||||
return output
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Cascade attention is not implemented for ROCM AITER")
|
||||
@ -14,6 +14,39 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
class EncoderCacheManager:
|
||||
"""Manages caching of encoder outputs for multimodal models in vLLM V1.
|
||||
|
||||
The EncoderCacheManager handles the lifecycle of multimodal encoder outputs
|
||||
(such as vision embeddings from images) during request processing. It
|
||||
provides memory-aware caching to avoid recomputing encoder outputs when the
|
||||
same multimodal inputs appear in different stages of request processing.
|
||||
|
||||
This manager is particularly important for:
|
||||
- Vision-language models (e.g., LLaVA) where image encoder outputs are
|
||||
cached
|
||||
- Any multimodal model where encoder computation is expensive and
|
||||
cacheable
|
||||
|
||||
The cache operates at the granularity of individual multimodal input items
|
||||
within requests, allowing for fine-grained memory management and enabling
|
||||
chunked processing of multimodal inputs.
|
||||
|
||||
Note that no caching is shared between requests at this time. If the same
|
||||
input is used across multiple requests, it will be reprocessed for each
|
||||
request.
|
||||
|
||||
Args:
|
||||
cache_size: Limit the size of the cache, measured by the number of
|
||||
tokens from the input sequence.
|
||||
|
||||
Attributes:
|
||||
cache_size: Total cache capacity in encoder tokens
|
||||
num_free_slots: Current available cache capacity in encoder tokens
|
||||
cached: Mapping from request_id to set of cached input_ids for that
|
||||
request
|
||||
freed: List of (request_id, input_id) pairs that were recently freed.
|
||||
This is cleared after every call to get_freed_ids().
|
||||
"""
|
||||
|
||||
def __init__(self, cache_size: int):
|
||||
self.cache_size = cache_size
|
||||
@ -24,14 +57,48 @@ class EncoderCacheManager:
|
||||
self.freed: list[tuple[str, int]] = []
|
||||
|
||||
def has_cache(self, request: Request, input_id: int) -> bool:
|
||||
"""Check if encoder output for a specific multimodal input is cached.
|
||||
|
||||
Args:
|
||||
request: The request containing the multimodal input
|
||||
input_id: Index of the multimodal input within the request
|
||||
|
||||
Returns:
|
||||
True if the encoder output for this input is already cached
|
||||
"""
|
||||
req_id = request.request_id
|
||||
return req_id in self.cached and input_id in self.cached[req_id]
|
||||
|
||||
def can_allocate(self, request: Request, input_id: int) -> bool:
|
||||
"""Check if there's sufficient cache space for a multimodal input.
|
||||
|
||||
Args:
|
||||
request: The request containing the multimodal input
|
||||
input_id: Index of the multimodal input within the request
|
||||
|
||||
Returns:
|
||||
True if there's enough free cache space to store the encoder output
|
||||
for this multimodal input
|
||||
"""
|
||||
num_tokens = request.get_num_encoder_tokens(input_id)
|
||||
return num_tokens <= self.num_free_slots
|
||||
|
||||
def allocate(self, request: Request, input_id: int) -> None:
|
||||
"""Allocate cache space for a multimodal input's encoder output.
|
||||
|
||||
This method reserves cache space for storing the encoder output of
|
||||
the specified multimodal input. The actual encoder output storage
|
||||
happens in the model runner, but this method ensures the cache
|
||||
manager tracks the allocation.
|
||||
|
||||
Args:
|
||||
request: The request containing the multimodal input
|
||||
input_id: Index of the multimodal input within the request
|
||||
|
||||
Note:
|
||||
This method assumes can_allocate() returned True for the same
|
||||
request and input_id. It will reduce available cache space.
|
||||
"""
|
||||
req_id = request.request_id
|
||||
if req_id not in self.cached:
|
||||
self.cached[req_id] = set()
|
||||
@ -39,10 +106,30 @@ class EncoderCacheManager:
|
||||
self.num_free_slots -= request.get_num_encoder_tokens(input_id)
|
||||
|
||||
def get_cached_input_ids(self, request: Request) -> set[int]:
|
||||
"""Get all cached multimodal input IDs for a request.
|
||||
|
||||
Args:
|
||||
request: The request to query
|
||||
|
||||
Returns:
|
||||
Set of input_ids that have cached encoder outputs for this request.
|
||||
Returns empty set if no inputs are cached for this request.
|
||||
"""
|
||||
return self.cached.get(request.request_id, set())
|
||||
|
||||
def free_encoder_input(self, request: Request, input_id: int) -> None:
|
||||
"""Free a single encoder input id for the request."""
|
||||
"""Free cache space for a single multimodal input's encoder output.
|
||||
|
||||
This method is called when:
|
||||
- The encoder output has been fully consumed by the decoder and is
|
||||
no longer needed (e.g., in vision-language models after image
|
||||
tokens are processed)
|
||||
- A request is being cancelled or aborted
|
||||
|
||||
Args:
|
||||
request: The request containing the multimodal input
|
||||
input_id: Index of the multimodal input to free from cache
|
||||
"""
|
||||
req_id = request.request_id
|
||||
if req_id not in self.cached:
|
||||
return
|
||||
@ -54,12 +141,29 @@ class EncoderCacheManager:
|
||||
self.freed.append((req_id, input_id))
|
||||
|
||||
def free(self, request: Request) -> None:
|
||||
"""Free all cached input ids for the request."""
|
||||
"""Free all cached encoder outputs for a request.
|
||||
|
||||
This method is typically called when a request is finished, cancelled,
|
||||
or aborted, and all its encoder outputs should be freed from cache.
|
||||
|
||||
Args:
|
||||
request: The request whose encoder outputs should be freed
|
||||
"""
|
||||
input_ids = self.get_cached_input_ids(request).copy()
|
||||
for input_id in input_ids:
|
||||
self.free_encoder_input(request, input_id)
|
||||
|
||||
def get_freed_ids(self) -> list[tuple[str, int]]:
|
||||
"""Get and clear the list of recently freed encoder cache entries.
|
||||
|
||||
This method returns all encoder cache entries that were freed since
|
||||
the last call to this method. It's used by the scheduler to notify
|
||||
workers about which encoder outputs can be removed from their caches.
|
||||
|
||||
Returns:
|
||||
List of (request_id, input_id) tuples that were freed since the
|
||||
last call. The internal freed list is cleared after this call.
|
||||
"""
|
||||
freed = self.freed
|
||||
self.freed = []
|
||||
return freed
|
||||
|
||||
@ -171,6 +171,9 @@ class RequestStatus(enum.IntEnum):
|
||||
FINISHED_ABORTED = enum.auto()
|
||||
FINISHED_IGNORED = enum.auto()
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
@staticmethod
|
||||
def is_finished(status: "RequestStatus") -> bool:
|
||||
return status > RequestStatus.PREEMPTED
|
||||
|
||||
@ -5,7 +5,7 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
from vllm.v1.worker.tpu_input_batch import InputBatch
|
||||
|
||||
DEFAULT_SAMPLING_PARAMS = dict(
|
||||
temperature=-1.0,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Datastructures defining an input batch
|
||||
# Datastructures defining a GPU input batch
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, cast
|
||||
@ -453,6 +453,11 @@ class InputBatch:
|
||||
self.block_table.swap_row(i1, i2)
|
||||
|
||||
def condense(self, empty_req_indices: list[int]) -> None:
|
||||
"""Move non-empty requests down into lower, empty indices.
|
||||
|
||||
Args:
|
||||
empty_req_indices: empty batch indices, sorted descending.
|
||||
"""
|
||||
num_reqs = self.num_reqs
|
||||
if num_reqs == 0:
|
||||
# The batched states are empty.
|
||||
|
||||
@ -5,6 +5,7 @@ Define LoRA functionality mixin for model runners.
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
@ -15,7 +16,10 @@ from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
||||
from vllm.model_executor.models import supports_lora, supports_multimodal
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch as GPUInputBatch
|
||||
from vllm.v1.worker.tpu_input_batch import InputBatch as TPUInputBatch
|
||||
|
||||
InputBatch = Union[TPUInputBatch, GPUInputBatch]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
584
vllm/v1/worker/tpu_input_batch.py
Normal file
584
vllm/v1/worker/tpu_input_batch.py
Normal file
@ -0,0 +1,584 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Datastructures defining a TPU input batch
|
||||
|
||||
from typing import Optional, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.utils import swap_dict_values
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
from vllm.v1.worker.block_table import MultiGroupBlockTable
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
class InputBatch:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_model_len: int,
|
||||
max_num_batched_tokens: int,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
vocab_size: int,
|
||||
block_sizes: list[int], # The block_size of each kv cache group
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_model_len = max_model_len
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
self._req_ids: list[Optional[str]] = []
|
||||
self.req_id_to_index: dict[str, int] = {}
|
||||
|
||||
# TODO(woosuk): This buffer could be too large if max_model_len is big.
|
||||
# Find a way to reduce the CPU memory usage.
|
||||
# This buffer is not directly transferred to the GPU, so it does not
|
||||
# need to be pinned.
|
||||
self.token_ids_cpu_tensor = torch.zeros(
|
||||
(max_num_reqs, max_model_len),
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
pin_memory=False,
|
||||
)
|
||||
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
|
||||
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
self.num_computed_tokens_cpu_tensor = torch.zeros(
|
||||
(max_num_reqs, ),
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
self.num_computed_tokens_cpu = \
|
||||
self.num_computed_tokens_cpu_tensor.numpy()
|
||||
|
||||
# Block table.
|
||||
self.block_table = MultiGroupBlockTable(
|
||||
max_num_reqs=max_num_reqs,
|
||||
max_model_len=max_model_len,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
pin_memory=pin_memory,
|
||||
device=device,
|
||||
block_sizes=block_sizes,
|
||||
)
|
||||
|
||||
# Sampling-related.
|
||||
self.temperature = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
self.temperature_cpu_tensor = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.temperature_cpu = self.temperature_cpu_tensor.numpy()
|
||||
self.greedy_reqs: set[str] = set()
|
||||
self.random_reqs: set[str] = set()
|
||||
|
||||
self.top_p = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
self.top_p_cpu_tensor = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.top_p_cpu = self.top_p_cpu_tensor.numpy()
|
||||
self.top_p_reqs: set[str] = set()
|
||||
|
||||
self.top_k = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
self.top_k_cpu_tensor = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
|
||||
self.top_k_reqs: set[str] = set()
|
||||
|
||||
self.min_p = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
self.min_p_cpu_tensor = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
|
||||
self.min_p_reqs: set[str] = set()
|
||||
|
||||
# Frequency penalty related data structures
|
||||
self.frequency_penalties = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
self.frequency_penalties_cpu_tensor = torch.empty(
|
||||
(max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.frequency_penalties_cpu = \
|
||||
self.frequency_penalties_cpu_tensor.numpy()
|
||||
self.frequency_penalties_reqs: set[str] = set()
|
||||
|
||||
# Presence penalty related data structures
|
||||
self.presence_penalties = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy(
|
||||
)
|
||||
self.presence_penalties_reqs: set[str] = set()
|
||||
|
||||
# Repetition penalty related data structures
|
||||
self.repetition_penalties = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
self.repetition_penalties_cpu_tensor = torch.empty(
|
||||
(max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.repetition_penalties_cpu = \
|
||||
self.repetition_penalties_cpu_tensor.numpy()
|
||||
self.repetition_penalties_reqs: set[str] = set()
|
||||
|
||||
# req_index -> (min_tokens, stop_token_ids)
|
||||
self.min_tokens: dict[int, tuple[int, set[int]]] = {}
|
||||
|
||||
# lora related
|
||||
self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
|
||||
dtype=np.int32)
|
||||
self.lora_id_to_request_ids: dict[int, set[str]] = {}
|
||||
self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
|
||||
|
||||
# req_index -> generator
|
||||
# NOTE(woosuk): The indices of the requests that do not have their own
|
||||
# generator should not be included in the dictionary.
|
||||
self.generators: dict[int, torch.Generator] = {}
|
||||
|
||||
self.num_logprobs: dict[str, int] = {}
|
||||
# NOTE(rob): num_prompt_logprobs only includes reqs
|
||||
# that are currently in the prefill phase.
|
||||
self.num_prompt_logprobs: dict[str, int] = {}
|
||||
|
||||
# To accumulate prompt logprobs tensor chunks across prefill steps.
|
||||
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
|
||||
|
||||
self.logit_bias: list[Optional[dict[int,
|
||||
float]]] = [None] * max_num_reqs
|
||||
self.has_allowed_token_ids: set[str] = set()
|
||||
# NOTE(lufang): In the mask tensor, if the corresponding token allowed,
|
||||
# the value is False. Since we use masked_fill_ to set -inf.
|
||||
self.allowed_token_ids_mask: Optional[torch.Tensor] = None
|
||||
self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None
|
||||
|
||||
# req_index -> bad_words_token_ids
|
||||
self.bad_words_token_ids: dict[int, list[list[int]]] = {}
|
||||
|
||||
self.req_output_token_ids: list[Optional[list[int]]] = []
|
||||
|
||||
@property
|
||||
def req_ids(self) -> list[str]:
|
||||
# None elements should only be present transiently
|
||||
# while performing state updates to the batch.
|
||||
return cast(list[str], self._req_ids)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request: "CachedRequestState",
|
||||
req_index: Optional[int] = None,
|
||||
) -> None:
|
||||
if req_index is None:
|
||||
req_index = self.num_reqs
|
||||
assert req_index < self.max_num_reqs
|
||||
|
||||
req_id = request.req_id
|
||||
if req_index == len(self._req_ids):
|
||||
self._req_ids.append(req_id)
|
||||
self.req_output_token_ids.append(request.output_token_ids)
|
||||
else:
|
||||
self._req_ids[req_index] = req_id
|
||||
self.req_output_token_ids[req_index] = request.output_token_ids
|
||||
|
||||
self.req_id_to_index[req_id] = req_index
|
||||
|
||||
# Copy the prompt token ids and output token ids.
|
||||
num_prompt_tokens = len(request.prompt_token_ids)
|
||||
self.num_prompt_tokens[req_index] = num_prompt_tokens
|
||||
self.token_ids_cpu[
|
||||
req_index, :num_prompt_tokens] = request.prompt_token_ids
|
||||
start_idx = num_prompt_tokens
|
||||
end_idx = start_idx + len(request.output_token_ids)
|
||||
self.token_ids_cpu[req_index,
|
||||
start_idx:end_idx] = request.output_token_ids
|
||||
# Number of token ids in token_ids_cpu.
|
||||
# NOTE(woosuk): This may include spec decode tokens.
|
||||
self.num_tokens[req_index] = request.num_tokens
|
||||
# Number of tokens without spec decode tokens.
|
||||
self.num_tokens_no_spec[req_index] = request.num_tokens
|
||||
|
||||
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
|
||||
self.block_table.add_row(request.block_ids, req_index)
|
||||
|
||||
sampling_params = request.sampling_params
|
||||
if sampling_params.sampling_type == SamplingType.GREEDY:
|
||||
# Avoid later division by zero.
|
||||
self.temperature_cpu[req_index] = -1.0
|
||||
self.greedy_reqs.add(req_id)
|
||||
else:
|
||||
self.temperature_cpu[req_index] = sampling_params.temperature
|
||||
self.random_reqs.add(req_id)
|
||||
|
||||
self.top_p_cpu[req_index] = sampling_params.top_p
|
||||
if sampling_params.top_p < 1:
|
||||
self.top_p_reqs.add(req_id)
|
||||
top_k = sampling_params.top_k
|
||||
if 0 < top_k < self.vocab_size:
|
||||
self.top_k_reqs.add(req_id)
|
||||
else:
|
||||
top_k = self.vocab_size
|
||||
self.top_k_cpu[req_index] = top_k
|
||||
self.min_p_cpu[req_index] = sampling_params.min_p
|
||||
self.frequency_penalties_cpu[
|
||||
req_index] = sampling_params.frequency_penalty
|
||||
if sampling_params.min_p > _SAMPLING_EPS:
|
||||
self.min_p_reqs.add(req_id)
|
||||
if sampling_params.frequency_penalty != 0.0:
|
||||
self.frequency_penalties_reqs.add(req_id)
|
||||
self.presence_penalties_cpu[
|
||||
req_index] = sampling_params.presence_penalty
|
||||
if sampling_params.presence_penalty != 0.0:
|
||||
self.presence_penalties_reqs.add(req_id)
|
||||
self.repetition_penalties_cpu[
|
||||
req_index] = sampling_params.repetition_penalty
|
||||
if sampling_params.repetition_penalty != 1.0:
|
||||
self.repetition_penalties_reqs.add(req_id)
|
||||
if sampling_params.min_tokens:
|
||||
self.min_tokens[req_index] = (sampling_params.min_tokens,
|
||||
sampling_params.all_stop_token_ids)
|
||||
|
||||
# NOTE(woosuk): self.generators should not include the requests that
|
||||
# do not have their own generator.
|
||||
if request.generator is not None:
|
||||
self.generators[req_index] = request.generator
|
||||
|
||||
if sampling_params.logprobs is not None:
|
||||
self.num_logprobs[req_id] = sampling_params.logprobs
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs
|
||||
if sampling_params.logit_bias is not None:
|
||||
self.logit_bias[req_index] = sampling_params.logit_bias
|
||||
|
||||
if sampling_params.allowed_token_ids:
|
||||
self.has_allowed_token_ids.add(req_id)
|
||||
if self.allowed_token_ids_mask_cpu_tensor is None:
|
||||
# Lazy allocation for this tensor, which can be large.
|
||||
# False means we don't fill with -inf.
|
||||
self.allowed_token_ids_mask = torch.zeros(self.max_num_reqs,
|
||||
self.vocab_size,
|
||||
dtype=torch.bool,
|
||||
device=self.device)
|
||||
self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
self.vocab_size,
|
||||
dtype=torch.bool,
|
||||
device="cpu")
|
||||
self.allowed_token_ids_mask_cpu_tensor[req_index] = True
|
||||
# False means we don't fill with -inf.
|
||||
self.allowed_token_ids_mask_cpu_tensor[req_index][
|
||||
sampling_params.allowed_token_ids] = False
|
||||
|
||||
if sampling_params.bad_words_token_ids:
|
||||
self.bad_words_token_ids[
|
||||
req_index] = sampling_params.bad_words_token_ids
|
||||
|
||||
# Add request lora ID
|
||||
if request.lora_request:
|
||||
lora_id = request.lora_request.lora_int_id
|
||||
if lora_id not in self.lora_id_to_request_ids:
|
||||
self.lora_id_to_request_ids[lora_id] = set()
|
||||
|
||||
self.request_lora_mapping[req_index] = lora_id
|
||||
self.lora_id_to_request_ids[lora_id].add(request.req_id)
|
||||
self.lora_id_to_lora_request[lora_id] = request.lora_request
|
||||
else:
|
||||
# No LoRA
|
||||
self.request_lora_mapping[req_index] = 0
|
||||
|
||||
def remove_request(self, req_id: str) -> Optional[int]:
|
||||
"""This method must always be followed by a call to condense()."""
|
||||
|
||||
req_index = self.req_id_to_index.pop(req_id, None)
|
||||
if req_index is None:
|
||||
return None
|
||||
self._req_ids[req_index] = None
|
||||
self.req_output_token_ids[req_index] = None
|
||||
|
||||
self.greedy_reqs.discard(req_id)
|
||||
self.random_reqs.discard(req_id)
|
||||
self.top_p_reqs.discard(req_id)
|
||||
self.top_k_reqs.discard(req_id)
|
||||
self.min_p_reqs.discard(req_id)
|
||||
self.min_tokens.pop(req_index, None)
|
||||
self.frequency_penalties_reqs.discard(req_id)
|
||||
self.presence_penalties_reqs.discard(req_id)
|
||||
self.repetition_penalties_reqs.discard(req_id)
|
||||
self.generators.pop(req_index, None)
|
||||
self.num_logprobs.pop(req_id, None)
|
||||
self.num_prompt_logprobs.pop(req_id, None)
|
||||
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
|
||||
|
||||
# LoRA
|
||||
lora_id = self.request_lora_mapping[req_index]
|
||||
if lora_id != 0:
|
||||
self.lora_id_to_request_ids[lora_id].discard(req_id)
|
||||
if len(self.lora_id_to_request_ids[lora_id]) == 0:
|
||||
self.lora_id_to_request_ids.pop(lora_id)
|
||||
self.lora_id_to_lora_request.pop(lora_id)
|
||||
self.request_lora_mapping[req_index] = 0
|
||||
|
||||
self.logit_bias[req_index] = None
|
||||
self.has_allowed_token_ids.discard(req_id)
|
||||
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||
# False means we don't fill with -inf.
|
||||
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
|
||||
self.bad_words_token_ids.pop(req_index, None)
|
||||
return req_index
|
||||
|
||||
def swap_states(self, i1: int, i2: int) -> None:
|
||||
old_id_i1 = self._req_ids[i1]
|
||||
old_id_i2 = self._req_ids[i2]
|
||||
self._req_ids[i1], self._req_ids[i2] =\
|
||||
self._req_ids[i2], self._req_ids[i1] # noqa
|
||||
self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\
|
||||
self.req_output_token_ids[i2], self.req_output_token_ids[i1]
|
||||
assert old_id_i1 is not None and old_id_i2 is not None
|
||||
self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\
|
||||
self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
|
||||
self.num_tokens[i1], self.num_tokens[i2] =\
|
||||
self.num_tokens[i2], self.num_tokens[i1]
|
||||
self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\
|
||||
self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1]
|
||||
self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\
|
||||
self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
|
||||
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
|
||||
self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
|
||||
self.temperature_cpu[i1], self.temperature_cpu[i2] =\
|
||||
self.temperature_cpu[i2], self.temperature_cpu[i1]
|
||||
self.top_p_cpu[i1], self.top_p_cpu[i2] =\
|
||||
self.top_p_cpu[i2], self.top_p_cpu[i1]
|
||||
self.top_k_cpu[i1], self.top_k_cpu[i2] =\
|
||||
self.top_k_cpu[i2], self.top_k_cpu[i1]
|
||||
self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\
|
||||
self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
|
||||
self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\
|
||||
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
|
||||
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\
|
||||
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
|
||||
self.min_p_cpu[i1], self.min_p_cpu[i2] =\
|
||||
self.min_p_cpu[i2], self.min_p_cpu[i1]
|
||||
|
||||
# NOTE: the following is unsafe
|
||||
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
|
||||
# self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
|
||||
# instead, we need to temporiarily copy the data for one of the indices
|
||||
# TODO(lucas): optimize this by only copying valid indices
|
||||
tmp = self.token_ids_cpu[i1, ...].copy()
|
||||
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
|
||||
self.token_ids_cpu[i2, ...] = tmp
|
||||
|
||||
swap_dict_values(self.generators, i1, i2)
|
||||
swap_dict_values(self.min_tokens, i1, i2)
|
||||
swap_dict_values(self.bad_words_token_ids, i1, i2)
|
||||
|
||||
self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
|
||||
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
|
||||
self.logit_bias[i1], self.logit_bias[i2] =\
|
||||
self.logit_bias[i2], self.logit_bias[i1]
|
||||
|
||||
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||
self.allowed_token_ids_mask_cpu_tensor[i1], \
|
||||
self.allowed_token_ids_mask_cpu_tensor[i2] =\
|
||||
self.allowed_token_ids_mask_cpu_tensor[i2], \
|
||||
self.allowed_token_ids_mask_cpu_tensor[i1]
|
||||
self.block_table.swap_row(i1, i2)
|
||||
|
||||
def condense(self, empty_req_indices: list[int]) -> None:
|
||||
"""Move non-empty requests down into lower, empty indices.
|
||||
|
||||
Args:
|
||||
empty_req_indices: empty batch indices, sorted descending.
|
||||
"""
|
||||
num_reqs = self.num_reqs
|
||||
if num_reqs == 0:
|
||||
# The batched states are empty.
|
||||
self._req_ids.clear()
|
||||
self.req_output_token_ids.clear()
|
||||
return
|
||||
|
||||
# NOTE(woosuk): This function assumes that the empty_req_indices
|
||||
# is sorted in descending order.
|
||||
last_req_index = num_reqs + len(empty_req_indices) - 1
|
||||
while empty_req_indices:
|
||||
# Find the largest non-empty index.
|
||||
while last_req_index in empty_req_indices:
|
||||
last_req_index -= 1
|
||||
|
||||
# Find the smallest empty index.
|
||||
empty_index = empty_req_indices.pop()
|
||||
if empty_index >= last_req_index:
|
||||
break
|
||||
|
||||
# Swap the states.
|
||||
req_id = self._req_ids[last_req_index]
|
||||
output_token_ids = self.req_output_token_ids[last_req_index]
|
||||
assert req_id is not None
|
||||
self._req_ids[empty_index] = req_id
|
||||
self._req_ids[last_req_index] = None
|
||||
self.req_output_token_ids[empty_index] = output_token_ids
|
||||
self.req_output_token_ids[last_req_index] = None
|
||||
self.req_id_to_index[req_id] = empty_index
|
||||
|
||||
num_tokens = self.num_tokens[last_req_index]
|
||||
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
|
||||
last_req_index, :num_tokens]
|
||||
self.num_tokens[empty_index] = num_tokens
|
||||
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
|
||||
last_req_index]
|
||||
self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[
|
||||
last_req_index]
|
||||
self.num_computed_tokens_cpu[
|
||||
empty_index] = self.num_computed_tokens_cpu[last_req_index]
|
||||
self.block_table.move_row(last_req_index, empty_index)
|
||||
self.temperature_cpu[empty_index] = self.temperature_cpu[
|
||||
last_req_index]
|
||||
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
|
||||
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
|
||||
self.frequency_penalties_cpu[
|
||||
empty_index] = self.frequency_penalties_cpu[last_req_index]
|
||||
self.presence_penalties_cpu[
|
||||
empty_index] = self.presence_penalties_cpu[last_req_index]
|
||||
self.repetition_penalties_cpu[
|
||||
empty_index] = self.repetition_penalties_cpu[last_req_index]
|
||||
self.min_p_cpu[empty_index] = self.min_p_cpu[last_req_index]
|
||||
generator = self.generators.pop(last_req_index, None)
|
||||
if generator is not None:
|
||||
self.generators[empty_index] = generator
|
||||
|
||||
min_token = self.min_tokens.pop(last_req_index, None)
|
||||
if min_token is not None:
|
||||
self.min_tokens[empty_index] = min_token
|
||||
|
||||
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
|
||||
last_req_index]
|
||||
|
||||
self.logit_bias[empty_index] = self.logit_bias[last_req_index]
|
||||
|
||||
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||
self.allowed_token_ids_mask_cpu_tensor[
|
||||
empty_index] = self.allowed_token_ids_mask_cpu_tensor[
|
||||
last_req_index]
|
||||
|
||||
bad_words_token_ids = self.bad_words_token_ids.pop(
|
||||
last_req_index, None)
|
||||
if bad_words_token_ids is not None:
|
||||
self.bad_words_token_ids[empty_index] = bad_words_token_ids
|
||||
# Decrement last_req_index since it is now empty.
|
||||
last_req_index -= 1
|
||||
|
||||
# Trim lists to the batch size.
|
||||
del self._req_ids[self.num_reqs:]
|
||||
del self.req_output_token_ids[self.num_reqs:]
|
||||
|
||||
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
|
||||
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
|
||||
prompt_token_ids_cpu_tensor = torch.empty(
|
||||
(self.num_reqs, max_prompt_len),
|
||||
device="cpu",
|
||||
dtype=torch.int64,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
|
||||
prompt_token_ids[:] = self.token_ids_cpu[:self.
|
||||
num_reqs, :max_prompt_len]
|
||||
# Use the value of vocab_size as a pad since we don't have a
|
||||
# token_id of this value.
|
||||
for i in range(self.num_reqs):
|
||||
prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
|
||||
return prompt_token_ids_cpu_tensor.to(device=self.device,
|
||||
non_blocking=True)
|
||||
|
||||
def make_lora_inputs(
|
||||
self, num_scheduled_tokens: np.ndarray
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
|
||||
"""
|
||||
Given the num_scheduled_tokens for each request in the batch, return
|
||||
datastructures used to activate the current LoRAs.
|
||||
Returns:
|
||||
1. prompt_lora_mapping: A tuple of size self.num_reqs where,
|
||||
prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
|
||||
2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
|
||||
where, token_lora_mapping[i] is the LoRA id to use for ith token.
|
||||
3. lora_requests: Set of relevant LoRA requests.
|
||||
"""
|
||||
|
||||
req_lora_mapping = self.request_lora_mapping[:self.num_reqs]
|
||||
prompt_lora_mapping = tuple(req_lora_mapping)
|
||||
token_lora_mapping = tuple(
|
||||
req_lora_mapping.repeat(num_scheduled_tokens))
|
||||
active_lora_requests: set[LoRARequest] = set(
|
||||
self.lora_id_to_lora_request.values())
|
||||
|
||||
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
|
||||
|
||||
@property
|
||||
def num_reqs(self) -> int:
|
||||
return len(self.req_id_to_index)
|
||||
|
||||
@property
|
||||
def all_greedy(self) -> bool:
|
||||
return len(self.random_reqs) == 0
|
||||
|
||||
@property
|
||||
def all_random(self) -> bool:
|
||||
return len(self.greedy_reqs) == 0
|
||||
|
||||
@property
|
||||
def no_top_p(self) -> bool:
|
||||
return len(self.top_p_reqs) == 0
|
||||
|
||||
@property
|
||||
def no_top_k(self) -> bool:
|
||||
return len(self.top_k_reqs) == 0
|
||||
|
||||
@property
|
||||
def no_min_p(self) -> bool:
|
||||
return len(self.min_p_reqs) == 0
|
||||
|
||||
@property
|
||||
def no_penalties(self) -> bool:
|
||||
return (len(self.presence_penalties_reqs) == 0
|
||||
and len(self.frequency_penalties_reqs) == 0
|
||||
and len(self.repetition_penalties_reqs) == 0)
|
||||
|
||||
@property
|
||||
def max_num_logprobs(self) -> Optional[int]:
|
||||
return max(self.num_logprobs.values()) if self.num_logprobs else None
|
||||
|
||||
@property
|
||||
def no_prompt_logprob(self) -> bool:
|
||||
return not self.num_prompt_logprobs
|
||||
|
||||
@property
|
||||
def no_allowed_token_ids(self) -> bool:
|
||||
return len(self.has_allowed_token_ids) == 0
|
||||
@ -42,8 +42,8 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
|
||||
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
|
||||
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
from .utils import (initialize_kv_cache_for_kv_sharing,
|
||||
sanity_check_mm_encoder_outputs)
|
||||
|
||||
Reference in New Issue
Block a user