Compare commits

..

1 Commits

Author SHA1 Message Date
c1d1875ba3 Updates docs with correction about default cuda version
Correct 12.1 --> 12.4
2025-01-07 17:29:07 -05:00
1401 changed files with 20973 additions and 60330 deletions

View File

@ -1,14 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
import os import os
import sys import sys
import zipfile import zipfile
# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 400 MiB # Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 250 MB
# Note that we have 400 MiB quota, please use it wisely. VLLM_MAX_SIZE_MB = int(os.environ.get('VLLM_MAX_SIZE_MB', 250))
# See https://github.com/pypi/support/issues/3792 .
# Please also sync the value with the one in Dockerfile.
VLLM_MAX_SIZE_MB = int(os.environ.get('VLLM_MAX_SIZE_MB', 400))
def print_top_10_largest_files(zip_file): def print_top_10_largest_files(zip_file):

View File

@ -1,5 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
import argparse import argparse
import os import os

View File

@ -1,11 +0,0 @@
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM -b "auto" -t 2
model_name: "nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.6353
- name: "exact_match,flexible-extract"
value: 0.637
limit: null
num_fewshot: null

View File

@ -1,4 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
""" """
LM eval harness on model to compare vs HF baseline computed offline. LM eval harness on model to compare vs HF baseline computed offline.
Configs are found in configs/$MODEL.yaml Configs are found in configs/$MODEL.yaml

View File

@ -1,5 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
import json import json
import os import os
from pathlib import Path from pathlib import Path

View File

@ -1,5 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
import argparse import argparse
from transformers import AutoTokenizer from transformers import AutoTokenizer

View File

@ -1,5 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
import argparse import argparse
import json import json
from pathlib import Path from pathlib import Path

View File

@ -1,5 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
from lmdeploy.serve.openai.api_client import APIClient from lmdeploy.serve.openai.api_client import APIClient
api_client = APIClient("http://localhost:8000") api_client = APIClient("http://localhost:8000")

View File

@ -43,7 +43,7 @@ main() {
# The figures should be generated by a separate process outside the CI/CD pipeline # The figures should be genereated by a separate process outside the CI/CD pipeline
# # generate figures # # generate figures
# python3 -m pip install tabulate pandas matplotlib # python3 -m pip install tabulate pandas matplotlib

View File

@ -301,104 +301,6 @@ run_serving_tests() {
kill_gpu_processes kill_gpu_processes
} }
run_genai_perf_tests() {
# run genai-perf tests
# $1: a json file specifying genai-perf test cases
local genai_perf_test_file
genai_perf_test_file=$1
# Iterate over genai-perf tests
jq -c '.[]' "$genai_perf_test_file" | while read -r params; do
# get the test name, and append the GPU type back to it.
test_name=$(echo "$params" | jq -r '.test_name')
# if TEST_SELECTOR is set, only run the test cases that match the selector
if [[ -n "$TEST_SELECTOR" ]] && [[ ! "$test_name" =~ $TEST_SELECTOR ]]; then
echo "Skip test case $test_name."
continue
fi
# prepend the current serving engine to the test name
test_name=${CURRENT_LLM_SERVING_ENGINE}_${test_name}
# get common parameters
common_params=$(echo "$params" | jq -r '.common_parameters')
model=$(echo "$common_params" | jq -r '.model')
tp=$(echo "$common_params" | jq -r '.tp')
dataset_name=$(echo "$common_params" | jq -r '.dataset_name')
dataset_path=$(echo "$common_params" | jq -r '.dataset_path')
port=$(echo "$common_params" | jq -r '.port')
num_prompts=$(echo "$common_params" | jq -r '.num_prompts')
reuse_server=$(echo "$common_params" | jq -r '.reuse_server')
# get client and server arguments
server_params=$(echo "$params" | jq -r ".${CURRENT_LLM_SERVING_ENGINE}_server_parameters")
qps_list=$(echo "$params" | jq -r '.qps_list')
qps_list=$(echo "$qps_list" | jq -r '.[] | @sh')
echo "Running over qps list $qps_list"
# check if there is enough GPU to run the test
if [[ $gpu_count -lt $tp ]]; then
echo "Required num-shard $tp but only $gpu_count GPU found. Skip testcase $test_name."
continue
fi
if [[ $reuse_server == "true" ]]; then
echo "Reuse previous server for test case $test_name"
else
kill_gpu_processes
bash "$VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/scripts/launch-server.sh" \
"$server_params" "$common_params"
fi
if wait_for_server; then
echo ""
echo "$CURRENT_LLM_SERVING_ENGINE server is up and running."
else
echo ""
echo "$CURRENT_LLM_SERVING_ENGINE failed to start within the timeout period."
break
fi
# iterate over different QPS
for qps in $qps_list; do
# remove the surrounding single quote from qps
if [[ "$qps" == *"inf"* ]]; then
echo "qps was $qps"
qps=$num_prompts
echo "now qps is $qps"
fi
new_test_name=$test_name"_qps_"$qps
backend=$CURRENT_LLM_SERVING_ENGINE
if [[ "$backend" == *"vllm"* ]]; then
backend="vllm"
fi
#TODO: add output dir.
client_command="genai-perf profile \
-m $model \
--service-kind openai \
--backend vllm \
--endpoint-type chat \
--streaming \
--url localhost:$port \
--request-rate $qps \
--num-prompts $num_prompts \
"
echo "Client command: $client_command"
eval "$client_command"
#TODO: process/record outputs
done
done
kill_gpu_processes
}
prepare_dataset() { prepare_dataset() {
@ -426,17 +328,12 @@ main() {
pip install -U transformers pip install -U transformers
pip install -r requirements-dev.txt
which genai-perf
# check storage # check storage
df -h df -h
ensure_installed wget ensure_installed wget
ensure_installed curl ensure_installed curl
ensure_installed jq ensure_installed jq
# genai-perf dependency
ensure_installed libb64-0d
prepare_dataset prepare_dataset
@ -448,10 +345,6 @@ main() {
# run the test # run the test
run_serving_tests "$BENCHMARK_ROOT/tests/nightly-tests.json" run_serving_tests "$BENCHMARK_ROOT/tests/nightly-tests.json"
# run genai-perf tests
run_genai_perf_tests "$BENCHMARK_ROOT/tests/genai-perf-tests.json"
mv artifacts/ $RESULTS_FOLDER/
# upload benchmark results to buildkite # upload benchmark results to buildkite
python3 -m pip install tabulate pandas python3 -m pip install tabulate pandas
python3 "$BENCHMARK_ROOT/scripts/summary-nightly-results.py" python3 "$BENCHMARK_ROOT/scripts/summary-nightly-results.py"

View File

@ -1,5 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
import datetime import datetime
import json import json
import os import os

View File

@ -1,23 +0,0 @@
[
{
"test_name": "llama8B_tp1_genai_perf",
"qps_list": [4,8,16,32],
"common_parameters": {
"model": "meta-llama/Meta-Llama-3-8B-Instruct",
"tp": 1,
"port": 8000,
"num_prompts": 500,
"reuse_server": false
},
"vllm_server_parameters": {
"disable_log_stats": "",
"disable_log_requests": "",
"gpu_memory_utilization": 0.9,
"num_scheduler_steps": 10,
"max_num_seqs": 512,
"dtype": "bfloat16"
},
"genai_perf_input_parameters": {
}
}
]

View File

@ -56,11 +56,6 @@ steps:
env: env:
DOCKER_BUILDKIT: "1" DOCKER_BUILDKIT: "1"
- input: "Provide Release version here"
fields:
- text: "What is the release version?"
key: "release-version"
- block: "Build CPU release image" - block: "Build CPU release image"
key: block-cpu-release-image-build key: block-cpu-release-image-build
depends_on: ~ depends_on: ~
@ -71,7 +66,7 @@ steps:
queue: cpu_queue_postmerge queue: cpu_queue_postmerge
commands: commands:
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --progress plain -f Dockerfile.cpu ." - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$RELEASE_VERSION --progress plain -f Dockerfile.cpu ."
- "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)" - "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$RELEASE_VERSION"
env: env:
DOCKER_BUILDKIT: "1" DOCKER_BUILDKIT: "1"

View File

@ -13,7 +13,7 @@ numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build -t cpu-test-"$BUILDKITE_BU
numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" -t cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2 -f Dockerfile.cpu . numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" -t cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2 -f Dockerfile.cpu .
# Setup cleanup # Setup cleanup
remove_docker_container() { set -e; docker rm -f cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2-"$NUMA_NODE" || true; } remove_docker_container() { docker rm -f cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2-"$NUMA_NODE" || true; }
trap remove_docker_container EXIT trap remove_docker_container EXIT
remove_docker_container remove_docker_container
@ -30,12 +30,15 @@ function cpu_tests() {
# offline inference # offline inference
docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2-"$NUMA_NODE" bash -c " docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2-"$NUMA_NODE" bash -c "
set -e set -e
python3 examples/offline_inference/basic.py" python3 examples/offline_inference.py"
# Run basic model test # Run basic model test
docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c " docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c "
set -e set -e
pip install -r vllm/requirements-test.txt pip install pytest pytest-asyncio \
decord einops librosa peft Pillow sentence-transformers soundfile \
transformers_stream_generator matplotlib datamodel_code_generator
pip install torchvision --index-url https://download.pytorch.org/whl/cpu
pytest -v -s tests/models/decoder_only/language -m cpu_model pytest -v -s tests/models/decoder_only/language -m cpu_model
pytest -v -s tests/models/embedding/language -m cpu_model pytest -v -s tests/models/embedding/language -m cpu_model
pytest -v -s tests/models/encoder_decoder/language -m cpu_model pytest -v -s tests/models/encoder_decoder/language -m cpu_model
@ -61,7 +64,7 @@ function cpu_tests() {
pytest -s -v -k cpu_model \ pytest -s -v -k cpu_model \
tests/basic_correctness/test_chunked_prefill.py" tests/basic_correctness/test_chunked_prefill.py"
# online serving # online inference
docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c " docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c "
set -e set -e
export VLLM_CPU_KVCACHE_SPACE=10 export VLLM_CPU_KVCACHE_SPACE=10
@ -75,14 +78,8 @@ function cpu_tests() {
--num-prompts 20 \ --num-prompts 20 \
--endpoint /v1/completions \ --endpoint /v1/completions \
--tokenizer facebook/opt-125m" --tokenizer facebook/opt-125m"
# Run multi-lora tests
docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c "
set -e
pytest -s -v \
tests/lora/test_qwen2vl.py"
} }
# All of CPU tests are expected to be finished less than 40 mins. # All of CPU tests are expected to be finished less than 25 mins.
export -f cpu_tests export -f cpu_tests
timeout 40m bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" timeout 30m bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"

View File

@ -23,6 +23,6 @@ trap remove_docker_container EXIT
remove_docker_container remove_docker_container
# Run the image and test offline inference # Run the image and test offline inference
docker run -e HF_TOKEN -v /root/.cache/huggingface:/root/.cache/huggingface --name gh200-test --gpus=all --entrypoint="" gh200-test bash -c ' docker run --name gh200-test --gpus=all --entrypoint="" gh200-test bash -c '
python3 examples/offline_inference/cli.py --model meta-llama/Llama-3.2-1B python3 examples/offline_inference.py
' '

View File

@ -8,17 +8,9 @@ set -ex
docker build -t hpu-test-env -f Dockerfile.hpu . docker build -t hpu-test-env -f Dockerfile.hpu .
# Setup cleanup # Setup cleanup
# certain versions of HPU software stack have a bug that can
# override the exit code of the script, so we need to use
# separate remove_docker_container and remove_docker_container_and_exit
# functions, while other platforms only need one remove_docker_container
# function.
EXITCODE=1
remove_docker_container() { docker rm -f hpu-test || true; } remove_docker_container() { docker rm -f hpu-test || true; }
remove_docker_container_and_exit() { remove_docker_container; exit $EXITCODE; } trap remove_docker_container EXIT
trap remove_docker_container_and_exit EXIT
remove_docker_container remove_docker_container
# Run the image and launch offline inference # Run the image and launch offline inference
docker run --runtime=habana --name=hpu-test --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference/basic.py docker run --runtime=habana --name=hpu-test --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference.py
EXITCODE=$?

View File

@ -25,11 +25,8 @@ if [ -f /tmp/neuron-docker-build-timestamp ]; then
last_build=$(cat /tmp/neuron-docker-build-timestamp) last_build=$(cat /tmp/neuron-docker-build-timestamp)
current_time=$(date +%s) current_time=$(date +%s)
if [ $((current_time - last_build)) -gt 86400 ]; then if [ $((current_time - last_build)) -gt 86400 ]; then
# Remove dangling images (those that are not tagged and not used by any container)
docker image prune -f docker image prune -f
# Remove unused volumes / force the system prune for old images as well. docker system prune -f
docker volume prune -f && docker system prune -f
# Remove huggingface model artifacts and compiler cache
rm -rf "${HF_MOUNT:?}/*" rm -rf "${HF_MOUNT:?}/*"
rm -rf "${NEURON_COMPILE_CACHE_MOUNT:?}/*" rm -rf "${NEURON_COMPILE_CACHE_MOUNT:?}/*"
echo "$current_time" > /tmp/neuron-docker-build-timestamp echo "$current_time" > /tmp/neuron-docker-build-timestamp
@ -54,4 +51,4 @@ docker run --rm -it --device=/dev/neuron0 --device=/dev/neuron1 --network host \
-e "NEURON_COMPILE_CACHE_URL=${NEURON_COMPILE_CACHE_MOUNT}" \ -e "NEURON_COMPILE_CACHE_URL=${NEURON_COMPILE_CACHE_MOUNT}" \
--name "${container_name}" \ --name "${container_name}" \
${image_name} \ ${image_name} \
/bin/bash -c "python3 /workspace/vllm/examples/offline_inference/neuron.py && python3 -m pytest /workspace/vllm/tests/neuron/ -v --capture=tee-sys" /bin/bash -c "python3 /workspace/vllm/examples/offline_inference_neuron.py"

View File

@ -13,4 +13,4 @@ trap remove_docker_container EXIT
remove_docker_container remove_docker_container
# Run the image and launch offline inference # Run the image and launch offline inference
docker run --network host --env VLLM_OPENVINO_KVCACHE_SPACE=1 --name openvino-test openvino-test python3 /workspace/examples/offline_inference/basic.py docker run --network host --env VLLM_OPENVINO_KVCACHE_SPACE=1 --name openvino-test openvino-test python3 /workspace/examples/offline_inference.py

11
.buildkite/run-tpu-test.sh Executable file → Normal file
View File

@ -14,13 +14,4 @@ remove_docker_container
# For HF_TOKEN. # For HF_TOKEN.
source /etc/environment source /etc/environment
# Run a simple end-to-end example. # Run a simple end-to-end example.
docker run --privileged --net host --shm-size=16G -it \ docker run --privileged --net host --shm-size=16G -it -e "HF_TOKEN=$HF_TOKEN" --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 -m pip install pytest && python3 -m pip install lm_eval[api]==0.4.4 && pytest -v -s /workspace/vllm/tests/entrypoints/openai/test_accuracy.py && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py"
-e "HF_TOKEN=$HF_TOKEN" --name tpu-test \
vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git \
&& python3 -m pip install pytest \
&& python3 -m pip install lm_eval[api]==0.4.4 \
&& pytest -v -s /workspace/vllm/tests/entrypoints/openai/test_accuracy.py \
&& pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \
&& python3 /workspace/vllm/tests/tpu/test_compilation.py \
&& python3 /workspace/vllm/tests/tpu/test_quantization_accuracy.py \
&& python3 /workspace/vllm/examples/offline_inference/tpu.py"

View File

@ -14,6 +14,6 @@ remove_docker_container
# Run the image and test offline inference/tensor parallel # Run the image and test offline inference/tensor parallel
docker run --name xpu-test --device /dev/dri -v /dev/dri/by-path:/dev/dri/by-path --entrypoint="" xpu-test sh -c ' docker run --name xpu-test --device /dev/dri -v /dev/dri/by-path:/dev/dri/by-path --entrypoint="" xpu-test sh -c '
python3 examples/offline_inference/basic.py python3 examples/offline_inference.py
python3 examples/offline_inference/cli.py -tp 2 python3 examples/offline_inference_cli.py -tp 2
' '

View File

@ -38,7 +38,7 @@ steps:
- pip install -r requirements-docs.txt - pip install -r requirements-docs.txt
- SPHINXOPTS=\"-W\" make html - SPHINXOPTS=\"-W\" make html
# Check API reference (if it fails, you may have missing mock imports) # Check API reference (if it fails, you may have missing mock imports)
- grep \"sig sig-object py\" build/html/api/inference_params.html - grep \"sig sig-object py\" build/html/dev/sampling_params.html
- label: Async Engine, Inputs, Utils, Worker Test # 24min - label: Async Engine, Inputs, Utils, Worker Test # 24min
fast_check: true fast_check: true
@ -50,9 +50,9 @@ steps:
- tests/multimodal - tests/multimodal
- tests/test_utils - tests/test_utils
- tests/worker - tests/worker
- tests/standalone_tests/lazy_imports.py - tests/standalone_tests/lazy_torch_compile.py
commands: commands:
- python3 standalone_tests/lazy_imports.py - python3 standalone_tests/lazy_torch_compile.py
- pytest -v -s mq_llm_engine # MQLLMEngine - pytest -v -s mq_llm_engine # MQLLMEngine
- pytest -v -s async_engine # AsyncLLMEngine - pytest -v -s async_engine # AsyncLLMEngine
- NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py - NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py
@ -76,9 +76,7 @@ steps:
- tests/basic_correctness/test_basic_correctness - tests/basic_correctness/test_basic_correctness
- tests/basic_correctness/test_cpu_offload - tests/basic_correctness/test_cpu_offload
- tests/basic_correctness/test_preemption - tests/basic_correctness/test_preemption
- tests/basic_correctness/test_cumem.py
commands: commands:
- pytest -v -s basic_correctness/test_cumem.py
- pytest -v -s basic_correctness/test_basic_correctness.py - pytest -v -s basic_correctness/test_basic_correctness.py
- pytest -v -s basic_correctness/test_cpu_offload.py - pytest -v -s basic_correctness/test_cpu_offload.py
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
@ -108,7 +106,7 @@ steps:
source_file_dependencies: source_file_dependencies:
- vllm/ - vllm/
commands: commands:
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py --ignore=entrypoints/llm/test_collective_rpc.py - pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process - pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process - pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
@ -127,17 +125,11 @@ steps:
- tests/distributed - tests/distributed
- tests/spec_decode/e2e/test_integration_dist_tp4 - tests/spec_decode/e2e/test_integration_dist_tp4
- tests/compile - tests/compile
- examples/offline_inference/rlhf.py
- examples/offline_inference/ray_placement.py
commands: commands:
- pytest -v -s distributed/test_utils.py - pytest -v -s distributed/test_utils.py
- pytest -v -s compile/test_basic_correctness.py - pytest -v -s compile/test_basic_correctness.py
- pytest -v -s distributed/test_pynccl.py - pytest -v -s distributed/test_pynccl.py
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
# TODO: create a dedicated test section for multi-GPU example tests
# when we have multiple distributed example tests
- python3 ../examples/offline_inference/rlhf.py
- RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/ray_placement.py
- label: Metrics, Tracing Test # 10min - label: Metrics, Tracing Test # 10min
num_gpus: 2 num_gpus: 2
@ -185,16 +177,7 @@ steps:
- vllm/ - vllm/
- tests/v1 - tests/v1
commands: commands:
# split the test to avoid interference - VLLM_USE_V1=1 pytest -v -s v1
- VLLM_USE_V1=1 pytest -v -s v1/core
- VLLM_USE_V1=1 pytest -v -s v1/engine
- VLLM_USE_V1=1 pytest -v -s v1/sample
- VLLM_USE_V1=1 pytest -v -s v1/worker
- VLLM_USE_V1=1 pytest -v -s v1/test_stats.py
- VLLM_USE_V1=1 pytest -v -s v1/test_utils.py
# TODO: accuracy does not match, whether setting
# VLLM_USE_FLASHINFER_SAMPLER or not on H100.
- VLLM_USE_V1=1 pytest -v -s v1/e2e
- label: Examples Test # 25min - label: Examples Test # 25min
working_dir: "/vllm-workspace/examples" working_dir: "/vllm-workspace/examples"
@ -204,19 +187,19 @@ steps:
- examples/ - examples/
commands: commands:
- pip install tensorizer # for tensorizer test - pip install tensorizer # for tensorizer test
- python3 offline_inference/basic.py - python3 offline_inference.py
- python3 offline_inference/cpu_offload.py - python3 cpu_offload.py
- python3 offline_inference/chat.py - python3 offline_inference_chat.py
- python3 offline_inference/prefix_caching.py - python3 offline_inference_with_prefix.py
- python3 offline_inference/llm_engine_example.py - python3 llm_engine_example.py
- python3 offline_inference/vision_language.py - python3 offline_inference_vision_language.py
- python3 offline_inference/vision_language_multi_image.py - python3 offline_inference_vision_language_multi_image.py
- python3 other/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 other/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
- python3 offline_inference/encoder_decoder.py - python3 offline_inference_encoder_decoder.py
- python3 offline_inference/classification.py - python3 offline_inference_classification.py
- python3 offline_inference/embedding.py - python3 offline_inference_embedding.py
- python3 offline_inference/scoring.py - python3 offline_inference_scoring.py
- python3 offline_inference/profiling.py --model facebook/opt-125m run_num_steps --num-steps 2 - python3 offline_profile.py --model facebook/opt-125m run_num_steps --num-steps 2
- label: Prefix Caching Test # 9min - label: Prefix Caching Test # 9min
mirror_hardwares: [amd] mirror_hardwares: [amd]
@ -231,7 +214,6 @@ steps:
- vllm/model_executor/layers - vllm/model_executor/layers
- vllm/sampling_metadata.py - vllm/sampling_metadata.py
- tests/samplers - tests/samplers
- tests/conftest.py
commands: commands:
- pytest -v -s samplers - pytest -v -s samplers
- VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers - VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers
@ -247,15 +229,13 @@ steps:
- pytest -v -s test_logits_processor.py - pytest -v -s test_logits_processor.py
- pytest -v -s model_executor/test_guided_processors.py - pytest -v -s model_executor/test_guided_processors.py
- label: Speculative decoding tests # 40min - label: Speculative decoding tests # 30min
source_file_dependencies: source_file_dependencies:
- vllm/spec_decode - vllm/spec_decode
- tests/spec_decode - tests/spec_decode
- vllm/model_executor/models/eagle.py
commands: commands:
- pytest -v -s spec_decode/e2e/test_multistep_correctness.py - pytest -v -s spec_decode/e2e/test_multistep_correctness.py
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py
- pytest -v -s spec_decode/e2e/test_eagle_correctness.py
- label: LoRA Test %N # 15min each - label: LoRA Test %N # 15min each
mirror_hardwares: [amd] mirror_hardwares: [amd]
@ -351,7 +331,6 @@ steps:
- vllm/ - vllm/
- tests/models - tests/models
commands: commands:
- pytest -v -s models/test_transformers.py
- pytest -v -s models/test_registry.py - pytest -v -s models/test_registry.py
- pytest -v -s models/test_initialization.py - pytest -v -s models/test_initialization.py
@ -388,7 +367,6 @@ steps:
- tests/models/encoder_decoder/vision_language - tests/models/encoder_decoder/vision_language
commands: commands:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- pytest -v -s models/multimodal
- pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model' - pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model'
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model' - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model'
- pytest -v -s models/embedding/vision_language -m core_model - pytest -v -s models/embedding/vision_language -m core_model
@ -479,22 +457,16 @@ steps:
- vllm/worker/worker_base.py - vllm/worker/worker_base.py
- vllm/worker/worker.py - vllm/worker/worker.py
- vllm/worker/model_runner.py - vllm/worker/model_runner.py
- entrypoints/llm/test_collective_rpc.py
commands: commands:
- pytest -v -s entrypoints/llm/test_collective_rpc.py
- torchrun --nproc-per-node=2 distributed/test_torchrun_example.py
- pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_basic_correctness.py
- pytest -v -s ./compile/test_wrapper.py - pytest -v -s ./compile/test_wrapper.py
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed' - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)' - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)'
# Avoid importing model tests that cause CUDA reinitialization error # Avoid importing model tests that cause CUDA reinitialization error
- pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)' - pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)' - pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)' - pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)'
# this test fails consistently. - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
# TODO: investigate and fix
# - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/disagg_test.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/disagg_test.py
@ -532,9 +504,7 @@ steps:
- vllm/engine - vllm/engine
- tests/multi_step - tests/multi_step
commands: commands:
# this test is quite flaky - pytest -v -s multi_step/test_correctness_async_llm.py
# TODO: investigate and fix.
# - pytest -v -s multi_step/test_correctness_async_llm.py
- pytest -v -s multi_step/test_correctness_llm.py - pytest -v -s multi_step/test_correctness_llm.py
- label: Pipeline Parallelism Test # 45min - label: Pipeline Parallelism Test # 45min

27
.github/CODEOWNERS vendored
View File

@ -2,35 +2,32 @@
# for more info about CODEOWNERS file # for more info about CODEOWNERS file
# This lists cover the "core" components of vLLM that require careful review # This lists cover the "core" components of vLLM that require careful review
/vllm/attention/backends/abstract.py @WoosukKwon @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill /vllm/attention/backends/abstract.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
/vllm/core @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill /vllm/core @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
/vllm/engine/llm_engine.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill /vllm/engine/llm_engine.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
/vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill /vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
/vllm/worker/worker_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill /vllm/worker/worker_base.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
/vllm/worker/worker.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill /vllm/worker/worker.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill /vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth
/vllm/model_executor/guided_decoding @mgoin
/vllm/multimodal @DarkLight1337 @ywang96
CMakeLists.txt @tlrmchlsmth CMakeLists.txt @tlrmchlsmth
# vLLM V1 # vLLM V1
/vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat /vllm/v1 @WoosukKwon @robertgshaw2-neuralmagic @njhill @ywang96 @comaniac @alexm-neuralmagic
# Test ownership # Test ownership
/tests/async_engine @njhill @robertgshaw2-redhat @simon-mo /tests/async_engine @njhill @robertgshaw2-neuralmagic @simon-mo
/tests/test_inputs.py @DarkLight1337 @ywang96 /tests/test_inputs.py @DarkLight1337 @ywang96
/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @simon-mo /tests/entrypoints @DarkLight1337 @robertgshaw2-neuralmagic @simon-mo
/tests/models @DarkLight1337 @ywang96 /tests/models @DarkLight1337 @ywang96
/tests/multimodal @DarkLight1337 @ywang96 /tests/multimodal @DarkLight1337 @ywang96
/tests/prefix_caching @comaniac @KuntaiDu /tests/prefix_caching @comaniac @KuntaiDu
/tests/spec_decode @njhill @LiuXiaoxuanPKU /tests/spec_decode @njhill @LiuXiaoxuanPKU
/tests/kernels @tlrmchlsmth @WoosukKwon /tests/kernels @tlrmchlsmth @WoosukKwon
/tests/quantization @mgoin @robertgshaw2-redhat /tests/quantization @mgoin @robertgshaw2-neuralmagic
/.buildkite/lm-eval-harness @mgoin @simon-mo /.buildkite/lm-eval-harness @mgoin @simon-mo
/tests/distributed/test_multi_node_assignment.py @youkaichao /tests/distributed/test_multi_node_assignment.py @youkaichao
/tests/distributed/test_pipeline_parallel.py @youkaichao /tests/distributed/test_pipeline_parallel.py @youkaichao
/tests/distributed/test_same_node.py @youkaichao /tests/distributed/test_same_node.py @youkaichao
/tests/multi_step @alexm-redhat @comaniac /tests/multi_step @alexm-neuralmagic @comaniac
/tests/weight_loading @mgoin @youkaichao /tests/weight_loading @mgoin @youkaichao
/tests/basic_correctness/test_chunked_prefill @rkooo567 @comaniac /tests/basic_correctness/test_chunked_prefill @rkooo567 @comaniac

View File

@ -30,6 +30,15 @@ body:
</details> </details>
validations: validations:
required: true required: true
- type: textarea
attributes:
label: Model Input Dumps
description: |
If you are facing crashing due to illegal memory access or other issues with model execution, vLLM may dump the problematic input of the model. In this case, you will see the message `Error in model execution (input dumped to /tmp/err_xxx.pkl)`. If you see this message, please zip the file (because GitHub doesn't support .pkl file format) and upload it here. This will help us to reproduce the issue and facilitate the debugging process.
placeholder: |
Upload the dumped input file.
validations:
required: false
- type: textarea - type: textarea
attributes: attributes:
label: 🐛 Describe the bug label: 🐛 Describe the bug

37
.github/mergify.yml vendored
View File

@ -35,43 +35,6 @@ pull_request_rules:
add: add:
- frontend - frontend
- name: label-structured-output
description: Automatically apply structured-output label
conditions:
- or:
- files~=^vllm/model_executor/guided_decoding/
- files=tests/model_executor/test_guided_processors.py
- files=tests/entrypoints/llm/test_guided_generate.py
- files=benchmarks/benchmark_serving_guided.py
- files=benchmarks/benchmark_guided.py
actions:
label:
add:
- structured-output
- name: label-speculative-decoding
description: Automatically apply speculative-decoding label
conditions:
- or:
- files~=^vllm/spec_decode/
- files=vllm/model_executor/layers/spec_decode_base_sampler.py
- files~=^tests/spec_decode/
actions:
label:
add:
- speculative-decoding
- name: label-v1
description: Automatically apply v1 label
conditions:
- or:
- files~=^vllm/v1/
- files~=^tests/v1/
actions:
label:
add:
- v1
- name: ping author on conflicts and add 'needs-rebase' label - name: ping author on conflicts and add 'needs-rebase' label
conditions: conditions:
- conflict - conflict

40
.github/workflows/actionlint.yml vendored Normal file
View File

@ -0,0 +1,40 @@
name: Lint GitHub Actions workflows
on:
push:
branches:
- "main"
paths:
- '.github/workflows/*.ya?ml'
- '.github/workflows/actionlint.*'
- '.github/workflows/matchers/actionlint.json'
pull_request:
branches:
- "main"
paths:
- '.github/workflows/*.ya?ml'
- '.github/workflows/actionlint.*'
- '.github/workflows/matchers/actionlint.json'
env:
LC_ALL: en_US.UTF-8
defaults:
run:
shell: bash
permissions:
contents: read
jobs:
actionlint:
runs-on: ubuntu-latest
steps:
- name: "Checkout"
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
fetch-depth: 0
- name: "Run actionlint"
run: |
echo "::add-matcher::.github/workflows/matchers/actionlint.json"
tools/actionlint.sh -color

53
.github/workflows/clang-format.yml vendored Normal file
View File

@ -0,0 +1,53 @@
name: clang-format
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
paths:
- '**/*.h'
- '**/*.cpp'
- '**/*.cu'
- '**/*.cuh'
- '.github/workflows/clang-format.yml'
pull_request:
branches:
- main
paths:
- '**/*.h'
- '**/*.cpp'
- '**/*.cu'
- '**/*.cuh'
- '.github/workflows/clang-format.yml'
jobs:
clang-format:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.11"]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install clang-format==18.1.5
- name: Running clang-format
run: |
EXCLUDES=(
'csrc/moe/topk_softmax_kernels.cu'
'csrc/quantization/gguf/ggml-common.h'
'csrc/quantization/gguf/dequantize.cuh'
'csrc/quantization/gguf/vecdotq.cuh'
'csrc/quantization/gguf/mmq.cuh'
'csrc/quantization/gguf/mmvq.cuh'
)
find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \
| grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \
| xargs clang-format --dry-run --Werror

45
.github/workflows/codespell.yml vendored Normal file
View File

@ -0,0 +1,45 @@
name: codespell
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
paths:
- "**/*.py"
- "**/*.md"
- "**/*.rst"
- pyproject.toml
- requirements-lint.txt
- .github/workflows/codespell.yml
pull_request:
branches:
- main
paths:
- "**/*.py"
- "**/*.md"
- "**/*.rst"
- pyproject.toml
- requirements-lint.txt
- .github/workflows/codespell.yml
jobs:
codespell:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.12"]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements-lint.txt
- name: Spelling check with codespell
run: |
codespell --toml pyproject.toml

View File

@ -27,7 +27,7 @@ jobs:
version: v3.10.1 version: v3.10.1
- name: Run chart-testing (lint) - name: Run chart-testing (lint)
run: ct lint --target-branch ${{ github.event.repository.default_branch }} --chart-dirs examples/online_serving/chart-helm --charts examples/online_serving/chart-helm run: ct lint --target-branch ${{ github.event.repository.default_branch }} --chart-dirs examples/chart-helm --charts examples/chart-helm
- name: Setup minio - name: Setup minio
run: | run: |
@ -64,8 +64,7 @@ jobs:
run: | run: |
export AWS_ACCESS_KEY_ID=minioadmin export AWS_ACCESS_KEY_ID=minioadmin
export AWS_SECRET_ACCESS_KEY=minioadmin export AWS_SECRET_ACCESS_KEY=minioadmin
sleep 30 && kubectl -n ns-vllm logs -f "$(kubectl -n ns-vllm get pods | awk '/deployment/ {print $1;exit}')" & helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/chart-helm -f examples/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env"
helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env"
- name: curl test - name: curl test
run: | run: |

17
.github/workflows/matchers/ruff.json vendored Normal file
View File

@ -0,0 +1,17 @@
{
"problemMatcher": [
{
"owner": "ruff",
"pattern": [
{
"regexp": "^(.+?):(\\d+):(\\d+): (\\w+): (.+)$",
"file": 1,
"line": 2,
"column": 3,
"code": 4,
"message": 5
}
]
}
]
}

51
.github/workflows/mypy.yaml vendored Normal file
View File

@ -0,0 +1,51 @@
name: mypy
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
paths:
- '**/*.py'
- '.github/workflows/mypy.yaml'
- 'tools/mypy.sh'
- 'pyproject.toml'
pull_request:
branches:
- main
# This workflow is only relevant when one of the following files changes.
# However, we have github configured to expect and require this workflow
# to run and pass before github with auto-merge a pull request. Until github
# allows more flexible auto-merge policy, we can just run this on every PR.
# It doesn't take that long to run, anyway.
#paths:
# - '**/*.py'
# - '.github/workflows/mypy.yaml'
# - 'tools/mypy.sh'
# - 'pyproject.toml'
jobs:
mypy:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install mypy==1.11.1
pip install types-setuptools
pip install types-PyYAML
pip install types-requests
pip install types-setuptools
- name: Mypy
run: |
echo "::add-matcher::.github/workflows/matchers/mypy.json"
tools/mypy.sh 1 ${{ matrix.python-version }}

37
.github/workflows/png-lint.yml vendored Normal file
View File

@ -0,0 +1,37 @@
name: Lint PNG exports from excalidraw
on:
push:
branches:
- "main"
paths:
- '*.excalidraw.png'
- '.github/workflows/png-lint.yml'
pull_request:
branches:
- "main"
paths:
- '*.excalidraw.png'
- '.github/workflows/png-lint.yml'
env:
LC_ALL: en_US.UTF-8
defaults:
run:
shell: bash
permissions:
contents: read
jobs:
actionlint:
runs-on: ubuntu-latest
steps:
- name: "Checkout"
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
fetch-depth: 0
- name: "Run png-lint.sh to check excalidraw exported images"
run: |
tools/png-lint.sh

View File

@ -1,19 +0,0 @@
name: pre-commit
on:
pull_request:
push:
branches: [main]
jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: "3.12"
- run: echo "::add-matcher::.github/workflows/matchers/actionlint.json"
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
with:
extra_args: --all-files --hook-stage manual

View File

@ -2,6 +2,7 @@ name: PR Reminder Comment Bot
on: on:
pull_request_target: pull_request_target:
types: [opened] types: [opened]
jobs: jobs:
pr_reminder: pr_reminder:
runs-on: ubuntu-latest runs-on: ubuntu-latest
@ -14,12 +15,7 @@ jobs:
owner: context.repo.owner, owner: context.repo.owner,
repo: context.repo.repo, repo: context.repo.repo,
issue_number: context.issue.number, issue_number: context.issue.number,
body: '👋 Hi! Thank you for contributing to the vLLM project.\n\n' + body: '👋 Hi! Thank you for contributing to the vLLM project.\n Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your `fastcheck` build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping `simon-mo` or `khluu` to add you in our Buildkite org. \n\nOnce the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.\n\n To run CI, PR reviewers can do one of these:\n- Add `ready` label to the PR\n- Enable auto-merge.\n\n🚀'
'💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.\n\n' +
'Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your `fastcheck` build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping `simon-mo` or `khluu` to add you in our Buildkite org.\n\n' +
'Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.\n\n' +
'To run CI, PR reviewers can either: Add `ready` label to the PR or enable auto-merge.\n\n' +
'🚀'
}) })
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

