Compare commits

..

75 Commits

Author SHA1 Message Date
c00ddd6834 Add buffer donation to benchmark 2024-04-30 21:58:47 +00:00
881b884046 Add block size 2024-04-27 22:35:28 +00:00
98a3df0f8d Disable memory tracking 2024-04-26 08:56:26 +00:00
3f6288cc89 Fix for binary cache 2024-04-26 08:56:12 +00:00
408ff4950c Tune pages_per_compute_block 2024-04-26 08:55:23 +00:00
278e8a1adc Add tpu 2024-04-26 08:54:52 +00:00
07be6ed3eb Improve benchmark 2024-04-26 08:54:41 +00:00
f6637dba18 Use persistent cache 2024-04-26 07:09:44 +00:00
707a5f6473 Move JAX-smi to worker 2024-04-26 07:05:51 +00:00
57690a9c09 Fix bucketing 2024-04-26 07:05:27 +00:00
b15db234ba Add precompilation step 2024-04-26 05:43:08 +00:00
d1591f0f1f Add op benchmark scripts 2024-04-26 05:35:19 +00:00
85d4488458 yapf 2024-04-26 05:31:31 +00:00
8d072dbfbd yapf 2024-04-26 05:30:25 +00:00
d830766c0c yapf 2024-04-26 05:30:08 +00:00
5ae2f81c2b Add warmup + formatting 2024-04-26 05:28:09 +00:00
4ea41d01a9 yapf 2024-04-26 05:27:38 +00:00
d16a348477 Add comment 2024-04-26 05:27:27 +00:00
aa092834bb Format gemma.py 2024-04-26 05:26:38 +00:00
d2c6a32c0c Fix is_tpu 2024-04-26 05:26:24 +00:00
21f35c2289 Change version 2024-04-26 05:00:26 +00:00
2aa9831dd3 Minor 2024-04-25 23:40:44 +00:00
028f528aad Fix KV cache shape 2024-04-25 23:38:07 +00:00
fa5bacd5b0 Add warmup 2024-04-25 05:06:41 +00:00
b62170e4e3 Fix scheduler 2024-04-25 05:06:22 +00:00
98eda57899 Add timer 2024-04-25 05:06:11 +00:00
81b8b813f1 Pad to avoid recompilation 2024-04-25 04:43:33 +00:00
e2c7dedb3a Minor 2024-04-25 03:28:53 +00:00
5323969fcf Increase #blocks 2024-04-24 08:56:58 +00:00
f42b4c27d8 Include argmax to jit 2024-04-24 08:56:45 +00:00
620e7646d3 Fix cache write 2024-04-24 08:56:30 +00:00
d5fb1c20c1 Fix JAX jit OOM 2024-04-24 07:52:56 +00:00
092e3d6d6d Remove hardcoded path 2024-04-19 08:18:10 +00:00
84284302d8 Minor 2024-04-19 08:08:25 +00:00
743695f586 Fix write_to_kv_cache 2024-04-19 07:51:54 +00:00
62b870fa07 Use FlashAttention kernel 2024-04-17 20:24:45 +00:00
7e3a230c38 Fix paged_attn 2024-04-17 20:06:26 +00:00
186c88c497 explictly return new_kv_caches 2024-04-17 18:42:34 +00:00
ef762cb110 Write kV 2024-04-17 18:21:39 +00:00
756c4e78d3 Add write_to_cache ops 2024-04-17 18:20:55 +00:00
4880de35d2 Add attn_mask 2024-04-17 18:12:20 +00:00
0fb07c08d0 Minor 2024-04-17 18:08:33 +00:00
e4377dd698 Add model runner 2024-04-17 18:04:54 +00:00
5cb213c85e Add flash-attn op 2024-04-17 18:02:28 +00:00
25bbc21ef6 Minor 2024-04-17 18:02:16 +00:00
b25fcc06c2 Minor 2024-04-17 18:02:13 +00:00
6661c030c4 Add paged_attn op 2024-04-17 18:02:00 +00:00
8888d1c474 Fix logit indices 2024-04-17 18:01:43 +00:00
cedb67028a Add gemma 2024-04-17 17:00:10 +00:00
91b47e3f2f JAX-based TPU worker 2024-04-16 17:37:11 +00:00
6d62e4c6aa Add torch to dependencies 2024-04-16 17:06:35 +00:00
de82e95787 Minor 2024-04-16 17:04:46 +00:00
b3b89cf755 Renew TPU executor 2024-04-16 09:42:15 +00:00
6692a30266 Minor 2024-04-16 09:41:53 +00:00
eb0a0466a9 Add JAX requirements 2024-04-16 08:05:54 +00:00
c59c1e7b2c Remove 2024-04-16 08:05:36 +00:00
d4adf92beb Merge branch 'main' into woosuk-tpu 2024-04-16 07:56:53 +00:00
363e6a950f Fix flashattn 2024-04-10 08:02:40 +00:00
696b653193 yapf 2024-04-10 08:02:21 +00:00
0d6402ddfd Fix requirements 2024-04-10 07:52:45 +00:00
60ff6b8c5c Merge branch 'main' into woosuk-tpu 2024-04-10 07:51:35 +00:00
d899009a63 [WIP] Add TPU worker 2024-04-01 08:24:23 +00:00
6894d3efef Add JAX to requirements.txt 2024-04-01 08:23:59 +00:00
38e3d33a62 Add TPU to device config 2024-04-01 08:23:44 +00:00
02e614d922 [WIP] Add Pallas backend 2024-04-01 08:23:32 +00:00
46b31ed98d Fix RoPE output shape 2024-04-01 08:22:47 +00:00
31d05f7edb yapf 2024-04-01 07:07:57 +00:00
4cdb732cef Add TPU to setup 2024-04-01 07:07:38 +00:00
27c592b97b Add get_dtype_size 2024-04-01 06:33:06 +00:00
5083aa9092 Add TPUExecutor 2024-04-01 03:24:07 +00:00
824521c987 Add TPU to DeviceConfig 2024-04-01 03:19:17 +00:00
3b8f43024f Add is_tpu 2024-04-01 03:18:36 +00:00
d148c2ef00 Add requirements 2024-04-01 03:17:43 +00:00
86f073edd6 Add reference 2024-04-01 02:02:13 +00:00
52a1e908e4 Add TPU gemma 2024-04-01 02:01:28 +00:00
539 changed files with 17947 additions and 50841 deletions

View File

@ -1,36 +0,0 @@
import os
import zipfile
MAX_SIZE_MB = 200
def print_top_10_largest_files(zip_file):
with zipfile.ZipFile(zip_file, 'r') as z:
file_sizes = [(f, z.getinfo(f).file_size) for f in z.namelist()]
file_sizes.sort(key=lambda x: x[1], reverse=True)
for f, size in file_sizes[:10]:
print(f"{f}: {size/(1024*1024)} MBs uncompressed.")
def check_wheel_size(directory):
for root, _, files in os.walk(directory):
for f in files:
if f.endswith(".whl"):
wheel_path = os.path.join(root, f)
wheel_size = os.path.getsize(wheel_path)
wheel_size_mb = wheel_size / (1024 * 1024)
if wheel_size_mb > MAX_SIZE_MB:
print(
f"Wheel {wheel_path} is too large ({wheel_size_mb} MB) "
f"compare to the allowed size ({MAX_SIZE_MB} MB).")
print_top_10_largest_files(wheel_path)
return 1
else:
print(f"Wheel {wheel_path} is within the allowed size "
f"({wheel_size_mb} MB).")
return 0
if __name__ == "__main__":
import sys
sys.exit(check_wheel_size(sys.argv[1]))

View File

@ -1,73 +1,38 @@
# This script runs test inside the corresponding ROCm docker container.
# This script build the ROCm docker image and run the API server inside the container.
# It serves a sanity check for compilation and basic model usage.
set -ex
# Print ROCm version
echo "--- ROCm info"
rocminfo
# cleanup older docker images
cleanup_docker() {
# Get Docker's root directory
docker_root=$(docker info -f '{{.DockerRootDir}}')
if [ -z "$docker_root" ]; then
echo "Failed to determine Docker root directory."
exit 1
fi
echo "Docker root directory: $docker_root"
# Check disk usage of the filesystem where Docker's root directory is located
disk_usage=$(df "$docker_root" | tail -1 | awk '{print $5}' | sed 's/%//')
# Define the threshold
threshold=70
if [ "$disk_usage" -gt "$threshold" ]; then
echo "Disk usage is above $threshold%. Cleaning up Docker images and volumes..."
# Remove dangling images (those that are not tagged and not used by any container)
docker image prune -f
# Remove unused volumes
docker volume prune -f
echo "Docker images and volumes cleanup completed."
else
echo "Disk usage is below $threshold%. No cleanup needed."
fi
}
# Try building the docker image
docker build -t rocm -f Dockerfile.rocm .
# Call the cleanup docker function
cleanup_docker
echo "--- Resetting GPUs"
echo "reset" > /opt/amdgpu/etc/gpu_state
while true; do
sleep 3
if grep -q clean /opt/amdgpu/etc/gpu_state; then
echo "GPUs state is \"clean\""
break
fi
done
echo "--- Building container"
sha=$(git rev-parse --short HEAD)
image_name=rocm_${sha}
container_name=rocm_${sha}_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)
docker build \
-t ${image_name} \
-f Dockerfile.rocm \
--progress plain \
.
remove_docker_container() {
docker rm -f ${container_name} || docker image rm -f ${image_name} || true
}
# Setup cleanup
remove_docker_container() { docker rm -f rocm || true; }
trap remove_docker_container EXIT
remove_docker_container
echo "--- Running container"
# Run the image
docker run --device /dev/kfd --device /dev/dri --network host --name rocm rocm python3 -m vllm.entrypoints.api_server &
docker run \
--device /dev/kfd --device /dev/dri \
--network host \
--rm \
-e HF_TOKEN \
--name ${container_name} \
${image_name} \
/bin/bash -c "${@}"
# Wait for the server to start
wait_for_server_to_start() {
timeout=300
counter=0
while [ "$(curl -s -o /dev/null -w ''%{http_code}'' localhost:8000/health)" != "200" ]; do
sleep 1
counter=$((counter + 1))
if [ $counter -ge $timeout ]; then
echo "Timeout after $timeout seconds"
break
fi
done
}
wait_for_server_to_start
# Test a simple prompt
curl -X POST -H "Content-Type: application/json" \
localhost:8000/generate \
-d '{"prompt": "San Francisco is a"}'

View File

@ -9,10 +9,10 @@ cd "$(dirname "${BASH_SOURCE[0]}")/.."
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
# run python-based benchmarks and upload the result to buildkite
python3 benchmarks/benchmark_latency.py --output-json latency_results.json 2>&1 | tee benchmark_latency.txt
python3 benchmarks/benchmark_latency.py 2>&1 | tee benchmark_latency.txt
bench_latency_exit_code=$?
python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 --output-json throughput_results.json 2>&1 | tee benchmark_throughput.txt
python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 2>&1 | tee benchmark_throughput.txt
bench_throughput_exit_code=$?
# run server-based benchmarks and upload the result to buildkite
@ -53,11 +53,6 @@ echo '```' >> benchmark_results.md
tail -n 20 benchmark_serving.txt >> benchmark_results.md # last 20 lines
echo '```' >> benchmark_results.md
# if the agent binary is not found, skip uploading the results, exit 0
if [ ! -f /workspace/buildkite-agent ]; then
exit 0
fi
# upload the results to buildkite
/workspace/buildkite-agent annotate --style "info" --context "benchmark-results" < benchmark_results.md
@ -74,5 +69,4 @@ if [ $bench_serving_exit_code -ne 0 ]; then
exit $bench_serving_exit_code
fi
rm ShareGPT_V3_unfiltered_cleaned_split.json
/workspace/buildkite-agent artifact upload "*.json"
/workspace/buildkite-agent artifact upload openai-*.json

View File

@ -11,4 +11,4 @@ trap remove_docker_container EXIT
remove_docker_container
# Run the image and launch offline inference
docker run --network host --env VLLM_CPU_KVCACHE_SPACE=1 --name cpu-test cpu-test python3 vllm/examples/offline_inference.py
docker run --network host --env VLLM_CPU_KVCACHE_SPACE=1 --name cpu-test cpu-test python3 examples/offline_inference.py

View File

@ -1,51 +0,0 @@
# This script build the Neuron docker image and run the API server inside the container.
# It serves a sanity check for compilation and basic model usage.
set -e
# Try building the docker image
aws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin 763104351884.dkr.ecr.us-west-2.amazonaws.com
# prune old image and containers to save disk space, and only once a day
# by using a timestamp file in tmp.
if [ -f /tmp/neuron-docker-build-timestamp ]; then
last_build=$(cat /tmp/neuron-docker-build-timestamp)
current_time=$(date +%s)
if [ $((current_time - last_build)) -gt 86400 ]; then
docker system prune -f
echo $current_time > /tmp/neuron-docker-build-timestamp
fi
else
echo $(date +%s) > /tmp/neuron-docker-build-timestamp
fi
docker build -t neuron -f Dockerfile.neuron .
# Setup cleanup
remove_docker_container() { docker rm -f neuron || true; }
trap remove_docker_container EXIT
remove_docker_container
# Run the image
docker run --device=/dev/neuron0 --device=/dev/neuron1 --network host --name neuron neuron python3 -m vllm.entrypoints.api_server \
--model TinyLlama/TinyLlama-1.1B-Chat-v1.0 --max-num-seqs 8 --max-model-len 128 --block-size 128 --device neuron --tensor-parallel-size 2 &
# Wait for the server to start
wait_for_server_to_start() {
timeout=300
counter=0
while [ "$(curl -s -o /dev/null -w ''%{http_code}'' localhost:8000/health)" != "200" ]; do
sleep 1
counter=$((counter + 1))
if [ $counter -ge $timeout ]; then
echo "Timeout after $timeout seconds"
break
fi
done
}
wait_for_server_to_start
# Test a simple prompt
curl -X POST -H "Content-Type: application/json" \
localhost:8000/generate \
-d '{"prompt": "San Francisco is a"}'

View File

@ -5,155 +5,100 @@
steps:
- label: Regression Test
mirror_hardwares: [amd]
command: pytest -v -s test_regression.py
working_dir: "/vllm-workspace/tests" # optional
- label: AsyncEngine Test
#mirror_hardwares: [amd]
command: pytest -v -s async_engine
- label: Basic Correctness Test
mirror_hardwares: [amd]
commands:
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py
- VLLM_ATTENTION_BACKEND=ROCM_FLASH pytest -v -s basic_correctness/test_basic_correctness.py
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
- VLLM_ATTENTION_BACKEND=ROCM_FLASH pytest -v -s basic_correctness/test_chunked_prefill.py
- label: Core Test
mirror_hardwares: [amd]
command: pytest -v -s core
- label: Distributed Comm Ops Test
#mirror_hardwares: [amd]
command: pytest -v -s distributed/test_comm_ops.py
working_dir: "/vllm-workspace/tests"
num_gpus: 2
command: pytest -v -s test_comm_ops.py
working_dir: "/vllm-workspace/tests/distributed"
num_gpus: 2 # only support 1 or 2 for now.
- label: Distributed Tests
mirror_hardwares: [amd]
working_dir: "/vllm-workspace/tests"
num_gpus: 2
working_dir: "/vllm-workspace/tests/distributed"
num_gpus: 2 # only support 1 or 2 for now.
commands:
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- pytest -v -s spec_decode/e2e/test_integration_dist.py
- label: Distributed Tests (Multiple Groups)
#mirror_hardwares: [amd]
working_dir: "/vllm-workspace/tests"
num_gpus: 4
commands:
- pytest -v -s distributed/test_pynccl.py
- pytest -v -s test_pynccl.py
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py
- label: Engine Test
mirror_hardwares: [amd]
command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py
command: pytest -v -s engine tokenization test_sequence.py test_config.py
- label: Entrypoints Test
mirror_hardwares: [amd]
commands:
- pytest -v -s test_inputs.py
- pytest -v -s entrypoints -m llm
- pytest -v -s entrypoints -m openai
# these tests have to be separated, because each one will allocate all posible GPU memory
- pytest -v -s entrypoints --ignore=entrypoints/test_server_oot_registration.py
- pytest -v -s entrypoints/test_server_oot_registration.py
- label: Examples Test
working_dir: "/vllm-workspace/examples"
mirror_hardwares: [amd]
commands:
# install aws cli for llava_example.py
# install tensorizer for tensorize_vllm_model.py
- pip install awscli tensorizer
- pip install awscli
- python3 offline_inference.py
- python3 offline_inference_with_prefix.py
- python3 llm_engine_example.py
- python3 llava_example.py
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
- label: Kernels Test %N
#mirror_hardwares: [amd]
command: pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
parallelism: 4
- label: Models Test
#mirror_hardwares: [amd]
commands:
- bash ../.buildkite/download-images.sh
- pytest -v -s models --ignore=models/test_llava.py
- pytest -v -s models --ignore=models/test_llava.py --ignore=models/test_mistral.py
- label: Llava Test
mirror_hardwares: [amd]
commands:
- bash ../.buildkite/download-images.sh
- pytest -v -s models/test_llava.py
- label: Prefix Caching Test
mirror_hardwares: [amd]
commands:
- pytest -v -s prefix_caching
- label: Samplers Test
#mirror_hardwares: [amd]
command: pytest -v -s samplers
- label: LogitsProcessor Test
mirror_hardwares: [amd]
command: pytest -v -s test_logits_processor.py
- label: Utils Test
command: pytest -v -s test_utils.py
- label: Worker Test
mirror_hardwares: [amd]
command: pytest -v -s worker
- label: Speculative decoding tests
#mirror_hardwares: [amd]
command: pytest -v -s spec_decode
- label: LoRA Test %N
#mirror_hardwares: [amd]
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
parallelism: 4
- label: LoRA Long Context (Distributed)
#mirror_hardwares: [amd]
num_gpus: 4
# This test runs llama 13B, so it is required to run on 4 GPUs.
commands:
# Temporarily run this way because we cannot clean up GPU mem usage
# for multi GPU tests.
# TODO(sang): Fix it.
- pytest -v -s lora/test_long_context.py::test_rotary_emb_replaced
- pytest -v -s lora/test_long_context.py::test_batched_rope_kernel
- pytest -v -s lora/test_long_context.py::test_self_consistency
- pytest -v -s lora/test_long_context.py::test_quality
- pytest -v -s lora/test_long_context.py::test_max_len
- label: Tensorizer Test
#mirror_hardwares: [amd]
command: apt-get install curl libsodium23 && pytest -v -s tensorizer_loader
command: apt-get install curl libsodium23 && pytest -v -s tensorizer
- label: Metrics Test
mirror_hardwares: [amd]
command: pytest -v -s metrics
- label: Quantization Test
#mirror_hardwares: [amd]
command: pytest -v -s quantization
- label: Benchmarks
working_dir: "/vllm-workspace/.buildkite"
mirror_hardwares: [amd]
commands:
- pip install aiohttp
- bash run-benchmarks.sh

View File

@ -3,8 +3,16 @@
{% set default_working_dir = "/vllm-workspace/tests" %}
steps:
- label: "AMD Test"
agents:
queue: amd
command: bash .buildkite/run-amd-test.sh
- label: "CPU Test"
command: bash .buildkite/run-cpu-test.sh
- label: ":docker: build image"
commands:
commands:
- "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ."
- "docker push {{ docker_image }}"
env:
@ -13,35 +21,8 @@ steps:
automatic:
- exit_status: -1 # Agent was lost
limit: 5
- exit_status: -10 # Agent was lost
limit: 5
- wait
- group: "AMD Tests"
depends_on: ~
steps:
{% for step in steps %}
{% if step.mirror_hardwares and "amd" in step.mirror_hardwares %}
- label: "AMD: {{ step.label }}"
agents:
queue: amd
command: bash .buildkite/run-amd-test.sh "cd {{ (step.working_dir or default_working_dir) | safe }} ; {{ step.command or (step.commands | join(" ; ")) | safe }}"
env:
DOCKER_BUILDKIT: "1"
{% endif %}
{% endfor %}
- label: "Neuron Test"
depends_on: ~
agents:
queue: neuron
command: bash .buildkite/run-neuron-test.sh
soft_fail: true
- label: "Intel Test"
depends_on: ~
command: bash .buildkite/run-cpu-test.sh
{% for step in steps %}
- label: "{{ step.label }}"
agents:
@ -54,14 +35,9 @@ steps:
automatic:
- exit_status: -1 # Agent was lost
limit: 5
- exit_status: -10 # Agent was lost
limit: 5
plugins:
- kubernetes:
podSpec:
{% if step.num_gpus %}
priorityClassName: gpu-priority-cls-{{ step.num_gpus }}
{% endif %}
volumes:
- name: dshm
emptyDir:

View File

@ -1,26 +0,0 @@
BasedOnStyle: Google
UseTab: Never
IndentWidth: 2
ColumnLimit: 80
# Force pointers to the type for C++.
DerivePointerAlignment: false
PointerAlignment: Left
# Reordering #include statements can (and currently will) introduce errors
SortIncludes: false
# Style choices
AlignConsecutiveAssignments: false
AlignConsecutiveDeclarations: false
IndentPPDirectives: BeforeHash
IncludeCategories:
- Regex: '^<'
Priority: 4
- Regex: '^"(llvm|llvm-c|clang|clang-c|mlir|mlir-c)/'
Priority: 3
- Regex: '^"(qoda|\.\.)/'
Priority: 2
- Regex: '.*'
Priority: 1

View File

@ -18,7 +18,6 @@ body:
# For security purposes, please feel free to check the contents of collect_env.py before running it.
python collect_env.py
```
It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues.
value: |
```text
The output of `python collect_env.py`

View File

@ -18,7 +18,6 @@ body:
# For security purposes, please feel free to check the contents of collect_env.py before running it.
python collect_env.py
```
It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues.
value: |
```text
The output of `python collect_env.py`

View File

@ -18,7 +18,6 @@ body:
# For security purposes, please feel free to check the contents of collect_env.py before running it.
python collect_env.py
```
It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues.
value: |
```text
The output of `python collect_env.py`
@ -58,10 +57,6 @@ body:
If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com.
Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````.
Please set the environment variable `export VLLM_LOGGING_LEVEL=DEBUG` to turn on more logging to help debugging potential issues.
If you experienced crashes or hangs, it would be helpful to run vllm with `export VLLM_TRACE_FUNCTION=1` . All the function calls in vllm will be recorded. Inspect these log files, and tell which function crashes or hangs.
placeholder: |
A clear and concise description of what the bug is.

View File

@ -39,7 +39,6 @@ body:
# For security purposes, please feel free to check the contents of collect_env.py before running it.
python collect_env.py
```
It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues.
value: |
```text
The output of `python collect_env.py`

View File

@ -1,49 +0,0 @@
name: 💬 Request for comments (RFC).
description: Ask for feedback on major architectural changes or design choices.
title: "[RFC]: "
labels: ["RFC"]
body:
- type: markdown
attributes:
value: >
#### Please take a look at previous [RFCs](https://github.com/vllm-project/vllm/issues?q=label%3ARFC+sort%3Aupdated-desc) for reference.
- type: textarea
attributes:
label: Motivation.
description: >
The motivation of the RFC.
validations:
required: true
- type: textarea
attributes:
label: Proposed Change.
description: >
The proposed change of the RFC.
validations:
required: true
- type: textarea
attributes:
label: Feedback Period.
description: >
The feedback period of the RFC. Usually at least one week.
validations:
required: false
- type: textarea
attributes:
label: CC List.
description: >
The list of people you want to CC.
validations:
required: false
- type: textarea
attributes:
label: Any Other Things.
description: >
Any other things you would like to mention.
validations:
required: false
- type: markdown
attributes:
value: >
Thanks for contributing 🎉!

View File

@ -1,42 +0,0 @@
name: clang-format
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
pull_request:
branches:
- main
jobs:
clang-format:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.11"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install clang-format==18.1.5
- name: Running clang-format
run: |
EXCLUDES=(
'csrc/moe/topk_softmax_kernels.cu'
'csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu'
'csrc/punica/bgmv/bgmv_config.h'
'csrc/punica/bgmv/bgmv_impl.cuh'
'csrc/punica/bgmv/vec_dtypes.cuh'
'csrc/punica/punica_ops.cu'
'csrc/punica/type_convert.h'
)
find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \
| grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \
| xargs clang-format --dry-run --Werror

View File

@ -32,19 +32,19 @@ jobs:
pip install types-setuptools
- name: Mypy
run: |
mypy vllm/attention --config-file pyproject.toml
mypy vllm/core --config-file pyproject.toml
mypy vllm/distributed --config-file pyproject.toml
mypy vllm/entrypoints --config-file pyproject.toml
mypy vllm/executor --config-file pyproject.toml
mypy vllm/usage --config-file pyproject.toml
mypy vllm/*.py --config-file pyproject.toml
mypy vllm/transformers_utils --config-file pyproject.toml
mypy vllm/engine --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/model_executor --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml
mypy vllm/logging --config-file pyproject.toml
mypy vllm/model_executor --config-file pyproject.toml
mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
# TODO(sang): Follow up
# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml

View File

@ -49,19 +49,13 @@ jobs:
matrix:
os: ['ubuntu-20.04']
python-version: ['3.8', '3.9', '3.10', '3.11']
pytorch-version: ['2.3.0'] # Must be the most recent version that meets requirements-cuda.txt.
pytorch-version: ['2.2.1'] # Must be the most recent version that meets requirements-cuda.txt.
cuda-version: ['11.8', '12.1']
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Setup ccache
uses: hendrikmuhs/ccache-action@v1.2
with:
create-symlink: true
key: ${{ github.job }}-${{ matrix.python-version }}-${{ matrix.cuda-version }}
- name: Set up Linux Env
if: ${{ runner.os == 'Linux' }}
run: |
@ -82,8 +76,6 @@ jobs:
- name: Build wheel
shell: bash
env:
CMAKE_BUILD_TYPE: Release # do not compile with debug symbol to reduce wheel size
run: |
bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }}
wheel_name=$(ls dist/*whl | xargs -n 1 basename)

View File

@ -8,7 +8,7 @@ module.exports = async (github, context, core) => {
generate_release_notes: true,
name: process.env.RELEASE_TAG,
owner: context.repo.owner,
prerelease: true,
prerelease: false,
repo: context.repo.repo,
tag_name: process.env.RELEASE_TAG,
});

2
.gitignore vendored
View File

@ -70,8 +70,6 @@ instance/
# Sphinx documentation
docs/_build/
docs/source/getting_started/examples/*.rst
!**/*.template.rst
# PyBuilder
.pybuilder/

View File

@ -31,7 +31,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11
# requirements.txt files and should be kept consistent. The ROCm torch
# versions are derived from Dockerfile.rocm
#
set(TORCH_SUPPORTED_VERSION_CUDA "2.3.0")
set(TORCH_SUPPORTED_VERSION_CUDA "2.2.1")
set(TORCH_SUPPORTED_VERSION_ROCM_5X "2.0.1")
set(TORCH_SUPPORTED_VERSION_ROCM_6X "2.1.1")
@ -167,47 +167,15 @@ set(VLLM_EXT_SRC
"csrc/layernorm_kernels.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu"
"csrc/pybind.cpp")
if(VLLM_GPU_LANG STREQUAL "CUDA")
include(FetchContent)
SET(CUTLASS_ENABLE_HEADERS_ONLY=ON)
FetchContent_Declare(
cutlass
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
# CUTLASS 3.5.0
GIT_TAG 7d49e6c7e2f8896c47f586706e67e1fb215529dc
)
FetchContent_MakeAvailable(cutlass)
list(APPEND VLLM_EXT_SRC
"csrc/quantization/aqlm/gemm_kernels.cu"
"csrc/quantization/awq/gemm_kernels.cu"
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
"csrc/custom_all_reduce.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu")
#
# The CUTLASS kernels for Hopper require sm90a to be enabled.
# This is done via the below gencode option, BUT that creates kernels for both sm90 and sm90a.
# That adds an extra 17MB to compiled binary, so instead we selectively enable it.
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0)
set_source_files_properties(
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu"
PROPERTIES
COMPILE_FLAGS
"-gencode arch=compute_90a,code=sm_90a")
endif()
"csrc/quantization/marlin/marlin_cuda_kernel.cu"
"csrc/custom_all_reduce.cu")
endif()
define_gpu_extension_target(
@ -217,7 +185,6 @@ define_gpu_extension_target(
SOURCES ${VLLM_EXT_SRC}
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES}
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
WITH_SOABI)
#
@ -243,13 +210,24 @@ define_gpu_extension_target(
set(VLLM_PUNICA_EXT_SRC
"csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu"
"csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu"
"csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu"
"csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu"
"csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu"
"csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu"
"csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu"
"csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu"
"csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu"
"csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu"
"csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu"
"csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu"
"csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu"
"csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu"
"csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu"
"csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu"
"csrc/punica/punica_ops.cu"
"csrc/punica/punica_pybind.cpp")
"csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu"
"csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu"
"csrc/punica/punica_ops.cc")
#
# Copy GPU compilation flags+update for punica
@ -273,9 +251,6 @@ if (${VLLM_GPU_LANG} STREQUAL "CUDA")
endif()
endforeach()
message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}")
elseif(${VLLM_GPU_LANG} STREQUAL "HIP")
set(VLLM_PUNICA_GPU_ARCHES ${VLLM_GPU_ARCHES})
message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}")
endif()
if (VLLM_PUNICA_GPU_ARCHES)
@ -310,6 +285,11 @@ add_custom_target(default)
if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
message(STATUS "Enabling C extension.")
add_dependencies(default _C)
endif()
if(VLLM_GPU_LANG STREQUAL "CUDA")
message(STATUS "Enabling moe extension.")
add_dependencies(default _moe_C)
# Enable punica if -DVLLM_INSTALL_PUNICA_KERNELS=ON or
# VLLM_INSTALL_PUNICA_KERNELS is set in the environment and
@ -320,8 +300,3 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
add_dependencies(default _punica_C)
endif()
endif()
if(VLLM_GPU_LANG STREQUAL "CUDA")
message(STATUS "Enabling moe extension.")
add_dependencies(default _moe_C)
endif()

View File

