Compare commits
1 Commits
prune-samp
...
remove-asy
| Author | SHA1 | Date | |
|---|---|---|---|
| 0470cac520 |
@ -8,8 +8,7 @@ template = """<!DOCTYPE html>
|
||||
<html>
|
||||
<body>
|
||||
<h1>Links for vLLM</h1/>
|
||||
<a href="../{x86_wheel_html_escaped}">{x86_wheel}</a><br/>
|
||||
<a href="../{arm_wheel_html_escaped}">{arm_wheel}</a><br/>
|
||||
<a href="../{wheel_html_escaped}">{wheel}</a><br/>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
@ -22,25 +21,7 @@ filename = os.path.basename(args.wheel)
|
||||
|
||||
with open("index.html", "w") as f:
|
||||
print(f"Generated index.html for {args.wheel}")
|
||||
# sync the abi tag with .buildkite/scripts/upload-wheels.sh
|
||||
if "x86_64" in filename:
|
||||
x86_wheel = filename
|
||||
arm_wheel = filename.replace("x86_64", "aarch64").replace(
|
||||
"manylinux1", "manylinux2014"
|
||||
)
|
||||
elif "aarch64" in filename:
|
||||
x86_wheel = filename.replace("aarch64", "x86_64").replace(
|
||||
"manylinux2014", "manylinux1"
|
||||
)
|
||||
arm_wheel = filename
|
||||
else:
|
||||
raise ValueError(f"Unsupported wheel: {filename}")
|
||||
# cloudfront requires escaping the '+' character
|
||||
f.write(
|
||||
template.format(
|
||||
x86_wheel=x86_wheel,
|
||||
x86_wheel_html_escaped=x86_wheel.replace("+", "%2B"),
|
||||
arm_wheel=arm_wheel,
|
||||
arm_wheel_html_escaped=arm_wheel.replace("+", "%2B"),
|
||||
)
|
||||
template.format(wheel=filename, wheel_html_escaped=filename.replace("+", "%2B"))
|
||||
)
|
||||
|
||||
12
.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml
Normal file
12
.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml
Normal file
@ -0,0 +1,12 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m HandH1998/QQQ-Llama-3-8b-g128 -b 32 -l 1000 -f 5 -t 1
|
||||
model_name: "HandH1998/QQQ-Llama-3-8b-g128"
|
||||
tasks:
|
||||
- name: "gsm8k"
|
||||
metrics:
|
||||
- name: "exact_match,strict-match"
|
||||
value: 0.419
|
||||
- name: "exact_match,flexible-extract"
|
||||
value: 0.416
|
||||
limit: 1000
|
||||
num_fewshot: 5
|
||||
@ -3,3 +3,4 @@ Meta-Llama-3-70B-Instruct.yaml
|
||||
Mixtral-8x7B-Instruct-v0.1.yaml
|
||||
Qwen2-57B-A14-Instruct.yaml
|
||||
DeepSeek-V2-Lite-Chat.yaml
|
||||
Meta-Llama-3-8B-QQQ.yaml
|
||||
|
||||
@ -7,7 +7,7 @@ This directory contains two sets of benchmark for vllm.
|
||||
- Performance benchmark: benchmark vllm's performance under various workload, for **developers** to gain clarity on whether their PR improves/degrades vllm's performance
|
||||
- Nightly benchmark: compare vllm's performance against alternatives (tgi, trt-llm and lmdeploy), for **the public** to know when to choose vllm.
|
||||
|
||||
See [vLLM performance dashboard](https://hud.pytorch.org/benchmark/llms?repoName=vllm-project%2Fvllm) for the latest performance benchmark results and [vLLM GitHub README](https://github.com/vllm-project/vllm/blob/main/README.md) for latest nightly benchmark results.
|
||||
See [vLLM performance dashboard](https://perf.vllm.ai) for the latest performance benchmark results and [vLLM GitHub README](https://github.com/vllm-project/vllm/blob/main/README.md) for latest nightly benchmark results.
|
||||
|
||||
## Performance benchmark quick overview
|
||||
|
||||
@ -138,20 +138,28 @@ The raw benchmarking results (in the format of json files) are in the `Artifacts
|
||||
|
||||
The `compare-json-results.py` helps to compare benchmark results JSON files converted using `convert-results-json-to-markdown.py`.
|
||||
When run, benchmark script generates results under `benchmark/results` folder, along with the `benchmark_results.md` and `benchmark_results.json`.
|
||||
`compare-json-results.py` compares two `benchmark_results.json` files and provides performance ratio e.g. for Output Tput, Median TTFT and Median TPOT.
|
||||
If only one benchmark_results.json is passed, `compare-json-results.py` compares different TP and PP configurations in the benchmark_results.json instead.
|
||||
`compare-json-results.py` compares two `benchmark_results.json` files and provides performance ratio e.g. for Output Tput, Median TTFT and Median TPOT.
|
||||
|
||||
Here is an example using the script to compare result_a and result_b with Model, Dataset name, input/output lenght, max concurrency and qps.
|
||||
Here is an example using the script to compare result_a and result_b without detail test name.
|
||||
`python3 compare-json-results.py -f results_a/benchmark_results.json -f results_b/benchmark_results.json --ignore_test_name`
|
||||
|
||||
| | results_a/benchmark_results.json | results_b/benchmark_results.json | perf_ratio |
|
||||
|----|----------------------------------------|----------------------------------------|----------|
|
||||
| 0 | 142.633982 | 156.526018 | 1.097396 |
|
||||
| 1 | 241.620334 | 294.018783 | 1.216863 |
|
||||
| 2 | 218.298905 | 262.664916 | 1.203235 |
|
||||
| 3 | 242.743860 | 299.816190 | 1.235113 |
|
||||
|
||||
Here is an example using the script to compare result_a and result_b with detail test name.
|
||||
`python3 compare-json-results.py -f results_a/benchmark_results.json -f results_b/benchmark_results.json`
|
||||
|
||||
| | Model | Dataset Name | Input Len | Output Len | # of max concurrency | qps | results_a/benchmark_results.json | results_b/benchmark_results.json | perf_ratio |
|
||||
|----|---------------------------------------|--------|-----|-----|------|-----|-----------|----------|----------|
|
||||
| 0 | meta-llama/Meta-Llama-3.1-8B-Instruct | random | 128 | 128 | 1000 | 1 | 142.633982 | 156.526018 | 1.097396 |
|
||||
| 1 | meta-llama/Meta-Llama-3.1-8B-Instruct | random | 128 | 128 | 1000 | inf| 241.620334 | 294.018783 | 1.216863 |
|
||||
|
||||
A comparison diagram will be generated below the table.
|
||||
Here is an example to compare between 96c/results_gnr_96c_091_tp2pp3 and 128c/results_gnr_128c_091_tp2pp3
|
||||
<img width="1886" height="828" alt="image" src="https://github.com/user-attachments/assets/c02a43ef-25d0-4fd6-90e5-2169a28682dd" />
|
||||
| | results_a/benchmark_results.json_name | results_a/benchmark_results.json | results_b/benchmark_results.json_name | results_b/benchmark_results.json | perf_ratio |
|
||||
|---|---------------------------------------------|----------------------------------------|---------------------------------------------|----------------------------------------|----------|
|
||||
| 0 | serving_llama8B_tp1_sharegpt_qps_1 | 142.633982 | serving_llama8B_tp1_sharegpt_qps_1 | 156.526018 | 1.097396 |
|
||||
| 1 | serving_llama8B_tp1_sharegpt_qps_16 | 241.620334 | serving_llama8B_tp1_sharegpt_qps_16 | 294.018783 | 1.216863 |
|
||||
| 2 | serving_llama8B_tp1_sharegpt_qps_4 | 218.298905 | serving_llama8B_tp1_sharegpt_qps_4 | 262.664916 | 1.203235 |
|
||||
| 3 | serving_llama8B_tp1_sharegpt_qps_inf | 242.743860 | serving_llama8B_tp1_sharegpt_qps_inf | 299.816190 | 1.235113 |
|
||||
| 4 | serving_llama8B_tp2_random_1024_128_qps_1 | 96.613390 | serving_llama8B_tp4_random_1024_128_qps_1 | 108.404853 | 1.122048 |
|
||||
|
||||
## Nightly test details
|
||||
|
||||
|
||||
@ -1,202 +1,33 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from importlib import util
|
||||
|
||||
import pandas as pd
|
||||
|
||||
plotly_found = util.find_spec("plotly.express") is not None
|
||||
|
||||
|
||||
def compare_data_columns(
|
||||
files, name_column, data_column, info_cols, drop_column, debug=False
|
||||
files, name_column, data_column, drop_column, ignore_test_name=False
|
||||
):
|
||||
"""
|
||||
Align concatenation by keys derived from info_cols instead of row order.
|
||||
- Pick one canonical key list: subset of info_cols present in ALL files.
|
||||
- For each file: set index to those keys, aggregate duplicates
|
||||
- (mean for metric, first for names).
|
||||
- Concat along axis=1 (indexes align), then reset_index so callers can
|
||||
- group by columns.
|
||||
- If --debug, add a <file_label>_name column per file.
|
||||
"""
|
||||
print("\ncompare_data_column:", data_column)
|
||||
|
||||
print("\ncompare_data_column: " + data_column)
|
||||
frames = []
|
||||
raw_data_cols = []
|
||||
compare_frames = []
|
||||
|
||||
# 1) choose a canonical key list from info_cols that exists in ALL files
|
||||
cols_per_file = []
|
||||
for f in files:
|
||||
try:
|
||||
df_tmp = pd.read_json(f, orient="records")
|
||||
except Exception as err:
|
||||
raise ValueError(f"Failed to read {f}") from err
|
||||
cols_per_file.append(set(df_tmp.columns))
|
||||
|
||||
key_cols = [c for c in info_cols if all(c in cset for cset in cols_per_file)]
|
||||
if not key_cols:
|
||||
# soft fallback: use any info_cols present in the first file
|
||||
key_cols = [c for c in info_cols if c in list(cols_per_file[0])]
|
||||
if not key_cols:
|
||||
raise ValueError(
|
||||
"No common key columns found from info_cols across the input files."
|
||||
)
|
||||
|
||||
# 2) build a single "meta" block (keys as columns) once, aligned by the key index
|
||||
meta_added = False
|
||||
|
||||
for file in files:
|
||||
df = pd.read_json(file, orient="records")
|
||||
|
||||
# Keep rows that actually have the compared metric (same as original behavior)
|
||||
if drop_column in df.columns:
|
||||
df = df.dropna(subset=[drop_column], ignore_index=True)
|
||||
|
||||
# Stabilize numeric key columns (harmless if missing)
|
||||
for c in (
|
||||
"Input Len",
|
||||
"Output Len",
|
||||
"TP Size",
|
||||
"PP Size",
|
||||
"# of max concurrency.",
|
||||
"qps",
|
||||
):
|
||||
if c in df.columns:
|
||||
df[c] = pd.to_numeric(df[c], errors="coerce")
|
||||
|
||||
# Ensure all key columns exist
|
||||
for c in key_cols:
|
||||
if c not in df.columns:
|
||||
df[c] = pd.NA
|
||||
|
||||
# Set index = key_cols and aggregate duplicates → unique MultiIndex
|
||||
df_idx = df.set_index(key_cols, drop=False)
|
||||
|
||||
# meta (key columns), unique per key
|
||||
meta = df_idx[key_cols]
|
||||
if not meta.index.is_unique:
|
||||
meta = meta.groupby(level=key_cols, dropna=False).first()
|
||||
|
||||
# metric series for this file, aggregated to one row per key
|
||||
file_label = "/".join(file.split("/")[:-1]) or os.path.basename(file)
|
||||
s = df_idx[data_column]
|
||||
if not s.index.is_unique:
|
||||
s = s.groupby(level=key_cols, dropna=False).mean()
|
||||
s.name = file_label # column label like original
|
||||
|
||||
# add meta once (from first file) so keys are the leftmost columns
|
||||
if not meta_added:
|
||||
frames.append(meta)
|
||||
meta_added = True
|
||||
|
||||
# (NEW) debug: aligned test-name column per file
|
||||
if debug and name_column in df_idx.columns:
|
||||
name_s = df_idx[name_column]
|
||||
if not name_s.index.is_unique:
|
||||
name_s = name_s.groupby(level=key_cols, dropna=False).first()
|
||||
name_s.name = f"{file_label}_name"
|
||||
frames.append(name_s)
|
||||
|
||||
frames.append(s)
|
||||
raw_data_cols.append(file_label)
|
||||
compare_frames.append(s)
|
||||
|
||||
# Generalize ratio: for any file N>=2, add ratio (fileN / file1)
|
||||
data_df = pd.read_json(file)
|
||||
serving_df = data_df.dropna(subset=[drop_column], ignore_index=True)
|
||||
if ignore_test_name is False:
|
||||
serving_df = serving_df.rename(columns={name_column: file + "_name"})
|
||||
frames.append(serving_df[file + "_name"])
|
||||
serving_df = serving_df.rename(columns={data_column: file})
|
||||
frames.append(serving_df[file])
|
||||
compare_frames.append(serving_df[file])
|
||||
if len(compare_frames) >= 2:
|
||||
base = compare_frames[0]
|
||||
current = compare_frames[-1]
|
||||
ratio = current / base
|
||||
ratio = ratio.mask(base == 0) # avoid inf when baseline is 0
|
||||
ratio.name = f"Ratio 1 vs {len(compare_frames)}"
|
||||
frames.append(ratio)
|
||||
# Compare numbers among two files
|
||||
ratio_df = compare_frames[1] / compare_frames[0]
|
||||
frames.append(ratio_df)
|
||||
compare_frames.pop(1)
|
||||
|
||||
# 4) concat on columns with aligned MultiIndex;
|
||||
# then reset_index to return keys as columns
|
||||
concat_df = pd.concat(frames, axis=1)
|
||||
concat_df = concat_df.reset_index(drop=True).reset_index()
|
||||
if "index" in concat_df.columns:
|
||||
concat_df = concat_df.drop(columns=["index"])
|
||||
|
||||
# Ensure key/info columns appear first (in your info_cols order)
|
||||
front = [c for c in info_cols if c in concat_df.columns]
|
||||
rest = [c for c in concat_df.columns if c not in front]
|
||||
concat_df = concat_df[front + rest]
|
||||
|
||||
print(raw_data_cols)
|
||||
return concat_df, raw_data_cols
|
||||
|
||||
|
||||
def split_json_by_tp_pp(
|
||||
input_file: str = "benchmark_results.json", output_root: str = "."
|
||||
) -> list[str]:
|
||||
"""
|
||||
Split a benchmark JSON into separate folders by (TP Size, PP Size).
|
||||
|
||||
Creates: <output_root>/tp{TP}_pp{PP}/benchmark_results.json
|
||||
Returns: list of file paths written.
|
||||
"""
|
||||
# Load JSON data into DataFrame
|
||||
with open(input_file, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
# If the JSON is a dict with a list under common keys, use that list
|
||||
if isinstance(data, dict):
|
||||
for key in ("results", "serving_results", "benchmarks", "data"):
|
||||
if isinstance(data.get(key), list):
|
||||
data = data[key]
|
||||
break
|
||||
|
||||
df = pd.DataFrame(data)
|
||||
|
||||
# Keep only "serving" tests
|
||||
name_col = next(
|
||||
(c for c in ["Test name", "test_name", "Test Name"] if c in df.columns), None
|
||||
)
|
||||
if name_col:
|
||||
df = df[
|
||||
df[name_col].astype(str).str.contains(r"serving", case=False, na=False)
|
||||
].copy()
|
||||
|
||||
# Handle alias column names
|
||||
rename_map = {
|
||||
"tp_size": "TP Size",
|
||||
"tensor_parallel_size": "TP Size",
|
||||
"pp_size": "PP Size",
|
||||
"pipeline_parallel_size": "PP Size",
|
||||
}
|
||||
df.rename(
|
||||
columns={k: v for k, v in rename_map.items() if k in df.columns}, inplace=True
|
||||
)
|
||||
|
||||
# Ensure TP/PP columns exist (default to 1 if missing)
|
||||
if "TP Size" not in df.columns:
|
||||
df["TP Size"] = 1
|
||||
if "PP Size" not in df.columns:
|
||||
df["PP Size"] = 1
|
||||
|
||||
# make sure TP/PP are numeric ints with no NaN
|
||||
df["TP Size"] = (
|
||||
pd.to_numeric(df.get("TP Size", 1), errors="coerce").fillna(1).astype(int)
|
||||
)
|
||||
df["PP Size"] = (
|
||||
pd.to_numeric(df.get("PP Size", 1), errors="coerce").fillna(1).astype(int)
|
||||
)
|
||||
|
||||
# Split into separate folders
|
||||
saved_paths: list[str] = []
|
||||
for (tp, pp), group_df in df.groupby(["TP Size", "PP Size"], dropna=False):
|
||||
folder_name = os.path.join(output_root, f"tp{int(tp)}_pp{int(pp)}")
|
||||
os.makedirs(folder_name, exist_ok=True)
|
||||
filepath = os.path.join(folder_name, "benchmark_results.json")
|
||||
group_df.to_json(filepath, orient="records", indent=2, force_ascii=False)
|
||||
print(f"Saved: {filepath}")
|
||||
saved_paths.append(filepath)
|
||||
|
||||
return saved_paths
|
||||
return concat_df
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -205,103 +36,31 @@ if __name__ == "__main__":
|
||||
"-f", "--file", action="append", type=str, help="input file name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug", action="store_true", help="show all information for debugging"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--plot",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=True,
|
||||
help="plot perf diagrams or not --no-plot --plot",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-x",
|
||||
"--xaxis",
|
||||
type=str,
|
||||
default="# of max concurrency.",
|
||||
help="column name to use as X Axis in comparision graph",
|
||||
"--ignore_test_name", action="store_true", help="ignore_test_name or not"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
files = args.file
|
||||
print("comparing : " + ", ".join(files))
|
||||
|
||||
drop_column = "P99"
|
||||
name_column = "Test name"
|
||||
info_cols = [
|
||||
"Model",
|
||||
"Dataset Name",
|
||||
"Input Len",
|
||||
"Output Len",
|
||||
"TP Size",
|
||||
"PP Size",
|
||||
"# of max concurrency.",
|
||||
"qps",
|
||||
]
|
||||
data_cols_to_compare = ["Output Tput (tok/s)", "Median TTFT (ms)", "Median"]
|
||||
html_msgs_for_data_cols = [
|
||||
"Compare Output Tokens /n",
|
||||
"Median TTFT /n",
|
||||
"Median TPOT /n",
|
||||
]
|
||||
|
||||
if len(args.file) == 1:
|
||||
files = split_json_by_tp_pp(args.file[0], output_root="splits")
|
||||
info_cols = [c for c in info_cols if c not in ("TP Size", "PP Size")]
|
||||
else:
|
||||
files = args.file
|
||||
print("comparing : " + ", ".join(files))
|
||||
debug = args.debug
|
||||
plot = args.plot
|
||||
# For Plot feature, assign y axis from one of info_cols
|
||||
y_axis_index = info_cols.index(args.xaxis) if args.xaxis in info_cols else 6
|
||||
ignore_test_name = args.ignore_test_name
|
||||
with open("perf_comparison.html", "w") as text_file:
|
||||
for i in range(len(data_cols_to_compare)):
|
||||
output_df, raw_data_cols = compare_data_columns(
|
||||
output_df = compare_data_columns(
|
||||
files,
|
||||
name_column,
|
||||
data_cols_to_compare[i],
|
||||
info_cols,
|
||||
drop_column,
|
||||
debug=debug,
|
||||
ignore_test_name=ignore_test_name,
|
||||
)
|
||||
|
||||
# For Plot feature, insert y axis from one of info_cols
|
||||
raw_data_cols.insert(0, info_cols[y_axis_index])
|
||||
|
||||
filtered_info_cols = info_cols[:-2]
|
||||
existing_group_cols = [
|
||||
c for c in filtered_info_cols if c in output_df.columns
|
||||
]
|
||||
if not existing_group_cols:
|
||||
raise ValueError(
|
||||
f"No valid group-by columns "
|
||||
f"Expected subset: {filtered_info_cols}, "
|
||||
f"but DataFrame has: {list(output_df.columns)}"
|
||||
)
|
||||
output_df_sorted = output_df.sort_values(by=existing_group_cols)
|
||||
output_groups = output_df_sorted.groupby(existing_group_cols, dropna=False)
|
||||
for name, group in output_groups:
|
||||
html = group.to_html()
|
||||
text_file.write(html_msgs_for_data_cols[i])
|
||||
text_file.write(html)
|
||||
|
||||
if plot and plotly_found:
|
||||
import plotly.express as px
|
||||
|
||||
df = group[raw_data_cols]
|
||||
df_sorted = df.sort_values(by=info_cols[y_axis_index])
|
||||
# Melt DataFrame for plotting
|
||||
df_melted = df_sorted.melt(
|
||||
id_vars=info_cols[y_axis_index],
|
||||
var_name="Configuration",
|
||||
value_name=data_cols_to_compare[i],
|
||||
)
|
||||
title = data_cols_to_compare[i] + " vs " + info_cols[y_axis_index]
|
||||
# Create Plotly line chart
|
||||
fig = px.line(
|
||||
df_melted,
|
||||
x=info_cols[y_axis_index],
|
||||
y=data_cols_to_compare[i],
|
||||
color="Configuration",
|
||||
title=title,
|
||||
markers=True,
|
||||
)
|
||||
# Export to HTML
|
||||
text_file.write(fig.to_html(full_html=True, include_plotlyjs="cdn"))
|
||||
print(output_df)
|
||||
html = output_df.to_html()
|
||||
text_file.write(html_msgs_for_data_cols[i])
|
||||
text_file.write(html)
|
||||
|
||||
@ -1,19 +1,17 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import shlex
|
||||
from importlib import util
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
import psutil
|
||||
import regex as re
|
||||
from tabulate import tabulate
|
||||
|
||||
results_folder = Path("results/")
|
||||
|
||||
# latency results and the keys that will be printed into markdown
|
||||
latency_results = []
|
||||
latency_column_mapping = {
|
||||
@ -44,22 +42,14 @@ throughput_results_column_mapping = {
|
||||
serving_results = []
|
||||
serving_column_mapping = {
|
||||
"test_name": "Test name",
|
||||
"model_id": "Model",
|
||||
"dataset_name": "Dataset Name",
|
||||
"input_len": "Input Len",
|
||||
"output_len": "Output Len",
|
||||
"tp_size": "TP Size",
|
||||
"pp_size": "PP Size",
|
||||
"dtype": "dtype",
|
||||
"gpu_type": "GPU",
|
||||
"completed": "# of req.",
|
||||
"qps": "qps",
|
||||
"max_concurrency": "# of max concurrency.",
|
||||
"request_throughput": "Tput (req/s)",
|
||||
"total_token_throughput": "Total Token Tput (tok/s)",
|
||||
"output_throughput": "Output Tput (tok/s)",
|
||||
# "total_input_tokens": "Total input tokens",
|
||||
# "total_output_tokens": "Total output tokens",
|
||||
"total_input_tokens": "Total input tokens",
|
||||
"total_output_tokens": "Total output tokens",
|
||||
"mean_ttft_ms": "Mean TTFT (ms)",
|
||||
"median_ttft_ms": "Median TTFT (ms)",
|
||||
"p99_ttft_ms": "P99 TTFT (ms)",
|
||||
@ -104,104 +94,7 @@ def get_size_with_unit(bytes, suffix="B"):
|
||||
bytes /= factor
|
||||
|
||||
|
||||
def _coerce(val: str) -> Any:
|
||||
"""Best-effort type coercion from string to Python types."""
|
||||
low = val.lower()
|
||||
if low == "null":
|
||||
return None
|
||||
if low == "true":
|
||||
return True
|
||||
if low == "false":
|
||||
return False
|
||||
# integers
|
||||
if re.fullmatch(r"[+-]?\d+", val):
|
||||
try:
|
||||
return int(val)
|
||||
except ValueError:
|
||||
pass
|
||||
# floats (keep 'inf'/'-inf'/'nan' as strings)
|
||||
if re.fullmatch(r"[+-]?\d*\.\d+", val):
|
||||
try:
|
||||
return float(val)
|
||||
except ValueError:
|
||||
pass
|
||||
return val
|
||||
|
||||
|
||||
def parse_client_command(cmd: str) -> dict[str, Any]:
|
||||
"""Parse the client_command shell string into {executable, script, args}."""
|
||||
toks = shlex.split(cmd)
|
||||
if len(toks) < 2:
|
||||
raise ValueError("client_command must include an executable and a script")
|
||||
executable, script = toks[0], toks[1]
|
||||
args: dict[str, Any] = {}
|
||||
|
||||
i = 2
|
||||
while i < len(toks):
|
||||
t = toks[i]
|
||||
if t.startswith("--"):
|
||||
# --key=value or --key (value) or boolean flag
|
||||
if "=" in t:
|
||||
key, val = t.split("=", 1)
|
||||
if key == "--metadata":
|
||||
md = {}
|
||||
if val:
|
||||
if "=" in val:
|
||||
k, v = val.split("=", 1)
|
||||
md[k] = _coerce(v)
|
||||
else:
|
||||
md[val] = True
|
||||
args[key] = md
|
||||
else:
|
||||
args[key] = _coerce(val)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
key = t
|
||||
|
||||
# Special: consume metadata k=v pairs until next --flag
|
||||
if key == "--metadata":
|
||||
i += 1
|
||||
md = {}
|
||||
while i < len(toks) and not toks[i].startswith("--"):
|
||||
pair = toks[i]
|
||||
if "=" in pair:
|
||||
k, v = pair.split("=", 1)
|
||||
md[k] = _coerce(v)
|
||||
else:
|
||||
md[pair] = True
|
||||
i += 1
|
||||
args[key] = md
|
||||
continue
|
||||
|
||||
# Standard: check if next token is a value (not a flag)
|
||||
if i + 1 < len(toks) and not toks[i + 1].startswith("--"):
|
||||
args[key] = _coerce(toks[i + 1])
|
||||
i += 2
|
||||
else:
|
||||
# lone flag -> True
|
||||
args[key] = True
|
||||
i += 1
|
||||
else:
|
||||
# unexpected positional; skip
|
||||
i += 1
|
||||
|
||||
return {"executable": executable, "script": script, "args": args}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-r",
|
||||
"--result",
|
||||
type=str,
|
||||
default="results",
|
||||
help="Folder name for benchmark output results.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
results_folder = Path(args.result)
|
||||
if not results_folder.exists():
|
||||
raise FileNotFoundError(f"results folder does not exist: {results_folder}")
|
||||
# collect results
|
||||
for test_file in results_folder.glob("*.json"):
|
||||
with open(test_file) as f:
|
||||
@ -209,6 +102,7 @@ if __name__ == "__main__":
|
||||
|
||||
if "serving" in str(test_file):
|
||||
# this result is generated via `vllm bench serve` command
|
||||
|
||||
# attach the benchmarking command to raw_result
|
||||
try:
|
||||
with open(test_file.with_suffix(".commands")) as f:
|
||||
@ -216,44 +110,12 @@ if __name__ == "__main__":
|
||||
except OSError as e:
|
||||
print(e)
|
||||
continue
|
||||
# Parse Server Command Arg
|
||||
out: dict[str, Any] = {
|
||||
"server_command": parse_client_command(command["server_command"])
|
||||
}
|
||||
parse_args = [
|
||||
"--tensor-parallel-size",
|
||||
"--pipeline-parallel-size",
|
||||
"--dtype",
|
||||
]
|
||||
col_mapping = ["tp_size", "pp_size", "dtype"]
|
||||
for index, arg in enumerate(parse_args):
|
||||
if arg in out["server_command"]["args"]:
|
||||
raw_result.update(
|
||||
{col_mapping[index]: out["server_command"]["args"][arg]}
|
||||
)
|
||||
|
||||
# Parse Client Command Arg
|
||||
out: dict[str, Any] = {
|
||||
"client_command": parse_client_command(command["client_command"])
|
||||
}
|
||||
parse_args = [
|
||||
"--dataset-name",
|
||||
"--random-input-len",
|
||||
"--random-output-len",
|
||||
"--request-rate",
|
||||
]
|
||||
col_mapping = ["dataset_name", "input_len", "output_len", "qps"]
|
||||
|
||||
for index, arg in enumerate(parse_args):
|
||||
if arg in out["client_command"]["args"]:
|
||||
raw_result.update(
|
||||
{col_mapping[index]: out["client_command"]["args"][arg]}
|
||||
)
|
||||
# Add Server, Client command
|
||||
raw_result.update(command)
|
||||
|
||||
# update the test name of this result
|
||||
raw_result.update({"test_name": test_file.stem})
|
||||
|
||||
# add the result to raw_result
|
||||
serving_results.append(raw_result)
|
||||
continue
|
||||
@ -343,10 +205,7 @@ if __name__ == "__main__":
|
||||
columns=latency_column_mapping
|
||||
)
|
||||
if not serving_results.empty:
|
||||
valid_columns = [
|
||||
col for col in serving_column_mapping if col in serving_results.columns
|
||||
]
|
||||
serving_results = serving_results[valid_columns].rename(
|
||||
serving_results = serving_results[list(serving_column_mapping.keys())].rename(
|
||||
columns=serving_column_mapping
|
||||
)
|
||||
if not throughput_results.empty:
|
||||
@ -386,9 +245,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
# document the result
|
||||
md_file = "benchmark_results.md"
|
||||
json_file = "benchmark_results.json"
|
||||
with open(results_folder / md_file, "w") as f:
|
||||
with open(results_folder / "benchmark_results.md", "w") as f:
|
||||
results = read_markdown(
|
||||
"../.buildkite/nightly-benchmarks/"
|
||||
+ "performance-benchmarks-descriptions.md"
|
||||
@ -403,7 +260,7 @@ if __name__ == "__main__":
|
||||
f.write(results)
|
||||
|
||||
# document benchmarking results in json
|
||||
with open(results_folder / json_file, "w") as f:
|
||||
with open(results_folder / "benchmark_results.json", "w") as f:
|
||||
results = (
|
||||
latency_results.to_dict(orient="records")
|
||||
+ throughput_results.to_dict(orient="records")
|
||||
|
||||
@ -194,11 +194,9 @@ run_latency_tests() {
|
||||
|
||||
# check if there is enough GPU to run the test
|
||||
tp=$(echo "$latency_params" | jq -r '.tensor_parallel_size')
|
||||
if [ "$ON_CPU" == "1" ]; then
|
||||
pp=$(echo "$latency_params" | jq -r '.pipeline_parallel_size')
|
||||
world_size=$(($tp*$pp))
|
||||
if [[ $numa_count -lt $world_size && -z "${REMOTE_HOST}" ]]; then
|
||||
echo "Required world-size $world_size but only $numa_count NUMA nodes found. Skip testcase $test_name."
|
||||
if [ "$ON_CPU" == "1" ];then
|
||||
if [[ $numa_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $numa_count NUMA nodes found. Skip testcase $test_name."
|
||||
continue
|
||||
fi
|
||||
else
|
||||
@ -263,11 +261,9 @@ run_throughput_tests() {
|
||||
|
||||
# check if there is enough GPU to run the test
|
||||
tp=$(echo "$throughput_params" | jq -r '.tensor_parallel_size')
|
||||
if [ "$ON_CPU" == "1" ]; then
|
||||
pp=$(echo "$throughput_params" | jq -r '.pipeline_parallel_size')
|
||||
world_size=$(($tp*$pp))
|
||||
if [[ $numa_count -lt $world_size && -z "${REMOTE_HOST}" ]]; then
|
||||
echo "Required world-size $world_size but only $numa_count NUMA nodes found. Skip testcase $test_name."
|
||||
if [ "$ON_CPU" == "1" ];then
|
||||
if [[ $numa_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $numa_count NUMA nodes found. Skip testcase $test_name."
|
||||
continue
|
||||
fi
|
||||
else
|
||||
@ -333,21 +329,12 @@ run_serving_tests() {
|
||||
qps_list=$(echo "$params" | jq -r '.qps_list')
|
||||
qps_list=$(echo "$qps_list" | jq -r '.[] | @sh')
|
||||
echo "Running over qps list $qps_list"
|
||||
max_concurrency_list=$(echo "$params" | jq -r '.max_concurrency_list')
|
||||
if [[ -z "$max_concurrency_list" || "$max_concurrency_list" == "null" ]]; then
|
||||
num_prompts=$(echo "$client_params" | jq -r '.num_prompts')
|
||||
max_concurrency_list="[$num_prompts]"
|
||||
fi
|
||||
max_concurrency_list=$(echo "$max_concurrency_list" | jq -r '.[] | @sh')
|
||||
echo "Running over max concurrency list $max_concurrency_list"
|
||||
|
||||
# check if there is enough resources to run the test
|
||||
tp=$(echo "$server_params" | jq -r '.tensor_parallel_size')
|
||||
if [ "$ON_CPU" == "1" ]; then
|
||||
pp=$(echo "$server_params" | jq -r '.pipeline_parallel_size')
|
||||
world_size=$(($tp*$pp))
|
||||
if [[ $numa_count -lt $world_size && -z "${REMOTE_HOST}" ]]; then
|
||||
echo "Required world-size $world_size but only $numa_count NUMA nodes found. Skip testcase $test_name."
|
||||
if [ "$ON_CPU" == "1" ];then
|
||||
if [[ $numa_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $numa_count NUMA nodes found. Skip testcase $test_name."
|
||||
continue
|
||||
fi
|
||||
else
|
||||
@ -403,39 +390,35 @@ run_serving_tests() {
|
||||
echo "now qps is $qps"
|
||||
fi
|
||||
|
||||
# iterate over different max_concurrency
|
||||
for max_concurrency in $max_concurrency_list; do
|
||||
new_test_name=$test_name"_qps_"$qps"_concurrency_"$max_concurrency
|
||||
echo " new test name $new_test_name"
|
||||
# pass the tensor parallel size to the client so that it can be displayed
|
||||
# on the benchmark dashboard
|
||||
client_command="vllm bench serve \
|
||||
--save-result \
|
||||
--result-dir $RESULTS_FOLDER \
|
||||
--result-filename ${new_test_name}.json \
|
||||
--request-rate $qps \
|
||||
--max-concurrency $max_concurrency \
|
||||
--metadata "tensor_parallel_size=$tp" \
|
||||
$client_args $client_remote_args "
|
||||
new_test_name=$test_name"_qps_"$qps
|
||||
|
||||
echo "Running test case $test_name with qps $qps"
|
||||
echo "Client command: $client_command"
|
||||
# pass the tensor parallel size to the client so that it can be displayed
|
||||
# on the benchmark dashboard
|
||||
client_command="vllm bench serve \
|
||||
--save-result \
|
||||
--result-dir $RESULTS_FOLDER \
|
||||
--result-filename ${new_test_name}.json \
|
||||
--request-rate $qps \
|
||||
--metadata "tensor_parallel_size=$tp" \
|
||||
$client_args $client_remote_args "
|
||||
|
||||
bash -c "$client_command"
|
||||
echo "Running test case $test_name with qps $qps"
|
||||
echo "Client command: $client_command"
|
||||
|
||||
# record the benchmarking commands
|
||||
jq_output=$(jq -n \
|
||||
--arg server "$server_command" \
|
||||
--arg client "$client_command" \
|
||||
--arg gpu "$gpu_type" \
|
||||
'{
|
||||
server_command: $server,
|
||||
client_command: $client,
|
||||
gpu_type: $gpu
|
||||
}')
|
||||
echo "$jq_output" >"$RESULTS_FOLDER/${new_test_name}.commands"
|
||||
bash -c "$client_command"
|
||||
|
||||
# record the benchmarking commands
|
||||
jq_output=$(jq -n \
|
||||
--arg server "$server_command" \
|
||||
--arg client "$client_command" \
|
||||
--arg gpu "$gpu_type" \
|
||||
'{
|
||||
server_command: $server,
|
||||
client_command: $client,
|
||||
gpu_type: $gpu
|
||||
}')
|
||||
echo "$jq_output" >"$RESULTS_FOLDER/${new_test_name}.commands"
|
||||
|
||||
done
|
||||
done
|
||||
|
||||
# clean up
|
||||
|
||||
@ -6,7 +6,7 @@
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 1,
|
||||
"load_format": "dummy",
|
||||
"num_iters_warmup": 5,
|
||||
@ -20,7 +20,7 @@
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 4,
|
||||
"load_format": "dummy",
|
||||
"num_iters_warmup": 5,
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
[
|
||||
{
|
||||
"test_name": "serving_llama8B_tp1_sharegpt",
|
||||
"qps_list": ["inf"],
|
||||
"max_concurrency_list": [12, 16, 24, 32, 64, 128, 200],
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
@ -11,7 +10,7 @@
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 1,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
@ -24,17 +23,17 @@
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"max_concurrency": 60,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp2_sharegpt",
|
||||
"qps_list": ["inf"],
|
||||
"max_concurrency_list": [12, 16, 24, 32, 64, 128, 200],
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
@ -43,7 +42,7 @@
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 2,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
@ -56,17 +55,17 @@
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"max_concurrency": 60,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp4_sharegpt",
|
||||
"qps_list": ["inf"],
|
||||
"max_concurrency_list": [12, 16, 24, 32, 64, 128, 200],
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
@ -75,7 +74,7 @@
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 4,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
@ -88,17 +87,17 @@
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"max_concurrency": 60,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp1_random_128_128",
|
||||
"qps_list": ["inf"],
|
||||
"max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000],
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
@ -107,7 +106,7 @@
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 1,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
@ -121,19 +120,19 @@
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "random",
|
||||
"random-input-len": 128,
|
||||
"random-output-len": 128,
|
||||
"ignore-eos": "",
|
||||
"max_concurrency": 1000,
|
||||
"num_prompts": 1000
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp2_random_128_128",
|
||||
"qps_list": ["inf"],
|
||||
"max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000],
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
@ -142,7 +141,7 @@
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 2,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
@ -156,19 +155,19 @@
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "random",
|
||||
"random-input-len": 128,
|
||||
"random-output-len": 128,
|
||||
"ignore-eos": "",
|
||||
"max_concurrency": 1000,
|
||||
"num_prompts": 1000
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp4_random_128_128",
|
||||
"qps_list": ["inf"],
|
||||
"max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000],
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
@ -177,7 +176,7 @@
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 4,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
@ -191,11 +190,13 @@
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "random",
|
||||
"random-input-len": 128,
|
||||
"random-output-len": 128,
|
||||
"ignore-eos": "",
|
||||
"max_concurrency": 1000,
|
||||
"num_prompts": 1000
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
[
|
||||
{
|
||||
"test_name": "serving_llama8B_pp1_sharegpt",
|
||||
"qps_list": ["inf"],
|
||||
"max_concurrency_list": [12, 16, 24, 32, 64, 128, 200],
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
@ -11,7 +10,7 @@
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"pipeline_parallel_size": 1,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
@ -24,17 +23,17 @@
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"max_concurrency": 60,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_pp3_sharegpt",
|
||||
"qps_list": ["inf"],
|
||||
"max_concurrency_list": [12, 16, 24, 32, 64, 128, 200],
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
@ -43,7 +42,7 @@
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"pipeline_parallel_size": 3,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
@ -56,17 +55,17 @@
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"max_concurrency": 60,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp2pp3_sharegpt",
|
||||
"qps_list": ["inf"],
|
||||
"max_concurrency_list": [12, 16, 24, 32, 64, 128, 200],
|
||||
"test_name": "serving_llama8B_tp2pp6_sharegpt",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
@ -75,7 +74,7 @@
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 2,
|
||||
"pipeline_parallel_size": 3,
|
||||
"dtype": "bfloat16",
|
||||
@ -89,17 +88,17 @@
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"max_concurrency": 60,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_pp1_random_128_128",
|
||||
"qps_list": ["inf"],
|
||||
"max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000],
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
@ -108,7 +107,7 @@
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"pipeline_parallel_size": 1,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
@ -122,28 +121,28 @@
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "random",
|
||||
"random-input-len": 128,
|
||||
"random-output-len": 128,
|
||||
"ignore-eos": "",
|
||||
"max_concurrency": 1000,
|
||||
"num_prompts": 1000
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_pp3_random_128_128",
|
||||
"qps_list": ["inf"],
|
||||
"max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000],
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_ENGINE_ITERATION_TIMEOUT_S": 120,
|
||||
"VLLM_CPU_SGL_KERNEL": 1,
|
||||
"VLLM_CPU_SGL_KERNEL:": 1,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"pipeline_parallel_size": 3,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
@ -157,19 +156,19 @@
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "random",
|
||||
"random-input-len": 128,
|
||||
"random-output-len": 128,
|
||||
"ignore-eos": "",
|
||||
"max_concurrency": 1000,
|
||||
"num_prompts": 1000
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp2pp3_random_128_128",
|
||||
"qps_list": ["inf"],
|
||||
"max_concurrency_list": [12, 16, 24, 32, 64, 128, 200, 1000],
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
@ -178,7 +177,7 @@
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 2,
|
||||
"pipeline_parallel_size": 3,
|
||||
"dtype": "bfloat16",
|
||||
@ -193,12 +192,13 @@
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "random",
|
||||
"random-input-len": 128,
|
||||
"random-output-len": 128,
|
||||
"ignore-eos": "",
|
||||
"max_concurrency": 1000,
|
||||
"num_prompts": 1000
|
||||
}
|
||||
}
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
{
|
||||
"test_name": "serving_llama8B_tp1_sharegpt",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"max_concurrency_list": [12, 16, 24, 32, 64, 128, 200],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
@ -11,7 +10,7 @@
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 1,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
@ -24,17 +23,17 @@
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"max_concurrency": 60,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp2_sharegpt",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"max_concurrency_list": [12, 16, 24, 32, 64, 128, 200],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
@ -43,7 +42,7 @@
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 2,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
@ -56,17 +55,17 @@
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"max_concurrency": 60,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp4_sharegpt",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"max_concurrency_list": [12, 16, 24, 32, 64, 128, 200],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
@ -75,7 +74,7 @@
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 4,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
@ -88,17 +87,17 @@
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"max_concurrency": 60,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp4_random_1024_128",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"max_concurrency_list": [12, 16, 24, 32, 64, 128, 200],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
@ -107,7 +106,7 @@
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 4,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
@ -121,19 +120,19 @@
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "random",
|
||||
"random-input-len": 1024,
|
||||
"random-output-len": 128,
|
||||
"ignore-eos": "",
|
||||
"max_concurrency": 100,
|
||||
"num_prompts": 100
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_pp6_random_1024_128",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"max_concurrency_list": [12, 16, 24, 32, 64, 128, 200],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
@ -142,7 +141,7 @@
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"pipeline_parallel_size": 6,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
@ -156,12 +155,13 @@
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "random",
|
||||
"random-input-len": 1024,
|
||||
"random-output-len": 128,
|
||||
"ignore-eos": "",
|
||||
"max_concurrency": 100,
|
||||
"num_prompts": 100
|
||||
}
|
||||
}
|
||||
|
||||
@ -6,7 +6,7 @@
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 1,
|
||||
"load_format": "dummy",
|
||||
"dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
@ -21,7 +21,7 @@
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"parameters": {
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 4,
|
||||
"load_format": "dummy",
|
||||
"dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
|
||||
@ -1,20 +1,4 @@
|
||||
steps:
|
||||
# aarch64 + CUDA builds
|
||||
- label: "Build arm64 wheel - CUDA 12.8"
|
||||
id: build-wheel-arm64-cuda-12-8
|
||||
agents:
|
||||
queue: arm64_cpu_queue_postmerge
|
||||
commands:
|
||||
# #NOTE: torch_cuda_arch_list is derived from upstream PyTorch build files here:
|
||||
# https://github.com/pytorch/pytorch/blob/main/.ci/aarch64_linux/aarch64_ci_build.sh#L7
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg torch_cuda_arch_list='8.7 9.0 10.0+PTX' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
|
||||
- "mkdir artifacts"
|
||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
# x86 + CUDA builds
|
||||
- label: "Build wheel - CUDA 12.8"
|
||||
id: build-wheel-cuda-12-8
|
||||
agents:
|
||||
@ -27,12 +11,7 @@ steps:
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
- block: "Build CUDA 12.6 wheel"
|
||||
key: block-build-cu126-wheel
|
||||
depends_on: ~
|
||||
|
||||
- label: "Build wheel - CUDA 12.6"
|
||||
depends_on: block-build-cu126-wheel
|
||||
id: build-wheel-cuda-12-6
|
||||
agents:
|
||||
queue: cpu_queue_postmerge
|
||||
@ -73,7 +52,7 @@ steps:
|
||||
queue: cpu_queue_postmerge
|
||||
commands:
|
||||
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg FLASHINFER_AOT_COMPILE=true --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ."
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ."
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT"
|
||||
|
||||
- label: "Annotate release workflow"
|
||||
|
||||
@ -121,6 +121,7 @@ fi
|
||||
if [[ $commands == *" kernels/quantization"* ]]; then
|
||||
commands="${commands} \
|
||||
--ignore=kernels/quantization/test_int8_quant.py \
|
||||
--ignore=kernels/quantization/test_aqlm.py \
|
||||
--ignore=kernels/quantization/test_machete_mm.py \
|
||||
--ignore=kernels/quantization/test_block_fp8.py \
|
||||
--ignore=kernels/quantization/test_block_int8.py \
|
||||
|
||||
@ -46,11 +46,6 @@ function cpu_tests() {
|
||||
set -e
|
||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m"
|
||||
|
||||
# Run kernel tests
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
pytest -v -s tests/kernels/test_onednn.py"
|
||||
|
||||
# Run basic model test
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
@ -104,4 +99,4 @@ function cpu_tests() {
|
||||
|
||||
# All of CPU tests are expected to be finished less than 40 mins.
|
||||
export -f cpu_tests
|
||||
timeout 2h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
|
||||
timeout 1.5h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
|
||||
|
||||
@ -23,13 +23,9 @@ docker run \
|
||||
--device /dev/dri \
|
||||
-v /dev/dri/by-path:/dev/dri/by-path \
|
||||
--entrypoint="" \
|
||||
-e "HF_TOKEN=${HF_TOKEN}" \
|
||||
-e "ZE_AFFINITY_MASK=${ZE_AFFINITY_MASK}" \
|
||||
--name "${container_name}" \
|
||||
"${image_name}" \
|
||||
bash -c '
|
||||
set -e
|
||||
echo $ZE_AFFINITY_MASK
|
||||
sh -c '
|
||||
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
|
||||
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
|
||||
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
|
||||
@ -39,8 +35,8 @@ docker run \
|
||||
pytest -v -s v1/sample --ignore=v1/sample/test_logprobs.py --ignore=v1/sample/test_logprobs_e2e.py
|
||||
pytest -v -s v1/worker --ignore=v1/worker/test_gpu_model_runner.py
|
||||
pytest -v -s v1/structured_output
|
||||
pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_eagle.py --ignore=v1/spec_decode/test_tree_attention.py
|
||||
pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py --ignore=v1/kv_connector/unit/test_shared_storage_connector.py
|
||||
pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_eagle.py
|
||||
pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py
|
||||
pytest -v -s v1/test_serial_utils.py
|
||||
pytest -v -s v1/test_utils.py
|
||||
pytest -v -s v1/test_metrics_reader.py
|
||||
|
||||
@ -17,7 +17,7 @@ if [ "$disk_usage" -gt "$threshold" ]; then
|
||||
# Remove dangling images (those that are not tagged and not used by any container)
|
||||
docker image prune -f
|
||||
# Remove unused volumes / force the system prune for old images as well.
|
||||
docker volume prune -f && docker system prune --force --filter "until=24h" --all
|
||||
docker volume prune -f && docker system prune --force --filter "until=72h" --all
|
||||
echo "Docker images and volumes cleanup completed."
|
||||
else
|
||||
echo "Disk usage is below $threshold%. No cleanup needed."
|
||||
|
||||
@ -14,19 +14,8 @@ fi
|
||||
# Get the single wheel file
|
||||
wheel="${wheel_files[0]}"
|
||||
|
||||
# Detect architecture and rename 'linux' to appropriate manylinux version
|
||||
arch=$(uname -m)
|
||||
if [[ $arch == "x86_64" ]]; then
|
||||
manylinux_version="manylinux1"
|
||||
elif [[ $arch == "aarch64" ]]; then
|
||||
manylinux_version="manylinux2014"
|
||||
else
|
||||
echo "Warning: Unknown architecture $arch, using manylinux1 as default"
|
||||
manylinux_version="manylinux1"
|
||||
fi
|
||||
|
||||
# Rename 'linux' to the appropriate manylinux version in the wheel filename
|
||||
new_wheel="${wheel/linux/$manylinux_version}"
|
||||
# Rename 'linux' to 'manylinux1' in the wheel filename
|
||||
new_wheel="${wheel/linux/manylinux1}"
|
||||
mv -- "$wheel" "$new_wheel"
|
||||
wheel="$new_wheel"
|
||||
|
||||
|
||||
@ -31,6 +31,16 @@
|
||||
steps:
|
||||
##### fast check tests #####
|
||||
|
||||
- label: Documentation Build # 2min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/test_docs"
|
||||
fast_check: true
|
||||
no_gpu: True
|
||||
commands:
|
||||
- pip install -r ../requirements/docs.txt
|
||||
# TODO: add `--strict` once warnings in docstrings are fixed
|
||||
- mkdocs build
|
||||
|
||||
- label: Pytorch Nightly Dependency Override Check # 2min
|
||||
# if this test fails, it means the nightly torch version is not compatible with some
|
||||
# of the dependencies. Please check the error message and add the package to whitelist
|
||||
@ -46,7 +56,6 @@ steps:
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/mq_llm_engine
|
||||
- tests/async_engine
|
||||
- tests/test_inputs.py
|
||||
- tests/test_outputs.py
|
||||
- tests/multimodal
|
||||
@ -56,7 +65,6 @@ steps:
|
||||
commands:
|
||||
- python3 standalone_tests/lazy_imports.py
|
||||
- pytest -v -s mq_llm_engine # MQLLMEngine
|
||||
- pytest -v -s async_engine # AsyncLLMEngine
|
||||
- pytest -v -s test_inputs.py
|
||||
- pytest -v -s test_outputs.py
|
||||
- pytest -v -s multimodal
|
||||
@ -88,6 +96,15 @@ steps:
|
||||
- pytest -v -s basic_correctness/test_cpu_offload.py
|
||||
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
|
||||
|
||||
- label: Chunked Prefill Test
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/basic_correctness/test_chunked_prefill
|
||||
commands:
|
||||
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||
|
||||
- label: Core Test # 10min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
fast_check: true
|
||||
@ -126,8 +143,7 @@ steps:
|
||||
- tests/entrypoints/test_chat_utils
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/openai/test_collective_rpc.py # PYTHONPATH is needed to import custom Worker extension
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_collective_rpc.py
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/
|
||||
- pytest -v -s entrypoints/test_chat_utils.py
|
||||
|
||||
- label: Distributed Tests (4 GPUs) # 10min
|
||||
@ -245,7 +261,6 @@ steps:
|
||||
- pytest -v -s v1/engine
|
||||
- pytest -v -s v1/entrypoints
|
||||
- pytest -v -s v1/sample
|
||||
- pytest -v -s v1/logits_processors
|
||||
- pytest -v -s v1/worker
|
||||
- pytest -v -s v1/structured_output
|
||||
- pytest -v -s v1/spec_decode
|
||||
@ -287,6 +302,15 @@ steps:
|
||||
- python3 offline_inference/basic/score.py
|
||||
- VLLM_USE_V1=0 python3 offline_inference/profiling.py --model facebook/opt-125m run_num_steps --num-steps 2
|
||||
|
||||
- label: Prefix Caching Test # 9min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/prefix_caching
|
||||
commands:
|
||||
- pytest -v -s prefix_caching
|
||||
|
||||
|
||||
- label: Platform Tests (CUDA)
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
@ -304,6 +328,7 @@ steps:
|
||||
- tests/conftest.py
|
||||
commands:
|
||||
- pytest -v -s samplers
|
||||
- VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers
|
||||
|
||||
- label: LoRA Test %N # 15min each
|
||||
mirror_hardwares: [amdexperimental]
|
||||
@ -327,7 +352,6 @@ steps:
|
||||
- pytest -v -s compile/test_sequence_parallelism.py
|
||||
- pytest -v -s compile/test_async_tp.py
|
||||
- pytest -v -s compile/test_fusion_all_reduce.py
|
||||
- pytest -v -s compile/test_decorator.py
|
||||
|
||||
- label: PyTorch Fullgraph Smoke Test # 9min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
@ -341,7 +365,6 @@ steps:
|
||||
- pytest -v -s compile/piecewise/test_simple.py
|
||||
- pytest -v -s compile/piecewise/test_toy_llama.py
|
||||
- pytest -v -s compile/piecewise/test_full_cudagraph.py
|
||||
- pytest -v -s compile/piecewise/test_multiple_graphs.py
|
||||
|
||||
- label: PyTorch Fullgraph Test # 18min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
@ -384,7 +407,6 @@ steps:
|
||||
- label: Kernels MoE Test %N
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- csrc/quantization/cutlass_w8a8/moe/
|
||||
- csrc/moe/
|
||||
- tests/kernels/moe
|
||||
- vllm/model_executor/layers/fused_moe/
|
||||
@ -452,11 +474,13 @@ steps:
|
||||
|
||||
- label: LM Eval Small Models # 53min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
|
||||
source_file_dependencies:
|
||||
- csrc/
|
||||
- vllm/model_executor/layers/quantization
|
||||
commands:
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-small.txt --tp-size=1
|
||||
|
||||
- label: OpenAI API correctness
|
||||
mirror_hardwares: [amdexperimental]
|
||||
@ -544,15 +568,6 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s models/language/pooling -m 'not core_model'
|
||||
|
||||
- label: Multi-Modal Processor Test
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models/multimodal
|
||||
commands:
|
||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
- pytest -v -s models/multimodal/processing --ignore models/multimodal/processing/test_tensor_schema.py
|
||||
- pytest -v -s models/multimodal/processing/test_tensor_schema.py
|
||||
|
||||
- label: Multi-Modal Models Test (Standard)
|
||||
mirror_hardwares: [amdexperimental]
|
||||
torch_nightly: true
|
||||
@ -562,7 +577,9 @@ steps:
|
||||
commands:
|
||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
- pip freeze | grep -E 'torch'
|
||||
- pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing
|
||||
- pytest -v -s models/multimodal/processing
|
||||
- pytest -v -s --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/test_tensor_schema.py models/multimodal -m core_model
|
||||
- pytest -v -s models/multimodal/test_tensor_schema.py -m core_model # Needs mp_method="spawn"
|
||||
- cd .. && pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work
|
||||
|
||||
- label: Multi-Modal Models Test (Extended) 1
|
||||
@ -573,7 +590,7 @@ steps:
|
||||
- tests/models/multimodal
|
||||
commands:
|
||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
- pytest -v -s models/multimodal -m 'not core_model' --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/processing
|
||||
- pytest -v -s --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/processing models/multimodal -m 'not core_model'
|
||||
|
||||
- label: Multi-Modal Models Test (Extended) 2
|
||||
mirror_hardwares: [amdexperimental]
|
||||
@ -636,10 +653,8 @@ steps:
|
||||
- vllm/model_executor/layers/fused_moe/cutlass_moe.py
|
||||
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
|
||||
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
|
||||
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
|
||||
- vllm/v1/attention/backends/flashinfer.py
|
||||
- vllm/compilation/fusion.py
|
||||
- vllm/compilation/fusion_attn.py
|
||||
commands:
|
||||
- nvidia-smi
|
||||
- python3 examples/offline_inference/basic/chat.py
|
||||
@ -652,13 +667,9 @@ steps:
|
||||
- pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8'
|
||||
- pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py
|
||||
- pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py
|
||||
- pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
|
||||
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
|
||||
- pytest -v -s tests/kernels/moe/test_mxfp4_moe.py
|
||||
# Fusion
|
||||
- pytest -v -s tests/compile/test_fusion_all_reduce.py
|
||||
- pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern
|
||||
- pytest -v -s tests/kernels/moe/test_flashinfer.py
|
||||
|
||||
##### 1 GPU test #####
|
||||
##### multi gpus test #####
|
||||
|
||||
10
.github/CODEOWNERS
vendored
10
.github/CODEOWNERS
vendored
@ -10,7 +10,6 @@
|
||||
/vllm/worker/worker.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256
|
||||
/vllm/model_executor/layers/mamba @tdoublep
|
||||
/vllm/multimodal @DarkLight1337 @ywang96
|
||||
/vllm/vllm_flash_attn @LucasWilkinson
|
||||
/vllm/lora @jeejeelee
|
||||
@ -26,11 +25,11 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||
# vLLM V1
|
||||
/vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat
|
||||
/vllm/v1/structured_output @mgoin @russellb @aarnphm
|
||||
/vllm/v1/attention/backends/triton_attn.py @tdoublep
|
||||
|
||||
# Test ownership
|
||||
/.buildkite/lm-eval-harness @mgoin @simon-mo
|
||||
/tests/async_engine @njhill @robertgshaw2-redhat @simon-mo
|
||||
/tests/basic_correctness/test_chunked_prefill @rkooo567 @comaniac
|
||||
/tests/distributed/test_multi_node_assignment.py @youkaichao
|
||||
/tests/distributed/test_pipeline_parallel.py @youkaichao
|
||||
/tests/distributed/test_same_node.py @youkaichao
|
||||
@ -45,7 +44,6 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||
/tests/v1/structured_output @mgoin @russellb @aarnphm
|
||||
/tests/weight_loading @mgoin @youkaichao @yewentao256
|
||||
/tests/lora @jeejeelee
|
||||
/tests/models/language/generation/test_hybrid.py @tdoublep
|
||||
|
||||
# Docs
|
||||
/docs @hmellor
|
||||
@ -74,9 +72,3 @@ mkdocs.yaml @hmellor
|
||||
/vllm/model_executor/models/pixtral*.py @patrickvonplaten
|
||||
/vllm/transformers_utils/configs/mistral.py @patrickvonplaten
|
||||
/vllm/transformers_utils/tokenizers/mistral.py @patrickvonplaten
|
||||
|
||||
# Kernels
|
||||
/vllm/attention/ops/chunked_prefill_paged_decode.py @tdoublep
|
||||
/vllm/attention/ops/triton_unified_attention.py @tdoublep
|
||||
|
||||
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@ -207,6 +207,3 @@ shellcheck*/
|
||||
|
||||
# Ignore moe/marlin_moe gen code
|
||||
csrc/moe/marlin_moe_wna16/kernel_*
|
||||
|
||||
# Ignore ep_kernels_workspace folder
|
||||
ep_kernels_workspace/
|
||||
@ -30,7 +30,7 @@ install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
|
||||
# Supported python versions. These versions will be searched in order, the
|
||||
# first match will be selected. These should be kept in sync with setup.py.
|
||||
#
|
||||
set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12", "3.13")
|
||||
set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12")
|
||||
|
||||
# Supported AMD GPU architectures.
|
||||
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201")
|
||||
@ -249,6 +249,7 @@ set(VLLM_EXT_SRC
|
||||
"csrc/quantization/gguf/gguf_kernel.cu"
|
||||
"csrc/quantization/activation_kernels.cu"
|
||||
"csrc/cuda_utils_kernels.cu"
|
||||
"csrc/prepare_inputs/advance_step.cu"
|
||||
"csrc/custom_all_reduce.cu"
|
||||
"csrc/torch_bindings.cpp")
|
||||
|
||||
@ -286,6 +287,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
FetchContent_MakeAvailable(cutlass)
|
||||
|
||||
list(APPEND VLLM_EXT_SRC
|
||||
"csrc/quantization/aqlm/gemm_kernels.cu"
|
||||
"csrc/quantization/awq/gemm_kernels.cu"
|
||||
"csrc/permute_cols.cu"
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
|
||||
@ -349,27 +351,20 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MARLIN_TEMPLATE_KERNEL_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
|
||||
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
|
||||
|
||||
set(MARLIN_SRCS
|
||||
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
|
||||
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
|
||||
"csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu"
|
||||
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
||||
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
|
||||
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_SRCS}"
|
||||
CUDA_ARCHS "${MARLIN_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties("csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
list(APPEND VLLM_EXT_SRC "${MARLIN_SRCS}")
|
||||
|
||||
message(STATUS "Building Marlin kernels for archs: ${MARLIN_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building Marlin kernels as no compatible archs found"
|
||||
@ -859,10 +854,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MOE_WNAA16_MARLIN_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
set_source_files_properties(${MOE_WNAA16_MARLIN_SRC}
|
||||
PROPERTIES COMPILE_FLAGS "-static-global-template-stub=false")
|
||||
endif()
|
||||
|
||||
list(APPEND VLLM_MOE_EXT_SRC ${MOE_WNAA16_MARLIN_SRC})
|
||||
|
||||
|
||||
@ -22,25 +22,6 @@ become available.
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td><code>wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>ShareGPT4V (Image)</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td>
|
||||
<code>wget https://huggingface.co/datasets/Lin-Chen/ShareGPT4V/blob/main/sharegpt4v_instruct_gpt4-vision_cap100k.json</code>
|
||||
<br>
|
||||
<div>Note that the images need to be downloaded separately. For example, to download COCO's 2017 Train images:</div>
|
||||
<code>wget http://images.cocodataset.org/zips/train2017.zip</code>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>ShareGPT4Video (Video)</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td>
|
||||
<code>git clone https://huggingface.co/datasets/ShareGPT4Video/ShareGPT4Video</code>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>BurstGPT</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
@ -48,7 +29,7 @@ become available.
|
||||
<td><code>wget https://github.com/HPMLL/BurstGPT/releases/download/v1.1/BurstGPT_without_fails_2.csv</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>Sonnet (deprecated)</strong></td>
|
||||
<td><strong>Sonnet</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td>Local file: <code>benchmarks/sonnet.txt</code></td>
|
||||
@ -59,12 +40,6 @@ become available.
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td><code>synthetic</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>Prefix Repetition</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td><code>synthetic</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>HuggingFace-VisionArena</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
@ -202,7 +177,6 @@ vllm serve Qwen/Qwen2-VL-7B-Instruct
|
||||
```bash
|
||||
vllm bench serve \
|
||||
--backend openai-chat \
|
||||
--endpoint-type openai-chat \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--endpoint /v1/chat/completions \
|
||||
--dataset-name hf \
|
||||
@ -239,7 +213,6 @@ vllm serve Qwen/Qwen2-VL-7B-Instruct
|
||||
```bash
|
||||
vllm bench serve \
|
||||
--backend openai-chat \
|
||||
--endpoint-type openai-chat \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--endpoint /v1/chat/completions \
|
||||
--dataset-name hf \
|
||||
@ -254,7 +227,6 @@ vllm bench serve \
|
||||
```bash
|
||||
vllm bench serve \
|
||||
--backend openai-chat \
|
||||
--endpoint-type openai-chat \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--endpoint /v1/chat/completions \
|
||||
--dataset-name hf \
|
||||
@ -609,20 +581,6 @@ python3 benchmarks/benchmark_prefix_caching.py \
|
||||
--input-length-range 128:256
|
||||
```
|
||||
|
||||
### Prefix Repetition Dataset
|
||||
|
||||
```bash
|
||||
vllm bench serve \
|
||||
--backend openai \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--dataset-name prefix_repetition \
|
||||
--num-prompts 100 \
|
||||
--prefix-repetition-prefix-len 512 \
|
||||
--prefix-repetition-suffix-len 128 \
|
||||
--prefix-repetition-num-prefixes 5 \
|
||||
--prefix-repetition-output-len 128
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## ⚡ Example - Request Prioritization Benchmark
|
||||
@ -658,68 +616,3 @@ python3 benchmarks/benchmark_prioritization.py \
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## 👁️ Example - Multi-Modal Benchmark
|
||||
|
||||
<details>
|
||||
<summary>Show more</summary>
|
||||
|
||||
<br/>
|
||||
|
||||
Benchmark the performance of multi-modal requests in vLLM.
|
||||
|
||||
### Images (ShareGPT4V)
|
||||
|
||||
Start vLLM:
|
||||
|
||||
```bash
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
--model Qwen/Qwen2.5-VL-7B-Instruct \
|
||||
--dtype bfloat16 \
|
||||
--limit-mm-per-prompt '{"image": 1}' \
|
||||
--allowed-local-media-path /path/to/sharegpt4v/images
|
||||
```
|
||||
|
||||
Send requests with images:
|
||||
|
||||
```bash
|
||||
python benchmarks/benchmark_serving.py \
|
||||
--backend openai-chat \
|
||||
--model Qwen/Qwen2.5-VL-7B-Instruct \
|
||||
--dataset-name sharegpt \
|
||||
--dataset-path /path/to/ShareGPT4V/sharegpt4v_instruct_gpt4-vision_cap100k.json \
|
||||
--num-prompts 100 \
|
||||
--save-result \
|
||||
--result-dir ~/vllm_benchmark_results \
|
||||
--save-detailed \
|
||||
--endpoint /v1/chat/completion
|
||||
```
|
||||
|
||||
### Videos (ShareGPT4Video)
|
||||
|
||||
Start vLLM:
|
||||
|
||||
```bash
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
--model Qwen/Qwen2.5-VL-7B-Instruct \
|
||||
--dtype bfloat16 \
|
||||
--limit-mm-per-prompt '{"video": 1}' \
|
||||
--allowed-local-media-path /path/to/sharegpt4video/videos
|
||||
```
|
||||
|
||||
Send requests with videos:
|
||||
|
||||
```bash
|
||||
python benchmarks/benchmark_serving.py \
|
||||
--backend openai-chat \
|
||||
--model Qwen/Qwen2.5-VL-7B-Instruct \
|
||||
--dataset-name sharegpt \
|
||||
--dataset-path /path/to/ShareGPT4Video/llava_v1_5_mix665k_with_video_chatgpt72k_share4video28k.json \
|
||||
--num-prompts 100 \
|
||||
--save-result \
|
||||
--result-dir ~/vllm_benchmark_results \
|
||||
--save-detailed \
|
||||
--endpoint /v1/chat/completion
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
@ -34,7 +34,6 @@ class RequestFuncInput:
|
||||
multi_modal_content: Optional[dict | list[dict]] = None
|
||||
ignore_eos: bool = False
|
||||
language: Optional[str] = None
|
||||
request_id: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -72,9 +71,6 @@ async def async_request_tgi(
|
||||
"inputs": request_func_input.prompt,
|
||||
"parameters": params,
|
||||
}
|
||||
headers = None
|
||||
if request_func_input.request_id:
|
||||
headers = {"x-request-id": request_func_input.request_id}
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
if request_func_input.ignore_eos:
|
||||
@ -86,9 +82,7 @@ async def async_request_tgi(
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(
|
||||
url=api_url, json=payload, headers=headers
|
||||
) as response:
|
||||
async with session.post(url=api_url, json=payload) as response:
|
||||
if response.status == 200:
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
@ -151,9 +145,6 @@ async def async_request_trt_llm(
|
||||
}
|
||||
if request_func_input.ignore_eos:
|
||||
payload["min_length"] = request_func_input.output_len
|
||||
headers = None
|
||||
if request_func_input.request_id:
|
||||
headers = {"x-request-id": request_func_input.request_id}
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
@ -161,9 +152,7 @@ async def async_request_trt_llm(
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(
|
||||
url=api_url, json=payload, headers=headers
|
||||
) as response:
|
||||
async with session.post(url=api_url, json=payload) as response:
|
||||
if response.status == 200:
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
@ -222,8 +211,6 @@ async def async_request_deepspeed_mii(
|
||||
"top_p": 1.0,
|
||||
}
|
||||
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
||||
if request_func_input.request_id:
|
||||
headers["x-request-id"] = request_func_input.request_id
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
@ -296,8 +283,6 @@ async def async_request_openai_completions(
|
||||
if request_func_input.extra_body:
|
||||
payload.update(request_func_input.extra_body)
|
||||
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
||||
if request_func_input.request_id:
|
||||
headers["x-request-id"] = request_func_input.request_id
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
@ -410,8 +395,6 @@ async def async_request_openai_chat_completions(
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
}
|
||||
if request_func_input.request_id:
|
||||
headers["x-request-id"] = request_func_input.request_id
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
@ -508,8 +491,6 @@ async def async_request_openai_audio(
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
}
|
||||
if request_func_input.request_id:
|
||||
headers["x-request-id"] = request_func_input.request_id
|
||||
|
||||
# Send audio file
|
||||
def to_bytes(y, sr):
|
||||
|
||||
@ -19,7 +19,6 @@ import logging
|
||||
import random
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from functools import cache
|
||||
from io import BytesIO
|
||||
@ -55,7 +54,6 @@ class SampleRequest:
|
||||
expected_output_len: int
|
||||
multi_modal_data: Optional[Union[MultiModalDataDict, dict, list[dict]]] = None
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
request_id: Optional[str] = None
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@ -157,10 +155,7 @@ class BenchmarkDataset(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def sample(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
request_id_prefix: str = "",
|
||||
self, tokenizer: PreTrainedTokenizerBase, num_requests: int
|
||||
) -> list[SampleRequest]:
|
||||
"""
|
||||
Abstract method to generate sample requests from the dataset.
|
||||
@ -172,7 +167,6 @@ class BenchmarkDataset(ABC):
|
||||
tokenizer (PreTrainedTokenizerBase): The tokenizer to be used
|
||||
for processing the dataset's text.
|
||||
num_requests (int): The number of sample requests to generate.
|
||||
request_id_prefix (str) The prefix of request_id.
|
||||
|
||||
Returns:
|
||||
list[SampleRequest]: A list of sample requests generated from the
|
||||
@ -181,10 +175,7 @@ class BenchmarkDataset(ABC):
|
||||
raise NotImplementedError("sample must be implemented in subclasses.")
|
||||
|
||||
def maybe_oversample_requests(
|
||||
self,
|
||||
requests: list[SampleRequest],
|
||||
num_requests: int,
|
||||
request_id_prefix: str = "",
|
||||
self, requests: list[SampleRequest], num_requests: int
|
||||
) -> None:
|
||||
"""
|
||||
Oversamples the list of requests if its size is less than the desired
|
||||
@ -192,18 +183,11 @@ class BenchmarkDataset(ABC):
|
||||
|
||||
Args:
|
||||
requests (List[SampleRequest]): The current list of sampled
|
||||
requests.
|
||||
num_requests (int): The target number of requests.
|
||||
request_id_prefix (str) The prefix of the request ids.
|
||||
requests. num_requests (int): The target number of requests.
|
||||
"""
|
||||
if len(requests) < num_requests:
|
||||
random.seed(self.random_seed)
|
||||
additional = deepcopy(
|
||||
random.choices(requests, k=num_requests - len(requests))
|
||||
)
|
||||
for i in range(len(additional)):
|
||||
req = additional[i]
|
||||
req.request_id = request_id_prefix + str(len(requests) + i)
|
||||
additional = random.choices(requests, k=num_requests - len(requests))
|
||||
requests.extend(additional)
|
||||
logger.info("Oversampled requests to reach %d total samples.", num_requests)
|
||||
|
||||
@ -293,41 +277,6 @@ def process_image(image: Any) -> Mapping[str, Any]:
|
||||
)
|
||||
|
||||
|
||||
def process_video(video: Any) -> Mapping[str, Any]:
|
||||
"""
|
||||
Process a single video input and return a multimedia content dictionary.
|
||||
|
||||
Supports the following input types:
|
||||
|
||||
1. Dictionary with raw video bytes: - Expects a dict with a 'bytes' key
|
||||
containing raw video data.
|
||||
|
||||
2. String input: - Treats the string as a URL or local file path. -
|
||||
Prepends "file://" if the string doesn't start with "http://" or
|
||||
"file://". - Returns a dictionary with the image URL.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input is not a supported type.
|
||||
"""
|
||||
if isinstance(video, dict) and "bytes" in video:
|
||||
video_bytes = video["bytes"]
|
||||
video_base64 = base64.b64encode(video_bytes).decode("utf-8")
|
||||
return {
|
||||
"type": "video_url",
|
||||
"video_url": {"url": f"data:video/mp4;base64,{video_base64}"},
|
||||
}
|
||||
|
||||
if isinstance(video, str):
|
||||
video_url = (
|
||||
video if video.startswith(("http://", "file://")) else f"file://{video}"
|
||||
)
|
||||
return {"type": "video_url", "video_url": {"url": video_url}}
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid video input {video}. Must be a string of local path/remote url, or a dictionary with raw video bytes in the form of `{{'bytes': raw_video_bytes}}`." # noqa: E501
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Random Dataset Implementation (Synthetic Data)
|
||||
# -----------------------------------------------------------------------------
|
||||
@ -354,7 +303,6 @@ class RandomDataset(BenchmarkDataset):
|
||||
range_ratio: float = DEFAULT_RANGE_RATIO,
|
||||
input_len: int = DEFAULT_INPUT_LEN,
|
||||
output_len: int = DEFAULT_OUTPUT_LEN,
|
||||
request_id_prefix: str = "",
|
||||
**kwargs,
|
||||
) -> list[SampleRequest]:
|
||||
# Enforce range_ratio < 1
|
||||
@ -415,10 +363,8 @@ class RandomDataset(BenchmarkDataset):
|
||||
prompt=prompt,
|
||||
prompt_len=total_input_len,
|
||||
expected_output_len=int(output_lens[i]),
|
||||
request_id=request_id_prefix + str(i),
|
||||
)
|
||||
)
|
||||
|
||||
return requests
|
||||
|
||||
|
||||
@ -460,11 +406,9 @@ class ShareGPTDataset(BenchmarkDataset):
|
||||
max_loras: Optional[int] = None,
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
request_id_prefix: str = "",
|
||||
**kwargs,
|
||||
) -> list:
|
||||
samples: list = []
|
||||
ind = 0
|
||||
for entry in self.data:
|
||||
if len(samples) >= num_requests:
|
||||
break
|
||||
@ -486,26 +430,17 @@ class ShareGPTDataset(BenchmarkDataset):
|
||||
skip_min_output_len_check=output_len is not None,
|
||||
):
|
||||
continue
|
||||
if image_path := entry.get("image"):
|
||||
mm_content = process_image(image_path)
|
||||
elif video_path := entry.get("video"):
|
||||
mm_content = process_video(video_path)
|
||||
else:
|
||||
mm_content = None
|
||||
if enable_multimodal_chat:
|
||||
prompt = self.apply_multimodal_chat_transformation(prompt, mm_content)
|
||||
prompt = self.apply_multimodal_chat_transformation(prompt, None)
|
||||
samples.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=new_output_len,
|
||||
lora_request=lora_request,
|
||||
multi_modal_data=mm_content,
|
||||
request_id=request_id_prefix + str(ind),
|
||||
)
|
||||
)
|
||||
ind += 1
|
||||
self.maybe_oversample_requests(samples, num_requests, request_id_prefix)
|
||||
self.maybe_oversample_requests(samples, num_requests)
|
||||
return samples
|
||||
|
||||
|
||||
@ -571,11 +506,10 @@ class CustomDataset(BenchmarkDataset):
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
skip_chat_template: bool = False,
|
||||
request_id_prefix: str = "",
|
||||
**kwargs,
|
||||
) -> list:
|
||||
sampled_requests = []
|
||||
for i, item in enumerate(self.data):
|
||||
for item in self.data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
prompt = item["prompt"]
|
||||
@ -594,12 +528,9 @@ class CustomDataset(BenchmarkDataset):
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
request_id=request_id_prefix + str(i),
|
||||
)
|
||||
)
|
||||
self.maybe_oversample_requests(
|
||||
sampled_requests, num_requests, request_id_prefix
|
||||
)
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
|
||||
return sampled_requests
|
||||
|
||||
@ -641,7 +572,6 @@ class SonnetDataset(BenchmarkDataset):
|
||||
input_len: int = DEFAULT_INPUT_LEN,
|
||||
output_len: int = DEFAULT_OUTPUT_LEN,
|
||||
return_prompt_formatted: bool = False,
|
||||
request_id_prefix: str = "",
|
||||
**kwargs,
|
||||
) -> list:
|
||||
# Calculate average token length for a poem line.
|
||||
@ -667,7 +597,6 @@ class SonnetDataset(BenchmarkDataset):
|
||||
prefix_lines = self.data[:num_prefix_lines]
|
||||
|
||||
samples = []
|
||||
ind = 0
|
||||
while len(samples) < num_requests:
|
||||
extra_lines = random.choices(
|
||||
self.data, k=num_input_lines - num_prefix_lines
|
||||
@ -678,17 +607,14 @@ class SonnetDataset(BenchmarkDataset):
|
||||
msg, add_generation_prompt=True, tokenize=False
|
||||
)
|
||||
prompt_len = len(tokenizer(prompt_formatted).input_ids)
|
||||
|
||||
if prompt_len <= input_len:
|
||||
samples.append(
|
||||
SampleRequest(
|
||||
prompt=prompt_formatted if return_prompt_formatted else prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
request_id=request_id_prefix + str(ind),
|
||||
)
|
||||
)
|
||||
ind += 1
|
||||
return samples
|
||||
|
||||
|
||||
@ -740,7 +666,6 @@ class BurstGPTDataset(BenchmarkDataset):
|
||||
num_requests: int,
|
||||
max_loras: Optional[int] = None,
|
||||
lora_path: Optional[str] = None,
|
||||
request_id_prefix: str = "",
|
||||
**kwargs,
|
||||
) -> list[SampleRequest]:
|
||||
samples = []
|
||||
@ -762,7 +687,6 @@ class BurstGPTDataset(BenchmarkDataset):
|
||||
prompt_len=input_len,
|
||||
expected_output_len=output_len,
|
||||
lora_request=lora_req,
|
||||
request_id=request_id_prefix + str(i),
|
||||
)
|
||||
)
|
||||
return samples
|
||||
@ -822,14 +746,12 @@ class ConversationDataset(HuggingFaceDataset):
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
request_id_prefix: str = "",
|
||||
**kwargs,
|
||||
) -> list:
|
||||
# Filter examples with at least 2 conversations
|
||||
filtered_data = self.data.filter(lambda x: len(x["conversations"]) >= 2)
|
||||
sampled_requests = []
|
||||
dynamic_output = output_len is None
|
||||
ind = 0
|
||||
|
||||
for item in filtered_data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
@ -857,13 +779,9 @@ class ConversationDataset(HuggingFaceDataset):
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
multi_modal_data=mm_content,
|
||||
request_id=request_id_prefix + str(ind),
|
||||
)
|
||||
)
|
||||
ind += 1
|
||||
self.maybe_oversample_requests(
|
||||
sampled_requests, num_requests, request_id_prefix
|
||||
)
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
|
||||
|
||||
@ -890,12 +808,11 @@ class VisionArenaDataset(HuggingFaceDataset):
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
request_id_prefix: str = "",
|
||||
**kwargs,
|
||||
) -> list:
|
||||
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
|
||||
sampled_requests = []
|
||||
for i, item in enumerate(self.data):
|
||||
for item in self.data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path)
|
||||
@ -915,12 +832,9 @@ class VisionArenaDataset(HuggingFaceDataset):
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
multi_modal_data=mm_content,
|
||||
request_id=request_id_prefix + str(i),
|
||||
)
|
||||
)
|
||||
self.maybe_oversample_requests(
|
||||
sampled_requests, num_requests, request_id_prefix
|
||||
)
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
|
||||
|
||||
@ -950,18 +864,15 @@ class InstructCoderDataset(HuggingFaceDataset):
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
request_id_prefix: str = "",
|
||||
**kwargs,
|
||||
) -> list:
|
||||
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
|
||||
sampled_requests = []
|
||||
for i, item in enumerate(self.data):
|
||||
for item in self.data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
prompt = (
|
||||
f"{item['input']}\n\n{item['instruction']} Just output "
|
||||
"the code, do not include any explanation."
|
||||
)
|
||||
prompt = f"{item['input']}\n\n{item['instruction']} Just output \
|
||||
the code, do not include any explanation."
|
||||
|
||||
# apply template
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
@ -975,12 +886,9 @@ class InstructCoderDataset(HuggingFaceDataset):
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
request_id=request_id_prefix + str(i),
|
||||
)
|
||||
)
|
||||
self.maybe_oversample_requests(
|
||||
sampled_requests, num_requests, request_id_prefix
|
||||
)
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
|
||||
|
||||
@ -1010,13 +918,12 @@ class MTBenchDataset(HuggingFaceDataset):
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
request_id_prefix: str = "",
|
||||
**kwargs,
|
||||
) -> list:
|
||||
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
|
||||
sampled_requests = []
|
||||
|
||||
for i, item in enumerate(self.data):
|
||||
for item in self.data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
prompt = item["turns"][0]
|
||||
@ -1034,12 +941,9 @@ class MTBenchDataset(HuggingFaceDataset):
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
request_id=request_id_prefix + str(i),
|
||||
)
|
||||
)
|
||||
self.maybe_oversample_requests(
|
||||
sampled_requests, num_requests, request_id_prefix
|
||||
)
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
|
||||
|
||||
@ -1064,12 +968,10 @@ class AIMODataset(HuggingFaceDataset):
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
request_id_prefix: str = "",
|
||||
**kwargs,
|
||||
) -> list:
|
||||
sampled_requests = []
|
||||
dynamic_output = output_len is None
|
||||
ind = 0
|
||||
|
||||
for item in self.data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
@ -1092,13 +994,9 @@ class AIMODataset(HuggingFaceDataset):
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
multi_modal_data=None,
|
||||
request_id=request_id_prefix + str(ind),
|
||||
)
|
||||
)
|
||||
ind += 1
|
||||
self.maybe_oversample_requests(
|
||||
sampled_requests, num_requests, request_id_prefix
|
||||
)
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
|
||||
|
||||
@ -1168,18 +1066,12 @@ class NextEditPredictionDataset(HuggingFaceDataset):
|
||||
"zed-industries/zeta": _format_zeta_prompt,
|
||||
}
|
||||
|
||||
def sample(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
request_id_prefix: str = "",
|
||||
**kwargs,
|
||||
):
|
||||
def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, **kwargs):
|
||||
formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(self.dataset_path)
|
||||
if formatting_prompt_func is None:
|
||||
raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
|
||||
samples = []
|
||||
for i, sample in enumerate(self.data):
|
||||
for sample in self.data:
|
||||
sample = formatting_prompt_func(sample)
|
||||
samples.append(
|
||||
SampleRequest(
|
||||
@ -1188,12 +1080,11 @@ class NextEditPredictionDataset(HuggingFaceDataset):
|
||||
expected_output_len=len(
|
||||
tokenizer(sample["expected_output"]).input_ids
|
||||
),
|
||||
request_id=request_id_prefix + str(i),
|
||||
)
|
||||
)
|
||||
if len(samples) >= num_requests:
|
||||
break
|
||||
self.maybe_oversample_requests(samples, num_requests, request_id_prefix)
|
||||
self.maybe_oversample_requests(samples, num_requests)
|
||||
return samples
|
||||
|
||||
|
||||
@ -1242,7 +1133,6 @@ class ASRDataset(HuggingFaceDataset):
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
request_id_prefix: str = "",
|
||||
**kwargs,
|
||||
) -> list:
|
||||
import librosa
|
||||
@ -1252,7 +1142,6 @@ class ASRDataset(HuggingFaceDataset):
|
||||
prompt_len = len(tokenizer(prompt).input_ids)
|
||||
sampled_requests = []
|
||||
skipped = 0
|
||||
ind = 0
|
||||
for item in self.data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
@ -1271,10 +1160,8 @@ class ASRDataset(HuggingFaceDataset):
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
multi_modal_data=mm_content,
|
||||
request_id=request_id_prefix + str(ind),
|
||||
)
|
||||
)
|
||||
ind += 1
|
||||
if skipped:
|
||||
logger.warning(
|
||||
"%d samples discarded from dataset due to"
|
||||
@ -1282,7 +1169,5 @@ class ASRDataset(HuggingFaceDataset):
|
||||
" what Whisper supports.",
|
||||
skipped,
|
||||
)
|
||||
self.maybe_oversample_requests(
|
||||
sampled_requests, num_requests, request_id_prefix
|
||||
)
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
|
||||
@ -375,12 +375,11 @@ async def benchmark(
|
||||
rps_change_events.append({"rps": rps_val, "timestamp": timestamp})
|
||||
last_int_rps = current_int_rps
|
||||
|
||||
prompt, prompt_len, output_len, mm_content, request_id = (
|
||||
prompt, prompt_len, output_len, mm_content = (
|
||||
request.prompt,
|
||||
request.prompt_len,
|
||||
request.expected_output_len,
|
||||
request.multi_modal_data,
|
||||
request.request_id,
|
||||
)
|
||||
req_model_id, req_model_name = model_id, model_name
|
||||
if lora_modules:
|
||||
@ -398,7 +397,6 @@ async def benchmark(
|
||||
multi_modal_content=mm_content,
|
||||
ignore_eos=ignore_eos,
|
||||
extra_body=extra_body,
|
||||
request_id=request_id,
|
||||
)
|
||||
task = limited_request_func(request_func_input=request_func_input, pbar=pbar)
|
||||
tasks.append(asyncio.create_task(task))
|
||||
@ -667,7 +665,6 @@ def main(args: argparse.Namespace):
|
||||
tokenizer=tokenizer,
|
||||
output_len=args.custom_output_len,
|
||||
skip_chat_template=args.custom_skip_chat_template,
|
||||
request_id_prefix=args.request_id_prefix,
|
||||
)
|
||||
|
||||
elif args.dataset_name == "sonnet":
|
||||
@ -681,7 +678,6 @@ def main(args: argparse.Namespace):
|
||||
prefix_len=args.sonnet_prefix_len,
|
||||
tokenizer=tokenizer,
|
||||
return_prompt_formatted=False,
|
||||
request_id_prefix=args.request_id_prefix,
|
||||
)
|
||||
else:
|
||||
assert tokenizer.chat_template or tokenizer.default_chat_template, (
|
||||
@ -694,7 +690,6 @@ def main(args: argparse.Namespace):
|
||||
prefix_len=args.sonnet_prefix_len,
|
||||
tokenizer=tokenizer,
|
||||
return_prompt_formatted=True,
|
||||
request_id_prefix=args.request_id_prefix,
|
||||
)
|
||||
|
||||
elif args.dataset_name == "hf":
|
||||
@ -756,7 +751,6 @@ def main(args: argparse.Namespace):
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
output_len=args.hf_output_len,
|
||||
request_id_prefix=args.request_id_prefix,
|
||||
)
|
||||
|
||||
else:
|
||||
@ -768,15 +762,10 @@ def main(args: argparse.Namespace):
|
||||
tokenizer=tokenizer,
|
||||
num_requests=args.num_prompts,
|
||||
output_len=args.sharegpt_output_len,
|
||||
request_id_prefix=args.request_id_prefix,
|
||||
),
|
||||
"burstgpt": lambda: BurstGPTDataset(
|
||||
random_seed=args.seed, dataset_path=args.dataset_path
|
||||
).sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=args.num_prompts,
|
||||
request_id_prefix=args.request_id_prefix,
|
||||
),
|
||||
).sample(tokenizer=tokenizer, num_requests=args.num_prompts),
|
||||
"random": lambda: RandomDataset(dataset_path=args.dataset_path).sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=args.num_prompts,
|
||||
@ -784,7 +773,6 @@ def main(args: argparse.Namespace):
|
||||
input_len=args.random_input_len,
|
||||
output_len=args.random_output_len,
|
||||
range_ratio=args.random_range_ratio,
|
||||
request_id_prefix=args.request_id_prefix,
|
||||
),
|
||||
}
|
||||
|
||||
@ -1130,13 +1118,6 @@ def create_argument_parser():
|
||||
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
|
||||
"and the blog: https://hao-ai-lab.github.io/blogs/distserve",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--request-id-prefix",
|
||||
type=str,
|
||||
required=False,
|
||||
default="benchmark-serving",
|
||||
help="Specify the prefix of request id.",
|
||||
)
|
||||
|
||||
# group for dataset specific arguments
|
||||
custom_group = parser.add_argument_group("custom dataset options")
|
||||
|
||||
@ -597,8 +597,8 @@ def validate_args(args):
|
||||
# https://github.com/vllm-project/vllm/issues/16222
|
||||
if args.data_parallel_size > 1:
|
||||
raise ValueError(
|
||||
"Data parallel is not supported in offline benchmark, "
|
||||
"please use benchmark serving instead"
|
||||
"Data parallel is not supported in offline benchmark, \
|
||||
please use benchmark serving instead"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -1,199 +1,63 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
|
||||
import aiohttp
|
||||
from quart import Quart, Response, make_response, request
|
||||
from rate_limiter import RateLimiter
|
||||
from request_queue import RequestQueue
|
||||
from quart import Quart, make_response, request
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
||||
|
||||
app = Quart(__name__)
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""parse command line arguments"""
|
||||
parser = argparse.ArgumentParser(description="vLLM P/D disaggregation proxy server")
|
||||
|
||||
# Add args
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=float,
|
||||
default=300,
|
||||
help="Timeout for backend service requests in seconds (default: 300)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-concurrent",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Maximum concurrent requests to backend services (default: 100)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--queue-size",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Maximum number of requests in the queue (default: 500)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rate-limit",
|
||||
type=int,
|
||||
default=40,
|
||||
help="Maximum requests per second (default: 40)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="Port to run the server on (default: 8000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefill-url",
|
||||
type=str,
|
||||
default="http://localhost:8100/v1/completions",
|
||||
help="Prefill service endpoint URL",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decode-url",
|
||||
type=str,
|
||||
default="http://localhost:8200/v1/completions",
|
||||
help="Decode service endpoint URL",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""parse command line arguments"""
|
||||
args = parse_args()
|
||||
|
||||
# Initialize configuration using command line parameters
|
||||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=args.timeout)
|
||||
MAX_CONCURRENT_REQUESTS = args.max_concurrent
|
||||
REQUEST_QUEUE_SIZE = args.queue_size
|
||||
RATE_LIMIT = args.rate_limit
|
||||
PREFILL_SERVICE_URL = args.prefill_url
|
||||
DECODE_SERVICE_URL = args.decode_url
|
||||
PORT = args.port
|
||||
|
||||
app = Quart(__name__)
|
||||
|
||||
# Initialize the rate limiter and request queue
|
||||
rate_limiter = RateLimiter(RATE_LIMIT)
|
||||
request_queue = RequestQueue(MAX_CONCURRENT_REQUESTS, REQUEST_QUEUE_SIZE)
|
||||
|
||||
# Attach the configuration object to the application instance
|
||||
app.config.update(
|
||||
{
|
||||
"AIOHTTP_TIMEOUT": AIOHTTP_TIMEOUT,
|
||||
"rate_limiter": rate_limiter,
|
||||
"request_queue": request_queue,
|
||||
"PREFILL_SERVICE_URL": PREFILL_SERVICE_URL,
|
||||
"DECODE_SERVICE_URL": DECODE_SERVICE_URL,
|
||||
}
|
||||
)
|
||||
|
||||
# Start queue processing on app startup
|
||||
@app.before_serving
|
||||
async def startup():
|
||||
"""Start request processing task when app starts serving"""
|
||||
asyncio.create_task(request_queue.process())
|
||||
|
||||
async def forward_request(url, data):
|
||||
"""Forward request to backend service with rate limiting and error handling"""
|
||||
async def forward_request(url, data):
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
||||
async with session.post(url=url, json=data, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
# if response.headers.get('Transfer-Encoding') == 'chunked':
|
||||
if True:
|
||||
async for chunk_bytes in response.content.iter_chunked(1024):
|
||||
yield chunk_bytes
|
||||
else:
|
||||
content = await response.read()
|
||||
yield content
|
||||
|
||||
# Use rate limiter as context manager
|
||||
async with (
|
||||
rate_limiter,
|
||||
aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session,
|
||||
|
||||
@app.route("/v1/completions", methods=["POST"])
|
||||
async def handle_request():
|
||||
try:
|
||||
original_request_data = await request.get_json()
|
||||
|
||||
prefill_request = original_request_data.copy()
|
||||
# change max_tokens = 1 to let it only do prefill
|
||||
prefill_request["max_tokens"] = 1
|
||||
|
||||
# finish prefill
|
||||
async for _ in forward_request(
|
||||
"http://localhost:8100/v1/completions", prefill_request
|
||||
):
|
||||
try:
|
||||
async with session.post(
|
||||
url=url, json=data, headers=headers
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
# Stream response chunks
|
||||
async for chunk_bytes in response.content.iter_chunked(1024):
|
||||
yield chunk_bytes
|
||||
else:
|
||||
# Handle backend service errors
|
||||
error_text = await response.text()
|
||||
logger.error(
|
||||
"Backend service error: %s - %s",
|
||||
response.status,
|
||||
error_text,
|
||||
)
|
||||
yield b'{"error": "Backend service error"}'
|
||||
except aiohttp.ClientError as e:
|
||||
# Handle connection errors
|
||||
logger.error("Connection error to %s: %s", url, str(e))
|
||||
yield b'{"error": "Service unavailable"}'
|
||||
except asyncio.TimeoutError:
|
||||
# Handle timeout errors
|
||||
logger.error("Timeout connecting to %s", url)
|
||||
yield b'{"error": "Service timeout"}'
|
||||
continue
|
||||
|
||||
async def process_request():
|
||||
"""Process a single request through prefill and decode stages"""
|
||||
try:
|
||||
original_request_data = await request.get_json()
|
||||
# return decode
|
||||
generator = forward_request(
|
||||
"http://localhost:8200/v1/completions", original_request_data
|
||||
)
|
||||
response = await make_response(generator)
|
||||
response.timeout = None
|
||||
|
||||
# Create prefill request (max_tokens=1)
|
||||
prefill_request = original_request_data.copy()
|
||||
prefill_request["max_tokens"] = 1
|
||||
return response
|
||||
|
||||
# Execute prefill stage
|
||||
async for _ in forward_request(PREFILL_SERVICE_URL, prefill_request):
|
||||
continue
|
||||
except Exception as e:
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
# Execute decode stage and stream response
|
||||
generator = forward_request(DECODE_SERVICE_URL, original_request_data)
|
||||
response = await make_response(generator)
|
||||
response.timeout = None # Disable timeout for streaming response
|
||||
return response
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error processing request")
|
||||
return Response(
|
||||
response=b'{"error": "Internal server error"}',
|
||||
status=500,
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
@app.route("/v1/completions", methods=["POST"])
|
||||
async def handle_request():
|
||||
"""Handle incoming API requests with concurrency and rate limiting"""
|
||||
# Create task for request processing
|
||||
task = asyncio.create_task(process_request())
|
||||
|
||||
# Enqueue request or reject if queue is full
|
||||
if not await request_queue.enqueue(task):
|
||||
return Response(
|
||||
response=b'{"error": "Server busy, try again later"}',
|
||||
status=503,
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
try:
|
||||
# Return the response from the processing task
|
||||
return await task
|
||||
except asyncio.CancelledError:
|
||||
# Handle task cancellation (timeout or queue full)
|
||||
logger.warning("Request cancelled due to timeout or queue full")
|
||||
return Response(
|
||||
response=b'{"error": "Request cancelled"}',
|
||||
status=503,
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
# Start the Quart server with host can be set to 0.0.0.0
|
||||
app.run(port=PORT)
|
||||
exc_info = sys.exc_info()
|
||||
print("Error occurred in disagg prefill proxy server")
|
||||
print(e)
|
||||
print("".join(traceback.format_exception(*exc_info)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
app.run(port=8000)
|
||||
|
||||
@ -1,45 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Token bucket rate limiter implementation"""
|
||||
|
||||
def __init__(self, rate_limit):
|
||||
self.rate_limit = rate_limit # Requests per second
|
||||
self.num_available_tokens = rate_limit # Available tokens
|
||||
self.last_refill = time.monotonic() # Last token refill time
|
||||
self.lock = asyncio.Lock() # Synchronization lock
|
||||
|
||||
async def acquire(self):
|
||||
"""Acquire a token from the rate limiter"""
|
||||
while True:
|
||||
async with self.lock:
|
||||
current_time = time.monotonic()
|
||||
elapsed = current_time - self.last_refill
|
||||
|
||||
# Refill num_available_tokens if more than 1 second has passed
|
||||
if elapsed > 1.0:
|
||||
self.num_available_tokens = self.rate_limit
|
||||
self.last_refill = current_time
|
||||
|
||||
# Check if num_available_tokens are available
|
||||
if self.num_available_tokens > 0:
|
||||
self.num_available_tokens -= 1
|
||||
return True
|
||||
|
||||
# Calculate wait time if no num_available_tokens available
|
||||
wait_time = 1.0 - elapsed
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Enter async context manager - acquire token"""
|
||||
await self.acquire()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
"""Exit async context manager - no cleanup needed"""
|
||||
pass
|
||||
@ -1,39 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
from collections import deque
|
||||
|
||||
|
||||
class RequestQueue:
|
||||
"""Request queue manager with concurrency control"""
|
||||
|
||||
def __init__(self, max_concurrent, max_queue_size):
|
||||
# Maximum concurrent requests
|
||||
self.max_concurrent = max_concurrent
|
||||
self.max_queue_size = max_queue_size # Maximum queue size
|
||||
# Concurrency control
|
||||
self.semaphore = asyncio.Semaphore(max_concurrent)
|
||||
self.queue = deque() # Request queue
|
||||
self.queue_size = 0 # Current queue size
|
||||
self.lock = asyncio.Lock() # Sync queue Lock
|
||||
|
||||
async def enqueue(self, task):
|
||||
"""Add a request task to the queue"""
|
||||
async with self.lock:
|
||||
if self.queue_size >= self.max_queue_size:
|
||||
return False
|
||||
|
||||
self.queue.append(task)
|
||||
self.queue_size += 1
|
||||
return True
|
||||
|
||||
async def process(self):
|
||||
"""Process queued requests using semaphore for concurrency control"""
|
||||
while True:
|
||||
if self.queue:
|
||||
async with self.semaphore, self.lock:
|
||||
task = self.queue.popleft()
|
||||
self.queue_size -= 1
|
||||
await task
|
||||
await asyncio.sleep(0.01) # Yield control to event loop
|
||||
345
benchmarks/kernels/benchmark_aqlm.py
Normal file
345
benchmarks/kernels/benchmark_aqlm.py
Normal file
@ -0,0 +1,345 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.aqlm import (
|
||||
dequantize_weight,
|
||||
generic_dequantize_gemm,
|
||||
get_int_dtype,
|
||||
optimized_dequantize_gemm,
|
||||
)
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
|
||||
|
||||
def torch_mult(
|
||||
# [..., in_features]
|
||||
input: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
# [num_out_groups, 1, 1, 1]
|
||||
scales: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
output = F.linear(input, weights)
|
||||
return output
|
||||
|
||||
|
||||
def dequant_out_scale(
|
||||
# [..., in_features]
|
||||
input: torch.Tensor,
|
||||
# [num_out_groups, num_in_groups, num_codebooks]
|
||||
codes: torch.IntTensor,
|
||||
# [num_codebooks, codebook_size, out_group_size, in_group_size]
|
||||
codebooks: torch.Tensor,
|
||||
# [num_out_groups, 1, 1, 1]
|
||||
scales: torch.Tensor,
|
||||
output_partition_sizes: torch.IntTensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
|
||||
|
||||
if bias is None:
|
||||
output = F.linear(input, weights, bias)
|
||||
orig_shape = output.shape
|
||||
flattened_output = output.view(-1, output.size(-1))
|
||||
f_scales = scales.view(-1, scales.shape[0])
|
||||
b_scales = f_scales.expand(flattened_output.shape[0], -1)
|
||||
flattened_output *= b_scales
|
||||
return flattened_output.view(orig_shape)
|
||||
else:
|
||||
b_scales = scales.view(scales.shape[:-3] + (-1,)).expand(-1, weights.shape[1])
|
||||
weights *= b_scales
|
||||
return F.linear(input, weights, bias)
|
||||
|
||||
|
||||
def dequant_weight_scale(
|
||||
# [..., in_features]
|
||||
input: torch.Tensor,
|
||||
# [num_out_groups, num_in_groups, num_codebooks]
|
||||
codes: torch.IntTensor,
|
||||
# [num_codebooks, codebook_size, out_group_size, in_group_size]
|
||||
codebooks: torch.Tensor,
|
||||
# [num_out_groups, 1, 1, 1]
|
||||
scales: torch.Tensor,
|
||||
output_partition_sizes: torch.IntTensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
|
||||
|
||||
b_scales = scales.view(scales.shape[:-3] + (-1,)).expand(-1, weights.shape[1])
|
||||
weights *= b_scales
|
||||
return F.linear(input, weights, bias)
|
||||
|
||||
|
||||
def dequant_no_scale(
|
||||
# [..., in_features]
|
||||
input: torch.Tensor,
|
||||
# [num_out_groups, num_in_groups, num_codebooks]
|
||||
codes: torch.IntTensor,
|
||||
# [num_codebooks, codebook_size, out_group_size, in_group_size]
|
||||
codebooks: torch.Tensor,
|
||||
# [num_out_groups, 1, 1, 1]
|
||||
scales: torch.Tensor,
|
||||
output_partition_sizes: torch.IntTensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
|
||||
|
||||
return F.linear(input, weights, bias)
|
||||
|
||||
|
||||
# Compare the optimized 1x16 and 2x8 cuda decompression/dequant kernels against
|
||||
# the generic pytorch version.
|
||||
# Just visual comparison.
|
||||
def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None:
|
||||
n = int(parts.sum().item())
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
code_range = (1 << bits) // 2
|
||||
ingroups = 8
|
||||
|
||||
codes = torch.randint(
|
||||
-code_range,
|
||||
code_range,
|
||||
size=(n, k // ingroups, nbooks),
|
||||
dtype=get_int_dtype(bits),
|
||||
device=device,
|
||||
)
|
||||
|
||||
codebooks = torch.randn(
|
||||
size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
|
||||
dtype=torch.float16,
|
||||
device=device,
|
||||
)
|
||||
|
||||
count = 0
|
||||
for index in range(16):
|
||||
for i in range(8):
|
||||
for book in range(nbooks):
|
||||
codebooks[book, index, 0, i] = count * (10**book)
|
||||
count += 1
|
||||
|
||||
print("codes shape", codes.shape)
|
||||
|
||||
for i in range(16):
|
||||
for book in range(nbooks):
|
||||
codes[0, i, book] = i
|
||||
codes[0, -i, book] = i
|
||||
|
||||
weights = dequantize_weight(codes, codebooks, None)
|
||||
weights2 = ops.aqlm_dequant(codes, codebooks, parts)
|
||||
|
||||
print("weights shape:", weights.shape)
|
||||
print("weights2 shape:", weights2.shape)
|
||||
|
||||
print("weights are:", weights)
|
||||
print("weights2 are:", weights2)
|
||||
|
||||
print("first 128 weights are", weights[0, 0:128].to(torch.int32))
|
||||
print("first 128 weights2 are:", weights2[0, 0:128].to(torch.int32))
|
||||
|
||||
print("last 128 weights are", weights[0, -128:])
|
||||
print("last 128 weights2 are:", weights2[0, -128:])
|
||||
|
||||
|
||||
def main():
|
||||
parser = FlexibleArgumentParser(description="Benchmark aqlm performance.")
|
||||
|
||||
# Add arguments
|
||||
parser.add_argument(
|
||||
"--nbooks", type=int, default=1, help="Number of codebooks (default: 1)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bits",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Number of bits per code element (default: 16)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Run the decompression/dequant tester rather than benchmarking "
|
||||
"(default: False)",
|
||||
)
|
||||
|
||||
# Parse the arguments
|
||||
args = parser.parse_args()
|
||||
|
||||
# Extract values
|
||||
nbooks = args.nbooks
|
||||
bits = args.bits
|
||||
|
||||
if args.test:
|
||||
dequant_test(4096, torch.tensor((4096,)), nbooks, bits)
|
||||
return
|
||||
|
||||
# Otherwise, benchmark.
|
||||
methods = [
|
||||
ops.aqlm_gemm,
|
||||
dequant_out_scale,
|
||||
generic_dequantize_gemm,
|
||||
optimized_dequantize_gemm,
|
||||
dequant_weight_scale,
|
||||
torch_mult,
|
||||
dequant_no_scale,
|
||||
]
|
||||
|
||||
filename = f"./aqlm_benchmark_{nbooks}x{bits}.csv"
|
||||
print(f"writing benchmarks to file {filename}")
|
||||
with open(filename, "w") as f:
|
||||
sys.stdout = f
|
||||
|
||||
print("m | k | n | n parts", end="")
|
||||
for method in methods:
|
||||
print(f" | {method.__name__.replace('_', ' ')} (µs)", end="")
|
||||
print("")
|
||||
|
||||
# These are reasonable prefill sizes.
|
||||
ksandpartions = (
|
||||
(4096, (4096, 4096, 4096)),
|
||||
(4096, (4096,)),
|
||||
(4096, (11008, 11008)),
|
||||
(11008, (4096,)),
|
||||
)
|
||||
|
||||
# reasonable ranges for m.
|
||||
for m in [
|
||||
1,
|
||||
2,
|
||||
4,
|
||||
8,
|
||||
10,
|
||||
12,
|
||||
14,
|
||||
16,
|
||||
24,
|
||||
32,
|
||||
48,
|
||||
52,
|
||||
56,
|
||||
64,
|
||||
96,
|
||||
112,
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
1536,
|
||||
2048,
|
||||
3072,
|
||||
4096,
|
||||
]:
|
||||
print(f"{m}", file=sys.__stdout__)
|
||||
for ksp in ksandpartions:
|
||||
run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits, methods)
|
||||
|
||||
sys.stdout = sys.__stdout__
|
||||
|
||||
|
||||
def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, methods):
|
||||
# I didn't see visible improvements from increasing these, but feel free :)
|
||||
num_warmup_trials = 1
|
||||
num_trials = 1
|
||||
|
||||
num_calls = 100
|
||||
|
||||
# warmup.
|
||||
for method in methods:
|
||||
for _ in range(num_warmup_trials):
|
||||
run_timing(
|
||||
num_calls=num_calls,
|
||||
m=m,
|
||||
k=k,
|
||||
parts=parts,
|
||||
nbooks=nbooks,
|
||||
bits=bits,
|
||||
method=method,
|
||||
)
|
||||
|
||||
n = parts.sum().item()
|
||||
print(f"{m} | {k} | {n} | {parts.tolist()}", end="")
|
||||
|
||||
for method in methods:
|
||||
best_time_us = 1e20
|
||||
for _ in range(num_trials):
|
||||
kernel_dur_ms = run_timing(
|
||||
num_calls=num_calls,
|
||||
m=m,
|
||||
k=k,
|
||||
parts=parts,
|
||||
nbooks=nbooks,
|
||||
bits=bits,
|
||||
method=method,
|
||||
)
|
||||
|
||||
kernel_dur_us = 1000 * kernel_dur_ms
|
||||
|
||||
if kernel_dur_us < best_time_us:
|
||||
best_time_us = kernel_dur_us
|
||||
|
||||
print(f" | {kernel_dur_us:.0f}", end="")
|
||||
|
||||
print("")
|
||||
|
||||
|
||||
def run_timing(
|
||||
num_calls: int, m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, method
|
||||
) -> float:
|
||||
n = int(parts.sum().item())
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
input = torch.randn((1, m, k), dtype=torch.float16, device=device)
|
||||
|
||||
code_range = (1 << bits) // 2
|
||||
ingroups = 8
|
||||
|
||||
codes = torch.randint(
|
||||
-code_range,
|
||||
code_range,
|
||||
size=(n, k // ingroups, nbooks),
|
||||
dtype=get_int_dtype(bits),
|
||||
device=device,
|
||||
)
|
||||
|
||||
codebooks = torch.randn(
|
||||
size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
|
||||
dtype=torch.float16,
|
||||
device=device,
|
||||
)
|
||||
|
||||
scales = torch.randn(size=(n, 1, 1, 1), dtype=torch.float16, device=device)
|
||||
|
||||
# for comparison to just a pytorch mult.
|
||||
weights = torch.randn((n, k), dtype=torch.float16, device=device)
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
start_event.record()
|
||||
|
||||
if method is torch_mult:
|
||||
for i in range(num_calls):
|
||||
torch_mult(input, weights, scales)
|
||||
else:
|
||||
for i in range(num_calls):
|
||||
method(input, codes, codebooks, scales, parts, None)
|
||||
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
|
||||
dur_ms = start_event.elapsed_time(end_event) / num_calls
|
||||
return dur_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@ -80,11 +80,6 @@ def bench_run(
|
||||
a, score, topk, renormalize=False
|
||||
)
|
||||
|
||||
ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
|
||||
|
||||
def run_triton_moe(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
@ -116,10 +111,6 @@ def bench_run(
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
per_act_token: bool,
|
||||
@ -134,10 +125,6 @@ def bench_run(
|
||||
topk_ids,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
per_act_token,
|
||||
a1_scale=None,
|
||||
)
|
||||
@ -149,10 +136,6 @@ def bench_run(
|
||||
w2_q: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
):
|
||||
@ -167,10 +150,6 @@ def bench_run(
|
||||
topk_ids,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
per_act_token,
|
||||
a1_scale=None,
|
||||
)
|
||||
@ -215,10 +194,6 @@ def bench_run(
|
||||
w2_q,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
)
|
||||
@ -256,10 +231,6 @@ def bench_run(
|
||||
"w1_scale": w1_scale,
|
||||
"w2_scale": w2_scale,
|
||||
"per_act_token": per_act_token,
|
||||
"ab_strides1": ab_strides1,
|
||||
"ab_strides2": ab_strides2,
|
||||
"c_strides1": c_strides1,
|
||||
"c_strides2": c_strides2,
|
||||
# cuda graph params
|
||||
"cutlass_graph": cutlass_graph,
|
||||
"triton_graph": triton_graph,
|
||||
@ -318,10 +289,6 @@ def bench_run(
|
||||
w2_q,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
per_act_token,
|
||||
@ -330,7 +297,7 @@ def bench_run(
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
|
||||
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
|
||||
@ -236,7 +236,6 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
|
||||
a=bt.a,
|
||||
c=None,
|
||||
b_q_weight=w_q,
|
||||
b_bias=None,
|
||||
b_scales=w_s,
|
||||
global_scale=None,
|
||||
b_zeros=w_zp,
|
||||
@ -253,7 +252,28 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
|
||||
else:
|
||||
assert bt.a.dtype == torch.int8
|
||||
assert bt.wtype == scalar_types.uint4b8
|
||||
raise NotImplementedError("QQQ is not supported anymore")
|
||||
|
||||
if bt.w_ch_s is not None:
|
||||
s_ch = bt.w_ch_s.to(torch.float32)
|
||||
else:
|
||||
s_ch = torch.ones(bt.w_ref.shape[1], dtype=torch.float32, device=device)
|
||||
|
||||
if bt.w_tok_s is not None:
|
||||
s_tok = bt.w_tok_s.to(torch.float32)
|
||||
else:
|
||||
s_tok = torch.ones(bt.a.shape[0], dtype=torch.float32, device=device)
|
||||
|
||||
fn = lambda: ops.marlin_qqq_gemm(
|
||||
a=bt.a,
|
||||
b_q_weight=w_q,
|
||||
s_group=w_s,
|
||||
s_tok=s_tok,
|
||||
s_ch=s_ch,
|
||||
workspace=workspace.scratch,
|
||||
size_m=bt.a.shape[0],
|
||||
size_n=bt.w_ref.shape[1],
|
||||
size_k=bt.w_ref.shape[0],
|
||||
)
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from datetime import datetime
|
||||
@ -430,6 +429,7 @@ class BenchmarkWorker:
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype_str,
|
||||
is_marlin=False,
|
||||
)
|
||||
else:
|
||||
config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]
|
||||
@ -542,7 +542,6 @@ def save_configs(
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
block_quant_shape: list[int],
|
||||
save_dir: str,
|
||||
) -> None:
|
||||
dtype_str = get_config_dtype_str(
|
||||
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
|
||||
@ -553,8 +552,7 @@ def save_configs(
|
||||
filename = get_config_file_name(
|
||||
num_experts, shard_intermediate_size // 2, dtype_str, block_quant_shape
|
||||
)
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
filename = os.path.join(save_dir, filename)
|
||||
|
||||
print(f"Writing best config to {filename}...")
|
||||
with open(filename, "w") as f:
|
||||
json.dump(configs, f, indent=4)
|
||||
@ -709,7 +707,6 @@ def main(args: argparse.Namespace):
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
block_quant_shape,
|
||||
args.save_dir,
|
||||
)
|
||||
end = time.time()
|
||||
print(f"Tuning took {end - start:.2f} seconds")
|
||||
@ -751,9 +748,6 @@ if __name__ == "__main__":
|
||||
"--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
|
||||
)
|
||||
parser.add_argument("--use-deep-gemm", action="store_true")
|
||||
parser.add_argument(
|
||||
"--save-dir", type=str, default="./", help="Directory to save tuned results"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--batch-size", type=int, nargs="+", required=False)
|
||||
parser.add_argument("--tune", action="store_true")
|
||||
|
||||
@ -3,14 +3,16 @@
|
||||
|
||||
import csv
|
||||
import os
|
||||
import random
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import flashinfer
|
||||
import torch
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
FP8_DTYPE = torch.float8_e4m3fn
|
||||
|
||||
# KV Cache Layout for TRT-LLM
|
||||
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
|
||||
|
||||
|
||||
def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||
@ -24,107 +26,65 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||
|
||||
@torch.no_grad()
|
||||
def benchmark_decode(
|
||||
dtype: torch.dtype,
|
||||
quant_dtypes: tuple[
|
||||
Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype]
|
||||
],
|
||||
batch_size: int,
|
||||
max_seq_len: int,
|
||||
num_heads: tuple[int, int] = (64, 8),
|
||||
head_size: int = 128,
|
||||
kv_layout: str = "HND",
|
||||
block_size: int = 16,
|
||||
warmup: int = 10,
|
||||
trials: int = 20,
|
||||
num_seqs,
|
||||
max_seq_len,
|
||||
page_size=16,
|
||||
dtype=torch.bfloat16,
|
||||
kv_layout="HND",
|
||||
num_kv_heads=8,
|
||||
kv_cache_dtype="auto",
|
||||
head_dim=128,
|
||||
warmup=10,
|
||||
trials=20,
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
device = "cuda"
|
||||
torch.manual_seed(0)
|
||||
|
||||
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
|
||||
q_quant_dtype = q_quant_dtype or dtype
|
||||
kv_quant_dtype = kv_quant_dtype or dtype
|
||||
o_quant_dtype = o_quant_dtype or dtype
|
||||
|
||||
num_qo_heads, num_kv_heads = num_heads
|
||||
assert num_qo_heads % num_kv_heads == 0
|
||||
|
||||
sm_scale = float(1.0 / (head_size**0.5))
|
||||
HEAD_GRP_SIZE = 8
|
||||
MAX_SEQ_LEN = max_seq_len
|
||||
|
||||
# large number to reduce kv_cache reuse
|
||||
NUM_BLOCKS = int(256000 / block_size)
|
||||
NUM_BLOCKS = int(256000 / page_size)
|
||||
|
||||
kv_cache_shape = None
|
||||
if kv_layout == "NHD":
|
||||
kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
|
||||
elif kv_layout == "HND":
|
||||
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
|
||||
else:
|
||||
raise ValueError(f"Invalid kv_layout: {kv_layout}")
|
||||
workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8, device=device)
|
||||
|
||||
query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype)
|
||||
if q_quant_dtype == FP8_DTYPE:
|
||||
query, q_scale = to_float8(query)
|
||||
ref_query = query.to(dtype) * q_scale
|
||||
else:
|
||||
q_scale = 1.0
|
||||
ref_query = query
|
||||
# For decode, batch_size is num_decode_token
|
||||
num_qo_heads = num_kv_heads * HEAD_GRP_SIZE
|
||||
sm_scale = float(1.0 / (head_dim**0.5))
|
||||
q = torch.randn(num_seqs, num_qo_heads, head_dim, device=device, dtype=dtype)
|
||||
kv_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
||||
|
||||
kv_lens = torch.randint(1, max_seq_len, (batch_size,), dtype=torch.int32)
|
||||
kv_lens[-1] = max_seq_len
|
||||
max_kv_len = max(kv_lens)
|
||||
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int, device=device)
|
||||
max_num_blocks_per_seq = (max_kv_len + page_size - 1) // page_size
|
||||
|
||||
seq_lens = kv_lens
|
||||
max_seq_len = torch.max(seq_lens).item()
|
||||
|
||||
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
||||
if kv_quant_dtype == FP8_DTYPE:
|
||||
kv_cache, kv_scale = to_float8(kv_cache)
|
||||
ref_kv_cache = kv_cache.to(dtype) * kv_scale
|
||||
else:
|
||||
kv_scale = 1.0
|
||||
ref_kv_cache = kv_cache
|
||||
k_scale = v_scale = kv_scale
|
||||
|
||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(
|
||||
0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32
|
||||
0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
kv_last_page_lens = []
|
||||
for i in range(batch_size):
|
||||
seq_len = seq_lens[i]
|
||||
assert seq_len > 0
|
||||
num_blocks = (seq_len + block_size - 1) // block_size
|
||||
kv_indices.extend(block_tables[i, :num_blocks])
|
||||
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||
kv_last_page_len = seq_len % block_size
|
||||
if kv_last_page_len == 0:
|
||||
kv_last_page_len = block_size
|
||||
kv_last_page_lens.append(kv_last_page_len)
|
||||
|
||||
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||
workspace_buffer = torch.zeros(1024 * 1024 * 1024, dtype=torch.int8)
|
||||
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim)
|
||||
kv_cache = torch.randn(size=kv_cache_shape, device=device, dtype=dtype)
|
||||
k_scale = v_scale = 1.0
|
||||
|
||||
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer,
|
||||
kv_layout,
|
||||
use_tensor_cores=((num_qo_heads // num_kv_heads) > 4),
|
||||
)
|
||||
wrapper.plan(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
"NONE",
|
||||
sm_scale=sm_scale,
|
||||
q_data_type=dtype,
|
||||
kv_data_type=dtype,
|
||||
)
|
||||
if kv_cache_dtype.startswith("fp8"):
|
||||
kv_cache, _ = to_float8(kv_cache)
|
||||
|
||||
output_trtllm = torch.empty(q.shape, dtype=dtype)
|
||||
|
||||
# Benchmark TRT decode
|
||||
def trt_decode():
|
||||
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
|
||||
q,
|
||||
kv_cache,
|
||||
workspace_buffer,
|
||||
block_tables,
|
||||
kv_lens_tensor,
|
||||
max_kv_len,
|
||||
bmm1_scale=k_scale * sm_scale,
|
||||
bmm2_scale=v_scale,
|
||||
out=output_trtllm,
|
||||
)
|
||||
|
||||
def time_fn(fn, warmup=10, trials=20):
|
||||
torch.cuda.synchronize()
|
||||
@ -141,51 +101,74 @@ def benchmark_decode(
|
||||
times.append(start.elapsed_time(end)) # ms
|
||||
return sum(times) / len(times), torch.std(torch.tensor(times))
|
||||
|
||||
o_scale = 1.0
|
||||
output_baseline = torch.empty(ref_query.shape, dtype=dtype)
|
||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||
# TRT Decode
|
||||
trt_mean, trt_std = time_fn(trt_decode)
|
||||
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
kv_last_page_lens = []
|
||||
for i in range(num_seqs):
|
||||
seq_len = kv_lens[i]
|
||||
assert seq_len > 0
|
||||
num_blocks = (seq_len + page_size - 1) // page_size
|
||||
kv_indices.extend(block_tables[i, :num_blocks])
|
||||
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||
kv_last_page_len = seq_len % page_size
|
||||
if kv_last_page_len == 0:
|
||||
kv_last_page_len = page_size
|
||||
kv_last_page_lens.append(kv_last_page_len)
|
||||
|
||||
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||
|
||||
output_baseline = torch.empty(q.shape, dtype=dtype)
|
||||
|
||||
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer,
|
||||
kv_layout,
|
||||
use_tensor_cores=((num_qo_heads // num_kv_heads) > 4),
|
||||
)
|
||||
|
||||
wrapper.plan(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
page_size,
|
||||
"NONE",
|
||||
q_data_type=dtype,
|
||||
kv_data_type=torch.float8_e4m3fn if kv_cache_dtype.startswith("fp8") else dtype,
|
||||
)
|
||||
|
||||
def baseline_decode():
|
||||
return wrapper.run(ref_query, ref_kv_cache, out=output_baseline)
|
||||
|
||||
def trtllm_decode():
|
||||
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
|
||||
query=query,
|
||||
kv_cache=kv_cache,
|
||||
workspace_buffer=workspace_buffer,
|
||||
block_tables=block_tables,
|
||||
seq_lens=seq_lens,
|
||||
max_seq_len=max_seq_len,
|
||||
bmm1_scale=q_scale * k_scale * sm_scale,
|
||||
bmm2_scale=v_scale / o_scale,
|
||||
out=output_trtllm,
|
||||
)
|
||||
return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale, output_baseline)
|
||||
|
||||
baseline_mean, baseline_std = time_fn(baseline_decode)
|
||||
trtllm_mean, trtllm_std = time_fn(trtllm_decode)
|
||||
|
||||
# Calculate percentage speedup (positive means TRT is faster)
|
||||
speedup_percent = (baseline_mean - trtllm_mean) / baseline_mean
|
||||
speedup_percent = (baseline_mean - trt_mean) / baseline_mean
|
||||
|
||||
print(
|
||||
f"\t{batch_size}\t{max_seq_len}\t{trtllm_mean:.3f}\t{trtllm_std.item():.3f}"
|
||||
f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.3f}\t{trt_std.item():.3f}"
|
||||
f"\t{baseline_mean:.3f}\t{baseline_std.item():.3f}\t{speedup_percent:.3f}"
|
||||
)
|
||||
|
||||
# Return results for CSV writing
|
||||
return {
|
||||
"batch_size": batch_size,
|
||||
"trtllm_mean": trtllm_mean,
|
||||
"trtllm_std": trtllm_std.item(),
|
||||
"num_seqs": num_seqs,
|
||||
"trt_mean": trt_mean,
|
||||
"trt_std": trt_std.item(),
|
||||
"baseline_mean": baseline_mean,
|
||||
"baseline_std": baseline_std.item(),
|
||||
"speedup_percent": speedup_percent,
|
||||
"q_dtype": str(q_quant_dtype),
|
||||
"kv_cache_dtype": str(kv_quant_dtype),
|
||||
"output_dtype": str(o_quant_dtype),
|
||||
"block_size": block_size,
|
||||
"q_dtype": str(dtype),
|
||||
"kv_cache_dtype": kv_cache_dtype,
|
||||
"page_size": page_size,
|
||||
"num_kv_heads": num_kv_heads,
|
||||
"head_size": head_size,
|
||||
"head_dim": head_dim,
|
||||
"max_seq_len": max_seq_len,
|
||||
}
|
||||
|
||||
@ -197,18 +180,17 @@ def write_results_to_csv(results, filename=None):
|
||||
filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv"
|
||||
|
||||
fieldnames = [
|
||||
"batch_size",
|
||||
"trtllm_mean",
|
||||
"trtllm_std",
|
||||
"num_seqs",
|
||||
"trt_mean",
|
||||
"trt_std",
|
||||
"baseline_mean",
|
||||
"baseline_std",
|
||||
"speedup_percent",
|
||||
"q_dtype",
|
||||
"kv_cache_dtype",
|
||||
"output_dtype",
|
||||
"block_size",
|
||||
"page_size",
|
||||
"num_kv_heads",
|
||||
"head_size",
|
||||
"head_dim",
|
||||
"max_seq_len",
|
||||
]
|
||||
|
||||
@ -227,42 +209,45 @@ def write_results_to_csv(results, filename=None):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256]
|
||||
num_seqs = [1, 4, 8, 16, 32, 64, 128, 256]
|
||||
max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
|
||||
all_results = []
|
||||
|
||||
dtype = torch.bfloat16
|
||||
quant_dtypes = [
|
||||
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
|
||||
(None, None, None),
|
||||
(None, FP8_DTYPE, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
|
||||
]
|
||||
print(
|
||||
"Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, "
|
||||
"output_dtype: bfloat16"
|
||||
)
|
||||
print(
|
||||
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t"
|
||||
"baseline_std\tspeedup_percent"
|
||||
)
|
||||
for max_seq_len in max_seq_lens:
|
||||
for bs in num_seqs:
|
||||
result = benchmark_decode(
|
||||
bs,
|
||||
max_seq_len,
|
||||
dtype=torch.bfloat16,
|
||||
kv_cache_dtype="auto",
|
||||
)
|
||||
all_results.append(result)
|
||||
|
||||
for quant_dtype in quant_dtypes:
|
||||
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtype
|
||||
q_quant_dtype = q_quant_dtype or dtype
|
||||
kv_quant_dtype = kv_quant_dtype or dtype
|
||||
o_quant_dtype = o_quant_dtype or dtype
|
||||
|
||||
print(
|
||||
f"Running benchmark for q_dtype = {q_quant_dtype}, "
|
||||
f"kv_cache_dtype: {kv_quant_dtype}, "
|
||||
f"output_dtype: {o_quant_dtype}"
|
||||
)
|
||||
print(
|
||||
"\tbatch_size\tmax_seq_len\ttrtllm_mean\ttrtllm_std\tbaseline_mean\t"
|
||||
"baseline_std\tspeedup_percent"
|
||||
)
|
||||
for max_seq_len in max_seq_lens:
|
||||
for bs in batch_sizes:
|
||||
result = benchmark_decode(
|
||||
dtype=dtype,
|
||||
quant_dtypes=quant_dtype,
|
||||
batch_size=bs,
|
||||
max_seq_len=max_seq_len,
|
||||
)
|
||||
all_results.append(result)
|
||||
print(
|
||||
"Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8, "
|
||||
"output_dtype: bfloat16"
|
||||
)
|
||||
print(
|
||||
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t"
|
||||
"baseline_std\tspeedup_percent"
|
||||
)
|
||||
for max_seq_len in max_seq_lens:
|
||||
for bs in num_seqs:
|
||||
result = benchmark_decode(
|
||||
bs,
|
||||
max_seq_len,
|
||||
dtype=torch.bfloat16,
|
||||
kv_cache_dtype="fp8",
|
||||
)
|
||||
all_results.append(result)
|
||||
|
||||
# Write all results to CSV
|
||||
write_results_to_csv(all_results)
|
||||
|
||||
@ -3,14 +3,16 @@
|
||||
|
||||
import csv
|
||||
import os
|
||||
import random
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import flashinfer
|
||||
import torch
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
FP8_DTYPE = torch.float8_e4m3fn
|
||||
|
||||
# KV Cache Layout for TRT-LLM
|
||||
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
|
||||
|
||||
|
||||
def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||
@ -24,99 +26,84 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||
|
||||
@torch.no_grad()
|
||||
def benchmark_prefill(
|
||||
dtype: torch.dtype,
|
||||
quant_dtypes: tuple[
|
||||
Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype]
|
||||
],
|
||||
batch_size: int,
|
||||
max_seq_len: int,
|
||||
num_heads: tuple[int, int] = (64, 8),
|
||||
head_size: int = 128,
|
||||
kv_layout: str = "HND",
|
||||
block_size: int = 16,
|
||||
warmup: int = 10,
|
||||
trials: int = 20,
|
||||
num_seqs,
|
||||
max_seq_len,
|
||||
page_size=16,
|
||||
dtype=torch.bfloat16,
|
||||
kv_layout="HND",
|
||||
num_kv_heads=8,
|
||||
kv_cache_dtype="auto",
|
||||
head_dim=128,
|
||||
warmup=10,
|
||||
trials=20,
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
torch.manual_seed(0)
|
||||
|
||||
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
|
||||
q_quant_dtype = q_quant_dtype or dtype
|
||||
kv_quant_dtype = kv_quant_dtype or dtype
|
||||
o_quant_dtype = o_quant_dtype or dtype
|
||||
|
||||
max_q_len = max_kv_len = max_seq_len
|
||||
|
||||
num_qo_heads, num_kv_heads = num_heads
|
||||
assert num_qo_heads % num_kv_heads == 0
|
||||
|
||||
sm_scale = float(1.0 / (head_size**0.5))
|
||||
HEAD_GRP_SIZE = 8
|
||||
MAX_SEQ_LEN = max_seq_len
|
||||
|
||||
# large number to reduce kv_cache reuse
|
||||
NUM_BLOCKS = int(256000 / block_size)
|
||||
NUM_BLOCKS = int(256000 / page_size)
|
||||
|
||||
kv_cache_shape = None
|
||||
if kv_layout == "NHD":
|
||||
kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
|
||||
elif kv_layout == "HND":
|
||||
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
|
||||
else:
|
||||
raise ValueError(f"Invalid kv_layout: {kv_layout}")
|
||||
workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8)
|
||||
|
||||
q_lens = torch.randint(1, max_q_len, (batch_size,), dtype=torch.int32)
|
||||
q_lens[-1] = max_q_len
|
||||
num_qo_heads = num_kv_heads * HEAD_GRP_SIZE
|
||||
sm_scale = float(1.0 / (head_dim**0.5))
|
||||
|
||||
q_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
||||
q_lens[-1] = MAX_SEQ_LEN
|
||||
max_q_len = max(q_lens)
|
||||
q_indptr = torch.cat(
|
||||
[
|
||||
torch.tensor([0], dtype=torch.int32),
|
||||
torch.cumsum(q_lens, dim=0, dtype=torch.int32),
|
||||
torch.cumsum(
|
||||
torch.tensor(q_lens, dtype=torch.int32), dim=0, dtype=torch.int32
|
||||
),
|
||||
]
|
||||
)
|
||||
q = torch.randn(sum(q_lens), num_qo_heads, head_dim, dtype=dtype)
|
||||
|
||||
query = torch.randn(torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype)
|
||||
if q_quant_dtype == FP8_DTYPE:
|
||||
query, q_scale = to_float8(query)
|
||||
ref_query = query.to(dtype) * q_scale
|
||||
else:
|
||||
q_scale = 1.0
|
||||
ref_query = query
|
||||
kv_lens = [random.randint(0, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
||||
kv_lens[-1] = MAX_SEQ_LEN
|
||||
|
||||
kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32)
|
||||
kv_lens[-1] = max_kv_len
|
||||
seq_lens = [q_len + kv_len for q_len, kv_len in zip(q_lens, kv_lens)]
|
||||
max_seq_len = max(seq_lens)
|
||||
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int32)
|
||||
|
||||
seq_lens = kv_lens + q_lens
|
||||
max_seq_len = torch.max(seq_lens).item()
|
||||
|
||||
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
||||
if kv_quant_dtype == FP8_DTYPE:
|
||||
kv_cache, kv_scale = to_float8(kv_cache)
|
||||
ref_kv_cache = kv_cache.to(dtype) * kv_scale
|
||||
else:
|
||||
kv_scale = 1.0
|
||||
ref_kv_cache = kv_cache
|
||||
k_scale = v_scale = kv_scale
|
||||
|
||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||
max_num_blocks_per_seq = (max_seq_len + page_size - 1) // page_size
|
||||
block_tables = torch.randint(
|
||||
0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32
|
||||
0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
|
||||
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim)
|
||||
kv_cache = torch.randn(size=kv_cache_shape, dtype=dtype)
|
||||
k_scale = v_scale = 1.0
|
||||
|
||||
if kv_cache_dtype.startswith("fp8"):
|
||||
kv_cache, _ = to_float8(kv_cache)
|
||||
|
||||
output_trtllm = torch.empty(q.shape, dtype=dtype)
|
||||
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
kv_last_page_lens = []
|
||||
for i in range(batch_size):
|
||||
for i in range(num_seqs):
|
||||
seq_len = seq_lens[i]
|
||||
assert seq_len > 0
|
||||
num_blocks = (seq_len + block_size - 1) // block_size
|
||||
num_blocks = (seq_len + page_size - 1) // page_size
|
||||
kv_indices.extend(block_tables[i, :num_blocks])
|
||||
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||
kv_last_page_len = seq_len % block_size
|
||||
kv_last_page_len = seq_len % page_size
|
||||
if kv_last_page_len == 0:
|
||||
kv_last_page_len = block_size
|
||||
kv_last_page_len = page_size
|
||||
kv_last_page_lens.append(kv_last_page_len)
|
||||
|
||||
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||
workspace_buffer = torch.zeros(1024 * 1024 * 1024, dtype=torch.int8)
|
||||
|
||||
output_baseline = torch.empty(q.shape, dtype=dtype)
|
||||
|
||||
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
||||
workspace_buffer, kv_layout
|
||||
@ -128,12 +115,12 @@ def benchmark_prefill(
|
||||
kv_last_page_lens,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
head_dim,
|
||||
page_size,
|
||||
causal=True,
|
||||
sm_scale=sm_scale,
|
||||
q_data_type=dtype,
|
||||
kv_data_type=dtype,
|
||||
kv_data_type=kv_cache.dtype,
|
||||
)
|
||||
|
||||
def time_fn(fn, warmup=10, trials=20):
|
||||
@ -151,55 +138,52 @@ def benchmark_prefill(
|
||||
times.append(start.elapsed_time(end)) # ms
|
||||
return sum(times) / len(times), torch.std(torch.tensor(times))
|
||||
|
||||
o_scale = 1.0
|
||||
output_baseline = torch.empty(ref_query.shape, dtype=dtype)
|
||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||
|
||||
def baseline_prefill():
|
||||
return wrapper.run(ref_query, ref_kv_cache, out=output_baseline)
|
||||
return wrapper.run(
|
||||
q, kv_cache, k_scale=k_scale, v_scale=v_scale, out=output_baseline
|
||||
)
|
||||
|
||||
def trtllm_prefill():
|
||||
def trt_prefill():
|
||||
return flashinfer.prefill.trtllm_batch_context_with_kv_cache(
|
||||
query=query,
|
||||
query=q,
|
||||
kv_cache=kv_cache,
|
||||
workspace_buffer=workspace_buffer,
|
||||
block_tables=block_tables,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens=seq_lens_tensor,
|
||||
max_q_len=max_q_len,
|
||||
max_kv_len=max_seq_len,
|
||||
bmm1_scale=q_scale * k_scale * sm_scale,
|
||||
bmm2_scale=v_scale / o_scale,
|
||||
batch_size=batch_size,
|
||||
bmm1_scale=k_scale * sm_scale,
|
||||
bmm2_scale=v_scale,
|
||||
batch_size=num_seqs,
|
||||
cum_seq_lens_q=q_indptr,
|
||||
cum_seq_lens_kv=kv_indptr,
|
||||
out=output_trtllm,
|
||||
)
|
||||
|
||||
trt_mean, trt_std = time_fn(trt_prefill)
|
||||
baseline_mean, baseline_std = time_fn(baseline_prefill)
|
||||
trtllm_mean, trtllm_std = time_fn(trtllm_prefill)
|
||||
|
||||
# Calculate percentage speedup (positive means TRT is faster)
|
||||
speedup_percent = (baseline_mean - trtllm_mean) / baseline_mean
|
||||
speedup_percent = (baseline_mean - trt_mean) / baseline_mean
|
||||
|
||||
print(
|
||||
f"\t{batch_size}\t{max_seq_len}\t{trtllm_mean:8.3f}\t{trtllm_std.item():8.3f}"
|
||||
f"\t{baseline_mean:8.3f}\t{baseline_std.item():8.3f}\t{speedup_percent:8.3f}"
|
||||
f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.5f}\t{trt_std.item():.5f}"
|
||||
f"\t{baseline_mean:.5f}\t{baseline_std.item():.5f}\t{speedup_percent:.5f}"
|
||||
)
|
||||
|
||||
# Return results for CSV writing
|
||||
return {
|
||||
"batch_size": batch_size,
|
||||
"trtllm_mean": trtllm_mean,
|
||||
"trtllm_std": trtllm_std.item(),
|
||||
"num_seqs": num_seqs,
|
||||
"trt_mean": trt_mean,
|
||||
"trt_std": trt_std.item(),
|
||||
"baseline_mean": baseline_mean,
|
||||
"baseline_std": baseline_std.item(),
|
||||
"speedup_percent": speedup_percent,
|
||||
"q_dtype": str(q_quant_dtype),
|
||||
"kv_cache_dtype": str(kv_quant_dtype),
|
||||
"output_dtype": str(o_quant_dtype),
|
||||
"block_size": block_size,
|
||||
"q_dtype": str(dtype),
|
||||
"kv_cache_dtype": kv_cache_dtype,
|
||||
"page_size": page_size,
|
||||
"num_kv_heads": num_kv_heads,
|
||||
"head_size": head_size,
|
||||
"head_dim": head_dim,
|
||||
"max_seq_len": max_seq_len,
|
||||
}
|
||||
|
||||
@ -211,18 +195,17 @@ def write_results_to_csv(results, filename=None):
|
||||
filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv"
|
||||
|
||||
fieldnames = [
|
||||
"batch_size",
|
||||
"trtllm_mean",
|
||||
"trtllm_std",
|
||||
"num_seqs",
|
||||
"trt_mean",
|
||||
"trt_std",
|
||||
"baseline_mean",
|
||||
"baseline_std",
|
||||
"speedup_percent",
|
||||
"q_dtype",
|
||||
"kv_cache_dtype",
|
||||
"output_dtype",
|
||||
"block_size",
|
||||
"page_size",
|
||||
"num_kv_heads",
|
||||
"head_size",
|
||||
"head_dim",
|
||||
"max_seq_len",
|
||||
]
|
||||
|
||||
@ -241,41 +224,27 @@ def write_results_to_csv(results, filename=None):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256]
|
||||
num_seqs = [1, 4, 8, 16, 32, 64, 128, 256]
|
||||
max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
|
||||
all_results = []
|
||||
|
||||
dtype = torch.bfloat16
|
||||
quant_dtypes = [
|
||||
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
|
||||
(None, None, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
|
||||
]
|
||||
|
||||
for quant_dtype in quant_dtypes:
|
||||
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtype
|
||||
q_quant_dtype = q_quant_dtype or dtype
|
||||
kv_quant_dtype = kv_quant_dtype or dtype
|
||||
o_quant_dtype = o_quant_dtype or dtype
|
||||
|
||||
print(
|
||||
f"Running benchmark for q_dtype = {q_quant_dtype}, "
|
||||
f"kv_cache_dtype: {kv_quant_dtype}, "
|
||||
f"output_dtype: {o_quant_dtype}"
|
||||
)
|
||||
print(
|
||||
"\tbatch_size\tmax_seq_len\ttrtllm_mean\ttrtllm_std\tbaseline_mean\t"
|
||||
"baseline_std\tspeedup_percent"
|
||||
)
|
||||
for max_seq_len in max_seq_lens:
|
||||
for bs in batch_sizes:
|
||||
result = benchmark_prefill(
|
||||
dtype=dtype,
|
||||
quant_dtypes=quant_dtype,
|
||||
batch_size=bs,
|
||||
max_seq_len=max_seq_len,
|
||||
)
|
||||
all_results.append(result)
|
||||
print(
|
||||
"Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, "
|
||||
"output_dtype: bfloat16"
|
||||
)
|
||||
print(
|
||||
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t"
|
||||
"baseline_std\tspeedup_percent"
|
||||
)
|
||||
for max_seq_len in max_seq_lens:
|
||||
for bs in num_seqs:
|
||||
result = benchmark_prefill(
|
||||
bs,
|
||||
max_seq_len,
|
||||
dtype=torch.bfloat16,
|
||||
kv_cache_dtype="auto",
|
||||
)
|
||||
all_results.append(result)
|
||||
|
||||
# Write all results to CSV
|
||||
write_results_to_csv(all_results)
|
||||
|
||||
@ -5,13 +5,11 @@ The requirements (pip) for `benchmark_serving_multi_turn.py` can be found in `re
|
||||
First start serving your model
|
||||
|
||||
```bash
|
||||
export MODEL_PATH=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/
|
||||
export MODEL_NAME=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/
|
||||
|
||||
vllm serve $MODEL_PATH --served-model-name Llama --disable-log-requests
|
||||
vllm serve $MODEL_NAME --disable-log-requests
|
||||
```
|
||||
|
||||
The variable `MODEL_PATH` should be a path to the model files (e.g. downloaded from huggingface).
|
||||
|
||||
## Synthetic Multi-Turn Conversations
|
||||
|
||||
Download the following text file (used for generation of synthetic conversations)
|
||||
@ -28,10 +26,10 @@ But you may use other text files if you prefer (using this specific file is not
|
||||
Then run the benchmarking script
|
||||
|
||||
```bash
|
||||
export MODEL_PATH=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/
|
||||
export MODEL_NAME=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/
|
||||
|
||||
python benchmark_serving_multi_turn.py --model $MODEL_PATH --served-model-name Llama \
|
||||
--input-file generate_multi_turn.json --num-clients 2 --max-active-conversations 6
|
||||
python benchmark_serving_multi_turn.py --model $MODEL_NAME --input-file generate_multi_turn.json \
|
||||
--num-clients 2 --max-active-conversations 6
|
||||
```
|
||||
|
||||
You can edit the file `generate_multi_turn.json` to change the conversation parameters (number of turns, etc.).
|
||||
|
||||
@ -825,11 +825,9 @@ def get_client_config(
|
||||
|
||||
# Arguments for API requests
|
||||
chat_url = f"{args.url}/v1/chat/completions"
|
||||
model_name = args.served_model_name if args.served_model_name else args.model
|
||||
|
||||
req_args = RequestArgs(
|
||||
chat_url=chat_url,
|
||||
model=model_name,
|
||||
model=args.model,
|
||||
stream=not args.no_stream,
|
||||
limit_min_tokens=args.limit_min_tokens,
|
||||
limit_max_tokens=args.limit_max_tokens,
|
||||
@ -1249,19 +1247,9 @@ async def main() -> None:
|
||||
default=0,
|
||||
help="Seed for random number generators (default: 0)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-m", "--model", type=str, required=True, help="Path of the LLM model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--served-model-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The model name used in the API. "
|
||||
"If not specified, the model name will be the "
|
||||
"same as the ``--model`` argument. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-u",
|
||||
"--url",
|
||||
|
||||
@ -182,17 +182,17 @@ endif()
|
||||
#
|
||||
# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 /ARM platforms)
|
||||
# Flag to enable ACL kernels for AARCH64 platforms
|
||||
if (VLLM_BUILD_ACL STREQUAL "ON")
|
||||
if ( VLLM_BUILD_ACL STREQUAL "ON")
|
||||
set(USE_ACL ON)
|
||||
else()
|
||||
set(USE_ACL OFF)
|
||||
endif()
|
||||
|
||||
if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND)
|
||||
if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND)
|
||||
FetchContent_Declare(
|
||||
oneDNN
|
||||
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
|
||||
GIT_TAG v3.9
|
||||
GIT_TAG v3.8.1
|
||||
GIT_PROGRESS TRUE
|
||||
GIT_SHALLOW TRUE
|
||||
)
|
||||
@ -204,7 +204,7 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND OR POWER9_FOUND OR POW
|
||||
endif()
|
||||
set(ONEDNN_AARCH64_USE_ACL "ON")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(ONEDNN_LIBRARY_TYPE "STATIC")
|
||||
set(ONEDNN_BUILD_DOC "OFF")
|
||||
@ -217,23 +217,38 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND OR POWER9_FOUND OR POW
|
||||
set(ONEDNN_ENABLE_ITT_TASKS "OFF")
|
||||
set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF")
|
||||
set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF")
|
||||
set(ONEDNN_VERBOSE "OFF")
|
||||
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
|
||||
|
||||
FetchContent_MakeAvailable(oneDNN)
|
||||
add_library(dnnl_ext OBJECT "csrc/cpu/dnnl_helper.cpp")
|
||||
target_include_directories(
|
||||
dnnl_ext
|
||||
PUBLIC ${oneDNN_SOURCE_DIR}/include
|
||||
PUBLIC ${oneDNN_BINARY_DIR}/include
|
||||
PRIVATE ${oneDNN_SOURCE_DIR}/src
|
||||
|
||||
list(APPEND LIBS dnnl)
|
||||
elseif(POWER10_FOUND)
|
||||
FetchContent_Declare(
|
||||
oneDNN
|
||||
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
|
||||
GIT_TAG v3.7.2
|
||||
GIT_PROGRESS TRUE
|
||||
GIT_SHALLOW TRUE
|
||||
)
|
||||
target_link_libraries(dnnl_ext dnnl)
|
||||
target_compile_options(dnnl_ext PRIVATE ${CXX_COMPILE_FLAGS} -fPIC)
|
||||
list(APPEND LIBS dnnl_ext)
|
||||
set(USE_ONEDNN ON)
|
||||
else()
|
||||
set(USE_ONEDNN OFF)
|
||||
|
||||
set(ONEDNN_LIBRARY_TYPE "STATIC")
|
||||
set(ONEDNN_BUILD_DOC "OFF")
|
||||
set(ONEDNN_BUILD_EXAMPLES "OFF")
|
||||
set(ONEDNN_BUILD_TESTS "OFF")
|
||||
set(ONEDNN_ENABLE_WORKLOAD "INFERENCE")
|
||||
set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER")
|
||||
set(ONEDNN_BUILD_GRAPH "OFF")
|
||||
set(ONEDNN_ENABLE_JIT_PROFILING "OFF")
|
||||
set(ONEDNN_ENABLE_ITT_TASKS "OFF")
|
||||
set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF")
|
||||
set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF")
|
||||
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
|
||||
|
||||
set(DNNL_CPU_RUNTIME "OMP")
|
||||
|
||||
FetchContent_MakeAvailable(oneDNN)
|
||||
|
||||
list(APPEND LIBS dnnl)
|
||||
endif()
|
||||
|
||||
message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
|
||||
@ -260,6 +275,7 @@ set(VLLM_EXT_SRC
|
||||
|
||||
if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/quant.cpp"
|
||||
"csrc/cpu/shm.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI)
|
||||
@ -273,11 +289,14 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
${VLLM_EXT_SRC})
|
||||
add_compile_definitions(-DCPU_CAPABILITY_AVX512)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(USE_ONEDNN)
|
||||
elseif(POWER10_FOUND)
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/dnnl_kernels.cpp"
|
||||
"csrc/cpu/quant.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
endif()
|
||||
if (ASIMD_FOUND)
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/quant.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
endif()
|
||||
|
||||
|
||||
@ -38,7 +38,7 @@ else()
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn
|
||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||
GIT_TAG 57b4e68b9f9d94750b46de8f8dbd2bfcc86edd4f
|
||||
GIT_TAG 93cf5a08f421a3efd0c4a7e005ef8f742b578ce0
|
||||
GIT_PROGRESS TRUE
|
||||
# Don't share the vllm-flash-attn build between build types
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
|
||||
@ -128,45 +128,6 @@ __global__ void act_and_mul_kernel_with_param(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T swigluoai_and_mul(const T& gate, const T& up,
|
||||
float alpha, float limit) {
|
||||
// clamp gate: min=None, max=limit
|
||||
const float gate_f = (float)gate;
|
||||
const float clamped_gate = gate_f > limit ? limit : gate_f;
|
||||
|
||||
// clamp up: min=-limit, max=limit
|
||||
const float up_f = (float)up;
|
||||
const float clamped_up =
|
||||
up_f > limit ? limit : (up_f < -limit ? -limit : up_f);
|
||||
|
||||
// glu = gate * sigmoid(gate * alpha)
|
||||
const float sigmoid_val = 1.0f / (1.0f + expf(-clamped_gate * alpha));
|
||||
const float glu = clamped_gate * sigmoid_val;
|
||||
|
||||
// (up + 1) * glu
|
||||
return (T)((clamped_up + 1.0f) * glu);
|
||||
}
|
||||
|
||||
template <typename scalar_t,
|
||||
scalar_t (*ACT_FN)(const scalar_t&, const scalar_t&, const float,
|
||||
const float)>
|
||||
__global__ void swigluoai_and_mul_kernel(
|
||||
scalar_t* __restrict__ out, // [..., d]
|
||||
const scalar_t* __restrict__ input, // [..., 2, d]
|
||||
const int d, const float alpha, const float limit) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
// TODO: Vectorize loads and stores.
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
// gate = x[..., ::2] (even indices)
|
||||
const scalar_t gate = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx]);
|
||||
// up = x[..., 1::2] (odd indices)
|
||||
const scalar_t up = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx + 1]);
|
||||
|
||||
out[token_idx * d + idx] = ACT_FN(gate, up, alpha, limit);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \
|
||||
@ -184,31 +145,11 @@ __global__ void swigluoai_and_mul_kernel(
|
||||
PARAM); \
|
||||
});
|
||||
|
||||
#define LAUNCH_SIGLUOAI_AND_MUL(KERNEL, ALPHA, LIMIT) \
|
||||
int d = input.size(-1) / 2; \
|
||||
int64_t num_tokens = input.numel() / input.size(-1); \
|
||||
dim3 grid(num_tokens); \
|
||||
dim3 block(std::min(d, 1024)); \
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
input.scalar_type(), "clamp_swiglu_kernel_with_params", [&] { \
|
||||
vllm::swigluoai_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
|
||||
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
|
||||
input.data_ptr<scalar_t>(), d, ALPHA, \
|
||||
LIMIT); \
|
||||
});
|
||||
|
||||
void fatrelu_and_mul(torch::Tensor& out, // [..., d],
|
||||
torch::Tensor& input, // [..., 2 * d]
|
||||
double threshold) {
|
||||
LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(vllm::fatrelu_kernel, threshold);
|
||||
}
|
||||
void swigluoai_and_mul(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input, // [..., 2 * d]
|
||||
double alpha, double limit) {
|
||||
LAUNCH_SIGLUOAI_AND_MUL(vllm::swigluoai_and_mul, alpha, limit);
|
||||
}
|
||||
namespace vllm {
|
||||
|
||||
// Element-wise activation kernel template.
|
||||
|
||||
@ -167,7 +167,7 @@ typename T::Fmha::Arguments args_from_options(
|
||||
// TODO(trevor-m): Change split_kv back to -1 when
|
||||
// https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will
|
||||
// perform worse with larger context length and smaller batch sizes.
|
||||
static_cast<int>(num_kv_splits), // split_kv
|
||||
num_kv_splits, // split_kv
|
||||
nullptr, // is_var_split_kv
|
||||
};
|
||||
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
|
||||
@ -264,7 +264,7 @@ int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_ba
|
||||
// Assumes device 0 when getting sm_count.
|
||||
arguments.hw_info.sm_count =
|
||||
sm_count <= 0 ? cutlass::KernelHardwareInfo::query_device_multiprocessor_count(/*device_id=*/0) : sm_count;
|
||||
arguments.split_kv = static_cast<int>(num_kv_splits);
|
||||
arguments.split_kv = num_kv_splits;
|
||||
MlaSm100Type::Fmha::set_split_kv(arguments);
|
||||
|
||||
return MlaSm100Type::Fmha::get_workspace_size(arguments);
|
||||
|
||||
@ -321,8 +321,6 @@ static inline constexpr auto kFE3M2f =
|
||||
ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
|
||||
static inline constexpr auto kFE4M3fn =
|
||||
ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN);
|
||||
static inline constexpr auto kFE8M0fnu =
|
||||
ScalarType(8, 0, false, 0, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN);
|
||||
static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2);
|
||||
static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7);
|
||||
static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10);
|
||||
|
||||
@ -89,7 +89,7 @@ struct FP16Vec16 : public Vec<FP16Vec16> {
|
||||
|
||||
explicit FP16Vec16(const FP32Vec16&);
|
||||
|
||||
void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); }
|
||||
void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; }
|
||||
|
||||
void save(void* ptr, const int elem_num) const {
|
||||
constexpr uint32_t M = 0xFFFFFFFF;
|
||||
@ -126,7 +126,7 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||
|
||||
explicit BF16Vec16(const FP32Vec16&);
|
||||
|
||||
void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); }
|
||||
void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; }
|
||||
|
||||
void save(void* ptr, const int elem_num) const {
|
||||
constexpr uint32_t M = 0xFFFFFFFF;
|
||||
@ -180,8 +180,8 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
(__m128i)vec8_data.reg, 1)) {}
|
||||
|
||||
void save(void* ptr) const {
|
||||
_mm256_storeu_si256((__m256i*)ptr, reg_low);
|
||||
_mm256_storeu_si256((__m256i*)ptr + 1, reg_high);
|
||||
*reinterpret_cast<__m256i*>(ptr) = reg_low;
|
||||
*reinterpret_cast<__m256i*>((__m256i*)ptr + 1) = reg_high;
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
@ -1,346 +0,0 @@
|
||||
#include <list>
|
||||
#include <optional>
|
||||
|
||||
#include "common/memory_desc.hpp"
|
||||
#include "common/memory.hpp"
|
||||
|
||||
#include "dnnl_helper.h"
|
||||
|
||||
static dnnl::engine& default_engine() {
|
||||
static dnnl::engine engine(dnnl::engine::kind::cpu, 0);
|
||||
return engine;
|
||||
}
|
||||
|
||||
static dnnl::stream& default_stream() {
|
||||
static dnnl::stream stream(default_engine());
|
||||
return stream;
|
||||
}
|
||||
|
||||
void release_dnnl_matmul_handler(int64_t handler) {
|
||||
DNNLMatMulPrimitiveHandler* ptr =
|
||||
reinterpret_cast<DNNLMatMulPrimitiveHandler*>(handler);
|
||||
delete ptr;
|
||||
}
|
||||
|
||||
template <typename KT, typename VT>
|
||||
class DNNLPrimitiveCache {
|
||||
public:
|
||||
using cache_value_t = std::pair<KT, VT>;
|
||||
using result_value_t = VT;
|
||||
using container_t = std::list<cache_value_t>;
|
||||
using value_iterator_t = typename container_t::iterator;
|
||||
using map_t = std::unordered_map<KT, value_iterator_t>;
|
||||
using creator_t = VT (*)();
|
||||
|
||||
public:
|
||||
DNNLPrimitiveCache(size_t capacity)
|
||||
: capacity_(capacity),
|
||||
values_(),
|
||||
key_to_value_(std::min(256lu, capacity)) {
|
||||
assert(capacity > 0);
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
result_value_t get_or_create(const KT& key, F&& creator) {
|
||||
std::optional<value_iterator_t> value = get_value(key);
|
||||
if (value.has_value()) {
|
||||
return value.value()->second;
|
||||
} else {
|
||||
return add_value({key, creator()})->second;
|
||||
}
|
||||
}
|
||||
|
||||
size_t size() const { return values_.size(); }
|
||||
|
||||
private:
|
||||
void dump_data() {
|
||||
std::stringstream ss;
|
||||
ss << "table_id: " << std::hex << reinterpret_cast<size_t>(this) << std::dec
|
||||
<< "\n";
|
||||
ss << "container: [";
|
||||
for (auto&& iter : values_) {
|
||||
ss << "(" << iter.first << ", " << std::hex
|
||||
<< reinterpret_cast<size_t>(iter.second.get()) << "), " << std::dec;
|
||||
}
|
||||
ss << "]\n";
|
||||
|
||||
ss << "map: [";
|
||||
for (auto&& iter : key_to_value_) {
|
||||
ss << "(" << iter.first << ", " << iter.second->first << ", " << std::hex
|
||||
<< reinterpret_cast<size_t>(iter.second->second.get()) << std::dec
|
||||
<< "), ";
|
||||
}
|
||||
ss << "]\n";
|
||||
std::printf("%s\n", ss.str().c_str());
|
||||
}
|
||||
|
||||
value_iterator_t add_value(cache_value_t&& new_value) {
|
||||
if (size() == capacity_) {
|
||||
cache_value_t& last_item = values_.back();
|
||||
key_to_value_.erase(last_item.first);
|
||||
values_.pop_back();
|
||||
}
|
||||
|
||||
auto& added_value_ = values_.emplace_front(std::move(new_value));
|
||||
key_to_value_.emplace(added_value_.first, values_.begin());
|
||||
return values_.begin();
|
||||
}
|
||||
|
||||
std::optional<value_iterator_t> get_value(const KT& key) {
|
||||
if (key_to_value_.size() > 0 && key == values_.begin()->first) {
|
||||
return values_.begin();
|
||||
}
|
||||
|
||||
auto value_map_iterator = key_to_value_.find(key);
|
||||
if (value_map_iterator != key_to_value_.end()) {
|
||||
values_.splice(values_.begin(), values_, value_map_iterator->second);
|
||||
return value_map_iterator->second;
|
||||
} else {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
const size_t capacity_;
|
||||
container_t values_;
|
||||
map_t key_to_value_;
|
||||
};
|
||||
|
||||
DNNLMatMulPrimitiveHandler::DNNLMatMulPrimitiveHandler(
|
||||
const Args& args, dnnl::memory::data_type b_type)
|
||||
: b_n_size_(args.b_n_size),
|
||||
b_n_stride_(args.b_n_stride),
|
||||
b_k_size_(args.b_k_size),
|
||||
b_k_stride_(args.b_k_stride),
|
||||
b_type_(b_type),
|
||||
c_type_(args.c_type),
|
||||
runtime_memory_ptrs_(8),
|
||||
primitive_cache_size_(args.primitive_cache_size) {
|
||||
assert(primitive_cache_size_ > 0);
|
||||
}
|
||||
|
||||
void DNNLMatMulPrimitiveHandler::prepack_weight(
|
||||
void* original_b_ptr, dnnl::memory::desc b_target_mem_desc) {
|
||||
dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_,
|
||||
{b_k_stride_, b_n_stride_});
|
||||
dnnl::memory original_weight(original_b_md, default_engine(), original_b_ptr);
|
||||
dnnl::memory packed_weight(b_target_mem_desc, default_engine());
|
||||
{
|
||||
dnnl::reorder(original_weight, packed_weight)
|
||||
.execute(default_stream(), original_weight, packed_weight);
|
||||
default_stream().wait();
|
||||
}
|
||||
memory_cache_[DNNL_ARG_WEIGHTS] = packed_weight;
|
||||
b_target_mem_desc_ = b_target_mem_desc;
|
||||
}
|
||||
|
||||
void DNNLMatMulPrimitiveHandler::set_runtime_memory_ptr(
|
||||
size_t index, dnnl_memory* memory_ptr) {
|
||||
dnnl::impl::memory_storage_t* mem_storage_ptr = memory_ptr->memory_storage();
|
||||
dnnl_memory_desc* mem_desc = const_cast<dnnl_memory_desc*>(memory_ptr->md());
|
||||
runtime_memory_ptrs_[index] = {mem_storage_ptr, mem_desc};
|
||||
}
|
||||
|
||||
std::pair<dnnl::impl::memory_storage_t*, dnnl_memory_desc*>
|
||||
DNNLMatMulPrimitiveHandler::get_runtime_memory_ptr(size_t index) {
|
||||
return runtime_memory_ptrs_[index];
|
||||
}
|
||||
|
||||
namespace std {
|
||||
template <>
|
||||
struct hash<W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey> {
|
||||
size_t operator()(
|
||||
const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& val) const {
|
||||
return hash<dnnl_dim_t>()(val.b_n_size) ^ hash<dnnl_dim_t>()(val.b_k_size) ^
|
||||
hash<int>()(static_cast<int>(val.a_qs)) ^
|
||||
hash<int>()(static_cast<int>(val.b_qs)) ^ hash<bool>()(val.use_azp) ^
|
||||
hash<int>()(static_cast<int>(val.c_type));
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct hash<W8A8MatMulPrimitiveHandler::MSizeCacheKey> {
|
||||
size_t operator()(
|
||||
const W8A8MatMulPrimitiveHandler::MSizeCacheKey& val) const {
|
||||
return hash<dnnl_dim_t>()(val.a_m_size) ^ hash<bool>()(val.use_bias) ^
|
||||
hash<int>()(static_cast<int>(val.bias_type));
|
||||
}
|
||||
};
|
||||
} // namespace std
|
||||
|
||||
bool operator==(const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& l,
|
||||
const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& r) {
|
||||
return l.b_n_size == r.b_n_size && l.b_k_size == r.b_k_size &&
|
||||
l.a_qs == r.a_qs && l.b_qs == r.b_qs && l.use_azp == r.use_azp &&
|
||||
l.c_type == r.c_type;
|
||||
}
|
||||
|
||||
bool operator==(const W8A8MatMulPrimitiveHandler::MSizeCacheKey& l,
|
||||
const W8A8MatMulPrimitiveHandler::MSizeCacheKey& r) {
|
||||
return l.use_bias == r.use_bias && l.a_m_size == r.a_m_size &&
|
||||
l.bias_type == r.bias_type;
|
||||
}
|
||||
|
||||
static std::shared_ptr<W8A8MatMulPrimitiveHandler::MSizeCache>
|
||||
get_w8a8_class_primitive_cache(
|
||||
const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& key,
|
||||
int64_t cache_size) {
|
||||
static W8A8MatMulPrimitiveHandler::ClassMatmulCache cache(128);
|
||||
assert(cache_size > 0);
|
||||
return cache.get_or_create(key, [&]() {
|
||||
return std::make_shared<W8A8MatMulPrimitiveHandler::MSizeCache>(cache_size);
|
||||
});
|
||||
}
|
||||
|
||||
W8A8MatMulPrimitiveHandler::W8A8MatMulPrimitiveHandler(const Args& args)
|
||||
: DNNLMatMulPrimitiveHandler(
|
||||
static_cast<const DNNLMatMulPrimitiveHandler::Args&>(args),
|
||||
dnnl::memory::data_type::s8),
|
||||
use_azp_(args.use_a_zero_point),
|
||||
a_qs_(args.a_quantization_strategy),
|
||||
b_qs_(args.b_quantization_strategy),
|
||||
m_size_cache_(nullptr) {
|
||||
assert(a_qs_ != QuantizationStrategy::PER_OUTPUT_CHANNEL);
|
||||
assert(b_qs_ != QuantizationStrategy::PER_TOKEN);
|
||||
if (a_qs_ == QuantizationStrategy::PER_TOKEN) {
|
||||
assert(!use_azp_);
|
||||
};
|
||||
prepack_weight(args.b_ptr,
|
||||
create_primitive_desc(
|
||||
MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL,
|
||||
.use_bias = false,
|
||||
.bias_type = dnnl::memory::data_type::undef},
|
||||
true)
|
||||
.weights_desc());
|
||||
init_runtime_memory_cache(args);
|
||||
}
|
||||
|
||||
void W8A8MatMulPrimitiveHandler::execute(ExecArgs& args) {
|
||||
auto&& [a_storage, a_mem_desc] = get_runtime_memory_ptr(0);
|
||||
auto&& [c_storage, c_mem_desc] = get_runtime_memory_ptr(1);
|
||||
a_storage->set_data_handle((void*)args.a_ptr);
|
||||
a_mem_desc->dims[0] = args.a_m_size;
|
||||
c_storage->set_data_handle((void*)args.c_ptr);
|
||||
c_mem_desc->dims[0] = args.a_m_size;
|
||||
|
||||
if (a_qs_ == QuantizationStrategy::PER_TENSOR) {
|
||||
auto&& [a_scale_storage, a_scale_mem_desc] = get_runtime_memory_ptr(2);
|
||||
a_scale_storage->set_data_handle((void*)args.a_scales_ptr);
|
||||
}
|
||||
if (use_azp_) {
|
||||
auto&& [a_zero_point_storage, a_zero_point_mem_desc] =
|
||||
get_runtime_memory_ptr(3);
|
||||
a_zero_point_storage->set_data_handle((void*)args.a_zero_points_ptr);
|
||||
}
|
||||
|
||||
if (args.use_bias) {
|
||||
auto&& [bias_storage, bias_mem_desc] = get_runtime_memory_ptr(4);
|
||||
bias_storage->set_data_handle((void*)args.bias_ptr);
|
||||
}
|
||||
|
||||
dnnl::matmul matmul = get_matmul_cache(args);
|
||||
matmul.execute(default_stream(), memory_cache_);
|
||||
default_stream().wait();
|
||||
}
|
||||
|
||||
dnnl::matmul W8A8MatMulPrimitiveHandler::get_matmul_cache(
|
||||
const MSizeCacheKey& key) {
|
||||
if (m_size_cache_.get() == nullptr) {
|
||||
ClassMatmulCacheKey key = {.b_n_size = b_n_size_,
|
||||
.b_k_size = b_k_size_,
|
||||
.a_qs = a_qs_,
|
||||
.b_qs = b_qs_,
|
||||
.use_azp = use_azp_,
|
||||
.c_type = c_type_};
|
||||
m_size_cache_ = get_w8a8_class_primitive_cache(key, primitive_cache_size_);
|
||||
}
|
||||
|
||||
return m_size_cache_->get_or_create(key, [&]() {
|
||||
dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false);
|
||||
return dnnl::matmul(desc);
|
||||
});
|
||||
}
|
||||
|
||||
void W8A8MatMulPrimitiveHandler::init_runtime_memory_cache(const Args& args) {
|
||||
memory_cache_[DNNL_ARG_SRC] = dnnl::memory({{1, b_k_size_},
|
||||
dnnl::memory::data_type::s8,
|
||||
dnnl::memory::format_tag::ab},
|
||||
default_engine(), nullptr);
|
||||
set_runtime_memory_ptr(0, memory_cache_[DNNL_ARG_SRC].get());
|
||||
memory_cache_[DNNL_ARG_DST] =
|
||||
dnnl::memory({{1, b_n_size_}, c_type_, dnnl::memory::format_tag::ab},
|
||||
default_engine(), nullptr);
|
||||
set_runtime_memory_ptr(1, memory_cache_[DNNL_ARG_DST].get());
|
||||
|
||||
// For PER_TOKEN, scales will be applied in outside epilogue
|
||||
if (a_qs_ == QuantizationStrategy::PER_TENSOR) {
|
||||
memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC] = dnnl::memory(
|
||||
{{1}, dnnl::memory::data_type::f32, {1}}, default_engine(), nullptr);
|
||||
set_runtime_memory_ptr(
|
||||
2, memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC].get());
|
||||
if (use_azp_) {
|
||||
memory_cache_[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC] = dnnl::memory(
|
||||
{{1}, dnnl::memory::data_type::s32, {1}}, default_engine(), nullptr);
|
||||
set_runtime_memory_ptr(
|
||||
3, memory_cache_[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC].get());
|
||||
}
|
||||
}
|
||||
|
||||
if (b_qs_ == QuantizationStrategy::PER_TENSOR) {
|
||||
memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] =
|
||||
dnnl::memory({{1}, dnnl::memory::data_type::f32, {1}}, default_engine(),
|
||||
(void*)args.b_scales_ptr);
|
||||
} else if (b_qs_ == QuantizationStrategy::PER_OUTPUT_CHANNEL) {
|
||||
memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] =
|
||||
dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
|
||||
default_engine(), (void*)args.b_scales_ptr);
|
||||
}
|
||||
|
||||
memory_cache_[DNNL_ARG_BIAS] =
|
||||
dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
|
||||
default_engine(), nullptr);
|
||||
set_runtime_memory_ptr(4, memory_cache_[DNNL_ARG_BIAS].get());
|
||||
}
|
||||
|
||||
dnnl::matmul::primitive_desc W8A8MatMulPrimitiveHandler::create_primitive_desc(
|
||||
const MSizeCacheKey& key, bool first_time) {
|
||||
dnnl::memory::desc a_md({key.a_m_size, b_k_size_},
|
||||
dnnl::memory::data_type::s8,
|
||||
dnnl::memory::format_tag::ab);
|
||||
dnnl::memory::desc b_md;
|
||||
if (first_time) {
|
||||
b_md =
|
||||
dnnl::memory::desc({b_k_size_, b_n_size_}, dnnl::memory::data_type::s8,
|
||||
dnnl::memory::format_tag::any);
|
||||
} else {
|
||||
b_md = b_target_mem_desc_;
|
||||
}
|
||||
dnnl::memory::desc c_md({key.a_m_size, b_n_size_}, c_type_,
|
||||
dnnl::memory::format_tag::ab);
|
||||
|
||||
dnnl::primitive_attr attr;
|
||||
// For PER_TOKEN, scales will be applied in outside epilogue
|
||||
if (a_qs_ == QuantizationStrategy::PER_TENSOR) {
|
||||
attr.set_scales_mask(DNNL_ARG_SRC, 0);
|
||||
if (use_azp_) {
|
||||
attr.set_zero_points_mask(DNNL_ARG_SRC, 0);
|
||||
}
|
||||
}
|
||||
|
||||
if (b_qs_ == QuantizationStrategy::PER_TENSOR) {
|
||||
attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0);
|
||||
} else if (b_qs_ == QuantizationStrategy::PER_OUTPUT_CHANNEL) {
|
||||
attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2);
|
||||
}
|
||||
|
||||
if (key.use_bias) {
|
||||
// For PER_TOKEN, bias will be applied in epilogue
|
||||
assert(a_qs_ == QuantizationStrategy::PER_TENSOR);
|
||||
dnnl::memory::desc bias_md({1, b_n_size_}, key.bias_type, {b_n_size_, 1});
|
||||
return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, bias_md,
|
||||
c_md, attr);
|
||||
} else {
|
||||
return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md,
|
||||
attr);
|
||||
}
|
||||
}
|
||||
@ -1,169 +0,0 @@
|
||||
#ifndef DNNL_HELPER_H
|
||||
#define DNNL_HELPER_H
|
||||
|
||||
#include <optional>
|
||||
#include <cassert>
|
||||
|
||||
#include "oneapi/dnnl/dnnl.hpp"
|
||||
|
||||
namespace c10 {
|
||||
struct BFloat16;
|
||||
struct Half;
|
||||
} // namespace c10
|
||||
|
||||
namespace dnnl {
|
||||
namespace impl {
|
||||
struct memory_storage_t;
|
||||
struct matmul_pd_t;
|
||||
struct matmul_desc_t;
|
||||
} // namespace impl
|
||||
} // namespace dnnl
|
||||
struct dnnl_memory_desc;
|
||||
|
||||
template <typename KT, typename VT>
|
||||
class DNNLPrimitiveCache;
|
||||
|
||||
template <typename T>
|
||||
struct DNNLType {
|
||||
static constexpr dnnl::memory::data_type type =
|
||||
dnnl::memory::data_type::undef;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<int8_t> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<int32_t> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<float> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<c10::BFloat16> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<c10::Half> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f16;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
constexpr inline dnnl::memory::data_type get_dnnl_type() {
|
||||
return DNNLType<std::decay_t<T>>::type;
|
||||
}
|
||||
|
||||
class DNNLMatMulPrimitiveHandler {
|
||||
public:
|
||||
virtual ~DNNLMatMulPrimitiveHandler() = default;
|
||||
|
||||
protected:
|
||||
struct Args {
|
||||
dnnl_dim_t b_n_size;
|
||||
dnnl_dim_t b_n_stride;
|
||||
dnnl_dim_t b_k_size;
|
||||
dnnl_dim_t b_k_stride;
|
||||
void* b_ptr;
|
||||
dnnl::memory::data_type c_type;
|
||||
size_t primitive_cache_size;
|
||||
};
|
||||
|
||||
protected:
|
||||
DNNLMatMulPrimitiveHandler(const Args& args, dnnl::memory::data_type b_type);
|
||||
|
||||
void prepack_weight(void* original_b_ptr,
|
||||
dnnl::memory::desc b_target_mem_desc);
|
||||
|
||||
void set_runtime_memory_ptr(size_t index, dnnl_memory* memory_ptr);
|
||||
|
||||
std::pair<dnnl::impl::memory_storage_t*, dnnl_memory_desc*>
|
||||
get_runtime_memory_ptr(size_t index);
|
||||
|
||||
protected:
|
||||
const dnnl_dim_t b_n_size_;
|
||||
const dnnl_dim_t b_n_stride_;
|
||||
const dnnl_dim_t b_k_size_;
|
||||
const dnnl_dim_t b_k_stride_;
|
||||
dnnl::memory::data_type b_type_;
|
||||
dnnl::memory::data_type c_type_;
|
||||
std::unordered_map<int, dnnl::memory> memory_cache_;
|
||||
std::vector<std::pair<dnnl::impl::memory_storage_t*, dnnl_memory_desc*>>
|
||||
runtime_memory_ptrs_;
|
||||
dnnl::memory::desc b_target_mem_desc_;
|
||||
int64_t primitive_cache_size_;
|
||||
};
|
||||
|
||||
class W8A8MatMulPrimitiveHandler : public DNNLMatMulPrimitiveHandler {
|
||||
public:
|
||||
enum class QuantizationStrategy { PER_TOKEN, PER_TENSOR, PER_OUTPUT_CHANNEL };
|
||||
|
||||
struct Args : public DNNLMatMulPrimitiveHandler::Args {
|
||||
bool use_a_zero_point;
|
||||
QuantizationStrategy a_quantization_strategy;
|
||||
QuantizationStrategy b_quantization_strategy;
|
||||
float* b_scales_ptr;
|
||||
};
|
||||
|
||||
struct ClassMatmulCacheKey {
|
||||
dnnl_dim_t b_n_size;
|
||||
dnnl_dim_t b_k_size;
|
||||
QuantizationStrategy a_qs;
|
||||
QuantizationStrategy b_qs;
|
||||
bool use_azp;
|
||||
dnnl::memory::data_type c_type;
|
||||
|
||||
friend bool operator==(const ClassMatmulCacheKey& l,
|
||||
const ClassMatmulCacheKey& r);
|
||||
};
|
||||
|
||||
struct MSizeCacheKey {
|
||||
dnnl_dim_t a_m_size;
|
||||
bool use_bias;
|
||||
dnnl::memory::data_type bias_type;
|
||||
|
||||
friend bool operator==(const MSizeCacheKey& l, const MSizeCacheKey& r);
|
||||
};
|
||||
|
||||
using MSizeCache = DNNLPrimitiveCache<MSizeCacheKey, dnnl::matmul>;
|
||||
using ClassMatmulCache =
|
||||
DNNLPrimitiveCache<ClassMatmulCacheKey, std::shared_ptr<MSizeCache>>;
|
||||
|
||||
struct ExecArgs : public MSizeCacheKey {
|
||||
const int8_t* a_ptr;
|
||||
const float* a_scales_ptr;
|
||||
const int32_t* a_zero_points_ptr;
|
||||
const void* bias_ptr;
|
||||
void* c_ptr;
|
||||
};
|
||||
|
||||
public:
|
||||
W8A8MatMulPrimitiveHandler(const Args& args);
|
||||
|
||||
QuantizationStrategy get_input_scale_strategy() const { return a_qs_; }
|
||||
|
||||
bool get_input_use_zero_point() const { return use_azp_; }
|
||||
|
||||
void execute(ExecArgs& args);
|
||||
|
||||
private:
|
||||
dnnl::matmul::primitive_desc create_primitive_desc(const MSizeCacheKey& key,
|
||||
bool first_time);
|
||||
|
||||
void init_runtime_memory_cache(const Args& args);
|
||||
|
||||
dnnl::matmul get_matmul_cache(const MSizeCacheKey& key);
|
||||
|
||||
private:
|
||||
const bool use_azp_;
|
||||
const QuantizationStrategy a_qs_;
|
||||
const QuantizationStrategy b_qs_;
|
||||
std::shared_ptr<MSizeCache> m_size_cache_;
|
||||
};
|
||||
|
||||
#endif
|
||||
206
csrc/cpu/dnnl_helper.hpp
Normal file
206
csrc/cpu/dnnl_helper.hpp
Normal file
@ -0,0 +1,206 @@
|
||||
#ifndef DNNL_HELPER_HPP
|
||||
#define DNNL_HELPER_HPP
|
||||
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Half.h>
|
||||
|
||||
#include "oneapi/dnnl/dnnl.hpp"
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
struct DNNLType {
|
||||
static constexpr dnnl::memory::data_type type =
|
||||
dnnl::memory::data_type::undef;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<int8_t> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<int32_t> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<float> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<c10::BFloat16> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<c10::Half> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f16;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
constexpr inline dnnl::memory::data_type get_dnnl_type() {
|
||||
return DNNLType<std::decay_t<T>>::type;
|
||||
}
|
||||
}; // namespace
|
||||
|
||||
template <bool InputNoScale>
|
||||
class DNNLPrimitiveHelper {
|
||||
public:
|
||||
// I8 input GEMM kernel (C = a_scales * A @ (b_scales * B^T) + bias)
|
||||
// A: [M, K], row-major
|
||||
// B: [K, N], column-major
|
||||
// C: [M, N], row-major
|
||||
// bias: [N], row-major, optional
|
||||
// a_scales: [MS]
|
||||
// b_scales: [NS]
|
||||
// Note: Due to the limitation of oneDNN
|
||||
// (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is
|
||||
// not supported.
|
||||
|
||||
template <typename OutputT, typename BiasT>
|
||||
static void gemm_s8s8_jit(const int8_t* a, const int8_t* b, OutputT* c,
|
||||
const BiasT* bias, dnnl_dim_t M, dnnl_dim_t N,
|
||||
dnnl_dim_t K, const float* a_scales,
|
||||
const float* b_scales, dnnl_dim_t MS,
|
||||
dnnl_dim_t NS) {
|
||||
auto&& OutputType = get_dnnl_type<OutputT>();
|
||||
auto&& BiasType = get_dnnl_type<BiasT>();
|
||||
|
||||
dnnl::memory::desc a_md({M, K}, dnnl::memory::data_type::s8, {K, 1});
|
||||
dnnl::memory::desc b_md({K, N}, dnnl::memory::data_type::s8, {1, K});
|
||||
dnnl::memory::desc c_md({M, N}, OutputType, {N, 1});
|
||||
|
||||
dnnl::primitive_attr attr;
|
||||
if constexpr (!InputNoScale) {
|
||||
if (MS == 1) {
|
||||
// per-tensor
|
||||
attr.set_scales_mask(DNNL_ARG_SRC, 0);
|
||||
} else {
|
||||
// per-token
|
||||
TORCH_CHECK(false, "per-token quantization is unsupported.");
|
||||
}
|
||||
}
|
||||
|
||||
if (NS == 1) {
|
||||
// per-tensor
|
||||
attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0);
|
||||
} else {
|
||||
// per-channel
|
||||
attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2);
|
||||
}
|
||||
|
||||
dnnl::matmul::primitive_desc matmul_pd;
|
||||
// Create memory descriptors with format_tag::any for the primitive. This
|
||||
// enables the matmul primitive to choose memory layouts for an
|
||||
// optimized primitive implementation, and these layouts may differ from the
|
||||
// ones provided by the user.
|
||||
#ifdef __aarch64__
|
||||
auto mat_src_md = dnnl::memory::desc({M, K}, dnnl::memory::data_type::s8,
|
||||
dnnl::memory::format_tag::any);
|
||||
auto mat_weights_md = dnnl::memory::desc(
|
||||
{K, N}, dnnl::memory::data_type::s8, dnnl::memory::format_tag::any);
|
||||
auto mat_dst_md =
|
||||
dnnl::memory::desc({M, N}, OutputType, dnnl::memory::format_tag::any);
|
||||
if (bias) {
|
||||
dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1});
|
||||
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), mat_src_md,
|
||||
mat_weights_md, bias_md,
|
||||
mat_dst_md, attr);
|
||||
} else {
|
||||
matmul_pd = dnnl::matmul::primitive_desc(
|
||||
default_engine(), mat_src_md, mat_weights_md, mat_dst_md, attr);
|
||||
}
|
||||
#else
|
||||
if (bias) {
|
||||
dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1});
|
||||
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
|
||||
bias_md, c_md, attr);
|
||||
} else {
|
||||
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
|
||||
c_md, attr);
|
||||
}
|
||||
#endif
|
||||
dnnl::matmul matmul(matmul_pd);
|
||||
|
||||
auto& engine = default_engine();
|
||||
|
||||
dnnl::memory a_m(a_md, engine, (void*)a);
|
||||
dnnl::memory b_m(b_md, engine, (void*)b);
|
||||
dnnl::memory c_m(c_md, engine, (void*)c);
|
||||
dnnl::memory a_scales_m({{MS}, dnnl::memory::data_type::f32, {1}}, engine,
|
||||
(void*)a_scales);
|
||||
dnnl::memory b_scales_m({{NS}, dnnl::memory::data_type::f32, {1}}, engine,
|
||||
(void*)b_scales);
|
||||
|
||||
auto& stream = default_stream();
|
||||
|
||||
auto mat_src_mem = a_m;
|
||||
auto mat_weights_mem = b_m;
|
||||
auto mat_dst_mem = c_m;
|
||||
#ifdef __aarch64__
|
||||
if (matmul_pd.weights_desc() != b_m.get_desc()) {
|
||||
mat_weights_mem = dnnl::memory(matmul_pd.weights_desc(), engine);
|
||||
dnnl::reorder(b_m, mat_weights_mem).execute(stream, b_m, mat_weights_mem);
|
||||
}
|
||||
#endif
|
||||
if constexpr (InputNoScale) {
|
||||
if (bias) {
|
||||
dnnl::memory::desc bias_md({N}, BiasType, {1});
|
||||
dnnl::memory bias_m(bias_md, engine, (void*)bias);
|
||||
matmul.execute(
|
||||
stream, {
|
||||
{DNNL_ARG_SRC, mat_src_mem},
|
||||
{DNNL_ARG_WEIGHTS, mat_weights_mem},
|
||||
{DNNL_ARG_BIAS, bias_m},
|
||||
{DNNL_ARG_DST, mat_dst_mem},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||
});
|
||||
} else {
|
||||
matmul.execute(
|
||||
stream, {
|
||||
{DNNL_ARG_SRC, mat_src_mem},
|
||||
{DNNL_ARG_WEIGHTS, mat_weights_mem},
|
||||
{DNNL_ARG_DST, mat_dst_mem},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||
});
|
||||
}
|
||||
} else {
|
||||
if (bias) {
|
||||
dnnl::memory::desc bias_md({N}, BiasType, {1});
|
||||
dnnl::memory bias_m(bias_md, engine, (void*)bias);
|
||||
matmul.execute(
|
||||
stream, {
|
||||
{DNNL_ARG_SRC, mat_src_mem},
|
||||
{DNNL_ARG_WEIGHTS, mat_weights_mem},
|
||||
{DNNL_ARG_BIAS, bias_m},
|
||||
{DNNL_ARG_DST, mat_dst_mem},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||
});
|
||||
} else {
|
||||
matmul.execute(
|
||||
stream, {
|
||||
{DNNL_ARG_SRC, mat_src_mem},
|
||||
{DNNL_ARG_WEIGHTS, mat_weights_mem},
|
||||
{DNNL_ARG_DST, mat_dst_mem},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||
});
|
||||
}
|
||||
}
|
||||
stream.wait();
|
||||
}
|
||||
|
||||
private:
|
||||
static dnnl::engine& default_engine() {
|
||||
static dnnl::engine engine(dnnl::engine::kind::cpu, 0);
|
||||
return engine;
|
||||
}
|
||||
|
||||
static dnnl::stream& default_stream() {
|
||||
static dnnl::stream stream(default_engine());
|
||||
return stream;
|
||||
}
|
||||
};
|
||||
#endif
|
||||
@ -1,494 +0,0 @@
|
||||
#include "cpu_types.hpp"
|
||||
#include "dnnl_helper.h"
|
||||
|
||||
namespace {
|
||||
template <typename scalar_t>
|
||||
struct KernelVecType {
|
||||
using load_vec_type = void;
|
||||
using cvt_vec_type = void;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KernelVecType<float> {
|
||||
using load_vec_type = vec_op::FP32Vec16;
|
||||
using cvt_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
|
||||
#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT)
|
||||
template <>
|
||||
struct KernelVecType<c10::BFloat16> {
|
||||
using load_vec_type = vec_op::BF16Vec16;
|
||||
using cvt_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <>
|
||||
struct KernelVecType<c10::Half> {
|
||||
#if defined(__powerpc64__) || defined(__s390x__)
|
||||
// Power architecture-specific vector type
|
||||
using load_vec_type = vec_op::FP32Vec16;
|
||||
#else
|
||||
// Fallback for other architectures
|
||||
using load_vec_type = vec_op::FP16Vec16;
|
||||
#endif
|
||||
using cvt_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
|
||||
template <bool AZP, typename scalar_t>
|
||||
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
const float* scale, const int32_t* azp,
|
||||
const int64_t num_tokens,
|
||||
const int64_t input_stride,
|
||||
const int64_t hidden_size) {
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int64_t vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
constexpr float i8_min =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::min());
|
||||
constexpr float i8_max =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::max());
|
||||
const cvt_vec_t inv_scale(1.0 / *scale);
|
||||
const cvt_vec_t i8_min_vec(i8_min);
|
||||
const cvt_vec_t i8_max_vec(i8_max);
|
||||
|
||||
cvt_vec_t zp_vec;
|
||||
if constexpr (AZP) {
|
||||
zp_vec = cvt_vec_t(static_cast<float>(*azp));
|
||||
}
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int64_t i = 0; i < num_tokens; ++i) {
|
||||
int64_t j = 0;
|
||||
const scalar_t* input_ptr = input + i * input_stride;
|
||||
int8_t* output_ptr = output + i * hidden_size;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input_ptr + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = elems_fp32 * inv_scale;
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + zp_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output_ptr + j);
|
||||
}
|
||||
|
||||
load_vec_t elems(input_ptr + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = elems_fp32 * inv_scale;
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + zp_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output_ptr + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool AZP, typename scalar_t>
|
||||
void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
float* scale, int32_t* azp,
|
||||
const int64_t num_tokens,
|
||||
const int64_t input_stride,
|
||||
const int64_t hidden_size) {
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
constexpr float i8_min =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::min());
|
||||
constexpr float i8_max =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::max());
|
||||
const cvt_vec_t i8_min_vec(i8_min);
|
||||
const cvt_vec_t i8_max_vec(i8_max);
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int64_t i = 0; i < num_tokens; ++i) {
|
||||
cvt_vec_t max_value(std::numeric_limits<float>::lowest());
|
||||
cvt_vec_t min_value(std::numeric_limits<float>::max());
|
||||
{
|
||||
int64_t j = 0;
|
||||
const scalar_t* input_ptr = input + i * input_stride;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input_ptr + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
if constexpr (AZP) {
|
||||
max_value = max_value.max(elems_fp32);
|
||||
min_value = min_value.min(elems_fp32);
|
||||
} else {
|
||||
max_value = max_value.max(elems_fp32.abs());
|
||||
}
|
||||
}
|
||||
|
||||
load_vec_t elems(input_ptr + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
|
||||
if (j + vec_elem_num == hidden_size) {
|
||||
if constexpr (AZP) {
|
||||
max_value = max_value.max(elems_fp32);
|
||||
min_value = min_value.min(elems_fp32);
|
||||
} else {
|
||||
max_value = max_value.max(elems_fp32.abs());
|
||||
}
|
||||
} else {
|
||||
if constexpr (AZP) {
|
||||
max_value = max_value.max(elems_fp32, hidden_size - j);
|
||||
min_value = min_value.min(elems_fp32, hidden_size - j);
|
||||
} else {
|
||||
max_value = max_value.max(elems_fp32.abs(), hidden_size - j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float scale_val, azp_val;
|
||||
if constexpr (AZP) {
|
||||
float max_scalar = max_value.reduce_max();
|
||||
float min_scalar = min_value.reduce_min();
|
||||
scale_val = (max_scalar - min_scalar) / 255.0f;
|
||||
azp_val = std::nearbyint(-128.0f - min_scalar / scale_val);
|
||||
azp[i] = azp_val;
|
||||
scale[i] = scale_val;
|
||||
} else {
|
||||
scale_val = max_value.reduce_max() / 127.0f;
|
||||
scale[i] = scale_val;
|
||||
}
|
||||
|
||||
const cvt_vec_t inv_scale(1.0 / scale_val);
|
||||
const cvt_vec_t azp_vec(azp_val);
|
||||
|
||||
{
|
||||
int64_t j = 0;
|
||||
const scalar_t* input_ptr = input + i * input_stride;
|
||||
int8_t* output_ptr = output + i * hidden_size;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input_ptr + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = (elems_fp32 * inv_scale);
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + azp_vec;
|
||||
}
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output_ptr + j);
|
||||
}
|
||||
|
||||
load_vec_t elems(input_ptr + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = (elems_fp32 * inv_scale);
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + azp_vec;
|
||||
}
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output_ptr + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool AZP, bool Bias, typename scalar_t>
|
||||
void dynamic_quant_epilogue(const float* input, scalar_t* output,
|
||||
const float* a_scale, const int32_t* azp,
|
||||
const float* azp_adj, const scalar_t* bias,
|
||||
const int64_t num_tokens,
|
||||
const int64_t hidden_size) {
|
||||
CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue)
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
const int64_t thread_num = omp_get_max_threads();
|
||||
if (num_tokens > thread_num) {
|
||||
#pragma omp parallel for
|
||||
for (int64_t i = 0; i < num_tokens; ++i) {
|
||||
const float* input_ptr = input + i * hidden_size;
|
||||
scalar_t* output_ptr = output + i * hidden_size;
|
||||
int64_t j = 0;
|
||||
cvt_vec_t token_scale_vec(a_scale[i]);
|
||||
cvt_vec_t token_zp_scale_vec;
|
||||
if constexpr (AZP) {
|
||||
float zp_scale_val = a_scale[i] * static_cast<float>(azp[i]);
|
||||
token_zp_scale_vec = cvt_vec_t(zp_scale_val);
|
||||
}
|
||||
for (; j < hidden_size - vec_elem_num; ++j) {
|
||||
cvt_vec_t elems_fp32(input_ptr + j);
|
||||
elems_fp32 = elems_fp32 * token_scale_vec;
|
||||
if constexpr (AZP) {
|
||||
cvt_vec_t azp_adj_fp32(azp_adj + j);
|
||||
elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec;
|
||||
}
|
||||
if constexpr (Bias) {
|
||||
load_vec_t bias_vec(bias + j);
|
||||
cvt_vec_t bias_vec_fp32(bias_vec);
|
||||
elems_fp32 = elems_fp32 + bias_vec_fp32;
|
||||
}
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output_ptr + j);
|
||||
}
|
||||
cvt_vec_t elems_fp32(input_ptr + j);
|
||||
elems_fp32 = elems_fp32 * token_scale_vec;
|
||||
if constexpr (AZP) {
|
||||
cvt_vec_t azp_adj_fp32(azp_adj + j);
|
||||
elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec;
|
||||
}
|
||||
if constexpr (Bias) {
|
||||
load_vec_t bias_vec(bias + j);
|
||||
cvt_vec_t bias_vec_fp32(bias_vec);
|
||||
elems_fp32 = elems_fp32 + bias_vec_fp32;
|
||||
}
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output_ptr + j, hidden_size - j);
|
||||
}
|
||||
} else {
|
||||
const int64_t vec_iteration =
|
||||
(hidden_size + vec_elem_num - 1) / vec_elem_num;
|
||||
const int64_t vec_iteration_per_thread =
|
||||
(vec_iteration + thread_num - 1) / thread_num;
|
||||
const int64_t elem_num_per_thread = vec_iteration_per_thread * vec_elem_num;
|
||||
#pragma omp parallel for schedule(static, 1)
|
||||
for (int64_t i = 0; i < thread_num; ++i) {
|
||||
const int64_t start = elem_num_per_thread * i;
|
||||
const int64_t end = std::min(hidden_size, elem_num_per_thread + start);
|
||||
for (int64_t j = 0; j < num_tokens; ++j) {
|
||||
cvt_vec_t token_scale_vec(a_scale[j]);
|
||||
cvt_vec_t token_zp_scale_vec;
|
||||
if constexpr (AZP) {
|
||||
float zp_scale_val = a_scale[j] * static_cast<float>(azp[j]);
|
||||
token_zp_scale_vec = cvt_vec_t(zp_scale_val);
|
||||
}
|
||||
int64_t k = start;
|
||||
const float* input_ptr = input + j * hidden_size;
|
||||
scalar_t* output_ptr = output + j * hidden_size;
|
||||
for (; k < end - vec_elem_num; k += vec_elem_num) {
|
||||
cvt_vec_t elems_fp32(input_ptr + k);
|
||||
elems_fp32 = elems_fp32 * token_scale_vec;
|
||||
if constexpr (AZP) {
|
||||
cvt_vec_t azp_adj_fp32(azp_adj + k);
|
||||
elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec;
|
||||
}
|
||||
if constexpr (Bias) {
|
||||
load_vec_t bias_vec(bias + k);
|
||||
cvt_vec_t bias_vec_fp32(bias_vec);
|
||||
elems_fp32 = elems_fp32 + bias_vec_fp32;
|
||||
}
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output_ptr + k);
|
||||
}
|
||||
if (k < end) {
|
||||
cvt_vec_t elems_fp32(input_ptr + k);
|
||||
elems_fp32 = elems_fp32 * token_scale_vec;
|
||||
if constexpr (AZP) {
|
||||
cvt_vec_t azp_adj_fp32(azp_adj + k);
|
||||
elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec;
|
||||
}
|
||||
if constexpr (Bias) {
|
||||
load_vec_t bias_vec(bias + k);
|
||||
cvt_vec_t bias_vec_fp32(bias_vec);
|
||||
elems_fp32 = elems_fp32 + bias_vec_fp32;
|
||||
}
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output_ptr + k, end - k);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
int64_t create_onednn_scaled_mm_handler(
|
||||
const torch::Tensor& b, // [IC, OC], column-major
|
||||
const torch::Tensor& b_scales, // [1] or [OC]
|
||||
at::ScalarType output_type, bool dynamic_act_quant, bool use_azp,
|
||||
int64_t primitive_cache_size) {
|
||||
TORCH_CHECK(b.dim() == 2);
|
||||
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||
TORCH_CHECK(b_scales.is_contiguous());
|
||||
|
||||
W8A8MatMulPrimitiveHandler::Args args;
|
||||
args.primitive_cache_size = primitive_cache_size;
|
||||
|
||||
if (b_scales.numel() == 1) {
|
||||
args.b_quantization_strategy =
|
||||
W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR;
|
||||
} else {
|
||||
TORCH_CHECK_EQ(b_scales.numel(), b.size(1));
|
||||
args.b_quantization_strategy =
|
||||
W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_OUTPUT_CHANNEL;
|
||||
}
|
||||
args.b_scales_ptr = b_scales.data_ptr<float>();
|
||||
args.b_k_size = b.size(0);
|
||||
args.b_k_stride = b.stride(0);
|
||||
args.b_n_size = b.size(1);
|
||||
args.b_n_stride = b.stride(1);
|
||||
args.b_ptr = b.data_ptr<int8_t>();
|
||||
|
||||
if (dynamic_act_quant) {
|
||||
// dynamic per-token, bias, A scales and A zps will be applied in outside.
|
||||
args.a_quantization_strategy =
|
||||
W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TOKEN;
|
||||
args.use_a_zero_point = false;
|
||||
} else {
|
||||
// static per-tensor
|
||||
args.a_quantization_strategy =
|
||||
W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR;
|
||||
args.use_a_zero_point = use_azp;
|
||||
}
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(output_type, "create_onednn_scaled_mm_handler",
|
||||
[&] {
|
||||
if (dynamic_act_quant) {
|
||||
args.c_type = get_dnnl_type<float>();
|
||||
} else {
|
||||
args.c_type = get_dnnl_type<scalar_t>();
|
||||
}
|
||||
});
|
||||
|
||||
return reinterpret_cast<int64_t>(new W8A8MatMulPrimitiveHandler(args));
|
||||
}
|
||||
|
||||
void onednn_scaled_mm(
|
||||
torch::Tensor& c, // [M, OC], row-major
|
||||
const torch::Tensor& a, // [M, IC], row-major
|
||||
const torch::Tensor& a_scales, // [M] or [1]
|
||||
const std::optional<torch::Tensor>& azp, // [M] or [1]
|
||||
const std::optional<torch::Tensor>& azp_adj, // [M] or [1]
|
||||
const std::optional<torch::Tensor>& bias, // [N]
|
||||
int64_t handler) {
|
||||
CPU_KERNEL_GUARD_IN(onednn_scaled_mm)
|
||||
TORCH_CHECK(a.dim() == 2);
|
||||
TORCH_CHECK(a.is_contiguous());
|
||||
TORCH_CHECK(c.is_contiguous());
|
||||
W8A8MatMulPrimitiveHandler* ptr =
|
||||
reinterpret_cast<W8A8MatMulPrimitiveHandler*>(handler);
|
||||
const int32_t* azp_ptr = nullptr;
|
||||
if (azp.has_value()) {
|
||||
azp_ptr = azp->data_ptr<int32_t>();
|
||||
}
|
||||
if (ptr->get_input_scale_strategy() ==
|
||||
W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR) {
|
||||
TORCH_CHECK_EQ(a_scales.numel(), 1);
|
||||
}
|
||||
|
||||
W8A8MatMulPrimitiveHandler::ExecArgs exec_args;
|
||||
exec_args.a_ptr = a.data_ptr<int8_t>();
|
||||
exec_args.a_m_size = a.size(0);
|
||||
exec_args.bias_ptr = nullptr;
|
||||
exec_args.use_bias = false;
|
||||
exec_args.a_scales_ptr = nullptr;
|
||||
exec_args.a_zero_points_ptr = nullptr;
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "onednn_scaled_mm", [&] {
|
||||
if (ptr->get_input_scale_strategy() ==
|
||||
W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR) {
|
||||
if (bias.has_value()) {
|
||||
exec_args.bias_ptr = bias->data_ptr<scalar_t>();
|
||||
exec_args.bias_type = get_dnnl_type<scalar_t>();
|
||||
exec_args.use_bias = true;
|
||||
}
|
||||
exec_args.a_scales_ptr = a_scales.data_ptr<float>();
|
||||
exec_args.a_zero_points_ptr = azp_ptr;
|
||||
exec_args.c_ptr = c.data_ptr<scalar_t>();
|
||||
ptr->execute(exec_args);
|
||||
} else if (ptr->get_input_scale_strategy() ==
|
||||
W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TOKEN) {
|
||||
torch::Tensor tmp_fp32_out =
|
||||
torch::empty_like(c, ::at::ScalarType::Float);
|
||||
exec_args.c_ptr = tmp_fp32_out.data_ptr<float>();
|
||||
ptr->execute(exec_args);
|
||||
if (bias.has_value()) {
|
||||
if (azp.has_value()) {
|
||||
dynamic_quant_epilogue<true, true>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), azp_ptr, azp_adj->data_ptr<float>(),
|
||||
bias->data_ptr<scalar_t>(), c.size(0), c.size(1));
|
||||
} else {
|
||||
dynamic_quant_epilogue<false, true>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), azp_ptr, nullptr,
|
||||
bias->data_ptr<scalar_t>(), c.size(0), c.size(1));
|
||||
}
|
||||
} else {
|
||||
if (azp.has_value()) {
|
||||
dynamic_quant_epilogue<true, false>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), azp_ptr, azp_adj->data_ptr<float>(),
|
||||
(scalar_t*)nullptr, c.size(0), c.size(1));
|
||||
} else {
|
||||
dynamic_quant_epilogue<false, false>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), azp_ptr, nullptr, (scalar_t*)nullptr,
|
||||
c.size(0), c.size(1));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "invalid act quant type.");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// static-per-tensor quantization.
|
||||
void static_scaled_int8_quant(
|
||||
torch::Tensor& out, // [batch, hidden_size]
|
||||
const torch::Tensor& input, // [batch, hidden_size]
|
||||
const torch::Tensor& scale, std::optional<torch::Tensor> const& azp) {
|
||||
CPU_KERNEL_GUARD_IN(static_scaled_int8_quant)
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK_EQ(input.dim(), 2);
|
||||
TORCH_CHECK_EQ(input.stride(1), 1);
|
||||
TORCH_CHECK(scale.numel() == 1);
|
||||
TORCH_CHECK(!azp.has_value() || azp->numel() == 1);
|
||||
|
||||
const int64_t stride = input.stride(0);
|
||||
const int64_t hidden_size = input.size(1);
|
||||
const int64_t num_tokens = input.size(0);
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "static_scaled_int8_quant_impl", [&] {
|
||||
if (azp.has_value()) {
|
||||
static_scaled_int8_quant_impl<true>(
|
||||
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), azp->data_ptr<int32_t>(), num_tokens,
|
||||
stride, hidden_size);
|
||||
} else {
|
||||
static_scaled_int8_quant_impl<false>(input.data_ptr<scalar_t>(),
|
||||
out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), nullptr,
|
||||
num_tokens, stride, hidden_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// dynamic-per-token quantization.
|
||||
void dynamic_scaled_int8_quant(
|
||||
torch::Tensor& out, // [batch, hidden_size]
|
||||
const torch::Tensor& input, // [batch, hidden_size]
|
||||
torch::Tensor& scale, // [batch, 1]
|
||||
std::optional<torch::Tensor> const& azp) {
|
||||
CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant)
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK_EQ(input.dim(), 2);
|
||||
TORCH_CHECK_EQ(input.stride(1), 1);
|
||||
|
||||
const int64_t hidden_size = input.size(1);
|
||||
const int64_t num_tokens = input.size(0);
|
||||
const int64_t stride = input.stride(0);
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "dynamic_scaled_int8_quant_impl", [&] {
|
||||
if (azp.has_value()) {
|
||||
dynamic_scaled_int8_quant_impl<true>(
|
||||
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), azp->data_ptr<int32_t>(), num_tokens,
|
||||
stride, hidden_size);
|
||||
} else {
|
||||
dynamic_scaled_int8_quant_impl<false>(
|
||||
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), nullptr, num_tokens, stride,
|
||||
hidden_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
951
csrc/cpu/quant.cpp
Normal file
951
csrc/cpu/quant.cpp
Normal file
@ -0,0 +1,951 @@
|
||||
#include "cpu_types.hpp"
|
||||
#include "dnnl_helper.hpp"
|
||||
|
||||
namespace {
|
||||
template <typename scalar_t>
|
||||
struct KernelVecType {
|
||||
using load_vec_type = void;
|
||||
using azp_adj_load_vec_type = void;
|
||||
using cvt_vec_type = void;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KernelVecType<float> {
|
||||
using load_vec_type = vec_op::FP32Vec16;
|
||||
using azp_adj_load_vec_type = vec_op::INT32Vec16;
|
||||
using cvt_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
|
||||
#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT)
|
||||
template <>
|
||||
struct KernelVecType<c10::BFloat16> {
|
||||
using load_vec_type = vec_op::BF16Vec16;
|
||||
using azp_adj_load_vec_type = vec_op::INT32Vec16;
|
||||
using cvt_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <>
|
||||
struct KernelVecType<c10::Half> {
|
||||
#if defined(__powerpc64__) || defined(__s390x__)
|
||||
// Power architecture-specific vector type
|
||||
using load_vec_type = vec_op::FP32Vec16;
|
||||
#else
|
||||
// Fallback for other architectures
|
||||
using load_vec_type = vec_op::FP16Vec16;
|
||||
#endif
|
||||
using azp_adj_load_vec_type = vec_op::INT32Vec16;
|
||||
using cvt_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
|
||||
#if defined(__AVX512F__) || defined(__aarch64__)
|
||||
template <bool AZP, typename scalar_t>
|
||||
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
const float* scale, const int32_t* azp,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
constexpr float i8_min =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::min());
|
||||
constexpr float i8_max =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::max());
|
||||
const cvt_vec_t inv_scale(1.0 / *scale);
|
||||
const cvt_vec_t i8_min_vec(i8_min);
|
||||
const cvt_vec_t i8_max_vec(i8_max);
|
||||
|
||||
cvt_vec_t zp_vec;
|
||||
if constexpr (AZP) {
|
||||
zp_vec = cvt_vec_t(static_cast<float>(*azp));
|
||||
}
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
int j = 0;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = elems_fp32 * inv_scale;
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + zp_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output + i * hidden_size + j);
|
||||
}
|
||||
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = elems_fp32 * inv_scale;
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + zp_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output + i * hidden_size + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool AZP, typename scalar_t>
|
||||
void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
float* scale, int32_t* azp,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
constexpr float i8_min =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::min());
|
||||
constexpr float i8_max =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::max());
|
||||
const cvt_vec_t i8_min_vec(i8_min);
|
||||
const cvt_vec_t i8_max_vec(i8_max);
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
cvt_vec_t max_value(std::numeric_limits<float>::lowest());
|
||||
cvt_vec_t min_value(std::numeric_limits<float>::max());
|
||||
{
|
||||
int j = 0;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
if constexpr (AZP) {
|
||||
max_value = max_value.max(elems_fp32);
|
||||
min_value = min_value.min(elems_fp32);
|
||||
} else {
|
||||
max_value = max_value.max(elems_fp32.abs());
|
||||
}
|
||||
}
|
||||
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
|
||||
if (j + vec_elem_num == hidden_size) {
|
||||
if constexpr (AZP) {
|
||||
max_value = max_value.max(elems_fp32);
|
||||
min_value = min_value.min(elems_fp32);
|
||||
} else {
|
||||
max_value = max_value.max(elems_fp32.abs());
|
||||
}
|
||||
} else {
|
||||
if constexpr (AZP) {
|
||||
max_value = max_value.max(elems_fp32, hidden_size - j);
|
||||
min_value = min_value.min(elems_fp32, hidden_size - j);
|
||||
} else {
|
||||
max_value = max_value.max(elems_fp32.abs(), hidden_size - j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float scale_val, azp_val;
|
||||
if constexpr (AZP) {
|
||||
float max_scalar = max_value.reduce_max();
|
||||
float min_scalar = min_value.reduce_min();
|
||||
scale_val = (max_scalar - min_scalar) / 255.0f;
|
||||
azp_val = std::nearbyint(-128.0f - min_scalar / scale_val);
|
||||
azp[i] = static_cast<int32_t>(azp_val);
|
||||
scale[i] = scale_val;
|
||||
} else {
|
||||
scale_val = max_value.reduce_max() / 127.0f;
|
||||
scale[i] = scale_val;
|
||||
}
|
||||
|
||||
const cvt_vec_t inv_scale(1.0 / scale_val);
|
||||
const cvt_vec_t azp_vec(azp_val);
|
||||
|
||||
{
|
||||
int j = 0;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = (elems_fp32 * inv_scale);
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + azp_vec;
|
||||
}
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output + i * hidden_size + j);
|
||||
}
|
||||
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = (elems_fp32 * inv_scale);
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + azp_vec;
|
||||
}
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output + i * hidden_size + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool PerChannel, typename scalar_t>
|
||||
void static_quant_epilogue(const float* input, scalar_t* output,
|
||||
const float a_scale, const float* b_scale,
|
||||
const int32_t* azp_with_adj, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl)
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using azp_adj_load_vec_t =
|
||||
typename KernelVecType<scalar_t>::azp_adj_load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
cvt_vec_t a_scale_vec(a_scale);
|
||||
cvt_vec_t b_scale_vec(*b_scale);
|
||||
cvt_vec_t scale_vec = a_scale_vec * b_scale_vec;
|
||||
|
||||
int j = 0;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
cvt_vec_t elems_fp32(input + i * hidden_size + j);
|
||||
azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j);
|
||||
cvt_vec_t azp_adj_fp32(azp_adj_vec);
|
||||
|
||||
if constexpr (PerChannel) {
|
||||
b_scale_vec = cvt_vec_t(b_scale + j);
|
||||
scale_vec = b_scale_vec * a_scale_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32;
|
||||
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output + i * hidden_size + j);
|
||||
}
|
||||
|
||||
cvt_vec_t elems_fp32(input + i * hidden_size + j);
|
||||
azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j);
|
||||
cvt_vec_t azp_adj_fp32(azp_adj_vec);
|
||||
|
||||
if constexpr (PerChannel) {
|
||||
b_scale_vec = cvt_vec_t(b_scale + j);
|
||||
scale_vec = b_scale_vec * a_scale_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32;
|
||||
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output + i * hidden_size + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool AZP, bool PerChannel, bool Bias, typename scalar_t>
|
||||
void dynamic_quant_epilogue(const float* input, scalar_t* output,
|
||||
const float* a_scale, const float* b_scale,
|
||||
const int32_t* azp, const int32_t* azp_adj,
|
||||
const scalar_t* bias, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue)
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using azp_adj_load_vec_t =
|
||||
typename KernelVecType<scalar_t>::azp_adj_load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
int j = 0;
|
||||
cvt_vec_t token_scale_vec(a_scale[i]);
|
||||
cvt_vec_t token_zp_scale_vec;
|
||||
if constexpr (AZP) {
|
||||
float zp_scale_val = a_scale[i] * static_cast<float>(azp[i]);
|
||||
if constexpr (!PerChannel) {
|
||||
zp_scale_val *= *b_scale;
|
||||
}
|
||||
token_zp_scale_vec = cvt_vec_t(zp_scale_val);
|
||||
}
|
||||
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
cvt_vec_t elems_fp32(input + i * hidden_size + j);
|
||||
elems_fp32 = elems_fp32 * token_scale_vec;
|
||||
|
||||
if constexpr (AZP) {
|
||||
azp_adj_load_vec_t azp_adj_vec(azp_adj + j);
|
||||
cvt_vec_t azp_adj_fp32(azp_adj_vec);
|
||||
azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec;
|
||||
|
||||
if constexpr (PerChannel) {
|
||||
cvt_vec_t b_scale_vec(b_scale + j);
|
||||
azp_adj_fp32 = azp_adj_fp32 * b_scale_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32 - azp_adj_fp32;
|
||||
}
|
||||
|
||||
if constexpr (Bias) {
|
||||
load_vec_t bias_vec(bias + j);
|
||||
cvt_vec_t bias_vec_fp32(bias_vec);
|
||||
elems_fp32 = elems_fp32 + bias_vec_fp32;
|
||||
}
|
||||
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output + i * hidden_size + j);
|
||||
}
|
||||
|
||||
cvt_vec_t elems_fp32(input + i * hidden_size + j);
|
||||
elems_fp32 = elems_fp32 * token_scale_vec;
|
||||
|
||||
if constexpr (AZP) {
|
||||
azp_adj_load_vec_t azp_adj_vec(azp_adj + j);
|
||||
cvt_vec_t azp_adj_fp32(azp_adj_vec);
|
||||
azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec;
|
||||
|
||||
if constexpr (PerChannel) {
|
||||
cvt_vec_t b_scale_vec(b_scale + j);
|
||||
azp_adj_fp32 = azp_adj_fp32 * b_scale_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32 - azp_adj_fp32;
|
||||
}
|
||||
|
||||
if constexpr (Bias) {
|
||||
load_vec_t bias_vec(bias + j);
|
||||
cvt_vec_t bias_vec_fp32(bias_vec);
|
||||
elems_fp32 = elems_fp32 + bias_vec_fp32;
|
||||
}
|
||||
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output + i * hidden_size + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
#elif defined(__powerpc64__)
|
||||
template <bool AZP, typename scalar_t>
|
||||
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
const float* scale, const int32_t* azp,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
constexpr float i8_min =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::min());
|
||||
constexpr float i8_max =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::max());
|
||||
|
||||
const cvt_vec_t inv_scale(1.0 / *scale);
|
||||
const cvt_vec_t i8_min_vec(i8_min);
|
||||
const cvt_vec_t i8_max_vec(i8_max);
|
||||
|
||||
cvt_vec_t zp_vec;
|
||||
if constexpr (AZP) {
|
||||
zp_vec = cvt_vec_t(static_cast<float>(*azp));
|
||||
}
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
int j = 0;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = elems_fp32 * inv_scale;
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + zp_vec;
|
||||
}
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output + i * hidden_size + j);
|
||||
}
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = elems_fp32 * inv_scale;
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + zp_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output + i * hidden_size + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
template <bool AZP, typename scalar_t>
|
||||
void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
float* scale, int32_t* azp,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
constexpr float i8_min =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::min());
|
||||
constexpr float i8_max =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::max());
|
||||
const cvt_vec_t i8_min_vec(i8_min);
|
||||
const cvt_vec_t i8_max_vec(i8_max);
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
cvt_vec_t max_value(std::numeric_limits<float>::lowest());
|
||||
cvt_vec_t min_value(std::numeric_limits<float>::max());
|
||||
{
|
||||
int j = 0;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
if constexpr (AZP) {
|
||||
max_value = max_value.max(elems_fp32);
|
||||
min_value = min_value.min(elems_fp32);
|
||||
} else {
|
||||
max_value = max_value.max(elems_fp32.abs());
|
||||
}
|
||||
}
|
||||
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
|
||||
if (j + vec_elem_num == hidden_size) {
|
||||
if constexpr (AZP) {
|
||||
max_value = max_value.max(elems_fp32);
|
||||
min_value = min_value.min(elems_fp32);
|
||||
} else {
|
||||
max_value = max_value.max(elems_fp32.abs());
|
||||
}
|
||||
} else {
|
||||
if constexpr (AZP) {
|
||||
max_value = max_value.max(elems_fp32, hidden_size - j);
|
||||
min_value = min_value.min(elems_fp32, hidden_size - j);
|
||||
} else {
|
||||
max_value = max_value.max(elems_fp32.abs(), hidden_size - j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float scale_val, azp_val;
|
||||
if constexpr (AZP) {
|
||||
float max_scalar = max_value.reduce_max();
|
||||
float min_scalar = min_value.reduce_min();
|
||||
scale_val = (max_scalar - min_scalar) / 255.0f;
|
||||
azp_val = std::nearbyint(-128.0f - min_scalar / scale_val);
|
||||
azp[i] = static_cast<int32_t>(azp_val);
|
||||
scale[i] = scale_val;
|
||||
} else {
|
||||
scale_val = max_value.reduce_max() / 127.0f;
|
||||
scale[i] = scale_val;
|
||||
}
|
||||
|
||||
const cvt_vec_t inv_scale(1.0 / scale_val);
|
||||
const cvt_vec_t azp_vec(azp_val);
|
||||
|
||||
{
|
||||
int j = 0;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = (elems_fp32 * inv_scale);
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + azp_vec;
|
||||
}
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output + i * hidden_size + j);
|
||||
}
|
||||
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = (elems_fp32 * inv_scale);
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + azp_vec;
|
||||
}
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output + i * hidden_size + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
}
|
||||
template <bool PerChannel, typename scalar_t>
|
||||
void static_quant_epilogue(const float* input, scalar_t* output,
|
||||
const float a_scale, const float* b_scale,
|
||||
const int32_t* azp_with_adj, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl)
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using azp_adj_load_vec_t =
|
||||
typename KernelVecType<scalar_t>::azp_adj_load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
cvt_vec_t a_scale_vec(a_scale);
|
||||
cvt_vec_t b_scale_vec(*b_scale);
|
||||
cvt_vec_t scale_vec = a_scale_vec * b_scale_vec;
|
||||
|
||||
int j = 0;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
cvt_vec_t elems_fp32(input + i * hidden_size + j);
|
||||
azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j);
|
||||
cvt_vec_t azp_adj_fp32(azp_adj_vec);
|
||||
|
||||
if constexpr (PerChannel) {
|
||||
b_scale_vec = cvt_vec_t(b_scale + j);
|
||||
scale_vec = b_scale_vec * a_scale_vec;
|
||||
}
|
||||
elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32;
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output + i * hidden_size + j);
|
||||
}
|
||||
|
||||
cvt_vec_t elems_fp32(input + i * hidden_size + j);
|
||||
azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j);
|
||||
cvt_vec_t azp_adj_fp32(azp_adj_vec);
|
||||
|
||||
if constexpr (PerChannel) {
|
||||
b_scale_vec = cvt_vec_t(b_scale + j);
|
||||
scale_vec = b_scale_vec * a_scale_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32;
|
||||
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output + i * hidden_size + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
template <bool AZP, bool PerChannel, bool Bias, typename scalar_t>
|
||||
void dynamic_quant_epilogue(const float* input, scalar_t* output,
|
||||
const float* a_scale, const float* b_scale,
|
||||
const int32_t* azp, const int32_t* azp_adj,
|
||||
const scalar_t* bias, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue)
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using azp_adj_load_vec_t =
|
||||
typename KernelVecType<scalar_t>::azp_adj_load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
int j = 0;
|
||||
cvt_vec_t token_scale_vec(a_scale[i]);
|
||||
cvt_vec_t token_zp_scale_vec;
|
||||
if constexpr (AZP) {
|
||||
float zp_scale_val = a_scale[i] * static_cast<float>(azp[i]);
|
||||
if constexpr (!PerChannel) {
|
||||
zp_scale_val *= *b_scale;
|
||||
}
|
||||
token_zp_scale_vec = cvt_vec_t(zp_scale_val);
|
||||
}
|
||||
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
cvt_vec_t elems_fp32(input + i * hidden_size + j);
|
||||
elems_fp32 = elems_fp32 * token_scale_vec;
|
||||
|
||||
if constexpr (AZP) {
|
||||
azp_adj_load_vec_t azp_adj_vec(azp_adj + j);
|
||||
cvt_vec_t azp_adj_fp32(azp_adj_vec);
|
||||
azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec;
|
||||
|
||||
if constexpr (PerChannel) {
|
||||
cvt_vec_t b_scale_vec(b_scale + j);
|
||||
azp_adj_fp32 = azp_adj_fp32 * b_scale_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32 - azp_adj_fp32;
|
||||
}
|
||||
|
||||
if constexpr (Bias) {
|
||||
load_vec_t bias_vec(bias + j);
|
||||
cvt_vec_t bias_vec_fp32(bias_vec);
|
||||
elems_fp32 = elems_fp32 + bias_vec_fp32;
|
||||
}
|
||||
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output + i * hidden_size + j);
|
||||
}
|
||||
|
||||
cvt_vec_t elems_fp32(input + i * hidden_size + j);
|
||||
elems_fp32 = elems_fp32 * token_scale_vec;
|
||||
|
||||
if constexpr (AZP) {
|
||||
azp_adj_load_vec_t azp_adj_vec(azp_adj + j);
|
||||
cvt_vec_t azp_adj_fp32(azp_adj_vec);
|
||||
azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec;
|
||||
|
||||
if constexpr (PerChannel) {
|
||||
cvt_vec_t b_scale_vec(b_scale + j);
|
||||
azp_adj_fp32 = azp_adj_fp32 * b_scale_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32 - azp_adj_fp32;
|
||||
}
|
||||
|
||||
if constexpr (Bias) {
|
||||
load_vec_t bias_vec(bias + j);
|
||||
cvt_vec_t bias_vec_fp32(bias_vec);
|
||||
elems_fp32 = elems_fp32 + bias_vec_fp32;
|
||||
}
|
||||
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output + i * hidden_size + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
#else
|
||||
template <typename scalar_t>
|
||||
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
const float* scale, const int32_t* azp,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(false,
|
||||
"static_scaled_int8_quant_impl requires AVX512/powerpc64/AArch64 "
|
||||
"support.")
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
float* scale, int32_t* azp,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(false,
|
||||
"dynamic_scaled_int8_quant_impl requires "
|
||||
"AVX512/powerpc64/AArch64 support.")
|
||||
}
|
||||
|
||||
template <bool PerChannel, typename scalar_t>
|
||||
void static_quant_epilogue(const float* input, scalar_t* output,
|
||||
const float a_scale, const float* b_scale,
|
||||
const int32_t* azp_with_adj, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(
|
||||
false, "static_quant_epilogue requires AVX512/powerpc64/AArch64 support.")
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void dynamic_quant_epilogue(const float* input, scalar_t* output,
|
||||
const float* a_scale, const float* b_scale,
|
||||
const int32_t* azp, const int32_t* azp_with_adj,
|
||||
const scalar_t* bias, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"dynamic_quant_epilogue requires AVX512/powerpc64/AArch64 support.")
|
||||
}
|
||||
#endif
|
||||
} // namespace
|
||||
|
||||
void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major
|
||||
const torch::Tensor& a, // [M, IC], row-major
|
||||
const torch::Tensor& b, // [IC, OC], column-major
|
||||
const torch::Tensor& a_scales, // [1] or [M]
|
||||
const torch::Tensor& b_scales, // [1] or [OC]
|
||||
const std::optional<torch::Tensor>& bias // [OC]
|
||||
) {
|
||||
CPU_KERNEL_GUARD_IN(cutlass_scaled_mm)
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8,
|
||||
"int8_scaled_mm only supports INT8 inputs.")
|
||||
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||
b.size(1) == c.size(1));
|
||||
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
||||
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
|
||||
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||
TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
||||
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
|
||||
bias->dim() == 1);
|
||||
}
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm", [&] {
|
||||
if (a_scales.numel() != 1) {
|
||||
// per-token
|
||||
// Note: oneDNN doesn't support per-token activation quantization
|
||||
// Ideally we want to fuse the GEMM and the scale procedure with oneDNN
|
||||
// JIT, the intermediate data is cached in registers or L1. But for now
|
||||
// the oneDNN GEMM code generation only supports two quantization
|
||||
// patterns: per-tensor or per-output-channel of weight.
|
||||
// So we have to apply the per-token scale with a 'epilogue'. In C=s_a *
|
||||
// s_b * (A@B) + bias, the C_inter = s_b * (A@B) is computed by oneDNN
|
||||
// GEMM, then the per-token scale (and bias) is applied with the epilogue
|
||||
// C=s_a * C_inter + bias.
|
||||
torch::Tensor tmp_fp32_out =
|
||||
torch::empty_like(c, ::at::ScalarType::Float);
|
||||
// Compute C_inter=s_b * (A@B)
|
||||
DNNLPrimitiveHelper<true>::gemm_s8s8_jit<float, void>(
|
||||
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(),
|
||||
tmp_fp32_out.data_ptr<float>(), nullptr, a.size(0), b.size(1),
|
||||
a.size(1), nullptr, b_scales.data_ptr<float>(), 0, b_scales.numel());
|
||||
if (bias.has_value()) {
|
||||
// Compute C=s_a * C_inter + bias
|
||||
dynamic_quant_epilogue<false, true, true>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), nullptr, nullptr, nullptr,
|
||||
bias->data_ptr<scalar_t>(), c.size(0), c.size(1));
|
||||
} else {
|
||||
// Compute C=s_a * C_inter
|
||||
dynamic_quant_epilogue<false, true, false, scalar_t>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), nullptr, nullptr, nullptr, nullptr,
|
||||
c.size(0), c.size(1));
|
||||
}
|
||||
} else {
|
||||
// per-tensor
|
||||
if (bias.has_value()) {
|
||||
// Compute C=s_a * s_b * (A@B) + bias
|
||||
DNNLPrimitiveHelper<false>::gemm_s8s8_jit(
|
||||
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(), c.data_ptr<scalar_t>(),
|
||||
bias->data_ptr<scalar_t>(), a.size(0), b.size(1), a.size(1),
|
||||
a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
|
||||
a_scales.numel(), b_scales.numel());
|
||||
} else {
|
||||
// Compute C=s_a * s_b * (A@B)
|
||||
DNNLPrimitiveHelper<false>::gemm_s8s8_jit<scalar_t, void>(
|
||||
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(), c.data_ptr<scalar_t>(),
|
||||
nullptr, a.size(0), b.size(1), a.size(1),
|
||||
a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
|
||||
a_scales.numel(), b_scales.numel());
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void int8_scaled_mm_azp(torch::Tensor& c, // [M, OC], row-major
|
||||
const torch::Tensor& a, // [M, IC], row-major
|
||||
const torch::Tensor& b, // [IC, OC], column-major
|
||||
const torch::Tensor& a_scales, // [1] or [M]
|
||||
const torch::Tensor& b_scales, // [1] or [OC]
|
||||
const torch::Tensor& azp_adj, // [OC]
|
||||
const std::optional<torch::Tensor>& azp, // [1] or [M]
|
||||
const std::optional<torch::Tensor>& bias // [OC]
|
||||
) {
|
||||
CPU_KERNEL_GUARD_IN(cutlass_scaled_mm_azp)
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8,
|
||||
"int8_scaled_mm_azp only supports INT8 inputs.")
|
||||
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||
b.size(1) == c.size(1));
|
||||
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
||||
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
|
||||
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||
TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
||||
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous());
|
||||
}
|
||||
if (azp) {
|
||||
TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous());
|
||||
}
|
||||
TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous());
|
||||
|
||||
// azp & bias types
|
||||
TORCH_CHECK(azp_adj.dtype() == torch::kInt32);
|
||||
TORCH_CHECK(!azp || azp->dtype() == torch::kInt32);
|
||||
TORCH_CHECK(!bias || bias->dtype() == c.dtype(),
|
||||
"currently bias dtype must match output dtype ", c.dtype());
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm_azp", [&] {
|
||||
torch::Tensor tmp_fp32_out = torch::empty_like(c, ::at::ScalarType::Float);
|
||||
if (a_scales.numel() != 1) {
|
||||
// per-token
|
||||
// Note: oneDNN doesn't support per-token activation quantization
|
||||
// Compute C_inter=s_b * (A@B)
|
||||
DNNLPrimitiveHelper<true>::gemm_s8s8_jit<float, void>(
|
||||
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(),
|
||||
tmp_fp32_out.data_ptr<float>(), nullptr, a.size(0), b.size(1),
|
||||
a.size(1), nullptr, b_scales.data_ptr<float>(), 0, b_scales.numel());
|
||||
if (bias.has_value()) {
|
||||
// Compute C=s_a * C_inter - s_a * s_b * azp * azp_adj + bias
|
||||
if (b_scales.numel() != 1) {
|
||||
// Per-Channel
|
||||
dynamic_quant_epilogue<true, true, true>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
|
||||
azp->data_ptr<int32_t>(), azp_adj.data_ptr<int32_t>(),
|
||||
bias->data_ptr<scalar_t>(), c.size(0), c.size(1));
|
||||
} else {
|
||||
// Per-Tensor
|
||||
dynamic_quant_epilogue<true, false, true>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
|
||||
azp->data_ptr<int32_t>(), azp_adj.data_ptr<int32_t>(),
|
||||
bias->data_ptr<scalar_t>(), c.size(0), c.size(1));
|
||||
}
|
||||
} else {
|
||||
// Compute C=s_a * C_inter - s_a * s_b * azp * azp_adj
|
||||
if (b_scales.numel() != 1) {
|
||||
// Per-Channel
|
||||
dynamic_quant_epilogue<true, true, false, scalar_t>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
|
||||
azp->data_ptr<int32_t>(), azp_adj.data_ptr<int32_t>(), nullptr,
|
||||
c.size(0), c.size(1));
|
||||
} else {
|
||||
// Per-Tensor
|
||||
dynamic_quant_epilogue<true, false, false, scalar_t>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
|
||||
azp->data_ptr<int32_t>(), azp_adj.data_ptr<int32_t>(), nullptr,
|
||||
c.size(0), c.size(1));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// per-tensor
|
||||
if (bias.has_value()) {
|
||||
// Compute C_inter=s_a * s_b * (A@B) + bias
|
||||
DNNLPrimitiveHelper<false>::gemm_s8s8_jit(
|
||||
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(),
|
||||
tmp_fp32_out.data_ptr<float>(), bias->data_ptr<scalar_t>(),
|
||||
a.size(0), b.size(1), a.size(1), a_scales.data_ptr<float>(),
|
||||
b_scales.data_ptr<float>(), a_scales.numel(), b_scales.numel());
|
||||
} else {
|
||||
// Compute C_inter=s_a * s_b * (A@B)
|
||||
DNNLPrimitiveHelper<false>::gemm_s8s8_jit<float, void>(
|
||||
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(),
|
||||
tmp_fp32_out.data_ptr<float>(), nullptr, a.size(0), b.size(1),
|
||||
a.size(1), a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
|
||||
a_scales.numel(), b_scales.numel());
|
||||
}
|
||||
|
||||
// Compute C=C_inter - s_a * s_b * azp_adj
|
||||
if (b_scales.numel() != 1) {
|
||||
// Per-Channel
|
||||
static_quant_epilogue<true>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
*a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
|
||||
azp_adj.data_ptr<int32_t>(), a.size(0), b.size(1));
|
||||
} else {
|
||||
// Per-Tensor
|
||||
static_quant_epilogue<false>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
*a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
|
||||
azp_adj.data_ptr<int32_t>(), a.size(0), b.size(1));
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// static-per-tensor quantization.
|
||||
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||
const torch::Tensor& input, // [..., hidden_size]
|
||||
const torch::Tensor& scale,
|
||||
std::optional<torch::Tensor> const& azp) {
|
||||
CPU_KERNEL_GUARD_IN(static_scaled_int8_quant)
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK(scale.numel() == 1);
|
||||
TORCH_CHECK(!azp.has_value() || azp->numel() == 1);
|
||||
|
||||
const int hidden_size = input.size(-1);
|
||||
const int num_tokens = input.numel() / hidden_size;
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "static_scaled_int8_quant_impl", [&] {
|
||||
if (azp.has_value()) {
|
||||
static_scaled_int8_quant_impl<true>(
|
||||
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), azp->data_ptr<int32_t>(), num_tokens,
|
||||
hidden_size);
|
||||
} else {
|
||||
static_scaled_int8_quant_impl<false>(
|
||||
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), nullptr, num_tokens, hidden_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// dynamic-per-token quantization.
|
||||
void dynamic_scaled_int8_quant(
|
||||
torch::Tensor& out, // [..., hidden_size]
|
||||
const torch::Tensor& input, // [..., hidden_size]
|
||||
torch::Tensor& scale, // [..., 1]
|
||||
std::optional<torch::Tensor> const& azp) {
|
||||
CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant)
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
|
||||
int const hidden_size = input.size(-1);
|
||||
int const num_tokens = input.numel() / hidden_size;
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "dynamic_scaled_int8_quant_impl", [&] {
|
||||
if (azp.has_value()) {
|
||||
dynamic_scaled_int8_quant_impl<true>(
|
||||
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), azp->data_ptr<int32_t>(), num_tokens,
|
||||
hidden_size);
|
||||
} else {
|
||||
dynamic_scaled_int8_quant_impl<false>(
|
||||
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), nullptr, num_tokens, hidden_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#if defined(__powerpc64__)
|
||||
void int8_scaled_mm_ppc64le(torch::Tensor& c, // [M, OC], row-major
|
||||
const torch::Tensor& a, // [M, IC], row-major
|
||||
const torch::Tensor& b, // [IC, OC], column-major
|
||||
const torch::Tensor& a_scales,
|
||||
const torch::Tensor& b_scales,
|
||||
const std::optional<torch::Tensor>& bias // [OC]
|
||||
) {
|
||||
CPU_KERNEL_GUARD_IN(cutlass_scaled_mm)
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8,
|
||||
"int8_scaled_mm_ppc64le only supports INT8 inputs.");
|
||||
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||
b.size(1) == c.size(1));
|
||||
// We dont need this
|
||||
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
||||
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
|
||||
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||
TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
||||
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
|
||||
bias->dim() == 1);
|
||||
}
|
||||
VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm_ppc64le", [&] {
|
||||
torch::Tensor tmp_fp32_out = torch::empty_like(c, ::at::ScalarType::Float);
|
||||
// Compute C_inter=s_b * (A@B)
|
||||
DNNLPrimitiveHelper<true>::gemm_s8s8_jit<float, void>(
|
||||
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(),
|
||||
tmp_fp32_out.data_ptr<float>(), nullptr, a.size(0), b.size(1),
|
||||
a.size(1), nullptr, b_scales.data_ptr<float>(), 0, b_scales.numel());
|
||||
if (bias.has_value()) {
|
||||
// Compute C=s_a * C_inter + bias
|
||||
dynamic_quant_epilogue<false, true, true>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), nullptr, nullptr, nullptr,
|
||||
bias->data_ptr<scalar_t>(), c.size(0), c.size(1));
|
||||
} else {
|
||||
// Compute C=s_a * C_inter
|
||||
dynamic_quant_epilogue<false, true, false, scalar_t>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), nullptr, nullptr, nullptr, nullptr,
|
||||
c.size(0), c.size(1));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#endif
|
||||
@ -6,20 +6,25 @@
|
||||
|
||||
std::string init_cpu_threads_env(const std::string& cpu_ids);
|
||||
|
||||
void release_dnnl_matmul_handler(int64_t handler);
|
||||
void int8_scaled_mm(torch::Tensor& c, const torch::Tensor& a,
|
||||
const torch::Tensor& b, const torch::Tensor& a_scales,
|
||||
const torch::Tensor& b_scales,
|
||||
const std::optional<torch::Tensor>& bias);
|
||||
|
||||
int64_t create_onednn_scaled_mm_handler(const torch::Tensor& b,
|
||||
const torch::Tensor& b_scales,
|
||||
at::ScalarType output_type,
|
||||
bool dynamic_act_quant, bool use_azp,
|
||||
int64_t primitive_cache_size);
|
||||
void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a,
|
||||
const torch::Tensor& b, const torch::Tensor& a_scales,
|
||||
const torch::Tensor& b_scales,
|
||||
const torch::Tensor& azp_adj,
|
||||
const std::optional<torch::Tensor>& azp,
|
||||
const std::optional<torch::Tensor>& bias);
|
||||
|
||||
void onednn_scaled_mm(torch::Tensor& c, const torch::Tensor& a,
|
||||
const torch::Tensor& a_scales,
|
||||
const std::optional<torch::Tensor>& azp,
|
||||
const std::optional<torch::Tensor>& azp_adj,
|
||||
const std::optional<torch::Tensor>& bias,
|
||||
int64_t handler);
|
||||
#if defined(__powerpc64__)
|
||||
void int8_scaled_mm_ppc64le(torch::Tensor& c, const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& a_scales,
|
||||
const torch::Tensor& b_scales,
|
||||
const std::optional<torch::Tensor>& bias);
|
||||
#endif
|
||||
|
||||
void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
|
||||
torch::Tensor& kv_cache, double scale,
|
||||
@ -146,25 +151,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding);
|
||||
|
||||
// Quantization
|
||||
#if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__)) || \
|
||||
defined(__powerpc64__)
|
||||
#if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__))
|
||||
at::Tag stride_tag = at::Tag::needs_fixed_stride_order;
|
||||
// Helper function to release oneDNN handlers
|
||||
ops.def("release_dnnl_matmul_handler(int handler) -> ()",
|
||||
&release_dnnl_matmul_handler);
|
||||
|
||||
// Create oneDNN W8A8 handler
|
||||
ops.def(
|
||||
"create_onednn_scaled_mm_handler(Tensor b, Tensor b_scales, ScalarType "
|
||||
"output_type, bool dynamic_act_quant, bool use_azp, int "
|
||||
"primitive_cache_size) -> int",
|
||||
&create_onednn_scaled_mm_handler);
|
||||
|
||||
// oneDNN scaled_mm for W8A8 with static per-tensor activation quantization
|
||||
ops.def(
|
||||
"onednn_scaled_mm(Tensor! c, Tensor a, Tensor a_scales, Tensor? azp, "
|
||||
"Tensor? azp_adj, Tensor? bias, int handler) -> ()");
|
||||
ops.impl("onednn_scaled_mm", torch::kCPU, &onednn_scaled_mm);
|
||||
|
||||
// Compute int8 quantized tensor for given scaling factor.
|
||||
ops.def(
|
||||
@ -180,6 +168,50 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
{stride_tag});
|
||||
ops.impl("dynamic_scaled_int8_quant", torch::kCPU,
|
||||
&dynamic_scaled_int8_quant);
|
||||
// W8A8 GEMM, supporting symmetric per-tensor or per-row/column
|
||||
// quantization.
|
||||
ops.def(
|
||||
"cutlass_scaled_mm(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor? bias) -> ()",
|
||||
{stride_tag});
|
||||
ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm);
|
||||
// w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
|
||||
// quantization.
|
||||
ops.def(
|
||||
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor azp_adj,"
|
||||
" Tensor? azp, Tensor? bias) -> ()",
|
||||
{stride_tag});
|
||||
ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp);
|
||||
#elif defined(__powerpc64__)
|
||||
// Compute int8 quantized tensor for given scaling factor.
|
||||
ops.def(
|
||||
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
|
||||
"Tensor? azp) -> ()");
|
||||
ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant);
|
||||
|
||||
// Compute int8 quantized tensor and scaling factor
|
||||
ops.def(
|
||||
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, "
|
||||
"Tensor!? azp) -> ()");
|
||||
ops.impl("dynamic_scaled_int8_quant", torch::kCPU,
|
||||
&dynamic_scaled_int8_quant);
|
||||
// W8A8 GEMM, supporting symmetric quantization.
|
||||
ops.def(
|
||||
"cutlass_scaled_mm(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor? bias) -> ()");
|
||||
ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm_ppc64le);
|
||||
// w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
|
||||
// quantization.
|
||||
ops.def(
|
||||
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor azp_adj,"
|
||||
" Tensor? azp, Tensor? bias) -> ()");
|
||||
ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp);
|
||||
#endif
|
||||
|
||||
// SHM CCL
|
||||
|
||||
@ -20,7 +20,6 @@ namespace MARLIN_NAMESPACE_NAME {
|
||||
TEMPLATE = ("template __global__ void Marlin<"
|
||||
"{{scalar_t}}, "
|
||||
"{{w_type_id}}, "
|
||||
"{{s_type_id}}, "
|
||||
"{{threads}}, "
|
||||
"{{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, "
|
||||
@ -78,7 +77,6 @@ def generate_new_kernels():
|
||||
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
|
||||
continue
|
||||
# nvfp4 only supports group_size == 16
|
||||
# mxfp4 only supports group_size == 32
|
||||
if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]:
|
||||
continue
|
||||
# other quantization methods don't support group_size = 16
|
||||
@ -91,22 +89,9 @@ def generate_new_kernels():
|
||||
|
||||
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
|
||||
|
||||
if scalar_type == "vllm::kFE2M1f" and group_blocks == 1:
|
||||
s_type = "vllm::kFE4M3fn"
|
||||
elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2:
|
||||
s_type = "vllm::kFE8M0fnu"
|
||||
if dtype == "fp16":
|
||||
# we cannot safely dequantize e8m0 to fp16, so skip this
|
||||
continue
|
||||
elif dtype == "fp16":
|
||||
s_type = "vllm::kFloat16"
|
||||
elif dtype == "bf16":
|
||||
s_type = "vllm::kBFloat16"
|
||||
|
||||
template_str = jinja2.Template(TEMPLATE).render(
|
||||
scalar_t=c_dtype,
|
||||
w_type_id=scalar_type + ".id()",
|
||||
s_type_id=s_type + ".id()",
|
||||
threads=threads,
|
||||
thread_m_blocks=max(m_blocks, 1),
|
||||
thread_n_blocks=n_blocks,
|
||||
|
||||
@ -7,25 +7,23 @@
|
||||
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
#define MARLIN_KERNEL_PARAMS \
|
||||
const int4 *__restrict__ A, const int4 *__restrict__ B, \
|
||||
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
||||
const int4 *__restrict__ b_bias_ptr, \
|
||||
const int4 *__restrict__ scales_ptr, \
|
||||
const uint16_t *__restrict__ scale2_ptr, \
|
||||
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
|
||||
const int32_t *__restrict__ sorted_token_ids_ptr, \
|
||||
const int32_t *__restrict__ expert_ids_ptr, \
|
||||
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
|
||||
const float *__restrict__ topk_weights_ptr, int top_k, \
|
||||
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
|
||||
int prob_n, int prob_k, int *locks, bool has_bias, bool use_atomic_add, \
|
||||
#define MARLIN_KERNEL_PARAMS \
|
||||
const int4 *__restrict__ A, const int4 *__restrict__ B, \
|
||||
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
||||
const int4 *__restrict__ scales_ptr, \
|
||||
const uint16_t *__restrict__ scale2_ptr, \
|
||||
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
|
||||
const int32_t *__restrict__ sorted_token_ids_ptr, \
|
||||
const int32_t *__restrict__ expert_ids_ptr, \
|
||||
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
|
||||
const float *__restrict__ topk_weights_ptr, int top_k, \
|
||||
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
|
||||
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
|
||||
bool use_fp32_reduce, int max_shared_mem
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
||||
const vllm::ScalarTypeId s_type_id, // weight scale ScalarType id
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the
|
||||
|
||||
@ -280,7 +280,6 @@ __device__ inline void wait_negative_and_add(int* lock) {
|
||||
|
||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
||||
const vllm::ScalarTypeId s_type_id, // weight scale ScalarType id
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the
|
||||
@ -300,7 +299,6 @@ __global__ void Marlin(
|
||||
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
||||
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
||||
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
|
||||
const int4* __restrict__ b_bias_ptr,
|
||||
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
||||
// (k/groupsize)xn
|
||||
const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4
|
||||
@ -320,9 +318,8 @@ __global__ void Marlin(
|
||||
int prob_n, // output dimension n
|
||||
int prob_k, // reduction dimension k
|
||||
int* locks, // extra global storage for barrier synchronization
|
||||
bool has_bias,
|
||||
bool use_atomic_add, // whether to use atomic add to reduce
|
||||
bool use_fp32_reduce, // whether to use fp32 global reduce
|
||||
bool use_atomic_add, // whether to use atomic add to reduce
|
||||
bool use_fp32_reduce, // whether to use fp32 global reduce
|
||||
int max_shared_mem) {
|
||||
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
|
||||
// same size, which might involve multiple column "slices" (of width 16 *
|
||||
@ -345,23 +342,12 @@ __global__ void Marlin(
|
||||
|
||||
extern __shared__ int4 sh[];
|
||||
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
|
||||
static constexpr auto s_type = vllm::ScalarType::from_id(s_type_id);
|
||||
if constexpr (w_type == vllm::kFE2M1f) {
|
||||
static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 ||
|
||||
s_type == vllm::kFE8M0fnu && group_blocks == 2);
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
static_assert(s_type == vllm::kBFloat16);
|
||||
} else if constexpr (std::is_same<scalar_t, half>::value) {
|
||||
static_assert(s_type == vllm::kFloat16);
|
||||
}
|
||||
|
||||
constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8;
|
||||
constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 ||
|
||||
w_type == vllm::kU4B8 || w_type == vllm::kU8B128;
|
||||
// see comments of dequant.h for more details
|
||||
constexpr bool dequant_skip_flop =
|
||||
w_type == vllm::kFE4M3fn ||
|
||||
w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn ||
|
||||
!is_int_type ||
|
||||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
|
||||
has_zp && !is_zp_float && !(w_type == vllm::kU8);
|
||||
|
||||
@ -379,7 +365,6 @@ __global__ void Marlin(
|
||||
const int zp_expert_stride =
|
||||
is_zp_float ? prob_n * prob_k / group_size / 8
|
||||
: prob_n * prob_k / group_size / (pack_factor * 4);
|
||||
const int b_bias_expert_stride = prob_n / 8;
|
||||
|
||||
// parallel: num valid moe blocks
|
||||
int num_tokens_past_padded = num_tokens_past_padded_ptr[0];
|
||||
@ -490,7 +475,7 @@ __global__ void Marlin(
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int idx = tid4 * 4 + i;
|
||||
idx = idx < block_num_valid_tokens ? idx : 0;
|
||||
if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
|
||||
if constexpr (w_type == vllm::kFE2M1f) {
|
||||
sh_block_topk_weights[idx] = __hmul2(
|
||||
global_scale, Dtype::num2num2(Dtype::float2num(
|
||||
topk_weights_ptr[sh_block_sorted_ids[idx]])));
|
||||
@ -528,7 +513,7 @@ __global__ void Marlin(
|
||||
expert_id = expert_ids_ptr[block_id];
|
||||
}
|
||||
|
||||
if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
|
||||
if constexpr (w_type == vllm::kFE2M1f) {
|
||||
uint16_t val = scale2_ptr[expert_id];
|
||||
global_scale = Dtype::num2num2(*reinterpret_cast<scalar_t*>(&val));
|
||||
}
|
||||
@ -541,9 +526,6 @@ __global__ void Marlin(
|
||||
if constexpr (has_act_order) {
|
||||
g_idx += (expert_id - old_expert_id) * prob_k;
|
||||
}
|
||||
if (has_bias) {
|
||||
b_bias_ptr += (expert_id - old_expert_id) * b_bias_expert_stride;
|
||||
}
|
||||
|
||||
read_moe_block_data(block_id);
|
||||
};
|
||||
@ -739,7 +721,7 @@ __global__ void Marlin(
|
||||
|
||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
(threadIdx.x % 32) / 4;
|
||||
s_sh_rd = s_sh_rd * 2 + (warp_row / group_blocks) % 2;
|
||||
s_sh_rd = s_sh_rd * 2 + warp_row % 2;
|
||||
|
||||
} else if constexpr (group_blocks != -1)
|
||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
@ -752,18 +734,6 @@ __global__ void Marlin(
|
||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
(threadIdx.x % 32) % 4;
|
||||
|
||||
int bias_sh_rd;
|
||||
if constexpr (m_block_size_8) {
|
||||
bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
(threadIdx.x % 32) / 8;
|
||||
} else {
|
||||
bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
(threadIdx.x % 32) % 4;
|
||||
}
|
||||
|
||||
int bias_sh_wr = threadIdx.x;
|
||||
int bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x;
|
||||
|
||||
// Zero-points have the same read layout as the scales
|
||||
// (without column-wise case)
|
||||
constexpr int num_col_threads = 8;
|
||||
@ -823,19 +793,7 @@ __global__ void Marlin(
|
||||
constexpr int sh_b_size = stages * b_sh_stage;
|
||||
int4* sh_b = sh_new;
|
||||
int4* sh_red = sh_new;
|
||||
|
||||
constexpr int sh_size_b_red_min =
|
||||
(sh_red_size < sh_b_size ? sh_red_size : sh_b_size);
|
||||
constexpr int sh_size_b_red_max =
|
||||
(sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
|
||||
constexpr int sh_bias_size = (thread_n_blocks * 16 / 8);
|
||||
constexpr int sh_b_red_bias_size =
|
||||
sh_size_b_red_max > (sh_size_b_red_min + sh_bias_size)
|
||||
? sh_size_b_red_max
|
||||
: (sh_size_b_red_min + sh_bias_size);
|
||||
|
||||
int4* sh_bias = sh_new + sh_size_b_red_min;
|
||||
int4* sh_g_idx = sh_new + sh_b_red_bias_size;
|
||||
int4* sh_g_idx = sh_b + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
|
||||
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
|
||||
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
|
||||
: (stages * s_sh_stage);
|
||||
@ -845,9 +803,9 @@ __global__ void Marlin(
|
||||
static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <=
|
||||
stages * b_sh_stage);
|
||||
int4* sh_a = sh_s + sh_s_size;
|
||||
constexpr int shm_size_used = moe_block_size +
|
||||
stages * (g_idx_stage + zp_sh_stage) +
|
||||
sh_s_size + sh_b_red_bias_size;
|
||||
constexpr int shm_size_used =
|
||||
moe_block_size + stages * (g_idx_stage + zp_sh_stage) + sh_s_size +
|
||||
(sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
|
||||
|
||||
// all remaining shared memory is used to cache A (input)
|
||||
// sh_a_max_row is at least ` stages * 16 * thread_m_blocks `
|
||||
@ -858,8 +816,7 @@ __global__ void Marlin(
|
||||
FragA frag_a[2][thread_m_blocks];
|
||||
I4 frag_b_quant[2][b_thread_vecs];
|
||||
FragC frag_c[thread_m_blocks][4][2];
|
||||
FragS frag_s[2][4]; // No act-order
|
||||
FragS frag_bias[2][4];
|
||||
FragS frag_s[2][4]; // No act-order
|
||||
FragS act_frag_s[2][4][4]; // For act-order
|
||||
int frag_qzp[2][num_ints_per_thread]; // Zero-points
|
||||
FragZP frag_zp; // Zero-points in fp16
|
||||
@ -1108,15 +1065,10 @@ __global__ void Marlin(
|
||||
if constexpr (w_type_id != vllm::kFE2M1f.id()) {
|
||||
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
|
||||
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
|
||||
} else if constexpr (group_blocks == 1 || thread_k_blocks > 4) {
|
||||
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
|
||||
reinterpret_cast<int2*>(
|
||||
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
|
||||
} else {
|
||||
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
|
||||
reinterpret_cast<int2*>(
|
||||
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride) +
|
||||
k % 2];
|
||||
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1329,9 +1281,9 @@ __global__ void Marlin(
|
||||
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
|
||||
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];
|
||||
|
||||
dequant_fp8_scales<scalar_t2, s_type_id>(
|
||||
s_quant_0, reinterpret_cast<scalar_t2*>(&frag_s[k2]));
|
||||
dequant_fp8_scales<scalar_t2, s_type_id>(
|
||||
dequant_fp8_scales<scalar_t2>(s_quant_0,
|
||||
reinterpret_cast<scalar_t2*>(&frag_s[k2]));
|
||||
dequant_fp8_scales<scalar_t2>(
|
||||
s_quant_1, reinterpret_cast<scalar_t2*>(&frag_s[k2]) + 2);
|
||||
}
|
||||
|
||||
@ -1614,7 +1566,7 @@ __global__ void Marlin(
|
||||
// Write out the reduce final result in the correct layout. We only actually
|
||||
// reshuffle matrix fragments in this step, the reduction above is performed
|
||||
// in fragment layout.
|
||||
auto write_result = [&](bool last) {
|
||||
auto write_result = [&]() {
|
||||
int c_gl_stride = prob_n / 8;
|
||||
constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
|
||||
int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
|
||||
@ -1640,7 +1592,7 @@ __global__ void Marlin(
|
||||
|
||||
// We first reorder in shared memory to guarantee the most efficient final
|
||||
// global write patterns
|
||||
auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) {
|
||||
auto write = [&](int idx, float c0, float c1, FragS& s) {
|
||||
scalar_t2 res =
|
||||
Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));
|
||||
|
||||
@ -1649,27 +1601,14 @@ __global__ void Marlin(
|
||||
if constexpr (!has_act_order && group_blocks == -1 &&
|
||||
w_type.size_bits() == 4 &&
|
||||
(has_zp && dequant_skip_flop || !has_zp)) {
|
||||
scalar_t2 tmp_scale = s[0];
|
||||
if constexpr (m_block_size_8) {
|
||||
tmp_scale = Dtype::num2num2(
|
||||
reinterpret_cast<scalar_t*>(&s[0])[(threadIdx.x % 8) / 4]);
|
||||
}
|
||||
res = __hmul2(res, tmp_scale);
|
||||
res = __hmul2(res, s[0]);
|
||||
}
|
||||
|
||||
if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
|
||||
if constexpr (w_type == vllm::kFE2M1f) {
|
||||
if (!mul_topk_weights) {
|
||||
res = __hmul2(res, global_scale);
|
||||
}
|
||||
}
|
||||
if (has_bias && last) {
|
||||
scalar_t2 tmp_bias = b_bias[0];
|
||||
if constexpr (m_block_size_8) {
|
||||
tmp_bias = Dtype::num2num2(
|
||||
reinterpret_cast<scalar_t*>(&b_bias[0])[(threadIdx.x % 8) / 4]);
|
||||
}
|
||||
res = __hadd2(res, tmp_bias);
|
||||
}
|
||||
|
||||
if constexpr (m_block_size_8) {
|
||||
((scalar_t*)sh_red)[idx] = res.x;
|
||||
@ -1687,25 +1626,19 @@ __global__ void Marlin(
|
||||
if constexpr (m_block_size_8) {
|
||||
int wr = c_sh_wr + 16 * j;
|
||||
write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1],
|
||||
frag_s[j / 2][2 * (j % 2) + 0],
|
||||
frag_bias[j / 2][2 * (j % 2) + 0]);
|
||||
frag_s[j / 2][2 * (j % 2) + 0]);
|
||||
write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3],
|
||||
frag_s[j / 2][2 * (j % 2) + 1],
|
||||
frag_bias[j / 2][2 * (j % 2) + 1]);
|
||||
frag_s[j / 2][2 * (j % 2) + 1]);
|
||||
} else {
|
||||
int wr = c_sh_wr + 8 * j;
|
||||
write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
|
||||
frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0],
|
||||
frag_bias[j / 2][2 * (j % 2) + 0]);
|
||||
frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
|
||||
write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
|
||||
frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0],
|
||||
frag_bias[j / 2][2 * (j % 2) + 0]);
|
||||
frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
|
||||
write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
|
||||
frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1],
|
||||
frag_bias[j / 2][2 * (j % 2) + 1]);
|
||||
frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
|
||||
write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
|
||||
frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1],
|
||||
frag_bias[j / 2][2 * (j % 2) + 1]);
|
||||
frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
|
||||
}
|
||||
}
|
||||
c_sh_wr += 16 * (4 * c_sh_stride);
|
||||
@ -1872,14 +1805,6 @@ __global__ void Marlin(
|
||||
}
|
||||
|
||||
thread_block_reduce();
|
||||
|
||||
if (has_bias && last) {
|
||||
__syncthreads();
|
||||
cp_async4_pred(&sh_bias[bias_sh_wr], &b_bias_ptr[bias_gl_rd],
|
||||
threadIdx.x < 16 * thread_n_blocks / 8);
|
||||
cp_async_fence();
|
||||
}
|
||||
|
||||
if constexpr (!has_act_order && group_blocks == -1 &&
|
||||
(has_zp && dequant_skip_flop || !has_zp)) {
|
||||
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
|
||||
@ -1942,20 +1867,11 @@ __global__ void Marlin(
|
||||
}
|
||||
barrier_release(&locks[locks_off], last);
|
||||
}
|
||||
|
||||
if (has_bias && last) {
|
||||
cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
reinterpret_cast<int4*>(&frag_bias)[0] = sh_bias[bias_sh_rd];
|
||||
reinterpret_cast<int4*>(&frag_bias)[1] = sh_bias[bias_sh_rd + 4];
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (use_atomic_add && slice_count > 1 && slice_idx != 0)
|
||||
wait_negative_and_add(&locks[locks_off]);
|
||||
if (last || use_atomic_add)
|
||||
// only the last block in a slice actually writes the result
|
||||
write_result(last);
|
||||
write_result();
|
||||
int old_slice_row = slice_row;
|
||||
slice_row = 0;
|
||||
slice_col_par++;
|
||||
@ -1988,7 +1904,6 @@ __global__ void Marlin(
|
||||
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
|
||||
}
|
||||
|
||||
bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x;
|
||||
// Update slice k/n for scales loading
|
||||
if constexpr (has_act_order) {
|
||||
slice_k_start = tb_k * slice_row;
|
||||
|
||||
@ -51,9 +51,8 @@ __global__ void permute_cols_kernel(
|
||||
} // namespace marlin
|
||||
|
||||
torch::Tensor moe_wna16_marlin_gemm(
|
||||
torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
|
||||
torch::Tensor& b_q_weight,
|
||||
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
|
||||
torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none,
|
||||
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
|
||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
|
||||
@ -213,7 +212,7 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
|
||||
// Get B size
|
||||
int tb_k = th_config.thread_k;
|
||||
int tb_n = th_config.thread_n;
|
||||
int tb_m = thread_m_blocks * 16;
|
||||
int tb_m = thread_m_blocks * (m_block_size_8 ? 8 : 16);
|
||||
|
||||
// shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights
|
||||
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
|
||||
@ -221,11 +220,6 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
|
||||
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;
|
||||
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
|
||||
int sh_red_size = tb_m * (tb_n + 8) * 2;
|
||||
int sh_bias_size = tb_n * 2;
|
||||
int tmp_size =
|
||||
(sh_b_size > sh_red_size ? sh_red_size : sh_b_size) + sh_bias_size;
|
||||
tmp_size = max(max(sh_b_size, sh_red_size), tmp_size);
|
||||
|
||||
int sh_s_size =
|
||||
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
|
||||
group_size, has_act_order, is_k_full);
|
||||
@ -240,8 +234,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
|
||||
sh_zp_size = sh_s_size / 2;
|
||||
}
|
||||
|
||||
int total_size = tmp_size + sh_a_size + sh_s_size + sh_zp_size +
|
||||
sh_g_idx_size + sh_block_meta_size;
|
||||
int total_size = max(sh_b_size, sh_red_size) + sh_a_size + sh_s_size +
|
||||
sh_zp_size + sh_g_idx_size + sh_block_meta_size;
|
||||
|
||||
return total_size;
|
||||
}
|
||||
@ -276,25 +270,20 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
|
||||
int cache_size = get_kernel_cache_size(
|
||||
th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k,
|
||||
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float);
|
||||
return cache_size + 512 <= max_shared_mem;
|
||||
return cache_size <= max_shared_mem;
|
||||
}
|
||||
|
||||
#define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||
M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \
|
||||
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
|
||||
thread_n_blocks == THREAD_N_BLOCKS && \
|
||||
thread_k_blocks == THREAD_K_BLOCKS && \
|
||||
m_block_size_8 == M_BLOCK_SIZE_8 && \
|
||||
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
|
||||
is_zp_float == IS_ZP_FLOAT) { \
|
||||
constexpr auto S_TYPE = \
|
||||
W_TYPE == vllm::kFE2M1f \
|
||||
? (GROUP_BLOCKS == 1 ? vllm::kFE4M3fn : vllm::kFE8M0fnu) \
|
||||
: (std::is_same<scalar_t, half>::value ? vllm::kFloat16 \
|
||||
: vllm::kBFloat16); \
|
||||
kernel = Marlin<scalar_t, W_TYPE.id(), S_TYPE.id(), NUM_THREADS, \
|
||||
THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||
M_BLOCK_SIZE_8, pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
|
||||
#define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||
M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \
|
||||
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
|
||||
thread_n_blocks == THREAD_N_BLOCKS && \
|
||||
thread_k_blocks == THREAD_K_BLOCKS && \
|
||||
m_block_size_8 == M_BLOCK_SIZE_8 && \
|
||||
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
|
||||
is_zp_float == IS_ZP_FLOAT) { \
|
||||
kernel = Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
|
||||
THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, \
|
||||
pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
|
||||
}
|
||||
|
||||
// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false)
|
||||
@ -346,45 +335,31 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
|
||||
\
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
|
||||
|
||||
#define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
|
||||
|
||||
#define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
|
||||
|
||||
#define FP4_GET_IF(W_TYPE) \
|
||||
FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||
FP4_GET_IF_M234(W_TYPE, 8, 4, 128)
|
||||
|
||||
#define BIGGROUP_GET_IF(W_TYPE) \
|
||||
BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128)
|
||||
|
||||
#define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
|
||||
|
||||
#define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
|
||||
|
||||
#define NVFP4_GET_IF(W_TYPE) \
|
||||
NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||
NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128)
|
||||
|
||||
#define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
|
||||
|
||||
#define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
|
||||
|
||||
#define MXFP4_GET_IF(W_TYPE) \
|
||||
MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||
MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128)
|
||||
|
||||
// We currently have 4-bit models only with group_blocks == 4
|
||||
#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \
|
||||
@ -433,17 +408,12 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
|
||||
COMMON_GET_IF(vllm::kU4B8)
|
||||
COMMON_GET_IF(vllm::kU8B128)
|
||||
|
||||
NVFP4_GET_IF(vllm::kFE2M1f)
|
||||
|
||||
BIGGROUP_GET_IF(vllm::kFE4M3fn)
|
||||
|
||||
FP4_GET_IF(vllm::kFE2M1f)
|
||||
|
||||
ACT_GET_IF(vllm::kU4B8)
|
||||
ACT_GET_IF(vllm::kU8B128)
|
||||
if (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
if (false) {
|
||||
}
|
||||
MXFP4_GET_IF(vllm::kFE2M1f)
|
||||
}
|
||||
|
||||
return kernel;
|
||||
}
|
||||
@ -512,16 +482,16 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
void* s, void* s2, void* zp, void* g_idx, void* perm,
|
||||
void* a_tmp, void* sorted_token_ids, void* expert_ids,
|
||||
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
void* s2, void* zp, void* g_idx, void* perm, void* a_tmp,
|
||||
void* sorted_token_ids, void* expert_ids,
|
||||
void* num_tokens_past_padded, void* topk_weights,
|
||||
int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep,
|
||||
int prob_m, int prob_n, int prob_k, void* workspace,
|
||||
vllm::ScalarType const& q_type, bool has_bias,
|
||||
bool has_act_order, bool is_k_full, bool has_zp, int num_groups,
|
||||
int group_size, int dev, cudaStream_t stream, int thread_k,
|
||||
int thread_n, int sms, bool use_atomic_add, bool use_fp32_reduce,
|
||||
vllm::ScalarType const& q_type, bool has_act_order,
|
||||
bool is_k_full, bool has_zp, int num_groups, int group_size,
|
||||
int dev, cudaStream_t stream, int thread_k, int thread_n,
|
||||
int sms, bool use_atomic_add, bool use_fp32_reduce,
|
||||
bool is_zp_float) {
|
||||
int thread_m_blocks = div_ceil(moe_block_size, 16);
|
||||
bool m_block_size_8 = moe_block_size == 8;
|
||||
@ -568,7 +538,6 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
const int4* B_ptr = (const int4*)B;
|
||||
int4* C_ptr = (int4*)C;
|
||||
int4* C_tmp_ptr = (int4*)C_tmp;
|
||||
const int4* bias_ptr = (const int4*)b_bias;
|
||||
const int4* s_ptr = (const int4*)s;
|
||||
const uint16_t* s2_ptr = (const uint16_t*)s2;
|
||||
const int4* zp_ptr = (const int4*)zp;
|
||||
@ -679,10 +648,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
// avoid ">>>" being formatted to "> > >"
|
||||
// clang-format off
|
||||
kernel<<<blocks, num_threads, max_shared_mem, stream>>>(
|
||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr,
|
||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr,
|
||||
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
|
||||
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
|
||||
prob_n, prob_k, locks, has_bias, use_atomic_add, use_fp32_reduce, max_shared_mem);
|
||||
prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce, max_shared_mem);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@ -690,8 +659,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
|
||||
torch::Tensor moe_wna16_marlin_gemm(
|
||||
torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none,
|
||||
torch::Tensor& b_q_weight,
|
||||
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
|
||||
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
|
||||
std::optional<torch::Tensor> const& global_scale_or_none,
|
||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||
@ -798,6 +766,7 @@ torch::Tensor moe_wna16_marlin_gemm(
|
||||
num_groups = b_scales.size(1);
|
||||
|
||||
torch::Tensor g_idx, perm, a_tmp;
|
||||
;
|
||||
if (g_idx_or_none.has_value() && perm_or_none.has_value()) {
|
||||
g_idx = g_idx_or_none.value();
|
||||
perm = perm_or_none.value();
|
||||
@ -846,24 +815,12 @@ torch::Tensor moe_wna16_marlin_gemm(
|
||||
torch::Tensor global_scale;
|
||||
if (global_scale_or_none.has_value()) {
|
||||
global_scale = global_scale_or_none.value();
|
||||
TORCH_CHECK(b_q_type == vllm::kFE2M1f && group_size == 16,
|
||||
"global_scale can only be used for nvfp4 format.");
|
||||
TORCH_CHECK(b_q_type == vllm::kFE2M1f,
|
||||
"global_scale can only be used for float4_e2m1f.");
|
||||
} else {
|
||||
global_scale = torch::empty({0}, options);
|
||||
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f && group_size == 16),
|
||||
"the global_scale parameter must be passed for nvfp4 format.");
|
||||
}
|
||||
|
||||
bool has_bias = b_bias_or_none.has_value();
|
||||
torch::Tensor b_bias;
|
||||
if (has_bias) {
|
||||
b_bias = b_bias_or_none.value();
|
||||
TORCH_CHECK(b_bias.device().is_cuda(), "b_bias is not on GPU");
|
||||
TORCH_CHECK(b_bias.is_contiguous(), "b_bias is not contiguous");
|
||||
TORCH_CHECK(b_bias.size(1) == size_n, "b_bias.size(0) != size_n");
|
||||
TORCH_CHECK(b_bias.stride(1) == 1, "b_bias.stride(1) != 1");
|
||||
} else {
|
||||
b_bias = torch::empty({0}, options);
|
||||
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f),
|
||||
"the global_scale parameter must be passed for float4_e2m1f.");
|
||||
}
|
||||
|
||||
torch::Tensor b_zeros;
|
||||
@ -875,6 +832,7 @@ torch::Tensor moe_wna16_marlin_gemm(
|
||||
b_zeros = torch::empty({0}, options);
|
||||
}
|
||||
bool has_zp = b_zeros.size(-1) > 0;
|
||||
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(
|
||||
b_q_type == vllm::kU4 || b_q_type == vllm::kU8,
|
||||
@ -932,58 +890,41 @@ torch::Tensor moe_wna16_marlin_gemm(
|
||||
if (a.scalar_type() == at::ScalarType::Half) {
|
||||
void* scales_ptr;
|
||||
if (b_q_type == vllm::kFE2M1f) {
|
||||
if (group_size == 16)
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
|
||||
else if (group_size == 32)
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
|
||||
else
|
||||
TORCH_CHECK(false,
|
||||
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
|
||||
"and group_size == 32 (MXFP4)");
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
|
||||
} else {
|
||||
scales_ptr = b_scales.data_ptr<at::Half>();
|
||||
}
|
||||
|
||||
MARLIN_NAMESPACE_NAME::marlin_mm<half>(
|
||||
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
|
||||
c_tmp.data_ptr<float>(), b_bias.data_ptr<at::Half>(), scales_ptr,
|
||||
global_scale.data_ptr<at::Half>(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
||||
perm.data_ptr(), a_tmp.data_ptr<at::Half>(),
|
||||
sorted_token_ids.data_ptr(), expert_ids.data_ptr(),
|
||||
num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(),
|
||||
moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k,
|
||||
workspace.data_ptr(), b_q_type, has_bias, has_act_order, is_k_full,
|
||||
has_zp, num_groups, group_size, dev,
|
||||
c_tmp.data_ptr<float>(), scales_ptr, global_scale.data_ptr<at::Half>(),
|
||||
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
|
||||
a_tmp.data_ptr<at::Half>(), sorted_token_ids.data_ptr(),
|
||||
expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(),
|
||||
topk_weights.data_ptr(), moe_block_size, top_k, mul_topk_weights, is_ep,
|
||||
size_m, size_n, size_k, workspace.data_ptr(), b_q_type, has_act_order,
|
||||
is_k_full, has_zp, num_groups, group_size, dev,
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||
use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||
void* scales_ptr;
|
||||
if (b_q_type == vllm::kFE2M1f) {
|
||||
if (group_size == 16)
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
|
||||
else if (group_size == 32)
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
|
||||
else
|
||||
TORCH_CHECK(false,
|
||||
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
|
||||
"and group_size == 32 (MXFP4)");
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
|
||||
} else {
|
||||
scales_ptr = b_scales.data_ptr<at::BFloat16>();
|
||||
}
|
||||
|
||||
MARLIN_NAMESPACE_NAME::marlin_mm<nv_bfloat16>(
|
||||
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
||||
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
|
||||
b_bias.data_ptr<at::BFloat16>(), scales_ptr,
|
||||
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(), scales_ptr,
|
||||
global_scale.data_ptr<at::BFloat16>(), b_zeros.data_ptr(),
|
||||
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
|
||||
sorted_token_ids.data_ptr(), expert_ids.data_ptr(),
|
||||
num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(),
|
||||
moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k,
|
||||
workspace.data_ptr(), b_q_type, has_bias, has_act_order, is_k_full,
|
||||
has_zp, num_groups, group_size, dev,
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||
use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
|
||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||
} else {
|
||||
TORCH_CHECK(false,
|
||||
"moe_wna16_marlin_gemm only supports bfloat16 and float16");
|
||||
|
||||
@ -45,6 +45,8 @@ void moe_permute(
|
||||
auto copy_topk_ids = topk_ids.clone(); // copy topk_ids for preprocess
|
||||
auto permuted_experts_id = torch::empty_like(topk_ids);
|
||||
auto sorted_row_idx = torch::empty_like(inv_permuted_idx);
|
||||
auto align_expert_first_token_offset =
|
||||
torch::zeros_like(expert_first_token_offset);
|
||||
|
||||
CubKeyValueSorter sorter{};
|
||||
int64_t* valid_num_ptr = nullptr;
|
||||
@ -83,14 +85,12 @@ void moe_permute(
|
||||
});
|
||||
|
||||
// get m_indices and update expert_first_token_offset with align block
|
||||
// this is only required for DeepGemm and not required for CUTLASS group gemm
|
||||
getMIndices(get_ptr<int64_t>(expert_first_token_offset),
|
||||
get_ptr<int64_t>(align_expert_first_token_offset),
|
||||
get_ptr<int>(m_indices), n_local_expert, align_block_size_value,
|
||||
stream);
|
||||
if (align_block_size.has_value()) {
|
||||
auto align_expert_first_token_offset =
|
||||
torch::zeros_like(expert_first_token_offset);
|
||||
getMIndices(get_ptr<int64_t>(expert_first_token_offset),
|
||||
get_ptr<int64_t>(align_expert_first_token_offset),
|
||||
get_ptr<int>(m_indices), n_local_expert, align_block_size_value,
|
||||
stream);
|
||||
// update align_expert_first_token_offset
|
||||
expert_first_token_offset.copy_(align_expert_first_token_offset);
|
||||
}
|
||||
}
|
||||
@ -195,14 +195,19 @@ void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights,
|
||||
torch::Tensor& expert_first_token_offset,
|
||||
torch::Tensor& src_row_id2dst_row_id_map,
|
||||
torch::Tensor& m_indices) {
|
||||
TORCH_CHECK(false, "moe_permute is not supported on CUDA < 12.0");
|
||||
TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0");
|
||||
}
|
||||
|
||||
void moe_unpermute(
|
||||
const torch::Tensor& permuted_hidden_states,
|
||||
const torch::Tensor& topk_weights, const torch::Tensor& inv_permuted_idx,
|
||||
const std::optional<torch::Tensor>& expert_first_token_offset, int64_t topk,
|
||||
torch::Tensor& hidden_states) {
|
||||
void moe_unpermute(const torch::Tensor& input,
|
||||
const torch::Tensor& topk_weights, torch::Tensor& topk_ids,
|
||||
const torch::Tensor& token_expert_indices,
|
||||
const std::optional<torch::Tensor>& expert_map,
|
||||
int64_t n_expert, int64_t n_local_expert, int64_t topk,
|
||||
const std::optional<int64_t>& align_block_size,
|
||||
torch::Tensor& permuted_input,
|
||||
torch::Tensor& expert_first_token_offset,
|
||||
torch::Tensor& src_row_id2dst_row_id_map,
|
||||
torch::Tensor& m_indices) {
|
||||
TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0");
|
||||
}
|
||||
|
||||
@ -219,4 +224,4 @@ bool moe_permute_unpermute_supported() {
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("moe_permute", &moe_permute);
|
||||
m.impl("moe_unpermute", &moe_unpermute);
|
||||
}
|
||||
}
|
||||
|
||||
@ -35,8 +35,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
|
||||
m.def(
|
||||
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
|
||||
"Tensor! b_q_weight, Tensor? b_bias_or_none,"
|
||||
"Tensor! b_scales, Tensor? global_scale, Tensor? "
|
||||
"Tensor! b_q_weight, Tensor! b_scales, Tensor? global_scale, Tensor? "
|
||||
"b_zeros_or_none,"
|
||||
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
|
||||
"Tensor sorted_token_ids,"
|
||||
|
||||
32
csrc/ops.h
32
csrc/ops.h
@ -138,8 +138,6 @@ void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
void fatrelu_and_mul(torch::Tensor& out, torch::Tensor& input,
|
||||
double threshold);
|
||||
void swigluoai_and_mul(torch::Tensor& out, torch::Tensor& input,
|
||||
double alpha = 1.702, double limit = 7.0);
|
||||
|
||||
void gelu_new(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
@ -147,6 +145,22 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
void gelu_quick(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
void advance_step_flashattn(int64_t num_seqs, int64_t num_queries,
|
||||
int64_t block_size, torch::Tensor& input_tokens,
|
||||
torch::Tensor& sampled_token_ids,
|
||||
torch::Tensor& input_positions,
|
||||
torch::Tensor& seq_lens,
|
||||
torch::Tensor& slot_mapping,
|
||||
torch::Tensor& block_tables);
|
||||
|
||||
void advance_step_flashinfer(
|
||||
int64_t num_seqs, int64_t num_queries, int64_t block_size,
|
||||
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
|
||||
torch::Tensor& input_positions, torch::Tensor& seq_lens,
|
||||
torch::Tensor& slot_mapping, torch::Tensor& block_tables,
|
||||
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
|
||||
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);
|
||||
|
||||
void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
|
||||
torch::Tensor const& q_pe,
|
||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||
@ -156,6 +170,15 @@ void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
|
||||
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
const torch::Tensor& scales,
|
||||
const std::vector<int64_t>& codebook_partition_sizes,
|
||||
const std::optional<torch::Tensor>& bias);
|
||||
|
||||
torch::Tensor aqlm_dequant(
|
||||
const torch::Tensor& codes, const torch::Tensor& codebooks,
|
||||
const std::vector<int64_t>& codebook_partition_sizes);
|
||||
|
||||
torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors, torch::Tensor _zeros,
|
||||
@ -229,11 +252,6 @@ void get_cutlass_moe_mm_data(
|
||||
const int64_t num_experts, const int64_t n, const int64_t k,
|
||||
const std::optional<torch::Tensor>& blockscale_offsets);
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
|
||||
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets);
|
||||
|
||||
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
|
||||
336
csrc/prepare_inputs/advance_step.cu
Normal file
336
csrc/prepare_inputs/advance_step.cu
Normal file
@ -0,0 +1,336 @@
|
||||
/*
|
||||
* The goal of this GPU kernel is to advance input tensors on the GPU directly
|
||||
* PR: https://github.com/vllm-project/vllm/pull/6338
|
||||
* Current restrictions:
|
||||
* 1. Specialized for DraftModelRunner
|
||||
* 2. Supports flash_attn only
|
||||
*/
|
||||
|
||||
#include "advance_step.cuh"
|
||||
|
||||
namespace prepare_inputs {
|
||||
|
||||
//
|
||||
template <int const num_threads>
|
||||
__global__ void advance_step_flashattn_kernel(
|
||||
int num_seqs, int num_queries, int block_size, long* input_tokens_ptr,
|
||||
long const* sampled_token_ids_ptr, long* input_positions_ptr,
|
||||
int* seq_lens_ptr, long* slot_mapping_ptr, int const* block_tables_ptr,
|
||||
int64_t const block_tables_stride) {
|
||||
int const n_pad = num_seqs - num_queries;
|
||||
if (n_pad && blockIdx.x == 0) {
|
||||
// Handle cuda graph padding
|
||||
int const offset = num_queries;
|
||||
for (int i = threadIdx.x; i < n_pad; i += blockDim.x) {
|
||||
input_tokens_ptr[offset + i] = 0;
|
||||
input_positions_ptr[offset + i] = 0;
|
||||
slot_mapping_ptr[offset + i] = -1;
|
||||
}
|
||||
}
|
||||
|
||||
int num_query_blocks = div_ceil(num_queries, num_threads);
|
||||
|
||||
if (blockIdx.x >= num_query_blocks) {
|
||||
return;
|
||||
}
|
||||
|
||||
int cur_query_id = blockIdx.x * num_threads + threadIdx.x;
|
||||
|
||||
if (cur_query_id >= num_queries) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Update input_tokens
|
||||
input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id];
|
||||
|
||||
int seq_len = seq_lens_ptr[cur_query_id];
|
||||
int next_seq_len = seq_len + 1;
|
||||
int next_input_pos = next_seq_len - 1;
|
||||
|
||||
// Update seq_lens
|
||||
seq_lens_ptr[cur_query_id] = next_seq_len;
|
||||
// Update input_positions
|
||||
input_positions_ptr[cur_query_id] = next_input_pos;
|
||||
|
||||
int const* seq_block_tables_ptr =
|
||||
block_tables_ptr + block_tables_stride * cur_query_id;
|
||||
|
||||
int block_index = next_input_pos / block_size;
|
||||
int block_offset = next_input_pos % block_size;
|
||||
|
||||
int slot_num = seq_block_tables_ptr[block_index] * block_size + block_offset;
|
||||
// Update slot_mapping
|
||||
slot_mapping_ptr[cur_query_id] = slot_num;
|
||||
}
|
||||
|
||||
inline void verify_tensor(std::string const& name, torch::Tensor const& t,
|
||||
int64_t const size_0, int64_t const size_1,
|
||||
c10::ScalarType const type) {
|
||||
bool size_0_cond = true;
|
||||
if (size_0 != -1) {
|
||||
size_0_cond = t.size(0) == size_0;
|
||||
}
|
||||
|
||||
bool size_1_cond = true;
|
||||
if (size_1 != -1) {
|
||||
size_1_cond = t.size(1) == size_1;
|
||||
}
|
||||
|
||||
bool is_contiguous = t.is_contiguous();
|
||||
bool same_type = t.dtype() == type;
|
||||
|
||||
bool pass = size_0_cond && size_1_cond && is_contiguous && same_type;
|
||||
if (!pass) {
|
||||
TORCH_CHECK(false, "tensor: name = ", name, ", shape = ", t.sizes(),
|
||||
" is_cont = ", t.is_contiguous(), ", type = ", t.dtype(),
|
||||
" is not as expected: shape = [", size_0, ", ", size_1,
|
||||
"], type = ", type);
|
||||
}
|
||||
}
|
||||
|
||||
/// each thread processes a block per query
|
||||
__global__ void advance_step_flashinfer_kernel(
|
||||
int num_threads, int num_seqs, int num_queries, int block_size,
|
||||
long* input_tokens_ptr, long const* sampled_token_ids_ptr,
|
||||
long* input_positions_ptr, int* seq_lens_ptr, long* slot_mapping_ptr,
|
||||
int const* block_tables_ptr, int64_t const block_tables_stride,
|
||||
int* paged_kv_last_page_len_ptr, int* block_table_bound_ptr) {
|
||||
int const n_pad = num_seqs - num_queries;
|
||||
if (n_pad && blockIdx.x == 0) {
|
||||
// Handle cuda graph padding
|
||||
int const offset = num_queries;
|
||||
for (int i = threadIdx.x; i < n_pad; i += blockDim.x) {
|
||||
input_tokens_ptr[offset + i] = 0;
|
||||
input_positions_ptr[offset + i] = 0;
|
||||
slot_mapping_ptr[offset + i] = -1;
|
||||
}
|
||||
}
|
||||
int num_query_blocks = div_ceil(num_queries, num_threads);
|
||||
|
||||
if (blockIdx.x < num_query_blocks) {
|
||||
int cur_query_id = blockIdx.x * num_threads + threadIdx.x;
|
||||
|
||||
if (cur_query_id < num_queries) {
|
||||
// Update input_tokens
|
||||
input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id];
|
||||
|
||||
int seq_len = seq_lens_ptr[cur_query_id];
|
||||
int next_seq_len = seq_len + 1;
|
||||
int next_input_pos = next_seq_len - 1;
|
||||
|
||||
// Update seq_lens
|
||||
seq_lens_ptr[cur_query_id] = next_seq_len;
|
||||
// Update input_positions
|
||||
input_positions_ptr[cur_query_id] = next_input_pos;
|
||||
|
||||
int const* seq_block_tables_ptr =
|
||||
block_tables_ptr + block_tables_stride * cur_query_id;
|
||||
|
||||
int block_index = next_input_pos / block_size;
|
||||
int block_offset = next_input_pos % block_size;
|
||||
|
||||
// Update paged_kv_last_page_len
|
||||
paged_kv_last_page_len_ptr[cur_query_id] = block_offset + 1;
|
||||
|
||||
int slot_num =
|
||||
seq_block_tables_ptr[block_index] * block_size + block_offset;
|
||||
// Update slot_mapping
|
||||
slot_mapping_ptr[cur_query_id] = slot_num;
|
||||
block_table_bound_ptr[cur_query_id] = div_ceil(next_seq_len, block_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void advance_step_flashinfer_indptr_kernel(
|
||||
int num_threads, int num_seqs, int num_queries, int* paged_kv_indptr_ptr,
|
||||
int* block_table_bound_ptr) {
|
||||
int idx = blockIdx.x * num_threads + threadIdx.x;
|
||||
// Update paged_kv_indptr
|
||||
if (idx == 0) {
|
||||
paged_kv_indptr_ptr[idx] = 0;
|
||||
}
|
||||
if (idx < num_queries) {
|
||||
int sum = 0;
|
||||
for (int i = 0; i <= idx; ++i) {
|
||||
sum += block_table_bound_ptr[i];
|
||||
}
|
||||
paged_kv_indptr_ptr[idx + 1] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void advance_step_flashinfer_indices_kernel(
|
||||
int num_seqs, int num_queries, int const* block_tables_ptr,
|
||||
int64_t const max_num_blocks_per_seq, int* paged_kv_indices_ptr,
|
||||
int* paged_kv_indptr_ptr, int* block_table_bound_ptr) {
|
||||
// note: max_num_blocks_per_seq = block_tables.stride(0)
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
// when cuda graphs are enabled, paged_kv_indptr tensor
|
||||
// has to be updated for the padded queries
|
||||
// tid represents a query# for paged_kv_indptr tensor
|
||||
if (num_queries < tid && tid <= num_seqs) {
|
||||
paged_kv_indptr_ptr[tid] = paged_kv_indptr_ptr[num_queries];
|
||||
}
|
||||
|
||||
// each thread processes a block_ptr in block_tables
|
||||
// block_tables shape: [num_queries, max_num_blocks_per_seq]
|
||||
// paged_kv_indices is flattened block_tables.
|
||||
for (int idx = tid; idx < (num_seqs * max_num_blocks_per_seq);
|
||||
idx += (gridDim.x * blockDim.x)) {
|
||||
// block_tables-row = paged_kv_indptr[queryNum]
|
||||
int queryNum = idx / max_num_blocks_per_seq;
|
||||
int col = idx % max_num_blocks_per_seq;
|
||||
if (queryNum < num_queries && col < block_table_bound_ptr[queryNum]) {
|
||||
int indices_arr_idx = paged_kv_indptr_ptr[queryNum] + col;
|
||||
int block_tables_idx = queryNum * max_num_blocks_per_seq + col;
|
||||
paged_kv_indices_ptr[indices_arr_idx] =
|
||||
block_tables_ptr[block_tables_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void advance_step_flashattn(int num_seqs, int num_queries, int block_size,
|
||||
torch::Tensor& input_tokens, // type: long
|
||||
torch::Tensor& sampled_token_ids, // type: long
|
||||
torch::Tensor& input_positions, // type: long
|
||||
torch::Tensor& seq_lens, // type: int
|
||||
torch::Tensor& slot_mapping, // type: long
|
||||
torch::Tensor& block_tables) { // type: int
|
||||
|
||||
if (logging) {
|
||||
printf("advance_step_flashattn:\n");
|
||||
printf(" num_seqs = %d\n", num_seqs);
|
||||
printf(" num_queries = %d\n", num_queries);
|
||||
printf(" block_size = %d\n", block_size);
|
||||
}
|
||||
// Verify all tensors
|
||||
verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong);
|
||||
verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1,
|
||||
at::kLong);
|
||||
verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong);
|
||||
verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt);
|
||||
verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong);
|
||||
verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt);
|
||||
|
||||
int dev = sampled_token_ids.get_device();
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
|
||||
|
||||
int blocks;
|
||||
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
||||
|
||||
advance_step_flashattn_kernel<max_threads>
|
||||
<<<blocks, max_threads, 0, stream>>>(
|
||||
num_seqs, num_queries, block_size,
|
||||
reinterpret_cast<long*>(input_tokens.data_ptr()),
|
||||
reinterpret_cast<long const*>(sampled_token_ids.data_ptr()),
|
||||
reinterpret_cast<long*>(input_positions.data_ptr()),
|
||||
reinterpret_cast<int*>(seq_lens.data_ptr()),
|
||||
reinterpret_cast<long*>(slot_mapping.data_ptr()),
|
||||
reinterpret_cast<int const*>(block_tables.data_ptr()),
|
||||
block_tables.stride(0));
|
||||
}
|
||||
|
||||
void advance_step_flashinfer(
|
||||
int num_seqs, int num_queries, int block_size,
|
||||
torch::Tensor& input_tokens, // type: long
|
||||
torch::Tensor& sampled_token_ids, // type: long
|
||||
torch::Tensor& input_positions, // type: long
|
||||
torch::Tensor& seq_lens, // type: int
|
||||
torch::Tensor& slot_mapping, // type: long
|
||||
torch::Tensor& block_tables, // type: int
|
||||
torch::Tensor& paged_kv_indices, // type: int
|
||||
torch::Tensor& paged_kv_indptr, // type: int
|
||||
torch::Tensor& paged_kv_last_page_len, // type: int
|
||||
torch::Tensor& block_table_bound) { // type: int
|
||||
|
||||
if (logging) {
|
||||
printf("advance_step_flashinfer:\n");
|
||||
printf(" num_seqs = %d\n", num_seqs);
|
||||
printf(" num_queries = %d\n", num_queries);
|
||||
printf(" block_size = %d\n", block_size);
|
||||
printf(" block_tables.stride(0) = %zu\n", block_tables.stride(0));
|
||||
}
|
||||
// Verify all tensors
|
||||
verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong);
|
||||
// verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1,
|
||||
// at::kLong);
|
||||
verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong);
|
||||
verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt);
|
||||
verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong);
|
||||
verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt);
|
||||
|
||||
verify_tensor("paged_kv_indices", paged_kv_indices, -1, -1, at::kInt);
|
||||
verify_tensor("paged_kv_indptr", paged_kv_indptr, num_seqs + 1, -1, at::kInt);
|
||||
verify_tensor("paged_kv_last_page_len", paged_kv_last_page_len, num_seqs, -1,
|
||||
at::kInt);
|
||||
|
||||
verify_tensor("block_table_bound", block_table_bound, num_seqs, -1, at::kInt);
|
||||
|
||||
int dev = sampled_token_ids.get_device();
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
|
||||
|
||||
int blocks;
|
||||
int threads;
|
||||
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
||||
cudaDeviceGetAttribute(&threads, cudaDevAttrMaxThreadsPerBlock, dev);
|
||||
|
||||
TORCH_CHECK((blocks * threads > num_queries),
|
||||
"multi-step: not enough threads to map to num_queries = ",
|
||||
num_queries, " block_tables.stride(0) = ", block_tables.stride(0),
|
||||
" blocks = ", blocks, " max_threads = ", threads);
|
||||
if (logging) {
|
||||
printf("launching kernels with %d blocks and %d threads\n", blocks,
|
||||
threads);
|
||||
}
|
||||
advance_step_flashinfer_kernel<<<blocks, threads, 0, stream>>>(
|
||||
threads, num_seqs, num_queries, block_size,
|
||||
reinterpret_cast<long*>(input_tokens.data_ptr()),
|
||||
reinterpret_cast<long const*>(sampled_token_ids.data_ptr()),
|
||||
reinterpret_cast<long*>(input_positions.data_ptr()),
|
||||
reinterpret_cast<int*>(seq_lens.data_ptr()),
|
||||
reinterpret_cast<long*>(slot_mapping.data_ptr()),
|
||||
reinterpret_cast<int const*>(block_tables.data_ptr()),
|
||||
block_tables.stride(0),
|
||||
reinterpret_cast<int*>(paged_kv_last_page_len.data_ptr()),
|
||||
reinterpret_cast<int*>(block_table_bound.data_ptr()));
|
||||
|
||||
advance_step_flashinfer_indptr_kernel<<<blocks, threads, 0, stream>>>(
|
||||
threads, num_seqs, num_queries,
|
||||
reinterpret_cast<int*>(paged_kv_indptr.data_ptr()),
|
||||
reinterpret_cast<int*>(block_table_bound.data_ptr()));
|
||||
|
||||
advance_step_flashinfer_indices_kernel<<<blocks, threads, 0, stream>>>(
|
||||
num_seqs, num_queries,
|
||||
reinterpret_cast<int const*>(block_tables.data_ptr()),
|
||||
block_tables.stride(0),
|
||||
reinterpret_cast<int*>(paged_kv_indices.data_ptr()),
|
||||
reinterpret_cast<int*>(paged_kv_indptr.data_ptr()),
|
||||
reinterpret_cast<int*>(block_table_bound.data_ptr()));
|
||||
}
|
||||
|
||||
} // namespace prepare_inputs
|
||||
|
||||
void advance_step_flashattn(int64_t num_seqs, int64_t num_queries,
|
||||
int64_t block_size, torch::Tensor& input_tokens,
|
||||
torch::Tensor& sampled_token_ids,
|
||||
torch::Tensor& input_positions,
|
||||
torch::Tensor& seq_lens,
|
||||
torch::Tensor& slot_mapping,
|
||||
torch::Tensor& block_tables) {
|
||||
prepare_inputs::advance_step_flashattn(
|
||||
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
|
||||
input_positions, seq_lens, slot_mapping, block_tables);
|
||||
}
|
||||
|
||||
void advance_step_flashinfer(
|
||||
int64_t num_seqs, int64_t num_queries, int64_t block_size,
|
||||
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
|
||||
torch::Tensor& input_positions, torch::Tensor& seq_lens,
|
||||
torch::Tensor& slot_mapping, torch::Tensor& block_tables,
|
||||
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
|
||||
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bound) {
|
||||
prepare_inputs::advance_step_flashinfer(
|
||||
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
|
||||
input_positions, seq_lens, slot_mapping, block_tables, paged_kv_indices,
|
||||
paged_kv_indptr, paged_kv_last_page_len, block_table_bound);
|
||||
}
|
||||
19
csrc/prepare_inputs/advance_step.cuh
Normal file
19
csrc/prepare_inputs/advance_step.cuh
Normal file
@ -0,0 +1,19 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace prepare_inputs {
|
||||
|
||||
static constexpr int max_threads = 256;
|
||||
static constexpr bool logging = false;
|
||||
|
||||
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
|
||||
|
||||
} // namespace prepare_inputs
|
||||
597
csrc/quantization/aqlm/gemm_kernels.cu
Normal file
597
csrc/quantization/aqlm/gemm_kernels.cu
Normal file
@ -0,0 +1,597 @@
|
||||
/*
|
||||
* Modified by Neural Magic
|
||||
* Adapted from https://github.com/Vahe1994/AQLM
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/all.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <cstdlib>
|
||||
|
||||
namespace vllm {
|
||||
namespace aqlm {
|
||||
|
||||
__global__ void Code1x16MatVec(
|
||||
const int4* __restrict__ A, const int4* __restrict__ B,
|
||||
int4* __restrict__ C, const int4* __restrict__ codebook, const int prob_m,
|
||||
const int prob_k,
|
||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each
|
||||
// codebook, at most 3 long.
|
||||
const int codebook_stride // as int4.
|
||||
) {
|
||||
int a_gl_stride = prob_k / 8 / 8;
|
||||
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
||||
bool pred = a_gl_rd < prob_m;
|
||||
|
||||
if (pred) {
|
||||
// advance to the correct codebook, this easy because we only multiply one
|
||||
// column of the codebook.
|
||||
auto codebook_size = &codebook_a_sizes.x;
|
||||
while (a_gl_rd >= *codebook_size) {
|
||||
codebook += codebook_stride;
|
||||
++codebook_size;
|
||||
}
|
||||
}
|
||||
|
||||
int b_gl_rd = 0;
|
||||
int c_gl_wr = a_gl_rd;
|
||||
a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
|
||||
int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
|
||||
|
||||
__shared__ int4 sh_b[32 * 9];
|
||||
float res = 0;
|
||||
|
||||
int iters = (prob_k / 8 + 8 * 32 - 1) / (8 * 32);
|
||||
while (iters--) {
|
||||
// We pad shared memory to avoid bank conflicts during reads
|
||||
__syncthreads();
|
||||
for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
|
||||
if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
|
||||
}
|
||||
__syncthreads();
|
||||
b_gl_rd += 32 * 8;
|
||||
|
||||
int b_sh_rd = 9 * (threadIdx.x % 32);
|
||||
if (pred && a_gl_rd < a_gl_end) {
|
||||
const uint16_t* enc = reinterpret_cast<const uint16_t*>(&A[a_gl_rd]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
uint32_t dec[4];
|
||||
// We bypass the L1 cache to avoid massive amounts of memory streaming
|
||||
// that doesn't actually help us; this brings > 2x speedup.
|
||||
asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
|
||||
: "l"((void*)&codebook[enc[i]]));
|
||||
half2* a = reinterpret_cast<half2*>(&dec);
|
||||
half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
|
||||
half2 res2 = {};
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++) res2 = __hfma2(a[j], b[j], res2);
|
||||
res += __half2float(res2.x) + __half2float(res2.y);
|
||||
b_sh_rd++;
|
||||
}
|
||||
a_gl_rd += 32;
|
||||
}
|
||||
}
|
||||
|
||||
if (pred) {
|
||||
#pragma unroll
|
||||
for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i);
|
||||
if (threadIdx.x % 32 == 0)
|
||||
reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void Code2x8MatVec(
|
||||
const int4* __restrict__ A, const int4* __restrict__ B,
|
||||
int4* __restrict__ C, const int4* __restrict__ codebook, int prob_m,
|
||||
int prob_k,
|
||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each
|
||||
// codebook, at most 3 long.
|
||||
const int codebook_stride // as int4.
|
||||
|
||||
) {
|
||||
int a_gl_stride = prob_k / 8 / 8;
|
||||
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
||||
bool pred = a_gl_rd < prob_m;
|
||||
|
||||
if (pred) {
|
||||
// advance to the correct codebook, this easy because we only multiply one
|
||||
// column of the codebook.
|
||||
auto codebook_size = &codebook_a_sizes.x;
|
||||
while (a_gl_rd >= *codebook_size) {
|
||||
codebook += codebook_stride;
|
||||
++codebook_size;
|
||||
}
|
||||
}
|
||||
|
||||
int b_gl_rd = 0;
|
||||
int c_gl_wr = a_gl_rd;
|
||||
a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
|
||||
int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
|
||||
int lane = threadIdx.x % 8;
|
||||
|
||||
extern __shared__ int4 sh[];
|
||||
int4* sh_b = sh;
|
||||
int4* sh_code = sh_b + 32 * 9;
|
||||
int4* sh_code0 = sh_code;
|
||||
int4* sh_code1 = sh_code + 256 * 8;
|
||||
|
||||
for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
|
||||
int4 dec = codebook[i];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float res = 0;
|
||||
|
||||
int iters = (prob_k / 8 + 8 * 32 - 1) / (8 * 32);
|
||||
while (iters--) {
|
||||
// We pad shared memory to avoid bank conflicts during reads
|
||||
__syncthreads();
|
||||
for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
|
||||
if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
|
||||
}
|
||||
__syncthreads();
|
||||
b_gl_rd += 32 * 8;
|
||||
|
||||
int b_sh_rd = 9 * (threadIdx.x % 32);
|
||||
if (pred && a_gl_rd < a_gl_end) {
|
||||
const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
half2* a0 =
|
||||
reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
|
||||
half2* a1 =
|
||||
reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
|
||||
half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
|
||||
half2 res2 = {};
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++)
|
||||
res2 = __hfma2(__hadd2(a0[j], a1[j]), b[j], res2);
|
||||
res += __half2float(res2.x) + __half2float(res2.y);
|
||||
b_sh_rd++;
|
||||
}
|
||||
a_gl_rd += 32;
|
||||
}
|
||||
}
|
||||
|
||||
if (pred) {
|
||||
#pragma unroll
|
||||
for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i);
|
||||
if (threadIdx.x % 32 == 0)
|
||||
reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void Code1x16Dequant(
|
||||
const int4* __restrict__ A, int4* __restrict__ C,
|
||||
const int4* __restrict__ codebook, int prob_m, int prob_k,
|
||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each
|
||||
// codebook, at most 3 long, sums to m.
|
||||
const int codebook_stride // as int4
|
||||
) {
|
||||
int a_gl_stride = prob_k / 8 / 8;
|
||||
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
||||
bool pred = a_gl_rd < prob_m;
|
||||
|
||||
if (pred) {
|
||||
// advance to the correct codebook, this easy because we only multiply one
|
||||
// column of the codebook.
|
||||
auto codebook_size = &codebook_a_sizes.x;
|
||||
while (a_gl_rd >= *codebook_size) {
|
||||
codebook += codebook_stride;
|
||||
++codebook_size;
|
||||
}
|
||||
}
|
||||
|
||||
a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
|
||||
int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
|
||||
|
||||
int c_gl_stride = prob_k / 8;
|
||||
int c_gl_wr = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
||||
c_gl_wr = c_gl_stride * c_gl_wr + (threadIdx.x % 32) * 8;
|
||||
|
||||
int iters = (prob_k / 8 - 1) / (8 * 32) + 1;
|
||||
while (iters--) {
|
||||
if (pred && a_gl_rd < a_gl_end) {
|
||||
const uint16_t* enc = reinterpret_cast<const uint16_t*>(&A[a_gl_rd]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
int4 chunk;
|
||||
auto dec = reinterpret_cast<uint32_t*>(&chunk);
|
||||
// We bypass the L1 cache to avoid massive amounts of memory streaming
|
||||
// that doesn't actually help us; this brings > 2x speedup.
|
||||
asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
|
||||
: "l"((void*)&codebook[enc[i]]));
|
||||
|
||||
C[a_gl_rd * 8 + i] = chunk;
|
||||
}
|
||||
}
|
||||
a_gl_rd += 32;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void Code2x8Dequant(
|
||||
const int4* __restrict__ A, int4* __restrict__ C,
|
||||
const int4* __restrict__ codebook, int prob_m, int prob_k,
|
||||
const int4
|
||||
codebook_a_sizes, // cumulative sizes of A spanning each codebook, at
|
||||
// most 3 long, corresponds to cols.
|
||||
const int codebook_stride // as int4
|
||||
) {
|
||||
int a_gl_stride = prob_k / 8 / 8;
|
||||
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
||||
bool pred = a_gl_rd < prob_m;
|
||||
|
||||
if (pred) {
|
||||
// advance to the correct codebook, this easy because we only multiply one
|
||||
// column of the codebook.
|
||||
auto codebook_size = &codebook_a_sizes.x;
|
||||
while (a_gl_rd >= *codebook_size) {
|
||||
codebook += codebook_stride;
|
||||
++codebook_size;
|
||||
}
|
||||
}
|
||||
|
||||
a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
|
||||
int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
|
||||
int lane = threadIdx.x % 8;
|
||||
|
||||
int c_gl_stride = prob_k / 8;
|
||||
int c_gl_wr = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
|
||||
c_gl_wr = c_gl_stride * c_gl_wr + (threadIdx.x % 32) * 8;
|
||||
|
||||
extern __shared__ int4 sh[];
|
||||
int4* sh_code = sh;
|
||||
int4* sh_code0 = sh_code;
|
||||
int4* sh_code1 = sh_code + 256 * 8;
|
||||
|
||||
for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
|
||||
int4 dec = codebook[i];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int iters = (prob_k / 8 - 1) / (8 * 32) + 1;
|
||||
while (iters--) {
|
||||
if (pred && a_gl_rd < a_gl_end) {
|
||||
const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
int4 chunk;
|
||||
half2* a0 =
|
||||
reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
|
||||
half2* a1 =
|
||||
reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++)
|
||||
reinterpret_cast<half2*>(&chunk)[j] = __hadd2(a0[j], a1[j]);
|
||||
C[a_gl_rd * 8 + i] = chunk;
|
||||
}
|
||||
}
|
||||
a_gl_rd += 32;
|
||||
}
|
||||
}
|
||||
|
||||
inline int ceildiv(int a, int b) { return (a + b - 1) / b; }
|
||||
|
||||
const int THREAD_M = 16;
|
||||
|
||||
void code1x16_matvec_cuda(const void* __restrict__ A,
|
||||
const void* __restrict__ B, void* __restrict__ C,
|
||||
const void* __restrict__ codebook, int prob_m,
|
||||
int prob_k, const int4 codebook_a_sizes,
|
||||
const int codebook_stride) {
|
||||
int sms;
|
||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
|
||||
int waves = 0;
|
||||
int thread_m;
|
||||
do {
|
||||
waves++;
|
||||
thread_m = ceildiv(prob_m, waves * sms);
|
||||
} while (thread_m > THREAD_M);
|
||||
|
||||
int blocks = ceildiv(prob_m, thread_m);
|
||||
int threads = 32 * thread_m;
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
Code1x16MatVec<<<blocks, threads, 16 * 32 * 9, stream>>>(
|
||||
(const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m,
|
||||
prob_k, codebook_a_sizes, codebook_stride);
|
||||
}
|
||||
|
||||
void code2x8_matvec_cuda(const void* __restrict__ A, const void* __restrict__ B,
|
||||
void* __restrict__ C,
|
||||
const void* __restrict__ codebook, int prob_m,
|
||||
int prob_k, const int4 codebook_a_sizes,
|
||||
const int codebook_stride) {
|
||||
int sms;
|
||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
|
||||
int waves = 0;
|
||||
int thread_m;
|
||||
do {
|
||||
waves++;
|
||||
thread_m = ceildiv(prob_m, waves * sms);
|
||||
} while (thread_m > THREAD_M);
|
||||
|
||||
int blocks = ceildiv(prob_m, thread_m);
|
||||
int threads = 32 * thread_m;
|
||||
int shared = 16 * (2 * 256 * 8 + 32 * 9);
|
||||
cudaFuncSetAttribute(Code2x8MatVec,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, shared);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
Code2x8MatVec<<<blocks, threads, shared, stream>>>(
|
||||
(const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m,
|
||||
prob_k, codebook_a_sizes, codebook_stride);
|
||||
}
|
||||
|
||||
void code1x16_dequant_cuda(
|
||||
const void* __restrict__ A, void* __restrict__ C,
|
||||
const void* __restrict__ codebook, int prob_m, int prob_k,
|
||||
const int4 codebook_a_sizes, // cumulative sizes of A spanning each
|
||||
// codebook, at most 3 long.
|
||||
const int codebook_stride // as int4.
|
||||
) {
|
||||
int sms;
|
||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
|
||||
int waves = 0;
|
||||
int thread_m;
|
||||
do {
|
||||
waves++;
|
||||
thread_m = ceildiv(prob_m, waves * sms);
|
||||
} while (thread_m > THREAD_M);
|
||||
|
||||
int blocks = ceildiv(prob_m, thread_m);
|
||||
int threads = 32 * thread_m;
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
Code1x16Dequant<<<blocks, threads, 0, stream>>>(
|
||||
(const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k,
|
||||
codebook_a_sizes, // cumulative sizes of A spanning each codebook, at
|
||||
// most 3 long.
|
||||
codebook_stride // as int4.
|
||||
);
|
||||
}
|
||||
|
||||
// Dequantizes the code and codebook into weights.
|
||||
void code2x8_dequant_cuda(
|
||||
const void* __restrict__ A, void* __restrict__ C,
|
||||
const void* __restrict__ codebook, int prob_m, int prob_k,
|
||||
const int4
|
||||
codebook_a_sizes, // cumulative sizes of A spanning each codebook, at
|
||||
// most 3 long, corresponds to cols.
|
||||
const int codebook_stride // as int4
|
||||
) {
|
||||
int sms;
|
||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
|
||||
int waves = 0;
|
||||
int thread_m;
|
||||
do {
|
||||
waves++;
|
||||
thread_m = ceildiv(prob_m, waves * sms);
|
||||
} while (thread_m > THREAD_M);
|
||||
|
||||
int blocks = ceildiv(prob_m, thread_m);
|
||||
int threads = 32 * thread_m;
|
||||
int shared = 16 * (2 * 256 * 8 + 32 * 9);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
cudaFuncSetAttribute(Code2x8Dequant,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, shared);
|
||||
Code2x8Dequant<<<blocks, threads, shared, stream>>>(
|
||||
(const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k,
|
||||
codebook_a_sizes, codebook_stride);
|
||||
}
|
||||
|
||||
int codebook_stride(const torch::Tensor& codebooks) {
|
||||
return codebooks.stride(0) * codebooks.element_size() / sizeof(int4);
|
||||
}
|
||||
|
||||
void code1x16_matvec(
|
||||
const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& C,
|
||||
const torch::Tensor& codebook,
|
||||
const int4 codebook_a_sizes // cumulative sizes of A spanning each
|
||||
// codebook, at most 3 long.
|
||||
) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
||||
int prob_m = C.size(0);
|
||||
int prob_k = B.size(0);
|
||||
|
||||
code1x16_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(),
|
||||
codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes,
|
||||
codebook_stride(codebook));
|
||||
}
|
||||
|
||||
torch::Tensor code1x16_matmat(const torch::Tensor& input,
|
||||
const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
const torch::Tensor& scales,
|
||||
const int4 codebook_a_sizes,
|
||||
const std::optional<torch::Tensor>& bias) {
|
||||
auto input_sizes = input.sizes();
|
||||
auto out_features = codes.size(0) * codebooks.size(2);
|
||||
auto flat_input = input.reshape({-1, input.size(-1)});
|
||||
auto flat_output = torch::empty(
|
||||
{flat_input.size(0), out_features},
|
||||
torch::TensorOptions().dtype(input.dtype()).device(input.device()));
|
||||
|
||||
for (int i = 0; i < flat_input.size(0); ++i) {
|
||||
auto input_vec = flat_input.index({i});
|
||||
auto output_vec = flat_output.index({i});
|
||||
code1x16_matvec(codes.squeeze(2), input_vec, output_vec, codebooks,
|
||||
codebook_a_sizes);
|
||||
}
|
||||
flat_output *= scales.flatten().unsqueeze(0);
|
||||
|
||||
if (bias.has_value()) {
|
||||
flat_output += bias->unsqueeze(0);
|
||||
}
|
||||
|
||||
auto output_sizes = input_sizes.vec();
|
||||
output_sizes.pop_back();
|
||||
output_sizes.push_back(-1);
|
||||
auto output = flat_output.reshape(output_sizes);
|
||||
return output;
|
||||
}
|
||||
|
||||
void code2x8_matvec(const torch::Tensor& A, const torch::Tensor& B,
|
||||
torch::Tensor& C, const torch::Tensor& codebook,
|
||||
const int4 codebook_a_sizes) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
||||
int prob_m = C.size(0);
|
||||
int prob_k = B.size(0);
|
||||
code2x8_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(),
|
||||
codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes,
|
||||
2 * codebook_stride(codebook));
|
||||
}
|
||||
|
||||
torch::Tensor code2x8_matmat(const torch::Tensor& input,
|
||||
const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
const torch::Tensor& scales,
|
||||
const int4 codebook_a_sizes,
|
||||
const std::optional<torch::Tensor>& bias) {
|
||||
auto input_sizes = input.sizes();
|
||||
auto out_features = codes.size(0) * codebooks.size(2);
|
||||
auto flat_input = input.reshape({-1, input.size(-1)});
|
||||
auto flat_output = torch::empty(
|
||||
{flat_input.size(0), out_features},
|
||||
torch::TensorOptions().dtype(input.dtype()).device(input.device()));
|
||||
|
||||
for (int i = 0; i < flat_input.size(0); ++i) {
|
||||
auto input_vec = flat_input.index({i});
|
||||
auto output_vec = flat_output.index({i});
|
||||
code2x8_matvec(codes.squeeze(2), input_vec, output_vec, codebooks,
|
||||
codebook_a_sizes);
|
||||
}
|
||||
flat_output *= scales.flatten().unsqueeze(0);
|
||||
if (bias.has_value()) {
|
||||
flat_output += bias->unsqueeze(0);
|
||||
}
|
||||
|
||||
auto output_sizes = input_sizes.vec();
|
||||
output_sizes.pop_back();
|
||||
output_sizes.push_back(-1);
|
||||
auto output = flat_output.reshape(output_sizes);
|
||||
return output;
|
||||
}
|
||||
|
||||
// Accumulate the partition sizes.
|
||||
int4 accumulate_sizes(const std::vector<int64_t>& codebook_partition_sizes) {
|
||||
int4 cumulative_sizes;
|
||||
auto cumulative_size = &cumulative_sizes.x;
|
||||
size_t i = 0;
|
||||
int last = 0;
|
||||
assert(codebook_partition_sizes.size() <= 4);
|
||||
for (; i < codebook_partition_sizes.size(); ++i, ++cumulative_size) {
|
||||
*cumulative_size = codebook_partition_sizes[i] + last;
|
||||
last = *cumulative_size;
|
||||
}
|
||||
// fill in the rest with unreachable.
|
||||
for (; i < 4; ++i, ++cumulative_size) {
|
||||
*cumulative_size = last * 10;
|
||||
}
|
||||
return cumulative_sizes;
|
||||
}
|
||||
|
||||
} // namespace aqlm
|
||||
} // namespace vllm
|
||||
|
||||
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
const torch::Tensor& scales,
|
||||
const std::vector<int64_t>& codebook_partition_sizes,
|
||||
const std::optional<torch::Tensor>& bias) {
|
||||
int4 cumulative_sizes =
|
||||
vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
|
||||
|
||||
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size();
|
||||
int const entries = codebooks.size(1);
|
||||
|
||||
if (nbooks == 1 && entries == (1 << 16)) {
|
||||
return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales,
|
||||
cumulative_sizes, bias);
|
||||
}
|
||||
if (nbooks == 2 && entries == (1 << 8)) {
|
||||
return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales,
|
||||
cumulative_sizes, bias);
|
||||
}
|
||||
|
||||
TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries,
|
||||
" entries is not currently supported.")
|
||||
return {};
|
||||
}
|
||||
|
||||
torch::Tensor aqlm_dequant(
|
||||
const torch::Tensor& codes, const torch::Tensor& codebooks,
|
||||
const std::vector<int64_t>& codebook_partition_sizes) {
|
||||
int4 cumulative_sizes =
|
||||
vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
|
||||
|
||||
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size();
|
||||
int const entries = codebooks.size(1);
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(codes));
|
||||
int rows = codes.size(1);
|
||||
int cols = codes.size(0);
|
||||
|
||||
auto in_features = codes.size(1) * 8;
|
||||
auto out_features = codes.size(0);
|
||||
|
||||
assert(out_features == std::accumulate(codebook_partition_sizes.begin(),
|
||||
codebook_partition_sizes.end(), 0));
|
||||
|
||||
auto weights = torch::empty({out_features, in_features},
|
||||
torch::TensorOptions()
|
||||
.dtype(codebooks.dtype())
|
||||
.device(codebooks.device()));
|
||||
|
||||
if (nbooks == 1 && entries == (1 << 16)) {
|
||||
vllm::aqlm::code1x16_dequant_cuda(codes.data_ptr(), weights.data_ptr(),
|
||||
codebooks.data_ptr(), out_features,
|
||||
in_features, cumulative_sizes,
|
||||
vllm::aqlm::codebook_stride(codebooks));
|
||||
|
||||
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower
|
||||
// and not consistent with gemv implementation.) weights *=
|
||||
// scales.index({"...", 0, 0});
|
||||
|
||||
return weights;
|
||||
}
|
||||
|
||||
if (nbooks == 2 && entries == (1 << 8)) {
|
||||
vllm::aqlm::code2x8_dequant_cuda(codes.data_ptr(), weights.data_ptr(),
|
||||
codebooks.data_ptr(), out_features,
|
||||
in_features, cumulative_sizes,
|
||||
vllm::aqlm::codebook_stride(codebooks));
|
||||
|
||||
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower
|
||||
// and not consistent with gemv implementation) weights *=
|
||||
// scales.index({"...", 0, 0});
|
||||
|
||||
return weights;
|
||||
}
|
||||
|
||||
TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries,
|
||||
" entries is not currently supported.")
|
||||
return {};
|
||||
}
|
||||
@ -10,7 +10,7 @@
|
||||
|
||||
template <typename ElementAB, typename ElementC, typename ElementAccumulator>
|
||||
__global__ void get_group_gemm_starts(
|
||||
int64_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets,
|
||||
int32_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets,
|
||||
ElementC** out_offsets, ElementAccumulator** a_scales_offsets,
|
||||
ElementAccumulator** b_scales_offsets, ElementAB* a_base_as_int,
|
||||
ElementAB* b_base_as_int, ElementC* out_base_as_int,
|
||||
@ -34,7 +34,7 @@ __global__ void get_group_gemm_starts(
|
||||
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
|
||||
get_group_gemm_starts<cutlass::float_e4m3_t, C_TYPE, float> \
|
||||
<<<1, num_experts, 0, stream>>>( \
|
||||
static_cast<int64_t*>(expert_offsets.data_ptr()), \
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()), \
|
||||
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
|
||||
@ -61,8 +61,6 @@ void run_get_group_gemm_starts(
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
// expect int64_t to avoid overflow during offset calculations
|
||||
TORCH_CHECK(expert_offsets.dtype() == torch::kInt64);
|
||||
|
||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||
bool per_act_token = a_scales.numel() != 1;
|
||||
|
||||
@ -104,53 +104,6 @@ __global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& atomic_buffer,
|
||||
int64_t num_experts, int64_t n,
|
||||
int64_t k, cudaStream_t stream,
|
||||
const bool swap_ab) {
|
||||
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
||||
|
||||
const int32_t* topk_ptr = static_cast<const int32_t*>(topk_ids.data_ptr());
|
||||
int32_t* ps1_ptr = static_cast<int32_t*>(problem_sizes1.data_ptr());
|
||||
int32_t* ps2_ptr = static_cast<int32_t*>(problem_sizes2.data_ptr());
|
||||
int32_t* atomic_ptr = static_cast<int32_t*>(atomic_buffer.data_ptr());
|
||||
|
||||
if (swap_ab) {
|
||||
compute_problem_sizes<true><<<num_experts, num_threads, 0, stream>>>(
|
||||
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
|
||||
static_cast<int>(topk_ids.numel()), static_cast<int>(n),
|
||||
static_cast<int>(k));
|
||||
} else {
|
||||
compute_problem_sizes<false><<<num_experts, num_threads, 0, stream>>>(
|
||||
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
|
||||
static_cast<int>(topk_ids.numel()), static_cast<int>(n),
|
||||
static_cast<int>(k));
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes_caller(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
|
||||
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets) {
|
||||
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
|
||||
auto options_int32 =
|
||||
torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
|
||||
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
|
||||
|
||||
// Swap-AB should be disabled for FP4 path
|
||||
bool may_swap_ab = (!blockscale_offsets.has_value()) &&
|
||||
(topk_ids.numel() <= SWAP_AB_THRESHOLD);
|
||||
|
||||
launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2,
|
||||
atomic_buffer, num_experts, n, k, stream,
|
||||
may_swap_ab);
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_data_caller(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
@ -168,9 +121,21 @@ void get_cutlass_moe_mm_data_caller(
|
||||
bool may_swap_ab = (!blockscale_offsets.has_value()) &&
|
||||
(topk_ids.numel() <= SWAP_AB_THRESHOLD);
|
||||
|
||||
launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2,
|
||||
atomic_buffer, num_experts, n, k, stream,
|
||||
may_swap_ab);
|
||||
if (may_swap_ab) {
|
||||
compute_problem_sizes<true><<<num_experts, num_threads, 0, stream>>>(
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n,
|
||||
k);
|
||||
} else {
|
||||
compute_problem_sizes<false><<<num_experts, num_threads, 0, stream>>>(
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n,
|
||||
k);
|
||||
}
|
||||
|
||||
if (blockscale_offsets.has_value()) {
|
||||
// fp4 path
|
||||
@ -196,7 +161,6 @@ void get_cutlass_moe_mm_data_caller(
|
||||
topk_ids.size(1));
|
||||
}
|
||||
|
||||
template <bool SWAP_AB>
|
||||
__global__ void compute_pplx_data(int32_t* expert_offsets,
|
||||
int32_t* problem_sizes1,
|
||||
int32_t* problem_sizes2,
|
||||
@ -204,23 +168,14 @@ __global__ void compute_pplx_data(int32_t* expert_offsets,
|
||||
const int padded_m, const int n,
|
||||
const int k) {
|
||||
int expert_idx = threadIdx.x;
|
||||
expert_offsets[expert_idx] = expert_idx * padded_m;
|
||||
|
||||
if constexpr (!SWAP_AB) {
|
||||
problem_sizes1[expert_idx * 3] = expert_num_tokens[expert_idx];
|
||||
problem_sizes1[expert_idx * 3 + 1] = 2 * n;
|
||||
problem_sizes1[expert_idx * 3 + 2] = k;
|
||||
problem_sizes2[expert_idx * 3] = expert_num_tokens[expert_idx];
|
||||
problem_sizes2[expert_idx * 3 + 1] = k;
|
||||
problem_sizes2[expert_idx * 3 + 2] = n;
|
||||
} else {
|
||||
problem_sizes1[expert_idx * 3] = 2 * n;
|
||||
problem_sizes1[expert_idx * 3 + 1] = expert_num_tokens[expert_idx];
|
||||
problem_sizes1[expert_idx * 3 + 2] = k;
|
||||
problem_sizes2[expert_idx * 3] = k;
|
||||
problem_sizes2[expert_idx * 3 + 1] = expert_num_tokens[expert_idx];
|
||||
problem_sizes2[expert_idx * 3 + 2] = n;
|
||||
}
|
||||
expert_offsets[expert_idx] = expert_idx * padded_m;
|
||||
problem_sizes1[expert_idx * 3] = expert_num_tokens[expert_idx];
|
||||
problem_sizes1[expert_idx * 3 + 1] = 2 * n;
|
||||
problem_sizes1[expert_idx * 3 + 2] = k;
|
||||
problem_sizes2[expert_idx * 3] = expert_num_tokens[expert_idx];
|
||||
problem_sizes2[expert_idx * 3 + 1] = k;
|
||||
problem_sizes2[expert_idx * 3 + 2] = n;
|
||||
}
|
||||
|
||||
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
|
||||
@ -232,19 +187,10 @@ void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
|
||||
const int64_t n, const int64_t k) {
|
||||
auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index());
|
||||
|
||||
if (num_local_experts * padded_m > SWAP_AB_THRESHOLD) {
|
||||
compute_pplx_data<false><<<1, num_local_experts, 0, stream>>>(
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
|
||||
k);
|
||||
} else {
|
||||
compute_pplx_data<true><<<1, num_local_experts, 0, stream>>>(
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
|
||||
k);
|
||||
}
|
||||
compute_pplx_data<<<1, num_local_experts, 0, stream>>>(
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
|
||||
k);
|
||||
}
|
||||
@ -76,11 +76,6 @@ void get_cutlass_moe_mm_data_caller(
|
||||
const int64_t num_experts, const int64_t n, const int64_t k,
|
||||
const std::optional<torch::Tensor>& blockscale_offsets);
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes_caller(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
|
||||
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets);
|
||||
|
||||
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
@ -298,25 +293,6 @@ void get_cutlass_moe_mm_data(
|
||||
version_num, ". Required capability: 90 or 100");
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
|
||||
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets) {
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100)
|
||||
get_cutlass_moe_mm_problem_sizes_caller(topk_ids, problem_sizes1,
|
||||
problem_sizes2, num_experts, n, k,
|
||||
blockscale_offsets);
|
||||
return;
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled get_cutlass_moe_mm_problem_sizes: no cutlass_scaled_mm "
|
||||
"kernel for CUDA device capability: ",
|
||||
version_num, ". Required capability: 90 or 100");
|
||||
}
|
||||
|
||||
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
|
||||
@ -470,12 +470,11 @@ __device__ inline void dequant<nv_bfloat162, vllm::kFE2M1f.id(), false>(
|
||||
frag_b[0] = __hmul2(frag_b[0], bias_reg);
|
||||
}
|
||||
|
||||
template <typename scalar_t2, vllm::ScalarTypeId s_type_id>
|
||||
template <typename scalar_t2>
|
||||
__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b);
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant_fp8_scales<half2, vllm::kFE4M3fn.id()>(
|
||||
int q, half2* frag_b) {
|
||||
__device__ inline void dequant_fp8_scales<half2>(int q, half2* frag_b) {
|
||||
int Out1 = (q & 0xFF00FF00) >> 1;
|
||||
;
|
||||
q <<= 8;
|
||||
@ -487,8 +486,8 @@ __device__ inline void dequant_fp8_scales<half2, vllm::kFE4M3fn.id()>(
|
||||
};
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant_fp8_scales<nv_bfloat162, vllm::kFE4M3fn.id()>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
__device__ inline void dequant_fp8_scales<nv_bfloat162>(int q,
|
||||
nv_bfloat162* frag_b) {
|
||||
constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8;
|
||||
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
|
||||
constexpr int MASK = 0x7F007F00;
|
||||
@ -503,20 +502,6 @@ __device__ inline void dequant_fp8_scales<nv_bfloat162, vllm::kFE4M3fn.id()>(
|
||||
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant_fp8_scales<nv_bfloat162, vllm::kFE8M0fnu.id()>(
|
||||
int q, nv_bfloat162* frag_b) {
|
||||
// In this conversion, 2 ** -127 in FP8E8M0 would become 0 in BF16,
|
||||
// but we assume that such a extreme value would not occur in real models.
|
||||
int Out1 = (q & 0xFF00FF00) >> 1;
|
||||
q <<= 7;
|
||||
int Out2 = q & 0x7F807F80;
|
||||
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
|
||||
@ -20,7 +20,6 @@ namespace MARLIN_NAMESPACE_NAME {
|
||||
TEMPLATE = ("template __global__ void Marlin<"
|
||||
"{{scalar_t}}, "
|
||||
"{{w_type_id}}, "
|
||||
"{{s_type_id}}, "
|
||||
"{{threads}}, "
|
||||
"{{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, "
|
||||
@ -79,8 +78,7 @@ def generate_new_kernels():
|
||||
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
|
||||
continue
|
||||
# nvfp4 only supports group_size == 16
|
||||
# mxfp4 only supports group_size == 32
|
||||
if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]:
|
||||
if scalar_type == "vllm::kFE2M1f" and group_blocks != 1:
|
||||
continue
|
||||
# other quantization methods don't support group_size = 16
|
||||
if scalar_type != "vllm::kFE2M1f" and group_blocks == 1:
|
||||
@ -99,23 +97,10 @@ def generate_new_kernels():
|
||||
# 4bit quantization and fp16
|
||||
is_zp_float_list.append(True)
|
||||
|
||||
if scalar_type == "vllm::kFE2M1f" and group_blocks == 1:
|
||||
s_type = "vllm::kFE4M3fn"
|
||||
elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2:
|
||||
s_type = "vllm::kFE8M0fnu"
|
||||
if dtype == "fp16":
|
||||
# we cannot safely dequantize e8m0 to fp16, so skip this
|
||||
continue
|
||||
elif dtype == "fp16":
|
||||
s_type = "vllm::kFloat16"
|
||||
elif dtype == "bf16":
|
||||
s_type = "vllm::kBFloat16"
|
||||
|
||||
for is_zp_float in is_zp_float_list:
|
||||
template_str = jinja2.Template(TEMPLATE).render(
|
||||
scalar_t=c_dtype,
|
||||
w_type_id=scalar_type + ".id()",
|
||||
s_type_id=s_type + ".id()",
|
||||
threads=threads,
|
||||
thread_m_blocks=max(m_blocks, 1),
|
||||
thread_n_blocks=n_blocks,
|
||||
|
||||
@ -48,8 +48,7 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
|
||||
|
||||
torch::Tensor gptq_marlin_gemm(
|
||||
torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
|
||||
torch::Tensor& b_q_weight,
|
||||
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
|
||||
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
|
||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
|
||||
@ -188,12 +187,7 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
|
||||
int tb_m = thread_m_blocks * 16;
|
||||
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;
|
||||
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
|
||||
int sh_red_size = tb_m * (tb_n + 8) * 2;
|
||||
int sh_bias_size = tb_n * 2;
|
||||
int tmp_size =
|
||||
(sh_b_size > sh_red_size ? sh_red_size : sh_b_size) + sh_bias_size;
|
||||
tmp_size = max(max(sh_b_size, sh_red_size), tmp_size);
|
||||
|
||||
int sh_red_size = tb_m * (tb_n + 8);
|
||||
int sh_s_size =
|
||||
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
|
||||
group_size, has_act_order, is_k_full);
|
||||
@ -208,8 +202,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
|
||||
sh_zp_size = sh_s_size / 2;
|
||||
}
|
||||
|
||||
int total_size =
|
||||
tmp_size + sh_a_size + sh_s_size + sh_zp_size + sh_g_idx_size;
|
||||
int total_size = max(sh_b_size, sh_red_size) + sh_a_size + sh_s_size +
|
||||
sh_zp_size + sh_g_idx_size;
|
||||
|
||||
return total_size;
|
||||
}
|
||||
@ -243,25 +237,20 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
|
||||
int cache_size = get_kernel_cache_size(
|
||||
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp, is_zp_float);
|
||||
return cache_size + 512 <= max_shared_mem;
|
||||
return cache_size <= max_shared_mem;
|
||||
}
|
||||
|
||||
#define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||
M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \
|
||||
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
|
||||
thread_n_blocks == THREAD_N_BLOCKS && \
|
||||
thread_k_blocks == THREAD_K_BLOCKS && \
|
||||
m_block_size_8 == M_BLOCK_SIZE_8 && \
|
||||
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
|
||||
is_zp_float == IS_ZP_FLOAT) { \
|
||||
constexpr auto S_TYPE = \
|
||||
W_TYPE == vllm::kFE2M1f \
|
||||
? (GROUP_BLOCKS == 1 ? vllm::kFE4M3fn : vllm::kFE8M0fnu) \
|
||||
: (std::is_same<scalar_t, half>::value ? vllm::kFloat16 \
|
||||
: vllm::kBFloat16); \
|
||||
kernel = Marlin<scalar_t, W_TYPE.id(), S_TYPE.id(), NUM_THREADS, \
|
||||
THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||
M_BLOCK_SIZE_8, pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
|
||||
#define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||
M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \
|
||||
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
|
||||
thread_n_blocks == THREAD_N_BLOCKS && \
|
||||
thread_k_blocks == THREAD_K_BLOCKS && \
|
||||
m_block_size_8 == M_BLOCK_SIZE_8 && \
|
||||
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
|
||||
is_zp_float == IS_ZP_FLOAT) { \
|
||||
kernel = Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
|
||||
THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, \
|
||||
pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
|
||||
}
|
||||
|
||||
// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false)
|
||||
@ -326,39 +315,22 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
|
||||
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \
|
||||
BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128)
|
||||
|
||||
#define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
#define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
|
||||
|
||||
#define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
#define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false)
|
||||
|
||||
#define NVFP4_GET_IF(W_TYPE) \
|
||||
NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
NVFP4_GET_IF_M1(W_TYPE, 4, 8, 128) \
|
||||
NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||
NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128) \
|
||||
NVFP4_GET_IF_M234(W_TYPE, 4, 8, 128)
|
||||
|
||||
#define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
|
||||
|
||||
#define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
|
||||
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false)
|
||||
|
||||
#define MXFP4_GET_IF(W_TYPE) \
|
||||
MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
MXFP4_GET_IF_M1(W_TYPE, 4, 8, 128) \
|
||||
MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||
MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128) \
|
||||
MXFP4_GET_IF_M234(W_TYPE, 4, 8, 128)
|
||||
#define FP4_GET_IF(W_TYPE) \
|
||||
FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \
|
||||
FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \
|
||||
FP4_GET_IF_M1(W_TYPE, 4, 8, 128) \
|
||||
FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \
|
||||
FP4_GET_IF_M234(W_TYPE, 8, 4, 128) \
|
||||
FP4_GET_IF_M234(W_TYPE, 4, 8, 128)
|
||||
|
||||
// We currently have 4-bit models only with group_blocks == 4
|
||||
#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
@ -412,7 +384,7 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
|
||||
COMMON_GET_IF(vllm::kU4B8)
|
||||
COMMON_GET_IF(vllm::kU8B128)
|
||||
|
||||
NVFP4_GET_IF(vllm::kFE2M1f)
|
||||
FP4_GET_IF(vllm::kFE2M1f)
|
||||
|
||||
BIGGROUP_GET_IF(vllm::kFE4M3fn)
|
||||
|
||||
@ -424,11 +396,6 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
|
||||
}
|
||||
FZP_GET_IF(vllm::kU4)
|
||||
}
|
||||
if (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
if (false) {
|
||||
}
|
||||
MXFP4_GET_IF(vllm::kFE2M1f)
|
||||
}
|
||||
|
||||
return kernel;
|
||||
}
|
||||
@ -486,12 +453,12 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
void* s, void* s2, void* zp, void* g_idx, void* perm,
|
||||
void* a_tmp, int prob_m, int prob_n, int prob_k, int lda,
|
||||
void* workspace, vllm::ScalarType const& q_type, bool has_bias,
|
||||
bool has_act_order, bool is_k_full, bool has_zp, int num_groups,
|
||||
int group_size, int dev, cudaStream_t stream, int thread_k_init,
|
||||
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
void* s2, void* zp, void* g_idx, void* perm, void* a_tmp,
|
||||
int prob_m, int prob_n, int prob_k, int lda, void* workspace,
|
||||
vllm::ScalarType const& q_type, bool has_act_order,
|
||||
bool is_k_full, bool has_zp, int num_groups, int group_size,
|
||||
int dev, cudaStream_t stream, int thread_k_init,
|
||||
int thread_n_init, int sms, bool use_atomic_add,
|
||||
bool use_fp32_reduce, bool is_zp_float) {
|
||||
if (has_zp) {
|
||||
@ -536,7 +503,6 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
const int4* B_ptr = (const int4*)B;
|
||||
int4* C_ptr = (int4*)C;
|
||||
int4* C_tmp_ptr = (int4*)C_tmp;
|
||||
const int4* bias_ptr = (const int4*)b_bias;
|
||||
const int4* s_ptr = (const int4*)s;
|
||||
const uint16_t* s2_ptr = (const uint16_t*)s2;
|
||||
const int4* zp_ptr = (const int4*)zp;
|
||||
@ -657,9 +623,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
// avoid ">>>" being formatted to "> > >"
|
||||
// clang-format off
|
||||
kernel<<<blocks, num_threads, max_shared_mem_new, stream>>>(
|
||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, s_ptr, s2_ptr, zp_ptr,
|
||||
g_idx_ptr, num_groups,
|
||||
prob_m_split, prob_n, prob_k, lda, locks, has_bias, part_use_atomic_add,
|
||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr, num_groups,
|
||||
prob_m_split, prob_n, prob_k, lda, locks, part_use_atomic_add,
|
||||
use_fp32_reduce, max_shared_mem_new);
|
||||
// clang-format on
|
||||
|
||||
@ -673,8 +638,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
|
||||
|
||||
torch::Tensor gptq_marlin_gemm(
|
||||
torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
|
||||
torch::Tensor& b_q_weight,
|
||||
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
|
||||
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
|
||||
std::optional<torch::Tensor> const& global_scale_or_none,
|
||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||
@ -821,24 +785,12 @@ torch::Tensor gptq_marlin_gemm(
|
||||
torch::Tensor global_scale;
|
||||
if (global_scale_or_none.has_value()) {
|
||||
global_scale = global_scale_or_none.value();
|
||||
TORCH_CHECK(b_q_type == vllm::kFE2M1f && group_size == 16,
|
||||
"global_scale can only be used for nvfp4 format.");
|
||||
TORCH_CHECK(b_q_type == vllm::kFE2M1f,
|
||||
"global_scale can only be used for float4_e2m1f.");
|
||||
} else {
|
||||
global_scale = torch::empty({0}, options);
|
||||
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f && group_size == 16),
|
||||
"the global_scale parameter must be passed for nvfp4 format.");
|
||||
}
|
||||
|
||||
bool has_bias = b_bias_or_none.has_value();
|
||||
torch::Tensor b_bias;
|
||||
if (has_bias) {
|
||||
b_bias = b_bias_or_none.value();
|
||||
TORCH_CHECK(b_bias.device().is_cuda(), "b_bias is not on GPU");
|
||||
TORCH_CHECK(b_bias.is_contiguous(), "b_bias is not contiguous");
|
||||
TORCH_CHECK(b_bias.size(0) == size_n, "b_bias.size(0) != size_n");
|
||||
TORCH_CHECK(b_bias.stride(0) == 1, "b_bias.stride(0) != 1");
|
||||
} else {
|
||||
b_bias = torch::empty({0}, options);
|
||||
TORCH_CHECK(!(b_q_type == vllm::kFE2M1f),
|
||||
"the global_scale parameter must be passed for float4_e2m1f.");
|
||||
}
|
||||
|
||||
torch::Tensor b_zeros;
|
||||
@ -905,50 +857,34 @@ torch::Tensor gptq_marlin_gemm(
|
||||
if (a.scalar_type() == at::ScalarType::Half) {
|
||||
void* scales_ptr;
|
||||
if (b_q_type == vllm::kFE2M1f) {
|
||||
if (group_size == 16)
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
|
||||
else if (group_size == 32)
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
|
||||
else
|
||||
TORCH_CHECK(false,
|
||||
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
|
||||
"and group_size == 32 (MXFP4)");
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
|
||||
} else {
|
||||
scales_ptr = b_scales.data_ptr<at::Half>();
|
||||
}
|
||||
|
||||
marlin::marlin_mm<half>(
|
||||
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
|
||||
c_tmp.data_ptr<float>(), b_bias.data_ptr<at::Half>(), scales_ptr,
|
||||
global_scale.data_ptr<at::Half>(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
||||
perm.data_ptr(), a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
|
||||
a.stride(0), workspace.data_ptr(), b_q_type, has_bias, has_act_order,
|
||||
is_k_full, has_zp, num_groups, group_size, dev,
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||
use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||
c_tmp.data_ptr<float>(), scales_ptr, global_scale.data_ptr<at::Half>(),
|
||||
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
|
||||
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k, a.stride(0),
|
||||
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
|
||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||
void* scales_ptr;
|
||||
if (b_q_type == vllm::kFE2M1f) {
|
||||
if (group_size == 16)
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
|
||||
else if (group_size == 32)
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e8m0fnu>();
|
||||
else
|
||||
TORCH_CHECK(false,
|
||||
"float4_e2m1f only supports group_size == 16 (NVFP4) ",
|
||||
"and group_size == 32 (MXFP4)");
|
||||
scales_ptr = b_scales.data_ptr<at::Float8_e4m3fn>();
|
||||
} else {
|
||||
scales_ptr = b_scales.data_ptr<at::BFloat16>();
|
||||
}
|
||||
|
||||
marlin::marlin_mm<nv_bfloat16>(
|
||||
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
||||
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
|
||||
b_bias.data_ptr<at::BFloat16>(), scales_ptr,
|
||||
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(), scales_ptr,
|
||||
global_scale.data_ptr<at::BFloat16>(), b_zeros.data_ptr(),
|
||||
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
|
||||
size_m, size_n, size_k, a.stride(0), workspace.data_ptr(), b_q_type,
|
||||
has_bias, has_act_order, is_k_full, has_zp, num_groups, group_size, dev,
|
||||
has_act_order, is_k_full, has_zp, num_groups, group_size, dev,
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||
use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||
} else {
|
||||
|
||||
@ -10,18 +10,15 @@
|
||||
#define MARLIN_KERNEL_PARAMS \
|
||||
const int4 *__restrict__ A, const int4 *__restrict__ B, \
|
||||
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
||||
const int4 *__restrict__ b_bias_ptr, \
|
||||
const int4 *__restrict__ scales_ptr, \
|
||||
const uint16_t *__restrict__ scale2_ptr, \
|
||||
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
|
||||
int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \
|
||||
bool has_bias, bool use_atomic_add, bool use_fp32_reduce, \
|
||||
int max_shared_mem
|
||||
bool use_atomic_add, bool use_fp32_reduce, int max_shared_mem
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
||||
const vllm::ScalarTypeId s_type_id, // weight ScalarType id
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the
|
||||
|
||||
@ -39,7 +39,6 @@ namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
||||
const vllm::ScalarTypeId s_type_id, // weight scale ScalarType id
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the
|
||||
@ -272,7 +271,6 @@ __device__ inline void wait_negative_and_add(int* lock) {
|
||||
|
||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
||||
const vllm::ScalarTypeId s_type_id, // weight scale ScalarType id
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the
|
||||
@ -292,7 +290,6 @@ __global__ void Marlin(
|
||||
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
||||
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
||||
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
|
||||
const int4* __restrict__ b_bias_ptr,
|
||||
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
||||
// (k/groupsize)xn
|
||||
const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4
|
||||
@ -300,13 +297,12 @@ __global__ void Marlin(
|
||||
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
|
||||
// (k/groupsize)x(n/pack_factor)
|
||||
const int* __restrict__ g_idx, // int32 group indices of shape k
|
||||
int num_groups, // number of scale groups per output channel
|
||||
int prob_m, // batch dimension m
|
||||
int prob_n, // output dimension n
|
||||
int prob_k, // reduction dimension k
|
||||
int lda, // A.stride(0), equal to prob_k is A is contiguous
|
||||
int* locks, // extra global storage for barrier synchronization
|
||||
bool has_bias,
|
||||
int num_groups, // number of scale groups per output channel
|
||||
int prob_m, // batch dimension m
|
||||
int prob_n, // output dimension n
|
||||
int prob_k, // reduction dimension k
|
||||
int lda, // A.stride(0), equal to prob_k is A is contiguous
|
||||
int* locks, // extra global storage for barrier synchronization
|
||||
bool use_atomic_add, // whether to use atomic add to reduce
|
||||
bool use_fp32_reduce, // whether to use fp32 global reduce
|
||||
int max_shared_mem) {
|
||||
@ -330,29 +326,18 @@ __global__ void Marlin(
|
||||
using FragZP = typename ScalarType<scalar_t>::FragZP;
|
||||
|
||||
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
|
||||
static constexpr auto s_type = vllm::ScalarType::from_id(s_type_id);
|
||||
if constexpr (w_type == vllm::kFE2M1f) {
|
||||
static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 ||
|
||||
s_type == vllm::kFE8M0fnu && group_blocks == 2);
|
||||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
static_assert(s_type == vllm::kBFloat16);
|
||||
} else if constexpr (std::is_same<scalar_t, half>::value) {
|
||||
static_assert(s_type == vllm::kFloat16);
|
||||
}
|
||||
|
||||
constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8;
|
||||
constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 ||
|
||||
w_type == vllm::kU4B8 || w_type == vllm::kU8B128;
|
||||
// see comments of dequant.h for more details
|
||||
constexpr bool dequant_skip_flop =
|
||||
w_type == vllm::kFE4M3fn ||
|
||||
w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn ||
|
||||
!is_int_type ||
|
||||
has_zp && !is_zp_float && !std::is_same<scalar_t, nv_bfloat16>::value ||
|
||||
has_zp && !is_zp_float && !(w_type == vllm::kU8);
|
||||
|
||||
scalar_t2 global_scale;
|
||||
if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
|
||||
// NVFP4 format requires global scale
|
||||
|
||||
if constexpr (w_type == vllm::kFE2M1f) {
|
||||
uint16_t val = scale2_ptr[0];
|
||||
global_scale = Dtype::num2num2(*reinterpret_cast<scalar_t*>(&val));
|
||||
}
|
||||
@ -604,7 +589,7 @@ __global__ void Marlin(
|
||||
|
||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
(threadIdx.x % 32) / 4;
|
||||
s_sh_rd = s_sh_rd * 2 + (warp_row / group_blocks) % 2;
|
||||
s_sh_rd = s_sh_rd * 2 + warp_row % 2;
|
||||
|
||||
} else if constexpr (group_blocks != -1)
|
||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
@ -617,18 +602,6 @@ __global__ void Marlin(
|
||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
(threadIdx.x % 32) % 4;
|
||||
|
||||
int bias_sh_rd;
|
||||
if constexpr (m_block_size_8) {
|
||||
bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
(threadIdx.x % 32) / 8;
|
||||
} else {
|
||||
bias_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
(threadIdx.x % 32) % 4;
|
||||
}
|
||||
|
||||
int bias_sh_wr = threadIdx.x;
|
||||
int bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x;
|
||||
|
||||
// Zero-points have the same read layout as the scales
|
||||
// (without column-wise case)
|
||||
constexpr int num_col_threads = 8;
|
||||
@ -697,19 +670,7 @@ __global__ void Marlin(
|
||||
constexpr int sh_b_size = stages * b_sh_stage;
|
||||
int4* sh_b = sh;
|
||||
int4* sh_red = sh;
|
||||
|
||||
constexpr int sh_size_b_red_min =
|
||||
(sh_red_size < sh_b_size ? sh_red_size : sh_b_size);
|
||||
constexpr int sh_size_b_red_max =
|
||||
(sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
|
||||
constexpr int sh_bias_size = (thread_n_blocks * 16 / 8);
|
||||
constexpr int sh_b_red_bias_size =
|
||||
sh_size_b_red_max > (sh_size_b_red_min + sh_bias_size)
|
||||
? sh_size_b_red_max
|
||||
: (sh_size_b_red_min + sh_bias_size);
|
||||
|
||||
int4* sh_bias = sh + sh_size_b_red_min;
|
||||
int4* sh_g_idx = sh + sh_b_red_bias_size;
|
||||
int4* sh_g_idx = sh_b + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
|
||||
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
|
||||
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
|
||||
: (stages * s_sh_stage);
|
||||
@ -719,13 +680,15 @@ __global__ void Marlin(
|
||||
static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <=
|
||||
stages * b_sh_stage);
|
||||
int4* sh_a = sh_s + sh_s_size;
|
||||
// constexpr int shm_size_used =
|
||||
// stages * (g_idx_stage + zp_sh_stage) + sh_s_size +
|
||||
// (sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
|
||||
|
||||
// Register storage for double buffer of shared memory reads.
|
||||
FragA frag_a[2][thread_m_blocks];
|
||||
I4 frag_b_quant[2][b_thread_vecs];
|
||||
FragC frag_c[thread_m_blocks][4][2];
|
||||
FragS frag_s[2][4]; // No act-order
|
||||
FragS frag_bias[2][4];
|
||||
FragS frag_s[2][4]; // No act-order
|
||||
FragS act_frag_s[2][4][4]; // For act-order
|
||||
int frag_qzp[2][num_ints_per_thread]; // Zero-points
|
||||
FragZP frag_zp; // Zero-points in fp16
|
||||
@ -960,15 +923,10 @@ __global__ void Marlin(
|
||||
if constexpr (w_type_id != vllm::kFE2M1f.id()) {
|
||||
reinterpret_cast<int4*>(&frag_s[k % 2])[0] =
|
||||
sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
|
||||
} else if constexpr (group_blocks == 1 || thread_k_blocks > 4) {
|
||||
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
|
||||
reinterpret_cast<int2*>(
|
||||
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
|
||||
} else {
|
||||
reinterpret_cast<int2*>(&frag_s[k % 2])[0] =
|
||||
reinterpret_cast<int2*>(
|
||||
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride) +
|
||||
k % 2];
|
||||
sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)];
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1181,9 +1139,9 @@ __global__ void Marlin(
|
||||
int s_quant_0 = reinterpret_cast<int*>(frag_s[k2])[0];
|
||||
int s_quant_1 = reinterpret_cast<int*>(frag_s[k2])[1];
|
||||
|
||||
dequant_fp8_scales<scalar_t2, s_type_id>(
|
||||
s_quant_0, reinterpret_cast<scalar_t2*>(&frag_s[k2]));
|
||||
dequant_fp8_scales<scalar_t2, s_type_id>(
|
||||
dequant_fp8_scales<scalar_t2>(s_quant_0,
|
||||
reinterpret_cast<scalar_t2*>(&frag_s[k2]));
|
||||
dequant_fp8_scales<scalar_t2>(
|
||||
s_quant_1, reinterpret_cast<scalar_t2*>(&frag_s[k2]) + 2);
|
||||
}
|
||||
|
||||
@ -1453,7 +1411,7 @@ __global__ void Marlin(
|
||||
// Write out the reduce final result in the correct layout. We only actually
|
||||
// reshuffle matrix fragments in this step, the reduction above is performed
|
||||
// in fragment layout.
|
||||
auto write_result = [&](bool last) {
|
||||
auto write_result = [&]() {
|
||||
int c_gl_stride = prob_n / 8;
|
||||
constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
|
||||
int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
|
||||
@ -1480,7 +1438,7 @@ __global__ void Marlin(
|
||||
int c_gl_wr_end = c_gl_stride * prob_m;
|
||||
// We first reorder in shared memory to guarantee the most efficient final
|
||||
// global write patterns
|
||||
auto write = [&](int idx, float c0, float c1, FragS& s, FragS& b_bias) {
|
||||
auto write = [&](int idx, float c0, float c1, FragS& s) {
|
||||
scalar_t2 res =
|
||||
Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));
|
||||
|
||||
@ -1489,25 +1447,12 @@ __global__ void Marlin(
|
||||
if constexpr (!has_act_order && group_blocks == -1 &&
|
||||
w_type.size_bits() == 4 &&
|
||||
(has_zp && dequant_skip_flop || !has_zp)) {
|
||||
scalar_t2 tmp_scale = s[0];
|
||||
if constexpr (m_block_size_8) {
|
||||
tmp_scale = Dtype::num2num2(
|
||||
reinterpret_cast<scalar_t*>(&s[0])[(threadIdx.x % 8) / 4]);
|
||||
}
|
||||
res = __hmul2(res, tmp_scale);
|
||||
res = __hmul2(res, s[0]);
|
||||
}
|
||||
|
||||
if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) {
|
||||
if constexpr (w_type == vllm::kFE2M1f) {
|
||||
res = __hmul2(res, global_scale);
|
||||
}
|
||||
if (has_bias && last) {
|
||||
scalar_t2 tmp_bias = b_bias[0];
|
||||
if constexpr (m_block_size_8) {
|
||||
tmp_bias = Dtype::num2num2(
|
||||
reinterpret_cast<scalar_t*>(&b_bias[0])[(threadIdx.x % 8) / 4]);
|
||||
}
|
||||
res = __hadd2(res, tmp_bias);
|
||||
}
|
||||
|
||||
if constexpr (m_block_size_8) {
|
||||
((scalar_t*)sh_red)[idx] = res.x;
|
||||
@ -1525,25 +1470,19 @@ __global__ void Marlin(
|
||||
if constexpr (m_block_size_8) {
|
||||
int wr = c_sh_wr + 16 * j;
|
||||
write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1],
|
||||
frag_s[j / 2][2 * (j % 2) + 0],
|
||||
frag_bias[j / 2][2 * (j % 2) + 0]);
|
||||
frag_s[j / 2][2 * (j % 2) + 0]);
|
||||
write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3],
|
||||
frag_s[j / 2][2 * (j % 2) + 1],
|
||||
frag_bias[j / 2][2 * (j % 2) + 1]);
|
||||
frag_s[j / 2][2 * (j % 2) + 1]);
|
||||
} else {
|
||||
int wr = c_sh_wr + 8 * j;
|
||||
write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
|
||||
frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0],
|
||||
frag_bias[j / 2][2 * (j % 2) + 0]);
|
||||
frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
|
||||
write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
|
||||
frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0],
|
||||
frag_bias[j / 2][2 * (j % 2) + 0]);
|
||||
frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
|
||||
write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
|
||||
frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1],
|
||||
frag_bias[j / 2][2 * (j % 2) + 1]);
|
||||
frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
|
||||
write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
|
||||
frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1],
|
||||
frag_bias[j / 2][2 * (j % 2) + 1]);
|
||||
frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
|
||||
}
|
||||
}
|
||||
c_sh_wr += 16 * (4 * c_sh_stride);
|
||||
@ -1683,14 +1622,6 @@ __global__ void Marlin(
|
||||
}
|
||||
|
||||
thread_block_reduce();
|
||||
|
||||
if (has_bias && last) {
|
||||
__syncthreads();
|
||||
cp_async4_pred(&sh_bias[bias_sh_wr], &b_bias_ptr[bias_gl_rd],
|
||||
threadIdx.x < 16 * thread_n_blocks / 8);
|
||||
cp_async_fence();
|
||||
}
|
||||
|
||||
if constexpr (!has_act_order && group_blocks == -1 &&
|
||||
(has_zp && dequant_skip_flop || !has_zp)) {
|
||||
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
|
||||
@ -1753,20 +1684,11 @@ __global__ void Marlin(
|
||||
}
|
||||
barrier_release(&locks[locks_off], last);
|
||||
}
|
||||
|
||||
if (has_bias && last) {
|
||||
cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
reinterpret_cast<int4*>(&frag_bias)[0] = sh_bias[bias_sh_rd];
|
||||
reinterpret_cast<int4*>(&frag_bias)[1] = sh_bias[bias_sh_rd + 4];
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (use_atomic_add && slice_count > 1 && slice_idx != 0)
|
||||
wait_negative_and_add(&locks[locks_off]);
|
||||
if (last || use_atomic_add)
|
||||
// only the last block in a slice actually writes the result
|
||||
write_result(last);
|
||||
write_result();
|
||||
slice_row = 0;
|
||||
slice_col_par++;
|
||||
slice_col++;
|
||||
@ -1784,7 +1706,6 @@ __global__ void Marlin(
|
||||
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
|
||||
}
|
||||
|
||||
bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x;
|
||||
// Update slice k/n for scales loading
|
||||
if constexpr (has_act_order) {
|
||||
slice_k_start = tb_k * slice_row;
|
||||
|
||||
@ -349,12 +349,9 @@ def to_cute_constant(value: list[int]):
|
||||
|
||||
|
||||
def unique_schedules(impl_configs: list[ImplConfig]):
|
||||
# Use dict over set for deterministic ordering
|
||||
return list({
|
||||
sch: None
|
||||
for impl_config in impl_configs
|
||||
for sch in impl_config.schedules
|
||||
}.keys())
|
||||
return list(
|
||||
set(sch for impl_config in impl_configs
|
||||
for sch in impl_config.schedules))
|
||||
|
||||
|
||||
def unsigned_type_with_bitwidth(num_bits):
|
||||
@ -571,79 +568,78 @@ def generate():
|
||||
itertools.repeat(default_heuristic))
|
||||
]
|
||||
|
||||
# TODO: Support W4A8 when ready
|
||||
# # Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk))
|
||||
# # TODO (LucasWilkinson): Further tuning required
|
||||
# qqq_tile_heuristic_config = {
|
||||
# #### M = 257+
|
||||
# # ((128, 256), (2, 1, 1)) Broken for QQQ types
|
||||
# # TODO (LucasWilkinson): Investigate further
|
||||
# # "M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)),
|
||||
# # "M > 256": ((128, 256), (2, 1, 1)),
|
||||
# "M > 256": ((128, 128), (2, 1, 1)),
|
||||
# #### M = 129-256
|
||||
# "M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)),
|
||||
# "M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)),
|
||||
# # ((128, 256), (2, 1, 1)) Broken for QQQ types
|
||||
# # TODO (LucasWilkinson): Investigate further
|
||||
# # "M > 128": ((128, 256), (2, 1, 1)),
|
||||
# "M > 128": ((128, 128), (2, 1, 1)),
|
||||
# #### M = 65-128
|
||||
# "M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)),
|
||||
# "M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)),
|
||||
# "M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)),
|
||||
# "M > 64": ((128, 128), (2, 1, 1)),
|
||||
# #### M = 33-64
|
||||
# "M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)),
|
||||
# # Broken for QQQ types
|
||||
# # TODO (LucasWilkinson): Investigate further
|
||||
# #"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)),
|
||||
# "M > 32": ((128, 64), (2, 1, 1)),
|
||||
# #### M = 17-32
|
||||
# "M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)),
|
||||
# "M > 16": ((256, 32), (2, 1, 1)),
|
||||
# #### M = 1-16
|
||||
# "N >= 26624": ((256, 16), (1, 1, 1)),
|
||||
# None: ((128, 16), (1, 1, 1)),
|
||||
# }
|
||||
# Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk))
|
||||
# TODO (LucasWilkinson): Further tuning required
|
||||
qqq_tile_heuristic_config = {
|
||||
#### M = 257+
|
||||
# ((128, 256), (2, 1, 1)) Broken for QQQ types
|
||||
# TODO (LucasWilkinson): Investigate further
|
||||
# "M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)),
|
||||
# "M > 256": ((128, 256), (2, 1, 1)),
|
||||
"M > 256": ((128, 128), (2, 1, 1)),
|
||||
#### M = 129-256
|
||||
"M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)),
|
||||
"M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)),
|
||||
# ((128, 256), (2, 1, 1)) Broken for QQQ types
|
||||
# TODO (LucasWilkinson): Investigate further
|
||||
# "M > 128": ((128, 256), (2, 1, 1)),
|
||||
"M > 128": ((128, 128), (2, 1, 1)),
|
||||
#### M = 65-128
|
||||
"M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)),
|
||||
"M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)),
|
||||
"M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)),
|
||||
"M > 64": ((128, 128), (2, 1, 1)),
|
||||
#### M = 33-64
|
||||
"M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)),
|
||||
# Broken for QQQ types
|
||||
# TODO (LucasWilkinson): Investigate further
|
||||
#"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)),
|
||||
"M > 32": ((128, 64), (2, 1, 1)),
|
||||
#### M = 17-32
|
||||
"M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)),
|
||||
"M > 16": ((256, 32), (2, 1, 1)),
|
||||
#### M = 1-16
|
||||
"N >= 26624": ((256, 16), (1, 1, 1)),
|
||||
None: ((128, 16), (1, 1, 1)),
|
||||
}
|
||||
|
||||
# # For now we use the same heuristic for all types
|
||||
# # Heuristic is currently tuned for H100s
|
||||
# qqq_heuristic = [
|
||||
# (cond, ScheduleConfig(*tile_config,
|
||||
# **sch_common_params)) # type: ignore
|
||||
# for cond, tile_config in qqq_tile_heuristic_config.items()
|
||||
# ]
|
||||
# For now we use the same heuristic for all types
|
||||
# Heuristic is currently tuned for H100s
|
||||
qqq_heuristic = [
|
||||
(cond, ScheduleConfig(*tile_config,
|
||||
**sch_common_params)) # type: ignore
|
||||
for cond, tile_config in qqq_tile_heuristic_config.items()
|
||||
]
|
||||
|
||||
# QQQ_kernel_types = [
|
||||
# *(TypeConfig(
|
||||
# a=DataType.s8,
|
||||
# b=VLLMDataType.u4b8,
|
||||
# b_group_scale=b_group_scale,
|
||||
# b_group_zeropoint=DataType.void,
|
||||
# b_channel_scale=DataType.f32,
|
||||
# a_token_scale=DataType.f32,
|
||||
# out=DataType.f16,
|
||||
# accumulator=DataType.s32,
|
||||
# ) for b_group_scale in (DataType.f16, DataType.void)),
|
||||
# *(TypeConfig(
|
||||
# a=DataType.e4m3,
|
||||
# b=VLLMDataType.u4b8,
|
||||
# b_group_scale=b_group_scale,
|
||||
# b_group_zeropoint=DataType.void,
|
||||
# b_channel_scale=DataType.f32,
|
||||
# a_token_scale=DataType.f32,
|
||||
# out=DataType.f16,
|
||||
# accumulator=DataType.f32,
|
||||
# ) for b_group_scale in (DataType.f16, DataType.void)),
|
||||
# ]
|
||||
QQQ_kernel_types = [
|
||||
*(TypeConfig(
|
||||
a=DataType.s8,
|
||||
b=VLLMDataType.u4b8,
|
||||
b_group_scale=b_group_scale,
|
||||
b_group_zeropoint=DataType.void,
|
||||
b_channel_scale=DataType.f32,
|
||||
a_token_scale=DataType.f32,
|
||||
out=DataType.f16,
|
||||
accumulator=DataType.s32,
|
||||
) for b_group_scale in (DataType.f16, DataType.void)),
|
||||
*(TypeConfig(
|
||||
a=DataType.e4m3,
|
||||
b=VLLMDataType.u4b8,
|
||||
b_group_scale=b_group_scale,
|
||||
b_group_zeropoint=DataType.void,
|
||||
b_channel_scale=DataType.f32,
|
||||
a_token_scale=DataType.f32,
|
||||
out=DataType.f16,
|
||||
accumulator=DataType.f32,
|
||||
) for b_group_scale in (DataType.f16, DataType.void)),
|
||||
]
|
||||
|
||||
# impl_configs += [
|
||||
# ImplConfig(x[0], x[1], x[2])
|
||||
# for x in zip(QQQ_kernel_types,
|
||||
# itertools.repeat(get_unique_schedules(qqq_heuristic)),
|
||||
# itertools.repeat(qqq_heuristic))
|
||||
# ]
|
||||
impl_configs += [
|
||||
ImplConfig(x[0], x[1], x[2])
|
||||
for x in zip(QQQ_kernel_types,
|
||||
itertools.repeat(get_unique_schedules(qqq_heuristic)),
|
||||
itertools.repeat(qqq_heuristic))
|
||||
]
|
||||
|
||||
output_dir = os.path.join(SCRIPT_DIR, "generated")
|
||||
|
||||
|
||||
209
csrc/quantization/marlin/dense/LICENSE
Normal file
209
csrc/quantization/marlin/dense/LICENSE
Normal file
@ -0,0 +1,209 @@
|
||||
Contains code from https://github.com/IST-DASLab/marlin
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "{}"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright {yyyy} {name of copyright owner}
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
------------------------------------------------------------------------------------
|
||||
|
||||
This product bundles various third-party components under other open source licenses.
|
||||
This section summarizes those components and their licenses. See licenses/
|
||||
for text of these licenses.
|
||||
32
csrc/quantization/marlin/dense/common/base.h
Normal file
32
csrc/quantization/marlin/dense/common/base.h
Normal file
@ -0,0 +1,32 @@
|
||||
/*
|
||||
* Modified by HandH1998
|
||||
* Modified by Neural Magic
|
||||
* Copyright (C) Marlin.2024 Elias Frantar
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
|
||||
|
||||
// Instances of `Vec` are used to organize groups of >>registers<<, as needed
|
||||
// for instance as inputs to tensor core operations. Consequently, all
|
||||
// corresponding index accesses must be compile-time constants, which is why we
|
||||
// extensively use `#pragma unroll` throughout the kernel code to guarantee
|
||||
// this.
|
||||
template <typename T, int n>
|
||||
struct Vec {
|
||||
T elems[n];
|
||||
__device__ T& operator[](int i) { return elems[i]; }
|
||||
};
|
||||
89
csrc/quantization/marlin/dense/common/mem.h
Normal file
89
csrc/quantization/marlin/dense/common/mem.h
Normal file
@ -0,0 +1,89 @@
|
||||
/*
|
||||
* Modified by HandH1998
|
||||
* Modified by Neural Magic
|
||||
* Copyright (C) Marlin.2024 Elias Frantar
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// Predicated asynchronous global->shared copy; used for inputs A where we apply
|
||||
// predication to handle batchsizes that are not multiples of 16.
|
||||
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
const int BYTES = 16;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %0, 0;\n"
|
||||
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||
"}\n" ::"r"((int)pred),
|
||||
"r"(smem), "l"(glob_ptr), "n"(BYTES));
|
||||
}
|
||||
|
||||
// Asynchronous global->shared copy
|
||||
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
||||
const int BYTES = 16;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
||||
"}\n" ::"r"(smem),
|
||||
"l"(glob_ptr), "n"(BYTES));
|
||||
}
|
||||
|
||||
// Async copy fence.
|
||||
__device__ inline void cp_async_fence() {
|
||||
asm volatile("cp.async.commit_group;\n" ::);
|
||||
}
|
||||
|
||||
// Wait until at most `n` async copy stages are still pending.
|
||||
template <int n>
|
||||
__device__ inline void cp_async_wait() {
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
||||
}
|
||||
|
||||
// Wait until barrier reaches `count`, then lock for current threadblock.
|
||||
__device__ inline void barrier_acquire(int* lock, int count) {
|
||||
if (threadIdx.x == 0) {
|
||||
int state = -1;
|
||||
do
|
||||
// Guarantee that subsequent writes by this threadblock will be visible
|
||||
// globally.
|
||||
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
|
||||
: "=r"(state)
|
||||
: "l"(lock));
|
||||
while (state != count);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Release barrier and increment visitation count.
|
||||
__device__ inline void barrier_release(int* lock, bool reset = false) {
|
||||
__syncthreads();
|
||||
if (threadIdx.x == 0) {
|
||||
if (reset) {
|
||||
lock[0] = 0;
|
||||
return;
|
||||
}
|
||||
int val = 1;
|
||||
// Make sure that all writes since acquiring this barrier are visible
|
||||
// globally, while releasing the barrier.
|
||||
asm volatile("fence.acq_rel.gpu;\n");
|
||||
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
|
||||
:
|
||||
: "l"(lock), "r"(val));
|
||||
}
|
||||
}
|
||||
1073
csrc/quantization/marlin/dense/marlin_cuda_kernel.cu
Normal file
1073
csrc/quantization/marlin/dense/marlin_cuda_kernel.cu
Normal file
File diff suppressed because it is too large
Load Diff
1248
csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu
Normal file
1248
csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu
Normal file
File diff suppressed because it is too large
Load Diff
@ -130,12 +130,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("fatrelu_and_mul(Tensor! out, Tensor input, float threshold) -> ()");
|
||||
ops.impl("fatrelu_and_mul", torch::kCUDA, &fatrelu_and_mul);
|
||||
|
||||
ops.def(
|
||||
"swigluoai_and_mul(Tensor! out, Tensor input, float alpha=1.702, float "
|
||||
"limit=7.0) "
|
||||
"-> ()");
|
||||
ops.impl("swigluoai_and_mul", torch::kCUDA, &swigluoai_and_mul);
|
||||
|
||||
// GELU implementation used in GPT-2.
|
||||
ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
|
||||
ops.impl("gelu_new", torch::kCUDA, &gelu_new);
|
||||
@ -148,6 +142,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
|
||||
ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);
|
||||
|
||||
// prepare_inputs advance_step
|
||||
ops.def(
|
||||
"advance_step_flashattn(int num_seqs, int num_queries, int block_size, "
|
||||
"Tensor! input_tokens, Tensor sampled_token_ids, "
|
||||
"Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping, "
|
||||
"Tensor block_tables) -> ()");
|
||||
ops.impl("advance_step_flashattn", torch::kCUDA, &advance_step_flashattn);
|
||||
|
||||
ops.def(
|
||||
"advance_step_flashinfer("
|
||||
" int num_seqs, int num_queries, int block_size,"
|
||||
" Tensor! input_tokens, Tensor sampled_token_ids,"
|
||||
" Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping,"
|
||||
" Tensor block_tables, Tensor! paged_kv_indices,"
|
||||
" Tensor! paged_kv_indptr, Tensor! paged_kv_last_page_len,"
|
||||
" Tensor! block_table_bounds"
|
||||
") -> ()");
|
||||
ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer);
|
||||
|
||||
// Layernorm
|
||||
// Apply Root Mean Square (RMS) Normalization to the input tensor.
|
||||
ops.def(
|
||||
@ -213,6 +226,21 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
|
||||
// Quantization ops
|
||||
#ifndef USE_ROCM
|
||||
// Quantized GEMM for AQLM.
|
||||
ops.def(
|
||||
"aqlm_gemm(Tensor input, Tensor codes, Tensor codebooks, "
|
||||
"Tensor scales, int[] codebook_partition_sizes, Tensor? bias) "
|
||||
"-> Tensor",
|
||||
{stride_tag});
|
||||
ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm);
|
||||
|
||||
// Decompression method for AQLM.
|
||||
ops.def(
|
||||
"aqlm_dequant(Tensor codes, Tensor codebooks, "
|
||||
"int[] codebook_partition_sizes) -> Tensor",
|
||||
{stride_tag});
|
||||
ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant);
|
||||
|
||||
// Quantized GEMM for AWQ.
|
||||
ops.def(
|
||||
"awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, "
|
||||
@ -241,6 +269,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// custom types:
|
||||
// https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA
|
||||
|
||||
// Marlin (Dense) Optimized Quantized GEMM for GPTQ.
|
||||
ops.def(
|
||||
"marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
||||
"Tensor! workspace, SymInt size_m, SymInt size_n, SymInt size_k) -> "
|
||||
"Tensor",
|
||||
{stride_tag});
|
||||
// conditionally compiled so impl in source file
|
||||
|
||||
// Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
|
||||
ops.def(
|
||||
"gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, "
|
||||
@ -290,7 +326,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// gptq_marlin Optimized Quantized GEMM for GPTQ.
|
||||
ops.def(
|
||||
"gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, "
|
||||
"Tensor? b_bias_or_none,"
|
||||
"Tensor b_scales, Tensor? global_scale, Tensor? b_zeros_or_none, Tensor? "
|
||||
"g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_q_type, "
|
||||
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
|
||||
@ -345,6 +380,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("ggml_moe_get_block_size", &ggml_moe_get_block_size);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// marlin_qqq_gemm for QQQ.
|
||||
ops.def(
|
||||
"marlin_qqq_gemm(Tensor a, Tensor b_q_weight, "
|
||||
"Tensor s_tok, Tensor s_ch, Tensor s_group, "
|
||||
"Tensor! workspace, SymInt size_m, SymInt size_n, "
|
||||
"SymInt size_k) -> Tensor",
|
||||
{stride_tag});
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
// CUTLASS nvfp4 block scaled GEMM
|
||||
ops.def(
|
||||
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
|
||||
@ -423,19 +467,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
{stride_tag});
|
||||
ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data);
|
||||
|
||||
// A function that computes problem sizes for each expert's multiplication
|
||||
// used by the two mms called from fused MoE operation. It takes topk_ids as
|
||||
// an input, and computes problem_sizes1 and problem_sizes2 only.
|
||||
ops.def(
|
||||
"get_cutlass_moe_mm_problem_sizes(Tensor topk_ids, "
|
||||
" Tensor! problem_sizes1, "
|
||||
" Tensor! problem_sizes2, "
|
||||
" int num_experts, int n, int k, "
|
||||
" Tensor? blockscale_offsets) -> ()",
|
||||
{stride_tag});
|
||||
ops.impl("get_cutlass_moe_mm_problem_sizes", torch::kCUDA,
|
||||
&get_cutlass_moe_mm_problem_sizes);
|
||||
|
||||
// A function that computes data required to run fused MoE with w8a8 grouped
|
||||
// GEMM and PPLX. It takes expert_num_tokens and non_zero_expert_idxs
|
||||
// as an input, and computes expert_offsets (token start indices of each
|
||||
|
||||
@ -139,6 +139,21 @@ RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/
|
||||
WORKDIR /workspace
|
||||
|
||||
# install build and runtime dependencies
|
||||
|
||||
# arm64 (GH200) build follows the practice of "use existing pytorch" build,
|
||||
# we need to install torch and torchvision from the nightly builds first,
|
||||
# pytorch will not appear as a vLLM dependency in all of the following steps
|
||||
# after this step
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
|
||||
uv pip install --system \
|
||||
--index-url ${PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \
|
||||
"torch==2.8.0.dev20250318+cu128" "torchvision==0.22.0.dev20250319"; \
|
||||
uv pip install --system \
|
||||
--index-url ${PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \
|
||||
--pre pytorch_triton==3.3.0+gitab727c40; \
|
||||
fi
|
||||
|
||||
COPY requirements/common.txt requirements/common.txt
|
||||
COPY requirements/cuda.txt requirements/cuda.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
@ -219,8 +234,6 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
&& sccache --show-stats; \
|
||||
fi
|
||||
|
||||
ARG vllm_target_device="cuda"
|
||||
ENV VLLM_TARGET_DEVICE=${vllm_target_device}
|
||||
ENV CCACHE_DIR=/root/.cache/ccache
|
||||
RUN --mount=type=cache,target=/root/.cache/ccache \
|
||||
--mount=type=cache,target=/root/.cache/uv \
|
||||
@ -372,45 +385,31 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
|
||||
|
||||
# Install FlashInfer from source
|
||||
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
|
||||
# Keep this in sync with "flashinfer" extra in setup.py
|
||||
ARG FLASHINFER_GIT_REF="v0.2.12"
|
||||
# Flag to control whether to compile FlashInfer AOT kernels
|
||||
# Set to "true" to enable AOT compilation:
|
||||
# docker build --build-arg FLASHINFER_AOT_COMPILE=true ...
|
||||
ARG FLASHINFER_AOT_COMPILE=false
|
||||
# Keep this in sync with https://github.com/vllm-project/vllm/blob/main/requirements/cuda.txt
|
||||
# We use `--force-reinstall --no-deps` to avoid issues with the existing FlashInfer wheel.
|
||||
ARG FLASHINFER_GIT_REF="v0.2.11"
|
||||
RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
|
||||
. /etc/environment
|
||||
git clone --depth 1 --recursive --shallow-submodules \
|
||||
--branch ${FLASHINFER_GIT_REF} \
|
||||
${FLASHINFER_GIT_REPO} flashinfer
|
||||
# Exclude CUDA arches for older versions (11.x and 12.0-12.7)
|
||||
# TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg.
|
||||
if [[ "${CUDA_VERSION}" == 11.* ]]; then
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9"
|
||||
elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a"
|
||||
else
|
||||
# CUDA 12.8+ supports 10.0a and 12.0
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0"
|
||||
fi
|
||||
echo "🏗️ Building FlashInfer for arches: ${FI_TORCH_CUDA_ARCH_LIST}"
|
||||
# Needed to build AOT kernels
|
||||
pushd flashinfer
|
||||
if [ "${FLASHINFER_AOT_COMPILE}" = "true" ]; then
|
||||
# Exclude CUDA arches for older versions (11.x and 12.0-12.7)
|
||||
# TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg.
|
||||
if [[ "${CUDA_VERSION}" == 11.* ]]; then
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9"
|
||||
elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a"
|
||||
else
|
||||
# CUDA 12.8+ supports 10.0a and 12.0
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0"
|
||||
fi
|
||||
echo "🏗️ Installing FlashInfer with AOT compilation for arches: ${FI_TORCH_CUDA_ARCH_LIST}"
|
||||
# Build AOT kernels
|
||||
TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \
|
||||
python3 -m flashinfer.aot
|
||||
# Install with no-build-isolation since we already built AOT kernels
|
||||
TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \
|
||||
uv pip install --system --no-build-isolation . \
|
||||
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
# Download pre-compiled cubins
|
||||
TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \
|
||||
python3 -m flashinfer --download-cubin || echo "WARNING: Failed to download flashinfer cubins."
|
||||
else
|
||||
echo "🏗️ Installing FlashInfer without AOT compilation in JIT mode"
|
||||
uv pip install --system . \
|
||||
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
fi
|
||||
TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \
|
||||
python3 -m flashinfer.aot
|
||||
TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \
|
||||
uv pip install --system --no-build-isolation --force-reinstall --no-deps .
|
||||
popd
|
||||
rm -rf flashinfer
|
||||
BASH
|
||||
@ -498,11 +497,14 @@ ENV HF_HUB_ENABLE_HF_TRANSFER 1
|
||||
# Copy in the v1 package for testing (it isn't distributed yet)
|
||||
COPY vllm/v1 /usr/local/lib/python${PYTHON_VERSION}/dist-packages/vllm/v1
|
||||
|
||||
# Source code is used in the `python_only_compile.sh` test
|
||||
# We hide it inside `src/` so that this source code
|
||||
# doc requires source code
|
||||
# we hide them inside `test_docs/` , so that this source code
|
||||
# will not be imported by other tests
|
||||
RUN mkdir src
|
||||
RUN mv vllm src/vllm
|
||||
RUN mkdir test_docs
|
||||
RUN mv docs test_docs/
|
||||
RUN cp -r examples test_docs/
|
||||
RUN mv vllm test_docs/
|
||||
RUN mv mkdocs.yaml test_docs/
|
||||
#################### TEST IMAGE ####################
|
||||
|
||||
#################### OPENAI API SERVER ####################
|
||||
|
||||
@ -16,7 +16,7 @@ ENV LANG=C.UTF-8 \
|
||||
RUN microdnf install -y \
|
||||
which procps findutils tar vim git gcc gcc-gfortran g++ make patch zlib-devel \
|
||||
libjpeg-turbo-devel libtiff-devel libpng-devel libwebp-devel freetype-devel harfbuzz-devel \
|
||||
openssl-devel openblas openblas-devel autoconf automake libtool cmake numpy libsndfile && \
|
||||
openssl-devel openblas openblas-devel autoconf automake libtool cmake numpy && \
|
||||
microdnf clean all
|
||||
|
||||
# Python Installation
|
||||
@ -136,71 +136,6 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
mkdir -p /tmp/hf-xet/dist && \
|
||||
cp dist/*.whl /tmp/hf-xet/dist/
|
||||
|
||||
# Build numba
|
||||
FROM python-install AS numba-builder
|
||||
|
||||
ARG MAX_JOBS
|
||||
ARG NUMBA_VERSION=0.61.2
|
||||
|
||||
WORKDIR /tmp
|
||||
|
||||
# Clone all required dependencies
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
microdnf install ninja-build gcc gcc-c++ -y && \
|
||||
git clone --recursive https://github.com/llvm/llvm-project.git -b llvmorg-15.0.7 && \
|
||||
git clone --recursive https://github.com/numba/llvmlite.git -b v0.44.0 && \
|
||||
git clone --recursive https://github.com/numba/numba.git -b ${NUMBA_VERSION} && \
|
||||
cd llvm-project && mkdir build && cd build && \
|
||||
uv pip install 'cmake<4' setuptools numpy && \
|
||||
export PREFIX=/usr/local && CMAKE_ARGS="${CMAKE_ARGS} -DLLVM_ENABLE_PROJECTS=lld;libunwind;compiler-rt" \
|
||||
CFLAGS="$(echo $CFLAGS | sed 's/-fno-plt //g')" \
|
||||
CXXFLAGS="$(echo $CXXFLAGS | sed 's/-fno-plt //g')" \
|
||||
CMAKE_ARGS="${CMAKE_ARGS} -DFFI_INCLUDE_DIR=$PREFIX/include" \
|
||||
CMAKE_ARGS="${CMAKE_ARGS} -DFFI_LIBRARY_DIR=$PREFIX/lib" \
|
||||
cmake -DCMAKE_INSTALL_PREFIX="${PREFIX}" \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DCMAKE_LIBRARY_PATH="${PREFIX}" \
|
||||
-DLLVM_ENABLE_LIBEDIT=OFF \
|
||||
-DLLVM_ENABLE_LIBXML2=OFF \
|
||||
-DLLVM_ENABLE_RTTI=ON \
|
||||
-DLLVM_ENABLE_TERMINFO=OFF \
|
||||
-DLLVM_INCLUDE_BENCHMARKS=OFF \
|
||||
-DLLVM_INCLUDE_DOCS=OFF \
|
||||
-DLLVM_INCLUDE_EXAMPLES=OFF \
|
||||
-DLLVM_INCLUDE_GO_TESTS=OFF \
|
||||
-DLLVM_INCLUDE_TESTS=OFF \
|
||||
-DLLVM_INCLUDE_UTILS=ON \
|
||||
-DLLVM_INSTALL_UTILS=ON \
|
||||
-DLLVM_UTILS_INSTALL_DIR=libexec/llvm \
|
||||
-DLLVM_BUILD_LLVM_DYLIB=OFF \
|
||||
-DLLVM_LINK_LLVM_DYLIB=OFF \
|
||||
-DLLVM_EXPERIMENTAL_TARGETS_TO_BUILD=WebAssembly \
|
||||
-DLLVM_ENABLE_FFI=ON \
|
||||
-DLLVM_ENABLE_Z3_SOLVER=OFF \
|
||||
-DLLVM_OPTIMIZED_TABLEGEN=ON \
|
||||
-DCMAKE_POLICY_DEFAULT_CMP0111=NEW \
|
||||
-DCOMPILER_RT_BUILD_BUILTINS=ON \
|
||||
-DCOMPILER_RT_BUILTINS_HIDE_SYMBOLS=OFF \
|
||||
-DCOMPILER_RT_BUILD_LIBFUZZER=OFF \
|
||||
-DCOMPILER_RT_BUILD_CRT=OFF \
|
||||
-DCOMPILER_RT_BUILD_MEMPROF=OFF \
|
||||
-DCOMPILER_RT_BUILD_PROFILE=OFF \
|
||||
-DCOMPILER_RT_BUILD_SANITIZERS=OFF \
|
||||
-DCOMPILER_RT_BUILD_XRAY=OFF \
|
||||
-DCOMPILER_RT_BUILD_GWP_ASAN=OFF \
|
||||
-DCOMPILER_RT_BUILD_ORC=OFF \
|
||||
-DCOMPILER_RT_INCLUDE_TESTS=OFF \
|
||||
${CMAKE_ARGS} -GNinja ../llvm \
|
||||
|
||||
&& ninja install . && \
|
||||
# build llvmlite
|
||||
cd ../../llvmlite && python setup.py bdist_wheel && \
|
||||
cd ../numba && \
|
||||
if ! grep '#include "dynamic_annotations.h"' numba/_dispatcher.cpp; then \
|
||||
sed -i '/#include "internal\/pycore_atomic.h"/i\#include "dynamic_annotations.h"' numba/_dispatcher.cpp; \
|
||||
fi && python setup.py bdist_wheel
|
||||
|
||||
|
||||
# Final build stage
|
||||
FROM python-install AS vllm-cpu
|
||||
ARG PYTHON_VERSION
|
||||
@ -228,30 +163,23 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,from=torch-vision,source=/tmp/vision/dist,target=/tmp/vision-wheels/ \
|
||||
--mount=type=bind,from=hf-xet-builder,source=/tmp/hf-xet/dist,target=/tmp/hf-xet-wheels/ \
|
||||
--mount=type=bind,from=torch,source=/tmp/pytorch/dist,target=/tmp/torch-wheels/ \
|
||||
--mount=type=bind,from=numba-builder,source=/tmp/llvmlite/dist,target=/tmp/llvmlite-wheels/ \
|
||||
--mount=type=bind,from=numba-builder,source=/tmp/numba/dist,target=/tmp/numba-wheels/ \
|
||||
sed -i '/^torch/d' requirements/build.txt && \
|
||||
ARROW_WHL_FILE=$(ls /tmp/arrow-wheels/pyarrow-*.whl) && \
|
||||
VISION_WHL_FILE=$(ls /tmp/vision-wheels/*.whl) && \
|
||||
HF_XET_WHL_FILE=$(ls /tmp/hf-xet-wheels/*.whl) && \
|
||||
TORCH_WHL_FILE=$(ls /tmp/torch-wheels/*.whl) && \
|
||||
LLVM_WHL_FILE=$(ls /tmp/llvmlite-wheels/*.whl) && \
|
||||
NUMBA_WHL_FILE=$(ls /tmp/numba-wheels/*.whl) && \
|
||||
ARROW_WHL_FILE=$(ls /tmp/arrow-wheels/pyarrow-*.whl | head -n 1) && \
|
||||
VISION_WHL_FILE=$(ls /tmp/vision-wheels/*.whl | head -n 1) && \
|
||||
HF_XET_WHL_FILE=$(ls /tmp/hf-xet-wheels/*.whl | head -n 1) && \
|
||||
TORCH_WHL_FILE=$(ls /tmp/torch-wheels/*.whl | head -n 1) && \
|
||||
uv pip install -v \
|
||||
$ARROW_WHL_FILE \
|
||||
$VISION_WHL_FILE \
|
||||
$HF_XET_WHL_FILE \
|
||||
$TORCH_WHL_FILE \
|
||||
$LLVM_WHL_FILE \
|
||||
$NUMBA_WHL_FILE \
|
||||
--index-strategy unsafe-best-match \
|
||||
-r requirements/build.txt \
|
||||
-r requirements/cpu.txt
|
||||
|
||||
-r requirements/cpu.txt
|
||||
|
||||
# Build and install vllm
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
VLLM_TARGET_DEVICE=cpu VLLM_CPU_MOE_PREPACK=0 python setup.py bdist_wheel && \
|
||||
VLLM_TARGET_DEVICE=cpu python setup.py bdist_wheel && \
|
||||
uv pip install "$(echo dist/*.whl)[tensorizer]"
|
||||
|
||||
# setup non-root user for vllm
|
||||
@ -268,3 +196,4 @@ WORKDIR /home/vllm
|
||||
|
||||
# Set the default entrypoint
|
||||
ENTRYPOINT ["python", "-m", "vllm.entrypoints.openai.api_server"]
|
||||
|
||||
|
||||
@ -77,7 +77,6 @@ Internal data structures.
|
||||
- [vllm.multimodal.inputs.MultiModalFieldElem][]
|
||||
- [vllm.multimodal.inputs.MultiModalFieldConfig][]
|
||||
- [vllm.multimodal.inputs.MultiModalKwargsItem][]
|
||||
- [vllm.multimodal.inputs.MultiModalKwargsItems][]
|
||||
- [vllm.multimodal.inputs.MultiModalKwargs][]
|
||||
- [vllm.multimodal.inputs.MultiModalInputs][]
|
||||
|
||||
|
||||
@ -48,7 +48,7 @@ You can tune the performance by adjusting `max_num_batched_tokens`:
|
||||
|
||||
- Smaller values (e.g., 2048) achieve better inter-token latency (ITL) because there are fewer prefills slowing down decodes.
|
||||
- Higher values achieve better time to first token (TTFT) as you can process more prefill tokens in a batch.
|
||||
- For optimal throughput, we recommend setting `max_num_batched_tokens > 8192` especially for smaller models on large GPUs.
|
||||
- For optimal throughput, we recommend setting `max_num_batched_tokens > 8096` especially for smaller models on large GPUs.
|
||||
- If `max_num_batched_tokens` is the same as `max_model_len`, that's almost the equivalent to the V0 default scheduling policy (except that it still prioritizes decodes).
|
||||
|
||||
```python
|
||||
@ -129,52 +129,6 @@ Data parallelism replicates the entire model across multiple GPU sets and proces
|
||||
Data parallelism can be combined with the other parallelism strategies and is set by `data_parallel_size=N`.
|
||||
Note that MoE layers will be sharded according to the product of the tensor parallel size and data parallel size.
|
||||
|
||||
### Batch-level DP for Multi-Modal Encoders
|
||||
|
||||
By default, TP is used to shard the weights of multi-modal encoders just like for language decoders,
|
||||
in order to reduce the memory and compute load on each GPU.
|
||||
|
||||
However, since the size of multi-modal encoders is very small compared to language decoders,
|
||||
there is relatively little gain from TP. On the other hand, TP incurs significant communication
|
||||
overhead because of all-reduce being performed after every layer.
|
||||
|
||||
Given this, it may be advantageous to instead shard the batched input data using TP, essentially
|
||||
performing batch-level DP. This has been shown to improve the throughput by around 10% for
|
||||
`tensor_parallel_size=8`. For vision encoders that use hardware-unoptimized Conv3D operations,
|
||||
batch-level DP can provide another 40% increase to throughput compared to regular TP.
|
||||
|
||||
Nevertheless, since the weights of the multi-modal encoder are replicated across each TP rank,
|
||||
there will be a minor increase in memory consumption and may cause OOM if you can barely fit the model already.
|
||||
|
||||
You can enable batch-level DP by setting `mm_encoder_tp_mode="data"`, for example:
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
|
||||
llm = LLM(
|
||||
model="Qwen/Qwen2.5-VL-72B-Instruct",
|
||||
tensor_parallel_size=4,
|
||||
# When mm_encoder_tp_mode="data",
|
||||
# the vision encoder uses TP=4 (not DP=1) to shard the input data,
|
||||
# so the TP size becomes the effective DP size.
|
||||
# Note that this is independent of the DP size for language decoder which is used in expert parallel setting.
|
||||
mm_encoder_tp_mode="data",
|
||||
# The language decoder uses TP=4 to shard the weights regardless
|
||||
# of the setting of mm_encoder_tp_mode
|
||||
)
|
||||
```
|
||||
|
||||
!! important
|
||||
Batch-level DP is not to be confused with API request-level DP
|
||||
(which is instead controlled by `data_parallel_size`).
|
||||
|
||||
The availablilty of batch-level DP is based on model implementation.
|
||||
Currently, the following models support `mm_encoder_tp_mode="data"`:
|
||||
|
||||
- Llama4 (<gh-pr:18368>)
|
||||
- Qwen2.5-VL (<gh-pr:22742>)
|
||||
- Step3 (<gh-pr:22697>)
|
||||
|
||||
## Input Processing
|
||||
|
||||
### Parallel Processing
|
||||
|
||||
@ -11,7 +11,7 @@ vLLM contains two sets of benchmarks:
|
||||
|
||||
The performance benchmarks are used for development to confirm whether new changes improve performance under various workloads. They are triggered on every commit with both the `perf-benchmarks` and `ready` labels, and when a PR is merged into vLLM.
|
||||
|
||||
The latest performance results are hosted on the public [vLLM Performance Dashboard](https://hud.pytorch.org/benchmark/llms?repoName=vllm-project%2Fvllm).
|
||||
The latest performance results are hosted on the public [vLLM Performance Dashboard](https://perf.vllm.ai).
|
||||
|
||||
More information on the performance benchmarks and their parameters can be found [here](gh-file:.buildkite/nightly-benchmarks/performance-benchmarks-descriptions.md).
|
||||
|
||||
|
||||
@ -629,7 +629,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargsItems,
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_config = self.info.get_hf_config()
|
||||
image_token_id = hf_config.image_token_index
|
||||
@ -778,7 +778,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargsItems,
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_config = self.info.get_hf_config()
|
||||
bos_token_id = hf_config.bos_token_id
|
||||
|
||||
@ -9,7 +9,7 @@ vLLM can be run on a cloud based GPU machine with [dstack](https://dstack.ai/),
|
||||
To install dstack client, run:
|
||||
|
||||
```bash
|
||||
pip install dstack[all]
|
||||
pip install "dstack[all]
|
||||
dstack server
|
||||
```
|
||||
|
||||
|
||||
@ -175,19 +175,11 @@ implementations that input `FusedMoEActivationFormat.Standard` support chunking
|
||||
|
||||
### FusedMoEModularKernel Initialization
|
||||
|
||||
`FusedMoEMethodBase` class has 3 methods that are collectively responsible in creating the `FusedMoEModularKernel` object. They are,
|
||||
`FusedMoEMethodBase` class has 2 methods that are collectively responsible in creating the `FusedMoEModularKernel` object. They are,
|
||||
|
||||
* maybe_make_prepare_finalize,
|
||||
* select_gemm_impl, and
|
||||
* init_prepare_finalize
|
||||
|
||||
#### maybe_make_prepare_finalize
|
||||
|
||||
The `maybe_make_prepare_finalize` method is responsbile for constructing an instance of `FusedMoEPrepareAndFinalize` when appropriate based on the current all2all backend, e.g. when EP + DP is enabled. The base class method currently constructs all the `FusedMoEPrepareAndFinalize` objects for the EP+DP case. Derived classes can override this method to construct prepare/finalize objects for different scenarios, e.g. `ModelOptNvFp4FusedMoE` can construct a `FlashInferCutlassMoEPrepareAndFinalize` for the EP+TP case.
|
||||
Please refer to the implementations in,
|
||||
|
||||
* `ModelOptNvFp4FusedMoE`
|
||||
|
||||
#### select_gemm_impl
|
||||
|
||||
The `select_gemm_impl` method is undefined in the base class. It is the responsibility of the derived class to implement a method that constructs a valid/appropriate `FusedMoEPermuteExpertsUnpermute` object.
|
||||
|
||||
@ -216,7 +216,7 @@ Instead of NumPy arrays, you can also pass `'torch.Tensor'` instances, as shown
|
||||
from vllm import LLM, SamplingParams
|
||||
from qwen_vl_utils import process_vision_info
|
||||
|
||||
model_path = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
model_path = "Qwen/Qwen2.5-VL-3B-Instruct/"
|
||||
video_path = "https://content.pexels.com/videos/free-videos.mp4"
|
||||
|
||||
llm = LLM(
|
||||
|
||||
@ -17,6 +17,7 @@ th {
|
||||
| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ |
|
||||
| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ❌ |
|
||||
| BitBLAS (GPTQ) | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| AQLM | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
|
||||
@ -18,7 +18,7 @@ vLLM supports the following hardware platforms:
|
||||
## Hardware Plugins
|
||||
|
||||
The backends below live **outside** the main `vllm` repository and follow the
|
||||
[Hardware-Pluggable RFC](../../design/plugin_system.md).
|
||||
[Hardware-Pluggable RFC](../design/plugin_system.md).
|
||||
|
||||
| Accelerator | PyPI / package | Repository |
|
||||
|-------------|----------------|------------|
|
||||
|
||||
@ -8,7 +8,7 @@ This guide will help you quickly get started with vLLM to perform:
|
||||
## Prerequisites
|
||||
|
||||
- OS: Linux
|
||||
- Python: 3.9 -- 3.13
|
||||
- Python: 3.9 -- 3.12
|
||||
|
||||
## Installation
|
||||
|
||||
|
||||
@ -24,6 +24,7 @@ def fix_case(text: str) -> str:
|
||||
"llm": "LLM",
|
||||
"mae": "MAE",
|
||||
"tpu": "TPU",
|
||||
"aqlm": "AQLM",
|
||||
"gguf": "GGUF",
|
||||
"lora": "LoRA",
|
||||
"rlhf": "RLHF",
|
||||
|
||||
@ -2,5 +2,4 @@ Loading Model weights with fastsafetensors
|
||||
===================================================================
|
||||
|
||||
Using fastsafetensors library enables loading model weights to GPU memory by leveraging GPU direct storage. See [their GitHub repository](https://github.com/foundation-model-stack/fastsafetensors) for more details.
|
||||
|
||||
To enable this feature, use the ``--load-format fastsafetensors`` command-line argument
|
||||
For enabling this feature, set the environment variable ``USE_FASTSAFETENSOR`` to ``true``
|
||||
|
||||
@ -330,7 +330,6 @@ th {
|
||||
| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | |
|
||||
| `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | |
|
||||
| `MBartForConditionalGeneration` | mBART | `facebook/mbart-large-en-ro`, `facebook/mbart-large-50`, etc. | | | |
|
||||
| `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `zai-org/chatglm2-6b`, `zai-org/chatglm3-6b`, `ShieldLM-6B-chatglm3`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `CohereForCausalLM`, `Cohere2ForCausalLM` | Command-R | `CohereLabs/c4ai-command-r-v01`, `CohereLabs/c4ai-command-r7b-12-2024`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `DbrxForCausalLM` | DBRX | `databricks/dbrx-base`, `databricks/dbrx-instruct`, etc. | | ✅︎ | ✅︎ |
|
||||
@ -363,7 +362,7 @@ th {
|
||||
| `GraniteMoeForCausalLM` | Granite 3.0 MoE, PowerMoE | `ibm-granite/granite-3.0-1b-a400m-base`, `ibm-granite/granite-3.0-3b-a800m-instruct`, `ibm/PowerMoE-3b`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `GraniteMoeHybridForCausalLM` | Granite 4.0 MoE Hybrid | `ibm-granite/granite-4.0-tiny-preview`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `GraniteMoeSharedForCausalLM` | Granite MoE Shared | `ibm-research/moe-7b-1b-active-shared-experts` (test model) | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | |
|
||||
| `Grok1ModelForCausalLM` | Grok1 | `hpcai-tech/grok-1`. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `HunYuanDenseV1ForCausalLM` | Hunyuan-7B-Instruct-0124 | `tencent/Hunyuan-7B-Instruct-0124` | ✅︎ | | ✅︎ |
|
||||
| `HunYuanMoEV1ForCausalLM` | Hunyuan-80B-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`, etc. | ✅︎ | | ✅︎ |
|
||||
@ -373,7 +372,6 @@ th {
|
||||
| `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ | ✅︎ |
|
||||
| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Lfm2ForCausalLM` | LFM2 | `LiquidAI/LFM2-1.2B`, `LiquidAI/LFM2-700M`, `LiquidAI/LFM2-350M`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | ✅︎ |
|
||||
@ -385,8 +383,8 @@ th {
|
||||
| `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ | ✅︎ |
|
||||
| `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | | ✅︎ | ✅︎ |
|
||||
| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | | ✅︎ | ✅︎ |
|
||||
| `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ | ✅︎ |
|
||||
| `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | | ✅︎ | ✅︎ |
|
||||
| `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | ✅︎ |
|
||||
@ -420,9 +418,6 @@ Some models are supported only via the [Transformers backend](#transformers). Th
|
||||
!!! note
|
||||
Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096.
|
||||
|
||||
!!! note
|
||||
Some mBART models' config files do not have an `architecture` defined. Therefore, you need to use `--hf-overrides '{"architectures": ["MBartForConditionalGeneration"]}'` to explicitly specify the use of the `MBartForConditionalGeneration` architecture.
|
||||
|
||||
### Pooling Models
|
||||
|
||||
See [this page](./pooling_models.md) for more information on how to use pooling models.
|
||||
@ -437,17 +432,17 @@ These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) A
|
||||
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|
||||
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
|
||||
| `BertModel`<sup>C</sup> | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | | ✅︎ |
|
||||
| `Gemma2Model`<sup>C</sup> | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `GteModel`<sup>C</sup> | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | | | ✅︎ |
|
||||
| `GteNewModel`<sup>C</sup> | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | | | ✅︎ |
|
||||
| `ModernBertModel`<sup>C</sup> | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. | | | ✅︎ |
|
||||
| `NomicBertModel`<sup>C</sup> | Nomic BERT | `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. | | | ✅︎ |
|
||||
| `BertModel`<sup>C</sup> | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | | |
|
||||
| `Gemma2Model`<sup>C</sup> | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | | ✅︎ |
|
||||
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | |
|
||||
| `GteModel`<sup>C</sup> | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | | | |
|
||||
| `GteNewModel`<sup>C</sup> | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | | | |
|
||||
| `ModernBertModel`<sup>C</sup> | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. | | | |
|
||||
| `NomicBertModel`<sup>C</sup> | Nomic BERT | `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. | | | |
|
||||
| `LlamaModel`<sup>C</sup>, `LlamaForCausalLM`<sup>C</sup>, `MistralModel`<sup>C</sup>, etc. | Llama-based | `intfloat/e5-mistral-7b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen2Model`<sup>C</sup>, `Qwen2ForCausalLM`<sup>C</sup> | Qwen2-based | `ssmits/Qwen2-7B-Instruct-embed-base` (see note), `Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen3Model`<sup>C</sup>, `Qwen3ForCausalLM`<sup>C</sup> | Qwen3-based | `Qwen/Qwen3-Embedding-0.6B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `RobertaModel`, `RobertaForMaskedLM` | RoBERTa-based | `sentence-transformers/all-roberta-large-v1`, etc. | | | ✅︎ |
|
||||
| `RobertaModel`, `RobertaForMaskedLM` | RoBERTa-based | `sentence-transformers/all-roberta-large-v1`, etc. | | | |
|
||||
| `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* | \* |
|
||||
|
||||
<sup>C</sup> Automatically converted into an embedding model via `--convert embed`. ([details](./pooling_models.md#model-conversion))
|
||||
@ -477,7 +472,7 @@ These models primarily support the [`LLM.classify`](./pooling_models.md#llmclass
|
||||
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|
||||
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
|
||||
| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ | |
|
||||
| `GPT2ForSequenceClassification` | GPT2 | `nie3e/sentiment-polish-gpt2-small` | | | ✅︎ |
|
||||
| `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* | \* |
|
||||
|
||||
@ -494,12 +489,12 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
|
||||
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|
||||
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
|
||||
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | | ✅︎ |
|
||||
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | | |
|
||||
| `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen2ForSequenceClassification` | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | | ✅︎ |
|
||||
| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | | | ✅︎ |
|
||||
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | | |
|
||||
| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | | | |
|
||||
| `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* | \* |
|
||||
|
||||
<sup>C</sup> Automatically converted into a classification model via `--convert classify`. ([details](./pooling_models.md#model-conversion))
|
||||
@ -620,14 +615,14 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
||||
| `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ |
|
||||
| `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Glm4vMoeForConditionalGeneration` | GLM-4.5V | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5V`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Glm4vMoeForConditionalGeneration` | GLM-4.5V | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5V`, etc. | | ✅︎ | ✅︎ |
|
||||
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | ✅︎ |
|
||||
| `InternS1ForConditionalGeneration` | Intern-S1 | T + I<sup>E+</sup> + V<sup>E+</sup> | `internlm/Intern-S1`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `InternVLChatModel` | InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | | | ✅︎ |
|
||||
| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | ✅︎ | ✅︎ |
|
||||
| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | | ✅︎ |
|
||||
| `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Llama_Nemotron_Nano_VL` | Llama Nemotron Nano VL | T + I<sup>E+</sup> | `nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1` | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `LlavaForConditionalGeneration` | LLaVA-1.5, Pixtral (HF Transformers) | T + I<sup>E+</sup> | `llava-hf/llava-1.5-7b-hf`, `TIGER-Lab/Mantis-8B-siglip-llama3` (see note), `mistral-community/pixtral-12b`, etc. | | ✅︎ | ✅︎ |
|
||||
@ -642,7 +637,6 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
||||
| `MolmoForCausalLM` | Molmo | T + I<sup>+</sup> | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `NVLM_D_Model` | NVLM-D 1.0 | T + I<sup>+</sup> | `nvidia/NVLM-D-72B`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Ovis` | Ovis2, Ovis1.6 | T + I<sup>+</sup> | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Ovis2_5` | Ovis2.5 | T + I<sup>+</sup> + V | `AIDC-AI/Ovis2.5-9B`, etc. | | | ✅︎ |
|
||||
| `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + I<sup>E</sup> | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ | ⚠️ |
|
||||
| `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + I<sup>E+</sup> | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Phi4MMForCausalLM` | Phi-4-multimodal | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
@ -653,7 +647,6 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
||||
| `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen2_5_VLForConditionalGeneration` | Qwen2.5-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen2.5-Omni-7B` | | ✅︎ | ✅︎ |
|
||||
| `RForConditionalGeneration` | R-VL-4B | T + I<sup>E+</sup> | `YannQi/R-4B` | | ✅︎ | ✅︎ |
|
||||
| `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ |
|
||||
| `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ |
|
||||
| `Step3VLForConditionalGeneration` | Step3-VL | T + I<sup>+</sup> | `stepfun-ai/step3` | | ✅︎ | ✅︎ |
|
||||
|
||||
@ -35,7 +35,6 @@ You can check if this is happening by trying the old defaults with `--generation
|
||||
If other strategies don't solve the problem, it's likely that the vLLM instance is stuck somewhere. You can use the following environment variables to help debug the issue:
|
||||
|
||||
- `export VLLM_LOGGING_LEVEL=DEBUG` to turn on more logging.
|
||||
- `export VLLM_LOG_STATS_INTERVAL=1.` to get log statistics more frequently for tracking running queue, waiting queue and cache hit states.
|
||||
- `export CUDA_LAUNCH_BLOCKING=1` to identify which CUDA kernel is causing the problem.
|
||||
- `export NCCL_DEBUG=TRACE` to turn on more logging for NCCL.
|
||||
- `export VLLM_TRACE_FUNCTION=1` to record all function calls for inspection in the log files to tell which function crashes or hangs. Do not use this flag unless absolutely needed for debugging, it will cause significant delays in startup time.
|
||||
|
||||
@ -107,7 +107,7 @@ to enable simultaneous generation and embedding using the same engine instance i
|
||||
#### Mamba Models
|
||||
|
||||
Models using selective state-space mechanisms instead of standard transformer attention are supported.
|
||||
Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`) are supported. Please note that these models currently require disabling prefix caching in V1.
|
||||
Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`) are supported. Please note that these models currently require disabling prefix caching in V1. Additionally, Mamba-1 models require `enforce_eager=True`.
|
||||
|
||||
Models that combine Mamba-2 and Mamba-1 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`,
|
||||
`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`, `JambaForCausalLM`). Please note that
|
||||
@ -154,15 +154,12 @@ differences compared to V0:
|
||||
|
||||
##### Logprobs Calculation
|
||||
|
||||
By default, logprobs in V1 are now returned immediately once computed from the model’s raw output (i.e.
|
||||
Logprobs in V1 are now returned immediately once computed from the model’s raw output (i.e.
|
||||
before applying any logits post-processing such as temperature scaling or penalty
|
||||
adjustments). As a result, the returned logprobs do not reflect the final adjusted
|
||||
probabilities used during sampling.
|
||||
|
||||
You can adjust this behavior by setting the `--logprobs-mode` flag.
|
||||
Four modes are supported: `raw_logprobs` (default), `processed_logprobs`, `raw_logits`, `processed_logits`.
|
||||
Raw means the values before applying any logit processors, like bad words.
|
||||
Processed means the values after applying all processors, including temperature and top_k/top_p.
|
||||
Support for logprobs with post-sampling adjustments is in progress and will be added in future updates.
|
||||
|
||||
##### Prompt Logprobs with Prefix Caching
|
||||
|
||||
|
||||
@ -52,6 +52,20 @@ Try it yourself with the following argument:
|
||||
|
||||
### Quantization
|
||||
|
||||
#### AQLM
|
||||
|
||||
vLLM supports models that are quantized using AQLM.
|
||||
|
||||
Try one yourself by passing one of the following models to the `--model` argument:
|
||||
|
||||
- `ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf`
|
||||
- `ISTA-DASLab/Llama-2-7b-AQLM-2Bit-2x8-hf`
|
||||
- `ISTA-DASLab/Llama-2-13b-AQLM-2Bit-1x16-hf`
|
||||
- `ISTA-DASLab/Mixtral-8x7b-AQLM-2Bit-1x16-hf`
|
||||
- `BlackSamorez/TinyLlama-1_1B-Chat-v1_0-AQLM-2Bit-1x16-hf`
|
||||
|
||||
> Some of these models are likely to be too large for a single GPU. You can split them across multiple GPUs by setting `--tensor-parallel-size` to the number of required GPUs.
|
||||
|
||||
#### GGUF
|
||||
|
||||
vLLM supports models that are quantized using GGUF.
|
||||
|
||||
@ -70,27 +70,12 @@ def parse_args():
|
||||
default=64,
|
||||
help=("Maximum number of sequences to be processed in a single iteration."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-model-len",
|
||||
type=int,
|
||||
help=("Maximum number of tokens to be processed in a single iteration."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=int,
|
||||
default=300,
|
||||
help=("Number of seconds before unresponsive process is killed."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu-memory-utilization",
|
||||
type=float,
|
||||
default=0.8,
|
||||
help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quantization",
|
||||
type=str,
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -105,9 +90,7 @@ def main(
|
||||
enforce_eager,
|
||||
trust_remote_code,
|
||||
max_num_seqs,
|
||||
max_model_len,
|
||||
gpu_memory_utilization,
|
||||
quantization,
|
||||
):
|
||||
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
|
||||
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
|
||||
@ -159,9 +142,7 @@ def main(
|
||||
enable_expert_parallel=True,
|
||||
trust_remote_code=trust_remote_code,
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_model_len=max_model_len,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
quantization=quantization,
|
||||
)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
@ -217,16 +198,14 @@ if __name__ == "__main__":
|
||||
args.enforce_eager,
|
||||
args.trust_remote_code,
|
||||
args.max_num_seqs,
|
||||
args.max_model_len,
|
||||
args.gpu_memory_utilization,
|
||||
args.quantization,
|
||||
),
|
||||
)
|
||||
proc.start()
|
||||
procs.append(proc)
|
||||
exit_code = 0
|
||||
for proc in procs:
|
||||
proc.join(timeout=args.timeout)
|
||||
proc.join(timeout=300)
|
||||
if proc.exitcode is None:
|
||||
print(f"Killing process {proc.pid} that didn't stop within 5 minutes.")
|
||||
proc.kill()
|
||||
|
||||
@ -2,14 +2,9 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Demonstrate prompting of text-to-text
|
||||
encoder/decoder models, specifically BART and mBART.
|
||||
|
||||
This script is refactored to allow model selection via command-line arguments.
|
||||
encoder/decoder models, specifically BART
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from typing import NamedTuple, Optional
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.inputs import (
|
||||
ExplicitEncoderDecoderPrompt,
|
||||
@ -19,175 +14,119 @@ from vllm.inputs import (
|
||||
)
|
||||
|
||||
|
||||
class ModelRequestData(NamedTuple):
|
||||
"""
|
||||
Holds the configuration for a specific model, including its
|
||||
HuggingFace ID and the prompts to use for the demo.
|
||||
"""
|
||||
|
||||
model_id: str
|
||||
encoder_prompts: list
|
||||
decoder_prompts: list
|
||||
hf_overrides: Optional[dict] = None
|
||||
|
||||
|
||||
def get_bart_config() -> ModelRequestData:
|
||||
"""
|
||||
Returns the configuration for facebook/bart-large-cnn.
|
||||
This uses the exact test cases from the original script.
|
||||
"""
|
||||
encoder_prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"An encoder prompt",
|
||||
]
|
||||
decoder_prompts = [
|
||||
"A decoder prompt",
|
||||
"Another decoder prompt",
|
||||
]
|
||||
return ModelRequestData(
|
||||
model_id="facebook/bart-large-cnn",
|
||||
encoder_prompts=encoder_prompts,
|
||||
decoder_prompts=decoder_prompts,
|
||||
)
|
||||
|
||||
|
||||
def get_mbart_config() -> ModelRequestData:
|
||||
"""
|
||||
Returns the configuration for facebook/mbart-large-en-ro.
|
||||
This uses prompts suitable for an English-to-Romanian translation task.
|
||||
"""
|
||||
encoder_prompts = [
|
||||
"The quick brown fox jumps over the lazy dog.",
|
||||
"How are you today?",
|
||||
]
|
||||
decoder_prompts = ["", ""]
|
||||
hf_overrides = {"architectures": ["MBartForConditionalGeneration"]}
|
||||
return ModelRequestData(
|
||||
model_id="facebook/mbart-large-en-ro",
|
||||
encoder_prompts=encoder_prompts,
|
||||
decoder_prompts=decoder_prompts,
|
||||
hf_overrides=hf_overrides,
|
||||
)
|
||||
|
||||
|
||||
MODEL_GETTERS = {
|
||||
"bart": get_bart_config,
|
||||
"mbart": get_mbart_config,
|
||||
}
|
||||
|
||||
|
||||
def create_all_prompt_types(
|
||||
encoder_prompts_raw: list,
|
||||
decoder_prompts_raw: list,
|
||||
tokenizer,
|
||||
) -> list:
|
||||
"""
|
||||
Generates a list of diverse prompt types for demonstration.
|
||||
This function is generic and uses the provided raw prompts
|
||||
to create various vLLM input objects.
|
||||
"""
|
||||
text_prompt_raw = encoder_prompts_raw[0]
|
||||
text_prompt = TextPrompt(prompt=encoder_prompts_raw[1 % len(encoder_prompts_raw)])
|
||||
def create_prompts(tokenizer):
|
||||
# Test prompts
|
||||
#
|
||||
# This section shows all of the valid ways to prompt an
|
||||
# encoder/decoder model.
|
||||
#
|
||||
# - Helpers for building prompts
|
||||
text_prompt_raw = "Hello, my name is"
|
||||
text_prompt = TextPrompt(prompt="The president of the United States is")
|
||||
tokens_prompt = TokensPrompt(
|
||||
prompt_token_ids=tokenizer.encode(
|
||||
encoder_prompts_raw[2 % len(encoder_prompts_raw)]
|
||||
)
|
||||
prompt_token_ids=tokenizer.encode(prompt="The capital of France is")
|
||||
)
|
||||
# - Pass a single prompt to encoder/decoder model
|
||||
# (implicitly encoder input prompt);
|
||||
# decoder input prompt is assumed to be None
|
||||
|
||||
single_text_prompt_raw = text_prompt_raw # Pass a string directly
|
||||
single_text_prompt = text_prompt # Pass a TextPrompt
|
||||
single_tokens_prompt = tokens_prompt # Pass a TokensPrompt
|
||||
|
||||
# ruff: noqa: E501
|
||||
# - Pass explicit encoder and decoder input prompts within one data structure.
|
||||
# Encoder and decoder prompts can both independently be text or tokens, with
|
||||
# no requirement that they be the same prompt type. Some example prompt-type
|
||||
# combinations are shown below, note that these are not exhaustive.
|
||||
|
||||
enc_dec_prompt1 = ExplicitEncoderDecoderPrompt(
|
||||
# Pass encoder prompt string directly, &
|
||||
# pass decoder prompt tokens
|
||||
encoder_prompt=single_text_prompt_raw,
|
||||
decoder_prompt=single_tokens_prompt,
|
||||
)
|
||||
enc_dec_prompt2 = ExplicitEncoderDecoderPrompt(
|
||||
# Pass TextPrompt to encoder, and
|
||||
# pass decoder prompt string directly
|
||||
encoder_prompt=single_text_prompt,
|
||||
decoder_prompt=single_text_prompt_raw,
|
||||
)
|
||||
enc_dec_prompt3 = ExplicitEncoderDecoderPrompt(
|
||||
# Pass encoder prompt tokens directly, and
|
||||
# pass TextPrompt to decoder
|
||||
encoder_prompt=single_tokens_prompt,
|
||||
decoder_prompt=single_text_prompt,
|
||||
)
|
||||
|
||||
decoder_tokens_prompt = TokensPrompt(
|
||||
prompt_token_ids=tokenizer.encode(decoder_prompts_raw[0])
|
||||
)
|
||||
single_prompt_examples = [
|
||||
text_prompt_raw,
|
||||
text_prompt,
|
||||
tokens_prompt,
|
||||
]
|
||||
explicit_pair_examples = [
|
||||
ExplicitEncoderDecoderPrompt(
|
||||
encoder_prompt=text_prompt_raw,
|
||||
decoder_prompt=decoder_tokens_prompt,
|
||||
),
|
||||
ExplicitEncoderDecoderPrompt(
|
||||
encoder_prompt=text_prompt,
|
||||
decoder_prompt=decoder_prompts_raw[1 % len(decoder_prompts_raw)],
|
||||
),
|
||||
ExplicitEncoderDecoderPrompt(
|
||||
encoder_prompt=tokens_prompt,
|
||||
decoder_prompt=text_prompt,
|
||||
),
|
||||
]
|
||||
# - Finally, here's a useful helper function for zipping encoder and
|
||||
# decoder prompts together into a list of ExplicitEncoderDecoderPrompt
|
||||
# instances
|
||||
zipped_prompt_list = zip_enc_dec_prompts(
|
||||
encoder_prompts_raw,
|
||||
decoder_prompts_raw,
|
||||
["An encoder prompt", "Another encoder prompt"],
|
||||
["A decoder prompt", "Another decoder prompt"],
|
||||
)
|
||||
return single_prompt_examples + explicit_pair_examples + zipped_prompt_list
|
||||
|
||||
# - Let's put all of the above example prompts together into one list
|
||||
# which we will pass to the encoder/decoder LLM.
|
||||
return [
|
||||
single_text_prompt_raw,
|
||||
single_text_prompt,
|
||||
single_tokens_prompt,
|
||||
enc_dec_prompt1,
|
||||
enc_dec_prompt2,
|
||||
enc_dec_prompt3,
|
||||
] + zipped_prompt_list
|
||||
|
||||
|
||||
def create_sampling_params() -> SamplingParams:
|
||||
"""Create a sampling params object."""
|
||||
# Create a sampling params object.
|
||||
def create_sampling_params():
|
||||
return SamplingParams(
|
||||
temperature=0,
|
||||
top_p=1.0,
|
||||
min_tokens=0,
|
||||
max_tokens=30,
|
||||
max_tokens=20,
|
||||
)
|
||||
|
||||
|
||||
def print_outputs(outputs: list):
|
||||
"""Formats and prints the generation outputs."""
|
||||
print("-" * 80)
|
||||
# Print the outputs.
|
||||
def print_outputs(outputs):
|
||||
print("-" * 50)
|
||||
for i, output in enumerate(outputs):
|
||||
prompt = output.prompt
|
||||
encoder_prompt = output.encoder_prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Output {i + 1}:")
|
||||
print(f"Encoder Prompt: {encoder_prompt!r}")
|
||||
print(f"Decoder Prompt: {prompt!r}")
|
||||
print(f"Generated Text: {generated_text!r}")
|
||||
print("-" * 80)
|
||||
|
||||
|
||||
def main(args):
|
||||
"""Main execution function."""
|
||||
model_key = args.model
|
||||
if model_key not in MODEL_GETTERS:
|
||||
raise ValueError(
|
||||
f"Unknown model: {model_key}. "
|
||||
f"Available models: {list(MODEL_GETTERS.keys())}"
|
||||
print(
|
||||
f"Encoder prompt: {encoder_prompt!r}\n"
|
||||
f"Decoder prompt: {prompt!r}\n"
|
||||
f"Generated text: {generated_text!r}"
|
||||
)
|
||||
config_getter = MODEL_GETTERS[model_key]
|
||||
model_config = config_getter()
|
||||
print("-" * 50)
|
||||
|
||||
print(f"🚀 Running demo for model: {model_config.model_id}")
|
||||
|
||||
def main():
|
||||
dtype = "float"
|
||||
|
||||
# Create a BART encoder/decoder model instance
|
||||
llm = LLM(
|
||||
model=model_config.model_id,
|
||||
dtype="float",
|
||||
hf_overrides=model_config.hf_overrides,
|
||||
model="facebook/bart-large-cnn",
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Get BART tokenizer
|
||||
tokenizer = llm.llm_engine.get_tokenizer_group()
|
||||
prompts = create_all_prompt_types(
|
||||
encoder_prompts_raw=model_config.encoder_prompts,
|
||||
decoder_prompts_raw=model_config.decoder_prompts,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
prompts = create_prompts(tokenizer)
|
||||
sampling_params = create_sampling_params()
|
||||
|
||||
# Generate output tokens from the prompts. The output is a list of
|
||||
# RequestOutput objects that contain the prompt, generated
|
||||
# text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
print_outputs(outputs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="A flexible demo for vLLM encoder-decoder models."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
"-m",
|
||||
type=str,
|
||||
default="bart",
|
||||
choices=MODEL_GETTERS.keys(),
|
||||
help="The short name of the model to run.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
main()
|
||||
|
||||
@ -1,147 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""This example demonstrates instantiating vLLM with a custom logits processor
|
||||
class object.
|
||||
|
||||
For a basic example of implementing a custom logits processor, see
|
||||
the `DummyLogitsProcessor` implementation in `vllm/test_utils.py`.
|
||||
|
||||
For testing purposes, a dummy logits processor is employed which, if
|
||||
`target_token` is passed as a keyword argument to `SamplingParams.extra_args`,
|
||||
will mask out all tokens except `target_token`.
|
||||
|
||||
A batch is constructed with `temperature=0.0` and 50% of requests specifying
|
||||
`target_token`, and for these requests - and *only* these requests - we
|
||||
expect the `target_token` to be decoded in each step, yielding an output
|
||||
similar to that shown below:
|
||||
|
||||
Generated Outputs:
|
||||
------------------------------------------------------------
|
||||
Prompt: 'Hello, my name is'
|
||||
Output: " ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '"
|
||||
------------------------------------------------------------
|
||||
Prompt: 'The president of the United States is'
|
||||
Output: " not a racist. He is a racist.\nHe's a racist because he"
|
||||
------------------------------------------------------------
|
||||
Prompt: 'The capital of France is'
|
||||
Output: ' also also also also also also also also also also also also also
|
||||
also also also'
|
||||
------------------------------------------------------------
|
||||
Prompt: 'The future of AI is'
|
||||
Output: ' in the hands of the people.\n\nThe future of AI is in the'
|
||||
------------------------------------------------------------
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.sample.logits_processor import (
|
||||
BatchUpdate,
|
||||
LogitsProcessor,
|
||||
MoveDirectionality,
|
||||
)
|
||||
|
||||
|
||||
# Hypothetical custom logits processor
|
||||
class DummyLogitsProcessor(LogitsProcessor):
|
||||
"""Fake logit processor to support unit testing and examples"""
|
||||
|
||||
def __init__(
|
||||
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
|
||||
):
|
||||
self.req_info: dict[int, SamplingParams] = {}
|
||||
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
"""Never impacts greedy sampling"""
|
||||
return False
|
||||
|
||||
def update_state(self, batch_update: Optional[BatchUpdate]):
|
||||
if not batch_update:
|
||||
return
|
||||
|
||||
# Process added requests.
|
||||
for index, params, _, _ in batch_update.added:
|
||||
assert params is not None
|
||||
if params.extra_args and (
|
||||
target_token := params.extra_args.get("target_token")
|
||||
):
|
||||
self.req_info[index] = target_token
|
||||
|
||||
if self.req_info:
|
||||
# Process removed requests.
|
||||
for index in batch_update.removed:
|
||||
self.req_info.pop(index, None)
|
||||
|
||||
# Process moved requests, unidirectional move (a->b) and swap
|
||||
# (a<->b)
|
||||
for adx, bdx, direct in batch_update.moved:
|
||||
a_val = self.req_info.pop(adx, None)
|
||||
b_val = self.req_info.pop(bdx, None)
|
||||
if a_val is not None:
|
||||
self.req_info[bdx] = a_val
|
||||
if direct == MoveDirectionality.SWAP and b_val is not None:
|
||||
self.req_info[adx] = b_val
|
||||
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
if not self.req_info:
|
||||
return logits
|
||||
|
||||
# Save target values before modification
|
||||
rows_list = list(self.req_info.keys())
|
||||
cols = torch.tensor(
|
||||
[self.req_info[i] for i in rows_list],
|
||||
dtype=torch.long,
|
||||
device=logits.device,
|
||||
)
|
||||
rows = torch.tensor(rows_list, dtype=torch.long, device=logits.device)
|
||||
values_to_keep = logits[rows, cols].clone()
|
||||
|
||||
# Mask all but target tokens
|
||||
logits[rows] = float("-inf")
|
||||
logits[rows, cols] = values_to_keep
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a mixture of requests which do and don't utilize the dummy logitproc
|
||||
sampling_params_list = [
|
||||
SamplingParams(temperature=0.0, extra_args={"target_token": 128}),
|
||||
SamplingParams(temperature=0.0),
|
||||
SamplingParams(temperature=0.0, extra_args={"target_token": 67}),
|
||||
SamplingParams(temperature=0.0),
|
||||
]
|
||||
|
||||
|
||||
def main():
|
||||
# Create an LLM.
|
||||
llm = LLM(
|
||||
model="facebook/opt-125m",
|
||||
logits_processors=[DummyLogitsProcessor],
|
||||
)
|
||||
# Generate texts from the prompts.
|
||||
# The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params_list)
|
||||
# Print the outputs.
|
||||
print("\nGenerated Outputs:\n" + "-" * 60)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}")
|
||||
print(f"Output: {generated_text!r}")
|
||||
print("-" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -15,8 +15,6 @@ from pydantic import BaseModel
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
|
||||
MAX_TOKENS = 50
|
||||
|
||||
# Guided decoding by Choice (list of possible options)
|
||||
guided_decoding_params_choice = GuidedDecodingParams(choice=["Positive", "Negative"])
|
||||
sampling_params_choice = SamplingParams(guided_decoding=guided_decoding_params_choice)
|
||||
@ -25,9 +23,7 @@ prompt_choice = "Classify this sentiment: vLLM is wonderful!"
|
||||
# Guided decoding by Regex
|
||||
guided_decoding_params_regex = GuidedDecodingParams(regex=r"\w+@\w+\.com\n")
|
||||
sampling_params_regex = SamplingParams(
|
||||
guided_decoding=guided_decoding_params_regex,
|
||||
stop=["\n"],
|
||||
max_tokens=MAX_TOKENS,
|
||||
guided_decoding=guided_decoding_params_regex, stop=["\n"]
|
||||
)
|
||||
prompt_regex = (
|
||||
"Generate an email address for Alan Turing, who works in Enigma."
|
||||
@ -52,10 +48,7 @@ class CarDescription(BaseModel):
|
||||
|
||||
json_schema = CarDescription.model_json_schema()
|
||||
guided_decoding_params_json = GuidedDecodingParams(json=json_schema)
|
||||
sampling_params_json = SamplingParams(
|
||||
guided_decoding=guided_decoding_params_json,
|
||||
max_tokens=MAX_TOKENS,
|
||||
)
|
||||
sampling_params_json = SamplingParams(guided_decoding=guided_decoding_params_json)
|
||||
prompt_json = (
|
||||
"Generate a JSON with the brand, model and car_type of"
|
||||
"the most iconic car from the 90's"
|
||||
@ -71,10 +64,7 @@ condition ::= column "= " number
|
||||
number ::= "1 " | "2 "
|
||||
"""
|
||||
guided_decoding_params_grammar = GuidedDecodingParams(grammar=simplified_sql_grammar)
|
||||
sampling_params_grammar = SamplingParams(
|
||||
guided_decoding=guided_decoding_params_grammar,
|
||||
max_tokens=MAX_TOKENS,
|
||||
)
|
||||
sampling_params_grammar = SamplingParams(guided_decoding=guided_decoding_params_grammar)
|
||||
prompt_grammar = (
|
||||
"Generate an SQL query to show the 'username' and 'email'from the 'users' table."
|
||||
)
|
||||
|
||||
@ -283,10 +283,8 @@ def run_glm4v(questions: list[str], modality: str) -> ModelRequestData:
|
||||
)
|
||||
|
||||
prompts = [
|
||||
(
|
||||
"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>"
|
||||
f"{question}<|assistant|>"
|
||||
)
|
||||
f"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>\
|
||||
{question}<|assistant|>"
|
||||
for question in questions
|
||||
]
|
||||
|
||||
@ -335,80 +333,6 @@ def run_glm4_1v(questions: list[str], modality: str) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
# GLM-4.5V
|
||||
def run_glm4_5v(questions: list[str], modality: str) -> ModelRequestData:
|
||||
model_name = "zai-org/GLM-4.5V"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
mm_processor_kwargs={
|
||||
"size": {"shortest_edge": 12544, "longest_edge": 47040000},
|
||||
"fps": 1,
|
||||
},
|
||||
limit_mm_per_prompt={modality: 1},
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=4,
|
||||
)
|
||||
|
||||
if modality == "image":
|
||||
placeholder = "<|begin_of_image|><|image|><|end_of_image|>"
|
||||
elif modality == "video":
|
||||
placeholder = "<|begin_of_video|><|video|><|end_of_video|>"
|
||||
|
||||
prompts = [
|
||||
(
|
||||
"[gMASK]<sop><|system|>\nYou are a helpful assistant.<|user|>\n"
|
||||
f"{placeholder}"
|
||||
f"{question}<|assistant|>assistant\n"
|
||||
)
|
||||
for question in questions
|
||||
]
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# GLM-4.5V-FP8
|
||||
def run_glm4_5v_fp8(questions: list[str], modality: str) -> ModelRequestData:
|
||||
model_name = "zai-org/GLM-4.5V-FP8"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
mm_processor_kwargs={
|
||||
"size": {"shortest_edge": 12544, "longest_edge": 47040000},
|
||||
"fps": 1,
|
||||
},
|
||||
limit_mm_per_prompt={modality: 1},
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=4,
|
||||
)
|
||||
|
||||
if modality == "image":
|
||||
placeholder = "<|begin_of_image|><|image|><|end_of_image|>"
|
||||
elif modality == "video":
|
||||
placeholder = "<|begin_of_video|><|video|><|end_of_video|>"
|
||||
|
||||
prompts = [
|
||||
(
|
||||
"[gMASK]<sop><|system|>\nYou are a helpful assistant.<|user|>\n"
|
||||
f"{placeholder}"
|
||||
f"{question}<|assistant|>assistant\n"
|
||||
)
|
||||
for question in questions
|
||||
]
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# H2OVL-Mississippi
|
||||
def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
@ -459,8 +383,8 @@ def run_hyperclovax_seed_vision(
|
||||
for question in questions:
|
||||
if modality == "image":
|
||||
"""
|
||||
ocr: List the words in the image in raster order.
|
||||
Even if the word order feels unnatural for reading,
|
||||
ocr: List the words in the image in raster order.
|
||||
Even if the word order feels unnatural for reading,
|
||||
the model will handle it as long as it follows raster order.
|
||||
e.g. "Naver, CLOVA, bigshane"
|
||||
lens_keywords: List the entity names in the image.
|
||||
@ -769,13 +693,15 @@ def run_llava_next_video(questions: list[str], modality: str) -> ModelRequestDat
|
||||
def run_llava_onevision(questions: list[str], modality: str) -> ModelRequestData:
|
||||
if modality == "video":
|
||||
prompts = [
|
||||
f"<|im_start|>user <video>\n{question}<|im_end|><|im_start|>assistant\n"
|
||||
f"<|im_start|>user <video>\n{question}<|im_end|> \
|
||||
<|im_start|>assistant\n"
|
||||
for question in questions
|
||||
]
|
||||
|
||||
elif modality == "image":
|
||||
prompts = [
|
||||
f"<|im_start|>user <image>\n{question}<|im_end|><|im_start|>assistant\n"
|
||||
f"<|im_start|>user <image>\n{question}<|im_end|> \
|
||||
<|im_start|>assistant\n"
|
||||
for question in questions
|
||||
]
|
||||
|
||||
@ -889,39 +815,6 @@ def run_minicpmv(questions: list[str], modality: str) -> ModelRequestData:
|
||||
return run_minicpmv_base(questions, modality, "openbmb/MiniCPM-V-2_6")
|
||||
|
||||
|
||||
def run_minimax_vl_01(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
model_name = "MiniMaxAI/MiniMax-VL-01"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={modality: 1},
|
||||
trust_remote_code=True,
|
||||
tensor_parallel_size=8,
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
messages = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "image"}, {"type": "text", "text": question}],
|
||||
}
|
||||
]
|
||||
for question in questions
|
||||
]
|
||||
prompts = tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=True, tokenize=False
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# Mistral-3 HF-format
|
||||
def run_mistral3(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
@ -998,7 +891,8 @@ def run_molmo(questions: list[str], modality: str) -> ModelRequestData:
|
||||
)
|
||||
|
||||
prompts = [
|
||||
f"<|im_start|>user <image>\n{question}<|im_end|><|im_start|>assistant\n"
|
||||
f"<|im_start|>user <image>\n{question}<|im_end|> \
|
||||
<|im_start|>assistant\n"
|
||||
for question in questions
|
||||
]
|
||||
|
||||
@ -1104,38 +998,6 @@ def run_ovis(questions: list[str], modality: str) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
# Ovis2_5
|
||||
def run_ovis2_5(questions: list[str], modality: str) -> ModelRequestData:
|
||||
model_name = "AIDC-AI/Ovis2.5-2B"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
trust_remote_code=True,
|
||||
dtype="half",
|
||||
limit_mm_per_prompt={modality: 1},
|
||||
)
|
||||
if modality == "image":
|
||||
placeholder = "<image>"
|
||||
elif modality == "video":
|
||||
placeholder = "<video>"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||
messages = [
|
||||
[{"role": "user", "content": f"{placeholder}\n{question}"}]
|
||||
for question in questions
|
||||
]
|
||||
prompts = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# PaliGemma
|
||||
def run_paligemma(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
@ -1435,28 +1297,6 @@ def run_qwen2_5_omni(questions: list[str], modality: str):
|
||||
)
|
||||
|
||||
|
||||
# R-4B
|
||||
def run_r_vl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
model_name = "YannQi/R-4B"
|
||||
|
||||
prompts = [
|
||||
f"<|im_start|>user <image>\n{question}<|im_end|><|im_start|>assistant\n"
|
||||
for question in questions
|
||||
]
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=16384,
|
||||
limit_mm_per_prompt={modality: 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# SkyworkR1V
|
||||
def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
@ -1608,8 +1448,6 @@ model_example_map = {
|
||||
"gemma3n": run_gemma3n,
|
||||
"glm4v": run_glm4v,
|
||||
"glm4_1v": run_glm4_1v,
|
||||
"glm4_5v": run_glm4_5v,
|
||||
"glm4_5v_fp8": run_glm4_5v_fp8,
|
||||
"h2ovl_chat": run_h2ovl,
|
||||
"hyperclovax_seed_vision": run_hyperclovax_seed_vision,
|
||||
"idefics3": run_idefics3,
|
||||
@ -1625,14 +1463,12 @@ model_example_map = {
|
||||
"mantis": run_mantis,
|
||||
"minicpmo": run_minicpmo,
|
||||
"minicpmv": run_minicpmv,
|
||||
"minimax_vl_01": run_minimax_vl_01,
|
||||
"mistral3": run_mistral3,
|
||||
"mllama": run_mllama,
|
||||
"molmo": run_molmo,
|
||||
"nemotron_vl": run_nemotron_vl,
|
||||
"NVLM_D": run_nvlm_d,
|
||||
"ovis": run_ovis,
|
||||
"ovis2_5": run_ovis2_5,
|
||||
"paligemma": run_paligemma,
|
||||
"paligemma2": run_paligemma2,
|
||||
"phi3_v": run_phi3v,
|
||||
@ -1643,7 +1479,6 @@ model_example_map = {
|
||||
"qwen2_vl": run_qwen2_vl,
|
||||
"qwen2_5_vl": run_qwen2_5_vl,
|
||||
"qwen2_5_omni": run_qwen2_5_omni,
|
||||
"rvl": run_r_vl,
|
||||
"skywork_chat": run_skyworkr1v,
|
||||
"smolvlm": run_smolvlm,
|
||||
"step3": run_step3,
|
||||
|
||||
@ -680,36 +680,6 @@ def load_ovis(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
# ovis2_5
|
||||
def load_ovis2_5(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "AIDC-AI/Ovis2.5-2B"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
trust_remote_code=True,
|
||||
dtype="half",
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
)
|
||||
|
||||
placeholders = "\n".join(
|
||||
f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1)
|
||||
)
|
||||
messages = [{"role": "user", "content": f"{placeholders}\n{question}"}]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
image_data=[fetch_image(url) for url in image_urls],
|
||||
)
|
||||
|
||||
|
||||
def load_pixtral_hf(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "mistral-community/pixtral-12b"
|
||||
|
||||
@ -992,39 +962,6 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
def load_r_vl(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "YannQi/R-4B"
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=16384,
|
||||
max_num_seqs=16,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
)
|
||||
|
||||
placeholders = [{"type": "image", "image": url} for url in image_urls]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
*placeholders,
|
||||
{"type": "text", "text": question},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||||
|
||||
prompt = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
image_data=[fetch_image(url) for url in image_urls],
|
||||
)
|
||||
|
||||
|
||||
def load_smolvlm(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "HuggingFaceTB/SmolVLM2-2.2B-Instruct"
|
||||
|
||||
@ -1127,76 +1064,6 @@ def load_tarsier2(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
# GLM-4.5V
|
||||
def load_glm4_5v(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "zai-org/GLM-4.5V"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=32768,
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=4,
|
||||
)
|
||||
placeholders = [{"type": "image", "image": url} for url in image_urls]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
*placeholders,
|
||||
{"type": "text", "text": question},
|
||||
],
|
||||
}
|
||||
]
|
||||
processor = AutoProcessor.from_pretrained(model_name)
|
||||
prompt = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
image_data = [fetch_image(url) for url in image_urls]
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
image_data=image_data,
|
||||
)
|
||||
|
||||
|
||||
# GLM-4.5V-FP8
|
||||
def load_glm4_5v_fp8(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "zai-org/GLM-4.5V-FP8"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=32768,
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=4,
|
||||
)
|
||||
placeholders = [{"type": "image", "image": url} for url in image_urls]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
*placeholders,
|
||||
{"type": "text", "text": question},
|
||||
],
|
||||
}
|
||||
]
|
||||
processor = AutoProcessor.from_pretrained(model_name)
|
||||
prompt = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
image_data = [fetch_image(url) for url in image_urls]
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
image_data=image_data,
|
||||
)
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"aria": load_aria,
|
||||
"aya_vision": load_aya_vision,
|
||||
@ -1218,7 +1085,6 @@ model_example_map = {
|
||||
"mllama": load_mllama,
|
||||
"NVLM_D": load_nvlm_d,
|
||||
"ovis": load_ovis,
|
||||
"ovis2_5": load_ovis2_5,
|
||||
"phi3_v": load_phi3v,
|
||||
"phi4_mm": load_phi4mm,
|
||||
"phi4_multimodal": load_phi4_multimodal,
|
||||
@ -1226,13 +1092,10 @@ model_example_map = {
|
||||
"qwen_vl_chat": load_qwen_vl_chat,
|
||||
"qwen2_vl": load_qwen2_vl,
|
||||
"qwen2_5_vl": load_qwen2_5_vl,
|
||||
"rvl": load_r_vl,
|
||||
"smolvlm": load_smolvlm,
|
||||
"step3": load_step3,
|
||||
"tarsier": load_tarsier,
|
||||
"tarsier2": load_tarsier2,
|
||||
"glm4_5v": load_glm4_5v,
|
||||
"glm4_5v_fp8": load_glm4_5v_fp8,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -1,14 +1,10 @@
|
||||
{%- if messages and messages[0]['role'] == 'system' %}
|
||||
{%- set system_message = messages[0]['content']|trim %}
|
||||
{%- set messages = messages[1:] %}
|
||||
{%- else %}
|
||||
{%- set system_message = "You are a helpful assistant." %}
|
||||
{%- endif %}
|
||||
|
||||
{%- if messages %}
|
||||
{%- if system_message or tools %}
|
||||
<|system|>
|
||||
|
||||
{%- if system_message %}
|
||||
{{ system_message }}
|
||||
{%- if tools %}
|
||||
{%- endif %}
|
||||
In addition to plain text responses, you can chose to call one or more of the provided functions.
|
||||
|
||||
Use the following rule to decide when to call a function:
|
||||
@ -23,11 +19,13 @@ If you decide to call functions:
|
||||
* make sure you pick the right functions that match the user intent
|
||||
|
||||
|
||||
{%- if tools %}
|
||||
{%- for t in tools %}
|
||||
{{- t | tojson(indent=4) }}
|
||||
{{- "\n\n" }}
|
||||
{%- endfor %}
|
||||
{%- endif %}<|end|>
|
||||
{%- endif %}
|
||||
|
||||
{%- for message in messages %}
|
||||
{%- if message.role != "system" %}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user