52
.github/workflows/ruff.yml vendored Normal file
View File

@ -0,0 +1,52 @@
name: ruff
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
paths:
- "**/*.py"
- pyproject.toml
- requirements-lint.txt
- .github/workflows/matchers/ruff.json
- .github/workflows/ruff.yml
pull_request:
branches:
- main
# This workflow is only relevant when one of the following files changes.
# However, we have github configured to expect and require this workflow
# to run and pass before github with auto-merge a pull request. Until github
# allows more flexible auto-merge policy, we can just run this on every PR.
# It doesn't take that long to run, anyway.
#paths:
# - "**/*.py"
# - pyproject.toml
# - requirements-lint.txt
# - .github/workflows/matchers/ruff.json
# - .github/workflows/ruff.yml
jobs:
ruff:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.12"]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements-lint.txt
- name: Analysing the code with ruff
run: |
echo "::add-matcher::.github/workflows/matchers/ruff.json"
ruff check --output-format github .
- name: Run isort
run: |
isort . --check-only

37
.github/workflows/shellcheck.yml vendored Normal file
View File

@ -0,0 +1,37 @@
name: Lint shell scripts
on:
push:
branches:
- "main"
paths:
- '**/*.sh'
- '.github/workflows/shellcheck.yml'
pull_request:
branches:
- "main"
paths:
- '**/*.sh'
- '.github/workflows/shellcheck.yml'
env:
LC_ALL: en_US.UTF-8
defaults:
run:
shell: bash
permissions:
contents: read
jobs:
shellcheck:
runs-on: ubuntu-latest
steps:
- name: "Checkout"
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
fetch-depth: 0
- name: "Check shell scripts"
run: |
tools/shellcheck.sh

32
.github/workflows/sphinx-lint.yml vendored Normal file
View File

@ -0,0 +1,32 @@
name: Lint documentation
on:
push:
branches:
- main
paths:
- "docs/**"
pull_request:
branches:
- main
paths:
- "docs/**"
jobs:
sphinx-lint:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.12"]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements-lint.txt
- name: Linting docs
run: tools/sphinx-lint.sh

38
.github/workflows/yapf.yml vendored Normal file
View File

@ -0,0 +1,38 @@
name: yapf
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
paths:
- "**/*.py"
- .github/workflows/yapf.yml
pull_request:
branches:
- main
paths:
- "**/*.py"
- .github/workflows/yapf.yml
jobs:
yapf:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.12"]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install yapf==0.32.0
pip install toml==0.10.2
- name: Running yapf
run: |
yapf --diff --recursive .

5
.gitignore vendored
View File

@ -79,7 +79,10 @@ instance/
# Sphinx documentation # Sphinx documentation
docs/_build/ docs/_build/
docs/source/getting_started/examples/ docs/source/getting_started/examples/*.rst
!**/*.template.rst
docs/source/getting_started/examples/*.md
!**/*.template.md
# PyBuilder # PyBuilder
.pybuilder/ .pybuilder/

View File

@ -1,110 +0,0 @@
default_stages:
- pre-commit # Run locally
- manual # Run in CI
repos:
- repo: https://github.com/google/yapf
rev: v0.43.0
hooks:
- id: yapf
args: [--in-place, --verbose]
additional_dependencies: [toml] # TODO: Remove when yapf is upgraded
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.3
hooks:
- id: ruff
args: [--output-format, github]
- repo: https://github.com/codespell-project/codespell
rev: v2.4.0
hooks:
- id: codespell
exclude: 'benchmarks/sonnet.txt|(build|tests/(lora/data|models/fixtures|prompts))/.*'
- repo: https://github.com/PyCQA/isort
rev: 5.13.2
hooks:
- id: isort
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v19.1.7
hooks:
- id: clang-format
exclude: 'csrc/(moe/topk_softmax_kernels.cu|quantization/gguf/(ggml-common.h|dequantize.cuh|vecdotq.cuh|mmq.cuh|mmvq.cuh))'
types_or: [c++, cuda]
args: [--style=file, --verbose]
- repo: https://github.com/jackdewinter/pymarkdown
rev: v0.9.27
hooks:
- id: pymarkdown
files: docs/.*
- repo: https://github.com/rhysd/actionlint
rev: v1.7.7
hooks:
- id: actionlint
- repo: local
hooks:
- id: mypy-local
name: Run mypy for local Python installation
entry: tools/mypy.sh 0 "local"
language: python
types: [python]
additional_dependencies: &mypy_deps [mypy==1.11.1, types-setuptools, types-PyYAML, types-requests]
stages: [pre-commit] # Don't run in CI
- id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
name: Run mypy for Python 3.9
entry: tools/mypy.sh 1 "3.9"
language: python
types: [python]
additional_dependencies: *mypy_deps
stages: [manual] # Only run in CI
- id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
name: Run mypy for Python 3.10
entry: tools/mypy.sh 1 "3.10"
language: python
types: [python]
additional_dependencies: *mypy_deps
stages: [manual] # Only run in CI
- id: mypy-3.11 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
name: Run mypy for Python 3.11
entry: tools/mypy.sh 1 "3.11"
language: python
types: [python]
additional_dependencies: *mypy_deps
stages: [manual] # Only run in CI
- id: mypy-3.12 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
name: Run mypy for Python 3.12
entry: tools/mypy.sh 1 "3.12"
language: python
types: [python]
additional_dependencies: *mypy_deps
stages: [manual] # Only run in CI
- id: shellcheck
name: Lint shell scripts
entry: tools/shellcheck.sh
language: script
types: [shell]
- id: png-lint
name: Lint PNG exports from excalidraw
entry: tools/png-lint.sh
language: script
types: [png]
- id: signoff-commit
name: Sign-off Commit
entry: bash
args:
- -c
- |
if ! grep -q "^Signed-off-by: $(git config user.name) <$(git config user.email)>" .git/COMMIT_EDITMSG; then
printf "\nSigned-off-by: $(git config user.name) <$(git config user.email)>\n" >> .git/COMMIT_EDITMSG
fi
language: system
verbose: true
stages: [commit-msg]
- id: check-spdx-header
name: Check SPDX headers
entry: python tools/check_spdx_header.py
language: python
types: [python]
- id: suggestion
name: Suggestion
entry: bash -c 'echo "To bypass pre-commit hooks, add --no-verify to git commit."'
language: system
verbose: true
pass_filenames: false

87
CMakeLists.txt Executable file → Normal file
View File

@ -24,6 +24,9 @@ include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
# Suppress potential warnings about unused manually-specified variables # Suppress potential warnings about unused manually-specified variables
set(ignoreMe "${VLLM_PYTHON_PATH}") set(ignoreMe "${VLLM_PYTHON_PATH}")
# Prevent installation of dependencies (cutlass) by default.
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
# #
# Supported python versions. These versions will be searched in order, the # Supported python versions. These versions will be searched in order, the
# first match will be selected. These should be kept in sync with setup.py. # first match will be selected. These should be kept in sync with setup.py.
@ -178,31 +181,6 @@ message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}")
# Define other extension targets # Define other extension targets
# #
#
# cumem_allocator extension
#
set(VLLM_CUMEM_EXT_SRC
"csrc/cumem_allocator.cpp")
set_gencode_flags_for_srcs(
SRCS "${VLLM_CUMEM_EXT_SRC}"
CUDA_ARCHS "${CUDA_ARCHS}")
if(VLLM_GPU_LANG STREQUAL "CUDA")
message(STATUS "Enabling cumem allocator extension.")
# link against cuda driver library
list(APPEND CUMEM_LIBS cuda)
define_gpu_extension_target(
cumem_allocator
DESTINATION vllm
LANGUAGE CXX
SOURCES ${VLLM_CUMEM_EXT_SRC}
LIBRARIES ${CUMEM_LIBS}
USE_SABI 3.8
WITH_SOABI)
endif()
# #
# _C extension # _C extension
# #
@ -245,7 +223,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
FetchContent_Declare( FetchContent_Declare(
cutlass cutlass
GIT_REPOSITORY https://github.com/nvidia/cutlass.git GIT_REPOSITORY https://github.com/nvidia/cutlass.git
GIT_TAG v3.7.0 GIT_TAG v3.6.0
GIT_PROGRESS TRUE GIT_PROGRESS TRUE
# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
@ -275,7 +253,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# Only build Marlin kernels if we are building for at least some compatible archs. # Only build Marlin kernels if we are building for at least some compatible archs.
# Keep building Marlin for 9.0 as there are some group sizes and shapes that # Keep building Marlin for 9.0 as there are some group sizes and shapes that
# are not supported by Machete yet. # are not supported by Machete yet.
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}") cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0" ${CUDA_ARCHS})
if (MARLIN_ARCHS) if (MARLIN_ARCHS)
set(MARLIN_SRCS set(MARLIN_SRCS
"csrc/quantization/fp8/fp8_marlin.cu" "csrc/quantization/fp8/fp8_marlin.cu"
@ -296,15 +274,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif() endif()
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require # The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.0 or later (and only work on Hopper, 9.0a for now). # CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now).
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a" "${CUDA_ARCHS}") cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
set(SRCS set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
SRCS "${SRCS}" SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}") CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
@ -356,7 +329,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# 2:4 Sparse Kernels # 2:4 Sparse Kernels
# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor # The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
# require CUDA 12.2 or later (and only work on Hopper, 9.0a for now). # require CUDA 12.2 or later (and only work on Hopper, 9.0/9.0a for now).
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
set(SRCS "csrc/sparse/cutlass/sparse_compressor_c3x.cu" set(SRCS "csrc/sparse/cutlass/sparse_compressor_c3x.cu"
"csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu") "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
@ -537,7 +510,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
endif() endif()
# vllm-flash-attn currently only supported on CUDA # vllm-flash-attn currently only supported on CUDA
if (NOT VLLM_GPU_LANG STREQUAL "CUDA") if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda")
return() return()
endif () endif ()
@ -560,7 +533,7 @@ endif()
# They should be identical but if they aren't, this is a massive footgun. # They should be identical but if they aren't, this is a massive footgun.
# #
# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place. # The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place.
# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2) or --component _vllm_fa3_C (for FA3). # To only install vllm-flash-attn, use --component vllm_flash_attn_c.
# If no component is specified, vllm-flash-attn is still installed. # If no component is specified, vllm-flash-attn is still installed.
# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading. # If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading.
@ -572,41 +545,43 @@ if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR})
endif() endif()
if(VLLM_FLASH_ATTN_SRC_DIR) if(VLLM_FLASH_ATTN_SRC_DIR)
FetchContent_Declare( FetchContent_Declare(vllm-flash-attn SOURCE_DIR ${VLLM_FLASH_ATTN_SRC_DIR})
vllm-flash-attn SOURCE_DIR
${VLLM_FLASH_ATTN_SRC_DIR}
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
)
else() else()
FetchContent_Declare( FetchContent_Declare(
vllm-flash-attn vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG d4e09037abf588af1ec47d0e966b237ee376876c GIT_TAG 96266b1111111f3d11aabefaf3bacbab6a89d03c
GIT_PROGRESS TRUE GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types # Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
) )
endif() endif()
# Set the parent build flag so that the vllm-flash-attn library does not redo compile flag and arch initialization.
set(VLLM_PARENT_BUILD ON)
# Ensure the vllm/vllm_flash_attn directory exists before installation
install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")" COMPONENT vllm_flash_attn_c)
# Make sure vllm-flash-attn install rules are nested under vllm/
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" COMPONENT vllm_flash_attn_c)
install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c)
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" COMPONENT vllm_flash_attn_c)
# Fetch the vllm-flash-attn library # Fetch the vllm-flash-attn library
FetchContent_MakeAvailable(vllm-flash-attn) FetchContent_MakeAvailable(vllm-flash-attn)
message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}") message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}")
# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in # Restore the install prefix
# case only one is built, in the case both are built redundant work is done) install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c)
install( install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" COMPONENT vllm_flash_attn_c)
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
DESTINATION vllm_flash_attn
COMPONENT _vllm_fa2_C
FILES_MATCHING PATTERN "*.py"
)
# Copy over the vllm-flash-attn python files
install( install(
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/ DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
DESTINATION vllm_flash_attn DESTINATION vllm/vllm_flash_attn
COMPONENT _vllm_fa3_C COMPONENT vllm_flash_attn_c
FILES_MATCHING PATTERN "*.py" FILES_MATCHING PATTERN "*.py"
) )
# Nothing after vllm-flash-attn, see comment about macros above # Nothing after vllm-flash-attn, see comment about macros above

View File