@ -1,13 +1,9 @@
# The vLLM Dockerfile is used to construct vLLM image that can be directly used
# to run the OpenAI compatible server.
# Please update any changes made here to
# docs/source/dev/dockerfile/dockerfile.rst and
# docs/source/assets/dev/dockerfile-stages-dependency.png
#################### BASE BUILD IMAGE ####################
# prepare basic build environment
FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS dev
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev
RUN apt-get update -y \
&& apt-get install -y python3-pip git
@ -16,7 +12,7 @@ RUN apt-get update -y \
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
# this won't be needed for future versions of this docker image
# or future versions of triton.
RUN ldconfig /usr/local/cuda-12.4/compat/
RUN ldconfig /usr/local/cuda-12.1/compat/
WORKDIR /workspace
@ -75,15 +71,34 @@ RUN --mount=type=cache,target=/root/.cache/ccache \
--mount=type=cache,target=/root/.cache/pip \
python3 setup.py bdist_wheel --dist-dir=dist
# check the size of the wheel, we cannot upload wheels larger than 100MB
COPY .buildkite/check-wheel-size.py check-wheel-size.py
RUN python3 check-wheel-size.py dist
# the `vllm_nccl` package must be installed from source distribution
# pip is too smart to store a wheel in the cache, and other CI jobs
# will directly use the wheel from the cache, which is not what we want.
# we need to remove it manually
RUN --mount=type=cache,target=/root/.cache/pip \
pip cache remove vllm_nccl*
#################### EXTENSION Build IMAGE ####################
#################### FLASH_ATTENTION Build IMAGE ####################
FROM dev as flash-attn-builder
# max jobs used for build
ARG max_jobs=2
ENV MAX_JOBS=${max_jobs}
# flash attention version
ARG flash_attn_version=v2.5.6
ENV FLASH_ATTN_VERSION=${flash_attn_version}
WORKDIR /usr/src/flash-attention-v2
# Download the wheel or build it if a pre-compiled release doesn't exist
RUN pip --verbose wheel flash-attn==${FLASH_ATTN_VERSION} \
--no-build-isolation --no-deps --no-cache-dir
#################### FLASH_ATTENTION Build IMAGE ####################
#################### vLLM installation IMAGE ####################
# image with vLLM installed
FROM nvidia/cuda:12.4.1-base-ubuntu22.04 AS vllm-base
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS vllm-base
WORKDIR /vllm-workspace
RUN apt-get update -y \
@ -93,12 +108,16 @@ RUN apt-get update -y \
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
# this won't be needed for future versions of this docker image
# or future versions of triton.
RUN ldconfig /usr/local/cuda-12.4/compat/
RUN ldconfig /usr/local/cuda-12.1/compat/
# install vllm wheel first, so that torch etc will be installed
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
--mount=type=cache,target=/root/.cache/pip \
pip install dist/*.whl --verbose
RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \
--mount=type=cache,target=/root/.cache/pip \
pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir
#################### vLLM installation IMAGE ####################

View File

@ -17,6 +17,4 @@ RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.py
RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install
WORKDIR /workspace/
CMD ["/bin/bash"]

View File

@ -1,36 +0,0 @@
# default base image
ARG BASE_IMAGE="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference-neuronx:2.1.1-neuronx-py310-sdk2.17.0-ubuntu20.04"
FROM $BASE_IMAGE
RUN echo "Base image is $BASE_IMAGE"
# Install some basic utilities
RUN apt-get update && apt-get install python3 python3-pip -y
### Mount Point ###
# When launching the container, mount the code directory to /app
ARG APP_MOUNT=/app
VOLUME [ ${APP_MOUNT} ]
WORKDIR ${APP_MOUNT}
RUN python3 -m pip install --upgrade pip
RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas
RUN python3 -m pip install sentencepiece transformers==4.36.2 -U
RUN python3 -m pip install transformers-neuronx --extra-index-url=https://pip.repos.neuron.amazonaws.com -U
RUN python3 -m pip install --pre neuronx-cc==2.12.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U
COPY ./vllm /app/vllm/vllm
COPY ./setup.py /app/vllm/setup.py
COPY ./requirements-common.txt /app/vllm/requirements-common.txt
COPY ./requirements-neuron.txt /app/vllm/requirements-neuron.txt
RUN cd /app/vllm \
&& python3 -m pip install -U -r requirements-neuron.txt
ENV VLLM_BUILD_WITH_NEURON 1
RUN cd /app/vllm \
&& pip install -e . \
&& cd ..
CMD ["/bin/bash"]

View File

@ -14,7 +14,7 @@ RUN echo "Base image is $BASE_IMAGE"
ARG FA_GFX_ARCHS="gfx90a;gfx942"
RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
ARG FA_BRANCH="ae7928c"
ARG FA_BRANCH="3d2b6f5"
RUN echo "FA_BRANCH is $FA_BRANCH"
# whether to build flash-attention
@ -46,7 +46,7 @@ RUN apt-get update && apt-get install -y \
### Mount Point ###
# When launching the container, mount the code directory to /app
ARG APP_MOUNT=/vllm-workspace
ARG APP_MOUNT=/app
VOLUME [ ${APP_MOUNT} ]
WORKDIR ${APP_MOUNT}
@ -89,26 +89,21 @@ RUN if [ "$BUILD_TRITON" = "1" ]; then \
&& cd ../..; \
fi
WORKDIR /vllm-workspace
COPY . .
COPY ./ /app/vllm
#RUN python3 -m pip install pynvml # to be removed eventually
RUN python3 -m pip install --upgrade pip numba
RUN python3 -m pip install xformers==0.0.23 --no-deps
# make sure punica kernels are built (for LoRA)
ENV VLLM_INSTALL_PUNICA_KERNELS=1
# Workaround for ray >= 2.10.0
ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1
ENV VLLM_NCCL_SO_PATH=/opt/rocm/lib/librccl.so
RUN --mount=type=cache,target=/root/.cache/pip \
pip install -U -r requirements-rocm.txt \
&& patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch \
RUN cd /app \
&& cd vllm \
&& pip install -U -r requirements-rocm.txt \
&& if [ "$BUILD_FA" = "1" ]; then \
bash patch_xformers.rocm.sh; fi \
&& patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h /app/vllm/rocm_patch/rocm_bf16.patch \
&& python3 setup.py install \
&& cp build/lib.linux-x86_64-cpython-39/vllm/_C.cpython-39-x86_64-linux-gnu.so vllm/ \
&& cp build/lib.linux-x86_64-cpython-39/vllm/_punica_C.cpython-39-x86_64-linux-gnu.so vllm/ \
&& cd ..
RUN python3 -m pip install --upgrade pip
RUN python3 -m pip install --no-cache-dir ray[all]==2.9.3
CMD ["/bin/bash"]

View File

@ -1,9 +1,6 @@
include LICENSE
include requirements-common.txt
include requirements-cuda.txt
include requirements-rocm.txt
include requirements-neuron.txt
include requirements-cpu.txt
include CMakeLists.txt
recursive-include cmake *

View File

@ -14,17 +14,6 @@ Easy, fast, and cheap LLM serving for everyone
</p>
---
**The Fourth vLLM Bay Area Meetup (June 11th 5:30pm-8pm PT)**
We are thrilled to announce our fourth vLLM Meetup!
The vLLM team will share recent updates and roadmap.
We will also have vLLM collaborators from BentoML and Cloudflare coming up to the stage to discuss their experience in deploying LLMs with vLLM.
Please register [here](https://lu.ma/agivllm) and join us!
---
*Latest News* 🔥
- [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing).
- [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing).
@ -62,14 +51,40 @@ vLLM is flexible and easy to use with:
- (Experimental) Prefix caching support
- (Experimental) Multi-lora support
vLLM seamlessly supports most popular open-source models on HuggingFace, including:
- Transformer-like LLMs (e.g., Llama)
- Mixture-of-Expert LLMs (e.g., Mixtral)
- Multi-modal LLMs (e.g., LLaVA)
vLLM seamlessly supports many Hugging Face models, including the following architectures:
Find the full list of supported models [here](https://docs.vllm.ai/en/latest/models/supported_models.html).
## Getting Started
- Aquila & Aquila2 (`BAAI/AquilaChat2-7B`, `BAAI/AquilaChat2-34B`, `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.)
- Baichuan & Baichuan2 (`baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.)
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
- ChatGLM (`THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, etc.)
- Command-R (`CohereForAI/c4ai-command-r-v01`, etc.)
- DBRX (`databricks/dbrx-base`, `databricks/dbrx-instruct` etc.)
- DeciLM (`Deci/DeciLM-7B`, `Deci/DeciLM-7B-instruct`, etc.)
- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
- Gemma (`google/gemma-2b`, `google/gemma-7b`, etc.)
- GPT-2 (`gpt2`, `gpt2-xl`, etc.)
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.)
- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.)
- InternLM2 (`internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc.)
- Jais (`core42/jais-13b`, `core42/jais-13b-chat`, `core42/jais-30b-v3`, `core42/jais-30b-chat-v3`, etc.)
- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
- MiniCPM (`openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, etc.)
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc.)
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
- OLMo (`allenai/OLMo-1B`, `allenai/OLMo-7B`, etc.)
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
- Orion (`OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc.)
- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.)
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
- Qwen2 (`Qwen/Qwen1.5-7B`, `Qwen/Qwen1.5-7B-Chat`, etc.)
- Qwen2MoE (`Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc.)
- StableLM(`stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.)
- Starcoder2(`bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc.)
- Xverse (`xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc.)
- Yi (`01-ai/Yi-6B`, `01-ai/Yi-34B`, etc.)
Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
@ -77,7 +92,9 @@ Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/get
pip install vllm
```
Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to learn more.
## Getting Started
Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to get started.
- [Installation](https://vllm.readthedocs.io/en/latest/getting_started/installation.html)
- [Quickstart](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html)
- [Supported Models](https://vllm.readthedocs.io/en/latest/models/supported_models.html)
@ -87,32 +104,6 @@ Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to learn more.
We welcome and value any contributions and collaborations.
Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved.
## Sponsors
vLLM is a community project. Our compute resources for development and testing are supported by the following organizations. Thank you for your support!
<!-- Note: Please sort them in alphabetical order. -->
<!-- Note: Please keep these consistent with docs/source/community/sponsors.md -->
- a16z
- AMD
- Anyscale
- AWS
- Crusoe Cloud
- Databricks
- DeepInfra
- Dropbox
- Lambda Lab
- NVIDIA
- Replicate
- Roblox
- RunPod
- Trainy
- UC Berkeley
- UC San Diego
We also have an official fundraising venue through [OpenCollective](https://opencollective.com/vllm). We plan to use the fund to support the development, maintenance, and adoption of vLLM.
## Citation
If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180):

View File

@ -89,9 +89,6 @@ async def async_request_tgi(
output.latency = most_recent_timestamp - st
output.success = True
output.generated_text = data["generated_text"]
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()
@ -138,7 +135,6 @@ async def async_request_trt_llm(
"data:")
data = json.loads(chunk)
output.generated_text += data["text_output"]
timestamp = time.perf_counter()
# First token
if ttft == 0.0:
@ -153,6 +149,7 @@ async def async_request_trt_llm(
most_recent_timestamp = timestamp
output.latency = most_recent_timestamp - st
output.generated_text = json.loads(data)["text_output"]
output.success = True
else:
@ -279,9 +276,6 @@ async def async_request_openai_completions(
output.generated_text = generated_text
output.success = True
output.latency = latency
else:
output.error = response.reason or ""
output.success = False
except Exception:
output.success = False
exc_info = sys.exc_info()

View File

@ -0,0 +1,148 @@
import functools
import time
from typing import Tuple
import chex
import jax
import jax.numpy as jnp
_PAD_SLOT_ID = -1
@jax.jit
def write_to_kv_cache1(
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
slot_mapping: jax.Array, # [batch_size, seq_len]
) -> Tuple[jax.Array, jax.Array]:
num_heads = key.shape[-2]
head_size = key.shape[-1]
key = key.reshape(-1, num_heads, head_size)
key = key.transpose((1, 0, 2))
value = value.reshape(-1, num_heads, head_size)
value = value.transpose((1, 0, 2))
k_cache = k_cache.at[:, slot_mapping.reshape(-1), :].set(key)
v_cache = v_cache.at[:, slot_mapping.reshape(-1), :].set(value)
return k_cache, v_cache
@functools.partial(jax.jit, donate_argnums=(2, 3))
def write_to_kv_cache2(
key: jax.Array, # [batch_size, seq_len, num_heads, head_size]
value: jax.Array, # [batch_size, seq_len, num_heads, head_size]
k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
slot_mapping: jax.Array, # [batch_size, seq_len]
) -> Tuple[jax.Array, jax.Array]:
batch_size = slot_mapping.shape[0]
def cond(val: _IteratorState):
return val.idx < batch_size
def body(val: _IteratorState):
k_cache, v_cache = _write_seq_to_kv_cache(
key[val.idx],
value[val.idx],
val.k_cache,
val.v_cache,
slot_mapping[val.idx],
)
val.k_cache = k_cache
val.v_cache = v_cache
val.idx += 1
return val
iterator = _IteratorState(idx=0, k_cache=k_cache, v_cache=v_cache)
iterator = jax.lax.while_loop(cond, body, iterator)
return iterator.k_cache, iterator.v_cache
@functools.partial(jax.jit, donate_argnums=(2, 3))
def _write_seq_to_kv_cache(
key: jax.Array, # [seq_len, num_heads, head_size]
value: jax.Array, # [seq_len, num_heads, head_size]
k_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
v_cache: jax.Array, # [num_heads, num_blocks * block_size, head_size]
slot_mapping: jax.Array, # [seq_len]
) -> Tuple[jax.Array, jax.Array]:
seq_len = slot_mapping.shape[0]
num_heads, _, head_size = k_cache.shape
# Reshape to match the rank of kv_cache.
key = key.reshape(seq_len, num_heads, 1, head_size)
value = value.reshape(seq_len, num_heads, 1, head_size)
def cond(val: _IteratorState):
return jnp.logical_and(
val.idx < seq_len, slot_mapping[val.idx] != _PAD_SLOT_ID)
def body(val: _IteratorState):
slot_idx = slot_mapping[val.idx]
val.k_cache = jax.lax.dynamic_update_slice(
val.k_cache,
key[val.idx],
(0, slot_idx, 0),
)
val.v_cache = jax.lax.dynamic_update_slice(
val.v_cache,
value[val.idx],
(0, slot_idx, 0),
)
val.idx += 1
return val
iterator = _IteratorState(idx=0, k_cache=k_cache, v_cache=v_cache)
iterator = jax.lax.while_loop(cond, body, iterator)
return iterator.k_cache, iterator.v_cache
@chex.dataclass
class _IteratorState:
idx: jnp.int32
k_cache: jnp.ndarray # [num_heads, num_blocks, block_size, head_size]
v_cache: jnp.ndarray # [num_heads, num_blocks, block_size, head_size]
def benchmark_write_to_kv_cache(
batch_size: int,
seq_len: int,
num_kv_heads: int,
head_size: int,
num_blocks: int,
block_size: int,
version: int = 1,
):
if version == 1:
f = write_to_kv_cache1
elif version == 2:
f = write_to_kv_cache2
else:
raise ValueError(f"Invalid version: {version}")
rng_key = jax.random.PRNGKey(0)
key = jax.random.normal(rng_key, (batch_size, seq_len, num_kv_heads, head_size), dtype=jnp.bfloat16)
value = jax.random.normal(rng_key, (batch_size, seq_len, num_kv_heads, head_size), dtype=jnp.bfloat16)
k_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * block_size, head_size), dtype=jnp.bfloat16)
v_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * block_size, head_size), dtype=jnp.bfloat16)
slot_mapping = jax.random.randint(rng_key, (batch_size, seq_len), 0, num_blocks * block_size, dtype=jnp.int32)
# For JIT compilation.
k_cache, v_cache = f(key, value, k_cache, v_cache, slot_mapping)
k_cache.block_until_ready()
start = time.time()
for _ in range(100):
k_cache, v_cache = f(key, value, k_cache, v_cache, slot_mapping)
k_cache.block_until_ready()
end = time.time()
print(f"Time taken: {(end - start) * 10:.2f} ms")
if __name__ == "__main__":
for num_blocks in [16, 256, 512, 1024, 2048, 8192, 16384]:
print(f"Benchmarking Write to KV Cache w/ {num_blocks} blocks")
benchmark_write_to_kv_cache(16, 256, 16, 256, num_blocks, 16, version=1)

View File

@ -0,0 +1,101 @@
import argparse
import functools
import time
import jax
import jax.numpy as jnp
from jax.experimental.pallas.ops.tpu.paged_attention import paged_attention
BLOCK_SIZE = 16
MAX_NUM_BLOCKS_PER_SEQ = 512
@functools.partial(jax.jit, static_argnums=(6, 7))
def paged_attn(
q: jax.Array, # [batch, 1, num_heads, head_size]
k_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size]
v_cache: jax.Array, # [num_kv_heads, num_blocks * block_size, head_size]
sm_scale: float,
block_tables: jax.Array, # [batch, max_num_blocks_per_batch]
context_lens: jax.Array, # [batch]
block_size: int,
pages_per_compute_block: int,
) -> jax.Array: # [batch, 1, num_heads, head_size]
q = q.squeeze(1)
q = q * sm_scale
head_size = q.shape[-1]
num_slots = k_cache.shape[-2]
k_cache = k_cache.reshape(-1, num_slots // block_size, block_size, head_size)
v_cache = v_cache.reshape(-1, num_slots // block_size, block_size, head_size)
output = paged_attention(
q,
k_cache,
v_cache,
context_lens,
block_tables,
pages_per_compute_block=pages_per_compute_block,
)
return output.reshape(q.shape[0], 1, q.shape[1], q.shape[2])
def benchmark_paged_attn(
batch_size: int,
num_heads: int,
num_kv_heads: int,
head_size: int,
context_len: int,
num_blocks: int,
block_size: int,
pages_per_compute_block: int,
):
rng_key = jax.random.PRNGKey(0)
query = jax.random.normal(rng_key, (batch_size, 1, num_heads, head_size), dtype=jnp.bfloat16)
k_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * block_size, head_size), dtype=jnp.bfloat16)
v_cache = jax.random.normal(rng_key, (num_kv_heads, num_blocks * block_size, head_size), dtype=jnp.bfloat16)
sm_scale = head_size ** -0.5
block_tables = jax.random.randint(rng_key, (batch_size, MAX_NUM_BLOCKS_PER_SEQ), 0, num_blocks, dtype=jnp.int32)
context_lens = jnp.array([context_len] * batch_size, dtype=jnp.int32)
# For JIT compilation.
output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens, block_size, pages_per_compute_block)
output.block_until_ready()
start = time.time()
for _ in range(100):
output = paged_attn(query, k_cache, v_cache, sm_scale, block_tables, context_lens, block_size, pages_per_compute_block)
output.block_until_ready()
end = time.time()
print(f"Time taken: {(end - start) * 10000:.2f} us")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--num-heads", type=int, default=16)
parser.add_argument("--num-kv-heads", type=int, default=16)
parser.add_argument("--head-size", type=int, default=256)
parser.add_argument("--context-len", type=int, default=512)
parser.add_argument("--num-blocks", type=int, default=2048)
args = parser.parse_args()
print(args)
for block_size in [16, 32, 64, 128]:
for pages_per_compute_block in [1, 2, 4, 8, 16, 32, 64, 128]:
if pages_per_compute_block > MAX_NUM_BLOCKS_PER_SEQ:
continue
if block_size * pages_per_compute_block > 1024:
continue
print(f"block_size {block_size}, pages_per_compute_block: {pages_per_compute_block}")
benchmark_paged_attn(
args.batch_size,
args.num_heads,
args.num_kv_heads,
args.head_size,
args.context_len,
args.num_blocks,
block_size,
pages_per_compute_block,
)

View File

@ -1,17 +1,14 @@
"""Benchmark the latency of processing a single batch of requests."""
import argparse
import json
import time
from pathlib import Path
from typing import List, Optional
from typing import Optional
import numpy as np
import torch
from tqdm import tqdm
from vllm import LLM, SamplingParams
from vllm.inputs import PromptStrictInputs
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
def main(args: argparse.Namespace):
@ -20,8 +17,6 @@ def main(args: argparse.Namespace):
# NOTE(woosuk): If the request cannot be processed in a single batch,
# the engine will automatically process the request in multiple batches.
llm = LLM(model=args.model,
speculative_model=args.speculative_model,
num_speculative_tokens=args.num_speculative_tokens,
tokenizer=args.tokenizer,
quantization=args.quantization,
tensor_parallel_size=args.tensor_parallel_size,
@ -32,11 +27,9 @@ def main(args: argparse.Namespace):
quantization_param_path=args.quantization_param_path,
device=args.device,
ray_workers_use_nsight=args.ray_workers_use_nsight,
use_v2_block_manager=args.use_v2_block_manager,
enable_chunked_prefill=args.enable_chunked_prefill,
download_dir=args.download_dir,
block_size=args.block_size,
gpu_memory_utilization=args.gpu_memory_utilization)
block_size=args.block_size)
sampling_params = SamplingParams(
n=args.n,
@ -50,9 +43,7 @@ def main(args: argparse.Namespace):
dummy_prompt_token_ids = np.random.randint(10000,
size=(args.batch_size,
args.input_len))
dummy_inputs: List[PromptStrictInputs] = [{
"prompt_token_ids": batch
} for batch in dummy_prompt_token_ids.tolist()]
dummy_prompt_token_ids = dummy_prompt_token_ids.tolist()
def run_to_completion(profile_dir: Optional[str] = None):
if profile_dir:
@ -63,13 +54,13 @@ def main(args: argparse.Namespace):
],
on_trace_ready=torch.profiler.tensorboard_trace_handler(
str(profile_dir))) as p:
llm.generate(dummy_inputs,
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
sampling_params=sampling_params,
use_tqdm=False)
print(p.key_averages())
else:
start_time = time.perf_counter()
llm.generate(dummy_inputs,
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
sampling_params=sampling_params,
use_tqdm=False)
end_time = time.perf_counter()
@ -101,28 +92,16 @@ def main(args: argparse.Namespace):
for percentage, percentile in zip(percentages, percentiles):
print(f'{percentage}% percentile latency: {percentile} seconds')
# Output JSON results if specified
if args.output_json:
results = {
"avg_latency": np.mean(latencies),
"latencies": latencies.tolist(),
"percentiles": dict(zip(percentages, percentiles.tolist())),
}
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Benchmark the latency of processing a single batch of '
'requests till completion.')
parser.add_argument('--model', type=str, default='facebook/opt-125m')
parser.add_argument('--speculative-model', type=str, default=None)
parser.add_argument('--num-speculative-tokens', type=int, default=None)
parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=[*QUANTIZATION_METHODS, None],
choices=['awq', 'gptq', 'squeezellm', None],
default=None)
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--input-len', type=int, default=32)
@ -157,13 +136,15 @@ if __name__ == '__main__':
action='store_true',
help='enforce eager mode and disable CUDA graph')
parser.add_argument(
'--kv-cache-dtype',
"--kv-cache-dtype",
type=str,
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
default="auto",
help='Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
choices=['auto', 'fp8'],
default='auto',
help=
'Data type for kv cache storage. If "auto", will use model data type. '
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
'common inference criteria.')
parser.add_argument(
'--quantization-param-path',
type=str,
@ -199,7 +180,6 @@ if __name__ == '__main__':
action='store_true',
help='If True, the prefill requests can be chunked based on the '
'max_num_batched_tokens')
parser.add_argument('--use-v2-block-manager', action='store_true')
parser.add_argument(
"--ray-workers-use-nsight",
action='store_true',
@ -210,16 +190,5 @@ if __name__ == '__main__':
default=None,
help='directory to download and load the weights, '
'default to the default cache dir of huggingface')
parser.add_argument(
'--output-json',
type=str,
default=None,
help='Path to save the latency results in JSON format.')
parser.add_argument('--gpu-memory-utilization',
type=float,
default=0.9,
help='the fraction of GPU memory to be used for '
'the model executor, which can range from 0 to 1.'
'If unspecified, will use the default value of 0.9.')
args = parser.parse_args()
main(args)

View File

@ -16,22 +16,20 @@ def test_prefix(llm=None, sampling_params=None, prompts=None):
def main(args):
llm = LLM(model=args.model,
llm = LLM(model="baichuan-inc/Baichuan2-13B-Chat",
tokenizer_mode='auto',
trust_remote_code=True,
enforce_eager=True,
use_v2_block_manager=args.use_v2_block_manager,
tensor_parallel_size=args.tensor_parallel_size,
enable_prefix_caching=args.enable_prefix_caching)
num_prompts = 100
prompts = [PROMPT] * num_prompts
sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
sampling_params = SamplingParams(temperature=0, max_tokens=100)
print("------warm up------")
test_prefix(
llm=llm,
prompts=prompts,
prompts=prompts[:1],
sampling_params=sampling_params,
)
@ -47,16 +45,8 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='Benchmark the performance with or without automatic '
'prefix caching.')
parser.add_argument('--model',
type=str,
default='baichuan-inc/Baichuan2-13B-Chat')
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--output-len', type=int, default=10)
parser.add_argument('--enable-prefix-caching',
action='store_true',
help='enable prefix caching')
parser.add_argument('--use-v2-block-manager',
action='store_true',
help='Use BlockSpaceMangerV2')
args = parser.parse_args()
main(args)

View File

@ -17,10 +17,6 @@ On the client side, run:
--dataset-path <path to dataset> \
--request-rate <request_rate> \ # By default <request_rate> is inf
--num-prompts <num_prompts> # By default <num_prompts> is 1000
when using tgi backend, add
--endpoint /generate_stream
to the end of the command above.
"""
import argparse
import asyncio
@ -31,7 +27,7 @@ import time
import warnings
from dataclasses import dataclass
from datetime import datetime
from typing import AsyncGenerator, List, Optional, Tuple
from typing import AsyncGenerator, List, Tuple
import numpy as np
from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
@ -62,11 +58,7 @@ def sample_sharegpt_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int] = None,
) -> List[Tuple[str, int, int]]:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
@ -76,32 +68,38 @@ def sample_sharegpt_requests(
dataset = [(data["conversations"][0]["value"],
data["conversations"][1]["value"]) for data in dataset]
# Shuffle the dataset.
random.shuffle(dataset)
# some of these will be filtered out, so sample more than we need
sampled_indices = random.sample(range(len(dataset)),
int(num_requests * 1.2))
dataset = [dataset[i] for i in sampled_indices]
# Filter out sequences that are too long or too short
filtered_dataset: List[Tuple[str, int, int]] = []
# Tokenize the prompts and completions.
prompts = [prompt for prompt, _ in dataset]
prompt_token_ids = tokenizer(prompts).input_ids
completions = [completion for _, completion in dataset]
completion_token_ids = tokenizer(completions).input_ids
tokenized_dataset = []
for i in range(len(dataset)):
if len(filtered_dataset) == num_requests:
break
output_len = len(completion_token_ids[i])
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
# Tokenize the prompts and completions.
prompt = dataset[i][0]
prompt_token_ids = tokenizer(prompt).input_ids
completion = dataset[i][1]
completion_token_ids = tokenizer(completion).input_ids
# Filter out too long sequences.
filtered_dataset: List[Tuple[str, int, int]] = []
for prompt, prompt_token_ids, output_len in tokenized_dataset:
prompt_len = len(prompt_token_ids)
output_len = len(completion_token_ids
) if fixed_output_len is None else fixed_output_len
if prompt_len < 4 or output_len < 4:
# Prune too short sequences.
# This is because TGI causes errors when the input or output length
# is too short.
continue
if prompt_len > 1024 or prompt_len + output_len > 2048:
# Prune too long sequences.
continue
filtered_dataset.append((prompt, prompt_len, output_len))
return filtered_dataset
# Sample the requests.
sampled_requests = random.sample(filtered_dataset, num_requests)
return sampled_requests
def sample_sonnet_requests(
@ -215,11 +213,6 @@ def calculate_metrics(
else:
actual_output_lens.append(0)
if completed == 0:
warnings.warn(
"All requests failed. This is likely due to a misconfiguration "
"on the benchmark arguments.",
stacklevel=2)
metrics = BenchmarkMetrics(
completed=completed,
total_input=total_input,
@ -231,9 +224,9 @@ def calculate_metrics(
1000, # ttfts is empty if streaming is not supported by backend
median_ttft_ms=np.median(ttfts or 0) * 1000,
p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
mean_tpot_ms=np.mean(tpots or 0) * 1000,
median_tpot_ms=np.median(tpots or 0) * 1000,
p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
mean_tpot_ms=np.mean(tpots) * 1000,
median_tpot_ms=np.median(tpots) * 1000,
p99_tpot_ms=np.percentile(tpots, 99) * 1000,
)
return metrics, actual_output_lens
@ -255,24 +248,6 @@ async def benchmark(
else:
raise ValueError(f"Unknown backend: {backend}")
print("Starting initial single prompt test run...")
test_prompt, test_prompt_len, test_output_len = input_requests[0]
test_input = RequestFuncInput(
model=model_id,
prompt=test_prompt,
api_url=api_url,
prompt_len=test_prompt_len,
output_len=test_output_len,
best_of=best_of,
use_beam_search=use_beam_search,
)
test_output = await request_func(request_func_input=test_input)
if not test_output.success:
raise ValueError(
"Initial test run failed - Please make sure benchmark arguments "
f"are correctly specified. Error: {test_output.error}")
else:
print("Initial test run completed. Starting main benchmark run...")
print(f"Traffic request rate: {request_rate}")
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
@ -386,7 +361,6 @@ def main(args: argparse.Namespace):
dataset_path=args.dataset,
num_requests=args.num_prompts,
tokenizer=tokenizer,
fixed_output_len=args.sharegpt_output_len,
)
elif args.dataset_name == "sharegpt":
@ -394,7 +368,6 @@ def main(args: argparse.Namespace):
dataset_path=args.dataset_path,
num_requests=args.num_prompts,
tokenizer=tokenizer,
fixed_output_len=args.sharegpt_output_len,
)
elif args.dataset_name == "sonnet":
@ -551,12 +524,6 @@ if __name__ == "__main__":
default=1000,
help="Number of prompts to process.",
)
parser.add_argument(
"--sharegpt-output-len",
type=int,
default=None,
help="Output length for each request. Overrides the output length "
"from the ShareGPT dataset.")
parser.add_argument(
"--sonnet-input-len",
type=int,

View File

@ -10,8 +10,6 @@ from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase)
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
def sample_requests(
dataset_path: str,
@ -103,22 +101,25 @@ def run_vllm(
)
# Add the requests to the engine.
prompts = []
sampling_params = []
for prompt, _, output_len in requests:
prompts.append(prompt)
sampling_params.append(
SamplingParams(
n=n,
temperature=0.0 if use_beam_search else 1.0,
top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=True,
max_tokens=output_len,
))
sampling_params = SamplingParams(
n=n,
temperature=0.0 if use_beam_search else 1.0,
top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=True,
max_tokens=output_len,
)
# FIXME(woosuk): Do not use internal method.
llm._add_request(
prompt=prompt,
prompt_token_ids=None,
sampling_params=sampling_params,
)
start = time.perf_counter()
llm.generate(prompts, sampling_params, use_tqdm=True)
# FIXME(woosuk): Do not use internal method.
llm._run_engine(use_tqdm=True)
end = time.perf_counter()
return end - start
@ -242,18 +243,6 @@ def main(args: argparse.Namespace):
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
# Output JSON results if specified
if args.output_json:
results = {
"elapsed_time": elapsed_time,
"num_requests": len(requests),
"total_num_tokens": total_num_tokens,
"requests_per_second": len(requests) / elapsed_time,
"tokens_per_second": total_num_tokens / elapsed_time,
}
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
@ -278,7 +267,7 @@ if __name__ == "__main__":
parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=[*QUANTIZATION_METHODS, None],
choices=['awq', 'gptq', 'squeezellm', None],
default=None)
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--n",
@ -323,13 +312,15 @@ if __name__ == "__main__":
action="store_true",
help="enforce eager execution")
parser.add_argument(
'--kv-cache-dtype',
"--kv-cache-dtype",
type=str,
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
choices=["auto", "fp8"],
default="auto",
help='Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
help=
'Data type for kv cache storage. If "auto", will use model data type. '
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
'common inference criteria.')
parser.add_argument(
'--quantization-param-path',
type=str,
@ -344,7 +335,7 @@ if __name__ == "__main__":
"--device",
type=str,
default="cuda",
choices=["cuda", "cpu"],
choices=["cuda", "cpu", "tpu"],
help='device type for vLLM execution, supporting CUDA and CPU.')
parser.add_argument(
"--enable-prefix-caching",
@ -363,11 +354,6 @@ if __name__ == "__main__":
default=None,
help='directory to download and load the weights, '
'default to the default cache dir of huggingface')
parser.add_argument(
'--output-json',
type=str,
default=None,
help='Path to save the throughput results in JSON format.')
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model

View File

@ -1,302 +0,0 @@
import argparse
import os
import sys
from typing import Optional
import torch
import torch.nn.functional as F
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.aqlm import (
dequantize_weight, generic_dequantize_gemm, get_int_dtype,
optimized_dequantize_gemm)
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
def torch_mult(
input: torch.Tensor, # [..., in_features]
weights: torch.Tensor,
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
) -> torch.Tensor:
output = F.linear(input, weights)
return output
def dequant_out_scale(
input: torch.Tensor, # [..., in_features]
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
output_partition_sizes: torch.IntTensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
if bias is None:
output = F.linear(input, weights, bias)
orig_shape = output.shape
flattened_output = output.view(-1, output.size(-1))
f_scales = scales.view(-1, scales.shape[0])
b_scales = f_scales.expand(flattened_output.shape[0], -1)
flattened_output *= b_scales
return flattened_output.view(orig_shape)
else:
b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
-1, weights.shape[1])
weights *= b_scales
return F.linear(input, weights, bias)
def dequant_weight_scale(
input: torch.Tensor, # [..., in_features]
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
output_partition_sizes: torch.IntTensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
-1, weights.shape[1])
weights *= b_scales
return F.linear(input, weights, bias)
def dequant_no_scale(
input: torch.Tensor, # [..., in_features]
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
output_partition_sizes: torch.IntTensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
return F.linear(input, weights, bias)
# Compare the optimized 1x16 and 2x8 cuda decompression/dequant kernels against
# the generic pytorch version.
# Just visual comparison.
def dequant_test(k: int, parts: torch.tensor, nbooks: int, bits: int) -> None:
n = parts.sum().item()
device = torch.device('cuda:0')
code_range = (1 << bits) // 2
ingroups = 8
codes = torch.randint(-code_range,
code_range,
size=(n, k // ingroups, nbooks),
dtype=get_int_dtype(bits),
device=device)
codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
dtype=torch.float16,
device=device)
count = 0
for index in range(16):
for i in range(8):
for book in range(nbooks):
codebooks[book, index, 0, i] = count * (10**book)
count += 1
print("codes shape", codes.shape)
for i in range(16):
for book in range(nbooks):
codes[0, i, book] = i
codes[0, -i, book] = i
weights = dequantize_weight(codes, codebooks, None)
weights2 = ops.aqlm_dequant(codes, codebooks, parts)
print("weights shape:", weights.shape)
print("weights2 shape:", weights2.shape)
print("weights are:", weights)
print("weights2 are:", weights2)
print("first 128 weights are", weights[0, 0:128].to(torch.int32))
print("first 128 weights2 are:", weights2[0, 0:128].to(torch.int32))
print("last 128 weights are", weights[0, -128:])
print("last 128 weights2 are:", weights2[0, -128:])
def main():
parser = argparse.ArgumentParser(description="Benchmark aqlm performance.")
# Add arguments
parser.add_argument("--nbooks",
type=int,
default=1,
help="Number of codebooks (default: 1)")
parser.add_argument("--bits",
type=int,
default=16,
help="Number of bits per code element (default: 16)")
parser.add_argument(
"--test",
type=bool,
default=False,
help="Run the decompression/dequant tester rather than benchmarking "
"(default: False)")
# Parse the arguments
args = parser.parse_args()
# Extract values
nbooks = args.nbooks
bits = args.bits
if args.test:
dequant_test(4096, torch.tensor((4096, )), nbooks, bits)
return
# Otherwise, benchmark.
methods = [
ops.aqlm_gemm,
dequant_out_scale,
generic_dequantize_gemm,
optimized_dequantize_gemm,
dequant_weight_scale,
torch_mult,
dequant_no_scale,
]
filename = f"./aqlm_benchmark_{nbooks}x{bits}.csv"
print(f"writing benchmarks to file {filename}")
with open(filename, "w") as f:
sys.stdout = f
print('m | k | n | n parts', end='')
for method in methods:
print(f" | {method.__name__.replace('_', ' ')} (µs)", end='')
print('')
# These are reasonable prefill sizes.
ksandpartions = ((4096, (4096, 4096, 4096)), (4096, (4096, )),
(4096, (11008, 11008)), (11008, (4096, )))
# reasonable ranges for m.
for m in [
1, 2, 4, 8, 10, 12, 14, 16, 24, 32, 48, 52, 56, 64, 96, 112,
128, 256, 512, 1024, 1536, 2048, 3072, 4096
]:
print(f'{m}', file=sys.__stdout__)
for ksp in ksandpartions:
run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits,
methods)
sys.stdout = sys.__stdout__
def run_grid(m: int, k: int, parts: torch.tensor, nbooks: int, bits: int,
methods):
# I didn't see visible improvements from increasing these, but feel free :)
num_warmup_trials = 1
num_trials = 1
num_calls = 100
# warmup.
for method in methods:
for _ in range(num_warmup_trials):
run_timing(
num_calls=num_calls,
m=m,
k=k,
parts=parts,
nbooks=nbooks,
bits=bits,
method=method,
)
n = parts.sum().item()
print(f'{m} | {k} | {n} | {parts.tolist()}', end='')
for method in methods:
best_time_us = 1e20
for _ in range(num_trials):
kernel_dur_ms = run_timing(
num_calls=num_calls,
m=m,
k=k,
parts=parts,
nbooks=nbooks,
bits=bits,
method=method,
)
kernel_dur_us = 1000 * kernel_dur_ms
if kernel_dur_us < best_time_us:
best_time_us = kernel_dur_us
print(f' | {kernel_dur_us:.0f}', end='')
print('')
def run_timing(num_calls: int, m: int, k: int, parts: torch.tensor,
nbooks: int, bits: int, method) -> float:
n = parts.sum().item()
device = torch.device('cuda:0')
input = torch.randn((1, m, k), dtype=torch.float16, device=device)
code_range = (1 << bits) // 2
ingroups = 8
codes = torch.randint(-code_range,
code_range,
size=(n, k // ingroups, nbooks),
dtype=get_int_dtype(bits),
device=device)
codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
dtype=torch.float16,
device=device)
scales = torch.randn(size=(n, 1, 1, 1), dtype=torch.float16, device=device)
# for comparison to just a pytorch mult.
weights = torch.randn((n, k), dtype=torch.float16, device=device)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
if method is torch_mult:
for i in range(num_calls):
torch_mult(input, weights, scales)
else:
for i in range(num_calls):
method(input, codes, codebooks, scales, parts, None)
end_event.record()
end_event.synchronize()
dur_ms = start_event.elapsed_time(end_event) / num_calls
return dur_ms
if __name__ == "__main__":
sys.exit(main())

View File

@ -1,233 +0,0 @@
import argparse
import torch
import torch.utils.benchmark as benchmark
from benchmark_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MarlinWorkspace, marlin_24_quantize, marlin_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
gptq_pack, quantize_weights, sort_weights)
DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True]
def bench_run(results, model, act_order, is_k_full, num_bits, group_size,
size_m, size_k, size_n):
label = "Quant Matmul"
sub_label = ("{}, act={} k_full={}, b={}, g={}, "
"MKN=({}x{}x{})".format(model, act_order, is_k_full, num_bits,
group_size, size_m, size_k, size_n))
print(f"Testing: {sub_label}")
a = torch.randn(size_m, size_k).to(torch.half).cuda()
b = torch.rand(size_k, size_n).to(torch.half).cuda()
a_tmp = (torch.zeros(size_m, size_k).to(torch.half).cuda())
# Marlin quant
(
marlin_w_ref,
marlin_q_w,
marlin_s,
marlin_g_idx,
marlin_sort_indices,
marlin_rand_perm,
) = marlin_quantize(b, num_bits, group_size, act_order)
# Marlin_24 quant
(marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta,
marlin_24_s) = marlin_24_quantize(b, num_bits, group_size)
# GPTQ quant
(w_ref, q_w, s, g_idx,
rand_perm) = quantize_weights(b, num_bits, group_size, act_order)
q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n)
# For act_order, sort the "weights" and "g_idx"
# so that group ids are increasing
repack_sort_indices = torch.empty(0, dtype=torch.int, device=b.device)
if act_order:
(q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx)
# Prepare
marlin_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)
marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_MAX_PARALLEL)
globals = {
# Gen params
"num_bits": num_bits,
"group_size": group_size,
"size_m": size_m,
"size_n": size_n,
"size_k": size_k,
"a": a,
"a_tmp": a_tmp,
# Marlin params
"marlin_w_ref": marlin_w_ref,
"marlin_q_w": marlin_q_w,
"marlin_s": marlin_s,
"marlin_g_idx": marlin_g_idx,
"marlin_sort_indices": marlin_sort_indices,
"marlin_rand_perm": marlin_rand_perm,
"marlin_workspace": marlin_workspace,
"is_k_full": is_k_full,
# Marlin_24 params
"marlin_24_w_ref": marlin_24_w_ref,
"marlin_24_q_w_comp": marlin_24_q_w_comp,
"marlin_24_meta": marlin_24_meta,
"marlin_24_s": marlin_24_s,
"marlin_24_workspace": marlin_24_workspace,
# GPTQ params
"q_w_gptq": q_w_gptq,
"repack_sort_indices": repack_sort_indices,
# Kernels
"gptq_marlin_gemm": ops.gptq_marlin_gemm,
"gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
"gptq_marlin_repack": ops.gptq_marlin_repack,
}
min_run_time = 1
# Warmup pytorch
for i in range(5):
torch.matmul(a, marlin_w_ref)
results.append(
benchmark.Timer(
stmt="torch.matmul(a, marlin_w_ref)",
globals=globals,
label=label,
sub_label=sub_label,
description="pytorch_gemm",
).blocked_autorange(min_run_time=min_run_time))
results.append(
benchmark.Timer(
stmt=
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="gptq_marlin_gemm",
).blocked_autorange(min_run_time=min_run_time))
if (num_bits in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES):
results.append(
benchmark.Timer(
stmt=
"output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, num_bits, size_m, size_n, size_k)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="gptq_marlin_24_gemm",
).blocked_autorange(min_run_time=min_run_time))
results.append(
benchmark.Timer(
stmt=
"q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, num_bits)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="gptq_marlin_repack",
).blocked_autorange(min_run_time=min_run_time))
def main(args):
print("Benchmarking models:")
for i, model in enumerate(args.models):
print(f"[{i}] {model}")
results = []
for model in args.models:
for layer in WEIGHT_SHAPES[model]:
size_k = layer[0]
size_n = layer[1]
if len(args.limit_k) > 0 and size_k not in args.limit_k:
continue
if len(args.limit_n) > 0 and size_n not in args.limit_n:
continue
for act_order in ACT_ORDER_OPTS:
if len(args.limit_act_order
) > 0 and act_order not in args.limit_act_order:
continue
for is_k_full in K_FULL_OPTS:
if len(args.limit_k_full
) > 0 and is_k_full not in args.limit_k_full:
continue
for num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
if len(args.limit_num_bits
) > 0 and num_bits not in args.limit_num_bits:
continue
for group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES:
if len(
args.limit_group_size
) > 0 and group_size not in args.limit_group_size:
continue
# For act_order, the group_size must be less than
# size_k
if act_order and (group_size == size_k
or group_size == -1):
continue
for size_m in args.batch_sizes:
bench_run(results, model, act_order, is_k_full,
num_bits, group_size, size_m, size_k,
size_n)
compare = benchmark.Compare(results)
compare.print()
# For quick benchmarking use:
# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 --limit-num-bits 4 --limit-act-order 0 --limit-k-full 1 # noqa E501
#
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Benchmark Marlin across specified models/shapes/batches")
parser.add_argument(
"--models",
nargs="+",
type=str,
default=DEFAULT_MODELS,
choices=WEIGHT_SHAPES.keys(),
)
parser.add_argument("--batch-sizes",
nargs="+",
type=int,
default=DEFAULT_BATCH_SIZES)
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
parser.add_argument("--limit-group-size", nargs="+", type=int, default=[])
parser.add_argument("--limit-num-bits", nargs="+", type=int, default=[])
parser.add_argument("--limit-act-order", nargs="+", type=int, default=[])
parser.add_argument("--limit-k-full", nargs="+", type=int, default=[])
args = parser.parse_args()
main(args)

View File

@ -1,4 +1,3 @@
import argparse
import json
import os
import sys
@ -6,70 +5,68 @@ import sys
import torch
import torch.nn.functional as F
import triton
from tqdm import tqdm
from vllm.model_executor.layers.fused_moe import (fused_moe,
get_config_file_name)
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
def main(model, tp_size, gpu, dtype: str):
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
def main():
method = fused_moe
for bs in [
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
2048, 3072, 4096
]:
run_grid(bs,
model=model,
method=method,
gpu=gpu,
tp_size=tp_size,
dtype=dtype)
run_grid(bs, method=method)
def run_grid(bs, model, method, gpu, tp_size, dtype: str):
if model == '8x7B':
d_model = 4096
model_intermediate_size = 14336
num_layers = 32
elif model == '8x22B':
d_model = 6144
model_intermediate_size = 16384
num_layers = 56
else:
raise ValueError(f'Unsupported Mixtral model {model}')
def run_grid(bs, method):
d_model = 4096
num_total_experts = 8
top_k = 2
# tp_size = 2
tp_size = 2
model_intermediate_size = 14336
num_layers = 32
num_calls = 100
num_warmup_trials = 1
num_trials = 1
configs = []
if bs <= 16:
BLOCK_SIZES_M = [16]
elif bs <= 32:
BLOCK_SIZES_M = [16, 32]
elif bs <= 64:
BLOCK_SIZES_M = [16, 32, 64]
elif bs <= 128:
BLOCK_SIZES_M = [16, 32, 64, 128]
else:
BLOCK_SIZES_M = [16, 32, 64, 128, 256]
for block_size_n in [32, 64, 128, 256]:
for block_size_m in [16, 32, 64, 128, 256]:
for block_size_m in BLOCK_SIZES_M:
for block_size_k in [64, 128, 256]:
for group_size_m in [1, 16, 32, 64]:
for num_warps in [4, 8]:
for num_stages in [2, 3, 4, 5]:
configs.append({
"BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
"GROUP_SIZE_M": group_size_m,
"num_warps": num_warps,
"num_stages": num_stages,
})
configs.append({
"BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
"GROUP_SIZE_M": group_size_m,
"num_warps": num_warps,
"num_stages": 4,
})
best_config = None
best_time_us = 1e20
print(f'{tp_size=} {bs=}')
for config in tqdm(configs):
for config in configs:
print(f'{tp_size=} {bs=}')
print(f'{config}')
# warmup
print('warming up')
try:
for _ in range(num_warmup_trials):
run_timing(
@ -82,12 +79,12 @@ def run_grid(bs, model, method, gpu, tp_size, dtype: str):
model_intermediate_size=model_intermediate_size,
method=method,
config=config,
dtype=dtype,
)
except triton.runtime.autotuner.OutOfResources:
continue
# trial
print('benchmarking')
for _ in range(num_trials):
kernel_dur_ms = run_timing(
num_calls=num_calls,
@ -99,7 +96,6 @@ def run_grid(bs, model, method, gpu, tp_size, dtype: str):
model_intermediate_size=model_intermediate_size,
method=method,
config=config,
dtype=dtype,
)
kernel_dur_us = 1000 * kernel_dur_ms
@ -109,18 +105,16 @@ def run_grid(bs, model, method, gpu, tp_size, dtype: str):
best_config = config
best_time_us = kernel_dur_us
tqdm.write(
f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}'
f' {bs=} {tp_size=} {top_k=} {num_total_experts=} '
f'{d_model=} {model_intermediate_size=} {num_layers=}')
print(f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}'
f' {bs=} {tp_size=} {top_k=} {num_total_experts=} '
f'{d_model=} {model_intermediate_size=} {num_layers=}')
print("best_time_us", best_time_us)
print("best_config", best_config)
# holds Dict[str, Dict[str, int]]
filename = get_config_file_name(num_total_experts,
model_intermediate_size // tp_size,
"float8" if dtype == "float8" else None)
model_intermediate_size // tp_size)
print(f"writing config to file {filename}")
existing_content = {}
if os.path.exists(filename):
@ -134,48 +128,27 @@ def run_grid(bs, model, method, gpu, tp_size, dtype: str):
def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
top_k: int, tp_size: int, model_intermediate_size: int, method,
config, dtype: str) -> float:
config) -> float:
shard_intermediate_size = model_intermediate_size // tp_size
hidden_states = torch.rand(
(bs, d_model),
device="cuda:0",
dtype=torch.float16,
dtype=torch.bfloat16,
)
w1 = torch.rand(
ws = torch.rand(
(num_total_experts, 2 * shard_intermediate_size, d_model),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
w2 = torch.rand(
w2s = torch.rand(
(num_total_experts, d_model, shard_intermediate_size),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
w1_scale = None
w2_scale = None
a1_scale = None
a2_scale = None
if dtype == "float8":
w1 = w1.to(torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fn)
w1_scale = torch.ones(num_total_experts,
device=hidden_states.device,
dtype=torch.float32)
w2_scale = torch.ones(num_total_experts,
device=hidden_states.device,
dtype=torch.float32)
a1_scale = torch.ones(1,
device=hidden_states.device,
dtype=torch.float32)
a2_scale = torch.ones(1,
device=hidden_states.device,
dtype=torch.float32)
gating_output = F.softmax(torch.rand(
(num_calls, bs, num_total_experts),
device=hidden_states.device,
@ -190,18 +163,13 @@ def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
for i in range(num_calls):
hidden_states = method(
hidden_states=hidden_states,
w1=w1,
w2=w2,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
w1=ws,
w2=w2s,
gating_output=gating_output[i],
topk=2,
renormalize=True,
inplace=True,
override_config=config,
use_fp8=dtype == "float8",
)
end_event.record()
end_event.synchronize()
@ -211,29 +179,4 @@ def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog='benchmark_mixtral_moe',
description='Benchmark and tune the fused_moe kernel',
)
parser.add_argument(
'--dtype',
type=str,
default='auto',
choices=['float8', 'float16'],
help='Data type used for fused_moe kernel computations',
)
parser.add_argument('--model',
type=str,
default='8x7B',
choices=['8x7B', '8x22B'],
help='The Mixtral model to benchmark')
parser.add_argument('--tp-size',
type=int,
default=2,
help='Tensor paralleli size')
parser.add_argument('--gpu',
type=int,
default=0,
help="GPU ID for benchmarking")
args = parser.parse_args()
sys.exit(main(args.model, args.tp_size, args.gpu, args.dtype))
sys.exit(main())

View File

@ -16,7 +16,7 @@ PARTITION_SIZE = 512
def main(
version: str,
num_seqs: int,
seq_len: int,
context_len: int,
num_query_heads: int,
num_kv_heads: int,
head_size: int,
@ -48,12 +48,12 @@ def main(
dtype=torch.float,
device=device)
seq_lens = [seq_len for _ in range(num_seqs)]
max_seq_len = max(seq_lens)
seq_lens = torch.tensor(seq_lens, dtype=torch.int, device=device)
context_lens = [context_len for _ in range(num_seqs)]
max_context_len = max(context_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int, device=device)
# Create the block tables.
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
block_tables = []
for _ in range(num_seqs):
block_table = [
@ -77,7 +77,8 @@ def main(
# Prepare for the paged attention kernel.
output = torch.empty_like(query)
if version == "v2":
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
PARTITION_SIZE)
tmp_output = torch.empty(
size=(num_seqs, num_query_heads, num_partitions, head_size),
dtype=output.dtype,
@ -109,9 +110,9 @@ def main(
num_kv_heads,
scale,
block_tables,
seq_lens,
context_lens,
block_size,
max_seq_len,
max_context_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
@ -128,9 +129,9 @@ def main(
num_kv_heads,
scale,
block_tables,
seq_lens,
context_lens,
block_size,
max_seq_len,
max_context_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
@ -165,12 +166,12 @@ if __name__ == '__main__':
choices=["v1", "v2"],
default="v2")
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--seq_len", type=int, default=4096)
parser.add_argument("--context-len", type=int, default=4096)
parser.add_argument("--num-query-heads", type=int, default=64)
parser.add_argument("--num-kv-heads", type=int, default=8)
parser.add_argument("--head-size",
type=int,
choices=[64, 80, 96, 112, 128, 192, 256],
choices=[64, 80, 96, 112, 128, 256],
default=128)
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
parser.add_argument("--use-alibi", action="store_true")
@ -183,11 +184,13 @@ if __name__ == '__main__':
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"],
choices=["auto", "fp8"],
default="auto",
help="Data type for kv cache storage. If 'auto', will use model "
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
"ROCm (AMD GPU) supports fp8 (=fp8_e4m3)")
help=
'Data type for kv cache storage. If "auto", will use model data type. '
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
'common inference criteria.')
args = parser.parse_args()
print(args)
@ -196,7 +199,7 @@ if __name__ == '__main__':
main(
version=args.version,
num_seqs=args.batch_size,
seq_len=args.seq_len,
context_len=args.context_len,
num_query_heads=args.num_query_heads,
num_kv_heads=args.num_kv_heads,
head_size=args.head_size,

View File

@ -93,7 +93,7 @@ if __name__ == '__main__':
parser.add_argument("--num-heads", type=int, default=8)
parser.add_argument("--head-size",
type=int,
choices=[64, 80, 96, 112, 128, 192, 256],
choices=[64, 80, 96, 112, 128, 256],
default=128)
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
parser.add_argument("--dtype",

View File

@ -1,75 +0,0 @@
WEIGHT_SHAPES = {
"ideal": [[4 * 256 * 32, 256 * 32]],
"mistralai/Mistral-7B-v0.1/TP1": [
[4096, 6144],
[4096, 4096],
[4096, 28672],
[14336, 4096],
],
"mistralai/Mistral-7B-v0.1/TP2": [
[4096, 3072],
[2048, 4096],
[4096, 14336],
[7168, 4096],
],
"mistralai/Mistral-7B-v0.1/TP4": [
[4096, 1536],
[1024, 4096],
[4096, 7168],
[3584, 4096],
],
"meta-llama/Llama-2-7b-hf/TP1": [
[4096, 12288],
[4096, 4096],
[4096, 22016],
[11008, 4096],
],
"meta-llama/Llama-2-7b-hf/TP2": [
[4096, 6144],
[2048, 4096],
[4096, 11008],
[5504, 4096],
],
"meta-llama/Llama-2-7b-hf/TP4": [
[4096, 3072],
[1024, 4096],
[4096, 5504],
[2752, 4096],
],
"meta-llama/Llama-2-13b-hf/TP1": [
[5120, 15360],
[5120, 5120],
[5120, 27648],
[13824, 5120],
],
"meta-llama/Llama-2-13b-hf/TP2": [
[5120, 7680],
[2560, 5120],
[5120, 13824],
[6912, 5120],
],
"meta-llama/Llama-2-13b-hf/TP4": [
[5120, 3840],
[1280, 5120],
[5120, 6912],
[3456, 5120],
],
"meta-llama/Llama-2-70b-hf/TP1": [
[8192, 10240],
[8192, 8192],
[8192, 57344],
[28672, 8192],
],
"meta-llama/Llama-2-70b-hf/TP2": [
[8192, 5120],
[4096, 8192],
[8192, 28672],
[14336, 8192],
],
"meta-llama/Llama-2-70b-hf/TP4": [
[8192, 2560],
[2048, 8192],
[8192, 14336],
[7168, 8192],
],
}

View File

@ -4,7 +4,7 @@ PORT=8000
MODEL=$1
TOKENS=$2
docker run -e HF_TOKEN=$HF_TOKEN --gpus all --shm-size 1g -p $PORT:80 \
docker run --gpus all --shm-size 1g -p $PORT:80 \
-v $PWD/data:/data \
ghcr.io/huggingface/text-generation-inference:1.4.0 \
--model-id $MODEL \

View File

@ -1,63 +0,0 @@
import argparse
import cProfile
import pstats
from vllm import LLM, SamplingParams
# A very long prompt, total number of tokens is about 15k.
LONG_PROMPT = ["You are an expert in large language models, aren't you?"
] * 1000
LONG_PROMPT = ' '.join(LONG_PROMPT)
def main(args):
llm = LLM(
model=args.model,
enforce_eager=True,
enable_prefix_caching=True,
tensor_parallel_size=args.tensor_parallel_size,
use_v2_block_manager=args.use_v2_block_manager,
)
sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
profiler = cProfile.Profile()
print("------warm up------")
for i in range(3):
output = llm.generate(LONG_PROMPT, sampling_params)
print(output[0].outputs[0].text)
print("------start generating------")
for i in range(3):
profiler.runctx('llm.generate(LONG_PROMPT, sampling_params)',
globals(), locals())
# analyze the runtime of hashing function
stats = pstats.Stats(profiler)
stats.sort_stats('cumulative')
total_time = 0
total_calls = 0
for func in stats.stats:
if 'hash_of_block' in func[2]:
total_time = stats.stats[func][3]
total_calls = stats.stats[func][0]
percentage = (total_time / stats.total_tt) * 100
print(f"Hashing took {total_time:.2f} seconds,"
f"{percentage:.2f}% of the total runtime.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='Benchmark the performance of hashing function in'
'automatic prefix caching.')
parser.add_argument('--model', type=str, default='lmsys/longchat-7b-16k')
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--output-len', type=int, default=10)
parser.add_argument('--enable-prefix-caching',
action='store_true',
help='enable prefix caching')
parser.add_argument('--use-v2-block-manager',
action='store_true',
help='Use BlockSpaceMangerV2')
args = parser.parse_args()
main(args)

View File

@ -99,7 +99,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
"Failed to determine torch nvcc compiler flags")
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8)
list(APPEND GPU_FLAGS "-DENABLE_FP8")
list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2")
endif()
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0)
list(REMOVE_ITEM GPU_FLAGS
@ -119,7 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
list(APPEND GPU_FLAGS
"-DUSE_ROCM"
"-DENABLE_FP8"
"-DENABLE_FP8_E4M3"
"-U__HIP_NO_HALF_CONVERSIONS__"
"-U__HIP_NO_HALF_OPERATORS__"
"-fno-gpu-rdc")

View File

@ -63,7 +63,6 @@ DEFAULT_CONDA_PATTERNS = {
"magma",
"triton",
"optree",
"nccl",
}
DEFAULT_PIP_PATTERNS = {
@ -74,7 +73,6 @@ DEFAULT_PIP_PATTERNS = {
"triton",
"optree",
"onnx",
"nccl",
}

View File

@ -10,11 +10,11 @@
namespace vllm {
// Activation and gating kernel template.
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void act_and_mul_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d) {
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
@ -23,66 +23,72 @@ __global__ void act_and_mul_kernel(
}
}
template <typename T>
template<typename T>
__device__ __forceinline__ T silu_kernel(const T& x) {
// x * sigmoid(x)
return (T)(((float)x) / (1.0f + expf((float)-x)));
return (T) (((float) x) / (1.0f + expf((float) -x)));
}
template <typename T>
template<typename T>
__device__ __forceinline__ T gelu_kernel(const T& x) {
// Equivalent to PyTorch GELU with 'none' approximation.
// Refer to:
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
const float f = (float)x;
const float f = (float) x;
constexpr float ALPHA = M_SQRT1_2;
return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA)));
return (T) (f * 0.5f * (1.0f + ::erf(f * ALPHA)));
}
template <typename T>
template<typename T>
__device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
// Equivalent to PyTorch GELU with 'tanh' approximation.
// Refer to:
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30
const float f = (float)x;
const float f = (float) x;
constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f;
constexpr float KAPPA = 0.044715;
float x_cube = f * f * f;
float inner = BETA * (f + KAPPA * x_cube);
return (T)(0.5f * f * (1.0f + ::tanhf(inner)));
return (T) (0.5f * f * (1.0f + ::tanhf(inner)));
}
} // namespace vllm
} // namespace vllm
// Launch activation and gating kernel.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "act_and_mul_kernel", [&] { \
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
});
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \
"act_and_mul_kernel", \
[&] { \
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), \
d); \
});
void silu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
void silu_and_mul(
torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
}
void gelu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
void gelu_and_mul(
torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
}
void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
void gelu_tanh_and_mul(
torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
}
@ -90,11 +96,11 @@ void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
namespace vllm {
// Element-wise activation kernel template.
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void activation_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., d]
const int d) {
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., d]
const int d) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
@ -102,49 +108,54 @@ __global__ void activation_kernel(
}
}
} // namespace vllm
} // namespace vllm
// Launch element-wise activation kernel.
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
int d = input.size(-1); \
int64_t num_tokens = input.numel() / d; \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
});
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
int d = input.size(-1); \
int64_t num_tokens = input.numel() / d; \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \
"activation_kernel", \
[&] { \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), \
d); \
});
namespace vllm {
template <typename T>
template<typename T>
__device__ __forceinline__ T gelu_new_kernel(const T& x) {
const float x3 = (float)(x * x * x);
const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3))));
return ((T)0.5) * x * (((T)1.0) + t);
const float x3 = (float) (x * x * x);
const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3))));
return ((T) 0.5) * x * (((T) 1.0) + t);
}
template <typename T>
template<typename T>
__device__ __forceinline__ T gelu_fast_kernel(const T& x) {
const float f = (float)x;
const T t =
(T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x));
return ((T)0.5) * x * (((T)1.0) + t);
const float f = (float) x;
const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x));
return ((T) 0.5) * x * (((T) 1.0) + t);
}
} // namespace vllm
} // namespace vllm
void gelu_new(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., d]
void gelu_new(
torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., d]
{
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
}
void gelu_fast(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., d]
void gelu_fast(
torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., d]
{
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
}

View File

@ -1,6 +1,5 @@
/*
* Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
@ -23,31 +22,31 @@
namespace vllm {
// A vector type to store Q, K, V elements.
template <typename T, int VEC_SIZE>
template<typename T, int VEC_SIZE>
struct Vec {};
// A vector type to store FP32 accumulators.
template <typename T>
template<typename T>
struct FloatVec {};
// Template vector operations.
template <typename Acc, typename A, typename B>
template<typename Acc, typename A, typename B>
inline __device__ Acc mul(A a, B b);
template <typename T>
template<typename T>
inline __device__ float sum(T v);
template <typename T>
template<typename T>
inline __device__ float dot(T a, T b) {
return sum(mul<T, T, T>(a, b));
}
template <typename A, typename T>
template<typename A, typename T>
inline __device__ float dot(T a, T b) {
return sum(mul<A, T, T>(a, b));
}
template <typename T>
template<typename T>
inline __device__ void zero(T& dst) {
constexpr int WORDS = sizeof(T) / 4;
union {
@ -62,4 +61,4 @@ inline __device__ void zero(T& dst) {
dst = tmp.raw;
}
} // namespace vllm
} // namespace vllm

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,5 @@
/*
* Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
@ -27,7 +26,7 @@
namespace vllm {
// Q*K^T operation.
template <int THREAD_GROUP_SIZE, typename Vec, int N>
template<int THREAD_GROUP_SIZE, typename Vec, int N>
inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
using A_vec = typename FloatVec<Vec>::Type;
// Compute the parallel products for Q*K^T (treat vector lanes separately).
@ -46,12 +45,12 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
return qk;
}
template <typename T, int THREAD_GROUP_SIZE>
template<typename T, int THREAD_GROUP_SIZE>
struct Qk_dot {
template <typename Vec, int N>
template<typename Vec, int N>
static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) {
return qk_dot_<THREAD_GROUP_SIZE>(q, k);
}
};
} // namespace vllm
} // namespace vllm

View File

@ -1,8 +1,6 @@
/*
* Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* and
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
@ -30,8 +28,8 @@
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
typedef __hip_bfloat162 __nv_bfloat162;
typedef __hip_bfloat16 __nv_bfloat16;
typedef __hip_bfloat162 __nv_bfloat162;
typedef __hip_bfloat16 __nv_bfloat16;
#endif
#include <stdint.h>
@ -52,37 +50,37 @@ struct bf16_8_t {
};
// BF16 vector types for Q, K, V.
template <>
template<>
struct Vec<__nv_bfloat16, 1> {
using Type = __nv_bfloat16;
};
template <>
template<>
struct Vec<__nv_bfloat16, 2> {
using Type = __nv_bfloat162;
};
template <>
template<>
struct Vec<__nv_bfloat16, 4> {
using Type = bf16_4_t;
};
template <>
template<>
struct Vec<__nv_bfloat16, 8> {
using Type = bf16_8_t;
};
// FP32 accumulator vector types corresponding to Vec.
template <>
template<>
struct FloatVec<__nv_bfloat16> {
using Type = float;
};
template <>
template<>
struct FloatVec<__nv_bfloat162> {
using Type = float2;
};
template <>
template<>
struct FloatVec<bf16_4_t> {
using Type = Float4_;
};
template <>
template<>
struct FloatVec<bf16_8_t> {
using Type = Float8_;
};
@ -110,9 +108,9 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
assert(false);
#else
#ifndef USE_ROCM
return a + b;
return a + b;
#else
return __hadd(a, b);
return __hadd(a, b);
#endif
#endif
}
@ -163,7 +161,7 @@ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
}
// Vector multiplication.
template <>
template<>
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
@ -172,7 +170,7 @@ inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
#endif
}
template <>
template<>
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
@ -181,12 +179,12 @@ inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
#endif
}
template <>
template<>
inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
}
template <>
template<>
inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) {
bf16_4_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
@ -194,7 +192,7 @@ inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) {
return c;
}
template <>
template<>
inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) {
__nv_bfloat162 s = bf162bf162(a);
bf16_4_t c;
@ -203,7 +201,7 @@ inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) {
return c;
}
template <>
template<>
inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) {
bf16_8_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
@ -213,7 +211,7 @@ inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) {
return c;
}
template <>
template<>
inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) {
__nv_bfloat162 s = bf162bf162(a);
bf16_8_t c;
@ -224,26 +222,26 @@ inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) {
return c;
}
template <>
template<>
inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) {
float fa = __bfloat162float(a);
float fb = __bfloat162float(b);
return fa * fb;
}
template <>
template<>
inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
float2 fa = bf1622float2(a);
float2 fb = bf1622float2(b);
return mul<float2, float2, float2>(fa, fb);
}
template <>
template<>
inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
}
template <>
template<>
inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) {
Float4_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
@ -251,7 +249,7 @@ inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) {
return fc;
}
template <>
template<>
inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) {
__nv_bfloat162 s = bf162bf162(a);
Float4_ fc;
@ -260,7 +258,7 @@ inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) {
return fc;
}
template <>
template<>
inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) {
Float8_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
@ -270,7 +268,7 @@ inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) {
return fc;
}
template <>
template<>
inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
__nv_bfloat162 s = bf162bf162(a);
Float8_ fc;
@ -282,8 +280,7 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
}
// Vector fused multiply-add.
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b,
__nv_bfloat162 c) {
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
@ -291,8 +288,7 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b,
#endif
}
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b,
__nv_bfloat162 c) {
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
@ -383,23 +379,23 @@ inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) {
}
// Vector sum.
template <>
template<>
inline __device__ float sum(__nv_bfloat16 v) {
return __bfloat162float(v);
}
template <>
template<>
inline __device__ float sum(__nv_bfloat162 v) {
float2 vf = bf1622float2(v);
return vf.x + vf.y;
}
template <>
template<>
inline __device__ float sum(bf16_4_t v) {
return sum(v.x) + sum(v.y);
}
template <>
template<>
inline __device__ float sum(bf16_8_t v) {
return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w);
}
@ -452,4 +448,4 @@ inline __device__ void zero(__nv_bfloat16& dst) {
#endif
}
} // namespace vllm
} // namespace vllm

View File

@ -1,8 +1,6 @@
/*
* Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* and
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
@ -32,37 +30,37 @@
namespace vllm {
// FP16 vector types for Q, K, V.
template <>
template<>
struct Vec<uint16_t, 1> {
using Type = uint16_t;
};
template <>
template<>
struct Vec<uint16_t, 2> {
using Type = uint32_t;
};
template <>
template<>
struct Vec<uint16_t, 4> {
using Type = uint2;
};
template <>
template<>
struct Vec<uint16_t, 8> {
using Type = uint4;
};
// FP32 accumulator vector types corresponding to Vec.
template <>
template<>
struct FloatVec<uint16_t> {
using Type = float;
};
template <>
template<>
struct FloatVec<uint32_t> {
using Type = float2;
};
template <>
template<>
struct FloatVec<uint2> {
using Type = Float4_;
};
template <>
template<>
struct FloatVec<uint4> {
using Type = Float8_;
};
@ -75,8 +73,8 @@ inline __device__ uint32_t h0_h0(uint16_t a) {
return b;
#else
union {
uint32_t u32;
uint16_t u16[2];
uint32_t u32;
uint16_t u16[2];
} tmp;
tmp.u16[0] = a;
tmp.u16[1] = a;
@ -132,12 +130,10 @@ inline __device__ uint32_t float2_to_half2(float2 f) {
} tmp;
#ifndef USE_ROCM
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n"
: "=r"(tmp.u32)
: "f"(f.y), "f"(f.x));
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
#else
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
#endif
#else
tmp.u16[0] = float_to_half(f.x);
@ -205,7 +201,7 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) {
}
// Vector multiplication.
template <>
template<>
inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
uint16_t c;
#ifndef USE_ROCM
@ -216,7 +212,7 @@ inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
return c;
}
template <>
template<>
inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
uint32_t c;
#ifndef USE_ROCM
@ -227,12 +223,12 @@ inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
return c;
}
template <>
template<>
inline __device__ uint32_t mul(uint16_t a, uint32_t b) {
return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b);
}
template <>
template<>
inline __device__ uint2 mul(uint2 a, uint2 b) {
uint2 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
@ -240,7 +236,7 @@ inline __device__ uint2 mul(uint2 a, uint2 b) {
return c;
}
template <>
template<>
inline __device__ uint2 mul(uint16_t a, uint2 b) {
uint32_t s = h0_h0(a);
uint2 c;
@ -249,7 +245,7 @@ inline __device__ uint2 mul(uint16_t a, uint2 b) {
return c;
}
template <>
template<>
inline __device__ uint4 mul(uint4 a, uint4 b) {
uint4 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
@ -259,7 +255,7 @@ inline __device__ uint4 mul(uint4 a, uint4 b) {
return c;
}
template <>
template<>
inline __device__ uint4 mul(uint16_t a, uint4 b) {
uint32_t s = h0_h0(a);
uint4 c;
@ -270,26 +266,26 @@ inline __device__ uint4 mul(uint16_t a, uint4 b) {
return c;
}
template <>
template<>
inline __device__ float mul(uint16_t a, uint16_t b) {
float fa = half_to_float(a);
float fb = half_to_float(b);
return fa * fb;
}
template <>
template<>
inline __device__ float2 mul(uint32_t a, uint32_t b) {
float2 fa = half2_to_float2(a);
float2 fb = half2_to_float2(b);
return mul<float2, float2, float2>(fa, fb);
}
template <>
template<>
inline __device__ float2 mul(uint16_t a, uint32_t b) {
return mul<float2, uint32_t, uint32_t>(h0_h0(a), b);
}
template <>
template<>
inline __device__ Float4_ mul(uint2 a, uint2 b) {
Float4_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
@ -297,7 +293,7 @@ inline __device__ Float4_ mul(uint2 a, uint2 b) {
return fc;
}
template <>
template<>
inline __device__ Float4_ mul(uint16_t a, uint2 b) {
uint32_t s = h0_h0(a);
Float4_ fc;
@ -306,7 +302,7 @@ inline __device__ Float4_ mul(uint16_t a, uint2 b) {
return fc;
}
template <>
template<>
inline __device__ Float8_ mul(uint4 a, uint4 b) {
Float8_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
@ -316,7 +312,7 @@ inline __device__ Float8_ mul(uint4 a, uint4 b) {
return fc;
}
template <>
template<>
inline __device__ Float8_ mul(uint16_t a, uint4 b) {
uint32_t s = h0_h0(a);
Float8_ fc;
@ -331,13 +327,9 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) {
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
uint32_t d;
#ifndef USE_ROCM
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(d)
: "r"(a), "r"(b), "r"(c));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
#else
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n"
: "=v"(d)
: "v"(a), "v"(b), "v"(c));
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c));
#endif
return d;
}
@ -431,24 +423,24 @@ inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) {
}
// Vector sum.
template <>
template<>
inline __device__ float sum(uint16_t v) {
return half_to_float(v);
}
template <>
template<>
inline __device__ float sum(uint32_t v) {
float2 tmp = half2_to_float2(v);
return tmp.x + tmp.y;
}
template <>
template<>
inline __device__ float sum(uint2 v) {
uint32_t c = add(v.x, v.y);
return sum(c);
}
template <>
template<>
inline __device__ float sum(uint4 v) {
uint32_t c = add(v.x, v.y);
c = add(c, v.z);
@ -478,9 +470,13 @@ inline __device__ void from_float(uint4& dst, Float8_ src) {
}
// From float16 to float32.
inline __device__ float to_float(uint16_t u) { return half_to_float(u); }
inline __device__ float to_float(uint16_t u) {
return half_to_float(u);
}
inline __device__ float2 to_float(uint32_t u) { return half2_to_float2(u); }
inline __device__ float2 to_float(uint32_t u) {
return half2_to_float2(u);
}
inline __device__ Float4_ to_float(uint2 u) {
Float4_ tmp;
@ -499,6 +495,8 @@ inline __device__ Float8_ to_float(uint4 u) {
}
// Zero-out a variable.
inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); }
inline __device__ void zero(uint16_t& dst) {
dst = uint16_t(0);
}
} // namespace vllm
} // namespace vllm

View File

@ -1,8 +1,6 @@
/*
* Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* and
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
@ -40,35 +38,37 @@ struct Float8_ {
};
// FP32 vector types for Q, K, V.
template <>
template<>
struct Vec<float, 1> {
using Type = float;
};
template <>
template<>
struct Vec<float, 2> {
using Type = float2;
};
template <>
template<>
struct Vec<float, 4> {
using Type = float4;
};
// FP32 accumulator vector types corresponding to Vec.
template <>
template<>
struct FloatVec<float> {
using Type = float;
};
template <>
template<>
struct FloatVec<float2> {
using Type = float2;
};
template <>
template<>
struct FloatVec<float4> {
using Type = float4;
};
// Vector addition.
inline __device__ float add(float a, float b) { return a + b; }
inline __device__ float add(float a, float b) {
return a + b;
}
inline __device__ float2 add(float2 a, float2 b) {
float2 c;
@ -87,12 +87,12 @@ inline __device__ float4 add(float4 a, float4 b) {
}
// Vector multiplication.
template <>
template<>
inline __device__ float mul<float, float>(float a, float b) {
return a * b;
}
template <>
template<>
inline __device__ float2 mul(float2 a, float2 b) {
float2 c;
c.x = a.x * b.x;
@ -100,7 +100,7 @@ inline __device__ float2 mul(float2 a, float2 b) {
return c;
}
template <>
template<>
inline __device__ float2 mul(float a, float2 b) {
float2 c;
c.x = a * b.x;
@ -108,7 +108,7 @@ inline __device__ float2 mul(float a, float2 b) {
return c;
}
template <>
template<>
inline __device__ float4 mul(float4 a, float4 b) {
float4 c;
c.x = a.x * b.x;
@ -118,7 +118,7 @@ inline __device__ float4 mul(float4 a, float4 b) {
return c;
}
template <>
template<>
inline __device__ float4 mul(float a, float4 b) {
float4 c;
c.x = a * b.x;
@ -129,7 +129,9 @@ inline __device__ float4 mul(float a, float4 b) {
}
// Vector fused multiply-add.
inline __device__ float fma(float a, float b, float c) { return a * b + c; }
inline __device__ float fma(float a, float b, float c) {
return a * b + c;
}
inline __device__ float2 fma(float2 a, float2 b, float2 c) {
float2 d;
@ -180,33 +182,35 @@ inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
}
// Vector sum.
template <>
template<>
inline __device__ float sum(float v) {
return v;
}
template <>
template<>
inline __device__ float sum(float2 v) {
return v.x + v.y;
}
template <>
template<>
inline __device__ float sum(float4 v) {
return v.x + v.y + v.z + v.w;
}
template <>
template<>
inline __device__ float sum(Float4_ v) {
return v.x.x + v.x.y + v.y.x + v.y.y;
}
template <>
template<>
inline __device__ float sum(Float8_ v) {
return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
}
// Vector dot product.
inline __device__ float dot(float a, float b) { return a * b; }
inline __device__ float dot(float a, float b) {
return a * b;
}
inline __device__ float dot(float2 a, float2 b) {
float2 c = mul<float2, float2, float2>(a, b);
@ -228,24 +232,42 @@ inline __device__ float dot(Float8_ a, Float8_ b) {
}
// From float to float.
inline __device__ void from_float(float& dst, float src) { dst = src; }
inline __device__ void from_float(float& dst, float src) {
dst = src;
}
inline __device__ void from_float(float2& dst, float2 src) { dst = src; }
inline __device__ void from_float(float2& dst, float2 src) {
dst = src;
}
inline __device__ void from_float(float4& dst, float4 src) { dst = src; }
inline __device__ void from_float(float4& dst, float4 src) {
dst = src;
}
// From float to float.
inline __device__ float to_float(float u) { return u; }
inline __device__ float to_float(float u) {
return u;
}
inline __device__ float2 to_float(float2 u) { return u; }
inline __device__ float2 to_float(float2 u) {
return u;
}
inline __device__ float4 to_float(float4 u) { return u; }
inline __device__ float4 to_float(float4 u) {
return u;
}
inline __device__ Float4_ to_float(Float4_ u) { return u; }
inline __device__ Float4_ to_float(Float4_ u) {
return u;
}
inline __device__ Float8_ to_float(Float8_ u) { return u; }
inline __device__ Float8_ to_float(Float8_ u) {
return u;
}
// Zero-out a variable.
inline __device__ void zero(float& dst) { dst = 0.f; }
inline __device__ void zero(float& dst) {
dst = 0.f;
}
} // namespace vllm
} // namespace vllm

View File

@ -3,39 +3,33 @@
#include "attention_generic.cuh"
#include <stdint.h>
#ifdef ENABLE_FP8
#ifndef USE_ROCM
#include <cuda_fp8.h>
#endif // USE_ROCM
#endif // ENABLE_FP8
#ifdef ENABLE_FP8_E5M2
#include <cuda_fp8.h>
#endif
namespace vllm {
enum class Fp8KVCacheDataType {
kAuto = 0,
kFp8E4M3 = 1,
kFp8E5M2 = 2,
};
#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3)
// fp8 vector types for quantization of kv cache
template <>
template<>
struct Vec<uint8_t, 1> {
using Type = uint8_t;
using Type = uint8_t;
};
template <>
template<>
struct Vec<uint8_t, 2> {
using Type = uint16_t;
using Type = uint16_t;
};
template <>
template<>
struct Vec<uint8_t, 4> {
using Type = uint32_t;
using Type = uint32_t;
};
template <>
template<>
struct Vec<uint8_t, 8> {
using Type = uint2;
using Type = uint2;
};
#endif // ENABLE_FP8_E5M2
} // namespace vllm
} // namespace vllm

View File

@ -5,24 +5,26 @@
#include <map>
#include <vector>
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
const torch::Tensor& block_mapping);
void swap_blocks(
torch::Tensor& src,
torch::Tensor& dst,
const std::map<int64_t, int64_t>& block_mapping);
void copy_blocks(std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& value_caches,
const torch::Tensor& block_mapping);
void copy_blocks(
std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& value_caches,
const std::map<int64_t, std::vector<int64_t>>& block_mapping);
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype, const float kv_scale);
void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype);
void reshape_and_cache(
torch::Tensor& key,
torch::Tensor& value,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
const float kv_scale);
// Just for unittest
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const float scale, const std::string& kv_cache_dtype);
void convert_fp8(
torch::Tensor& src_cache,
torch::Tensor& dst_cache);

View File

@ -4,11 +4,10 @@
#include "cuda_compat.h"
#include "dispatch_utils.h"
#ifdef USE_ROCM
#include "quantization/fp8/amd/quant_utils.cuh"
#else
#include "quantization/fp8/nvidia/quant_utils.cuh"
#if defined(ENABLE_FP8_E5M2)
#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
#elif defined(ENABLE_FP8_E4M3)
#include "quantization/fp8/amd_detail/quant_utils.cuh"
#endif
#include <algorithm>
@ -18,17 +17,20 @@
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
typedef __hip_bfloat16 __nv_bfloat16;
typedef __hip_bfloat16 __nv_bfloat16;
#endif
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
const torch::Tensor& block_mapping) {
void swap_blocks(
torch::Tensor& src,
torch::Tensor& dst,
const std::map<int64_t, int64_t>& block_mapping) {
torch::Device src_device = src.device();
torch::Device dst_device = dst.device();
cudaMemcpyKind memcpy_type;
if (src_device.is_cuda() && dst_device.is_cuda()) {
TORCH_CHECK(src_device.index() == dst_device.index(),
"src and dst must be on the same GPU");
TORCH_CHECK(
src_device.index() == dst_device.index(),
"src and dst must be on the same GPU");
memcpy_type = cudaMemcpyDeviceToDevice;
} else if (src_device.is_cuda() && dst_device.is_cpu()) {
memcpy_type = cudaMemcpyDeviceToHost;
@ -38,44 +40,41 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
TORCH_CHECK(false, "Invalid device combination");
}
// NOTE(youkaichao): keep in mind that `block_mapping` should be
// a cpu tensor, otherwise every `item` call will require a gpu-cpu
// synchronization.
TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");
char* src_ptr = static_cast<char*>(src.data_ptr());
char* dst_ptr = static_cast<char*>(dst.data_ptr());
char *src_ptr = static_cast<char*>(src.data_ptr());
char *dst_ptr = static_cast<char*>(dst.data_ptr());
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
const at::cuda::OptionalCUDAGuard device_guard(
src_device.is_cuda() ? src_device : dst_device);
const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// NOTE(woosuk): This can be slow if the number of blocks is large.
const int64_t num_blocks = block_mapping.size(0);
for (size_t i = 0; i < num_blocks; i++) {
int64_t src_block_number = block_mapping[i][0].item<int64_t>();
int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
for (const auto& pair : block_mapping) {
int64_t src_block_number = pair.first;
int64_t dst_block_number = pair.second;
int64_t src_offset = src_block_number * block_size_in_bytes;
int64_t dst_offset = dst_block_number * block_size_in_bytes;
cudaMemcpyAsync(dst_ptr + dst_offset, src_ptr + src_offset,
block_size_in_bytes, memcpy_type, stream);
cudaMemcpyAsync(
dst_ptr + dst_offset,
src_ptr + src_offset,
block_size_in_bytes,
memcpy_type,
stream);
}
}
namespace vllm {
// Grid: (num_layers, num_pairs)
template <typename scalar_t>
__global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,
int64_t* value_cache_ptrs,
const int64_t* __restrict__ block_mapping,
const int numel_per_block) {
template<typename scalar_t>
__global__ void copy_blocks_kernel(
int64_t* key_cache_ptrs,
int64_t* value_cache_ptrs,
const int64_t* __restrict__ block_mapping,
const int numel_per_block) {
const int layer_idx = blockIdx.x;
const int pair_idx = blockIdx.y;
scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
scalar_t* value_cache =
reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
int64_t src_block_number = block_mapping[2 * pair_idx];
int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
@ -93,11 +92,12 @@ __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,
}
}
} // namespace vllm
} // namespace vllm
void copy_blocks(std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& value_caches,
const torch::Tensor& block_mapping) {
void copy_blocks(
std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& value_caches,
const std::map<int64_t, std::vector<int64_t>>& block_mapping) {
int num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size());
if (num_layers == 0) {
@ -111,23 +111,29 @@ void copy_blocks(std::vector<torch::Tensor>& key_caches,
int64_t key_cache_ptrs[num_layers];
int64_t value_cache_ptrs[num_layers];
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
key_cache_ptrs[layer_idx] =
reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
value_cache_ptrs[layer_idx] =
reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
}
// block_mapping is a 2D tensor with shape (num_pairs, 2).
int num_pairs = block_mapping.size(0);
// Create block mapping array.
std::vector<int64_t> block_mapping_vec;
for (const auto& pair : block_mapping) {
int64_t src_block_number = pair.first;
for (int64_t dst_block_number : pair.second) {
block_mapping_vec.push_back(src_block_number);
block_mapping_vec.push_back(dst_block_number);
}
}
int64_t* block_mapping_array = block_mapping_vec.data();
int num_pairs = block_mapping_vec.size() / 2;
// Move the data structures to the GPU.
// NOTE: This synchronizes the CPU and GPU.
torch::Tensor key_cache_ptrs_tensor =
torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64)
.to(cache_device);
torch::Tensor value_cache_ptrs_tensor =
torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64)
.to(cache_device);
torch::Tensor key_cache_ptrs_tensor = torch::from_blob(
key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
torch::Tensor block_mapping_tensor = torch::from_blob(
block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device);
// Launch the kernel.
const int numel_per_block = key_caches[0][0].numel();
@ -136,28 +142,31 @@ void copy_blocks(std::vector<torch::Tensor>& key_caches,
const at::cuda::OptionalCUDAGuard device_guard(cache_device);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
key_cache_ptrs_tensor.data_ptr<int64_t>(),
value_cache_ptrs_tensor.data_ptr<int64_t>(),
block_mapping.data_ptr<int64_t>(), numel_per_block);
}));
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
key_cache_ptrs_tensor.data_ptr<int64_t>(),
value_cache_ptrs_tensor.data_ptr<int64_t>(),
block_mapping_tensor.data_ptr<int64_t>(),
numel_per_block);
}));
}
namespace vllm {
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
template<typename scalar_t, typename cache_t, bool is_fp8_kv_cache>
__global__ void reshape_and_cache_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x,
// block_size, x]
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size,
// block_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int key_stride, const int value_stride, const int num_heads,
const int head_size, const int block_size, const int x,
const float kv_scale) {
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int key_stride,
const int value_stride,
const int num_heads,
const int head_size,
const int block_size,
const int x,
const float kv_scale) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
if (slot_idx < 0) {
@ -178,84 +187,60 @@ __global__ void reshape_and_cache_kernel(
const int x_idx = head_offset / x;
const int x_offset = head_offset % x;
const int64_t tgt_key_idx =
block_idx * num_heads * (head_size / x) * block_size * x +
head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
block_offset * x + x_offset;
const int64_t tgt_value_idx =
block_idx * num_heads * head_size * block_size +
head_idx * head_size * block_size + head_offset * block_size +
block_offset;
const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
+ head_idx * (head_size / x) * block_size * x
+ x_idx * block_size * x
+ block_offset * x
+ x_offset;
const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size
+ head_idx * head_size * block_size
+ head_offset * block_size
+ block_offset;
scalar_t tgt_key = key[src_key_idx];
scalar_t tgt_value = value[src_value_idx];
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
if constexpr (is_fp8_kv_cache) {
#if defined(ENABLE_FP8_E5M2)
key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
#elif defined(ENABLE_FP8_E4M3)
key_cache[tgt_key_idx] = fp8_e4m3::scaled_vec_conversion<uint8_t, scalar_t>(tgt_key, kv_scale);
value_cache[tgt_value_idx] = fp8_e4m3::scaled_vec_conversion<uint8_t, scalar_t>(tgt_value, kv_scale);
#else
assert(false);
#endif
} else {
key_cache[tgt_key_idx] = tgt_key;
value_cache[tgt_value_idx] = tgt_value;
} else {
key_cache[tgt_key_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, kv_scale);
value_cache[tgt_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, kv_scale);
}
}
}
template <typename scalar_t>
__global__ void reshape_and_cache_flash_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads,
// head_size]
scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads,
// head_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int block_stride, const int key_stride, const int value_stride,
const int num_heads, const int head_size, const int block_size) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded
if (slot_idx < 0) {
return;
}
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
const int n = num_heads * head_size;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int64_t src_key_idx = token_idx * key_stride + i;
const int64_t src_value_idx = token_idx * value_stride + i;
const int head_idx = i / head_size;
const int head_offset = i % head_size;
const int64_t tgt_value_idx = block_idx * block_stride +
block_offset * num_heads * head_size +
head_idx * head_size + head_offset;
k_cache[tgt_value_idx] = key[src_key_idx];
v_cache[tgt_value_idx] = value[src_value_idx];
}
}
} // namespace vllm
} // namespace vllm
// KV_T is the stored data type of kv-cache.
// CACHE_T is the data type of key and value tensors.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
num_heads, head_size, block_size, x, kv_scale);
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_KV_CACHE><<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), \
key_stride, \
value_stride, \
num_heads, \
head_size, \
block_size, \
x, \
kv_scale);
void reshape_and_cache(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype, const float kv_scale) {
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype,
const float kv_scale)
{
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
@ -269,78 +254,66 @@ void reshape_and_cache(
dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
CALL_RESHAPE_AND_CACHE)
}
void reshape_and_cache_flash(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype) {
// FIXME: only support auto datatype, does not support fp8
if (kv_cache_dtype != "auto") {
if (kv_cache_dtype == "auto") {
if (key.dtype() == at::ScalarType::Float) {
CALL_RESHAPE_AND_CACHE(float, float, false);
} else if (key.dtype() == at::ScalarType::Half) {
CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false);
} else if (key.dtype() == at::ScalarType::BFloat16) {
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
}
} else if (kv_cache_dtype == "fp8") {
if (key.dtype() == at::ScalarType::Float) {
CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
} else if (key.dtype() == at::ScalarType::Half) {
CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true);
} else if (key.dtype() == at::ScalarType::BFloat16) {
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true);
}
} else {
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
}
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = k_cache.size(1);
int key_stride = key.stride(0);
int value_stride = value.stride(0);
int block_stride = k_cache.stride(0);
TORCH_CHECK(k_cache.stride(0) == v_cache.stride(0));
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
key.scalar_type(), "reshape_and_cache_flash", [&] {
vllm::reshape_and_cache_flash_kernel<scalar_t>
<<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
k_cache.data_ptr<scalar_t>(), v_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride,
value_stride, num_heads, head_size, block_size);
});
}
namespace vllm {
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
Tout* __restrict__ dst_cache,
const float kv_scale,
const int64_t block_stride) {
template<typename Tout, typename Tin>
__global__ void convert_fp8_kernel(
const Tin* __restrict__ src_cache,
Tout* __restrict__ dst_cache,
const int64_t block_stride) {
const int64_t block_idx = blockIdx.x;
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
int64_t idx = block_idx * block_stride + i;
dst_cache[idx] =
fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], kv_scale);
#if defined(ENABLE_FP8_E5M2)
dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]);
#elif defined(ENABLE_FP8_E4M3)
dst_cache[idx] = fp8_e4m3::vec_conversion<Tout, Tin>(src_cache[idx]);
#else
assert(false);
#endif
}
}
} // namespace vllm
} // namespace vllm
#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \
vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
reinterpret_cast<Tout*>(dst_cache.data_ptr()), kv_scale, block_stride);
#define CALL_CONVERT_FP8(Tout, Tin) \
vllm::convert_fp8_kernel<Tout, Tin><<<grid, block, 0, stream>>>( \
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
block_stride);
// Only for testing.
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const float kv_scale, const std::string& kv_cache_dtype) {
void convert_fp8(
torch::Tensor& src_cache,
torch::Tensor& dst_cache)
{
torch::Device src_device = src_cache.device();
torch::Device dst_device = dst_cache.device();
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
TORCH_CHECK(src_device.index() == dst_device.index(),
"src and dst must be on the same GPU");
TORCH_CHECK(
src_device.index() == dst_device.index(),
"src and dst must be on the same GPU");
at::cuda::OptionalCUDAGuard device_guard(src_device);
int64_t num_blocks = src_cache.size(0);
@ -350,37 +323,17 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
dim3 block(std::min(block_stride, int64_t(512)));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (kv_cache_dtype == "auto") {
if (src_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kAuto);
} else if (src_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto);
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto);
} else if (dst_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
} else if (dst_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
}
} else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") {
if (src_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (src_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16,
vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (dst_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (dst_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E4M3);
}
} else {
TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
if (src_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(uint8_t, float);
} else if (src_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint8_t, uint16_t);
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16);
} else if (dst_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(float, uint8_t);
} else if (dst_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint16_t, uint8_t);
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t);
}
}

View File

@ -1,10 +1,10 @@
#include "cpu_types.hpp"
namespace {
template <typename scalar_t, vec_op::FP32Vec8 (*func)(const vec_op::FP32Vec8&),
template <typename scalar_t, vec_op::FP32Vec8 (*func)(const vec_op::FP32Vec8 &),
bool is_gated>
void activation_kernel(int num_tokens, int d, scalar_t* __restrict__ input,
scalar_t* __restrict__ output) {
void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input,
scalar_t *__restrict__ output) {
using scalar_vec_t = vec_op::vec_t<scalar_t>;
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
@ -34,13 +34,13 @@ void activation_kernel(int num_tokens, int d, scalar_t* __restrict__ input,
}
}
FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8& x) {
FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8 &x) {
const vec_op::FP32Vec8 zeros(0.0);
const vec_op::FP32Vec8 ones(1.0);
return x / (ones + (zeros - x).exp());
}
FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8& x) {
FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) {
const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(0.79788456f);
const vec_op::FP32Vec8 w2(0.044715f);
@ -50,7 +50,7 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8& x) {
return w3 * x * (ones + t);
}
FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8& x) {
FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) {
const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(0.79788456f);
const vec_op::FP32Vec8 w2(0.044715f);
@ -59,14 +59,14 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8& x) {
return w3 * x * (ones + t);
}
FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8& x) {
FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8 &x) {
const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(M_SQRT1_2);
const vec_op::FP32Vec8 w2(0.5);
return x * w2 * (ones + (x * w1).er());
}
FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8& x) {
FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) {
const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(M_SQRT2 * M_2_SQRTPI * 0.5);
const vec_op::FP32Vec8 w2(0.5);
@ -75,36 +75,40 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 inner = w1 * (x + x_3 * w3);
return x * w2 * (ones + inner.tanh());
}
}; // namespace
}; // namespace
void silu_and_mul(torch::Tensor& out, torch::Tensor& input) {
void silu_and_mul(torch::Tensor &out, torch::Tensor &input) {
int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2;
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_impl", [&] {
CPU_KERNEL_GUARD_IN(silu_and_mul_impl)
activation_kernel<scalar_t, silu_act, true>(
num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
CPU_KERNEL_GUARD_OUT(silu_and_mul_impl)
});
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "silu_and_mul_impl", [&] {
CPU_KERNEL_GUARD_IN(silu_and_mul_impl)
activation_kernel<scalar_t, silu_act, true>(num_tokens, d,
input.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>());
CPU_KERNEL_GUARD_OUT(silu_and_mul_impl)
});
}
void gelu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
void gelu_and_mul(torch::Tensor &out, // [..., d]
torch::Tensor &input) // [..., 2 * d]
{
int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2;
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_and_mul_impl", [&] {
CPU_KERNEL_GUARD_IN(gelu_and_mul_impl)
activation_kernel<scalar_t, gelu_act, true>(
num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl)
});
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "gelu_and_mul_impl", [&] {
CPU_KERNEL_GUARD_IN(gelu_and_mul_impl)
activation_kernel<scalar_t, gelu_act, true>(num_tokens, d,
input.data_ptr<scalar_t>(),
out.data_ptr<scalar_t>());
CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl)
});
}
void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
void gelu_tanh_and_mul(torch::Tensor &out, // [..., d]
torch::Tensor &input) // [..., 2 * d]
{
int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2;
@ -119,7 +123,7 @@ void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
});
}
void gelu_new(torch::Tensor& out, torch::Tensor& input) {
void gelu_new(torch::Tensor &out, torch::Tensor &input) {
int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1);
@ -131,7 +135,7 @@ void gelu_new(torch::Tensor& out, torch::Tensor& input) {
});
}
void gelu_fast(torch::Tensor& out, torch::Tensor& input) {
void gelu_fast(torch::Tensor &out, torch::Tensor &input) {
int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1);

View File

@ -2,8 +2,7 @@
namespace {
template <typename scalar_t>
struct KernelVecType {
template <typename scalar_t> struct KernelVecType {
using q_load_vec_type = void;
using q_vec_type = void;
using k_load_vec_type = void;
@ -12,8 +11,7 @@ struct KernelVecType {
using v_load_vec_type = void;
};
template <>
struct KernelVecType<float> {
template <> struct KernelVecType<float> {
using q_load_vec_type = vec_op::FP32Vec4;
using q_vec_type = vec_op::FP32Vec16;
using k_load_vec_type = vec_op::FP32Vec16;
@ -23,8 +21,7 @@ struct KernelVecType<float> {
};
#ifdef __AVX512BF16__
template <>
struct KernelVecType<c10::BFloat16> {
template <> struct KernelVecType<c10::BFloat16> {
using q_load_vec_type = vec_op::BF16Vec8;
using q_vec_type = vec_op::BF16Vec32;
using k_load_vec_type = vec_op::BF16Vec32;
@ -33,8 +30,7 @@ struct KernelVecType<c10::BFloat16> {
using v_load_vec_type = vec_op::BF16Vec16;
};
#else
template <>
struct KernelVecType<c10::BFloat16> {
template <> struct KernelVecType<c10::BFloat16> {
using q_load_vec_type = vec_op::BF16Vec8;
using q_vec_type = vec_op::FP32Vec16;
using k_load_vec_type = vec_op::BF16Vec16;
@ -45,7 +41,7 @@ struct KernelVecType<c10::BFloat16> {
#endif
template <typename T>
FORCE_INLINE std::pair<T, T> reduceSoftmax(T* data, const int size,
FORCE_INLINE std::pair<T, T> reduceSoftmax(T *data, const int size,
const int capacity) {
T max = data[0];
for (int i = 1; i < size; ++i) {
@ -71,15 +67,14 @@ FORCE_INLINE std::pair<T, T> reduceSoftmax(T* data, const int size,
}
template <typename T>
FORCE_INLINE std::pair<T, T> reduceSoftmaxAlibi(T* data, const int size,
const int capacity,
const float alibi_slope,
const int start_index,
const int seq_len) {
data[0] += alibi_slope * (start_index - seq_len + 1);
FORCE_INLINE std::pair<T, T>
reduceSoftmaxAlibi(T *data, const int size, const int capacity,
const float alibi_slope, const int start_index,
const int context_len) {
data[0] += alibi_slope * (start_index - context_len + 1);
T max = data[0];
for (int i = 1; i < size; ++i) {
T qk = data[i] + alibi_slope * (start_index + i - seq_len + 1);
T qk = data[i] + alibi_slope * (start_index + i - context_len + 1);
data[i] = qk;
max = max >= qk ? max : qk;
}
@ -103,7 +98,7 @@ FORCE_INLINE std::pair<T, T> reduceSoftmaxAlibi(T* data, const int size,
}
template <typename T>
FORCE_INLINE void reducePartitonSoftmax(const T* max_data, T* sum_data,
FORCE_INLINE void reducePartitonSoftmax(const T *max_data, T *sum_data,
const int size) {
T max = max_data[0];
for (int i = 1; i < size; ++i) {
@ -137,9 +132,9 @@ struct reduceQKBlockKernel {
static_assert(k_load_vec_type::get_elem_num() % x == 0);
static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16);
FORCE_INLINE static void call(const scalar_t* __restrict__ q,
const scalar_t* __restrict__ k_block,
float* __restrict__ logits, float scale,
FORCE_INLINE static void call(const scalar_t *__restrict__ q,
const scalar_t *__restrict__ k_block,
float *__restrict__ logits, float scale,
const int token_num) {
const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP;
@ -201,8 +196,8 @@ struct reduceQKBlockKernel {
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE,
int HEAD_PARTITION_SIZE, typename acc_t>
FORCE_INLINE void reduceValueBlock(const float* prob, const scalar_t* v_block,
acc_t&& acc) {
FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block,
acc_t &&acc) {
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
constexpr int ELEM_NUM = v_load_vec_type::get_elem_num();
static_assert(BLOCK_SIZE == ELEM_NUM);
@ -214,65 +209,66 @@ FORCE_INLINE void reduceValueBlock(const float* prob, const scalar_t* v_block,
acc[head_elem_idx] = acc[head_elem_idx] + prob_vec * fp32_v_vec;
});
}
}; // namespace
}; // namespace
// Paged attention v1
namespace {
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE>
struct paged_attention_v1_impl {
static void call(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
static void
call(scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_kv_heads, const float scale,
const int* __restrict__ block_tables, // [num_seqs,
// max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const int num_seqs, const int num_heads) {
const int num_kv_heads, const float scale,
const int
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int *__restrict__ context_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float *__restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const int num_seqs, const int num_heads) {
constexpr int x = 16 / sizeof(scalar_t);
const int num_queries_per_kv = num_heads / num_kv_heads;
static_assert(BLOCK_SIZE == 16);
int max_seq_len = max_num_blocks_per_seq * BLOCK_SIZE;
int max_seq_len_padded = (max_seq_len + 15) & 0xFFFFFFF0;
TORCH_CHECK((max_seq_len_padded * sizeof(float)) % 64 == 0);
int max_context_len = max_num_blocks_per_seq * BLOCK_SIZE;
int max_context_len_padded = (max_context_len + 15) & 0xFFFFFFF0;
TORCH_CHECK((max_context_len_padded * sizeof(float)) % 64 == 0);
const int parallel_work_item_num = omp_get_max_threads();
size_t logits_bytes =
parallel_work_item_num * max_seq_len_padded * sizeof(float);
float* logits = (float*)std::aligned_alloc(
64, logits_bytes); // Cacheline alignment for each context token.
// [parallel_work_item_num, max_seq_len_padded]
parallel_work_item_num * max_context_len_padded * sizeof(float);
float *logits = (float *)std::aligned_alloc(
64, logits_bytes); // Cacheline alignment for each context token.
// [parallel_work_item_num, max_context_len_padded]
#pragma omp parallel for collapse(2) schedule(dynamic, 1)
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
int seq_len = seq_lens[seq_idx];
const int* seq_block_table =
int context_len = context_lens[seq_idx];
const int *seq_block_table =
block_tables + max_num_blocks_per_seq * seq_idx;
const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
const int block_num = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
const int64_t kv_head_idx = head_idx / num_queries_per_kv;
const scalar_t* __restrict__ q_vec_ptr =
const scalar_t *__restrict__ q_vec_ptr =
q + seq_idx * q_stride + head_idx * HEAD_SIZE;
const int last_block_token_num = seq_len - (block_num - 1) * BLOCK_SIZE;
float* __restrict__ thread_block_logits =
logits + omp_get_thread_num() * max_seq_len_padded;
const int last_block_token_num =
context_len - (block_num - 1) * BLOCK_SIZE;
float *__restrict__ thread_block_logits =
logits + omp_get_thread_num() * max_context_len_padded;
// Compute logits
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
const int64_t physical_block_idx = seq_block_table[block_idx];
const scalar_t* __restrict__ k_block_cache_ptr =
const scalar_t *__restrict__ k_block_cache_ptr =
k_cache + physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride;
float* __restrict__ head_block_logits =
float *__restrict__ head_block_logits =
thread_block_logits + block_idx * BLOCK_SIZE;
reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call(
@ -282,11 +278,12 @@ struct paged_attention_v1_impl {
// Compute softmax
if (alibi_slopes) {
reduceSoftmaxAlibi(thread_block_logits, seq_len,
reduceSoftmaxAlibi(thread_block_logits, context_len,
block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0,
seq_len);
context_len);
} else {
reduceSoftmax(thread_block_logits, seq_len, block_num * BLOCK_SIZE);
reduceSoftmax(thread_block_logits, context_len,
block_num * BLOCK_SIZE);
}
// Compute value
@ -296,14 +293,14 @@ struct paged_attention_v1_impl {
for (int head_part_idx = 0; head_part_idx < head_partition_num;
++head_part_idx) {
vec_op::FP32Vec16 accums[head_elem_num_per_partition];
scalar_t* __restrict__ out_ptr =
scalar_t *__restrict__ out_ptr =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE +
head_part_idx * head_elem_num_per_partition;
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
const int64_t physical_block_idx = seq_block_table[block_idx];
const float* __restrict__ prob_vec_ptr =
const float *__restrict__ prob_vec_ptr =
thread_block_logits + block_idx * BLOCK_SIZE;
const scalar_t* __restrict__ v_block_cache_ptr =
const scalar_t *__restrict__ v_block_cache_ptr =
v_cache + physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride +
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
@ -314,7 +311,7 @@ struct paged_attention_v1_impl {
if (block_idx != block_num - 1) {
const int64_t next_physical_block_idx =
seq_block_table[block_idx + 1];
const scalar_t* __restrict__ next_v_block_cache_ptr =
const scalar_t *__restrict__ next_v_block_cache_ptr =
v_cache + next_physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride +
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
@ -343,16 +340,16 @@ struct paged_attention_v1_impl {
#define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
paged_attention_v1_impl<T, HEAD_SIZE, BLOCK_SIZE>::call( \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \
num_heads);
template <typename T, int BLOCK_SIZE>
void paged_attention_v1_impl_launcher(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes) {
torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache,
torch::Tensor &value_cache, int num_kv_heads, float scale,
torch::Tensor &block_tables, torch::Tensor &context_lens,
int max_context_len, const c10::optional<torch::Tensor> &alibi_slopes) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
@ -362,73 +359,68 @@ void paged_attention_v1_impl_launcher(
int kv_head_stride = key_cache.stride(1);
// NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr =
const float *alibi_slopes_ptr =
alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
? reinterpret_cast<const float *>(alibi_slopes.value().data_ptr())
: nullptr;
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
T *out_ptr = reinterpret_cast<T *>(out.data_ptr());
T *query_ptr = reinterpret_cast<T *>(query.data_ptr());
T *key_cache_ptr = reinterpret_cast<T *>(key_cache.data_ptr());
T *value_cache_ptr = reinterpret_cast<T *>(value_cache.data_ptr());
int *block_tables_ptr = block_tables.data_ptr<int>();
int *context_lens_ptr = context_lens.data_ptr<int>();
switch (head_size) {
case 64:
LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
break;
case 80:
LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
break;
case 96:
LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
break;
case 112:
LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
break;
case 128:
LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
break;
case 192:
LAUNCH_V1_ATTENTION_KERNEL(T, 192, BLOCK_SIZE);
break;
case 256:
LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
break;
default:
TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
case 64:
LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
break;
case 80:
LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
break;
case 96:
LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
break;
case 112:
LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
break;
case 128:
LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
break;
case 256:
LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
break;
default:
TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
}
}
#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v1_impl_launcher<T, BLOCK_SIZE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes);
#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v1_impl_launcher<T, BLOCK_SIZE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
context_lens, max_context_len, alibi_slopes);
#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \
case 16: \
CALL_V1_KERNEL_LAUNCHER(T, 16); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \
case 16: \
CALL_V1_KERNEL_LAUNCHER(T, 16); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
} // namespace
} // namespace
void paged_attention_v1(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
void paged_attention_v1(torch::Tensor &out, torch::Tensor &query,
torch::Tensor &key_cache, torch::Tensor &value_cache,
int num_kv_heads, float scale,
torch::Tensor &block_tables,
torch::Tensor &context_lens, int block_size,
int max_context_len,
const c10::optional<torch::Tensor> &alibi_slopes,
const std::string &kv_cache_dtype, float kv_scale) {
TORCH_CHECK(kv_scale == 1.0f);
TORCH_CHECK(blocksparse_vert_stride <= 1,
"CPU backend does not support blocksparse attention yet.");
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
[&] {
CPU_KERNEL_GUARD_IN(paged_attention_v1_impl)
@ -442,24 +434,23 @@ namespace {
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int PARTITION_SIZE>
struct paged_attention_v2_impl {
static void call(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
float* __restrict__ exp_sums, // [num_seqs, num_heads,
// max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size]
float *__restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float
*__restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
scalar_t *__restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_kv_heads, const float scale,
const int* __restrict__ block_tables, // [num_seqs,
// max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs]
const int
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int *__restrict__ context_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const float *__restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const int num_seqs, const int num_heads, const int max_num_partitions) {
constexpr int x = 16 / sizeof(scalar_t);
@ -474,25 +465,27 @@ struct paged_attention_v2_impl {
for (int partition_idx = 0; partition_idx < max_num_partitions;
++partition_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
const int seq_len = seq_lens[seq_idx];
const int context_len = context_lens[seq_idx];
const int start_token_idx = partition_idx * PARTITION_SIZE;
if (start_token_idx >= seq_len) continue;
if (start_token_idx >= context_len)
continue;
const int partition_num =
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
(context_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
const bool no_reduce = (partition_num == 1);
const int token_num =
(std::min(seq_len, start_token_idx + PARTITION_SIZE) -
const int context_token_num =
(std::min(context_len, start_token_idx + PARTITION_SIZE) -
start_token_idx);
const int block_num = (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
const int block_num =
(context_token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
const int last_block_token_num =
token_num - (block_num - 1) * BLOCK_SIZE;
const int* seq_block_table = block_tables +
context_token_num - (block_num - 1) * BLOCK_SIZE;
const int *seq_block_table = block_tables +
max_num_blocks_per_seq * seq_idx +
start_token_idx / BLOCK_SIZE;
const int64_t kv_head_idx = head_idx / num_queries_per_kv;
const scalar_t* __restrict__ q_vec_ptr =
const scalar_t *__restrict__ q_vec_ptr =
q + seq_idx * q_stride + head_idx * HEAD_SIZE;
float logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0};
@ -500,10 +493,10 @@ struct paged_attention_v2_impl {
// Compute logits
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
const int64_t physical_block_idx = seq_block_table[block_idx];
const scalar_t* __restrict__ k_block_cache_ptr =
const scalar_t *__restrict__ k_block_cache_ptr =
k_cache + physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride;
float* __restrict__ head_block_logits =
float *__restrict__ head_block_logits =
logits + block_idx * BLOCK_SIZE;
reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call(
@ -514,16 +507,16 @@ struct paged_attention_v2_impl {
std::pair<float, float> max_and_sum;
if (alibi_slopes) {
max_and_sum = reduceSoftmaxAlibi(
logits, token_num, block_num * BLOCK_SIZE,
alibi_slopes[head_idx], start_token_idx, seq_len);
logits, context_token_num, block_num * BLOCK_SIZE,
alibi_slopes[head_idx], start_token_idx, context_len);
} else {
max_and_sum =
reduceSoftmax(logits, token_num, block_num * BLOCK_SIZE);
max_and_sum = reduceSoftmax(logits, context_token_num,
block_num * BLOCK_SIZE);
}
auto&& [max_logit, exp_sum] = max_and_sum;
auto &&[max_logit, exp_sum] = max_and_sum;
scalar_t* __restrict__ output_buffer = nullptr;
scalar_t *__restrict__ output_buffer = nullptr;
if (!no_reduce) {
auto idx = seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx;
@ -545,13 +538,13 @@ struct paged_attention_v2_impl {
for (int head_part_idx = 0; head_part_idx < head_partition_num;
++head_part_idx) {
vec_op::FP32Vec16 accums[head_elem_num_per_partition];
scalar_t* __restrict__ out_ptr =
scalar_t *__restrict__ out_ptr =
output_buffer + head_part_idx * head_elem_num_per_partition;
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
const int64_t physical_block_idx = seq_block_table[block_idx];
const float* __restrict__ prob_vec_ptr =
const float *__restrict__ prob_vec_ptr =
logits + block_idx * BLOCK_SIZE;
const scalar_t* __restrict__ v_block_cache_ptr =
const scalar_t *__restrict__ v_block_cache_ptr =
v_cache + physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride +
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
@ -562,7 +555,7 @@ struct paged_attention_v2_impl {
if (block_idx != block_num - 1) {
const int64_t next_physical_block_idx =
seq_block_table[block_idx + 1];
const scalar_t* __restrict__ next_v_block_cache_ptr =
const scalar_t *__restrict__ next_v_block_cache_ptr =
v_cache + next_physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride +
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
@ -590,11 +583,12 @@ struct paged_attention_v2_impl {
#pragma omp parallel for collapse(2) schedule(static, 1)
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
const int seq_len = seq_lens[seq_idx];
const int context_len = context_lens[seq_idx];
const int partition_num =
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
(context_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
if (partition_num == 1) continue;
if (partition_num == 1)
continue;
reducePartitonSoftmax(
max_logits + seq_idx * num_heads * max_num_partitions +
@ -609,29 +603,30 @@ struct paged_attention_v2_impl {
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE);
constexpr int head_elem_num_per_group =
16; // Note: didn't align with the cacheline size, due to some
// HEAD_SIZE didn't align with 64 bytes
16; // Note: didn't align with the cacheline size, due to some HEAD_SIZE
// didn't align with 64 bytes
static_assert(HEAD_SIZE % head_elem_num_per_group == 0);
constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group;
const float* __restrict__ rescale_factors = exp_sums;
const float *__restrict__ rescale_factors = exp_sums;
#pragma omp parallel for collapse(3) schedule(static, 1)
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
for (int group_idx = 0; group_idx < head_group_num; ++group_idx) {
const int seq_len = seq_lens[seq_idx];
const int context_len = context_lens[seq_idx];
const int partition_num =
(seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
(context_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
if (partition_num == 1) continue;
if (partition_num == 1)
continue;
const float* __restrict__ seq_head_rescale_factors =
const float *__restrict__ seq_head_rescale_factors =
rescale_factors + seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions;
const scalar_t* __restrict__ seq_head_tmp_out =
const scalar_t *__restrict__ seq_head_tmp_out =
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE +
group_idx * head_elem_num_per_group;
scalar_t* __restrict__ seq_head_output =
scalar_t *__restrict__ seq_head_output =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE +
group_idx * head_elem_num_per_group;
@ -650,21 +645,21 @@ struct paged_attention_v2_impl {
}
};
#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
paged_attention_v2_impl<T, HEAD_SIZE, BLOCK_SIZE, PARTITION_SIZE>::call( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \
key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, num_seqs, num_heads, \
#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
paged_attention_v2_impl<T, HEAD_SIZE, BLOCK_SIZE, PARTITION_SIZE>::call( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \
key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
context_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, num_seqs, num_heads, \
max_num_partitions);
template <typename T, int BLOCK_SIZE, int PARTITION_SIZE = 512>
void paged_attention_v2_impl_launcher(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes) {
torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits,
torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache,
torch::Tensor &value_cache, int num_kv_heads, float scale,
torch::Tensor &block_tables, torch::Tensor &context_lens, int block_size,
int max_context_len, const c10::optional<torch::Tensor> &alibi_slopes) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
@ -675,78 +670,73 @@ void paged_attention_v2_impl_launcher(
int max_num_partitions = exp_sums.size(-1);
// NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr =
const float *alibi_slopes_ptr =
alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
? reinterpret_cast<const float *>(alibi_slopes.value().data_ptr())
: nullptr;
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
T *out_ptr = reinterpret_cast<T *>(out.data_ptr());
float *exp_sums_ptr = reinterpret_cast<float *>(exp_sums.data_ptr());
float *max_logits_ptr = reinterpret_cast<float *>(max_logits.data_ptr());
T *tmp_out_ptr = reinterpret_cast<T *>(tmp_out.data_ptr());
T *query_ptr = reinterpret_cast<T *>(query.data_ptr());
T *key_cache_ptr = reinterpret_cast<T *>(key_cache.data_ptr());
T *value_cache_ptr = reinterpret_cast<T *>(value_cache.data_ptr());
int *block_tables_ptr = block_tables.data_ptr<int>();
int *context_lens_ptr = context_lens.data_ptr<int>();
switch (head_size) {
case 64:
LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
break;
case 80:
LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
break;
case 96:
LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
break;
case 112:
LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
break;
case 128:
LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
break;
case 192:
LAUNCH_V2_ATTENTION_KERNEL(T, 192, BLOCK_SIZE);
break;
case 256:
LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
break;
default:
TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
case 64:
LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
break;
case 80:
LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
break;
case 96:
LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
break;
case 112:
LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
break;
case 128:
LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
break;
case 256:
LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
break;
default:
TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
}
}
#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, \
alibi_slopes);
#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, context_lens, block_size, \
max_context_len, alibi_slopes);
#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \
case 16: \
CALL_V2_KERNEL_LAUNCHER(T, 16); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \
case 16: \
CALL_V2_KERNEL_LAUNCHER(T, 16); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
} // namespace
} // namespace
void paged_attention_v2(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums,
torch::Tensor &max_logits, torch::Tensor &tmp_out,
torch::Tensor &query, torch::Tensor &key_cache,
torch::Tensor &value_cache, int num_kv_heads,
float scale, torch::Tensor &block_tables,
torch::Tensor &context_lens, int block_size,
int max_context_len,
const c10::optional<torch::Tensor> &alibi_slopes,
const std::string &kv_cache_dtype, float kv_scale) {
TORCH_CHECK(kv_scale == 1.0f);
TORCH_CHECK(blocksparse_vert_stride <= 1,
"CPU backend does not support blocksparse attention yet.");
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
[&] {
CPU_KERNEL_GUARD_IN(paged_attention_v2_impl)

View File

@ -5,26 +5,25 @@
namespace {
template <typename scalar_t>
void copy_blocks_cpu_impl(std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& value_caches,
const torch::Tensor& mapping_pairs,
const int element_num_per_block,
const int layer_num) {
const size_t pair_num = mapping_pairs.size(0);
void copy_blocks_cpu_impl(
std::vector<torch::Tensor> &key_caches,
std::vector<torch::Tensor> &value_caches,
const std::vector<std::pair<int64_t, int64_t>> mapping_pairs,
const int element_num_per_block, const int layer_num) {
const size_t pair_num = mapping_pairs.size();
const size_t block_bytes = sizeof(scalar_t) * element_num_per_block;
#pragma omp parallel for collapse(2)
for (int layer = 0; layer < layer_num; ++layer) {
for (size_t pair = 0; pair < pair_num; ++pair) {
int64_t source_offset =
element_num_per_block * mapping_pairs[pair][0].item<int64_t>();
int64_t source_offset = element_num_per_block * mapping_pairs[pair].first;
int64_t target_offset =
element_num_per_block * mapping_pairs[pair][1].item<int64_t>();
scalar_t* key_cache_ptr = key_caches[layer].data_ptr<scalar_t>();
scalar_t* source_ptr = key_cache_ptr + source_offset;
scalar_t* target_ptr = key_cache_ptr + target_offset;
element_num_per_block * mapping_pairs[pair].second;
scalar_t *key_cache_ptr = key_caches[layer].data_ptr<scalar_t>();
scalar_t *source_ptr = key_cache_ptr + source_offset;
scalar_t *target_ptr = key_cache_ptr + target_offset;
std::memcpy(target_ptr, source_ptr, block_bytes);
scalar_t* value_cache_ptr = value_caches[layer].data_ptr<scalar_t>();
scalar_t *value_cache_ptr = value_caches[layer].data_ptr<scalar_t>();
source_ptr = value_cache_ptr + source_offset;
target_ptr = value_cache_ptr + target_offset;
std::memcpy(target_ptr, source_ptr, block_bytes);
@ -34,9 +33,9 @@ void copy_blocks_cpu_impl(std::vector<torch::Tensor>& key_caches,
template <typename scalar_t>
void reshape_and_cache_cpu_impl(
const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
const int64_t* __restrict__ slot_mapping, const int num_tokens,
const scalar_t *__restrict__ key, const scalar_t *__restrict__ value,
scalar_t *__restrict__ key_cache, scalar_t *__restrict__ value_cache,
const int64_t *__restrict__ slot_mapping, const int num_tokens,
const int key_stride, const int value_stride, const int num_heads,
const int head_size, const int block_size, const int x) {
const int block_elem_num = num_heads * head_size * block_size;
@ -49,14 +48,14 @@ void reshape_and_cache_cpu_impl(
int src_key_head_idx = token_idx * key_stride + head_idx * head_size;
int src_value_head_idx =
token_idx * value_stride + head_idx * head_size;
const scalar_t* src_key_head_ptr = key + src_key_head_idx;
const scalar_t* src_value_head_ptr = value + src_value_head_idx;
const scalar_t *src_key_head_ptr = key + src_key_head_idx;
const scalar_t *src_value_head_ptr = value + src_value_head_idx;
const int64_t block_index = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
scalar_t* target_key_head_ptr = key_cache +
scalar_t *target_key_head_ptr = key_cache +
block_elem_num * block_index +
head_idx * block_size * head_size;
scalar_t* target_value_head_ptr = value_cache +
scalar_t *target_value_head_ptr = value_cache +
block_elem_num * block_index +
head_idx * block_size * head_size;
@ -80,31 +79,39 @@ void reshape_and_cache_cpu_impl(
}
}
}
}; // namespace
}; // namespace
void copy_blocks(std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& value_caches,
const torch::Tensor& block_mapping) {
unsigned num_layers = key_caches.size();
void copy_blocks(std::vector<torch::Tensor> &key_caches,
std::vector<torch::Tensor> &value_caches,
const std::map<int64_t, std::vector<int64_t>> &block_mapping) {
int num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size());
if (num_layers == 0) {
return;
}
std::vector<std::pair<int64_t, int64_t>> mapping_pairs;
mapping_pairs.reserve(block_mapping.size());
for (const auto &pair : block_mapping) {
for (const auto &dst : pair.second) {
mapping_pairs.emplace_back(pair.first, dst);
}
}
const int element_num_per_block = key_caches[0][0].numel();
VLLM_DISPATCH_FLOATING_TYPES(
key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, block_mapping,
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, mapping_pairs,
element_num_per_block, num_layers);
CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl)
});
}
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype, float kv_scale) {
void reshape_and_cache(torch::Tensor &key, torch::Tensor &value,
torch::Tensor &key_cache, torch::Tensor &value_cache,
torch::Tensor &slot_mapping,
const std::string &kv_cache_dtype, float kv_scale) {
TORCH_CHECK(kv_scale == 1.0f);
int num_tokens = key.size(0);
@ -128,7 +135,7 @@ void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
});
}
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
const torch::Tensor& block_mapping) {
void swap_blocks(torch::Tensor &src, torch::Tensor &dst,
const std::map<int64_t, int64_t> &block_mapping) {
TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")
}

View File

@ -2,10 +2,10 @@
namespace {
template <typename scalar_t>
void rms_norm_impl(scalar_t* __restrict__ out,
const scalar_t* __restrict__ input,
const scalar_t* __restrict__ weight, const float epsilon,
const int num_tokens, const int hidden_size) {
void rms_norm_impl(scalar_t *__restrict__ out,
const scalar_t *__restrict__ input,
const scalar_t *__restrict__ weight, const float epsilon,
const int num_tokens, const int hidden_size) {
using scalar_vec_t = vec_op::vec_t<scalar_t>;
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0);
@ -41,11 +41,11 @@ void rms_norm_impl(scalar_t* __restrict__ out,
}
template <typename scalar_t>
void fused_add_rms_norm_impl(scalar_t* __restrict__ input,
scalar_t* __restrict__ residual,
const scalar_t* __restrict__ weight,
const float epsilon, const int num_tokens,
const int hidden_size) {
void fused_add_rms_norm_impl(scalar_t *__restrict__ input,
scalar_t *__restrict__ residual,
const scalar_t *__restrict__ weight,
const float epsilon, const int num_tokens,
const int hidden_size) {
using scalar_vec_t = vec_op::vec_t<scalar_t>;
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0);
@ -85,24 +85,24 @@ void fused_add_rms_norm_impl(scalar_t* __restrict__ input,
}
}
}
} // namespace
} // namespace
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
float epsilon) {
void rms_norm(torch::Tensor &out, torch::Tensor &input,
torch::Tensor &weight, float epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_impl", [&] {
CPU_KERNEL_GUARD_IN(rms_norm_impl)
rms_norm_impl(out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), epsilon, num_tokens,
hidden_size);
weight.data_ptr<scalar_t>(), epsilon, num_tokens,
hidden_size);
CPU_KERNEL_GUARD_OUT(rms_norm_impl)
});
}
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, float epsilon) {
void fused_add_rms_norm(torch::Tensor &input, torch::Tensor &residual,
torch::Tensor &weight, float epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;

View File

@ -4,21 +4,22 @@
namespace {
template <typename scalar_t>
void rotary_embedding_impl(
const int64_t* __restrict__ positions, // [batch_size, seq_len] or
// [num_tokens]
scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
/// head_size] or [num_tokens, num_heads,
/// head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
// 2]
const int64_t
*__restrict__ positions, // [batch_size, seq_len] or [num_tokens]
scalar_t
*__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or
/// [num_tokens, num_heads, head_size]
scalar_t
*__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or
// [num_tokens, num_kv_heads, head_size]
const scalar_t
*__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size,
const int num_tokens) {
using scalar_vec_t = vec_op::vec_t<scalar_t>;
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
constexpr int ELEM_SIZE = sizeof(scalar_t);
const int embed_dim = rot_dim / 2;
TORCH_CHECK(embed_dim % VEC_ELEM_NUM == 0);
@ -26,7 +27,7 @@ void rotary_embedding_impl(
#pragma omp parallel for
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
int64_t pos = positions[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim;
for (int i = 0; i < num_heads; ++i) {
const int head_idx = i;
@ -94,16 +95,16 @@ void rotary_embedding_impl(
template <typename scalar_t>
void rotary_embedding_gptj_impl(
const int64_t* __restrict__ positions, // [batch_size, seq_len] or
// [num_tokens]
scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
/// head_size] or [num_tokens, num_heads,
/// head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
// 2]
const int64_t
*__restrict__ positions, // [batch_size, seq_len] or [num_tokens]
scalar_t
*__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or
/// [num_tokens, num_heads, head_size]
scalar_t
*__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or
// [num_tokens, num_kv_heads, head_size]
const scalar_t
*__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size,
const int num_tokens) {
@ -113,13 +114,13 @@ void rotary_embedding_gptj_impl(
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
for (int i = 0; i < num_heads; ++i) {
int64_t pos = positions[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
const scalar_t* cos_cache_ptr = cache_ptr;
const scalar_t* sin_cache_ptr = cache_ptr + embed_dim;
const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim;
const scalar_t *cos_cache_ptr = cache_ptr;
const scalar_t *sin_cache_ptr = cache_ptr + embed_dim;
const int head_idx = i;
const int64_t token_head =
token_idx * query_stride + head_idx * head_size;
scalar_t* head_query = token_head + query;
scalar_t *head_query = token_head + query;
for (int j = 0; j < embed_dim; j += 1) {
const int rot_offset = j;
const int x_index = 2 * rot_offset;
@ -141,12 +142,12 @@ void rotary_embedding_gptj_impl(
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
for (int i = 0; i < num_kv_heads; ++i) {
int64_t pos = positions[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
const scalar_t* cos_cache_ptr = cache_ptr;
const scalar_t* sin_cache_ptr = cache_ptr + embed_dim;
const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim;
const scalar_t *cos_cache_ptr = cache_ptr;
const scalar_t *sin_cache_ptr = cache_ptr + embed_dim;
const int head_idx = i;
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
scalar_t* head_key = key + token_head;
scalar_t *head_key = key + token_head;
for (int j = 0; j < embed_dim; j += 1) {
const int rot_offset = j;
const int x_index = 2 * rot_offset;
@ -164,11 +165,11 @@ void rotary_embedding_gptj_impl(
}
}
}
}; // namespace
}; // namespace
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& key, int head_size,
torch::Tensor& cos_sin_cache, bool is_neox) {
void rotary_embedding(torch::Tensor &positions, torch::Tensor &query,
torch::Tensor &key, int head_size,
torch::Tensor &cos_sin_cache, bool is_neox) {
int num_tokens = query.numel() / query.size(-1);
int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-1) / head_size;

View File

@ -8,37 +8,66 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
// Attention ops
ops.def("paged_attention_v1", &paged_attention_v1,
"Compute the attention between an input query and the cached "
"keys/values using PagedAttention.");
ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2.");
ops.def(
"paged_attention_v1",
&paged_attention_v1,
"Compute the attention between an input query and the cached keys/values using PagedAttention.");
ops.def(
"paged_attention_v2",
&paged_attention_v2,
"PagedAttention V2.");
// Activation ops
ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU.");
ops.def("gelu_and_mul", &gelu_and_mul,
"Activation function used in GeGLU with `none` approximation.");
ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul,
"Activation function used in GeGLU with `tanh` approximation.");
ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2.");
ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation.");
ops.def(
"silu_and_mul",
&silu_and_mul,
"Activation function used in SwiGLU.");
ops.def(
"gelu_and_mul",
&gelu_and_mul,
"Activation function used in GeGLU with `none` approximation.");
ops.def(
"gelu_tanh_and_mul",
&gelu_tanh_and_mul,
"Activation function used in GeGLU with `tanh` approximation.");
ops.def(
"gelu_new",
&gelu_new,
"GELU implementation used in GPT-2.");
ops.def(
"gelu_fast",
&gelu_fast,
"Approximate GELU implementation.");
// Layernorm
ops.def("rms_norm", &rms_norm,
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
ops.def(
"rms_norm",
&rms_norm,
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
ops.def("fused_add_rms_norm", &fused_add_rms_norm,
"In-place fused Add and RMS Normalization");
ops.def(
"fused_add_rms_norm",
&fused_add_rms_norm,
"In-place fused Add and RMS Normalization");
// Rotary embedding
ops.def("rotary_embedding", &rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
ops.def(
"rotary_embedding",
&rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
// Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
cache_ops.def("swap_blocks", &swap_blocks,
"Swap in (out) the cache blocks from src to dst");
cache_ops.def("copy_blocks", &copy_blocks,
"Copy the cache blocks from src to dst");
cache_ops.def("reshape_and_cache", &reshape_and_cache,
"Reshape the key and value tensors and cache them");
cache_ops.def(
"swap_blocks",
&swap_blocks,
"Swap in (out) the cache blocks from src to dst");
cache_ops.def(
"copy_blocks",
&copy_blocks,
"Copy the cache blocks from src to dst");
cache_ops.def(
"reshape_and_cache",
&reshape_and_cache,
"Reshape the key and value tensors and cache them");
}

View File

@ -1,7 +1,7 @@
#pragma once
#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#endif
#ifndef USE_ROCM
@ -17,8 +17,7 @@
#endif
#ifndef USE_ROCM
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) \
__shfl_xor_sync(uint32_t(-1), var, lane_mask)
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask)
#else
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
#endif
@ -29,13 +28,6 @@
#define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
#endif
#ifndef USE_ROCM
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \
__shfl_down_sync(uint32_t(-1), var, lane_delta)
#else
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta)
#endif
#ifndef USE_ROCM
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
@ -43,3 +35,4 @@
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
#endif

View File

@ -2,6 +2,9 @@
#include <torch/extension.h>
int get_device_attribute(int attribute, int device_id);
int get_device_attribute(
int attribute,
int device_id);
int get_max_shared_memory_per_block_device_attribute(int device_id);
int get_max_shared_memory_per_block_device_attribute(
int device_id);

View File

@ -2,28 +2,34 @@
#include <hip/hip_runtime.h>
#include <hip/hip_runtime_api.h>
#endif
int get_device_attribute(int attribute, int device_id) {
int device, value;
if (device_id < 0) {
cudaGetDevice(&device);
} else {
device = device_id;
}
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute),
device);
return value;
int get_device_attribute(
int attribute,
int device_id)
{
int device, value;
if (device_id < 0) {
cudaGetDevice(&device);
}
else {
device = device_id;
}
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device);
return value;
}
int get_max_shared_memory_per_block_device_attribute(int device_id) {
int attribute;
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
int get_max_shared_memory_per_block_device_attribute(
int device_id)
{
int attribute;
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
#ifdef USE_ROCM
attribute = hipDeviceAttributeMaxSharedMemoryPerBlock;
attribute = hipDeviceAttributeMaxSharedMemoryPerBlock;
#else
attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin;
attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin;
#endif
return get_device_attribute(attribute, device_id);
return get_device_attribute(attribute, device_id);
}

View File

@ -7,11 +7,11 @@
// fake pointer type
using fptr_t = uint64_t;
static_assert(sizeof(void*) == sizeof(fptr_t));
static_assert(sizeof(void *) == sizeof(fptr_t));
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets, int rank,
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
const std::vector<std::string> &handles,
const std::vector<int64_t> &offsets, int rank,
bool full_nvlink) {
int world_size = offsets.size();
if (world_size > 8)
@ -29,7 +29,7 @@ fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t));
}
return (fptr_t) new vllm::CustomAllreduce(
reinterpret_cast<vllm::Signal*>(meta.data_ptr()), rank_data.data_ptr(),
reinterpret_cast<vllm::Signal *>(meta.data_ptr()), rank_data.data_ptr(),
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
}
@ -49,13 +49,13 @@ fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
* 5. A[None].expand(2, -1, -1, -1): Not OK
* 6. A[:, 1:, 1:]: Not OK
*/
bool _is_weak_contiguous(torch::Tensor& t) {
bool _is_weak_contiguous(torch::Tensor &t) {
return t.is_contiguous() ||
(t.storage().nbytes() - t.storage_offset() * t.element_size() ==
t.numel() * t.element_size());
}
bool should_custom_ar(torch::Tensor& inp, int max_size, int world_size,
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
bool full_nvlink) {
auto inp_size = inp.numel() * inp.element_size();
// custom allreduce requires input byte size to be multiples of 16
@ -67,27 +67,28 @@ bool should_custom_ar(torch::Tensor& inp, int max_size, int world_size,
return false;
}
void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
cudaStream_t stream) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
TORCH_CHECK(_is_weak_contiguous(out));
switch (out.scalar_type()) {
case at::ScalarType::Float: {
fa->allreduce<float>(stream, reinterpret_cast<float*>(inp.data_ptr()),
reinterpret_cast<float*>(out.data_ptr()),
fa->allreduce<float>(stream, reinterpret_cast<float *>(inp.data_ptr()),
reinterpret_cast<float *>(out.data_ptr()),
out.numel());
break;
}
case at::ScalarType::Half: {
fa->allreduce<half>(stream, reinterpret_cast<half*>(inp.data_ptr()),
reinterpret_cast<half*>(out.data_ptr()), out.numel());
fa->allreduce<half>(stream, reinterpret_cast<half *>(inp.data_ptr()),
reinterpret_cast<half *>(out.data_ptr()),
out.numel());
break;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case at::ScalarType::BFloat16: {
fa->allreduce<nv_bfloat16>(
stream, reinterpret_cast<nv_bfloat16*>(inp.data_ptr()),
reinterpret_cast<nv_bfloat16*>(out.data_ptr()), out.numel());
stream, reinterpret_cast<nv_bfloat16 *>(inp.data_ptr()),
reinterpret_cast<nv_bfloat16 *>(out.data_ptr()), out.numel());
break;
}
#endif
@ -97,7 +98,7 @@ void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
}
}
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) {
void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = c10::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
@ -105,8 +106,8 @@ void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) {
_all_reduce(_fa, inp, out, stream);
}
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
torch::Tensor& out) {
void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &reg_buffer,
torch::Tensor &out) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = c10::cuda::getCurrentCUDAStream().stream();
@ -121,27 +122,27 @@ void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
}
void dispose(fptr_t _fa) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
delete fa;
}
int meta_size() { return sizeof(vllm::Signal); }
void register_buffer(fptr_t _fa, torch::Tensor& t,
const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
void register_buffer(fptr_t _fa, torch::Tensor &t,
const std::vector<std::string> &handles,
const std::vector<int64_t> &offsets) {
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
fa->register_buffer(handles, offsets, t.data_ptr());
}
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
fptr_t _fa) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
return fa->get_graph_buffer_ipc_meta();
}
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>>& offsets) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
const std::vector<std::vector<int64_t>> &offsets) {
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
fa->register_graph_buffers(handles, offsets);
}

