Compare commits

..

10 Commits

28 changed files with 1074 additions and 418 deletions

View File

@ -189,6 +189,9 @@ class BenchmarkDataset(ABC):
"""
if len(requests) < num_requests:
random.seed(self.random_seed)
logger.info("Current number of requests: %d", len(requests))
logger.info("Oversampled requests to reach %d total samples.",
num_requests)
additional = random.choices(requests,
k=num_requests - len(requests))
requests.extend(additional)
@ -402,6 +405,13 @@ class ShareGPTDataset(BenchmarkDataset):
entry["conversations"][1]["value"],
)
prompt = tokenizer.apply_chat_template([{
"role": "user",
"content": prompt
}],
add_generation_prompt=True,
tokenize=False)
lora_request, tokenizer = self.get_random_lora_request(
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path)
prompt_ids = tokenizer(prompt).input_ids
@ -760,6 +770,14 @@ class InstructCoderDataset(HuggingFaceDataset):
if len(sampled_requests) >= num_requests:
break
prompt = f"{item['instruction']}:\n{item['input']}"
prompt = tokenizer.apply_chat_template([{
"role": "user",
"content": prompt
}],
add_generation_prompt=True,
tokenize=False)
prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests.append(
SampleRequest(
@ -793,11 +811,18 @@ class AIMODataset(HuggingFaceDataset):
sampled_requests = []
dynamic_output = output_len is None
for item in self.data:
for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests:
break
prompt, completion = item['problem'], item["solution"]
prompt = tokenizer.apply_chat_template([{
"role": "user",
"content": prompt
}],
add_generation_prompt=True,
tokenize=False)
prompt_ids = tokenizer(prompt).input_ids
completion_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_ids)
@ -895,3 +920,103 @@ class ASRDataset(HuggingFaceDataset):
" what Whisper supports.", skipped)
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests
class MTBenchDataset(HuggingFaceDataset):
"""
MT-Bench Dataset.
https://huggingface.co/datasets/philschmid/mt-bench
We create a single turn dataset for MT-Bench.
This is similar to Spec decoding benchmark setup in vLLM
https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18
""" # noqa: E501
DEFAULT_OUTPUT_LEN = 256 # avg len used in SD bench in vLLM
SUPPORTED_DATASET_PATHS = {
"philschmid/mt-bench",
}
def sample(self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
**kwargs) -> list:
output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
sampled_requests = []
for item in self.data:
if len(sampled_requests) >= num_requests:
break
prompt = item['turns'][0]
# apply template
prompt = tokenizer.apply_chat_template([{
"role": "user",
"content": prompt
}],
add_generation_prompt=True,
tokenize=False)
prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
))
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests
class CNNDailyMailDataset(HuggingFaceDataset):
"""
MT-Bench Dataset.
https://huggingface.co/datasets/philschmid/mt-bench
We create a single turn dataset for MT-Bench.
This is similar to Spec decoding benchmark setup in vLLM
https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18
""" # noqa: E501
DEFAULT_OUTPUT_LEN = 256 # avg len used in SD bench in vLLM
SUPPORTED_DATASET_PATHS = {
"abisee/cnn_dailymail",
}
def sample(self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
**kwargs) -> list:
output_len = (output_len
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
sampled_requests = []
for item in self.data:
if len(sampled_requests) >= num_requests:
break
instruction = "Could you summarize the following article, " \
"please reuse text from the article if possible: "
prompt = instruction + item['article']
# apply template
prompt = tokenizer.apply_chat_template([{
"role": "user",
"content": prompt
}],
add_generation_prompt=True,
tokenize=False)
prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
))
self.maybe_oversample_requests(sampled_requests, num_requests)
return sampled_requests

View File

@ -12,7 +12,8 @@ from typing import Any, Optional, Union
import torch
import uvloop
from benchmark_dataset import (AIMODataset, BurstGPTDataset,
ConversationDataset, InstructCoderDataset,
CNNDailyMailDataset, ConversationDataset,
InstructCoderDataset, MTBenchDataset,
RandomDataset, SampleRequest, ShareGPTDataset,
SonnetDataset, VisionArenaDataset)
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
@ -57,9 +58,9 @@ def run_vllm(
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
temperature=0,
top_p=1.0,
ignore_eos=True,
ignore_eos=False,
max_tokens=request.expected_output_len,
detokenize=not disable_detokenize,
))
@ -123,9 +124,9 @@ def run_vllm_chat(
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
temperature=0,
top_p=1.0,
ignore_eos=True,
ignore_eos=False,
max_tokens=request.expected_output_len,
detokenize=not disable_detokenize,
))
@ -167,9 +168,9 @@ async def run_vllm_async(
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
temperature=0,
top_p=1.0,
ignore_eos=True,
ignore_eos=False,
max_tokens=request.expected_output_len,
detokenize=not disable_detokenize,
))
@ -339,6 +340,14 @@ def get_requests(args, tokenizer):
dataset_cls = AIMODataset
common_kwargs['dataset_subset'] = None
common_kwargs['dataset_split'] = "train"
elif args.dataset_path in MTBenchDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = MTBenchDataset
common_kwargs['dataset_subset'] = None
common_kwargs['dataset_split'] = "train"
elif args.dataset_path in CNNDailyMailDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = CNNDailyMailDataset
common_kwargs['dataset_subset'] = '3.0.0'
common_kwargs['dataset_split'] = "train"
else:
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
# Remove None values
@ -477,8 +486,11 @@ def validate_args(args):
VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
| ConversationDataset.SUPPORTED_DATASET_PATHS):
assert args.backend == "vllm-chat", f"{args.dataset_path} needs to use vllm-chat as the backend." #noqa: E501
elif args.dataset_path in (InstructCoderDataset.SUPPORTED_DATASET_PATHS
| AIMODataset.SUPPORTED_DATASET_PATHS):
elif args.dataset_path in (
InstructCoderDataset.SUPPORTED_DATASET_PATHS
| AIMODataset.SUPPORTED_DATASET_PATHS
| MTBenchDataset.SUPPORTED_DATASET_PATHS
| CNNDailyMailDataset.SUPPORTED_DATASET_PATHS):
assert args.backend == "vllm", f"{args.dataset_path} needs to use vllm as the backend." #noqa: E501
else:
raise ValueError(

213
benchmarks/run.sh Normal file
View File

@ -0,0 +1,213 @@
# python benchmarks/benchmark_throughput.py \
# --model meta-llama/Meta-Llama-3-8B-Instruct \
# --dataset-name sonnet \
# --dataset-path /data/lily/batch-sd/benchmarks/sonnet.txt \
# --prefix-len 0 \
# --output-len 512 \
# --num-prompts 200 \
# --speculative_config '{"method": "ngram", "num_speculative_tokens": 20, "prompt_lookup_min": 2, "prompt_lookup_max": 5}'
# python benchmarks/benchmark_throughput.py \
# --model meta-llama/Meta-Llama-3-8B-Instruct \
# --dataset-name sharegpt \
# --dataset-path /data/lily/ShareGPT_V3_unfiltered_cleaned_split.json \
# --prefix-len 0 \
# --output-len 512 \
# --num-prompts 200 \
# --speculative_config '{"method": "ngram", "num_speculative_tokens": 20, "prompt_lookup_min": 2, "prompt_lookup_max": 5}'
# python benchmarks/benchmark_throughput.py \
# --model meta-llama/Meta-Llama-3-8B-Instruct \
# --dataset-name hf \
# --dataset-path likaixin/InstructCoder \
# --prefix-len 0 \
# --output-len 512 \
# --num-prompts 200 \
# --speculative_config '{"method": "ngram", "num_speculative_tokens": 20, "prompt_lookup_min": 2, "prompt_lookup_max": 5}'
# python benchmarks/benchmark_throughput.py \
# --model meta-llama/Meta-Llama-3-8B-Instruct \
# --dataset-name sonnet \
# --dataset-path /data/lily/batch-sd/benchmarks/sonnet.txt \
# --prefix-len 0 \
# --output-len 512 \
# --num-prompts 200 \
# --speculative_config '{"method": "eagle", "model": "yuhuili/EAGLE-LLaMA3-Instruct-8B", "num_speculative_tokens": 20}'
# python benchmarks/benchmark_throughput.py \
# --model meta-llama/Meta-Llama-3-8B-Instruct \
# --dataset-name sharegpt \
# --dataset-path /data/lily/ShareGPT_V3_unfiltered_cleaned_split.json \
# --prefix-len 0 \
# --output-len 512 \
# --num-prompts 200 \
# --speculative_config '{"method": "eagle", "model": "yuhuili/EAGLE-LLaMA3-Instruct-8B", "num_speculative_tokens": 20}'
# python benchmarks/benchmark_throughput.py \
# --model meta-llama/Meta-Llama-3-8B-Instruct \
# --dataset-name hf \
# --dataset-path likaixin/InstructCoder \
# --prefix-len 0 \
# --output-len 512 \
# --num-prompts 200 \
# --speculative_config '{"method": "eagle", "model": "yuhuili/EAGLE-LLaMA3-Instruct-8B", "num_speculative_tokens": 20}'
# python benchmarks/benchmark_throughput.py \
# --model meta-llama/Meta-Llama-3.1-8B-Instruct \
# --dataset-name hf \
# --dataset-path likaixin/InstructCoder \
# --prefix-len 0 \
# --output-len 512 \
# --num-prompts 200 \
# --speculative_config '{"method": "eagle3", "model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 20}'
python benchmarks/benchmark_throughput.py \
--model meta-llama/Meta-Llama-3.1-8B-Instruct\
--dataset-name hf \
--dataset-path philschmid/mt-bench \
--prefix-len 0 \
--output-len 512 \
--num-prompts 200 \
--speculative_config '{"method": "eagle3", "model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 20}'
# python benchmarks/benchmark_throughput.py \
# --model meta-llama/Meta-Llama-3.1-8B-Instruct \
# --dataset-name sharegpt \
# --dataset-path /data/lily/ShareGPT_V3_unfiltered_cleaned_split.json \
# --prefix-len 0 \
# --output-len 512 \
# --num-prompts 200 \
# --speculative_config '{"method": "eagle3", "model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 20}'
# python benchmarks/benchmark_throughput.py \
# --model meta-llama/Meta-Llama-3.1-8B-Instruct \
# --dataset-name sonnet \
# --dataset-path /data/lily/batch-sd/benchmarks/sonnet.txt \
# --prefix-len 0 \
# --output-len 512 \
# --num-prompts 200 \
# --speculative_config '{"method": "eagle3", "model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 20}'
# python benchmarks/benchmark_throughput.py \
# --model meta-llama/Meta-Llama-3.1-8B-Instruct \
# --dataset-name hf \
# --dataset-path likaixin/InstructCoder \
# --prefix-len 0 \
# --output-len 512 \
# --num-prompts 200 \
# --speculative_config '{"method": "eagle", "model": "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 20}'
# python benchmarks/benchmark_throughput.py \
# --model meta-llama/Meta-Llama-3.1-8B-Instruct \
# --dataset-name sharegpt \
# --dataset-path /data/lily/ShareGPT_V3_unfiltered_cleaned_split.json \
# --prefix-len 0 \
# --output-len 512 \
# --num-prompts 200 \
# --speculative_config '{"method": "eagle", "model": "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 20}'
# python benchmarks/benchmark_throughput.py \
# --model meta-llama/Meta-Llama-3.1-8B-Instruct \
# --dataset-name hf \
# --dataset-path likaixin/InstructCoder \
# --prefix-len 0 \
# --output-len 512 \
# --num-prompts 200 \
# --speculative_config '{"method": "ngram", "num_speculative_tokens": 20, "prompt_lookup_min": 2, "prompt_lookup_max": 5}'
# python benchmarks/benchmark_throughput.py \
# --model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
# --dataset-name hf \
# --dataset-path AI-MO/aimo-validation-aime \
# --prefix-len 0 \
# --output-len 1024 \
# --num-prompts 90 \
# --speculative_config '{"method": "eagle3", "num_speculative_tokens": 20, "model": "yuhuili/EAGLE3-DeepSeek-R1-Distill-LLaMA-8B"}'
# python benchmarks/benchmark_throughput.py \
# --model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
# --dataset-name hf \
# --dataset-path AI-MO/aimo-validation-aime \
# --prefix-len 0 \
# --output-len 1024 \
# --num-prompts 90 \
# --speculative_config '{"method": "ngram", "num_speculative_tokens": 20, "prompt_lookup_min": 2, "prompt_lookup_max": 5}'
# python benchmarks/benchmark_throughput.py \
# --model meta-llama/Meta-Llama-3.1-8B-Instruct \
# --dataset-name sharegpt \
# --dataset-path /data/lily/ShareGPT_V3_unfiltered_cleaned_split.json \
# --prefix-len 0 \
# --output-len 512 \
# --num-prompts 200 \
# --speculative_config '{"method": "ngram", "num_speculative_tokens": 20, "prompt_lookup_min": 2, "prompt_lookup_max": 5}'
# python benchmarks/benchmark_throughput.py \
# --model meta-llama/Meta-Llama-3.1-8B-Instruct \
# --dataset-name hf \
# --dataset-path philschmid/mt-bench \
# --prefix-len 0 \
# --output-len 512 \
# --num-prompts 200 \
# --speculative_config '{"method": "ngram", "num_speculative_tokens": 20, "prompt_lookup_min": 2, "prompt_lookup_max": 5}'
# python benchmarks/benchmark_throughput.py \
# --model meta-llama/Meta-Llama-3.1-8B-Instruct \
# --dataset-name hf \
# --dataset-path philschmid/mt-bench \
# --prefix-len 0 \
# --output-len 512 \
# --num-prompts 200 \
# --speculative_config '{"method": "eagle", "model": "yuhuili/EAGLE-LLaMA3-Instruct-8B", "num_speculative_tokens": 20}'
# python benchmarks/benchmark_throughput.py \
# --model meta-llama/Meta-Llama-3.1-8B-Instruct \
# --dataset-name hf \
# --dataset-path abisee/cnn_dailymail \
# --prefix-len 0 \
# --output-len 512 \
# --num-prompts 200 \
# --speculative_config '{"method": "eagle", "model": "yuhuili/EAGLE-LLaMA3-Instruct-8B", "num_speculative_tokens": 20}'
# python benchmarks/benchmark_throughput.py \
# --model meta-llama/Meta-Llama-3.1-8B-Instruct \
# --dataset-name hf \
# --dataset-path abisee/cnn_dailymail \
# --prefix-len 0 \
# --output-len 512 \
# --num-prompts 200 \
# --speculative_config '{"method": "ngram", "num_speculative_tokens": 20, "prompt_lookup_min": 2, "prompt_lookup_max": 5}'
# python benchmarks/benchmark_throughput.py \
# --model meta-llama/Meta-Llama-3.1-8B-Instruct \
# --dataset-name hf \
# --dataset-path philschmid/mt-bench \
# --prefix-len 0 \
# --output-len 512 \
# --num-prompts 10 \
# --speculative_config '{"method": "eagle3", "model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 20}'
# python benchmarks/benchmark_throughput.py \
# --model meta-llama/Meta-Llama-3.1-8B-Instruct \
# --dataset-name hf \
# --dataset-path abisee/cnn_dailymail \
# --prefix-len 0 \
# --output-len 512 \
# --num-prompts 200 \
# --speculative_config '{"method": "ngram", "num_speculative_tokens": 20, "prompt_lookup_min": 2, "prompt_lookup_max": 5}'

View File

@ -0,0 +1,63 @@
import json
from dataclasses import dataclass
MODEL_TO_NAMES = {
"r1-distill-llama-8B" : "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"llama3-8B" : "meta-llama/Meta-Llama-3-8B-Instruct",
"llama3.1-8B" : "meta-llama/Llama-3.1-8B-Instruct",
"llama3.1-70B" : "meta-llama/Llama-3.1-70B-Instruct",
}
@dataclass
class AccStats:
lens: list[int]
probs: list[float] = None
entropies: list[float] = None
def __post_init__(self):
if self.probs is not None:
assert len(self.lens) == len(self.probs), "Length of lens and probs must match"
if self.entropies is not None:
assert len(self.lens) == len(self.entropies), "Length of lens and entropies must match"
# remove the prefill accepted lens
self.lens = self.lens[1:]
# remove the last proposed tokens
if self.probs:
self.probs = self.probs[:-1]
if self.entropies:
self.entropies = self.entropies[:-1]
@property
def length(self):
return len(self.lens)
# def cleanup(acc_stats: AccStats) ->
# # Remove the prefill phase
# data = data[1:]
# # Cap the maximum value to 10
# data = [min(x, 10) for x in data]
# return data
def load_data(datapath, tokenizer, verbose=False):
acceptance_stats = []
with open(datapath, "r") as f:
lines = f.readlines()
for line in lines:
data = json.loads(line)
stat = AccStats(
lens=data['acc']['acc_len'],
probs=data['acc'].get('acc_prob', None),
entropies=data['acc'].get('acc_entropy', None)
)
acceptance_stats.append(stat)
if verbose:
print("Input:", tokenizer.decode(data['prompt_token_ids']))
print("Output:", tokenizer.decode(data['generated_token_ids']))
print("=============================================")
max_length = max(stats.length for stats in acceptance_stats)
print(f"Load {len(acceptance_stats)} with max length {max_length}")
return acceptance_stats

View File

@ -0,0 +1,108 @@
import json
import seaborn as sns
import matplotlib.pyplot as plt
from transformers import AutoTokenizer
from .common import MODEL_TO_NAMES, load_data
import requests
import os
from pathlib import Path
class AcceptanceStatsClient:
"""Client for fetching and processing acceptance statistics data."""
def __init__(self, model_name, method, dataset, data_path=None):
"""Initialize the client with model and dataset info."""
self.model_name = model_name
self.method = method
self.dataset = dataset
if data_path is None:
self.data_path = f"/data/lily/batch-sd/data/{model_name}/{method}_{dataset}_acceptance_stats.jsonl"
else:
self.data_path = data_path
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_TO_NAMES[model_name], use_fast=False)
self.acceptance_stats = None
def load_data(self):
"""Load the acceptance statistics from file."""
self.acceptance_stats = load_data(self.data_path, self.tokenizer)
return self.acceptance_stats
def plot_heatmap(self, output_dir="figures"):
"""Plot the acceptance statistics as a heatmap."""
if self.acceptance_stats is None:
self.load_data()
fig, ax = plt.subplots(figsize=(12, 8))
sns.heatmap(self.acceptance_stats, cmap="YlGnBu")
plt.xlabel("Position")
plt.ylabel("Request ID")
# Add Y-axis labels on the right
ax2 = ax.twinx()
ax2.set_ylim(ax.get_ylim())
ax2.set_yticks([])
ax2.set_ylabel("# of Accepted Tokens", labelpad=10)
plt.title(f"Acceptance Statistics: {self.model_name} - {self.method} - {self.dataset}")
plt.tight_layout()
# Create output directory if it doesn't exist
output_path = Path(output_dir) / self.model_name
os.makedirs(output_path, exist_ok=True)
output_file = output_path / f"{self.method}_{self.dataset}_acceptance_stats.pdf"
plt.savefig(output_file)
print(f"Saved heatmap to {output_file}")
return fig
def get_summary_stats(self):
"""Get summary statistics about the acceptance data."""
if self.acceptance_stats is None:
self.load_data()
# Calculate average acceptance rate for each position
avg_by_position = [sum(col)/len(col) for col in zip(*self.acceptance_stats) if sum(1 for v in col if v >= 0) > 0]
# Calculate average acceptance rate for each request
avg_by_request = [sum(row)/len(row) for row in self.acceptance_stats]
return {
"total_requests": len(self.acceptance_stats),
"max_position": len(avg_by_position),
"avg_acceptance_rate": sum(avg_by_request)/len(avg_by_request),
"avg_by_position": avg_by_position,
"avg_by_request": avg_by_request
}
# Example model configuration
model = "llama3.1-8B"
# model = "r1-distill-llama-8B"
method = "eagle3"
dataset = "mtbench"
# dataset = "aime"
# method = "ngram"
# dataset = "cnndailymail"
# datapath = f"/data/lily/batch-sd/data/{model}/{method}_{dataset}_acceptance_stats.jsonl"
datapath = "acceptance_stats.jsonl"
tokenizer = AutoTokenizer.from_pretrained(MODEL_TO_NAMES[model], use_fast=False)
if __name__ == "__main__":
# Use the client instead of directly loading data
client = AcceptanceStatsClient(model, method, dataset, datapath)
acceptance_stats = client.load_data()
# Get summary statistics
summary = client.get_summary_stats()
print("Summary Statistics:")
print(f"Total Requests: {summary['total_requests']}")
print(f"Max Position: {summary['max_position']}")
print(f"Average Acceptance Rate: {summary['avg_acceptance_rate']:.2f}")
# Create heatmap visualization
plot_heatmap = False
if plot_heatmap:
client.plot_heatmap()

View File

@ -0,0 +1,69 @@
import json
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
model = "llama3.1-8B"
dataset = "instructcode"
method1 = "ngram"
method2 = "eagle3"
def get_datapath(method):
datapath = f"/data/lily/batch-sd/data/{model}/{method}_{dataset}_acceptance_stats.jsonl"
return datapath
def cleanup(data):
# Remove the prefill phase
data = data[1:]
# Cap the maximum value to 10
data = [min(x, 10) for x in data]
return data
def load_data(datapath):
acceptance_stats = {}
with open(datapath, "r") as f:
lines = f.readlines()
for line in lines:
data = json.loads(line)
key = hash(tuple(data['prompt_token_ids']))
acceptance_stats[key] = cleanup(data['acc'])
# Pad the acceptance stats to the same length
max_length = max(len(stats) for k, stats in acceptance_stats.items())
for key in acceptance_stats:
acceptance_stats[key] += [-2] * (max_length - len(acceptance_stats[key]))
print(f"Load {len(acceptance_stats)} with max length {max_length} from {datapath}")
return acceptance_stats
def diff(acceptance_stats1, acceptance_stats2):
diff = {}
for key in acceptance_stats1:
if key in acceptance_stats2:
diff[key] = [a - b for a, b in zip(acceptance_stats1[key], acceptance_stats2[key])]
return diff
datapath_1 = get_datapath(method1)
datapath_2 = get_datapath(method2)
acceptance_stats_1 = load_data(datapath_1)
acceptance_stats_2 = load_data(datapath_2)
acceptance_stats_diff = diff(acceptance_stats_1, acceptance_stats_2)
acceptance_stats = list(acceptance_stats_diff.values())
fig, ax = plt.subplots()
colors = ["red", "white", "blue"]
custom_cmap = LinearSegmentedColormap.from_list("custom", colors, N=256)
sns.heatmap(acceptance_stats, cmap=custom_cmap, center=0)
plt.xlabel("Position")
plt.ylabel("Request ID")
# Add Y-axis labels on the right
ax2 = ax.twinx()
ax2.set_ylim(ax.get_ylim()) # Match y-axis range
ax2.set_yticks([]) # Remove right tick marks if undesired
ax2.set_ylabel("# of Accepted Tokens", labelpad=10) # Set right y-axis label
plt.title(f"Diff between {method2} - {method1} acceptance stats for {dataset}")
plt.tight_layout()
plt.savefig(f"figures/{model}/diff_{method2}_{method1}_{dataset}_acceptance_stats.pdf")

View File

@ -0,0 +1,38 @@
from transformers import AutoTokenizer
from common import MODEL_TO_NAMES, load_data
import matplotlib.pyplot as plt
def plot_prob_entropy(acceptance_stats,
output_path):
acc_probs = []
rej_probs = []
for stat in acceptance_stats:
for i, acc_len in enumerate(stat.lens):
acc_probs.extend(stat.probs[i][:acc_len-1])
rej_probs.extend(stat.probs[i][acc_len-1:])
fig, ax = plt.subplots(figsize=(12, 8))
plt.hist(acc_probs, bins=100, alpha=0.5,
label='Accepted Probabilities', color='green')
plt.hist(rej_probs, bins=100, alpha=0.5,
label='Rejected Probabilities', color='red')
plt.xlabel('Probability')
plt.ylabel('Frequency')
plt.title('Distribution of Accepted and Rejected Probabilities')
plt.legend()
plt.tight_layout()
plt.savefig(output_path)
if __name__ == "__main__":
datapath = "/data/lily/sd-benchmark-paper/batch-sd/acceptance_stats.jsonl"
model = "llama3.1-8B"
tokenizer = AutoTokenizer.from_pretrained(MODEL_TO_NAMES[model],
use_fast=False)
acceptance_stats = load_data(datapath, tokenizer)
plot_prob_entropy(acceptance_stats, output_path="prob_entropy_figures")

View File

@ -0,0 +1,58 @@
# Security Guide
## Inter-Node Communication
All communications between nodes in a multi-node vLLM deployment are **insecure by default** and must be protected by placing the nodes on an isolated network. This includes:
1. PyTorch Distributed communications
2. KV cache transfer communications
3. Tensor, Pipeline, and Data parallel communications
### Configuration Options for Inter-Node Communications
The following options control inter-node communications in vLLM:
1. **Environment Variables:**
- `VLLM_HOST_IP`: Sets the IP address for vLLM processes to communicate on
2. **KV Cache Transfer Configuration:**
- `--kv-ip`: The IP address for KV cache transfer communications (default: 127.0.0.1)
- `--kv-port`: The port for KV cache transfer communications (default: 14579)
3. **Data Parallel Configuration:**
- `data_parallel_master_ip`: IP of the data parallel master (default: 127.0.0.1)
- `data_parallel_master_port`: Port of the data parallel master (default: 29500)
### Notes on PyTorch Distributed
vLLM uses PyTorch's distributed features for some inter-node communication. For
detailed information about PyTorch Distributed security considerations, please
refer to the [PyTorch Security
Guide](https://github.com/pytorch/pytorch/security/policy#using-distributed-features).
Key points from the PyTorch security guide:
- PyTorch Distributed features are intended for internal communication only
- They are not built for use in untrusted environments or networks
- No authorization protocol is included for performance reasons
- Messages are sent unencrypted
- Connections are accepted from anywhere without checks
### Security Recommendations
1. **Network Isolation:**
- Deploy vLLM nodes on a dedicated, isolated network
- Use network segmentation to prevent unauthorized access
- Implement appropriate firewall rules
2. **Configuration Best Practices:**
- Always set `VLLM_HOST_IP` to a specific IP address rather than using defaults
- Configure firewalls to only allow necessary ports between nodes
3. **Access Control:**
- Restrict physical and network access to the deployment environment
- Implement proper authentication and authorization for management interfaces
- Follow the principle of least privilege for all system components
## Reporting Security Vulnerabilities
If you believe you have found a security vulnerability in vLLM, please report it following the project's security policy. For more information on how to report security issues and the project's security policy, please see the [vLLM Security Policy](https://github.com/vllm-project/vllm/blob/main/SECURITY.md).

View File

@ -132,6 +132,7 @@ serving/integrations/index
:caption: Deployment
:maxdepth: 1
deployment/security
deployment/docker
deployment/k8s
deployment/nginx

View File

@ -77,6 +77,10 @@ bash run_cluster.sh \
Then you get a ray cluster of **containers**. Note that you need to keep the shells running these commands alive to hold the cluster. Any shell disconnect will terminate the cluster. In addition, please note that the argument `ip_of_head_node` should be the IP address of the head node, which is accessible by all the worker nodes. The IP addresses of each worker node should be specified in the `VLLM_HOST_IP` environment variable, and should be different for each worker node. Please check the network configuration of your cluster to make sure the nodes can communicate with each other through the specified IP addresses.
:::{warning}
It is considered best practice to set `VLLM_HOST_IP` to an address on a private network segment for the vLLM cluster. The traffic sent here is not encrypted. The endpoints are also exchanging data in a format that could be exploited to execute arbitrary code should a malicious party gain access to the network. Please ensure that this network is not reachable by any untrusted parties.
:::
:::{warning}
Since this is a ray cluster of **containers**, all the following commands should be executed in the **containers**, otherwise you are executing the commands on the host machine, which is not connected to the ray cluster. To enter the container, you can use `docker exec -it node /bin/bash`.
:::

View File

@ -10,12 +10,12 @@ prompts = [
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=10)
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
def main():
# Create an LLM.
llm = LLM(model="facebook/opt-125m", disable_cascade_attn=True)
llm = LLM(model="facebook/opt-125m")
# Generate texts from the prompts.
# The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.

View File

@ -20,15 +20,11 @@ def models_list(*, all: bool = True, keywords: Optional[list[str]] = None):
("facebook/opt-125m", {}),
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
"dtype": torch.float16,
"quantization": "compressed-tensors"
}),
("neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic", {
"dtype": torch.float16,
"quantization": "compressed-tensors"
}),
("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {
"quantization": "compressed-tensors"
}),
("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {}),
("meta-llama/Llama-3.2-1B-Instruct", {}),
]

View File

@ -1,14 +1,118 @@
# SPDX-License-Identifier: Apache-2.0
import json
from argparse import ArgumentError, ArgumentTypeError
from contextlib import nullcontext
from dataclasses import dataclass, field
from typing import Literal, Optional
import pytest
from vllm.config import PoolerConfig
from vllm.engine.arg_utils import EngineArgs, nullable_kvs
from vllm.config import PoolerConfig, config
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
get_type, is_not_builtin, is_type,
nullable_kvs, optional_type)
from vllm.utils import FlexibleArgumentParser
@pytest.mark.parametrize(("type", "value", "expected"), [
(int, "42", 42),
(int, "None", None),
(float, "3.14", 3.14),
(float, "None", None),
(str, "Hello World!", "Hello World!"),
(str, "None", None),
(json.loads, '{"foo":1,"bar":2}', {
"foo": 1,
"bar": 2
}),
(json.loads, "foo=1,bar=2", {
"foo": 1,
"bar": 2
}),
(json.loads, "None", None),
])
def test_optional_type(type, value, expected):
optional_type_func = optional_type(type)
context = nullcontext()
if value == "foo=1,bar=2":
context = pytest.warns(DeprecationWarning)
with context:
assert optional_type_func(value) == expected
@pytest.mark.parametrize(("type_hint", "type", "expected"), [
(int, int, True),
(int, float, False),
(list[int], list, True),
(list[int], tuple, False),
(Literal[0, 1], Literal, True),
])
def test_is_type(type_hint, type, expected):
assert is_type(type_hint, type) == expected
@pytest.mark.parametrize(("type_hints", "type", "expected"), [
({float, int}, int, True),
({int, tuple[int]}, int, True),
({int, tuple[int]}, float, False),
({str, Literal["x", "y"]}, Literal, True),
])
def test_contains_type(type_hints, type, expected):
assert contains_type(type_hints, type) == expected
@pytest.mark.parametrize(("type_hints", "type", "expected"), [
({int, float}, int, int),
({int, float}, str, None),
({str, Literal["x", "y"]}, Literal, Literal["x", "y"]),
])
def test_get_type(type_hints, type, expected):
assert get_type(type_hints, type) == expected
@config
@dataclass
class DummyConfigClass:
regular_bool: bool = True
"""Regular bool with default True"""
optional_bool: Optional[bool] = None
"""Optional bool with default None"""
optional_literal: Optional[Literal["x", "y"]] = None
"""Optional literal with default None"""
tuple_n: tuple[int, ...] = field(default_factory=lambda: (1, 2, 3))
"""Tuple with default (1, 2, 3)"""
tuple_2: tuple[int, int] = field(default_factory=lambda: (1, 2))
"""Tuple with default (1, 2)"""
list_n: list[int] = field(default_factory=lambda: [1, 2, 3])
"""List with default [1, 2, 3]"""
@pytest.mark.parametrize(("type_hint", "expected"), [
(int, False),
(DummyConfigClass, True),
])
def test_is_not_builtin(type_hint, expected):
assert is_not_builtin(type_hint) == expected
def test_get_kwargs():
kwargs = get_kwargs(DummyConfigClass)
print(kwargs)
# bools should not have their type set
assert kwargs["regular_bool"].get("type") is None
assert kwargs["optional_bool"].get("type") is None
# optional literals should have None as a choice
assert kwargs["optional_literal"]["choices"] == ["x", "y", "None"]
# tuples should have the correct nargs
assert kwargs["tuple_n"]["nargs"] == "+"
assert kwargs["tuple_2"]["nargs"] == 2
# lists should work
assert kwargs["list_n"]["type"] is int
assert kwargs["list_n"]["nargs"] == "+"
@pytest.mark.parametrize(("arg", "expected"), [
(None, dict()),
("image=16", {

View File

@ -28,6 +28,7 @@ import vllm.envs as envs
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
QuantizationMethods,
get_quantization_config)
from vllm.model_executor.models import ModelRegistry
from vllm.platforms import CpuArchEnum, current_platform
@ -752,9 +753,8 @@ class ModelConfig:
supported_quantization = QUANTIZATION_METHODS
optimized_quantization_methods = [
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
"awq_marlin", "fbgemm_fp8", "compressed_tensors",
"compressed-tensors", "experts_int8", "quark", "nvfp4", "bitblas",
"gptq_bitblas"
"awq_marlin", "fbgemm_fp8", "compressed-tensors", "experts_int8",
"quark", "nvfp4", "bitblas", "gptq_bitblas"
]
if self.quantization is not None:
self.quantization = self.quantization.lower()
@ -764,13 +764,47 @@ class ModelConfig:
if quant_cfg is not None:
quant_method = quant_cfg.get("quant_method", "").lower()
quant_method = quant_method.replace("compressed_tensors",
"compressed-tensors")
quant_cfg["quant_method"] = quant_method
# Quantization methods which are overrides (i.e. they have a
# `override_quantization_method` method) must be checked in order
# of preference (this is particularly important for GPTQ).
overrides = [
"marlin",
"bitblas",
"gptq_marlin_24",
"gptq_marlin",
"gptq_bitblas",
"awq_marlin",
"ipex",
"moe_wna16",
]
quantization_methods = [
q for q in supported_quantization if q not in overrides
]
# Any custom overrides will be in quantization_methods so we place
# them at the start of the list so custom overrides have preference
# over the built in ones.
quantization_methods = quantization_methods + overrides
# Detect which checkpoint is it
for name in QUANTIZATION_METHODS:
for name in quantization_methods:
method = get_quantization_config(name)
quantization_override = method.override_quantization_method(
quant_cfg, self.quantization)
if quantization_override:
if quantization_override is not None:
# Raise error if the override is not custom (custom would
# be in QUANTIZATION_METHODS but not QuantizationMethods)
# and hasn't been added to the overrides list.
if (name in get_args(QuantizationMethods)
and name not in overrides):
raise ValueError(
f"Quantization method {name} is an override but "
"is has not been added to the `overrides` list "
"above. This is necessary to ensure that the "
"overrides are checked in order of preference.")
quant_method = quantization_override
self.quantization = quantization_override
break

View File

@ -241,7 +241,7 @@ class MessageQueue:
self.remote_socket.setsockopt(IPV6, 1)
remote_addr_ipv6 = True
connect_ip = f"[{connect_ip}]"
socket_addr = f"tcp://*:{remote_subscribe_port}"
socket_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
self.remote_socket.bind(socket_addr)
remote_subscribe_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
else:

View File

@ -11,7 +11,7 @@ from typing import (Any, Callable, Dict, List, Literal, Optional, Type,
TypeVar, Union, cast, get_args, get_origin)
import torch
from typing_extensions import TypeIs
from typing_extensions import TypeIs, deprecated
import vllm.envs as envs
from vllm import version
@ -48,33 +48,29 @@ TypeHint = Union[type[Any], object]
TypeHintT = Union[type[T], object]
def optional_arg(val: str, return_type: Callable[[str], T]) -> Optional[T]:
if val == "" or val == "None":
return None
try:
return return_type(val)
except ValueError as e:
raise argparse.ArgumentTypeError(
f"Value {val} cannot be converted to {return_type}.") from e
def optional_type(
return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]:
def _optional_type(val: str) -> Optional[T]:
if val == "" or val == "None":
return None
try:
if return_type is json.loads and not re.match("^{.*}$", val):
return cast(T, nullable_kvs(val))
return return_type(val)
except ValueError as e:
raise argparse.ArgumentTypeError(
f"Value {val} cannot be converted to {return_type}.") from e
return _optional_type
def optional_str(val: str) -> Optional[str]:
return optional_arg(val, str)
def optional_int(val: str) -> Optional[int]:
return optional_arg(val, int)
def optional_float(val: str) -> Optional[float]:
return optional_arg(val, float)
def nullable_kvs(val: str) -> Optional[dict[str, int]]:
"""NOTE: This function is deprecated, args should be passed as JSON
strings instead.
Parses a string containing comma separate key [str] to value [int]
@deprecated(
"Passing a JSON argument as a string containing comma separated key=value "
"pairs is deprecated. This will be removed in v0.10.0. Please use a JSON "
"string instead.")
def nullable_kvs(val: str) -> dict[str, int]:
"""Parses a string containing comma separate key [str] to value [int]
pairs into a dictionary.
Args:
@ -83,10 +79,7 @@ def nullable_kvs(val: str) -> Optional[dict[str, int]]:
Returns:
Dictionary with parsed values.
"""
if len(val) == 0:
return None
out_dict: Dict[str, int] = {}
out_dict: dict[str, int] = {}
for item in val.split(","):
kv_parts = [part.lower().strip() for part in item.split("=")]
if len(kv_parts) != 2:
@ -108,15 +101,103 @@ def nullable_kvs(val: str) -> Optional[dict[str, int]]:
return out_dict
def optional_dict(val: str) -> Optional[dict[str, int]]:
if re.match("^{.*}$", val):
return optional_arg(val, json.loads)
def is_type(type_hint: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]:
"""Check if the type hint is a specific type."""
return type_hint is type or get_origin(type_hint) is type
logger.warning(
"Failed to parse JSON string. Attempting to parse as "
"comma-separated key=value pairs. This will be deprecated in a "
"future release.")
return nullable_kvs(val)
def contains_type(type_hints: set[TypeHint], type: TypeHintT) -> bool:
"""Check if the type hints contain a specific type."""
return any(is_type(type_hint, type) for type_hint in type_hints)
def get_type(type_hints: set[TypeHint], type: TypeHintT) -> TypeHintT:
"""Get the specific type from the type hints."""
return next((th for th in type_hints if is_type(th, type)), None)
def is_not_builtin(type_hint: TypeHint) -> bool:
"""Check if the class is not a built-in type."""
return type_hint.__module__ != "builtins"
def get_kwargs(cls: ConfigType) -> dict[str, Any]:
cls_docs = get_attr_docs(cls)
kwargs = {}
for field in fields(cls):
# Get the default value of the field
default = field.default
if field.default_factory is not MISSING:
default = field.default_factory()
# Get the help text for the field
name = field.name
help = cls_docs[name]
# Escape % for argparse
help = help.replace("%", "%%")
# Initialise the kwargs dictionary for the field
kwargs[name] = {"default": default, "help": help}
# Get the set of possible types for the field
type_hints: set[TypeHint] = set()
if get_origin(field.type) is Union:
type_hints.update(get_args(field.type))
else:
type_hints.add(field.type)
# Set other kwargs based on the type hints
if contains_type(type_hints, bool):
# Creates --no-<name> and --<name> flags
kwargs[name]["action"] = argparse.BooleanOptionalAction
elif contains_type(type_hints, Literal):
# Creates choices from Literal arguments
type_hint = get_type(type_hints, Literal)
choices = sorted(get_args(type_hint))
kwargs[name]["choices"] = choices
choice_type = type(choices[0])
assert all(type(c) is choice_type for c in choices), (
"All choices must be of the same type. "
f"Got {choices} with types {[type(c) for c in choices]}")
kwargs[name]["type"] = choice_type
elif contains_type(type_hints, tuple):
type_hint = get_type(type_hints, tuple)
types = get_args(type_hint)
tuple_type = types[0]
assert all(t is tuple_type for t in types if t is not Ellipsis), (
"All non-Ellipsis tuple elements must be of the same "
f"type. Got {types}.")
kwargs[name]["type"] = tuple_type
kwargs[name]["nargs"] = "+" if Ellipsis in types else len(types)
elif contains_type(type_hints, list):
type_hint = get_type(type_hints, list)
types = get_args(type_hint)
assert len(types) == 1, (
"List type must have exactly one type. Got "
f"{type_hint} with types {types}")
kwargs[name]["type"] = types[0]
kwargs[name]["nargs"] = "+"
elif contains_type(type_hints, int):
kwargs[name]["type"] = int
elif contains_type(type_hints, float):
kwargs[name]["type"] = float
elif contains_type(type_hints, dict):
# Dict arguments will always be optional
kwargs[name]["type"] = optional_type(json.loads)
elif (contains_type(type_hints, str)
or any(is_not_builtin(th) for th in type_hints)):
kwargs[name]["type"] = str
else:
raise ValueError(
f"Unsupported type {type_hints} for argument {name}.")
# If None is in type_hints, make the argument optional.
# But not if it's a bool, argparse will handle this better.
if type(None) in type_hints and not contains_type(type_hints, bool):
kwargs[name]["type"] = optional_type(kwargs[name]["type"])
if kwargs[name].get("choices"):
kwargs[name]["choices"].append("None")
return kwargs
@dataclass
@ -279,100 +360,6 @@ class EngineArgs:
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"""Shared CLI arguments for vLLM engine."""
def is_type_in_union(cls: TypeHint, type: TypeHint) -> bool:
"""Check if the class is a type in a union type."""
is_union = get_origin(cls) is Union
type_in_union = type in [get_origin(a) or a for a in get_args(cls)]
return is_union and type_in_union
def get_type_from_union(cls: TypeHint, type: TypeHintT) -> TypeHintT:
"""Get the type in a union type."""
for arg in get_args(cls):
if (get_origin(arg) or arg) is type:
return arg
raise ValueError(f"Type {type} not found in union type {cls}.")
def is_optional(cls: TypeHint) -> TypeIs[Union[Any, None]]:
"""Check if the class is an optional type."""
return is_type_in_union(cls, type(None))
def can_be_type(cls: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]:
"""Check if the class can be of type."""
return cls is type or get_origin(cls) is type or is_type_in_union(
cls, type)
def is_custom_type(cls: TypeHint) -> bool:
"""Check if the class is a custom type."""
return cls.__module__ != "builtins"
def get_kwargs(cls: ConfigType) -> dict[str, Any]:
cls_docs = get_attr_docs(cls)
kwargs = {}
for field in fields(cls):
# Get the default value of the field
default = field.default
if field.default_factory is not MISSING:
default = field.default_factory()
# Get the help text for the field
name = field.name
help = cls_docs[name]
# Escape % for argparse
help = help.replace("%", "%%")
# Initialise the kwargs dictionary for the field
kwargs[name] = {"default": default, "help": help}
# Make note of if the field is optional and get the actual
# type of the field if it is
optional = is_optional(field.type)
field_type = get_args(
field.type)[0] if optional else field.type
# Set type, action and choices for the field depending on the
# type of the field
if can_be_type(field_type, bool):
# Creates --no-<name> and --<name> flags
kwargs[name]["action"] = argparse.BooleanOptionalAction
kwargs[name]["type"] = bool
elif can_be_type(field_type, Literal):
# Creates choices from Literal arguments
if is_type_in_union(field_type, Literal):
field_type = get_type_from_union(field_type, Literal)
choices = get_args(field_type)
kwargs[name]["choices"] = choices
choice_type = type(choices[0])
assert all(type(c) is choice_type for c in choices), (
"All choices must be of the same type. "
f"Got {choices} with types {[type(c) for c in choices]}"
)
kwargs[name]["type"] = choice_type
elif can_be_type(field_type, tuple):
if is_type_in_union(field_type, tuple):
field_type = get_type_from_union(field_type, tuple)
dtypes = get_args(field_type)
dtype = dtypes[0]
assert all(
d is dtype for d in dtypes if d is not Ellipsis
), ("All non-Ellipsis tuple elements must be of the same "
f"type. Got {dtypes}.")
kwargs[name]["type"] = dtype
kwargs[name]["nargs"] = "+"
elif can_be_type(field_type, int):
kwargs[name]["type"] = optional_int if optional else int
elif can_be_type(field_type, float):
kwargs[name][
"type"] = optional_float if optional else float
elif can_be_type(field_type, dict):
kwargs[name]["type"] = optional_dict
elif (can_be_type(field_type, str)
or is_custom_type(field_type)):
kwargs[name]["type"] = optional_str if optional else str
else:
raise ValueError(
f"Unsupported type {field.type} for argument {name}. ")
return kwargs
# Model arguments
parser.add_argument(
'--model',
@ -390,13 +377,13 @@ class EngineArgs:
'which task to use.')
parser.add_argument(
'--tokenizer',
type=optional_str,
type=optional_type(str),
default=EngineArgs.tokenizer,
help='Name or path of the huggingface tokenizer to use. '
'If unspecified, model name or path will be used.')
parser.add_argument(
"--hf-config-path",
type=optional_str,
type=optional_type(str),
default=EngineArgs.hf_config_path,
help='Name or path of the huggingface config to use. '
'If unspecified, model name or path will be used.')
@ -408,21 +395,21 @@ class EngineArgs:
'the input. The generated output will contain token ids.')
parser.add_argument(
'--revision',
type=optional_str,
type=optional_type(str),
default=None,
help='The specific model version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use '
'the default version.')
parser.add_argument(
'--code-revision',
type=optional_str,
type=optional_type(str),
default=None,
help='The specific revision to use for the model code on '
'Hugging Face Hub. It can be a branch name, a tag name, or a '
'commit id. If unspecified, will use the default version.')
parser.add_argument(
'--tokenizer-revision',
type=optional_str,
type=optional_type(str),
default=None,
help='Revision of the huggingface tokenizer to use. '
'It can be a branch name, a tag name, or a commit id. '
@ -513,7 +500,7 @@ class EngineArgs:
parser.add_argument(
'--logits-processor-pattern',
type=optional_str,
type=optional_type(str),
default=None,
help='Optional regex pattern specifying valid logits processor '
'qualified names that can be passed with the `logits_processors` '
@ -612,7 +599,7 @@ class EngineArgs:
# Quantization settings.
parser.add_argument('--quantization',
'-q',
type=optional_str,
type=optional_type(str),
choices=[*QUANTIZATION_METHODS, None],
default=EngineArgs.quantization,
help='Method used to quantize the weights. If '
@ -921,7 +908,7 @@ class EngineArgs:
'class without changing the existing functions.')
parser.add_argument(
"--generation-config",
type=optional_str,
type=optional_type(str),
default="auto",
help="The folder path to the generation config. "
"Defaults to 'auto', the generation config will be loaded from "

View File

@ -11,7 +11,7 @@ import ssl
from collections.abc import Sequence
from typing import Optional, Union, get_args
from vllm.engine.arg_utils import AsyncEngineArgs, optional_str
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
validate_chat_template)
from vllm.entrypoints.openai.serving_models import (LoRAModulePath,
@ -79,7 +79,7 @@ class PromptAdapterParserAction(argparse.Action):
def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument("--host",
type=optional_str,
type=optional_type(str),
default=None,
help="Host name.")
parser.add_argument("--port", type=int, default=8000, help="Port number.")
@ -108,13 +108,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=["*"],
help="Allowed headers.")
parser.add_argument("--api-key",
type=optional_str,
type=optional_type(str),
default=None,
help="If provided, the server will require this key "
"to be presented in the header.")
parser.add_argument(
"--lora-modules",
type=optional_str,
type=optional_type(str),
default=None,
nargs='+',
action=LoRAParserAction,
@ -126,14 +126,14 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"\"base_model_name\": \"id\"}``")
parser.add_argument(
"--prompt-adapters",
type=optional_str,
type=optional_type(str),
default=None,
nargs='+',
action=PromptAdapterParserAction,
help="Prompt adapter configurations in the format name=path. "
"Multiple adapters can be specified.")
parser.add_argument("--chat-template",
type=optional_str,
type=optional_type(str),
default=None,
help="The file path to the chat template, "
"or the template in single-line form "
@ -151,20 +151,20 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'similar to OpenAI schema. '
'Example: ``[{"type": "text", "text": "Hello world!"}]``')
parser.add_argument("--response-role",
type=optional_str,
type=optional_type(str),
default="assistant",
help="The role name to return if "
"``request.add_generation_prompt=true``.")
parser.add_argument("--ssl-keyfile",
type=optional_str,
type=optional_type(str),
default=None,
help="The file path to the SSL key file.")
parser.add_argument("--ssl-certfile",
type=optional_str,
type=optional_type(str),
default=None,
help="The file path to the SSL cert file.")
parser.add_argument("--ssl-ca-certs",
type=optional_str,
type=optional_type(str),
default=None,
help="The CA certificates file.")
parser.add_argument(
@ -180,13 +180,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
)
parser.add_argument(
"--root-path",
type=optional_str,
type=optional_type(str),
default=None,
help="FastAPI root_path when app is behind a path based routing proxy."
)
parser.add_argument(
"--middleware",
type=optional_str,
type=optional_type(str),
action="append",
default=[],
help="Additional ASGI middleware to apply to the app. "

View File

@ -12,7 +12,7 @@ import torch
from prometheus_client import start_http_server
from tqdm import tqdm
from vllm.engine.arg_utils import AsyncEngineArgs, optional_str
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.logger import RequestLogger, logger
# yapf: disable
@ -61,7 +61,7 @@ def parse_args():
"to the output URL.",
)
parser.add_argument("--response-role",
type=optional_str,
type=optional_type(str),
default="assistant",
help="The role name to return if "
"`request.add_generation_prompt=True`.")

View File

@ -85,7 +85,6 @@ if TYPE_CHECKING:
VLLM_ROCM_MOE_PADDING: bool = True
VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
VLLM_ENABLE_V1_ADVANCE_STEP: bool = False
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
VLLM_DISABLE_COMPILE_CACHE: bool = False
Q_SCALE_CONSTANT: int = 200
@ -601,8 +600,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: float(os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")),
"VLLM_DISABLE_COMPILE_CACHE":
lambda: bool(int(os.getenv("VLLM_DISABLE_COMPILE_CACHE", "0"))),
"VLLM_ENABLE_V1_ADVANCE_STEP":
lambda: bool(int(os.getenv("VLLM_ENABLE_V1_ADVANCE_STEP", "0"))),
# If set, vllm will run in development mode, which will enable
# some additional endpoints for developing and debugging,

View File

@ -1,11 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Dict, List, Type
from typing import Literal, Type, get_args
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
QUANTIZATION_METHODS: List[str] = [
QuantizationMethods = Literal[
"aqlm",
"awq",
"deepspeedfp",
@ -15,8 +15,6 @@ QUANTIZATION_METHODS: List[str] = [
"fbgemm_fp8",
"modelopt",
"nvfp4",
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin",
"bitblas",
"gguf",
@ -36,6 +34,7 @@ QUANTIZATION_METHODS: List[str] = [
"moe_wna16",
"torchao",
]
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
# The customized quantization methods which will be added to this dict.
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG = {}
@ -111,7 +110,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
from .torchao import TorchAOConfig
from .tpu_int8 import Int8TpuConfig
method_to_config: Dict[str, Type[QuantizationConfig]] = {
method_to_config: dict[str, Type[QuantizationConfig]] = {
"aqlm": AQLMConfig,
"awq": AWQConfig,
"deepspeedfp": DeepSpeedFPConfig,
@ -120,8 +119,6 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
"fbgemm_fp8": FBGEMMFp8Config,
"modelopt": ModelOptFp8Config,
"nvfp4": ModelOptNvFp4Config,
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin": MarlinConfig,
"bitblas": BitBLASConfig,
"gguf": GGUFConfig,
@ -150,6 +147,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
__all__ = [
"QuantizationConfig",
"QuantizationMethods",
"get_quantization_config",
"QUANTIZATION_METHODS",
]

View File

@ -72,7 +72,7 @@ class CompressedTensorsConfig(QuantizationConfig):
return 70
def get_name(self) -> str:
return "compressed_tensors"
return "compressed-tensors"
def get_quant_method(
self,

View File

@ -130,8 +130,8 @@ class RocmPlatform(Platform):
device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
supported_quantization: list[str] = [
"awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
"fbgemm_fp8", "gguf", "quark", "ptpc_fp8"
"awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf",
"quark", "ptpc_fp8"
]
@classmethod

View File

@ -30,9 +30,7 @@ class TpuPlatform(Platform):
ray_device_key: str = "TPU"
device_control_env_var: str = "TPU_VISIBLE_CHIPS"
supported_quantization: list[str] = [
"tpu_int8", "compressed-tensors", "compressed_tensors"
]
supported_quantization: list[str] = ["tpu_int8", "compressed-tensors"]
additional_env_vars: list[str] = [
"TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"

View File

@ -28,6 +28,7 @@ from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager
import json
logger = init_logger(__name__)
@ -632,6 +633,7 @@ class Scheduler(SchedulerInterface):
logprobs = model_runner_output.logprobs
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
self.acceptance_stats = model_runner_output.acceptance_stats
new_running: list[Request] = []
outputs: list[EngineCoreOutput] = []
@ -789,6 +791,18 @@ class Scheduler(SchedulerInterface):
self._free_request(request)
def _free_request(self, request: Request) -> None:
req_id = request.request_id
data = self.acceptance_stats.pop(req_id)
with open('acceptance_stats.jsonl', 'a') as f:
f.write(json.dumps({
"id": req_id,
"acc": data,
"prompt_token_ids": request.prompt_token_ids,
"generated_token_ids": request.output_token_ids._x
}))
f.write('\n')
assert request.is_finished()
self.kv_cache_manager.free(request)
self.kv_cache_manager.free_block_hashes(request)

View File

@ -99,6 +99,8 @@ class ModelRunnerOutput:
# [prompt_len, num_prompt_logprobs]
# [prompt_len]
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]]
acceptance_stats: Optional[dict[str, list]] = None
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(

View File

@ -1,8 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
import torch
import torch.nn as nn
import torch.nn.functional as F
import triton
import triton.language as tl
from torch.distributions import Categorical
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.forward_context import set_forward_context
@ -98,12 +100,23 @@ class EagleProposer:
)
sample_hidden_states = hidden_states_logits[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
all_draft_probs = []
all_draft_entropy = []
probs = F.softmax(logits, dim=-1, dtype=torch.float32)
draft_token_ids = logits.argmax(dim=-1)
# Get the probabilities of the draft tokens.
draft_probs = probs.gather(dim=1, index=draft_token_ids.unsqueeze(1))
dist = Categorical(logits=logits)
entropy = dist.entropy().unsqueeze(-1) # [batch_size, 1]
all_draft_probs.append(draft_probs)
all_draft_entropy.append(entropy)
# Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1:
# [batch_size, 1]
return draft_token_ids.view(-1, 1)
return draft_token_ids.view(-1,
1), all_draft_probs, all_draft_entropy
# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]
@ -164,9 +177,17 @@ class EagleProposer:
draft_token_ids = logits.argmax(dim=-1)
draft_token_ids_list.append(draft_token_ids)
probs = F.softmax(logits, dim=-1, dtype=torch.float32)
draft_probs = probs.gather(dim=1,
index=draft_token_ids.unsqueeze(1))
dist = Categorical(logits=logits)
entropy = dist.entropy().unsqueeze(-1) # [batch_size, 1]
all_draft_probs.append(draft_probs)
all_draft_entropy.append(entropy)
# [batch_size, num_speculative_tokens]
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
return draft_token_ids
return draft_token_ids, all_draft_probs, all_draft_entropy
@staticmethod
def prepare_inputs(

View File

@ -3,7 +3,6 @@
import numpy as np
import torch
import vllm.envs as envs
from vllm.logger import init_logger
logger = init_logger(__name__)
@ -37,9 +36,6 @@ class BlockTable:
self.block_table_np = self.block_table_cpu.numpy()
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
self.prev_num_reqs = 0
self.is_updated = True
def append_row(
self,
block_ids: list[int],
@ -52,22 +48,16 @@ class BlockTable:
self.num_blocks_per_row[row_idx] += num_blocks
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
self.is_updated = True
def add_row(self, block_ids: list[int], row_idx: int) -> None:
self.num_blocks_per_row[row_idx] = 0
self.append_row(block_ids, row_idx)
self.is_updated = True
def move_row(self, src: int, tgt: int) -> None:
num_blocks = self.num_blocks_per_row[src]
self.block_table_np[tgt, :num_blocks] = self.block_table_np[
src, :num_blocks]
self.num_blocks_per_row[tgt] = num_blocks
self.is_updated = True
def swap_row(self, src: int, tgt: int) -> None:
num_blocks_src = self.num_blocks_per_row[src]
num_blocks_tgt = self.num_blocks_per_row[tgt]
@ -76,28 +66,14 @@ class BlockTable:
self.block_table_np[[src, tgt]] = self.block_table_np[[tgt, src]]
self.is_updated = True
def commit(self, num_reqs: int) -> None:
if envs.VLLM_ENABLE_V1_ADVANCE_STEP:
# Incremental copy
if self.prev_num_reqs != num_reqs or self.is_updated:
self.block_table[:num_reqs].copy_(
self.block_table_cpu[:num_reqs], non_blocking=True)
self.prev_num_reqs = num_reqs
self.is_updated = False
else:
# Always copy
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
non_blocking=True)
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
non_blocking=True)
def clear(self) -> None:
self.block_table.fill_(0)
self.block_table_cpu.fill_(0)
self.is_updated = True
def get_device_tensor(self) -> torch.Tensor:
"""Ruturns the device tensor of the block table."""
return self.block_table

View File

@ -10,7 +10,6 @@ import torch
import torch.distributed
import torch.nn as nn
import vllm.envs as envs
from vllm.attention import AttentionType, get_attn_backend
from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig,
@ -143,15 +142,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
weakref.proxy(self))
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
if envs.VLLM_ENABLE_V1_ADVANCE_STEP:
logger.info("Advance_step is enabled")
if self.cascade_attn_enabled:
logger.warning(
"Disabling cascade attn (since advance_step is on)")
self.cascade_attn_enabled = False
else:
logger.info("Advance_step is disabled")
# Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY
self.uses_mrope = model_config.uses_mrope
@ -281,51 +271,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
device="cpu",
pin_memory=self.pin_memory)
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
self.slot_mapping_gpu = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device=self.device)
self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
self.query_start_loc_np = self.query_start_loc_cpu.numpy()
self.query_start_loc_gpu = torch.zeros(self.max_num_reqs + 1,
dtype=torch.int32,
device=self.device)
self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
self.seq_lens_np = self.seq_lens_cpu.numpy()
self.seq_lens_gpu = torch.zeros(self.max_num_reqs,
dtype=torch.int32,
device=self.device)
# Cached
self.prev_num_reqs = 0
self.req_indices_gpu = torch.arange(self.max_num_reqs,
dtype=torch.int32,
device=self.device)
self.req_indices_block_table_offsets_gpu = (
self.req_indices_gpu * self.max_num_blocks_per_req)
self.num_scheduled_tokens_gpu = torch.ones(self.max_num_reqs,
dtype=torch.int32,
device=self.device)
self.cu_num_tokens_gpu = torch.cumsum(self.num_scheduled_tokens_gpu, 0)
self.query_start_loc_gpu[0] = 0
self.query_start_loc_gpu[1:self.max_num_reqs +
1] = self.cu_num_tokens_gpu
self.logits_indices_gpu = self.query_start_loc_gpu[1:] - 1
self.prev_sampled_token_ids: Optional[torch.Tensor] = None
self.prev_attn_metadata = None
self.is_first_advance_decode = True
self.acceptance_stats = {}
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
"""Update the cached states and the persistent batch with the scheduler
@ -530,119 +487,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if batch_changed or batch_reordered:
self.input_batch.refresh_sampling_metadata()
def _advance_decode_step(
self,
scheduler_output,
num_scheduled_tokens,
):
# print(" -- inside advance_decode_step")
num_reqs = self.input_batch.num_reqs
assert num_reqs > 0
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens == num_reqs
# TODO: Add if needed
# Get request indices.
# E.g., num_reqs == 3 -> [0, 1, 2]
# req_indices_gpu = self.req_indices_gpu[:num_reqs]
# Get cu_sums
# cu_num_tokens = self.cu_num_tokens_gpu[:num_reqs]
# Increment positions
positions_gpu = self.positions[:total_num_scheduled_tokens]
positions_gpu[:total_num_scheduled_tokens] += 1
# TODO: Verify MROPE is ok here
# Calculate M-RoPE positions.
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
self._calc_mrope_positions(scheduler_output)
# Set next tokens
# (prev iteration tokens are cached in prev_sampled_token_ids tensor)
assert self.prev_sampled_token_ids is not None
self.input_ids[:total_num_scheduled_tokens] = \
self.prev_sampled_token_ids[:,0]
# Calculate the slot mapping
block_table_indices_gpu = (
self.req_indices_block_table_offsets_gpu[:num_reqs] +
positions_gpu // self.block_size)
block_table_gpu = self.input_batch.block_table.get_device_tensor()
# Note: The block table tensor is async copied from CPU to GPU
# (inside the .commit() call) if was previously modified
block_numbers_gpu = block_table_gpu.flatten()[block_table_indices_gpu]
block_offsets_gpu = positions_gpu % self.block_size
slot_mapping_gpu = self.slot_mapping_gpu[:total_num_scheduled_tokens]
slot_mapping_gpu[:] = (block_numbers_gpu * self.block_size +
block_offsets_gpu)
# Prepare the attention metadata.
# query_start_loc is always the same for all decode iterations
query_start_loc_gpu = self.query_start_loc_gpu[:num_reqs + 1]
if self.uses_mrope:
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
non_blocking=True)
# TODO: Add cascade attn support
# Verify cascade attention is disabled
assert not self.cascade_attn_enabled
# TODO: Add support for other attn backends
assert self.prev_attn_metadata is not None
assert isinstance(self.prev_attn_metadata, FlashAttentionMetadata)
attn_metadata = self.prev_attn_metadata
attn_metadata.max_seq_len += 1
attn_metadata.query_start_loc = query_start_loc_gpu
attn_metadata.seq_lens += 1
attn_metadata.slot_mapping = slot_mapping_gpu
# print("attn_metadata.seq_lens: shape = {} data = {}".format(
# attn_metadata.seq_lens.shape, attn_metadata.seq_lens))
use_spec_decode = len(
scheduler_output.scheduled_spec_decode_tokens) > 0
if not use_spec_decode:
# NOTE(woosuk): Due to chunked prefills, the batch may contain
# partial requests. While we should not sample any token
# from these partial requests, we do so for simplicity.
# We will ignore the sampled tokens from the partial requests.
# TODO: Support prompt logprobs.
logits_indices = self.logits_indices_gpu[:num_reqs]
spec_decode_metadata = None
else:
# TODO: Check if spec_decode can be enabled here
raise Exception("advance_step has no support for spec_decode yet")
# # Get the number of draft tokens for each request.
# # Iterate over the dictionary rather than all requests since
# # not all requests have draft tokens.
# num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
# for req_id, draft_token_ids in (
# scheduler_output.scheduled_spec_decode_tokens.items()):
# req_idx = self.input_batch.req_id_to_index[req_id]
# num_draft_tokens[req_idx] = len(draft_token_ids)
# spec_decode_metadata = self._calc_spec_decode_metadata(
# num_draft_tokens, cu_num_tokens)
# logits_indices = spec_decode_metadata.logits_indices
# Hot-Swap lora model
if self.lora_config:
# TODO: Check if this works
raise Exception("advance_step has no LORA support yet")
self.set_active_loras(self.input_batch, num_scheduled_tokens)
return attn_metadata, logits_indices, spec_decode_metadata
def _prepare_inputs(
self,
scheduler_output: "SchedulerOutput",
@ -663,38 +507,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
max_num_scheduled_tokens = max(tokens)
# Determine if advance step can be used
use_spec_decode = len(
scheduler_output.scheduled_spec_decode_tokens) > 0
is_flash_attn = self.prev_attn_metadata is not None and isinstance(
self.prev_attn_metadata, FlashAttentionMetadata)
is_advance_decode = (envs.VLLM_ENABLE_V1_ADVANCE_STEP
and self.prev_num_reqs == num_reqs
and max_num_scheduled_tokens == 1
and not use_spec_decode
and not self.cascade_attn_enabled
and is_flash_attn)
if is_advance_decode:
if self.is_first_advance_decode:
# The first time advance_step can be used,
# we run the usual prepare, so that positions tensor
# is initialized
self.is_first_advance_decode = False
else:
# This is the fast-path advance_step
# (all tensors are on the GPU and are updated on the GPU)
(attn_metadata, logits_indices,
spec_decode_metadata) = self._advance_decode_step(
scheduler_output, num_scheduled_tokens)
return attn_metadata, logits_indices, spec_decode_metadata
else:
self.is_first_advance_decode = True
self.prev_num_reqs = num_reqs
# Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
req_indices = np.repeat(self.arange_np[:num_reqs],
@ -713,7 +525,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets
# Get positions.
positions_np = self.positions_np[:total_num_scheduled_tokens]
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
arange,
@ -790,7 +601,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
max_query_len=max_num_scheduled_tokens,
common_prefix_len=common_prefix_len,
)
self.prev_attn_metadata = attn_metadata
use_spec_decode = len(
scheduler_output.scheduled_spec_decode_tokens) > 0
@ -1369,8 +1179,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Get the valid generated tokens.
sampled_token_ids = sampler_output.sampled_token_ids
self.prev_sampled_token_ids = sampled_token_ids
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
# No spec decode tokens.
@ -1381,6 +1189,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
sampled_token_ids,
self.input_batch.vocab_size,
)
for i, token_ids in enumerate(valid_sampled_token_ids):
req_id = self.input_batch.req_ids[i]
if req_id not in self.acceptance_stats:
self.acceptance_stats[req_id] = {
'acc_len': [],
'acc_prob': [],
'acc_entropy': [],
}
self.acceptance_stats[req_id]['acc_len'].append(len(token_ids))
# Force 1 generated token per request.
for i, token_ids in enumerate(valid_sampled_token_ids):
valid_sampled_token_ids[i] = token_ids[:1]
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
@ -1456,7 +1277,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat(target_hidden_states, dim=-1)
draft_token_ids = self.drafter.propose(
draft_token_ids, draft_probs, draft_entropy = self.drafter.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
@ -1468,6 +1289,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
spec_token_ids = draft_token_ids.tolist()
for req_id in self.input_batch.req_ids:
if req_id not in self.acceptance_stats:
self.acceptance_stats[req_id] = {
'acc_len': [],
'acc_prob': [],
'acc_entropy': [],
}
req_index = self.input_batch.req_id_to_index[req_id]
step_probs, step_entropy = [], []
for prob, entropy in zip(draft_probs, draft_entropy):
step_probs.append(prob[req_index].item())
step_entropy.append(entropy[req_index].item())
self.acceptance_stats[req_id]['acc_prob'].append(step_probs)
self.acceptance_stats[req_id]['acc_entropy'].append(step_entropy)
# Clear KVConnector state after all KVs are generated.
if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata()
@ -1479,6 +1316,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
spec_token_ids=spec_token_ids,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
acceptance_stats=self.acceptance_stats,
)
def generate_draft_token_ids(