Compare commits
81 Commits
torch-2.8
...
prune-samp
| Author | SHA1 | Date | |
|---|---|---|---|
| ce24bc8a9e | |||
| 48bfb0c9b7 | |||
| f8ce022948 | |||
| 0278f1ac3a | |||
| a482e4e769 | |||
| e0b056e443 | |||
| 79f05e4436 | |||
| f8daddcc4c | |||
| c8e33c72c6 | |||
| d70a16625d | |||
| 5cc54f7c5b | |||
| 0c6e40bbaa | |||
| 2e2000f352 | |||
| 31282401b6 | |||
| 0c31e28e95 | |||
| f571ff8eb6 | |||
| f64ee61d9e | |||
| 8993073dc1 | |||
| 655a09f653 | |||
| f94bf9b924 | |||
| 3663870c72 | |||
| 2461d9e562 | |||
| 7be5d113d8 | |||
| b029de9902 | |||
| bbea1cefdd | |||
| f5aa307d77 | |||
| 4b795020ed | |||
| c86af22f31 | |||
| 10cc12ba66 | |||
| a4fbb32fab | |||
| 1b125004be | |||
| 4fbda0b20c | |||
| 4e51fa8cba | |||
| bf7c99dfc4 | |||
| b95697d731 | |||
| 582bbe6bd7 | |||
| 0cdbf5e61c | |||
| ebe56a0064 | |||
| f77a0802b7 | |||
| c4477f55e5 | |||
| dfd2382039 | |||
| 3b11b26b50 | |||
| d6d13bd49e | |||
| 5efd6905bc | |||
| b17109beea | |||
| 4449235843 | |||
| 38217877aa | |||
| c6d80a7a96 | |||
| 7cd17e22d7 | |||
| 50df09fe13 | |||
| 68fcd3fa73 | |||
| 83e69a09d6 | |||
| 3aa8c10038 | |||
| 103f1ec8d3 | |||
| d983769c41 | |||
| 8fd920924c | |||
| de7b67a023 | |||
| f729023272 | |||
| 1a3079a15e | |||
| 941f56858a | |||
| a634733f67 | |||
| 64ab3c7253 | |||
| e58c5a9768 | |||
| d46d417b58 | |||
| 0167efe20d | |||
| c32e6ad1f6 | |||
| 1630cc8d0f | |||
| 14e2b0730b | |||
| 0f4f0191d8 | |||
| a38b8af4c3 | |||
| 21dce80ea9 | |||
| e61bac87ee | |||
| 80141bbf2f | |||
| b94faf9d50 | |||
| 5b5f350d67 | |||
| f7cf5b512e | |||
| 03d4235fd2 | |||
| d6a1a20973 | |||
| 2c4101068c | |||
| 91d7dedc06 | |||
| 806a377171 |
@ -8,7 +8,8 @@ template = """<!DOCTYPE html>
|
||||
<html>
|
||||
<body>
|
||||
<h1>Links for vLLM</h1/>
|
||||
<a href="../{wheel_html_escaped}">{wheel}</a><br/>
|
||||
<a href="../{x86_wheel_html_escaped}">{x86_wheel}</a><br/>
|
||||
<a href="../{arm_wheel_html_escaped}">{arm_wheel}</a><br/>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
@ -21,7 +22,25 @@ filename = os.path.basename(args.wheel)
|
||||
|
||||
with open("index.html", "w") as f:
|
||||
print(f"Generated index.html for {args.wheel}")
|
||||
# sync the abi tag with .buildkite/scripts/upload-wheels.sh
|
||||
if "x86_64" in filename:
|
||||
x86_wheel = filename
|
||||
arm_wheel = filename.replace("x86_64", "aarch64").replace(
|
||||
"manylinux1", "manylinux2014"
|
||||
)
|
||||
elif "aarch64" in filename:
|
||||
x86_wheel = filename.replace("aarch64", "x86_64").replace(
|
||||
"manylinux2014", "manylinux1"
|
||||
)
|
||||
arm_wheel = filename
|
||||
else:
|
||||
raise ValueError(f"Unsupported wheel: {filename}")
|
||||
# cloudfront requires escaping the '+' character
|
||||
f.write(
|
||||
template.format(wheel=filename, wheel_html_escaped=filename.replace("+", "%2B"))
|
||||
template.format(
|
||||
x86_wheel=x86_wheel,
|
||||
x86_wheel_html_escaped=x86_wheel.replace("+", "%2B"),
|
||||
arm_wheel=arm_wheel,
|
||||
arm_wheel_html_escaped=arm_wheel.replace("+", "%2B"),
|
||||
)
|
||||
)
|
||||
|
||||
@ -1,12 +0,0 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m HandH1998/QQQ-Llama-3-8b-g128 -b 32 -l 1000 -f 5 -t 1
|
||||
model_name: "HandH1998/QQQ-Llama-3-8b-g128"
|
||||
tasks:
|
||||
- name: "gsm8k"
|
||||
metrics:
|
||||
- name: "exact_match,strict-match"
|
||||
value: 0.419
|
||||
- name: "exact_match,flexible-extract"
|
||||
value: 0.416
|
||||
limit: 1000
|
||||
num_fewshot: 5
|
||||
@ -3,4 +3,3 @@ Meta-Llama-3-70B-Instruct.yaml
|
||||
Mixtral-8x7B-Instruct-v0.1.yaml
|
||||
Qwen2-57B-A14-Instruct.yaml
|
||||
DeepSeek-V2-Lite-Chat.yaml
|
||||
Meta-Llama-3-8B-QQQ.yaml
|
||||
|
||||
@ -3,44 +3,129 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from importlib import util
|
||||
|
||||
import pandas as pd
|
||||
|
||||
plotly_found = util.find_spec("plotly.express") is not None
|
||||
|
||||
|
||||
def compare_data_columns(
|
||||
files, name_column, data_column, info_cols, drop_column, debug=False
|
||||
):
|
||||
print("\ncompare_data_column: " + data_column)
|
||||
"""
|
||||
Align concatenation by keys derived from info_cols instead of row order.
|
||||
- Pick one canonical key list: subset of info_cols present in ALL files.
|
||||
- For each file: set index to those keys, aggregate duplicates
|
||||
- (mean for metric, first for names).
|
||||
- Concat along axis=1 (indexes align), then reset_index so callers can
|
||||
- group by columns.
|
||||
- If --debug, add a <file_label>_name column per file.
|
||||
"""
|
||||
print("\ncompare_data_column:", data_column)
|
||||
|
||||
frames = []
|
||||
raw_data_cols = []
|
||||
compare_frames = []
|
||||
|
||||
# 1) choose a canonical key list from info_cols that exists in ALL files
|
||||
cols_per_file = []
|
||||
for f in files:
|
||||
try:
|
||||
df_tmp = pd.read_json(f, orient="records")
|
||||
except Exception as err:
|
||||
raise ValueError(f"Failed to read {f}") from err
|
||||
cols_per_file.append(set(df_tmp.columns))
|
||||
|
||||
key_cols = [c for c in info_cols if all(c in cset for cset in cols_per_file)]
|
||||
if not key_cols:
|
||||
# soft fallback: use any info_cols present in the first file
|
||||
key_cols = [c for c in info_cols if c in list(cols_per_file[0])]
|
||||
if not key_cols:
|
||||
raise ValueError(
|
||||
"No common key columns found from info_cols across the input files."
|
||||
)
|
||||
|
||||
# 2) build a single "meta" block (keys as columns) once, aligned by the key index
|
||||
meta_added = False
|
||||
|
||||
for file in files:
|
||||
data_df = pd.read_json(file)
|
||||
serving_df = data_df.dropna(subset=[drop_column], ignore_index=True)
|
||||
# Show all info columns in the first couple columns
|
||||
if not frames:
|
||||
for col in info_cols:
|
||||
if col not in serving_df.columns:
|
||||
print(f"Skipping missing column: {col}")
|
||||
continue
|
||||
frames.append(serving_df[col])
|
||||
# only show test name under debug mode
|
||||
if debug is True:
|
||||
serving_df = serving_df.rename(columns={name_column: file + "_name"})
|
||||
frames.append(serving_df[file + "_name"])
|
||||
df = pd.read_json(file, orient="records")
|
||||
|
||||
file = "/".join(file.split("/")[:-1])
|
||||
serving_df = serving_df.rename(columns={data_column: file})
|
||||
frames.append(serving_df[file])
|
||||
raw_data_cols.append(file)
|
||||
compare_frames.append(serving_df[file])
|
||||
# Keep rows that actually have the compared metric (same as original behavior)
|
||||
if drop_column in df.columns:
|
||||
df = df.dropna(subset=[drop_column], ignore_index=True)
|
||||
|
||||
# Stabilize numeric key columns (harmless if missing)
|
||||
for c in (
|
||||
"Input Len",
|
||||
"Output Len",
|
||||
"TP Size",
|
||||
"PP Size",
|
||||
"# of max concurrency.",
|
||||
"qps",
|
||||
):
|
||||
if c in df.columns:
|
||||
df[c] = pd.to_numeric(df[c], errors="coerce")
|
||||
|
||||
# Ensure all key columns exist
|
||||
for c in key_cols:
|
||||
if c not in df.columns:
|
||||
df[c] = pd.NA
|
||||
|
||||
# Set index = key_cols and aggregate duplicates → unique MultiIndex
|
||||
df_idx = df.set_index(key_cols, drop=False)
|
||||
|
||||
# meta (key columns), unique per key
|
||||
meta = df_idx[key_cols]
|
||||
if not meta.index.is_unique:
|
||||
meta = meta.groupby(level=key_cols, dropna=False).first()
|
||||
|
||||
# metric series for this file, aggregated to one row per key
|
||||
file_label = "/".join(file.split("/")[:-1]) or os.path.basename(file)
|
||||
s = df_idx[data_column]
|
||||
if not s.index.is_unique:
|
||||
s = s.groupby(level=key_cols, dropna=False).mean()
|
||||
s.name = file_label # column label like original
|
||||
|
||||
# add meta once (from first file) so keys are the leftmost columns
|
||||
if not meta_added:
|
||||
frames.append(meta)
|
||||
meta_added = True
|
||||
|
||||
# (NEW) debug: aligned test-name column per file
|
||||
if debug and name_column in df_idx.columns:
|
||||
name_s = df_idx[name_column]
|
||||
if not name_s.index.is_unique:
|
||||
name_s = name_s.groupby(level=key_cols, dropna=False).first()
|
||||
name_s.name = f"{file_label}_name"
|
||||
frames.append(name_s)
|
||||
|
||||
frames.append(s)
|
||||
raw_data_cols.append(file_label)
|
||||
compare_frames.append(s)
|
||||
|
||||
# Generalize ratio: for any file N>=2, add ratio (fileN / file1)
|
||||
if len(compare_frames) >= 2:
|
||||
# Compare numbers among two files
|
||||
ratio_df = compare_frames[1] / compare_frames[0]
|
||||
frames.append(ratio_df)
|
||||
compare_frames.pop(1)
|
||||
base = compare_frames[0]
|
||||
current = compare_frames[-1]
|
||||
ratio = current / base
|
||||
ratio = ratio.mask(base == 0) # avoid inf when baseline is 0
|
||||
ratio.name = f"Ratio 1 vs {len(compare_frames)}"
|
||||
frames.append(ratio)
|
||||
|
||||
# 4) concat on columns with aligned MultiIndex;
|
||||
# then reset_index to return keys as columns
|
||||
concat_df = pd.concat(frames, axis=1)
|
||||
concat_df = concat_df.reset_index(drop=True).reset_index()
|
||||
if "index" in concat_df.columns:
|
||||
concat_df = concat_df.drop(columns=["index"])
|
||||
|
||||
# Ensure key/info columns appear first (in your info_cols order)
|
||||
front = [c for c in info_cols if c in concat_df.columns]
|
||||
rest = [c for c in concat_df.columns if c not in front]
|
||||
concat_df = concat_df[front + rest]
|
||||
|
||||
print(raw_data_cols)
|
||||
return concat_df, raw_data_cols
|
||||
|
||||
@ -67,6 +152,15 @@ def split_json_by_tp_pp(
|
||||
|
||||
df = pd.DataFrame(data)
|
||||
|
||||
# Keep only "serving" tests
|
||||
name_col = next(
|
||||
(c for c in ["Test name", "test_name", "Test Name"] if c in df.columns), None
|
||||
)
|
||||
if name_col:
|
||||
df = df[
|
||||
df[name_col].astype(str).str.contains(r"serving", case=False, na=False)
|
||||
].copy()
|
||||
|
||||
# Handle alias column names
|
||||
rename_map = {
|
||||
"tp_size": "TP Size",
|
||||
@ -181,7 +275,6 @@ if __name__ == "__main__":
|
||||
f"Expected subset: {filtered_info_cols}, "
|
||||
f"but DataFrame has: {list(output_df.columns)}"
|
||||
)
|
||||
|
||||
output_df_sorted = output_df.sort_values(by=existing_group_cols)
|
||||
output_groups = output_df_sorted.groupby(existing_group_cols, dropna=False)
|
||||
for name, group in output_groups:
|
||||
@ -189,8 +282,7 @@ if __name__ == "__main__":
|
||||
text_file.write(html_msgs_for_data_cols[i])
|
||||
text_file.write(html)
|
||||
|
||||
if plot is True:
|
||||
import pandas as pd
|
||||
if plot and plotly_found:
|
||||
import plotly.express as px
|
||||
|
||||
df = group[raw_data_cols]
|
||||
|
||||
@ -27,7 +27,12 @@ steps:
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
- block: "Build CUDA 12.6 wheel"
|
||||
key: block-build-cu126-wheel
|
||||
depends_on: ~
|
||||
|
||||
- label: "Build wheel - CUDA 12.6"
|
||||
depends_on: block-build-cu126-wheel
|
||||
id: build-wheel-cuda-12-6
|
||||
agents:
|
||||
queue: cpu_queue_postmerge
|
||||
@ -68,7 +73,7 @@ steps:
|
||||
queue: cpu_queue_postmerge
|
||||
commands:
|
||||
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ."
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg FLASHINFER_AOT_COMPILE=true --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ."
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT"
|
||||
|
||||
- label: "Annotate release workflow"
|
||||
|
||||
@ -46,6 +46,11 @@ function cpu_tests() {
|
||||
set -e
|
||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m"
|
||||
|
||||
# Run kernel tests
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
pytest -v -s tests/kernels/test_onednn.py"
|
||||
|
||||
# Run basic model test
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
@ -99,4 +104,4 @@ function cpu_tests() {
|
||||
|
||||
# All of CPU tests are expected to be finished less than 40 mins.
|
||||
export -f cpu_tests
|
||||
timeout 1.5h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
|
||||
timeout 2h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
|
||||
|
||||
@ -17,7 +17,7 @@ if [ "$disk_usage" -gt "$threshold" ]; then
|
||||
# Remove dangling images (those that are not tagged and not used by any container)
|
||||
docker image prune -f
|
||||
# Remove unused volumes / force the system prune for old images as well.
|
||||
docker volume prune -f && docker system prune --force --filter "until=72h" --all
|
||||
docker volume prune -f && docker system prune --force --filter "until=24h" --all
|
||||
echo "Docker images and volumes cleanup completed."
|
||||
else
|
||||
echo "Disk usage is below $threshold%. No cleanup needed."
|
||||
|
||||
@ -14,8 +14,19 @@ fi
|
||||
# Get the single wheel file
|
||||
wheel="${wheel_files[0]}"
|
||||
|
||||
# Rename 'linux' to 'manylinux1' in the wheel filename
|
||||
new_wheel="${wheel/linux/manylinux1}"
|
||||
# Detect architecture and rename 'linux' to appropriate manylinux version
|
||||
arch=$(uname -m)
|
||||
if [[ $arch == "x86_64" ]]; then
|
||||
manylinux_version="manylinux1"
|
||||
elif [[ $arch == "aarch64" ]]; then
|
||||
manylinux_version="manylinux2014"
|
||||
else
|
||||
echo "Warning: Unknown architecture $arch, using manylinux1 as default"
|
||||
manylinux_version="manylinux1"
|
||||
fi
|
||||
|
||||
# Rename 'linux' to the appropriate manylinux version in the wheel filename
|
||||
new_wheel="${wheel/linux/$manylinux_version}"
|
||||
mv -- "$wheel" "$new_wheel"
|
||||
wheel="$new_wheel"
|
||||
|
||||
|
||||
@ -126,7 +126,8 @@ steps:
|
||||
- tests/entrypoints/test_chat_utils
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/
|
||||
- PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/openai/test_collective_rpc.py # PYTHONPATH is needed to import custom Worker extension
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_collective_rpc.py
|
||||
- pytest -v -s entrypoints/test_chat_utils.py
|
||||
|
||||
- label: Distributed Tests (4 GPUs) # 10min
|
||||
@ -303,7 +304,6 @@ steps:
|
||||
- tests/conftest.py
|
||||
commands:
|
||||
- pytest -v -s samplers
|
||||
- VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers
|
||||
|
||||
- label: LoRA Test %N # 15min each
|
||||
mirror_hardwares: [amdexperimental]
|
||||
@ -327,6 +327,7 @@ steps:
|
||||
- pytest -v -s compile/test_sequence_parallelism.py
|
||||
- pytest -v -s compile/test_async_tp.py
|
||||
- pytest -v -s compile/test_fusion_all_reduce.py
|
||||
- pytest -v -s compile/test_decorator.py
|
||||
|
||||
- label: PyTorch Fullgraph Smoke Test # 9min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
@ -340,6 +341,7 @@ steps:
|
||||
- pytest -v -s compile/piecewise/test_simple.py
|
||||
- pytest -v -s compile/piecewise/test_toy_llama.py
|
||||
- pytest -v -s compile/piecewise/test_full_cudagraph.py
|
||||
- pytest -v -s compile/piecewise/test_multiple_graphs.py
|
||||
|
||||
- label: PyTorch Fullgraph Test # 18min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
@ -450,13 +452,11 @@ steps:
|
||||
|
||||
- label: LM Eval Small Models # 53min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
|
||||
source_file_dependencies:
|
||||
- csrc/
|
||||
- vllm/model_executor/layers/quantization
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-small.txt --tp-size=1
|
||||
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1
|
||||
|
||||
- label: OpenAI API correctness
|
||||
mirror_hardwares: [amdexperimental]
|
||||
@ -544,6 +544,15 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s models/language/pooling -m 'not core_model'
|
||||
|
||||
- label: Multi-Modal Processor Test
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models/multimodal
|
||||
commands:
|
||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
- pytest -v -s models/multimodal/processing --ignore models/multimodal/processing/test_tensor_schema.py
|
||||
- pytest -v -s models/multimodal/processing/test_tensor_schema.py
|
||||
|
||||
- label: Multi-Modal Models Test (Standard)
|
||||
mirror_hardwares: [amdexperimental]
|
||||
torch_nightly: true
|
||||
@ -553,9 +562,7 @@ steps:
|
||||
commands:
|
||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
- pip freeze | grep -E 'torch'
|
||||
- pytest -v -s models/multimodal/processing
|
||||
- pytest -v -s --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/test_tensor_schema.py models/multimodal -m core_model
|
||||
- pytest -v -s models/multimodal/test_tensor_schema.py -m core_model # Needs mp_method="spawn"
|
||||
- pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing
|
||||
- cd .. && pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work
|
||||
|
||||
- label: Multi-Modal Models Test (Extended) 1
|
||||
@ -566,7 +573,7 @@ steps:
|
||||
- tests/models/multimodal
|
||||
commands:
|
||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
- pytest -v -s --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/processing models/multimodal -m 'not core_model'
|
||||
- pytest -v -s models/multimodal -m 'not core_model' --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/processing
|
||||
|
||||
- label: Multi-Modal Models Test (Extended) 2
|
||||
mirror_hardwares: [amdexperimental]
|
||||
@ -629,6 +636,7 @@ steps:
|
||||
- vllm/model_executor/layers/fused_moe/cutlass_moe.py
|
||||
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
|
||||
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
|
||||
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
|
||||
- vllm/v1/attention/backends/flashinfer.py
|
||||
- vllm/compilation/fusion.py
|
||||
- vllm/compilation/fusion_attn.py
|
||||
@ -646,9 +654,11 @@ steps:
|
||||
- pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py
|
||||
- pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
|
||||
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
|
||||
- pytest -v -s tests/kernels/moe/test_mxfp4_moe.py
|
||||
# Fusion
|
||||
- pytest -v -s tests/compile/test_fusion_all_reduce.py
|
||||
- pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern
|
||||
- pytest -v -s tests/kernels/moe/test_flashinfer.py
|
||||
|
||||
##### 1 GPU test #####
|
||||
##### multi gpus test #####
|
||||
|
||||
@ -30,7 +30,7 @@ install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
|
||||
# Supported python versions. These versions will be searched in order, the
|
||||
# first match will be selected. These should be kept in sync with setup.py.
|
||||
#
|
||||
set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12")
|
||||
set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12", "3.13")
|
||||
|
||||
# Supported AMD GPU architectures.
|
||||
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201")
|
||||
@ -357,9 +357,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
|
||||
|
||||
set(MARLIN_SRCS
|
||||
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
|
||||
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
|
||||
"csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu"
|
||||
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
||||
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
|
||||
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu")
|
||||
|
||||
@ -32,6 +32,14 @@ become available.
|
||||
<div>Note that the images need to be downloaded separately. For example, to download COCO's 2017 Train images:</div>
|
||||
<code>wget http://images.cocodataset.org/zips/train2017.zip</code>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>ShareGPT4Video (Video)</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td>
|
||||
<code>git clone https://huggingface.co/datasets/ShareGPT4Video/ShareGPT4Video</code>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>BurstGPT</strong></td>
|
||||
@ -194,6 +202,7 @@ vllm serve Qwen/Qwen2-VL-7B-Instruct
|
||||
```bash
|
||||
vllm bench serve \
|
||||
--backend openai-chat \
|
||||
--endpoint-type openai-chat \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--endpoint /v1/chat/completions \
|
||||
--dataset-name hf \
|
||||
@ -230,6 +239,7 @@ vllm serve Qwen/Qwen2-VL-7B-Instruct
|
||||
```bash
|
||||
vllm bench serve \
|
||||
--backend openai-chat \
|
||||
--endpoint-type openai-chat \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--endpoint /v1/chat/completions \
|
||||
--dataset-name hf \
|
||||
@ -244,6 +254,7 @@ vllm bench serve \
|
||||
```bash
|
||||
vllm bench serve \
|
||||
--backend openai-chat \
|
||||
--endpoint-type openai-chat \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--endpoint /v1/chat/completions \
|
||||
--dataset-name hf \
|
||||
@ -609,7 +620,7 @@ vllm bench serve \
|
||||
--prefix-repetition-prefix-len 512 \
|
||||
--prefix-repetition-suffix-len 128 \
|
||||
--prefix-repetition-num-prefixes 5 \
|
||||
--prefix-repetition-output-len 128
|
||||
--prefix-repetition-output-len 128
|
||||
```
|
||||
|
||||
</details>
|
||||
@ -684,4 +695,31 @@ python benchmarks/benchmark_serving.py \
|
||||
--endpoint /v1/chat/completion
|
||||
```
|
||||
|
||||
### Videos (ShareGPT4Video)
|
||||
|
||||
Start vLLM:
|
||||
|
||||
```bash
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
--model Qwen/Qwen2.5-VL-7B-Instruct \
|
||||
--dtype bfloat16 \
|
||||
--limit-mm-per-prompt '{"video": 1}' \
|
||||
--allowed-local-media-path /path/to/sharegpt4video/videos
|
||||
```
|
||||
|
||||
Send requests with videos:
|
||||
|
||||
```bash
|
||||
python benchmarks/benchmark_serving.py \
|
||||
--backend openai-chat \
|
||||
--model Qwen/Qwen2.5-VL-7B-Instruct \
|
||||
--dataset-name sharegpt \
|
||||
--dataset-path /path/to/ShareGPT4Video/llava_v1_5_mix665k_with_video_chatgpt72k_share4video28k.json \
|
||||
--num-prompts 100 \
|
||||
--save-result \
|
||||
--result-dir ~/vllm_benchmark_results \
|
||||
--save-detailed \
|
||||
--endpoint /v1/chat/completion
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
@ -293,6 +293,41 @@ def process_image(image: Any) -> Mapping[str, Any]:
|
||||
)
|
||||
|
||||
|
||||
def process_video(video: Any) -> Mapping[str, Any]:
|
||||
"""
|
||||
Process a single video input and return a multimedia content dictionary.
|
||||
|
||||
Supports the following input types:
|
||||
|
||||
1. Dictionary with raw video bytes: - Expects a dict with a 'bytes' key
|
||||
containing raw video data.
|
||||
|
||||
2. String input: - Treats the string as a URL or local file path. -
|
||||
Prepends "file://" if the string doesn't start with "http://" or
|
||||
"file://". - Returns a dictionary with the image URL.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input is not a supported type.
|
||||
"""
|
||||
if isinstance(video, dict) and "bytes" in video:
|
||||
video_bytes = video["bytes"]
|
||||
video_base64 = base64.b64encode(video_bytes).decode("utf-8")
|
||||
return {
|
||||
"type": "video_url",
|
||||
"video_url": {"url": f"data:video/mp4;base64,{video_base64}"},
|
||||
}
|
||||
|
||||
if isinstance(video, str):
|
||||
video_url = (
|
||||
video if video.startswith(("http://", "file://")) else f"file://{video}"
|
||||
)
|
||||
return {"type": "video_url", "video_url": {"url": video_url}}
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid video input {video}. Must be a string of local path/remote url, or a dictionary with raw video bytes in the form of `{{'bytes': raw_video_bytes}}`." # noqa: E501
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Random Dataset Implementation (Synthetic Data)
|
||||
# -----------------------------------------------------------------------------
|
||||
@ -451,9 +486,10 @@ class ShareGPTDataset(BenchmarkDataset):
|
||||
skip_min_output_len_check=output_len is not None,
|
||||
):
|
||||
continue
|
||||
# TODO: Also support ShareGPT4Video.
|
||||
if image_path := entry.get("image"):
|
||||
mm_content = process_image(image_path)
|
||||
elif video_path := entry.get("video"):
|
||||
mm_content = process_video(video_path)
|
||||
else:
|
||||
mm_content = None
|
||||
if enable_multimodal_chat:
|
||||
@ -922,8 +958,10 @@ class InstructCoderDataset(HuggingFaceDataset):
|
||||
for i, item in enumerate(self.data):
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
prompt = f"{item['input']}\n\n{item['instruction']} Just output \
|
||||
the code, do not include any explanation."
|
||||
prompt = (
|
||||
f"{item['input']}\n\n{item['instruction']} Just output "
|
||||
"the code, do not include any explanation."
|
||||
)
|
||||
|
||||
# apply template
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
|
||||
@ -597,8 +597,8 @@ def validate_args(args):
|
||||
# https://github.com/vllm-project/vllm/issues/16222
|
||||
if args.data_parallel_size > 1:
|
||||
raise ValueError(
|
||||
"Data parallel is not supported in offline benchmark, \
|
||||
please use benchmark serving instead"
|
||||
"Data parallel is not supported in offline benchmark, "
|
||||
"please use benchmark serving instead"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -80,6 +80,11 @@ def bench_run(
|
||||
a, score, topk, renormalize=False
|
||||
)
|
||||
|
||||
ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
|
||||
|
||||
def run_triton_moe(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
@ -111,6 +116,10 @@ def bench_run(
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
per_act_token: bool,
|
||||
@ -125,6 +134,10 @@ def bench_run(
|
||||
topk_ids,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
per_act_token,
|
||||
a1_scale=None,
|
||||
)
|
||||
@ -136,6 +149,10 @@ def bench_run(
|
||||
w2_q: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
):
|
||||
@ -150,6 +167,10 @@ def bench_run(
|
||||
topk_ids,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
per_act_token,
|
||||
a1_scale=None,
|
||||
)
|
||||
@ -194,6 +215,10 @@ def bench_run(
|
||||
w2_q,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
)
|
||||
@ -231,6 +256,10 @@ def bench_run(
|
||||
"w1_scale": w1_scale,
|
||||
"w2_scale": w2_scale,
|
||||
"per_act_token": per_act_token,
|
||||
"ab_strides1": ab_strides1,
|
||||
"ab_strides2": ab_strides2,
|
||||
"c_strides1": c_strides1,
|
||||
"c_strides2": c_strides2,
|
||||
# cuda graph params
|
||||
"cutlass_graph": cutlass_graph,
|
||||
"triton_graph": triton_graph,
|
||||
@ -289,6 +318,10 @@ def bench_run(
|
||||
w2_q,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
per_act_token,
|
||||
@ -297,7 +330,7 @@ def bench_run(
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
|
||||
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
|
||||
@ -253,28 +253,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
|
||||
else:
|
||||
assert bt.a.dtype == torch.int8
|
||||
assert bt.wtype == scalar_types.uint4b8
|
||||
|
||||
if bt.w_ch_s is not None:
|
||||
s_ch = bt.w_ch_s.to(torch.float32)
|
||||
else:
|
||||
s_ch = torch.ones(bt.w_ref.shape[1], dtype=torch.float32, device=device)
|
||||
|
||||
if bt.w_tok_s is not None:
|
||||
s_tok = bt.w_tok_s.to(torch.float32)
|
||||
else:
|
||||
s_tok = torch.ones(bt.a.shape[0], dtype=torch.float32, device=device)
|
||||
|
||||
fn = lambda: ops.marlin_qqq_gemm(
|
||||
a=bt.a,
|
||||
b_q_weight=w_q,
|
||||
s_group=w_s,
|
||||
s_tok=s_tok,
|
||||
s_ch=s_ch,
|
||||
workspace=workspace.scratch,
|
||||
size_m=bt.a.shape[0],
|
||||
size_n=bt.w_ref.shape[1],
|
||||
size_k=bt.w_ref.shape[0],
|
||||
)
|
||||
raise NotImplementedError("QQQ is not supported anymore")
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
@ -182,17 +182,17 @@ endif()
|
||||
#
|
||||
# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 /ARM platforms)
|
||||
# Flag to enable ACL kernels for AARCH64 platforms
|
||||
if ( VLLM_BUILD_ACL STREQUAL "ON")
|
||||
if (VLLM_BUILD_ACL STREQUAL "ON")
|
||||
set(USE_ACL ON)
|
||||
else()
|
||||
set(USE_ACL OFF)
|
||||
endif()
|
||||
|
||||
if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND)
|
||||
if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND)
|
||||
FetchContent_Declare(
|
||||
oneDNN
|
||||
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
|
||||
GIT_TAG v3.8.1
|
||||
GIT_TAG v3.9
|
||||
GIT_PROGRESS TRUE
|
||||
GIT_SHALLOW TRUE
|
||||
)
|
||||
@ -204,7 +204,7 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND)
|
||||
endif()
|
||||
set(ONEDNN_AARCH64_USE_ACL "ON")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(ONEDNN_LIBRARY_TYPE "STATIC")
|
||||
set(ONEDNN_BUILD_DOC "OFF")
|
||||
@ -217,38 +217,23 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND)
|
||||
set(ONEDNN_ENABLE_ITT_TASKS "OFF")
|
||||
set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF")
|
||||
set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF")
|
||||
set(ONEDNN_VERBOSE "OFF")
|
||||
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
|
||||
|
||||
FetchContent_MakeAvailable(oneDNN)
|
||||
|
||||
list(APPEND LIBS dnnl)
|
||||
elseif(POWER10_FOUND)
|
||||
FetchContent_Declare(
|
||||
oneDNN
|
||||
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
|
||||
GIT_TAG v3.7.2
|
||||
GIT_PROGRESS TRUE
|
||||
GIT_SHALLOW TRUE
|
||||
add_library(dnnl_ext OBJECT "csrc/cpu/dnnl_helper.cpp")
|
||||
target_include_directories(
|
||||
dnnl_ext
|
||||
PUBLIC ${oneDNN_SOURCE_DIR}/include
|
||||
PUBLIC ${oneDNN_BINARY_DIR}/include
|
||||
PRIVATE ${oneDNN_SOURCE_DIR}/src
|
||||
)
|
||||
|
||||
set(ONEDNN_LIBRARY_TYPE "STATIC")
|
||||
set(ONEDNN_BUILD_DOC "OFF")
|
||||
set(ONEDNN_BUILD_EXAMPLES "OFF")
|
||||
set(ONEDNN_BUILD_TESTS "OFF")
|
||||
set(ONEDNN_ENABLE_WORKLOAD "INFERENCE")
|
||||
set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER")
|
||||
set(ONEDNN_BUILD_GRAPH "OFF")
|
||||
set(ONEDNN_ENABLE_JIT_PROFILING "OFF")
|
||||
set(ONEDNN_ENABLE_ITT_TASKS "OFF")
|
||||
set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF")
|
||||
set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF")
|
||||
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
|
||||
|
||||
set(DNNL_CPU_RUNTIME "OMP")
|
||||
|
||||
FetchContent_MakeAvailable(oneDNN)
|
||||
|
||||
list(APPEND LIBS dnnl)
|
||||
target_link_libraries(dnnl_ext dnnl)
|
||||
target_compile_options(dnnl_ext PRIVATE ${CXX_COMPILE_FLAGS} -fPIC)
|
||||
list(APPEND LIBS dnnl_ext)
|
||||
set(USE_ONEDNN ON)
|
||||
else()
|
||||
set(USE_ONEDNN OFF)
|
||||
endif()
|
||||
|
||||
message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
|
||||
@ -275,7 +260,6 @@ set(VLLM_EXT_SRC
|
||||
|
||||
if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/quant.cpp"
|
||||
"csrc/cpu/shm.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI)
|
||||
@ -289,14 +273,11 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
${VLLM_EXT_SRC})
|
||||
add_compile_definitions(-DCPU_CAPABILITY_AVX512)
|
||||
endif()
|
||||
elseif(POWER10_FOUND)
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/quant.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
endif()
|
||||
if (ASIMD_FOUND)
|
||||
|
||||
if(USE_ONEDNN)
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/quant.cpp"
|
||||
"csrc/cpu/dnnl_kernels.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
endif()
|
||||
|
||||
|
||||
@ -167,7 +167,7 @@ typename T::Fmha::Arguments args_from_options(
|
||||
// TODO(trevor-m): Change split_kv back to -1 when
|
||||
// https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will
|
||||
// perform worse with larger context length and smaller batch sizes.
|
||||
num_kv_splits, // split_kv
|
||||
static_cast<int>(num_kv_splits), // split_kv
|
||||
nullptr, // is_var_split_kv
|
||||
};
|
||||
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
|
||||
@ -264,7 +264,7 @@ int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_ba
|
||||
// Assumes device 0 when getting sm_count.
|
||||
arguments.hw_info.sm_count =
|
||||
sm_count <= 0 ? cutlass::KernelHardwareInfo::query_device_multiprocessor_count(/*device_id=*/0) : sm_count;
|
||||
arguments.split_kv = num_kv_splits;
|
||||
arguments.split_kv = static_cast<int>(num_kv_splits);
|
||||
MlaSm100Type::Fmha::set_split_kv(arguments);
|
||||
|
||||
return MlaSm100Type::Fmha::get_workspace_size(arguments);
|
||||
|
||||
@ -89,7 +89,7 @@ struct FP16Vec16 : public Vec<FP16Vec16> {
|
||||
|
||||
explicit FP16Vec16(const FP32Vec16&);
|
||||
|
||||
void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; }
|
||||
void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); }
|
||||
|
||||
void save(void* ptr, const int elem_num) const {
|
||||
constexpr uint32_t M = 0xFFFFFFFF;
|
||||
@ -126,7 +126,7 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||
|
||||
explicit BF16Vec16(const FP32Vec16&);
|
||||
|
||||
void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; }
|
||||
void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); }
|
||||
|
||||
void save(void* ptr, const int elem_num) const {
|
||||
constexpr uint32_t M = 0xFFFFFFFF;
|
||||
@ -180,8 +180,8 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
(__m128i)vec8_data.reg, 1)) {}
|
||||
|
||||
void save(void* ptr) const {
|
||||
*reinterpret_cast<__m256i*>(ptr) = reg_low;
|
||||
*reinterpret_cast<__m256i*>((__m256i*)ptr + 1) = reg_high;
|
||||
_mm256_storeu_si256((__m256i*)ptr, reg_low);
|
||||
_mm256_storeu_si256((__m256i*)ptr + 1, reg_high);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
346
csrc/cpu/dnnl_helper.cpp
Normal file
346
csrc/cpu/dnnl_helper.cpp
Normal file
@ -0,0 +1,346 @@
|
||||
#include <list>
|
||||
#include <optional>
|
||||
|
||||
#include "common/memory_desc.hpp"
|
||||
#include "common/memory.hpp"
|
||||
|
||||
#include "dnnl_helper.h"
|
||||
|
||||
static dnnl::engine& default_engine() {
|
||||
static dnnl::engine engine(dnnl::engine::kind::cpu, 0);
|
||||
return engine;
|
||||
}
|
||||
|
||||
static dnnl::stream& default_stream() {
|
||||
static dnnl::stream stream(default_engine());
|
||||
return stream;
|
||||
}
|
||||
|
||||
void release_dnnl_matmul_handler(int64_t handler) {
|
||||
DNNLMatMulPrimitiveHandler* ptr =
|
||||
reinterpret_cast<DNNLMatMulPrimitiveHandler*>(handler);
|
||||
delete ptr;
|
||||
}
|
||||
|
||||
template <typename KT, typename VT>
|
||||
class DNNLPrimitiveCache {
|
||||
public:
|
||||
using cache_value_t = std::pair<KT, VT>;
|
||||
using result_value_t = VT;
|
||||
using container_t = std::list<cache_value_t>;
|
||||
using value_iterator_t = typename container_t::iterator;
|
||||
using map_t = std::unordered_map<KT, value_iterator_t>;
|
||||
using creator_t = VT (*)();
|
||||
|
||||
public:
|
||||
DNNLPrimitiveCache(size_t capacity)
|
||||
: capacity_(capacity),
|
||||
values_(),
|
||||
key_to_value_(std::min(256lu, capacity)) {
|
||||
assert(capacity > 0);
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
result_value_t get_or_create(const KT& key, F&& creator) {
|
||||
std::optional<value_iterator_t> value = get_value(key);
|
||||
if (value.has_value()) {
|
||||
return value.value()->second;
|
||||
} else {
|
||||
return add_value({key, creator()})->second;
|
||||
}
|
||||
}
|
||||
|
||||
size_t size() const { return values_.size(); }
|
||||
|
||||
private:
|
||||
void dump_data() {
|
||||
std::stringstream ss;
|
||||
ss << "table_id: " << std::hex << reinterpret_cast<size_t>(this) << std::dec
|
||||
<< "\n";
|
||||
ss << "container: [";
|
||||
for (auto&& iter : values_) {
|
||||
ss << "(" << iter.first << ", " << std::hex
|
||||
<< reinterpret_cast<size_t>(iter.second.get()) << "), " << std::dec;
|
||||
}
|
||||
ss << "]\n";
|
||||
|
||||
ss << "map: [";
|
||||
for (auto&& iter : key_to_value_) {
|
||||
ss << "(" << iter.first << ", " << iter.second->first << ", " << std::hex
|
||||
<< reinterpret_cast<size_t>(iter.second->second.get()) << std::dec
|
||||
<< "), ";
|
||||
}
|
||||
ss << "]\n";
|
||||
std::printf("%s\n", ss.str().c_str());
|
||||
}
|
||||
|
||||
value_iterator_t add_value(cache_value_t&& new_value) {
|
||||
if (size() == capacity_) {
|
||||
cache_value_t& last_item = values_.back();
|
||||
key_to_value_.erase(last_item.first);
|
||||
values_.pop_back();
|
||||
}
|
||||
|
||||
auto& added_value_ = values_.emplace_front(std::move(new_value));
|
||||
key_to_value_.emplace(added_value_.first, values_.begin());
|
||||
return values_.begin();
|
||||
}
|
||||
|
||||
std::optional<value_iterator_t> get_value(const KT& key) {
|
||||
if (key_to_value_.size() > 0 && key == values_.begin()->first) {
|
||||
return values_.begin();
|
||||
}
|
||||
|
||||
auto value_map_iterator = key_to_value_.find(key);
|
||||
if (value_map_iterator != key_to_value_.end()) {
|
||||
values_.splice(values_.begin(), values_, value_map_iterator->second);
|
||||
return value_map_iterator->second;
|
||||
} else {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
const size_t capacity_;
|
||||
container_t values_;
|
||||
map_t key_to_value_;
|
||||
};
|
||||
|
||||
DNNLMatMulPrimitiveHandler::DNNLMatMulPrimitiveHandler(
|
||||
const Args& args, dnnl::memory::data_type b_type)
|
||||
: b_n_size_(args.b_n_size),
|
||||
b_n_stride_(args.b_n_stride),
|
||||
b_k_size_(args.b_k_size),
|
||||
b_k_stride_(args.b_k_stride),
|
||||
b_type_(b_type),
|
||||
c_type_(args.c_type),
|
||||
runtime_memory_ptrs_(8),
|
||||
primitive_cache_size_(args.primitive_cache_size) {
|
||||
assert(primitive_cache_size_ > 0);
|
||||
}
|
||||
|
||||
void DNNLMatMulPrimitiveHandler::prepack_weight(
|
||||
void* original_b_ptr, dnnl::memory::desc b_target_mem_desc) {
|
||||
dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_,
|
||||
{b_k_stride_, b_n_stride_});
|
||||
dnnl::memory original_weight(original_b_md, default_engine(), original_b_ptr);
|
||||
dnnl::memory packed_weight(b_target_mem_desc, default_engine());
|
||||
{
|
||||
dnnl::reorder(original_weight, packed_weight)
|
||||
.execute(default_stream(), original_weight, packed_weight);
|
||||
default_stream().wait();
|
||||
}
|
||||
memory_cache_[DNNL_ARG_WEIGHTS] = packed_weight;
|
||||
b_target_mem_desc_ = b_target_mem_desc;
|
||||
}
|
||||
|
||||
void DNNLMatMulPrimitiveHandler::set_runtime_memory_ptr(
|
||||
size_t index, dnnl_memory* memory_ptr) {
|
||||
dnnl::impl::memory_storage_t* mem_storage_ptr = memory_ptr->memory_storage();
|
||||
dnnl_memory_desc* mem_desc = const_cast<dnnl_memory_desc*>(memory_ptr->md());
|
||||
runtime_memory_ptrs_[index] = {mem_storage_ptr, mem_desc};
|
||||
}
|
||||
|
||||
std::pair<dnnl::impl::memory_storage_t*, dnnl_memory_desc*>
|
||||
DNNLMatMulPrimitiveHandler::get_runtime_memory_ptr(size_t index) {
|
||||
return runtime_memory_ptrs_[index];
|
||||
}
|
||||
|
||||
namespace std {
|
||||
template <>
|
||||
struct hash<W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey> {
|
||||
size_t operator()(
|
||||
const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& val) const {
|
||||
return hash<dnnl_dim_t>()(val.b_n_size) ^ hash<dnnl_dim_t>()(val.b_k_size) ^
|
||||
hash<int>()(static_cast<int>(val.a_qs)) ^
|
||||
hash<int>()(static_cast<int>(val.b_qs)) ^ hash<bool>()(val.use_azp) ^
|
||||
hash<int>()(static_cast<int>(val.c_type));
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct hash<W8A8MatMulPrimitiveHandler::MSizeCacheKey> {
|
||||
size_t operator()(
|
||||
const W8A8MatMulPrimitiveHandler::MSizeCacheKey& val) const {
|
||||
return hash<dnnl_dim_t>()(val.a_m_size) ^ hash<bool>()(val.use_bias) ^
|
||||
hash<int>()(static_cast<int>(val.bias_type));
|
||||
}
|
||||
};
|
||||
} // namespace std
|
||||
|
||||
bool operator==(const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& l,
|
||||
const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& r) {
|
||||
return l.b_n_size == r.b_n_size && l.b_k_size == r.b_k_size &&
|
||||
l.a_qs == r.a_qs && l.b_qs == r.b_qs && l.use_azp == r.use_azp &&
|
||||
l.c_type == r.c_type;
|
||||
}
|
||||
|
||||
bool operator==(const W8A8MatMulPrimitiveHandler::MSizeCacheKey& l,
|
||||
const W8A8MatMulPrimitiveHandler::MSizeCacheKey& r) {
|
||||
return l.use_bias == r.use_bias && l.a_m_size == r.a_m_size &&
|
||||
l.bias_type == r.bias_type;
|
||||
}
|
||||
|
||||
static std::shared_ptr<W8A8MatMulPrimitiveHandler::MSizeCache>
|
||||
get_w8a8_class_primitive_cache(
|
||||
const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& key,
|
||||
int64_t cache_size) {
|
||||
static W8A8MatMulPrimitiveHandler::ClassMatmulCache cache(128);
|
||||
assert(cache_size > 0);
|
||||
return cache.get_or_create(key, [&]() {
|
||||
return std::make_shared<W8A8MatMulPrimitiveHandler::MSizeCache>(cache_size);
|
||||
});
|
||||
}
|
||||
|
||||
W8A8MatMulPrimitiveHandler::W8A8MatMulPrimitiveHandler(const Args& args)
|
||||
: DNNLMatMulPrimitiveHandler(
|
||||
static_cast<const DNNLMatMulPrimitiveHandler::Args&>(args),
|
||||
dnnl::memory::data_type::s8),
|
||||
use_azp_(args.use_a_zero_point),
|
||||
a_qs_(args.a_quantization_strategy),
|
||||
b_qs_(args.b_quantization_strategy),
|
||||
m_size_cache_(nullptr) {
|
||||
assert(a_qs_ != QuantizationStrategy::PER_OUTPUT_CHANNEL);
|
||||
assert(b_qs_ != QuantizationStrategy::PER_TOKEN);
|
||||
if (a_qs_ == QuantizationStrategy::PER_TOKEN) {
|
||||
assert(!use_azp_);
|
||||
};
|
||||
prepack_weight(args.b_ptr,
|
||||
create_primitive_desc(
|
||||
MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL,
|
||||
.use_bias = false,
|
||||
.bias_type = dnnl::memory::data_type::undef},
|
||||
true)
|
||||
.weights_desc());
|
||||
init_runtime_memory_cache(args);
|
||||
}
|
||||
|
||||
void W8A8MatMulPrimitiveHandler::execute(ExecArgs& args) {
|
||||
auto&& [a_storage, a_mem_desc] = get_runtime_memory_ptr(0);
|
||||
auto&& [c_storage, c_mem_desc] = get_runtime_memory_ptr(1);
|
||||
a_storage->set_data_handle((void*)args.a_ptr);
|
||||
a_mem_desc->dims[0] = args.a_m_size;
|
||||
c_storage->set_data_handle((void*)args.c_ptr);
|
||||
c_mem_desc->dims[0] = args.a_m_size;
|
||||
|
||||
if (a_qs_ == QuantizationStrategy::PER_TENSOR) {
|
||||
auto&& [a_scale_storage, a_scale_mem_desc] = get_runtime_memory_ptr(2);
|
||||
a_scale_storage->set_data_handle((void*)args.a_scales_ptr);
|
||||
}
|
||||
if (use_azp_) {
|
||||
auto&& [a_zero_point_storage, a_zero_point_mem_desc] =
|
||||
get_runtime_memory_ptr(3);
|
||||
a_zero_point_storage->set_data_handle((void*)args.a_zero_points_ptr);
|
||||
}
|
||||
|
||||
if (args.use_bias) {
|
||||
auto&& [bias_storage, bias_mem_desc] = get_runtime_memory_ptr(4);
|
||||
bias_storage->set_data_handle((void*)args.bias_ptr);
|
||||
}
|
||||
|
||||
dnnl::matmul matmul = get_matmul_cache(args);
|
||||
matmul.execute(default_stream(), memory_cache_);
|
||||
default_stream().wait();
|
||||
}
|
||||
|
||||
dnnl::matmul W8A8MatMulPrimitiveHandler::get_matmul_cache(
|
||||
const MSizeCacheKey& key) {
|
||||
if (m_size_cache_.get() == nullptr) {
|
||||
ClassMatmulCacheKey key = {.b_n_size = b_n_size_,
|
||||
.b_k_size = b_k_size_,
|
||||
.a_qs = a_qs_,
|
||||
.b_qs = b_qs_,
|
||||
.use_azp = use_azp_,
|
||||
.c_type = c_type_};
|
||||
m_size_cache_ = get_w8a8_class_primitive_cache(key, primitive_cache_size_);
|
||||
}
|
||||
|
||||
return m_size_cache_->get_or_create(key, [&]() {
|
||||
dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false);
|
||||
return dnnl::matmul(desc);
|
||||
});
|
||||
}
|
||||
|
||||
void W8A8MatMulPrimitiveHandler::init_runtime_memory_cache(const Args& args) {
|
||||
memory_cache_[DNNL_ARG_SRC] = dnnl::memory({{1, b_k_size_},
|
||||
dnnl::memory::data_type::s8,
|
||||
dnnl::memory::format_tag::ab},
|
||||
default_engine(), nullptr);
|
||||
set_runtime_memory_ptr(0, memory_cache_[DNNL_ARG_SRC].get());
|
||||
memory_cache_[DNNL_ARG_DST] =
|
||||
dnnl::memory({{1, b_n_size_}, c_type_, dnnl::memory::format_tag::ab},
|
||||
default_engine(), nullptr);
|
||||
set_runtime_memory_ptr(1, memory_cache_[DNNL_ARG_DST].get());
|
||||
|
||||
// For PER_TOKEN, scales will be applied in outside epilogue
|
||||
if (a_qs_ == QuantizationStrategy::PER_TENSOR) {
|
||||
memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC] = dnnl::memory(
|
||||
{{1}, dnnl::memory::data_type::f32, {1}}, default_engine(), nullptr);
|
||||
set_runtime_memory_ptr(
|
||||
2, memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC].get());
|
||||
if (use_azp_) {
|
||||
memory_cache_[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC] = dnnl::memory(
|
||||
{{1}, dnnl::memory::data_type::s32, {1}}, default_engine(), nullptr);
|
||||
set_runtime_memory_ptr(
|
||||
3, memory_cache_[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC].get());
|
||||
}
|
||||
}
|
||||
|
||||
if (b_qs_ == QuantizationStrategy::PER_TENSOR) {
|
||||
memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] =
|
||||
dnnl::memory({{1}, dnnl::memory::data_type::f32, {1}}, default_engine(),
|
||||
(void*)args.b_scales_ptr);
|
||||
} else if (b_qs_ == QuantizationStrategy::PER_OUTPUT_CHANNEL) {
|
||||
memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] =
|
||||
dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
|
||||
default_engine(), (void*)args.b_scales_ptr);
|
||||
}
|
||||
|
||||
memory_cache_[DNNL_ARG_BIAS] =
|
||||
dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
|
||||
default_engine(), nullptr);
|
||||
set_runtime_memory_ptr(4, memory_cache_[DNNL_ARG_BIAS].get());
|
||||
}
|
||||
|
||||
dnnl::matmul::primitive_desc W8A8MatMulPrimitiveHandler::create_primitive_desc(
|
||||
const MSizeCacheKey& key, bool first_time) {
|
||||
dnnl::memory::desc a_md({key.a_m_size, b_k_size_},
|
||||
dnnl::memory::data_type::s8,
|
||||
dnnl::memory::format_tag::ab);
|
||||
dnnl::memory::desc b_md;
|
||||
if (first_time) {
|
||||
b_md =
|
||||
dnnl::memory::desc({b_k_size_, b_n_size_}, dnnl::memory::data_type::s8,
|
||||
dnnl::memory::format_tag::any);
|
||||
} else {
|
||||
b_md = b_target_mem_desc_;
|
||||
}
|
||||
dnnl::memory::desc c_md({key.a_m_size, b_n_size_}, c_type_,
|
||||
dnnl::memory::format_tag::ab);
|
||||
|
||||
dnnl::primitive_attr attr;
|
||||
// For PER_TOKEN, scales will be applied in outside epilogue
|
||||
if (a_qs_ == QuantizationStrategy::PER_TENSOR) {
|
||||
attr.set_scales_mask(DNNL_ARG_SRC, 0);
|
||||
if (use_azp_) {
|
||||
attr.set_zero_points_mask(DNNL_ARG_SRC, 0);
|
||||
}
|
||||
}
|
||||
|
||||
if (b_qs_ == QuantizationStrategy::PER_TENSOR) {
|
||||
attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0);
|
||||
} else if (b_qs_ == QuantizationStrategy::PER_OUTPUT_CHANNEL) {
|
||||
attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2);
|
||||
}
|
||||
|
||||
if (key.use_bias) {
|
||||
// For PER_TOKEN, bias will be applied in epilogue
|
||||
assert(a_qs_ == QuantizationStrategy::PER_TENSOR);
|
||||
dnnl::memory::desc bias_md({1, b_n_size_}, key.bias_type, {b_n_size_, 1});
|
||||
return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, bias_md,
|
||||
c_md, attr);
|
||||
} else {
|
||||
return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md,
|
||||
attr);
|
||||
}
|
||||
}
|
||||
169
csrc/cpu/dnnl_helper.h
Normal file
169
csrc/cpu/dnnl_helper.h
Normal file
@ -0,0 +1,169 @@
|
||||
#ifndef DNNL_HELPER_H
|
||||
#define DNNL_HELPER_H
|
||||
|
||||
#include <optional>
|
||||
#include <cassert>
|
||||
|
||||
#include "oneapi/dnnl/dnnl.hpp"
|
||||
|
||||
namespace c10 {
|
||||
struct BFloat16;
|
||||
struct Half;
|
||||
} // namespace c10
|
||||
|
||||
namespace dnnl {
|
||||
namespace impl {
|
||||
struct memory_storage_t;
|
||||
struct matmul_pd_t;
|
||||
struct matmul_desc_t;
|
||||
} // namespace impl
|
||||
} // namespace dnnl
|
||||
struct dnnl_memory_desc;
|
||||
|
||||
template <typename KT, typename VT>
|
||||
class DNNLPrimitiveCache;
|
||||
|
||||
template <typename T>
|
||||
struct DNNLType {
|
||||
static constexpr dnnl::memory::data_type type =
|
||||
dnnl::memory::data_type::undef;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<int8_t> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<int32_t> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<float> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<c10::BFloat16> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<c10::Half> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f16;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
constexpr inline dnnl::memory::data_type get_dnnl_type() {
|
||||
return DNNLType<std::decay_t<T>>::type;
|
||||
}
|
||||
|
||||
class DNNLMatMulPrimitiveHandler {
|
||||
public:
|
||||
virtual ~DNNLMatMulPrimitiveHandler() = default;
|
||||
|
||||
protected:
|
||||
struct Args {
|
||||
dnnl_dim_t b_n_size;
|
||||
dnnl_dim_t b_n_stride;
|
||||
dnnl_dim_t b_k_size;
|
||||
dnnl_dim_t b_k_stride;
|
||||
void* b_ptr;
|
||||
dnnl::memory::data_type c_type;
|
||||
size_t primitive_cache_size;
|
||||
};
|
||||
|
||||
protected:
|
||||
DNNLMatMulPrimitiveHandler(const Args& args, dnnl::memory::data_type b_type);
|
||||
|
||||
void prepack_weight(void* original_b_ptr,
|
||||
dnnl::memory::desc b_target_mem_desc);
|
||||
|
||||
void set_runtime_memory_ptr(size_t index, dnnl_memory* memory_ptr);
|
||||
|
||||
std::pair<dnnl::impl::memory_storage_t*, dnnl_memory_desc*>
|
||||
get_runtime_memory_ptr(size_t index);
|
||||
|
||||
protected:
|
||||
const dnnl_dim_t b_n_size_;
|
||||
const dnnl_dim_t b_n_stride_;
|
||||
const dnnl_dim_t b_k_size_;
|
||||
const dnnl_dim_t b_k_stride_;
|
||||
dnnl::memory::data_type b_type_;
|
||||
dnnl::memory::data_type c_type_;
|
||||
std::unordered_map<int, dnnl::memory> memory_cache_;
|
||||
std::vector<std::pair<dnnl::impl::memory_storage_t*, dnnl_memory_desc*>>
|
||||
runtime_memory_ptrs_;
|
||||
dnnl::memory::desc b_target_mem_desc_;
|
||||
int64_t primitive_cache_size_;
|
||||
};
|
||||
|
||||
class W8A8MatMulPrimitiveHandler : public DNNLMatMulPrimitiveHandler {
|
||||
public:
|
||||
enum class QuantizationStrategy { PER_TOKEN, PER_TENSOR, PER_OUTPUT_CHANNEL };
|
||||
|
||||
struct Args : public DNNLMatMulPrimitiveHandler::Args {
|
||||
bool use_a_zero_point;
|
||||
QuantizationStrategy a_quantization_strategy;
|
||||
QuantizationStrategy b_quantization_strategy;
|
||||
float* b_scales_ptr;
|
||||
};
|
||||
|
||||
struct ClassMatmulCacheKey {
|
||||
dnnl_dim_t b_n_size;
|
||||
dnnl_dim_t b_k_size;
|
||||
QuantizationStrategy a_qs;
|
||||
QuantizationStrategy b_qs;
|
||||
bool use_azp;
|
||||
dnnl::memory::data_type c_type;
|
||||
|
||||
friend bool operator==(const ClassMatmulCacheKey& l,
|
||||
const ClassMatmulCacheKey& r);
|
||||
};
|
||||
|
||||
struct MSizeCacheKey {
|
||||
dnnl_dim_t a_m_size;
|
||||
bool use_bias;
|
||||
dnnl::memory::data_type bias_type;
|
||||
|
||||
friend bool operator==(const MSizeCacheKey& l, const MSizeCacheKey& r);
|
||||
};
|
||||
|
||||
using MSizeCache = DNNLPrimitiveCache<MSizeCacheKey, dnnl::matmul>;
|
||||
using ClassMatmulCache =
|
||||
DNNLPrimitiveCache<ClassMatmulCacheKey, std::shared_ptr<MSizeCache>>;
|
||||
|
||||
struct ExecArgs : public MSizeCacheKey {
|
||||
const int8_t* a_ptr;
|
||||
const float* a_scales_ptr;
|
||||
const int32_t* a_zero_points_ptr;
|
||||
const void* bias_ptr;
|
||||
void* c_ptr;
|
||||
};
|
||||
|
||||
public:
|
||||
W8A8MatMulPrimitiveHandler(const Args& args);
|
||||
|
||||
QuantizationStrategy get_input_scale_strategy() const { return a_qs_; }
|
||||
|
||||
bool get_input_use_zero_point() const { return use_azp_; }
|
||||
|
||||
void execute(ExecArgs& args);
|
||||
|
||||
private:
|
||||
dnnl::matmul::primitive_desc create_primitive_desc(const MSizeCacheKey& key,
|
||||
bool first_time);
|
||||
|
||||
void init_runtime_memory_cache(const Args& args);
|
||||
|
||||
dnnl::matmul get_matmul_cache(const MSizeCacheKey& key);
|
||||
|
||||
private:
|
||||
const bool use_azp_;
|
||||
const QuantizationStrategy a_qs_;
|
||||
const QuantizationStrategy b_qs_;
|
||||
std::shared_ptr<MSizeCache> m_size_cache_;
|
||||
};
|
||||
|
||||
#endif
|
||||
@ -1,206 +0,0 @@
|
||||
#ifndef DNNL_HELPER_HPP
|
||||
#define DNNL_HELPER_HPP
|
||||
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Half.h>
|
||||
|
||||
#include "oneapi/dnnl/dnnl.hpp"
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
struct DNNLType {
|
||||
static constexpr dnnl::memory::data_type type =
|
||||
dnnl::memory::data_type::undef;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<int8_t> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<int32_t> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<float> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<c10::BFloat16> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<c10::Half> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f16;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
constexpr inline dnnl::memory::data_type get_dnnl_type() {
|
||||
return DNNLType<std::decay_t<T>>::type;
|
||||
}
|
||||
}; // namespace
|
||||
|
||||
template <bool InputNoScale>
|
||||
class DNNLPrimitiveHelper {
|
||||
public:
|
||||
// I8 input GEMM kernel (C = a_scales * A @ (b_scales * B^T) + bias)
|
||||
// A: [M, K], row-major
|
||||
// B: [K, N], column-major
|
||||
// C: [M, N], row-major
|
||||
// bias: [N], row-major, optional
|
||||
// a_scales: [MS]
|
||||
// b_scales: [NS]
|
||||
// Note: Due to the limitation of oneDNN
|
||||
// (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is
|
||||
// not supported.
|
||||
|
||||
template <typename OutputT, typename BiasT>
|
||||
static void gemm_s8s8_jit(const int8_t* a, const int8_t* b, OutputT* c,
|
||||
const BiasT* bias, dnnl_dim_t M, dnnl_dim_t N,
|
||||
dnnl_dim_t K, const float* a_scales,
|
||||
const float* b_scales, dnnl_dim_t MS,
|
||||
dnnl_dim_t NS) {
|
||||
auto&& OutputType = get_dnnl_type<OutputT>();
|
||||
auto&& BiasType = get_dnnl_type<BiasT>();
|
||||
|
||||
dnnl::memory::desc a_md({M, K}, dnnl::memory::data_type::s8, {K, 1});
|
||||
dnnl::memory::desc b_md({K, N}, dnnl::memory::data_type::s8, {1, K});
|
||||
dnnl::memory::desc c_md({M, N}, OutputType, {N, 1});
|
||||
|
||||
dnnl::primitive_attr attr;
|
||||
if constexpr (!InputNoScale) {
|
||||
if (MS == 1) {
|
||||
// per-tensor
|
||||
attr.set_scales_mask(DNNL_ARG_SRC, 0);
|
||||
} else {
|
||||
// per-token
|
||||
TORCH_CHECK(false, "per-token quantization is unsupported.");
|
||||
}
|
||||
}
|
||||
|
||||
if (NS == 1) {
|
||||
// per-tensor
|
||||
attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0);
|
||||
} else {
|
||||
// per-channel
|
||||
attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2);
|
||||
}
|
||||
|
||||
dnnl::matmul::primitive_desc matmul_pd;
|
||||
// Create memory descriptors with format_tag::any for the primitive. This
|
||||
// enables the matmul primitive to choose memory layouts for an
|
||||
// optimized primitive implementation, and these layouts may differ from the
|
||||
// ones provided by the user.
|
||||
#ifdef __aarch64__
|
||||
auto mat_src_md = dnnl::memory::desc({M, K}, dnnl::memory::data_type::s8,
|
||||
dnnl::memory::format_tag::any);
|
||||
auto mat_weights_md = dnnl::memory::desc(
|
||||
{K, N}, dnnl::memory::data_type::s8, dnnl::memory::format_tag::any);
|
||||
auto mat_dst_md =
|
||||
dnnl::memory::desc({M, N}, OutputType, dnnl::memory::format_tag::any);
|
||||
if (bias) {
|
||||
dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1});
|
||||
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), mat_src_md,
|
||||
mat_weights_md, bias_md,
|
||||
mat_dst_md, attr);
|
||||
} else {
|
||||
matmul_pd = dnnl::matmul::primitive_desc(
|
||||
default_engine(), mat_src_md, mat_weights_md, mat_dst_md, attr);
|
||||
}
|
||||
#else
|
||||
if (bias) {
|
||||
dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1});
|
||||
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
|
||||
bias_md, c_md, attr);
|
||||
} else {
|
||||
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
|
||||
c_md, attr);
|
||||
}
|
||||
#endif
|
||||
dnnl::matmul matmul(matmul_pd);
|
||||
|
||||
auto& engine = default_engine();
|
||||
|
||||
dnnl::memory a_m(a_md, engine, (void*)a);
|
||||
dnnl::memory b_m(b_md, engine, (void*)b);
|
||||
dnnl::memory c_m(c_md, engine, (void*)c);
|
||||
dnnl::memory a_scales_m({{MS}, dnnl::memory::data_type::f32, {1}}, engine,
|
||||
(void*)a_scales);
|
||||
dnnl::memory b_scales_m({{NS}, dnnl::memory::data_type::f32, {1}}, engine,
|
||||
(void*)b_scales);
|
||||
|
||||
auto& stream = default_stream();
|
||||
|
||||
auto mat_src_mem = a_m;
|
||||
auto mat_weights_mem = b_m;
|
||||
auto mat_dst_mem = c_m;
|
||||
#ifdef __aarch64__
|
||||
if (matmul_pd.weights_desc() != b_m.get_desc()) {
|
||||
mat_weights_mem = dnnl::memory(matmul_pd.weights_desc(), engine);
|
||||
dnnl::reorder(b_m, mat_weights_mem).execute(stream, b_m, mat_weights_mem);
|
||||
}
|
||||
#endif
|
||||
if constexpr (InputNoScale) {
|
||||
if (bias) {
|
||||
dnnl::memory::desc bias_md({N}, BiasType, {1});
|
||||
dnnl::memory bias_m(bias_md, engine, (void*)bias);
|
||||
matmul.execute(
|
||||
stream, {
|
||||
{DNNL_ARG_SRC, mat_src_mem},
|
||||
{DNNL_ARG_WEIGHTS, mat_weights_mem},
|
||||
{DNNL_ARG_BIAS, bias_m},
|
||||
{DNNL_ARG_DST, mat_dst_mem},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||
});
|
||||
} else {
|
||||
matmul.execute(
|
||||
stream, {
|
||||
{DNNL_ARG_SRC, mat_src_mem},
|
||||
{DNNL_ARG_WEIGHTS, mat_weights_mem},
|
||||
{DNNL_ARG_DST, mat_dst_mem},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||
});
|
||||
}
|
||||
} else {
|
||||
if (bias) {
|
||||
dnnl::memory::desc bias_md({N}, BiasType, {1});
|
||||
dnnl::memory bias_m(bias_md, engine, (void*)bias);
|
||||
matmul.execute(
|
||||
stream, {
|
||||
{DNNL_ARG_SRC, mat_src_mem},
|
||||
{DNNL_ARG_WEIGHTS, mat_weights_mem},
|
||||
{DNNL_ARG_BIAS, bias_m},
|
||||
{DNNL_ARG_DST, mat_dst_mem},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||
});
|
||||
} else {
|
||||
matmul.execute(
|
||||
stream, {
|
||||
{DNNL_ARG_SRC, mat_src_mem},
|
||||
{DNNL_ARG_WEIGHTS, mat_weights_mem},
|
||||
{DNNL_ARG_DST, mat_dst_mem},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||
});
|
||||
}
|
||||
}
|
||||
stream.wait();
|
||||
}
|
||||
|
||||
private:
|
||||
static dnnl::engine& default_engine() {
|
||||
static dnnl::engine engine(dnnl::engine::kind::cpu, 0);
|
||||
return engine;
|
||||
}
|
||||
|
||||
static dnnl::stream& default_stream() {
|
||||
static dnnl::stream stream(default_engine());
|
||||
return stream;
|
||||
}
|
||||
};
|
||||
#endif
|
||||
494
csrc/cpu/dnnl_kernels.cpp
Normal file
494
csrc/cpu/dnnl_kernels.cpp
Normal file
@ -0,0 +1,494 @@
|
||||
#include "cpu_types.hpp"
|
||||
#include "dnnl_helper.h"
|
||||
|
||||
namespace {
|
||||
template <typename scalar_t>
|
||||
struct KernelVecType {
|
||||
using load_vec_type = void;
|
||||
using cvt_vec_type = void;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KernelVecType<float> {
|
||||
using load_vec_type = vec_op::FP32Vec16;
|
||||
using cvt_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
|
||||
#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT)
|
||||
template <>
|
||||
struct KernelVecType<c10::BFloat16> {
|
||||
using load_vec_type = vec_op::BF16Vec16;
|
||||
using cvt_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <>
|
||||
struct KernelVecType<c10::Half> {
|
||||
#if defined(__powerpc64__) || defined(__s390x__)
|
||||
// Power architecture-specific vector type
|
||||
using load_vec_type = vec_op::FP32Vec16;
|
||||
#else
|
||||
// Fallback for other architectures
|
||||
using load_vec_type = vec_op::FP16Vec16;
|
||||
#endif
|
||||
using cvt_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
|
||||
template <bool AZP, typename scalar_t>
|
||||
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
const float* scale, const int32_t* azp,
|
||||
const int64_t num_tokens,
|
||||
const int64_t input_stride,
|
||||
const int64_t hidden_size) {
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int64_t vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
constexpr float i8_min =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::min());
|
||||
constexpr float i8_max =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::max());
|
||||
const cvt_vec_t inv_scale(1.0 / *scale);
|
||||
const cvt_vec_t i8_min_vec(i8_min);
|
||||
const cvt_vec_t i8_max_vec(i8_max);
|
||||
|
||||
cvt_vec_t zp_vec;
|
||||
if constexpr (AZP) {
|
||||
zp_vec = cvt_vec_t(static_cast<float>(*azp));
|
||||
}
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int64_t i = 0; i < num_tokens; ++i) {
|
||||
int64_t j = 0;
|
||||
const scalar_t* input_ptr = input + i * input_stride;
|
||||
int8_t* output_ptr = output + i * hidden_size;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input_ptr + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = elems_fp32 * inv_scale;
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + zp_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output_ptr + j);
|
||||
}
|
||||
|
||||
load_vec_t elems(input_ptr + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = elems_fp32 * inv_scale;
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + zp_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output_ptr + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool AZP, typename scalar_t>
|
||||
void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
float* scale, int32_t* azp,
|
||||
const int64_t num_tokens,
|
||||
const int64_t input_stride,
|
||||
const int64_t hidden_size) {
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
constexpr float i8_min =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::min());
|
||||
constexpr float i8_max =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::max());
|
||||
const cvt_vec_t i8_min_vec(i8_min);
|
||||
const cvt_vec_t i8_max_vec(i8_max);
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int64_t i = 0; i < num_tokens; ++i) {
|
||||
cvt_vec_t max_value(std::numeric_limits<float>::lowest());
|
||||
cvt_vec_t min_value(std::numeric_limits<float>::max());
|
||||
{
|
||||
int64_t j = 0;
|
||||
const scalar_t* input_ptr = input + i * input_stride;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input_ptr + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
if constexpr (AZP) {
|
||||
max_value = max_value.max(elems_fp32);
|
||||
min_value = min_value.min(elems_fp32);
|
||||
} else {
|
||||
max_value = max_value.max(elems_fp32.abs());
|
||||
}
|
||||
}
|
||||
|
||||
load_vec_t elems(input_ptr + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
|
||||
if (j + vec_elem_num == hidden_size) {
|
||||
if constexpr (AZP) {
|
||||
max_value = max_value.max(elems_fp32);
|
||||
min_value = min_value.min(elems_fp32);
|
||||
} else {
|
||||
max_value = max_value.max(elems_fp32.abs());
|
||||
}
|
||||
} else {
|
||||
if constexpr (AZP) {
|
||||
max_value = max_value.max(elems_fp32, hidden_size - j);
|
||||
min_value = min_value.min(elems_fp32, hidden_size - j);
|
||||
} else {
|
||||
max_value = max_value.max(elems_fp32.abs(), hidden_size - j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float scale_val, azp_val;
|
||||
if constexpr (AZP) {
|
||||
float max_scalar = max_value.reduce_max();
|
||||
float min_scalar = min_value.reduce_min();
|
||||
scale_val = (max_scalar - min_scalar) / 255.0f;
|
||||
azp_val = std::nearbyint(-128.0f - min_scalar / scale_val);
|
||||
azp[i] = azp_val;
|
||||
scale[i] = scale_val;
|
||||
} else {
|
||||
scale_val = max_value.reduce_max() / 127.0f;
|
||||
scale[i] = scale_val;
|
||||
}
|
||||
|
||||
const cvt_vec_t inv_scale(1.0 / scale_val);
|
||||
const cvt_vec_t azp_vec(azp_val);
|
||||
|
||||
{
|
||||
int64_t j = 0;
|
||||
const scalar_t* input_ptr = input + i * input_stride;
|
||||
int8_t* output_ptr = output + i * hidden_size;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input_ptr + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = (elems_fp32 * inv_scale);
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + azp_vec;
|
||||
}
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output_ptr + j);
|
||||
}
|
||||
|
||||
load_vec_t elems(input_ptr + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = (elems_fp32 * inv_scale);
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + azp_vec;
|
||||
}
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output_ptr + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool AZP, bool Bias, typename scalar_t>
|
||||
void dynamic_quant_epilogue(const float* input, scalar_t* output,
|
||||
const float* a_scale, const int32_t* azp,
|
||||
const float* azp_adj, const scalar_t* bias,
|
||||
const int64_t num_tokens,
|
||||
const int64_t hidden_size) {
|
||||
CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue)
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
const int64_t thread_num = omp_get_max_threads();
|
||||
if (num_tokens > thread_num) {
|
||||
#pragma omp parallel for
|
||||
for (int64_t i = 0; i < num_tokens; ++i) {
|
||||
const float* input_ptr = input + i * hidden_size;
|
||||
scalar_t* output_ptr = output + i * hidden_size;
|
||||
int64_t j = 0;
|
||||
cvt_vec_t token_scale_vec(a_scale[i]);
|
||||
cvt_vec_t token_zp_scale_vec;
|
||||
if constexpr (AZP) {
|
||||
float zp_scale_val = a_scale[i] * static_cast<float>(azp[i]);
|
||||
token_zp_scale_vec = cvt_vec_t(zp_scale_val);
|
||||
}
|
||||
for (; j < hidden_size - vec_elem_num; ++j) {
|
||||
cvt_vec_t elems_fp32(input_ptr + j);
|
||||
elems_fp32 = elems_fp32 * token_scale_vec;
|
||||
if constexpr (AZP) {
|
||||
cvt_vec_t azp_adj_fp32(azp_adj + j);
|
||||
elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec;
|
||||
}
|
||||
if constexpr (Bias) {
|
||||
load_vec_t bias_vec(bias + j);
|
||||
cvt_vec_t bias_vec_fp32(bias_vec);
|
||||
elems_fp32 = elems_fp32 + bias_vec_fp32;
|
||||
}
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output_ptr + j);
|
||||
}
|
||||
cvt_vec_t elems_fp32(input_ptr + j);
|
||||
elems_fp32 = elems_fp32 * token_scale_vec;
|
||||
if constexpr (AZP) {
|
||||
cvt_vec_t azp_adj_fp32(azp_adj + j);
|
||||
elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec;
|
||||
}
|
||||
if constexpr (Bias) {
|
||||
load_vec_t bias_vec(bias + j);
|
||||
cvt_vec_t bias_vec_fp32(bias_vec);
|
||||
elems_fp32 = elems_fp32 + bias_vec_fp32;
|
||||
}
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output_ptr + j, hidden_size - j);
|
||||
}
|
||||
} else {
|
||||
const int64_t vec_iteration =
|
||||
(hidden_size + vec_elem_num - 1) / vec_elem_num;
|
||||
const int64_t vec_iteration_per_thread =
|
||||
(vec_iteration + thread_num - 1) / thread_num;
|
||||
const int64_t elem_num_per_thread = vec_iteration_per_thread * vec_elem_num;
|
||||
#pragma omp parallel for schedule(static, 1)
|
||||
for (int64_t i = 0; i < thread_num; ++i) {
|
||||
const int64_t start = elem_num_per_thread * i;
|
||||
const int64_t end = std::min(hidden_size, elem_num_per_thread + start);
|
||||
for (int64_t j = 0; j < num_tokens; ++j) {
|
||||
cvt_vec_t token_scale_vec(a_scale[j]);
|
||||
cvt_vec_t token_zp_scale_vec;
|
||||
if constexpr (AZP) {
|
||||
float zp_scale_val = a_scale[j] * static_cast<float>(azp[j]);
|
||||
token_zp_scale_vec = cvt_vec_t(zp_scale_val);
|
||||
}
|
||||
int64_t k = start;
|
||||
const float* input_ptr = input + j * hidden_size;
|
||||
scalar_t* output_ptr = output + j * hidden_size;
|
||||
for (; k < end - vec_elem_num; k += vec_elem_num) {
|
||||
cvt_vec_t elems_fp32(input_ptr + k);
|
||||
elems_fp32 = elems_fp32 * token_scale_vec;
|
||||
if constexpr (AZP) {
|
||||
cvt_vec_t azp_adj_fp32(azp_adj + k);
|
||||
elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec;
|
||||
}
|
||||
if constexpr (Bias) {
|
||||
load_vec_t bias_vec(bias + k);
|
||||
cvt_vec_t bias_vec_fp32(bias_vec);
|
||||
elems_fp32 = elems_fp32 + bias_vec_fp32;
|
||||
}
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output_ptr + k);
|
||||
}
|
||||
if (k < end) {
|
||||
cvt_vec_t elems_fp32(input_ptr + k);
|
||||
elems_fp32 = elems_fp32 * token_scale_vec;
|
||||
if constexpr (AZP) {
|
||||
cvt_vec_t azp_adj_fp32(azp_adj + k);
|
||||
elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec;
|
||||
}
|
||||
if constexpr (Bias) {
|
||||
load_vec_t bias_vec(bias + k);
|
||||
cvt_vec_t bias_vec_fp32(bias_vec);
|
||||
elems_fp32 = elems_fp32 + bias_vec_fp32;
|
||||
}
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output_ptr + k, end - k);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
int64_t create_onednn_scaled_mm_handler(
|
||||
const torch::Tensor& b, // [IC, OC], column-major
|
||||
const torch::Tensor& b_scales, // [1] or [OC]
|
||||
at::ScalarType output_type, bool dynamic_act_quant, bool use_azp,
|
||||
int64_t primitive_cache_size) {
|
||||
TORCH_CHECK(b.dim() == 2);
|
||||
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||
TORCH_CHECK(b_scales.is_contiguous());
|
||||
|
||||
W8A8MatMulPrimitiveHandler::Args args;
|
||||
args.primitive_cache_size = primitive_cache_size;
|
||||
|
||||
if (b_scales.numel() == 1) {
|
||||
args.b_quantization_strategy =
|
||||
W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR;
|
||||
} else {
|
||||
TORCH_CHECK_EQ(b_scales.numel(), b.size(1));
|
||||
args.b_quantization_strategy =
|
||||
W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_OUTPUT_CHANNEL;
|
||||
}
|
||||
args.b_scales_ptr = b_scales.data_ptr<float>();
|
||||
args.b_k_size = b.size(0);
|
||||
args.b_k_stride = b.stride(0);
|
||||
args.b_n_size = b.size(1);
|
||||
args.b_n_stride = b.stride(1);
|
||||
args.b_ptr = b.data_ptr<int8_t>();
|
||||
|
||||
if (dynamic_act_quant) {
|
||||
// dynamic per-token, bias, A scales and A zps will be applied in outside.
|
||||
args.a_quantization_strategy =
|
||||
W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TOKEN;
|
||||
args.use_a_zero_point = false;
|
||||
} else {
|
||||
// static per-tensor
|
||||
args.a_quantization_strategy =
|
||||
W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR;
|
||||
args.use_a_zero_point = use_azp;
|
||||
}
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(output_type, "create_onednn_scaled_mm_handler",
|
||||
[&] {
|
||||
if (dynamic_act_quant) {
|
||||
args.c_type = get_dnnl_type<float>();
|
||||
} else {
|
||||
args.c_type = get_dnnl_type<scalar_t>();
|
||||
}
|
||||
});
|
||||
|
||||
return reinterpret_cast<int64_t>(new W8A8MatMulPrimitiveHandler(args));
|
||||
}
|
||||
|
||||
void onednn_scaled_mm(
|
||||
torch::Tensor& c, // [M, OC], row-major
|
||||
const torch::Tensor& a, // [M, IC], row-major
|
||||
const torch::Tensor& a_scales, // [M] or [1]
|
||||
const std::optional<torch::Tensor>& azp, // [M] or [1]
|
||||
const std::optional<torch::Tensor>& azp_adj, // [M] or [1]
|
||||
const std::optional<torch::Tensor>& bias, // [N]
|
||||
int64_t handler) {
|
||||
CPU_KERNEL_GUARD_IN(onednn_scaled_mm)
|
||||
TORCH_CHECK(a.dim() == 2);
|
||||
TORCH_CHECK(a.is_contiguous());
|
||||
TORCH_CHECK(c.is_contiguous());
|
||||
W8A8MatMulPrimitiveHandler* ptr =
|
||||
reinterpret_cast<W8A8MatMulPrimitiveHandler*>(handler);
|
||||
const int32_t* azp_ptr = nullptr;
|
||||
if (azp.has_value()) {
|
||||
azp_ptr = azp->data_ptr<int32_t>();
|
||||
}
|
||||
if (ptr->get_input_scale_strategy() ==
|
||||
W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR) {
|
||||
TORCH_CHECK_EQ(a_scales.numel(), 1);
|
||||
}
|
||||
|
||||
W8A8MatMulPrimitiveHandler::ExecArgs exec_args;
|
||||
exec_args.a_ptr = a.data_ptr<int8_t>();
|
||||
exec_args.a_m_size = a.size(0);
|
||||
exec_args.bias_ptr = nullptr;
|
||||
exec_args.use_bias = false;
|
||||
exec_args.a_scales_ptr = nullptr;
|
||||
exec_args.a_zero_points_ptr = nullptr;
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "onednn_scaled_mm", [&] {
|
||||
if (ptr->get_input_scale_strategy() ==
|
||||
W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR) {
|
||||
if (bias.has_value()) {
|
||||
exec_args.bias_ptr = bias->data_ptr<scalar_t>();
|
||||
exec_args.bias_type = get_dnnl_type<scalar_t>();
|
||||
exec_args.use_bias = true;
|
||||
}
|
||||
exec_args.a_scales_ptr = a_scales.data_ptr<float>();
|
||||
exec_args.a_zero_points_ptr = azp_ptr;
|
||||
exec_args.c_ptr = c.data_ptr<scalar_t>();
|
||||
ptr->execute(exec_args);
|
||||
} else if (ptr->get_input_scale_strategy() ==
|
||||
W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TOKEN) {
|
||||
torch::Tensor tmp_fp32_out =
|
||||
torch::empty_like(c, ::at::ScalarType::Float);
|
||||
exec_args.c_ptr = tmp_fp32_out.data_ptr<float>();
|
||||
ptr->execute(exec_args);
|
||||
if (bias.has_value()) {
|
||||
if (azp.has_value()) {
|
||||
dynamic_quant_epilogue<true, true>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), azp_ptr, azp_adj->data_ptr<float>(),
|
||||
bias->data_ptr<scalar_t>(), c.size(0), c.size(1));
|
||||
} else {
|
||||
dynamic_quant_epilogue<false, true>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), azp_ptr, nullptr,
|
||||
bias->data_ptr<scalar_t>(), c.size(0), c.size(1));
|
||||
}
|
||||
} else {
|
||||
if (azp.has_value()) {
|
||||
dynamic_quant_epilogue<true, false>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), azp_ptr, azp_adj->data_ptr<float>(),
|
||||
(scalar_t*)nullptr, c.size(0), c.size(1));
|
||||
} else {
|
||||
dynamic_quant_epilogue<false, false>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), azp_ptr, nullptr, (scalar_t*)nullptr,
|
||||
c.size(0), c.size(1));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "invalid act quant type.");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// static-per-tensor quantization.
|
||||
void static_scaled_int8_quant(
|
||||
torch::Tensor& out, // [batch, hidden_size]
|
||||
const torch::Tensor& input, // [batch, hidden_size]
|
||||
const torch::Tensor& scale, std::optional<torch::Tensor> const& azp) {
|
||||
CPU_KERNEL_GUARD_IN(static_scaled_int8_quant)
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK_EQ(input.dim(), 2);
|
||||
TORCH_CHECK_EQ(input.stride(1), 1);
|
||||
TORCH_CHECK(scale.numel() == 1);
|
||||
TORCH_CHECK(!azp.has_value() || azp->numel() == 1);
|
||||
|
||||
const int64_t stride = input.stride(0);
|
||||
const int64_t hidden_size = input.size(1);
|
||||
const int64_t num_tokens = input.size(0);
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "static_scaled_int8_quant_impl", [&] {
|
||||
if (azp.has_value()) {
|
||||
static_scaled_int8_quant_impl<true>(
|
||||
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), azp->data_ptr<int32_t>(), num_tokens,
|
||||
stride, hidden_size);
|
||||
} else {
|
||||
static_scaled_int8_quant_impl<false>(input.data_ptr<scalar_t>(),
|
||||
out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), nullptr,
|
||||
num_tokens, stride, hidden_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// dynamic-per-token quantization.
|
||||
void dynamic_scaled_int8_quant(
|
||||
torch::Tensor& out, // [batch, hidden_size]
|
||||
const torch::Tensor& input, // [batch, hidden_size]
|
||||
torch::Tensor& scale, // [batch, 1]
|
||||
std::optional<torch::Tensor> const& azp) {
|
||||
CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant)
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK_EQ(input.dim(), 2);
|
||||
TORCH_CHECK_EQ(input.stride(1), 1);
|
||||
|
||||
const int64_t hidden_size = input.size(1);
|
||||
const int64_t num_tokens = input.size(0);
|
||||
const int64_t stride = input.stride(0);
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "dynamic_scaled_int8_quant_impl", [&] {
|
||||
if (azp.has_value()) {
|
||||
dynamic_scaled_int8_quant_impl<true>(
|
||||
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), azp->data_ptr<int32_t>(), num_tokens,
|
||||
stride, hidden_size);
|
||||
} else {
|
||||
dynamic_scaled_int8_quant_impl<false>(
|
||||
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), nullptr, num_tokens, stride,
|
||||
hidden_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
@ -1,951 +0,0 @@
|
||||
#include "cpu_types.hpp"
|
||||
#include "dnnl_helper.hpp"
|
||||
|
||||
namespace {
|
||||
template <typename scalar_t>
|
||||
struct KernelVecType {
|
||||
using load_vec_type = void;
|
||||
using azp_adj_load_vec_type = void;
|
||||
using cvt_vec_type = void;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KernelVecType<float> {
|
||||
using load_vec_type = vec_op::FP32Vec16;
|
||||
using azp_adj_load_vec_type = vec_op::INT32Vec16;
|
||||
using cvt_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
|
||||
#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT)
|
||||
template <>
|
||||
struct KernelVecType<c10::BFloat16> {
|
||||
using load_vec_type = vec_op::BF16Vec16;
|
||||
using azp_adj_load_vec_type = vec_op::INT32Vec16;
|
||||
using cvt_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <>
|
||||
struct KernelVecType<c10::Half> {
|
||||
#if defined(__powerpc64__) || defined(__s390x__)
|
||||
// Power architecture-specific vector type
|
||||
using load_vec_type = vec_op::FP32Vec16;
|
||||
#else
|
||||
// Fallback for other architectures
|
||||
using load_vec_type = vec_op::FP16Vec16;
|
||||
#endif
|
||||
using azp_adj_load_vec_type = vec_op::INT32Vec16;
|
||||
using cvt_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
|
||||
#if defined(__AVX512F__) || defined(__aarch64__)
|
||||
template <bool AZP, typename scalar_t>
|
||||
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
const float* scale, const int32_t* azp,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
constexpr float i8_min =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::min());
|
||||
constexpr float i8_max =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::max());
|
||||
const cvt_vec_t inv_scale(1.0 / *scale);
|
||||
const cvt_vec_t i8_min_vec(i8_min);
|
||||
const cvt_vec_t i8_max_vec(i8_max);
|
||||
|
||||
cvt_vec_t zp_vec;
|
||||
if constexpr (AZP) {
|
||||
zp_vec = cvt_vec_t(static_cast<float>(*azp));
|
||||
}
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
int j = 0;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = elems_fp32 * inv_scale;
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + zp_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output + i * hidden_size + j);
|
||||
}
|
||||
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = elems_fp32 * inv_scale;
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + zp_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output + i * hidden_size + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool AZP, typename scalar_t>
|
||||
void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
float* scale, int32_t* azp,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
constexpr float i8_min =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::min());
|
||||
constexpr float i8_max =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::max());
|
||||
const cvt_vec_t i8_min_vec(i8_min);
|
||||
const cvt_vec_t i8_max_vec(i8_max);
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
cvt_vec_t max_value(std::numeric_limits<float>::lowest());
|
||||
cvt_vec_t min_value(std::numeric_limits<float>::max());
|
||||
{
|
||||
int j = 0;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
if constexpr (AZP) {
|
||||
max_value = max_value.max(elems_fp32);
|
||||
min_value = min_value.min(elems_fp32);
|
||||
} else {
|
||||
max_value = max_value.max(elems_fp32.abs());
|
||||
}
|
||||
}
|
||||
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
|
||||
if (j + vec_elem_num == hidden_size) {
|
||||
if constexpr (AZP) {
|
||||
max_value = max_value.max(elems_fp32);
|
||||
min_value = min_value.min(elems_fp32);
|
||||
} else {
|
||||
max_value = max_value.max(elems_fp32.abs());
|
||||
}
|
||||
} else {
|
||||
if constexpr (AZP) {
|
||||
max_value = max_value.max(elems_fp32, hidden_size - j);
|
||||
min_value = min_value.min(elems_fp32, hidden_size - j);
|
||||
} else {
|
||||
max_value = max_value.max(elems_fp32.abs(), hidden_size - j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float scale_val, azp_val;
|
||||
if constexpr (AZP) {
|
||||
float max_scalar = max_value.reduce_max();
|
||||
float min_scalar = min_value.reduce_min();
|
||||
scale_val = (max_scalar - min_scalar) / 255.0f;
|
||||
azp_val = std::nearbyint(-128.0f - min_scalar / scale_val);
|
||||
azp[i] = static_cast<int32_t>(azp_val);
|
||||
scale[i] = scale_val;
|
||||
} else {
|
||||
scale_val = max_value.reduce_max() / 127.0f;
|
||||
scale[i] = scale_val;
|
||||
}
|
||||
|
||||
const cvt_vec_t inv_scale(1.0 / scale_val);
|
||||
const cvt_vec_t azp_vec(azp_val);
|
||||
|
||||
{
|
||||
int j = 0;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = (elems_fp32 * inv_scale);
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + azp_vec;
|
||||
}
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output + i * hidden_size + j);
|
||||
}
|
||||
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = (elems_fp32 * inv_scale);
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + azp_vec;
|
||||
}
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output + i * hidden_size + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool PerChannel, typename scalar_t>
|
||||
void static_quant_epilogue(const float* input, scalar_t* output,
|
||||
const float a_scale, const float* b_scale,
|
||||
const int32_t* azp_with_adj, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl)
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using azp_adj_load_vec_t =
|
||||
typename KernelVecType<scalar_t>::azp_adj_load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
cvt_vec_t a_scale_vec(a_scale);
|
||||
cvt_vec_t b_scale_vec(*b_scale);
|
||||
cvt_vec_t scale_vec = a_scale_vec * b_scale_vec;
|
||||
|
||||
int j = 0;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
cvt_vec_t elems_fp32(input + i * hidden_size + j);
|
||||
azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j);
|
||||
cvt_vec_t azp_adj_fp32(azp_adj_vec);
|
||||
|
||||
if constexpr (PerChannel) {
|
||||
b_scale_vec = cvt_vec_t(b_scale + j);
|
||||
scale_vec = b_scale_vec * a_scale_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32;
|
||||
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output + i * hidden_size + j);
|
||||
}
|
||||
|
||||
cvt_vec_t elems_fp32(input + i * hidden_size + j);
|
||||
azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j);
|
||||
cvt_vec_t azp_adj_fp32(azp_adj_vec);
|
||||
|
||||
if constexpr (PerChannel) {
|
||||
b_scale_vec = cvt_vec_t(b_scale + j);
|
||||
scale_vec = b_scale_vec * a_scale_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32;
|
||||
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output + i * hidden_size + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool AZP, bool PerChannel, bool Bias, typename scalar_t>
|
||||
void dynamic_quant_epilogue(const float* input, scalar_t* output,
|
||||
const float* a_scale, const float* b_scale,
|
||||
const int32_t* azp, const int32_t* azp_adj,
|
||||
const scalar_t* bias, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue)
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using azp_adj_load_vec_t =
|
||||
typename KernelVecType<scalar_t>::azp_adj_load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
int j = 0;
|
||||
cvt_vec_t token_scale_vec(a_scale[i]);
|
||||
cvt_vec_t token_zp_scale_vec;
|
||||
if constexpr (AZP) {
|
||||
float zp_scale_val = a_scale[i] * static_cast<float>(azp[i]);
|
||||
if constexpr (!PerChannel) {
|
||||
zp_scale_val *= *b_scale;
|
||||
}
|
||||
token_zp_scale_vec = cvt_vec_t(zp_scale_val);
|
||||
}
|
||||
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
cvt_vec_t elems_fp32(input + i * hidden_size + j);
|
||||
elems_fp32 = elems_fp32 * token_scale_vec;
|
||||
|
||||
if constexpr (AZP) {
|
||||
azp_adj_load_vec_t azp_adj_vec(azp_adj + j);
|
||||
cvt_vec_t azp_adj_fp32(azp_adj_vec);
|
||||
azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec;
|
||||
|
||||
if constexpr (PerChannel) {
|
||||
cvt_vec_t b_scale_vec(b_scale + j);
|
||||
azp_adj_fp32 = azp_adj_fp32 * b_scale_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32 - azp_adj_fp32;
|
||||
}
|
||||
|
||||
if constexpr (Bias) {
|
||||
load_vec_t bias_vec(bias + j);
|
||||
cvt_vec_t bias_vec_fp32(bias_vec);
|
||||
elems_fp32 = elems_fp32 + bias_vec_fp32;
|
||||
}
|
||||
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output + i * hidden_size + j);
|
||||
}
|
||||
|
||||
cvt_vec_t elems_fp32(input + i * hidden_size + j);
|
||||
elems_fp32 = elems_fp32 * token_scale_vec;
|
||||
|
||||
if constexpr (AZP) {
|
||||
azp_adj_load_vec_t azp_adj_vec(azp_adj + j);
|
||||
cvt_vec_t azp_adj_fp32(azp_adj_vec);
|
||||
azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec;
|
||||
|
||||
if constexpr (PerChannel) {
|
||||
cvt_vec_t b_scale_vec(b_scale + j);
|
||||
azp_adj_fp32 = azp_adj_fp32 * b_scale_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32 - azp_adj_fp32;
|
||||
}
|
||||
|
||||
if constexpr (Bias) {
|
||||
load_vec_t bias_vec(bias + j);
|
||||
cvt_vec_t bias_vec_fp32(bias_vec);
|
||||
elems_fp32 = elems_fp32 + bias_vec_fp32;
|
||||
}
|
||||
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output + i * hidden_size + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
#elif defined(__powerpc64__)
|
||||
template <bool AZP, typename scalar_t>
|
||||
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
const float* scale, const int32_t* azp,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
constexpr float i8_min =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::min());
|
||||
constexpr float i8_max =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::max());
|
||||
|
||||
const cvt_vec_t inv_scale(1.0 / *scale);
|
||||
const cvt_vec_t i8_min_vec(i8_min);
|
||||
const cvt_vec_t i8_max_vec(i8_max);
|
||||
|
||||
cvt_vec_t zp_vec;
|
||||
if constexpr (AZP) {
|
||||
zp_vec = cvt_vec_t(static_cast<float>(*azp));
|
||||
}
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
int j = 0;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = elems_fp32 * inv_scale;
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + zp_vec;
|
||||
}
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output + i * hidden_size + j);
|
||||
}
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = elems_fp32 * inv_scale;
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + zp_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output + i * hidden_size + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
template <bool AZP, typename scalar_t>
|
||||
void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
float* scale, int32_t* azp,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
constexpr float i8_min =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::min());
|
||||
constexpr float i8_max =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::max());
|
||||
const cvt_vec_t i8_min_vec(i8_min);
|
||||
const cvt_vec_t i8_max_vec(i8_max);
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
cvt_vec_t max_value(std::numeric_limits<float>::lowest());
|
||||
cvt_vec_t min_value(std::numeric_limits<float>::max());
|
||||
{
|
||||
int j = 0;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
if constexpr (AZP) {
|
||||
max_value = max_value.max(elems_fp32);
|
||||
min_value = min_value.min(elems_fp32);
|
||||
} else {
|
||||
max_value = max_value.max(elems_fp32.abs());
|
||||
}
|
||||
}
|
||||
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
|
||||
if (j + vec_elem_num == hidden_size) {
|
||||
if constexpr (AZP) {
|
||||
max_value = max_value.max(elems_fp32);
|
||||
min_value = min_value.min(elems_fp32);
|
||||
} else {
|
||||
max_value = max_value.max(elems_fp32.abs());
|
||||
}
|
||||
} else {
|
||||
if constexpr (AZP) {
|
||||
max_value = max_value.max(elems_fp32, hidden_size - j);
|
||||
min_value = min_value.min(elems_fp32, hidden_size - j);
|
||||
} else {
|
||||
max_value = max_value.max(elems_fp32.abs(), hidden_size - j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float scale_val, azp_val;
|
||||
if constexpr (AZP) {
|
||||
float max_scalar = max_value.reduce_max();
|
||||
float min_scalar = min_value.reduce_min();
|
||||
scale_val = (max_scalar - min_scalar) / 255.0f;
|
||||
azp_val = std::nearbyint(-128.0f - min_scalar / scale_val);
|
||||
azp[i] = static_cast<int32_t>(azp_val);
|
||||
scale[i] = scale_val;
|
||||
} else {
|
||||
scale_val = max_value.reduce_max() / 127.0f;
|
||||
scale[i] = scale_val;
|
||||
}
|
||||
|
||||
const cvt_vec_t inv_scale(1.0 / scale_val);
|
||||
const cvt_vec_t azp_vec(azp_val);
|
||||
|
||||
{
|
||||
int j = 0;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = (elems_fp32 * inv_scale);
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + azp_vec;
|
||||
}
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output + i * hidden_size + j);
|
||||
}
|
||||
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = (elems_fp32 * inv_scale);
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + azp_vec;
|
||||
}
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output + i * hidden_size + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
}
|
||||
template <bool PerChannel, typename scalar_t>
|
||||
void static_quant_epilogue(const float* input, scalar_t* output,
|
||||
const float a_scale, const float* b_scale,
|
||||
const int32_t* azp_with_adj, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl)
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using azp_adj_load_vec_t =
|
||||
typename KernelVecType<scalar_t>::azp_adj_load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
cvt_vec_t a_scale_vec(a_scale);
|
||||
cvt_vec_t b_scale_vec(*b_scale);
|
||||
cvt_vec_t scale_vec = a_scale_vec * b_scale_vec;
|
||||
|
||||
int j = 0;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
cvt_vec_t elems_fp32(input + i * hidden_size + j);
|
||||
azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j);
|
||||
cvt_vec_t azp_adj_fp32(azp_adj_vec);
|
||||
|
||||
if constexpr (PerChannel) {
|
||||
b_scale_vec = cvt_vec_t(b_scale + j);
|
||||
scale_vec = b_scale_vec * a_scale_vec;
|
||||
}
|
||||
elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32;
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output + i * hidden_size + j);
|
||||
}
|
||||
|
||||
cvt_vec_t elems_fp32(input + i * hidden_size + j);
|
||||
azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j);
|
||||
cvt_vec_t azp_adj_fp32(azp_adj_vec);
|
||||
|
||||
if constexpr (PerChannel) {
|
||||
b_scale_vec = cvt_vec_t(b_scale + j);
|
||||
scale_vec = b_scale_vec * a_scale_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32;
|
||||
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output + i * hidden_size + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
template <bool AZP, bool PerChannel, bool Bias, typename scalar_t>
|
||||
void dynamic_quant_epilogue(const float* input, scalar_t* output,
|
||||
const float* a_scale, const float* b_scale,
|
||||
const int32_t* azp, const int32_t* azp_adj,
|
||||
const scalar_t* bias, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue)
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using azp_adj_load_vec_t =
|
||||
typename KernelVecType<scalar_t>::azp_adj_load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
int j = 0;
|
||||
cvt_vec_t token_scale_vec(a_scale[i]);
|
||||
cvt_vec_t token_zp_scale_vec;
|
||||
if constexpr (AZP) {
|
||||
float zp_scale_val = a_scale[i] * static_cast<float>(azp[i]);
|
||||
if constexpr (!PerChannel) {
|
||||
zp_scale_val *= *b_scale;
|
||||
}
|
||||
token_zp_scale_vec = cvt_vec_t(zp_scale_val);
|
||||
}
|
||||
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
cvt_vec_t elems_fp32(input + i * hidden_size + j);
|
||||
elems_fp32 = elems_fp32 * token_scale_vec;
|
||||
|
||||
if constexpr (AZP) {
|
||||
azp_adj_load_vec_t azp_adj_vec(azp_adj + j);
|
||||
cvt_vec_t azp_adj_fp32(azp_adj_vec);
|
||||
azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec;
|
||||
|
||||
if constexpr (PerChannel) {
|
||||
cvt_vec_t b_scale_vec(b_scale + j);
|
||||
azp_adj_fp32 = azp_adj_fp32 * b_scale_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32 - azp_adj_fp32;
|
||||
}
|
||||
|
||||
if constexpr (Bias) {
|
||||
load_vec_t bias_vec(bias + j);
|
||||
cvt_vec_t bias_vec_fp32(bias_vec);
|
||||
elems_fp32 = elems_fp32 + bias_vec_fp32;
|
||||
}
|
||||
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output + i * hidden_size + j);
|
||||
}
|
||||
|
||||
cvt_vec_t elems_fp32(input + i * hidden_size + j);
|
||||
elems_fp32 = elems_fp32 * token_scale_vec;
|
||||
|
||||
if constexpr (AZP) {
|
||||
azp_adj_load_vec_t azp_adj_vec(azp_adj + j);
|
||||
cvt_vec_t azp_adj_fp32(azp_adj_vec);
|
||||
azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec;
|
||||
|
||||
if constexpr (PerChannel) {
|
||||
cvt_vec_t b_scale_vec(b_scale + j);
|
||||
azp_adj_fp32 = azp_adj_fp32 * b_scale_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32 - azp_adj_fp32;
|
||||
}
|
||||
|
||||
if constexpr (Bias) {
|
||||
load_vec_t bias_vec(bias + j);
|
||||
cvt_vec_t bias_vec_fp32(bias_vec);
|
||||
elems_fp32 = elems_fp32 + bias_vec_fp32;
|
||||
}
|
||||
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output + i * hidden_size + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
#else
|
||||
template <typename scalar_t>
|
||||
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
const float* scale, const int32_t* azp,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(false,
|
||||
"static_scaled_int8_quant_impl requires AVX512/powerpc64/AArch64 "
|
||||
"support.")
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
float* scale, int32_t* azp,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(false,
|
||||
"dynamic_scaled_int8_quant_impl requires "
|
||||
"AVX512/powerpc64/AArch64 support.")
|
||||
}
|
||||
|
||||
template <bool PerChannel, typename scalar_t>
|
||||
void static_quant_epilogue(const float* input, scalar_t* output,
|
||||
const float a_scale, const float* b_scale,
|
||||
const int32_t* azp_with_adj, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(
|
||||
false, "static_quant_epilogue requires AVX512/powerpc64/AArch64 support.")
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void dynamic_quant_epilogue(const float* input, scalar_t* output,
|
||||
const float* a_scale, const float* b_scale,
|
||||
const int32_t* azp, const int32_t* azp_with_adj,
|
||||
const scalar_t* bias, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"dynamic_quant_epilogue requires AVX512/powerpc64/AArch64 support.")
|
||||
}
|
||||
#endif
|
||||
} // namespace
|
||||
|
||||
void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major
|
||||
const torch::Tensor& a, // [M, IC], row-major
|
||||
const torch::Tensor& b, // [IC, OC], column-major
|
||||
const torch::Tensor& a_scales, // [1] or [M]
|
||||
const torch::Tensor& b_scales, // [1] or [OC]
|
||||
const std::optional<torch::Tensor>& bias // [OC]
|
||||
) {
|
||||
CPU_KERNEL_GUARD_IN(cutlass_scaled_mm)
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8,
|
||||
"int8_scaled_mm only supports INT8 inputs.")
|
||||
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||
b.size(1) == c.size(1));
|
||||
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
||||
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
|
||||
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||
TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
||||
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
|
||||
bias->dim() == 1);
|
||||
}
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm", [&] {
|
||||
if (a_scales.numel() != 1) {
|
||||
// per-token
|
||||
// Note: oneDNN doesn't support per-token activation quantization
|
||||
// Ideally we want to fuse the GEMM and the scale procedure with oneDNN
|
||||
// JIT, the intermediate data is cached in registers or L1. But for now
|
||||
// the oneDNN GEMM code generation only supports two quantization
|
||||
// patterns: per-tensor or per-output-channel of weight.
|
||||
// So we have to apply the per-token scale with a 'epilogue'. In C=s_a *
|
||||
// s_b * (A@B) + bias, the C_inter = s_b * (A@B) is computed by oneDNN
|
||||
// GEMM, then the per-token scale (and bias) is applied with the epilogue
|
||||
// C=s_a * C_inter + bias.
|
||||
torch::Tensor tmp_fp32_out =
|
||||
torch::empty_like(c, ::at::ScalarType::Float);
|
||||
// Compute C_inter=s_b * (A@B)
|
||||
DNNLPrimitiveHelper<true>::gemm_s8s8_jit<float, void>(
|
||||
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(),
|
||||
tmp_fp32_out.data_ptr<float>(), nullptr, a.size(0), b.size(1),
|
||||
a.size(1), nullptr, b_scales.data_ptr<float>(), 0, b_scales.numel());
|
||||
if (bias.has_value()) {
|
||||
// Compute C=s_a * C_inter + bias
|
||||
dynamic_quant_epilogue<false, true, true>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), nullptr, nullptr, nullptr,
|
||||
bias->data_ptr<scalar_t>(), c.size(0), c.size(1));
|
||||
} else {
|
||||
// Compute C=s_a * C_inter
|
||||
dynamic_quant_epilogue<false, true, false, scalar_t>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), nullptr, nullptr, nullptr, nullptr,
|
||||
c.size(0), c.size(1));
|
||||
}
|
||||
} else {
|
||||
// per-tensor
|
||||
if (bias.has_value()) {
|
||||
// Compute C=s_a * s_b * (A@B) + bias
|
||||
DNNLPrimitiveHelper<false>::gemm_s8s8_jit(
|
||||
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(), c.data_ptr<scalar_t>(),
|
||||
bias->data_ptr<scalar_t>(), a.size(0), b.size(1), a.size(1),
|
||||
a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
|
||||
a_scales.numel(), b_scales.numel());
|
||||
} else {
|
||||
// Compute C=s_a * s_b * (A@B)
|
||||
DNNLPrimitiveHelper<false>::gemm_s8s8_jit<scalar_t, void>(
|
||||
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(), c.data_ptr<scalar_t>(),
|
||||
nullptr, a.size(0), b.size(1), a.size(1),
|
||||
a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
|
||||
a_scales.numel(), b_scales.numel());
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void int8_scaled_mm_azp(torch::Tensor& c, // [M, OC], row-major
|
||||
const torch::Tensor& a, // [M, IC], row-major
|
||||
const torch::Tensor& b, // [IC, OC], column-major
|
||||
const torch::Tensor& a_scales, // [1] or [M]
|
||||
const torch::Tensor& b_scales, // [1] or [OC]
|
||||
const torch::Tensor& azp_adj, // [OC]
|
||||
const std::optional<torch::Tensor>& azp, // [1] or [M]
|
||||
const std::optional<torch::Tensor>& bias // [OC]
|
||||
) {
|
||||
CPU_KERNEL_GUARD_IN(cutlass_scaled_mm_azp)
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8,
|
||||
"int8_scaled_mm_azp only supports INT8 inputs.")
|
||||
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||
b.size(1) == c.size(1));
|
||||
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
||||
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
|
||||
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||
TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
||||
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous());
|
||||
}
|
||||
if (azp) {
|
||||
TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous());
|
||||
}
|
||||
TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous());
|
||||
|
||||
// azp & bias types
|
||||
TORCH_CHECK(azp_adj.dtype() == torch::kInt32);
|
||||
TORCH_CHECK(!azp || azp->dtype() == torch::kInt32);
|
||||
TORCH_CHECK(!bias || bias->dtype() == c.dtype(),
|
||||
"currently bias dtype must match output dtype ", c.dtype());
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm_azp", [&] {
|
||||
torch::Tensor tmp_fp32_out = torch::empty_like(c, ::at::ScalarType::Float);
|
||||
if (a_scales.numel() != 1) {
|
||||
// per-token
|
||||
// Note: oneDNN doesn't support per-token activation quantization
|
||||
// Compute C_inter=s_b * (A@B)
|
||||
DNNLPrimitiveHelper<true>::gemm_s8s8_jit<float, void>(
|
||||
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(),
|
||||
tmp_fp32_out.data_ptr<float>(), nullptr, a.size(0), b.size(1),
|
||||
a.size(1), nullptr, b_scales.data_ptr<float>(), 0, b_scales.numel());
|
||||
if (bias.has_value()) {
|
||||
// Compute C=s_a * C_inter - s_a * s_b * azp * azp_adj + bias
|
||||
if (b_scales.numel() != 1) {
|
||||
// Per-Channel
|
||||
dynamic_quant_epilogue<true, true, true>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
|
||||
azp->data_ptr<int32_t>(), azp_adj.data_ptr<int32_t>(),
|
||||
bias->data_ptr<scalar_t>(), c.size(0), c.size(1));
|
||||
} else {
|
||||
// Per-Tensor
|
||||
dynamic_quant_epilogue<true, false, true>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
|
||||
azp->data_ptr<int32_t>(), azp_adj.data_ptr<int32_t>(),
|
||||
bias->data_ptr<scalar_t>(), c.size(0), c.size(1));
|
||||
}
|
||||
} else {
|
||||
// Compute C=s_a * C_inter - s_a * s_b * azp * azp_adj
|
||||
if (b_scales.numel() != 1) {
|
||||
// Per-Channel
|
||||
dynamic_quant_epilogue<true, true, false, scalar_t>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
|
||||
azp->data_ptr<int32_t>(), azp_adj.data_ptr<int32_t>(), nullptr,
|
||||
c.size(0), c.size(1));
|
||||
} else {
|
||||
// Per-Tensor
|
||||
dynamic_quant_epilogue<true, false, false, scalar_t>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
|
||||
azp->data_ptr<int32_t>(), azp_adj.data_ptr<int32_t>(), nullptr,
|
||||
c.size(0), c.size(1));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// per-tensor
|
||||
if (bias.has_value()) {
|
||||
// Compute C_inter=s_a * s_b * (A@B) + bias
|
||||
DNNLPrimitiveHelper<false>::gemm_s8s8_jit(
|
||||
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(),
|
||||
tmp_fp32_out.data_ptr<float>(), bias->data_ptr<scalar_t>(),
|
||||
a.size(0), b.size(1), a.size(1), a_scales.data_ptr<float>(),
|
||||
b_scales.data_ptr<float>(), a_scales.numel(), b_scales.numel());
|
||||
} else {
|
||||
// Compute C_inter=s_a * s_b * (A@B)
|
||||
DNNLPrimitiveHelper<false>::gemm_s8s8_jit<float, void>(
|
||||
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(),
|
||||
tmp_fp32_out.data_ptr<float>(), nullptr, a.size(0), b.size(1),
|
||||
a.size(1), a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
|
||||
a_scales.numel(), b_scales.numel());
|
||||
}
|
||||
|
||||
// Compute C=C_inter - s_a * s_b * azp_adj
|
||||
if (b_scales.numel() != 1) {
|
||||
// Per-Channel
|
||||
static_quant_epilogue<true>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
*a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
|
||||
azp_adj.data_ptr<int32_t>(), a.size(0), b.size(1));
|
||||
} else {
|
||||
// Per-Tensor
|
||||
static_quant_epilogue<false>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
*a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
|
||||
azp_adj.data_ptr<int32_t>(), a.size(0), b.size(1));
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// static-per-tensor quantization.
|
||||
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||
const torch::Tensor& input, // [..., hidden_size]
|
||||
const torch::Tensor& scale,
|
||||
std::optional<torch::Tensor> const& azp) {
|
||||
CPU_KERNEL_GUARD_IN(static_scaled_int8_quant)
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK(scale.numel() == 1);
|
||||
TORCH_CHECK(!azp.has_value() || azp->numel() == 1);
|
||||
|
||||
const int hidden_size = input.size(-1);
|
||||
const int num_tokens = input.numel() / hidden_size;
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "static_scaled_int8_quant_impl", [&] {
|
||||
if (azp.has_value()) {
|
||||
static_scaled_int8_quant_impl<true>(
|
||||
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), azp->data_ptr<int32_t>(), num_tokens,
|
||||
hidden_size);
|
||||
} else {
|
||||
static_scaled_int8_quant_impl<false>(
|
||||
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), nullptr, num_tokens, hidden_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// dynamic-per-token quantization.
|
||||
void dynamic_scaled_int8_quant(
|
||||
torch::Tensor& out, // [..., hidden_size]
|
||||
const torch::Tensor& input, // [..., hidden_size]
|
||||
torch::Tensor& scale, // [..., 1]
|
||||
std::optional<torch::Tensor> const& azp) {
|
||||
CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant)
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
|
||||
int const hidden_size = input.size(-1);
|
||||
int const num_tokens = input.numel() / hidden_size;
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "dynamic_scaled_int8_quant_impl", [&] {
|
||||
if (azp.has_value()) {
|
||||
dynamic_scaled_int8_quant_impl<true>(
|
||||
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), azp->data_ptr<int32_t>(), num_tokens,
|
||||
hidden_size);
|
||||
} else {
|
||||
dynamic_scaled_int8_quant_impl<false>(
|
||||
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), nullptr, num_tokens, hidden_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#if defined(__powerpc64__)
|
||||
void int8_scaled_mm_ppc64le(torch::Tensor& c, // [M, OC], row-major
|
||||
const torch::Tensor& a, // [M, IC], row-major
|
||||
const torch::Tensor& b, // [IC, OC], column-major
|
||||
const torch::Tensor& a_scales,
|
||||
const torch::Tensor& b_scales,
|
||||
const std::optional<torch::Tensor>& bias // [OC]
|
||||
) {
|
||||
CPU_KERNEL_GUARD_IN(cutlass_scaled_mm)
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8,
|
||||
"int8_scaled_mm_ppc64le only supports INT8 inputs.");
|
||||
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||
b.size(1) == c.size(1));
|
||||
// We dont need this
|
||||
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
||||
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
|
||||
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||
TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
||||
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
|
||||
bias->dim() == 1);
|
||||
}
|
||||
VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm_ppc64le", [&] {
|
||||
torch::Tensor tmp_fp32_out = torch::empty_like(c, ::at::ScalarType::Float);
|
||||
// Compute C_inter=s_b * (A@B)
|
||||
DNNLPrimitiveHelper<true>::gemm_s8s8_jit<float, void>(
|
||||
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(),
|
||||
tmp_fp32_out.data_ptr<float>(), nullptr, a.size(0), b.size(1),
|
||||
a.size(1), nullptr, b_scales.data_ptr<float>(), 0, b_scales.numel());
|
||||
if (bias.has_value()) {
|
||||
// Compute C=s_a * C_inter + bias
|
||||
dynamic_quant_epilogue<false, true, true>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), nullptr, nullptr, nullptr,
|
||||
bias->data_ptr<scalar_t>(), c.size(0), c.size(1));
|
||||
} else {
|
||||
// Compute C=s_a * C_inter
|
||||
dynamic_quant_epilogue<false, true, false, scalar_t>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), nullptr, nullptr, nullptr, nullptr,
|
||||
c.size(0), c.size(1));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#endif
|
||||
@ -6,25 +6,20 @@
|
||||
|
||||
std::string init_cpu_threads_env(const std::string& cpu_ids);
|
||||
|
||||
void int8_scaled_mm(torch::Tensor& c, const torch::Tensor& a,
|
||||
const torch::Tensor& b, const torch::Tensor& a_scales,
|
||||
const torch::Tensor& b_scales,
|
||||
const std::optional<torch::Tensor>& bias);
|
||||
void release_dnnl_matmul_handler(int64_t handler);
|
||||
|
||||
void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a,
|
||||
const torch::Tensor& b, const torch::Tensor& a_scales,
|
||||
const torch::Tensor& b_scales,
|
||||
const torch::Tensor& azp_adj,
|
||||
const std::optional<torch::Tensor>& azp,
|
||||
const std::optional<torch::Tensor>& bias);
|
||||
int64_t create_onednn_scaled_mm_handler(const torch::Tensor& b,
|
||||
const torch::Tensor& b_scales,
|
||||
at::ScalarType output_type,
|
||||
bool dynamic_act_quant, bool use_azp,
|
||||
int64_t primitive_cache_size);
|
||||
|
||||
#if defined(__powerpc64__)
|
||||
void int8_scaled_mm_ppc64le(torch::Tensor& c, const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& a_scales,
|
||||
const torch::Tensor& b_scales,
|
||||
const std::optional<torch::Tensor>& bias);
|
||||
#endif
|
||||
void onednn_scaled_mm(torch::Tensor& c, const torch::Tensor& a,
|
||||
const torch::Tensor& a_scales,
|
||||
const std::optional<torch::Tensor>& azp,
|
||||
const std::optional<torch::Tensor>& azp_adj,
|
||||
const std::optional<torch::Tensor>& bias,
|
||||
int64_t handler);
|
||||
|
||||
void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
|
||||
torch::Tensor& kv_cache, double scale,
|
||||
@ -151,8 +146,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding);
|
||||
|
||||
// Quantization
|
||||
#if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__))
|
||||
#if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__)) || \
|
||||
defined(__powerpc64__)
|
||||
at::Tag stride_tag = at::Tag::needs_fixed_stride_order;
|
||||
// Helper function to release oneDNN handlers
|
||||
ops.def("release_dnnl_matmul_handler(int handler) -> ()",
|
||||
&release_dnnl_matmul_handler);
|
||||
|
||||
// Create oneDNN W8A8 handler
|
||||
ops.def(
|
||||
"create_onednn_scaled_mm_handler(Tensor b, Tensor b_scales, ScalarType "
|
||||
"output_type, bool dynamic_act_quant, bool use_azp, int "
|
||||
"primitive_cache_size) -> int",
|
||||
&create_onednn_scaled_mm_handler);
|
||||
|
||||
// oneDNN scaled_mm for W8A8 with static per-tensor activation quantization
|
||||
ops.def(
|
||||
"onednn_scaled_mm(Tensor! c, Tensor a, Tensor a_scales, Tensor? azp, "
|
||||
"Tensor? azp_adj, Tensor? bias, int handler) -> ()");
|
||||
ops.impl("onednn_scaled_mm", torch::kCPU, &onednn_scaled_mm);
|
||||
|
||||
// Compute int8 quantized tensor for given scaling factor.
|
||||
ops.def(
|
||||
@ -168,50 +180,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
{stride_tag});
|
||||
ops.impl("dynamic_scaled_int8_quant", torch::kCPU,
|
||||
&dynamic_scaled_int8_quant);
|
||||
// W8A8 GEMM, supporting symmetric per-tensor or per-row/column
|
||||
// quantization.
|
||||
ops.def(
|
||||
"cutlass_scaled_mm(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor? bias) -> ()",
|
||||
{stride_tag});
|
||||
ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm);
|
||||
// w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
|
||||
// quantization.
|
||||
ops.def(
|
||||
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor azp_adj,"
|
||||
" Tensor? azp, Tensor? bias) -> ()",
|
||||
{stride_tag});
|
||||
ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp);
|
||||
#elif defined(__powerpc64__)
|
||||
// Compute int8 quantized tensor for given scaling factor.
|
||||
ops.def(
|
||||
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
|
||||
"Tensor? azp) -> ()");
|
||||
ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant);
|
||||
|
||||
// Compute int8 quantized tensor and scaling factor
|
||||
ops.def(
|
||||
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, "
|
||||
"Tensor!? azp) -> ()");
|
||||
ops.impl("dynamic_scaled_int8_quant", torch::kCPU,
|
||||
&dynamic_scaled_int8_quant);
|
||||
// W8A8 GEMM, supporting symmetric quantization.
|
||||
ops.def(
|
||||
"cutlass_scaled_mm(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor? bias) -> ()");
|
||||
ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm_ppc64le);
|
||||
// w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
|
||||
// quantization.
|
||||
ops.def(
|
||||
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor azp_adj,"
|
||||
" Tensor? azp, Tensor? bias) -> ()");
|
||||
ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp);
|
||||
#endif
|
||||
|
||||
// SHM CCL
|
||||
|
||||
@ -45,8 +45,6 @@ void moe_permute(
|
||||
auto copy_topk_ids = topk_ids.clone(); // copy topk_ids for preprocess
|
||||
auto permuted_experts_id = torch::empty_like(topk_ids);
|
||||
auto sorted_row_idx = torch::empty_like(inv_permuted_idx);
|
||||
auto align_expert_first_token_offset =
|
||||
torch::zeros_like(expert_first_token_offset);
|
||||
|
||||
CubKeyValueSorter sorter{};
|
||||
int64_t* valid_num_ptr = nullptr;
|
||||
@ -85,12 +83,14 @@ void moe_permute(
|
||||
});
|
||||
|
||||
// get m_indices and update expert_first_token_offset with align block
|
||||
getMIndices(get_ptr<int64_t>(expert_first_token_offset),
|
||||
get_ptr<int64_t>(align_expert_first_token_offset),
|
||||
get_ptr<int>(m_indices), n_local_expert, align_block_size_value,
|
||||
stream);
|
||||
// this is only required for DeepGemm and not required for CUTLASS group gemm
|
||||
if (align_block_size.has_value()) {
|
||||
// update align_expert_first_token_offset
|
||||
auto align_expert_first_token_offset =
|
||||
torch::zeros_like(expert_first_token_offset);
|
||||
getMIndices(get_ptr<int64_t>(expert_first_token_offset),
|
||||
get_ptr<int64_t>(align_expert_first_token_offset),
|
||||
get_ptr<int>(m_indices), n_local_expert, align_block_size_value,
|
||||
stream);
|
||||
expert_first_token_offset.copy_(align_expert_first_token_offset);
|
||||
}
|
||||
}
|
||||
@ -195,19 +195,14 @@ void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights,
|
||||
torch::Tensor& expert_first_token_offset,
|
||||
torch::Tensor& src_row_id2dst_row_id_map,
|
||||
torch::Tensor& m_indices) {
|
||||
TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0");
|
||||
TORCH_CHECK(false, "moe_permute is not supported on CUDA < 12.0");
|
||||
}
|
||||
|
||||
void moe_unpermute(const torch::Tensor& input,
|
||||
const torch::Tensor& topk_weights, torch::Tensor& topk_ids,
|
||||
const torch::Tensor& token_expert_indices,
|
||||
const std::optional<torch::Tensor>& expert_map,
|
||||
int64_t n_expert, int64_t n_local_expert, int64_t topk,
|
||||
const std::optional<int64_t>& align_block_size,
|
||||
torch::Tensor& permuted_input,
|
||||
torch::Tensor& expert_first_token_offset,
|
||||
torch::Tensor& src_row_id2dst_row_id_map,
|
||||
torch::Tensor& m_indices) {
|
||||
void moe_unpermute(
|
||||
const torch::Tensor& permuted_hidden_states,
|
||||
const torch::Tensor& topk_weights, const torch::Tensor& inv_permuted_idx,
|
||||
const std::optional<torch::Tensor>& expert_first_token_offset, int64_t topk,
|
||||
torch::Tensor& hidden_states) {
|
||||
TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0");
|
||||
}
|
||||
|
||||
@ -224,4 +219,4 @@ bool moe_permute_unpermute_supported() {
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("moe_permute", &moe_permute);
|
||||
m.impl("moe_unpermute", &moe_unpermute);
|
||||
}
|
||||
}
|
||||
@ -229,6 +229,11 @@ void get_cutlass_moe_mm_data(
|
||||
const int64_t num_experts, const int64_t n, const int64_t k,
|
||||
const std::optional<torch::Tensor>& blockscale_offsets);
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
|
||||
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets);
|
||||
|
||||
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
|
||||
@ -10,7 +10,7 @@
|
||||
|
||||
template <typename ElementAB, typename ElementC, typename ElementAccumulator>
|
||||
__global__ void get_group_gemm_starts(
|
||||
int32_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets,
|
||||
int64_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets,
|
||||
ElementC** out_offsets, ElementAccumulator** a_scales_offsets,
|
||||
ElementAccumulator** b_scales_offsets, ElementAB* a_base_as_int,
|
||||
ElementAB* b_base_as_int, ElementC* out_base_as_int,
|
||||
@ -34,7 +34,7 @@ __global__ void get_group_gemm_starts(
|
||||
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
|
||||
get_group_gemm_starts<cutlass::float_e4m3_t, C_TYPE, float> \
|
||||
<<<1, num_experts, 0, stream>>>( \
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()), \
|
||||
static_cast<int64_t*>(expert_offsets.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()), \
|
||||
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
|
||||
@ -61,6 +61,8 @@ void run_get_group_gemm_starts(
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
// expect int64_t to avoid overflow during offset calculations
|
||||
TORCH_CHECK(expert_offsets.dtype() == torch::kInt64);
|
||||
|
||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||
bool per_act_token = a_scales.numel() != 1;
|
||||
|
||||
@ -104,6 +104,53 @@ __global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& atomic_buffer,
|
||||
int64_t num_experts, int64_t n,
|
||||
int64_t k, cudaStream_t stream,
|
||||
const bool swap_ab) {
|
||||
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
||||
|
||||
const int32_t* topk_ptr = static_cast<const int32_t*>(topk_ids.data_ptr());
|
||||
int32_t* ps1_ptr = static_cast<int32_t*>(problem_sizes1.data_ptr());
|
||||
int32_t* ps2_ptr = static_cast<int32_t*>(problem_sizes2.data_ptr());
|
||||
int32_t* atomic_ptr = static_cast<int32_t*>(atomic_buffer.data_ptr());
|
||||
|
||||
if (swap_ab) {
|
||||
compute_problem_sizes<true><<<num_experts, num_threads, 0, stream>>>(
|
||||
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
|
||||
static_cast<int>(topk_ids.numel()), static_cast<int>(n),
|
||||
static_cast<int>(k));
|
||||
} else {
|
||||
compute_problem_sizes<false><<<num_experts, num_threads, 0, stream>>>(
|
||||
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
|
||||
static_cast<int>(topk_ids.numel()), static_cast<int>(n),
|
||||
static_cast<int>(k));
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes_caller(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
|
||||
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets) {
|
||||
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
|
||||
auto options_int32 =
|
||||
torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
|
||||
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
|
||||
|
||||
// Swap-AB should be disabled for FP4 path
|
||||
bool may_swap_ab = (!blockscale_offsets.has_value()) &&
|
||||
(topk_ids.numel() <= SWAP_AB_THRESHOLD);
|
||||
|
||||
launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2,
|
||||
atomic_buffer, num_experts, n, k, stream,
|
||||
may_swap_ab);
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_data_caller(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
@ -121,21 +168,9 @@ void get_cutlass_moe_mm_data_caller(
|
||||
bool may_swap_ab = (!blockscale_offsets.has_value()) &&
|
||||
(topk_ids.numel() <= SWAP_AB_THRESHOLD);
|
||||
|
||||
if (may_swap_ab) {
|
||||
compute_problem_sizes<true><<<num_experts, num_threads, 0, stream>>>(
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n,
|
||||
k);
|
||||
} else {
|
||||
compute_problem_sizes<false><<<num_experts, num_threads, 0, stream>>>(
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n,
|
||||
k);
|
||||
}
|
||||
launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2,
|
||||
atomic_buffer, num_experts, n, k, stream,
|
||||
may_swap_ab);
|
||||
|
||||
if (blockscale_offsets.has_value()) {
|
||||
// fp4 path
|
||||
|
||||
@ -76,6 +76,11 @@ void get_cutlass_moe_mm_data_caller(
|
||||
const int64_t num_experts, const int64_t n, const int64_t k,
|
||||
const std::optional<torch::Tensor>& blockscale_offsets);
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes_caller(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
|
||||
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets);
|
||||
|
||||
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
@ -293,6 +298,25 @@ void get_cutlass_moe_mm_data(
|
||||
version_num, ". Required capability: 90 or 100");
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
|
||||
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets) {
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100)
|
||||
get_cutlass_moe_mm_problem_sizes_caller(topk_ids, problem_sizes1,
|
||||
problem_sizes2, num_experts, n, k,
|
||||
blockscale_offsets);
|
||||
return;
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled get_cutlass_moe_mm_problem_sizes: no cutlass_scaled_mm "
|
||||
"kernel for CUDA device capability: ",
|
||||
version_num, ". Required capability: 90 or 100");
|
||||
}
|
||||
|
||||
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
|
||||
@ -571,78 +571,79 @@ def generate():
|
||||
itertools.repeat(default_heuristic))
|
||||
]
|
||||
|
||||
# Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk))
|
||||
# TODO (LucasWilkinson): Further tuning required
|
||||
qqq_tile_heuristic_config = {
|
||||
#### M = 257+
|
||||
# ((128, 256), (2, 1, 1)) Broken for QQQ types
|
||||
# TODO (LucasWilkinson): Investigate further
|
||||
# "M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)),
|
||||
# "M > 256": ((128, 256), (2, 1, 1)),
|
||||
"M > 256": ((128, 128), (2, 1, 1)),
|
||||
#### M = 129-256
|
||||
"M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)),
|
||||
"M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)),
|
||||
# ((128, 256), (2, 1, 1)) Broken for QQQ types
|
||||
# TODO (LucasWilkinson): Investigate further
|
||||
# "M > 128": ((128, 256), (2, 1, 1)),
|
||||
"M > 128": ((128, 128), (2, 1, 1)),
|
||||
#### M = 65-128
|
||||
"M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)),
|
||||
"M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)),
|
||||
"M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)),
|
||||
"M > 64": ((128, 128), (2, 1, 1)),
|
||||
#### M = 33-64
|
||||
"M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)),
|
||||
# Broken for QQQ types
|
||||
# TODO (LucasWilkinson): Investigate further
|
||||
#"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)),
|
||||
"M > 32": ((128, 64), (2, 1, 1)),
|
||||
#### M = 17-32
|
||||
"M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)),
|
||||
"M > 16": ((256, 32), (2, 1, 1)),
|
||||
#### M = 1-16
|
||||
"N >= 26624": ((256, 16), (1, 1, 1)),
|
||||
None: ((128, 16), (1, 1, 1)),
|
||||
}
|
||||
# TODO: Support W4A8 when ready
|
||||
# # Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk))
|
||||
# # TODO (LucasWilkinson): Further tuning required
|
||||
# qqq_tile_heuristic_config = {
|
||||
# #### M = 257+
|
||||
# # ((128, 256), (2, 1, 1)) Broken for QQQ types
|
||||
# # TODO (LucasWilkinson): Investigate further
|
||||
# # "M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)),
|
||||
# # "M > 256": ((128, 256), (2, 1, 1)),
|
||||
# "M > 256": ((128, 128), (2, 1, 1)),
|
||||
# #### M = 129-256
|
||||
# "M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)),
|
||||
# "M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)),
|
||||
# # ((128, 256), (2, 1, 1)) Broken for QQQ types
|
||||
# # TODO (LucasWilkinson): Investigate further
|
||||
# # "M > 128": ((128, 256), (2, 1, 1)),
|
||||
# "M > 128": ((128, 128), (2, 1, 1)),
|
||||
# #### M = 65-128
|
||||
# "M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)),
|
||||
# "M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)),
|
||||
# "M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)),
|
||||
# "M > 64": ((128, 128), (2, 1, 1)),
|
||||
# #### M = 33-64
|
||||
# "M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)),
|
||||
# # Broken for QQQ types
|
||||
# # TODO (LucasWilkinson): Investigate further
|
||||
# #"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)),
|
||||
# "M > 32": ((128, 64), (2, 1, 1)),
|
||||
# #### M = 17-32
|
||||
# "M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)),
|
||||
# "M > 16": ((256, 32), (2, 1, 1)),
|
||||
# #### M = 1-16
|
||||
# "N >= 26624": ((256, 16), (1, 1, 1)),
|
||||
# None: ((128, 16), (1, 1, 1)),
|
||||
# }
|
||||
|
||||
# For now we use the same heuristic for all types
|
||||
# Heuristic is currently tuned for H100s
|
||||
qqq_heuristic = [
|
||||
(cond, ScheduleConfig(*tile_config,
|
||||
**sch_common_params)) # type: ignore
|
||||
for cond, tile_config in qqq_tile_heuristic_config.items()
|
||||
]
|
||||
# # For now we use the same heuristic for all types
|
||||
# # Heuristic is currently tuned for H100s
|
||||
# qqq_heuristic = [
|
||||
# (cond, ScheduleConfig(*tile_config,
|
||||
# **sch_common_params)) # type: ignore
|
||||
# for cond, tile_config in qqq_tile_heuristic_config.items()
|
||||
# ]
|
||||
|
||||
QQQ_kernel_types = [
|
||||
*(TypeConfig(
|
||||
a=DataType.s8,
|
||||
b=VLLMDataType.u4b8,
|
||||
b_group_scale=b_group_scale,
|
||||
b_group_zeropoint=DataType.void,
|
||||
b_channel_scale=DataType.f32,
|
||||
a_token_scale=DataType.f32,
|
||||
out=DataType.f16,
|
||||
accumulator=DataType.s32,
|
||||
) for b_group_scale in (DataType.f16, DataType.void)),
|
||||
*(TypeConfig(
|
||||
a=DataType.e4m3,
|
||||
b=VLLMDataType.u4b8,
|
||||
b_group_scale=b_group_scale,
|
||||
b_group_zeropoint=DataType.void,
|
||||
b_channel_scale=DataType.f32,
|
||||
a_token_scale=DataType.f32,
|
||||
out=DataType.f16,
|
||||
accumulator=DataType.f32,
|
||||
) for b_group_scale in (DataType.f16, DataType.void)),
|
||||
]
|
||||
# QQQ_kernel_types = [
|
||||
# *(TypeConfig(
|
||||
# a=DataType.s8,
|
||||
# b=VLLMDataType.u4b8,
|
||||
# b_group_scale=b_group_scale,
|
||||
# b_group_zeropoint=DataType.void,
|
||||
# b_channel_scale=DataType.f32,
|
||||
# a_token_scale=DataType.f32,
|
||||
# out=DataType.f16,
|
||||
# accumulator=DataType.s32,
|
||||
# ) for b_group_scale in (DataType.f16, DataType.void)),
|
||||
# *(TypeConfig(
|
||||
# a=DataType.e4m3,
|
||||
# b=VLLMDataType.u4b8,
|
||||
# b_group_scale=b_group_scale,
|
||||
# b_group_zeropoint=DataType.void,
|
||||
# b_channel_scale=DataType.f32,
|
||||
# a_token_scale=DataType.f32,
|
||||
# out=DataType.f16,
|
||||
# accumulator=DataType.f32,
|
||||
# ) for b_group_scale in (DataType.f16, DataType.void)),
|
||||
# ]
|
||||
|
||||
impl_configs += [
|
||||
ImplConfig(x[0], x[1], x[2])
|
||||
for x in zip(QQQ_kernel_types,
|
||||
itertools.repeat(get_unique_schedules(qqq_heuristic)),
|
||||
itertools.repeat(qqq_heuristic))
|
||||
]
|
||||
# impl_configs += [
|
||||
# ImplConfig(x[0], x[1], x[2])
|
||||
# for x in zip(QQQ_kernel_types,
|
||||
# itertools.repeat(get_unique_schedules(qqq_heuristic)),
|
||||
# itertools.repeat(qqq_heuristic))
|
||||
# ]
|
||||
|
||||
output_dir = os.path.join(SCRIPT_DIR, "generated")
|
||||
|
||||
|
||||
@ -1,209 +0,0 @@
|
||||
Contains code from https://github.com/IST-DASLab/marlin
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "{}"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright {yyyy} {name of copyright owner}
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
------------------------------------------------------------------------------------
|
||||
|
||||
This product bundles various third-party components under other open source licenses.
|
||||
This section summarizes those components and their licenses. See licenses/
|
||||
for text of these licenses.
|
||||
@ -1,32 +0,0 @@
|
||||
/*
|
||||
* Modified by HandH1998
|
||||
* Modified by Neural Magic
|
||||
* Copyright (C) Marlin.2024 Elias Frantar
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
|
||||
|
||||
// Instances of `Vec` are used to organize groups of >>registers<<, as needed
|
||||
// for instance as inputs to tensor core operations. Consequently, all
|
||||
// corresponding index accesses must be compile-time constants, which is why we
|
||||
// extensively use `#pragma unroll` throughout the kernel code to guarantee
|
||||
// this.
|
||||
template <typename T, int n>
|
||||
struct Vec {
|
||||
T elems[n];
|
||||
__device__ T& operator[](int i) { return elems[i]; }
|
||||
};
|
||||
@ -1,89 +0,0 @@
|
||||
/*
|
||||
* Modified by HandH1998
|
||||
* Modified by Neural Magic
|
||||
* Copyright (C) Marlin.2024 Elias Frantar
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// Predicated asynchronous global->shared copy; used for inputs A where we apply
|
||||
// predication to handle batchsizes that are not multiples of 16.
|
||||
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
const int BYTES = 16;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %0, 0;\n"
|
||||
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||
"}\n" ::"r"((int)pred),
|
||||
"r"(smem), "l"(glob_ptr), "n"(BYTES));
|
||||
}
|
||||
|
||||
// Asynchronous global->shared copy
|
||||
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
||||
const int BYTES = 16;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
||||
"}\n" ::"r"(smem),
|
||||
"l"(glob_ptr), "n"(BYTES));
|
||||
}
|
||||
|
||||
// Async copy fence.
|
||||
__device__ inline void cp_async_fence() {
|
||||
asm volatile("cp.async.commit_group;\n" ::);
|
||||
}
|
||||
|
||||
// Wait until at most `n` async copy stages are still pending.
|
||||
template <int n>
|
||||
__device__ inline void cp_async_wait() {
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
||||
}
|
||||
|
||||
// Wait until barrier reaches `count`, then lock for current threadblock.
|
||||
__device__ inline void barrier_acquire(int* lock, int count) {
|
||||
if (threadIdx.x == 0) {
|
||||
int state = -1;
|
||||
do
|
||||
// Guarantee that subsequent writes by this threadblock will be visible
|
||||
// globally.
|
||||
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
|
||||
: "=r"(state)
|
||||
: "l"(lock));
|
||||
while (state != count);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Release barrier and increment visitation count.
|
||||
__device__ inline void barrier_release(int* lock, bool reset = false) {
|
||||
__syncthreads();
|
||||
if (threadIdx.x == 0) {
|
||||
if (reset) {
|
||||
lock[0] = 0;
|
||||
return;
|
||||
}
|
||||
int val = 1;
|
||||
// Make sure that all writes since acquiring this barrier are visible
|
||||
// globally, while releasing the barrier.
|
||||
asm volatile("fence.acq_rel.gpu;\n");
|
||||
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
|
||||
:
|
||||
: "l"(lock), "r"(val));
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -241,14 +241,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// custom types:
|
||||
// https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA
|
||||
|
||||
// Marlin (Dense) Optimized Quantized GEMM for GPTQ.
|
||||
ops.def(
|
||||
"marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
||||
"Tensor! workspace, SymInt size_m, SymInt size_n, SymInt size_k) -> "
|
||||
"Tensor",
|
||||
{stride_tag});
|
||||
// conditionally compiled so impl in source file
|
||||
|
||||
// Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
|
||||
ops.def(
|
||||
"gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, "
|
||||
@ -353,15 +345,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("ggml_moe_get_block_size", &ggml_moe_get_block_size);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// marlin_qqq_gemm for QQQ.
|
||||
ops.def(
|
||||
"marlin_qqq_gemm(Tensor a, Tensor b_q_weight, "
|
||||
"Tensor s_tok, Tensor s_ch, Tensor s_group, "
|
||||
"Tensor! workspace, SymInt size_m, SymInt size_n, "
|
||||
"SymInt size_k) -> Tensor",
|
||||
{stride_tag});
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
// CUTLASS nvfp4 block scaled GEMM
|
||||
ops.def(
|
||||
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
|
||||
@ -440,6 +423,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
{stride_tag});
|
||||
ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data);
|
||||
|
||||
// A function that computes problem sizes for each expert's multiplication
|
||||
// used by the two mms called from fused MoE operation. It takes topk_ids as
|
||||
// an input, and computes problem_sizes1 and problem_sizes2 only.
|
||||
ops.def(
|
||||
"get_cutlass_moe_mm_problem_sizes(Tensor topk_ids, "
|
||||
" Tensor! problem_sizes1, "
|
||||
" Tensor! problem_sizes2, "
|
||||
" int num_experts, int n, int k, "
|
||||
" Tensor? blockscale_offsets) -> ()",
|
||||
{stride_tag});
|
||||
ops.impl("get_cutlass_moe_mm_problem_sizes", torch::kCUDA,
|
||||
&get_cutlass_moe_mm_problem_sizes);
|
||||
|
||||
// A function that computes data required to run fused MoE with w8a8 grouped
|
||||
// GEMM and PPLX. It takes expert_num_tokens and non_zero_expert_idxs
|
||||
// as an input, and computes expert_offsets (token start indices of each
|
||||
|
||||
@ -372,31 +372,45 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
|
||||
|
||||
# Install FlashInfer from source
|
||||
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
|
||||
# Keep this in sync with https://github.com/vllm-project/vllm/blob/main/requirements/cuda.txt
|
||||
# We use `--force-reinstall --no-deps` to avoid issues with the existing FlashInfer wheel.
|
||||
ARG FLASHINFER_GIT_REF="v0.2.11"
|
||||
# Keep this in sync with "flashinfer" extra in setup.py
|
||||
ARG FLASHINFER_GIT_REF="v0.2.12"
|
||||
# Flag to control whether to compile FlashInfer AOT kernels
|
||||
# Set to "true" to enable AOT compilation:
|
||||
# docker build --build-arg FLASHINFER_AOT_COMPILE=true ...
|
||||
ARG FLASHINFER_AOT_COMPILE=false
|
||||
RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
|
||||
. /etc/environment
|
||||
git clone --depth 1 --recursive --shallow-submodules \
|
||||
--branch ${FLASHINFER_GIT_REF} \
|
||||
${FLASHINFER_GIT_REPO} flashinfer
|
||||
# Exclude CUDA arches for older versions (11.x and 12.0-12.7)
|
||||
# TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg.
|
||||
if [[ "${CUDA_VERSION}" == 11.* ]]; then
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9"
|
||||
elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a"
|
||||
else
|
||||
# CUDA 12.8+ supports 10.0a and 12.0
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0"
|
||||
fi
|
||||
echo "🏗️ Building FlashInfer for arches: ${FI_TORCH_CUDA_ARCH_LIST}"
|
||||
# Needed to build AOT kernels
|
||||
pushd flashinfer
|
||||
TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \
|
||||
python3 -m flashinfer.aot
|
||||
TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \
|
||||
uv pip install --system --no-build-isolation --force-reinstall --no-deps .
|
||||
if [ "${FLASHINFER_AOT_COMPILE}" = "true" ]; then
|
||||
# Exclude CUDA arches for older versions (11.x and 12.0-12.7)
|
||||
# TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg.
|
||||
if [[ "${CUDA_VERSION}" == 11.* ]]; then
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9"
|
||||
elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a"
|
||||
else
|
||||
# CUDA 12.8+ supports 10.0a and 12.0
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0"
|
||||
fi
|
||||
echo "🏗️ Installing FlashInfer with AOT compilation for arches: ${FI_TORCH_CUDA_ARCH_LIST}"
|
||||
# Build AOT kernels
|
||||
TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \
|
||||
python3 -m flashinfer.aot
|
||||
# Install with no-build-isolation since we already built AOT kernels
|
||||
TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \
|
||||
uv pip install --system --no-build-isolation . \
|
||||
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
# Download pre-compiled cubins
|
||||
TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \
|
||||
python3 -m flashinfer --download-cubin || echo "WARNING: Failed to download flashinfer cubins."
|
||||
else
|
||||
echo "🏗️ Installing FlashInfer without AOT compilation in JIT mode"
|
||||
uv pip install --system . \
|
||||
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
fi
|
||||
popd
|
||||
rm -rf flashinfer
|
||||
BASH
|
||||
|
||||
@ -129,6 +129,52 @@ Data parallelism replicates the entire model across multiple GPU sets and proces
|
||||
Data parallelism can be combined with the other parallelism strategies and is set by `data_parallel_size=N`.
|
||||
Note that MoE layers will be sharded according to the product of the tensor parallel size and data parallel size.
|
||||
|
||||
### Batch-level DP for Multi-Modal Encoders
|
||||
|
||||
By default, TP is used to shard the weights of multi-modal encoders just like for language decoders,
|
||||
in order to reduce the memory and compute load on each GPU.
|
||||
|
||||
However, since the size of multi-modal encoders is very small compared to language decoders,
|
||||
there is relatively little gain from TP. On the other hand, TP incurs significant communication
|
||||
overhead because of all-reduce being performed after every layer.
|
||||
|
||||
Given this, it may be advantageous to instead shard the batched input data using TP, essentially
|
||||
performing batch-level DP. This has been shown to improve the throughput by around 10% for
|
||||
`tensor_parallel_size=8`. For vision encoders that use hardware-unoptimized Conv3D operations,
|
||||
batch-level DP can provide another 40% increase to throughput compared to regular TP.
|
||||
|
||||
Nevertheless, since the weights of the multi-modal encoder are replicated across each TP rank,
|
||||
there will be a minor increase in memory consumption and may cause OOM if you can barely fit the model already.
|
||||
|
||||
You can enable batch-level DP by setting `mm_encoder_tp_mode="data"`, for example:
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
|
||||
llm = LLM(
|
||||
model="Qwen/Qwen2.5-VL-72B-Instruct",
|
||||
tensor_parallel_size=4,
|
||||
# When mm_encoder_tp_mode="data",
|
||||
# the vision encoder uses TP=4 (not DP=1) to shard the input data,
|
||||
# so the TP size becomes the effective DP size.
|
||||
# Note that this is independent of the DP size for language decoder which is used in expert parallel setting.
|
||||
mm_encoder_tp_mode="data",
|
||||
# The language decoder uses TP=4 to shard the weights regardless
|
||||
# of the setting of mm_encoder_tp_mode
|
||||
)
|
||||
```
|
||||
|
||||
!! important
|
||||
Batch-level DP is not to be confused with API request-level DP
|
||||
(which is instead controlled by `data_parallel_size`).
|
||||
|
||||
The availablilty of batch-level DP is based on model implementation.
|
||||
Currently, the following models support `mm_encoder_tp_mode="data"`:
|
||||
|
||||
- Llama4 (<gh-pr:18368>)
|
||||
- Qwen2.5-VL (<gh-pr:22742>)
|
||||
- Step3 (<gh-pr:22697>)
|
||||
|
||||
## Input Processing
|
||||
|
||||
### Parallel Processing
|
||||
|
||||
@ -9,7 +9,7 @@ vLLM can be run on a cloud based GPU machine with [dstack](https://dstack.ai/),
|
||||
To install dstack client, run:
|
||||
|
||||
```bash
|
||||
pip install "dstack[all]
|
||||
pip install dstack[all]
|
||||
dstack server
|
||||
```
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ This guide will help you quickly get started with vLLM to perform:
|
||||
## Prerequisites
|
||||
|
||||
- OS: Linux
|
||||
- Python: 3.9 -- 3.12
|
||||
- Python: 3.9 -- 3.13
|
||||
|
||||
## Installation
|
||||
|
||||
|
||||
@ -363,7 +363,7 @@ th {
|
||||
| `GraniteMoeForCausalLM` | Granite 3.0 MoE, PowerMoE | `ibm-granite/granite-3.0-1b-a400m-base`, `ibm-granite/granite-3.0-3b-a800m-instruct`, `ibm/PowerMoE-3b`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `GraniteMoeHybridForCausalLM` | Granite 4.0 MoE Hybrid | `ibm-granite/granite-4.0-tiny-preview`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `GraniteMoeSharedForCausalLM` | Granite MoE Shared | `ibm-research/moe-7b-1b-active-shared-experts` (test model) | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | |
|
||||
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Grok1ModelForCausalLM` | Grok1 | `hpcai-tech/grok-1`. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `HunYuanDenseV1ForCausalLM` | Hunyuan-7B-Instruct-0124 | `tencent/Hunyuan-7B-Instruct-0124` | ✅︎ | | ✅︎ |
|
||||
| `HunYuanMoEV1ForCausalLM` | Hunyuan-80B-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`, etc. | ✅︎ | | ✅︎ |
|
||||
@ -373,6 +373,7 @@ th {
|
||||
| `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ | ✅︎ |
|
||||
| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Lfm2ForCausalLM` | LFM2 | `LiquidAI/LFM2-1.2B`, `LiquidAI/LFM2-700M`, `LiquidAI/LFM2-350M`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | ✅︎ |
|
||||
@ -384,8 +385,8 @@ th {
|
||||
| `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ | ✅︎ |
|
||||
| `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | | ✅︎ | ✅︎ |
|
||||
| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | | ✅︎ | ✅︎ |
|
||||
| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ | ✅︎ |
|
||||
| `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | | ✅︎ | ✅︎ |
|
||||
| `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | ✅︎ |
|
||||
@ -436,17 +437,17 @@ These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) A
|
||||
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|
||||
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
|
||||
| `BertModel`<sup>C</sup> | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | | |
|
||||
| `Gemma2Model`<sup>C</sup> | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | | ✅︎ |
|
||||
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | |
|
||||
| `GteModel`<sup>C</sup> | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | | | |
|
||||
| `GteNewModel`<sup>C</sup> | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | | | |
|
||||
| `ModernBertModel`<sup>C</sup> | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. | | | |
|
||||
| `NomicBertModel`<sup>C</sup> | Nomic BERT | `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. | | | |
|
||||
| `BertModel`<sup>C</sup> | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | | ✅︎ |
|
||||
| `Gemma2Model`<sup>C</sup> | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `GteModel`<sup>C</sup> | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | | | ✅︎ |
|
||||
| `GteNewModel`<sup>C</sup> | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | | | ✅︎ |
|
||||
| `ModernBertModel`<sup>C</sup> | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. | | | ✅︎ |
|
||||
| `NomicBertModel`<sup>C</sup> | Nomic BERT | `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. | | | ✅︎ |
|
||||
| `LlamaModel`<sup>C</sup>, `LlamaForCausalLM`<sup>C</sup>, `MistralModel`<sup>C</sup>, etc. | Llama-based | `intfloat/e5-mistral-7b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen2Model`<sup>C</sup>, `Qwen2ForCausalLM`<sup>C</sup> | Qwen2-based | `ssmits/Qwen2-7B-Instruct-embed-base` (see note), `Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen3Model`<sup>C</sup>, `Qwen3ForCausalLM`<sup>C</sup> | Qwen3-based | `Qwen/Qwen3-Embedding-0.6B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `RobertaModel`, `RobertaForMaskedLM` | RoBERTa-based | `sentence-transformers/all-roberta-large-v1`, etc. | | | |
|
||||
| `RobertaModel`, `RobertaForMaskedLM` | RoBERTa-based | `sentence-transformers/all-roberta-large-v1`, etc. | | | ✅︎ |
|
||||
| `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* | \* |
|
||||
|
||||
<sup>C</sup> Automatically converted into an embedding model via `--convert embed`. ([details](./pooling_models.md#model-conversion))
|
||||
@ -476,7 +477,7 @@ These models primarily support the [`LLM.classify`](./pooling_models.md#llmclass
|
||||
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|
||||
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
|
||||
| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ | |
|
||||
| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `GPT2ForSequenceClassification` | GPT2 | `nie3e/sentiment-polish-gpt2-small` | | | ✅︎ |
|
||||
| `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* | \* |
|
||||
|
||||
@ -493,12 +494,12 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
|
||||
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|
||||
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
|
||||
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | | |
|
||||
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | | ✅︎ |
|
||||
| `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen2ForSequenceClassification` | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | | |
|
||||
| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | | | |
|
||||
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | | ✅︎ |
|
||||
| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | | | ✅︎ |
|
||||
| `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* | \* |
|
||||
|
||||
<sup>C</sup> Automatically converted into a classification model via `--convert classify`. ([details](./pooling_models.md#model-conversion))
|
||||
@ -652,6 +653,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
||||
| `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen2_5_VLForConditionalGeneration` | Qwen2.5-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen2.5-Omni-7B` | | ✅︎ | ✅︎ |
|
||||
| `RForConditionalGeneration` | R-VL-4B | T + I<sup>E+</sup> | `YannQi/R-4B` | | ✅︎ | ✅︎ |
|
||||
| `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ |
|
||||
| `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ |
|
||||
| `Step3VLForConditionalGeneration` | Step3-VL | T + I<sup>+</sup> | `stepfun-ai/step3` | | ✅︎ | ✅︎ |
|
||||
|
||||
@ -840,8 +840,3 @@ Key capabilities:
|
||||
The following example shows how to deploy a large model like DeepSeek R1 with Ray Serve LLM: <gh-file:examples/online_serving/ray_serve_deepseek.py>.
|
||||
|
||||
Learn more about Ray Serve LLM with the official [Ray Serve LLM documentation](https://docs.ray.io/en/latest/serve/llm/serving-llms.html).
|
||||
|
||||
curl http://localhost:8002/v1/rerank -H "Content-Type: application/json" -d '{
|
||||
"query": "What is the capital of France?",
|
||||
"documents": ["The capital of France is Paris.", "The capital of Germany is Berlin."]
|
||||
}'
|
||||
@ -107,7 +107,7 @@ to enable simultaneous generation and embedding using the same engine instance i
|
||||
#### Mamba Models
|
||||
|
||||
Models using selective state-space mechanisms instead of standard transformer attention are supported.
|
||||
Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`) are supported. Please note that these models currently require disabling prefix caching in V1. Additionally, Mamba-1 models require `enforce_eager=True`.
|
||||
Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`) are supported. Please note that these models currently require disabling prefix caching in V1.
|
||||
|
||||
Models that combine Mamba-2 and Mamba-1 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`,
|
||||
`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`, `JambaForCausalLM`). Please note that
|
||||
@ -154,12 +154,15 @@ differences compared to V0:
|
||||
|
||||
##### Logprobs Calculation
|
||||
|
||||
Logprobs in V1 are now returned immediately once computed from the model’s raw output (i.e.
|
||||
By default, logprobs in V1 are now returned immediately once computed from the model’s raw output (i.e.
|
||||
before applying any logits post-processing such as temperature scaling or penalty
|
||||
adjustments). As a result, the returned logprobs do not reflect the final adjusted
|
||||
probabilities used during sampling.
|
||||
|
||||
Support for logprobs with post-sampling adjustments is in progress and will be added in future updates.
|
||||
You can adjust this behavior by setting the `--logprobs-mode` flag.
|
||||
Four modes are supported: `raw_logprobs` (default), `processed_logprobs`, `raw_logits`, `processed_logits`.
|
||||
Raw means the values before applying any logit processors, like bad words.
|
||||
Processed means the values after applying all processors, including temperature and top_k/top_p.
|
||||
|
||||
##### Prompt Logprobs with Prefix Caching
|
||||
|
||||
|
||||
@ -283,8 +283,10 @@ def run_glm4v(questions: list[str], modality: str) -> ModelRequestData:
|
||||
)
|
||||
|
||||
prompts = [
|
||||
f"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>\
|
||||
{question}<|assistant|>"
|
||||
(
|
||||
"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>"
|
||||
f"{question}<|assistant|>"
|
||||
)
|
||||
for question in questions
|
||||
]
|
||||
|
||||
@ -767,15 +769,13 @@ def run_llava_next_video(questions: list[str], modality: str) -> ModelRequestDat
|
||||
def run_llava_onevision(questions: list[str], modality: str) -> ModelRequestData:
|
||||
if modality == "video":
|
||||
prompts = [
|
||||
f"<|im_start|>user <video>\n{question}<|im_end|> \
|
||||
<|im_start|>assistant\n"
|
||||
f"<|im_start|>user <video>\n{question}<|im_end|><|im_start|>assistant\n"
|
||||
for question in questions
|
||||
]
|
||||
|
||||
elif modality == "image":
|
||||
prompts = [
|
||||
f"<|im_start|>user <image>\n{question}<|im_end|> \
|
||||
<|im_start|>assistant\n"
|
||||
f"<|im_start|>user <image>\n{question}<|im_end|><|im_start|>assistant\n"
|
||||
for question in questions
|
||||
]
|
||||
|
||||
@ -998,8 +998,7 @@ def run_molmo(questions: list[str], modality: str) -> ModelRequestData:
|
||||
)
|
||||
|
||||
prompts = [
|
||||
f"<|im_start|>user <image>\n{question}<|im_end|> \
|
||||
<|im_start|>assistant\n"
|
||||
f"<|im_start|>user <image>\n{question}<|im_end|><|im_start|>assistant\n"
|
||||
for question in questions
|
||||
]
|
||||
|
||||
@ -1436,6 +1435,28 @@ def run_qwen2_5_omni(questions: list[str], modality: str):
|
||||
)
|
||||
|
||||
|
||||
# R-4B
|
||||
def run_r_vl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
model_name = "YannQi/R-4B"
|
||||
|
||||
prompts = [
|
||||
f"<|im_start|>user <image>\n{question}<|im_end|><|im_start|>assistant\n"
|
||||
for question in questions
|
||||
]
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=16384,
|
||||
limit_mm_per_prompt={modality: 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# SkyworkR1V
|
||||
def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
@ -1622,6 +1643,7 @@ model_example_map = {
|
||||
"qwen2_vl": run_qwen2_vl,
|
||||
"qwen2_5_vl": run_qwen2_5_vl,
|
||||
"qwen2_5_omni": run_qwen2_5_omni,
|
||||
"rvl": run_r_vl,
|
||||
"skywork_chat": run_skyworkr1v,
|
||||
"smolvlm": run_smolvlm,
|
||||
"step3": run_step3,
|
||||
|
||||
@ -992,6 +992,39 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
def load_r_vl(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "YannQi/R-4B"
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=16384,
|
||||
max_num_seqs=16,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
)
|
||||
|
||||
placeholders = [{"type": "image", "image": url} for url in image_urls]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
*placeholders,
|
||||
{"type": "text", "text": question},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||||
|
||||
prompt = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
image_data=[fetch_image(url) for url in image_urls],
|
||||
)
|
||||
|
||||
|
||||
def load_smolvlm(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "HuggingFaceTB/SmolVLM2-2.2B-Instruct"
|
||||
|
||||
@ -1193,6 +1226,7 @@ model_example_map = {
|
||||
"qwen_vl_chat": load_qwen_vl_chat,
|
||||
"qwen2_vl": load_qwen2_vl,
|
||||
"qwen2_5_vl": load_qwen2_5_vl,
|
||||
"rvl": load_r_vl,
|
||||
"smolvlm": load_smolvlm,
|
||||
"step3": load_step3,
|
||||
"tarsier": load_tarsier,
|
||||
|
||||
@ -1,10 +1,14 @@
|
||||
{%- if messages %}
|
||||
{%- if system_message or tools %}
|
||||
<|system|>
|
||||
|
||||
{%- if system_message %}
|
||||
{{ system_message }}
|
||||
{%- if messages and messages[0]['role'] == 'system' %}
|
||||
{%- set system_message = messages[0]['content']|trim %}
|
||||
{%- set messages = messages[1:] %}
|
||||
{%- else %}
|
||||
{%- set system_message = "You are a helpful assistant." %}
|
||||
{%- endif %}
|
||||
|
||||
{%- if messages %}
|
||||
<|system|>
|
||||
{{ system_message }}
|
||||
{%- if tools %}
|
||||
In addition to plain text responses, you can chose to call one or more of the provided functions.
|
||||
|
||||
Use the following rule to decide when to call a function:
|
||||
@ -19,13 +23,11 @@ If you decide to call functions:
|
||||
* make sure you pick the right functions that match the user intent
|
||||
|
||||
|
||||
{%- if tools %}
|
||||
{%- for t in tools %}
|
||||
{{- t | tojson(indent=4) }}
|
||||
{{- "\n\n" }}
|
||||
{%- endfor %}
|
||||
{%- endif %}<|end|>
|
||||
{%- endif %}
|
||||
|
||||
{%- for message in messages %}
|
||||
{%- if message.role != "system" %}
|
||||
|
||||
@ -24,13 +24,14 @@ classifiers = [
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Intended Audience :: Developers",
|
||||
"Intended Audience :: Information Technology",
|
||||
"Intended Audience :: Science/Research",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
"Topic :: Scientific/Engineering :: Information Analysis",
|
||||
]
|
||||
requires-python = ">=3.9,<3.13"
|
||||
requires-python = ">=3.9,<3.14"
|
||||
dynamic = [ "version", "dependencies", "optional-dependencies"]
|
||||
|
||||
[project.urls]
|
||||
|
||||
@ -7,7 +7,7 @@ requests >= 2.26.0
|
||||
tqdm
|
||||
blake3
|
||||
py-cpuinfo
|
||||
transformers >= 4.55.0
|
||||
transformers >= 4.55.2
|
||||
tokenizers >= 0.21.1 # Required for fast incremental detokenization.
|
||||
protobuf # Required by LlamaTokenizer.
|
||||
fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint.
|
||||
@ -39,7 +39,7 @@ pyyaml
|
||||
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
|
||||
setuptools>=77.0.3,<80; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12
|
||||
einops # Required for Qwen2-VL.
|
||||
compressed-tensors == 0.10.2 # required for compressed-tensors
|
||||
compressed-tensors == 0.11.0 # required for compressed-tensors
|
||||
depyf==0.19.0 # required for profiling and debugging with compilation config
|
||||
cloudpickle # allows pickling lambda functions in model_executor/models/registry.py
|
||||
watchfiles # required for http server to monitor the updates of TLS files
|
||||
|
||||
@ -5,4 +5,10 @@ numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Req
|
||||
numba == 0.61.2; python_version > '3.9'
|
||||
|
||||
# Dependencies for NVIDIA GPUs
|
||||
torch==2.8.0
|
||||
ray[cgraph]>=2.48.0 # Ray Compiled Graph, required for pipeline parallelism in V1.
|
||||
torch==2.7.1
|
||||
torchaudio==2.7.1
|
||||
# These must be updated alongside torch
|
||||
torchvision==0.22.1 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
|
||||
# https://github.com/facebookresearch/xformers/releases/tag/v0.0.31
|
||||
xformers==0.0.31; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.7
|
||||
@ -34,7 +34,7 @@ opencv-python-headless >= 4.11.0 # required for video test
|
||||
datamodel_code_generator # required for minicpm3 test
|
||||
lm-eval[api]==0.4.8 # required for model evaluation test
|
||||
mteb[bm25s]>=1.38.11, <2 # required for mteb test
|
||||
transformers==4.55.0
|
||||
transformers==4.55.2
|
||||
tokenizers==0.21.1
|
||||
schemathesis>=3.39.15 # Required for openai schema test.
|
||||
# quantization
|
||||
|
||||
@ -1139,7 +1139,7 @@ tqdm==4.66.6
|
||||
# transformers
|
||||
tqdm-multiprocess==0.0.11
|
||||
# via lm-eval
|
||||
transformers==4.55.0
|
||||
transformers==4.55.2
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# genai-perf
|
||||
|
||||
17
setup.py
17
setup.py
@ -643,16 +643,25 @@ if envs.VLLM_USE_PRECOMPILED:
|
||||
if wheel_location is not None:
|
||||
wheel_url = wheel_location
|
||||
else:
|
||||
import platform
|
||||
arch = platform.machine()
|
||||
if arch == "x86_64":
|
||||
wheel_tag = "manylinux1_x86_64"
|
||||
elif arch == "aarch64":
|
||||
wheel_tag = "manylinux2014_aarch64"
|
||||
else:
|
||||
raise ValueError(f"Unsupported architecture: {arch}")
|
||||
base_commit = precompiled_wheel_utils.get_base_commit_in_main_branch()
|
||||
wheel_url = f"https://wheels.vllm.ai/{base_commit}/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl"
|
||||
wheel_url = f"https://wheels.vllm.ai/{base_commit}/vllm-1.0.0.dev-cp38-abi3-{wheel_tag}.whl"
|
||||
nightly_wheel_url = f"https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-{wheel_tag}.whl"
|
||||
from urllib.request import urlopen
|
||||
try:
|
||||
with urlopen(wheel_url) as resp:
|
||||
if resp.status != 200:
|
||||
wheel_url = "https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl"
|
||||
wheel_url = nightly_wheel_url
|
||||
except Exception as e:
|
||||
print(f"[warn] Falling back to nightly wheel: {e}")
|
||||
wheel_url = "https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl"
|
||||
wheel_url = nightly_wheel_url
|
||||
|
||||
patch = precompiled_wheel_utils.extract_precompiled_and_patch_package(
|
||||
wheel_url)
|
||||
@ -685,7 +694,7 @@ setup(
|
||||
"mistral_common[audio]"], # Required for audio processing
|
||||
"video": [], # Kept for backwards compatibility
|
||||
# FlashInfer should be updated together with the Dockerfile
|
||||
"flashinfer": ["flashinfer-python==0.2.11"],
|
||||
"flashinfer": ["flashinfer-python==0.2.12"],
|
||||
},
|
||||
cmdclass=cmdclass,
|
||||
package_data=package_data,
|
||||
|
||||
@ -12,10 +12,9 @@ from vllm.compilation.backends import set_model_tag
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import (ignore_torch_compile,
|
||||
support_torch_compile)
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
|
||||
set_current_vllm_config)
|
||||
from vllm.envs import VLLM_USE_V1
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
||||
VllmConfig, set_current_vllm_config)
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
# create a library to hold the custom op
|
||||
@ -164,104 +163,34 @@ class SimpleModelWithTwoGraphs(ParentModel):
|
||||
return x
|
||||
|
||||
|
||||
def test_ignore_torch_compile_decorator():
|
||||
assert VLLM_USE_V1
|
||||
|
||||
# piecewise
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly.attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
))
|
||||
|
||||
@support_torch_compile
|
||||
class A(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = '',
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x + x
|
||||
attn_output = torch.empty_like(x)
|
||||
torch.ops.silly.attention(x, x, x, attn_output)
|
||||
x = attn_output
|
||||
x = x * 3
|
||||
return x
|
||||
|
||||
@ignore_torch_compile
|
||||
class B(A):
|
||||
...
|
||||
|
||||
@support_torch_compile
|
||||
class C(B):
|
||||
...
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda()
|
||||
|
||||
# A has support_torch_compile
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1,
|
||||
num_piecewise_graphs_seen=3,
|
||||
num_piecewise_capturable_graphs_seen=2,
|
||||
num_backend_compilations=2,
|
||||
num_cudagraph_captured=4,
|
||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
), set_forward_context({}, vllm_config=vllm_config):
|
||||
# first run is for compile
|
||||
mod_A(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
|
||||
# run cudagraph captured sizes
|
||||
mod_A(torch.randn(2, MLP_SIZE).cuda())
|
||||
mod_A(torch.randn(1, MLP_SIZE).cuda())
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mod_B = B(vllm_config=vllm_config, prefix='').eval().cuda()
|
||||
|
||||
# B's ignore_torch_compile should override A's support_torch_compile
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=0,
|
||||
num_piecewise_graphs_seen=0,
|
||||
num_piecewise_capturable_graphs_seen=0,
|
||||
num_backend_compilations=0,
|
||||
num_cudagraph_captured=0,
|
||||
), set_forward_context({}, vllm_config=vllm_config):
|
||||
mod_B(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
|
||||
mod_B(torch.randn(2, MLP_SIZE).cuda())
|
||||
mod_B(torch.randn(1, MLP_SIZE).cuda())
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mod_C = C(vllm_config=vllm_config, prefix='').eval().cuda()
|
||||
|
||||
# C's support_torch_compile should override B's ignore_torch_compile
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1,
|
||||
num_piecewise_graphs_seen=3,
|
||||
num_piecewise_capturable_graphs_seen=2,
|
||||
num_backend_compilations=2,
|
||||
num_cudagraph_captured=4,
|
||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
), set_forward_context({}, vllm_config=vllm_config):
|
||||
mod_C(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
|
||||
mod_C(torch.randn(2, MLP_SIZE).cuda())
|
||||
mod_C(torch.randn(1, MLP_SIZE).cuda())
|
||||
|
||||
|
||||
@torch.inference_mode
|
||||
def run_model(vllm_config, model: nn.Module, inputs: torch.Tensor):
|
||||
def run_model(vllm_config: VllmConfig, model: nn.Module, inputs: torch.Tensor,
|
||||
cudagraph_runtime_mode: CUDAGraphMode):
|
||||
with set_forward_context({}, vllm_config=vllm_config):
|
||||
# First run is for compile
|
||||
# warmup for the model with cudagraph_mode NONE
|
||||
model(inputs)
|
||||
|
||||
# Run CUDAGraph captured sizes
|
||||
model(inputs[:2])
|
||||
model(inputs[:1])
|
||||
# simulate cudagraphs capturing
|
||||
with set_forward_context({},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=2, )):
|
||||
model(inputs[:2])
|
||||
with set_forward_context({},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=1, )):
|
||||
model(inputs[:1])
|
||||
|
||||
output = model(inputs[:2])
|
||||
# simulate cudagraphs replay
|
||||
with set_forward_context({},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=2, )):
|
||||
output = model(inputs[:2])
|
||||
|
||||
output = output.cpu()
|
||||
return output.cpu()
|
||||
@ -277,6 +206,7 @@ def test_multi_graph_piecewise_compile_outputs_equal():
|
||||
splitting_ops=["silly.attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
))
|
||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
|
||||
@ -299,11 +229,13 @@ def test_multi_graph_piecewise_compile_outputs_equal():
|
||||
num_cudagraph_captured=8,
|
||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
):
|
||||
outputs.append(run_model(vllm_config, model, inputs))
|
||||
outputs.append(
|
||||
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
||||
|
||||
# no compile or cudagraph
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.NO_COMPILATION, ))
|
||||
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
|
||||
@ -318,7 +250,8 @@ def test_multi_graph_piecewise_compile_outputs_equal():
|
||||
num_backend_compilations=0,
|
||||
num_cudagraph_captured=0,
|
||||
):
|
||||
outputs.append(run_model(vllm_config, model, inputs))
|
||||
outputs.append(
|
||||
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
||||
|
||||
# piecewise compile without CUDA graph
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
@ -326,6 +259,7 @@ def test_multi_graph_piecewise_compile_outputs_equal():
|
||||
use_cudagraph=False,
|
||||
splitting_ops=["silly.attention"],
|
||||
))
|
||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
|
||||
@ -340,7 +274,8 @@ def test_multi_graph_piecewise_compile_outputs_equal():
|
||||
num_backend_compilations=4,
|
||||
num_cudagraph_captured=0, # no cudagraph captured
|
||||
):
|
||||
outputs.append(run_model(vllm_config, model, inputs))
|
||||
outputs.append(
|
||||
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
||||
|
||||
# Generally don't expect outputs with and without inductor
|
||||
# to be bitwise equivalent
|
||||
|
||||
251
tests/compile/test_decorator.py
Normal file
251
tests/compile/test_decorator.py
Normal file
@ -0,0 +1,251 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.library import Library
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import (ignore_torch_compile,
|
||||
support_torch_compile)
|
||||
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
|
||||
CUDAGraphMode, VllmConfig, set_current_vllm_config)
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
# create a library to hold the custom op
|
||||
silly_lib = Library("silly", "FRAGMENT") # noqa
|
||||
|
||||
BATCH_SIZE = 32
|
||||
MLP_SIZE = 128
|
||||
|
||||
|
||||
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
out.copy_(q)
|
||||
out += k
|
||||
out += v
|
||||
|
||||
|
||||
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="attention",
|
||||
op_func=silly_attention,
|
||||
mutates_args=["out"],
|
||||
fake_impl=silly_attention_fake,
|
||||
target_lib=silly_lib,
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode
|
||||
def run_model(vllm_config: VllmConfig, model: nn.Module,
|
||||
cudagraph_runtime_mode: CUDAGraphMode):
|
||||
with set_forward_context({}, vllm_config=vllm_config):
|
||||
# warmup for the model with cudagraph_mode NONE
|
||||
model(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
|
||||
|
||||
# simulate cudagraphs capturing
|
||||
with set_forward_context({},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=2, )):
|
||||
model(torch.randn(2, MLP_SIZE).cuda())
|
||||
with set_forward_context({},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=1, )):
|
||||
model(torch.randn(1, MLP_SIZE).cuda())
|
||||
|
||||
# simulate cudagraphs replay
|
||||
with set_forward_context({},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=2, )):
|
||||
output = model(torch.randn(2, MLP_SIZE).cuda())
|
||||
|
||||
output = output.cpu()
|
||||
return output.cpu()
|
||||
|
||||
|
||||
def test_ignore_torch_compile_decorator():
|
||||
# piecewise
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly.attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
))
|
||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||
|
||||
@support_torch_compile
|
||||
class A(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = '',
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x + x
|
||||
attn_output = torch.empty_like(x)
|
||||
torch.ops.silly.attention(x, x, x, attn_output)
|
||||
x = attn_output
|
||||
x = x * 3
|
||||
return x
|
||||
|
||||
@ignore_torch_compile
|
||||
class B(A):
|
||||
...
|
||||
|
||||
@support_torch_compile
|
||||
class C(B):
|
||||
...
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda()
|
||||
|
||||
# A has support_torch_compile
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1,
|
||||
num_piecewise_graphs_seen=3,
|
||||
num_piecewise_capturable_graphs_seen=2,
|
||||
num_backend_compilations=2,
|
||||
num_cudagraph_captured=4,
|
||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
):
|
||||
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mod_B = B(vllm_config=vllm_config, prefix='').eval().cuda()
|
||||
|
||||
# B's ignore_torch_compile should override A's support_torch_compile
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=0,
|
||||
num_piecewise_graphs_seen=0,
|
||||
num_piecewise_capturable_graphs_seen=0,
|
||||
num_backend_compilations=0,
|
||||
num_cudagraph_captured=0,
|
||||
):
|
||||
run_model(vllm_config, mod_B, cudagraph_runtime_mode)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mod_C = C(vllm_config=vllm_config, prefix='').eval().cuda()
|
||||
|
||||
# C's support_torch_compile should override B's ignore_torch_compile
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1,
|
||||
num_piecewise_graphs_seen=3,
|
||||
num_piecewise_capturable_graphs_seen=2,
|
||||
num_backend_compilations=2,
|
||||
num_cudagraph_captured=4,
|
||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
):
|
||||
run_model(vllm_config, mod_C, cudagraph_runtime_mode)
|
||||
|
||||
|
||||
# Only enable torch.compile if
|
||||
# vllm_config.cache_config.kv_sharing_fast_prefill=True
|
||||
@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config.
|
||||
kv_sharing_fast_prefill)
|
||||
class B(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = '',
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x + x
|
||||
attn_output = torch.empty_like(x)
|
||||
torch.ops.silly.attention(x, x, x, attn_output)
|
||||
x = attn_output
|
||||
x = x + x
|
||||
return x
|
||||
|
||||
|
||||
# Only enable torch.compile if
|
||||
# vllm_config.cache_config.kv_sharing_fast_prefill=False
|
||||
@support_torch_compile(enable_if=lambda vllm_config: not vllm_config.
|
||||
cache_config.kv_sharing_fast_prefill)
|
||||
class A(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = '',
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
self.mod1 = B(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||
self.mod2 = B(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.mod1(x)
|
||||
attn_output = torch.empty_like(x)
|
||||
torch.ops.silly.attention(x, x, x, attn_output)
|
||||
x = attn_output
|
||||
x = self.mod2(x)
|
||||
return x
|
||||
|
||||
|
||||
def test_conditional_compile_enable_if():
|
||||
vllm_config = VllmConfig(cache_config=CacheConfig(
|
||||
kv_sharing_fast_prefill=True, ),
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly.attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
))
|
||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda()
|
||||
|
||||
# A has support_torch_compile but enable_if fn returns False
|
||||
# enalbe_if will be True for B, so we expect mod1 and mod2
|
||||
# to be compiled
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=2,
|
||||
num_piecewise_graphs_seen=6,
|
||||
# 3 piecewise graphs per instance of B()
|
||||
num_piecewise_capturable_graphs_seen=4,
|
||||
num_backend_compilations=4,
|
||||
num_cudagraph_captured=8,
|
||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
):
|
||||
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
|
||||
|
||||
# Set kv_sharing_fast_prefill=False
|
||||
# which will cause A to be compiled and B to not be compiled
|
||||
vllm_config = VllmConfig(cache_config=CacheConfig(
|
||||
kv_sharing_fast_prefill=False, ),
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly.attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
))
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda()
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1,
|
||||
num_piecewise_graphs_seen=7,
|
||||
# 3 attn ops and 4 non-attn ops
|
||||
num_piecewise_capturable_graphs_seen=4,
|
||||
num_backend_compilations=4,
|
||||
num_cudagraph_captured=8,
|
||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
):
|
||||
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
|
||||
@ -53,12 +53,6 @@ def models_list(*, all: bool = True, keywords: Optional[list[str]] = None):
|
||||
"quantization": "gptq_marlin_24"
|
||||
}))
|
||||
|
||||
if is_quant_method_supported("marlin"):
|
||||
TEST_MODELS.append(
|
||||
("robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin", {
|
||||
"quantization": "marlin"
|
||||
}))
|
||||
|
||||
if not current_platform.is_rocm() and is_quant_method_supported("awq"):
|
||||
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {
|
||||
"quantization": "AWQ"
|
||||
|
||||
@ -148,7 +148,7 @@ class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seq_len", [8])
|
||||
@pytest.mark.parametrize("hidden_size", [16])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
|
||||
reason="Only test on CUDA")
|
||||
@pytest.mark.skipif(
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for HF_HUB_OFFLINE mode"""
|
||||
import dataclasses
|
||||
import importlib
|
||||
import sys
|
||||
|
||||
@ -9,6 +10,7 @@ import urllib3
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
|
||||
MODEL_CONFIGS = [
|
||||
{
|
||||
@ -108,3 +110,36 @@ def _re_import_modules():
|
||||
# Error this test if reloading a module failed
|
||||
if reload_exception is not None:
|
||||
raise reload_exception
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.usefixtures("cache_models")
|
||||
def test_model_from_huggingface_offline(monkeypatch: pytest.MonkeyPatch):
|
||||
# Set HF to offline mode and ensure we can still construct an LLM
|
||||
with monkeypatch.context() as m:
|
||||
try:
|
||||
m.setenv("HF_HUB_OFFLINE", "1")
|
||||
m.setenv("VLLM_NO_USAGE_STATS", "1")
|
||||
|
||||
def disable_connect(*args, **kwargs):
|
||||
raise RuntimeError("No http calls allowed")
|
||||
|
||||
m.setattr(
|
||||
urllib3.connection.HTTPConnection,
|
||||
"connect",
|
||||
disable_connect,
|
||||
)
|
||||
m.setattr(
|
||||
urllib3.connection.HTTPSConnection,
|
||||
"connect",
|
||||
disable_connect,
|
||||
)
|
||||
# Need to re-import huggingface_hub
|
||||
# and friends to setup offline mode
|
||||
_re_import_modules()
|
||||
engine_args = EngineArgs(model="facebook/opt-125m")
|
||||
LLM(**dataclasses.asdict(engine_args))
|
||||
finally:
|
||||
# Reset the environment after the test
|
||||
# NB: Assuming tests are run in online mode
|
||||
_re_import_modules()
|
||||
|
||||
88
tests/entrypoints/openai/test_collective_rpc.py
Normal file
88
tests/entrypoints/openai/test_collective_rpc.py
Normal file
@ -0,0 +1,88 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen3-0.6B"
|
||||
|
||||
|
||||
class TestWorkerExtension:
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
"""Test non-pydantic return type."""
|
||||
return MODEL_NAME
|
||||
|
||||
def echo_args_kwargs(self, *args, **kwargs) -> dict[str, Any]:
|
||||
"""Echo back both args and kwargs."""
|
||||
return dict(
|
||||
args=list(args),
|
||||
kwargs=kwargs,
|
||||
total_items=len(args) + len(kwargs),
|
||||
)
|
||||
|
||||
def return_none(self, *args, **kwargs) -> None:
|
||||
"""Test method that does not return anything"""
|
||||
return
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--max-num-seqs",
|
||||
"128",
|
||||
"--worker-extension-cls",
|
||||
"tests.entrypoints.openai.test_collective_rpc.TestWorkerExtension",
|
||||
]
|
||||
with RemoteOpenAIServer(
|
||||
MODEL_NAME,
|
||||
args,
|
||||
env_dict={
|
||||
"VLLM_SERVER_DEV_MODE": "1",
|
||||
"CUDA_VISIBLE_DEVICES": "0"
|
||||
},
|
||||
) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
def test_get_model_name(server):
|
||||
"""Test basic response"""
|
||||
response = requests.post(server.url_for("collective_rpc"),
|
||||
json={"method": "get_model_name"})
|
||||
assert response.status_code == 200
|
||||
results = response.json()
|
||||
assert "results" in results
|
||||
assert results["results"] == [MODEL_NAME]
|
||||
|
||||
|
||||
def test_return_none(server):
|
||||
"""Test return none"""
|
||||
response = requests.post(server.url_for("collective_rpc"),
|
||||
json={"method": "return_none"})
|
||||
assert response.status_code == 200
|
||||
results = response.json()
|
||||
assert results["results"] == [None]
|
||||
|
||||
|
||||
def test_echo_args_kwargs(server):
|
||||
"""Test args, kwargs, and dict response"""
|
||||
args = ["arg1", "arg2"]
|
||||
kwargs = {"key1": "value1", "key2": "value2"}
|
||||
response = requests.post(server.url_for("collective_rpc"),
|
||||
json={
|
||||
"method": "echo_args_kwargs",
|
||||
"args": args,
|
||||
"kwargs": kwargs
|
||||
})
|
||||
assert response.status_code == 200
|
||||
results = response.json()
|
||||
result = results["results"][0]
|
||||
assert result["args"] == args
|
||||
assert result["kwargs"] == kwargs
|
||||
assert result["total_items"] == len(args) + len(kwargs)
|
||||
@ -13,6 +13,127 @@ from ...utils import RemoteOpenAIServer
|
||||
# any model with a chat template should work here
|
||||
MODEL_NAME = "Qwen/Qwen3-0.6B"
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description":
|
||||
"The city to find the weather for, e.g. 'Vienna'",
|
||||
"default": "Vienna",
|
||||
},
|
||||
"country": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The country that the city is in, e.g. 'Austria'",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "The unit to fetch the temperature in",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
"options": {
|
||||
"$ref": "#/$defs/WeatherOptions",
|
||||
"description": "Optional parameters for weather query",
|
||||
},
|
||||
},
|
||||
"required": ["country", "unit"],
|
||||
"$defs": {
|
||||
"WeatherOptions": {
|
||||
"title": "WeatherOptions",
|
||||
"type": "object",
|
||||
"additionalProperties": False,
|
||||
"properties": {
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"default": "celsius",
|
||||
"description": "Temperature unit",
|
||||
"title": "Temperature Unit",
|
||||
},
|
||||
"include_forecast": {
|
||||
"type": "boolean",
|
||||
"default": False,
|
||||
"description":
|
||||
"Whether to include a 24-hour forecast",
|
||||
"title": "Include Forecast",
|
||||
},
|
||||
"language": {
|
||||
"type": "string",
|
||||
"default": "zh-CN",
|
||||
"description": "Language of the response",
|
||||
"title": "Language",
|
||||
"enum": ["zh-CN", "en-US", "ja-JP"],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_forecast",
|
||||
"description": "Get the weather forecast for a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description":
|
||||
"The city to get the forecast for, e.g. 'Vienna'",
|
||||
"default": "Vienna",
|
||||
},
|
||||
"country": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The country that the city is in, e.g. 'Austria'",
|
||||
},
|
||||
"days": {
|
||||
"type":
|
||||
"integer",
|
||||
"description":
|
||||
"Number of days to get the forecast for (1-7)",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "The unit to fetch the temperature in",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["country", "days", "unit"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi! How are you doing today?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I'm doing well! How can I help you?"
|
||||
},
|
||||
{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"Can you tell me what the current weather is in Berlin and the "\
|
||||
"forecast for the next 5 days, in fahrenheit?",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server(): # noqa: F811
|
||||
@ -27,6 +148,8 @@ def server(): # noqa: F811
|
||||
"hermes",
|
||||
"--reasoning-parser",
|
||||
"qwen3",
|
||||
"--gpu-memory-utilization",
|
||||
"0.4"
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
@ -54,129 +177,6 @@ async def client(server):
|
||||
async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str,
|
||||
stream: bool, tool_choice: Union[str, dict],
|
||||
enable_thinking: bool):
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description":
|
||||
"The city to find the weather for, e.g. 'Vienna'",
|
||||
"default": "Vienna",
|
||||
},
|
||||
"country": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The country that the city is in, e.g. 'Austria'",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description":
|
||||
"The unit to fetch the temperature in",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
"options": {
|
||||
"$ref": "#/$defs/WeatherOptions",
|
||||
"description":
|
||||
"Optional parameters for weather query",
|
||||
},
|
||||
},
|
||||
"required": ["country", "unit"],
|
||||
"$defs": {
|
||||
"WeatherOptions": {
|
||||
"title": "WeatherOptions",
|
||||
"type": "object",
|
||||
"additionalProperties": False,
|
||||
"properties": {
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"default": "celsius",
|
||||
"description": "Temperature unit",
|
||||
"title": "Temperature Unit",
|
||||
},
|
||||
"include_forecast": {
|
||||
"type": "boolean",
|
||||
"default": False,
|
||||
"description":
|
||||
"Whether to include a 24-hour forecast",
|
||||
"title": "Include Forecast",
|
||||
},
|
||||
"language": {
|
||||
"type": "string",
|
||||
"default": "zh-CN",
|
||||
"description": "Language of the response",
|
||||
"title": "Language",
|
||||
"enum": ["zh-CN", "en-US", "ja-JP"],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_forecast",
|
||||
"description": "Get the weather forecast for a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description":
|
||||
"The city to get the forecast for, e.g. 'Vienna'",
|
||||
"default": "Vienna",
|
||||
},
|
||||
"country": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The country that the city is in, e.g. 'Austria'",
|
||||
},
|
||||
"days": {
|
||||
"type":
|
||||
"integer",
|
||||
"description":
|
||||
"Number of days to get the forecast for (1-7)",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description":
|
||||
"The unit to fetch the temperature in",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["country", "days", "unit"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi! How are you doing today?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I'm doing well! How can I help you?"
|
||||
},
|
||||
{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"Can you tell me what the current weather is in Berlin and the "\
|
||||
"forecast for the next 5 days, in fahrenheit?",
|
||||
},
|
||||
]
|
||||
if not stream:
|
||||
# Non-streaming test
|
||||
chat_completion = await client.chat.completions.create(
|
||||
@ -216,3 +216,71 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str,
|
||||
output.extend(chunk.choices[0].delta.tool_calls)
|
||||
|
||||
assert len(output) > 0
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def k2_server(): # noqa: F811
|
||||
args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"half",
|
||||
"--enable-auto-tool-choice",
|
||||
"--guided-decoding-backend",
|
||||
"xgrammar",
|
||||
"--tool-call-parser",
|
||||
"hermes",
|
||||
"--reasoning-parser",
|
||||
"qwen3",
|
||||
"--gpu-memory-utilization",
|
||||
"0.4",
|
||||
]
|
||||
# hack to test kimi_k2 tool use tool_id format.
|
||||
# avoid error in is_deepseek_mla check by setting kv_lora_rank=null
|
||||
with RemoteOpenAIServer(MODEL_NAME,
|
||||
args,
|
||||
override_hf_configs={
|
||||
"model_type": 'kimi_k2',
|
||||
'kv_lora_rank': None
|
||||
}) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def k2_client(k2_server):
|
||||
async with k2_server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
@pytest.mark.parametrize("tool_choice", ["required"])
|
||||
async def test_tool_id_kimi_k2(k2_client: openai.AsyncOpenAI, model_name: str,
|
||||
stream: bool, tool_choice: str):
|
||||
|
||||
if not stream:
|
||||
# Non-streaming test
|
||||
chat_completion = await k2_client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=model_name,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice)
|
||||
assert chat_completion.choices[0].message.tool_calls is not None
|
||||
assert len(chat_completion.choices[0].message.tool_calls) > 0
|
||||
assert chat_completion.choices[0].message.tool_calls[
|
||||
0].id == 'functions.get_current_weather:0'
|
||||
else:
|
||||
# Streaming test
|
||||
output_stream = await k2_client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=model_name,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=True)
|
||||
|
||||
output = []
|
||||
async for chunk in output_stream:
|
||||
if chunk.choices and chunk.choices[0].delta.tool_calls:
|
||||
output.extend(chunk.choices[0].delta.tool_calls)
|
||||
for o in output:
|
||||
assert o.id is None or o.id == 'functions.get_current_weather:0'
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
@ -294,6 +294,99 @@ async def test_metrics_exist(server: RemoteOpenAIServer,
|
||||
assert metric in response.text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abort_metrics_reset(server: RemoteOpenAIServer,
|
||||
client: openai.AsyncClient, use_v1: bool):
|
||||
|
||||
running_requests, waiting_requests, kv_cache_usage = (
|
||||
_get_running_metrics_from_api(server))
|
||||
|
||||
# Expect no running requests or kvcache usage
|
||||
assert running_requests == 0
|
||||
assert waiting_requests == 0
|
||||
assert kv_cache_usage == 0.0
|
||||
|
||||
# Start some long-running requests that we can abort
|
||||
tasks = []
|
||||
for _ in range(3):
|
||||
task = asyncio.create_task(
|
||||
client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=_TOKENIZED_PROMPT,
|
||||
max_tokens=100, # Long generation to give time to abort
|
||||
temperature=0.0))
|
||||
tasks.append(task)
|
||||
|
||||
# Wait a bit for requests to start processing
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Check that we have running requests
|
||||
running_requests, waiting_requests, kv_cache_usage = (
|
||||
_get_running_metrics_from_api(server))
|
||||
|
||||
# Expect running requests and kvcache usage
|
||||
assert running_requests > 0
|
||||
assert kv_cache_usage > 0
|
||||
|
||||
# Cancel all tasks to abort the requests
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
|
||||
# Wait for cancellations to be processed
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
# Check that metrics have reset to zero
|
||||
response = requests.get(server.url_for("metrics"))
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
|
||||
# Verify running and waiting requests counts and KV cache usage are zero
|
||||
running_requests_after, waiting_requests_after, kv_cache_usage_after = (
|
||||
_get_running_metrics_from_api(server))
|
||||
|
||||
assert running_requests_after == 0,\
|
||||
(f"Expected 0 running requests after abort, got "
|
||||
f"{running_requests_after}")
|
||||
assert waiting_requests_after == 0,\
|
||||
(f"Expected 0 waiting requests after abort, got "
|
||||
f"{waiting_requests_after}")
|
||||
assert kv_cache_usage_after == 0,\
|
||||
(f"Expected 0% KV cache usage after abort, got "
|
||||
f"{kv_cache_usage_after}")
|
||||
|
||||
|
||||
def _get_running_metrics_from_api(server: RemoteOpenAIServer):
|
||||
"""Return (running_count, waiting_count, kv_cache_usage)"""
|
||||
|
||||
response = requests.get(server.url_for("metrics"))
|
||||
assert response.status_code == HTTPStatus.OK
|
||||
|
||||
# Verify running and waiting requests counts and KV cache usage are zero
|
||||
running_requests, waiting_requests, kv_cache_usage = None, None, None
|
||||
|
||||
for family in text_string_to_metric_families(response.text):
|
||||
if family.name == "vllm:num_requests_running":
|
||||
for sample in family.samples:
|
||||
if sample.name == "vllm:num_requests_running":
|
||||
running_requests = sample.value
|
||||
break
|
||||
elif family.name == "vllm:num_requests_waiting":
|
||||
for sample in family.samples:
|
||||
if sample.name == "vllm:num_requests_waiting":
|
||||
waiting_requests = sample.value
|
||||
break
|
||||
elif family.name == "vllm:gpu_cache_usage_perc":
|
||||
for sample in family.samples:
|
||||
if sample.name == "vllm:gpu_cache_usage_perc":
|
||||
kv_cache_usage = sample.value
|
||||
break
|
||||
|
||||
assert running_requests is not None
|
||||
assert waiting_requests is not None
|
||||
assert kv_cache_usage is not None
|
||||
|
||||
return running_requests, waiting_requests, kv_cache_usage
|
||||
|
||||
|
||||
def test_metrics_exist_run_batch(use_v1: bool):
|
||||
input_batch = """{"custom_id": "request-0", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/multilingual-e5-small", "input": "You are a helpful assistant."}}""" # noqa: E501
|
||||
|
||||
|
||||
@ -282,9 +282,11 @@ async def test_serving_chat_could_load_correct_generation_config():
|
||||
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_type", ["gpt_oss", "any"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_serving_chat_did_set_correct_cache_salt():
|
||||
async def test_serving_chat_did_set_correct_cache_salt(model_type):
|
||||
mock_model_config = MockModelConfig()
|
||||
mock_model_config.hf_config.model_type = model_type
|
||||
|
||||
mock_engine = MagicMock(spec=MQLLMEngineClient)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
|
||||
35
tests/evals/gsm8k/README.md
Normal file
35
tests/evals/gsm8k/README.md
Normal file
@ -0,0 +1,35 @@
|
||||
# GSM8K Accuracy Evaluation
|
||||
|
||||
This directory contains a replacement for the lm-eval-harness GSM8K evaluation, using an isolated GSM8K script and vLLM server for better performance and control.
|
||||
|
||||
## Usage
|
||||
|
||||
### Run tests with pytest (like buildkite)
|
||||
|
||||
```bash
|
||||
pytest -s -v tests/gsm8k/test_gsm8k_correctness.py \
|
||||
--config-list-file=configs/models-small.txt \
|
||||
--tp-size=1
|
||||
```
|
||||
|
||||
### Run standalone evaluation script
|
||||
|
||||
```bash
|
||||
# Start vLLM server first
|
||||
vllm serve Qwen/Qwen2.5-1.5B-Instruct --port 8000
|
||||
|
||||
# Run evaluation
|
||||
python tests/gsm8k/gsm8k_eval.py --port 8000
|
||||
```
|
||||
|
||||
## Configuration Format
|
||||
|
||||
Model configs in `configs/` directory use this YAML format:
|
||||
|
||||
```yaml
|
||||
model_name: "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
accuracy_threshold: 0.54 # Minimum expected accuracy
|
||||
num_questions: 1319 # Number of questions (default: full test set)
|
||||
num_fewshot: 5 # Few-shot examples from train set
|
||||
max_model_len: 4096 # Model context length
|
||||
```
|
||||
2
tests/evals/gsm8k/__init__.py
Normal file
2
tests/evals/gsm8k/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
@ -0,0 +1,5 @@
|
||||
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test"
|
||||
accuracy_threshold: 0.74
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
max_model_len: 4096
|
||||
@ -0,0 +1,5 @@
|
||||
model_name: "RedHatAI/Llama-3.2-1B-Instruct-quantized.w8a8"
|
||||
accuracy_threshold: 0.31
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
max_model_len: 4096
|
||||
5
tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml
Normal file
5
tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml
Normal file
@ -0,0 +1,5 @@
|
||||
model_name: "nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16"
|
||||
accuracy_threshold: 0.45
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
max_model_len: 4096
|
||||
@ -0,0 +1,5 @@
|
||||
model_name: "RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic"
|
||||
accuracy_threshold: 0.60
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
max_model_len: 4096
|
||||
5
tests/evals/gsm8k/configs/Qwen3-0.6B-FP8.yaml
Normal file
5
tests/evals/gsm8k/configs/Qwen3-0.6B-FP8.yaml
Normal file
@ -0,0 +1,5 @@
|
||||
model_name: "Qwen/Qwen3-0.6B-FP8"
|
||||
accuracy_threshold: 0.375
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
max_model_len: 4096
|
||||
5
tests/evals/gsm8k/configs/models-small.txt
Normal file
5
tests/evals/gsm8k/configs/models-small.txt
Normal file
@ -0,0 +1,5 @@
|
||||
Qwen3-0.6B-FP8.yaml
|
||||
Llama-3.2-1B-Instruct-INT8-CT.yaml
|
||||
Llama-3-8B-Instruct-nonuniform-CT.yaml
|
||||
Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml
|
||||
Qwen1.5-MoE-W4A16-CT.yaml
|
||||
66
tests/evals/gsm8k/conftest.py
Normal file
66
tests/evals/gsm8k/conftest.py
Normal file
@ -0,0 +1,66 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
"""Add custom command line options."""
|
||||
parser.addoption("--config-list-file",
|
||||
default="configs/models-small.txt",
|
||||
help="File containing list of config files to test")
|
||||
parser.addoption("--tp-size",
|
||||
default=1,
|
||||
type=int,
|
||||
help="Tensor parallel size")
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
"""Generate test parameters from config files."""
|
||||
if "config_filename" in metafunc.fixturenames:
|
||||
config_list_file = metafunc.config.getoption("--config-list-file")
|
||||
tp_size = metafunc.config.getoption("--tp-size")
|
||||
|
||||
# Handle both relative and absolute paths
|
||||
config_list_path = Path(config_list_file)
|
||||
if not config_list_path.is_absolute():
|
||||
# If relative, try relative to test directory first
|
||||
test_dir_path = Path(__file__).parent / config_list_file
|
||||
if test_dir_path.exists():
|
||||
config_list_path = test_dir_path
|
||||
else:
|
||||
# Try relative to current working directory
|
||||
config_list_path = Path.cwd() / config_list_file
|
||||
|
||||
print(f"Looking for config list at: {config_list_path}")
|
||||
|
||||
config_files = []
|
||||
if config_list_path.exists():
|
||||
# Determine config directory (same directory as the list file)
|
||||
config_dir = config_list_path.parent
|
||||
|
||||
with open(config_list_path) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#"):
|
||||
config_path = config_dir / line
|
||||
print(f"Checking config file: {config_path}")
|
||||
if config_path.exists():
|
||||
config_files.append(config_path)
|
||||
print(f" ✓ Found: {config_path}")
|
||||
else:
|
||||
print(f" ✗ Missing: {config_path}")
|
||||
else:
|
||||
print(f"Config list file not found: {config_list_path}")
|
||||
|
||||
# Generate test parameters
|
||||
if config_files:
|
||||
metafunc.parametrize(["config_filename", "tp_size"],
|
||||
[(config_file, int(tp_size))
|
||||
for config_file in config_files],
|
||||
ids=[
|
||||
f"{config_file.stem}-tp{tp_size}"
|
||||
for config_file in config_files
|
||||
])
|
||||
else:
|
||||
print("No config files found, test will be skipped")
|
||||
252
tests/evals/gsm8k/gsm8k_eval.py
Normal file
252
tests/evals/gsm8k/gsm8k_eval.py
Normal file
@ -0,0 +1,252 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Isolated GSM8K evaluation script for vLLM serve endpoint.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
import regex as re
|
||||
import requests
|
||||
from tqdm.asyncio import tqdm
|
||||
|
||||
INVALID = -9999999
|
||||
|
||||
|
||||
def download_and_cache_file(url: str, filename: Optional[str] = None) -> str:
|
||||
"""Download and cache a file from a URL."""
|
||||
if filename is None:
|
||||
filename = os.path.join("/tmp", url.split("/")[-1])
|
||||
|
||||
if os.path.exists(filename):
|
||||
return filename
|
||||
|
||||
print(f"Downloading from {url} to {filename}")
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(filename, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
f.write(chunk)
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
def load_gsm8k_data() -> tuple[list[dict], list[dict]]:
|
||||
"""Load GSM8K train and test data"""
|
||||
train_url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/train.jsonl"
|
||||
test_url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
|
||||
|
||||
train_file = download_and_cache_file(train_url)
|
||||
test_file = download_and_cache_file(test_url)
|
||||
|
||||
train_data = list(read_jsonl(train_file))
|
||||
test_data = list(read_jsonl(test_file))
|
||||
|
||||
return train_data, test_data
|
||||
|
||||
|
||||
def read_jsonl(filename: str) -> Generator[dict, None, None]:
|
||||
"""Read a JSONL file."""
|
||||
with open(filename) as fin:
|
||||
for line in fin:
|
||||
if not line.startswith("#"):
|
||||
yield json.loads(line)
|
||||
|
||||
|
||||
def get_answer_value(answer_str: str) -> int:
|
||||
"""Extract the numerical answer from the response."""
|
||||
answer_str = answer_str.replace(",", "")
|
||||
numbers = re.findall(r"\d+", answer_str)
|
||||
if len(numbers) < 1:
|
||||
return INVALID
|
||||
try:
|
||||
return ast.literal_eval(numbers[-1])
|
||||
except SyntaxError:
|
||||
return INVALID
|
||||
|
||||
|
||||
async def call_vllm_api(session: aiohttp.ClientSession,
|
||||
prompt: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
stop: Optional[list[str]] = None,
|
||||
url: Optional[str] = None,
|
||||
seed: Optional[int] = None) -> str:
|
||||
"""Call vLLM's OpenAI-compatible completions endpoint."""
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stop": stop,
|
||||
}
|
||||
if seed is not None:
|
||||
data["seed"] = seed
|
||||
|
||||
try:
|
||||
async with session.post(f"{url}/v1/completions",
|
||||
json=data) as response:
|
||||
response.raise_for_status()
|
||||
result = await response.json()
|
||||
return result["choices"][0]["text"]
|
||||
except Exception as e:
|
||||
print(f"Error calling vLLM API: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
def evaluate_gsm8k(num_questions: int = 1319,
|
||||
num_shots: int = 5,
|
||||
max_tokens: int = 256,
|
||||
host: str = "http://127.0.0.1",
|
||||
port: int = 8000,
|
||||
temperature: float = 0.0,
|
||||
seed: Optional[int] = 42) -> dict[str, Union[float, int]]:
|
||||
"""
|
||||
Evaluate GSM8K accuracy using vLLM serve endpoint.
|
||||
|
||||
Returns dict with accuracy, invalid_rate, latency, etc.
|
||||
"""
|
||||
base_url = f"{host}:{port}"
|
||||
|
||||
# Load GSM8K train and test data
|
||||
train_data, test_data = load_gsm8k_data()
|
||||
|
||||
# Limit to available test questions
|
||||
num_questions = min(num_questions, len(test_data))
|
||||
|
||||
# Build few-shot examples from train split (like lm-eval does)
|
||||
few_shot_examples = ""
|
||||
for i in range(num_shots):
|
||||
few_shot_examples += (f"Question: {train_data[i]['question']}\n"
|
||||
f"Answer: {train_data[i]['answer']}\n\n")
|
||||
|
||||
# Prepare test questions and labels from test split
|
||||
questions = []
|
||||
labels = []
|
||||
for i in range(num_questions):
|
||||
questions.append(f"Question: {test_data[i]['question']}\nAnswer:")
|
||||
labels.append(get_answer_value(test_data[i]["answer"]))
|
||||
|
||||
assert all(label != INVALID for label in labels), "Some labels are invalid"
|
||||
|
||||
# Run evaluation
|
||||
async def run_async_evaluation():
|
||||
states: list[str] = [""] * num_questions
|
||||
|
||||
async def get_answer(session: aiohttp.ClientSession, i: int) -> str:
|
||||
prompt = few_shot_examples + questions[i]
|
||||
answer = await call_vllm_api(
|
||||
session=session,
|
||||
prompt=prompt,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stop=["Question", "Assistant:", "<|separator|>"],
|
||||
url=base_url,
|
||||
seed=seed,
|
||||
)
|
||||
states[i] = answer
|
||||
return answer
|
||||
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(
|
||||
total=600)) as session:
|
||||
tasks = [get_answer(session, i) for i in range(num_questions)]
|
||||
await tqdm.gather(*tasks, desc="Evaluating")
|
||||
|
||||
return states
|
||||
|
||||
print(f"Running GSM8K evaluation: {num_questions} questions, "
|
||||
f"{num_shots}-shot")
|
||||
|
||||
tic = time.perf_counter()
|
||||
states = asyncio.run(run_async_evaluation())
|
||||
latency = time.perf_counter() - tic
|
||||
|
||||
# Compute metrics
|
||||
preds = [get_answer_value(state) for state in states]
|
||||
accuracy = np.mean(np.array(preds) == np.array(labels))
|
||||
invalid_rate = np.mean(np.array(preds) == INVALID)
|
||||
|
||||
result = {
|
||||
"accuracy": accuracy,
|
||||
"invalid_rate": invalid_rate,
|
||||
"latency": latency,
|
||||
"questions_per_second": num_questions / latency,
|
||||
"num_questions": num_questions,
|
||||
"num_shots": num_shots,
|
||||
"max_tokens": max_tokens,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="GSM8K evaluation for vLLM serve")
|
||||
parser.add_argument("--num-shots",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of few-shot examples")
|
||||
parser.add_argument("--num-questions",
|
||||
type=int,
|
||||
default=1319,
|
||||
help="Number of questions to evaluate")
|
||||
parser.add_argument("--max-tokens",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Max tokens for generation")
|
||||
parser.add_argument("--host",
|
||||
type=str,
|
||||
default="http://127.0.0.1",
|
||||
help="Host URL")
|
||||
parser.add_argument("--port", type=int, default=8000, help="Port number")
|
||||
parser.add_argument("--temperature",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Temperature for generation")
|
||||
parser.add_argument("--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="Random seed for reproducibility")
|
||||
parser.add_argument("--save-results",
|
||||
type=str,
|
||||
help="Save results to JSON file")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
result = evaluate_gsm8k(
|
||||
num_questions=args.num_questions,
|
||||
num_shots=args.num_shots,
|
||||
max_tokens=args.max_tokens,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
temperature=args.temperature,
|
||||
seed=args.seed,
|
||||
)
|
||||
|
||||
# Print results to terminal
|
||||
print("\nResults:")
|
||||
print(f"Accuracy: {result['accuracy']:.3f}")
|
||||
print(f"Invalid responses: {result['invalid_rate']:.3f}")
|
||||
print(f"Total latency: {result['latency']:.3f} s")
|
||||
print(f"Questions per second: {result['questions_per_second']:.3f}")
|
||||
|
||||
# Optional file saving
|
||||
if args.save_results:
|
||||
with open(args.save_results, "w") as f:
|
||||
json.dump(result, f, indent=2)
|
||||
print(f"Results saved to {args.save_results}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
90
tests/evals/gsm8k/test_gsm8k_correctness.py
Normal file
90
tests/evals/gsm8k/test_gsm8k_correctness.py
Normal file
@ -0,0 +1,90 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
GSM8K evaluation using vLLM server and isolated GSM8K script.
|
||||
Replacement for lm-eval-harness with better performance and control.
|
||||
|
||||
Usage:
|
||||
pytest -s -v test_gsm8k_correctness.py \
|
||||
--config-list-file=configs/models-small.txt \
|
||||
--tp-size=1
|
||||
"""
|
||||
|
||||
import yaml
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
from .gsm8k_eval import evaluate_gsm8k
|
||||
|
||||
RTOL = 0.08 # Relative tolerance for accuracy comparison
|
||||
|
||||
|
||||
def launch_gsm8k_eval(eval_config, server_url, tp_size):
|
||||
"""Launch GSM8K evaluation using our isolated script."""
|
||||
# Extract host and port from server URL
|
||||
if "://" in server_url:
|
||||
server_url = server_url.split("://")[1]
|
||||
|
||||
host_port = server_url.split("/")[0] # Remove path if present
|
||||
if ":" in host_port:
|
||||
host, port = host_port.split(":")
|
||||
port = int(port)
|
||||
else:
|
||||
host = host_port
|
||||
port = 8000
|
||||
|
||||
# Add http:// prefix if not present
|
||||
if not host.startswith("http"):
|
||||
host = f"http://{host}"
|
||||
|
||||
# Run GSM8K evaluation
|
||||
results = evaluate_gsm8k(
|
||||
num_questions=eval_config["num_questions"],
|
||||
num_shots=eval_config["num_fewshot"],
|
||||
host=host,
|
||||
port=port,
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def test_gsm8k_correctness_param(config_filename, tp_size):
|
||||
"""Test GSM8K correctness for a given model configuration."""
|
||||
eval_config = yaml.safe_load(config_filename.read_text(encoding="utf-8"))
|
||||
|
||||
# Server arguments
|
||||
server_args = [
|
||||
"--max-model-len",
|
||||
str(eval_config.get("max_model_len", 4096)),
|
||||
"--enforce-eager",
|
||||
"--trust-remote-code",
|
||||
"--tensor-parallel-size",
|
||||
str(tp_size),
|
||||
]
|
||||
|
||||
# Launch server and run evaluation
|
||||
with RemoteOpenAIServer(eval_config["model_name"],
|
||||
server_args,
|
||||
max_wait_seconds=480) as remote_server:
|
||||
server_url = remote_server.url_for("v1")
|
||||
|
||||
results = launch_gsm8k_eval(eval_config, server_url, tp_size)
|
||||
|
||||
# Check accuracy against threshold
|
||||
measured_accuracy = results["accuracy"]
|
||||
expected_accuracy = eval_config["accuracy_threshold"]
|
||||
|
||||
print(f"GSM8K Results for {eval_config['model_name']}:")
|
||||
print(f" Accuracy: {measured_accuracy:.3f}")
|
||||
print(f" Expected: {expected_accuracy:.3f}")
|
||||
print(f" Questions: {results['num_questions']}")
|
||||
print(f" Invalid rate: {results['invalid_rate']:.3f}")
|
||||
print(f" Latency: {results['latency']:.1f}s")
|
||||
print(f" QPS: {results['questions_per_second']:.1f}")
|
||||
|
||||
# Verify accuracy is within tolerance
|
||||
assert measured_accuracy >= expected_accuracy - RTOL, (
|
||||
f"Accuracy too low: {measured_accuracy:.3f} < "
|
||||
f"{expected_accuracy:.3f} - {RTOL:.3f}")
|
||||
|
||||
print(f"✅ GSM8K test passed for {eval_config['model_name']}")
|
||||
@ -207,6 +207,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
|
||||
'topk_ids': topk_ids,
|
||||
'w1_scale': moe_tensors.w1_scale,
|
||||
'w2_scale': moe_tensors.w2_scale,
|
||||
'ab_strides1': moe_tensors.ab_strides1,
|
||||
'ab_strides2': moe_tensors.ab_strides2,
|
||||
'c_strides1': moe_tensors.c_strides1,
|
||||
'c_strides2': moe_tensors.c_strides2,
|
||||
'per_act_token': per_act_token,
|
||||
'a1_scale': None #moe_tensors.a_scale
|
||||
}
|
||||
@ -424,8 +428,8 @@ def test_run_cutlass_moe_fp8(
|
||||
topk_ids[0][1] = 1
|
||||
|
||||
workspace13_shape = (m * topk, max(2 * n, k))
|
||||
workspace2_shape = (m * topk, n)
|
||||
output_shape = (m * topk, k)
|
||||
workspace2_shape = (m * topk, max(n, k))
|
||||
output_shape = (m, k)
|
||||
|
||||
workspace13 = torch.empty(prod(workspace13_shape),
|
||||
device="cuda",
|
||||
@ -440,6 +444,11 @@ def test_run_cutlass_moe_fp8(
|
||||
expert_map[start:end] = list(range(num_local_experts))
|
||||
expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda")
|
||||
|
||||
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
|
||||
activation = lambda o, i: torch.ops._C.silu_and_mul(o, i)
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale,
|
||||
torch.float8_e4m3fn,
|
||||
@ -448,8 +457,9 @@ def test_run_cutlass_moe_fp8(
|
||||
func = lambda output: run_cutlass_moe_fp8(
|
||||
output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation,
|
||||
global_num_experts, expert_map, mt.w1_scale, mt.w2_scale,
|
||||
a1q_scale, None, workspace13, workspace2, None, mt.a.dtype,
|
||||
per_act_token, per_out_channel, False)
|
||||
a1q_scale, None, ab_strides1, ab_strides2, c_strides1, c_strides2,
|
||||
workspace13, workspace2, None, mt.a.dtype, per_act_token,
|
||||
per_out_channel, False, topk_weights)
|
||||
|
||||
workspace13.random_()
|
||||
output_random_workspace = torch.empty(output_shape,
|
||||
|
||||
248
tests/kernels/moe/test_flashinfer.py
Normal file
248
tests/kernels/moe/test_flashinfer.py
Normal file
@ -0,0 +1,248 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
apply_flashinfer_per_tensor_scale_fp8, flashinfer_cutlass_moe_fp8,
|
||||
register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights,
|
||||
swap_w13_to_w31)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
input_to_float8)
|
||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
|
||||
if not has_flashinfer_cutlass_fused_moe(
|
||||
) or not current_platform.has_device_capability(100):
|
||||
pytest.skip("Requires flashinfer_cutlass_fused_moe and nvfp4 support",
|
||||
allow_module_level=True)
|
||||
|
||||
NUM_EXPERTS = [16]
|
||||
TOP_KS = [1]
|
||||
|
||||
MNK_FACTORS = [
|
||||
(256, 8192, 5120),
|
||||
(256, 4096, 5120),
|
||||
(127, 8192, 5120),
|
||||
(127, 4096, 5120),
|
||||
(10, 8192, 5120),
|
||||
(10, 4096, 5120),
|
||||
(1, 8192, 5120),
|
||||
(1, 4096, 5120),
|
||||
]
|
||||
|
||||
vllm_config = VllmConfig(parallel_config=ParallelConfig(
|
||||
pipeline_parallel_size=1))
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
vllm_config.scheduler_config.max_model_len = 8192
|
||||
|
||||
|
||||
def quant_fp8_per_tensor_batches(a):
|
||||
num_batches = a.size(0)
|
||||
a_quant = []
|
||||
a_scales = []
|
||||
|
||||
for i in range(num_batches):
|
||||
a_fp8, a_global_sf = input_to_float8(a[i])
|
||||
a_global_sf = 1.0 / a_global_sf
|
||||
a_quant.append(a_fp8)
|
||||
a_scales.append(a_global_sf)
|
||||
|
||||
result_a_quant = torch.stack(a_quant)
|
||||
result_a_scales = torch.stack(a_scales)
|
||||
|
||||
return result_a_quant, result_a_scales
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestData:
|
||||
hidden_states: torch.Tensor
|
||||
w13_quantized: torch.Tensor
|
||||
w2_quantized: torch.Tensor
|
||||
a1_scale: torch.Tensor
|
||||
a2_scale: torch.Tensor
|
||||
w13_weight_scale: torch.Tensor
|
||||
w2_weight_scale: torch.Tensor
|
||||
layer: torch.nn.Module
|
||||
|
||||
@staticmethod
|
||||
def make_moe_tensors_8bit(m: int, k: int, n: int, e: int,
|
||||
reorder: bool) -> "TestData":
|
||||
hidden_states = torch.randn(
|
||||
(m, k), device="cuda", dtype=torch.bfloat16) / 10
|
||||
w13 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16)
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Scale to fp8
|
||||
_, a1_scale = input_to_float8(hidden_states)
|
||||
a1_scale = 1.0 / a1_scale
|
||||
a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to(
|
||||
dtype=torch.float32)
|
||||
w13_quantized, w13_weight_scale = quant_fp8_per_tensor_batches(w13)
|
||||
w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2)
|
||||
|
||||
layer = torch.nn.Module()
|
||||
layer.w13_weight = w13_quantized.clone()
|
||||
layer.w2_weight = w2_quantized.clone()
|
||||
layer.w13_input_scale = a1_scale
|
||||
layer.w2_input_scale = a2_scale
|
||||
layer.w13_weight_scale = w13_weight_scale
|
||||
layer.w2_weight_scale = w2_weight_scale
|
||||
|
||||
register_moe_scaling_factors(layer)
|
||||
|
||||
# flashinfer expects swapped rows for w13
|
||||
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
|
||||
if reorder:
|
||||
rotate_flashinfer_fp8_moe_weights(layer.w13_weight,
|
||||
layer.w2_weight)
|
||||
layer.custom_routing_function = Llama4MoE.custom_routing_function
|
||||
layer.intermediate_size_per_partition = n
|
||||
layer.ep_rank = 0
|
||||
layer.local_num_experts = e
|
||||
|
||||
return TestData(
|
||||
hidden_states=hidden_states,
|
||||
w13_quantized=w13_quantized,
|
||||
w2_quantized=w2_quantized,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
w13_weight_scale=w13_weight_scale,
|
||||
w2_weight_scale=w2_weight_scale,
|
||||
layer=layer,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
def test_flashinfer_per_tensor_moe_fp8_no_graph(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
monkeypatch,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
with set_current_vllm_config(vllm_config):
|
||||
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=True)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=td.hidden_states,
|
||||
router_logits=score,
|
||||
use_grouped_topk=False,
|
||||
top_k=topk,
|
||||
renormalize=False,
|
||||
custom_routing_function=Llama4MoE.custom_routing_function,
|
||||
scoring_func="softmax")
|
||||
|
||||
output = fused_experts(
|
||||
td.hidden_states,
|
||||
td.w13_quantized,
|
||||
td.w2_quantized,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
activation="silu",
|
||||
use_fp8_w8a8=True,
|
||||
per_channel_quant=False,
|
||||
global_num_experts=e,
|
||||
expert_map=None,
|
||||
w1_scale=td.w13_weight_scale,
|
||||
w2_scale=td.w2_weight_scale,
|
||||
a1_scale=td.a1_scale,
|
||||
a2_scale=td.a2_scale,
|
||||
apply_router_weight_on_input=True,
|
||||
)
|
||||
|
||||
flashinfer_output = apply_flashinfer_per_tensor_scale_fp8(
|
||||
layer=td.layer,
|
||||
hidden_states=td.hidden_states,
|
||||
router_logits=score,
|
||||
routing_bias=None,
|
||||
global_num_experts=e,
|
||||
top_k=topk,
|
||||
num_expert_group=None,
|
||||
topk_group=None,
|
||||
apply_router_weight_on_input=True)
|
||||
|
||||
torch.testing.assert_close(output,
|
||||
flashinfer_output,
|
||||
atol=5.5e-2,
|
||||
rtol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
"Requires flashinfer version that contains https://github.com/flashinfer-ai/flashinfer/pull/1472"
|
||||
)
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
def test_flashinfer_cutlass_moe_fp8_no_graph(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
monkeypatch,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
with set_current_vllm_config(vllm_config):
|
||||
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=False)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=td.hidden_states,
|
||||
router_logits=score,
|
||||
use_grouped_topk=False,
|
||||
top_k=topk,
|
||||
renormalize=False,
|
||||
custom_routing_function=Llama4MoE.custom_routing_function,
|
||||
scoring_func="softmax")
|
||||
|
||||
output = fused_experts(
|
||||
td.hidden_states,
|
||||
td.w13_quantized,
|
||||
td.w2_quantized,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
activation="silu",
|
||||
use_fp8_w8a8=True,
|
||||
per_channel_quant=False,
|
||||
global_num_experts=e,
|
||||
expert_map=None,
|
||||
w1_scale=td.w13_weight_scale,
|
||||
w2_scale=td.w2_weight_scale,
|
||||
a1_scale=td.a1_scale,
|
||||
a2_scale=td.a2_scale,
|
||||
apply_router_weight_on_input=True,
|
||||
)
|
||||
|
||||
td.layer.dp_size = 1
|
||||
|
||||
flashinfer_cutlass_output = flashinfer_cutlass_moe_fp8(
|
||||
td.hidden_states,
|
||||
td.layer,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation="silu",
|
||||
global_num_experts=e,
|
||||
expert_map=None,
|
||||
apply_router_weight_on_input=True,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(output,
|
||||
flashinfer_cutlass_output,
|
||||
atol=5.5e-2,
|
||||
rtol=1e-2)
|
||||
@ -238,7 +238,11 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int,
|
||||
atol=0,
|
||||
rtol=0)
|
||||
# check mindice
|
||||
torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0)
|
||||
# current kernel usage assumes deepgemm requires align_block_size
|
||||
# when it's not provided then we don't compute m_indices (for cutlass)
|
||||
if align_block_size is not None:
|
||||
torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0)
|
||||
|
||||
# check permuted_hidden_states, only valid token
|
||||
torch.testing.assert_close(gold_permuted_hidden_states[valid_row_idx],
|
||||
permuted_hidden_states[valid_row_idx],
|
||||
|
||||
@ -4,15 +4,27 @@
|
||||
import importlib
|
||||
import importlib.metadata
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
QUARK_MXFP4_AVAILABLE = importlib.util.find_spec(
|
||||
"quark") is not None and version.parse(
|
||||
importlib.metadata.version("amd-quark")) >= version.parse('0.8.99')
|
||||
|
||||
TRTLLM_GEN_MXFP4_AVAILABLE = current_platform.is_cuda(
|
||||
) and current_platform.is_device_capability(100)
|
||||
|
||||
if TRTLLM_GEN_MXFP4_AVAILABLE:
|
||||
from flashinfer import (fp4_quantize, mxfp8_quantize,
|
||||
next_positive_power_of_2,
|
||||
reorder_rows_for_gated_act_gemm, shuffle_matrix_a,
|
||||
shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelCase:
|
||||
@ -54,4 +66,410 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
|
||||
|
||||
output = llm.generate_greedy("Today I am in the French Alps and",
|
||||
max_tokens=20)
|
||||
assert output
|
||||
assert output
|
||||
|
||||
|
||||
def swiglu(x,
|
||||
alpha: float = 1.702,
|
||||
beta: float = 1.0,
|
||||
limit: Optional[float] = None):
|
||||
# Note we add an extra bias of 1 to the linear layer
|
||||
x_glu, x_linear = torch.chunk(x, 2, dim=-1)
|
||||
if limit is not None:
|
||||
x_glu = x_glu.clamp(max=limit)
|
||||
x_linear = x_linear.clamp(min=-limit, max=limit)
|
||||
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
|
||||
return out_glu * (x_linear + beta)
|
||||
|
||||
|
||||
fp4_lookup_table = [
|
||||
0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6
|
||||
]
|
||||
|
||||
|
||||
def mxfp4_dequantize(x, scale):
|
||||
assert x.dtype == torch.uint8
|
||||
x = x.view(torch.uint8).to(torch.int32)
|
||||
x_unpacked = torch.zeros(*x.shape[:-1],
|
||||
x.shape[-1] * 2,
|
||||
dtype=torch.int32,
|
||||
device=x.device)
|
||||
x_unpacked[..., 0::2].copy_(x & 0xF)
|
||||
x_unpacked[..., 1::2].copy_((x >> 4) & 0xF)
|
||||
|
||||
x_float = torch.zeros(x_unpacked.shape,
|
||||
dtype=torch.float32,
|
||||
device=x.device)
|
||||
for i, val in enumerate(fp4_lookup_table):
|
||||
x_float[x_unpacked == i] = val
|
||||
|
||||
scale = scale.view(torch.uint8).to(torch.int32)
|
||||
scale = (scale << 23).view(torch.float32)
|
||||
scale = scale.reshape(*x.shape[:-1], -1)
|
||||
scale = torch.stack([scale] * 32, dim=-1).reshape(*x_float.shape)
|
||||
|
||||
return x_float * scale
|
||||
|
||||
|
||||
def mxfp8_dequantize(x, scale):
|
||||
assert x.dtype == torch.float8_e4m3fn
|
||||
x_float = x.to(torch.float32)
|
||||
|
||||
scale = scale.view(torch.uint8).to(torch.int32)
|
||||
scale = (scale << 23).view(torch.float32)
|
||||
scale = scale.reshape(*x.shape[:-1], -1)
|
||||
scale = torch.stack([scale] * 32, dim=-1).reshape(*x_float.shape)
|
||||
|
||||
return x_float * scale
|
||||
|
||||
|
||||
def reference_moe(
|
||||
roouting_logits,
|
||||
topk,
|
||||
num_experts,
|
||||
hidden_states,
|
||||
w13,
|
||||
bias13,
|
||||
w2,
|
||||
bias2,
|
||||
alpha,
|
||||
beta,
|
||||
limit,
|
||||
act_type,
|
||||
):
|
||||
# renormalize routing
|
||||
experts = torch.topk(roouting_logits, k=topk, dim=-1, sorted=True)
|
||||
expert_weights = torch.nn.functional.softmax(experts.values, dim=1)
|
||||
expert_indices = experts.indices
|
||||
t = hidden_states.clone()
|
||||
# MLP #1
|
||||
mlp1_weight = w13[expert_indices, ...]
|
||||
mlp1_bias = bias13[expert_indices, ...]
|
||||
t = torch.einsum("beck,bk->bec", mlp1_weight, t) + mlp1_bias
|
||||
t = swiglu(t, alpha=alpha, beta=beta, limit=limit)
|
||||
|
||||
if act_type == 'mxfp8':
|
||||
t_quantized, t_scale = mxfp8_quantize(t.to(torch.bfloat16),
|
||||
is_sf_swizzled_layout=False)
|
||||
t = mxfp8_dequantize(t_quantized, t_scale)
|
||||
# MLP #2
|
||||
mlp2_weight = w2[expert_indices, ...]
|
||||
mlp2_bias = bias2[expert_indices, ...]
|
||||
t = torch.einsum("beck,bek->bec", mlp2_weight, t) + mlp2_bias
|
||||
# Weighted sum of experts
|
||||
t = torch.einsum("bec,be->bc", t, expert_weights)
|
||||
assert t.shape == hidden_states.shape
|
||||
return t.to(torch.bfloat16)
|
||||
|
||||
|
||||
def get_tile_tokens_dim(x: torch.Tensor, top_k: int, num_experts: int):
|
||||
# Number of tokens in the input tensor.
|
||||
num_tokens = x.shape[0]
|
||||
# Factor to account for the imbalance of the experts.
|
||||
# factor equals to the
|
||||
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
|
||||
# - 1.0 means perfect expert distribution.
|
||||
# - > 1.0 means some experts have more
|
||||
# tokens than the perfect distribution.
|
||||
# - < 1.0 does not make sense.
|
||||
imbalance_factor = 1.3
|
||||
# Calculate the number of tokens per expert
|
||||
# assuming perfect distribution.
|
||||
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
||||
# Apply the imbalance factor.
|
||||
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
|
||||
# And pad the number to the next power of 2.
|
||||
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
|
||||
# Cap to 8-64 tokens per CTA tile
|
||||
# as it's the range supported by the kernel.
|
||||
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
||||
return tile_tokens_dim
|
||||
|
||||
|
||||
def tg_mxfp4_moe(
|
||||
router_logits,
|
||||
topk,
|
||||
num_experts,
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
hidden_states,
|
||||
hidden_states_scale,
|
||||
w13_weight,
|
||||
w13_weight_scale,
|
||||
w13_bias,
|
||||
w2_weight,
|
||||
w2_weight_scale,
|
||||
w2_bias,
|
||||
act_type,
|
||||
alpha,
|
||||
beta,
|
||||
limit,
|
||||
) -> torch.Tensor:
|
||||
sf_block_size = 32
|
||||
assert (w13_weight.dim() == 3 and w13_weight.shape[0] == num_experts
|
||||
and w13_weight.shape[1] == intermediate_size * 2
|
||||
and w13_weight.shape[2] == hidden_size // 2)
|
||||
assert (w13_weight_scale.dim() == 3
|
||||
and w13_weight_scale.shape[0] == num_experts
|
||||
and w13_weight_scale.shape[1] == intermediate_size * 2
|
||||
and w13_weight_scale.shape[2] == hidden_size // sf_block_size)
|
||||
assert (w2_weight.dim() == 3 and w2_weight.shape[0] == num_experts
|
||||
and w2_weight.shape[1] == hidden_size
|
||||
and w2_weight.shape[2] == intermediate_size // 2)
|
||||
assert (w2_weight_scale.dim() == 3
|
||||
and w2_weight_scale.shape[1] == hidden_size
|
||||
and w2_weight_scale.shape[2] == intermediate_size // sf_block_size)
|
||||
assert (w13_bias.dim() == 2 and w13_bias.shape[0] == num_experts
|
||||
and w13_bias.shape[1] == intermediate_size * 2)
|
||||
assert (w2_bias.dim() == 2 and w2_bias.shape[0] == num_experts
|
||||
and w2_bias.shape[1] == hidden_size)
|
||||
|
||||
# Swap w1 and w3 as the defenition of
|
||||
# swiglu is different in the trtllm-gen
|
||||
w13_weight_scale_ = w13_weight_scale.clone()
|
||||
w13_weight_ = w13_weight.clone()
|
||||
w13_bias_ = w13_bias.clone()
|
||||
w13_weight[:, :intermediate_size, :].copy_(
|
||||
w13_weight_[:, intermediate_size:, :])
|
||||
w13_weight[:, intermediate_size:, :].copy_(
|
||||
w13_weight_[:, :intermediate_size, :])
|
||||
w13_weight_scale[:, :intermediate_size, :].copy_(
|
||||
w13_weight_scale_[:, intermediate_size:, :])
|
||||
w13_weight_scale[:, intermediate_size:, :].copy_(
|
||||
w13_weight_scale_[:, :intermediate_size, :])
|
||||
w13_bias[:, :intermediate_size].copy_(w13_bias_[:, intermediate_size:])
|
||||
w13_bias[:, intermediate_size:].copy_(w13_bias_[:, :intermediate_size])
|
||||
|
||||
# Interleave the weights and scaling factors for activation
|
||||
w13_weight_interleaved = []
|
||||
w13_weight_scale_interleaved = []
|
||||
w13_bias_interleaved = []
|
||||
for i in range(num_experts):
|
||||
w13_weight_interleaved.append(
|
||||
reorder_rows_for_gated_act_gemm(w13_weight[i].clone()))
|
||||
w13_weight_scale_interleaved.append(
|
||||
reorder_rows_for_gated_act_gemm(w13_weight_scale[i].clone()))
|
||||
w13_bias_interleaved.append(
|
||||
reorder_rows_for_gated_act_gemm(w13_bias[i].clone().reshape(-1,
|
||||
1)))
|
||||
w13_weight = torch.stack(w13_weight_interleaved).reshape(
|
||||
num_experts, 2 * intermediate_size, hidden_size // 2)
|
||||
w13_weight_scale = torch.stack(w13_weight_scale_interleaved).reshape(
|
||||
num_experts, 2 * intermediate_size, hidden_size // 32)
|
||||
w13_bias = torch.stack(w13_bias_interleaved).reshape(
|
||||
num_experts, 2 * intermediate_size)
|
||||
|
||||
# Shuffle weights and scaling factors for transposed mma output
|
||||
gemm1_weights_shuffled = []
|
||||
gemm1_scales_shuffled = []
|
||||
gemm2_weights_shuffled = []
|
||||
gemm2_scales_shuffled = []
|
||||
gemm1_bias_shuffled = []
|
||||
gemm2_bias_shuffled = []
|
||||
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
|
||||
for i in range(num_experts):
|
||||
gemm1_weights_shuffled.append(
|
||||
shuffle_matrix_a(w13_weight[i].view(torch.uint8), epilogue_tile_m))
|
||||
gemm1_scales_shuffled.append(
|
||||
shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
|
||||
gemm2_weights_shuffled.append(
|
||||
shuffle_matrix_a(w2_weight[i].view(torch.uint8), epilogue_tile_m))
|
||||
gemm2_scales_shuffled.append(
|
||||
shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
gemm1_bias_shuffled.append(
|
||||
shuffle_matrix_a(w13_bias[i].reshape(-1, 1), epilogue_tile_m))
|
||||
gemm2_bias_shuffled.append(
|
||||
shuffle_matrix_a(w2_bias[i].reshape(-1, 1), epilogue_tile_m))
|
||||
|
||||
w13_weight = torch.stack(gemm1_weights_shuffled)
|
||||
w13_weight_scale = torch.stack(gemm1_scales_shuffled).reshape(
|
||||
num_experts, 2 * intermediate_size,
|
||||
hidden_size // sf_block_size).view(torch.float8_e4m3fn)
|
||||
w13_bias = torch.stack(gemm1_bias_shuffled).reshape(num_experts, -1)
|
||||
|
||||
w2_weight = torch.stack(gemm2_weights_shuffled)
|
||||
w2_weight_scale = torch.stack(gemm2_scales_shuffled).reshape(
|
||||
num_experts, hidden_size,
|
||||
intermediate_size // sf_block_size).view(torch.float8_e4m3fn)
|
||||
w2_bias = torch.stack(gemm2_bias_shuffled).reshape(num_experts, -1)
|
||||
|
||||
tg_result = trtllm_fp4_block_scale_moe(
|
||||
routing_logits=router_logits.to(torch.bfloat16),
|
||||
routing_bias=None,
|
||||
hidden_states=hidden_states,
|
||||
hidden_states_scale=hidden_states_scale,
|
||||
gemm1_weights=w13_weight,
|
||||
gemm1_weights_scale=w13_weight_scale,
|
||||
gemm1_bias=w13_bias,
|
||||
gemm1_alpha=alpha,
|
||||
gemm1_beta=beta,
|
||||
gemm1_clamp_limit=limit,
|
||||
gemm2_weights=w2_weight,
|
||||
gemm2_weights_scale=w2_weight_scale,
|
||||
gemm2_bias=w2_bias,
|
||||
output1_scale_scalar=None,
|
||||
output1_scale_gate_scalar=None,
|
||||
output2_scale_scalar=None,
|
||||
num_experts=num_experts,
|
||||
top_k=topk,
|
||||
n_group=None,
|
||||
topk_group=None,
|
||||
intermediate_size=intermediate_size,
|
||||
local_expert_offset=0,
|
||||
local_num_experts=num_experts,
|
||||
routed_scaling_factor=None,
|
||||
tile_tokens_dim=get_tile_tokens_dim(hidden_states, topk, num_experts),
|
||||
routing_method_type=1, # renormalize
|
||||
do_finalize=True)[0]
|
||||
return tg_result
|
||||
|
||||
|
||||
def check_accuracy(a, b, atol, rtol, percent):
|
||||
"""Allow a mismatch percentage of 1 - percent."""
|
||||
if torch.any(torch.isnan(a)):
|
||||
raise Exception("NaN in reference output")
|
||||
if torch.any(torch.isnan(b)):
|
||||
raise Exception("NaN in actual output")
|
||||
if torch.any(torch.isinf(a)):
|
||||
raise Exception("Inf in reference output")
|
||||
if torch.any(torch.isinf(b)):
|
||||
raise Exception("Inf in actual output")
|
||||
assert a.shape == b.shape, f"Shape mismatch: {a.shape} vs {b.shape}"
|
||||
|
||||
left = torch.abs(a - b)
|
||||
right = atol + rtol * torch.abs(b)
|
||||
count = torch.sum(left > right)
|
||||
mismatch_percent = count / a.numel()
|
||||
if mismatch_percent > 1 - percent:
|
||||
raise Exception(
|
||||
f"Mismatch percentage is {mismatch_percent:.4f} for rtol {rtol} "
|
||||
f"(threshold: {1-percent:.4f})")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("topk", [1, 4])
|
||||
@pytest.mark.parametrize("num_experts", [32, 128])
|
||||
@pytest.mark.parametrize("num_tokens", [1, 128, 1024])
|
||||
@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)])
|
||||
@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None),
|
||||
(1.702, 1.0, 7.0)])
|
||||
@pytest.mark.parametrize("act_type", ['mxfp8', 'bf16'])
|
||||
@pytest.mark.skipif(
|
||||
not TRTLLM_GEN_MXFP4_AVAILABLE,
|
||||
reason="nvidia gpu and compute capability sm100 is required for this test")
|
||||
def test_trtllm_gen_mxfp4_fused_moe(
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
num_tokens: int,
|
||||
intermediate_size: int,
|
||||
hidden_size: int,
|
||||
alpha: float,
|
||||
beta: float,
|
||||
limit: Optional[float],
|
||||
act_type: str,
|
||||
):
|
||||
seed = 42
|
||||
torch.manual_seed(seed)
|
||||
hidden_states = torch.randn(num_tokens,
|
||||
hidden_size,
|
||||
device="cuda:0",
|
||||
dtype=torch.bfloat16)
|
||||
w13 = (torch.randn(num_experts,
|
||||
intermediate_size * 2,
|
||||
hidden_size,
|
||||
device="cuda:0",
|
||||
dtype=torch.bfloat16))
|
||||
w2 = (torch.randn(num_experts,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
device="cuda:0",
|
||||
dtype=torch.bfloat16))
|
||||
bias13 = torch.randn(num_experts, intermediate_size * 2,
|
||||
device="cuda:0") * 10
|
||||
bias2 = torch.randn(num_experts, hidden_size, device="cuda:0") * 10
|
||||
router_logits = torch.rand(num_tokens, num_experts,
|
||||
dtype=torch.float32).cuda()
|
||||
|
||||
w13, w13_scale = fp4_quantize(w13,
|
||||
torch.tensor(1.0, device="cuda:0"),
|
||||
32,
|
||||
sf_use_ue8m0=True,
|
||||
is_sf_swizzled_layout=False)
|
||||
w13_scale = w13_scale.view(torch.float8_e4m3fn).reshape(
|
||||
num_experts, intermediate_size * 2, hidden_size // 32)
|
||||
w2, w2_scale = fp4_quantize(w2,
|
||||
torch.tensor(1.0, device="cuda:0"),
|
||||
32,
|
||||
sf_use_ue8m0=True,
|
||||
is_sf_swizzled_layout=False)
|
||||
w2_scale = w2_scale.view(torch.float8_e4m3fn).reshape(
|
||||
num_experts, hidden_size, intermediate_size // 32)
|
||||
if act_type == 'mxfp8':
|
||||
hidden_states, hidden_states_scale = mxfp8_quantize(
|
||||
hidden_states, is_sf_swizzled_layout=False)
|
||||
hidden_states_scale = hidden_states_scale.view(
|
||||
torch.float8_e4m3fn).reshape(-1)
|
||||
else:
|
||||
hidden_states_scale = None
|
||||
|
||||
# reference result
|
||||
ref_result = torch.empty_like(hidden_states, dtype=torch.bfloat16)
|
||||
w13_ref = mxfp4_dequantize(w13.clone(), w13_scale.clone())
|
||||
w2_ref = mxfp4_dequantize(w2.clone(), w2_scale.clone())
|
||||
bias13_ref = bias13
|
||||
bias2_ref = bias2
|
||||
if act_type == 'mxfp8':
|
||||
hidden_states_ref = mxfp8_dequantize(
|
||||
hidden_states, hidden_states_scale).to(torch.float32)
|
||||
else:
|
||||
hidden_states_ref = hidden_states.to(torch.float32)
|
||||
# Process tokens in chunks of 32 to reduce memory usage
|
||||
chunk_size = 32
|
||||
num_chunks = (num_tokens + chunk_size - 1) // chunk_size
|
||||
for i in range(num_chunks):
|
||||
start_idx = i * chunk_size
|
||||
end_idx = min(start_idx + chunk_size, num_tokens)
|
||||
chunk_result = reference_moe(
|
||||
router_logits[start_idx:end_idx].to(torch.float32),
|
||||
topk,
|
||||
num_experts,
|
||||
hidden_states_ref[start_idx:end_idx],
|
||||
w13_ref,
|
||||
bias13_ref,
|
||||
w2_ref,
|
||||
bias2_ref,
|
||||
alpha,
|
||||
beta,
|
||||
limit,
|
||||
act_type,
|
||||
)
|
||||
ref_result[start_idx:end_idx].copy_(chunk_result)
|
||||
|
||||
# trtllm-gen result
|
||||
if alpha is not None:
|
||||
alpha = torch.full((num_experts, ), alpha, device=hidden_states.device)
|
||||
if limit is not None:
|
||||
limit = torch.full((num_experts, ), limit, device=hidden_states.device)
|
||||
if beta is not None:
|
||||
beta = torch.full((num_experts, ), beta, device=hidden_states.device)
|
||||
tg_result = tg_mxfp4_moe(router_logits,
|
||||
topk,
|
||||
num_experts,
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
hidden_states,
|
||||
hidden_states_scale,
|
||||
w13,
|
||||
w13_scale,
|
||||
bias13,
|
||||
w2,
|
||||
w2_scale,
|
||||
bias2,
|
||||
act_type,
|
||||
alpha=alpha,
|
||||
beta=beta,
|
||||
limit=limit)
|
||||
# relatively loose check since the mxfp4 quantization is less accurate
|
||||
check_accuracy(ref_result, tg_result, atol=0, rtol=0.3, percent=0.8)
|
||||
|
||||
@ -76,6 +76,7 @@ def pplx_cutlass_moe(
|
||||
assert torch.cuda.current_device() == pgi.local_rank
|
||||
|
||||
num_tokens, hidden_dim = a.shape
|
||||
intermediate_dim = w2.shape[2]
|
||||
num_experts = w1.shape[0]
|
||||
block_size = hidden_dim # TODO support more cases
|
||||
device = pgi.device
|
||||
@ -124,8 +125,27 @@ def pplx_cutlass_moe(
|
||||
num_local_experts=num_local_experts,
|
||||
num_dispatchers=num_dispatchers)
|
||||
|
||||
ab_strides1 = torch.full((num_local_experts, ),
|
||||
hidden_dim,
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
ab_strides2 = torch.full((num_local_experts, ),
|
||||
intermediate_dim,
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
c_strides1 = torch.full((num_local_experts, ),
|
||||
2 * intermediate_dim,
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
c_strides2 = torch.full((num_local_experts, ),
|
||||
hidden_dim,
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
|
||||
experts = CutlassBatchedExpertsFp8(num_local_experts, num_dispatchers,
|
||||
out_dtype, per_act_token, per_out_ch)
|
||||
out_dtype, per_act_token, per_out_ch,
|
||||
ab_strides1, ab_strides2, c_strides1,
|
||||
c_strides2)
|
||||
|
||||
fused_cutlass_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
|
||||
@ -535,7 +535,7 @@ def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool,
|
||||
|
||||
expert_offsets = torch.zeros((num_experts + 1),
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
dtype=torch.int64)
|
||||
|
||||
problem_sizes = torch.zeros((num_experts, 3),
|
||||
device=device,
|
||||
|
||||
@ -95,23 +95,23 @@ TEST_TYPES = [
|
||||
token_scale_type=None)
|
||||
for w_type in [scalar_types.uint4, scalar_types.uint8]
|
||||
for a_type in [torch.float16, torch.bfloat16]),
|
||||
# QQQ style
|
||||
*(TypeConfig(act_type=torch.int8,
|
||||
weight_type=scalar_types.uint4b8,
|
||||
output_type=torch.float16,
|
||||
group_scale_type=group_scale_type,
|
||||
group_zero_type=None,
|
||||
channel_scale_type=torch.float,
|
||||
token_scale_type=torch.float)
|
||||
for group_scale_type in [None, torch.float16]),
|
||||
*(TypeConfig(act_type=torch.float8_e4m3fn,
|
||||
weight_type=scalar_types.uint4b8,
|
||||
output_type=torch.float16,
|
||||
group_scale_type=group_scale_type,
|
||||
group_zero_type=None,
|
||||
channel_scale_type=torch.float,
|
||||
token_scale_type=torch.float)
|
||||
for group_scale_type in [None, torch.float16]),
|
||||
# # QQQ style
|
||||
# *(TypeConfig(act_type=torch.int8,
|
||||
# weight_type=scalar_types.uint4b8,
|
||||
# output_type=torch.float16,
|
||||
# group_scale_type=group_scale_type,
|
||||
# group_zero_type=None,
|
||||
# channel_scale_type=torch.float,
|
||||
# token_scale_type=torch.float)
|
||||
# for group_scale_type in [None, torch.float16]),
|
||||
# *(TypeConfig(act_type=torch.float8_e4m3fn,
|
||||
# weight_type=scalar_types.uint4b8,
|
||||
# output_type=torch.float16,
|
||||
# group_scale_type=group_scale_type,
|
||||
# group_zero_type=None,
|
||||
# channel_scale_type=torch.float,
|
||||
# token_scale_type=torch.float)
|
||||
# for group_scale_type in [None, torch.float16]),
|
||||
]
|
||||
|
||||
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
|
||||
|
||||
@ -13,11 +13,7 @@ from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
|
||||
from vllm.model_executor.layers.quantization.qqq import (
|
||||
MARLIN_QQQ_MAX_PARALLEL, MARLIN_QQQ_MIN_THREAD_N,
|
||||
MARLIN_QQQ_SUPPORTED_GROUP_SIZES, MARLIN_QQQ_SUPPORTED_NUM_BITS)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx,
|
||||
marlin_make_workspace_new, marlin_permute_bias, marlin_permute_scales,
|
||||
query_marlin_supported_quant_types)
|
||||
@ -31,8 +27,6 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
marlin_weights)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
|
||||
marlin_24_quantize)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_qqq import ( # noqa: E501
|
||||
marlin_qqq_quantize)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
awq_pack, gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights)
|
||||
from vllm.scalar_type import scalar_types
|
||||
@ -449,68 +443,6 @@ def test_hqq_marlin_gemm(
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("qqq"),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("num_bits", MARLIN_QQQ_SUPPORTED_NUM_BITS)
|
||||
@pytest.mark.parametrize("group_size", MARLIN_QQQ_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
def test_marlin_qqq_gemm(
|
||||
k_chunk,
|
||||
n_chunk,
|
||||
num_bits,
|
||||
group_size,
|
||||
mnk_factors,
|
||||
):
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
size_m = m_factor
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
|
||||
a_input = rand_data((size_m, size_k))
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
|
||||
# Quantize activations
|
||||
s_a = a_input.abs().max(dim=-1, keepdim=True)[0].div(int8_traits.max).to(
|
||||
torch.float)
|
||||
q_a = (a_input / s_a).round().clamp(int8_traits.min,
|
||||
int8_traits.max).to(torch.int8)
|
||||
|
||||
# Quantize weights
|
||||
w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel = \
|
||||
marlin_qqq_quantize(b_weight, num_bits, group_size)
|
||||
|
||||
workspace = MarlinWorkspace(size_n, MARLIN_QQQ_MIN_THREAD_N,
|
||||
MARLIN_QQQ_MAX_PARALLEL)
|
||||
|
||||
opcheck(torch.ops._C.marlin_qqq_gemm,
|
||||
(q_a, marlin_qqq_q_w, s_a, marlin_qqq_s_channel,
|
||||
marlin_qqq_s_group, workspace.scratch, a_input.shape[0],
|
||||
b_weight.shape[1], a_input.shape[1]))
|
||||
|
||||
output = ops.marlin_qqq_gemm(
|
||||
q_a,
|
||||
marlin_qqq_q_w,
|
||||
s_a,
|
||||
marlin_qqq_s_channel,
|
||||
marlin_qqq_s_group,
|
||||
workspace.scratch,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
)
|
||||
output_ref = torch.matmul(q_a.half() * s_a.half(), w_ref)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
max_diff = compute_max_diff(output, output_ref)
|
||||
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
def test_marlin_gemm_subset_input():
|
||||
quant_type = scalar_types.uint4b8
|
||||
group_size = 128
|
||||
@ -602,18 +534,3 @@ def test_marlin_gemm_with_bias(size_m):
|
||||
max_diff = compute_max_diff(output, output_ref)
|
||||
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
def test_marlin_gemm_opcheck():
|
||||
size_m = 2048
|
||||
size_n = 4096
|
||||
size_k = 4096
|
||||
a = torch.rand((size_m, size_n), device='cuda', dtype=torch.float16)
|
||||
w = torch.randint(-5, 5, (256, 8192), device='cuda', dtype=torch.int32)
|
||||
s = torch.full((32, size_k), 0.125, device='cuda', dtype=torch.float16)
|
||||
wk = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MAX_PARALLEL).scratch
|
||||
x = torch.ops._C.marlin_gemm(a, w, s, wk, size_m, size_n, size_k)
|
||||
y = torch.ops._C.marlin_gemm(a, w, s, wk, size_m, size_n, size_k)
|
||||
torch.testing.assert_close(x, y)
|
||||
opcheck(torch.ops._C.marlin_gemm, (a, w, s, wk, size_m, size_n, size_k))
|
||||
|
||||
144
tests/kernels/test_onednn.py
Normal file
144
tests/kernels/test_onednn.py
Normal file
@ -0,0 +1,144 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Integration tests for FlexAttention backend vs default backend"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import to_int8
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.is_cpu():
|
||||
pytest.skip("skipping CPU-only tests", allow_module_level=True)
|
||||
|
||||
NK_FACTORS = [
|
||||
(256, 128),
|
||||
(4096, 4096),
|
||||
(16384, 4096),
|
||||
(1023, 491),
|
||||
(1001, 15),
|
||||
]
|
||||
M_FACTORS = [
|
||||
(16, 1, 32, 128, 64),
|
||||
(1, 17, 1, 31, 17),
|
||||
]
|
||||
CACHE_SIZES = [2]
|
||||
DTYPE = [torch.bfloat16]
|
||||
|
||||
|
||||
def rand_int8(shape: tuple, device: str = "cpu"):
|
||||
return to_int8(torch.rand(shape, device=device) * 255 - 128)
|
||||
|
||||
|
||||
def ref_int8_scaled_mm(
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
azp: Optional[torch.Tensor],
|
||||
bias: Optional[torch.Tensor],
|
||||
output_type: torch.dtype,
|
||||
):
|
||||
if azp is not None:
|
||||
a = a.to(dtype=torch.float32) - azp.to(dtype=torch.float32)
|
||||
output = torch.mm((scale_a * a.to(dtype=torch.float32)),
|
||||
(scale_b * b.to(dtype=torch.float32)))
|
||||
if bias is not None:
|
||||
output += bias.float()
|
||||
|
||||
return output.to(dtype=output_type)
|
||||
|
||||
|
||||
def onednn_int8_gemm_test_helper(primitive_cache_size: int,
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
per_tensor_a_quant: bool,
|
||||
per_tensor_b_quant: bool,
|
||||
use_azp: bool,
|
||||
use_bias: bool,
|
||||
out_dtype: torch.dtype = torch.bfloat16,
|
||||
device: str = "cpu"):
|
||||
# Test for a oneDNN kernel with per-tensor / per-token activation
|
||||
# quantization and per-tensor / per-output channel weight quantization.
|
||||
a = to_int8(torch.randn((m, k), device=device) * 5)
|
||||
b = to_int8(torch.randn((n, k), device=device).t() * 5)
|
||||
|
||||
a_scales_shape = (1, 1) if per_tensor_a_quant else (m, 1)
|
||||
b_scales_shape = (1, 1) if per_tensor_b_quant else (1, n)
|
||||
|
||||
scale_a = (torch.randn(a_scales_shape, device=device, dtype=torch.float32))
|
||||
scale_b = (torch.randn(b_scales_shape, device=device, dtype=torch.float32))
|
||||
|
||||
if use_azp:
|
||||
azp = torch.rand(a_scales_shape, dtype=torch.float32) * 10 + 1.5
|
||||
azp = (azp / scale_a).round().to(dtype=torch.int32)
|
||||
azp_adj = scale_b * b.sum(dim=0, keepdim=True, dtype=torch.float32)
|
||||
else:
|
||||
azp = None
|
||||
azp_adj = None
|
||||
|
||||
if use_bias:
|
||||
bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
|
||||
else:
|
||||
bias = None
|
||||
|
||||
handler = ops.create_onednn_scaled_mm(
|
||||
b,
|
||||
scale_b,
|
||||
out_dtype,
|
||||
not per_tensor_a_quant,
|
||||
use_azp,
|
||||
primitive_cache_size,
|
||||
)
|
||||
|
||||
out = torch.zeros((m, n), dtype=out_dtype)
|
||||
ops.onednn_scaled_mm(handler, a, out, scale_a, azp, azp_adj, bias)
|
||||
baseline = ref_int8_scaled_mm(a, b, scale_a, scale_b, azp, bias, out_dtype)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
||||
|
||||
if use_bias:
|
||||
# To test runtime bias setting
|
||||
out = torch.zeros((m, n), dtype=out_dtype)
|
||||
ops.onednn_scaled_mm(handler, a, out, scale_a, azp, azp_adj, None)
|
||||
baseline = ref_int8_scaled_mm(a, b, scale_a, scale_b, azp, None,
|
||||
out_dtype)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n,k", NK_FACTORS)
|
||||
@pytest.mark.parametrize("m_list", M_FACTORS)
|
||||
@pytest.mark.parametrize("per_tensor_a_scale", [True, False])
|
||||
@pytest.mark.parametrize("per_tensor_b_scale", [True, False])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.parametrize("use_azp", [True, False])
|
||||
@pytest.mark.parametrize("output_type", DTYPE)
|
||||
@pytest.mark.parametrize("primitive_cache_size", CACHE_SIZES)
|
||||
def test_onednn_int8_scaled_gemm(
|
||||
n: int,
|
||||
k: int,
|
||||
m_list: tuple[int],
|
||||
per_tensor_a_scale: bool,
|
||||
per_tensor_b_scale: bool,
|
||||
use_bias: bool,
|
||||
use_azp: bool,
|
||||
output_type: torch.dtype,
|
||||
primitive_cache_size: int,
|
||||
):
|
||||
for m in m_list:
|
||||
onednn_int8_gemm_test_helper(
|
||||
primitive_cache_size=primitive_cache_size,
|
||||
m=m,
|
||||
n=n,
|
||||
k=k,
|
||||
per_tensor_a_quant=per_tensor_a_scale,
|
||||
per_tensor_b_quant=per_tensor_b_scale,
|
||||
use_bias=use_bias,
|
||||
use_azp=use_azp,
|
||||
out_dtype=output_type,
|
||||
)
|
||||
@ -31,6 +31,7 @@ HYBRID_MODELS = [
|
||||
"hmellor/tiny-random-BambaForCausalLM",
|
||||
"ibm-granite/granite-4.0-tiny-preview",
|
||||
"tiiuae/Falcon-H1-0.5B-Base",
|
||||
"LiquidAI/LFM2-1.2B",
|
||||
]
|
||||
|
||||
HF_UNSUPPORTED_MODELS = [
|
||||
@ -52,18 +53,21 @@ V1_SUPPORTED_MODELS = [
|
||||
"hmellor/tiny-random-BambaForCausalLM",
|
||||
"ibm-granite/granite-4.0-tiny-preview",
|
||||
"tiiuae/Falcon-H1-0.5B-Base",
|
||||
"LiquidAI/LFM2-1.2B",
|
||||
]
|
||||
|
||||
FULL_CUDA_GRAPH_MODELS = [
|
||||
"ai21labs/Jamba-tiny-dev",
|
||||
"Zyphra/Zamba2-1.2B-instruct",
|
||||
]
|
||||
|
||||
V0_UNSUPPORTED_MODELS = [
|
||||
"LiquidAI/LFM2-1.2B",
|
||||
]
|
||||
|
||||
# Avoid OOM
|
||||
MAX_NUM_SEQS = 4
|
||||
|
||||
# Once we add support for FCG in Mamba1, this list will be removed and tests
|
||||
# all test cases will use enforce_eager=False
|
||||
ENFORCE_EAGER_MODELS_V1 = [
|
||||
"state-spaces/mamba-130m-hf",
|
||||
"ai21labs/Jamba-tiny-dev",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@ -96,31 +100,28 @@ def test_models(
|
||||
else:
|
||||
hf_outputs = None
|
||||
|
||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
if model not in V0_UNSUPPORTED_MODELS:
|
||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
else:
|
||||
vllm_v0_outputs = None
|
||||
|
||||
if model in V1_SUPPORTED_MODELS:
|
||||
enforce_eager = False
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
if model in HYBRID_MODELS:
|
||||
# required due to reorder_batch behaviour
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
|
||||
|
||||
if model in ENFORCE_EAGER_MODELS_V1:
|
||||
enforce_eager = True
|
||||
|
||||
with vllm_runner(model,
|
||||
max_num_seqs=MAX_NUM_SEQS,
|
||||
enforce_eager=enforce_eager,
|
||||
enable_prefix_caching=False) as vllm_model:
|
||||
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
else:
|
||||
vllm_v1_outputs = None
|
||||
|
||||
if hf_outputs is not None:
|
||||
if hf_outputs is not None and vllm_v0_outputs is not None:
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_v0_outputs,
|
||||
@ -130,6 +131,7 @@ def test_models(
|
||||
|
||||
if model in V1_SUPPORTED_MODELS:
|
||||
ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs
|
||||
assert ref_outputs is not None
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=ref_outputs,
|
||||
outputs_1_lst=vllm_v1_outputs,
|
||||
@ -148,6 +150,9 @@ def test_batching(
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
if model in V0_UNSUPPORTED_MODELS:
|
||||
pytest.skip(
|
||||
f"Unsupported V0 Engine. Skipping `test_batching` on {model}.")
|
||||
|
||||
try:
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||
@ -373,7 +378,7 @@ def test_distributed_correctness(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["Zyphra/Zamba2-1.2B-instruct"])
|
||||
@pytest.mark.parametrize("model", FULL_CUDA_GRAPH_MODELS)
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_full_cuda_graph(
|
||||
@ -400,9 +405,12 @@ def test_full_cuda_graph(
|
||||
else:
|
||||
hf_outputs = None
|
||||
|
||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
if model not in V0_UNSUPPORTED_MODELS:
|
||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
else:
|
||||
vllm_v0_outputs = None
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
@ -416,7 +424,7 @@ def test_full_cuda_graph(
|
||||
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
if hf_outputs is not None:
|
||||
if hf_outputs is not None and vllm_v0_outputs is not None:
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_v0_outputs,
|
||||
@ -425,6 +433,7 @@ def test_full_cuda_graph(
|
||||
)
|
||||
|
||||
ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs
|
||||
assert ref_outputs is not None
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=ref_outputs,
|
||||
outputs_1_lst=vllm_v1_outputs,
|
||||
|
||||
@ -14,6 +14,7 @@ from ....utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "parasail-ai/GritLM-7B-vllm"
|
||||
MAX_MODEL_LEN = 4000
|
||||
ATOL = 0.002
|
||||
|
||||
|
||||
def _arr(arr):
|
||||
@ -97,16 +98,16 @@ def get_test_data():
|
||||
|
||||
def validate_embed_output(q_rep: list[list[float]], d_rep: list[list[float]]):
|
||||
cosine_sim_q0_d0 = 1 - cosine(q_rep[0], d_rep[0])
|
||||
assert cosine_sim_q0_d0 == pytest.approx(0.609, abs=0.001)
|
||||
assert cosine_sim_q0_d0 == pytest.approx(0.609, abs=ATOL)
|
||||
|
||||
cosine_sim_q0_d1 = 1 - cosine(q_rep[0], d_rep[1])
|
||||
assert cosine_sim_q0_d1 == pytest.approx(0.101, abs=0.001)
|
||||
assert cosine_sim_q0_d1 == pytest.approx(0.101, abs=ATOL)
|
||||
|
||||
cosine_sim_q1_d0 = 1 - cosine(q_rep[1], d_rep[0])
|
||||
assert cosine_sim_q1_d0 == pytest.approx(0.120, abs=0.001)
|
||||
assert cosine_sim_q1_d0 == pytest.approx(0.120, abs=ATOL)
|
||||
|
||||
cosine_sim_q1_d1 = 1 - cosine(q_rep[1], d_rep[1])
|
||||
assert cosine_sim_q1_d1 == pytest.approx(0.534, abs=0.001)
|
||||
assert cosine_sim_q1_d1 == pytest.approx(0.534, abs=ATOL)
|
||||
|
||||
|
||||
def test_gritlm_offline_embedding(vllm_runner):
|
||||
|
||||
@ -5,6 +5,7 @@ from typing import Optional, overload
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging.version import Version
|
||||
from transformers import AutoConfig, AutoModelForImageTextToText, AutoTokenizer
|
||||
from transformers import __version__ as TRANSFORMERS_VERSION
|
||||
|
||||
@ -287,8 +288,8 @@ def clear_cache():
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
|
||||
@pytest.mark.skipif(
|
||||
TRANSFORMERS_VERSION == "4.55.0",
|
||||
reason="Transformers v4.55.0 has a regression issue on mllama, "
|
||||
Version(TRANSFORMERS_VERSION) <= Version("4.55.2"),
|
||||
reason="Transformers v4.55 has a regression issue on mllama, "
|
||||
"see: https://github.com/huggingface/transformers/pull/40083")
|
||||
def test_models_single_leading_image(hf_runner, vllm_runner, image_assets,
|
||||
model, sizes, dtype, max_tokens,
|
||||
@ -319,8 +320,8 @@ def test_models_single_leading_image(hf_runner, vllm_runner, image_assets,
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
|
||||
@pytest.mark.skipif(
|
||||
TRANSFORMERS_VERSION == "4.55.0",
|
||||
reason="Transformers v4.55.0 has a regression issue on mllama, "
|
||||
Version(TRANSFORMERS_VERSION) <= Version("4.55.2"),
|
||||
reason="Transformers v4.55 has a regression issue on mllama, "
|
||||
"see: https://github.com/huggingface/transformers/pull/40083")
|
||||
def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets,
|
||||
model, dtype, max_tokens, num_logprobs,
|
||||
@ -372,8 +373,8 @@ def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets,
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
|
||||
@pytest.mark.skipif(
|
||||
TRANSFORMERS_VERSION == "4.55.0",
|
||||
reason="Transformers v4.55.0 has a regression issue on mllama, "
|
||||
Version(TRANSFORMERS_VERSION) <= Version("4.55.2"),
|
||||
reason="Transformers v4.55 has a regression issue on mllama, "
|
||||
"see: https://github.com/huggingface/transformers/pull/40083")
|
||||
def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
|
||||
dtype, max_tokens, num_logprobs,
|
||||
@ -416,8 +417,8 @@ def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.skipif(
|
||||
TRANSFORMERS_VERSION == "4.55.0",
|
||||
reason="Transformers v4.55.0 has a regression issue on mllama, "
|
||||
Version(TRANSFORMERS_VERSION) <= Version("4.55.2"),
|
||||
reason="Transformers v4.55 has a regression issue on mllama, "
|
||||
"see: https://github.com/huggingface/transformers/pull/40083")
|
||||
def test_models_distributed(
|
||||
hf_runner,
|
||||
|
||||
@ -102,7 +102,7 @@ def _test_processing_correctness(
|
||||
partial(random_video,
|
||||
rng,
|
||||
min_frames=2,
|
||||
max_frames=8,
|
||||
max_frames=16,
|
||||
min_wh=128,
|
||||
max_wh=256),
|
||||
"audio":
|
||||
@ -268,6 +268,7 @@ def _test_processing_correctness_one(
|
||||
"CohereForAI/aya-vision-8b",
|
||||
"Salesforce/blip2-opt-2.7b",
|
||||
"facebook/chameleon-7b",
|
||||
"CohereLabs/command-a-vision-07-2025",
|
||||
"deepseek-ai/deepseek-vl2-tiny",
|
||||
"microsoft/Florence-2-base",
|
||||
"adept/fuyu-8b",
|
||||
@ -275,16 +276,17 @@ def _test_processing_correctness_one(
|
||||
"google/gemma-3n-E2B-it",
|
||||
"zai-org/glm-4v-9b",
|
||||
"zai-org/GLM-4.1V-9B-Thinking",
|
||||
"zai-org/GLM-4.5V",
|
||||
"ibm-granite/granite-speech-3.3-2b",
|
||||
"h2oai/h2ovl-mississippi-800m",
|
||||
"naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B",
|
||||
"HuggingFaceM4/Idefics3-8B-Llama3",
|
||||
"internlm/Intern-S1",
|
||||
"OpenGVLab/InternVL2-1B",
|
||||
"OpenGVLab/InternVL3-1B",
|
||||
"HuggingFaceM4/Idefics3-8B-Llama3",
|
||||
"HuggingFaceTB/SmolVLM2-2.2B-Instruct",
|
||||
"Kwai-Keye/Keye-VL-8B-Preview",
|
||||
"moonshotai/Kimi-VL-A3B-Instruct",
|
||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B",
|
||||
"llava-hf/llava-1.5-7b-hf",
|
||||
"llava-hf/llava-v1.6-mistral-7b-hf",
|
||||
"llava-hf/LLaVA-NeXT-Video-7B-hf",
|
||||
@ -314,11 +316,15 @@ def _test_processing_correctness_one(
|
||||
"Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
"Qwen/Qwen2-Audio-7B-Instruct",
|
||||
"Qwen/Qwen2.5-Omni-3B",
|
||||
"YannQi/R-4B",
|
||||
"Skywork/Skywork-R1V-38B",
|
||||
"HuggingFaceTB/SmolVLM2-2.2B-Instruct",
|
||||
"stepfun-ai/step3",
|
||||
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
|
||||
"openai/whisper-large-v3",
|
||||
"omni-research/Tarsier-7b",
|
||||
"omni-research/Tarsier2-Recap-7b",
|
||||
"mistralai/Voxtral-Mini-3B-2507",
|
||||
])
|
||||
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
|
||||
@pytest.mark.parametrize("num_batches", [32])
|
||||
|
||||
@ -24,9 +24,9 @@ from vllm.utils import GiB_bytes, is_list_of, set_default_torch_num_threads
|
||||
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
|
||||
from vllm.v1.engine.core import EngineCore as V1EngineCore
|
||||
|
||||
from ...conftest import VllmRunner
|
||||
from ..registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS
|
||||
from ..utils import dummy_hf_overrides
|
||||
from ....conftest import VllmRunner
|
||||
from ...registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS
|
||||
from ...utils import dummy_hf_overrides
|
||||
|
||||
ARCH_TO_SKIP = {
|
||||
"MolmoForCausalLM": "incompatible requirements",
|
||||
@ -147,7 +147,6 @@ def get_model_id_to_test(
|
||||
return filtered_results
|
||||
|
||||
|
||||
@pytest.mark.core_model
|
||||
@pytest.mark.parametrize(
|
||||
"model_arch, model_id",
|
||||
get_model_id_to_test(_MULTIMODAL_EXAMPLE_MODELS.keys()))
|
||||
@ -39,20 +39,3 @@ def test_models(example_prompts, model_name) -> None:
|
||||
expected_str = EXPECTED_STRS_MAP[model_name][i]
|
||||
assert expected_str == output_str, (
|
||||
f"Expected: {expected_str!r}\nvLLM: {output_str!r}")
|
||||
|
||||
curl https://localhost:8002/v1/embeddings \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"input": "Query: What is the capital of France? \n\nDocuments: \n1. Paris is the capital city of France.\n2. Berlin is the capital of Germany.\n \n Rank the documents from most to least relevant to the query and provide a relevance score",
|
||||
"model": "$MODEL",
|
||||
"encoding_format": "float"
|
||||
}'
|
||||
|
||||
|
||||
curl https://localhost:8002/v1/rerank \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"input": "Query: What is the capital of France? \n\nDocuments: \n1. Paris is the capital city of France.\n2. Berlin is the capital of Germany.\n \n Rank the documents from most to least relevant to the query and provide a relevance score",
|
||||
"prompt": "Query: What is the capital of France? \n\nDocuments: \n1. Paris is the capital city of France.\n2. Berlin is the capital of Germany.\n \n Rank the documents from most to least relevant to the query and provide a relevance score"
|
||||
"model": "BAAI/bge-reranker-v2-m3",
|
||||
}'
|
||||
|
||||
@ -215,9 +215,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"HunYuanDenseV1ForCausalLM":_HfExamplesInfo("tencent/Hunyuan-7B-Instruct-0124",
|
||||
trust_remote_code=True,
|
||||
is_available_online=False),
|
||||
"HCXVisionForCausalLM": _HfExamplesInfo(
|
||||
"naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B",
|
||||
trust_remote_code=True),
|
||||
"InternLMForCausalLM": _HfExamplesInfo("internlm/internlm-chat-7b",
|
||||
trust_remote_code=True),
|
||||
"InternLM2ForCausalLM": _HfExamplesInfo("internlm/internlm2-chat-7b",
|
||||
@ -233,6 +230,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"tiny": "ai21labs/Jamba-tiny-dev",
|
||||
"random": "ai21labs/Jamba-tiny-random", # noqa: E501
|
||||
}),
|
||||
"Lfm2ForCausalLM": _HfExamplesInfo("LiquidAI/LFM2-1.2B",
|
||||
min_transformers_version="4.54"),
|
||||
"LlamaForCausalLM": _HfExamplesInfo("meta-llama/Llama-3.2-1B-Instruct",
|
||||
extras={"guard": "meta-llama/Llama-Guard-3-1B", # noqa: E501
|
||||
"hermes": "NousResearch/Hermes-3-Llama-3.1-8B", # noqa: E501
|
||||
@ -298,8 +297,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
|
||||
"Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"),
|
||||
"Step3TextForCausalLM": _HfExamplesInfo("stepfun-ai/step3",
|
||||
trust_remote_code=True,
|
||||
is_available_online=False),
|
||||
trust_remote_code=True),
|
||||
"SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct",
|
||||
trust_remote_code=True),
|
||||
"TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B",
|
||||
@ -405,22 +403,24 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501
|
||||
"Glm4vForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.1V-9B-Thinking"), # noqa: E501
|
||||
"Glm4vMoeForConditionalGeneration": _HfExamplesInfo("zai-org/GLM-4.5V",
|
||||
is_available_online=False), # noqa: E501
|
||||
min_transformers_version="4.56"), # noqa: E501
|
||||
"H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m",
|
||||
trust_remote_code=True,
|
||||
extras={"2b": "h2oai/h2ovl-mississippi-2b"}, # noqa: E501
|
||||
max_transformers_version="4.48", # noqa: E501
|
||||
transformers_version_reason="HF model is not compatible."), # noqa: E501
|
||||
"HCXVisionForCausalLM": _HfExamplesInfo("naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B", # noqa: E501
|
||||
trust_remote_code=True),
|
||||
"Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3", # noqa: E501
|
||||
{"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}, # noqa: E501
|
||||
min_transformers_version="4.55.1",
|
||||
transformers_version_reason="HF model broken in 4.55.0"), # noqa: E501
|
||||
"InternS1ForConditionalGeneration": _HfExamplesInfo("internlm/Intern-S1",
|
||||
trust_remote_code=True), # noqa: E501
|
||||
"InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B",
|
||||
extras={"2B": "OpenGVLab/InternVL2-2B",
|
||||
"3.0": "OpenGVLab/InternVL3-1B"}, # noqa: E501
|
||||
trust_remote_code=True),
|
||||
"InternS1ForConditionalGeneration": _HfExamplesInfo("internlm/Intern-S1",
|
||||
trust_remote_code=True),
|
||||
"KeyeForConditionalGeneration": _HfExamplesInfo("Kwai-Keye/Keye-VL-8B-Preview", # noqa: E501
|
||||
trust_remote_code=True),
|
||||
"KimiVLForConditionalGeneration": _HfExamplesInfo("moonshotai/Kimi-VL-A3B-Instruct", # noqa: E501
|
||||
@ -464,9 +464,10 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
transformers_version_reason="HF model is not compatible", # noqa: E501
|
||||
extras={"1.6-llama": "AIDC-AI/Ovis1.6-Llama3.2-3B",
|
||||
"1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B"}), # noqa: E501
|
||||
"Ovis2_5": _HfExamplesInfo("AIDC-AI/Ovis2.5-2B", trust_remote_code=True,
|
||||
max_transformers_version="4.53",
|
||||
transformers_version_reason="HF model is not compatible"), # noqa: E501
|
||||
"Ovis2_5": _HfExamplesInfo("AIDC-AI/Ovis2.5-2B",
|
||||
trust_remote_code=True,
|
||||
max_transformers_version="4.53",
|
||||
transformers_version_reason="HF model is not compatible"), # noqa: E501
|
||||
"PaliGemmaForConditionalGeneration": _HfExamplesInfo("google/paligemma-3b-mix-224", # noqa: E501
|
||||
extras={"v2": "google/paligemma2-3b-ft-docci-448"}), # noqa: E501
|
||||
"Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct",
|
||||
@ -490,14 +491,15 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
max_model_len=4096),
|
||||
"Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-3B"),
|
||||
"Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ"), # noqa: E501
|
||||
"RForConditionalGeneration": _HfExamplesInfo("YannQi/R-4B",
|
||||
trust_remote_code=True),
|
||||
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B",
|
||||
trust_remote_code=True),
|
||||
"SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct", # noqa: E501
|
||||
min_transformers_version="4.55.1",
|
||||
transformers_version_reason="HF model broken in 4.55.0"), # noqa: E501
|
||||
"Step3VLForConditionalGeneration": _HfExamplesInfo("stepfun-ai/step3",
|
||||
trust_remote_code=True,
|
||||
is_available_online=False),
|
||||
trust_remote_code=True),
|
||||
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501
|
||||
trust_remote_code=True),
|
||||
"TarsierForConditionalGeneration": _HfExamplesInfo("omni-research/Tarsier-7b"), # noqa: E501
|
||||
@ -532,6 +534,9 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
||||
"DeepSeekMTPModel": _HfExamplesInfo("luccafong/deepseek_mtp_main_random",
|
||||
speculative_model="luccafong/deepseek_mtp_draft_random", # noqa: E501
|
||||
trust_remote_code=True),
|
||||
"EagleDeepSeekMTPModel": _HfExamplesInfo("eagle618/deepseek-v3-random",
|
||||
speculative_model="eagle618/eagle-deepseek-v3-random", # noqa: E501
|
||||
trust_remote_code=True),
|
||||
"EagleLlamaForCausalLM": _HfExamplesInfo("yuhuili/EAGLE-LLaMA3-Instruct-8B",
|
||||
trust_remote_code=True,
|
||||
speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B",
|
||||
@ -555,6 +560,9 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
||||
is_available_online=False,
|
||||
speculative_model="openbmb/MiniCPM-2B-sft-bf16",
|
||||
tokenizer="openbmb/MiniCPM-2B-sft-bf16"),
|
||||
"ErnieMTPModel": _HfExamplesInfo("baidu/ERNIE-4.5-21B-A3B-PT",
|
||||
trust_remote_code=True,
|
||||
speculative_model="baidu/ERNIE-4.5-21B-A3B-PT"),
|
||||
"Glm4MoeMTPModel": _HfExamplesInfo("zai-org/GLM-4.5",
|
||||
speculative_model="zai-org/GLM-4.5",
|
||||
min_transformers_version="4.54",
|
||||
|
||||
@ -95,6 +95,8 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
|
||||
|
||||
@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
|
||||
def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
|
||||
if model_arch == "Lfm2ForCausalLM":
|
||||
pytest.skip("Skipping until test supports V1-only models")
|
||||
can_initialize(model_arch, monkeypatch, HF_EXAMPLE_MODELS)
|
||||
|
||||
|
||||
|
||||
@ -22,22 +22,12 @@ class ModelPair:
|
||||
MODEL_ARG_EXPTYPES = [
|
||||
# AUTOGPTQ
|
||||
# compat: autogptq <=0.7.1 is_marlin_format: bool
|
||||
# Model Serialized in Marlin Format should always use Marlin kernel.
|
||||
("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", None, "marlin"),
|
||||
("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "marlin", "marlin"),
|
||||
("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "gptq", "marlin"),
|
||||
("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "awq", "ERROR"),
|
||||
# Model Serialized in Exllama Format.
|
||||
("TheBloke/Llama-2-7B-Chat-GPTQ", None, "gptq_marlin"),
|
||||
("TheBloke/Llama-2-7B-Chat-GPTQ", "marlin", "gptq_marlin"),
|
||||
("TheBloke/Llama-2-7B-Chat-GPTQ", "gptq", "gptq"),
|
||||
("TheBloke/Llama-2-7B-Chat-GPTQ", "awq", "ERROR"),
|
||||
# compat: autogptq >=0.8.0 use checkpoint_format: str
|
||||
# Model Serialized in Marlin Format should always use Marlin kernel.
|
||||
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", None, "marlin"),
|
||||
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "marlin", "marlin"),
|
||||
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "gptq", "marlin"),
|
||||
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "awq", "ERROR"),
|
||||
# Model Serialized in Exllama Format.
|
||||
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", None, "gptq_marlin"),
|
||||
("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "marlin", "gptq_marlin"),
|
||||
|
||||
@ -11,7 +11,6 @@ import torch
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
UnquantizedEmbeddingMethod)
|
||||
|
||||
@ -19,9 +18,7 @@ PROMPT = "On the surface of Mars, we found"
|
||||
|
||||
MODELS_QUANT = [
|
||||
("ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head", True),
|
||||
("ModelCloud/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit-10-25-2024", False),
|
||||
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", False),
|
||||
("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", False)
|
||||
]
|
||||
|
||||
|
||||
@ -41,8 +38,7 @@ def test_lm_head(
|
||||
lm_head_layer = model.lm_head
|
||||
if lm_head_quantized:
|
||||
assert isinstance(lm_head_layer.quant_method,
|
||||
(GPTQLinearMethod, GPTQMarlinLinearMethod,
|
||||
MarlinLinearMethod))
|
||||
(GPTQLinearMethod, GPTQMarlinLinearMethod))
|
||||
else:
|
||||
assert isinstance(lm_head_layer.quant_method,
|
||||
UnquantizedEmbeddingMethod)
|
||||
|
||||
@ -10,13 +10,6 @@ from transformers import AutoModelForSeq2SeqLM
|
||||
|
||||
from vllm.assets.audio import AudioAsset
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(run_with_both_engines):
|
||||
"""We can run both engines for this test."""
|
||||
pass
|
||||
|
||||
|
||||
# FIXME(zhuohan): The test can not pass if we:
|
||||
# 1. Increase max_tokens to 256.
|
||||
# 2. Increase beam_width to 8.
|
||||
@ -27,7 +20,7 @@ MM_BEAM_WIDTHS = [2]
|
||||
MODELS = ["TinyLlama/TinyLlama-1.1B-Chat-v1.0"]
|
||||
|
||||
|
||||
@pytest.mark.skip_v1 # FIXME: This fails on V1 right now.
|
||||
@pytest.mark.skip(reason="Fails on V1") # FIXME: This fails on V1 right now.
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
|
||||
|
||||
@ -9,13 +9,6 @@ import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(run_with_both_engines):
|
||||
"""We can run both engines for this test."""
|
||||
pass
|
||||
|
||||
|
||||
# We also test with llama because it has generation_config to specify EOS
|
||||
# (past regression).
|
||||
MODELS = ["distilbert/distilgpt2", "meta-llama/Llama-3.2-1B"]
|
||||
@ -31,7 +24,7 @@ def test_ignore_eos(
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||
with vllm_runner(model, enforce_eager=True, dtype=dtype) as vllm_model:
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens,
|
||||
ignore_eos=True)
|
||||
|
||||
|
||||
@ -11,19 +11,10 @@ from ..conftest import VllmRunner
|
||||
MODELS = ["distilbert/distilgpt2"]
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch):
|
||||
"""
|
||||
This module is V0 only since it uses dtype=float, so
|
||||
set VLLM_USE_V1=0 for all tests in the module.
|
||||
"""
|
||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype",
|
||||
["float"]) # needed for comparing logprobs with HF
|
||||
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
|
||||
@pytest.mark.parametrize("max_num_batched_tokens", [4])
|
||||
@pytest.mark.parametrize("num_top_logprobs", [0, 6]) # 32000 == vocab_size
|
||||
@pytest.mark.parametrize("detokenize", [True, False])
|
||||
def test_get_prompt_logprobs(
|
||||
@ -31,19 +22,11 @@ def test_get_prompt_logprobs(
|
||||
vllm_runner,
|
||||
model,
|
||||
dtype,
|
||||
chunked_prefill_token_size: int,
|
||||
max_num_batched_tokens: int,
|
||||
num_top_logprobs: int,
|
||||
detokenize: bool,
|
||||
example_prompts,
|
||||
):
|
||||
max_num_seqs = 256
|
||||
enable_chunked_prefill = False
|
||||
max_num_batched_tokens = None
|
||||
if chunked_prefill_token_size != -1:
|
||||
enable_chunked_prefill = True
|
||||
max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
|
||||
max_num_batched_tokens = chunked_prefill_token_size
|
||||
|
||||
max_tokens = 5
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_logprobs = hf_model.generate_greedy_logprobs(
|
||||
@ -55,9 +38,8 @@ def test_get_prompt_logprobs(
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_logprobs=num_top_logprobs,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_num_seqs=max_num_seqs,
|
||||
enforce_eager=True,
|
||||
) as vllm_model:
|
||||
vllm_sampling_params = SamplingParams(max_tokens=max_tokens,
|
||||
logprobs=num_top_logprobs,
|
||||
@ -140,7 +122,9 @@ def test_get_prompt_logprobs(
|
||||
|
||||
|
||||
def test_max_logprobs():
|
||||
runner = VllmRunner("facebook/opt-125m", max_logprobs=1)
|
||||
runner = VllmRunner("facebook/opt-125m",
|
||||
max_logprobs=1,
|
||||
enforce_eager=True)
|
||||
vllm_sampling_params = SamplingParams(logprobs=1)
|
||||
# should pass
|
||||
runner.generate(["Hello world"], sampling_params=vllm_sampling_params)
|
||||
@ -151,24 +135,16 @@ def test_max_logprobs():
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
|
||||
@pytest.mark.parametrize("max_num_batched_tokens", [4])
|
||||
@pytest.mark.parametrize("detokenize", [True, False])
|
||||
def test_none_logprobs(vllm_runner, model, chunked_prefill_token_size: int,
|
||||
def test_none_logprobs(vllm_runner, model, max_num_batched_tokens: int,
|
||||
detokenize: bool, example_prompts):
|
||||
max_num_seqs = 256
|
||||
enable_chunked_prefill = False
|
||||
max_num_batched_tokens = None
|
||||
if chunked_prefill_token_size != -1:
|
||||
enable_chunked_prefill = True
|
||||
max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
|
||||
max_num_batched_tokens = chunked_prefill_token_size
|
||||
max_tokens = 5
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
enforce_eager=True,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_num_seqs=max_num_seqs,
|
||||
) as vllm_model:
|
||||
sampling_params_logprobs_none = SamplingParams(max_tokens=max_tokens,
|
||||
logprobs=None,
|
||||
|
||||
@ -7,18 +7,11 @@ Run `pytest tests/samplers/test_no_bad_words.py`.
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(monkeypatch):
|
||||
"""Only run on vLLM v1."""
|
||||
monkeypatch.setenv('VLLM_USE_V1', '1')
|
||||
|
||||
|
||||
def _generate(
|
||||
llm: LLM,
|
||||
prompt: str,
|
||||
@ -57,7 +50,7 @@ class TestOneTokenBadWord:
|
||||
add_special_tokens=False)[0]
|
||||
|
||||
def test_one_token_bad_word(self, vllm_runner):
|
||||
with vllm_runner(self.MODEL) as llm:
|
||||
with vllm_runner(self.MODEL, enforce_eager=True) as llm:
|
||||
output_token_ids = self._generate(llm)
|
||||
assert output_token_ids[0] == self.target_token_id
|
||||
|
||||
@ -104,7 +97,7 @@ class TestTwoTokenBadWord:
|
||||
add_special_tokens=False)[0]
|
||||
|
||||
def test_two_token_bad_word(self, vllm_runner):
|
||||
with vllm_runner(self.MODEL, dtype="half") as llm:
|
||||
with vllm_runner(self.MODEL, enforce_eager=True) as llm:
|
||||
output_token_ids = self._generate(llm)
|
||||
assert output_token_ids[:2] == [
|
||||
self.target_token_id1, self.target_token_id2
|
||||
|
||||
@ -8,12 +8,6 @@ from vllm import SamplingParams
|
||||
MODELS = ["distilbert/distilgpt2"]
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(run_with_both_engines):
|
||||
"""We can run both engines for this test."""
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
def test_ranks(
|
||||
@ -26,7 +20,9 @@ def test_ranks(
|
||||
num_top_logprobs = 5
|
||||
num_prompt_logprobs = 5
|
||||
|
||||
with vllm_runner(model, dtype=dtype,
|
||||
with vllm_runner(model,
|
||||
dtype=dtype,
|
||||
enforce_eager=True,
|
||||
max_logprobs=num_top_logprobs) as vllm_model:
|
||||
|
||||
## Test greedy logprobs ranks
|
||||
|
||||
@ -1,769 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import itertools
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import GenerationConfig, GenerationMixin
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.utils import Counter, is_pin_memory_available
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch):
|
||||
"""
|
||||
This file tests V0 internals, so set VLLM_USE_V1=0.
|
||||
"""
|
||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||
|
||||
|
||||
class MockLogitsSampler(Sampler):
|
||||
|
||||
def __init__(self, fake_logits: torch.Tensor):
|
||||
super().__init__()
|
||||
self.fake_logits = fake_logits
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
|
||||
def _prepare_test(
|
||||
batch_size: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor, MockLogitsSampler]:
|
||||
input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
|
||||
fake_logits = torch.full((batch_size, VOCAB_SIZE),
|
||||
1e-2,
|
||||
dtype=input_tensor.dtype)
|
||||
sampler = MockLogitsSampler(fake_logits)
|
||||
return input_tensor, fake_logits, sampler
|
||||
|
||||
|
||||
VOCAB_SIZE = 32000
|
||||
RANDOM_SEEDS = list(range(128))
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
|
||||
def _do_sample(
|
||||
batch_size: int,
|
||||
input_tensor: torch.Tensor,
|
||||
sampler: MockLogitsSampler,
|
||||
sampling_params: SamplingParams,
|
||||
device: str,
|
||||
):
|
||||
seq_group_metadata_list: list[SequenceGroupMetadata] = []
|
||||
seq_lens: list[int] = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
|
||||
sampling_params=sampling_params,
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
seq_lens,
|
||||
query_lens=seq_lens,
|
||||
device=device,
|
||||
pin_memory=is_pin_memory_available())
|
||||
return sampler(logits=input_tensor, sampling_metadata=sampling_metadata)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_all_greedy(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, fake_logits, sampler = _prepare_test(batch_size)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||
sampling_params, device)
|
||||
expected = torch.argmax(fake_logits, dim=-1)
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
for nth_output in sequence_output.samples:
|
||||
assert nth_output.output_token == expected[i].item()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_all_random(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
_, fake_logits, sampler = _prepare_test(batch_size)
|
||||
|
||||
for i in range(batch_size):
|
||||
fake_logits[i, i] = 1e2
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
n=random.randint(1, 10),
|
||||
)
|
||||
sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||
sampling_params, device)
|
||||
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
for nth_output in sequence_output.samples:
|
||||
assert nth_output.output_token == i
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_all_random_seed(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
_, fake_logits, sampler = _prepare_test(batch_size)
|
||||
|
||||
for i in range(batch_size):
|
||||
fake_logits[i, i] = 1e2
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
n=random.randint(1, 10),
|
||||
seed=random.randint(0, 10000),
|
||||
)
|
||||
sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||
sampling_params, device)
|
||||
|
||||
for i, sequence_output in enumerate(sampler_output):
|
||||
for nth_output in sequence_output.samples:
|
||||
assert nth_output.output_token == i
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_all_random_seed_deterministic(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
_, fake_logits, sampler = _prepare_test(batch_size)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
n=random.randint(1, 10),
|
||||
seed=random.randint(0, 10000),
|
||||
)
|
||||
first_sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||
sampling_params, device)
|
||||
|
||||
second_sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||
sampling_params, device)
|
||||
|
||||
assert first_sampler_output == second_sampler_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
seq_id_counter = Counter(start=random.randint(0, 100))
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
def create_sampling_params(min_tokens,
|
||||
eos_token_id=0,
|
||||
*,
|
||||
stop_token_ids: Optional[list[int]] = None,
|
||||
prompt_logprobs: Optional[int] = None):
|
||||
sampling_params = SamplingParams(
|
||||
min_tokens=min_tokens,
|
||||
max_tokens=9999, # keep higher than max of min_tokens
|
||||
stop_token_ids=stop_token_ids,
|
||||
# requesting prompt_logprobs changes the structure of `logits`
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
)
|
||||
sampling_params.all_stop_token_ids.add(eos_token_id)
|
||||
return sampling_params
|
||||
|
||||
def create_sequence_data(num_input=3, num_generated=0):
|
||||
seq_data = SequenceData.from_seqs(
|
||||
random.choices(range(0, VOCAB_SIZE), k=num_input))
|
||||
if num_generated > 0:
|
||||
seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE),
|
||||
k=num_generated)
|
||||
return seq_data
|
||||
|
||||
def generate_test_case():
|
||||
# generate multiple seq groups but limit total batch size
|
||||
batch_size = random.randint(1, 128)
|
||||
|
||||
expected_penalization = []
|
||||
sequence_metadata_list: list[SequenceGroupMetadata] = []
|
||||
# 20% chance to generate seq group metadata list with all prompts
|
||||
is_prompt = random.random() < 0.2
|
||||
while batch_size > 0:
|
||||
num_seqs = 1 if is_prompt else random.randint(1, batch_size)
|
||||
|
||||
eos_token_id = random.randint(0, VOCAB_SIZE - 1)
|
||||
min_tokens = random.randint(0, 50)
|
||||
num_stop_tokens = random.randint(0, 8)
|
||||
if num_stop_tokens > 0:
|
||||
stop_token_ids = random.choices(range(0, VOCAB_SIZE - 1),
|
||||
k=num_stop_tokens)
|
||||
else:
|
||||
stop_token_ids = None
|
||||
|
||||
sampling_params = create_sampling_params(
|
||||
min_tokens=min_tokens,
|
||||
eos_token_id=eos_token_id,
|
||||
stop_token_ids=stop_token_ids)
|
||||
|
||||
seq_data: dict[int, SequenceData] = {}
|
||||
seq_group_penalization: list[bool] = []
|
||||
for _ in range(num_seqs):
|
||||
num_input = random.randint(1, 100)
|
||||
num_generated = 0 if is_prompt else random.randint(1, 100)
|
||||
seq_data[next(seq_id_counter)] = create_sequence_data(
|
||||
num_input=num_input, num_generated=num_generated)
|
||||
seq_group_penalization.append(num_generated < min_tokens)
|
||||
|
||||
expected_penalization.extend(seq_group_penalization)
|
||||
sequence_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{batch_size}",
|
||||
is_prompt=is_prompt,
|
||||
seq_data=seq_data,
|
||||
sampling_params=sampling_params,
|
||||
block_tables={},
|
||||
))
|
||||
batch_size -= num_seqs
|
||||
|
||||
return {
|
||||
"expected_penalization": expected_penalization,
|
||||
"seq_group_metadata_list": sequence_metadata_list,
|
||||
}
|
||||
|
||||
# define some explicit test cases for edge case behavior
|
||||
prompt_without_penalization = {
|
||||
"expected_penalization": [False],
|
||||
"seq_group_metadata_list": [
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_1",
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
next(seq_id_counter): create_sequence_data(),
|
||||
},
|
||||
sampling_params=create_sampling_params(0),
|
||||
block_tables={},
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
prompt_with_penalization = {
|
||||
"expected_penalization": [True],
|
||||
"seq_group_metadata_list": [
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_1",
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
next(seq_id_counter): create_sequence_data(),
|
||||
},
|
||||
sampling_params=create_sampling_params(1),
|
||||
block_tables={},
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
prompt_with_penalization_and_prompt_logprobs = {
|
||||
"expected_penalization": [False, False, True],
|
||||
"seq_group_metadata_list": [
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_1",
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
next(seq_id_counter): create_sequence_data(num_input=3),
|
||||
},
|
||||
sampling_params=create_sampling_params(1, prompt_logprobs=3),
|
||||
block_tables={},
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
stop_penalizing_after_min_tokens = {
|
||||
"expected_penalization": [False],
|
||||
"seq_group_metadata_list": [
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_1",
|
||||
is_prompt=False,
|
||||
seq_data={
|
||||
next(seq_id_counter):
|
||||
create_sequence_data(num_generated=1),
|
||||
},
|
||||
sampling_params=create_sampling_params(1),
|
||||
block_tables={},
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
stop_token_ids = [42, 99, 42, 0] # intentional duplication
|
||||
prompt_combination = {
|
||||
"expected_penalization": [False, True, False],
|
||||
"seq_group_metadata_list": [
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_2",
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
next(seq_id_counter): create_sequence_data(num_input=2),
|
||||
},
|
||||
sampling_params=create_sampling_params(1, prompt_logprobs=3),
|
||||
block_tables={},
|
||||
),
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_3",
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
next(seq_id_counter): create_sequence_data(),
|
||||
},
|
||||
sampling_params=create_sampling_params(
|
||||
0, stop_token_ids=stop_token_ids),
|
||||
block_tables={},
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
stop_token_ids = [1, 999, 37, 37] # intentional duplication
|
||||
decode_combination = {
|
||||
"expected_penalization": [True, False, False, True, False],
|
||||
"seq_group_metadata_list": [
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_1",
|
||||
is_prompt=False,
|
||||
seq_data={
|
||||
next(seq_id_counter):
|
||||
create_sequence_data(num_generated=1),
|
||||
next(seq_id_counter):
|
||||
create_sequence_data(num_generated=100),
|
||||
},
|
||||
sampling_params=create_sampling_params(
|
||||
2, stop_token_ids=stop_token_ids),
|
||||
block_tables={},
|
||||
),
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_2",
|
||||
is_prompt=False,
|
||||
seq_data={
|
||||
next(seq_id_counter):
|
||||
create_sequence_data(num_generated=20),
|
||||
next(seq_id_counter):
|
||||
create_sequence_data(num_generated=1),
|
||||
next(seq_id_counter):
|
||||
create_sequence_data(num_generated=10),
|
||||
},
|
||||
sampling_params=create_sampling_params(
|
||||
10, prompt_logprobs=5, stop_token_ids=stop_token_ids),
|
||||
block_tables={},
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
if seed == 0:
|
||||
test_cases = [
|
||||
prompt_without_penalization,
|
||||
prompt_with_penalization,
|
||||
prompt_with_penalization_and_prompt_logprobs,
|
||||
stop_penalizing_after_min_tokens,
|
||||
prompt_combination,
|
||||
decode_combination,
|
||||
]
|
||||
else:
|
||||
test_cases = [generate_test_case()]
|
||||
|
||||
def run_test_case(*, expected_penalization: list[bool],
|
||||
seq_group_metadata_list: list[SequenceGroupMetadata]):
|
||||
assert expected_penalization, \
|
||||
"Invalid test case, need expected_penalization"
|
||||
assert seq_group_metadata_list, \
|
||||
"Invalid test case, need seq_group_metadata_list"
|
||||
|
||||
batch_size = 0
|
||||
seq_lens: list[int] = []
|
||||
sampling_params_per_row: list[SamplingParams] = []
|
||||
for sgm in seq_group_metadata_list:
|
||||
sampling_params = sgm.sampling_params
|
||||
|
||||
num_rows = len(sgm.seq_data)
|
||||
if sgm.is_prompt:
|
||||
# a prompt seq_group has only one sequence
|
||||
seq_data = next(iter(sgm.seq_data.values()))
|
||||
prompt_len = seq_data.get_prompt_len()
|
||||
seq_lens.append(prompt_len)
|
||||
|
||||
assert sgm.sampling_params is not None
|
||||
if sgm.sampling_params.prompt_logprobs:
|
||||
# with prompt_logprobs each token in the prompt has a row in
|
||||
# logits
|
||||
num_rows = prompt_len
|
||||
|
||||
batch_size += num_rows
|
||||
sampling_params_per_row.extend(
|
||||
itertools.repeat(sampling_params, num_rows))
|
||||
|
||||
assert len(
|
||||
expected_penalization
|
||||
) == batch_size, \
|
||||
("Invalid test case, expected_penalization does not match computed"
|
||||
"batch size")
|
||||
|
||||
_, fake_logits, sampler = _prepare_test(batch_size)
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
seq_lens=seq_lens if seq_lens else None,
|
||||
query_lens=seq_lens if seq_lens else [1] * batch_size,
|
||||
device=device,
|
||||
pin_memory=is_pin_memory_available())
|
||||
# the logits tensor is modified in-place by the sampler
|
||||
_ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
|
||||
|
||||
for logits_idx, (should_penalize, sampling_params) in enumerate(
|
||||
zip(expected_penalization, sampling_params_per_row)):
|
||||
|
||||
tokens_to_check = sampling_params.all_stop_token_ids
|
||||
|
||||
if should_penalize:
|
||||
for token_id in tokens_to_check:
|
||||
assert fake_logits[logits_idx, token_id] == -float(
|
||||
'inf'
|
||||
), f"Expected token {token_id} for logits row {logits_idx}"
|
||||
" to be penalized"
|
||||
# no other tokens should be set to -inf
|
||||
assert torch.count_nonzero(
|
||||
fake_logits[logits_idx, :] == -float('inf')) == len(
|
||||
tokens_to_check
|
||||
), f"Expected only {len(tokens_to_check)} to be penalized"
|
||||
else:
|
||||
# no tokens should be set to -inf
|
||||
assert torch.count_nonzero(
|
||||
fake_logits[logits_idx, :] ==
|
||||
-float('inf')) == 0, "No tokens should have been penalized"
|
||||
|
||||
for test_case in test_cases:
|
||||
run_test_case(**test_case)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_mixed(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, fake_logits, sampler = _prepare_test(batch_size)
|
||||
|
||||
seq_group_metadata_list: list[SequenceGroupMetadata] = []
|
||||
expected_tokens: list[Optional[list[int]]] = []
|
||||
seq_lens: list[int] = []
|
||||
for i in range(batch_size):
|
||||
expected: Optional[list[int]] = None
|
||||
sampling_type = random.randint(0, 2)
|
||||
if sampling_type == 0:
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
expected = [int(torch.argmax(fake_logits[i], dim=-1).item())]
|
||||
elif sampling_type in (1, 2):
|
||||
n = random.randint(1, 10)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=random.random() + 0.1,
|
||||
top_p=min(random.random() + 0.1, 1),
|
||||
top_k=random.randint(0, 10),
|
||||
n=n,
|
||||
presence_penalty=random.randint(0, 1),
|
||||
)
|
||||
if sampling_type == 2:
|
||||
sampling_params.seed = random.randint(0, 10000)
|
||||
else:
|
||||
for idx in range(n):
|
||||
fake_logits[i, i + idx] = 1e2
|
||||
expected = list(range(i, i + n))
|
||||
|
||||
expected_tokens.append(expected)
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
|
||||
sampling_params=sampling_params,
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
generators: dict[str, torch.Generator] = {}
|
||||
|
||||
def test_sampling():
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
seq_lens,
|
||||
query_lens=seq_lens,
|
||||
device=device,
|
||||
pin_memory=is_pin_memory_available(),
|
||||
generators=generators)
|
||||
sampler_output = sampler(logits=fake_logits,
|
||||
sampling_metadata=sampling_metadata)
|
||||
|
||||
for i, (sequence_output, metadata) in enumerate(
|
||||
zip(sampler_output, seq_group_metadata_list)):
|
||||
assert metadata.sampling_params is not None
|
||||
|
||||
if (metadata.sampling_params.seed is not None
|
||||
and expected_tokens[i] is None):
|
||||
# Record seeded random result to compare with results of
|
||||
# second invocation
|
||||
expected_tokens[i] = [
|
||||
nth_output.output_token
|
||||
for nth_output in sequence_output.samples
|
||||
]
|
||||
continue
|
||||
|
||||
expected_tokens_item = expected_tokens[i]
|
||||
assert expected_tokens_item is not None
|
||||
|
||||
for n, nth_output in enumerate(sequence_output.samples):
|
||||
assert metadata.sampling_params is not None
|
||||
|
||||
if (metadata.sampling_params.temperature == 0
|
||||
or metadata.sampling_params.seed is not None):
|
||||
# Ensure exact matches for greedy or random with seed
|
||||
assert nth_output.output_token == expected_tokens_item[n]
|
||||
else:
|
||||
# For non-seeded random check that one of the high-logit
|
||||
# tokens were chosen
|
||||
assert nth_output.output_token in expected_tokens_item
|
||||
|
||||
# Test batch
|
||||
test_sampling()
|
||||
|
||||
# Shuffle the batch and resample
|
||||
target_index = list(range(batch_size))
|
||||
for list_to_shuffle in (target_index, seq_group_metadata_list,
|
||||
expected_tokens, seq_lens):
|
||||
random.Random(seed).shuffle(list_to_shuffle)
|
||||
target_index = torch.tensor(target_index)
|
||||
input_tensor.data = input_tensor.index_select(0, target_index)
|
||||
fake_logits.data = fake_logits.index_select(0, target_index)
|
||||
|
||||
# This time, results of seeded random samples will be compared with
|
||||
# the corresponding sample in the pre-shuffled batch
|
||||
test_sampling()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_top_k_top_p(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
batch_size = random.randint(1, 256)
|
||||
top_k = random.randint(100, 500)
|
||||
top_p = random.random() * 0.1
|
||||
vocab_size = 32000
|
||||
input_tensor = torch.rand((batch_size, 1024),
|
||||
device=device,
|
||||
dtype=torch.float16)
|
||||
fake_logits = torch.normal(0,
|
||||
5,
|
||||
size=(batch_size, vocab_size),
|
||||
device=input_tensor.device,
|
||||
dtype=input_tensor.dtype)
|
||||
sampler = MockLogitsSampler(fake_logits)
|
||||
|
||||
generation_model = GenerationMixin()
|
||||
generation_config = GenerationConfig(top_k=top_k,
|
||||
top_p=top_p,
|
||||
do_sample=True)
|
||||
|
||||
@dataclass
|
||||
class MockConfig:
|
||||
is_encoder_decoder: bool = False
|
||||
|
||||
generation_model.config = MockConfig() # needed by the following method
|
||||
generation_model._prepare_special_tokens(generation_config, device=device)
|
||||
processors = generation_model._get_logits_processor(generation_config,
|
||||
None,
|
||||
None,
|
||||
None, [],
|
||||
device=device)
|
||||
assert len(processors) == 2 # top_p and top_k
|
||||
|
||||
seq_group_metadata_list: list[SequenceGroupMetadata] = []
|
||||
seq_lens: list[int] = []
|
||||
for i in range(batch_size):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
|
||||
sampling_params=SamplingParams(
|
||||
temperature=1,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
),
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
seq_lens,
|
||||
query_lens=seq_lens,
|
||||
device=device,
|
||||
pin_memory=is_pin_memory_available())
|
||||
|
||||
sample_probs = None
|
||||
|
||||
def mock_sample(probs, *args, **kwargs):
|
||||
nonlocal sample_probs
|
||||
sample_probs = probs
|
||||
return ([[prob.topk(1, dim=-1).indices.tolist(), [0]]
|
||||
for prob in probs], None)
|
||||
|
||||
# top-k and top-p is only calculated when flashinfer kernel is not available
|
||||
with patch("vllm.model_executor.layers.sampler._sample", mock_sample), \
|
||||
patch("vllm.model_executor.layers.sampler."
|
||||
"flashinfer_top_k_top_p_sampling", None):
|
||||
sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
|
||||
|
||||
assert sample_probs is not None
|
||||
|
||||
hf_probs = processors(torch.zeros_like(fake_logits), fake_logits.clone())
|
||||
hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
|
||||
torch.testing.assert_close(hf_probs, sample_probs, rtol=0.0, atol=1e-5)
|
||||
assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_flashinfer_fallback(seed: int, device: str):
|
||||
if not envs.VLLM_USE_FLASHINFER_SAMPLER:
|
||||
pytest.skip("Flashinfer sampler is disabled")
|
||||
|
||||
pytest.skip("After FlashInfer 0.2.3, sampling will never fail")
|
||||
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
_, fake_logits, sampler = _prepare_test(batch_size)
|
||||
|
||||
def failing_flashinfer_sampling(*_args, **_kwargs):
|
||||
return None, torch.zeros(batch_size, device=device, dtype=torch.int32)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
n=random.randint(1, 10),
|
||||
seed=random.randint(0, 10000),
|
||||
)
|
||||
sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||
sampling_params, device)
|
||||
|
||||
with patch(
|
||||
"vllm.model_executor.layers.sampler."
|
||||
"flashinfer_top_k_top_p_sampling", failing_flashinfer_sampling):
|
||||
fallback_sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||
sampling_params, device)
|
||||
|
||||
assert sampler_output == fallback_sampler_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_repetition_penalty_mixed(device: str):
|
||||
|
||||
vocab_size = 8
|
||||
|
||||
def test_sampling_params(sampling_params: list[SamplingParams]):
|
||||
|
||||
seq_group_metadata_list: list[SequenceGroupMetadata] = []
|
||||
seq_lens: list[int] = []
|
||||
for i in range(2):
|
||||
seq_group_metadata_list.append(
|
||||
SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
|
||||
sampling_params=sampling_params[i],
|
||||
block_tables={0: [1]},
|
||||
))
|
||||
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
|
||||
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
seq_lens,
|
||||
query_lens=seq_lens,
|
||||
device=device,
|
||||
pin_memory=is_pin_memory_available())
|
||||
|
||||
fake_logits = torch.full((2, vocab_size),
|
||||
1e-2,
|
||||
device=device,
|
||||
dtype=torch.float16)
|
||||
|
||||
fake_logits[:, 5] = 1.1e-2
|
||||
fake_logits[:, 1] = 1.2e-2
|
||||
|
||||
sampler = MockLogitsSampler(fake_logits)
|
||||
|
||||
sampler_output = sampler(logits=fake_logits,
|
||||
sampling_metadata=sampling_metadata)
|
||||
|
||||
generated_tokens = []
|
||||
for output in sampler_output:
|
||||
generated_tokens.append(output.samples[0].output_token)
|
||||
|
||||
return generated_tokens
|
||||
|
||||
# one configuration is greedy with repetition_penalty
|
||||
sampling_params_rep = SamplingParams(
|
||||
temperature=0.0,
|
||||
repetition_penalty=2.0,
|
||||
)
|
||||
|
||||
# other configuration is sampling w/o repetition_penalty
|
||||
sampling_params_sample = SamplingParams(
|
||||
temperature=1.0,
|
||||
top_k=1,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
tokens1 = test_sampling_params(
|
||||
[sampling_params_rep, sampling_params_sample])
|
||||
|
||||
tokens2 = test_sampling_params(
|
||||
[sampling_params_sample, sampling_params_rep])
|
||||
|
||||
assert tokens1[0] == tokens2[1]
|
||||
assert tokens1[1] == tokens2[0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_include_gpu_probs_tensor(device: str):
|
||||
set_random_seed(42)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
_, fake_logits, sampler = _prepare_test(batch_size)
|
||||
sampler.include_gpu_probs_tensor = True
|
||||
sampler.should_modify_greedy_probs_inplace = False
|
||||
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
|
||||
mock_inplace = Mock()
|
||||
with patch(
|
||||
"vllm.model_executor.layers.sampler._modify_greedy_probs_inplace",
|
||||
mock_inplace):
|
||||
|
||||
sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||
sampling_params, device)
|
||||
mock_inplace.assert_not_called()
|
||||
|
||||
assert sampler_output.sampled_token_probs is not None
|
||||
assert sampler_output.logprobs is not None
|
||||
assert sampler_output.sampled_token_ids is not None
|
||||
@ -1,86 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Verify that seeded random sampling is deterministic.
|
||||
|
||||
Run `pytest tests/samplers/test_seeded_generate.py`.
|
||||
"""
|
||||
import copy
|
||||
import random
|
||||
from itertools import combinations
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
|
||||
MODEL = "facebook/opt-125m"
|
||||
RANDOM_SEEDS = list(range(5))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vllm_model(vllm_runner, monkeypatch):
|
||||
# This file relies on V0 internals.
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
with vllm_runner(MODEL, dtype="half") as vllm_model:
|
||||
yield vllm_model
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
def test_random_sample_with_seed(
|
||||
vllm_model,
|
||||
example_prompts,
|
||||
seed: int,
|
||||
) -> None:
|
||||
set_random_seed(seed)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
# Parameters to ensure sufficient randomness
|
||||
temperature=3.0,
|
||||
top_p=min(random.random() + 0.3, 1),
|
||||
top_k=random.randint(5, 20),
|
||||
n=random.randint(1, 10),
|
||||
presence_penalty=random.randint(0, 1),
|
||||
max_tokens=8,
|
||||
ignore_eos=True,
|
||||
)
|
||||
|
||||
sampling_params_seed_1 = copy.deepcopy(sampling_params)
|
||||
sampling_params_seed_1.seed = 100
|
||||
sampling_params_seed_2 = copy.deepcopy(sampling_params)
|
||||
sampling_params_seed_2.seed = 200
|
||||
|
||||
llm = vllm_model.llm
|
||||
|
||||
for prompt in example_prompts:
|
||||
for params in (
|
||||
sampling_params,
|
||||
sampling_params_seed_1,
|
||||
sampling_params_seed_2,
|
||||
sampling_params,
|
||||
sampling_params_seed_1,
|
||||
sampling_params_seed_2,
|
||||
):
|
||||
llm._add_request(prompt, params=params)
|
||||
|
||||
results = llm._run_engine(use_tqdm=False)
|
||||
all_outputs = [[out.token_ids for out in output.outputs]
|
||||
for output in results]
|
||||
|
||||
for i in range(0, len(example_prompts), 6):
|
||||
outputs = all_outputs[i:i + 6]
|
||||
|
||||
# verify all non-seeded requests differ
|
||||
for output_a, output_b in combinations(
|
||||
(outputs[0], outputs[1], outputs[2], outputs[3]),
|
||||
2,
|
||||
):
|
||||
assert output_a != output_b
|
||||
|
||||
# verify requests with the same seed match
|
||||
assert outputs[1] == outputs[4]
|
||||
assert outputs[2] == outputs[5]
|
||||
|
||||
# verify generations within the same parallel sampling group differ
|
||||
for output in outputs:
|
||||
for sub_output_a, sub_output_b in combinations(output, 2):
|
||||
assert sub_output_a != sub_output_b
|
||||
@ -5,6 +5,7 @@ import asyncio
|
||||
import copy
|
||||
import functools
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
@ -101,7 +102,8 @@ class RemoteOpenAIServer:
|
||||
env_dict: Optional[dict[str, str]] = None,
|
||||
seed: Optional[int] = 0,
|
||||
auto_port: bool = True,
|
||||
max_wait_seconds: Optional[float] = None) -> None:
|
||||
max_wait_seconds: Optional[float] = None,
|
||||
override_hf_configs: Optional[dict[str, Any]] = None) -> None:
|
||||
if auto_port:
|
||||
if "-p" in vllm_serve_args or "--port" in vllm_serve_args:
|
||||
raise ValueError("You have manually specified the port "
|
||||
@ -120,6 +122,12 @@ class RemoteOpenAIServer:
|
||||
|
||||
vllm_serve_args = vllm_serve_args + ["--seed", str(seed)]
|
||||
|
||||
if override_hf_configs is not None:
|
||||
vllm_serve_args = vllm_serve_args + [
|
||||
"--hf-overrides",
|
||||
json.dumps(override_hf_configs)
|
||||
]
|
||||
|
||||
parser = FlexibleArgumentParser(
|
||||
description="vLLM's remote OpenAI server.")
|
||||
subparsers = parser.add_subparsers(required=False, dest="subparser")
|
||||
|
||||
@ -150,15 +150,15 @@ def create_and_prepopulate_kv_cache(
|
||||
|
||||
# Permute the context blocks (excluding block 0 which is null)
|
||||
if randomize_blocks:
|
||||
perm = torch.randperm(
|
||||
blocks_end - 1) + 1 # Random permutation starting from block 1
|
||||
# Random permutation starting from block 1
|
||||
perm = torch.randperm(blocks_end - 1) + 1
|
||||
else:
|
||||
perm = torch.arange(
|
||||
1, blocks_end) # Sequential order starting from block 1
|
||||
# Sequential order starting from block 1
|
||||
perm = torch.arange(1, blocks_end)
|
||||
|
||||
inv_perm = torch.zeros(blocks_end, dtype=torch.long, device=device)
|
||||
inv_perm[1:] = torch.argsort(
|
||||
perm) + 1 # Add 1 to account for starting from block 1
|
||||
# Add 1 to account for starting from block 1
|
||||
inv_perm[1:] = torch.argsort(perm) + 1
|
||||
kv_cache[:, 1:blocks_end, ...] = kv_cache[:, perm, ...]
|
||||
|
||||
# Construct the right block table
|
||||
@ -281,7 +281,8 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
|
||||
|
||||
@pytest.mark.parametrize("batch_spec_name", [
|
||||
"small_decode", "small_prefill", "mixed_small", "medium_decode",
|
||||
"medium_prefill", "mixed_medium"
|
||||
"medium_prefill", "mixed_medium", "large_decode", "large_prefill",
|
||||
"single_decode", "single_prefill"
|
||||
])
|
||||
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
|
||||
def test_backend_correctness(batch_spec_name: str, model: str):
|
||||
@ -302,7 +303,8 @@ def test_backend_correctness(batch_spec_name: str, model: str):
|
||||
"""
|
||||
batch_spec = BATCH_SPECS[batch_spec_name]
|
||||
vllm_config = create_vllm_config(model_name=model,
|
||||
max_model_len=max(batch_spec.seq_lens))
|
||||
max_model_len=max(batch_spec.seq_lens),
|
||||
num_gpu_blocks=8192)
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
|
||||
@ -465,12 +467,6 @@ def test_backend_correctness(batch_spec_name: str, model: str):
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
if not all_close:
|
||||
print(f"[{backend_name}] output differs from SDPA baseline. "
|
||||
f"Max diff: {max_diff:.6f} (rel: {max_rel_diff:.6f})")
|
||||
print(f"[{backend_name}] output: {backend_output}")
|
||||
print(f"[{backend_name}] SDPA baseline: {sdpa_output}")
|
||||
|
||||
assert all_close, (
|
||||
f"[{backend_name}] output differs from SDPA baseline. "
|
||||
f"Max diff: {max_diff:.6f} (rel: {max_rel_diff:.6f})")
|
||||
f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})")
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user