View File

@ -31,9 +31,9 @@ struct Signal {
alignas(128) uint32_t end[kMaxBlocks][8];
};
struct __align__(16) RankData { const void* __restrict__ ptrs[8]; };
struct __align__(16) RankData { const void *__restrict__ ptrs[8]; };
struct __align__(16) RankSignals { volatile Signal* signals[8]; };
struct __align__(16) RankSignals { volatile Signal *signals[8]; };
// like std::array, but aligned
template <typename T, int sz>
@ -68,11 +68,11 @@ DINLINE half downcast_s(float val) {
// scalar add functions
// for some reason when compiling with Pytorch, the + operator for half and
// bfloat is disabled so we call the intrinsics directly
DINLINE half& assign_add(half& a, half b) {
DINLINE half &assign_add(half &a, half b) {
a = __hadd(a, b);
return a;
}
DINLINE float& assign_add(float& a, float b) { return a += b; }
DINLINE float &assign_add(float &a, float b) { return a += b; }
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); }
@ -80,14 +80,14 @@ template <>
DINLINE nv_bfloat16 downcast_s(float val) {
return __float2bfloat16(val);
}
DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) {
DINLINE nv_bfloat16 &assign_add(nv_bfloat16 &a, nv_bfloat16 b) {
a = __hadd(a, b);
return a;
}
#endif
template <typename T, int N>
DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) {
DINLINE array_t<T, N> &packed_assign_add(array_t<T, N> &a, array_t<T, N> b) {
#pragma unroll
for (int i = 0; i < N; i++) {
assign_add(a.data[i], b.data[i]);
@ -128,7 +128,7 @@ DINLINE O downcast(array_t<float, O::size> val) {
// prior memory accesses. Note: volatile writes will not be reordered against
// other volatile writes.
template <int ngpus>
DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg,
DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
int rank) {
if (threadIdx.x < ngpus) {
// reset flag for next time
@ -137,7 +137,8 @@ DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg,
// Latency = 1 p2p write
sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
// wait until we got true from all ranks
while (!self_sg->start[blockIdx.x][threadIdx.x]);
while (!self_sg->start[blockIdx.x][threadIdx.x])
;
}
__syncthreads();
}
@ -146,13 +147,13 @@ DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg,
// barrier in the all reduce kernel. If it's the final synchronization barrier,
// we don't need to make any visibility guarantees for prior memory accesses.
template <int ngpus, bool final_sync = false>
DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg,
DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg,
int rank) {
__syncthreads();
// eliminate the case that prior writes are not visible after signals become
// visible. Note that I did not managed to make this happen through a lot of
// testing. Might be the case that hardware provides stronger guarantee than
// the memory model.
// the memory model.
if constexpr (!final_sync) __threadfence_system();
if (threadIdx.x < ngpus) {
// reset flag for next time
@ -161,13 +162,14 @@ DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg,
// Latency = 1 p2p write
sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1;
// wait until we got true from all ranks
while (!self_sg->end[blockIdx.x][threadIdx.x]);
while (!self_sg->end[blockIdx.x][threadIdx.x])
;
}
if constexpr (!final_sync) __syncthreads();
}
template <typename P, int ngpus, typename A>
DINLINE P packed_reduce(const P* ptrs[], int idx) {
DINLINE P packed_reduce(const P *ptrs[], int idx) {
A tmp = upcast(ptrs[0][idx]);
#pragma unroll
for (int i = 1; i < ngpus; i++) {
@ -178,8 +180,8 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) {
template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
cross_device_reduce_1stage(RankData* _dp, RankSignals sg,
volatile Signal* self_sg, T* __restrict__ result,
cross_device_reduce_1stage(RankData *_dp, RankSignals sg,
volatile Signal *self_sg, T *__restrict__ result,
int rank, int size) {
using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A;
@ -190,20 +192,21 @@ __global__ void __launch_bounds__(512, 1)
// do the actual reduction
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) {
((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
((P *)result)[idx] =
packed_reduce<P, ngpus, A>((const P **)&dp.ptrs[0], idx);
}
end_sync<ngpus, true>(sg, self_sg, rank);
}
template <typename P>
DINLINE P* get_tmp_buf(volatile Signal* sg) {
return (P*)(((Signal*)sg) + 1);
DINLINE P *get_tmp_buf(volatile Signal *sg) {
return (P *)(((Signal *)sg) + 1);
}
template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
cross_device_reduce_2stage(RankData* _dp, RankSignals sg,
volatile Signal* self_sg, T* __restrict__ result,
cross_device_reduce_2stage(RankData *_dp, RankSignals sg,
volatile Signal *self_sg, T *__restrict__ result,
int rank, int size) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x;
@ -213,12 +216,12 @@ __global__ void __launch_bounds__(512, 1)
int start = rank * part;
int end = rank == ngpus - 1 ? size : start + part;
int largest_part = part + size % ngpus;
const P* ptrs[ngpus];
P* tmps[ngpus];
const P *ptrs[ngpus];
P *tmps[ngpus];
#pragma unroll
for (int i = 0; i < ngpus; i++) {
int target = (rank + i) % ngpus;
ptrs[i] = (const P*)_dp->ptrs[target];
ptrs[i] = (const P *)_dp->ptrs[target];
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
}
auto tmp_out = tmps[0];
@ -240,7 +243,7 @@ __global__ void __launch_bounds__(512, 1)
int gather_from_rank = ((rank + i) % ngpus);
if (gather_from_rank == ngpus - 1 || idx < part) {
int dst_idx = gather_from_rank * part + idx;
((P*)result)[dst_idx] = tmps[i][idx];
((P *)result)[dst_idx] = tmps[i][idx];
}
}
}
@ -258,14 +261,14 @@ class CustomAllreduce {
// below are device pointers
RankSignals sg_;
std::unordered_map<void*, RankData*> buffers_;
Signal* self_sg_;
std::unordered_map<void *, RankData *> buffers_;
Signal *self_sg_;
// stores the registered device pointers from all ranks
RankData *d_rank_data_base_, *d_rank_data_end_;
std::vector<void*> graph_unreg_buffers_;
std::vector<void *> graph_unreg_buffers_;
// a map from IPC handles to opened IPC pointers
std::map<IPC_KEY, char*> ipc_handles_;
std::map<IPC_KEY, char *> ipc_handles_;
/**
* meta is a pointer to device metadata and temporary buffer for allreduce.
@ -276,22 +279,22 @@ class CustomAllreduce {
* note: this class does not own any device memory. Any required buffers
* are passed in from the constructor
*/
CustomAllreduce(Signal* meta, void* rank_data, size_t rank_data_sz,
const cudaIpcMemHandle_t* handles,
const std::vector<int64_t>& offsets, int rank,
CustomAllreduce(Signal *meta, void *rank_data, size_t rank_data_sz,
const cudaIpcMemHandle_t *handles,
const std::vector<int64_t> &offsets, int rank,
bool full_nvlink = true)
: rank_(rank),
world_size_(offsets.size()),
full_nvlink_(full_nvlink),
self_sg_(meta),
d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
d_rank_data_base_(reinterpret_cast<RankData *>(rank_data)),
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
for (int i = 0; i < world_size_; i++) {
Signal* rank_sg;
Signal *rank_sg;
if (i != rank_) {
char* handle = open_ipc_handle(&handles[i]);
char *handle = open_ipc_handle(&handles[i]);
handle += offsets[i];
rank_sg = (Signal*)handle;
rank_sg = (Signal *)handle;
} else {
rank_sg = self_sg_;
}
@ -299,13 +302,13 @@ class CustomAllreduce {
}
}
char* open_ipc_handle(const void* ipc_handle) {
char *open_ipc_handle(const void *ipc_handle) {
auto [it, new_handle] =
ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
ipc_handles_.insert({*((IPC_KEY *)ipc_handle), nullptr});
if (new_handle) {
char* ipc_ptr;
CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptr,
*((const cudaIpcMemHandle_t*)ipc_handle),
char *ipc_ptr;
CUDACHECK(cudaIpcOpenMemHandle((void **)&ipc_ptr,
*((const cudaIpcMemHandle_t *)ipc_handle),
cudaIpcMemLazyEnablePeerAccess));
it->second = ipc_ptr;
}
@ -320,7 +323,7 @@ class CustomAllreduce {
std::vector<int64_t> offsets(num_buffers);
for (int i = 0; i < num_buffers; i++) {
auto ptr = graph_unreg_buffers_[i];
void* base_ptr;
void *base_ptr;
// note: must share the base address of each allocation, or we get wrong
// address
if (cuPointerGetAttribute(&base_ptr,
@ -328,8 +331,8 @@ class CustomAllreduce {
(CUdeviceptr)ptr) != CUDA_SUCCESS)
throw std::runtime_error("failed to get pointer attr");
CUDACHECK(cudaIpcGetMemHandle(
(cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
offsets[i] = ((char*)ptr) - ((char*)base_ptr);
(cudaIpcMemHandle_t *)&handles[i * handle_sz], base_ptr));
offsets[i] = ((char *)ptr) - ((char *)base_ptr);
}
return std::make_pair(handles, offsets);
}
@ -341,13 +344,13 @@ class CustomAllreduce {
std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
}
void register_buffer(const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets, void* self) {
void register_buffer(const std::vector<std::string> &handles,
const std::vector<int64_t> &offsets, void *self) {
check_rank_data_capacity();
RankData data;
for (int i = 0; i < world_size_; i++) {
if (i != rank_) {
char* handle = open_ipc_handle(handles[i].data());
char *handle = open_ipc_handle(handles[i].data());
handle += offsets[i];
data.ptrs[i] = handle;
} else {
@ -368,17 +371,17 @@ class CustomAllreduce {
// got a different address. IPC handles have internal reference counting
// mechanism so overhead should be small.
void register_graph_buffers(
const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>>& offsets) {
const std::vector<std::string> &handles,
const std::vector<std::vector<int64_t>> &offsets) {
auto num_buffers = graph_unreg_buffers_.size();
check_rank_data_capacity(num_buffers);
std::vector<RankData> rank_data(num_buffers);
for (int i = 0; i < num_buffers; i++) {
auto self_ptr = graph_unreg_buffers_[i];
auto& rd = rank_data[i];
auto &rd = rank_data[i];
for (int j = 0; j < world_size_; j++) {
if (j != rank_) {
char* handle =
char *handle =
open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
handle += offsets[j][i];
rd.ptrs[j] = handle;
@ -402,7 +405,7 @@ class CustomAllreduce {
* will cause contention on NVLink bus.
*/
template <typename T>
void allreduce(cudaStream_t stream, T* input, T* output, int size,
void allreduce(cudaStream_t stream, T *input, T *output, int size,
int threads = 512, int block_limit = 36) {
auto d = packed_t<T>::P::size;
if (size % d != 0)
@ -415,7 +418,7 @@ class CustomAllreduce {
std::to_string(kMaxBlocks) + ". Got " +
std::to_string(block_limit));
RankData* ptrs;
RankData *ptrs;
cudaStreamCaptureStatus status;
CUDACHECK(cudaStreamIsCapturing(stream, &status));
if (status == cudaStreamCaptureStatusActive) {

View File

@ -48,7 +48,7 @@ __global__ void dummy_kernel() {
}
template <typename T>
__global__ void set_data(T* data, int size, int myRank) {
__global__ void set_data(T *data, int size, int myRank) {
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) {
data[idx] = myRank * 0.11f;
@ -56,8 +56,8 @@ __global__ void set_data(T* data, int size, int myRank) {
}
template <typename T>
__global__ void convert_data(const T* data1, const T* data2, double* fdata1,
double* fdata2, int size) {
__global__ void convert_data(const T *data1, const T *data2, double *fdata1,
double *fdata2, int size) {
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) {
fdata1[idx] = data1[idx];
@ -65,7 +65,7 @@ __global__ void convert_data(const T* data1, const T* data2, double* fdata1,
}
}
__global__ void init_rand(curandState_t* state, int size, int nRanks) {
__global__ void init_rand(curandState_t *state, int size, int nRanks) {
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) {
for (int i = 0; i < nRanks; i++) {
@ -75,7 +75,7 @@ __global__ void init_rand(curandState_t* state, int size, int nRanks) {
}
template <typename T>
__global__ void gen_data(curandState_t* state, T* data, double* ground_truth,
__global__ void gen_data(curandState_t *state, T *data, double *ground_truth,
int myRank, int nRanks, int size) {
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) {
@ -91,9 +91,9 @@ __global__ void gen_data(curandState_t* state, T* data, double* ground_truth,
}
template <typename T>
void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
int data_size, bool performance_test) {
T* result;
T *result;
cudaStream_t stream;
CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
CUDACHECK(cudaMalloc(&result, data_size * sizeof(T)));
@ -101,8 +101,8 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
cudaIpcMemHandle_t self_data_handle;
cudaIpcMemHandle_t data_handles[8];
vllm::Signal* buffer;
T* self_data_copy;
vllm::Signal *buffer;
T *self_data_copy;
/**
* Allocate IPC buffer
*
@ -125,22 +125,22 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t),
MPI_BYTE, MPI_COMM_WORLD));
void* rank_data;
void *rank_data;
size_t rank_data_sz = 16 * 1024 * 1024;
CUDACHECK(cudaMalloc(&rank_data, rank_data_sz));
std::vector<int64_t> offsets(nRanks, 0);
vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles,
offsets, myRank);
auto* self_data =
reinterpret_cast<T*>(reinterpret_cast<char*>(buffer) +
sizeof(vllm::Signal) + data_size * sizeof(T));
auto *self_data =
reinterpret_cast<T *>(reinterpret_cast<char *>(buffer) +
sizeof(vllm::Signal) + data_size * sizeof(T));
// hack buffer registration
{
std::vector<std::string> handles;
handles.reserve(nRanks);
for (int i = 0; i < nRanks; i++) {
char* begin = (char*)&data_handles[i];
char* end = (char*)&data_handles[i + 1];
char *begin = (char *)&data_handles[i];
char *end = (char *)&data_handles[i + 1];
handles.emplace_back(begin, end);
}
std::vector<int64_t> offsets(nRanks,
@ -148,9 +148,9 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
fa.register_buffer(handles, offsets, self_data);
}
double* ground_truth;
double *ground_truth;
CUDACHECK(cudaMallocHost(&ground_truth, data_size * sizeof(double)));
curandState_t* states;
curandState_t *states;
CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size));
init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks);
gen_data<T><<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank,
@ -287,7 +287,7 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
CUDACHECK(cudaStreamDestroy(stream));
}
int main(int argc, char** argv) {
int main(int argc, char **argv) {
int nRanks, myRank;
MPICHECK(MPI_Init(&argc, &argv));
MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank));
@ -296,7 +296,7 @@ int main(int argc, char** argv) {
ncclUniqueId id;
ncclComm_t comm;
if (myRank == 0) ncclGetUniqueId(&id);
MPICHECK(MPI_Bcast(static_cast<void*>(&id), sizeof(id), MPI_BYTE, 0,
MPICHECK(MPI_Bcast(static_cast<void *>(&id), sizeof(id), MPI_BYTE, 0,
MPI_COMM_WORLD));
NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank));

View File

@ -6,30 +6,32 @@
#include <torch/extension.h>
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, \
VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))

View File

@ -11,24 +11,26 @@
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
using __nv_bfloat16 = __hip_bfloat16;
using __nv_bfloat162 = __hip_bfloat162;
using __nv_bfloat16 = __hip_bfloat16;
using __nv_bfloat162 = __hip_bfloat162;
#endif
namespace vllm {
// TODO(woosuk): Further optimize this kernel.
template <typename scalar_t>
template<typename scalar_t>
__global__ void rms_norm_kernel(
scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens, const int hidden_size) {
scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon,
const int num_tokens,
const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x = (float)input[blockIdx.x * hidden_size + idx];
const float x = (float) input[blockIdx.x * hidden_size + idx];
variance += x * x;
}
variance = blockReduceSum<float>(variance);
@ -38,12 +40,12 @@ __global__ void rms_norm_kernel(
__syncthreads();
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * hidden_size + idx];
out[blockIdx.x * hidden_size + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
float x = (float) input[blockIdx.x * hidden_size + idx];
out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
}
}
/* Converter structs for the conversion from torch types to HIP/CUDA types,
and the associated type conversions within HIP/CUDA. These helpers need
to be implemented for now because the relevant type conversion
@ -52,68 +54,51 @@ __global__ void rms_norm_kernel(
Each struct should have the member static constexpr bool `exists`:
If false, the optimized kernel is not used for the corresponding torch type.
If true, the struct should be fully defined as shown in the examples below.
If true, the struct should be fully defined as shown in the examples below.
*/
template <typename torch_type>
struct _typeConvert {
static constexpr bool exists = false;
};
template<typename torch_type>
struct _typeConvert { static constexpr bool exists = false; };
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
// CUDA < 12.0 runs into issues with packed type conversion
template <>
template<>
struct _typeConvert<c10::Half> {
static constexpr bool exists = true;
using hip_type = __half;
using packed_hip_type = __half2;
__device__ static inline float convert(hip_type x) { return __half2float(x); }
__device__ static inline float2 convert(packed_hip_type x) {
return __half22float2(x);
}
__device__ static inline hip_type convert(float x) {
return __float2half_rn(x);
}
__device__ static inline packed_hip_type convert(float2 x) {
return __float22half2_rn(x);
}
__device__ static inline float2 convert(packed_hip_type x) { return __half22float2(x); }
__device__ static inline hip_type convert(float x) { return __float2half_rn(x); }
__device__ static inline packed_hip_type convert(float2 x) { return __float22half2_rn(x); }
};
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// CUDA_ARCH < 800 does not have BF16 support
// TODO: Add in ROCm support once public headers handle bf16 maturely
template <>
template<>
struct _typeConvert<c10::BFloat16> {
static constexpr bool exists = true;
using hip_type = __nv_bfloat16;
using packed_hip_type = __nv_bfloat162;
__device__ static inline float convert(hip_type x) {
return __bfloat162float(x);
}
__device__ static inline float2 convert(packed_hip_type x) {
return __bfloat1622float2(x);
}
__device__ static inline hip_type convert(float x) {
return __float2bfloat16(x);
}
__device__ static inline packed_hip_type convert(float2 x) {
return __float22bfloat162_rn(x);
}
__device__ static inline float convert(hip_type x) { return __bfloat162float(x); }
__device__ static inline float2 convert(packed_hip_type x) { return __bfloat1622float2(x); }
__device__ static inline hip_type convert(float x) { return __float2bfloat16(x); }
__device__ static inline packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); }
};
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
// 12000))
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
for appropriate specializations of fused_add_rms_norm_kernel.
Only functions that are necessary in that kernel are implemented.
Alignment to 16 bytes is required to use 128-bit global memory ops.
*/
template <typename scalar_t, int width>
template<typename scalar_t, int width>
struct alignas(16) _f16Vec {
/* Not theoretically necessary that width is a power of 2 but should
almost always be the case for optimization purposes */
/* Not theoretically necessary that width is a power of 2 but should
almost always be the case for optimization purposes */
static_assert(width > 0 && (width & (width - 1)) == 0,
"Width is not a positive power of 2!");
using Converter = _typeConvert<scalar_t>;
@ -123,49 +108,51 @@ struct alignas(16) _f16Vec {
__device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) {
if constexpr (width % 2 == 0) {
#pragma unroll
#pragma unroll
for (int i = 0; i < width; i += 2) {
T2 temp{data[i], data[i + 1]};
temp += T2{other.data[i], other.data[i + 1]};
T2 temp{data[i], data[i+1]};
temp += T2{other.data[i], other.data[i+1]};
data[i] = temp.x;
data[i + 1] = temp.y;
data[i+1] = temp.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i) data[i] += other.data[i];
#pragma unroll
for (int i = 0; i < width; ++i)
data[i] += other.data[i];
}
return *this;
}
__device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) {
if constexpr (width % 2 == 0) {
#pragma unroll
#pragma unroll
for (int i = 0; i < width; i += 2) {
T2 temp{data[i], data[i + 1]};
temp *= T2{other.data[i], other.data[i + 1]};
T2 temp{data[i], data[i+1]};
temp *= T2{other.data[i], other.data[i+1]};
data[i] = temp.x;
data[i + 1] = temp.y;
data[i+1] = temp.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i) data[i] *= other.data[i];
#pragma unroll
for (int i = 0; i < width; ++i)
data[i] *= other.data[i];
}
return *this;
}
__device__ _f16Vec& operator*=(const float scale) {
if constexpr (width % 2 == 0) {
#pragma unroll
#pragma unroll
for (int i = 0; i < width; i += 2) {
float2 temp_f = Converter::convert(T2{data[i], data[i + 1]});
float2 temp_f = Converter::convert(T2{data[i], data[i+1]});
temp_f.x *= scale;
temp_f.y *= scale;
T2 temp = Converter::convert(temp_f);
data[i] = temp.x;
data[i + 1] = temp.y;
data[i+1] = temp.y;
}
} else {
#pragma unroll
#pragma unroll
for (int i = 0; i < width; ++i) {
float temp = Converter::convert(data[i]) * scale;
data[i] = Converter::convert(temp);
@ -177,13 +164,13 @@ struct alignas(16) _f16Vec {
__device__ float sum_squares() const {
float result = 0.0f;
if constexpr (width % 2 == 0) {
#pragma unroll
#pragma unroll
for (int i = 0; i < width; i += 2) {
float2 z = Converter::convert(T2{data[i], data[i + 1]});
float2 z = Converter::convert(T2{data[i], data[i+1]});
result += z.x * z.x + z.y * z.y;
}
} else {
#pragma unroll
#pragma unroll
for (int i = 0; i < width; ++i) {
float x = Converter::convert(data[i]);
result += x * x;
@ -197,13 +184,15 @@ struct alignas(16) _f16Vec {
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the
memory latency bottleneck. */
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
fused_add_rms_norm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens, const int hidden_size) {
template<typename scalar_t, int width>
__global__ std::enable_if_t<
(width > 0) && _typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon,
const int num_tokens,
const int hidden_size) {
// Sanity checks on our vector struct and type-punned pointer arithmetic
static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
@ -214,12 +203,9 @@ fused_add_rms_norm_kernel(
/* These and the argument pointers are all declared `restrict` as they are
not aliased in practice. Argument pointers should not be dereferenced
in this kernel as that would be undefined behavior */
auto* __restrict__ input_v =
reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
auto* __restrict__ residual_v =
reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
auto* __restrict__ weight_v =
reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
auto* __restrict__ input_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
auto* __restrict__ residual_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
auto* __restrict__ weight_v = reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
@ -229,11 +215,10 @@ fused_add_rms_norm_kernel(
residual_v[id] = temp;
}
/* Keep the following if-else block in sync with the
calculation of max_block_size in fused_add_rms_norm */
calculation of max_block_size in fused_add_rms_norm */
if (num_tokens < 256) {
variance = blockReduceSum<float, 1024>(variance);
} else
variance = blockReduceSum<float, 256>(variance);
} else variance = blockReduceSum<float, 256>(variance);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
@ -248,50 +233,52 @@ fused_add_rms_norm_kernel(
}
}
/* Generic fused_add_rms_norm_kernel
The width field is not used here but necessary for other specializations.
*/
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
fused_add_rms_norm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens, const int hidden_size) {
template<typename scalar_t, int width>
__global__ std::enable_if_t<
(width == 0) || !_typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon,
const int num_tokens,
const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
scalar_t z = input[blockIdx.x * hidden_size + idx];
z += residual[blockIdx.x * hidden_size + idx];
float x = (float)z;
float x = (float) z;
variance += x * x;
residual[blockIdx.x * hidden_size + idx] = z;
}
/* Keep the following if-else block in sync with the
calculation of max_block_size in fused_add_rms_norm */
calculation of max_block_size in fused_add_rms_norm */
if (num_tokens < 256) {
variance = blockReduceSum<float, 1024>(variance);
} else
variance = blockReduceSum<float, 256>(variance);
} else variance = blockReduceSum<float, 256>(variance);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)residual[blockIdx.x * hidden_size + idx];
input[blockIdx.x * hidden_size + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
float x = (float) residual[blockIdx.x * hidden_size + idx];
input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
}
}
} // namespace vllm
} // namespace vllm
void rms_norm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
float epsilon) {
void rms_norm(
torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
float epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
@ -299,27 +286,40 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
});
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
"rms_norm_kernel",
[&] {
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);
});
}
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
vllm::fused_add_rms_norm_kernel<scalar_t, width> \
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(), \
residual.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), epsilon, \
num_tokens, hidden_size); \
});
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \
"fused_add_rms_norm_kernel", \
[&] { \
vllm::fused_add_rms_norm_kernel \
<scalar_t, width><<<grid, block, 0, stream>>>( \
input.data_ptr<scalar_t>(), \
residual.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), \
epsilon, \
num_tokens, \
hidden_size); \
});
void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
float epsilon) {
void fused_add_rms_norm(
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
float epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
@ -342,8 +342,8 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned =
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 \
&& wt_ptr % 16 == 0;
if (ptrs_are_aligned && hidden_size % 8 == 0) {
LAUNCH_FUSED_ADD_RMS_NORM(8);
} else {

View File

@ -3,6 +3,5 @@
#include <torch/extension.h>
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("topk_softmax", &topk_softmax,
"Apply topk softmax to the gating outputs.");
m.def("topk_softmax", &topk_softmax, "Apply topk softmax to the gating outputs.");
}

View File

@ -2,6 +2,8 @@
#include <torch/extension.h>
void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
torch::Tensor& token_expert_indices,
torch::Tensor& gating_output);
void topk_softmax(
torch::Tensor& topk_weights,
torch::Tensor& topk_indices,
torch::Tensor& token_expert_indices,
torch::Tensor& gating_output);

View File

@ -7,128 +7,119 @@
#include "cuda_compat.h"
#include "dispatch_utils.h"
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
#define CEILDIV(x,y) (((x) + (y) - 1) / (y))
namespace vllm {
namespace {
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
int32_t col) {
// don't worry about overflow because num_experts is relatively small
return row * total_col + col;
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) {
// don't worry about overflow because num_experts is relatively small
return row * total_col + col;
}
}
} // namespace
template <typename scalar_t>
__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
int32_t* sorted_token_ids,
int32_t* expert_ids,
int32_t* total_tokens_post_pad,
int32_t num_experts,
int32_t block_size, size_t numel) {
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
const size_t start_idx = threadIdx.x * tokens_per_thread;
__global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
int32_t *sorted_token_ids,
int32_t *expert_ids,
int32_t *total_tokens_post_pad,
int32_t num_experts,
int32_t block_size,
size_t numel) {
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
const size_t start_idx = threadIdx.x * tokens_per_thread;
extern __shared__ int32_t shared_mem[];
extern __shared__ int32_t shared_mem[];
int32_t* tokens_cnts =
shared_mem; // 2d tensor with shape (num_experts + 1, num_experts)
int32_t* cumsum =
shared_mem + (num_experts + 1) *
num_experts; // 1d tensor with shape (num_experts + 1)
int32_t* tokens_cnts = shared_mem; // 2d tensor with shape (num_experts + 1, num_experts)
int32_t* cumsum = shared_mem + (num_experts + 1) * num_experts; // 1d tensor with shape (num_experts + 1)
for (int i = 0; i < num_experts; ++i) {
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
}
/**
* In the first step we compute token_cnts[thread_index + 1][expert_index],
* which counts how many tokens in the token shard of thread_index are
* assigned to expert expert_index.
*/
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
}
__syncthreads();
// For each expert we accumulate the token counts from the different threads.
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
for (int i = 1; i <= blockDim.x; ++i) {
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
}
__syncthreads();
// We accumulate the token counts of all experts in thread 0.
if (threadIdx.x == 0) {
cumsum[0] = 0;
for (int i = 1; i <= num_experts; ++i) {
cumsum[i] = cumsum[i - 1] +
CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
block_size) *
block_size;
for (int i = 0; i < num_experts; ++i) {
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
}
*total_tokens_post_pad = cumsum[num_experts];
}
__syncthreads();
/**
* In the first step we compute token_cnts[thread_index + 1][expert_index],
* which counts how many tokens in the token shard of thread_index are assigned
* to expert expert_index.
*/
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
}
/**
* For each expert, each thread processes the tokens of the corresponding
* blocks and stores the corresponding expert_id for each block.
*/
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
i += block_size) {
expert_ids[i / block_size] = threadIdx.x;
}
__syncthreads();
/**
* Each thread processes a token shard, calculating the index of each token
* after sorting by expert number. Given the example topk_ids =
* [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
* *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
* padding value(preset in python).
*/
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
int32_t expert_id = topk_ids[i];
/** The cumsum[expert_id] stores the starting index of the tokens that the
* expert with expert_id needs to process, and
* tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
* processed by the expert with expert_id within the current thread's token
* shard.
*/
int32_t rank_post_pad =
tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
cumsum[expert_id];
sorted_token_ids[rank_post_pad] = i;
++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
}
// For each expert we accumulate the token counts from the different threads.
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
for (int i = 1; i <= blockDim.x; ++i) {
tokens_cnts[index(num_experts, i, threadIdx.x)] += tokens_cnts[index(num_experts, i-1, threadIdx.x)];
}
__syncthreads();
// We accumulate the token counts of all experts in thread 0.
if (threadIdx.x == 0) {
cumsum[0] = 0;
for (int i = 1; i <= num_experts; ++i) {
cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], block_size) * block_size;
}
*total_tokens_post_pad = cumsum[num_experts];
}
__syncthreads();
/**
* For each expert, each thread processes the tokens of the corresponding blocks
* and stores the corresponding expert_id for each block.
*/
for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) {
expert_ids[i / block_size] = threadIdx.x;
}
/**
* Each thread processes a token shard, calculating the index of each token after
* sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and
* block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *],
* where * represents a padding value(preset in python).
*/
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
int32_t expert_id = topk_ids[i];
/** The cumsum[expert_id] stores the starting index of the tokens that the
* expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id]
* stores the indices of the tokens processed by the expert with expert_id within
* the current thread's token shard.
*/
int32_t rank_post_pad = tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + cumsum[expert_id];
sorted_token_ids[rank_post_pad] = i;
++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
}
}
}
} // namespace vllm
void moe_align_block_size(torch::Tensor topk_ids, int num_experts,
int block_size, torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors
const int32_t shared_mem =
((num_experts + 1) * num_experts + (num_experts + 1)) *
sizeof(int32_t);
void moe_align_block_size(
torch::Tensor topk_ids,
int num_experts,
int block_size,
torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
// calc needed amount of shared mem for `tokens_cnts` and `cumsum` tensors
const int32_t shared_mem = ((num_experts + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t);
// set dynamic shared mem
auto kernel = vllm::moe_align_block_size_kernel<scalar_t>;
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
(void*)kernel, shared_mem));
AT_CUDA_CHECK(
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize((void *)kernel, shared_mem));
kernel<<<1, num_experts, shared_mem, stream>>>(
topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(),
num_experts,
block_size,
topk_ids.numel());
});
});
}

View File

@ -3,139 +3,159 @@
#include <torch/extension.h>
void paged_attention_v1(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step);
torch::Tensor& out,
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
int num_kv_heads,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int block_size,
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype,
float kv_scale);
void paged_attention_v2(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step);
torch::Tensor& out,
torch::Tensor& exp_sums,
torch::Tensor& max_logits,
torch::Tensor& tmp_out,
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
int num_kv_heads,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int block_size,
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype,
float kv_scale);
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
float epsilon);
void rms_norm(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& weight,
float epsilon);
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, float epsilon);
void fused_add_rms_norm(
torch::Tensor& input,
torch::Tensor& residual,
torch::Tensor& weight,
float epsilon);
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& key, int head_size,
torch::Tensor& cos_sin_cache, bool is_neox);
void rotary_embedding(
torch::Tensor& positions,
torch::Tensor& query,
torch::Tensor& key,
int head_size,
torch::Tensor& cos_sin_cache,
bool is_neox);
void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& key, int head_size,
torch::Tensor& cos_sin_cache, bool is_neox,
int rot_dim,
torch::Tensor& cos_sin_cache_offsets);
void batched_rotary_embedding(
torch::Tensor& positions,
torch::Tensor& query,
torch::Tensor& key,
int head_size,
torch::Tensor& cos_sin_cache,
bool is_neox,
int rot_dim,
torch::Tensor& cos_sin_cache_offsets);
void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
void silu_and_mul(
torch::Tensor& out,
torch::Tensor& input);
void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);
void gelu_and_mul(
torch::Tensor& out,
torch::Tensor& input);
void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
void gelu_tanh_and_mul(
torch::Tensor& out,
torch::Tensor& input);
void gelu_new(torch::Tensor& out, torch::Tensor& input);
void gelu_new(
torch::Tensor& out,
torch::Tensor& input);
void gelu_fast(torch::Tensor& out, torch::Tensor& input);
void gelu_fast(
torch::Tensor& out,
torch::Tensor& input);
#ifndef USE_ROCM
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
const torch::Tensor& codebooks,
const torch::Tensor& scales,
const torch::Tensor& codebook_partition_sizes,
const std::optional<torch::Tensor>& bias);
torch::Tensor awq_gemm(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int split_k_iters);
torch::Tensor aqlm_dequant(const torch::Tensor& codes,
const torch::Tensor& codebooks,
const torch::Tensor& codebook_partition_sizes);
torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
torch::Tensor _scaling_factors, torch::Tensor _zeros,
int split_k_iters);
torch::Tensor awq_dequantize(torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros, int split_k_iters, int thx,
int thy);
torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& workspace,
int64_t size_m, int64_t size_n, int64_t size_k);
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_meta,
torch::Tensor& b_scales,
torch::Tensor& workspace, int64_t num_bits,
int64_t size_m, int64_t size_n,
int64_t size_k);
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& g_idx,
torch::Tensor& perm, torch::Tensor& workspace,
int64_t num_bits, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full);
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits);
int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
torch::Tensor awq_dequantize(
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int split_k_iters,
int thx,
int thy);
torch::Tensor marlin_gemm(
torch::Tensor& a,
torch::Tensor& b_q_weight,
torch::Tensor& b_scales,
torch::Tensor& workspace,
int64_t size_m,
int64_t size_n,
int64_t size_k);
#endif
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor& input,
float scale);
void squeezellm_gemm(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor lookup_table);
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor lookup_table);
torch::Tensor gptq_gemm(
torch::Tensor a,
torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales,
torch::Tensor b_g_idx,
bool use_exllama,
int bit);
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
bool use_exllama, int bit);
void gptq_shuffle(
torch::Tensor q_weight,
torch::Tensor q_perm,
int bit);
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int bit);
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& scale);
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& scale);
void moe_align_block_size(torch::Tensor topk_ids, int num_experts,
int block_size, torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad);
void moe_align_block_size(
torch::Tensor topk_ids,
int num_experts,
int block_size,
torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad);
#ifndef USE_ROCM
using fptr_t = uint64_t;
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets, int rank,
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
const std::vector<std::string> &handles,
const std::vector<int64_t> &offsets, int rank,
bool full_nvlink);
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
bool full_nvlink);
bool should_custom_ar(torch::Tensor& inp, int max_size, int world_size,
bool full_nvlink);
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
torch::Tensor& out);
void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out);
void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &reg_buffer,
torch::Tensor &out);
void dispose(fptr_t _fa);
int meta_size();
void register_buffer(fptr_t _fa, torch::Tensor& t,
const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets);
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
fptr_t _fa);
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>>& offsets);
void register_buffer(fptr_t _fa, torch::Tensor &t,
const std::vector<std::string> &handles,
const std::vector<int64_t> &offsets);
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
const std::vector<std::vector<int64_t>> &offsets);
#endif