@ -61,7 +61,7 @@ representative at an online or offline/IRL event.
Instances of abusive, harassing, or otherwise unacceptable behavior may be Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement in the #code-of-conduct reported to the community leaders responsible for enforcement in the #code-of-conduct
channel in the [vLLM Slack](https://slack.vllm.ai). channel in the [vLLM Discord](https://discord.com/invite/jz7wjKhh6g).
All complaints will be reviewed and investigated promptly and fairly. All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the All community leaders are obligated to respect the privacy and security of the

View File

@ -2,8 +2,8 @@
# to run the OpenAI compatible server. # to run the OpenAI compatible server.
# Please update any changes made here to # Please update any changes made here to
# docs/source/contributing/dockerfile/dockerfile.md and # docs/source/dev/dockerfile/dockerfile.md and
# docs/source/assets/contributing/dockerfile-stages-dependency.png # docs/source/assets/dev/dockerfile-stages-dependency.png
ARG CUDA_VERSION=12.4.1 ARG CUDA_VERSION=12.4.1
#################### BASE BUILD IMAGE #################### #################### BASE BUILD IMAGE ####################
@ -52,7 +52,7 @@ WORKDIR /workspace
# after this step # after this step
RUN --mount=type=cache,target=/root/.cache/pip \ RUN --mount=type=cache,target=/root/.cache/pip \
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \ if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
python3 -m pip install --index-url https://download.pytorch.org/whl/nightly/cu126 "torch==2.7.0.dev20250121+cu126" "torchvision==0.22.0.dev20250121"; \ python3 -m pip install --index-url https://download.pytorch.org/whl/nightly/cu124 "torch==2.6.0.dev20241210+cu124" "torchvision==0.22.0.dev20241215"; \
fi fi
COPY requirements-common.txt requirements-common.txt COPY requirements-common.txt requirements-common.txt
@ -126,8 +126,8 @@ RUN --mount=type=cache,target=/root/.cache/ccache \
# Check the size of the wheel if RUN_WHEEL_CHECK is true # Check the size of the wheel if RUN_WHEEL_CHECK is true
COPY .buildkite/check-wheel-size.py check-wheel-size.py COPY .buildkite/check-wheel-size.py check-wheel-size.py
# sync the default value with .buildkite/check-wheel-size.py # Default max size of the wheel is 250MB
ARG VLLM_MAX_SIZE_MB=400 ARG VLLM_MAX_SIZE_MB=250
ENV VLLM_MAX_SIZE_MB=$VLLM_MAX_SIZE_MB ENV VLLM_MAX_SIZE_MB=$VLLM_MAX_SIZE_MB
ARG RUN_WHEEL_CHECK=true ARG RUN_WHEEL_CHECK=true
RUN if [ "$RUN_WHEEL_CHECK" = "true" ]; then \ RUN if [ "$RUN_WHEEL_CHECK" = "true" ]; then \
@ -149,8 +149,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \
#################### vLLM installation IMAGE #################### #################### vLLM installation IMAGE ####################
# image with vLLM installed # image with vLLM installed
# TODO: Restore to base image after FlashInfer AOT wheel fixed FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu22.04 AS vllm-base
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 AS vllm-base
ARG CUDA_VERSION=12.4.1 ARG CUDA_VERSION=12.4.1
ARG PYTHON_VERSION=3.12 ARG PYTHON_VERSION=3.12
WORKDIR /vllm-workspace WORKDIR /vllm-workspace
@ -195,30 +194,12 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
--mount=type=cache,target=/root/.cache/pip \ --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install dist/*.whl --verbose python3 -m pip install dist/*.whl --verbose
# How to build this FlashInfer wheel:
# $ export FLASHINFER_ENABLE_AOT=1
# $ # Note we remove 7.0 from the arch list compared to the list below, since FlashInfer only supports sm75+
# $ export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.6 8.9 9.0+PTX'
# $ git clone https://github.com/flashinfer-ai/flashinfer.git --recursive
# $ cd flashinfer
# $ git checkout 524304395bd1d8cd7d07db083859523fcaa246a4
# $ python3 setup.py bdist_wheel --dist-dir=dist --verbose
RUN --mount=type=cache,target=/root/.cache/pip \ RUN --mount=type=cache,target=/root/.cache/pip \
. /etc/environment && \ . /etc/environment && \
if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \ if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
python3 -m pip install https://wheels.vllm.ai/flashinfer/524304395bd1d8cd7d07db083859523fcaa246a4/flashinfer_python-0.2.0.post1-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl; \ python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.6/flashinfer-0.1.6+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl; \
fi fi
COPY examples examples COPY examples examples
# Although we build Flashinfer with AOT mode, there's still
# some issues w.r.t. JIT compilation. Therefore we need to
# install build dependencies for JIT compilation.
# TODO: Remove this once FlashInfer AOT wheel is fixed
COPY requirements-build.txt requirements-build.txt
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install -r requirements-build.txt
#################### vLLM installation IMAGE #################### #################### vLLM installation IMAGE ####################
#################### TEST IMAGE #################### #################### TEST IMAGE ####################
@ -269,7 +250,7 @@ ENV VLLM_USAGE_SOURCE production-docker-image
# define sagemaker first, so it is not default from `docker build` # define sagemaker first, so it is not default from `docker build`
FROM vllm-openai-base AS vllm-sagemaker FROM vllm-openai-base AS vllm-sagemaker
COPY examples/online_serving/sagemaker-entrypoint.sh . COPY examples/sagemaker-entrypoint.sh .
RUN chmod +x sagemaker-entrypoint.sh RUN chmod +x sagemaker-entrypoint.sh
ENTRYPOINT ["./sagemaker-entrypoint.sh"] ENTRYPOINT ["./sagemaker-entrypoint.sh"]

View File

@ -26,10 +26,10 @@ RUN pip install intel_extension_for_pytorch==2.5.0
WORKDIR /workspace WORKDIR /workspace
COPY requirements-build.txt requirements-build.txt
ARG PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" ARG PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu"
ENV PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL} ENV PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}
RUN --mount=type=cache,target=/root/.cache/pip \ RUN --mount=type=cache,target=/root/.cache/pip \
--mount=type=bind,src=requirements-build.txt,target=requirements-build.txt \
pip install --upgrade pip && \ pip install --upgrade pip && \
pip install -r requirements-build.txt pip install -r requirements-build.txt
@ -37,9 +37,9 @@ FROM cpu-test-1 AS build
WORKDIR /workspace/vllm WORKDIR /workspace/vllm
COPY requirements-common.txt requirements-common.txt
COPY requirements-cpu.txt requirements-cpu.txt
RUN --mount=type=cache,target=/root/.cache/pip \ RUN --mount=type=cache,target=/root/.cache/pip \
--mount=type=bind,src=requirements-common.txt,target=requirements-common.txt \
--mount=type=bind,src=requirements-cpu.txt,target=requirements-cpu.txt \
pip install -v -r requirements-cpu.txt pip install -v -r requirements-cpu.txt
COPY . . COPY . .

View File

@ -1,4 +1,4 @@
FROM vault.habana.ai/gaudi-docker/1.19.1/ubuntu22.04/habanalabs/pytorch-installer-2.5.1:latest FROM vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest
COPY ./ /workspace/vllm COPY ./ /workspace/vllm

View File

@ -14,7 +14,6 @@ ARG GIT_REPO_CHECK=0
RUN --mount=type=bind,source=.git,target=.git \ RUN --mount=type=bind,source=.git,target=.git \
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi
RUN python3 -m pip install -U pip
# install build requirements # install build requirements
RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" python3 -m pip install -r /workspace/requirements-build.txt RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" python3 -m pip install -r /workspace/requirements-build.txt
# build vLLM with OpenVINO backend # build vLLM with OpenVINO backend

View File

@ -4,12 +4,12 @@ USER root
ENV PATH="/usr/local/cargo/bin:$PATH:/opt/conda/bin/" ENV PATH="/usr/local/cargo/bin:$PATH:/opt/conda/bin/"
RUN apt-get update -y && apt-get install -y git wget kmod curl vim libnuma-dev libsndfile-dev libprotobuf-dev build-essential ffmpeg libsm6 libxext6 libgl1 libssl-dev RUN apt-get update -y && apt-get install -y git wget curl vim libnuma-dev libsndfile-dev libprotobuf-dev build-essential ffmpeg libsm6 libxext6 libgl1
# Some packages in requirements-cpu are installed here # Some packages in requirements-cpu are installed here
# IBM provides optimized packages for ppc64le processors in the open-ce project for mamba # IBM provides optimized packages for ppc64le processors in the open-ce project for mamba
# Currently these may not be available for venv or pip directly # Currently these may not be available for venv or pip directly
RUN micromamba install -y -n base -c https://ftp.osuosl.org/pub/open-ce/1.11.0-p10/ -c defaults python=3.10 rust && micromamba clean --all --yes RUN micromamba install -y -n base -c https://ftp.osuosl.org/pub/open-ce/1.11.0-p10/ -c defaults python=3.10 torchvision-cpu=0.16.2 rust && micromamba clean --all --yes
COPY ./ /workspace/vllm COPY ./ /workspace/vllm
@ -18,9 +18,11 @@ ARG GIT_REPO_CHECK=0
RUN --mount=type=bind,source=.git,target=.git \ RUN --mount=type=bind,source=.git,target=.git \
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi
# These packages will be in rocketce eventually
RUN --mount=type=cache,target=/root/.cache/pip \ RUN --mount=type=cache,target=/root/.cache/pip \
RUSTFLAGS='-L /opt/conda/lib' pip install -v --prefer-binary --extra-index-url https://repo.fury.io/mgiessing \ pip install -v --prefer-binary --extra-index-url https://repo.fury.io/mgiessing \
'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \ 'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \
torch==2.3.1 \
-r requirements-cpu.txt \ -r requirements-cpu.txt \
xformers uvloop==0.20.0 xformers uvloop==0.20.0

View File

@ -1,119 +1,174 @@
# default base image # Default ROCm 6.2 base image
ARG REMOTE_VLLM="0" ARG BASE_IMAGE="rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.3.0"
ARG USE_CYTHON="0"
ARG BUILD_RPD="1"
ARG COMMON_WORKDIR=/app
ARG BASE_IMAGE=rocm/vllm-dev:base
FROM ${BASE_IMAGE} AS base # Default ROCm ARCHes to build vLLM for.
ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100"
ARG ARG_PYTORCH_ROCM_ARCH # Whether to install CK-based flash-attention
ENV PYTORCH_ROCM_ARCH=${ARG_PYTORCH_ROCM_ARCH:-${PYTORCH_ROCM_ARCH}} # If 0, will not install flash-attention
ARG BUILD_FA="1"
ARG FA_GFX_ARCHS="gfx90a;gfx942"
ARG FA_BRANCH="3cea2fb"
# Whether to build triton on rocm
ARG BUILD_TRITON="1"
ARG TRITON_BRANCH="e192dba"
### Base image build stage
FROM $BASE_IMAGE AS base
# Import arg(s) defined before this build stage
ARG PYTORCH_ROCM_ARCH
# Install some basic utilities # Install some basic utilities
RUN apt-get update -q -y && apt-get install -q -y \ RUN apt-get update && apt-get install python3 python3-pip -y
sqlite3 libsqlite3-dev libfmt-dev libmsgpack-dev libsuitesparse-dev RUN apt-get update && apt-get install -y \
# Remove sccache curl \
RUN python3 -m pip install --upgrade pip && pip install setuptools_scm ca-certificates \
sudo \
git \
bzip2 \
libx11-6 \
build-essential \
wget \
unzip \
tmux \
ccache \
&& rm -rf /var/lib/apt/lists/*
# When launching the container, mount the code directory to /vllm-workspace
ARG APP_MOUNT=/vllm-workspace
WORKDIR ${APP_MOUNT}
RUN python3 -m pip install --upgrade pip
# Remove sccache so it doesn't interfere with ccache
# TODO: implement sccache support across components
RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(which sccache)" RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(which sccache)"
ARG COMMON_WORKDIR
WORKDIR ${COMMON_WORKDIR}
# Install torch == 2.6.0 on ROCm
# ----------------------- RUN --mount=type=cache,target=/root/.cache/pip \
# vLLM fetch stages case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
FROM base AS fetch_vllm_0 *"rocm-6.2"*) \
ONBUILD COPY ./ vllm/ python3 -m pip uninstall -y torch torchvision \
FROM base AS fetch_vllm_1 && python3 -m pip install --pre \
ARG VLLM_REPO="https://github.com/vllm-project/vllm.git" torch==2.6.0.dev20241113+rocm6.2 \
ARG VLLM_BRANCH="main" 'setuptools-scm>=8' \
ONBUILD RUN git clone ${VLLM_REPO} \ torchvision==0.20.0.dev20241113+rocm6.2 \
&& cd vllm \ --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2;; \
&& git checkout ${VLLM_BRANCH}
FROM fetch_vllm_${REMOTE_VLLM} AS fetch_vllm
# -----------------------
# vLLM build stages
FROM fetch_vllm AS build_vllm
ARG USE_CYTHON
# Build vLLM
RUN cd vllm \
&& python3 -m pip install -r requirements-rocm.txt \
&& python3 setup.py clean --all \
&& if [ ${USE_CYTHON} -eq "1" ]; then python3 setup_cython.py build_ext --inplace; fi \
&& python3 setup.py bdist_wheel --dist-dir=dist
FROM scratch AS export_vllm
ARG COMMON_WORKDIR
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/dist/*.whl /
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/requirements*.txt /
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/benchmarks /benchmarks
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/tests /tests
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/examples /examples
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/.buildkite /.buildkite
# -----------------------
# Test vLLM image
FROM base AS test
RUN python3 -m pip install --upgrade pip && rm -rf /var/lib/apt/lists/*
# Install vLLM
RUN --mount=type=bind,from=export_vllm,src=/,target=/install \
cd /install \
&& pip install -U -r requirements-rocm.txt \
&& pip uninstall -y vllm \
&& pip install *.whl
WORKDIR /vllm-workspace
ARG COMMON_WORKDIR
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm /vllm-workspace
# install development dependencies (for testing)
RUN cd /vllm-workspace \
&& rm -rf vllm \
&& python3 -m pip install -e tests/vllm_test_utils \
&& python3 -m pip install lm-eval[api]==0.4.4 \
&& python3 -m pip install pytest-shard
# -----------------------
# Final vLLM image
FROM base AS final
RUN python3 -m pip install --upgrade pip && rm -rf /var/lib/apt/lists/*
# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
# Manually remove it so that later steps of numpy upgrade can continue
RUN case "$(which python3)" in \
*"/opt/conda/envs/py_3.9"*) \
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/;; \
*) ;; esac *) ;; esac
RUN python3 -m pip install --upgrade huggingface-hub[cli] ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
ARG BUILD_RPD ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin:
RUN if [ ${BUILD_RPD} -eq "1" ]; then \ ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
git clone -b nvtx_enabled https://github.com/ROCm/rocmProfileData.git \ ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:
&& cd rocmProfileData/rpd_tracer \
&& pip install -r requirements.txt && cd ../ \
&& make && make install \
&& cd hipMarker && python3 setup.py install ; fi
# Install vLLM ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
RUN --mount=type=bind,from=export_vllm,src=/,target=/install \ ENV CCACHE_DIR=/root/.cache/ccache
cd /install \
&& pip install -U -r requirements-rocm.txt \
&& pip uninstall -y vllm \
&& pip install *.whl
ARG COMMON_WORKDIR
# Copy over the benchmark scripts as well ### AMD-SMI build stage
COPY --from=export_vllm /benchmarks ${COMMON_WORKDIR}/vllm/benchmarks FROM base AS build_amdsmi
COPY --from=export_vllm /examples ${COMMON_WORKDIR}/vllm/examples # Build amdsmi wheel always
RUN cd /opt/rocm/share/amd_smi \
&& python3 -m pip wheel . --wheel-dir=/install
### Flash-Attention wheel build stage
FROM base AS build_fa
ARG BUILD_FA
ARG FA_GFX_ARCHS
ARG FA_BRANCH
# Build ROCm flash-attention wheel if `BUILD_FA = 1`
RUN --mount=type=cache,target=${CCACHE_DIR} \
if [ "$BUILD_FA" = "1" ]; then \
mkdir -p libs \
&& cd libs \
&& git clone https://github.com/ROCm/flash-attention.git \
&& cd flash-attention \
&& git checkout "${FA_BRANCH}" \
&& git submodule update --init \
&& GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \
# Create an empty directory otherwise as later build stages expect one
else mkdir -p /install; \
fi
### Triton wheel build stage
FROM base AS build_triton
ARG BUILD_TRITON
ARG TRITON_BRANCH
# Build triton wheel if `BUILD_TRITON = 1`
RUN --mount=type=cache,target=${CCACHE_DIR} \
if [ "$BUILD_TRITON" = "1" ]; then \
mkdir -p libs \
&& cd libs \
&& python3 -m pip install ninja cmake wheel pybind11 \
&& git clone https://github.com/OpenAI/triton.git \
&& cd triton \
&& git checkout "${TRITON_BRANCH}" \
&& cd python \
&& python3 setup.py bdist_wheel --dist-dir=/install; \
# Create an empty directory otherwise as later build stages expect one
else mkdir -p /install; \
fi
### Final vLLM build stage
FROM base AS final
# Import the vLLM development directory from the build context
COPY . .
ARG GIT_REPO_CHECK=0
RUN --mount=type=bind,source=.git,target=.git \
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi
RUN python3 -m pip install --upgrade pip
# Package upgrades for useful functionality or to avoid dependency issues
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install --upgrade numba scipy huggingface-hub[cli] pytest-shard
# Workaround for ray >= 2.10.0
ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1 ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1
# Silences the HF Tokenizers warning
ENV TOKENIZERS_PARALLELISM=false ENV TOKENIZERS_PARALLELISM=false
# Performance environment variable. RUN --mount=type=cache,target=${CCACHE_DIR} \
ENV HIP_FORCE_DEV_KERNARG=1 --mount=type=bind,source=.git,target=.git \
--mount=type=cache,target=/root/.cache/pip \
python3 -m pip install -Ur requirements-rocm.txt \
&& python3 setup.py clean --all \
&& python3 setup.py develop
# Copy amdsmi wheel into final image
RUN --mount=type=bind,from=build_amdsmi,src=/install,target=/install \
mkdir -p libs \
&& cp /install/*.whl libs \
# Preemptively uninstall to avoid same-version no-installs
&& python3 -m pip uninstall -y amdsmi;
# Copy triton wheel(s) into final image if they were built
RUN --mount=type=bind,from=build_triton,src=/install,target=/install \
mkdir -p libs \
&& if ls /install/*.whl; then \
cp /install/*.whl libs \
# Preemptively uninstall to avoid same-version no-installs
&& python3 -m pip uninstall -y triton; fi
# Copy flash-attn wheel(s) into final image if they were built
RUN --mount=type=bind,from=build_fa,src=/install,target=/install \
mkdir -p libs \
&& if ls /install/*.whl; then \
cp /install/*.whl libs \
# Preemptively uninstall to avoid same-version no-installs
&& python3 -m pip uninstall -y flash-attn; fi
# Install wheels that were built to the final image
RUN --mount=type=cache,target=/root/.cache/pip \
if ls libs/*.whl; then \
python3 -m pip install libs/*.whl; fi
# install development dependencies (for testing)
RUN python3 -m pip install -e tests/vllm_test_utils
CMD ["/bin/bash"] CMD ["/bin/bash"]

View File

@ -1,158 +0,0 @@
ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:6.3.1-complete
ARG HIPBLASLT_BRANCH="4d40e36"
ARG HIPBLAS_COMMON_BRANCH="7c1566b"
ARG LEGACY_HIPBLASLT_OPTION=
ARG RCCL_BRANCH="648a58d"
ARG RCCL_REPO="https://github.com/ROCm/rccl"
ARG TRITON_BRANCH="e5be006"
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
ARG PYTORCH_BRANCH="8d4926e"
ARG PYTORCH_VISION_BRANCH="v0.19.1"
ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
ARG FA_BRANCH="b7d29fb"
ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
FROM ${BASE_IMAGE} AS base
ENV PATH=/opt/rocm/llvm/bin:$PATH
ENV ROCM_PATH=/opt/rocm
ENV LD_LIBRARY_PATH=/opt/rocm/lib:/usr/local/lib:
ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942
ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
ARG PYTHON_VERSION=3.12
RUN mkdir -p /app
WORKDIR /app
ENV DEBIAN_FRONTEND=noninteractive
# Install Python and other dependencies
RUN apt-get update -y \
&& apt-get install -y software-properties-common git curl sudo vim less \
&& add-apt-repository ppa:deadsnakes/ppa \
&& apt-get update -y \
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
python${PYTHON_VERSION}-lib2to3 python-is-python3 \
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
&& update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \
&& ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \
&& curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \
&& python3 --version && python3 -m pip --version
RUN pip install -U packaging cmake ninja wheel setuptools pybind11 Cython
FROM base AS build_hipblaslt
ARG HIPBLASLT_BRANCH
ARG HIPBLAS_COMMON_BRANCH
# Set to "--legacy_hipblas_direct" for ROCm<=6.2
ARG LEGACY_HIPBLASLT_OPTION
RUN git clone https://github.com/ROCm/hipBLAS-common.git
RUN cd hipBLAS-common \
&& git checkout ${HIPBLAS_COMMON_BRANCH} \
&& mkdir build \
&& cd build \
&& cmake .. \
&& make package \
&& dpkg -i ./*.deb
RUN git clone https://github.com/ROCm/hipBLASLt
RUN cd hipBLASLt \
&& git checkout ${HIPBLASLT_BRANCH} \
&& ./install.sh -d --architecture ${PYTORCH_ROCM_ARCH} ${LEGACY_HIPBLASLT_OPTION} \
&& cd build/release \
&& make package
RUN mkdir -p /app/install && cp /app/hipBLASLt/build/release/*.deb /app/hipBLAS-common/build/*.deb /app/install
FROM base AS build_rccl
ARG RCCL_BRANCH
ARG RCCL_REPO
RUN git clone ${RCCL_REPO}
RUN cd rccl \
&& git checkout ${RCCL_BRANCH} \
&& ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH}
RUN mkdir -p /app/install && cp /app/rccl/build/release/*.deb /app/install
FROM base AS build_triton
ARG TRITON_BRANCH
ARG TRITON_REPO
RUN git clone ${TRITON_REPO}
RUN cd triton \
&& git checkout ${TRITON_BRANCH} \
&& cd python \
&& python3 setup.py bdist_wheel --dist-dir=dist
RUN mkdir -p /app/install && cp /app/triton/python/dist/*.whl /app/install
FROM base AS build_amdsmi
RUN cd /opt/rocm/share/amd_smi \
&& pip wheel . --wheel-dir=dist
RUN mkdir -p /app/install && cp /opt/rocm/share/amd_smi/dist/*.whl /app/install
FROM base AS build_pytorch
ARG PYTORCH_BRANCH
ARG PYTORCH_VISION_BRANCH
ARG PYTORCH_REPO
ARG PYTORCH_VISION_REPO
ARG FA_BRANCH
ARG FA_REPO
RUN git clone ${PYTORCH_REPO} pytorch
RUN cd pytorch && git checkout ${PYTORCH_BRANCH} && \
pip install -r requirements.txt && git submodule update --init --recursive \
&& python3 tools/amd_build/build_amd.py \
&& CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist \
&& pip install dist/*.whl
RUN git clone ${PYTORCH_VISION_REPO} vision
RUN cd vision && git checkout ${PYTORCH_VISION_BRANCH} \
&& python3 setup.py bdist_wheel --dist-dir=dist \
&& pip install dist/*.whl
RUN git clone ${FA_REPO}
RUN cd flash-attention \
&& git checkout ${FA_BRANCH} \
&& git submodule update --init \
&& MAX_JOBS=64 GPU_ARCHS=${PYTORCH_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist
RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \
&& cp /app/vision/dist/*.whl /app/install \
&& cp /app/flash-attention/dist/*.whl /app/install
FROM base AS final
RUN --mount=type=bind,from=build_hipblaslt,src=/app/install/,target=/install \
dpkg -i /install/*deb \
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status
RUN --mount=type=bind,from=build_rccl,src=/app/install/,target=/install \
dpkg -i /install/*deb \
&& sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \
&& sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status
RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \
pip install /install/*.whl
RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
pip install /install/*.whl
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
pip install /install/*.whl
ARG BASE_IMAGE
ARG HIPBLASLT_BRANCH
ARG LEGACY_HIPBLASLT_OPTION
ARG RCCL_BRANCH
ARG RCCL_REPO
ARG TRITON_BRANCH
ARG TRITON_REPO
ARG PYTORCH_BRANCH
ARG PYTORCH_VISION_BRANCH
ARG PYTORCH_REPO
ARG PYTORCH_VISION_REPO
ARG FA_BRANCH
ARG FA_REPO
RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \
&& echo "HIPBLAS_COMMON_BRANCH: ${HIPBLAS_COMMON_BRANCH}" >> /app/versions.txt \
&& echo "HIPBLASLT_BRANCH: ${HIPBLASLT_BRANCH}" >> /app/versions.txt \
&& echo "LEGACY_HIPBLASLT_OPTION: ${LEGACY_HIPBLASLT_OPTION}" >> /app/versions.txt \
&& echo "RCCL_BRANCH: ${RCCL_BRANCH}" >> /app/versions.txt \
&& echo "RCCL_REPO: ${RCCL_REPO}" >> /app/versions.txt \
&& echo "TRITON_BRANCH: ${TRITON_BRANCH}" >> /app/versions.txt \
&& echo "TRITON_REPO: ${TRITON_REPO}" >> /app/versions.txt \
&& echo "PYTORCH_BRANCH: ${PYTORCH_BRANCH}" >> /app/versions.txt \
&& echo "PYTORCH_VISION_BRANCH: ${PYTORCH_VISION_BRANCH}" >> /app/versions.txt \
&& echo "PYTORCH_REPO: ${PYTORCH_REPO}" >> /app/versions.txt \
&& echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \
&& echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \
&& echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt

View File

@ -1,4 +1,4 @@
ARG NIGHTLY_DATE="20250124" ARG NIGHTLY_DATE="20241017"
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE" ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"
FROM $BASE_IMAGE FROM $BASE_IMAGE

View File

@ -10,14 +10,12 @@ Easy, fast, and cheap LLM serving for everyone
</h3> </h3>
<p align="center"> <p align="center">
| <a href="https://docs.vllm.ai"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://x.com/vllm_project"><b>Twitter/X</b></a> | <a href="https://slack.vllm.ai"><b>Developer Slack</b></a> | | <a href="https://docs.vllm.ai"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://discord.gg/jz7wjKhh6g"><b>Discord</b></a> | <a href="https://x.com/vllm_project"><b>Twitter/X</b></a> | <a href="https://slack.vllm.ai"><b>Developer Slack</b></a> |
</p> </p>
--- ---
*Latest News* 🔥 *Latest News* 🔥
- [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html).
- [2025/01] We hosted [the eighth vLLM meetup](https://lu.ma/zep56hui) with Google Cloud! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1epVkt4Zu8Jz_S5OhEHPc798emsYh2BwYfRuDDVEF7u4/edit?usp=sharing).
- [2024/12] vLLM joins [pytorch ecosystem](https://pytorch.org/blog/vllm-joins-pytorch)! Easy, Fast, and Cheap LLM Serving for Everyone! - [2024/12] vLLM joins [pytorch ecosystem](https://pytorch.org/blog/vllm-joins-pytorch)! Easy, Fast, and Cheap LLM Serving for Everyone!
- [2024/11] We hosted [the seventh vLLM meetup](https://lu.ma/h0qvrajz) with Snowflake! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1e3CxQBV3JsfGp30SwyvS3eM_tW-ghOhJ9PAJGK6KR54/edit?usp=sharing), and Snowflake team [here](https://docs.google.com/presentation/d/1qF3RkDAbOULwz9WK5TOltt2fE9t6uIc_hVNLFAaQX6A/edit?usp=sharing). - [2024/11] We hosted [the seventh vLLM meetup](https://lu.ma/h0qvrajz) with Snowflake! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1e3CxQBV3JsfGp30SwyvS3eM_tW-ghOhJ9PAJGK6KR54/edit?usp=sharing), and Snowflake team [here](https://docs.google.com/presentation/d/1qF3RkDAbOULwz9WK5TOltt2fE9t6uIc_hVNLFAaQX6A/edit?usp=sharing).
- [2024/10] We have just created a developer slack ([slack.vllm.ai](https://slack.vllm.ai)) focusing on coordinating contributions and discussing features. Please feel free to join us there! - [2024/10] We have just created a developer slack ([slack.vllm.ai](https://slack.vllm.ai)) focusing on coordinating contributions and discussing features. Please feel free to join us there!
@ -36,12 +34,10 @@ Easy, fast, and cheap LLM serving for everyone
## About ## About
vLLM is a fast and easy-to-use library for LLM inference and serving. vLLM is a fast and easy-to-use library for LLM inference and serving.
Originally developed in the [Sky Computing Lab](https://sky.cs.berkeley.edu) at UC Berkeley, vLLM has evolved into a community-driven project with contributions from both academia and industry.
vLLM is fast with: vLLM is fast with:
- State-of-the-art serving throughput - State-of-the-art serving throughput
- Efficient management of attention key and value memory with [**PagedAttention**](https://blog.vllm.ai/2023/06/20/vllm.html) - Efficient management of attention key and value memory with **PagedAttention**
- Continuous batching of incoming requests - Continuous batching of incoming requests
- Fast model execution with CUDA/HIP graph - Fast model execution with CUDA/HIP graph
- Quantizations: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), INT4, INT8, and FP8. - Quantizations: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), INT4, INT8, and FP8.
@ -72,16 +68,16 @@ Find the full list of supported models [here](https://docs.vllm.ai/en/latest/mod
## Getting Started ## Getting Started
Install vLLM with `pip` or [from source](https://docs.vllm.ai/en/latest/getting_started/installation/gpu/index.html#build-wheel-from-source): Install vLLM with `pip` or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
```bash ```bash
pip install vllm pip install vllm
``` ```
Visit our [documentation](https://docs.vllm.ai/en/latest/) to learn more. Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to learn more.
- [Installation](https://docs.vllm.ai/en/latest/getting_started/installation/index.html) - [Installation](https://vllm.readthedocs.io/en/latest/getting_started/installation.html)
- [Quickstart](https://docs.vllm.ai/en/latest/getting_started/quickstart.html) - [Quickstart](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html)
- [List of Supported Models](https://docs.vllm.ai/en/latest/models/supported_models.html) - [List of Supported Models](https://vllm.readthedocs.io/en/latest/models/supported_models.html)
## Contributing ## Contributing
@ -94,33 +90,28 @@ vLLM is a community project. Our compute resources for development and testing a
<!-- Note: Please sort them in alphabetical order. --> <!-- Note: Please sort them in alphabetical order. -->
<!-- Note: Please keep these consistent with docs/source/community/sponsors.md --> <!-- Note: Please keep these consistent with docs/source/community/sponsors.md -->
Cash Donations:
- a16z
- Dropbox
- Sequoia Capital
- Skywork AI
- ZhenFund
Compute Resources: - a16z
- AMD - AMD
- Anyscale - Anyscale
- AWS - AWS
- Crusoe Cloud - Crusoe Cloud
- Databricks - Databricks
- DeepInfra - DeepInfra
- Dropbox
- Google Cloud - Google Cloud
- Lambda Lab - Lambda Lab
- Nebius - Nebius
- Novita AI
- NVIDIA - NVIDIA
- Replicate - Replicate
- Roblox - Roblox
- RunPod - RunPod
- Sequoia Capital
- Skywork AI
- Trainy - Trainy
- UC Berkeley - UC Berkeley
- UC San Diego - UC San Diego
- ZhenFund
Slack Sponsor: Anyscale
We also have an official fundraising venue through [OpenCollective](https://opencollective.com/vllm). We plan to use the fund to support the development, maintenance, and adoption of vLLM. We also have an official fundraising venue through [OpenCollective](https://opencollective.com/vllm). We plan to use the fund to support the development, maintenance, and adoption of vLLM.
@ -139,7 +130,8 @@ If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs
## Contact Us ## Contact Us
* For technical questions and feature requests, please use Github issues or discussions. * For technical questions and feature requests, please use Github issues or discussions.
* For discussing with fellow users and coordinating contributions and development, please use Slack. * For discussing with fellow users, please use Discord.
* For coordinating contributions and development, please use Slack.
* For security disclosures, please use Github's security advisory feature. * For security disclosures, please use Github's security advisory feature.
* For collaborations and partnerships, please contact us at vllm-questions AT lists.berkeley.edu. * For collaborations and partnerships, please contact us at vllm-questions AT lists.berkeley.edu.

View File

@ -4,7 +4,7 @@
If you believe you have found a security vulnerability in vLLM, we encourage you to let us know right away. We will investigate all legitimate reports and do our best to quickly fix the problem. If you believe you have found a security vulnerability in vLLM, we encourage you to let us know right away. We will investigate all legitimate reports and do our best to quickly fix the problem.
Please report security issues privately using [the vulnerability submission form](https://github.com/vllm-project/vllm/security/advisories/new). Reports will then be triaged by the [vulnerability management team](https://docs.vllm.ai/en/latest/contributing/vulnerability_management.html). Please report security issues privately using [the vulnerability submission form](https://github.com/vllm-project/vllm/security/advisories/new). Reports will then be triaged by the [vulnerability management team](https://docs.vllm.ai/contributing/vulnerability_management/).
--- ---

View File

@ -1,5 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
import json import json
import os import os
import sys import sys
@ -24,7 +22,6 @@ class RequestFuncInput:
prompt_len: int prompt_len: int
output_len: int output_len: int
model: str model: str
model_name: Optional[str] = None
best_of: int = 1 best_of: int = 1
logprobs: Optional[int] = None logprobs: Optional[int] = None
extra_body: Optional[dict] = None extra_body: Optional[dict] = None
@ -37,7 +34,6 @@ class RequestFuncOutput:
generated_text: str = "" generated_text: str = ""
success: bool = False success: bool = False
latency: float = 0.0 latency: float = 0.0
output_tokens: int = 0
ttft: float = 0.0 # Time to first token ttft: float = 0.0 # Time to first token
itl: List[float] = field( itl: List[float] = field(
default_factory=list) # List of inter-token latencies default_factory=list) # List of inter-token latencies
@ -53,8 +49,7 @@ async def async_request_tgi(
api_url = request_func_input.api_url api_url = request_func_input.api_url
assert api_url.endswith("generate_stream") assert api_url.endswith("generate_stream")
async with aiohttp.ClientSession(trust_env=True, async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
timeout=AIOHTTP_TIMEOUT) as session:
params = { params = {
"best_of": request_func_input.best_of, "best_of": request_func_input.best_of,
"max_new_tokens": request_func_input.output_len, "max_new_tokens": request_func_input.output_len,
@ -83,7 +78,7 @@ async def async_request_tgi(
continue continue
chunk_bytes = chunk_bytes.decode("utf-8") chunk_bytes = chunk_bytes.decode("utf-8")
# NOTE: Sometimes TGI returns a ping response without #NOTE: Sometimes TGI returns a ping response without
# any data, we should skip it. # any data, we should skip it.
if chunk_bytes.startswith(":"): if chunk_bytes.startswith(":"):
continue continue
@ -126,8 +121,7 @@ async def async_request_trt_llm(
api_url = request_func_input.api_url api_url = request_func_input.api_url
assert api_url.endswith("generate_stream") assert api_url.endswith("generate_stream")
async with aiohttp.ClientSession(trust_env=True, async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
timeout=AIOHTTP_TIMEOUT) as session:
assert request_func_input.best_of == 1 assert request_func_input.best_of == 1
payload = { payload = {
"accumulate_tokens": True, "accumulate_tokens": True,
@ -161,7 +155,7 @@ async def async_request_trt_llm(
timestamp = time.perf_counter() timestamp = time.perf_counter()
# First token # First token
if ttft == 0.0: if ttft == 0.0:
ttft = timestamp - st ttft = time.perf_counter() - st
output.ttft = ttft output.ttft = ttft
# Decoding phase # Decoding phase
@ -191,8 +185,7 @@ async def async_request_deepspeed_mii(
request_func_input: RequestFuncInput, request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None, pbar: Optional[tqdm] = None,
) -> RequestFuncOutput: ) -> RequestFuncOutput:
async with aiohttp.ClientSession(trust_env=True, async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
timeout=AIOHTTP_TIMEOUT) as session:
assert request_func_input.best_of == 1 assert request_func_input.best_of == 1
payload = { payload = {
@ -240,23 +233,17 @@ async def async_request_openai_completions(
("completions", "profile") ("completions", "profile")
), "OpenAI Completions API URL must end with 'completions' or 'profile'." ), "OpenAI Completions API URL must end with 'completions' or 'profile'."
async with aiohttp.ClientSession(trust_env=True, async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
timeout=AIOHTTP_TIMEOUT) as session:
payload = { payload = {
"model": request_func_input.model_name \ "model": request_func_input.model,
if request_func_input.model_name else request_func_input.model,
"prompt": request_func_input.prompt, "prompt": request_func_input.prompt,
"temperature": 0.0, "temperature": 0.0,
"best_of": request_func_input.best_of, "best_of": request_func_input.best_of,
"max_tokens": request_func_input.output_len, "max_tokens": request_func_input.output_len,
"logprobs": request_func_input.logprobs, "logprobs": request_func_input.logprobs,
"stream": True, "stream": True,
"stream_options": { "ignore_eos": request_func_input.ignore_eos,
"include_usage": True,
},
} }
if request_func_input.ignore_eos:
payload["ignore_eos"] = request_func_input.ignore_eos
if request_func_input.extra_body: if request_func_input.extra_body:
payload.update(request_func_input.extra_body) payload.update(request_func_input.extra_body)
headers = { headers = {
@ -267,6 +254,7 @@ async def async_request_openai_completions(
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
generated_text = "" generated_text = ""
ttft = 0.0
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
@ -281,16 +269,15 @@ async def async_request_openai_completions(
chunk = chunk_bytes.decode("utf-8").removeprefix( chunk = chunk_bytes.decode("utf-8").removeprefix(
"data: ") "data: ")
if chunk != "[DONE]": if chunk == "[DONE]":
latency = time.perf_counter() - st
else:
data = json.loads(chunk) data = json.loads(chunk)
# NOTE: Some completion API might have a last # NOTE: Some completion API might have a last
# usage summary response without a token so we # usage summary response without a token so we
# want to check a token was generated # want to check a token was generated
if choices := data.get("choices"): if data["choices"][0]["text"]:
# Note that text could be empty here
# e.g. for special tokens
text = choices[0].get("text")
timestamp = time.perf_counter() timestamp = time.perf_counter()
# First token # First token
if not first_chunk_received: if not first_chunk_received:
@ -304,10 +291,7 @@ async def async_request_openai_completions(
most_recent_timestamp) most_recent_timestamp)
most_recent_timestamp = timestamp most_recent_timestamp = timestamp
generated_text += text or "" generated_text += data["choices"][0]["text"]
elif usage := data.get("usage"):
output.output_tokens = usage.get(
"completion_tokens")
if first_chunk_received: if first_chunk_received:
output.success = True output.success = True
else: else:
@ -316,7 +300,7 @@ async def async_request_openai_completions(
"Never received a valid chunk to calculate TTFT." "Never received a valid chunk to calculate TTFT."
"This response will be marked as failed!") "This response will be marked as failed!")
output.generated_text = generated_text output.generated_text = generated_text
output.latency = most_recent_timestamp - st output.latency = latency
else: else:
output.error = response.reason or "" output.error = response.reason or ""
output.success = False output.success = False
@ -339,14 +323,12 @@ async def async_request_openai_chat_completions(
"chat/completions" "chat/completions"
), "OpenAI Chat Completions API URL must end with 'chat/completions'." ), "OpenAI Chat Completions API URL must end with 'chat/completions'."
async with aiohttp.ClientSession(trust_env=True, async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
timeout=AIOHTTP_TIMEOUT) as session:
content = [{"type": "text", "text": request_func_input.prompt}] content = [{"type": "text", "text": request_func_input.prompt}]
if request_func_input.multi_modal_content: if request_func_input.multi_modal_content:
content.append(request_func_input.multi_modal_content) content.append(request_func_input.multi_modal_content)
payload = { payload = {
"model": request_func_input.model_name \ "model": request_func_input.model,
if request_func_input.model_name else request_func_input.model,
"messages": [ "messages": [
{ {
"role": "user", "role": "user",
@ -356,12 +338,8 @@ async def async_request_openai_chat_completions(
"temperature": 0.0, "temperature": 0.0,
"max_completion_tokens": request_func_input.output_len, "max_completion_tokens": request_func_input.output_len,
"stream": True, "stream": True,
"stream_options": { "ignore_eos": request_func_input.ignore_eos,
"include_usage": True,
},
} }
if request_func_input.ignore_eos:
payload["ignore_eos"] = request_func_input.ignore_eos
if request_func_input.extra_body: if request_func_input.extra_body:
payload.update(request_func_input.extra_body) payload.update(request_func_input.extra_body)
headers = { headers = {
@ -387,15 +365,17 @@ async def async_request_openai_chat_completions(
chunk = chunk_bytes.decode("utf-8").removeprefix( chunk = chunk_bytes.decode("utf-8").removeprefix(
"data: ") "data: ")
if chunk != "[DONE]": if chunk == "[DONE]":
latency = time.perf_counter() - st
else:
timestamp = time.perf_counter() timestamp = time.perf_counter()
data = json.loads(chunk) data = json.loads(chunk)
if choices := data.get("choices"): delta = data["choices"][0]["delta"]
content = choices[0]["delta"].get("content") if delta.get("content", None):
# First token # First token
if ttft == 0.0: if ttft == 0.0:
ttft = timestamp - st ttft = time.perf_counter() - st
output.ttft = ttft output.ttft = ttft
# Decoding phase # Decoding phase
@ -403,16 +383,13 @@ async def async_request_openai_chat_completions(
output.itl.append(timestamp - output.itl.append(timestamp -
most_recent_timestamp) most_recent_timestamp)
generated_text += content or "" generated_text += delta["content"]
elif usage := data.get("usage"):
output.output_tokens = usage.get(
"completion_tokens")
most_recent_timestamp = timestamp most_recent_timestamp = timestamp
output.generated_text = generated_text output.generated_text = generated_text
output.success = True output.success = True
output.latency = most_recent_timestamp - st output.latency = latency
else: else:
output.error = response.reason or "" output.error = response.reason or ""
output.success = False output.success = False
@ -440,35 +417,14 @@ def get_model(pretrained_model_name_or_path: str) -> str:
def get_tokenizer( def get_tokenizer(
pretrained_model_name_or_path: str, pretrained_model_name_or_path: str, trust_remote_code: bool
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
**kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
if pretrained_model_name_or_path is not None and not os.path.exists( if pretrained_model_name_or_path is not None and not os.path.exists(
pretrained_model_name_or_path): pretrained_model_name_or_path):
pretrained_model_name_or_path = get_model( pretrained_model_name_or_path = get_model(
pretrained_model_name_or_path) pretrained_model_name_or_path)
if tokenizer_mode == "slow": return AutoTokenizer.from_pretrained(pretrained_model_name_or_path,
if kwargs.get("use_fast", False): trust_remote_code=trust_remote_code)
raise ValueError(
"Cannot use the fast tokenizer in slow tokenizer mode.")
kwargs["use_fast"] = False
if tokenizer_mode == "mistral":
try:
from vllm.transformers_utils.tokenizer import MistralTokenizer
except ImportError as e:
raise ImportError("MistralTokenizer requires vllm package.\n"
"Please install it with `pip install vllm` "
"to use mistral tokenizer mode.") from e
return MistralTokenizer.from_pretrained(
str(pretrained_model_name_or_path))
else:
return AutoTokenizer.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
**kwargs,
)
ASYNC_REQUEST_FUNCS = { ASYNC_REQUEST_FUNCS = {

View File

@ -1,4 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
"""Benchmark guided decoding throughput.""" """Benchmark guided decoding throughput."""
import argparse import argparse
import dataclasses import dataclasses

View File

@ -1,4 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
"""Benchmark the latency of processing a single batch of requests.""" """Benchmark the latency of processing a single batch of requests."""
import argparse import argparse
import dataclasses import dataclasses
@ -14,7 +13,6 @@ from tqdm import tqdm
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptType from vllm.inputs import PromptType
from vllm.sampling_params import BeamSearchParams
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
@ -42,20 +40,6 @@ def main(args: argparse.Namespace):
"prompt_token_ids": batch "prompt_token_ids": batch
} for batch in dummy_prompt_token_ids.tolist()] } for batch in dummy_prompt_token_ids.tolist()]
def llm_generate():
if not args.use_beam_search:
llm.generate(dummy_prompts,
sampling_params=sampling_params,
use_tqdm=False)
else:
llm.beam_search(
dummy_prompts,
BeamSearchParams(
beam_width=args.n,
max_tokens=args.output_len,
ignore_eos=True,
))
def run_to_completion(profile_dir: Optional[str] = None): def run_to_completion(profile_dir: Optional[str] = None):
if profile_dir: if profile_dir:
with torch.profiler.profile( with torch.profiler.profile(
@ -65,11 +49,15 @@ def main(args: argparse.Namespace):
], ],
on_trace_ready=torch.profiler.tensorboard_trace_handler( on_trace_ready=torch.profiler.tensorboard_trace_handler(
str(profile_dir))) as p: str(profile_dir))) as p:
llm_generate() llm.generate(dummy_prompts,
print(p.key_averages().table(sort_by="self_cuda_time_total")) sampling_params=sampling_params,
use_tqdm=False)
print(p.key_averages())
else: else:
start_time = time.perf_counter() start_time = time.perf_counter()
llm_generate() llm.generate(dummy_prompts,
sampling_params=sampling_params,
use_tqdm=False)
end_time = time.perf_counter() end_time = time.perf_counter()
latency = end_time - start_time latency = end_time - start_time
return latency return latency

View File

@ -1,9 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
""" """
Offline benchmark to test the long document QA throughput. Offline benchmark to test the long document QA throughput.
Example usage: Example usage:
# This workload samples 8 different prompts with a default input # This command run the vllm with 50GB CPU memory for offloading
# The workload samples 8 different prompts with a default input
# length of 20000 tokens, then replicates each prompt 2 times # length of 20000 tokens, then replicates each prompt 2 times
# in random order. # in random order.
python benchmark_long_document_qa_throughput.py \ python benchmark_long_document_qa_throughput.py \

View File

@ -1,4 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
""" """
Benchmark the efficiency of prefix caching. Benchmark the efficiency of prefix caching.
@ -11,8 +10,7 @@ Fixed example usage:
--model meta-llama/Llama-2-7b-chat-hf \ --model meta-llama/Llama-2-7b-chat-hf \
--enable-prefix-caching \ --enable-prefix-caching \
--num-prompts 1 \ --num-prompts 1 \
--repeat-count 100 \ --repeat-count 100
--input-length-range 128:256
ShareGPT example usage: ShareGPT example usage:
# This command samples 20 prompts with input lengths # This command samples 20 prompts with input lengths

View File

@ -1,4 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
"""Benchmark offline prioritization.""" """Benchmark offline prioritization."""
import argparse import argparse
import dataclasses import dataclasses

View File

@ -1,4 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
r"""Benchmark online serving throughput. r"""Benchmark online serving throughput.
On the server side, run one of the following commands: On the server side, run one of the following commands:
@ -26,7 +25,6 @@ On the client side, run:
import argparse import argparse
import asyncio import asyncio
import base64 import base64
import gc
import io import io
import json import json
import os import os
@ -201,7 +199,7 @@ def sample_sonnet_requests(
return sampled_requests return sampled_requests
def sample_vision_arena_requests( def sample_mmmu_pro_vision_requests(
dataset, dataset,
num_requests: int, num_requests: int,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
@ -213,7 +211,13 @@ def sample_vision_arena_requests(
if len(sampled_requests) == num_requests: if len(sampled_requests) == num_requests:
break break
prompt = data["turns"][0][0]['content'] # MMMU-Pro vision direct prompt
# Ref: https://github.com/MMMU-Benchmark/MMMU/blob/6ce42f4d8f70c1841c67867152648974415b5cac/mmmu-pro/prompts.yaml#L5
prompt = (
"Answer with the option letter from the given choices directly. "
"The last line of your response should be of the following "
"format: 'Answer: $LETTER' (without quotes) where LETTER is one of "
"options.")
prompt_token_ids = tokenizer(prompt).input_ids prompt_token_ids = tokenizer(prompt).input_ids
if fixed_output_len is None: if fixed_output_len is None:
@ -225,10 +229,10 @@ def sample_vision_arena_requests(
output_len = fixed_output_len output_len = fixed_output_len
assert isinstance( assert isinstance(
data["images"][0], data["image"],
Image), ("Input image format must be `PIL.Image.Image`, " Image), ("Input image format must be `PIL.Image.Image`, "
f"given {type(data['image'])}.") f"given {type(data['image'])}.")
image: Image = data["images"][0] image: Image = data["image"]
image = image.convert("RGB") image = image.convert("RGB")
image_data = io.BytesIO() image_data = io.BytesIO()
image.save(image_data, format='JPEG') image.save(image_data, format='JPEG')
@ -247,7 +251,7 @@ def sample_vision_arena_requests(
def sample_hf_requests( def sample_hf_requests(
dataset_path: str, dataset_path: str,
dataset_subset: Optional[str], dataset_subset: str,
dataset_split: str, dataset_split: str,
num_requests: int, num_requests: int,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
@ -255,17 +259,19 @@ def sample_hf_requests(
fixed_output_len: Optional[int] = None, fixed_output_len: Optional[int] = None,
) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]: ) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]:
# Special case for vision_arena dataset # Special case for MMMU-Pro vision dataset
if dataset_path == 'lmarena-ai/vision-arena-bench-v0.1' \ if dataset_path == 'MMMU/MMMU_Pro' and dataset_subset == 'vision':
and dataset_subset is None: assert dataset_split == "test"
assert dataset_split == "train"
dataset = load_dataset(dataset_path, dataset = load_dataset(dataset_path,
name=dataset_subset, name=dataset_subset,
split=dataset_split, split=dataset_split,
streaming=True) streaming=True)
dataset = dataset.shuffle(seed=random_seed) assert "image" in dataset.features, (
return sample_vision_arena_requests(dataset, num_requests, tokenizer, "MMMU/MMMU_Pro vision dataset must have 'image' column.")
fixed_output_len) filter_func = lambda x: isinstance(x["image"], Image)
dataset = dataset.shuffle(seed=random_seed).filter(filter_func)
return sample_mmmu_pro_vision_requests(dataset, num_requests,
tokenizer, fixed_output_len)
dataset = load_dataset(dataset_path, dataset = load_dataset(dataset_path,
name=dataset_subset, name=dataset_subset,
@ -417,7 +423,7 @@ def calculate_metrics(
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
selected_percentile_metrics: List[str], selected_percentile_metrics: List[str],
selected_percentiles: List[float], selected_percentiles: List[float],
goodput_config_dict: Dict[str, float], gootput_config_dict: Dict[str, float],
) -> Tuple[BenchmarkMetrics, List[int]]: ) -> Tuple[BenchmarkMetrics, List[int]]:
actual_output_lens: List[int] = [] actual_output_lens: List[int] = []
total_input = 0 total_input = 0
@ -430,23 +436,19 @@ def calculate_metrics(
e2els: List[float] = [] e2els: List[float] = []
for i in range(len(outputs)): for i in range(len(outputs)):
if outputs[i].success: if outputs[i].success:
output_len = outputs[i].output_tokens # We use the tokenizer to count the number of output tokens for all
# serving backends instead of looking at len(outputs[i].itl) since
if output_len is None: # multiple output tokens may be bundled together
# We use the tokenizer to count the number of output tokens # Note : this may inflate the output token count slightly
# for some serving backends instead of looking at output_len = len(
# len(outputs[i].itl) since multiple output tokens may be tokenizer(outputs[i].generated_text,
# bundled together add_special_tokens=False).input_ids)
# Note : this may inflate the output token count slightly
output_len = len(
tokenizer(outputs[i].generated_text,
add_special_tokens=False).input_ids)
actual_output_lens.append(output_len) actual_output_lens.append(output_len)
total_input += input_requests[i][1] total_input += input_requests[i][1]
tpot = 0 tpot = 0
if output_len > 1: if output_len > 1:
latency_minus_ttft = outputs[i].latency - outputs[i].ttft tpot = (outputs[i].latency - outputs[i].ttft) / (output_len -
tpot = latency_minus_ttft / (output_len - 1) 1)
tpots.append(tpot) tpots.append(tpot)
# Note: if output_len <= 1, we regard tpot as 0 for goodput # Note: if output_len <= 1, we regard tpot as 0 for goodput
all_tpots.append(tpot) all_tpots.append(tpot)
@ -457,21 +459,21 @@ def calculate_metrics(
else: else:
actual_output_lens.append(0) actual_output_lens.append(0)
if goodput_config_dict: if gootput_config_dict:
valid_metrics = [] valid_metrics = []
slo_values = [] slo_values = []
if "ttft" in goodput_config_dict: if "ttft" in gootput_config_dict:
valid_metrics.append(ttfts) valid_metrics.append(ttfts)
slo_values.append(goodput_config_dict["ttft"] / slo_values.append(gootput_config_dict["ttft"] /
MILLISECONDS_TO_SECONDS_CONVERSION) MILLISECONDS_TO_SECONDS_CONVERSION)
if "tpot" in goodput_config_dict: if "tpot" in gootput_config_dict:
valid_metrics.append(all_tpots) valid_metrics.append(all_tpots)
slo_values.append(goodput_config_dict["tpot"] / slo_values.append(gootput_config_dict["tpot"] /
MILLISECONDS_TO_SECONDS_CONVERSION) MILLISECONDS_TO_SECONDS_CONVERSION)
if "e2el" in goodput_config_dict: if "e2el" in gootput_config_dict:
valid_metrics.append(e2els) valid_metrics.append(e2els)
slo_values.append(goodput_config_dict["e2el"] / slo_values.append(gootput_config_dict["e2el"] /
MILLISECONDS_TO_SECONDS_CONVERSION) MILLISECONDS_TO_SECONDS_CONVERSION)
for req_metric in zip(*valid_metrics): for req_metric in zip(*valid_metrics):
@ -523,7 +525,6 @@ async def benchmark(
api_url: str, api_url: str,
base_url: str, base_url: str,
model_id: str, model_id: str,
model_name: str,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
input_requests: List[Tuple[str, int, int]], input_requests: List[Tuple[str, int, int]],
logprobs: Optional[int], logprobs: Optional[int],
@ -535,7 +536,7 @@ async def benchmark(
selected_percentile_metrics: List[str], selected_percentile_metrics: List[str],
selected_percentiles: List[str], selected_percentiles: List[str],
ignore_eos: bool, ignore_eos: bool,
goodput_config_dict: Dict[str, float], gootput_config_dict: Dict[str, float],
max_concurrency: Optional[int], max_concurrency: Optional[int],
): ):
if backend in ASYNC_REQUEST_FUNCS: if backend in ASYNC_REQUEST_FUNCS:
@ -552,7 +553,6 @@ async def benchmark(
"Multi-modal content is only supported on 'openai-chat' backend.") "Multi-modal content is only supported on 'openai-chat' backend.")
test_input = RequestFuncInput( test_input = RequestFuncInput(
model=model_id, model=model_id,
model_name=model_name,
prompt=test_prompt, prompt=test_prompt,
api_url=api_url, api_url=api_url,
prompt_len=test_prompt_len, prompt_len=test_prompt_len,
@ -573,7 +573,6 @@ async def benchmark(
if profile: if profile:
print("Starting profiler...") print("Starting profiler...")
profile_input = RequestFuncInput(model=model_id, profile_input = RequestFuncInput(model=model_id,
model_name=model_name,
prompt=test_prompt, prompt=test_prompt,
api_url=base_url + "/start_profile", api_url=base_url + "/start_profile",
prompt_len=test_prompt_len, prompt_len=test_prompt_len,
@ -617,7 +616,6 @@ async def benchmark(
async for request in get_request(input_requests, request_rate, burstiness): async for request in get_request(input_requests, request_rate, burstiness):
prompt, prompt_len, output_len, mm_content = request prompt, prompt_len, output_len, mm_content = request
request_func_input = RequestFuncInput(model=model_id, request_func_input = RequestFuncInput(model=model_id,
model_name=model_name,
prompt=prompt, prompt=prompt,
api_url=api_url, api_url=api_url,
prompt_len=prompt_len, prompt_len=prompt_len,
@ -659,7 +657,7 @@ async def benchmark(
tokenizer=tokenizer, tokenizer=tokenizer,
selected_percentile_metrics=selected_percentile_metrics, selected_percentile_metrics=selected_percentile_metrics,
selected_percentiles=selected_percentiles, selected_percentiles=selected_percentiles,
goodput_config_dict=goodput_config_dict, gootput_config_dict=gootput_config_dict,
) )
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
@ -671,7 +669,7 @@ async def benchmark(
metrics.total_output)) metrics.total_output))
print("{:<40} {:<10.2f}".format("Request throughput (req/s):", print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
metrics.request_throughput)) metrics.request_throughput))
if goodput_config_dict: if gootput_config_dict:
print("{:<40} {:<10.2f}".format("Request goodput (req/s):", print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
metrics.request_goodput)) metrics.request_goodput))
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
@ -686,7 +684,7 @@ async def benchmark(
"total_output_tokens": metrics.total_output, "total_output_tokens": metrics.total_output,
"request_throughput": metrics.request_throughput, "request_throughput": metrics.request_throughput,
"request_goodput:": "request_goodput:":
metrics.request_goodput if goodput_config_dict else None, metrics.request_goodput if gootput_config_dict else None,
"output_throughput": metrics.output_throughput, "output_throughput": metrics.output_throughput,
"total_token_throughput": metrics.total_token_throughput, "total_token_throughput": metrics.total_token_throughput,
"input_lens": [output.prompt_len for output in outputs], "input_lens": [output.prompt_len for output in outputs],
@ -742,11 +740,11 @@ async def benchmark(
def check_goodput_args(args): def check_goodput_args(args):
# Check and parse goodput arguments # Check and parse goodput arguments
goodput_config_dict = {} gootput_config_dict = {}
VALID_NAMES = ["ttft", "tpot", "e2el"] VALID_NAMES = ["ttft", "tpot", "e2el"]
if args.goodput: if args.goodput:
goodput_config_dict = parse_goodput(args.goodput) gootput_config_dict = parse_goodput(args.goodput)
for slo_name, slo_val in goodput_config_dict.items(): for slo_name, slo_val in gootput_config_dict.items():
if slo_name not in VALID_NAMES: if slo_name not in VALID_NAMES:
raise ValueError( raise ValueError(
f"Invalid metric name found, {slo_name}: {slo_val}. " f"Invalid metric name found, {slo_name}: {slo_val}. "
@ -757,22 +755,22 @@ def check_goodput_args(args):
f"Invalid value found, {slo_name}: {slo_val}. " f"Invalid value found, {slo_name}: {slo_val}. "
"The service level objective value should be " "The service level objective value should be "
"non-negative.") "non-negative.")
return goodput_config_dict return gootput_config_dict
def parse_goodput(slo_pairs): def parse_goodput(slo_pairs):
goodput_config_dict = {} gootput_config_dict = {}
try: try:
for slo_pair in slo_pairs: for slo_pair in slo_pairs:
slo_name, slo_val = slo_pair.split(":") slo_name, slo_val = slo_pair.split(":")
goodput_config_dict[slo_name] = float(slo_val) gootput_config_dict[slo_name] = float(slo_val)
except ValueError as err: except ValueError as err:
raise argparse.ArgumentTypeError( raise argparse.ArgumentTypeError(
"Invalid format found for service level objectives. " "Invalid format found for service level objectives. "
"Specify service level objectives for goodput as \"KEY:VALUE\" " "Specify service level objectives for goodput as \"KEY:VALUE\" "
"pairs, where the key is a metric name, and the value is a " "pairs, where the key is a metric name, and the value is a "
"number in milliseconds.") from err "number in milliseconds.") from err
return goodput_config_dict return gootput_config_dict
def main(args: argparse.Namespace): def main(args: argparse.Namespace):
@ -782,7 +780,6 @@ def main(args: argparse.Namespace):
backend = args.backend backend = args.backend
model_id = args.model model_id = args.model
model_name = args.served_model_name
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
tokenizer_mode = args.tokenizer_mode tokenizer_mode = args.tokenizer_mode
@ -872,11 +869,7 @@ def main(args: argparse.Namespace):
else: else:
raise ValueError(f"Unknown dataset: {args.dataset_name}") raise ValueError(f"Unknown dataset: {args.dataset_name}")
goodput_config_dict = check_goodput_args(args) gootput_config_dict = check_goodput_args(args)
# Avoid GC processing "static" data - reduce pause times.
gc.collect()
gc.freeze()
benchmark_result = asyncio.run( benchmark_result = asyncio.run(
benchmark( benchmark(
@ -884,7 +877,6 @@ def main(args: argparse.Namespace):
api_url=api_url, api_url=api_url,
base_url=base_url, base_url=base_url,
model_id=model_id, model_id=model_id,
model_name=model_name,
tokenizer=tokenizer, tokenizer=tokenizer,
input_requests=input_requests, input_requests=input_requests,
logprobs=args.logprobs, logprobs=args.logprobs,
@ -898,7 +890,7 @@ def main(args: argparse.Namespace):
float(p) for p in args.metric_percentiles.split(",") float(p) for p in args.metric_percentiles.split(",")
], ],
ignore_eos=args.ignore_eos, ignore_eos=args.ignore_eos,
goodput_config_dict=goodput_config_dict, gootput_config_dict=gootput_config_dict,
max_concurrency=args.max_concurrency, max_concurrency=args.max_concurrency,
)) ))
@ -927,8 +919,8 @@ def main(args: argparse.Namespace):
) )
# Traffic # Traffic
result_json["request_rate"] = (args.request_rate if args.request_rate result_json["request_rate"] = (
< float("inf") else "inf") args.request_rate if args.request_rate < float("inf") else "inf")
result_json["burstiness"] = args.burstiness result_json["burstiness"] = args.burstiness
result_json["max_concurrency"] = args.max_concurrency result_json["max_concurrency"] = args.max_concurrency
@ -1230,12 +1222,5 @@ if __name__ == "__main__":
'always use the slow tokenizer. \n* ' 'always use the slow tokenizer. \n* '
'"mistral" will always use the `mistral_common` tokenizer.') '"mistral" will always use the `mistral_common` tokenizer.')
parser.add_argument("--served-model-name",
type=str,
default=None,
help="The model name used in the API. "
"If not specified, the model name will be the "
"same as the ``--model`` argument. ")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@ -1,4 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
r"""Benchmark online serving throughput with guided decoding. r"""Benchmark online serving throughput with guided decoding.
On the server side, run one of the following commands: On the server side, run one of the following commands:

View File

@ -1,4 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
"""Benchmark offline inference throughput.""" """Benchmark offline inference throughput."""
import argparse import argparse
import dataclasses import dataclasses

View File

@ -1,5 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
import argparse import argparse
import copy import copy
import itertools import itertools

View File

@ -1,5 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# Cutlass bench utils # Cutlass bench utils
from typing import Iterable, Tuple from typing import Iterable, Tuple

View File

@ -1,11 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
import argparse import argparse
import copy import copy
import itertools import itertools
import pickle as pkl import pickle as pkl
import time import time
from typing import Callable, Iterable, List, Optional, Tuple from typing import Callable, Iterable, List, Tuple
import torch import torch
import torch.utils.benchmark as TBenchmark import torch.utils.benchmark as TBenchmark
@ -14,8 +12,6 @@ from utils import make_rand_tensors
from weight_shapes import WEIGHT_SHAPES from weight_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
w8a8_block_fp8_matmul)
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
@ -42,15 +38,8 @@ def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
).blocked_autorange(min_run_time=min_run_time) ).blocked_autorange(min_run_time=min_run_time)
def bench_int8( def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
dtype: torch.dtype, sub_label: str) -> Iterable[TMeasurement]:
m: int,
k: int,
n: int,
label: str,
sub_label: str,
bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]:
"""Benchmark INT8-based kernels."""
assert dtype == torch.int8 assert dtype == torch.int8
a, b = make_rand_tensors(torch.int8, m, n, k) a, b = make_rand_tensors(torch.int8, m, n, k)
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
@ -59,132 +48,155 @@ def bench_int8(
azp = torch.zeros((m, ), device="cuda", dtype=torch.int32) azp = torch.zeros((m, ), device="cuda", dtype=torch.int32)
azp_adj = torch.zeros((n, ), device="cuda", dtype=torch.int32) azp_adj = torch.zeros((n, ), device="cuda", dtype=torch.int32)
bench_fns = {
"pytorch_bf16_bf16_bf16_matmul-no-scales":
lambda: torch.mm(a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
),
"pytorch_fp16_fp16_fp16_matmul-no-scales":
lambda: torch.mm(a.to(dtype=torch.float16), b.to(dtype=torch.float16)),
"cutlass_i8_i8_bf16_scaled_mm":
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16),
"cutlass_i8_i8_bf16_scaled_mm_bias":
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16,
bias),
"cutlass_i8_i8_bf16_scaled_mm_azp":
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
bfloat16, azp_adj),
"cutlass_i8_i8_bf16_scaled_mm_azp_bias":
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
bfloat16, azp_adj, None, bias),
"cutlass_i8_i8_bf16_scaled_mm_azp_pt":
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
bfloat16, azp_adj, azp),
"cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias":
lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch.
bfloat16, azp_adj, azp, bias),
}
timers = [] timers = []
for name, fn in bench_fns.items(): # pytorch impl - bfloat16
# If bench_kernels is None, run all. Otherwise, run only exact matches. timers.append(
if bench_kernels is None or name in bench_kernels: bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
print(f"Running {name}") torch.mm, a.to(dtype=torch.bfloat16),
timers.append(bench_fn(label, sub_label, name, fn)) b.to(dtype=torch.bfloat16)))
# pytorch impl - float16
timers.append(
bench_fn(label, sub_label,
"pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm,
a.to(dtype=torch.float16), b.to(dtype=torch.float16)))
# cutlass impl
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm",
ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
torch.bfloat16))
# cutlass with bias
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias",
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16,
bias))
# cutlass with azp per-tensor
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp",
ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
torch.bfloat16, azp_adj))
# cutlass with azp per-tensor + bias
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_bias",
ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
torch.bfloat16, azp_adj, None, bias))
# cutlass with azp per-token
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_pt",
ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
torch.bfloat16, azp_adj, azp))
# cutlass with azp per-token + bias
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias",
ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
torch.bfloat16, azp_adj, azp, bias))
return timers return timers
def bench_fp8( def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
dtype: torch.dtype, sub_label: str) -> Iterable[TMeasurement]:
m: int,
k: int,
n: int,
label: str,
sub_label: str,
bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]:
"""Benchmark FP8-based kernels."""
assert dtype == torch.float8_e4m3fn assert dtype == torch.float8_e4m3fn
a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k) a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
a_cont = a.contiguous()
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
block_scale_a = torch.rand((m, k // 128),
device="cuda",
dtype=torch.float32)
block_scale_b = torch.rand((k // 128, n // 128),
device="cuda",
dtype=torch.float32)
block_scale_a_M_major = block_scale_a.t().contiguous().t()
block_scale_b_K_major = block_scale_b.t().contiguous().t()
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
print(m, k, n)
bench_fns = {
"pytorch_bf16_bf16_bf16_matmul-no-scales":
lambda: torch.mm(a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
),
"pytorch_fp16_fp16_fp16_matmul-no-scales":
lambda: torch.mm(a.to(dtype=torch.float16), b.to(dtype=torch.float16)),
"pytorch_fp8_fp8_fp16_scaled_mm":
lambda: torch._scaled_mm(
a, b, scale_a, scale_b, out_dtype=torch.float16),
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum":
lambda: torch._scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=torch.float16,
use_fast_accum=True),
"pytorch_fp8_fp8_bf16_scaled_mm":
lambda: torch._scaled_mm(
a, b, scale_a, scale_b, out_dtype=torch.bfloat16),
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum":
lambda: torch._scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=torch.bfloat16,
use_fast_accum=True),
"cutlass_fp8_fp8_bf16_scaled_mm":
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16),
"cutlass_fp8_fp8_fp16_scaled_mm":
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.float16),
"cutlass_fp8_fp8_bf16_scaled_mm_bias":
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16,
bias),
"cutlass_fp8_fp8_fp16_scaled_mm_bias":
lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.float16,
bias.to(dtype=torch.float16)),
"triton_fp8_fp8_fp16_scaled_mm_blockwise":
lambda: w8a8_block_fp8_matmul(a_cont, b.t(), block_scale_a,
block_scale_b.t(), (128, 128)),
"cutlass_fp8_fp8_fp16_scaled_mm_blockwise":
lambda: ops.cutlass_scaled_mm(a, b, block_scale_a_M_major,
block_scale_b_K_major, torch.float16),
}
timers = [] timers = []
for name, fn in bench_fns.items():
# If bench_kernels is None, run all. Otherwise, run only exact matches. # pytorch impl w. bf16
if bench_kernels is None or name in bench_kernels: timers.append(
print(f"Running {name}") bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
timers.append(bench_fn(label, sub_label, name, fn)) torch.mm, a.to(dtype=torch.bfloat16, device="cuda"),
b.to(dtype=torch.bfloat16, device="cuda")))
# pytorch impl: bf16 output, without fp8 fast accum
timers.append(
bench_fn(label,
sub_label,
"pytorch_fp8_fp8_bf16_scaled_mm",
torch._scaled_mm,
a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.bfloat16))
# pytorch impl: bf16 output, with fp8 fast accum
timers.append(
bench_fn(label,
sub_label,
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum",
torch._scaled_mm,
a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.bfloat16,
use_fast_accum=True))
# pytorch impl: fp16 output, without fp8 fast accum
timers.append(
bench_fn(label,
sub_label,
"pytorch_fp8_fp8_fp16_scaled_mm",
torch._scaled_mm,
a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.float16))
# pytorch impl: fp16 output, with fp8 fast accum
timers.append(
bench_fn(label,
sub_label,
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum",
torch._scaled_mm,
a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.float16,
use_fast_accum=True))
# cutlass impl: bf16 output
timers.append(
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm",
ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
torch.bfloat16))
# cutlass impl: fp16 output
timers.append(
bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_mm",
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.float16))
# cutlass impl: bf16 output, with bias
timers.append(
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm_bias",
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16,
bias))
# cutlass impl: fp16 output, with bias
timers.append(
bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_mm_bias",
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.float16,
bias.to(dtype=torch.float16)))
return timers return timers
def bench(dtype: torch.dtype, def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str,
m: int, sub_label: str) -> Iterable[TMeasurement]:
k: int,
n: int,
label: str,
sub_label: str,
bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]:
if dtype == torch.int8: if dtype == torch.int8:
return bench_int8(dtype, m, k, n, label, sub_label, bench_kernels) return bench_int8(dtype, m, k, n, label, sub_label)
if dtype == torch.float8_e4m3fn: if dtype == torch.float8_e4m3fn:
return bench_fp8(dtype, m, k, n, label, sub_label, bench_kernels) return bench_fp8(dtype, m, k, n, label, sub_label)
raise ValueError("unsupported type") raise ValueError("unsupported type")
@ -195,22 +207,18 @@ def print_timers(timers: Iterable[TMeasurement]):
def run(dtype: torch.dtype, def run(dtype: torch.dtype,
MKNs: Iterable[Tuple[int, int, int]], MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]:
results = [] results = []
for m, k, n in MKNs: for m, k, n in MKNs:
timers = bench(dtype, timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm",
m, f"MKN=({m}x{k}x{n})")
k,
n,
f"scaled-{dtype}-gemm",
f"MKN=({m}x{k}x{n})",
bench_kernels=bench_kernels)
print_timers(timers) print_timers(timers)
results.extend(timers) results.extend(timers)
return results return results
# output makers
def make_output(data: Iterable[TMeasurement], def make_output(data: Iterable[TMeasurement],
MKNs: Iterable[Tuple[int, int, int]], MKNs: Iterable[Tuple[int, int, int]],
base_description: str, base_description: str,
@ -224,11 +232,15 @@ def make_output(data: Iterable[TMeasurement],
pkl.dump(data, f) pkl.dump(data, f)
# argparse runners
def run_square_bench(args): def run_square_bench(args):
dim_sizes = list( dim_sizes = list(
range(args.dim_start, args.dim_end + 1, args.dim_increment)) range(args.dim_start, args.dim_end + 1, args.dim_increment))
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
data = run(args.dtype, MKNs, bench_kernels=args.kernels) data = run(args.dtype, MKNs)
make_output(data, MKNs, f"square_bench-{args.dtype}") make_output(data, MKNs, f"square_bench-{args.dtype}")
@ -239,7 +251,8 @@ def run_range_bench(args):
Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
MKNs = list(zip(Ms, Ks, Ns)) MKNs = list(zip(Ms, Ks, Ns))
data = run(args.dtype, MKNs, bench_kernels=args.kernels) data = run(args.dtype, MKNs)
make_output(data, MKNs, f"range_bench-{args.dtype}") make_output(data, MKNs, f"range_bench-{args.dtype}")
@ -265,7 +278,7 @@ def run_model_bench(args):
for k, n in KNs: for k, n in KNs:
MKNs.append((m, k, n)) MKNs.append((m, k, n))
data = run(args.dtype, MKNs, bench_kernels=args.kernels) data = run(args.dtype, MKNs)
model_bench_data.append(data) model_bench_data.append(data)
# Print all results # Print all results
@ -315,15 +328,6 @@ Benchmark Cutlass GEMM.
type=to_torch_dtype, type=to_torch_dtype,
required=True, required=True,
help="Available options are ['int8', 'fp8']") help="Available options are ['int8', 'fp8']")
parser.add_argument(
"--kernels",
nargs="+",
type=str,
default=None,
help=
"Exact names of the kernels to benchmark. If not set, runs all kernels."
)
subparsers = parser.add_subparsers(dest="cmd") subparsers = parser.add_subparsers(dest="cmd")
square_parser = subparsers.add_parser("square_bench") square_parser = subparsers.add_parser("square_bench")
@ -358,4 +362,4 @@ Benchmark Cutlass GEMM.
model_parser.set_defaults(func=run_model_bench) model_parser.set_defaults(func=run_model_bench)
args = parser.parse_args() args = parser.parse_args()
args.func(args) args.func(args)

View File

@ -1,5 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# Weight Shapes are in the format # Weight Shapes are in the format
# ([K, N], TP_SPLIT_DIM) # ([K, N], TP_SPLIT_DIM)
# Example: # Example:

View File

@ -1,5 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
import os import os
import aiohttp import aiohttp

View File

@ -1,5 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio import asyncio
import itertools import itertools

View File

@ -1,5 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
import json import json
import matplotlib.pyplot as plt import matplotlib.pyplot as plt

View File

@ -1,5 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
import pickle as pkl import pickle as pkl
import time import time
from dataclasses import dataclass from dataclasses import dataclass

View File

@ -1,5 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
import os import os
import sys import sys
from typing import Optional from typing import Optional

View File

@ -1,5 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
import time import time
import torch import torch

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
import argparse import argparse
import copy import copy
import itertools import itertools

View File

@ -1,5 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List from typing import List
import torch import torch

View File

@ -1,9 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
import argparse import argparse
import time import time
from datetime import datetime from datetime import datetime
from itertools import product
from typing import Any, Dict, List, Tuple, TypedDict from typing import Any, Dict, List, Tuple, TypedDict
import ray import ray
@ -16,9 +13,6 @@ from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
FP8_DTYPE = torch.float8_e4m3fnuz if current_platform.is_rocm(
) else torch.float8_e4m3fn
class BenchmarkConfig(TypedDict): class BenchmarkConfig(TypedDict):
BLOCK_SIZE_M: int BLOCK_SIZE_M: int
@ -86,8 +80,8 @@ def benchmark_config(
a1_scale = torch.randn(1, dtype=torch.float32) a1_scale = torch.randn(1, dtype=torch.float32)
a2_scale = torch.randn(1, dtype=torch.float32) a2_scale = torch.randn(1, dtype=torch.float32)
w1 = w1.to(FP8_DTYPE) w1 = w1.to(torch.float8_e4m3fn)
w2 = w2.to(FP8_DTYPE) w2 = w2.to(torch.float8_e4m3fn)
input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32) input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)
@ -147,172 +141,28 @@ def benchmark_config(
return avg return avg
def get_rocm_tuning_space(use_fp16): def get_configs_compute_bound() -> List[Dict[str, int]]:
block_mn_range = [16, 32, 64, 128, 256] # Reduced search space for faster tuning.
block_k_range = [16, 32, 64, 128, 256] # TODO(woosuk): Increase the search space and use a performance model to
if not use_fp16: # prune the search space.
block_k_range.remove(16) # BLOCK_K=16 not supported for fp8
num_warps_range = [1, 2, 4, 8]
group_m_range = [1, 4, 8, 16, 32]
num_stage_range = [2]
waves_per_eu_range = [0]
matrix_instr_nonkdim_range = [16, 32] if use_fp16 else []
kpack_range = [1, 2] if use_fp16 else []
param_ranges = {
"BLOCK_SIZE_M": block_mn_range,
"BLOCK_SIZE_N": block_mn_range,
"BLOCK_SIZE_K": block_k_range,
"GROUP_SIZE_M": group_m_range,
"num_warps": num_warps_range,
"num_stages": num_stage_range,
"waves_per_eu": waves_per_eu_range,
}
if use_fp16:
param_ranges["matrix_instr_nonkdim"] = matrix_instr_nonkdim_range
param_ranges["kpack"] = kpack_range
return param_ranges
def get_configs_compute_bound(use_fp16) -> List[Dict[str, int]]:
configs: List[BenchmarkConfig] = [] configs: List[BenchmarkConfig] = []
for num_stages in [2, 3, 4, 5]:
if current_platform.is_rocm(): for block_m in [16, 32, 64, 128, 256]:
param_ranges = get_rocm_tuning_space(use_fp16) for block_k in [64, 128, 256]:
else: for block_n in [32, 64, 128, 256]:
# Reduced search space for faster tuning. for num_warps in [4, 8]:
# TODO(woosuk): Increase the search space and use a performance model to for group_size in [1, 16, 32, 64]:
# prune the search space. configs.append({
block_m_range = [16, 32, 64, 128, 256] "BLOCK_SIZE_M": block_m,
block_n_range = [32, 64, 128, 256] "BLOCK_SIZE_N": block_n,
block_k_range = [64, 128, 256] "BLOCK_SIZE_K": block_k,
num_warps_range = [4, 8] "GROUP_SIZE_M": group_size,
group_m_range = [1, 16, 32, 64] "num_warps": num_warps,
num_stage_range = [2, 3, 4, 5] "num_stages": num_stages,
})
param_ranges = {
"BLOCK_SIZE_M": block_m_range,
"BLOCK_SIZE_N": block_n_range,
"BLOCK_SIZE_K": block_k_range,
"GROUP_SIZE_M": group_m_range,
"num_warps": num_warps_range,
"num_stages": num_stage_range,
}
keys, values = zip(*param_ranges.items())
for config_values in product(*values):
config = dict(zip(keys, config_values))
configs.append(config)
return configs return configs
def prune_rocm_search_space(num_tokens, shard_intermediate_size, hidden_size,
search_space, is_fp16):
N1, K1 = shard_intermediate_size, hidden_size
N2, K2 = hidden_size, shard_intermediate_size // 2
pruned_space_1 = prune_rocm_configs(num_tokens * 2, N1, K1, search_space,
is_fp16)
pruned_space_2 = prune_rocm_configs(num_tokens * 2, N2, K2, search_space,
is_fp16)
search_space = merge_unique_dicts(pruned_space_1, pruned_space_2)
return search_space
# The following code is inspired by ROCm/Triton GEMM tuning script:
# https://github.com/ROCm/triton/blob/triton-mlir/scripts/amd/gemm/tune_gemm.py#L89
def prune_rocm_configs(M, N, K, configs, is_fp16=True):
pruned_configs = []
elemBytes_a = 2 if is_fp16 else 1
elemBytes_b = 2 if is_fp16 else 1
mfma = 16 if M < 32 or N < 32 else 32
# TODO (zhanglx): figure out the boundary between large and small gemms
large_gemm = False
if M >= 2048 and N >= 2048:
large_gemm = True
for config in configs:
BLOCK_SIZE_M = config.get("BLOCK_SIZE_M")
BLOCK_SIZE_N = config.get("BLOCK_SIZE_N")
BLOCK_SIZE_K = config.get("BLOCK_SIZE_K")
num_warps = config.get("num_warps")
if is_fp16:
matrix_instr_nonkdim = config.get("matrix_instr_nonkdim")
if matrix_instr_nonkdim > mfma:
continue
if mfma == 4 and BLOCK_SIZE_K < 64:
continue
# some layouts could not work properly in case
# number elements per thread is less 1
if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
continue
SPLIT_K = config.get("SPLIT_K", 1)
GROUP_M = config.get("GROUP_SIZE_M")
if is_fp16:
if (matrix_instr_nonkdim > BLOCK_SIZE_M
or matrix_instr_nonkdim > BLOCK_SIZE_N):
continue
if (matrix_instr_nonkdim >= M
and matrix_instr_nonkdim != BLOCK_SIZE_M):
continue
if (matrix_instr_nonkdim >= N
and matrix_instr_nonkdim != BLOCK_SIZE_N):
continue
# Skip BLOCK_SIZE that is too large compare to M/N
# unless BLOCK_SIZE is already small enough
if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16:
continue
if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16:
continue
# skip large split_k when not necessary
if SPLIT_K != 1 and not need_split_k(M, N, K):
continue
# skip split_k that leads to EVEN_K = false
leap = SPLIT_K * BLOCK_SIZE_K
modv = K % leap
if modv != 0:
continue
# skip large GROUP_M
if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1:
continue
# out of shared memory resource
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a +
BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b)
if LDS > 65536:
continue
# Skip small block sizes and num_warps for large gemm
# For fp16 and f8, we want to only use BLOCK_SIZE >= 64
if large_gemm:
if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64:
continue
if BLOCK_SIZE_K < 64:
continue
if num_warps < 4:
continue
pruned_configs.append(config)
return pruned_configs
def need_split_k(SIZE_M, SIZE_N, SIZE_K):
return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024
def merge_unique_dicts(list1, list2):
result = []
combined_list = list1.copy()
combined_list.extend(list2)
for dictionary in combined_list:
if dictionary not in result:
result.append(dictionary)
return result
@ray.remote(num_gpus=1) @ray.remote(num_gpus=1)
class BenchmarkWorker: class BenchmarkWorker:
@ -320,10 +170,6 @@ class BenchmarkWorker:
torch.set_default_device("cuda") torch.set_default_device("cuda")
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
self.seed = seed self.seed = seed
# Get the device ID to allocate tensors and kernels
# on the respective GPU. This is required for Ray to work
# correctly with multi-GPU tuning on the ROCm platform.
self.device_id = int(ray.get_gpu_ids()[0])
def benchmark( def benchmark(
self, self,
@ -345,13 +191,9 @@ class BenchmarkWorker:
op_config = get_moe_configs(num_experts, shard_intermediate_size // 2, op_config = get_moe_configs(num_experts, shard_intermediate_size // 2,
dtype_str) dtype_str)
if op_config is None: if op_config is None:
config = get_default_config(num_tokens, config = get_default_config(num_tokens, num_experts,
num_experts, shard_intermediate_size, hidden_size,
shard_intermediate_size, topk, dtype_str)
hidden_size,
topk,
dtype_str,
is_marlin=False)
else: else:
config = op_config[min(op_config.keys(), config = op_config[min(op_config.keys(),
key=lambda x: abs(x - num_tokens))] key=lambda x: abs(x - num_tokens))]
@ -375,33 +217,25 @@ class BenchmarkWorker:
) -> Dict[str, int]: ) -> Dict[str, int]:
best_config = None best_config = None
best_time = float("inf") best_time = float("inf")
if current_platform.is_rocm(): for config in tqdm(search_space):
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16) try:
search_space = prune_rocm_search_space(num_tokens, kernel_time = benchmark_config(config,
shard_intermediate_size, num_tokens,
hidden_size, search_space, num_experts,
is_fp16) shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a16,
num_iters=10)
except triton.runtime.autotuner.OutOfResources:
# Some configurations may be invalid and fail to compile.
continue
with torch.cuda.device(self.device_id): if kernel_time < best_time:
for config in tqdm(search_space): best_time = kernel_time
try: best_config = config
kernel_time = benchmark_config(config,
num_tokens,
num_experts,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a16,
num_iters=20)
except triton.runtime.autotuner.OutOfResources:
# Some configurations may be invalid and fail to compile.
continue
if kernel_time < best_time:
best_time = kernel_time
best_config = config
now = datetime.now() now = datetime.now()
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}") print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
assert best_config is not None assert best_config is not None
@ -410,27 +244,12 @@ class BenchmarkWorker:
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
return { return {
"BLOCK_SIZE_M": "BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
config["BLOCK_SIZE_M"], "BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
"BLOCK_SIZE_N": "BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
config["BLOCK_SIZE_N"], "GROUP_SIZE_M": config["GROUP_SIZE_M"],
"BLOCK_SIZE_K": "num_warps": config["num_warps"],
config["BLOCK_SIZE_K"], "num_stages": config["num_stages"],
"GROUP_SIZE_M":
config["GROUP_SIZE_M"],
"num_warps":
config["num_warps"],
"num_stages":
config["num_stages"],
**({
"waves_per_eu": config["waves_per_eu"]
} if "waves_per_eu" in config else {}),
**({
"matrix_instr_nonkdim": config["matrix_instr_nonkdim"]
} if "matrix_instr_nonkdim" in config else {}),
**({
"kpack": config["kpack"]
} if "kpack" in config else {}),
} }
@ -456,8 +275,7 @@ def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int,
def main(args: argparse.Namespace): def main(args: argparse.Namespace):
print(args) print(args)
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(args.model)
args.model, trust_remote_code=args.trust_remote_code)
if config.architectures[0] == "DbrxForCausalLM": if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k topk = config.ffn_config.moe_top_k
@ -468,11 +286,6 @@ def main(args: argparse.Namespace):
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] == "DeepseekV3ForCausalLM":
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
else: else:
# Default: Mixtral. # Default: Mixtral.
E = config.num_local_experts E = config.num_local_experts
@ -481,7 +294,7 @@ def main(args: argparse.Namespace):
shard_intermediate_size = 2 * intermediate_size // args.tp_size shard_intermediate_size = 2 * intermediate_size // args.tp_size
hidden_size = config.hidden_size hidden_size = config.hidden_size
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype dtype = config.torch_dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16" use_int8_w8a16 = args.dtype == "int8_w8a16"
@ -509,8 +322,7 @@ def main(args: argparse.Namespace):
return ray.get(outputs) return ray.get(outputs)
if args.tune: if args.tune:
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16) search_space = get_configs_compute_bound()
search_space = get_configs_compute_bound(is_fp16)
print(f"Start tuning over {len(search_space)} configurations...") print(f"Start tuning over {len(search_space)} configurations...")
start = time.time() start = time.time()
@ -542,11 +354,7 @@ if __name__ == "__main__":
parser.add_argument("--model", parser.add_argument("--model",
type=str, type=str,
default="mistralai/Mixtral-8x7B-Instruct-v0.1") default="mistralai/Mixtral-8x7B-Instruct-v0.1")
parser.add_argument("--tp-size", parser.add_argument("--tp-size", "-tp", type=int, default=2)
"-tp",
"--tensor-parallel-size",
type=int,
default=2)
parser.add_argument("--dtype", parser.add_argument("--dtype",
type=str, type=str,
choices=["auto", "fp8_w8a8", "int8_w8a16"], choices=["auto", "fp8_w8a8", "int8_w8a16"],
@ -554,7 +362,6 @@ if __name__ == "__main__":
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, required=False) parser.add_argument("--batch-size", type=int, required=False)
parser.add_argument("--tune", action="store_true") parser.add_argument("--tune", action="store_true")
parser.add_argument("--trust-remote-code", action="store_true")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@ -1,5 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
import random import random
import time import time
from typing import List, Optional from typing import List, Optional
@ -100,9 +98,7 @@ def main(
start_time = time.perf_counter() start_time = time.perf_counter()
# Using default kv_scale # Using default kv_scale
k_scale = v_scale = torch.tensor(1.0, k_scale = v_scale = 1.0
dtype=torch.float32,
device=device)
for _ in range(num_iters): for _ in range(num_iters):
if version == "v1": if version == "v1":

View File

@ -1,5 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
import time import time
import torch import torch

View File

@ -1,5 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
import itertools import itertools
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union

View File

@ -1,5 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
from itertools import accumulate from itertools import accumulate
from typing import List, Optional from typing import List, Optional

View File

@ -1,5 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
WEIGHT_SHAPES = { WEIGHT_SHAPES = {
"ideal": [[4 * 256 * 32, 256 * 32]], "ideal": [[4 * 256 * 32, 256 * 32]],
"mistralai/Mistral-7B-v0.1/TP1": [ "mistralai/Mistral-7B-v0.1/TP1": [

View File

@ -1,5 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
import math import math
import pickle import pickle
import re import re

View File

@ -1,212 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
import dataclasses
from typing import Any, Callable, Iterable, Optional
import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
@dataclasses.dataclass
class CudaGraphBenchParams:
num_ops_in_cuda_graph: int
@dataclasses.dataclass
class ArgPool:
"""
When some argument of the benchmarking function is annotated with this type,
the benchmarking class (BenchMM) will collapse the argument to a pick a
single value from the given list of values, during function invocation.
For every invocation during a benchmarking run, it will choose a
different value from the list.
"""
values: Iterable[Any]
def __getitem__(self, index):
return self.values[index]
class Bench:
class ArgsIterator:
def __init__(self, args_list, kwargs_list):
assert len(args_list) == len(kwargs_list)
self.args_list = args_list
self.kwargs_list = kwargs_list
self.n = len(self.args_list)
self.idx = 0
def __next__(self):
while True:
yield (self.args_list[self.idx], self.kwargs_list[self.idx])
self.idx += 1
self.idx = self.idx % self.n
def reset(self):
self.idx = 0
@property
def n_args(self):
return self.n
def __init__(self, cuda_graph_params: Optional[CudaGraphBenchParams],
label: str, sub_label: str, description: str, fn: Callable,
*args, **kwargs):
self.cuda_graph_params = cuda_graph_params
self.use_cuda_graph = self.cuda_graph_params is not None
self.label = label
self.sub_label = sub_label
self.description = description
self.fn = fn
# Process args
self._args = args
self._kwargs = kwargs
self.args_list, self.kwargs_list = self.collapse_argpool(
*args, **kwargs)
self.args_iterator = self.ArgsIterator(self.args_list,
self.kwargs_list)
# Cudagraph runner
self.g = None
if self.use_cuda_graph:
self.g = self.get_cuda_graph_runner()
# benchmark run params
self.min_run_time = 1
def collapse_argpool(self, *args, **kwargs):
argpool_args = [arg for arg in args if isinstance(arg, ArgPool)] + [
arg for arg in kwargs.values() if isinstance(arg, ArgPool)
]
if len(argpool_args) == 0:
return [args], [kwargs]
# Make sure all argpools are of the same size
argpool_size = len(argpool_args[0].values)
assert all([argpool_size == len(arg.values) for arg in argpool_args])
# create copies of the args
args_list = []
kwargs_list = []
for _ in range(argpool_size):
args_list.append(args)
kwargs_list.append(kwargs.copy())
for i in range(argpool_size):
# collapse args; Just pick the ith value
args_list[i] = tuple([
arg[i] if isinstance(arg, ArgPool) else arg
for arg in args_list[i]
])
# collapse kwargs
kwargs_i = kwargs_list[i]
arg_pool_keys = [
k for k, v in kwargs_i.items() if isinstance(v, ArgPool)
]
for k in arg_pool_keys:
# again just pick the ith value
kwargs_i[k] = kwargs_i[k][i]
kwargs_list[i] = kwargs_i
return args_list, kwargs_list
def get_cuda_graph_runner(self):
assert self.use_cuda_graph
assert self.args_iterator is not None
num_graph_ops = self.cuda_graph_params.num_ops_in_cuda_graph
# warmup
args_it = self.args_iterator.__next__()
for _ in range(2):
args, kwargs = next(args_it)
self.fn(*args, **kwargs)
self.args_iterator.reset()
args_it = self.args_iterator.__next__()
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
for _ in range(num_graph_ops):
args, kwargs = next(args_it)
self.fn(*args, **kwargs)
return g
def run_cudagrah(self) -> TMeasurement:
assert self.use_cuda_graph
globals = {'g': self.g}
return TBenchmark.Timer(
stmt="g.replay()",
globals=globals,
label=(
f"{self.label}"
f" | cugraph {self.cuda_graph_params.num_ops_in_cuda_graph} ops"
),
sub_label=self.sub_label,
description=self.description,
).blocked_autorange(min_run_time=self.min_run_time)
def run_eager(self) -> TMeasurement:
setup = None
stmt = None
globals = None
has_arg_pool = self.args_iterator.n_args > 1
if has_arg_pool:
setup = '''
args_iterator.reset()
args_it = args_iterator.__next__()
'''
stmt = '''
args, kwargs = next(args_it)
fn(*args, **kwargs)
'''
globals = {'fn': self.fn, 'args_iterator': self.args_iterator}
else:
# no arg pool. Just use the args and kwargs directly
self.args_iterator.reset()
args_it = self.args_iterator.__next__()
args, kwargs = next(args_it)
setup = ""
stmt = '''
fn(*args, **kwargs)
'''
globals = {'fn': self.fn, 'args': args, 'kwargs': kwargs}
return TBenchmark.Timer(
stmt=stmt,
setup=setup,
globals=globals,
label=self.label,
sub_label=self.sub_label,
description=self.description,
).blocked_autorange(min_run_time=self.min_run_time)
def run(self) -> TMeasurement:
timer = None
if self.use_cuda_graph: # noqa SIM108
timer = self.run_cudagrah()
else:
timer = self.run_eager()
if not timer.meets_confidence() or timer.has_warnings:
print("Doesn't meet confidence - re-running bench ...")
return self.run()
return timer
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
if exc_type:
print(f"exc type {exc_type}")
print(f"exc value {exc_value}")
print(f"exc traceback {traceback}")

View File

@ -1,5 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# Weight Shapes are in the format # Weight Shapes are in the format
# ([K, N], TP_SPLIT_DIM) # ([K, N], TP_SPLIT_DIM)
# Example: # Example:

View File

@ -1,5 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
import cProfile import cProfile
import pstats import pstats

View File

@ -4,11 +4,6 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS ON) set(CMAKE_CXX_EXTENSIONS ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
set(MACOSX_FOUND TRUE)
endif()
# #
# Define environment variables for special configurations # Define environment variables for special configurations
# #
@ -18,9 +13,6 @@ endif()
include_directories("${CMAKE_SOURCE_DIR}/csrc") include_directories("${CMAKE_SOURCE_DIR}/csrc")
set (ENABLE_NUMA TRUE)
# #
# Check the compile flags # Check the compile flags
# #
@ -30,28 +22,18 @@ if (CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64")
"-mf16c" "-mf16c"
) )
endif() endif()
list(APPEND CXX_COMPILE_FLAGS
"-fopenmp"
"-DVLLM_CPU_EXTENSION")
if(MACOSX_FOUND) execute_process(COMMAND cat /proc/cpuinfo
list(APPEND CXX_COMPILE_FLAGS RESULT_VARIABLE CPUINFO_RET
"-Xpreprocessor" OUTPUT_VARIABLE CPUINFO)
"-fopenmp"
"-DVLLM_CPU_EXTENSION") if (NOT CPUINFO_RET EQUAL 0)
else() message(FATAL_ERROR "Failed to check CPU features via /proc/cpuinfo")
list(APPEND CXX_COMPILE_FLAGS
"-fopenmp"
"-DVLLM_CPU_EXTENSION")
endif() endif()
if (NOT MACOSX_FOUND)
execute_process(COMMAND cat /proc/cpuinfo
RESULT_VARIABLE CPUINFO_RET
OUTPUT_VARIABLE CPUINFO)
if (NOT CPUINFO_RET EQUAL 0)
message(FATAL_ERROR "Failed to check CPU features via /proc/cpuinfo")
endif()
endif()
function (find_isa CPUINFO TARGET OUT) function (find_isa CPUINFO TARGET OUT)
string(FIND ${CPUINFO} ${TARGET} ISA_FOUND) string(FIND ${CPUINFO} ${TARGET} ISA_FOUND)
if(NOT ISA_FOUND EQUAL -1) if(NOT ISA_FOUND EQUAL -1)
@ -72,17 +54,12 @@ endfunction()
is_avx512_disabled(AVX512_DISABLED) is_avx512_disabled(AVX512_DISABLED)
if (MACOSX_FOUND AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") find_isa(${CPUINFO} "avx2" AVX2_FOUND)
set(APPLE_SILICON_FOUND TRUE) find_isa(${CPUINFO} "avx512f" AVX512_FOUND)
else() find_isa(${CPUINFO} "POWER10" POWER10_FOUND)
find_isa(${CPUINFO} "avx2" AVX2_FOUND) find_isa(${CPUINFO} "POWER9" POWER9_FOUND)
find_isa(${CPUINFO} "avx512f" AVX512_FOUND) find_isa(${CPUINFO} "asimd" ASIMD_FOUND) # Check for ARM NEON support
find_isa(${CPUINFO} "POWER10" POWER10_FOUND) find_isa(${CPUINFO} "bf16" ARM_BF16_FOUND) # Check for ARM BF16 support
find_isa(${CPUINFO} "POWER9" POWER9_FOUND)
find_isa(${CPUINFO} "asimd" ASIMD_FOUND) # Check for ARM NEON support
find_isa(${CPUINFO} "bf16" ARM_BF16_FOUND) # Check for ARM BF16 support
endif()
if (AVX512_FOUND AND NOT AVX512_DISABLED) if (AVX512_FOUND AND NOT AVX512_DISABLED)
list(APPEND CXX_COMPILE_FLAGS list(APPEND CXX_COMPILE_FLAGS
@ -126,9 +103,6 @@ elseif (ASIMD_FOUND)
set(MARCH_FLAGS "-march=armv8.2-a+dotprod+fp16") set(MARCH_FLAGS "-march=armv8.2-a+dotprod+fp16")
endif() endif()
list(APPEND CXX_COMPILE_FLAGS ${MARCH_FLAGS}) list(APPEND CXX_COMPILE_FLAGS ${MARCH_FLAGS})
elseif(APPLE_SILICON_FOUND)
message(STATUS "Apple Silicon Detected")
set(ENABLE_NUMA OFF)
else() else()
message(FATAL_ERROR "vLLM CPU backend requires AVX512, AVX2, Power9+ ISA or ARMv8 support.") message(FATAL_ERROR "vLLM CPU backend requires AVX512, AVX2, Power9+ ISA or ARMv8 support.")
endif() endif()
@ -165,12 +139,7 @@ endif()
message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}") message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
if(ENABLE_NUMA) list(APPEND LIBS numa)
list(APPEND LIBS numa)
else()
message(STATUS "NUMA is disabled")
add_compile_definitions(-DVLLM_NUMA_DISABLED)
endif()
# #
# _C extension # _C extension

View File

@ -1,5 +1,4 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
# #
# A command line tool for running pytorch's hipify preprocessor on CUDA # A command line tool for running pytorch's hipify preprocessor on CUDA

View File

@ -58,8 +58,8 @@ function (hipify_sources_target OUT_SRCS NAME ORIG_SRCS)
# #
set(SRCS ${ORIG_SRCS}) set(SRCS ${ORIG_SRCS})
set(CXX_SRCS ${ORIG_SRCS}) set(CXX_SRCS ${ORIG_SRCS})
list(FILTER SRCS EXCLUDE REGEX "\.(cc)|(cpp)|(hip)$") list(FILTER SRCS EXCLUDE REGEX "\.(cc)|(cpp)$")
list(FILTER CXX_SRCS INCLUDE REGEX "\.(cc)|(cpp)|(hip)$") list(FILTER CXX_SRCS INCLUDE REGEX "\.(cc)|(cpp)$")
# #
# Generate ROCm/HIP source file names from CUDA file names. # Generate ROCm/HIP source file names from CUDA file names.
@ -259,7 +259,7 @@ endmacro()
# in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`. # in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`.
# We have special handling for 9.0a, if 9.0a is in `SRC_CUDA_ARCHS` and 9.0 is # We have special handling for 9.0a, if 9.0a is in `SRC_CUDA_ARCHS` and 9.0 is
# in `TGT_CUDA_ARCHS` then we should remove 9.0a from `SRC_CUDA_ARCHS` and add # in `TGT_CUDA_ARCHS` then we should remove 9.0a from `SRC_CUDA_ARCHS` and add
# 9.0a to the result (and remove 9.0 from TGT_CUDA_ARCHS). # 9.0a to the result.
# The result is stored in `OUT_CUDA_ARCHS`. # The result is stored in `OUT_CUDA_ARCHS`.
# #
# Example: # Example:
@ -270,47 +270,34 @@ endmacro()
# #
function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS) function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
list(REMOVE_DUPLICATES SRC_CUDA_ARCHS) list(REMOVE_DUPLICATES SRC_CUDA_ARCHS)
set(TGT_CUDA_ARCHS_ ${TGT_CUDA_ARCHS})
# if 9.0a is in SRC_CUDA_ARCHS and 9.0 is in CUDA_ARCHS then we should # if 9.0a is in SRC_CUDA_ARCHS and 9.0 is in CUDA_ARCHS then we should
# remove 9.0a from SRC_CUDA_ARCHS and add 9.0a to _CUDA_ARCHS # remove 9.0a from SRC_CUDA_ARCHS and add 9.0a to _CUDA_ARCHS
set(_CUDA_ARCHS) set(_CUDA_ARCHS)
if ("9.0a" IN_LIST SRC_CUDA_ARCHS) if ("9.0a" IN_LIST SRC_CUDA_ARCHS)
list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a") list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a")
if ("9.0" IN_LIST TGT_CUDA_ARCHS_) if ("9.0" IN_LIST TGT_CUDA_ARCHS)
list(REMOVE_ITEM TGT_CUDA_ARCHS_ "9.0")
set(_CUDA_ARCHS "9.0a") set(_CUDA_ARCHS "9.0a")
endif() endif()
endif() endif()
list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING) list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
# for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that # for each ARCH in CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that is
# is less or equal to ARCH (but has the same major version since SASS binary # less or eqault to ARCH
# compatibility is only forward compatible within the same major version). foreach(_ARCH ${CUDA_ARCHS})
foreach(_ARCH ${TGT_CUDA_ARCHS_}) set(_TMP_ARCH)
set(_TMP_ARCH) foreach(_SRC_ARCH ${SRC_CUDA_ARCHS})
# Extract the major version of the target arch if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}") set(_TMP_ARCH ${_SRC_ARCH})
foreach(_SRC_ARCH ${SRC_CUDA_ARCHS}) else()
# Extract the major version of the source arch break()
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}")
# Check major-version match AND version-less-or-equal
if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
if (SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
set(_TMP_ARCH "${_SRC_ARCH}")
endif()
else()
# If we hit a version greater than the target, we can break
break()
endif()
endforeach()
# If we found a matching _TMP_ARCH, append it to _CUDA_ARCHS
if (_TMP_ARCH)
list(APPEND _CUDA_ARCHS "${_TMP_ARCH}")
endif() endif()
endforeach() endforeach()
if (_TMP_ARCH)
list(APPEND _CUDA_ARCHS ${_TMP_ARCH})
endif()
endforeach()
list(REMOVE_DUPLICATES _CUDA_ARCHS) list(REMOVE_DUPLICATES _CUDA_ARCHS)
set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE) set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)

View File

@ -1,5 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# ruff: noqa # ruff: noqa
# code borrowed from https://github.com/pytorch/pytorch/blob/main/torch/utils/collect_env.py # code borrowed from https://github.com/pytorch/pytorch/blob/main/torch/utils/collect_env.py

View File

@ -9,16 +9,8 @@
namespace vllm { namespace vllm {
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
bool act_first>
__device__ __forceinline__ scalar_t compute(const scalar_t& x,
const scalar_t& y) {
return act_first ? ACT_FN(x) * y : x * ACT_FN(y);
}
// Activation and gating kernel template. // Activation and gating kernel template.
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
bool act_first>
__global__ void act_and_mul_kernel( __global__ void act_and_mul_kernel(
scalar_t* __restrict__ out, // [..., d] scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d] const scalar_t* __restrict__ input, // [..., 2, d]
@ -27,7 +19,7 @@ __global__ void act_and_mul_kernel(
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
out[token_idx * d + idx] = compute<scalar_t, ACT_FN, act_first>(x, y); out[token_idx * d + idx] = ACT_FN(x) * y;
} }
} }
@ -63,9 +55,7 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
} // namespace vllm } // namespace vllm
// Launch activation and gating kernel. // Launch activation and gating kernel.
// Use ACT_FIRST (bool) indicating whether to apply the activation function #define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
// first.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, ACT_FIRST) \
int d = input.size(-1) / 2; \ int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \ int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \ dim3 grid(num_tokens); \
@ -74,7 +64,7 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \ VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "act_and_mul_kernel", [&] { \ input.scalar_type(), "act_and_mul_kernel", [&] { \
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>, ACT_FIRST> \ vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \ <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \ input.data_ptr<scalar_t>(), d); \
}); });
@ -82,27 +72,19 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
void silu_and_mul(torch::Tensor& out, // [..., d] void silu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d] torch::Tensor& input) // [..., 2 * d]
{ {
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, true); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
}
void mul_and_silu(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
// The difference between mul_and_silu and silu_and_mul is that mul_and_silu
// applies the silu to the latter half of the input.
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, false);
} }
void gelu_and_mul(torch::Tensor& out, // [..., d] void gelu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d] torch::Tensor& input) // [..., 2 * d]
{ {
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel, true); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
} }
void gelu_tanh_and_mul(torch::Tensor& out, // [..., d] void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d] torch::Tensor& input) // [..., 2 * d]
{ {
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel, true); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
} }
namespace vllm { namespace vllm {

View File

@ -105,7 +105,7 @@ __device__ void paged_attention_kernel(
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float* k_scale, const float* v_scale, const int tp_rank, const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) { const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const int seq_idx = blockIdx.y; const int seq_idx = blockIdx.y;
@ -285,7 +285,7 @@ __device__ void paged_attention_kernel(
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>( Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2); k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>( k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
k_vec_quant, *k_scale); k_vec_quant, k_scale);
} }
} }
@ -415,7 +415,7 @@ __device__ void paged_attention_kernel(
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset); *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec. // Vector conversion from V_quant_vec to V_vec.
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec, v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
*v_scale); v_scale);
} }
if (block_idx == num_seq_blocks - 1) { if (block_idx == num_seq_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of the // NOTE(woosuk): When v_vec contains the tokens that are out of the
@ -513,7 +513,7 @@ __global__ void paged_attention_v1_kernel(
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float* k_scale, const float* v_scale, const int tp_rank, const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) { const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
@ -549,7 +549,7 @@ __global__ void paged_attention_v2_kernel(
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float* k_scale, const float* v_scale, const int tp_rank, const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) { const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,

View File

@ -41,7 +41,7 @@
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \ k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step); blocksparse_head_sliding_step);
@ -53,10 +53,10 @@ void paged_attention_v1_launcher(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale, const std::optional<torch::Tensor>& alibi_slopes, float k_scale,
torch::Tensor& v_scale, const int tp_rank, float v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) { const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
int num_heads = query.size(1); int num_heads = query.size(1);
int head_size = query.size(2); int head_size = query.size(2);
@ -80,8 +80,6 @@ void paged_attention_v1_launcher(
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr()); CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>(); int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>(); int* seq_lens_ptr = seq_lens.data_ptr<int>();
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int padded_max_seq_len = int padded_max_seq_len =
@ -179,9 +177,8 @@ void paged_attention_v1(
torch::Tensor& seq_lens, // [num_seqs] torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len, int64_t block_size, int64_t max_seq_len,
const std::optional<torch::Tensor>& alibi_slopes, const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, const std::string& kv_cache_dtype, double k_scale, double v_scale,
torch::Tensor& v_scale, const int64_t tp_rank, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) { const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1); const bool is_block_sparse = (blocksparse_vert_stride > 1);

View File

@ -37,7 +37,7 @@
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \ kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step); \ blocksparse_block_size, blocksparse_head_sliding_step); \
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \ vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
@ -54,10 +54,10 @@ void paged_attention_v2_launcher(
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale, const std::optional<torch::Tensor>& alibi_slopes, float k_scale,
torch::Tensor& v_scale, const int tp_rank, float v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) { const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
int num_heads = query.size(1); int num_heads = query.size(1);
int head_size = query.size(2); int head_size = query.size(2);
@ -84,8 +84,6 @@ void paged_attention_v2_launcher(
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr()); CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>(); int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>(); int* seq_lens_ptr = seq_lens.data_ptr<int>();
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
@ -190,9 +188,8 @@ void paged_attention_v2(
torch::Tensor& seq_lens, // [num_seqs] torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len, int64_t block_size, int64_t max_seq_len,
const std::optional<torch::Tensor>& alibi_slopes, const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, const std::string& kv_cache_dtype, double k_scale, double v_scale,
torch::Tensor& v_scale, const int64_t tp_rank, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) { const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1); const bool is_block_sparse = (blocksparse_vert_stride > 1);

View File

@ -15,26 +15,18 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches, std::vector<torch::Tensor> const& value_caches,
const torch::Tensor& block_mapping); const torch::Tensor& block_mapping);
void copy_blocks_mla(std::vector<torch::Tensor> const& kv_caches,
const torch::Tensor& block_mapping);
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping, torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype, const std::string& kv_cache_dtype, const double k_scale,
torch::Tensor& k_scale, torch::Tensor& v_scale); const double v_scale);
void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& key_cache,
torch::Tensor& value_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping, torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype, const std::string& kv_cache_dtype,
torch::Tensor& k_scale, torch::Tensor& v_scale); const double k_scale, const double v_scale);
void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
torch::Tensor& kv_cache, torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
torch::Tensor& scale);
// Just for unittest // Just for unittest
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,

View File

@ -46,10 +46,7 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
char* src_ptr = static_cast<char*>(src.data_ptr()); char* src_ptr = static_cast<char*>(src.data_ptr());
char* dst_ptr = static_cast<char*>(dst.data_ptr()); char* dst_ptr = static_cast<char*>(dst.data_ptr());
// We use the stride instead of numel in case the cache is padded for memory const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
// alignment reasons, we assume the blocks data (inclusive of any padding)
// is contiguous in memory
const int64_t block_size_in_bytes = src.element_size() * src.stride(0);
const at::cuda::OptionalCUDAGuard device_guard( const at::cuda::OptionalCUDAGuard device_guard(
src_device.is_cuda() ? src_device : dst_device); src_device.is_cuda() ? src_device : dst_device);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
@ -96,24 +93,6 @@ __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,
} }
} }
// Kernel for MLA, which works on a single joint kv_cache
// Grid: (num_layers, num_pairs)
template <typename scalar_t>
__global__ void copy_blocks_mla_kernel(
int64_t* cache_ptrs, const int64_t* __restrict__ block_mapping,
const int mem_footprint_per_block) {
const int layer_idx = blockIdx.x;
const int pair_idx = blockIdx.y;
scalar_t* cache = reinterpret_cast<scalar_t*>(cache_ptrs[layer_idx]);
int64_t src_block = block_mapping[2 * pair_idx];
int64_t dst_block = block_mapping[2 * pair_idx + 1];
int64_t src_offset = src_block * mem_footprint_per_block;
int64_t dst_offset = dst_block * mem_footprint_per_block;
for (int i = threadIdx.x; i < mem_footprint_per_block; i += blockDim.x) {
cache[dst_offset + i] = cache[src_offset + i];
}
}
} // namespace vllm } // namespace vllm
// Note: the key_caches and value_caches vectors are constant but // Note: the key_caches and value_caches vectors are constant but
@ -168,42 +147,6 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
})); }));
} }
// copy blocks kernel for MLA (assumes a joint KV-cache)
void copy_blocks_mla(std::vector<torch::Tensor> const& kv_caches,
const torch::Tensor& block_mapping) {
int num_layers = kv_caches.size();
if (num_layers == 0) {
return;
}
torch::Device cache_device = kv_caches[0].device();
TORCH_CHECK(cache_device.is_cuda(), "kv_cache must be on CUDA");
std::vector<int64_t> cache_ptrs(num_layers);
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
cache_ptrs[layer_idx] =
reinterpret_cast<int64_t>(kv_caches[layer_idx].data_ptr());
}
torch::Tensor cache_ptrs_tensor =
torch::from_blob(cache_ptrs.data(), {num_layers}, torch::kInt64)
.to(cache_device);
int num_pairs = block_mapping.size(0);
// We use the stride instead of numel in case the cache is padded for memory
// alignment reasons, we assume the blocks data (inclusive of any padding)
// is contiguous in memory
int mem_footprint_per_block = kv_caches[0].stride(0);
dim3 grid(num_layers, num_pairs);
dim3 block(std::min(1024, mem_footprint_per_block));
const at::cuda::OptionalCUDAGuard device_guard(cache_device);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
kv_caches[0].scalar_type(), "copy_blocks_mla_kernel", ([&] {
vllm::copy_blocks_mla_kernel<scalar_t><<<grid, block, 0, stream>>>(
cache_ptrs_tensor.data_ptr<int64_t>(),
block_mapping.data_ptr<int64_t>(), mem_footprint_per_block);
}));
}
namespace vllm { namespace vllm {
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt> template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
@ -216,8 +159,8 @@ __global__ void reshape_and_cache_kernel(
// block_size] // block_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens] const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int key_stride, const int value_stride, const int num_heads, const int key_stride, const int value_stride, const int num_heads,
const int head_size, const int block_size, const int x, const int head_size, const int block_size, const int x, const float k_scale,
const float* k_scale, const float* v_scale) { const float v_scale) {
const int64_t token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx]; const int64_t slot_idx = slot_mapping[token_idx];
if (slot_idx < 0) { if (slot_idx < 0) {
@ -253,9 +196,9 @@ __global__ void reshape_and_cache_kernel(
value_cache[tgt_value_idx] = tgt_value; value_cache[tgt_value_idx] = tgt_value;
} else { } else {
key_cache[tgt_key_idx] = key_cache[tgt_key_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale); fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, k_scale);
value_cache[tgt_value_idx] = value_cache[tgt_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale); fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, v_scale);
} }
} }
} }
@ -271,7 +214,7 @@ __global__ void reshape_and_cache_flash_kernel(
const int64_t* __restrict__ slot_mapping, // [num_tokens] const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int block_stride, const int key_stride, const int value_stride, const int block_stride, const int key_stride, const int value_stride,
const int num_heads, const int head_size, const int block_size, const int num_heads, const int head_size, const int block_size,
const float* k_scale, const float* v_scale) { const float k_scale, const float v_scale) {
const int64_t token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx]; const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded // NOTE: slot_idx can be -1 if the token is padded
@ -296,57 +239,12 @@ __global__ void reshape_and_cache_flash_kernel(
value_cache[tgt_key_value_idx] = tgt_value; value_cache[tgt_key_value_idx] = tgt_value;
} else { } else {
key_cache[tgt_key_value_idx] = key_cache[tgt_key_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale); fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, k_scale);
value_cache[tgt_key_value_idx] = value_cache[tgt_key_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale); fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, v_scale);
} }
} }
} }
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void concat_and_cache_mla_kernel(
const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim]
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
// + pe_dim)]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int block_stride, //
const int entry_stride, //
const int kv_c_stride, //
const int k_pe_stride, //
const int kv_lora_rank, //
const int pe_dim, //
const int block_size, //
const float* scale //
) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded
if (slot_idx < 0) {
return;
}
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
auto copy = [&](const scalar_t* __restrict__ src, cache_t* __restrict__ dst,
int src_stride, int dst_stride, int size, int offset) {
for (int i = threadIdx.x; i < size; i += blockDim.x) {
const int64_t src_idx = token_idx * src_stride + i;
const int64_t dst_idx =
block_idx * block_stride + block_offset * entry_stride + i + offset;
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
dst[dst_idx] = src[src_idx];
} else {
dst[dst_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(src[src_idx], *scale);
}
}
};
copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0);
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
}
} // namespace vllm } // namespace vllm
// KV_T is the stored data type of kv-cache. // KV_T is the stored data type of kv-cache.
@ -360,9 +258,7 @@ __global__ void concat_and_cache_mla_kernel(
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \ reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \ reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \ slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
num_heads, head_size, block_size, x, \ num_heads, head_size, block_size, x, k_scale, v_scale);
reinterpret_cast<const float*>(k_scale.data_ptr()), \
reinterpret_cast<const float*>(v_scale.data_ptr()));
void reshape_and_cache( void reshape_and_cache(
torch::Tensor& key, // [num_tokens, num_heads, head_size] torch::Tensor& key, // [num_tokens, num_heads, head_size]
@ -372,8 +268,8 @@ void reshape_and_cache(
torch::Tensor& torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size] value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping, // [num_tokens] torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype, torch::Tensor& k_scale, const std::string& kv_cache_dtype, const double k_scale,
torch::Tensor& v_scale) { const double v_scale) {
int num_tokens = key.size(0); int num_tokens = key.size(0);
int num_heads = key.size(1); int num_heads = key.size(1);
int head_size = key.size(2); int head_size = key.size(2);
@ -403,9 +299,7 @@ void reshape_and_cache(
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \ reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \ reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \ slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \
value_stride, num_heads, head_size, block_size, \ value_stride, num_heads, head_size, block_size, k_scale, v_scale);
reinterpret_cast<const float*>(k_scale.data_ptr()), \
reinterpret_cast<const float*>(v_scale.data_ptr()));
void reshape_and_cache_flash( void reshape_and_cache_flash(
torch::Tensor& key, // [num_tokens, num_heads, head_size] torch::Tensor& key, // [num_tokens, num_heads, head_size]
@ -414,8 +308,8 @@ void reshape_and_cache_flash(
torch::Tensor& torch::Tensor&
value_cache, // [num_blocks, block_size, num_heads, head_size] value_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens] torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
const std::string& kv_cache_dtype, torch::Tensor& k_scale, const std::string& kv_cache_dtype, const double k_scale,
torch::Tensor& v_scale) { const double v_scale) {
// NOTE(woosuk): In vLLM V1, key.size(0) can be different from // NOTE(woosuk): In vLLM V1, key.size(0) can be different from
// slot_mapping.size(0) because of padding for CUDA graphs. // slot_mapping.size(0) because of padding for CUDA graphs.
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because // In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
@ -445,57 +339,6 @@ void reshape_and_cache_flash(
CALL_RESHAPE_AND_CACHE_FLASH); CALL_RESHAPE_AND_CACHE_FLASH);
} }
// KV_T is the stored data type of kv-cache.
// CACHE_T is the data type of key and value tensors.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
vllm::concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, entry_stride, \
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
reinterpret_cast<const float*>(scale.data_ptr()));
void concat_and_cache_mla(
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
torch::Tensor& k_pe, // [num_tokens, pe_dim]
torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank +
// pe_dim)]
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
const std::string& kv_cache_dtype, torch::Tensor& scale) {
// NOTE(woosuk): In vLLM V1, key.size(0) can be different from
// slot_mapping.size(0) because of padding for CUDA graphs.
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
// both include padding.
// In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0)
// since key includes padding for CUDA graphs, while slot_mapping does not.
// In this case, slot_mapping.size(0) represents the actual number of tokens
// before padding.
// For compatibility with both cases, we use slot_mapping.size(0) as the
// number of tokens.
int num_tokens = slot_mapping.size(0);
int kv_lora_rank = kv_c.size(1);
int pe_dim = k_pe.size(1);
int block_size = kv_cache.size(1);
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
int kv_c_stride = kv_c.stride(0);
int k_pe_stride = k_pe.stride(0);
int block_stride = kv_cache.stride(0);
int entry_stride = kv_cache.stride(1);
dim3 grid(num_tokens);
dim3 block(std::min(kv_lora_rank, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
CALL_CONCAT_AND_CACHE_MLA);
}
namespace vllm { namespace vllm {
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt> template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>

View File

@ -1,14 +1,7 @@
#pragma once
#include <climits> #include <climits>
#include <iostream> #include <iostream>
inline constexpr uint32_t next_pow_2(uint32_t const num) { inline uint32_t next_pow_2(uint32_t const num) {
if (num <= 1) return num; if (num <= 1) return num;
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
}
template <typename T>
inline constexpr std::enable_if_t<std::is_integral_v<T>, T> ceil_div(T a, T b) {
return (a + b - 1) / b;
} }

View File

@ -32,7 +32,7 @@ class ScalarType {
signed_(signed_), signed_(signed_),
bias(bias), bias(bias),
finite_values_only(finite_values_only), finite_values_only(finite_values_only),
nan_repr(nan_repr) {}; nan_repr(nan_repr){};
static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) { static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) {
return ScalarType(0, size_bits - 1, true, bias); return ScalarType(0, size_bits - 1, true, bias);

View File

@ -460,11 +460,11 @@ void paged_attention_v1(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes, int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, const std::string& kv_cache_dtype, double k_scale, double v_scale,
torch::Tensor& v_scale, const int64_t tp_rank, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) { const int64_t blocksparse_head_sliding_step) {
TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);
TORCH_CHECK(blocksparse_vert_stride <= 1, TORCH_CHECK(blocksparse_vert_stride <= 1,
"CPU backend does not support blocksparse attention yet."); "CPU backend does not support blocksparse attention yet.");
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
@ -782,11 +782,11 @@ void paged_attention_v2(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes, int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, const std::string& kv_cache_dtype, double k_scale, double v_scale,
torch::Tensor& v_scale, const int64_t tp_rank, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) { const int64_t blocksparse_head_sliding_step) {
TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);
TORCH_CHECK(blocksparse_vert_stride <= 1, TORCH_CHECK(blocksparse_vert_stride <= 1,
"CPU backend does not support blocksparse attention yet."); "CPU backend does not support blocksparse attention yet.");
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",

View File

@ -107,8 +107,10 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping, torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype, const std::string& kv_cache_dtype, double k_scale,
torch::Tensor& k_scale, torch::Tensor& v_scale) { double v_scale) {
TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);
int num_tokens = key.size(0); int num_tokens = key.size(0);
int num_heads = key.size(1); int num_heads = key.size(1);
int head_size = key.size(2); int head_size = key.size(2);

View File

@ -2,13 +2,13 @@
#define CPU_TYPES_HPP #define CPU_TYPES_HPP
#if defined(__x86_64__) #if defined(__x86_64__)
// x86 implementation //x86 implementation
#include "cpu_types_x86.hpp" #include "cpu_types_x86.hpp"
#elif defined(__POWER9_VECTOR__) #elif defined(__POWER9_VECTOR__)
// ppc implementation //ppc implementation
#include "cpu_types_vsx.hpp" #include "cpu_types_vsx.hpp"
#elif defined(__aarch64__) #elif defined(__aarch64__)
// arm implementation //arm implementation
#include "cpu_types_arm.hpp" #include "cpu_types_arm.hpp"
#else #else
#warning "unsupported vLLM cpu implementation" #warning "unsupported vLLM cpu implementation"

View File

@ -1,50 +1,48 @@
#include <arm_neon.h> #include <arm_neon.h>
#include <torch/all.h> #include <torch/all.h>
#include <cmath> #include <cmath>
namespace vec_op { namespace vec_op {
#ifdef ARM_BF16_SUPPORT #ifdef ARM_BF16_SUPPORT
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#else #else
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
#endif #endif
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#ifndef CPU_OP_GUARD #ifndef CPU_OP_GUARD
#define CPU_KERNEL_GUARD_IN(NAME) #define CPU_KERNEL_GUARD_IN(NAME)
#define CPU_KERNEL_GUARD_OUT(NAME) #define CPU_KERNEL_GUARD_OUT(NAME)
#else #else
#define CPU_KERNEL_GUARD_IN(NAME) \ #define CPU_KERNEL_GUARD_IN(NAME) \
std::cout << #NAME << " invoked." << std::endl; std::cout << #NAME << " invoked." << std::endl;
#define CPU_KERNEL_GUARD_OUT(NAME) \ #define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl;
std::cout << #NAME << " exit." << std::endl;
#endif #endif
#define FORCE_INLINE __attribute__((always_inline)) inline #define FORCE_INLINE __attribute__((always_inline)) inline
namespace { namespace {
template <typename T, T... indexes, typename F> template <typename T, T... indexes, typename F>
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F&& f) { constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F &&f) {
(f(std::integral_constant<T, indexes>{}), ...); (f(std::integral_constant<T, indexes>{}), ...);
}; };
}; // namespace };
template <typename T, T count, typename F, template <typename T, T count, typename F,
typename = std::enable_if_t<std::is_invocable_v<F, T>>> typename = std::enable_if_t<std::is_invocable_v<F, T>>>
constexpr void unroll_loop(F&& f) { constexpr void unroll_loop(F &&f) {
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f)); unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
} }
template <typename T> template <typename T> struct Vec {
struct Vec {
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }; constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; };
}; };
@ -56,124 +54,70 @@ struct FP16Vec8 : public Vec<FP16Vec8> {
float16x8_t reg; float16x8_t reg;
explicit FP16Vec8(const void* ptr) explicit FP16Vec8(const void *ptr)
: reg(vld1q_f16(static_cast<const __fp16*>(ptr))) {}; : reg(vld1q_f16(static_cast<const __fp16 *>(ptr))) {};
explicit FP16Vec8(const FP32Vec8&); explicit FP16Vec8(const FP32Vec8 &);
void save(void* ptr) const { vst1q_f16(static_cast<__fp16*>(ptr), reg); } void save(void *ptr) const {
vst1q_f16(static_cast<__fp16 *>(ptr), reg);
}
}; };
struct FP16Vec16 : public Vec<FP16Vec16> { struct FP16Vec16 : public Vec<FP16Vec16> {
constexpr static int VEC_ELEM_NUM = 16; constexpr static int VEC_ELEM_NUM = 16;
float16x8x2_t reg; float16x8x2_t reg;
explicit FP16Vec16(const void* ptr) { explicit FP16Vec16(const void *ptr) {
reg.val[0] = vld1q_f16(reinterpret_cast<const __fp16*>(ptr)); reg.val[0] = vld1q_f16(reinterpret_cast<const __fp16*>(ptr));
reg.val[1] = vld1q_f16(reinterpret_cast<const __fp16*>(ptr) + 8); reg.val[1] = vld1q_f16(reinterpret_cast<const __fp16*>(ptr) + 8);
}
explicit FP16Vec16(const FP32Vec16& vec);
void save(void* ptr) const {
vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]);
vst1q_f16(reinterpret_cast<__fp16*>(ptr) + 8, reg.val[1]);
}
void save(void* ptr, const int elem_num) const {
int full_blocks = elem_num / 8;
int remainder = elem_num % 8;
if (full_blocks > 0) {
vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]);
if (full_blocks > 1) {
vst1q_f16(reinterpret_cast<__fp16*>(ptr) + 8, reg.val[1]);
}
} }
// Note: below is the unrolled version of the following code: explicit FP16Vec16(const FP32Vec16& vec);
//
// for (int i = 0; i < remainder; ++i) { void save(void *ptr) const {
// reinterpret_cast<__fp16*>(ptr)[full_blocks * 8 + i] = vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]);
// vgetq_lane_f16(temp, i); vst1q_f16(reinterpret_cast<__fp16*>(ptr) + 8, reg.val[1]);
// } }
//
// For macOS build (Clang), the arm/neon intrinsics function void save(void *ptr, const int elem_num) const {
// `vgetq_lane_f16` needs the parameter `i` to be constant at compile int full_blocks = elem_num / 8;
// time. int remainder = elem_num % 8;
if (remainder > 0) { if (full_blocks > 0) {
float16x8_t temp = reg.val[full_blocks]; vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]);
__fp16* fp16_ptr = reinterpret_cast<__fp16*>(ptr); if (full_blocks > 1) {
switch (remainder) { vst1q_f16(reinterpret_cast<__fp16*>(ptr) + 8, reg.val[1]);
case 1: }
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0); }
break;
case 2: if (remainder > 0) {
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0); float16x8_t temp = reg.val[full_blocks];
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1); for (int i = 0; i < remainder; ++i) {
break; reinterpret_cast<__fp16*>(ptr)[full_blocks * 8 + i] = vgetq_lane_f16(temp, i);
case 3: }
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0); }
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
break;
case 4:
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3);
break;
case 5:
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3);
fp16_ptr[full_blocks * 8 + 4] = vgetq_lane_f16(temp, 4);
break;
case 6:
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3);
fp16_ptr[full_blocks * 8 + 4] = vgetq_lane_f16(temp, 4);
fp16_ptr[full_blocks * 8 + 5] = vgetq_lane_f16(temp, 5);
break;
case 7:
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3);
fp16_ptr[full_blocks * 8 + 4] = vgetq_lane_f16(temp, 4);
fp16_ptr[full_blocks * 8 + 5] = vgetq_lane_f16(temp, 5);
fp16_ptr[full_blocks * 8 + 6] = vgetq_lane_f16(temp, 6);
break;
default:
break;
}
} }
}
}; };
#ifdef ARM_BF16_SUPPORT #ifdef ARM_BF16_SUPPORT
struct BF16Vec8 : public Vec<BF16Vec8> { struct BF16Vec8 : public Vec<BF16Vec8> {
constexpr static int VEC_ELEM_NUM = 8; constexpr static int VEC_ELEM_NUM = 8;
bfloat16x8_t reg; bfloat16x8_t reg;
explicit BF16Vec8(const void* ptr) explicit BF16Vec8(const void *ptr)
: reg(*reinterpret_cast<const bfloat16x8_t*>(ptr)) {}; : reg(*reinterpret_cast<const bfloat16x8_t *>(ptr)) {};
explicit BF16Vec8(bfloat16x8_t data) : reg(data) {}; explicit BF16Vec8(bfloat16x8_t data) : reg(data) {};
explicit BF16Vec8(const FP32Vec8&); explicit BF16Vec8(const FP32Vec8 &);
explicit BF16Vec8(float32x4x2_t v) explicit BF16Vec8(float32x4x2_t v) : reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[0]), v.val[1])) {};
: reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[0]), v.val[1])) {};
void save(void* ptr) const { *reinterpret_cast<bfloat16x8_t*>(ptr) = reg; } void save(void *ptr) const { *reinterpret_cast<bfloat16x8_t *>(ptr) = reg; }
}; };
struct BF16Vec16 : public Vec<BF16Vec16> { struct BF16Vec16 : public Vec<BF16Vec16> {
@ -181,18 +125,19 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
bfloat16x8x2_t reg; bfloat16x8x2_t reg;
explicit BF16Vec16(const void* ptr) explicit BF16Vec16(const void *ptr)
: reg(*reinterpret_cast<const bfloat16x8x2_t*>(ptr)) {}; : reg(*reinterpret_cast<const bfloat16x8x2_t *>(ptr)) {};
explicit BF16Vec16(bfloat16x8x2_t data) : reg(data) {}; explicit BF16Vec16(bfloat16x8x2_t data) : reg(data) {};
explicit BF16Vec16(const FP32Vec16&); explicit BF16Vec16(const FP32Vec16 &);
explicit BF16Vec16(float32x4x4_t v) explicit BF16Vec16(float32x4x4_t v) : reg({
: reg({vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[0]), v.val[1]), vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[0]), v.val[1]),
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[2]), v.val[3])}) {}; vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[2]), v.val[3])
}){};
void save(void* ptr) const { *reinterpret_cast<bfloat16x8x2_t*>(ptr) = reg; }; void save(void *ptr) const { *reinterpret_cast<bfloat16x8x2_t *>(ptr) = reg; };
}; };
struct BF16Vec32 : public Vec<BF16Vec32> { struct BF16Vec32 : public Vec<BF16Vec32> {
@ -200,15 +145,19 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
bfloat16x8x4_t reg; bfloat16x8x4_t reg;
explicit BF16Vec32(const void* ptr) explicit BF16Vec32(const void *ptr)
: reg(*reinterpret_cast<const bfloat16x8x4_t*>(ptr)) {}; : reg(*reinterpret_cast<const bfloat16x8x4_t *>(ptr)) {};
explicit BF16Vec32(bfloat16x8x4_t data) : reg(data) {}; explicit BF16Vec32(bfloat16x8x4_t data) : reg(data) {};
explicit BF16Vec32(const BF16Vec8& vec8_data) explicit BF16Vec32(const BF16Vec8 &vec8_data) : reg({
: reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {}; vec8_data.reg,
vec8_data.reg,
vec8_data.reg,
vec8_data.reg
}) {};
void save(void* ptr) const { *reinterpret_cast<bfloat16x8x4_t*>(ptr) = reg; }; void save(void *ptr) const { *reinterpret_cast<bfloat16x8x4_t *>(ptr) = reg; };
}; };
#endif #endif
@ -226,11 +175,11 @@ struct FP32Vec4 : public Vec<FP32Vec4> {
explicit FP32Vec4() : reg(vdupq_n_f32(0.0f)) {}; explicit FP32Vec4() : reg(vdupq_n_f32(0.0f)) {};
explicit FP32Vec4(const float* ptr) : reg(vld1q_f32(ptr)) {}; explicit FP32Vec4(const float *ptr) : reg(vld1q_f32(ptr)) {};
explicit FP32Vec4(float32x4_t data) : reg(data) {}; explicit FP32Vec4(float32x4_t data) : reg(data) {};
explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {}; explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {};
}; };
struct FP32Vec8 : public Vec<FP32Vec8> { struct FP32Vec8 : public Vec<FP32Vec8> {
@ -246,37 +195,32 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
explicit FP32Vec8() : reg({vmovq_n_f32(0.0), vmovq_n_f32(0.0)}) {}; explicit FP32Vec8() : reg({vmovq_n_f32(0.0), vmovq_n_f32(0.0)}) {};
explicit FP32Vec8(const float* ptr) explicit FP32Vec8(const float *ptr) : reg({vld1q_f32(ptr), vld1q_f32(ptr + 4)}) {};
: reg({vld1q_f32(ptr), vld1q_f32(ptr + 4)}) {};
explicit FP32Vec8(float32x4x2_t data) : reg(data) {}; explicit FP32Vec8(float32x4x2_t data) : reg(data) {};
explicit FP32Vec8(const FP32Vec8& data) : reg(data.reg) {}; explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {};
explicit FP32Vec8(const FP16Vec8& v) { explicit FP32Vec8(const FP16Vec8 &v) {
reg.val[0] = vcvt_f32_f16(vget_low_f16(v.reg)); reg.val[0] = vcvt_f32_f16(vget_low_f16(v.reg));
reg.val[1] = vcvt_f32_f16(vget_high_f16(v.reg)); reg.val[1] = vcvt_f32_f16(vget_high_f16(v.reg));
}; };
explicit FP32Vec8(float16x8_t v) explicit FP32Vec8(float16x8_t v) : reg({vcvt_f32_f16(vget_low_f16(v)), vcvt_f32_f16(vget_high_f16(v))}) {};
: reg({vcvt_f32_f16(vget_low_f16(v)), vcvt_f32_f16(vget_high_f16(v))}) {};
#ifdef ARM_BF16_SUPPORT #ifdef ARM_BF16_SUPPORT
explicit FP32Vec8(bfloat16x8_t v) explicit FP32Vec8(bfloat16x8_t v) : reg({vcvtq_low_f32_bf16(v), vcvtq_high_f32_bf16(v)}) {};
: reg({vcvtq_low_f32_bf16(v), vcvtq_high_f32_bf16(v)}) {};
explicit FP32Vec8(const BF16Vec8& v) explicit FP32Vec8(const BF16Vec8 &v) : reg({vcvtq_low_f32_bf16(v.reg), vcvtq_high_f32_bf16(v.reg)}) {};
: reg({vcvtq_low_f32_bf16(v.reg), vcvtq_high_f32_bf16(v.reg)}) {};
#endif #endif
float reduce_sum() const { float reduce_sum() const {
AliasReg ar; AliasReg ar;
ar.reg = reg; ar.reg = reg;
float answer = 0; float answer = 0;
unroll_loop<int, VEC_ELEM_NUM>( unroll_loop<int, VEC_ELEM_NUM>([&answer, &ar](int i) { answer += ar.values[i]; });
[&answer, &ar](int i) { answer += ar.values[i]; });
return answer; return answer;
} }
@ -323,14 +267,10 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
AliasReg ar; AliasReg ar;
ar.reg = reg; ar.reg = reg;
float32x2_t er_vec0 = {static_cast<float32_t>(erf(ar.values[0])), float32x2_t er_vec0 = {static_cast<float32_t>(erf(ar.values[0])), static_cast<float32_t>(erf(ar.values[1]))};
static_cast<float32_t>(erf(ar.values[1]))}; float32x2_t er_vec1 = {static_cast<float32_t>(erf(ar.values[2])), static_cast<float32_t>(erf(ar.values[3]))};
float32x2_t er_vec1 = {static_cast<float32_t>(erf(ar.values[2])), float32x2_t er_vec2 = {static_cast<float32_t>(erf(ar.values[4])), static_cast<float32_t>(erf(ar.values[5]))};
static_cast<float32_t>(erf(ar.values[3]))}; float32x2_t er_vec3 = {static_cast<float32_t>(erf(ar.values[6])), static_cast<float32_t>(erf(ar.values[7]))};
float32x2_t er_vec2 = {static_cast<float32_t>(erf(ar.values[4])),
static_cast<float32_t>(erf(ar.values[5]))};
float32x2_t er_vec3 = {static_cast<float32_t>(erf(ar.values[6])),
static_cast<float32_t>(erf(ar.values[7]))};
float32x4_t result0 = vcombine_f32(er_vec0, er_vec1); float32x4_t result0 = vcombine_f32(er_vec0, er_vec1);
float32x4_t result1 = vcombine_f32(er_vec2, er_vec3); float32x4_t result1 = vcombine_f32(er_vec2, er_vec3);
@ -340,29 +280,25 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
result.val[1] = result1; result.val[1] = result1;
return FP32Vec8(result); return FP32Vec8(result);
}
FP32Vec8 operator*(const FP32Vec8 &b) const {
return FP32Vec8(float32x4x2_t({vmulq_f32(reg.val[0], b.reg.val[0]), vmulq_f32(reg.val[1], b.reg.val[1])}));
} }
FP32Vec8 operator*(const FP32Vec8& b) const { FP32Vec8 operator+(const FP32Vec8 &b) const {
return FP32Vec8(float32x4x2_t({vmulq_f32(reg.val[0], b.reg.val[0]), return FP32Vec8(float32x4x2_t({vaddq_f32(reg.val[0], b.reg.val[0]), vaddq_f32(reg.val[1], b.reg.val[1])}));
vmulq_f32(reg.val[1], b.reg.val[1])}));
} }
FP32Vec8 operator+(const FP32Vec8& b) const { FP32Vec8 operator-(const FP32Vec8 &b) const {
return FP32Vec8(float32x4x2_t({vaddq_f32(reg.val[0], b.reg.val[0]), return FP32Vec8(float32x4x2_t({vsubq_f32(reg.val[0], b.reg.val[0]), vsubq_f32(reg.val[1], b.reg.val[1])}));
vaddq_f32(reg.val[1], b.reg.val[1])}));
} }
FP32Vec8 operator-(const FP32Vec8& b) const { FP32Vec8 operator/(const FP32Vec8 &b) const {
return FP32Vec8(float32x4x2_t({vsubq_f32(reg.val[0], b.reg.val[0]), return FP32Vec8(float32x4x2_t({vdivq_f32(reg.val[0], b.reg.val[0]), vdivq_f32(reg.val[1], b.reg.val[1])}));
vsubq_f32(reg.val[1], b.reg.val[1])}));
} }
FP32Vec8 operator/(const FP32Vec8& b) const { void save(float *ptr) const {
return FP32Vec8(float32x4x2_t({vdivq_f32(reg.val[0], b.reg.val[0]),
vdivq_f32(reg.val[1], b.reg.val[1])}));
}
void save(float* ptr) const {
vst1q_f32(ptr, reg.val[0]); vst1q_f32(ptr, reg.val[0]);
vst1q_f32(ptr + 4, reg.val[1]); vst1q_f32(ptr + 4, reg.val[1]);
} }
@ -377,100 +313,103 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
float32x4x4_t reg; float32x4x4_t reg;
explicit FP32Vec16(float v) explicit FP32Vec16(float v) : reg({vmovq_n_f32(v), vmovq_n_f32(v), vmovq_n_f32(v), vmovq_n_f32(v)}) {}
: reg({vmovq_n_f32(v), vmovq_n_f32(v), vmovq_n_f32(v), vmovq_n_f32(v)}) {}
explicit FP32Vec16() explicit FP32Vec16() : reg({vmovq_n_f32(0.0), vmovq_n_f32(0.0), vmovq_n_f32(0.0), vmovq_n_f32(0.0)}) {}
: reg({vmovq_n_f32(0.0), vmovq_n_f32(0.0), vmovq_n_f32(0.0),
vmovq_n_f32(0.0)}) {}
explicit FP32Vec16(const float* ptr) explicit FP32Vec16(const float *ptr) : reg({vld1q_f32(ptr), vld1q_f32(ptr + 4), vld1q_f32(ptr + 8), vld1q_f32(ptr + 12)}) {}
: reg({vld1q_f32(ptr), vld1q_f32(ptr + 4), vld1q_f32(ptr + 8),
vld1q_f32(ptr + 12)}) {}
explicit FP32Vec16(float32x4x4_t data) : reg(data) {} explicit FP32Vec16(float32x4x4_t data) : reg(data) {}
explicit FP32Vec16(const FP32Vec8& data) { explicit FP32Vec16(const FP32Vec8 &data) {
reg.val[0] = data.reg.val[0]; reg.val[0] = data.reg.val[0];
reg.val[1] = data.reg.val[1]; reg.val[1] = data.reg.val[1];
reg.val[2] = data.reg.val[0]; reg.val[2] = data.reg.val[0];
reg.val[3] = data.reg.val[1]; reg.val[3] = data.reg.val[1];
} }
explicit FP32Vec16(const FP32Vec16& data) : reg(data.reg) {} explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {}
explicit FP32Vec16(const FP16Vec8& v) : FP32Vec16(FP32Vec8(v.reg)) {} explicit FP32Vec16(const FP16Vec8 &v) : FP32Vec16(FP32Vec8(v.reg)) {}
#ifdef ARM_BF16_SUPPORT #ifdef ARM_BF16_SUPPORT
explicit FP32Vec16(bfloat16x8x2_t v) explicit FP32Vec16(bfloat16x8x2_t v) : reg({
: reg({vcvtq_low_f32_bf16(v.val[0]), vcvtq_high_f32_bf16(v.val[0]), vcvtq_low_f32_bf16(v.val[0]),
vcvtq_low_f32_bf16(v.val[1]), vcvtq_high_f32_bf16(v.val[1])}) {}; vcvtq_high_f32_bf16(v.val[0]),
#endif vcvtq_low_f32_bf16(v.val[1]),
vcvtq_high_f32_bf16(v.val[1])
}) {};
#endif
explicit FP32Vec16(const FP32Vec4& data) { explicit FP32Vec16(const FP32Vec4 &data) {
reg.val[0] = data.reg; reg.val[0] = data.reg;
reg.val[1] = data.reg; reg.val[1] = data.reg;
reg.val[2] = data.reg; reg.val[2] = data.reg;
reg.val[3] = data.reg; reg.val[3] = data.reg;
}; };
#ifdef ARM_BF16_SUPPORT #ifdef ARM_BF16_SUPPORT
explicit FP32Vec16(const BF16Vec16& v) explicit FP32Vec16(const BF16Vec16 &v) : reg({
: reg({vcvtq_low_f32_bf16(v.reg.val[0]), vcvtq_low_f32_bf16(v.reg.val[0]),
vcvtq_high_f32_bf16(v.reg.val[0]), vcvtq_high_f32_bf16(v.reg.val[0]),
vcvtq_low_f32_bf16(v.reg.val[1]), vcvtq_low_f32_bf16(v.reg.val[1]),
vcvtq_high_f32_bf16(v.reg.val[1])}) {}; vcvtq_high_f32_bf16(v.reg.val[1])
}) {};
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}; explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {};
#endif #endif
explicit FP32Vec16(const FP16Vec16& v) { explicit FP32Vec16(const FP16Vec16 &v) {
reg.val[0] = vcvt_f32_f16(vget_low_f16(v.reg.val[0])); reg.val[0] = vcvt_f32_f16(vget_low_f16(v.reg.val[0]));
reg.val[1] = vcvt_f32_f16(vget_high_f16(v.reg.val[0])); reg.val[1] = vcvt_f32_f16(vget_high_f16(v.reg.val[0]));
reg.val[2] = vcvt_f32_f16(vget_low_f16(v.reg.val[1])); reg.val[2] = vcvt_f32_f16(vget_low_f16(v.reg.val[1]));
reg.val[3] = vcvt_f32_f16(vget_high_f16(v.reg.val[1])); reg.val[3] = vcvt_f32_f16(vget_high_f16(v.reg.val[1]));
}; };
FP32Vec16 operator+(const FP32Vec16& b) const { FP32Vec16 operator+(const FP32Vec16 &b) const {
return FP32Vec16(float32x4x4_t({vaddq_f32(reg.val[0], b.reg.val[0]), return FP32Vec16(float32x4x4_t({
vaddq_f32(reg.val[1], b.reg.val[1]), vaddq_f32(reg.val[0], b.reg.val[0]),
vaddq_f32(reg.val[2], b.reg.val[2]), vaddq_f32(reg.val[1], b.reg.val[1]),
vaddq_f32(reg.val[3], b.reg.val[3])})); vaddq_f32(reg.val[2], b.reg.val[2]),
vaddq_f32(reg.val[3], b.reg.val[3])}));
}; };
FP32Vec16 operator*(const FP32Vec16& b) const { FP32Vec16 operator*(const FP32Vec16 &b) const {
return FP32Vec16(float32x4x4_t({vmulq_f32(reg.val[0], b.reg.val[0]), return FP32Vec16(float32x4x4_t({
vmulq_f32(reg.val[1], b.reg.val[1]), vmulq_f32(reg.val[0], b.reg.val[0]),
vmulq_f32(reg.val[2], b.reg.val[2]), vmulq_f32(reg.val[1], b.reg.val[1]),
vmulq_f32(reg.val[3], b.reg.val[3])})); vmulq_f32(reg.val[2], b.reg.val[2]),
vmulq_f32(reg.val[3], b.reg.val[3])}));
}; };
FP32Vec16 operator-(const FP32Vec16& b) const { FP32Vec16 operator-(const FP32Vec16 &b) const {
return FP32Vec16(float32x4x4_t({vsubq_f32(reg.val[0], b.reg.val[0]), return FP32Vec16(float32x4x4_t({
vsubq_f32(reg.val[1], b.reg.val[1]), vsubq_f32(reg.val[0], b.reg.val[0]),
vsubq_f32(reg.val[2], b.reg.val[2]), vsubq_f32(reg.val[1], b.reg.val[1]),
vsubq_f32(reg.val[3], b.reg.val[3])})); vsubq_f32(reg.val[2], b.reg.val[2]),
vsubq_f32(reg.val[3], b.reg.val[3])
}));
}; };
FP32Vec16 operator/(const FP32Vec16& b) const { FP32Vec16 operator/(const FP32Vec16 &b) const {
return FP32Vec16(float32x4x4_t({vdivq_f32(reg.val[0], b.reg.val[0]), return FP32Vec16(float32x4x4_t({
vdivq_f32(reg.val[1], b.reg.val[1]), vdivq_f32(reg.val[0], b.reg.val[0]),
vdivq_f32(reg.val[2], b.reg.val[2]), vdivq_f32(reg.val[1], b.reg.val[1]),
vdivq_f32(reg.val[3], b.reg.val[3])})); vdivq_f32(reg.val[2], b.reg.val[2]),
vdivq_f32(reg.val[3], b.reg.val[3])
}));
}; };
float reduce_sum() const { float reduce_sum() const {
AliasReg ar; AliasReg ar;
ar.reg = reg; ar.reg = reg;
float answer = 0; float answer = 0;
unroll_loop<int, VEC_ELEM_NUM>( unroll_loop<int, VEC_ELEM_NUM>([&answer, &ar](int i) { answer += ar.values[i]; });
[&answer, &ar](int i) { answer += ar.values[i]; });
return answer; return answer;
}; };
template <int group_size> template <int group_size> float reduce_sub_sum(int idx) {
float reduce_sub_sum(int idx) {
static_assert(VEC_ELEM_NUM % group_size == 0); static_assert(VEC_ELEM_NUM % group_size == 0);
AliasReg ar; AliasReg ar;
@ -483,7 +422,7 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
return answer; return answer;
}; };
void save(float* ptr) const { void save(float *ptr) const {
vst1q_f32(ptr, reg.val[0]); vst1q_f32(ptr, reg.val[0]);
vst1q_f32(ptr + 4, reg.val[1]); vst1q_f32(ptr + 4, reg.val[1]);
vst1q_f32(ptr + 8, reg.val[2]); vst1q_f32(ptr + 8, reg.val[2]);
@ -491,59 +430,43 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
}; };
}; };
template <typename T> template <typename T> struct VecType { using vec_type = void; };
struct VecType {
using vec_type = void;
};
template <typename T> template <typename T> using vec_t = typename VecType<T>::vec_type;
using vec_t = typename VecType<T>::vec_type;
template <> template <> struct VecType<float> { using vec_type = FP32Vec8; };
struct VecType<float> {
using vec_type = FP32Vec8;
};
template <> template <> struct VecType<c10::Half> { using vec_type = FP16Vec8; };
struct VecType<c10::Half> {
using vec_type = FP16Vec8;
};
#ifdef ARM_BF16_SUPPORT #ifdef ARM_BF16_SUPPORT
template <> template <> struct VecType<c10::BFloat16> { using vec_type = BF16Vec8; };
struct VecType<c10::BFloat16> {
using vec_type = BF16Vec8;
};
#endif #endif
template <typename T> template <typename T> void storeFP32(float v, T *ptr) { *ptr = v; }
void storeFP32(float v, T* ptr) {
*ptr = v; template <> inline void storeFP32<c10::Half>(float v, c10::Half *ptr) {
*reinterpret_cast<__fp16 *>(ptr) = v;
} }
template <> inline FP16Vec16::FP16Vec16(const FP32Vec16 &v) {
inline void storeFP32<c10::Half>(float v, c10::Half* ptr) { float16x4_t low_0 = vcvt_f16_f32(v.reg.val[0]);
*reinterpret_cast<__fp16*>(ptr) = v; float16x4_t high_0 = vcvt_f16_f32(v.reg.val[1]);
} float16x4_t low_1 = vcvt_f16_f32(v.reg.val[2]);
float16x4_t high_1 = vcvt_f16_f32(v.reg.val[3]);
inline FP16Vec16::FP16Vec16(const FP32Vec16& v) { reg.val[0] = vcombine_f16(low_0, high_0);
float16x4_t low_0 = vcvt_f16_f32(v.reg.val[0]); reg.val[1] = vcombine_f16(low_1, high_1);
float16x4_t high_0 = vcvt_f16_f32(v.reg.val[1]);
float16x4_t low_1 = vcvt_f16_f32(v.reg.val[2]);
float16x4_t high_1 = vcvt_f16_f32(v.reg.val[3]);
reg.val[0] = vcombine_f16(low_0, high_0);
reg.val[1] = vcombine_f16(low_1, high_1);
}; };
inline FP16Vec8 ::FP16Vec8(const FP32Vec8& v) { inline FP16Vec8 :: FP16Vec8(const FP32Vec8 &v) {
float16x4_t lower_half = vcvt_f16_f32(v.reg.val[0]); float16x4_t lower_half = vcvt_f16_f32(v.reg.val[0]);
float16x4_t upper_half = vcvt_f16_f32(v.reg.val[1]); float16x4_t upper_half = vcvt_f16_f32(v.reg.val[1]);
reg = vcombine_f16(lower_half, upper_half); reg = vcombine_f16(lower_half, upper_half);
}; };
inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) { inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) {
acc.reg.val[0] = vfmaq_f32(acc.reg.val[0], a.reg.val[0], b.reg.val[0]); acc.reg.val[0] = vfmaq_f32(acc.reg.val[0], a.reg.val[0], b.reg.val[0]);
acc.reg.val[1] = vfmaq_f32(acc.reg.val[1], a.reg.val[1], b.reg.val[1]); acc.reg.val[1] = vfmaq_f32(acc.reg.val[1], a.reg.val[1], b.reg.val[1]);
acc.reg.val[2] = vfmaq_f32(acc.reg.val[2], a.reg.val[2], b.reg.val[2]); acc.reg.val[2] = vfmaq_f32(acc.reg.val[2], a.reg.val[2], b.reg.val[2]);
@ -551,7 +474,8 @@ inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
}; };
#ifdef ARM_BF16_SUPPORT #ifdef ARM_BF16_SUPPORT
inline void fma(FP32Vec16& acc, BF16Vec32& a, BF16Vec32& b) { inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) {
float32x4_t a0_low = vcvt_f32_bf16(vget_low_bf16(a.reg.val[0])); float32x4_t a0_low = vcvt_f32_bf16(vget_low_bf16(a.reg.val[0]));
float32x4_t a0_high = vcvt_f32_bf16(vget_high_bf16(a.reg.val[0])); float32x4_t a0_high = vcvt_f32_bf16(vget_high_bf16(a.reg.val[0]));
float32x4_t a1_low = vcvt_f32_bf16(vget_low_bf16(a.reg.val[1])); float32x4_t a1_low = vcvt_f32_bf16(vget_low_bf16(a.reg.val[1]));
@ -570,22 +494,22 @@ inline void fma(FP32Vec16& acc, BF16Vec32& a, BF16Vec32& b) {
#endif #endif
#ifdef ARM_BF16_SUPPORT #ifdef ARM_BF16_SUPPORT
inline BF16Vec8::BF16Vec8(const FP32Vec8& v) inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) : reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1])) {};
: reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1])) {
};
inline BF16Vec16::BF16Vec16(const FP32Vec16& v) inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) : reg({
: reg({vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1]), vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1]),
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[2]), vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[2]), v.reg.val[3])
v.reg.val[3])}) {}; }){};
#endif #endif
inline void prefetch(const void* addr) { __builtin_prefetch(addr, 0, 1); }; inline void prefetch(const void *addr) {
__builtin_prefetch(addr, 0, 1);
};
#ifdef ARM_BF16_SUPPORT #ifdef ARM_BF16_SUPPORT
template <> template <>
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) { inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
*reinterpret_cast<__bf16*>(ptr) = vcvth_bf16_f32(v); *reinterpret_cast<__bf16 *>(ptr) = vcvth_bf16_f32(v);
}; };
#endif #endif
}; // namespace vec_op };

View File

@ -9,40 +9,38 @@
namespace vec_op { namespace vec_op {
// FIXME: FP16 is not fully supported in Torch-CPU // FIXME: FP16 is not fully supported in Torch-CPU
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#ifndef CPU_OP_GUARD #ifndef CPU_OP_GUARD
#define CPU_KERNEL_GUARD_IN(NAME) #define CPU_KERNEL_GUARD_IN(NAME)
#define CPU_KERNEL_GUARD_OUT(NAME) #define CPU_KERNEL_GUARD_OUT(NAME)
#else #else
#define CPU_KERNEL_GUARD_IN(NAME) \ #define CPU_KERNEL_GUARD_IN(NAME) \
std::cout << #NAME << " invoked." << std::endl; std::cout << #NAME << " invoked." << std::endl;
#define CPU_KERNEL_GUARD_OUT(NAME) \ #define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl;
std::cout << #NAME << " exit." << std::endl;
#endif #endif
#define FORCE_INLINE __attribute__((always_inline)) inline #define FORCE_INLINE __attribute__((always_inline)) inline
namespace { namespace {
template <typename T, T... indexes, typename F> template <typename T, T... indexes, typename F>
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F&& f) { constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F &&f) {
(f(std::integral_constant<T, indexes>{}), ...); (f(std::integral_constant<T, indexes>{}), ...);
} }
}; // namespace }; // namespace
template <typename T, T count, typename F, template <typename T, T count, typename F,
typename = std::enable_if_t<std::is_invocable_v<F, T>>> typename = std::enable_if_t<std::is_invocable_v<F, T>>>
constexpr void unroll_loop(F&& f) { constexpr void unroll_loop(F &&f) {
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f)); unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
} }
template <typename T> template <typename T> struct Vec {
struct Vec {
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; } constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }
}; };
@ -70,14 +68,12 @@ struct BF16Vec8 : public Vec<BF16Vec8> {
__vector signed short reg; __vector signed short reg;
explicit BF16Vec8(const void* ptr) explicit BF16Vec8(const void *ptr)
: reg((__vector signed short)vec_xl(0, (__vector signed short*)ptr)) {} : reg((__vector signed short)vec_xl(0, (__vector signed short *)ptr)) {}
explicit BF16Vec8(const FP32Vec8&); explicit BF16Vec8(const FP32Vec8 &);
void save(void* ptr) const { void save(void *ptr) const { *reinterpret_cast<__vector signed short *>(ptr) = reg; }
*reinterpret_cast<__vector signed short*>(ptr) = reg;
}
}; };
struct BF16Vec16 : public Vec<BF16Vec16> { struct BF16Vec16 : public Vec<BF16Vec16> {
@ -85,18 +81,18 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
ss16x8x2_t reg; ss16x8x2_t reg;
explicit BF16Vec16(const void* ptr) { explicit BF16Vec16(const void *ptr) {
// Load 256 bits in two parts // Load 256 bits in two parts
reg.val[0] = (__vector signed short)vec_xl(0, (signed short*)ptr); reg.val[0] = (__vector signed short)vec_xl(0, (signed short *)ptr);
reg.val[1] = (__vector signed short)vec_xl(16, (signed short*)ptr); reg.val[1] = (__vector signed short)vec_xl(16, (signed short *)ptr);
} }
explicit BF16Vec16(const FP32Vec16&); explicit BF16Vec16(const FP32Vec16 &);
void save(void* ptr) const { void save(void *ptr) const {
// Save 256 bits in two parts // Save 256 bits in two parts
vec_xst(reg.val[0], 0, (signed short*)ptr); vec_xst(reg.val[0], 0, (signed short *)ptr);
vec_xst(reg.val[1], 16, (signed short*)ptr); vec_xst(reg.val[1], 16, (signed short *)ptr);
} }
}; };
@ -106,15 +102,19 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
constexpr static int VEC_ELEM_NUM = 32; constexpr static int VEC_ELEM_NUM = 32;
ss16x8x4_t reg; ss16x8x4_t reg;
explicit BF16Vec32(const void* ptr) explicit BF16Vec32(const void *ptr)
: reg(*reinterpret_cast<const ss16x8x4_t*>(ptr)) {} : reg(*reinterpret_cast<const ss16x8x4_t *>(ptr)) {}
explicit BF16Vec32(ss16x8x4_t data) : reg(data) {} explicit BF16Vec32(ss16x8x4_t data) : reg(data) {}
explicit BF16Vec32(const BF16Vec8& vec8_data) explicit BF16Vec32(const BF16Vec8 &vec8_data) : reg({
: reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {} vec8_data.reg,
vec8_data.reg,
vec8_data.reg,
vec8_data.reg
}) {}
void save(void* ptr) const { *reinterpret_cast<ss16x8x4_t*>(ptr) = reg; } void save(void *ptr) const { *reinterpret_cast<ss16x8x4_t *>(ptr) = reg; }
}; };
struct FP32Vec4 : public Vec<FP32Vec4> { struct FP32Vec4 : public Vec<FP32Vec4> {
@ -130,11 +130,11 @@ struct FP32Vec4 : public Vec<FP32Vec4> {
explicit FP32Vec4() : reg(vec_splats(0.0f)) {} explicit FP32Vec4() : reg(vec_splats(0.0f)) {}
explicit FP32Vec4(const float* ptr) : reg(vec_xl(0, ptr)) {} explicit FP32Vec4(const float *ptr) : reg(vec_xl(0, ptr)) {}
explicit FP32Vec4(__vector float data) : reg(data) {} explicit FP32Vec4(__vector float data) : reg(data) {}
explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {} explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {}
}; };
struct FP32Vec8 : public Vec<FP32Vec8> { struct FP32Vec8 : public Vec<FP32Vec8> {
@ -156,19 +156,19 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
reg.val[1] = vec_splats(0.0f); reg.val[1] = vec_splats(0.0f);
} }
explicit FP32Vec8(const float* ptr) { explicit FP32Vec8(const float *ptr) {
reg.val[0] = vec_xl(0, ptr); reg.val[0] = vec_xl(0, ptr);
reg.val[1] = vec_xl(16, ptr); reg.val[1] = vec_xl(16, ptr);
} }
explicit FP32Vec8(f32x4x2_t data) : reg(data) {} explicit FP32Vec8(f32x4x2_t data) : reg(data) {}
explicit FP32Vec8(const FP32Vec8& data) { explicit FP32Vec8(const FP32Vec8 &data) {
reg.val[0] = data.reg.val[0]; reg.val[0] = data.reg.val[0];
reg.val[1] = data.reg.val[1]; reg.val[1] = data.reg.val[1];
} }
explicit FP32Vec8(const BF16Vec8& v) { explicit FP32Vec8(const BF16Vec8 &v) {
reg.val[0] = (__vector float)vec_mergeh(zero, v.reg); reg.val[0] = (__vector float)vec_mergeh(zero, v.reg);
reg.val[1] = (__vector float)vec_mergel(zero, v.reg); reg.val[1] = (__vector float)vec_mergel(zero, v.reg);
} }
@ -177,8 +177,7 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
AliasReg ar; AliasReg ar;
ar.reg = reg; ar.reg = reg;
float result = 0; float result = 0;
unroll_loop<int, VEC_ELEM_NUM>( unroll_loop<int, VEC_ELEM_NUM>([&result, &ar](int i) { result += ar.values[i]; });
[&result, &ar](int i) { result += ar.values[i]; });
return result; return result;
} }
@ -231,27 +230,23 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]})); return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
} }
FP32Vec8 operator*(const FP32Vec8& b) const { FP32Vec8 operator*(const FP32Vec8 &b) const {
return FP32Vec8( return FP32Vec8({vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1])});
{vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1])});
} }
FP32Vec8 operator+(const FP32Vec8& b) const { FP32Vec8 operator+(const FP32Vec8 &b) const {
return FP32Vec8( return FP32Vec8({vec_add(reg.val[0], b.reg.val[0]), vec_add(reg.val[1], b.reg.val[1])});
{vec_add(reg.val[0], b.reg.val[0]), vec_add(reg.val[1], b.reg.val[1])});
} }
FP32Vec8 operator-(const FP32Vec8& b) const { FP32Vec8 operator-(const FP32Vec8 &b) const {
return FP32Vec8( return FP32Vec8({vec_sub(reg.val[0], b.reg.val[0]), vec_sub(reg.val[1], b.reg.val[1])});
{vec_sub(reg.val[0], b.reg.val[0]), vec_sub(reg.val[1], b.reg.val[1])});
} }
FP32Vec8 operator/(const FP32Vec8& b) const { FP32Vec8 operator/(const FP32Vec8 &b) const {
return FP32Vec8( return FP32Vec8({vec_div(reg.val[0], b.reg.val[0]), vec_div(reg.val[1], b.reg.val[1])});
{vec_div(reg.val[0], b.reg.val[0]), vec_div(reg.val[1], b.reg.val[1])});
} }
void save(float* ptr) const { void save(float *ptr) const {
vec_xst(reg.val[0], 0, ptr); vec_xst(reg.val[0], 0, ptr);
vec_xst(reg.val[1], 16, ptr); vec_xst(reg.val[1], 16, ptr);
} }
@ -280,7 +275,7 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
reg.val[3] = vec_splats(0.0f); reg.val[3] = vec_splats(0.0f);
} }
explicit FP32Vec16(const float* ptr) { explicit FP32Vec16(const float *ptr) {
reg.val[0] = vec_xl(0, ptr); reg.val[0] = vec_xl(0, ptr);
reg.val[1] = vec_xl(16, ptr); reg.val[1] = vec_xl(16, ptr);
reg.val[2] = vec_xl(32, ptr); reg.val[2] = vec_xl(32, ptr);
@ -289,76 +284,78 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
explicit FP32Vec16(f32x4x4_t data) : reg(data) {} explicit FP32Vec16(f32x4x4_t data) : reg(data) {}
explicit FP32Vec16(const FP32Vec16& data) { explicit FP32Vec16(const FP32Vec16 &data) {
reg.val[0] = data.reg.val[0]; reg.val[0] = data.reg.val[0];
reg.val[1] = data.reg.val[1]; reg.val[1] = data.reg.val[1];
reg.val[2] = data.reg.val[2]; reg.val[2] = data.reg.val[2];
reg.val[3] = data.reg.val[3]; reg.val[3] = data.reg.val[3];
} }
explicit FP32Vec16(const FP32Vec4& data) { explicit FP32Vec16(const FP32Vec4 &data) {
reg.val[0] = data.reg; reg.val[0] = data.reg;
reg.val[1] = data.reg; reg.val[1] = data.reg;
reg.val[2] = data.reg; reg.val[2] = data.reg;
reg.val[3] = data.reg; reg.val[3] = data.reg;
} }
explicit FP32Vec16(const FP32Vec8& data) { explicit FP32Vec16(const FP32Vec8 &data) {
reg.val[0] = data.reg.val[0]; reg.val[0] = data.reg.val[0];
reg.val[1] = data.reg.val[1]; reg.val[1] = data.reg.val[1];
reg.val[2] = data.reg.val[0]; reg.val[2] = data.reg.val[0];
reg.val[3] = data.reg.val[1]; reg.val[3] = data.reg.val[1];
} }
explicit FP32Vec16(const BF16Vec16& v) { explicit FP32Vec16(const BF16Vec16 &v) {
reg.val[0] = (__vector float)vec_mergeh(zero, v.reg.val[0]); reg.val[0] = (__vector float)vec_mergeh(zero, v.reg.val[0]);
reg.val[1] = (__vector float)vec_mergel(zero, v.reg.val[0]); reg.val[1] = (__vector float)vec_mergel(zero, v.reg.val[0]);
reg.val[2] = (__vector float)vec_mergeh(zero, v.reg.val[1]); reg.val[2] = (__vector float)vec_mergeh(zero, v.reg.val[1]);
reg.val[3] = (__vector float)vec_mergel(zero, v.reg.val[1]); reg.val[3] = (__vector float)vec_mergel(zero, v.reg.val[1]);
} }
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {} explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
FP32Vec16 operator*(const FP32Vec16& b) const { FP32Vec16 operator*(const FP32Vec16 &b) const {
return FP32Vec16(f32x4x4_t({vec_mul(reg.val[0], b.reg.val[0]), return FP32Vec16(f32x4x4_t({
vec_mul(reg.val[1], b.reg.val[1]), vec_mul(reg.val[0], b.reg.val[0]),
vec_mul(reg.val[2], b.reg.val[2]), vec_mul(reg.val[1], b.reg.val[1]),
vec_mul(reg.val[3], b.reg.val[3])})); vec_mul(reg.val[2], b.reg.val[2]),
vec_mul(reg.val[3], b.reg.val[3])}));
} }
FP32Vec16 operator+(const FP32Vec16& b) const { FP32Vec16 operator+(const FP32Vec16 &b) const {
return FP32Vec16(f32x4x4_t({vec_add(reg.val[0], b.reg.val[0]), return FP32Vec16(f32x4x4_t({
vec_add(reg.val[1], b.reg.val[1]), vec_add(reg.val[0], b.reg.val[0]),
vec_add(reg.val[2], b.reg.val[2]), vec_add(reg.val[1], b.reg.val[1]),
vec_add(reg.val[3], b.reg.val[3])})); vec_add(reg.val[2], b.reg.val[2]),
vec_add(reg.val[3], b.reg.val[3])}));
} }
FP32Vec16 operator-(const FP32Vec16& b) const { FP32Vec16 operator-(const FP32Vec16 &b) const {
return FP32Vec16(f32x4x4_t({vec_sub(reg.val[0], b.reg.val[0]), return FP32Vec16(f32x4x4_t({
vec_sub(reg.val[1], b.reg.val[1]), vec_sub(reg.val[0], b.reg.val[0]),
vec_sub(reg.val[2], b.reg.val[2]), vec_sub(reg.val[1], b.reg.val[1]),
vec_sub(reg.val[3], b.reg.val[3])})); vec_sub(reg.val[2], b.reg.val[2]),
vec_sub(reg.val[3], b.reg.val[3])}));
} }
FP32Vec16 operator/(const FP32Vec16& b) const { FP32Vec16 operator/(const FP32Vec16 &b) const {
return FP32Vec16(f32x4x4_t({vec_div(reg.val[0], b.reg.val[0]), return FP32Vec16(f32x4x4_t({
vec_div(reg.val[1], b.reg.val[1]), vec_div(reg.val[0], b.reg.val[0]),
vec_div(reg.val[2], b.reg.val[2]), vec_div(reg.val[1], b.reg.val[1]),
vec_div(reg.val[3], b.reg.val[3])})); vec_div(reg.val[2], b.reg.val[2]),
vec_div(reg.val[3], b.reg.val[3])}));
} }
float reduce_sum() const { float reduce_sum() const {
AliasReg ar; AliasReg ar;
ar.reg = reg; ar.reg = reg;
float result = 0; float result = 0;
unroll_loop<int, VEC_ELEM_NUM>( unroll_loop<int, VEC_ELEM_NUM>([&result, &ar](int i) { result += ar.values[i]; });
[&result, &ar](int i) { result += ar.values[i]; });
return result; return result;
} }
template <int group_size> template <int group_size> float reduce_sub_sum(int idx) {
float reduce_sub_sum(int idx) {
static_assert(VEC_ELEM_NUM % group_size == 0); static_assert(VEC_ELEM_NUM % group_size == 0);
AliasReg ar; AliasReg ar;
@ -371,7 +368,7 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
return result; return result;
} }
void save(float* ptr) const { void save(float *ptr) const {
vec_xst(reg.val[0], 0, ptr); vec_xst(reg.val[0], 0, ptr);
vec_xst(reg.val[1], 16, ptr); vec_xst(reg.val[1], 16, ptr);
vec_xst(reg.val[2], 32, ptr); vec_xst(reg.val[2], 32, ptr);
@ -379,62 +376,43 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
} }
}; };
template <typename T> template <typename T> struct VecType { using vec_type = void; };
struct VecType {
using vec_type = void;
};
template <typename T> template <typename T> using vec_t = typename VecType<T>::vec_type;
using vec_t = typename VecType<T>::vec_type;
template <> template <> struct VecType<float> { using vec_type = FP32Vec8; };
struct VecType<float> {
using vec_type = FP32Vec8;
};
template <> template <> struct VecType<c10::BFloat16> { using vec_type = BF16Vec8; };
struct VecType<c10::BFloat16> {
using vec_type = BF16Vec8;
};
template <typename T> template <typename T> void storeFP32(float v, T *ptr) { *ptr = v; }
void storeFP32(float v, T* ptr) {
*ptr = v;
}
inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) { inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) {
acc = acc + a * b; acc = acc + a * b;
} }
template <> template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) { c10::BFloat16 __attribute__((__may_alias__)) *v_ptr =
c10::BFloat16 __attribute__((__may_alias__))* v_ptr = reinterpret_cast<c10::BFloat16 *>(&v);
reinterpret_cast<c10::BFloat16*>(&v);
*ptr = *(v_ptr + 1); *ptr = *(v_ptr + 1);
} }
#ifndef __VEC_CLASS_FP_NAN #ifndef __VEC_CLASS_FP_NAN
#define __VEC_CLASS_FP_NAN (1 << 6) #define __VEC_CLASS_FP_NAN (1 << 6)
#endif #endif
const static __vector unsigned char omask = {0, 1, 4, 5, 8, 9, 12, 13, const static __vector unsigned char omask = { 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29 };
16, 17, 20, 21, 24, 25, 28, 29};
#ifndef _ARCH_PWR10 #ifndef _ARCH_PWR10
const static __vector unsigned int bias = {0x00007fff, 0x00007fff, 0x00007fff, const static __vector unsigned int bias = { 0x00007fff, 0x00007fff, 0x00007fff, 0x00007fff };
0x00007fff}; const static __vector unsigned int nan = { 0x7fc00000, 0x7fc00000, 0x7fc00000, 0x7fc00000 };
const static __vector unsigned int nan = {0x7fc00000, 0x7fc00000, 0x7fc00000, const static __vector unsigned int sh16 = { 16, 16, 16, 16 };
0x7fc00000}; const static __vector unsigned int one = { 1, 1, 1, 1 };
const static __vector unsigned int sh16 = {16, 16, 16, 16};
const static __vector unsigned int one = {1, 1, 1, 1};
#endif #endif
inline BF16Vec8::BF16Vec8(const FP32Vec8& v) { inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) {
#ifdef _ARCH_PWR10 #ifdef _ARCH_PWR10
__vector signed short ret[2]; __vector signed short ret[2];
ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16( ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[0]);
(__vector unsigned char)v.reg.val[0]); ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[1]);
ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16(
(__vector unsigned char)v.reg.val[1]);
reg = vec_perm(ret[0], ret[1], omask); reg = vec_perm(ret[0], ret[1], omask);
#elif defined(_ARCH_PWR9) #elif defined(_ARCH_PWR9)
__vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]); __vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]);
@ -447,10 +425,8 @@ inline BF16Vec8::BF16Vec8(const FP32Vec8& v) {
__vector unsigned int rnd1 = vec_add(lsb1, bias); __vector unsigned int rnd1 = vec_add(lsb1, bias);
inp0 = vec_add(inp0, rnd0); inp0 = vec_add(inp0, rnd0);
inp1 = vec_add(inp1, rnd1); inp1 = vec_add(inp1, rnd1);
__vector __bool int sel0 = __vector __bool int sel0 = vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN);
vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN); __vector __bool int sel1 = vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN);
__vector __bool int sel1 =
vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN);
inp0 = vec_sel(inp0, nan, sel0); inp0 = vec_sel(inp0, nan, sel0);
inp1 = vec_sel(inp1, nan, sel1); inp1 = vec_sel(inp1, nan, sel1);
inp0 = vec_sr(inp0, sh16); inp0 = vec_sr(inp0, sh16);
@ -459,17 +435,13 @@ inline BF16Vec8::BF16Vec8(const FP32Vec8& v) {
#endif #endif
} }
inline BF16Vec16::BF16Vec16(const FP32Vec16& v) { inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) {
#ifdef _ARCH_PWR10 #ifdef _ARCH_PWR10
__vector signed short ret[4]; __vector signed short ret[4];
ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16( ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[0]);
(__vector unsigned char)v.reg.val[0]); ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[1]);
ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16( ret[2] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[2]);
(__vector unsigned char)v.reg.val[1]); ret[3] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[3]);
ret[2] = (__vector signed short)__builtin_vsx_xvcvspbf16(
(__vector unsigned char)v.reg.val[2]);
ret[3] = (__vector signed short)__builtin_vsx_xvcvspbf16(
(__vector unsigned char)v.reg.val[3]);
reg.val[0] = vec_perm(ret[0], ret[1], omask); reg.val[0] = vec_perm(ret[0], ret[1], omask);
reg.val[1] = vec_perm(ret[2], ret[3], omask); reg.val[1] = vec_perm(ret[2], ret[3], omask);
#elif defined(_ARCH_PWR9) #elif defined(_ARCH_PWR9)
@ -493,14 +465,10 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
inp1 = vec_add(inp1, rnd1); inp1 = vec_add(inp1, rnd1);
inp2 = vec_add(inp2, rnd2); inp2 = vec_add(inp2, rnd2);
inp3 = vec_add(inp3, rnd3); inp3 = vec_add(inp3, rnd3);
__vector __bool int sel0 = __vector __bool int sel0 = vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN);
vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN); __vector __bool int sel1 = vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN);
__vector __bool int sel1 = __vector __bool int sel2 = vec_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN);
vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN); __vector __bool int sel3 = vec_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN);
__vector __bool int sel2 =
vec_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN);
__vector __bool int sel3 =
vec_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN);
inp0 = vec_sel(inp0, nan, sel0); inp0 = vec_sel(inp0, nan, sel0);
inp1 = vec_sel(inp1, nan, sel1); inp1 = vec_sel(inp1, nan, sel1);
inp2 = vec_sel(inp2, nan, sel2); inp2 = vec_sel(inp2, nan, sel2);
@ -514,10 +482,10 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
#endif #endif
} }
inline void prefetch(const void* addr) { inline void prefetch(const void *addr) {
__asm__ __volatile__("dcbt 0, %0" : : "r"(addr) : "memory"); __asm__ __volatile__("dcbt 0, %0" : : "r"(addr) : "memory");
} }
}; // namespace vec_op }; // namespace vec_op
#endif #endif

Some files were not shown because too many files have changed in this diff Show More