Compare commits
2 Commits
revert-222
...
remove_mam
| Author | SHA1 | Date | |
|---|---|---|---|
| ddb65dad96 | |||
| c41ea52634 |
@ -168,9 +168,9 @@ See [nightly-descriptions.md](nightly-descriptions.md) for the detailed descript
|
||||
### Workflow
|
||||
|
||||
- The [nightly-pipeline.yaml](nightly-pipeline.yaml) specifies the docker containers for different LLM serving engines.
|
||||
- Inside each container, we run [scripts/run-nightly-benchmarks.sh](scripts/run-nightly-benchmarks.sh), which will probe the serving engine of the current container.
|
||||
- The `scripts/run-nightly-benchmarks.sh` will parse the workload described in [nightly-tests.json](tests/nightly-tests.json) and launch the right benchmark for the specified serving engine via `scripts/launch-server.sh`.
|
||||
- At last, we run [scripts/summary-nightly-results.py](scripts/summary-nightly-results.py) to collect and plot the final benchmarking results, and update the results to buildkite.
|
||||
- Inside each container, we run [run-nightly-suite.sh](run-nightly-suite.sh), which will probe the serving engine of the current container.
|
||||
- The `run-nightly-suite.sh` will redirect the request to `tests/run-[llm serving engine name]-nightly.sh`, which parses the workload described in [nightly-tests.json](tests/nightly-tests.json) and performs the benchmark.
|
||||
- At last, we run [scripts/plot-nightly-results.py](scripts/plot-nightly-results.py) to collect and plot the final benchmarking results, and update the results to buildkite.
|
||||
|
||||
### Nightly tests
|
||||
|
||||
@ -180,6 +180,6 @@ In [nightly-tests.json](tests/nightly-tests.json), we include the command line a
|
||||
|
||||
The docker containers for benchmarking are specified in `nightly-pipeline.yaml`.
|
||||
|
||||
WARNING: the docker versions are HARD-CODED and SHOULD BE ALIGNED WITH `nightly-descriptions.md`. The docker versions need to be hard-coded as there are several version-specific bug fixes inside `scripts/run-nightly-benchmarks.sh` and `scripts/launch-server.sh`.
|
||||
WARNING: the docker versions are HARD-CODED and SHOULD BE ALIGNED WITH `nightly-descriptions.md`. The docker versions need to be hard-coded as there are several version-specific bug fixes inside `tests/run-[llm serving engine name]-nightly.sh`.
|
||||
|
||||
WARNING: populating `trt-llm` to latest version is not easy, as it requires updating several protobuf files in [tensorrt-demo](https://github.com/neuralmagic/tensorrt-demo.git).
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@ -4,9 +4,6 @@
|
||||
# vllm-flash-attn built from source
|
||||
vllm/vllm_flash_attn/*
|
||||
|
||||
# triton jit
|
||||
.triton
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
|
||||
@ -427,7 +427,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
set(SRCS
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8.cu"
|
||||
)
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
|
||||
@ -3,8 +3,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from packaging import version
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
|
||||
MINIMUM_BITBLAS_VERSION,
|
||||
)
|
||||
@ -12,7 +10,7 @@ from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
|
||||
try:
|
||||
import bitblas
|
||||
|
||||
if version.parse(bitblas.__version__) < version.parse(MINIMUM_BITBLAS_VERSION):
|
||||
if bitblas.__version__ < MINIMUM_BITBLAS_VERSION:
|
||||
raise ImportError(
|
||||
"bitblas version is wrong. Please "
|
||||
f"install bitblas>={MINIMUM_BITBLAS_VERSION}"
|
||||
|
||||
@ -1,71 +0,0 @@
|
||||
# Benchmark KV Cache Offloading with Multi-Turn Conversations
|
||||
|
||||
The requirements (pip) for `benchmark_serving_multi_turn.py` can be found in `requirements.txt`
|
||||
|
||||
First start serving your model
|
||||
|
||||
```bash
|
||||
export MODEL_NAME=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/
|
||||
|
||||
vllm serve $MODEL_NAME --disable-log-requests
|
||||
```
|
||||
|
||||
## Synthetic Multi-Turn Conversations
|
||||
|
||||
Download the following text file (used for generation of synthetic conversations)
|
||||
|
||||
```bash
|
||||
wget https://www.gutenberg.org/ebooks/1184.txt.utf-8
|
||||
mv 1184.txt.utf-8 pg1184.txt
|
||||
```
|
||||
|
||||
The filename `pg1184.txt` is used in `generate_multi_turn.json` (see `"text_files"`).
|
||||
|
||||
But you may use other text files if you prefer (using this specific file is not required).
|
||||
|
||||
Then run the benchmarking script
|
||||
|
||||
```bash
|
||||
export MODEL_NAME=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/
|
||||
|
||||
python benchmark_serving_multi_turn.py --model $MODEL_NAME --input-file generate_multi_turn.json \
|
||||
--num-clients 2 --max-active-conversations 6
|
||||
```
|
||||
|
||||
You can edit the file `generate_multi_turn.json` to change the conversation parameters (number of turns, etc.).
|
||||
|
||||
If successful, you will see the following output
|
||||
|
||||
```bash
|
||||
----------------------------------------------------------------------------------------------------
|
||||
Statistics summary:
|
||||
runtime_sec = 215.810
|
||||
requests_per_sec = 0.769
|
||||
----------------------------------------------------------------------------------------------------
|
||||
count mean std min 25% 50% 75% 90% 99% max
|
||||
ttft_ms 166.0 78.22 67.63 45.91 59.94 62.26 64.43 69.66 353.18 567.54
|
||||
tpot_ms 166.0 25.37 0.57 24.40 25.07 25.31 25.50 25.84 27.50 28.05
|
||||
latency_ms 166.0 2591.07 326.90 1998.53 2341.62 2573.01 2860.10 3003.50 3268.46 3862.94
|
||||
input_num_turns 166.0 7.43 4.57 1.00 3.00 7.00 11.00 13.00 17.00 17.00
|
||||
input_num_tokens 166.0 2006.20 893.56 522.00 1247.75 2019.00 2718.00 3233.00 3736.45 3899.00
|
||||
output_num_tokens 166.0 100.01 11.80 80.00 91.00 99.00 109.75 116.00 120.00 120.00
|
||||
output_num_chunks 166.0 99.01 11.80 79.00 90.00 98.00 108.75 115.00 119.00 119.00
|
||||
----------------------------------------------------------------------------------------------------
|
||||
```
|
||||
|
||||
## ShareGPT Conversations
|
||||
|
||||
To run with the ShareGPT data, download the following ShareGPT dataset:
|
||||
`https://huggingface.co/datasets/philschmid/sharegpt-raw/blob/main/sharegpt_20230401_clean_lang_split.json`
|
||||
|
||||
Use the `convert_sharegpt_to_openai.py` script to convert the dataset to a format supported by `benchmark_serving_multi_turn.py`
|
||||
|
||||
```bash
|
||||
python convert_sharegpt_to_openai.py sharegpt_20230401_clean_lang_split.json sharegpt_conv_128.json --seed=99 --max-items=128
|
||||
```
|
||||
|
||||
The script will convert the ShareGPT dataset to a dataset with the standard user/assistant roles.
|
||||
|
||||
The flag `--max-items=128` is used to sample 128 conversations from the original dataset (change as needed).
|
||||
|
||||
Use the output JSON file `sharegpt_conv_128.json` as the `--input-file` for `benchmark_serving_multi_turn.py`.
|
||||
@ -1,493 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from statistics import mean
|
||||
from typing import Any, NamedTuple, Optional, Union
|
||||
|
||||
import numpy as np # type: ignore
|
||||
import pandas as pd # type: ignore
|
||||
from bench_utils import (
|
||||
TEXT_SEPARATOR,
|
||||
Color,
|
||||
logger,
|
||||
)
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
|
||||
# Conversation ID is a string (e.g: "UzTK34D")
|
||||
ConvId = str
|
||||
|
||||
# A list of dicts (dicts with keys "id" and "messages")
|
||||
ShareGptConversations = list[dict[str, Any]]
|
||||
|
||||
# A list of dicts (dicts with keys "role" and "content")
|
||||
MessagesList = list[dict[str, str]]
|
||||
|
||||
# Map conversation ID to conversation messages
|
||||
ConversationsMap = list[ConvId, MessagesList]
|
||||
|
||||
|
||||
class Distribution(ABC):
|
||||
@abstractmethod
|
||||
def sample(self, size: int = 1) -> np.ndarray:
|
||||
pass
|
||||
|
||||
|
||||
class UniformDistribution(Distribution):
|
||||
def __init__(
|
||||
self,
|
||||
min_val: Union[int, float],
|
||||
max_val: Union[int, float],
|
||||
is_integer: bool = True,
|
||||
) -> None:
|
||||
self.min_val = min_val
|
||||
self.max_val = max_val
|
||||
self.is_integer = is_integer
|
||||
|
||||
def sample(self, size: int = 1) -> np.ndarray:
|
||||
if self.is_integer:
|
||||
return np.random.randint(
|
||||
int(self.min_val), int(self.max_val + 1), size=size
|
||||
)
|
||||
else:
|
||||
return np.random.uniform(self.min_val, self.max_val, size=size)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"UniformDistribution[{self.min_val}, {self.max_val}]"
|
||||
|
||||
|
||||
class ConstantDistribution(Distribution):
|
||||
def __init__(self, value: Union[int, float]) -> None:
|
||||
self.value = value
|
||||
self.max_val = value
|
||||
|
||||
def sample(self, size: int = 1) -> np.ndarray:
|
||||
return np.full(shape=size, fill_value=self.value)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Constant[{self.value}]"
|
||||
|
||||
|
||||
class ZipfDistribution(Distribution):
|
||||
def __init__(self, alpha: float, max_val: Optional[int] = None) -> None:
|
||||
self.alpha = alpha
|
||||
self.max_val = max_val
|
||||
|
||||
def sample(self, size: int = 1) -> np.ndarray:
|
||||
samples = np.random.zipf(self.alpha, size=size)
|
||||
if self.max_val:
|
||||
samples = np.minimum(samples, self.max_val)
|
||||
return samples
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"ZipfDistribution[{self.alpha}]"
|
||||
|
||||
|
||||
class PoissonDistribution(Distribution):
|
||||
def __init__(self, alpha: float, max_val: Optional[int] = None) -> None:
|
||||
self.alpha = alpha
|
||||
self.max_val = max_val
|
||||
|
||||
def sample(self, size: int = 1) -> np.ndarray:
|
||||
samples = np.random.poisson(self.alpha, size=size)
|
||||
if self.max_val:
|
||||
samples = np.minimum(samples, self.max_val)
|
||||
return samples
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"PoissonDistribution[{self.alpha}]"
|
||||
|
||||
|
||||
class LognormalDistribution(Distribution):
|
||||
def __init__(
|
||||
self, mean: float, sigma: float, max_val: Optional[int] = None
|
||||
) -> None:
|
||||
self.mean = mean
|
||||
self.sigma = sigma
|
||||
self.max_val = max_val
|
||||
|
||||
def sample(self, size: int = 1) -> np.ndarray:
|
||||
samples = np.random.lognormal(mean=self.mean, sigma=self.sigma, size=size)
|
||||
if self.max_val:
|
||||
samples = np.minimum(samples, self.max_val)
|
||||
|
||||
return np.round(samples).astype(int)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"LognormalDistribution[{self.mean}, {self.sigma}]"
|
||||
|
||||
|
||||
class GenConvArgs(NamedTuple):
|
||||
num_conversations: int
|
||||
text_files: list[str]
|
||||
input_num_turns: Distribution
|
||||
input_common_prefix_num_tokens: Distribution
|
||||
input_prefix_num_tokens: Distribution
|
||||
input_num_tokens: Distribution
|
||||
output_num_tokens: Distribution
|
||||
print_stats: bool
|
||||
|
||||
|
||||
def verify_field_exists(
|
||||
conf: dict, field_name: str, section: str, subsection: str
|
||||
) -> None:
|
||||
if field_name not in conf:
|
||||
raise ValueError(
|
||||
f"Missing field '{field_name}' in {section=} and {subsection=}"
|
||||
)
|
||||
|
||||
|
||||
def get_random_distribution(
|
||||
conf: dict, section: str, subsection: str, optional: bool = False
|
||||
) -> Distribution:
|
||||
# section can be "prompt_input" or "prompt_output" (both required)
|
||||
conf = conf[section]
|
||||
|
||||
if optional and subsection not in conf:
|
||||
# Optional subsection, if not found assume the value is always 0
|
||||
return ConstantDistribution(0)
|
||||
|
||||
# subsection can be "num_turns", "num_tokens" or "prefix_num_tokens"
|
||||
if subsection not in conf:
|
||||
raise ValueError(f"Missing subsection {subsection} in section {section}")
|
||||
|
||||
conf = conf[subsection]
|
||||
|
||||
distribution = conf.get("distribution")
|
||||
if distribution is None:
|
||||
raise ValueError(
|
||||
f"Missing field 'distribution' in {section=} and {subsection=}"
|
||||
)
|
||||
|
||||
if distribution == "constant":
|
||||
verify_field_exists(conf, "value", section, subsection)
|
||||
return ConstantDistribution(conf["value"])
|
||||
|
||||
elif distribution == "zipf":
|
||||
verify_field_exists(conf, "alpha", section, subsection)
|
||||
max_val = conf.get("max", None)
|
||||
return ZipfDistribution(conf["alpha"], max_val=max_val)
|
||||
|
||||
elif distribution == "poisson":
|
||||
verify_field_exists(conf, "alpha", section, subsection)
|
||||
max_val = conf.get("max", None)
|
||||
return PoissonDistribution(conf["alpha"], max_val=max_val)
|
||||
|
||||
elif distribution == "lognormal":
|
||||
verify_field_exists(conf, "mean", section, subsection)
|
||||
verify_field_exists(conf, "sigma", section, subsection)
|
||||
max_val = conf.get("max", None)
|
||||
return LognormalDistribution(conf["mean"], conf["sigma"], max_val=max_val)
|
||||
|
||||
elif distribution == "uniform":
|
||||
verify_field_exists(conf, "min", section, subsection)
|
||||
verify_field_exists(conf, "max", section, subsection)
|
||||
|
||||
min_value = conf["min"]
|
||||
max_value = conf["max"]
|
||||
|
||||
assert min_value > 0
|
||||
assert min_value <= max_value
|
||||
|
||||
is_integer = isinstance(min_value, int) and isinstance(max_value, int)
|
||||
return UniformDistribution(min_value, max_value, is_integer)
|
||||
else:
|
||||
raise ValueError(f"Unknown distribution: {distribution}")
|
||||
|
||||
|
||||
def parse_input_json_file(conf: dict) -> GenConvArgs:
|
||||
# Validate the input file
|
||||
assert isinstance(conf, dict)
|
||||
required_fields = [
|
||||
"filetype",
|
||||
"num_conversations",
|
||||
"text_files",
|
||||
"prompt_input",
|
||||
"prompt_output",
|
||||
]
|
||||
for field in required_fields:
|
||||
assert field in conf, f"Missing field {field} in input {conf}"
|
||||
|
||||
assert conf["filetype"] == "generate_conversations"
|
||||
|
||||
assert conf["num_conversations"] > 0, "num_conversations should be larger than zero"
|
||||
|
||||
text_files = conf["text_files"]
|
||||
|
||||
assert isinstance(text_files, list), "Field 'text_files' should be a list"
|
||||
assert len(text_files) > 0, (
|
||||
"Field 'text_files' should be a list with at least one file"
|
||||
)
|
||||
|
||||
# Parse the parameters for the prompt input/output workload
|
||||
input_num_turns = get_random_distribution(conf, "prompt_input", "num_turns")
|
||||
input_num_tokens = get_random_distribution(conf, "prompt_input", "num_tokens")
|
||||
input_common_prefix_num_tokens = get_random_distribution(
|
||||
conf, "prompt_input", "common_prefix_num_tokens", optional=True
|
||||
)
|
||||
input_prefix_num_tokens = get_random_distribution(
|
||||
conf, "prompt_input", "prefix_num_tokens"
|
||||
)
|
||||
output_num_tokens = get_random_distribution(conf, "prompt_output", "num_tokens")
|
||||
|
||||
print_stats: bool = conf.get("print_stats", False)
|
||||
assert isinstance(print_stats, bool), (
|
||||
"Field 'print_stats' should be either 'true' or 'false'"
|
||||
)
|
||||
|
||||
args = GenConvArgs(
|
||||
num_conversations=conf["num_conversations"],
|
||||
text_files=text_files,
|
||||
input_num_turns=input_num_turns,
|
||||
input_common_prefix_num_tokens=input_common_prefix_num_tokens,
|
||||
input_prefix_num_tokens=input_prefix_num_tokens,
|
||||
input_num_tokens=input_num_tokens,
|
||||
output_num_tokens=output_num_tokens,
|
||||
print_stats=print_stats,
|
||||
)
|
||||
return args
|
||||
|
||||
|
||||
def print_conv_stats(conversations: ConversationsMap, tokenizer: AutoTokenizer) -> None:
|
||||
# Collect statistics
|
||||
conv_stats: list[dict[Any, Any]] = []
|
||||
req_stats: list[int] = []
|
||||
|
||||
print("\nCollecting statistics...")
|
||||
for messages in conversations.values():
|
||||
# messages is a list of dicts
|
||||
user_tokens: list[int] = []
|
||||
assistant_tokens: list[int] = []
|
||||
request_tokens: list[int] = []
|
||||
|
||||
req_tokens = 0
|
||||
for m in messages:
|
||||
content = m["content"]
|
||||
num_tokens = len(tokenizer(content).input_ids)
|
||||
|
||||
if m["role"] == "user":
|
||||
user_tokens.append(num_tokens)
|
||||
# New user prompt including all chat history
|
||||
req_tokens += num_tokens
|
||||
request_tokens.append(req_tokens)
|
||||
|
||||
elif m["role"] == "assistant":
|
||||
assistant_tokens.append(num_tokens)
|
||||
# Update assistant answer
|
||||
# (will be part of chat history for the next user prompt)
|
||||
req_tokens += num_tokens
|
||||
|
||||
item_stats = {
|
||||
"conversation_turns": len(messages),
|
||||
"user_tokens": mean(user_tokens),
|
||||
"assistant_tokens": mean(assistant_tokens),
|
||||
}
|
||||
|
||||
conv_stats.append(item_stats)
|
||||
req_stats.extend(request_tokens)
|
||||
|
||||
# Print statistics
|
||||
percentiles = [0.25, 0.5, 0.75, 0.9, 0.99]
|
||||
|
||||
print(TEXT_SEPARATOR)
|
||||
print(f"{Color.YELLOW}Conversations statistics:{Color.RESET}")
|
||||
print(TEXT_SEPARATOR)
|
||||
df = pd.DataFrame(conv_stats)
|
||||
print(df.describe(percentiles=percentiles).transpose())
|
||||
print(TEXT_SEPARATOR)
|
||||
print(f"{Color.YELLOW}Request statistics:{Color.RESET}")
|
||||
print(TEXT_SEPARATOR)
|
||||
df = pd.DataFrame(req_stats, columns=["request_tokens"])
|
||||
print(df.describe(percentiles=percentiles).transpose())
|
||||
print(TEXT_SEPARATOR)
|
||||
|
||||
|
||||
def generate_conversations(
|
||||
args: GenConvArgs, tokenizer: AutoTokenizer
|
||||
) -> ConversationsMap:
|
||||
# Text for all user prompts
|
||||
# (text from the input text files will be appended to this line)
|
||||
base_prompt_text = "Please rewrite the following text and add more content: "
|
||||
base_prompt_token_count = len(
|
||||
tokenizer.encode(base_prompt_text, add_special_tokens=False)
|
||||
)
|
||||
|
||||
logger.info(f"{Color.PURPLE}Generating conversations...{Color.RESET}")
|
||||
logger.info(args)
|
||||
|
||||
list_of_tokens = []
|
||||
|
||||
for filename in args.text_files:
|
||||
# Load text file that will be used to generate prompts
|
||||
with open(filename) as file:
|
||||
data = file.read()
|
||||
tokens_in_file = tokenizer.encode(data, add_special_tokens=False)
|
||||
list_of_tokens.extend(tokens_in_file)
|
||||
|
||||
conversations: ConversationsMap = {}
|
||||
conv_id = 0
|
||||
|
||||
# Generate number of turns for every conversation
|
||||
turn_count: np.ndarray = args.input_num_turns.sample(args.num_conversations)
|
||||
|
||||
# Turn count should be at least 2 (one user prompt and one assistant answer)
|
||||
turn_count = np.maximum(turn_count, 2)
|
||||
|
||||
# Round up to an even number (every user prompt should have an answer)
|
||||
turn_count = turn_count + (turn_count % 2)
|
||||
|
||||
# Generate number of prefix tokens for every conversation
|
||||
conv_prefix_tokens: np.ndarray = args.input_prefix_num_tokens.sample(
|
||||
args.num_conversations
|
||||
)
|
||||
|
||||
# Used to reduce shared text between conversations
|
||||
# (jump/skip over text sections between conversations)
|
||||
base_offset = 0
|
||||
|
||||
# Common prefix size for all conversations (only 1 sample required)
|
||||
common_prefix_text = ""
|
||||
common_prefix_tokens: int = args.input_common_prefix_num_tokens.sample(1)[0]
|
||||
if common_prefix_tokens > 0:
|
||||
# Using "." at the end to separate sentences
|
||||
common_prefix_text = (
|
||||
tokenizer.decode(list_of_tokens[: common_prefix_tokens - 2]) + "."
|
||||
)
|
||||
base_offset += common_prefix_tokens
|
||||
|
||||
for conv_id in range(args.num_conversations):
|
||||
# Generate a single conversation
|
||||
messages: MessagesList = []
|
||||
|
||||
nturns = turn_count[conv_id]
|
||||
|
||||
# User prompt token count per turn (with lower limit)
|
||||
input_token_count: np.ndarray = args.input_num_tokens.sample(nturns)
|
||||
input_token_count = np.maximum(input_token_count, base_prompt_token_count)
|
||||
|
||||
# Assistant answer token count per turn (with lower limit)
|
||||
output_token_count: np.ndarray = args.output_num_tokens.sample(nturns)
|
||||
output_token_count = np.maximum(output_token_count, 1)
|
||||
|
||||
user_turn = True
|
||||
for turn_id in range(nturns):
|
||||
if user_turn:
|
||||
role = "user"
|
||||
num_tokens = input_token_count[turn_id]
|
||||
|
||||
# Generate the user prompt,
|
||||
# use a unique prefix (the conv_id) for each conversation
|
||||
# (to avoid shared prefix between conversations)
|
||||
content = f"{conv_id} is a nice number... "
|
||||
|
||||
if len(common_prefix_text) > 0 and turn_id == 0:
|
||||
content = common_prefix_text + content
|
||||
|
||||
# Update the number of tokens left for the content
|
||||
num_tokens -= len(tokenizer.encode(content, add_special_tokens=False))
|
||||
|
||||
if turn_id == 0:
|
||||
prefix_num_tokens = conv_prefix_tokens[conv_id]
|
||||
if prefix_num_tokens > 0:
|
||||
# Add prefix text (context) to the first turn
|
||||
start_offset = base_offset
|
||||
end_offset = start_offset + prefix_num_tokens
|
||||
assert len(list_of_tokens) > end_offset, (
|
||||
"Not enough input text to generate "
|
||||
f"{prefix_num_tokens} tokens for the "
|
||||
f"prefix text ({start_offset=}, {end_offset=})"
|
||||
)
|
||||
|
||||
content += f"{conv_id}, " + tokenizer.decode(
|
||||
list_of_tokens[start_offset:end_offset]
|
||||
)
|
||||
base_offset += prefix_num_tokens
|
||||
|
||||
# Add the actual user prompt/question after the prefix text
|
||||
content += base_prompt_text
|
||||
num_tokens -= base_prompt_token_count
|
||||
|
||||
if num_tokens > 0:
|
||||
# Add text from the input file (to reach the desired token count)
|
||||
start_offset = base_offset + turn_id * input_token_count.max()
|
||||
end_offset = start_offset + num_tokens
|
||||
assert len(list_of_tokens) > end_offset, (
|
||||
f"Not enough input text to generate {num_tokens} tokens "
|
||||
f"for the prompt ({start_offset=}, {end_offset=})"
|
||||
)
|
||||
|
||||
# Convert tokens back to text
|
||||
content += tokenizer.decode(list_of_tokens[start_offset:end_offset])
|
||||
else:
|
||||
role = "assistant"
|
||||
# This content will not be used as input to the LLM server
|
||||
# (actual answers will be used instead).
|
||||
# Content is only required to determine the min_tokens/max_tokens
|
||||
# (inputs to the LLM server).
|
||||
num_tokens = output_token_count[turn_id]
|
||||
assert len(list_of_tokens) > num_tokens, (
|
||||
f"Not enough input text to generate {num_tokens} "
|
||||
"tokens for assistant content"
|
||||
)
|
||||
content = tokenizer.decode(list_of_tokens[:num_tokens])
|
||||
|
||||
# Append the user/assistant message to the list of messages
|
||||
messages.append({"role": role, "content": content})
|
||||
user_turn = not user_turn
|
||||
|
||||
# Add the new conversation
|
||||
conversations[f"CONV_ID_{conv_id}"] = messages
|
||||
|
||||
# Increase base offset for the next conversation
|
||||
base_offset += nturns
|
||||
|
||||
if args.print_stats:
|
||||
print_conv_stats(conversations, tokenizer)
|
||||
|
||||
return conversations
|
||||
|
||||
|
||||
def conversations_list_to_dict(input_list: ShareGptConversations) -> ConversationsMap:
|
||||
conversations: ConversationsMap = {}
|
||||
|
||||
for item in input_list:
|
||||
conv_id: str = item["id"]
|
||||
assert isinstance(conv_id, str)
|
||||
|
||||
assert conv_id not in conversations, (
|
||||
f"Conversation ID {conv_id} found more than once in the input"
|
||||
)
|
||||
|
||||
messages: MessagesList = item["messages"]
|
||||
assert isinstance(messages, list), (
|
||||
f"Conversation messages should be a list (ID: {conv_id})"
|
||||
)
|
||||
assert len(messages) > 0, f"Conversation with no messages (ID: {conv_id})"
|
||||
|
||||
conversations[conv_id] = messages
|
||||
|
||||
logger.info(f"Using {len(conversations)} unique conversations (IDs)")
|
||||
assert len(conversations) == len(input_list)
|
||||
|
||||
# Print statistics about the selected conversations
|
||||
stats: list[dict[str, Any]] = []
|
||||
for conv_data in conversations.values():
|
||||
stats.append({"num_turns": len(conv_data)})
|
||||
|
||||
print(TEXT_SEPARATOR)
|
||||
print(f"{Color.YELLOW}Conversations statistics:{Color.RESET}")
|
||||
print(TEXT_SEPARATOR)
|
||||
percentiles = [0.25, 0.5, 0.75, 0.9, 0.99, 0.999, 0.9999]
|
||||
conv_stats = pd.DataFrame(stats).describe(percentiles=percentiles)
|
||||
print(conv_stats.transpose())
|
||||
print(TEXT_SEPARATOR)
|
||||
|
||||
return conversations
|
||||
|
||||
|
||||
def conversations_dict_to_list(input_dict: ConversationsMap) -> ShareGptConversations:
|
||||
output: ShareGptConversations = []
|
||||
for conv_id, conv_data in input_dict.items():
|
||||
new_item = {"id": conv_id, "messages": conv_data}
|
||||
output.append(new_item)
|
||||
|
||||
return output
|
||||
@ -1,25 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import logging
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class Color(str, Enum):
|
||||
RED = "\033[91m"
|
||||
GREEN = "\033[92m"
|
||||
BLUE = "\033[94m"
|
||||
PURPLE = "\033[95m"
|
||||
CYAN = "\033[96m"
|
||||
YELLOW = "\033[93m"
|
||||
RESET = "\033[0m"
|
||||
|
||||
|
||||
TEXT_SEPARATOR = "-" * 100
|
||||
|
||||
# Configure the logger
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] - %(message)s",
|
||||
datefmt="%d-%m-%Y %H:%M:%S",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,354 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Download dataset from:
|
||||
https://huggingface.co/datasets/philschmid/sharegpt-raw/blob/main/sharegpt_20230401_clean_lang_split.json
|
||||
|
||||
Convert to OpenAI API:
|
||||
export INPUT_FILE=sharegpt_20230401_clean_lang_split.json
|
||||
python convert_sharegpt_to_openai.py $INPUT_FILE sharegpt_conv_128.json --max-items=128
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import random
|
||||
from statistics import mean
|
||||
from typing import Any, Optional
|
||||
|
||||
import pandas as pd # type: ignore
|
||||
import tqdm # type: ignore
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
|
||||
|
||||
def has_non_english_chars(text: str) -> bool:
|
||||
return not text.isascii()
|
||||
|
||||
|
||||
def content_is_valid(
|
||||
content: str, min_content_len: Optional[int], max_content_len: Optional[int]
|
||||
) -> bool:
|
||||
if min_content_len and len(content) < min_content_len:
|
||||
return False
|
||||
|
||||
if max_content_len and len(content) > max_content_len:
|
||||
return False
|
||||
|
||||
return has_non_english_chars(content)
|
||||
|
||||
|
||||
def print_stats(
|
||||
conversations: "list[dict[Any, Any]]", tokenizer: Optional[AutoTokenizer] = None
|
||||
) -> None:
|
||||
# Collect statistics
|
||||
stats = []
|
||||
|
||||
print("\nCollecting statistics...")
|
||||
for item in tqdm.tqdm(conversations):
|
||||
# item has "id" and "messages"
|
||||
messages = item["messages"]
|
||||
|
||||
user_turns = 0
|
||||
assistant_turns = 0
|
||||
user_words = 0
|
||||
assistant_words = 0
|
||||
conv_chars = 0
|
||||
|
||||
user_tokens: list[int] = []
|
||||
assistant_tokens: list[int] = []
|
||||
|
||||
for m in messages:
|
||||
content = m["content"]
|
||||
conv_chars += len(content)
|
||||
content_num_words = content.count(" ") + 1
|
||||
|
||||
num_tokens = 0
|
||||
if tokenizer:
|
||||
num_tokens = len(tokenizer(m["content"]).input_ids)
|
||||
|
||||
if m["role"] == "user":
|
||||
user_turns += 1
|
||||
user_words += content_num_words
|
||||
if tokenizer:
|
||||
user_tokens.append(num_tokens)
|
||||
|
||||
elif m["role"] == "assistant":
|
||||
assistant_turns += 1
|
||||
assistant_words += content_num_words
|
||||
if tokenizer:
|
||||
assistant_tokens.append(num_tokens)
|
||||
|
||||
# assert user_turns == assistant_turns, \
|
||||
# f"Invalid conversation ID {item['id']}"
|
||||
|
||||
conv_words = user_words + assistant_words
|
||||
item_stats = {
|
||||
"user_turns": user_turns,
|
||||
"assistant_turns": assistant_turns,
|
||||
"user_words": user_words,
|
||||
"assistant_words": assistant_words,
|
||||
"conv_turns": len(messages),
|
||||
"conv_words": conv_words,
|
||||
"conv_characters": conv_chars,
|
||||
}
|
||||
|
||||
if len(user_tokens) > 0:
|
||||
item_stats["user_tokens"] = int(mean(user_tokens))
|
||||
|
||||
if len(assistant_tokens) > 0:
|
||||
item_stats["assistant_tokens"] = int(mean(assistant_tokens))
|
||||
|
||||
stats.append(item_stats)
|
||||
|
||||
print("\nStatistics:")
|
||||
percentiles = [0.25, 0.5, 0.75, 0.9, 0.99, 0.999, 0.9999]
|
||||
df = pd.DataFrame(stats)
|
||||
print(df.describe(percentiles=percentiles).transpose())
|
||||
|
||||
|
||||
def convert_sharegpt_to_openai(
|
||||
seed: int,
|
||||
input_file: str,
|
||||
output_file: str,
|
||||
max_items: Optional[int],
|
||||
min_content_len: Optional[int] = None,
|
||||
max_content_len: Optional[int] = None,
|
||||
min_turns: Optional[int] = None,
|
||||
max_turns: Optional[int] = None,
|
||||
model: Optional[str] = None,
|
||||
) -> None:
|
||||
if min_turns and max_turns:
|
||||
assert min_turns <= max_turns
|
||||
|
||||
if min_content_len and max_content_len:
|
||||
# Verify that min is not larger than max if both were given
|
||||
assert min_content_len <= max_content_len
|
||||
|
||||
print(
|
||||
f"Input parameters:\n{seed=}, {max_items=}, {min_content_len=},"
|
||||
f" {max_content_len=}, {min_turns=}, {max_turns=}\n"
|
||||
)
|
||||
|
||||
random.seed(seed)
|
||||
|
||||
tokenizer = None
|
||||
if model is not None:
|
||||
print(f"Loading tokenizer from: {model}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||
|
||||
# Read the ShareGPT JSON file
|
||||
print(f"Reading file: {input_file}")
|
||||
with open(input_file, encoding="utf-8") as f:
|
||||
# Should be a list of dicts
|
||||
# Each dict should have "id" (string) and "conversations" (list of dicts)
|
||||
sharegpt_data = json.load(f)
|
||||
|
||||
assert isinstance(sharegpt_data, list), "Input file should contain a list of dicts"
|
||||
|
||||
print(f"Total items in input file: {len(sharegpt_data):,}")
|
||||
|
||||
print(f"Shuffling dataset with seed {seed}")
|
||||
random.shuffle(sharegpt_data)
|
||||
|
||||
# Map conversation ID to the all the messages
|
||||
conversation_parts: dict[str, list[Any]] = {}
|
||||
|
||||
for item in tqdm.tqdm(sharegpt_data):
|
||||
assert "id" in item, "Missing key 'id'"
|
||||
assert "conversations" in item, "Missing key 'conversations'"
|
||||
|
||||
# Conversation ID (e.g: "hiWPlMD") and part/session (0, 1, 2, etc.)
|
||||
conv_id, _ = item["id"].split("_")
|
||||
new_turns = item["conversations"]
|
||||
|
||||
if conv_id not in conversation_parts:
|
||||
# Start new conversation
|
||||
conversation_parts[conv_id] = []
|
||||
elif len(conversation_parts[conv_id]) > 0 and len(new_turns) > 0:
|
||||
prev_turns = conversation_parts[conv_id][-1]
|
||||
if prev_turns[-1]["from"] == new_turns[0]["from"]:
|
||||
new_turns = new_turns[1:]
|
||||
|
||||
if len(new_turns) > 0:
|
||||
# We assume that parts are in order in the ShareGPT dataset
|
||||
conversation_parts[conv_id].append(new_turns)
|
||||
|
||||
dataset: list[dict[str, Any]] = []
|
||||
for conv_id, conv_parts in conversation_parts.items():
|
||||
new_item = {"id": conv_id}
|
||||
|
||||
conversations: list[dict[str, str]] = []
|
||||
|
||||
# Merge all parts
|
||||
for conv_part in conv_parts:
|
||||
conversations.extend(conv_part)
|
||||
|
||||
if len(conversations) > 0:
|
||||
new_item["conversations"] = conversations
|
||||
dataset.append(new_item)
|
||||
|
||||
print(f"Total unique conversations (IDs) in input file: {len(dataset):,}")
|
||||
|
||||
# Final output data
|
||||
final_openai_dataset: list[dict] = []
|
||||
|
||||
# Filter conversations from the ShareGPT dataset and convert to OpenAI format
|
||||
for item in tqdm.tqdm(dataset):
|
||||
messages: list[dict] = []
|
||||
|
||||
assert "id" in item, "Missing key 'id'"
|
||||
assert "conversations" in item, "Missing key 'conversations'"
|
||||
|
||||
conv_id = item["id"]
|
||||
conversations = item["conversations"]
|
||||
|
||||
if min_turns is not None and len(conversations) < min_turns:
|
||||
# Skip short conversations
|
||||
continue
|
||||
|
||||
# Convert each message in the conversation, up to max_turns if specified
|
||||
for i, turn in enumerate(conversations):
|
||||
assert "from" in turn and "value" in turn, (
|
||||
f"Invalid conversation ID {conv_id} - missing 'from' or 'value'"
|
||||
)
|
||||
|
||||
role = None
|
||||
turn_from = turn["from"]
|
||||
|
||||
if turn_from in {"human", "user"}:
|
||||
role = "user"
|
||||
elif turn_from in {"gpt", "bing", "chatgpt", "bard"}:
|
||||
role = "assistant"
|
||||
elif turn_from == "system":
|
||||
role = "system"
|
||||
|
||||
assert role is not None, (
|
||||
f"Invalid conversation ID {conv_id} - 'from'='{turn_from}' is invalid"
|
||||
)
|
||||
|
||||
if i == 0 and role != "user":
|
||||
# If the first message is from assistant (gpt), skip it.
|
||||
# this happens when the conversation is a follow-up
|
||||
# to a previous conversation (from the same user).
|
||||
continue
|
||||
|
||||
if max_turns is not None and i >= max_turns:
|
||||
break
|
||||
|
||||
# Convert message to OpenAI format (with "role" and "content")
|
||||
content = turn["value"]
|
||||
messages.append({"role": role, "content": content})
|
||||
|
||||
# Add the converted conversation to the OpenAI format
|
||||
if len(messages) > 0:
|
||||
valid_messages = True
|
||||
|
||||
# First turn should always be from the user
|
||||
user_turn = True
|
||||
|
||||
for m in messages:
|
||||
# Make sure that turns alternate between user and assistant
|
||||
if (user_turn and m["role"] != "user") or (
|
||||
not user_turn and m["role"] != "assistant"
|
||||
):
|
||||
valid_messages = False
|
||||
break
|
||||
|
||||
user_turn = not user_turn
|
||||
|
||||
content = m["content"]
|
||||
valid_messages = content_is_valid(
|
||||
content, min_content_len, max_content_len
|
||||
)
|
||||
if not valid_messages:
|
||||
break
|
||||
|
||||
if valid_messages is True:
|
||||
final_openai_dataset.append({"id": conv_id, "messages": messages})
|
||||
|
||||
assert len(final_openai_dataset) > 0, "Final number of conversations is zero"
|
||||
|
||||
print_stats(final_openai_dataset)
|
||||
|
||||
print_stats_again = False
|
||||
if max_items is not None and len(final_openai_dataset) > max_items:
|
||||
print(f"\n\nSampling {max_items} items from the dataset...")
|
||||
print_stats_again = True
|
||||
final_openai_dataset = random.sample(final_openai_dataset, max_items)
|
||||
|
||||
if print_stats_again:
|
||||
# Print stats after the dataset changed
|
||||
print_stats(final_openai_dataset, tokenizer)
|
||||
|
||||
# Write the converted data to a new JSON file
|
||||
final_size = len(final_openai_dataset)
|
||||
print(f"\nTotal conversations converted (after filtering): {final_size:,}")
|
||||
print(f"\nWriting file: {output_file}")
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump(final_openai_dataset, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert ShareGPT dataset to OpenAI API format"
|
||||
)
|
||||
parser.add_argument("input_file", help="Path to the input ShareGPT JSON file")
|
||||
parser.add_argument(
|
||||
"output_file", help="Path to the output OpenAI format JSON file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", type=int, default=0, help="Seed for random number generators"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-items",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum number of items in the output file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min-turns",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Minimum number of turns per conversation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-turns",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum number of turns per conversation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min-content-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Min number of characters in the messages' content",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-content-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Max number of characters in the messages' content",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default=None,
|
||||
help="LLM model, only the tokenizer will be used",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_sharegpt_to_openai(
|
||||
args.seed,
|
||||
args.input_file,
|
||||
args.output_file,
|
||||
args.max_items,
|
||||
args.min_content_len,
|
||||
args.max_content_len,
|
||||
args.min_turns,
|
||||
args.max_turns,
|
||||
args.model,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -1,35 +0,0 @@
|
||||
{
|
||||
"filetype": "generate_conversations",
|
||||
"num_conversations": 24,
|
||||
"text_files": ["pg1184.txt"],
|
||||
"print_stats": false,
|
||||
"prompt_input": {
|
||||
"num_turns": {
|
||||
"distribution": "uniform",
|
||||
"min": 12,
|
||||
"max": 18
|
||||
},
|
||||
"common_prefix_num_tokens": {
|
||||
"distribution": "constant",
|
||||
"value": 500
|
||||
},
|
||||
"prefix_num_tokens": {
|
||||
"distribution": "lognormal",
|
||||
"mean": 6,
|
||||
"sigma": 4,
|
||||
"max": 1500
|
||||
},
|
||||
"num_tokens": {
|
||||
"distribution": "uniform",
|
||||
"min": 120,
|
||||
"max": 160
|
||||
}
|
||||
},
|
||||
"prompt_output": {
|
||||
"num_tokens": {
|
||||
"distribution": "uniform",
|
||||
"min": 80,
|
||||
"max": 120
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,5 +0,0 @@
|
||||
numpy>=1.24
|
||||
pandas>=2.0.0
|
||||
aiohttp>=3.10
|
||||
transformers>=4.46
|
||||
xlsxwriter>=3.2.1
|
||||
@ -19,7 +19,7 @@ else()
|
||||
FetchContent_Declare(
|
||||
flashmla
|
||||
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git
|
||||
GIT_TAG 0e43e774597682284358ff2c54530757b654b8d1
|
||||
GIT_TAG 575f7724b9762f265bbee5889df9c7d630801845
|
||||
GIT_PROGRESS TRUE
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
@ -37,9 +37,9 @@ cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
|
||||
set(FlashMLA_SOURCES
|
||||
${flashmla_SOURCE_DIR}/csrc/flash_api.cpp
|
||||
${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/kernels/mla_combine.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu)
|
||||
${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_bf16_sm90.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_fp16_sm90.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_metadata.cu)
|
||||
|
||||
set(FlashMLA_INCLUDES
|
||||
${flashmla_SOURCE_DIR}/csrc/cutlass/include
|
||||
|
||||
@ -38,7 +38,7 @@ else()
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn
|
||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||
GIT_TAG 93cf5a08f421a3efd0c4a7e005ef8f742b578ce0
|
||||
GIT_TAG 6dbc6e011a3ebe9349eeb74578940dd7095436ba
|
||||
GIT_PROGRESS TRUE
|
||||
# Don't share the vllm-flash-attn build between build types
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
|
||||
@ -60,13 +60,3 @@ struct enable_sm100_only : Kernel {
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Kernel>
|
||||
struct enable_sm120_only : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE void operator()(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1200
|
||||
Kernel::operator()(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@ -45,9 +45,6 @@ struct SSMParamsBase {
|
||||
index_t out_d_stride;
|
||||
index_t out_z_batch_stride;
|
||||
index_t out_z_d_stride;
|
||||
index_t ssm_states_batch_stride;
|
||||
index_t ssm_states_dim_stride;
|
||||
index_t ssm_states_dstate_stride;
|
||||
|
||||
// Common data pointers.
|
||||
void *__restrict__ A_ptr;
|
||||
|
||||
@ -132,10 +132,8 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
||||
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + sequence_start_index * params.B_batch_stride + group_id * params.B_group_stride;
|
||||
weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
|
||||
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + sequence_start_index * params.C_batch_stride + group_id * params.C_group_stride;
|
||||
input_t *ssm_states = reinterpret_cast<input_t *>(params.ssm_states_ptr) +
|
||||
cache_index * params.ssm_states_batch_stride +
|
||||
dim_id * kNRows * params.ssm_states_dim_stride;
|
||||
|
||||
input_t *ssm_states = reinterpret_cast<input_t *>(params.ssm_states_ptr) + (cache_index * params.dim + dim_id * kNRows) * params.dstate;
|
||||
|
||||
float D_val[kNRows] = {0};
|
||||
if (params.D_ptr != nullptr) {
|
||||
#pragma unroll
|
||||
@ -250,7 +248,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
||||
}
|
||||
// Initialize running total
|
||||
|
||||
scan_t running_prefix = chunk > 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.0, has_initial_state ? float(ssm_states[state_idx * params.ssm_states_dstate_stride]): 0.0);
|
||||
scan_t running_prefix = chunk > 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.0, has_initial_state ? float(ssm_states[state_idx]): 0.0);
|
||||
|
||||
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
|
||||
typename Ktraits::BlockScanT(smem_scan).InclusiveScan(
|
||||
@ -261,7 +259,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
||||
if (threadIdx.x == 0) {
|
||||
smem_running_prefix[state_idx] = prefix_op.running_prefix;
|
||||
if (chunk == n_chunks - 1) {
|
||||
ssm_states[state_idx * params.ssm_states_dstate_stride] = input_t(prefix_op.running_prefix.y);
|
||||
ssm_states[state_idx] = input_t(prefix_op.running_prefix.y);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
@ -483,10 +481,6 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
||||
params.out_batch_stride = out.stride(1);
|
||||
params.out_d_stride = out.stride(0);
|
||||
|
||||
params.ssm_states_batch_stride = ssm_states.stride(0);
|
||||
params.ssm_states_dim_stride = ssm_states.stride(1);
|
||||
params.ssm_states_dstate_stride = ssm_states.stride(2);
|
||||
|
||||
}
|
||||
else{
|
||||
if (!is_variable_B) {
|
||||
@ -515,10 +509,6 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
||||
}
|
||||
params.out_batch_stride = out.stride(0);
|
||||
params.out_d_stride = out.stride(1);
|
||||
|
||||
params.ssm_states_batch_stride = ssm_states.stride(0);
|
||||
params.ssm_states_dim_stride = ssm_states.stride(1);
|
||||
params.ssm_states_dstate_stride = ssm_states.stride(2);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1,23 +0,0 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_blockwise_sm120_fp8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm120_fp8(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::bfloat16_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::half_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@ -1,183 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "cuda_utils.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass_gemm_caller.cuh"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
// clang-format off
|
||||
template <class OutType, int ScaleGranularityM,
|
||||
int ScaleGranularityN, int ScaleGranularityK,
|
||||
class MmaTileShape, class ClusterShape,
|
||||
class EpilogueScheduler, class MainloopScheduler>
|
||||
struct cutlass_3x_gemm_fp8_blockwise {
|
||||
using ElementAB = cutlass::float_e4m3_t;
|
||||
|
||||
using ElementA = ElementAB;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
||||
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
||||
|
||||
using ElementB = ElementAB;
|
||||
// ColumnMajor is used for B to match the CUTLASS convention.
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
|
||||
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
||||
|
||||
using ElementD = OutType;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
using LayoutD_Transpose = typename cutlass::layout::LayoutTranspose<LayoutD>::type;
|
||||
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
using ElementC = void; // TODO: support bias
|
||||
using LayoutC = LayoutD;
|
||||
using LayoutC_Transpose = LayoutD_Transpose;
|
||||
static constexpr int AlignmentC = AlignmentD;
|
||||
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
using ElementBlockScale = float;
|
||||
|
||||
using ScaleConfig = cutlass::detail::Sm120BlockwiseScaleConfig<
|
||||
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
|
||||
cute::UMMA::Major::MN, cute::UMMA::Major::K>;
|
||||
|
||||
// layout_SFA and layout_SFB cannot be swapped since they are deduced.
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
|
||||
using ArchTag = cutlass::arch::Sm120;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
using ElementScalar = float;
|
||||
using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementScalar, RoundStyle>;
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
AlignmentC,
|
||||
ElementD,
|
||||
LayoutD,
|
||||
AlignmentD,
|
||||
EpilogueScheduler,
|
||||
DefaultOperation
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto;
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
ElementA,
|
||||
cute::tuple<LayoutA, LayoutSFA>,
|
||||
AlignmentA,
|
||||
ElementB,
|
||||
cute::tuple<LayoutB, LayoutSFB>,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
MainloopScheduler
|
||||
>::CollectiveOp;
|
||||
|
||||
using KernelType = enable_sm120_only<cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue>>;
|
||||
|
||||
struct GemmKernel : public KernelType {};
|
||||
};
|
||||
|
||||
template <typename Gemm>
|
||||
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using LayoutSFA = typename Gemm::LayoutSFA;
|
||||
using LayoutSFB = typename Gemm::LayoutSFB;
|
||||
using ScaleConfig = typename Gemm::ScaleConfig;
|
||||
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
|
||||
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
|
||||
|
||||
StrideA a_stride;
|
||||
StrideB b_stride;
|
||||
StrideC c_stride;
|
||||
a_stride =
|
||||
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
|
||||
b_stride =
|
||||
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
|
||||
c_stride =
|
||||
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1));
|
||||
|
||||
LayoutSFA layout_SFA =
|
||||
ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
|
||||
LayoutSFB layout_SFB =
|
||||
ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
|
||||
|
||||
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
||||
auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr());
|
||||
auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr());
|
||||
|
||||
auto mainloop_args = [&](){
|
||||
return typename GemmKernel::MainloopArguments{
|
||||
a_ptr, a_stride, b_ptr, b_stride,
|
||||
a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB
|
||||
};
|
||||
}();
|
||||
auto prob_shape = cute::make_shape(m, n, k, 1);
|
||||
|
||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
{}, c_ptr, c_stride, c_ptr, c_stride};
|
||||
c3x::cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
|
||||
epilogue_args);
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
// TODO: better heuristics
|
||||
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
||||
OutType, 1, 128, 128, Shape<_128, _128, _128>,
|
||||
Shape<_1, _1, _1>, cutlass::epilogue::collective::EpilogueScheduleAuto,
|
||||
cutlass::gemm::collective::KernelScheduleAuto>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@ -47,10 +47,4 @@ void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales);
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm120_fp8(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales);
|
||||
} // namespace vllm
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
#include "c3x/scaled_mm_helper.hpp"
|
||||
#include <cudaTypedefs.h>
|
||||
#include "c3x/scaled_mm_kernels.hpp"
|
||||
|
||||
#include "cuda_utils.h"
|
||||
|
||||
/*
|
||||
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
||||
NVIDIA GPUs with sm120 (Blackwell).
|
||||
NVIDIA GPUs with sm120 (Blackwell Geforce).
|
||||
*/
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
|
||||
@ -13,10 +15,20 @@ void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
|
||||
vllm::cutlass_scaled_mm_sm120_fp8,
|
||||
nullptr, // int8 not supported on SM120
|
||||
vllm::cutlass_scaled_mm_blockwise_sm120_fp8);
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
int M = a.size(0), N = b.size(1), K = a.size(1);
|
||||
TORCH_CHECK(
|
||||
(a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
|
||||
(b_scales.numel() == 1 || b_scales.numel() == b.size(1)),
|
||||
"Currently, block scaled fp8 gemm is not implemented for Blackwell");
|
||||
|
||||
// Standard per-tensor/per-token/per-channel scaling
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn,
|
||||
"Currently, only fp8 gemm is implemented for Blackwell");
|
||||
vllm::cutlass_scaled_mm_sm120_fp8(c, a, b, a_scales, b_scales, bias);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@ -392,7 +392,7 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
|
||||
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
|
||||
# Keep this in sync with https://github.com/vllm-project/vllm/blob/main/requirements/cuda.txt
|
||||
# We use `--force-reinstall --no-deps` to avoid issues with the existing FlashInfer wheel.
|
||||
ARG FLASHINFER_GIT_REF="v0.2.10"
|
||||
ARG FLASHINFER_GIT_REF="v0.2.9"
|
||||
RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
|
||||
. /etc/environment
|
||||
git clone --depth 1 --recursive --shallow-submodules \
|
||||
|
||||
@ -113,7 +113,6 @@ WORKDIR /workspace/vllm
|
||||
|
||||
RUN --mount=type=bind,src=requirements/test.in,target=requirements/test.in \
|
||||
cp requirements/test.in requirements/cpu-test.in && \
|
||||
sed -i '/mamba_ssm/d' requirements/cpu-test.in && \
|
||||
sed -i 's/^torch==.*/torch==2.6.0/g' requirements/cpu-test.in && \
|
||||
sed -i 's/torchaudio.*/torchaudio/g' requirements/cpu-test.in && \
|
||||
sed -i 's/torchvision.*/torchvision/g' requirements/cpu-test.in && \
|
||||
|
||||
@ -1,12 +1,9 @@
|
||||
FROM intel/deep-learning-essentials:2025.1.3-0-devel-ubuntu24.04 AS vllm-base
|
||||
# oneapi 2025.0.2 docker base image use rolling 2448 package. https://dgpu-docs.intel.com/releases/packages.html?release=Rolling+2448.13&os=Ubuntu+22.04, and we don't need install driver manually.
|
||||
FROM intel/deep-learning-essentials:2025.0.2-0-devel-ubuntu22.04 AS vllm-base
|
||||
|
||||
RUN rm /etc/apt/sources.list.d/intel-graphics.list
|
||||
|
||||
RUN apt clean && apt-get update -y && \
|
||||
apt-get install -y software-properties-common && \
|
||||
add-apt-repository ppa:deadsnakes/ppa && \
|
||||
apt-get install -y python3.10 python3.10-distutils && \
|
||||
curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10 && \
|
||||
RUN apt-get update -y && \
|
||||
apt-get install -y --no-install-recommends --fix-missing \
|
||||
curl \
|
||||
ffmpeg \
|
||||
@ -17,13 +14,11 @@ RUN apt clean && apt-get update -y && \
|
||||
libgl1 \
|
||||
lsb-release \
|
||||
numactl \
|
||||
python3.10-dev \
|
||||
python3 \
|
||||
python3-dev \
|
||||
python3-pip \
|
||||
wget
|
||||
|
||||
|
||||
RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1
|
||||
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.10 1
|
||||
|
||||
WORKDIR /workspace/vllm
|
||||
COPY requirements/xpu.txt /workspace/vllm/requirements/xpu.txt
|
||||
COPY requirements/common.txt /workspace/vllm/requirements/common.txt
|
||||
|
||||
@ -58,9 +58,10 @@ nav:
|
||||
- CI: contributing/ci
|
||||
- Design Documents: design
|
||||
- API Reference:
|
||||
- Summary: api/summary.md
|
||||
- Summary: api/README.md
|
||||
- Contents:
|
||||
- api/vllm/*
|
||||
- glob: api/vllm/*
|
||||
preserve_directory_names: true
|
||||
- CLI Reference:
|
||||
- Summary: cli/README.md
|
||||
- Community:
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 91 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 88 KiB |
@ -29,9 +29,6 @@ Start the vLLM OpenAI Compatible API server.
|
||||
# Specify the port
|
||||
vllm serve meta-llama/Llama-2-7b-hf --port 8100
|
||||
|
||||
# Serve over a Unix domain socket
|
||||
vllm serve meta-llama/Llama-2-7b-hf --uds /tmp/vllm.sock
|
||||
|
||||
# Check with --help for more options
|
||||
# To list all groups
|
||||
vllm serve --help=listgroup
|
||||
|
||||
@ -86,7 +86,7 @@ llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
|
||||
If you run out of CPU RAM, try the following options:
|
||||
|
||||
- (Multi-modal models only) you can set the size of multi-modal processor cache by setting `mm_processor_cache_gb` engine argument (default 4 GiB per API process + 4 GiB per engine core process)
|
||||
- (Multi-modal models only) you can set the size of multi-modal input cache using `VLLM_MM_INPUT_CACHE_GIB` environment variable (default 4 GiB).
|
||||
- (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB).
|
||||
|
||||
## Multi-modal input limits
|
||||
@ -129,18 +129,20 @@ reduce the size of the processed multi-modal inputs, which in turn saves memory.
|
||||
|
||||
Here are some examples:
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
??? code
|
||||
|
||||
# Available for Qwen2-VL series models
|
||||
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
mm_processor_kwargs={
|
||||
"max_pixels": 768 * 768, # Default is 1280 * 28 * 28
|
||||
})
|
||||
```python
|
||||
from vllm import LLM
|
||||
|
||||
# Available for InternVL series models
|
||||
llm = LLM(model="OpenGVLab/InternVL2-2B",
|
||||
mm_processor_kwargs={
|
||||
"max_dynamic_patch": 4, # Default is 12
|
||||
})
|
||||
```
|
||||
# Available for Qwen2-VL series models
|
||||
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
mm_processor_kwargs={
|
||||
"max_pixels": 768 * 768, # Default is 1280 * 28 * 28
|
||||
})
|
||||
|
||||
# Available for InternVL series models
|
||||
llm = LLM(model="OpenGVLab/InternVL2-2B",
|
||||
mm_processor_kwargs={
|
||||
"max_dynamic_patch": 4, # Default is 12
|
||||
})
|
||||
```
|
||||
|
||||
@ -2,9 +2,6 @@
|
||||
|
||||
This guide covers optimization strategies and performance tuning for vLLM V1.
|
||||
|
||||
!!! tip
|
||||
Running out of memory? Consult [this guide](./conserving_memory.md) on how to conserve memory.
|
||||
|
||||
## Preemption
|
||||
|
||||
Due to the auto-regressive nature of transformer architecture, there are times when KV cache space is insufficient to handle all batched requests.
|
||||
@ -129,50 +126,62 @@ Data parallelism replicates the entire model across multiple GPU sets and proces
|
||||
Data parallelism can be combined with the other parallelism strategies and is set by `data_parallel_size=N`.
|
||||
Note that MoE layers will be sharded according to the product of the tensor parallel size and data parallel size.
|
||||
|
||||
## Input Processing
|
||||
## Reducing Memory Usage
|
||||
|
||||
### Parallel Processing
|
||||
If you encounter out-of-memory issues, consider these strategies:
|
||||
|
||||
You can run input processing in parallel via [API server scale-out](../serving/data_parallel_deployment.md#internal-load-balancing).
|
||||
This is useful when input processing (which is run inside the API server)
|
||||
becomes a bottleneck compared to model execution (which is run inside engine core)
|
||||
and you have excess CPU capacity.
|
||||
### Context Length and Batch Size
|
||||
|
||||
```console
|
||||
# Run 4 API processes and 1 engine core process
|
||||
vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4
|
||||
|
||||
# Run 4 API processes and 2 engine core processes
|
||||
vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4 -dp 2
|
||||
```
|
||||
|
||||
!!! note
|
||||
API server scale-out is only available for online inference.
|
||||
|
||||
!!! note
|
||||
[Multi-modal processor cache](#processor-cache) is disabled when API server scale-out is enabled
|
||||
because it requires a one-to-one correspondance between API and engine core processes.
|
||||
|
||||
## Multi-Modal Caching
|
||||
|
||||
### Processor Cache
|
||||
|
||||
By default, the multi-modal processor cache is enabled to avoid repeatedly processing
|
||||
the same multi-modal inputs via Hugging Face `AutoProcessor`,
|
||||
which commonly occurs in multi-turn conversations.
|
||||
|
||||
You can adjust the size of the cache by setting the value of `mm_processor_cache_gb`
|
||||
(default 4 GiB per API process + 4 GiB per engine core process).
|
||||
If you do not benefit much from the cache, you can disable it completely via `mm_processor_cache_gb=0`.
|
||||
|
||||
Examples:
|
||||
You can reduce memory usage by limiting the context length and batch size:
|
||||
|
||||
```python
|
||||
# Use a larger cache
|
||||
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
mm_processor_cache_gb=8)
|
||||
from vllm import LLM
|
||||
|
||||
# Disable the cache
|
||||
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
mm_processor_cache_gb=0)
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
max_model_len=2048, # Limit context window
|
||||
max_num_seqs=4 # Limit batch size
|
||||
)
|
||||
```
|
||||
|
||||
### Adjust CUDA Graph Compilation
|
||||
|
||||
CUDA graph compilation in V1 uses more memory than in V0. You can reduce memory usage by adjusting the compilation level:
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
from vllm.config import CompilationConfig, CompilationLevel
|
||||
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
cudagraph_capture_sizes=[1, 2, 4, 8] # Capture fewer batch sizes
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
Or, if you are not concerned about latency or overall performance, disable CUDA graph compilation entirely with `enforce_eager=True`:
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
enforce_eager=True # Disable CUDA graph compilation
|
||||
)
|
||||
```
|
||||
|
||||
### Multimodal Models
|
||||
|
||||
For multi-modal models, you can reduce memory usage by limiting the number of images/videos per request:
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
|
||||
# Accept up to 2 images per prompt
|
||||
llm = LLM(
|
||||
model="Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
limit_mm_per_prompt={"image": 2}
|
||||
)
|
||||
```
|
||||
|
||||
@ -131,19 +131,6 @@ MAX_JOBS=16 uv pip install --system \
|
||||
--no-build-isolation "git+https://github.com/facebookresearch/xformers@v0.0.30"
|
||||
```
|
||||
|
||||
### Mamba
|
||||
|
||||
```bash
|
||||
uv pip install --system \
|
||||
--no-build-isolation "git+https://github.com/state-spaces/mamba@v2.2.5"
|
||||
```
|
||||
|
||||
### causal-conv1d
|
||||
|
||||
```bash
|
||||
uv pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8'
|
||||
```
|
||||
|
||||
## Update all the different vLLM platforms
|
||||
|
||||
Rather than attempting to update all vLLM platforms in a single pull request, it's more manageable
|
||||
|
||||
@ -200,8 +200,7 @@ vision-language model.
|
||||
lora_config = vllm_config.lora_config
|
||||
super().__init__(config, cache_config, quant_config, lora_config, prefix)
|
||||
|
||||
from packaging import version
|
||||
if version.parse(__version__) >= version.parse("0.6.4"):
|
||||
if __version__ >= "0.6.4":
|
||||
MyModel = MyNewModel
|
||||
else:
|
||||
MyModel = MyOldModel
|
||||
|
||||
@ -57,11 +57,11 @@ In v0, the following metrics are exposed via a Prometheus-compatible `/metrics`
|
||||
- `vllm:spec_decode_num_draft_tokens_total` (Counter)
|
||||
- `vllm:spec_decode_num_emitted_tokens_total` (Counter)
|
||||
|
||||
These are documented under [Inferencing and Serving -> Production Metrics](../usage/metrics.md).
|
||||
These are documented under [Inferencing and Serving -> Production Metrics](../../usage/metrics.md).
|
||||
|
||||
### Grafana Dashboard
|
||||
|
||||
vLLM also provides [a reference example](../examples/online_serving/prometheus_grafana.md) for how to collect and store these metrics using Prometheus and visualize them using a Grafana dashboard.
|
||||
vLLM also provides [a reference example](../../examples/online_serving/prometheus_grafana.md) for how to collect and store these metrics using Prometheus and visualize them using a Grafana dashboard.
|
||||
|
||||
The subset of metrics exposed in the Grafana dashboard gives us an indication of which metrics are especially important:
|
||||
|
||||
@ -455,7 +455,7 @@ In general:
|
||||
[an escape hatch](https://kubernetes.io/docs/concepts/cluster-administration/system-metrics/#show-hidden-metrics)
|
||||
for some time before deleting them.
|
||||
|
||||
See the [deprecation policy](../contributing/deprecation_policy.md) for
|
||||
See the [deprecation policy](../../contributing/deprecation_policy.md) for
|
||||
the project-wide deprecation policy.
|
||||
|
||||
### Unimplemented - `vllm:tokens_total`
|
||||
@ -655,7 +655,7 @@ v0 has support for OpenTelemetry tracing:
|
||||
- Added by <gh-pr:4687>
|
||||
- Configured with `--oltp-traces-endpoint` and `--collect-detailed-traces`
|
||||
- [OpenTelemetry blog post](https://opentelemetry.io/blog/2024/llm-observability/)
|
||||
- [User-facing docs](../examples/online_serving/opentelemetry.md)
|
||||
- [User-facing docs](../../examples/online_serving/opentelemetry.md)
|
||||
- [Blog post](https://medium.com/@ronen.schaffer/follow-the-trail-supercharging-vllm-with-opentelemetry-distributed-tracing-aa655229b46f)
|
||||
- [IBM product docs](https://www.ibm.com/docs/en/instana-observability/current?topic=mgaa-monitoring-large-language-models-llms-vllm-public-preview)
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
Automatic Prefix Caching (APC in short) caches the KV cache of existing queries, so that a new query can directly reuse the KV cache if it shares the same prefix with one of the existing queries, allowing the new query to skip the computation of the shared part.
|
||||
|
||||
!!! note
|
||||
Technical details on how vLLM implements APC can be found [here](../design/prefix_caching.md).
|
||||
Technical details on how vLLM implements APC can be found [here](../design/automatic_prefix_caching.md).
|
||||
|
||||
## Enabling APC in vLLM
|
||||
|
||||
|
||||
@ -19,18 +19,6 @@ Two main reasons:
|
||||
|
||||
Please refer to <gh-file:examples/online_serving/disaggregated_prefill.sh> for the example usage of disaggregated prefilling.
|
||||
|
||||
Now supports 5 types of connectors:
|
||||
|
||||
- **SharedStorageConnector**: refer to <gh-file:examples/offline_inference/disaggregated-prefill-v1/run.sh> for the example usage of SharedStorageConnector disaggregated prefilling.
|
||||
- **LMCacheConnectorV1**: refer to <gh-file:examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh> for the example usage of LMCacheConnectorV1 disaggregated prefilling which uses NIXL as the underlying KV transmission.
|
||||
- **NixlConnector**: refer to <gh-file:tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh> for the example usage of NixlConnector disaggregated prefilling which support fully async send/recv.
|
||||
- **P2pNcclConnector**: refer to <gh-file:examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_example_p2p_nccl_xpyd.sh> for the example usage of P2pNcclConnector disaggregated prefilling.
|
||||
- **MultiConnector**: take advantage of the kv_connector_extra_config: dict[str, Any] already present in KVTransferConfig to stash all the connectors we want in an ordered list of kwargs.such as:
|
||||
|
||||
```bash
|
||||
--kv-transfer-config '{"kv_connector":"MultiConnector","kv_role":"kv_both","kv_connector_extra_config":{"connectors":[{"kv_connector":"NixlConnector","kv_role":"kv_both"},{"kv_connector":"SharedStorageConnector","kv_role":"kv_both","kv_connector_extra_config":{"shared_storage_path":"local_storage"}}]}}'
|
||||
```
|
||||
|
||||
## Benchmarks
|
||||
|
||||
Please refer to <gh-file:benchmarks/disagg_benchmarks> for disaggregated prefilling benchmarks.
|
||||
@ -60,19 +48,6 @@ The workflow of disaggregated prefilling is as follows:
|
||||
|
||||
The `buffer` corresponds to `insert` API in LookupBuffer, and the `drop_select` corresponds to `drop_select` API in LookupBuffer.
|
||||
|
||||
Now every process in vLLM will have a corresponding connector. Specifically, we have:
|
||||
|
||||
- Scheduler connector: the connector that locates in the same process as the scheduler process. It schedules the KV cache transfer ops.
|
||||
- Worker connectors: the connectors that locate in the worker processes. They execute KV cache transfer ops.
|
||||
|
||||
Here is a figure illustrating how the above 2 connectors are organized:
|
||||
|
||||

|
||||
|
||||
The figure below shows how the worker connector works with the attention module to achieve layer-by-layer KV cache store and load:
|
||||
|
||||

|
||||
|
||||
## Third-party contributions
|
||||
|
||||
Disaggregated prefilling is highly related to infrastructure, so vLLM relies on third-party connectors for production-level disaggregated prefilling (and vLLM team will actively review and merge new PRs for third-party connectors).
|
||||
|
||||
@ -1,4 +1,7 @@
|
||||
# FP8 INC
|
||||
---
|
||||
title: FP8 INC
|
||||
---
|
||||
[](){ #inc }
|
||||
|
||||
vLLM supports FP8 (8-bit floating point) weight and activation quantization using Intel® Neural Compressor (INC) on Intel® Gaudi® 2 and Intel® Gaudi® 3 AI accelerators.
|
||||
Currently, quantization is validated only in Llama models.
|
||||
|
||||
@ -1,80 +0,0 @@
|
||||
# Sleep Mode
|
||||
|
||||
vLLM's Sleep Mode allows you to temporarily release most GPU memory used by a model, including model weights and KV cache, without stopping the server or unloading the Docker container. This is especially useful for RLHF, training, or cost-saving scenarios where GPU resources need to be freed between inference workloads.
|
||||
|
||||
Key benefits:
|
||||
|
||||
- **Frees GPU memory**: Offloads model weights to CPU RAM and discards KV cache, releasing up to 90%+ of GPU memory for other tasks.
|
||||
- **Fast resume**: Quickly wake up the engine and resume inference without full model reload.
|
||||
- **API endpoints**: Control sleep/wake_up state via HTTP endpoints or Python API.
|
||||
- **Supports distributed workloads**: Works with tensor parallelism, pipeline parallelism, etc.
|
||||
- **Fine-grained control**: Optionally wake up only model weights or KV cache to avoid OOM during weight updates.
|
||||
|
||||
!!! note
|
||||
This feature is only supported on CUDA platform.
|
||||
|
||||
## Sleep levels
|
||||
|
||||
Level 1 sleep will offload the model weights and discard the KV cache. The content of KV cache is forgotten. Level 1 sleep is good for sleeping and waking up the engine to run the same model again. The model weights are backed up in CPU memory. Please make sure there's enough CPU memory to store the model weights. Level 2 sleep will discard both the model weights and the KV cache (while the model's buffers are kept in CPU, like rope scaling tensors). The content of both the model weights and KV cache is forgotten. Level 2 sleep is good for sleeping and waking up the engine to run a different model or update the model, where previous model weights are not needed, e.g. RLHF weight update.
|
||||
|
||||
## Usage
|
||||
|
||||
### Offline inference
|
||||
|
||||
Enable sleep mode by passing `enable_sleep_mode=True` to the `LLM` class.
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
llm = LLM("Qwen/Qwen3-0.6B", enable_sleep_mode=True)
|
||||
```
|
||||
|
||||
#### Python API
|
||||
|
||||
```python
|
||||
# Put the engine to sleep (level=1: offload weights to CPU RAM, discard KV cache)
|
||||
llm.sleep(level=1)
|
||||
|
||||
# Wake up the engine (restore weights)
|
||||
llm.wake_up()
|
||||
```
|
||||
|
||||
#### RLHF weight updates
|
||||
|
||||
During RLHF training, vLLM allows you to selectively wake up only the model weights or the KV cache using the tags argument in wake_up(). This fine-grained control is especially useful when updating model weights: by waking up just the weights (e.g., llm.wake_up(tags=["weights"])), you avoid allocating memory for the KV cache until after the weight update is complete. This approach helps prevent GPU out-of-memory (OOM) errors, particularly with large models, by minimizing peak memory usage during weight synchronization and update operations.
|
||||
|
||||
Use `tags=["weights"]` or `tags=["kv_cache"]` to control which resources are restored, useful for RLHF and weight updates. **Note** that `is_sleeping` will report `true` until all components are awake.
|
||||
|
||||
```python
|
||||
# Put engine to deep sleep (level=2)
|
||||
llm.sleep(level=2)
|
||||
# ... Get the new weights
|
||||
# Wake up only weights to avoid OOM
|
||||
llm.wake_up(tags=["weights"])
|
||||
# ... Update the weights
|
||||
# wake up KV cache after weights are updated
|
||||
llm.wake_up(tags=["kv_cache"])
|
||||
```
|
||||
|
||||
### Online Serving
|
||||
|
||||
To enable sleep mode in a vLLM server you need to initialize it with the flag `VLLM_SERVER_DEV_MODE=1` and pass `--enable-sleep-mode` to the vLLM server.
|
||||
|
||||
#### Server in development mode
|
||||
|
||||
When using the flag `VLLM_SERVER_DEV_MODE=1` you enable development endpoints, and these endpoints should not be exposed to users.
|
||||
|
||||
```bash
|
||||
VLLM_SERVER_DEV_MODE=1 python -m vllm.entrypoints.openai.api_server \
|
||||
--model Qwen/Qwen3-0.6B \
|
||||
--enable-sleep-mode \
|
||||
--port 8000
|
||||
```
|
||||
|
||||
#### HTTP endpoints
|
||||
|
||||
- `POST /sleep?level=1` — Put the model to sleep (`level=1`).
|
||||
- `POST /wake_up` — Wake up the model. Supports optional `tags` query parameters for partial wake-up (e.g., `?tags=weights`).
|
||||
- `GET /is_sleeping` — Check if the model is sleeping.
|
||||
|
||||
!!! note
|
||||
These endpoints are only available when passing `VLLM_SERVER_DEV_MODE=1`.
|
||||
@ -105,7 +105,7 @@ class Example:
|
||||
return fix_case(self.path.stem.replace("_", " ").title())
|
||||
|
||||
def generate(self) -> str:
|
||||
content = f"# {self.title}\n\n"
|
||||
content = f"---\ntitle: {self.title}\n---\n\n"
|
||||
content += f"Source <gh-file:{self.path.relative_to(ROOT_DIR)}>.\n\n"
|
||||
|
||||
# Use long code fence to avoid issues with
|
||||
|
||||
@ -320,7 +320,7 @@ th {
|
||||
}
|
||||
</style>
|
||||
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
|
||||
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
|
||||
| `AquilaForCausalLM` | Aquila, Aquila2 | `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `ArceeForCausalLM` | Arcee (AFM) | `arcee-ai/AFM-4.5B-Base`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
@ -370,9 +370,9 @@ th {
|
||||
| `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ | ✅︎ |
|
||||
| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | |
|
||||
| `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | ✅︎ |
|
||||
| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | |
|
||||
| `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | ✅︎ |
|
||||
| `MiMoForCausalLM` | MiMo | `XiaomiMiMo/MiMo-7B-RL`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `MiniCPMForCausalLM` | MiniCPM | `openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, `openbmb/MiniCPM-S-1B-sft`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
@ -426,7 +426,7 @@ See [this page](./pooling_models.md) for more information on how to use pooling
|
||||
|
||||
These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) API.
|
||||
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
|
||||
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
|
||||
| `BertModel`<sup>C</sup> | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | | |
|
||||
| `Gemma2Model`<sup>C</sup> | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | | ✅︎ |
|
||||
@ -466,7 +466,7 @@ of the whole prompt are extracted from the normalized hidden state corresponding
|
||||
|
||||
These models primarily support the [`LLM.classify`](./pooling_models.md#llmclassify) API.
|
||||
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
|
||||
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
|
||||
| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ | |
|
||||
| `GPT2ForSequenceClassification` | GPT2 | `nie3e/sentiment-polish-gpt2-small` | | | ✅︎ |
|
||||
@ -483,7 +483,7 @@ If your model is not in the above list, we will try to automatically convert the
|
||||
Cross-encoder and reranker models are a subset of classification models that accept two prompts as input.
|
||||
These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) API.
|
||||
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
|
||||
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
|
||||
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | | |
|
||||
| `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
@ -521,7 +521,7 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
|
||||
|
||||
These models primarily support the [`LLM.reward`](./pooling_models.md#llmreward) API.
|
||||
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
|
||||
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
|
||||
| `InternLM2ForRewardModel` | InternLM2-based | `internlm/internlm2-1_8b-reward`, `internlm/internlm2-7b-reward`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `LlamaForCausalLM`<sup>C</sup> | Llama-based | `peiyi9979/math-shepherd-mistral-7b-prm`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
@ -594,7 +594,7 @@ See [this page](generative_models.md) for more information on how to use generat
|
||||
|
||||
These models primarily accept the [`LLM.generate`](./generative_models.md#llmgenerate) API. Chat/Instruct models additionally support the [`LLM.chat`](./generative_models.md#llmchat) API.
|
||||
|
||||
| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|
||||
| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
|
||||
|--------------|--------|--------|-------------------|----------------------|---------------------------|---------------------|
|
||||
| `AriaForConditionalGeneration` | Aria | T + I<sup>+</sup> | `rhymes-ai/Aria` | | | ✅︎ |
|
||||
| `AyaVisionForConditionalGeneration` | Aya Vision | T + I<sup>+</sup> | `CohereForAI/aya-vision-8b`, `CohereForAI/aya-vision-32b`, etc. | | ✅︎ | ✅︎ |
|
||||
@ -622,7 +622,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
||||
| `LlavaNextVideoForConditionalGeneration` | LLaVA-NeXT-Video | T + V | `llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. | | ✅︎ | ✅︎ |
|
||||
| `LlavaOnevisionForConditionalGeneration` | LLaVA-Onevision | T + I<sup>+</sup> + V<sup>+</sup> | `llava-hf/llava-onevision-qwen2-7b-ov-hf`, `llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. | | ✅︎ | ✅︎ |
|
||||
| `MiniCPMO` | MiniCPM-O | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>E+</sup> | `openbmb/MiniCPM-o-2_6`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `MiniCPMV` | MiniCPM-V | T + I<sup>E+</sup> + V<sup>E+</sup> | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, `openbmb/MiniCPM-V-4`, etc. | ✅︎ | | ✅︎ |
|
||||
| `MiniCPMV` | MiniCPM-V | T + I<sup>E+</sup> + V<sup>E+</sup> | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, etc. | ✅︎ | | ✅︎ |
|
||||
| `MiniMaxVL01ForConditionalGeneration` | MiniMax-VL | T + I<sup>E+</sup> | `MiniMaxAI/MiniMax-VL-01`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Mistral3ForConditionalGeneration` | Mistral3 (HF Transformers) | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `MllamaForConditionalGeneration` | Llama 3.2 | T + I<sup>+</sup> | `meta-llama/Llama-3.2-90B-Vision-Instruct`, `meta-llama/Llama-3.2-11B-Vision`, etc. | | | |
|
||||
@ -647,7 +647,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
||||
|
||||
Some models are supported only via the [Transformers backend](#transformers). The purpose of the table below is to acknowledge models which we officially support in this way. The logs will say that the Transformers backend is being used, and you will see no warning that this is fallback behaviour. This means that, if you have issues with any of the models listed below, please [make an issue](https://github.com/vllm-project/vllm/issues/new/choose) and we'll do our best to fix it!
|
||||
|
||||
| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|
||||
| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
|
||||
|--------------|--------|--------|-------------------|-----------------------------|-----------------------------------------|---------------------|
|
||||
| `Emu3ForConditionalGeneration` | Emu3 | T + I | `BAAI/Emu3-Chat-hf` | ✅︎ | ✅︎ | ✅︎ |
|
||||
|
||||
@ -726,7 +726,7 @@ Some models are supported only via the [Transformers backend](#transformers). Th
|
||||
|
||||
Speech2Text models trained specifically for Automatic Speech Recognition.
|
||||
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
|
||||
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
|
||||
| `WhisperForConditionalGeneration` | Whisper | `openai/whisper-small`, `openai/whisper-large-v3-turbo`, etc. | | | |
|
||||
| `VoxtralForConditionalGeneration` | Voxtral (Mistral format) | `mistralai/Voxtral-Mini-3B-2507`, `mistralai/Voxtral-Small-24B-2507`, etc. | | ✅︎ | ✅︎ |
|
||||
@ -744,7 +744,7 @@ These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) A
|
||||
|
||||
The following table lists those that are tested in vLLM.
|
||||
|
||||
| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|
||||
| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
|
||||
|--------------|--------|--------|-------------------|----------------------|---------------------------|---------------------|
|
||||
| `LlavaNextForConditionalGeneration`<sup>C</sup> | LLaVA-NeXT-based | T / I | `royokong/e5-v` | | | |
|
||||
| `Phi3VForCausalLM`<sup>C</sup> | Phi-3-Vision-based | T + I | `TIGER-Lab/VLM2Vec-Full` | 🚧 | ✅︎ | |
|
||||
@ -760,7 +760,7 @@ The following table lists those that are tested in vLLM.
|
||||
Cross-encoder and reranker models are a subset of classification models that accept two prompts as input.
|
||||
These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) API.
|
||||
|
||||
| Architecture | Models | Inputs | Example HF Models | [LoRA][lora-adapter] | [PP][parallelism-scaling] | [V1](gh-issue:8779) |
|
||||
| Architecture | Models | Inputs | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) |
|
||||
|-------------------------------------|--------------------|----------|--------------------------|------------------------|-----------------------------|-----------------------|
|
||||
| `JinaVLForSequenceClassification` | JinaVL-based | T + I<sup>E+</sup> | `jinaai/jina-reranker-m0`, etc. | | | ✅︎ |
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# Parallelism and Scaling
|
||||
# Distributed inference and serving
|
||||
|
||||
## Distributed inference strategies for a single-model replica
|
||||
|
||||
@ -128,17 +128,12 @@ vllm serve /path/to/the/model/in/the/container \
|
||||
--tensor-parallel-size 16
|
||||
```
|
||||
|
||||
## Optimizing network communication for tensor parallelism
|
||||
## Troubleshooting distributed deployments
|
||||
|
||||
Efficient tensor parallelism requires fast inter-node communication, preferably through high-speed network adapters such as InfiniBand.
|
||||
To set up the cluster to use InfiniBand, append additional arguments like `--privileged -e NCCL_IB_HCA=mlx5` to the
|
||||
<gh-file:examples/online_serving/run_cluster.sh> helper script.
|
||||
Contact your system administrator for more information about the required flags.
|
||||
To make tensor parallelism performant, ensure that communication between nodes is efficient, for example, by using high-speed network cards such as InfiniBand. To set up the cluster to use InfiniBand, append additional arguments like `--privileged -e NCCL_IB_HCA=mlx5` to the `run_cluster.sh` script. Contact your system administrator for more information about the required flags. One way to confirm if InfiniBand is working is to run `vllm` with the `NCCL_DEBUG=TRACE` environment variable set, for example `NCCL_DEBUG=TRACE vllm serve ...`, and check the logs for the NCCL version and the network used. If you find `[send] via NET/Socket` in the logs, NCCL uses a raw TCP socket, which is not efficient for cross-node tensor parallelism. If you find `[send] via NET/IB/GDRDMA` in the logs, NCCL uses InfiniBand with GPUDirect RDMA, which is efficient.
|
||||
|
||||
## Enabling GPUDirect RDMA
|
||||
|
||||
GPUDirect RDMA (Remote Direct Memory Access) is an NVIDIA technology that allows network adapters to directly access GPU memory, bypassing the CPU and system memory. This direct access reduces latency and CPU overhead, which is beneficial for large data transfers between GPUs across nodes.
|
||||
|
||||
To enable GPUDirect RDMA with vLLM, configure the following settings:
|
||||
|
||||
- `IPC_LOCK` security context: add the `IPC_LOCK` capability to the container's security context to lock memory pages and prevent swapping to disk.
|
||||
@ -180,17 +175,21 @@ spec:
|
||||
...
|
||||
```
|
||||
|
||||
!!! tip "Confirm GPUDirect RDMA operation"
|
||||
To confirm your InfiniBand card is using GPUDirect RDMA, run vLLM with detailed NCCL logs: `NCCL_DEBUG=TRACE vllm serve ...`.
|
||||
Efficient tensor parallelism requires fast inter-node communication, preferably through high-speed network adapters such as InfiniBand. To enable InfiniBand, append flags such as `--privileged -e NCCL_IB_HCA=mlx5` to `run_cluster.sh`. For cluster-specific settings, consult your system administrator.
|
||||
|
||||
Then look for the NCCL version and the network used.
|
||||
To confirm InfiniBand operation, enable detailed NCCL logs:
|
||||
|
||||
- If you find `[send] via NET/IB/GDRDMA` in the logs, then NCCL is using InfiniBand with GPUDirect RDMA, which *is* efficient.
|
||||
- If you find `[send] via NET/Socket` in the logs, NCCL used a raw TCP socket, which *is not* efficient for cross-node tensor parallelism.
|
||||
```bash
|
||||
NCCL_DEBUG=TRACE vllm serve ...
|
||||
```
|
||||
|
||||
Search the logs for the transport method. Entries containing `[send] via NET/Socket` indicate raw TCP sockets, which perform poorly for cross-node tensor parallelism. Entries containing `[send] via NET/IB/GDRDMA` indicate InfiniBand with GPUDirect RDMA, which provides high performance.
|
||||
|
||||
!!! tip "Verify inter-node GPU communication"
|
||||
After you start the Ray cluster, verify GPU-to-GPU communication across nodes. Proper configuration can be non-trivial. For more information, see [troubleshooting script][troubleshooting-incorrect-hardware-driver]. If you need additional environment variables for communication configuration, append them to `run_cluster.sh`, for example `-e NCCL_SOCKET_IFNAME=eth0`. Setting environment variables during cluster creation is recommended because the variables propagate to all nodes. In contrast, setting environment variables in the shell affects only the local node. For more information, see <gh-issue:6803>.
|
||||
|
||||
!!! tip "Pre-download Hugging Face models"
|
||||
If you use Hugging Face models, downloading the model before starting vLLM is recommended. Download the model on every node to the same path, or store the model on a distributed file system accessible by all nodes. Then pass the path to the model in place of the repository ID. Otherwise, supply a Hugging Face token by appending `-e HF_TOKEN=<TOKEN>` to `run_cluster.sh`.
|
||||
|
||||
## Troubleshooting distributed deployments
|
||||
|
||||
For information about distributed debugging, see [Troubleshooting distributed deployments](distributed_troubleshooting.md).
|
||||
!!! tip
|
||||
The error message `Error: No available node types can fulfill resource request` can appear even when the cluster has enough GPUs. The issue often occurs when nodes have multiple IP addresses and vLLM can't select the correct one. Ensure that vLLM and Ray use the same IP address by setting `VLLM_HOST_IP` in `run_cluster.sh` (with a different value on each node). Use `ray status` and `ray list nodes` to verify the chosen IP address. For more information, see <gh-issue:7815>.
|
||||
@ -1,16 +0,0 @@
|
||||
# Troubleshooting distributed deployments
|
||||
|
||||
For general troubleshooting, see [Troubleshooting](../usage/troubleshooting.md).
|
||||
|
||||
## Verify inter-node GPU communication
|
||||
|
||||
After you start the Ray cluster, verify GPU-to-GPU communication across nodes. Proper configuration can be non-trivial. For more information, see [troubleshooting script][troubleshooting-incorrect-hardware-driver]. If you need additional environment variables for communication configuration, append them to <gh-file:examples/online_serving/run_cluster.sh>, for example `-e NCCL_SOCKET_IFNAME=eth0`. Setting environment variables during cluster creation is recommended because the variables propagate to all nodes. In contrast, setting environment variables in the shell affects only the local node. For more information, see <gh-issue:6803>.
|
||||
|
||||
## No available node types can fulfill resource request
|
||||
|
||||
The error message `Error: No available node types can fulfill resource request` can appear even when the cluster has enough GPUs. The issue often occurs when nodes have multiple IP addresses and vLLM can't select the correct one. Ensure that vLLM and Ray use the same IP address by setting `VLLM_HOST_IP` in <gh-file:examples/online_serving/run_cluster.sh> (with a different value on each node). Use `ray status` and `ray list nodes` to verify the chosen IP address. For more information, see <gh-issue:7815>.
|
||||
|
||||
## Ray observability
|
||||
|
||||
Debugging a distributed system can be challenging due to the large scale and complexity. Ray provides a suite of tools to help monitor, debug, and optimize Ray applications and clusters. For more information about Ray observability, visit the [official Ray observability docs](https://docs.ray.io/en/latest/ray-observability/index.html). For more information about debugging Ray applications, visit the [Ray Debugging Guide](https://docs.ray.io/en/latest/ray-observability/user-guides/debug-apps/index.html). For information about troubleshooting Kubernetes clusters, see the
|
||||
[official KubeRay troubleshooting guide](https://docs.ray.io/en/latest/serve/advanced-guides/multi-node-gpu-troubleshooting.html).
|
||||
@ -289,7 +289,7 @@ Traceback (most recent call last):
|
||||
...
|
||||
```
|
||||
|
||||
This indicates vLLM failed to initialize the NCCL communicator, possibly due to a missing `IPC_LOCK` linux capability or an unmounted `/dev/shm`. Refer to [Enabling GPUDirect RDMA](../serving/parallelism_scaling.md#enabling-gpudirect-rdma) for guidance on properly configuring the environment for GPUDirect RDMA.
|
||||
This indicates vLLM failed to initialize the NCCL communicator, possibly due to a missing `IPC_LOCK` linux capability or an unmounted `/dev/shm`. Refer to [Distributed Inference and Serving](../serving/distributed_serving.md#running-vllm-on-multiple-nodes) for guidance on properly configuring the environment for distributed serving.
|
||||
|
||||
## Known Issues
|
||||
|
||||
|
||||
@ -83,7 +83,7 @@ based on assigned priority, with FCFS as a tie-breaker), configurable via the
|
||||
| **Decoder-only Models** | <nobr>🚀 Optimized</nobr> |
|
||||
| **Encoder-Decoder Models** | <nobr>🟠 Delayed</nobr> |
|
||||
| **Embedding Models** | <nobr>🟢 Functional</nobr> |
|
||||
| **Mamba Models** | <nobr>🟢 (Mamba-2), 🟢 (Mamba-1)</nobr> |
|
||||
| **Mamba Models** | <nobr>🟢 (Mamba-2), 🟡 (Mamba-1)</nobr> |
|
||||
| **Multimodal Models** | <nobr>🟢 Functional</nobr> |
|
||||
|
||||
vLLM V1 currently excludes model architectures with the `SupportsV0Only` protocol.
|
||||
@ -104,11 +104,13 @@ to enable simultaneous generation and embedding using the same engine instance i
|
||||
|
||||
#### Mamba Models
|
||||
|
||||
Models using selective state-space mechanisms instead of standard transformer attention are supported.
|
||||
Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`) are supported. Please note that these models currently require disabling prefix caching in V1. Additionally, Mamba-1 models require `enforce_eager=True`.
|
||||
Models using selective state-space mechanisms instead of standard transformer attention are partially supported.
|
||||
Models that use Mamba-2 layers (e.g., `Mamba2ForCausalLM`) are supported, but models that use older Mamba-1 layers
|
||||
(e.g., `MambaForCausalLM`, `JambaForCausalLM`) are not yet supported. Please note that these models currently require
|
||||
disabling prefix caching in V1.
|
||||
|
||||
Models that combine Mamba-2 and Mamba-1 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`,
|
||||
`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`, `JambaForCausalLM`). Please note that
|
||||
Models that combine Mamba-2 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`,
|
||||
`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`). Please note that
|
||||
these models currently require disabling prefix caching and using the FlashInfer attention backend in V1.
|
||||
|
||||
#### Encoder-Decoder Models
|
||||
|
||||
@ -68,7 +68,7 @@ def run_simple_demo(args: argparse.Namespace):
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
tensor_parallel_size=2,
|
||||
mm_processor_cache_gb=0 if args.disable_mm_processor_cache else 4,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
)
|
||||
|
||||
prompt = "Describe this image in one sentence."
|
||||
@ -105,7 +105,7 @@ def run_advanced_demo(args: argparse.Namespace):
|
||||
limit_mm_per_prompt={"image": max_img_per_msg},
|
||||
max_model_len=max_img_per_msg * max_tokens_per_img,
|
||||
tensor_parallel_size=2,
|
||||
mm_processor_cache_gb=0 if args.disable_mm_processor_cache else 4,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
)
|
||||
|
||||
prompt = "Describe the following image."
|
||||
@ -164,9 +164,9 @@ def parse_args():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--disable-mm-processor-cache",
|
||||
"--disable-mm-preprocessor-cache",
|
||||
action="store_true",
|
||||
help="If True, disables caching of multi-modal processor.",
|
||||
help="If True, disables caching of multi-modal preprocessor/mapper.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -1563,9 +1563,9 @@ def parse_args():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--disable-mm-processor-cache",
|
||||
"--disable-mm-preprocessor-cache",
|
||||
action="store_true",
|
||||
help="If True, disables caching of multi-modal processor.",
|
||||
help="If True, disables caching of multi-modal preprocessor/mapper.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -1603,7 +1603,7 @@ def main(args):
|
||||
|
||||
engine_args = asdict(req_data.engine_args) | {
|
||||
"seed": args.seed,
|
||||
"mm_processor_cache_gb": 0 if args.disable_mm_processor_cache else 4,
|
||||
"disable_mm_preprocessor_cache": args.disable_mm_preprocessor_cache,
|
||||
}
|
||||
llm = LLM(**engine_args)
|
||||
|
||||
|
||||
@ -40,7 +40,6 @@ theme:
|
||||
- navigation.sections
|
||||
- navigation.prune
|
||||
- navigation.top
|
||||
- navigation.indexes
|
||||
- search.highlight
|
||||
- search.share
|
||||
- toc.follow
|
||||
@ -52,6 +51,11 @@ hooks:
|
||||
- docs/mkdocs/hooks/generate_argparse.py
|
||||
- docs/mkdocs/hooks/url_schemes.py
|
||||
|
||||
# Required to stop api-autonav from raising an error
|
||||
# https://github.com/tlambert03/mkdocs-api-autonav/issues/16
|
||||
nav:
|
||||
- api
|
||||
|
||||
plugins:
|
||||
- meta
|
||||
- search
|
||||
|
||||
@ -73,6 +73,8 @@ line-length = 80
|
||||
"vllm/engine/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/executor/**/*.py" = ["UP006", "UP035"]
|
||||
"vllm/worker/**/*.py" = ["UP006", "UP035"]
|
||||
# Python 3.8 typing - skip utils for ROCm
|
||||
"vllm/utils/__init__.py" = ["UP006", "UP035"]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
|
||||
@ -8,11 +8,12 @@ tqdm
|
||||
blake3
|
||||
py-cpuinfo
|
||||
transformers >= 4.55.0
|
||||
huggingface-hub[hf_xet] >= 0.33.0 # Required for Xet downloads.
|
||||
tokenizers >= 0.21.1 # Required for fast incremental detokenization.
|
||||
protobuf # Required by LlamaTokenizer.
|
||||
fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint.
|
||||
aiohttp
|
||||
openai >= 1.99.1 # For Responses API with reasoning content
|
||||
openai >= 1.98.0 # For Responses API with reasoning content
|
||||
pydantic >= 2.10
|
||||
prometheus_client >= 0.18.0
|
||||
pillow # Required for image processing
|
||||
|
||||
@ -19,7 +19,6 @@ cloudpickle
|
||||
fastapi
|
||||
msgspec
|
||||
openai
|
||||
openai-harmony
|
||||
partial-json-parser
|
||||
pillow
|
||||
psutil
|
||||
|
||||
@ -31,6 +31,7 @@ lm-eval[api]==0.4.8 # required for model evaluation test
|
||||
mteb>=1.38.11, <2 # required for mteb test
|
||||
transformers==4.52.4
|
||||
tokenizers==0.21.1
|
||||
huggingface-hub[hf_xet]>=0.30.0 # Required for Xet downloads.
|
||||
schemathesis>=3.39.15 # Required for openai schema test.
|
||||
# quantization
|
||||
bitsandbytes>=0.46.1
|
||||
|
||||
@ -26,7 +26,6 @@ torch==2.7.1
|
||||
torchaudio==2.7.1
|
||||
torchvision==0.22.1
|
||||
transformers_stream_generator # required for qwen-vl test
|
||||
mamba_ssm==2.2.5 # required for plamo2 test
|
||||
matplotlib # required for qwen-vl test
|
||||
mistral_common[image,audio] >= 1.8.2 # required for voxtral test
|
||||
num2words # required for smolvlm test
|
||||
@ -37,6 +36,7 @@ lm-eval[api]==0.4.8 # required for model evaluation test
|
||||
mteb[bm25s]>=1.38.11, <2 # required for mteb test
|
||||
transformers==4.55.0
|
||||
tokenizers==0.21.1
|
||||
huggingface-hub[hf_xet]>=0.33.0 # Required for Xet downloads.
|
||||
schemathesis>=3.39.15 # Required for openai schema test.
|
||||
# quantization
|
||||
bitsandbytes==0.46.1
|
||||
@ -53,4 +53,4 @@ runai-model-streamer==0.11.0
|
||||
runai-model-streamer-s3==0.11.0
|
||||
fastsafetensors>=0.1.10
|
||||
pydantic>=2.10 # 2.9 leads to error on python 3.10
|
||||
terratorch==1.1rc2 # required for PrithviMAE test
|
||||
terratorch==1.1rc2 # required for PrithviMAE test
|
||||
|
||||
@ -178,7 +178,6 @@ einops==0.8.1
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# encodec
|
||||
# mamba-ssm
|
||||
# terratorch
|
||||
# torchgeo
|
||||
# vector-quantize-pytorch
|
||||
@ -276,7 +275,7 @@ h5py==3.13.0
|
||||
# via terratorch
|
||||
harfile==0.3.0
|
||||
# via schemathesis
|
||||
hf-xet==1.1.7
|
||||
hf-xet==1.1.3
|
||||
# via huggingface-hub
|
||||
hiredis==3.0.0
|
||||
# via tensorizer
|
||||
@ -288,6 +287,7 @@ httpx==0.27.2
|
||||
# schemathesis
|
||||
huggingface-hub==0.34.3
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# accelerate
|
||||
# datasets
|
||||
# evaluate
|
||||
@ -417,8 +417,6 @@ lxml==5.3.0
|
||||
# sacrebleu
|
||||
mako==1.3.10
|
||||
# via alembic
|
||||
mamba-ssm==2.2.5
|
||||
# via -r requirements/test.in
|
||||
markdown==3.8.2
|
||||
# via mlflow
|
||||
markdown-it-py==3.0.0
|
||||
@ -475,8 +473,6 @@ networkx==3.2.1
|
||||
# via
|
||||
# scikit-image
|
||||
# torch
|
||||
ninja==1.11.1.3
|
||||
# via mamba-ssm
|
||||
nltk==3.9.1
|
||||
# via rouge-score
|
||||
num2words==0.5.14
|
||||
@ -629,7 +625,6 @@ packaging==24.2
|
||||
# lazy-loader
|
||||
# lightning
|
||||
# lightning-utilities
|
||||
# mamba-ssm
|
||||
# matplotlib
|
||||
# mlflow-skinny
|
||||
# peft
|
||||
@ -973,7 +968,6 @@ sentencepiece==0.2.0
|
||||
setuptools==77.0.3
|
||||
# via
|
||||
# lightning-utilities
|
||||
# mamba-ssm
|
||||
# pytablewriter
|
||||
# torch
|
||||
# triton
|
||||
@ -1085,7 +1079,6 @@ torch==2.7.1+cu128
|
||||
# lightly
|
||||
# lightning
|
||||
# lm-eval
|
||||
# mamba-ssm
|
||||
# mteb
|
||||
# open-clip-torch
|
||||
# peft
|
||||
@ -1152,16 +1145,13 @@ transformers==4.55.0
|
||||
# -r requirements/test.in
|
||||
# genai-perf
|
||||
# lm-eval
|
||||
# mamba-ssm
|
||||
# peft
|
||||
# sentence-transformers
|
||||
# transformers-stream-generator
|
||||
transformers-stream-generator==0.0.5
|
||||
# via -r requirements/test.in
|
||||
triton==3.3.1
|
||||
# via
|
||||
# mamba-ssm
|
||||
# torch
|
||||
# via torch
|
||||
tritonclient==2.51.0
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
|
||||
@ -10,10 +10,15 @@ wheel
|
||||
jinja2>=3.1.6
|
||||
datasets # for benchmark scripts
|
||||
numba == 0.60.0 # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding
|
||||
--extra-index-url=https://download.pytorch.org/whl/xpu
|
||||
torch==2.8.0+xpu
|
||||
|
||||
torch==2.7.0+xpu
|
||||
torchaudio
|
||||
torchvision
|
||||
pytorch-triton-xpu
|
||||
--extra-index-url=https://download.pytorch.org/whl/xpu
|
||||
|
||||
# Please refer xpu doc, we need manually install intel-extension-for-pytorch 2.6.10+xpu due to there are some conflict dependencies with torch 2.6.0+xpu
|
||||
# FIXME: This will be fix in ipex 2.7. just leave this here for awareness.
|
||||
intel-extension-for-pytorch==2.7.10+xpu
|
||||
oneccl_bind_pt==2.7.0+xpu
|
||||
--extra-index-url=https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||
intel-extension-for-pytorch==2.8.10+xpu
|
||||
|
||||
2
setup.py
2
setup.py
@ -665,7 +665,7 @@ setup(
|
||||
"mistral_common[audio]"], # Required for audio processing
|
||||
"video": [], # Kept for backwards compatibility
|
||||
# FlashInfer should be updated together with the Dockerfile
|
||||
"flashinfer": ["flashinfer-python==0.2.10"],
|
||||
"flashinfer": ["flashinfer-python==0.2.9"],
|
||||
},
|
||||
cmdclass=cmdclass,
|
||||
package_data=package_data,
|
||||
|
||||
@ -10,7 +10,8 @@ import torch.distributed as dist
|
||||
|
||||
from vllm.distributed.communication_op import ( # noqa
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed.parallel_state import get_tp_group, graph_capture
|
||||
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
|
||||
get_tp_group, graph_capture)
|
||||
|
||||
from ..utils import (ensure_model_parallel_initialized,
|
||||
init_test_distributed_environment, multi_process_parallel)
|
||||
@ -36,7 +37,7 @@ def graph_allreduce(
|
||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
||||
distributed_init_port)
|
||||
ensure_model_parallel_initialized(tp_size, pp_size)
|
||||
group = get_tp_group().device_group
|
||||
group = get_tensor_model_parallel_group().device_group
|
||||
|
||||
# A small all_reduce for warmup.
|
||||
# this is needed because device communicators might be created lazily
|
||||
|
||||
@ -10,7 +10,8 @@ import torch.distributed as dist
|
||||
|
||||
from vllm.distributed.communication_op import ( # noqa
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed.parallel_state import get_tp_group, graph_capture
|
||||
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
|
||||
get_tp_group, graph_capture)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..utils import (ensure_model_parallel_initialized,
|
||||
@ -41,7 +42,7 @@ def graph_quickreduce(
|
||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
||||
distributed_init_port)
|
||||
ensure_model_parallel_initialized(tp_size, pp_size)
|
||||
group = get_tp_group().device_group
|
||||
group = get_tensor_model_parallel_group().device_group
|
||||
|
||||
# A small all_reduce for warmup.
|
||||
# this is needed because device communicators might be created lazily
|
||||
|
||||
@ -93,6 +93,32 @@ class NestedConfig:
|
||||
"""field"""
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class FromCliConfig1:
|
||||
field: int = 1
|
||||
"""field"""
|
||||
|
||||
@classmethod
|
||||
def from_cli(cls, cli_value: str):
|
||||
inst = cls(**json.loads(cli_value))
|
||||
inst.field += 1
|
||||
return inst
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class FromCliConfig2:
|
||||
field: int = 1
|
||||
"""field"""
|
||||
|
||||
@classmethod
|
||||
def from_cli(cls, cli_value: str):
|
||||
inst = cls(**json.loads(cli_value))
|
||||
inst.field += 2
|
||||
return inst
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class DummyConfig:
|
||||
@ -118,6 +144,10 @@ class DummyConfig:
|
||||
"""Dict which will be JSON in CLI"""
|
||||
nested_config: NestedConfig = field(default_factory=NestedConfig)
|
||||
"""Nested config"""
|
||||
from_cli_config1: FromCliConfig1 = field(default_factory=FromCliConfig1)
|
||||
"""Config with from_cli method"""
|
||||
from_cli_config2: FromCliConfig2 = field(default_factory=FromCliConfig2)
|
||||
"""Different config with from_cli method"""
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("type_hint", "expected"), [
|
||||
@ -169,6 +199,9 @@ def test_get_kwargs():
|
||||
assert json_tip in kwargs["json_tip"]["help"]
|
||||
# nested config should should construct the nested config
|
||||
assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2)
|
||||
# from_cli configs should be constructed with the correct method
|
||||
assert kwargs["from_cli_config1"]["type"]('{"field": 2}').field == 3
|
||||
assert kwargs["from_cli_config2"]["type"]('{"field": 2}').field == 4
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@ -2,12 +2,15 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import random
|
||||
import time
|
||||
from typing import Callable
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import requests
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
@ -84,3 +87,54 @@ async def test_with_and_without_truncate(
|
||||
|
||||
responses = await asyncio.gather(*[get_status_code(**b) for b in bodies])
|
||||
assert 500 not in responses
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
ids=["single completion", "multiple completions", "chat"],
|
||||
argnames=["create_func_gen", "content_body"],
|
||||
argvalues=[
|
||||
(lambda x: x.completions.create, {
|
||||
"prompt": " ".join(['A'] * 300_000)
|
||||
}),
|
||||
(lambda x: x.completions.create, {
|
||||
"prompt": [" ".join(['A'] * 300_000)] * 2
|
||||
}),
|
||||
(lambda x: x.chat.completions.create, {
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": " ".join(['A'] * 300_000)
|
||||
}]
|
||||
}),
|
||||
],
|
||||
)
|
||||
async def test_healthcheck_response_time(
|
||||
server: RemoteOpenAIServer,
|
||||
client: openai.AsyncOpenAI,
|
||||
create_func_gen: Callable,
|
||||
content_body: dict,
|
||||
):
|
||||
num_requests = 50
|
||||
|
||||
create_func = create_func_gen(client)
|
||||
body = {"model": MODEL_NAME, **content_body, "max_tokens": 10}
|
||||
|
||||
def get_response_time(url):
|
||||
start_time = time.monotonic()
|
||||
res = requests.get(url)
|
||||
end_time = time.monotonic()
|
||||
assert res.status_code == 200
|
||||
return end_time - start_time
|
||||
|
||||
no_load_response_time = get_response_time(server.url_for("health"))
|
||||
tasks = [
|
||||
asyncio.create_task(create_func(**body)) for _ in range(num_requests)
|
||||
]
|
||||
await asyncio.sleep(1) # give the tasks a chance to start running
|
||||
load_response_time = get_response_time(server.url_for("health"))
|
||||
|
||||
with contextlib.suppress(openai.APIStatusError):
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
assert load_response_time < 100 * no_load_response_time
|
||||
assert load_response_time < 0.1
|
||||
|
||||
@ -121,7 +121,8 @@ def test_invalid_truncate_prompt_tokens_error(server: RemoteOpenAIServer,
|
||||
|
||||
error = classification_response.json()
|
||||
assert classification_response.status_code == 400
|
||||
assert "truncate_prompt_tokens" in error["error"]["message"]
|
||||
assert error["object"] == "error"
|
||||
assert "truncate_prompt_tokens" in error["message"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@ -136,7 +137,7 @@ def test_empty_input_error(server: RemoteOpenAIServer, model_name: str):
|
||||
|
||||
error = classification_response.json()
|
||||
assert classification_response.status_code == 400
|
||||
assert "error" in error
|
||||
assert error["object"] == "error"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
|
||||
@ -160,8 +160,8 @@ async def test_serving_completion_resolver_not_found(mock_serving_setup,
|
||||
mock_engine.generate.assert_not_called()
|
||||
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert response.error.code == HTTPStatus.NOT_FOUND.value
|
||||
assert non_existent_model in response.error.message
|
||||
assert response.code == HTTPStatus.NOT_FOUND.value
|
||||
assert non_existent_model in response.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -190,8 +190,8 @@ async def test_serving_completion_resolver_add_lora_fails(
|
||||
|
||||
# Assert the correct error response
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert response.error.code == HTTPStatus.BAD_REQUEST.value
|
||||
assert invalid_model in response.error.message
|
||||
assert response.code == HTTPStatus.BAD_REQUEST.value
|
||||
assert invalid_model in response.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@ -66,8 +66,8 @@ async def test_load_lora_adapter_missing_fields():
|
||||
request = LoadLoRAAdapterRequest(lora_name="", lora_path="")
|
||||
response = await serving_models.load_lora_adapter(request)
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert response.error.type == "InvalidUserInput"
|
||||
assert response.error.code == HTTPStatus.BAD_REQUEST
|
||||
assert response.type == "InvalidUserInput"
|
||||
assert response.code == HTTPStatus.BAD_REQUEST
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -84,8 +84,8 @@ async def test_load_lora_adapter_duplicate():
|
||||
lora_path="/path/to/adapter1")
|
||||
response = await serving_models.load_lora_adapter(request)
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert response.error.type == "InvalidUserInput"
|
||||
assert response.error.code == HTTPStatus.BAD_REQUEST
|
||||
assert response.type == "InvalidUserInput"
|
||||
assert response.code == HTTPStatus.BAD_REQUEST
|
||||
assert len(serving_models.lora_requests) == 1
|
||||
|
||||
|
||||
@ -110,8 +110,8 @@ async def test_unload_lora_adapter_missing_fields():
|
||||
request = UnloadLoRAAdapterRequest(lora_name="", lora_int_id=None)
|
||||
response = await serving_models.unload_lora_adapter(request)
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert response.error.type == "InvalidUserInput"
|
||||
assert response.error.code == HTTPStatus.BAD_REQUEST
|
||||
assert response.type == "InvalidUserInput"
|
||||
assert response.code == HTTPStatus.BAD_REQUEST
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -120,5 +120,5 @@ async def test_unload_lora_adapter_not_found():
|
||||
request = UnloadLoRAAdapterRequest(lora_name="nonexistent_adapter")
|
||||
response = await serving_models.unload_lora_adapter(request)
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert response.error.type == "NotFoundError"
|
||||
assert response.error.code == HTTPStatus.NOT_FOUND
|
||||
assert response.type == "NotFoundError"
|
||||
assert response.code == HTTPStatus.NOT_FOUND
|
||||
|
||||
@ -116,10 +116,8 @@ async def test_non_asr_model(winning_call):
|
||||
file=winning_call,
|
||||
language="en",
|
||||
temperature=0.0)
|
||||
err = res.error
|
||||
assert err["code"] == 400 and not res.text
|
||||
assert err[
|
||||
"message"] == "The model does not support Transcriptions API"
|
||||
assert res.code == 400 and not res.text
|
||||
assert res.message == "The model does not support Transcriptions API"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -135,15 +133,12 @@ async def test_completion_endpoints():
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}])
|
||||
err = res.error
|
||||
assert err["code"] == 400
|
||||
assert err[
|
||||
"message"] == "The model does not support Chat Completions API"
|
||||
assert res.code == 400
|
||||
assert res.message == "The model does not support Chat Completions API"
|
||||
|
||||
res = await client.completions.create(model=model_name, prompt="Hello")
|
||||
err = res.error
|
||||
assert err["code"] == 400
|
||||
assert err["message"] == "The model does not support Completions API"
|
||||
assert res.code == 400
|
||||
assert res.message == "The model does not support Completions API"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@ -73,9 +73,8 @@ async def test_non_asr_model(foscolo):
|
||||
res = await client.audio.translations.create(model=model_name,
|
||||
file=foscolo,
|
||||
temperature=0.0)
|
||||
err = res.error
|
||||
assert err["code"] == 400 and not res.text
|
||||
assert err["message"] == "The model does not support Translations API"
|
||||
assert res.code == 400 and not res.text
|
||||
assert res.message == "The model does not support Translations API"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@ -1,43 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--enforce-eager",
|
||||
"--max-num-seqs",
|
||||
"128",
|
||||
"--uds",
|
||||
f"{tmpdir}/vllm.sock",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_show_version(server: RemoteOpenAIServer):
|
||||
transport = httpx.HTTPTransport(uds=server.uds)
|
||||
client = httpx.Client(transport=transport)
|
||||
response = client.get(server.url_for("version"))
|
||||
response.raise_for_status()
|
||||
|
||||
assert response.json() == {"version": VLLM_VERSION}
|
||||
@ -1,375 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass, fields
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import triton_kernels.swiglu
|
||||
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
||||
from triton_kernels.numerics import InFlexData
|
||||
from triton_kernels.numerics_details.mxfp import (downcast_to_mxfp,
|
||||
upcast_from_mxfp)
|
||||
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
|
||||
from triton_kernels.tensor_details import layout
|
||||
from triton_kernels.testing import assert_close
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedPrepareAndFinalize)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
|
||||
BatchedOAITritonExperts, triton_kernel_moe_forward)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
from vllm.model_executor.layers.utils import shuffle_weight
|
||||
from vllm.utils import round_up
|
||||
|
||||
|
||||
def deshuffle(w: torch.Tensor):
|
||||
first = w[..., ::2]
|
||||
second = w[..., 1::2]
|
||||
|
||||
deshuffled = torch.concat((first, second), dim=-1)
|
||||
return deshuffled
|
||||
|
||||
|
||||
def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
|
||||
randbits = [torch.randperm(E) for _ in range(M)]
|
||||
x_list = [
|
||||
(-1)**i *
|
||||
((16384 +
|
||||
((i * 512) % 4096) + bits).to(torch.int16).view(torch.bfloat16))
|
||||
for i, bits in enumerate(randbits)
|
||||
]
|
||||
exp_data = torch.stack(x_list).to(
|
||||
device="cuda") # simulating gate_output (M, E)
|
||||
|
||||
# create input tensor
|
||||
x = torch.randn((M, K), dtype=torch.bfloat16, device="cuda")
|
||||
w1 = torch.randn((E, 2 * N, K), dtype=torch.bfloat16, device="cuda")
|
||||
w1_bias = torch.randn((E, 2 * N), dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
w2 = torch.randn((E, K, N), dtype=torch.bfloat16, device="cuda")
|
||||
w2_bias = torch.randn((E, K), dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
exp_data_tri = exp_data.clone()
|
||||
x_tri = x.clone()
|
||||
w1_tri = w1.clone()
|
||||
w2_tri = w2.clone()
|
||||
|
||||
w1_bias_tri = w1_bias.clone()
|
||||
w2_bias_tri = w2_bias.clone()
|
||||
w1_bias_tri = w1_bias_tri.to(torch.float32)
|
||||
w2_bias_tri = w2_bias_tri.to(torch.float32)
|
||||
|
||||
dtype_dict = {
|
||||
"bf16": torch.bfloat16,
|
||||
"fp8_e4m3": torch.float8_e4m3fn,
|
||||
"fp8_e5m2": torch.float8_e5m2
|
||||
}
|
||||
|
||||
x = x.to(dtype_dict[a_dtype]).to(torch.bfloat16)
|
||||
if w_dtype != "mx4":
|
||||
# simulate quantization support on reference impl
|
||||
w1 = w1.to(dtype_dict[w_dtype]).to(torch.bfloat16)
|
||||
w2 = w2.to(dtype_dict[w_dtype]).to(torch.bfloat16)
|
||||
|
||||
# triton moe kernel use transposed shape for matmul
|
||||
w1_tri = w1_tri.transpose(-2, -1)
|
||||
w2_tri = w2_tri.transpose(-2, -1)
|
||||
|
||||
# shuffle weights
|
||||
w1_tri = shuffle_weight(w1_tri)
|
||||
w1_bias_tri = shuffle_weight(w1_bias_tri)
|
||||
|
||||
# quant triton_weights
|
||||
x_tri = x.to(dtype_dict[a_dtype])
|
||||
if w_dtype != "mx4":
|
||||
pytest.skip("NYI")
|
||||
else: # quantize to mx4
|
||||
# careful on the padding here, the activation padding need to be
|
||||
# multiple of 64, the actual engine is not implemented
|
||||
w1_bottom_pad = round_up(w1_tri.shape[1], 64) - w1_tri.shape[1]
|
||||
w1_right_pad = round_up(w1_tri.shape[2], 128) - w1_tri.shape[2]
|
||||
|
||||
w2_bottom_pad = w1_right_pad // 2
|
||||
w2_right_pad = w1_bottom_pad
|
||||
|
||||
x_pad = w1_bottom_pad
|
||||
|
||||
w1_tri = F.pad(w1_tri, (0, w1_right_pad, 0, w1_bottom_pad, 0, 0),
|
||||
mode="constant",
|
||||
value=0)
|
||||
w2_tri = F.pad(w2_tri, (0, w2_right_pad, 0, w2_bottom_pad, 0, 0),
|
||||
mode="constant",
|
||||
value=0)
|
||||
|
||||
w1_bias_tri = F.pad(w1_bias_tri, (0, w1_right_pad, 0, 0),
|
||||
mode="constant",
|
||||
value=0)
|
||||
w2_bias_tri = F.pad(w2_bias_tri, (0, w2_right_pad, 0, 0),
|
||||
mode="constant",
|
||||
value=0)
|
||||
|
||||
x_tri = F.pad(x_tri, (0, x_pad, 0, 0), mode="constant", value=0)
|
||||
|
||||
w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout(
|
||||
mx_axis=1)
|
||||
w_scale_layout, w_scale_layout_opts = (
|
||||
layout.make_default_matmul_mxfp4_w_scale_layout(
|
||||
mx_axis=1, num_warps=num_warps))
|
||||
|
||||
w1_tri, w1_scale_tri = downcast_to_mxfp(w1_tri, torch.uint8, axis=1)
|
||||
w1 = upcast_from_mxfp(w1_tri, w1_scale_tri, torch.bfloat16, axis=1)
|
||||
|
||||
w2_tri, w2_scale_tri = downcast_to_mxfp(w2_tri, torch.uint8, axis=1)
|
||||
w2 = upcast_from_mxfp(w2_tri, w2_scale_tri, torch.bfloat16, axis=1)
|
||||
|
||||
w1_tri = convert_layout(wrap_torch_tensor(w1_tri, FP4), w_layout,
|
||||
**w_layout_opts)
|
||||
w1_scale_tri = convert_layout(wrap_torch_tensor(w1_scale_tri),
|
||||
w_scale_layout, **w_scale_layout_opts)
|
||||
|
||||
w2_tri = convert_layout(wrap_torch_tensor(w2_tri, FP4), w_layout,
|
||||
**w_layout_opts)
|
||||
w2_scale_tri = convert_layout(wrap_torch_tensor(w2_scale_tri),
|
||||
w_scale_layout, **w_scale_layout_opts)
|
||||
|
||||
pc1 = PrecisionConfig(weight_scale=w1_scale_tri,
|
||||
flex_ctx=FlexCtx(rhs_data=InFlexData()))
|
||||
pc2 = PrecisionConfig(weight_scale=w2_scale_tri,
|
||||
flex_ctx=FlexCtx(rhs_data=InFlexData()))
|
||||
|
||||
# tucuate so the rest can run properly
|
||||
w1 = w1[..., :K, :2 * N]
|
||||
w2 = w2[..., :N, :K]
|
||||
|
||||
w1 = deshuffle(w1)
|
||||
|
||||
w1 = w1.transpose(-1, -2).contiguous()
|
||||
w2 = w2.transpose(-1, -2).contiguous()
|
||||
|
||||
return (x, w1, w1_bias, w2, w2_bias, exp_data, x_tri, w1_tri, w2_tri,
|
||||
exp_data_tri, w1_bias_tri, w2_bias_tri, pc1, pc2)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
num_hidden_layers: int = 36
|
||||
num_experts: int = 128
|
||||
experts_per_token: int = 4
|
||||
vocab_size: int = 201088
|
||||
hidden_size: int = 2880
|
||||
intermediate_size: int = 2880
|
||||
head_dim: int = 64
|
||||
num_attention_heads: int = 64
|
||||
num_key_value_heads: int = 8
|
||||
sliding_window: int = 128
|
||||
initial_context_length: int = 4096
|
||||
rope_theta: float = 150000.0
|
||||
rope_scaling_factor: float = 32.0
|
||||
rope_ntk_alpha: float = 1.0
|
||||
rope_ntk_beta: float = 32.0
|
||||
|
||||
|
||||
def swiglu(x, alpha: float = 1.702, limit: float = 1.0):
|
||||
# Note we add an extra bias of 1 to the linear layer
|
||||
x_glu, x_linear = torch.chunk(x, 2, dim=-1)
|
||||
if limit is not None:
|
||||
x_glu = x_glu.clamp(max=limit)
|
||||
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
|
||||
if limit is not None:
|
||||
x_linear = x_linear.clamp(min=-limit, max=limit)
|
||||
return out_glu * (x_linear + 1)
|
||||
|
||||
|
||||
def oai_moe_forward(
|
||||
hidden_states: torch.Tensor, # (M, K)
|
||||
w1: torch.Tensor, # (E, 2N)
|
||||
w1_bias: torch.Tensor, # (E, 2N, K)
|
||||
w2: torch.Tensor, # (E, K, N)
|
||||
w2_bias: torch.Tensor, # (E, N)
|
||||
gating_output: torch.Tensor, # (M, E)
|
||||
topk: int):
|
||||
# model.py 309:330, assuming gating and norm
|
||||
t = hidden_states
|
||||
experts = torch.topk(gating_output, k=topk, dim=-1, sorted=True)
|
||||
expert_weights = torch.nn.functional.softmax(experts.values, dim=1)
|
||||
expert_indices = experts.indices
|
||||
|
||||
# MLP #1
|
||||
mlp1_weight = w1[expert_indices, ...]
|
||||
mlp1_bias = w1_bias[expert_indices, ...]
|
||||
t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias
|
||||
t = swiglu(t, limit=7)
|
||||
|
||||
# MLP #2
|
||||
mlp2_weight = w2[expert_indices, ...]
|
||||
mlp2_bias = w2_bias[expert_indices, ...]
|
||||
t = torch.einsum("beck,bek->bec", mlp2_weight, t)
|
||||
t += mlp2_bias
|
||||
|
||||
# Weighted sum of experts
|
||||
t = torch.einsum("bec,be->bc", t, expert_weights)
|
||||
|
||||
return t
|
||||
|
||||
|
||||
@dataclass
|
||||
class Case:
|
||||
a_dtype: str
|
||||
w_dtype: str
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
", ".join(f.name for f in fields(Case)),
|
||||
[
|
||||
tuple(getattr(case, f.name) for f in fields(Case)) for case in [
|
||||
# Case(a_dtype="bf16", w_dtype="bf16"),
|
||||
# Case(a_dtype="fp8_e4m3", w_dtype="fp8_e5m2"),
|
||||
Case(a_dtype="bf16", w_dtype="mx4")
|
||||
]
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("num_token", [2])
|
||||
@pytest.mark.parametrize("tp", [1, 2, 4, 8])
|
||||
def test_equiv(num_token, a_dtype, w_dtype, tp):
|
||||
M = num_token
|
||||
E = ModelConfig.num_experts
|
||||
K = ModelConfig.hidden_size
|
||||
N = ModelConfig.intermediate_size // tp
|
||||
topk = ModelConfig.experts_per_token
|
||||
|
||||
x, w1, w1_bias, w2, w2_bias, exp_data, \
|
||||
x_tri, w1_tri, w2_tri, exp_data_tri, w1_bias_tri,\
|
||||
w2_bias_tri, pc1, pc2 = init_compute_data(
|
||||
M, K, N, E, a_dtype, w_dtype, num_warps=8)
|
||||
|
||||
out_triton_monolithic = triton_kernel_moe_forward(
|
||||
hidden_states=x_tri,
|
||||
w1=w1_tri,
|
||||
w2=w2_tri,
|
||||
gating_output=exp_data_tri,
|
||||
topk=topk,
|
||||
renormalize=True,
|
||||
w1_bias=w1_bias_tri,
|
||||
w2_bias=w2_bias_tri,
|
||||
w1_precision=pc1,
|
||||
w2_precision=pc2)
|
||||
out_triton_monolithic = out_triton_monolithic[..., :K]
|
||||
|
||||
out_ref = oai_moe_forward(hidden_states=x,
|
||||
w1=w1,
|
||||
w1_bias=w1_bias,
|
||||
w2=w2,
|
||||
w2_bias=w2_bias,
|
||||
gating_output=exp_data,
|
||||
topk=topk)
|
||||
assert_close(ref=out_ref,
|
||||
tri=out_triton_monolithic,
|
||||
maxtol=0.025,
|
||||
rmstol=0.005)
|
||||
|
||||
|
||||
def batched_moe(a: torch.Tensor, w1, w2, gating_output: torch.Tensor,
|
||||
topk: int, renormalize: bool, w1_bias: torch.Tensor,
|
||||
w2_bias: torch.Tensor, w1_precision: PrecisionConfig,
|
||||
w2_precision: PrecisionConfig) -> torch.Tensor:
|
||||
max_num_tokens = round_up(a.shape[0], 64)
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
BatchedPrepareAndFinalize(max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
num_local_experts=w1.shape[0],
|
||||
rank=0),
|
||||
BatchedOAITritonExperts(
|
||||
None,
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
w1_precision=w1_precision,
|
||||
w2_precision=w2_precision,
|
||||
),
|
||||
)
|
||||
|
||||
extra_expert_args = {
|
||||
"w1_bias": w1_bias,
|
||||
"w2_bias": w2_bias,
|
||||
}
|
||||
|
||||
topk_weight, topk_ids, _ = fused_topk(a, gating_output, topk, renormalize)
|
||||
|
||||
return fused_experts(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
extra_expert_args=extra_expert_args,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
", ".join(f.name for f in fields(Case)),
|
||||
[
|
||||
tuple(getattr(case, f.name) for f in fields(Case)) for case in [
|
||||
# Case(a_dtype="bf16", w_dtype="bf16"),
|
||||
# Case(a_dtype="fp8_e4m3", w_dtype="fp8_e5m2"),
|
||||
Case(a_dtype="bf16", w_dtype="mx4")
|
||||
]
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("num_token", [64])
|
||||
@pytest.mark.parametrize("ep", [1, 2, 4, 8])
|
||||
def test_triton_kernel_batched_moe(num_token, a_dtype, w_dtype, ep):
|
||||
M = num_token
|
||||
E = ModelConfig.num_experts // ep
|
||||
K = ModelConfig.hidden_size
|
||||
N = ModelConfig.intermediate_size
|
||||
topk = ModelConfig.experts_per_token
|
||||
|
||||
x, w1, w1_bias, w2, w2_bias, exp_data, \
|
||||
x_tri, w1_tri, w2_tri, exp_data_tri, w1_bias_tri, \
|
||||
w2_bias_tri, pc1, pc2 = init_compute_data(
|
||||
M, K, N, E, a_dtype, w_dtype, num_warps=4)
|
||||
|
||||
out_tri = batched_moe(a=x_tri,
|
||||
w1=w1_tri,
|
||||
w2=w2_tri,
|
||||
gating_output=exp_data_tri,
|
||||
topk=topk,
|
||||
renormalize=True,
|
||||
w1_bias=w1_bias_tri,
|
||||
w2_bias=w2_bias_tri,
|
||||
w1_precision=pc1,
|
||||
w2_precision=pc2)
|
||||
out_tri = out_tri[..., :K]
|
||||
|
||||
out_ref = oai_moe_forward(hidden_states=x,
|
||||
w1=w1,
|
||||
w1_bias=w1_bias,
|
||||
w2=w2,
|
||||
w2_bias=w2_bias,
|
||||
gating_output=exp_data,
|
||||
topk=topk)
|
||||
assert_close(ref=out_ref, tri=out_tri, maxtol=0.025, rmstol=0.005)
|
||||
|
||||
|
||||
def test_unit_shuffle():
|
||||
N = ModelConfig.intermediate_size
|
||||
K = ModelConfig.hidden_size
|
||||
m = torch.randn((K, 2 * N), dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
x = torch.randn(K, dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
m_shuffled = shuffle_weight(m)
|
||||
|
||||
out_ref = x @ m
|
||||
out_ref = swiglu(out_ref, limit=1.0)
|
||||
|
||||
out = x @ m_shuffled
|
||||
out = triton_kernels.swiglu.swiglu_torch(
|
||||
out,
|
||||
alpha=1.702,
|
||||
precision_config=triton_kernels.swiglu.PrecisionConfig(limit=1.0))
|
||||
|
||||
assert_close(ref=out_ref, tri=out)
|
||||
@ -9,9 +9,7 @@ import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
from ..models.utils import check_embeddings_close
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
TORCH_VERSION = version.parse(torch.__version__)
|
||||
MINIMUM_TORCH_VERSION = version.parse("2.7.0")
|
||||
@ -30,7 +28,7 @@ def set_seed(seed):
|
||||
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
|
||||
reason="CUDA not available or PyTorch version < 2.7",
|
||||
)
|
||||
def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
|
||||
def test_flex_attention_vs_default_backend(monkeypatch):
|
||||
"""Test that FlexAttention produces the same outputs as the default backend.
|
||||
|
||||
This test compares the outputs from the FlexAttention backend with
|
||||
@ -38,7 +36,7 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
|
||||
"""
|
||||
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
seed = 42
|
||||
max_tokens = 24
|
||||
max_tokens = 32
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
@ -56,30 +54,33 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
|
||||
|
||||
set_seed(seed)
|
||||
with vllm_runner(model_name,
|
||||
runner="generate",
|
||||
tensor_parallel_size=1,
|
||||
num_gpu_blocks_override=128,
|
||||
enforce_eager=True) as llm_flex:
|
||||
output_flex = llm_flex.generate(prompts, sampling_params)
|
||||
|
||||
llm_flex = LLM(
|
||||
model_name,
|
||||
tensor_parallel_size=1,
|
||||
num_gpu_blocks_override=128,
|
||||
enforce_eager=True,
|
||||
)
|
||||
output_flex = llm_flex.generate(prompts, sampling_params)
|
||||
|
||||
# Run with default backend
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
set_seed(seed)
|
||||
with vllm_runner(model_name,
|
||||
runner="generate",
|
||||
tensor_parallel_size=1,
|
||||
num_gpu_blocks_override=128,
|
||||
enforce_eager=True) as llm_default:
|
||||
output_default = llm_default.generate(prompts, sampling_params)
|
||||
llm_default = LLM(
|
||||
model_name,
|
||||
tensor_parallel_size=1,
|
||||
num_gpu_blocks_override=128,
|
||||
enforce_eager=True,
|
||||
)
|
||||
output_default = llm_default.generate(prompts, sampling_params)
|
||||
|
||||
# Compare outputs from both backends
|
||||
for i, (flex_result,
|
||||
default_result) in enumerate(zip(output_flex, output_default)):
|
||||
prompt = prompts[i]
|
||||
flex_text = flex_result[1][0]
|
||||
default_text = default_result[1][0]
|
||||
flex_text = flex_result.outputs[0].text
|
||||
default_text = default_result.outputs[0].text
|
||||
|
||||
assert flex_text == default_text, (
|
||||
f"FlexAttention output doesn't match default for: {prompt!r}\n"
|
||||
@ -87,54 +88,5 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
|
||||
f"Default: {default_text!r}")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
|
||||
reason="CUDA not available or PyTorch version < 2.7",
|
||||
)
|
||||
def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
|
||||
"""Test that FlexAttention produces the same outputs as the default backend.
|
||||
|
||||
This test compares the outputs from the FlexAttention backend with
|
||||
the default backend for encoder models.
|
||||
"""
|
||||
model_name = "BAAI/bge-base-en-v1.5"
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
]
|
||||
|
||||
# Run with flex attention
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
|
||||
with vllm_runner(model_name,
|
||||
runner="pooling",
|
||||
dtype=torch.bfloat16,
|
||||
tensor_parallel_size=1,
|
||||
max_model_len=100,
|
||||
enforce_eager=True) as llm_flex:
|
||||
flex_outputs = llm_flex.embed(prompts)
|
||||
|
||||
# Run with default backend
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
with vllm_runner(model_name,
|
||||
runner="pooling",
|
||||
dtype=torch.bfloat16,
|
||||
tensor_parallel_size=1,
|
||||
max_model_len=100,
|
||||
enforce_eager=True) as llm_default:
|
||||
default_outputs = llm_default.embed(prompts)
|
||||
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=flex_outputs,
|
||||
embeddings_1_lst=default_outputs,
|
||||
name_0="flex",
|
||||
name_1="default",
|
||||
tol=1e-2,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
|
||||
@ -25,9 +25,6 @@ SSM_MODELS = [
|
||||
|
||||
HYBRID_MODELS = [
|
||||
"ai21labs/Jamba-tiny-dev",
|
||||
# NOTE: Running Plamo2 in transformers implementation requires to install
|
||||
# causal-conv1d package, which is not listed as a test dependency as it's
|
||||
# not compatible with pip-compile.
|
||||
"pfnet/plamo-2-1b",
|
||||
"Zyphra/Zamba2-1.2B-instruct",
|
||||
"hmellor/tiny-random-BambaForCausalLM",
|
||||
@ -50,11 +47,13 @@ HF_UNSUPPORTED_MODELS = [
|
||||
# https://github.com/huggingface/transformers/pull/39033
|
||||
# We will enable vLLM test for Granite after next HF transformers release.
|
||||
"ibm-granite/granite-4.0-tiny-preview",
|
||||
# NOTE: Plamo2 requires both mamba_ssm and causal-conv1d libraries
|
||||
# (see https://huggingface.co/pfnet/plamo-2-1b/blob/main/modeling_plamo.py),
|
||||
# Don't compare it to HF, to avoid managing the dependency.
|
||||
"pfnet/plamo-2-1b",
|
||||
]
|
||||
|
||||
V1_SUPPORTED_MODELS = [
|
||||
"state-spaces/mamba-130m-hf",
|
||||
"ai21labs/Jamba-tiny-dev",
|
||||
"mistralai/Mamba-Codestral-7B-v0.1",
|
||||
"ibm-ai-platform/Bamba-9B-v1",
|
||||
"Zyphra/Zamba2-1.2B-instruct",
|
||||
|
||||
@ -162,8 +162,7 @@ def mteb_test_embed_models(hf_runner,
|
||||
vllm_runner,
|
||||
model_info: EmbedModelInfo,
|
||||
vllm_extra_kwargs=None,
|
||||
hf_model_callback=None,
|
||||
atol=MTEB_RERANK_TOL):
|
||||
hf_model_callback=None):
|
||||
if not model_info.enable_test:
|
||||
# A model family has many models with the same architecture,
|
||||
# and we don't need to test each one.
|
||||
@ -199,7 +198,7 @@ def mteb_test_embed_models(hf_runner,
|
||||
print("SentenceTransformers:", st_dtype, st_main_score)
|
||||
print("Difference:", st_main_score - vllm_main_score)
|
||||
|
||||
assert st_main_score == pytest.approx(vllm_main_score, abs=atol)
|
||||
assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_EMBED_TOL)
|
||||
|
||||
|
||||
def run_mteb_rerank(cross_encoder, tasks, languages):
|
||||
|
||||
@ -7,7 +7,7 @@ import pytest
|
||||
from vllm.config import PoolerConfig
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ...utils import check_embeddings_close, check_transformers_version
|
||||
from ...utils import check_embeddings_close
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@ -56,9 +56,6 @@ def test_models(
|
||||
model,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
if model == "Alibaba-NLP/gte-Qwen2-1.5B-instruct":
|
||||
check_transformers_version(model, max_transformers_version="4.53.2")
|
||||
|
||||
if model == "BAAI/bge-multilingual-gemma2" and current_platform.is_rocm():
|
||||
# ROCm Triton FA does not currently support sliding window attention
|
||||
# switch to use ROCm CK FA backend
|
||||
|
||||
@ -4,7 +4,6 @@ from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from ...utils import check_transformers_version
|
||||
from .embed_utils import EmbedModelInfo, correctness_test_embed_models
|
||||
from .mteb_utils import mteb_test_embed_models
|
||||
|
||||
@ -61,10 +60,6 @@ MODELS = [
|
||||
@pytest.mark.parametrize("model_info", MODELS)
|
||||
def test_embed_models_mteb(hf_runner, vllm_runner,
|
||||
model_info: EmbedModelInfo) -> None:
|
||||
if model_info.name == "Alibaba-NLP/gte-Qwen2-1.5B-instruct":
|
||||
check_transformers_version(model_info.name,
|
||||
max_transformers_version="4.53.2")
|
||||
|
||||
vllm_extra_kwargs: dict[str, Any] = {}
|
||||
if model_info.architecture == "GteNewModel":
|
||||
vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]}
|
||||
@ -77,10 +72,6 @@ def test_embed_models_mteb(hf_runner, vllm_runner,
|
||||
def test_embed_models_correctness(hf_runner, vllm_runner,
|
||||
model_info: EmbedModelInfo,
|
||||
example_prompts) -> None:
|
||||
if model_info.name == "Alibaba-NLP/gte-Qwen2-1.5B-instruct":
|
||||
check_transformers_version(model_info.name,
|
||||
max_transformers_version="4.53.2")
|
||||
|
||||
vllm_extra_kwargs: dict[str, Any] = {}
|
||||
if model_info.architecture == "GteNewModel":
|
||||
vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]}
|
||||
|
||||
@ -10,7 +10,6 @@ from transformers import AutoModel
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ....conftest import HfRunner
|
||||
from ...utils import check_transformers_version
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@ -87,9 +86,6 @@ def test_prm_models(
|
||||
dtype: str,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
check_transformers_version("Qwen/Qwen2.5-Math-PRM-7B",
|
||||
max_transformers_version="4.53.2")
|
||||
|
||||
if current_platform.is_cpu() and os.environ.get("VLLM_USE_V1", "0") == "0":
|
||||
pytest.skip("CPU only supports V1")
|
||||
|
||||
|
||||
@ -62,7 +62,9 @@ def run_test(
|
||||
# if we run HF first, the cuda initialization will be done and it
|
||||
# will hurt multiprocessing backend with fork method (the default method).
|
||||
|
||||
vllm_runner_kwargs_: dict[str, Any] = {"mm_processor_cache_gb": 0}
|
||||
vllm_runner_kwargs_: dict[str, Any] = {
|
||||
"disable_mm_preprocessor_cache": True,
|
||||
}
|
||||
if model_info.tokenizer:
|
||||
vllm_runner_kwargs_["tokenizer_name"] = model_info.tokenizer
|
||||
if model_info.tokenizer_mode:
|
||||
|
||||
@ -15,14 +15,14 @@ from ...utils import build_model_context
|
||||
["meta-llama/Llama-4-Scout-17B-16E-Instruct"])
|
||||
@pytest.mark.parametrize("mm_processor_kwargs", [{}])
|
||||
@pytest.mark.parametrize("num_imgs", [1, 5])
|
||||
@pytest.mark.parametrize("mm_processor_cache_gb", [0, 4])
|
||||
@pytest.mark.parametrize("disable_mm_preprocessor_cache", [True, False])
|
||||
@pytest.mark.parametrize("tokenized_prompt", [True, False])
|
||||
def test_processor_override(
|
||||
image_assets: ImageTestAssets,
|
||||
model_id: str,
|
||||
mm_processor_kwargs: dict,
|
||||
num_imgs: int,
|
||||
mm_processor_cache_gb: int,
|
||||
disable_mm_preprocessor_cache: bool,
|
||||
tokenized_prompt: bool,
|
||||
):
|
||||
"""Ensure llama4 processor works properly."""
|
||||
@ -30,7 +30,7 @@ def test_processor_override(
|
||||
model_id,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
limit_mm_per_prompt={"image": num_imgs},
|
||||
mm_processor_cache_gb=mm_processor_cache_gb,
|
||||
disable_mm_preprocessor_cache=disable_mm_preprocessor_cache,
|
||||
)
|
||||
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
|
||||
config = processor.info.get_hf_config()
|
||||
|
||||
@ -278,8 +278,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
transformers_version_reason="vLLM impl inherits PreTrainedModel and clashes with get_input_embeddings", # noqa: E501
|
||||
trust_remote_code=True),
|
||||
"QWenLMHeadModel": _HfExamplesInfo("Qwen/Qwen-7B-Chat",
|
||||
max_transformers_version="4.53",
|
||||
transformers_version_reason="HF model uses remote code that is not compatible with latest Transformers", # noqa: E501
|
||||
trust_remote_code=True),
|
||||
"Qwen2ForCausalLM": _HfExamplesInfo("Qwen/Qwen2-0.5B-Instruct",
|
||||
extras={"2.5": "Qwen/Qwen2.5-0.5B-Instruct"}), # noqa: E501
|
||||
@ -429,7 +427,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
"MiniCPMO": _HfExamplesInfo("openbmb/MiniCPM-o-2_6",
|
||||
trust_remote_code=True),
|
||||
"MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5",
|
||||
extras={"2.6": "openbmb/MiniCPM-V-2_6", "4.0": "openbmb/MiniCPM-V-4"}, # noqa: E501
|
||||
extras={"2.6": "openbmb/MiniCPM-V-2_6"}, # noqa: E501
|
||||
trust_remote_code=True),
|
||||
"MiniMaxVL01ForConditionalGeneration": _HfExamplesInfo("MiniMaxAI/MiniMax-VL-01", # noqa: E501
|
||||
trust_remote_code=True,
|
||||
|
||||
@ -9,7 +9,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import ModelConfig, ModelDType, RunnerOption
|
||||
from vllm.config import ModelConfig, RunnerOption
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
|
||||
|
||||
@ -257,11 +257,11 @@ def check_logprobs_close(
|
||||
def build_model_context(
|
||||
model_id: str,
|
||||
runner: RunnerOption = "auto",
|
||||
dtype: ModelDType = "auto",
|
||||
dtype: Union[str, torch.dtype] = "auto",
|
||||
model_config_kwargs: Optional[dict[str, Any]] = None,
|
||||
mm_processor_kwargs: Optional[dict[str, Any]] = None,
|
||||
limit_mm_per_prompt: Optional[dict[str, int]] = None,
|
||||
mm_processor_cache_gb: int = 0,
|
||||
disable_mm_preprocessor_cache: bool = True,
|
||||
):
|
||||
"""Creates an InputContext for a given model.
|
||||
|
||||
@ -279,7 +279,6 @@ def build_model_context(
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
|
||||
model_config_kwargs = model_config_kwargs or {}
|
||||
limit_mm_per_prompt = limit_mm_per_prompt or {}
|
||||
model_config = ModelConfig(
|
||||
model_id,
|
||||
runner=runner,
|
||||
@ -291,7 +290,7 @@ def build_model_context(
|
||||
seed=0,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
mm_processor_cache_gb=mm_processor_cache_gb,
|
||||
disable_mm_preprocessor_cache=disable_mm_preprocessor_cache,
|
||||
hf_overrides=model_info.hf_overrides,
|
||||
**model_config_kwargs,
|
||||
)
|
||||
@ -413,14 +412,3 @@ def dummy_hf_overrides(
|
||||
})
|
||||
|
||||
return hf_config
|
||||
|
||||
|
||||
def check_transformers_version(model: str,
|
||||
min_transformers_version: Optional[str] = None,
|
||||
max_transformers_version: Optional[str] = None):
|
||||
from .registry import _HfExamplesInfo
|
||||
|
||||
return _HfExamplesInfo(model,
|
||||
min_transformers_version=min_transformers_version,
|
||||
max_transformers_version=max_transformers_version
|
||||
).check_transformers_version(on_fail="skip")
|
||||
|
||||
@ -1,51 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata
|
||||
from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargs,
|
||||
MultiModalKwargsItem,
|
||||
MultiModalSharedField)
|
||||
|
||||
|
||||
def _dummy_elem(modality: str, key: str, size: int):
|
||||
return MultiModalFieldElem(
|
||||
modality=modality,
|
||||
key=key,
|
||||
data=torch.empty((size, ), dtype=torch.int8),
|
||||
field=MultiModalSharedField(1),
|
||||
)
|
||||
|
||||
|
||||
def _dummy_item(modality: str, size_by_key: dict[str, int]):
|
||||
return MultiModalKwargsItem.from_elems([
|
||||
_dummy_elem(modality, key, size) for key, size in size_by_key.items()
|
||||
])
|
||||
|
||||
|
||||
def _dummy_kw(size_by_key_modality: dict[str, dict[str, int]]):
|
||||
return MultiModalKwargs.from_items([
|
||||
_dummy_item(modality, size_by_key)
|
||||
for modality, size_by_key in size_by_key_modality.items()
|
||||
])
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
("item", "expected_size"),
|
||||
[
|
||||
(_dummy_item("a", {"a1": 100}), 100),
|
||||
(_dummy_item("a", {"a1": 100, "a2": 110}), 210),
|
||||
(_dummy_kw({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501
|
||||
],
|
||||
)
|
||||
# yapf: enable
|
||||
def test_cache_item_size(item, expected_size):
|
||||
cache = MultiModalCache.get_lru_cache(2048, type(item))
|
||||
|
||||
cache[""] = item
|
||||
assert cache.currsize == expected_size
|
||||
|
||||
cache[""] = MultiModalCacheItemMetadata.wraps(item)
|
||||
assert cache.currsize == expected_size
|
||||
@ -6,15 +6,20 @@ from typing import Optional, cast
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.inputs import InputProcessingContext
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargs,
|
||||
MultiModalKwargsItem,
|
||||
MultiModalSharedField)
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
|
||||
PromptIndexTargets, PromptInsertion,
|
||||
PromptReplacement, apply_text_matches,
|
||||
ProcessingCache, PromptIndexTargets,
|
||||
PromptInsertion, PromptReplacement,
|
||||
apply_text_matches,
|
||||
apply_token_matches,
|
||||
find_mm_placeholders,
|
||||
find_text_matches, find_token_matches,
|
||||
@ -897,6 +902,45 @@ def test_find_mm_placeholders(
|
||||
assert result == expected
|
||||
|
||||
|
||||
def _dummy_elem(modality: str, key: str, size: int):
|
||||
return MultiModalFieldElem(
|
||||
modality=modality,
|
||||
key=key,
|
||||
data=torch.empty((size, ), dtype=torch.int8),
|
||||
field=MultiModalSharedField(1),
|
||||
)
|
||||
|
||||
|
||||
def _dummy_item(modality: str, size_by_key: dict[str, int]):
|
||||
return MultiModalKwargsItem.from_elems([
|
||||
_dummy_elem(modality, key, size) for key, size in size_by_key.items()
|
||||
])
|
||||
|
||||
|
||||
def _dummy_kw(size_by_key_modality: dict[str, dict[str, int]]):
|
||||
return MultiModalKwargs.from_items([
|
||||
_dummy_item(modality, size_by_key)
|
||||
for modality, size_by_key in size_by_key_modality.items()
|
||||
])
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
("item", "expected_size"),
|
||||
[
|
||||
(_dummy_item("a", {"a1": 100}), 100),
|
||||
(_dummy_item("a", {"a1": 100, "a2": 110}), 210),
|
||||
(_dummy_kw({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501
|
||||
],
|
||||
)
|
||||
# yapf: enable
|
||||
def test_cache_item_size(item, expected_size):
|
||||
cache = ProcessingCache.get_lru_cache(2048, type(item))
|
||||
cache[""] = item
|
||||
|
||||
assert cache.currsize == expected_size
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
|
||||
@pytest.mark.parametrize(
|
||||
("limit", "num_supported", "is_valid"),
|
||||
|
||||
@ -10,12 +10,11 @@ from dataclasses import dataclass
|
||||
from json.decoder import JSONDecodeError
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.logger import (_DATE_FORMAT, _FORMAT, _configure_vllm_root_logger,
|
||||
enable_trace_function_call, init_logger)
|
||||
from vllm.logging_utils import NewLineFormatter
|
||||
@ -229,10 +228,9 @@ def test_prepare_object_to_dump():
|
||||
list_obj = [1, 2, 3]
|
||||
assert prepare_object_to_dump(list_obj) == '[1, 2, 3]'
|
||||
|
||||
dict_obj = {"a": 1, "b": "b"}
|
||||
dict_obj = {'a': 1, 'b': 'b'}
|
||||
assert prepare_object_to_dump(dict_obj) in [
|
||||
"{a: 1, b: 'b'}",
|
||||
"{b: 'b', a: 1}",
|
||||
"{a: 1, b: 'b'}", "{b: 'b', a: 1}"
|
||||
]
|
||||
|
||||
set_obj = {1, 2, 3}
|
||||
@ -254,246 +252,4 @@ def test_prepare_object_to_dump():
|
||||
b: str
|
||||
|
||||
assert (prepare_object_to_dump(CustomClass(
|
||||
1, "b")) == "CustomClass(a=1, b='b')")
|
||||
|
||||
|
||||
def test_request_logger_log_outputs():
|
||||
"""Test the new log_outputs functionality."""
|
||||
# Create a mock logger to capture log calls
|
||||
mock_logger = MagicMock()
|
||||
|
||||
with patch("vllm.entrypoints.logger.logger", mock_logger):
|
||||
request_logger = RequestLogger(max_log_len=None)
|
||||
|
||||
# Test basic output logging
|
||||
request_logger.log_outputs(
|
||||
request_id="test-123",
|
||||
outputs="Hello, world!",
|
||||
output_token_ids=[1, 2, 3, 4],
|
||||
finish_reason="stop",
|
||||
is_streaming=False,
|
||||
delta=False,
|
||||
)
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
call_args = mock_logger.info.call_args.args
|
||||
assert "Generated response %s%s" in call_args[0]
|
||||
assert call_args[1] == "test-123"
|
||||
assert call_args[3] == "Hello, world!"
|
||||
assert call_args[4] == [1, 2, 3, 4]
|
||||
assert call_args[5] == "stop"
|
||||
|
||||
|
||||
def test_request_logger_log_outputs_streaming_delta():
|
||||
"""Test log_outputs with streaming delta mode."""
|
||||
mock_logger = MagicMock()
|
||||
|
||||
with patch("vllm.entrypoints.logger.logger", mock_logger):
|
||||
request_logger = RequestLogger(max_log_len=None)
|
||||
|
||||
# Test streaming delta logging
|
||||
request_logger.log_outputs(
|
||||
request_id="test-456",
|
||||
outputs="Hello",
|
||||
output_token_ids=[1],
|
||||
finish_reason=None,
|
||||
is_streaming=True,
|
||||
delta=True,
|
||||
)
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
call_args = mock_logger.info.call_args.args
|
||||
assert "Generated response %s%s" in call_args[0]
|
||||
assert call_args[1] == "test-456"
|
||||
assert call_args[2] == " (streaming delta)"
|
||||
assert call_args[3] == "Hello"
|
||||
assert call_args[4] == [1]
|
||||
assert call_args[5] is None
|
||||
|
||||
|
||||
def test_request_logger_log_outputs_streaming_complete():
|
||||
"""Test log_outputs with streaming complete mode."""
|
||||
mock_logger = MagicMock()
|
||||
|
||||
with patch("vllm.entrypoints.logger.logger", mock_logger):
|
||||
request_logger = RequestLogger(max_log_len=None)
|
||||
|
||||
# Test streaming complete logging
|
||||
request_logger.log_outputs(
|
||||
request_id="test-789",
|
||||
outputs="Complete response",
|
||||
output_token_ids=[1, 2, 3],
|
||||
finish_reason="length",
|
||||
is_streaming=True,
|
||||
delta=False,
|
||||
)
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
call_args = mock_logger.info.call_args.args
|
||||
assert "Generated response %s%s" in call_args[0]
|
||||
assert call_args[1] == "test-789"
|
||||
assert call_args[2] == " (streaming complete)"
|
||||
assert call_args[3] == "Complete response"
|
||||
assert call_args[4] == [1, 2, 3]
|
||||
assert call_args[5] == "length"
|
||||
|
||||
|
||||
def test_request_logger_log_outputs_with_truncation():
|
||||
"""Test log_outputs respects max_log_len setting."""
|
||||
mock_logger = MagicMock()
|
||||
|
||||
with patch("vllm.entrypoints.logger.logger", mock_logger):
|
||||
# Set max_log_len to 10
|
||||
request_logger = RequestLogger(max_log_len=10)
|
||||
|
||||
# Test output truncation
|
||||
long_output = "This is a very long output that should be truncated"
|
||||
long_token_ids = list(range(20)) # 20 tokens
|
||||
|
||||
request_logger.log_outputs(
|
||||
request_id="test-truncate",
|
||||
outputs=long_output,
|
||||
output_token_ids=long_token_ids,
|
||||
finish_reason="stop",
|
||||
is_streaming=False,
|
||||
delta=False,
|
||||
)
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
call_args = mock_logger.info.call_args
|
||||
|
||||
# Check that output was truncated to first 10 characters
|
||||
logged_output = call_args[0][3]
|
||||
assert logged_output == "This is a "
|
||||
assert len(logged_output) == 10
|
||||
|
||||
# Check that token IDs were truncated to first 10 tokens
|
||||
logged_token_ids = call_args[0][4]
|
||||
assert logged_token_ids == list(range(10))
|
||||
assert len(logged_token_ids) == 10
|
||||
|
||||
|
||||
def test_request_logger_log_outputs_none_values():
|
||||
"""Test log_outputs handles None values correctly."""
|
||||
mock_logger = MagicMock()
|
||||
|
||||
with patch("vllm.entrypoints.logger.logger", mock_logger):
|
||||
request_logger = RequestLogger(max_log_len=None)
|
||||
|
||||
# Test with None output_token_ids
|
||||
request_logger.log_outputs(
|
||||
request_id="test-none",
|
||||
outputs="Test output",
|
||||
output_token_ids=None,
|
||||
finish_reason="stop",
|
||||
is_streaming=False,
|
||||
delta=False,
|
||||
)
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
call_args = mock_logger.info.call_args.args
|
||||
assert "Generated response %s%s" in call_args[0]
|
||||
assert call_args[1] == "test-none"
|
||||
assert call_args[3] == "Test output"
|
||||
assert call_args[4] is None
|
||||
assert call_args[5] == "stop"
|
||||
|
||||
|
||||
def test_request_logger_log_outputs_empty_output():
|
||||
"""Test log_outputs handles empty output correctly."""
|
||||
mock_logger = MagicMock()
|
||||
|
||||
with patch("vllm.entrypoints.logger.logger", mock_logger):
|
||||
request_logger = RequestLogger(max_log_len=5)
|
||||
|
||||
# Test with empty output
|
||||
request_logger.log_outputs(
|
||||
request_id="test-empty",
|
||||
outputs="",
|
||||
output_token_ids=[],
|
||||
finish_reason="stop",
|
||||
is_streaming=False,
|
||||
delta=False,
|
||||
)
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
call_args = mock_logger.info.call_args.args
|
||||
assert "Generated response %s%s" in call_args[0]
|
||||
assert call_args[1] == "test-empty"
|
||||
assert call_args[3] == ""
|
||||
assert call_args[4] == []
|
||||
assert call_args[5] == "stop"
|
||||
|
||||
|
||||
def test_request_logger_log_outputs_integration():
|
||||
"""Test that log_outputs can be called alongside log_inputs."""
|
||||
mock_logger = MagicMock()
|
||||
|
||||
with patch("vllm.entrypoints.logger.logger", mock_logger):
|
||||
request_logger = RequestLogger(max_log_len=None)
|
||||
|
||||
# Test that both methods can be called without interference
|
||||
request_logger.log_inputs(
|
||||
request_id="test-integration",
|
||||
prompt="Test prompt",
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
prompt_embeds=None,
|
||||
params=None,
|
||||
lora_request=None,
|
||||
)
|
||||
|
||||
request_logger.log_outputs(
|
||||
request_id="test-integration",
|
||||
outputs="Test output",
|
||||
output_token_ids=[4, 5, 6],
|
||||
finish_reason="stop",
|
||||
is_streaming=False,
|
||||
delta=False,
|
||||
)
|
||||
|
||||
# Should have been called twice - once for inputs, once for outputs
|
||||
assert mock_logger.info.call_count == 2
|
||||
|
||||
# Check that the calls were made with correct patterns
|
||||
input_call = mock_logger.info.call_args_list[0][0]
|
||||
output_call = mock_logger.info.call_args_list[1][0]
|
||||
|
||||
assert "Received request %s" in input_call[0]
|
||||
assert input_call[1] == "test-integration"
|
||||
|
||||
assert "Generated response %s%s" in output_call[0]
|
||||
assert output_call[1] == "test-integration"
|
||||
|
||||
|
||||
def test_streaming_complete_logs_full_text_content():
|
||||
"""Test that streaming complete logging includes
|
||||
full accumulated text, not just token count."""
|
||||
mock_logger = MagicMock()
|
||||
|
||||
with patch("vllm.entrypoints.logger.logger", mock_logger):
|
||||
request_logger = RequestLogger(max_log_len=None)
|
||||
|
||||
# Test with actual content instead of token count format
|
||||
full_response = "This is a complete response from streaming"
|
||||
request_logger.log_outputs(
|
||||
request_id="test-streaming-full-text",
|
||||
outputs=full_response,
|
||||
output_token_ids=None,
|
||||
finish_reason="streaming_complete",
|
||||
is_streaming=True,
|
||||
delta=False,
|
||||
)
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
call_args = mock_logger.info.call_args.args
|
||||
|
||||
# Verify the logged output is the full text, not a token count format
|
||||
logged_output = call_args[3]
|
||||
assert logged_output == full_response
|
||||
assert "tokens>" not in logged_output
|
||||
assert "streaming_complete" not in logged_output
|
||||
|
||||
# Verify other parameters
|
||||
assert call_args[1] == "test-streaming-full-text"
|
||||
assert call_args[2] == " (streaming complete)"
|
||||
assert call_args[5] == "streaming_complete"
|
||||
1, 'b')) == "CustomClass(a=1, b='b')")
|
||||
|
||||
@ -1,171 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Test script for the token-to-expert routing simulator.
|
||||
|
||||
This script demonstrates how to use the routing simulator to test
|
||||
different routing strategies and analyze their performance, including
|
||||
integration tests with FusedMoE layer.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.routing_simulator import (
|
||||
DistributionBasedRouting, RoutingSimulator)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def device():
|
||||
"""Fixture to provide the appropriate device for testing."""
|
||||
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [1, 16, 256])
|
||||
@pytest.mark.parametrize("hidden_size", [64, 1024])
|
||||
@pytest.mark.parametrize("num_experts", [16, 128])
|
||||
@pytest.mark.parametrize("top_k", [1, 4])
|
||||
def test_basic_functionality(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
device,
|
||||
):
|
||||
"""Test basic functionality of the routing simulator."""
|
||||
# Test each routing strategy
|
||||
strategies = RoutingSimulator.get_available_strategies()
|
||||
|
||||
hidden_states = torch.randn(num_tokens, hidden_size, device=device)
|
||||
router_logits = torch.randn(num_tokens, num_experts, device=device)
|
||||
|
||||
for strategy in strategies:
|
||||
# Simulate routing
|
||||
topk_weights, topk_ids = RoutingSimulator.simulate_routing(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
strategy_name=strategy,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
# Check output shapes
|
||||
assert topk_weights.shape == (
|
||||
num_tokens,
|
||||
top_k,
|
||||
), f"Wrong weights shape for {strategy}"
|
||||
assert topk_ids.shape == (
|
||||
num_tokens,
|
||||
top_k,
|
||||
), f"Wrong ids shape for {strategy}"
|
||||
|
||||
# Check that expert IDs are valid
|
||||
assert (topk_ids.min()
|
||||
>= 0), f"Invalid expert ID (negative) for {strategy}"
|
||||
assert (topk_ids.max()
|
||||
< num_experts), f"Invalid expert ID (too large) for {strategy}"
|
||||
|
||||
|
||||
def test_routing_strategy_integration(monkeypatch, device):
|
||||
"""Test that the routing strategy environment variable works with
|
||||
FusedMoE."""
|
||||
pytest.importorskip("vllm.model_executor.layers.fused_moe.layer")
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
|
||||
# Test parameters
|
||||
num_tokens = 32
|
||||
hidden_size = 16
|
||||
num_experts = 4
|
||||
top_k = 2
|
||||
|
||||
# Create test data
|
||||
hidden_states = torch.randn(num_tokens, hidden_size, device=device)
|
||||
router_logits = torch.randn(num_tokens, num_experts, device=device)
|
||||
|
||||
# Test different routing strategies
|
||||
strategies = RoutingSimulator.get_available_strategies()
|
||||
|
||||
for strategy in strategies:
|
||||
# Set environment variable
|
||||
env_name = "VLLM_MOE_ROUTING_SIMULATION_STRATEGY"
|
||||
monkeypatch.setenv(env_name, strategy)
|
||||
|
||||
# Force reload of environment variable
|
||||
envs.environment_variables[env_name] = lambda s=strategy: s
|
||||
|
||||
# Test the select_experts method
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=True,
|
||||
indices_type=torch.long)
|
||||
|
||||
# Verify output shapes
|
||||
assert topk_weights.shape == (
|
||||
num_tokens, top_k), f"Wrong weights shape for {strategy}"
|
||||
assert topk_ids.shape == (num_tokens,
|
||||
top_k), f"Wrong ids shape for {strategy}"
|
||||
|
||||
# Verify expert IDs are valid
|
||||
assert topk_ids.min(
|
||||
) >= 0, f"Invalid expert ID (negative) for {strategy}"
|
||||
assert topk_ids.max(
|
||||
) < num_experts, f"Invalid expert ID (too large) for {strategy}"
|
||||
|
||||
|
||||
def test_distribution_based_routing_with_custom_strategy():
|
||||
"""Test registering and using DistributionBasedRouting with custom
|
||||
parameters."""
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Register custom distribution-based strategy
|
||||
custom_strategy = DistributionBasedRouting(distribution="normal",
|
||||
mean=2.0,
|
||||
std=0.5)
|
||||
RoutingSimulator.register_strategy("custom_normal", custom_strategy)
|
||||
|
||||
# Test data
|
||||
num_tokens = 60
|
||||
hidden_size = 48
|
||||
num_experts = 6
|
||||
top_k = 3
|
||||
|
||||
hidden_states = torch.randn(num_tokens, hidden_size, device=device)
|
||||
router_logits = torch.randn(num_tokens, num_experts, device=device)
|
||||
|
||||
# Use the custom strategy
|
||||
topk_weights, topk_ids = RoutingSimulator.simulate_routing(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
strategy_name="custom_normal",
|
||||
top_k=top_k)
|
||||
|
||||
# Check output shapes
|
||||
assert topk_weights.shape == (num_tokens, top_k)
|
||||
assert topk_ids.shape == (num_tokens, top_k)
|
||||
|
||||
# Check that expert IDs are valid
|
||||
assert topk_ids.min() >= 0
|
||||
assert topk_ids.max() < num_experts
|
||||
|
||||
|
||||
def test_instance_compatibility():
|
||||
"""Test that static methods work correctly."""
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Test static method directly
|
||||
hidden_states = torch.randn(10, 8, device=device)
|
||||
router_logits = torch.randn(10, 4, device=device)
|
||||
|
||||
topk_weights, topk_ids = RoutingSimulator.simulate_routing(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
strategy_name="uniform_random",
|
||||
top_k=2)
|
||||
|
||||
assert topk_weights.shape == (10, 2)
|
||||
assert topk_ids.shape == (10, 2)
|
||||
@ -118,17 +118,8 @@ def test_sharded_state_loader(enable_lora, tp_size, num_gpus_available,
|
||||
tensor_parallel_size=tp_size,
|
||||
))
|
||||
p.start()
|
||||
# Call queue.get() before p.join() to prevent deadlock:
|
||||
# If p.join() is called before queue.get() and the queue is full,
|
||||
# the child process may block while writing to the queue and never
|
||||
# terminate, causing the parent to wait indefinitely on p.join().
|
||||
# See: https://github.com/vllm-project/vllm/pull/22371#discussion_r2257773814
|
||||
out_before = queue.get()
|
||||
p.join()
|
||||
queue.close()
|
||||
queue.join_thread()
|
||||
|
||||
queue = ctx.Queue()
|
||||
out_before = queue.get()
|
||||
|
||||
p = ctx.Process(target=_run_generate,
|
||||
args=(output_dir, queue),
|
||||
@ -140,14 +131,7 @@ def test_sharded_state_loader(enable_lora, tp_size, num_gpus_available,
|
||||
load_format="sharded_state",
|
||||
))
|
||||
p.start()
|
||||
# Call queue.get() before p.join() to prevent deadlock:
|
||||
# If p.join() is called before queue.get() and the queue is full,
|
||||
# the child process may block while writing to the queue and never
|
||||
# terminate, causing the parent to wait indefinitely on p.join().
|
||||
# See: https://github.com/vllm-project/vllm/pull/22371#discussion_r2257773814
|
||||
out_after = queue.get()
|
||||
p.join()
|
||||
queue.close()
|
||||
queue.join_thread()
|
||||
out_after = queue.get()
|
||||
|
||||
assert out_before == out_after
|
||||
|
||||
@ -3,12 +3,10 @@
|
||||
# ruff: noqa: E501
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionToolsParam,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
|
||||
from vllm.entrypoints.openai.tool_parsers import MinimaxToolParser
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
@ -26,57 +24,6 @@ def minimax_tool_parser(minimax_tokenizer):
|
||||
return MinimaxToolParser(minimax_tokenizer)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tools():
|
||||
return [
|
||||
ChatCompletionToolsParam(type="function",
|
||||
function={
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city name"
|
||||
},
|
||||
"state": {
|
||||
"type": "string",
|
||||
"description":
|
||||
"The state code"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum":
|
||||
["fahrenheit", "celsius"]
|
||||
}
|
||||
},
|
||||
"required": ["city", "state"]
|
||||
}
|
||||
}),
|
||||
ChatCompletionToolsParam(type="function",
|
||||
function={
|
||||
"name": "calculate_area",
|
||||
"description":
|
||||
"Calculate area of a shape",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"shape": {
|
||||
"type": "string"
|
||||
},
|
||||
"dimensions": {
|
||||
"type": "object"
|
||||
},
|
||||
"precision": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
]
|
||||
|
||||
|
||||
def assert_tool_calls(actual_tool_calls: list[ToolCall],
|
||||
expected_tool_calls: list[ToolCall]):
|
||||
assert len(actual_tool_calls) == len(expected_tool_calls)
|
||||
@ -423,794 +370,3 @@ def test_extract_tool_calls_multiline_json_not_supported(minimax_tool_parser):
|
||||
assert not extracted_tool_calls.tools_called
|
||||
assert extracted_tool_calls.tool_calls == []
|
||||
assert extracted_tool_calls.content is None
|
||||
|
||||
|
||||
def test_streaming_arguments_incremental_output(minimax_tool_parser):
|
||||
"""Test that streaming arguments are returned incrementally, not cumulatively."""
|
||||
# Reset streaming state
|
||||
minimax_tool_parser.current_tool_name_sent = False
|
||||
minimax_tool_parser.prev_tool_call_arr = []
|
||||
minimax_tool_parser.current_tool_id = -1
|
||||
minimax_tool_parser.streamed_args_for_tool = []
|
||||
|
||||
# Simulate progressive tool call building
|
||||
stages = [
|
||||
# Stage 1: Function name complete
|
||||
'<tool_calls>\n{"name": "get_current_weather", "arguments": ',
|
||||
# Stage 2: Arguments object starts with first key
|
||||
'<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": ',
|
||||
# Stage 3: First parameter value added
|
||||
'<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle"',
|
||||
# Stage 4: Second parameter added
|
||||
'<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA"',
|
||||
# Stage 5: Third parameter added, arguments complete
|
||||
'<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}',
|
||||
# Stage 6: Tool calls closed
|
||||
'<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}\n</tool',
|
||||
'<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}\n</tool_calls>'
|
||||
]
|
||||
|
||||
function_name_sent = False
|
||||
previous_args_content = ""
|
||||
|
||||
for i, current_text in enumerate(stages):
|
||||
previous_text = stages[i - 1] if i > 0 else ""
|
||||
delta_text = current_text[len(previous_text
|
||||
):] if i > 0 else current_text
|
||||
|
||||
result = minimax_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=None,
|
||||
)
|
||||
|
||||
print(f"Stage {i}: Current text: {repr(current_text)}")
|
||||
print(f"Stage {i}: Delta text: {repr(delta_text)}")
|
||||
|
||||
if result is not None and hasattr(result,
|
||||
'tool_calls') and result.tool_calls:
|
||||
tool_call = result.tool_calls[0]
|
||||
|
||||
# Check if function name is sent (should happen only once)
|
||||
if tool_call.function and tool_call.function.name:
|
||||
assert tool_call.function.name == "get_current_weather"
|
||||
function_name_sent = True
|
||||
print(
|
||||
f"Stage {i}: Function name sent: {tool_call.function.name}"
|
||||
)
|
||||
|
||||
# Check if arguments are sent incrementally
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
args_fragment = tool_call.function.arguments
|
||||
print(
|
||||
f"Stage {i}: Got arguments fragment: {repr(args_fragment)}"
|
||||
)
|
||||
|
||||
# For incremental output, each fragment should be new content only
|
||||
# The fragment should not contain all previous content
|
||||
if i >= 2 and previous_args_content: # After we start getting arguments
|
||||
# The new fragment should not be identical to or contain all previous content
|
||||
assert args_fragment != previous_args_content, f"Fragment should be incremental, not cumulative: {args_fragment}"
|
||||
|
||||
# If this is truly incremental, the fragment should be relatively small
|
||||
# compared to the complete arguments so far
|
||||
if len(args_fragment) > len(previous_args_content):
|
||||
print(
|
||||
"Warning: Fragment seems cumulative rather than incremental"
|
||||
)
|
||||
|
||||
previous_args_content = args_fragment
|
||||
|
||||
# Verify function name was sent at least once
|
||||
assert function_name_sent, "Function name should have been sent"
|
||||
|
||||
|
||||
def test_streaming_arguments_delta_only(minimax_tool_parser):
|
||||
"""Test that each streaming call returns only the delta (new part) of arguments."""
|
||||
# Reset streaming state
|
||||
minimax_tool_parser.current_tool_name_sent = False
|
||||
minimax_tool_parser.prev_tool_call_arr = []
|
||||
minimax_tool_parser.current_tool_id = -1
|
||||
minimax_tool_parser.streamed_args_for_tool = []
|
||||
|
||||
# Simulate two consecutive calls with growing arguments
|
||||
call1_text = '<tool_calls>\n{"name": "test_tool", "arguments": {"param1": "value1"}}'
|
||||
call2_text = '<tool_calls>\n{"name": "test_tool", "arguments": {"param1": "value1", "param2": "value2"}}'
|
||||
|
||||
print(f"Call 1 text: {repr(call1_text)}")
|
||||
print(f"Call 2 text: {repr(call2_text)}")
|
||||
|
||||
# First call - should get the function name and initial arguments
|
||||
result1 = minimax_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="",
|
||||
current_text=call1_text,
|
||||
delta_text=call1_text,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=None,
|
||||
)
|
||||
|
||||
print(f"Result 1: {result1}")
|
||||
if result1 and hasattr(result1, 'tool_calls') and result1.tool_calls:
|
||||
for i, tc in enumerate(result1.tool_calls):
|
||||
print(f" Tool call {i}: {tc}")
|
||||
|
||||
# Second call - should only get the delta (new part) of arguments
|
||||
result2 = minimax_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text=call1_text,
|
||||
current_text=call2_text,
|
||||
delta_text=', "param2": "value2"}',
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=None,
|
||||
)
|
||||
|
||||
print(f"Result 2: {result2}")
|
||||
if result2 and hasattr(result2, 'tool_calls') and result2.tool_calls:
|
||||
for i, tc in enumerate(result2.tool_calls):
|
||||
print(f" Tool call {i}: {tc}")
|
||||
|
||||
# Verify the second call only returns the delta
|
||||
if result2 is not None and hasattr(result2,
|
||||
'tool_calls') and result2.tool_calls:
|
||||
tool_call = result2.tool_calls[0]
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
args_delta = tool_call.function.arguments
|
||||
print(f"Arguments delta from second call: {repr(args_delta)}")
|
||||
|
||||
# Should only contain the new part, not the full arguments
|
||||
# The delta should be something like ', "param2": "value2"}' or just '"param2": "value2"'
|
||||
assert ', "param2": "value2"}' in args_delta or '"param2": "value2"' in args_delta, f"Expected delta containing param2, got: {args_delta}"
|
||||
|
||||
# Should NOT contain the previous parameter data
|
||||
assert '"param1": "value1"' not in args_delta, f"Arguments delta should not contain previous data: {args_delta}"
|
||||
|
||||
# The delta should be relatively short (incremental, not cumulative)
|
||||
expected_max_length = len(
|
||||
', "param2": "value2"}') + 10 # Some tolerance
|
||||
assert len(
|
||||
args_delta
|
||||
) <= expected_max_length, f"Delta seems too long (possibly cumulative): {args_delta}"
|
||||
|
||||
print("✓ Delta validation passed")
|
||||
else:
|
||||
print("No arguments in result2 tool call")
|
||||
else:
|
||||
print("No tool calls in result2 or result2 is None")
|
||||
# This might be acceptable if no incremental update is needed
|
||||
# But let's at least verify that result1 had some content
|
||||
assert result1 is not None, "At least the first call should return something"
|
||||
|
||||
|
||||
def test_streaming_openai_compatibility(minimax_tool_parser):
|
||||
"""Test that streaming behavior with buffering works correctly."""
|
||||
# Reset streaming state
|
||||
minimax_tool_parser.current_tool_name_sent = False
|
||||
minimax_tool_parser.prev_tool_call_arr = []
|
||||
minimax_tool_parser.current_tool_id = -1
|
||||
minimax_tool_parser.streamed_args_for_tool = []
|
||||
# Reset buffering state
|
||||
minimax_tool_parser.pending_buffer = ""
|
||||
minimax_tool_parser.in_thinking_tag = False
|
||||
minimax_tool_parser.thinking_depth = 0
|
||||
|
||||
# Test scenario: simple buffering without complex tool call context
|
||||
test_cases: list[dict[str, Any]] = [
|
||||
{
|
||||
'stage': 'Token: <',
|
||||
'previous': '',
|
||||
'current': '<',
|
||||
'delta': '<',
|
||||
'expected_content': None, # Should be buffered
|
||||
},
|
||||
{
|
||||
'stage': 'Token: tool_calls>',
|
||||
'previous': '<',
|
||||
'current': '<tool_calls>',
|
||||
'delta': 'tool_calls>',
|
||||
'expected_content': None, # Complete tag, should not output
|
||||
},
|
||||
{
|
||||
'stage': 'Regular content',
|
||||
'previous': 'Hello',
|
||||
'current': 'Hello world',
|
||||
'delta': ' world',
|
||||
'expected_content': ' world', # Normal content should pass through
|
||||
},
|
||||
{
|
||||
'stage': 'Content with end tag start',
|
||||
'previous': 'Text',
|
||||
'current': 'Text content</tool_',
|
||||
'delta': ' content</tool_',
|
||||
'expected_content':
|
||||
' content', # Content part output, </tool_ buffered
|
||||
},
|
||||
{
|
||||
'stage': 'Complete end tag',
|
||||
'previous': 'Text content</tool_',
|
||||
'current': 'Text content</tool_calls>',
|
||||
'delta': 'calls>',
|
||||
'expected_content': None, # Complete close tag, should not output
|
||||
},
|
||||
]
|
||||
|
||||
for i, test_case in enumerate(test_cases):
|
||||
print(f"\n--- Stage {i}: {test_case['stage']} ---")
|
||||
print(f"Previous: {repr(test_case['previous'])}")
|
||||
print(f"Current: {repr(test_case['current'])}")
|
||||
print(f"Delta: {repr(test_case['delta'])}")
|
||||
|
||||
result = minimax_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text=test_case['previous'],
|
||||
current_text=test_case['current'],
|
||||
delta_text=test_case['delta'],
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=None,
|
||||
)
|
||||
|
||||
print(f"Result: {result}")
|
||||
|
||||
# Check expected content
|
||||
if test_case['expected_content'] is None:
|
||||
assert result is None or not getattr(result, 'content', None), \
|
||||
f"Stage {i}: Expected no content, got {result}"
|
||||
print("✓ No content output as expected")
|
||||
else:
|
||||
assert result is not None and hasattr(result, 'content'), \
|
||||
f"Stage {i}: Expected content, got {result}"
|
||||
assert result.content == test_case['expected_content'], \
|
||||
f"Stage {i}: Expected content {test_case['expected_content']}, got {result.content}"
|
||||
print(f"✓ Content matches: {repr(result.content)}")
|
||||
|
||||
print("✓ Streaming test with buffering completed successfully")
|
||||
|
||||
|
||||
def test_streaming_thinking_tag_buffering(minimax_tool_parser):
|
||||
"""Test that tool calls within thinking tags are properly handled during streaming."""
|
||||
# Reset streaming state
|
||||
minimax_tool_parser.current_tool_name_sent = False
|
||||
minimax_tool_parser.prev_tool_call_arr = []
|
||||
minimax_tool_parser.current_tool_id = -1
|
||||
minimax_tool_parser.streamed_args_for_tool = []
|
||||
# Reset buffering state
|
||||
minimax_tool_parser.pending_buffer = ""
|
||||
minimax_tool_parser.in_thinking_tag = False
|
||||
minimax_tool_parser.thinking_depth = 0
|
||||
|
||||
# Test scenario: tool calls within thinking tags should be ignored
|
||||
test_cases: list[dict[str, Any]] = [
|
||||
{
|
||||
'stage': 'Start thinking',
|
||||
'previous': '',
|
||||
'current': '<think>I need to use a tool. <tool_calls>',
|
||||
'delta': '<think>I need to use a tool. <tool_calls>',
|
||||
'expected_content':
|
||||
'<think>I need to use a tool. <tool_calls>', # Should pass through as content
|
||||
},
|
||||
{
|
||||
'stage':
|
||||
'Tool call in thinking',
|
||||
'previous':
|
||||
'<think>I need to use a tool. <tool_calls>',
|
||||
'current':
|
||||
'<think>I need to use a tool. <tool_calls>\n{"name": "ignored_tool", "arguments": {"param": "value"}}\n</tool_calls>',
|
||||
'delta':
|
||||
'\n{"name": "ignored_tool", "arguments": {"param": "value"}}\n</tool_calls>',
|
||||
'expected_content':
|
||||
'\n{"name": "ignored_tool", "arguments": {"param": "value"}}\n</tool_calls>', # </tool_calls> should be preserved in thinking tags
|
||||
},
|
||||
{
|
||||
'stage': 'Real tool call after thinking',
|
||||
'previous':
|
||||
'<think>I need to use a tool. <tool_calls>\n{"name": "ignored_tool", "arguments": {"param": "value"}}\n</tool_calls></think>',
|
||||
'current':
|
||||
'<think>I need to use a tool. <tool_calls>\n{"name": "ignored_tool", "arguments": {"param": "value"}}\n</tool_calls></think>\n<tool_calls>',
|
||||
'delta': '\n<tool_calls>',
|
||||
'expected_content':
|
||||
'\n', # Should output '\n' and suppress <tool_calls>
|
||||
}
|
||||
]
|
||||
|
||||
for i, test_case in enumerate(test_cases):
|
||||
print(f"\n--- Stage {i}: {test_case['stage']} ---")
|
||||
print(f"Previous: {repr(test_case['previous'])}")
|
||||
print(f"Current: {repr(test_case['current'])}")
|
||||
print(f"Delta: {repr(test_case['delta'])}")
|
||||
|
||||
result = minimax_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text=test_case['previous'],
|
||||
current_text=test_case['current'],
|
||||
delta_text=test_case['delta'],
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=None,
|
||||
)
|
||||
|
||||
print(f"Result: {result}")
|
||||
|
||||
# Check expected content
|
||||
if 'expected_content' in test_case:
|
||||
if test_case['expected_content'] is None:
|
||||
assert result is None or not getattr(result, 'content', None), \
|
||||
f"Stage {i}: Expected no content, got {result}"
|
||||
else:
|
||||
assert result is not None and hasattr(result, 'content'), \
|
||||
f"Stage {i}: Expected content, got {result}"
|
||||
assert result.content == test_case['expected_content'], \
|
||||
f"Stage {i}: Expected content {test_case['expected_content']}, got {result.content}"
|
||||
print(f"✓ Content matches: {repr(result.content)}")
|
||||
|
||||
# Check tool calls
|
||||
if test_case.get('expected_tool_call'):
|
||||
assert result is not None and hasattr(result, 'tool_calls') and result.tool_calls, \
|
||||
f"Stage {i}: Expected tool call, got {result}"
|
||||
|
||||
tool_call = result.tool_calls[0]
|
||||
assert tool_call.function.name == "real_tool", \
|
||||
f"Expected real_tool, got {tool_call.function.name}"
|
||||
print(f"✓ Real tool call detected: {tool_call.function.name}")
|
||||
|
||||
print("✓ Thinking tag buffering test completed successfully")
|
||||
|
||||
|
||||
def reset_streaming_state(minimax_tool_parser):
|
||||
"""Helper function to properly reset the streaming state for MinimaxToolParser."""
|
||||
# Reset minimax-specific state
|
||||
minimax_tool_parser._reset_streaming_state()
|
||||
|
||||
# Reset base class state (these should still be reset for compatibility)
|
||||
minimax_tool_parser.prev_tool_call_arr = []
|
||||
minimax_tool_parser.current_tool_id = -1
|
||||
minimax_tool_parser.current_tool_name_sent = False
|
||||
minimax_tool_parser.streamed_args_for_tool = []
|
||||
|
||||
|
||||
def test_streaming_complex_scenario_with_multiple_tools(minimax_tool_parser):
|
||||
"""Test complex streaming scenario: tools inside <think> tags and multiple tool calls in one group."""
|
||||
# Reset streaming state
|
||||
reset_streaming_state(minimax_tool_parser)
|
||||
|
||||
# Complex scenario: tools inside thinking tags and multiple tools in one group
|
||||
test_stages: list[dict[str, Any]] = [
|
||||
{
|
||||
'stage': 'Initial content',
|
||||
'previous': '',
|
||||
'current': 'Let me help you with this task.',
|
||||
'delta': 'Let me help you with this task.',
|
||||
'expected_content': 'Let me help you with this task.',
|
||||
'expected_tool_calls': 0,
|
||||
},
|
||||
{
|
||||
'stage': 'Start thinking tag',
|
||||
'previous': 'Let me help you with this task.',
|
||||
'current':
|
||||
'Let me help you with this task.<think>I need to analyze this situation first.',
|
||||
'delta': '<think>I need to analyze this situation first.',
|
||||
'expected_content':
|
||||
'<think>I need to analyze this situation first.',
|
||||
'expected_tool_calls': 0,
|
||||
},
|
||||
{
|
||||
'stage': 'Tool call inside thinking tag starts',
|
||||
'previous':
|
||||
'Let me help you with this task.<think>I need to analyze this situation first.',
|
||||
'current':
|
||||
'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>',
|
||||
'delta': '<tool_calls>',
|
||||
'expected_content':
|
||||
'<tool_calls>', # Inside thinking tags, tool tags should be preserved as content
|
||||
'expected_tool_calls': 0,
|
||||
},
|
||||
{
|
||||
'stage': 'Complete tool call inside thinking tag',
|
||||
'previous':
|
||||
'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>',
|
||||
'current':
|
||||
'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls>',
|
||||
'delta':
|
||||
'\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls>',
|
||||
'expected_content':
|
||||
'\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls>',
|
||||
'expected_tool_calls':
|
||||
0, # Tools inside thinking tags should be ignored
|
||||
},
|
||||
{
|
||||
'stage': 'End thinking tag',
|
||||
'previous':
|
||||
'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls>',
|
||||
'current':
|
||||
'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>',
|
||||
'delta': '</think>',
|
||||
'expected_content': '</think>',
|
||||
'expected_tool_calls': 0,
|
||||
},
|
||||
{
|
||||
'stage': 'Multiple tools group starts',
|
||||
'previous':
|
||||
'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>',
|
||||
'current':
|
||||
'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>',
|
||||
'delta':
|
||||
'\nNow I need to get weather information and calculate area.<tool_calls>',
|
||||
'expected_content':
|
||||
'\nNow I need to get weather information and calculate area.', # <tool_calls> should be filtered
|
||||
'expected_tool_calls': 0,
|
||||
},
|
||||
{
|
||||
'stage': 'First tool in group',
|
||||
'previous':
|
||||
'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>',
|
||||
'current':
|
||||
'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}',
|
||||
'delta':
|
||||
'\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}',
|
||||
'expected_content':
|
||||
None, # No content should be output when tool call is in progress
|
||||
'expected_tool_calls': 1,
|
||||
'expected_tool_name': 'get_current_weather',
|
||||
},
|
||||
{
|
||||
'stage': 'Second tool in group',
|
||||
'previous':
|
||||
'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}',
|
||||
'current':
|
||||
'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}\n{"name": "calculate_area", "arguments": {"shape": "rectangle", "dimensions": {"width": 10, "height": 5}}}',
|
||||
'delta':
|
||||
'\n{"name": "calculate_area", "arguments": {"shape": "rectangle", "dimensions": {"width": 10, "height": 5}}}',
|
||||
'expected_content': None,
|
||||
'expected_tool_calls': 1,
|
||||
'expected_tool_name': 'calculate_area',
|
||||
},
|
||||
{
|
||||
'stage': 'Complete tool calls group',
|
||||
'previous':
|
||||
'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}\n{"name": "calculate_area", "arguments": {"shape": "rectangle", "dimensions": {"width": 10, "height": 5}}}',
|
||||
'current':
|
||||
'Let me help you with this task.<think>I need to analyze this situation first.<tool_calls>\n{"name": "internal_analysis", "arguments": {"query": "analyze situation"}}\n</tool_calls></think>\nNow I need to get weather information and calculate area.<tool_calls>\n{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}\n{"name": "calculate_area", "arguments": {"shape": "rectangle", "dimensions": {"width": 10, "height": 5}}}</tool_calls>',
|
||||
'delta': '</tool_calls>',
|
||||
'expected_content': None,
|
||||
'expected_tool_calls': 0,
|
||||
}
|
||||
]
|
||||
|
||||
tool_calls_count = 0
|
||||
|
||||
for i, test_case in enumerate(test_stages):
|
||||
print(f"\n--- Stage {i}: {test_case['stage']} ---")
|
||||
print(
|
||||
f"Previous: {repr(test_case['previous'][:100])}{'...' if len(test_case['previous']) > 100 else ''}"
|
||||
)
|
||||
print(f"Current: {repr(test_case['current'][-100:])}")
|
||||
print(f"Delta: {repr(test_case['delta'])}")
|
||||
|
||||
result = minimax_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text=test_case['previous'],
|
||||
current_text=test_case['current'],
|
||||
delta_text=test_case['delta'],
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=None,
|
||||
)
|
||||
|
||||
print(f"Result: {result}")
|
||||
|
||||
# Check expected content
|
||||
if test_case['expected_content'] is None:
|
||||
assert result is None or not getattr(result, 'content', None), \
|
||||
f"Stage {i}: Expected no content output, got {result}"
|
||||
print("✓ No content output as expected")
|
||||
else:
|
||||
assert result is not None and hasattr(result, 'content'), \
|
||||
f"Stage {i}: Expected content output, got {result}"
|
||||
assert result.content == test_case['expected_content'], \
|
||||
f"Stage {i}: Expected content {repr(test_case['expected_content'])}, got {repr(result.content)}"
|
||||
print(f"✓ Content matches: {repr(result.content)}")
|
||||
|
||||
# Check tool calls
|
||||
expected_tool_calls = test_case['expected_tool_calls']
|
||||
actual_tool_calls = len(result.tool_calls) if result and hasattr(
|
||||
result, 'tool_calls') and result.tool_calls else 0
|
||||
|
||||
if expected_tool_calls > 0:
|
||||
assert actual_tool_calls >= expected_tool_calls, \
|
||||
f"Stage {i}: Expected at least {expected_tool_calls} tool calls, got {actual_tool_calls}"
|
||||
|
||||
if 'expected_tool_name' in test_case:
|
||||
# Find the tool call with the expected name
|
||||
found_tool_call = None
|
||||
for tool_call in result.tool_calls:
|
||||
if tool_call.function.name == test_case[
|
||||
'expected_tool_name']:
|
||||
found_tool_call = tool_call
|
||||
break
|
||||
|
||||
assert found_tool_call is not None, \
|
||||
f"Stage {i}: Expected tool name {test_case['expected_tool_name']} not found in tool calls: {[tc.function.name for tc in result.tool_calls]}"
|
||||
print(f"✓ Tool call correct: {found_tool_call.function.name}")
|
||||
|
||||
# Ensure tools inside thinking tags are not called
|
||||
assert found_tool_call.function.name != "internal_analysis", \
|
||||
f"Stage {i}: Tool 'internal_analysis' inside thinking tags should not be called"
|
||||
|
||||
tool_calls_count += actual_tool_calls
|
||||
print(f"✓ Detected {actual_tool_calls} tool calls")
|
||||
else:
|
||||
assert actual_tool_calls == 0, \
|
||||
f"Stage {i}: Expected no tool calls, got {actual_tool_calls}"
|
||||
|
||||
# Verify overall results
|
||||
print("\n=== Test Summary ===")
|
||||
print(f"Total tool calls count: {tool_calls_count}")
|
||||
assert tool_calls_count >= 2, f"Expected at least 2 valid tool calls (outside thinking tags), but got {tool_calls_count}"
|
||||
|
||||
print("✓ Complex streaming test completed:")
|
||||
print(" - ✓ Tools inside thinking tags correctly ignored")
|
||||
print(" - ✓ Two tool groups outside thinking tags correctly parsed")
|
||||
print(" - ✓ Content and tool call streaming correctly handled")
|
||||
print(" - ✓ Buffering mechanism works correctly")
|
||||
|
||||
|
||||
def test_streaming_character_by_character_output(minimax_tool_parser):
|
||||
"""Test character-by-character streaming output to simulate real streaming scenarios."""
|
||||
# Reset streaming state
|
||||
reset_streaming_state(minimax_tool_parser)
|
||||
|
||||
# Complete text that will be streamed character by character
|
||||
complete_text = """I'll help you with the weather analysis. <think>Let me think about this. <tool_calls>
|
||||
{"name": "internal_analysis", "arguments": {"type": "thinking"}}
|
||||
</tool_calls>This tool should be ignored.</think>
|
||||
|
||||
Now I'll get the weather information for you. <tool_calls>
|
||||
{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}
|
||||
{"name": "calculate_area", "arguments": {"shape": "rectangle", "dimensions": {"width": 10, "height": 5}}}
|
||||
</tool_calls>Here are the results."""
|
||||
|
||||
print("\n=== Starting character-by-character streaming test ===")
|
||||
print(f"Complete text length: {len(complete_text)} characters")
|
||||
|
||||
# Track the streaming results
|
||||
content_fragments = []
|
||||
tool_calls_detected = []
|
||||
|
||||
# Stream character by character
|
||||
for i in range(1, len(complete_text) + 1):
|
||||
current_text = complete_text[:i]
|
||||
previous_text = complete_text[:i - 1] if i > 1 else ""
|
||||
delta_text = complete_text[i - 1:i]
|
||||
|
||||
# Show progress every 50 characters
|
||||
if i % 50 == 0 or i == len(complete_text):
|
||||
print(f"Progress: {i}/{len(complete_text)} characters")
|
||||
|
||||
# Call the streaming parser
|
||||
result = minimax_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=None,
|
||||
)
|
||||
|
||||
# Collect results
|
||||
if result is not None:
|
||||
if hasattr(result, 'content') and result.content:
|
||||
content_fragments.append(result.content)
|
||||
# Log important content fragments
|
||||
if any(
|
||||
keyword in result.content for keyword in
|
||||
['<think>', '</think>', '<tool_calls>', '</tool_calls>']):
|
||||
print(
|
||||
f" Char {i}: Content fragment: {repr(result.content)}"
|
||||
)
|
||||
|
||||
if hasattr(result, 'tool_calls') and result.tool_calls:
|
||||
for tool_call in result.tool_calls:
|
||||
tool_info = {
|
||||
'character_position':
|
||||
i,
|
||||
'function_name':
|
||||
tool_call.function.name
|
||||
if tool_call.function else None,
|
||||
'arguments':
|
||||
tool_call.function.arguments
|
||||
if tool_call.function else None,
|
||||
}
|
||||
tool_calls_detected.append(tool_info)
|
||||
print(
|
||||
f" Char {i}: Tool call detected: {tool_call.function.name}"
|
||||
)
|
||||
if tool_call.function.arguments:
|
||||
print(
|
||||
f" Arguments: {repr(tool_call.function.arguments)}"
|
||||
)
|
||||
|
||||
# Verify results
|
||||
print("\n=== Streaming Test Results ===")
|
||||
print(f"Total content fragments: {len(content_fragments)}")
|
||||
print(f"Total tool calls detected: {len(tool_calls_detected)}")
|
||||
|
||||
# Reconstruct content from fragments
|
||||
reconstructed_content = ''.join(content_fragments)
|
||||
print(f"Reconstructed content length: {len(reconstructed_content)}")
|
||||
|
||||
# Verify thinking tags content is preserved
|
||||
assert '<think>' in reconstructed_content, "Opening thinking tag should be preserved in content"
|
||||
assert '</think>' in reconstructed_content, "Closing thinking tag should be preserved in content"
|
||||
|
||||
# Verify that tool calls inside thinking tags are NOT extracted as actual tool calls
|
||||
thinking_tool_calls = [
|
||||
tc for tc in tool_calls_detected
|
||||
if tc['function_name'] == 'internal_analysis'
|
||||
]
|
||||
assert len(
|
||||
thinking_tool_calls
|
||||
) == 0, f"Tool calls inside thinking tags should be ignored, but found: {thinking_tool_calls}"
|
||||
|
||||
# Verify that real tool calls outside thinking tags ARE extracted
|
||||
weather_tool_calls = [
|
||||
tc for tc in tool_calls_detected
|
||||
if tc['function_name'] == 'get_current_weather'
|
||||
]
|
||||
area_tool_calls = [
|
||||
tc for tc in tool_calls_detected
|
||||
if tc['function_name'] == 'calculate_area'
|
||||
]
|
||||
print(tool_calls_detected)
|
||||
assert len(weather_tool_calls
|
||||
) > 0, "get_current_weather tool call should be detected"
|
||||
assert len(
|
||||
area_tool_calls) > 0, "calculate_area tool call should be detected"
|
||||
|
||||
# Verify tool call arguments are properly streamed
|
||||
weather_args_found = any(tc['arguments'] for tc in weather_tool_calls
|
||||
if tc['arguments'])
|
||||
area_args_found = any(tc['arguments'] for tc in area_tool_calls
|
||||
if tc['arguments'])
|
||||
|
||||
print(f"Weather tool call with arguments: {weather_args_found}")
|
||||
print(f"Area tool call with arguments: {area_args_found}")
|
||||
|
||||
# Verify content before and after tool calls
|
||||
assert 'I\'ll help you with the weather analysis.' in reconstructed_content, "Initial content should be preserved"
|
||||
assert 'Here are the results.' in reconstructed_content, "Final content should be preserved"
|
||||
|
||||
# Verify that <tool_calls> and </tool_calls> tags are not included in the final content
|
||||
# (they should be filtered out when not inside thinking tags)
|
||||
content_outside_thinking = reconstructed_content
|
||||
# Remove thinking tag content to check content outside
|
||||
if '<think>' in content_outside_thinking and '</think>' in content_outside_thinking:
|
||||
start_think = content_outside_thinking.find('<think>')
|
||||
end_think = content_outside_thinking.find('</think>') + len('</think>')
|
||||
content_outside_thinking = content_outside_thinking[:
|
||||
start_think] + content_outside_thinking[
|
||||
end_think:]
|
||||
|
||||
# Outside thinking tags, tool_calls tags should be filtered
|
||||
tool_calls_in_content = content_outside_thinking.count('<tool_calls>')
|
||||
assert tool_calls_in_content == 0, f"<tool_calls> tags should be filtered from content outside thinking tags, but found {tool_calls_in_content}"
|
||||
|
||||
print(
|
||||
"\n=== Character-by-character streaming test completed successfully ==="
|
||||
)
|
||||
print("✓ Tool calls inside thinking tags correctly ignored")
|
||||
print("✓ Tool calls outside thinking tags correctly detected")
|
||||
print("✓ Content properly streamed and reconstructed")
|
||||
print("✓ Tool call tags properly filtered from content")
|
||||
print("✓ Character-level streaming works correctly")
|
||||
|
||||
|
||||
def test_streaming_character_by_character_simple_tool_call(
|
||||
minimax_tool_parser):
|
||||
"""Test character-by-character streaming for a simple tool call scenario."""
|
||||
# Reset streaming state
|
||||
reset_streaming_state(minimax_tool_parser)
|
||||
|
||||
# Simple tool call text
|
||||
simple_text = 'Let me check the weather. <tool_calls>\n{"name": "get_weather", "arguments": {"city": "NYC"}}\n</tool_calls>'
|
||||
|
||||
print("\n=== Simple character-by-character test ===")
|
||||
print(f"Text: {repr(simple_text)}")
|
||||
|
||||
content_parts = []
|
||||
tool_name_sent = False
|
||||
tool_args_sent = False
|
||||
|
||||
for i in range(1, len(simple_text) + 1):
|
||||
current_text = simple_text[:i]
|
||||
previous_text = simple_text[:i - 1] if i > 1 else ""
|
||||
delta_text = simple_text[i - 1:i]
|
||||
|
||||
result = minimax_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=None,
|
||||
)
|
||||
|
||||
if result:
|
||||
if hasattr(result, 'content') and result.content:
|
||||
content_parts.append(result.content)
|
||||
print(
|
||||
f" Char {i} ({repr(delta_text)}): Content: {repr(result.content)}"
|
||||
)
|
||||
|
||||
if hasattr(result, 'tool_calls') and result.tool_calls:
|
||||
for tool_call in result.tool_calls:
|
||||
if tool_call.function and tool_call.function.name:
|
||||
tool_name_sent = True
|
||||
print(
|
||||
f" Char {i}: Tool name: {tool_call.function.name}"
|
||||
)
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
tool_args_sent = True
|
||||
print(
|
||||
f" Char {i}: Tool args: {repr(tool_call.function.arguments)}"
|
||||
)
|
||||
|
||||
# Verify basic expectations
|
||||
reconstructed_content = ''.join(content_parts)
|
||||
print(f"Final reconstructed content: {repr(reconstructed_content)}")
|
||||
|
||||
assert tool_name_sent, "Tool name should be sent during streaming"
|
||||
assert tool_args_sent, "Tool arguments should be sent during streaming"
|
||||
assert "Let me check the weather." in reconstructed_content, "Initial content should be preserved"
|
||||
|
||||
print("✓ Simple character-by-character test passed")
|
||||
|
||||
|
||||
def test_streaming_character_by_character_with_buffering(minimax_tool_parser):
|
||||
"""Test character-by-character streaming with edge cases that trigger buffering."""
|
||||
# Reset streaming state
|
||||
reset_streaming_state(minimax_tool_parser)
|
||||
|
||||
# Text that includes potential buffering scenarios
|
||||
buffering_text = 'Hello world<tool_calls>\n{"name": "test"}\n</tool_calls>done'
|
||||
|
||||
print("\n=== Buffering character-by-character test ===")
|
||||
print(f"Text: {repr(buffering_text)}")
|
||||
|
||||
all_content = []
|
||||
|
||||
for i in range(1, len(buffering_text) + 1):
|
||||
current_text = buffering_text[:i]
|
||||
previous_text = buffering_text[:i - 1] if i > 1 else ""
|
||||
delta_text = buffering_text[i - 1:i]
|
||||
|
||||
result = minimax_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=None,
|
||||
)
|
||||
|
||||
if result and hasattr(result, 'content') and result.content:
|
||||
all_content.append(result.content)
|
||||
print(f" Char {i} ({repr(delta_text)}): {repr(result.content)}")
|
||||
|
||||
final_content = ''.join(all_content)
|
||||
print(f"Final content: {repr(final_content)}")
|
||||
|
||||
# The parser should handle the edge case where </tool_calls> appears before <tool_calls>
|
||||
assert "Hello" in final_content, "Initial 'Hello' should be preserved"
|
||||
assert "world" in final_content, "Content after false closing tag should be preserved"
|
||||
assert "done" in final_content, "Final content should be preserved"
|
||||
|
||||
print("✓ Buffering character-by-character test passed")
|
||||
|
||||
@ -17,7 +17,6 @@ from pathlib import Path
|
||||
from typing import Any, Callable, Literal, Optional, Union
|
||||
|
||||
import cloudpickle
|
||||
import httpx
|
||||
import openai
|
||||
import pytest
|
||||
import requests
|
||||
@ -89,12 +88,10 @@ class RemoteOpenAIServer:
|
||||
raise ValueError("You have manually specified the port "
|
||||
"when `auto_port=True`.")
|
||||
|
||||
# No need for a port if using unix sockets
|
||||
if "--uds" not in vllm_serve_args:
|
||||
# Don't mutate the input args
|
||||
vllm_serve_args = vllm_serve_args + [
|
||||
"--port", str(get_open_port())
|
||||
]
|
||||
# Don't mutate the input args
|
||||
vllm_serve_args = vllm_serve_args + [
|
||||
"--port", str(get_open_port())
|
||||
]
|
||||
if seed is not None:
|
||||
if "--seed" in vllm_serve_args:
|
||||
raise ValueError("You have manually specified the seed "
|
||||
@ -107,13 +104,8 @@ class RemoteOpenAIServer:
|
||||
subparsers = parser.add_subparsers(required=False, dest="subparser")
|
||||
parser = ServeSubcommand().subparser_init(subparsers)
|
||||
args = parser.parse_args(["--model", model, *vllm_serve_args])
|
||||
self.uds = args.uds
|
||||
if args.uds:
|
||||
self.host = None
|
||||
self.port = None
|
||||
else:
|
||||
self.host = str(args.host or 'localhost')
|
||||
self.port = int(args.port)
|
||||
self.host = str(args.host or 'localhost')
|
||||
self.port = int(args.port)
|
||||
|
||||
self.show_hidden_metrics = \
|
||||
args.show_hidden_metrics_for_version is not None
|
||||
@ -158,11 +150,9 @@ class RemoteOpenAIServer:
|
||||
def _wait_for_server(self, *, url: str, timeout: float):
|
||||
# run health check
|
||||
start = time.time()
|
||||
client = (httpx.Client(transport=httpx.HTTPTransport(
|
||||
uds=self.uds)) if self.uds else requests)
|
||||
while True:
|
||||
try:
|
||||
if client.get(url).status_code == 200:
|
||||
if requests.get(url).status_code == 200:
|
||||
break
|
||||
except Exception:
|
||||
# this exception can only be raised by requests.get,
|
||||
@ -180,8 +170,7 @@ class RemoteOpenAIServer:
|
||||
|
||||
@property
|
||||
def url_root(self) -> str:
|
||||
return (f"http://{self.uds.split('/')[-1]}"
|
||||
if self.uds else f"http://{self.host}:{self.port}")
|
||||
return f"http://{self.host}:{self.port}"
|
||||
|
||||
def url_for(self, *parts: str) -> str:
|
||||
return self.url_root + "/" + "/".join(parts)
|
||||
@ -997,19 +986,3 @@ def has_module_attribute(module_name, attribute_name):
|
||||
return hasattr(module, attribute_name)
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def get_attn_backend_list_based_on_platform() -> list[str]:
|
||||
if current_platform.is_cuda():
|
||||
return ["FLASH_ATTN_VLLM_V1", "TRITON_ATTN_VLLM_V1", "TREE_ATTN"]
|
||||
elif current_platform.is_rocm():
|
||||
attn_backend_list = ["TRITON_ATTN_VLLM_V1"]
|
||||
try:
|
||||
import aiter # noqa: F401
|
||||
attn_backend_list.append("FLASH_ATTN_VLLM_V1")
|
||||
except Exception:
|
||||
print("Skip FLASH_ATTN_VLLM_V1 on ROCm as aiter is not installed")
|
||||
|
||||
return attn_backend_list
|
||||
else:
|
||||
raise ValueError("Unsupported platform")
|
||||
|
||||
@ -11,7 +11,7 @@ import torch
|
||||
from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig,
|
||||
LoadConfig, ModelConfig, ModelDType, ParallelConfig,
|
||||
SchedulerConfig, VllmConfig)
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.utils import resolve_obj_by_qualname
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
@ -119,10 +119,7 @@ def get_attention_backend(backend_name: _Backend):
|
||||
"""
|
||||
backend_map = {
|
||||
_Backend.FLASH_ATTN_VLLM_V1:
|
||||
("vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
|
||||
if current_platform.is_cuda() else
|
||||
"vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
|
||||
),
|
||||
"vllm.v1.attention.backends.flash_attn.FlashAttentionBackend",
|
||||
_Backend.FLASHINFER_VLLM_V1:
|
||||
"vllm.v1.attention.backends.flashinfer.FlashInferBackend",
|
||||
_Backend.FLEX_ATTENTION:
|
||||
|
||||
@ -8,12 +8,10 @@ from typing import Any, Union
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.utils import get_attn_backend_list_based_on_platform
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.assets.base import VLLM_S3_BUCKET_URL
|
||||
from vllm.assets.image import VLM_IMAGES_DIR
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def get_test_prompts(mm_enabled: bool):
|
||||
@ -143,14 +141,11 @@ def test_ngram_correctness(
|
||||
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
|
||||
],
|
||||
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm"])
|
||||
@pytest.mark.parametrize("attn_backend",
|
||||
get_attn_backend_list_based_on_platform())
|
||||
def test_eagle_correctness(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sampling_config: SamplingParams,
|
||||
model_setup: tuple[str, str, str, int],
|
||||
mm_enabled: bool,
|
||||
attn_backend: str,
|
||||
):
|
||||
# Generate test prompts inside the function instead of using fixture
|
||||
test_prompts = get_test_prompts(mm_enabled)
|
||||
@ -161,16 +156,6 @@ def test_eagle_correctness(
|
||||
'''
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||
|
||||
if (attn_backend == "TRITON_ATTN_VLLM_V1"
|
||||
and not current_platform.is_rocm()):
|
||||
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
|
||||
"multi-token eagle spec decode on current platform")
|
||||
|
||||
if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
|
||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
method, model_name, spec_model_name, tp_size = model_setup
|
||||
|
||||
ref_llm = LLM(model=model_name,
|
||||
|
||||
@ -17,7 +17,7 @@ async def test_simple_input(client: openai.AsyncOpenAI):
|
||||
|
||||
# Whether the output contains the reasoning.
|
||||
assert outputs[0].type == "reasoning"
|
||||
assert outputs[0].content[0].text != ""
|
||||
assert outputs[0].text != ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@ -20,8 +20,9 @@ MODEL_NAME = "facebook/opt-125m"
|
||||
@pytest.fixture(scope="module")
|
||||
def default_server_args():
|
||||
return [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"float32",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"2048",
|
||||
"--max-num-seqs",
|
||||
|
||||
@ -1,103 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import torch
|
||||
from transformers import AutoConfig
|
||||
|
||||
from tests.conftest import ImageTestAssets
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
# any model with a chat template should work here
|
||||
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
|
||||
CONFIG = AutoConfig.from_pretrained(MODEL_NAME)
|
||||
MAXIMUM_IMAGES = 2
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def default_image_embeds_server_args() -> list[str]:
|
||||
return [
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"2048",
|
||||
"--max-num-seqs",
|
||||
"4",
|
||||
"--enforce-eager",
|
||||
"--limit-mm-per-prompt",
|
||||
json.dumps({"image": MAXIMUM_IMAGES}),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server_with_image_embeds(default_image_embeds_server_args):
|
||||
with RemoteOpenAIServer(MODEL_NAME,
|
||||
default_image_embeds_server_args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client_with_image_embeds(server_with_image_embeds):
|
||||
async with server_with_image_embeds.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
def encode_image_embedding_to_base64(image_embedding) -> str:
|
||||
"""
|
||||
Encode image embedding to base64 string
|
||||
"""
|
||||
buffer = io.BytesIO()
|
||||
torch.save(image_embedding, buffer)
|
||||
buffer.seek(0)
|
||||
binary_data = buffer.read()
|
||||
base64_image_embedding = base64.b64encode(binary_data).decode('utf-8')
|
||||
return base64_image_embedding
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("dtype", [torch.half, torch.float16, torch.float32])
|
||||
async def test_completions_with_image_embeds(
|
||||
client_with_image_embeds: openai.AsyncOpenAI,
|
||||
model_name: str,
|
||||
image_assets: ImageTestAssets,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
# Test case: Single image embeds input
|
||||
image_embeds = image_assets[0].image_embeds.to(dtype=dtype)
|
||||
base64_image_embedding = encode_image_embedding_to_base64(image_embeds)
|
||||
chat_completion = await client_with_image_embeds.chat.completions.create(
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
},
|
||||
{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type":
|
||||
"text",
|
||||
"text":
|
||||
"Describe these images separately. For each image,"
|
||||
"reply with a short sentence (no more than 10 words).",
|
||||
},
|
||||
{
|
||||
"type": "image_embeds",
|
||||
"image_embeds": base64_image_embedding,
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
model=model_name,
|
||||
)
|
||||
assert chat_completion.choices[0].message.content is not None
|
||||
assert isinstance(chat_completion.choices[0].message.content, str)
|
||||
assert len(chat_completion.choices[0].message.content) > 0
|
||||
@ -6,7 +6,6 @@ from unittest import mock
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.utils import get_attn_backend_list_based_on_platform
|
||||
from tests.v1.attention.utils import (BatchSpec, _Backend,
|
||||
create_common_attn_metadata,
|
||||
create_standard_kv_cache_spec,
|
||||
@ -121,28 +120,17 @@ def test_prepare_inputs():
|
||||
assert torch.equal(token_indices, expected_token_indices)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
|
||||
@pytest.mark.parametrize("attn_backend",
|
||||
get_attn_backend_list_based_on_platform())
|
||||
@pytest.mark.parametrize("method,proposer_helper", [
|
||||
("eagle", lambda k: _create_proposer("eagle", k)),
|
||||
("eagle3", lambda k: _create_proposer("eagle3", k)),
|
||||
])
|
||||
@pytest.mark.parametrize("pp_size", [1, 2])
|
||||
@pytest.mark.parametrize("use_distinct_embed_tokens", [True, False])
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.get_model')
|
||||
def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
|
||||
attn_backend, pp_size, use_distinct_embed_tokens,
|
||||
monkeypatch):
|
||||
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||
|
||||
if (attn_backend == "TRITON_ATTN_VLLM_V1"
|
||||
and not current_platform.is_rocm()):
|
||||
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
|
||||
"multi-token eagle spec decode on current platform")
|
||||
|
||||
if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
proposer_helper, pp_size, use_distinct_embed_tokens):
|
||||
# Setup draft model mock
|
||||
mock_model = mock.MagicMock()
|
||||
if use_distinct_embed_tokens:
|
||||
@ -189,7 +177,7 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
|
||||
target_model.lm_head = mock.MagicMock()
|
||||
|
||||
# Create proposer using the helper function
|
||||
proposer = _create_proposer(method, k=8)
|
||||
proposer = proposer_helper(k=8)
|
||||
|
||||
# Call the method under test
|
||||
proposer.load_model(target_model)
|
||||
@ -213,22 +201,10 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
|
||||
target_model.model.embed_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
|
||||
@pytest.mark.parametrize("attn_backend",
|
||||
get_attn_backend_list_based_on_platform())
|
||||
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
|
||||
def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
||||
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||
|
||||
if (attn_backend == "TRITON_ATTN_VLLM_V1"
|
||||
and not current_platform.is_rocm()):
|
||||
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
|
||||
"multi-token eagle spec decode on current platform")
|
||||
|
||||
if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
@pytest.mark.parametrize("backend",
|
||||
[_Backend.FLASH_ATTN_VLLM_V1, _Backend.TREE_ATTN])
|
||||
def test_propose(num_speculative_tokens, backend):
|
||||
# Use GPU device
|
||||
device = torch.device(current_platform.device_type)
|
||||
|
||||
@ -327,18 +303,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
||||
device=device)
|
||||
sampling_metadata = mock.MagicMock()
|
||||
|
||||
if attn_backend == "FLASH_ATTN_VLLM_V1":
|
||||
attn_metadata_builder_cls, _ = get_attention_backend(
|
||||
_Backend.FLASH_ATTN_VLLM_V1)
|
||||
elif attn_backend == "TRITON_ATTN_VLLM_V1":
|
||||
attn_metadata_builder_cls, _ = get_attention_backend(
|
||||
_Backend.TRITON_ATTN_VLLM_V1)
|
||||
elif attn_backend == "TREE_ATTN":
|
||||
attn_metadata_builder_cls, _ = get_attention_backend(
|
||||
_Backend.TREE_ATTN)
|
||||
else:
|
||||
raise ValueError(f"Unsupported attention backend: {attn_backend}")
|
||||
|
||||
attn_metadata_builder_cls, _ = get_attention_backend(backend)
|
||||
attn_metadata_builder = attn_metadata_builder_cls(
|
||||
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
|
||||
layer_names=proposer.attn_layer_names,
|
||||
@ -348,8 +313,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
||||
|
||||
# Mock runner for attention metadata building
|
||||
proposer.runner = mock.MagicMock()
|
||||
proposer.runner.attn_groups.append([mock.MagicMock()])
|
||||
proposer.runner.attn_groups[0][0].metadata_builder = attn_metadata_builder
|
||||
proposer.runner.attn_metadata_builders = [attn_metadata_builder]
|
||||
|
||||
result = proposer.propose(target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
|
||||
@ -4,9 +4,7 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.utils import get_attn_backend_list_based_on_platform
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
_PROMPTS = [
|
||||
"1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1",
|
||||
@ -16,39 +14,35 @@ _PROMPTS = [
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10])
|
||||
def test_ngram_max_len(num_speculative_tokens: int):
|
||||
llm = LLM(
|
||||
model="facebook/opt-125m",
|
||||
max_model_len=100,
|
||||
enforce_eager=True, # For faster initialization.
|
||||
speculative_config={
|
||||
"method": "ngram",
|
||||
"prompt_lookup_max": 5,
|
||||
"prompt_lookup_min": 3,
|
||||
"num_speculative_tokens": num_speculative_tokens,
|
||||
},
|
||||
)
|
||||
sampling_params = SamplingParams(max_tokens=100, ignore_eos=True)
|
||||
llm.generate(_PROMPTS, sampling_params)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10])
|
||||
@pytest.mark.parametrize("attn_backend",
|
||||
get_attn_backend_list_based_on_platform())
|
||||
def test_eagle_max_len(monkeypatch: pytest.MonkeyPatch,
|
||||
num_speculative_tokens: int, attn_backend: str):
|
||||
def test_ngram_max_len(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
num_speculative_tokens: int,
|
||||
):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||
llm = LLM(
|
||||
model="facebook/opt-125m",
|
||||
max_model_len=100,
|
||||
enforce_eager=True, # For faster initialization.
|
||||
speculative_config={
|
||||
"method": "ngram",
|
||||
"prompt_lookup_max": 5,
|
||||
"prompt_lookup_min": 3,
|
||||
"num_speculative_tokens": num_speculative_tokens,
|
||||
},
|
||||
)
|
||||
sampling_params = SamplingParams(max_tokens=100, ignore_eos=True)
|
||||
llm.generate(_PROMPTS, sampling_params)
|
||||
|
||||
if (attn_backend == "TRITON_ATTN_VLLM_V1"
|
||||
and not current_platform.is_rocm()):
|
||||
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
|
||||
"multi-token eagle spec decode on current platform")
|
||||
|
||||
if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
|
||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10])
|
||||
def test_eagle_max_len(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
num_speculative_tokens: int,
|
||||
):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
|
||||
@ -47,12 +47,13 @@ def test_ngram_proposer():
|
||||
model_config = ModelConfig(model="facebook/opt-125m")
|
||||
return NgramProposer(
|
||||
vllm_config=VllmConfig(model_config=model_config,
|
||||
speculative_config=SpeculativeConfig(
|
||||
prompt_lookup_min=min_n,
|
||||
prompt_lookup_max=max_n,
|
||||
num_speculative_tokens=k,
|
||||
method="ngram",
|
||||
)))
|
||||
speculative_config=SpeculativeConfig.
|
||||
from_dict({
|
||||
"prompt_lookup_min": min_n,
|
||||
"prompt_lookup_max": max_n,
|
||||
"num_speculative_tokens": k,
|
||||
"method": "ngram",
|
||||
})))
|
||||
|
||||
# No match.
|
||||
result = ngram_proposer(
|
||||
|
||||
@ -12,6 +12,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
UNSUPPORTED_MODELS_V1 = [
|
||||
"openai/whisper-large-v3", # transcription
|
||||
"facebook/bart-large-cnn", # encoder decoder
|
||||
"state-spaces/mamba-130m-hf", # mamba1
|
||||
]
|
||||
|
||||
MODEL = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
|
||||
@ -417,12 +417,12 @@ def test_kv_cache_stride_order(monkeypatch, model_runner):
|
||||
return rnd_stride
|
||||
|
||||
# Patch the attention backend class and re-trigger the KV cache creation.
|
||||
for attn_group in model_runner._attn_group_iterator():
|
||||
attn_backend = attn_group.backend
|
||||
for attn_backend in model_runner.attn_backends:
|
||||
monkeypatch.setattr(attn_backend, "get_kv_cache_stride_order",
|
||||
rnd_stride_order)
|
||||
|
||||
model_runner.attn_groups = []
|
||||
model_runner.attn_backends = []
|
||||
model_runner.attn_metadata_builders = []
|
||||
model_runner.initialize_kv_cache(model_runner.kv_cache_config)
|
||||
|
||||
# Shape is unchanged, but layout may differ
|
||||
|
||||
@ -271,7 +271,6 @@ class ipex_ops:
|
||||
k_descale=None,
|
||||
v_descale=None,
|
||||
num_splits=0,
|
||||
s_aux: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if cu_seqlens_k is None:
|
||||
# cu_seqlens_k is not used in ipex kernel.
|
||||
|
||||
@ -106,10 +106,6 @@ class AttentionBackend(ABC):
|
||||
block_size: int, num_seqs: int, num_queries: int) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def full_cls_name(cls) -> tuple[str, str]:
|
||||
return (cls.__module__, cls.__qualname__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttentionMetadata:
|
||||
|
||||
@ -9,7 +9,6 @@ import torch.nn.functional as F
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import AttentionType
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
|
||||
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
|
||||
from vllm.config import CacheConfig, get_current_vllm_config
|
||||
@ -81,7 +80,6 @@ class Attention(nn.Module):
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
attn_backend: Optional[type[AttentionBackend]] = None,
|
||||
**extra_impl_args,
|
||||
) -> None:
|
||||
"""
|
||||
@ -139,6 +137,15 @@ class Attention(nn.Module):
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
# For v1 we have backend agnostic iRoPE (local chunked attention)
|
||||
# we have to store the flag on the layer so gpu model runner can
|
||||
# set KVSpec appropriately (and pop it so it doesnt get passed to
|
||||
# the backends)
|
||||
if envs.VLLM_USE_V1:
|
||||
self.use_irope = extra_impl_args.pop("use_irope", False)
|
||||
else:
|
||||
self.use_irope = extra_impl_args.get("use_irope", False)
|
||||
|
||||
quant_method = quant_config.get_quant_method(
|
||||
self, prefix=prefix) if quant_config else None
|
||||
if quant_method is not None and not isinstance(
|
||||
@ -159,22 +166,18 @@ class Attention(nn.Module):
|
||||
# During model initialization, the default dtype is set as the model
|
||||
# weight and activation dtype.
|
||||
dtype = torch.get_default_dtype()
|
||||
if attn_backend is None:
|
||||
self.attn_backend = get_attn_backend(head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
is_attention_free,
|
||||
use_mla=use_mla)
|
||||
else:
|
||||
self.attn_backend = attn_backend
|
||||
|
||||
impl_cls = self.attn_backend.get_impl_cls()
|
||||
attn_backend = get_attn_backend(head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
is_attention_free,
|
||||
use_mla=use_mla)
|
||||
impl_cls = attn_backend.get_impl_cls()
|
||||
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **extra_impl_args)
|
||||
self.backend = backend_name_to_enum(self.attn_backend.get_name())
|
||||
self.backend = backend_name_to_enum(attn_backend.get_name())
|
||||
self.dtype = dtype
|
||||
|
||||
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
|
||||
@ -184,7 +187,7 @@ class Attention(nn.Module):
|
||||
self.use_direct_call = not current_platform.is_cuda_alike(
|
||||
) and not current_platform.is_cpu()
|
||||
|
||||
self.use_output = self.attn_backend.accept_output_buffer
|
||||
self.use_output = attn_backend.accept_output_buffer
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
@ -306,9 +309,6 @@ class Attention(nn.Module):
|
||||
if hasattr(self.impl, "process_weights_after_loading"):
|
||||
self.impl.process_weights_after_loading(act_dtype)
|
||||
|
||||
def get_attn_backend(self) -> type[AttentionBackend]:
|
||||
return self.attn_backend
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
"""Multi-headed attention without any cache, used for ViT."""
|
||||
|
||||
@ -1,88 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import CacheConfig, QuantizationConfig
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
CommonAttentionMetadata, make_local_attention_virtual_batches,
|
||||
subclass_attention_backend, subclass_attention_metadata_builder)
|
||||
|
||||
from ..layer import Attention
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def create_chunked_local_attention_backend(
|
||||
underlying_attn_backend: AttentionBackend,
|
||||
attention_chunk_size: int,
|
||||
block_size: int,
|
||||
) -> type[AttentionBackend]:
|
||||
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_"
|
||||
|
||||
def build_preprocess_fn(cm: CommonAttentionMetadata):
|
||||
return make_local_attention_virtual_batches(attention_chunk_size, cm,
|
||||
block_size)
|
||||
|
||||
# Dynamically create a new attention backend that wraps the
|
||||
# underlying attention backend but applies
|
||||
# `make_local_attention_virtual_batches` before calling `build(...)`
|
||||
builder_cls = subclass_attention_metadata_builder(
|
||||
name_prefix=prefix,
|
||||
builder_cls=underlying_attn_backend.get_builder_cls(),
|
||||
build_preprocess_fn=build_preprocess_fn)
|
||||
attn_backend = subclass_attention_backend(
|
||||
name_prefix=prefix,
|
||||
attention_backend_cls=underlying_attn_backend,
|
||||
builder_cls=builder_cls)
|
||||
|
||||
return attn_backend
|
||||
|
||||
|
||||
class ChunkedLocalAttention(Attention):
|
||||
|
||||
def __init__(self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
attention_chunk_size: int,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
alibi_slopes: Optional[List[float]] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
prefix: str = ""):
|
||||
dtype = torch.get_default_dtype()
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
underlying_attn_backend = get_attn_backend(head_size, dtype,
|
||||
kv_cache_dtype,
|
||||
block_size)
|
||||
|
||||
attn_backend = create_chunked_local_attention_backend(
|
||||
underlying_attn_backend, attention_chunk_size, block_size)
|
||||
else:
|
||||
# in v0 the local attention is handled inside the backends
|
||||
attn_backend = None
|
||||
|
||||
super().__init__(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=num_kv_heads,
|
||||
alibi_slopes=alibi_slopes,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
|
||||
attn_backend=attn_backend)
|
||||
@ -91,6 +91,7 @@ def flash_mla_with_kvcache(
|
||||
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
|
||||
q,
|
||||
k_cache,
|
||||
None,
|
||||
head_dim_v,
|
||||
cache_seqlens,
|
||||
block_table,
|
||||
|
||||
@ -31,8 +31,6 @@ It supports page size >= 1.
|
||||
|
||||
import logging
|
||||
|
||||
from packaging import version
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
@ -42,7 +40,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# Only print the following warnings when triton version < 3.2.0.
|
||||
# The issue won't affect performance or accuracy.
|
||||
if version.parse(triton.__version__) < version.parse('3.2.0'):
|
||||
if triton.__version__ < '3.2.0':
|
||||
logger.warning(
|
||||
"The following error message 'operation scheduled before its operands' "
|
||||
"can be ignored.")
|
||||
|
||||
@ -142,7 +142,7 @@ def get_attn_backend(
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
is_attention_free: bool = False,
|
||||
is_attention_free: bool,
|
||||
use_mla: bool = False,
|
||||
) -> type[AttentionBackend]:
|
||||
"""Selects which attention backend to use and lazily imports it."""
|
||||
|
||||
@ -13,6 +13,7 @@ import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.benchmarks.lib.utils import (convert_to_pytorch_benchmark_format,
|
||||
write_to_json)
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
@ -84,9 +85,6 @@ def main(args: argparse.Namespace):
|
||||
"Please set it to a valid path to use torch profiler.")
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
|
||||
# Lazy import to avoid importing LLM when the bench command is not selected.
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# NOTE(woosuk): If the request cannot be processed in a single batch,
|
||||
# the engine will automatically process the request in multiple batches.
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user