View File

@ -7,10 +7,14 @@
namespace vllm {
template <typename scalar_t, bool IS_NEOX>
template<typename scalar_t, bool IS_NEOX>
inline __device__ void apply_token_rotary_embedding(
scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr,
const scalar_t* __restrict__ sin_ptr, int rot_offset, int embed_dim) {
scalar_t* __restrict__ arr,
const scalar_t* __restrict__ cos_ptr,
const scalar_t* __restrict__ sin_ptr,
int rot_offset,
int embed_dim)
{
int x_index, y_index;
scalar_t cos, sin;
if (IS_NEOX) {
@ -33,17 +37,19 @@ inline __device__ void apply_token_rotary_embedding(
arr[y_index] = y * cos + x * sin;
}
template <typename scalar_t, bool IS_NEOX>
template<typename scalar_t, bool IS_NEOX>
inline __device__ void apply_rotary_embedding(
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const scalar_t* cache_ptr, const int head_size, const int num_heads,
const int num_kv_heads, const int rot_dim, const int token_idx,
const int64_t query_stride, const int64_t key_stride) {
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
const scalar_t* cache_ptr,
const int head_size,
const int num_heads,
const int num_kv_heads,
const int rot_dim,
const int token_idx,
const int64_t query_stride,
const int64_t key_stride)
{
const int embed_dim = rot_dim / 2;
const scalar_t* cos_ptr = cache_ptr;
const scalar_t* sin_ptr = cache_ptr + embed_dim;
@ -53,8 +59,8 @@ inline __device__ void apply_rotary_embedding(
const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
apply_token_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
sin_ptr, rot_offset, embed_dim);
}
const int nk = num_kv_heads * embed_dim;
@ -62,74 +68,62 @@ inline __device__ void apply_rotary_embedding(
const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
apply_token_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
sin_ptr, rot_offset, embed_dim);
}
}
template <typename scalar_t, bool IS_NEOX>
template<typename scalar_t, bool IS_NEOX>
__global__ void rotary_embedding_kernel(
const int64_t* __restrict__ positions, // [batch_size, seq_len] or
// [num_tokens]
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
// 2]
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size) {
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int rot_dim,
const int64_t query_stride,
const int64_t key_stride,
const int num_heads,
const int num_kv_heads,
const int head_size) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
token_idx, query_stride, key_stride);
apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
}
template <typename scalar_t, bool IS_NEOX>
template<typename scalar_t, bool IS_NEOX>
__global__ void batched_rotary_embedding_kernel(
const int64_t* __restrict__ positions, // [batch_size, seq_len] or
// [num_tokens]
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
// 2]
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len]
// or [num_tokens]
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size) {
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] or [num_tokens]
const int rot_dim,
const int64_t query_stride,
const int64_t key_stride,
const int num_heads,
const int num_kv_heads,
const int head_size) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx];
int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx];
const scalar_t* cache_ptr =
cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;
const scalar_t* cache_ptr = cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
token_idx, query_stride, key_stride);
apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
}
} // namespace vllm
} // namespace vllm
void rotary_embedding(
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
// [num_tokens, num_heads * head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size]
int head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox) {
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
int head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox) {
int64_t num_tokens = query.numel() / query.size(-1);
int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-1) / head_size;
@ -141,21 +135,36 @@ void rotary_embedding(
dim3 block(std::min(num_heads * rot_dim / 2, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
if (is_neox) {
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), rot_dim,
query_stride, key_stride, num_heads, num_kv_heads, head_size);
} else {
vllm::rotary_embedding_kernel<scalar_t, false>
<<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
head_size);
}
});
VLLM_DISPATCH_FLOATING_TYPES(
query.scalar_type(),
"rotary_embedding",
[&] {
if (is_neox) {
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
} else {
vllm::rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
}
});
}
/*
@ -163,15 +172,14 @@ Batched version of rotary embedding, pack multiple LoRAs together
and process in batched manner.
*/
void batched_rotary_embedding(
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
// [num_tokens, num_heads * head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size]
int head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox, int rot_dim,
torch::Tensor& cos_sin_cache_offsets // [num_tokens]
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
int head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox,
int rot_dim,
torch::Tensor& cos_sin_cache_offsets // [num_tokens]
) {
int64_t num_tokens = cos_sin_cache_offsets.size(0);
int num_heads = query.size(-1) / head_size;
@ -183,21 +191,36 @@ void batched_rotary_embedding(
dim3 block(std::min(num_heads * rot_dim / 2, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
if (is_neox) {
vllm::batched_rotary_embedding_kernel<scalar_t, true>
<<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
key_stride, num_heads, num_kv_heads, head_size);
} else {
vllm::batched_rotary_embedding_kernel<scalar_t, false>
<<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
key_stride, num_heads, num_kv_heads, head_size);
}
});
VLLM_DISPATCH_FLOATING_TYPES(
query.scalar_type(),
"rotary_embedding",
[&] {
if (is_neox) {
vllm::batched_rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
cos_sin_cache_offsets.data_ptr<int64_t>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
} else {
vllm::batched_rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
cos_sin_cache_offsets.data_ptr<int64_t>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
}
});
}

View File

@ -2,4 +2,3 @@
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)

View File

@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_half)

View File

@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_bfloat16)

View File

@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_half)

View File

@ -2,4 +2,3 @@
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, float, nv_bfloat16)

View File

@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_half)

View File

@ -28,7 +28,6 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 2752) \
f(in_T, out_T, W_T, narrow, 2816) \
f(in_T, out_T, W_T, narrow, 3072) \
f(in_T, out_T, W_T, narrow, 3328) \
f(in_T, out_T, W_T, narrow, 3456) \
f(in_T, out_T, W_T, narrow, 3584) \
f(in_T, out_T, W_T, narrow, 4096) \
@ -37,7 +36,6 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 5504) \
f(in_T, out_T, W_T, narrow, 5632) \
f(in_T, out_T, W_T, narrow, 6144) \
f(in_T, out_T, W_T, narrow, 6400) \
f(in_T, out_T, W_T, narrow, 6848) \
f(in_T, out_T, W_T, narrow, 6912) \
f(in_T, out_T, W_T, narrow, 7168) \
@ -55,7 +53,6 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 22016) \
f(in_T, out_T, W_T, narrow, 24576) \
f(in_T, out_T, W_T, narrow, 27392) \
f(in_T, out_T, W_T, narrow, 27648) \
f(in_T, out_T, W_T, narrow, 28672) \
f(in_T, out_T, W_T, narrow, 32000) \
f(in_T, out_T, W_T, narrow, 32256) \
@ -63,7 +60,6 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 32768) \
f(in_T, out_T, W_T, narrow, 33024) \
f(in_T, out_T, W_T, narrow, 36864) \
f(in_T, out_T, W_T, narrow, 43264) \
f(in_T, out_T, W_T, narrow, 49152) \
f(in_T, out_T, W_T, narrow, 64000) \
f(in_T, out_T, W_T, narrow, 64256) \
@ -77,77 +73,6 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
// and vllm/tests/lora/test_punica.py
// Used for defining kernels going from the variety of
// dim in to the narrow dim out
// Using it for the fully sharded column
// parallel LoRA A which splits the rank dim
#define FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, narrow) \
f(in_T, out_T, W_T, 128, narrow) \
f(in_T, out_T, W_T, 256, narrow) \
f(in_T, out_T, W_T, 512, narrow) \
f(in_T, out_T, W_T, 640, narrow) \
f(in_T, out_T, W_T, 768, narrow) \
f(in_T, out_T, W_T, 1024, narrow) \
f(in_T, out_T, W_T, 1152, narrow) \
f(in_T, out_T, W_T, 1280, narrow) \
f(in_T, out_T, W_T, 1536, narrow) \
f(in_T, out_T, W_T, 1728, narrow) \
f(in_T, out_T, W_T, 1792, narrow) \
f(in_T, out_T, W_T, 2048, narrow) \
f(in_T, out_T, W_T, 2304, narrow) \
f(in_T, out_T, W_T, 2560, narrow) \
f(in_T, out_T, W_T, 2752, narrow) \
f(in_T, out_T, W_T, 2816, narrow) \
f(in_T, out_T, W_T, 3072, narrow) \
f(in_T, out_T, W_T, 3328, narrow) \
f(in_T, out_T, W_T, 3456, narrow) \
f(in_T, out_T, W_T, 3584, narrow) \
f(in_T, out_T, W_T, 4096, narrow) \
f(in_T, out_T, W_T, 4608, narrow) \
f(in_T, out_T, W_T, 5120, narrow) \
f(in_T, out_T, W_T, 5504, narrow) \
f(in_T, out_T, W_T, 5632, narrow) \
f(in_T, out_T, W_T, 6144, narrow) \
f(in_T, out_T, W_T, 6400, narrow) \
f(in_T, out_T, W_T, 6848, narrow) \
f(in_T, out_T, W_T, 6912, narrow) \
f(in_T, out_T, W_T, 7168, narrow) \
f(in_T, out_T, W_T, 8192, narrow) \
f(in_T, out_T, W_T, 9216, narrow) \
f(in_T, out_T, W_T, 10240, narrow) \
f(in_T, out_T, W_T, 11008, narrow) \
f(in_T, out_T, W_T, 12288, narrow) \
f(in_T, out_T, W_T, 13696, narrow) \
f(in_T, out_T, W_T, 13824, narrow) \
f(in_T, out_T, W_T, 14336, narrow) \
f(in_T, out_T, W_T, 15360, narrow) \
f(in_T, out_T, W_T, 16384, narrow) \
f(in_T, out_T, W_T, 20480, narrow) \
f(in_T, out_T, W_T, 22016, narrow) \
f(in_T, out_T, W_T, 24576, narrow) \
f(in_T, out_T, W_T, 27392, narrow) \
f(in_T, out_T, W_T, 27648, narrow) \
f(in_T, out_T, W_T, 28672, narrow) \
f(in_T, out_T, W_T, 32000, narrow) \
f(in_T, out_T, W_T, 32256, narrow) \
f(in_T, out_T, W_T, 32512, narrow) \
f(in_T, out_T, W_T, 32768, narrow) \
f(in_T, out_T, W_T, 33024, narrow) \
f(in_T, out_T, W_T, 36864, narrow) \
f(in_T, out_T, W_T, 43264, narrow) \
f(in_T, out_T, W_T, 49152, narrow) \
f(in_T, out_T, W_T, 64000, narrow) \
f(in_T, out_T, W_T, 64256, narrow) \
f(in_T, out_T, W_T, 64512, narrow) \
f(in_T, out_T, W_T, 102400, narrow) \
f(in_T, out_T, W_T, 102656, narrow) \
f(in_T, out_T, W_T, 102912, narrow) \
f(in_T, out_T, W_T, 128000, narrow) \
f(in_T, out_T, W_T, 128256, narrow) \
f(in_T, out_T, W_T, 128512, narrow) \
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
// Keep this in sync with vllm/config::LoRAConfig
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \
@ -155,14 +80,4 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64)
#define FOR_INST_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 1) \
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 2) \
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 4) \
f(in_T, out_T, W_T, 8, 64) \
f(in_T, out_T, W_T, 16, 64) \
f(in_T, out_T, W_T, 32, 64) \
f(in_T, out_T, W_T, 64, 64)
// clang-format on

View File

@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_bfloat16)

View File

@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_half)

View File

@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16)

View File

@ -2,4 +2,3 @@
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, nv_half, nv_half)

View File

@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16)

View File

@ -2,4 +2,3 @@
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, float, nv_half)

View File

@ -2,4 +2,3 @@
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_bfloat16, nv_bfloat16)

View File

@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_half)

View File

@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_bfloat16)

View File

@ -2,4 +2,3 @@
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_half, nv_half)

View File

@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_bfloat16)

View File

@ -0,0 +1,4 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_half)

View File

@ -1,14 +1,8 @@
#pragma once
#include <ATen/cuda/CUDAContext.h>
#ifndef USE_ROCM
#include <cooperative_groups.h>
#else
#include <hip/hip_cooperative_groups.h>
#endif
#ifndef USE_ROCM
#include <cuda/pipeline>
#endif
#include <cuda_runtime.h>
#include <iostream>
#include <stdio.h>
@ -17,24 +11,6 @@
namespace cg = cooperative_groups;
#ifdef USE_ROCM
template <size_t len>
__host__ __device__
inline void* memcpy_blocking(void *dst, const void *src) {
// Does not handle the case of long datatypes
char *d = reinterpret_cast<char *>(dst);
const char *s = reinterpret_cast<const char *>(src);
size_t i = 0;
#pragma unroll
for (i = 0; i < len; ++i) {
d[i] = s[i];
}
return dst;
}
#endif
#ifndef USE_ROCM
// nthrs = (32, 4)
template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
size_t W_copy_size, int tx, int ty, int tz, typename in_T,
@ -165,81 +141,6 @@ bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
}
}
#else
template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
size_t W_copy_size, int tx, int ty, int tz, typename in_T,
typename out_T, typename W_T>
__global__ void
bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
const W_T *__restrict__ W,
const int64_t *__restrict__ indicies, int64_t y_offset,
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
float scale) {
size_t batch_idx = blockIdx.y;
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
if (idx < 0) {
return;
}
size_t j = blockIdx.x;
constexpr size_t tile_size = tx * ty * vec_size;
constexpr size_t num_tiles = (feat_in + tile_size - 1) / tile_size;
__shared__ float y_warpwise[ty];
float y = 0;
vec_t<in_T, vec_size> x_vec;
vec_t<W_T, vec_size> w_vec;
size_t tile_idx;
#pragma unroll
for (tile_idx = 0; tile_idx < num_tiles; ++tile_idx) {
if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) {
x_vec.load(X + (batch_idx * feat_in) +
tile_idx * tile_size +
(threadIdx.y * tx + threadIdx.x) * vec_size);
w_vec.load(W + (idx * feat_out + j) * feat_in +
tile_idx * tile_size +
(threadIdx.y * tx + threadIdx.x) * vec_size);
}
float sum = 0.f;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
sum += convert_type<W_T, float>(w_vec[i]) * convert_type<in_T, float>(x_vec[i]) * scale;
}
#pragma unroll
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
sum += VLLM_SHFL_DOWN_SYNC(sum, offset);
}
__syncthreads();
if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) {
y += sum;
}
}
if (threadIdx.x == 0) {
y_warpwise[threadIdx.y] = y;
}
__syncthreads();
float y_write = 0.f;
#pragma unroll
for (size_t i = 0; i < ty; ++i) {
y_write += y_warpwise[i];
}
// write Y;
if (threadIdx.x == 0 && threadIdx.y == 0) {
size_t y_idx = batch_idx * full_y_size + y_offset + j;
Y[y_idx] = vllm_add<out_T>(Y[y_idx], convert_type<float, out_T>(y_write));
}
}
#endif
// nthrs = (2, 16, 4)
template <int feat_in, int feat_out, size_t vec_size, int tx, int ty, int tz,
typename in_T, typename out_T, typename W_T>
@ -271,11 +172,7 @@ bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
float sum = 0.f;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
#ifndef USE_ROCM
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
#else
sum += convert_type<W_T, float>(w_vec[i]) * convert_type<in_T, float>(x_vec[i]) * scale;
#endif
}
cg::thread_block_tile g = cg::tiled_partition<tx>(block);
@ -286,14 +183,8 @@ bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
sum = g.shfl(sum, 0);
if (threadIdx.x == 0) {
#ifndef USE_ROCM
Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
threadIdx.z * ty + threadIdx.y] += static_cast<out_T>(sum);
#else
size_t y_idx = batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
threadIdx.z * ty + threadIdx.y;
Y[y_idx] = vllm_add<out_T>(Y[y_idx], convert_type<float, out_T>(sum));
#endif
}
}
@ -308,7 +199,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
constexpr int tz = 4;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if constexpr (feat_in <= feat_out) {
if constexpr (feat_in < feat_out) {
static_assert(feat_in % vec_size == 0);
constexpr int tx = feat_in / vec_size;
@ -345,7 +236,6 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
scale);
}
} else {
#ifndef USE_ROCM
static_assert(feat_in % (vec_size * 32) == 0 ||
feat_in % (vec_size * 16) == 0 ||
feat_in % (vec_size * 8) == 0);
@ -389,50 +279,6 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
full_y_size, num_layers, layer_idx,
scale);
}
#else
constexpr size_t rocm_warp_size = warpSize;
#define CHECK_INPUT_TILEABLE_BY(vec_size_) \
feat_in % (rocm_warp_size * vec_size_) == 0
#define LAUNCH_BGMV_SHRINK_KERNELS_ROCM(factor_, vec_size_, tx_, ty_) \
if constexpr (CHECK_INPUT_TILEABLE_BY(factor_)) { \
constexpr size_t vec_size_shrink = vec_size_; \
constexpr int tx = tx_; \
constexpr int ty = ty_; \
dim3 nblks(feat_out, batch_size); \
dim3 nthrs(tx, ty); \
bgmv_shrink_kernel<feat_in, feat_out, vec_size_shrink, \
vec_size_shrink * sizeof(in_T), \
vec_size_shrink * sizeof(W_T), \
tx, ty, tz> \
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset, \
full_y_size, num_layers, layer_idx, \
scale); \
}
static_assert(CHECK_INPUT_TILEABLE_BY(32) ||
CHECK_INPUT_TILEABLE_BY(16) ||
CHECK_INPUT_TILEABLE_BY( 8) ||
CHECK_INPUT_TILEABLE_BY( 4) ||
CHECK_INPUT_TILEABLE_BY( 2) ||
CHECK_INPUT_TILEABLE_BY( 1));
LAUNCH_BGMV_SHRINK_KERNELS_ROCM(32, vec_size, rocm_warp_size, 32/vec_size)
else
LAUNCH_BGMV_SHRINK_KERNELS_ROCM(16, vec_size, rocm_warp_size, 16/vec_size)
else
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 8, vec_size, rocm_warp_size, 8/vec_size)
else
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 4, vec_size, rocm_warp_size/(vec_size/4), vec_size/4)
else
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 2, vec_size, rocm_warp_size/(vec_size/2), vec_size/2)
else
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 1, vec_size, rocm_warp_size/(vec_size/1), vec_size/1)
#undef CHECK_INPUT_TILEABLE_BY
#undef LAUNCH_BGMV_SHRINK_KERNELS_ROCM
#endif
}
}
@ -443,9 +289,6 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
int64_t y_offset, int64_t full_y_size, int64_t batch_size, \
int64_t num_layers, int64_t layer_idx, float scale);
#define INST_BGMV_ONESIDE(in_T, out_T, W_T, feat_in, feat_out) \
INST_BGMV(feat_in, feat_out, in_T, out_T, W_T)
#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \
INST_BGMV(narrow, wide, in_T, out_T, W_T) \
INST_BGMV(wide, narrow, in_T, out_T, W_T)

View File

@ -10,7 +10,6 @@ TEMPLATE = """
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype})
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, {input_dtype}, {output_dtype}, {weight_dtype})
""".lstrip() # noqa: E501
for input_dtype in DTYPES:
@ -19,26 +18,6 @@ for input_dtype in DTYPES:
if weight_dtype == "fp32":
# FP32 weights are not supported.
continue
if output_dtype == "fp32":
# LoRA A matrix.
if input_dtype != weight_dtype:
# NOTE(woosuk): While Punica supports the case where the
# input and weight dtypes are different, we only generate
# the kernels the same dtypes to reduce the binary size.
continue
elif input_dtype == "fp32":
# LoRA B matrix.
if output_dtype != weight_dtype:
# NOTE(woosuk): While Punica supports the case where the
# output and weight dtypes are different, we only generate
# the kernels the same dtypes to reduce the binary size.
continue
elif not (input_dtype == output_dtype == weight_dtype):
# NOTE(woosuk): While Punica supports mixed data types for
# input, output, and weight, we only generate the kernels with
# the same data types to reduce the binary size.
continue
kernel_definition = TEMPLATE.format(
input_dtype=DTYPE_MAP[input_dtype],
output_dtype=DTYPE_MAP[output_dtype],

View File

@ -1,6 +1,8 @@
#ifndef VEC_DTYPES_CUH_
#define VEC_DTYPES_CUH_
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#ifdef FLASHINFER_USE_FP8
#include <cuda_fp8.h>
#endif
@ -8,9 +10,6 @@
#include <type_traits>
#include "../type_convert.h"
#include "../../cuda_compat.h"
#define FLASHINFER_INLINE \
inline __attribute__((always_inline)) __device__ __host__

View File

@ -1,11 +1,12 @@
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cstdint>
#include "type_convert.h"
#include "../cuda_compat.h"
#include "bgmv/bgmv_config.h"
namespace {
//====== utils ======
@ -49,23 +50,6 @@ inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
int64_t y_offset, int64_t full_y_size,
int64_t batch_size, int64_t num_layers,
int64_t layer_idx, float scale) {
// NOTE(woosuk): While Punica supports various combinations of input/output
// data types, we limit the supported data types to reduce the binary size.
constexpr bool is_input_float = std::is_same<in_T, float>::value;
constexpr bool is_output_float = std::is_same<out_T, float>::value;
if (is_input_float) {
if (!std::is_same<out_T, W_T>::value) {
return false;
}
} else if (is_output_float) {
if (!std::is_same<in_T, W_T>::value) {
return false;
}
} else if (!(std::is_same<in_T, W_T>::value &&
std::is_same<out_T, W_T>::value)) {
return false;
}
switch (pack_u32(in_features, out_features)) {
#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \
case pack_u32(feat_in, feat_out): \
@ -78,12 +62,12 @@ inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
CASE_ONESIDE(in_T, out_T, W_T, wide, narrow)
FOR_BGMV_WIDE_NARROW(CASE, _, _, _)
FOR_INST_BGMV_WIDE_NARROW(CASE_ONESIDE, _, _, _)
#undef CASE
#undef CASE_ONESIDE
default:
return false;
}
return true;
}
@ -567,3 +551,15 @@ void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
" dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
}
} // namespace
//====== pybind ======
#define DEFINE_pybind(name) m.def(#name, &name, #name);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv");
m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level,
"dispatch_bgmv_low_level");
}

View File

@ -1,11 +0,0 @@
#pragma once
#include <torch/extension.h>
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
torch::Tensor indicies, int64_t layer_idx, float scale);
void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
torch::Tensor indicies, int64_t layer_idx,
float scale, int64_t h_in, int64_t h_out,
int64_t y_offset);

View File

@ -1,13 +0,0 @@
#include <torch/extension.h>
#include "punica_ops.h"
//====== pybind ======
#define DEFINE_pybind(name) m.def(#name, &name, #name);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv");
m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level,
"dispatch_bgmv_low_level");
}

View File

@ -1,82 +0,0 @@
#ifndef CSRC__PUNICA__TYPE_CONVERT_H__
#define CSRC__PUNICA__TYPE_CONVERT_H__
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#define __TYPE_CONVERT__HOST_DEVICE__ __host__ __device__
typedef __half nv_half;
typedef __hip_bfloat16 nv_bfloat16;
typedef __hip_bfloat162 nv_bfloat162;
__TYPE_CONVERT__HOST_DEVICE__
inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 val) {
return __hip_bfloat162{val, val};
}
__TYPE_CONVERT__HOST_DEVICE__
inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 vall, __hip_bfloat16 valr) {
return __hip_bfloat162{vall, valr};
}
template <typename T_src, typename T_dst>
__TYPE_CONVERT__HOST_DEVICE__
inline T_dst convert_type(T_src val) {
return static_cast<T_dst>(val);
}
template <>
__TYPE_CONVERT__HOST_DEVICE__
inline float convert_type<__half, float>(__half val) {
return __half2float(val);
}
template <>
__TYPE_CONVERT__HOST_DEVICE__
inline __half convert_type<float, __half>(float val) {
return __float2half(val);
}
template <>
__TYPE_CONVERT__HOST_DEVICE__
inline float convert_type<__hip_bfloat16, float>(__hip_bfloat16 val) {
return __bfloat162float(val);
}
template <>
__TYPE_CONVERT__HOST_DEVICE__
inline __hip_bfloat16 convert_type<float, __hip_bfloat16>(float val) {
return __float2bfloat16(val);
}
template <typename T>
__TYPE_CONVERT__HOST_DEVICE__
inline T vllm_add(T a, T b) {
return a + b;
}
template <>
__TYPE_CONVERT__HOST_DEVICE__
inline __half vllm_add<__half>(__half a, __half b) {
return __hadd(a, b);
}
template <>
__TYPE_CONVERT__HOST_DEVICE__
inline __hip_bfloat16 vllm_add<__hip_bfloat16>(__hip_bfloat16 a, __hip_bfloat16 b) {
return __hadd(a, b);
}
#undef __TYPE_CONVERT__HOST_DEVICE__
#endif // USE_ROCM
#endif // CSRC__PUNICA__TYPE_CONVERT_H__

View File

@ -8,90 +8,104 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
// Attention ops
ops.def("paged_attention_v1", &paged_attention_v1,
"Compute the attention between an input query and the cached "
"keys/values using PagedAttention.");
ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2.");
ops.def(
"paged_attention_v1",
&paged_attention_v1,
"Compute the attention between an input query and the cached keys/values using PagedAttention.");
ops.def(
"paged_attention_v2",
&paged_attention_v2,
"PagedAttention V2.");
// Activation ops
ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU.");
ops.def("gelu_and_mul", &gelu_and_mul,
"Activation function used in GeGLU with `none` approximation.");
ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul,
"Activation function used in GeGLU with `tanh` approximation.");
ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2.");
ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation.");
ops.def(
"silu_and_mul",
&silu_and_mul,
"Activation function used in SwiGLU.");
ops.def(
"gelu_and_mul",
&gelu_and_mul,
"Activation function used in GeGLU with `none` approximation.");
ops.def(
"gelu_tanh_and_mul",
&gelu_tanh_and_mul,
"Activation function used in GeGLU with `tanh` approximation.");
ops.def(
"gelu_new",
&gelu_new,
"GELU implementation used in GPT-2.");
ops.def(
"gelu_fast",
&gelu_fast,
"Approximate GELU implementation.");
// Layernorm
ops.def("rms_norm", &rms_norm,
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
ops.def(
"rms_norm",
&rms_norm,
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
ops.def("fused_add_rms_norm", &fused_add_rms_norm,
"In-place fused Add and RMS Normalization");
ops.def(
"fused_add_rms_norm",
&fused_add_rms_norm,
"In-place fused Add and RMS Normalization");
// Rotary embedding
ops.def("rotary_embedding", &rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
ops.def(
"rotary_embedding",
&rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
ops.def("batched_rotary_embedding", &batched_rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key "
"(supports multiple loras)");
ops.def(
"batched_rotary_embedding",
&batched_rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key (supports multiple loras)");
// Quantization ops
#ifndef USE_ROCM
ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM");
ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM");
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
ops.def("marlin_gemm", &marlin_gemm,
"Marlin (Dense) Optimized Quantized GEMM for GPTQ");
ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm,
"Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ");
ops.def("gptq_marlin_gemm", &gptq_marlin_gemm,
"gptq_marlin Optimized Quantized GEMM for GPTQ");
ops.def("gptq_marlin_repack", &gptq_marlin_repack,
"gptq_marlin repack from GPTQ");
ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ");
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
ops.def("cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq,
"CUTLASS w8a8 GEMM, supporting symmetric per-tensor or "
"per-row/column quantization.");
#endif
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant,
"Compute FP8 quantized tensor for given scaling factor");
ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant,
"Compute FP8 quantized tensor and scaling factor");
ops.def("moe_align_block_size", &moe_align_block_size,
"Aligning the number of tokens to be processed by each expert such "
"that it is divisible by the block size.");
ops.def("static_scaled_int8_quant", &static_scaled_int8_quant,
"Compute int8 quantized tensor for given scaling factor");
ops.def(
"moe_align_block_size",
&moe_align_block_size,
"Aligning the number of tokens to be processed by each expert such that it is divisible by the block size.");
// Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
cache_ops.def("swap_blocks", &swap_blocks,
"Swap in (out) the cache blocks from src to dst");
cache_ops.def("copy_blocks", &copy_blocks,
"Copy the cache blocks from src to dst");
cache_ops.def("reshape_and_cache", &reshape_and_cache,
"Reshape the key and value tensors and cache them");
cache_ops.def("reshape_and_cache_flash", &reshape_and_cache_flash,
"Reshape the key and value tensors and cache them");
cache_ops.def("convert_fp8", &convert_fp8,
"Convert the key and value cache to fp8 data type");
cache_ops.def(
"swap_blocks",
&swap_blocks,
"Swap in (out) the cache blocks from src to dst");
cache_ops.def(
"copy_blocks",
&copy_blocks,
"Copy the cache blocks from src to dst");
cache_ops.def(
"reshape_and_cache",
&reshape_and_cache,
"Reshape the key and value tensors and cache them");
cache_ops.def(
"convert_fp8",
&convert_fp8,
"Convert the key and value cache to fp8 data type");
// Cuda utils
pybind11::module cuda_utils =
m.def_submodule("cuda_utils", "vLLM cuda utils");
cuda_utils.def("get_device_attribute", &get_device_attribute,
"Gets the specified device attribute.");
pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils");
cuda_utils.def(
"get_device_attribute",
&get_device_attribute,
"Gets the specified device attribute.");
cuda_utils.def("get_max_shared_memory_per_block_device_attribute",
&get_max_shared_memory_per_block_device_attribute,
"Gets the maximum shared memory per block device attribute.");
cuda_utils.def(
"get_max_shared_memory_per_block_device_attribute",
&get_max_shared_memory_per_block_device_attribute,
"Gets the maximum shared memory per block device attribute.");
#ifndef USE_ROCM
// Custom all-reduce kernels
@ -108,4 +122,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
custom_ar.def("register_graph_buffers", &register_graph_buffers,
"register_graph_buffers");
#endif
}

View File

@ -1,598 +0,0 @@
/*
* Modified by Neural Magic
* Adapted from https://github.com/Vahe1994/AQLM
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/cuda/CUDAGuard.h>
#include <iostream>
#include <cstdlib>
namespace vllm {
namespace aqlm {
__global__ void Code1x16MatVec(
const int4* __restrict__ A, const int4* __restrict__ B,
int4* __restrict__ C, const int4* __restrict__ codebook, const int prob_m,
const int prob_k,
const int4 codebook_a_sizes, // cumulative sizes of A spanning each
// codebook, at most 3 long.
const int codebook_stride // as int4.
) {
int a_gl_stride = prob_k / 8 / 8;
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
bool pred = a_gl_rd < prob_m;
if (pred) {
// advance to the correct codebook, this easy because we only multiply one
// column of the codebook.
auto codebook_size = &codebook_a_sizes.x;
while (a_gl_rd >= *codebook_size) {
codebook += codebook_stride;
++codebook_size;
}
}
int b_gl_rd = 0;
int c_gl_wr = a_gl_rd;
a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
__shared__ int4 sh_b[32 * 9];
float res = 0;
int iters = (prob_k / 8 + 8 * 32 - 1) / (8 * 32);
while (iters--) {
// We pad shared memory to avoid bank conflicts during reads
__syncthreads();
for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
}
__syncthreads();
b_gl_rd += 32 * 8;
int b_sh_rd = 9 * (threadIdx.x % 32);
if (pred && a_gl_rd < a_gl_end) {
const uint16_t* enc = reinterpret_cast<const uint16_t*>(&A[a_gl_rd]);
#pragma unroll
for (int i = 0; i < 8; i++) {
uint32_t dec[4];
// We bypass the L1 cache to avoid massive amounts of memory streaming
// that doesn't actually help us; this brings > 2x speedup.
asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
: "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
: "l"((void*)&codebook[enc[i]]));
half2* a = reinterpret_cast<half2*>(&dec);
half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
half2 res2 = {};
#pragma unroll
for (int j = 0; j < 4; j++) res2 = __hfma2(a[j], b[j], res2);
res += __half2float(res2.x) + __half2float(res2.y);
b_sh_rd++;
}
a_gl_rd += 32;
}
}
if (pred) {
#pragma unroll
for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i);
if (threadIdx.x % 32 == 0)
reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
}
}
__global__ void Code2x8MatVec(
const int4* __restrict__ A, const int4* __restrict__ B,
int4* __restrict__ C, const int4* __restrict__ codebook, int prob_m,
int prob_k,
const int4 codebook_a_sizes, // cumulative sizes of A spanning each
// codebook, at most 3 long.
const int codebook_stride // as int4.
) {
int a_gl_stride = prob_k / 8 / 8;
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
bool pred = a_gl_rd < prob_m;
if (pred) {
// advance to the correct codebook, this easy because we only multiply one
// column of the codebook.
auto codebook_size = &codebook_a_sizes.x;
while (a_gl_rd >= *codebook_size) {
codebook += codebook_stride;
++codebook_size;
}
}
int b_gl_rd = 0;
int c_gl_wr = a_gl_rd;
a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
int lane = threadIdx.x % 8;
extern __shared__ int4 sh[];
int4* sh_b = sh;
int4* sh_code = sh_b + 32 * 9;
int4* sh_code0 = sh_code;
int4* sh_code1 = sh_code + 256 * 8;
for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
int4 dec = codebook[i];
#pragma unroll
for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec;
}
__syncthreads();
float res = 0;
int iters = (prob_k / 8 + 8 * 32 - 1) / (8 * 32);
while (iters--) {
// We pad shared memory to avoid bank conflicts during reads
__syncthreads();
for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
}
__syncthreads();
b_gl_rd += 32 * 8;
int b_sh_rd = 9 * (threadIdx.x % 32);
if (pred && a_gl_rd < a_gl_end) {
const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]);
#pragma unroll
for (int i = 0; i < 8; i++) {
half2* a0 =
reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
half2* a1 =
reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
half2 res2 = {};
#pragma unroll
for (int j = 0; j < 4; j++)
res2 = __hfma2(__hadd2(a0[j], a1[j]), b[j], res2);
res += __half2float(res2.x) + __half2float(res2.y);
b_sh_rd++;
}
a_gl_rd += 32;
}
}
if (pred) {
#pragma unroll
for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i);
if (threadIdx.x % 32 == 0)
reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
}
}
__global__ void Code1x16Dequant(
const int4* __restrict__ A, int4* __restrict__ C,
const int4* __restrict__ codebook, int prob_m, int prob_k,
const int4 codebook_a_sizes, // cumulative sizes of A spanning each
// codebook, at most 3 long, sums to m.
const int codebook_stride // as int4
) {
int a_gl_stride = prob_k / 8 / 8;
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
bool pred = a_gl_rd < prob_m;
if (pred) {
// advance to the correct codebook, this easy because we only multiply one
// column of the codebook.
auto codebook_size = &codebook_a_sizes.x;
while (a_gl_rd >= *codebook_size) {
codebook += codebook_stride;
++codebook_size;
}
}
a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
int c_gl_stride = prob_k / 8;
int c_gl_wr = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
c_gl_wr = c_gl_stride * c_gl_wr + (threadIdx.x % 32) * 8;
int iters = (prob_k / 8 - 1) / (8 * 32) + 1;
while (iters--) {
if (pred && a_gl_rd < a_gl_end) {
const uint16_t* enc = reinterpret_cast<const uint16_t*>(&A[a_gl_rd]);
#pragma unroll
for (int i = 0; i < 8; i++) {
int4 chunk;
auto dec = reinterpret_cast<uint32_t*>(&chunk);
// We bypass the L1 cache to avoid massive amounts of memory streaming
// that doesn't actually help us; this brings > 2x speedup.
asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
: "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
: "l"((void*)&codebook[enc[i]]));
C[a_gl_rd * 8 + i] = chunk;
}
}
a_gl_rd += 32;
}
}
__global__ void Code2x8Dequant(
const int4* __restrict__ A, int4* __restrict__ C,
const int4* __restrict__ codebook, int prob_m, int prob_k,
const int4
codebook_a_sizes, // cumulative sizes of A spanning each codebook, at
// most 3 long, corresponds to cols.
const int codebook_stride // as int4
) {
int a_gl_stride = prob_k / 8 / 8;
int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
bool pred = a_gl_rd < prob_m;
if (pred) {
// advance to the correct codebook, this easy because we only multiply one
// column of the codebook.
auto codebook_size = &codebook_a_sizes.x;
while (a_gl_rd >= *codebook_size) {
codebook += codebook_stride;
++codebook_size;
}
}
a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
int lane = threadIdx.x % 8;
int c_gl_stride = prob_k / 8;
int c_gl_wr = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
c_gl_wr = c_gl_stride * c_gl_wr + (threadIdx.x % 32) * 8;
extern __shared__ int4 sh[];
int4* sh_code = sh;
int4* sh_code0 = sh_code;
int4* sh_code1 = sh_code + 256 * 8;
for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
int4 dec = codebook[i];
#pragma unroll
for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec;
}
__syncthreads();
float res = 0;
int iters = (prob_k / 8 - 1) / (8 * 32) + 1;
while (iters--) {
if (pred && a_gl_rd < a_gl_end) {
const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]);
#pragma unroll
for (int i = 0; i < 8; i++) {
int4 chunk;
half2* a0 =
reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
half2* a1 =
reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
#pragma unroll
for (int j = 0; j < 4; j++)
reinterpret_cast<half2*>(&chunk)[j] = __hadd2(a0[j], a1[j]);
C[a_gl_rd * 8 + i] = chunk;
}
}
a_gl_rd += 32;
}
}
inline int ceildiv(int a, int b) { return (a + b - 1) / b; }
const int THREAD_M = 16;
void code1x16_matvec_cuda(const void* __restrict__ A,
const void* __restrict__ B, void* __restrict__ C,
const void* __restrict__ codebook, int prob_m,
int prob_k, const int4 codebook_a_sizes,
const int codebook_stride) {
int sms;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
int waves = 0;
int thread_m;
do {
waves++;
thread_m = ceildiv(prob_m, waves * sms);
} while (thread_m > THREAD_M);
int blocks = ceildiv(prob_m, thread_m);
int threads = 32 * thread_m;
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
Code1x16MatVec<<<blocks, threads, 16 * 32 * 9, stream>>>(
(const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m,
prob_k, codebook_a_sizes, codebook_stride);
}
void code2x8_matvec_cuda(const void* __restrict__ A, const void* __restrict__ B,
void* __restrict__ C,
const void* __restrict__ codebook, int prob_m,
int prob_k, const int4 codebook_a_sizes,
const int codebook_stride) {
int sms;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
int waves = 0;
int thread_m;
do {
waves++;
thread_m = ceildiv(prob_m, waves * sms);
} while (thread_m > THREAD_M);
int blocks = ceildiv(prob_m, thread_m);
int threads = 32 * thread_m;
int shared = 16 * (2 * 256 * 8 + 32 * 9);
cudaFuncSetAttribute(Code2x8MatVec,
cudaFuncAttributeMaxDynamicSharedMemorySize, shared);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
Code2x8MatVec<<<blocks, threads, shared, stream>>>(
(const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m,
prob_k, codebook_a_sizes, codebook_stride);
}
void code1x16_dequant_cuda(
const void* __restrict__ A, void* __restrict__ C,
const void* __restrict__ codebook, int prob_m, int prob_k,
const int4 codebook_a_sizes, // cumulative sizes of A spanning each
// codebook, at most 3 long.
const int codebook_stride // as int4.
) {
int sms;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
int waves = 0;
int thread_m;
do {
waves++;
thread_m = ceildiv(prob_m, waves * sms);
} while (thread_m > THREAD_M);
int blocks = ceildiv(prob_m, thread_m);
int threads = 32 * thread_m;
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
Code1x16Dequant<<<blocks, threads, 0, stream>>>(
(const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k,
codebook_a_sizes, // cumulative sizes of A spanning each codebook, at
// most 3 long.
codebook_stride // as int4.
);
}
// Dequantizes the code and codebook into weights.
void code2x8_dequant_cuda(
const void* __restrict__ A, void* __restrict__ C,
const void* __restrict__ codebook, int prob_m, int prob_k,
const int4
codebook_a_sizes, // cumulative sizes of A spanning each codebook, at
// most 3 long, corresponds to cols.
const int codebook_stride // as int4
) {
int sms;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
int waves = 0;
int thread_m;
do {
waves++;
thread_m = ceildiv(prob_m, waves * sms);
} while (thread_m > THREAD_M);
int blocks = ceildiv(prob_m, thread_m);
int threads = 32 * thread_m;
int shared = 16 * (2 * 256 * 8 + 32 * 9);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cudaFuncSetAttribute(Code2x8Dequant,
cudaFuncAttributeMaxDynamicSharedMemorySize, shared);
Code2x8Dequant<<<blocks, threads, shared, stream>>>(
(const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k,
codebook_a_sizes, codebook_stride);
}
int codebook_stride(const torch::Tensor& codebooks) {
return codebooks.stride(0) * codebooks.element_size() / sizeof(int4);
}
void code1x16_matvec(
const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& C,
const torch::Tensor& codebook,
const int4 codebook_a_sizes // cumulative sizes of A spanning each
// codebook, at most 3 long.
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
int prob_m = C.size(0);
int prob_k = B.size(0);
code1x16_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(),
codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes,
codebook_stride(codebook));
}
torch::Tensor code1x16_matmat(const torch::Tensor& input,
const torch::Tensor& codes,
const torch::Tensor& codebooks,
const torch::Tensor& scales,
const int4 codebook_a_sizes,
const std::optional<torch::Tensor>& bias) {
auto input_sizes = input.sizes();
auto out_features = codes.size(0) * codebooks.size(2);
auto flat_input = input.reshape({-1, input.size(-1)});
auto flat_output = torch::empty(
{flat_input.size(0), out_features},
torch::TensorOptions().dtype(input.dtype()).device(input.device()));
for (int i = 0; i < flat_input.size(0); ++i) {
auto input_vec = flat_input.index({i});
auto output_vec = flat_output.index({i});
code1x16_matvec(codes.squeeze(2), input_vec, output_vec, codebooks,
codebook_a_sizes);
}
flat_output *= scales.flatten().unsqueeze(0);
if (bias.has_value()) {
flat_output += bias->unsqueeze(0);
}
auto output_sizes = input_sizes.vec();
output_sizes.pop_back();
output_sizes.push_back(-1);
auto output = flat_output.reshape(output_sizes);
return output;
}
void code2x8_matvec(const torch::Tensor& A, const torch::Tensor& B,
torch::Tensor& C, const torch::Tensor& codebook,
const int4 codebook_a_sizes) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
int prob_m = C.size(0);
int prob_k = B.size(0);
code2x8_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(),
codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes,
2 * codebook_stride(codebook));
}
torch::Tensor code2x8_matmat(const torch::Tensor& input,
const torch::Tensor& codes,
const torch::Tensor& codebooks,
const torch::Tensor& scales,
const int4 codebook_a_sizes,
const std::optional<torch::Tensor>& bias) {
auto input_sizes = input.sizes();
auto out_features = codes.size(0) * codebooks.size(2);
auto flat_input = input.reshape({-1, input.size(-1)});
auto flat_output = torch::empty(
{flat_input.size(0), out_features},
torch::TensorOptions().dtype(input.dtype()).device(input.device()));
for (int i = 0; i < flat_input.size(0); ++i) {
auto input_vec = flat_input.index({i});
auto output_vec = flat_output.index({i});
code2x8_matvec(codes.squeeze(2), input_vec, output_vec, codebooks,
codebook_a_sizes);
}
flat_output *= scales.flatten().unsqueeze(0);
if (bias.has_value()) {
flat_output += bias->unsqueeze(0);
}
auto output_sizes = input_sizes.vec();
output_sizes.pop_back();
output_sizes.push_back(-1);
auto output = flat_output.reshape(output_sizes);
return output;
}
// Accumulate the partition sizes.
int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) {
int4 cumulative_sizes;
auto cumulative_size = &cumulative_sizes.x;
int i = 0;
int last = 0;
assert(codebook_partition_sizes.size(0) <= 4);
for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size) {
*cumulative_size = codebook_partition_sizes[i].item<int>() + last;
last = *cumulative_size;
}
// fill in the rest with unreachable.
for (; i < 4; ++i, ++cumulative_size) {
*cumulative_size = last * 10;
}
return cumulative_sizes;
}
} // namespace aqlm
} // namespace vllm
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
const torch::Tensor& codebooks,
const torch::Tensor& scales,
const torch::Tensor& codebook_partition_sizes,
const std::optional<torch::Tensor>& bias) {
int4 cumulative_sizes =
vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
int const entries = codebooks.size(1);
if (nbooks == 1 && entries == (1 << 16)) {
return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales,
cumulative_sizes, bias);
}
if (nbooks == 2 && entries == (1 << 8)) {
return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales,
cumulative_sizes, bias);
}
TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries,
" entries is not currently supported.")
return {};
}
torch::Tensor aqlm_dequant(const torch::Tensor& codes,
const torch::Tensor& codebooks,
const torch::Tensor& codebook_partition_sizes) {
int4 cumulative_sizes =
vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
int const entries = codebooks.size(1);
const at::cuda::OptionalCUDAGuard device_guard(device_of(codes));
int rows = codes.size(1);
int cols = codes.size(0);
auto in_features = codes.size(1) * 8;
auto out_features = codes.size(0);
assert(out_features = codebook_partition_sizes.sum().item<int>());
auto weights = torch::empty({out_features, in_features},
torch::TensorOptions()
.dtype(codebooks.dtype())
.device(codebooks.device()));
if (nbooks == 1 && entries == (1 << 16)) {
vllm::aqlm::code1x16_dequant_cuda(codes.data_ptr(), weights.data_ptr(),
codebooks.data_ptr(), out_features,
in_features, cumulative_sizes,
vllm::aqlm::codebook_stride(codebooks));
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower
// and not consistent with gemv implementation.) weights *=
// scales.index({"...", 0, 0});
return weights;
}
if (nbooks == 2 && entries == (1 << 8)) {
vllm::aqlm::code2x8_dequant_cuda(codes.data_ptr(), weights.data_ptr(),
codebooks.data_ptr(), out_features,
in_features, cumulative_sizes,
vllm::aqlm::codebook_stride(codebooks));
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower
// and not consistent with gemv implementation) weights *=
// scales.index({"...", 0, 0});
return weights;
}
TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries,
" entries is not currently supported.")
return {};
}

View File

@ -1,11 +1,11 @@
/*
Adapted from https://github.com/mit-han-lab/llm-awq
Modified from NVIDIA FasterTransformer:
https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and
Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang,
Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
}
*/
@ -14,88 +14,74 @@ Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
namespace vllm {
namespace awq {
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) {
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
assert(false);
#else
uint4 result;
uint4 result;
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
static constexpr uint32_t TOP_MASK = 0x00f000f0;
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
static constexpr uint32_t TOP_MASK = 0x00f000f0;
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
// Note that the entire sequence only requires 1 shift instruction. This is
// thanks to the register packing format and the fact that we force our
// integers to be unsigned, and account for this in the fp16 subtractions. In
// addition, I exploit the fact that sub and fma have the same throughput in
// order to convert elt_23 and elt_67 to fp16 without having to shift them to
// the bottom bits before hand.
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
// elt_67 to fp16 without having to shift them to the bottom bits before hand.
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
// dependency if we issue immediately before required.
const uint32_t top_i4s = i4s >> 8;
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
"n"(immLut));
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[1])
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
"n"(immLut));
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[2])
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
"n"(immLut));
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[3])
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM),
"n"(immLut));
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
// immediately before required.
const uint32_t top_i4s = i4s >> 8;
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[1])
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[2])
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[3])
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// I use inline PTX below because I am not sure if the compiler will emit
// float2half instructions if I use the half2 ctor. In this case, I chose
// performance reliability over code readability.
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
// half2 ctor. In this case, I chose performance reliability over code readability.
// This is the half2 {1032, 1032} represented as an integer.
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
// Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
// This is the half2 {-72, -72} represented as an integer.
// static constexpr uint32_t NEG_72 = 0xd480d480;
// Haotian: Let's use {-64, -64}.
static constexpr uint32_t NEG_64 = 0xd400d400;
// This is the half2 {1032, 1032} represented as an integer.
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
// Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
// This is the half2 {-72, -72} represented as an integer.
// static constexpr uint32_t NEG_72 = 0xd480d480;
// Haotian: Let's use {-64, -64}.
static constexpr uint32_t NEG_64 = 0xd400d400;
// Finally, we construct the output numbers.
// Convert elt_01
asm volatile("sub.f16x2 %0, %1, %2;\n"
: "=r"(h[0])
: "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_23
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(h[1])
: "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// Convert elt_45
asm volatile("sub.f16x2 %0, %1, %2;\n"
: "=r"(h[2])
: "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_67
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(h[3])
: "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// Finally, we construct the output numbers.
// Convert elt_01
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_23
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// Convert elt_45
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_67
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
return result;
return result;
#endif
}
} // namespace awq
} // namespace vllm
} // namespace awq
} // namespace vllm

Some files were not shown because too many files have changed in this diff Show More