Compare commits
446 Commits
bind_kv_ca
...
rob-fixes
| Author | SHA1 | Date | |
|---|---|---|---|
| 220d694080 | |||
| 70e06dd574 | |||
| 7954461d4c | |||
| a10da86677 | |||
| 284d5df45b | |||
| d5b0db449e | |||
| 66349c33a1 | |||
| 28d0396ff1 | |||
| 2f29ae383a | |||
| cf64b0e6a7 | |||
| f51f182d64 | |||
| 79e465f557 | |||
| 2ba687d39f | |||
| 5d57896e2c | |||
| f6f008ca1d | |||
| 24cbbe4778 | |||
| 2fec6e0b5c | |||
| 47a3f26b2a | |||
| 144162fc8c | |||
| 522279ebb9 | |||
| b877031d80 | |||
| 85687b43e7 | |||
| 120bbdfd82 | |||
| 2ceb7bc534 | |||
| 9f7fb5ec84 | |||
| a8a621e419 | |||
| dd861b992f | |||
| eb63ea1e18 | |||
| 2f4bd358f1 | |||
| 8a8b30eac1 | |||
| 2fa0e1396b | |||
| 1c2bec0f82 | |||
| ec870fba9a | |||
| df1430265c | |||
| 4c69e228b3 | |||
| 790b79750b | |||
| cfbb8c930f | |||
| baec0d4de9 | |||
| c21b99b912 | |||
| 93a00d7dde | |||
| 61e8c18350 | |||
| 8afcd0f633 | |||
| 91ca929dc7 | |||
| 84e00adc8a | |||
| 47c7126213 | |||
| a989ca2bf6 | |||
| 0fa3970deb | |||
| da6ea29f7a | |||
| 7297941b38 | |||
| f8a08cb90d | |||
| b15fd2be2a | |||
| e588ac237c | |||
| 5df2da5b97 | |||
| 11b986b3fb | |||
| 296f927f24 | |||
| 0032903a5b | |||
| 47195057e9 | |||
| 6edbfa924d | |||
| 1e508343e1 | |||
| 2e0b4cfde0 | |||
| 10f55fe6c5 | |||
| d3ccbd6350 | |||
| 0cfe7d386d | |||
| 0c6f5023c3 | |||
| 06dd08256f | |||
| b89d89f456 | |||
| 8355358fb3 | |||
| c0b1443345 | |||
| d35dace985 | |||
| 912031ceb5 | |||
| 4f13e89143 | |||
| b9a7dbe769 | |||
| 0cb2e05256 | |||
| d6945ecdf0 | |||
| 298298f97d | |||
| 6c8fae82dd | |||
| 16ed827378 | |||
| 8fa9df7987 | |||
| 27c1afe88b | |||
| ee6607332e | |||
| 7fbf70db57 | |||
| 2c31e4c3ea | |||
| 187f112ccd | |||
| 897db7b93d | |||
| b7ffb43792 | |||
| 6e1fba8a73 | |||
| bfde1688e7 | |||
| 905424ed65 | |||
| 5d20f389d6 | |||
| 2a0cb78016 | |||
| 2b22290ce0 | |||
| d8e82bc06d | |||
| 086b56824c | |||
| 5a0905ba2a | |||
| a8f12a63fd | |||
| 69ae2380c6 | |||
| 27261e40a6 | |||
| e3f813c33b | |||
| c607a2652b | |||
| 3d45e3d749 | |||
| 742369d35a | |||
| bfe2fe0af4 | |||
| a8652f4f0f | |||
| 2f726b241e | |||
| a597a57595 | |||
| ae65f3e237 | |||
| 34868b106a | |||
| 1f16b7fe74 | |||
| b88be22165 | |||
| d8c6d7d6b5 | |||
| 40828ce5fe | |||
| ffa443afed | |||
| 70e500cad9 | |||
| 4cb1c05c9e | |||
| c47aafa37c | |||
| cfbca8a2f2 | |||
| 0fe5609874 | |||
| 22d33baca2 | |||
| b0e96aaebb | |||
| 8310e0b59b | |||
| 26dd972adb | |||
| 61c7a1b856 | |||
| 374ee287d8 | |||
| a4d83661d7 | |||
| 8363cd093d | |||
| 6c5a3195db | |||
| 073d1ed354 | |||
| 3d446433ec | |||
| 1fe0fd12d3 | |||
| dafb4e504a | |||
| 68cf1601d3 | |||
| 61f412187d | |||
| 05ccd0aa35 | |||
| f690372b68 | |||
| 8b3e94a357 | |||
| 437f9162d0 | |||
| 4f065f12f5 | |||
| 228b768db6 | |||
| 027827cc1d | |||
| 72a8639b68 | |||
| 99abb8b650 | |||
| 3a1e648158 | |||
| 46c759c165 | |||
| 179a619c21 | |||
| 452e8fd968 | |||
| 8b793f7ec6 | |||
| af35d3a3cc | |||
| 3b457143d2 | |||
| ab656f2c2f | |||
| 64fc2193dc | |||
| dd732028f5 | |||
| 414919138b | |||
| db7c8ca910 | |||
| f863ffc965 | |||
| 400d483e87 | |||
| d1695758b2 | |||
| 53a0cf8b95 | |||
| 5eeabc2a44 | |||
| 18551e820c | |||
| e41e160263 | |||
| b89fb2a4a1 | |||
| 5340b0e221 | |||
| 37e3806132 | |||
| c0efdd655b | |||
| aaaec52ad9 | |||
| e1eb45d397 | |||
| 89fca671fb | |||
| d20b0c139c | |||
| 166a168b0f | |||
| 2bb0e1a799 | |||
| 6eaf1e5c52 | |||
| 868a8c5b2c | |||
| b4ad56c1bd | |||
| 69698f257e | |||
| cd0cd85102 | |||
| 0a74bfce9c | |||
| dd3b865854 | |||
| 9b87a579aa | |||
| b539222d4e | |||
| 8d6cf89526 | |||
| 583a9778e0 | |||
| a73e183e36 | |||
| 1e799b7ec1 | |||
| 7f6c5ee06c | |||
| faa0275730 | |||
| 8a5a9b70d7 | |||
| bb3aeddfaf | |||
| aecc780dba | |||
| 90df7f23aa | |||
| b9b5bdfc7d | |||
| 31060b2757 | |||
| fc1f67715d | |||
| f6137adbcb | |||
| e53b1350f2 | |||
| d30aa7e9e6 | |||
| d1ad2a57af | |||
| b82662d952 | |||
| 71c1e07107 | |||
| b30c75dda4 | |||
| def232e122 | |||
| 3453b964a3 | |||
| 61c6a5a796 | |||
| 74bc397b0a | |||
| f58aea002c | |||
| 3556a41434 | |||
| 9ed6ee92d6 | |||
| ee3778d5fc | |||
| aaacf17324 | |||
| 4c7629cae9 | |||
| e0fdfa1608 | |||
| 5952d8ab61 | |||
| a2ae496589 | |||
| 877e352262 | |||
| d4d93db2c5 | |||
| 8c0d15d5c5 | |||
| 97ac781c62 | |||
| 776dcec8fe | |||
| ccf02fcbae | |||
| acaea3bb07 | |||
| 9f37422779 | |||
| dd344e0342 | |||
| 54a8804455 | |||
| bbd94a19fc | |||
| 233ffce1eb | |||
| 40677783aa | |||
| 14f301b541 | |||
| 46f98893dd | |||
| fe66b34728 | |||
| 270a5da495 | |||
| 7097b4cc1c | |||
| 977a16772c | |||
| 73deea2fdb | |||
| 9d2b4a70f4 | |||
| 0b0d6421b2 | |||
| 1140991a7b | |||
| 613c5bb945 | |||
| fd8e055ffb | |||
| ab93f1360f | |||
| 40253bab44 | |||
| c77620d22d | |||
| 989ecd2007 | |||
| 54cc46f3eb | |||
| 601bd3268e | |||
| 09269b3127 | |||
| 27b50f1fe6 | |||
| 9532c49836 | |||
| 0c2af17c76 | |||
| a6e0d096dd | |||
| d3d4956261 | |||
| 4059adc31b | |||
| f1f632d9ec | |||
| 95d680b862 | |||
| fb4c7f8ef0 | |||
| 0b1cfa6180 | |||
| 32ef4983cd | |||
| ad19c8a003 | |||
| 2a602b055a | |||
| 7888e1d0a3 | |||
| 60c872d4b6 | |||
| 3fb17d26c8 | |||
| d47807ba08 | |||
| 02fcaa3d0a | |||
| 8a4a2efc6f | |||
| 8e9ffd37d6 | |||
| 01b3fd0af7 | |||
| f53a0586b9 | |||
| b1cc4dfef5 | |||
| 382403921f | |||
| a73122de96 | |||
| bd44b812cb | |||
| 55211b01e8 | |||
| 5d043c1685 | |||
| 36d1ccb286 | |||
| 1bc3b739c4 | |||
| 1bd32bc8dd | |||
| 128bf75283 | |||
| a94a699c3f | |||
| ab426ec9c0 | |||
| 165290d357 | |||
| ce20124671 | |||
| 53be4a8634 | |||
| f5d3acd474 | |||
| 916836bbfb | |||
| d9f83d6206 | |||
| 4a754fcf15 | |||
| c0c25e25fa | |||
| 45f3f3f59e | |||
| ff47aab056 | |||
| debd6bbf09 | |||
| 5c538c37b2 | |||
| e22ee1e7a2 | |||
| e392d85831 | |||
| 77a318bd01 | |||
| 80e78d02ac | |||
| 4a42b9f5d6 | |||
| 47532cd9f4 | |||
| 36e0c8f7da | |||
| 9f583e360c | |||
| b706d898af | |||
| 863d315c86 | |||
| d374f04a33 | |||
| 61a01b27a7 | |||
| 53056731fd | |||
| 4cbf286794 | |||
| c6e14a61ab | |||
| 07b4b7a37f | |||
| 07964e2f30 | |||
| 4bf82d4b90 | |||
| 9ab326713f | |||
| af295e9b01 | |||
| a1c8f3796c | |||
| 08a1a1121d | |||
| 1477ffc381 | |||
| 70b808fe1a | |||
| 63d635d179 | |||
| 1fc973c0b5 | |||
| c982ac5722 | |||
| 4290b704ff | |||
| c91b64f749 | |||
| d6123170d5 | |||
| 485afdd3cb | |||
| 90e88ab756 | |||
| 04421dff8a | |||
| 432d6dad15 | |||
| 5ff0d32580 | |||
| 0967110e42 | |||
| fb0acb6c72 | |||
| 92b0ce2ac7 | |||
| bc2d4473bf | |||
| 3b352a2f92 | |||
| dea985aef0 | |||
| 39be30351f | |||
| 001a9c7b0d | |||
| 89cdaa83e7 | |||
| b0746fae3d | |||
| 60a98b2de5 | |||
| 460f553a6d | |||
| 1253b15774 | |||
| dc74613fa2 | |||
| a21076ed3a | |||
| 212007b168 | |||
| fb16eea48b | |||
| 73ae0b44e9 | |||
| 6d7f037748 | |||
| 10f7552789 | |||
| b0d541947a | |||
| 5f0b53c6ea | |||
| eb8b5eb183 | |||
| 9513290032 | |||
| 0d5e73d30e | |||
| 609ef61fea | |||
| db84f5eb3b | |||
| 206e2577fa | |||
| e02883c400 | |||
| 9085aabd62 | |||
| 8d5aa466fb | |||
| 0b7f06b447 | |||
| 03fe18ae0f | |||
| cb8bdfade2 | |||
| 33f227e16b | |||
| cfd0ae8234 | |||
| 7caff01a7b | |||
| be0b399d74 | |||
| b8b0ccbd2d | |||
| c908a07f57 | |||
| 7b6fd6e486 | |||
| 47512b3200 | |||
| 3b9c6c6947 | |||
| 4aae667668 | |||
| 9f3bc0f58c | |||
| 980385f8c1 | |||
| ca7a2d5f28 | |||
| 333681408f | |||
| ef64044079 | |||
| 66e16a038e | |||
| e1f0835ae0 | |||
| 8ed5421aaa | |||
| c6359e8ca6 | |||
| 952a074980 | |||
| d0feea31c7 | |||
| 58abe35455 | |||
| f7ebad2307 | |||
| 80e9afb5bc | |||
| 1e3598edeb | |||
| f7a6bd0fa1 | |||
| 0ca3b8e01c | |||
| cc10281498 | |||
| 05fb6718f0 | |||
| 12c29a881f | |||
| 70da0c0748 | |||
| c1588a2c94 | |||
| 8ca7a71df7 | |||
| 63137cd922 | |||
| ddd1ef66ec | |||
| e5e03c2c1b | |||
| e1744502c2 | |||
| dae6896977 | |||
| c34eeec58d | |||
| ad60bbb2b2 | |||
| 0578e5a462 | |||
| 04222984f8 | |||
| 6832707e90 | |||
| 6b2ef5cd17 | |||
| 958adce478 | |||
| 99b0915d3b | |||
| 8ca2b21c98 | |||
| d9292786e1 | |||
| cc2f9b32c8 | |||
| cd579352bf | |||
| 9f1710f1ac | |||
| e642ec962c | |||
| ada19210a3 | |||
| bf0560bda9 | |||
| 151b08e0fe | |||
| 81b2f4a45f | |||
| 82551ad616 | |||
| caac5c2e59 | |||
| 6bd1dd9d26 | |||
| 4f27044aab | |||
| 0ddc991f5c | |||
| fa82b93853 | |||
| 69ff99fdcd | |||
| 5d802522a7 | |||
| 1769928079 | |||
| ed6ea06577 | |||
| 5ee10e990d | |||
| 3dbd2d813a | |||
| f5f7f00cd9 | |||
| abcc61e0af | |||
| f6bb18fd9a | |||
| 71eaf8969b | |||
| ca100c90fe | |||
| ffad94397d | |||
| 4dacaa4a83 | |||
| a7ea35aa67 | |||
| 1e3e76b6cc | |||
| 53ea6ad830 | |||
| 1b7624bf5c | |||
| ac60dc7fe1 | |||
| a4f1ee35d6 | |||
| a32c8669ca | |||
| ca2ca8de57 | |||
| f71b00a19e | |||
| 8f808cf86e | |||
| 7bab4bb048 | |||
| e17e4488bd |
@ -4,8 +4,8 @@ tasks:
|
||||
- name: "gsm8k"
|
||||
metrics:
|
||||
- name: "exact_match,strict-match"
|
||||
value: 0.233
|
||||
value: 0.231
|
||||
- name: "exact_match,flexible-extract"
|
||||
value: 0.236
|
||||
value: 0.22
|
||||
limit: 1000
|
||||
num_fewshot: 5
|
||||
|
||||
@ -13,6 +13,7 @@ from pathlib import Path
|
||||
|
||||
import lm_eval
|
||||
import numpy
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
RTOL = 0.05
|
||||
@ -46,6 +47,10 @@ def test_lm_eval_correctness():
|
||||
eval_config = yaml.safe_load(
|
||||
Path(TEST_DATA_FILE).read_text(encoding="utf-8"))
|
||||
|
||||
if eval_config[
|
||||
"model_name"] == "nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform": #noqa: E501
|
||||
pytest.skip("FBGEMM is currently failing on main.")
|
||||
|
||||
# Launch eval requests.
|
||||
results = launch_lm_eval(eval_config)
|
||||
|
||||
|
||||
@ -426,7 +426,7 @@ main() {
|
||||
|
||||
pip install -U transformers
|
||||
|
||||
pip install -r requirements-dev.txt
|
||||
pip install -r requirements/dev.txt
|
||||
which genai-perf
|
||||
|
||||
# check storage
|
||||
|
||||
@ -361,7 +361,7 @@ main() {
|
||||
# get the current IP address, required by benchmark_serving.py
|
||||
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
|
||||
# turn of the reporting of the status of each request, to clean up the terminal output
|
||||
export VLLM_LOG_LEVEL="WARNING"
|
||||
export VLLM_LOGGING_LEVEL="WARNING"
|
||||
|
||||
# prepare for benchmarking
|
||||
cd benchmarks || exit 1
|
||||
|
||||
@ -82,7 +82,7 @@ steps:
|
||||
queue: cpu_queue_postmerge
|
||||
commands:
|
||||
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --progress plain -f Dockerfile.cpu ."
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest --progress plain -f Dockerfile.cpu ."
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
@ -101,16 +101,30 @@ if [[ $commands == *" kernels "* ]]; then
|
||||
--ignore=kernels/test_permute_cols.py"
|
||||
fi
|
||||
|
||||
#ignore certain Entrypoints tests
|
||||
#ignore certain Entrypoints/openai tests
|
||||
if [[ $commands == *" entrypoints/openai "* ]]; then
|
||||
commands=${commands//" entrypoints/openai "/" entrypoints/openai \
|
||||
--ignore=entrypoints/openai/test_accuracy.py \
|
||||
--ignore=entrypoints/openai/test_audio.py \
|
||||
--ignore=entrypoints/openai/test_encoder_decoder.py \
|
||||
--ignore=entrypoints/openai/test_embedding.py \
|
||||
--ignore=entrypoints/openai/test_oot_registration.py "}
|
||||
--ignore=entrypoints/openai/test_chat.py \
|
||||
--ignore=entrypoints/openai/test_shutdown.py \
|
||||
--ignore=entrypoints/openai/test_completion.py \
|
||||
--ignore=entrypoints/openai/test_sleep.py \
|
||||
--ignore=entrypoints/openai/test_models.py \
|
||||
--ignore=entrypoints/openai/test_prompt_validation.py "}
|
||||
fi
|
||||
|
||||
#ignore certain Entrypoints/llm tests
|
||||
if [[ $commands == *" && pytest -v -s entrypoints/llm/test_guided_generate.py"* ]]; then
|
||||
commands=${commands//" && pytest -v -s entrypoints/llm/test_guided_generate.py"/" "}
|
||||
fi
|
||||
|
||||
# --ignore=entrypoints/openai/test_encoder_decoder.py \
|
||||
# --ignore=entrypoints/openai/test_embedding.py \
|
||||
# --ignore=entrypoints/openai/test_oot_registration.py
|
||||
# --ignore=entrypoints/openai/test_accuracy.py \
|
||||
# --ignore=entrypoints/openai/test_models.py <= Fails on MI250 but passes on MI300 as of 2025-03-13
|
||||
|
||||
|
||||
PARALLEL_JOB_COUNT=8
|
||||
# check if the command contains shard flag, we will run all shards in parallel because the host have 8 GPUs.
|
||||
if [[ $commands == *"--shard-id="* ]]; then
|
||||
|
||||
@ -19,13 +19,14 @@ remove_docker_container
|
||||
|
||||
# Run the image, setting --shm-size=4g for tensor parallel.
|
||||
docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus="$CORE_RANGE" \
|
||||
--cpuset-mems="$NUMA_NODE" --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" cpu-test-"$BUILDKITE_BUILD_NUMBER"
|
||||
--cpuset-mems="$NUMA_NODE" --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" cpu-test-"$BUILDKITE_BUILD_NUMBER"
|
||||
docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus="$CORE_RANGE" \
|
||||
--cpuset-mems="$NUMA_NODE" --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2-"$NUMA_NODE" cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2
|
||||
--cpuset-mems="$NUMA_NODE" --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2-"$NUMA_NODE" cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2
|
||||
|
||||
function cpu_tests() {
|
||||
set -e
|
||||
export NUMA_NODE=$2
|
||||
export BUILDKITE_BUILD_NUMBER=$3
|
||||
|
||||
# offline inference
|
||||
docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2-"$NUMA_NODE" bash -c "
|
||||
@ -35,7 +36,8 @@ function cpu_tests() {
|
||||
# Run basic model test
|
||||
docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
pip install -r vllm/requirements-test.txt
|
||||
pip install -r vllm/requirements/test.txt
|
||||
pip install -r vllm/requirements/cpu.txt
|
||||
pytest -v -s tests/models/decoder_only/language -m cpu_model
|
||||
pytest -v -s tests/models/embedding/language -m cpu_model
|
||||
pytest -v -s tests/models/encoder_decoder/language -m cpu_model
|
||||
@ -85,4 +87,4 @@ function cpu_tests() {
|
||||
|
||||
# All of CPU tests are expected to be finished less than 40 mins.
|
||||
export -f cpu_tests
|
||||
timeout 40m bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
|
||||
timeout 40m bash -c "cpu_tests $CORE_RANGE $NUMA_NODE $BUILDKITE_BUILD_NUMBER"
|
||||
|
||||
@ -44,11 +44,11 @@ remove_docker_container() {
|
||||
trap remove_docker_container EXIT
|
||||
|
||||
# Run the image
|
||||
docker run --rm -it --device=/dev/neuron0 --device=/dev/neuron1 --network host \
|
||||
docker run --rm -it --device=/dev/neuron0 --network bridge \
|
||||
-v "${HF_CACHE}:${HF_MOUNT}" \
|
||||
-e "HF_HOME=${HF_MOUNT}" \
|
||||
-v "${NEURON_COMPILE_CACHE_URL}:${NEURON_COMPILE_CACHE_MOUNT}" \
|
||||
-e "NEURON_COMPILE_CACHE_URL=${NEURON_COMPILE_CACHE_MOUNT}" \
|
||||
--name "${container_name}" \
|
||||
${image_name} \
|
||||
/bin/bash -c "python3 /workspace/vllm/examples/offline_inference/neuron.py && python3 -m pytest /workspace/vllm/tests/neuron/ -v --capture=tee-sys"
|
||||
/bin/bash -c "python3 /workspace/vllm/examples/offline_inference/neuron.py && python3 -m pytest /workspace/vllm/tests/neuron/1_core/ -v --capture=tee-sys && python3 -m pytest /workspace/vllm/tests/neuron/2_core/ -v --capture=tee-sys"
|
||||
|
||||
@ -1,16 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# This script build the OpenVINO docker image and run the offline inference inside the container.
|
||||
# It serves a sanity check for compilation and basic model usage.
|
||||
set -ex
|
||||
|
||||
# Try building the docker image
|
||||
docker build -t openvino-test -f Dockerfile.openvino .
|
||||
|
||||
# Setup cleanup
|
||||
remove_docker_container() { docker rm -f openvino-test || true; }
|
||||
trap remove_docker_container EXIT
|
||||
remove_docker_container
|
||||
|
||||
# Run the image and launch offline inference
|
||||
docker run --network host --env VLLM_OPENVINO_KVCACHE_SPACE=1 --name openvino-test openvino-test python3 /workspace/examples/offline_inference/basic/generate.py --model facebook/opt-125m
|
||||
@ -19,8 +19,20 @@ docker run --privileged --net host --shm-size=16G -it \
|
||||
vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git \
|
||||
&& python3 -m pip install pytest \
|
||||
&& python3 -m pip install lm_eval[api]==0.4.4 \
|
||||
&& pytest -v -s /workspace/vllm/tests/entrypoints/openai/test_accuracy.py \
|
||||
&& pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \
|
||||
&& export VLLM_USE_V1=1 \
|
||||
&& export VLLM_XLA_CHECK_RECOMPILATION=1 \
|
||||
&& echo TEST_1 \
|
||||
&& python3 /workspace/vllm/tests/tpu/test_compilation.py \
|
||||
&& python3 /workspace/vllm/tests/tpu/test_quantization_accuracy.py \
|
||||
&& python3 /workspace/vllm/examples/offline_inference/tpu.py"
|
||||
&& echo TEST_2 \
|
||||
&& pytest -v -s /workspace/vllm/tests/v1/tpu/test_basic.py \
|
||||
&& echo TEST_3 \
|
||||
&& pytest -v -s /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine \
|
||||
&& echo TEST_4 \
|
||||
&& pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py \
|
||||
&& echo TEST_5 \
|
||||
&& python3 /workspace/vllm/examples/offline_inference/tpu.py" \
|
||||
|
||||
|
||||
# TODO: This test fails because it uses RANDOM_SEED sampling
|
||||
# && VLLM_USE_V1=1 pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \
|
||||
|
||||
@ -4,16 +4,28 @@
|
||||
# It serves a sanity check for compilation and basic model usage.
|
||||
set -ex
|
||||
|
||||
image_name="xpu/vllm-ci:${BUILDKITE_COMMIT}"
|
||||
container_name="xpu_${BUILDKITE_COMMIT}_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)"
|
||||
|
||||
# Try building the docker image
|
||||
docker build -t xpu-test -f Dockerfile.xpu .
|
||||
docker build -t ${image_name} -f Dockerfile.xpu .
|
||||
|
||||
# Setup cleanup
|
||||
remove_docker_container() { docker rm -f xpu-test || true; }
|
||||
remove_docker_container() {
|
||||
docker rm -f "${container_name}" || true;
|
||||
docker image rm -f "${image_name}" || true;
|
||||
docker system prune -f || true;
|
||||
}
|
||||
trap remove_docker_container EXIT
|
||||
remove_docker_container
|
||||
|
||||
# Run the image and test offline inference/tensor parallel
|
||||
docker run --name xpu-test --device /dev/dri -v /dev/dri/by-path:/dev/dri/by-path --entrypoint="" xpu-test sh -c '
|
||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m
|
||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m -tp 2
|
||||
docker run \
|
||||
--device /dev/dri \
|
||||
-v /dev/dri/by-path:/dev/dri/by-path \
|
||||
--entrypoint="" \
|
||||
--name "${container_name}" \
|
||||
"${image_name}" \
|
||||
sh -c '
|
||||
VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m
|
||||
VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m -tp 2
|
||||
'
|
||||
|
||||
@ -35,13 +35,12 @@ steps:
|
||||
fast_check: true
|
||||
no_gpu: True
|
||||
commands:
|
||||
- pip install -r requirements-docs.txt
|
||||
- pip install -r ../../requirements/docs.txt
|
||||
- SPHINXOPTS=\"-W\" make html
|
||||
# Check API reference (if it fails, you may have missing mock imports)
|
||||
- grep \"sig sig-object py\" build/html/api/inference_params.html
|
||||
|
||||
- label: Async Engine, Inputs, Utils, Worker Test # 24min
|
||||
fast_check: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/mq_llm_engine
|
||||
@ -78,6 +77,7 @@ steps:
|
||||
- tests/basic_correctness/test_preemption
|
||||
- tests/basic_correctness/test_cumem.py
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s basic_correctness/test_cumem.py
|
||||
- pytest -v -s basic_correctness/test_basic_correctness.py
|
||||
- pytest -v -s basic_correctness/test_cpu_offload.py
|
||||
@ -112,19 +112,19 @@ steps:
|
||||
- tests/entrypoints/test_chat_utils
|
||||
- tests/entrypoints/offline_mode
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py --ignore=entrypoints/llm/test_collective_rpc.py
|
||||
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process
|
||||
- VLLM_USE_V1=0 pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/correctness/
|
||||
- pytest -v -s entrypoints/test_chat_utils.py
|
||||
- pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
||||
- VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
||||
|
||||
- label: Distributed Tests (4 GPUs) # 10min
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 4
|
||||
fast_check: true
|
||||
source_file_dependencies:
|
||||
- vllm/distributed/
|
||||
- vllm/core/
|
||||
@ -136,19 +136,24 @@ steps:
|
||||
- examples/offline_inference/rlhf_colocate.py
|
||||
- tests/examples/offline_inference/data_parallel.py
|
||||
commands:
|
||||
- VLLM_USE_V1=1 python3 ../examples/offline_inference/data_parallel.py
|
||||
# test with tp=2 and external_dp=2
|
||||
- VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
|
||||
- torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
|
||||
# test with internal dp
|
||||
- python3 ../examples/offline_inference/data_parallel.py
|
||||
- pytest -v -s distributed/test_utils.py
|
||||
- pytest -v -s compile/test_basic_correctness.py
|
||||
- pytest -v -s distributed/test_pynccl.py
|
||||
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
|
||||
# TODO: create a dedicated test section for multi-GPU example tests
|
||||
# when we have multiple distributed example tests
|
||||
- python3 ../examples/offline_inference/rlhf.py
|
||||
- RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/rlhf_colocate.py
|
||||
- pushd ../examples/offline_inference
|
||||
- python3 rlhf.py
|
||||
- RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
|
||||
- popd
|
||||
|
||||
- label: Metrics, Tracing Test # 10min
|
||||
num_gpus: 2
|
||||
fast_check: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/metrics
|
||||
@ -196,15 +201,19 @@ steps:
|
||||
- tests/v1
|
||||
commands:
|
||||
# split the test to avoid interference
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/core
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/engine
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/sample
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/worker
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/test_stats.py
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/test_utils.py
|
||||
- pytest -v -s v1/core
|
||||
- pytest -v -s v1/entrypoints
|
||||
- pytest -v -s v1/engine
|
||||
- pytest -v -s v1/entrypoints
|
||||
- pytest -v -s v1/sample
|
||||
- pytest -v -s v1/worker
|
||||
- pytest -v -s v1/structured_output
|
||||
- pytest -v -s v1/test_stats.py
|
||||
- pytest -v -s v1/test_utils.py
|
||||
- pytest -v -s v1/test_oracle.py
|
||||
# TODO: accuracy does not match, whether setting
|
||||
# VLLM_USE_FLASHINFER_SAMPLER or not on H100.
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/e2e
|
||||
- pytest -v -s v1/e2e
|
||||
# Integration test for streaming correctness (requires special branch).
|
||||
- pip install -U git+https://github.com/robertgshaw2-neuralmagic/lm-evaluation-harness.git@streaming-api
|
||||
- pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine
|
||||
@ -222,14 +231,17 @@ steps:
|
||||
- python3 offline_inference/basic/chat.py
|
||||
- python3 offline_inference/prefix_caching.py
|
||||
- python3 offline_inference/llm_engine_example.py
|
||||
- python3 offline_inference/vision_language.py
|
||||
- python3 offline_inference/vision_language_multi_image.py
|
||||
- python3 other/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 other/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
|
||||
- python3 offline_inference/audio_language.py --seed 0
|
||||
- python3 offline_inference/vision_language.py --seed 0
|
||||
- python3 offline_inference/vision_language_embedding.py --seed 0
|
||||
- python3 offline_inference/vision_language_multi_image.py --seed 0
|
||||
- VLLM_USE_V1=0 python3 other/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 other/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
|
||||
- python3 offline_inference/encoder_decoder.py
|
||||
- python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0
|
||||
- python3 offline_inference/basic/classify.py
|
||||
- python3 offline_inference/basic/embed.py
|
||||
- python3 offline_inference/basic/score.py
|
||||
- python3 offline_inference/profiling.py --model facebook/opt-125m run_num_steps --num-steps 2
|
||||
- VLLM_USE_V1=0 python3 offline_inference/profiling.py --model facebook/opt-125m run_num_steps --num-steps 2
|
||||
|
||||
- label: Prefix Caching Test # 9min
|
||||
mirror_hardwares: [amd]
|
||||
@ -279,7 +291,6 @@ steps:
|
||||
parallelism: 4
|
||||
|
||||
- label: PyTorch Fullgraph Smoke Test # 9min
|
||||
fast_check: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/compile
|
||||
@ -288,6 +299,7 @@ steps:
|
||||
# these tests need to be separated, cannot combine
|
||||
- pytest -v -s compile/piecewise/test_simple.py
|
||||
- pytest -v -s compile/piecewise/test_toy_llama.py
|
||||
- pytest -v -s compile/test_pass_manager.py
|
||||
|
||||
- label: PyTorch Fullgraph Test # 18min
|
||||
source_file_dependencies:
|
||||
@ -374,7 +386,8 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s models/test_transformers.py
|
||||
- pytest -v -s models/test_registry.py
|
||||
- pytest -v -s models/test_initialization.py
|
||||
# V1 Test: https://github.com/vllm-project/vllm/issues/14531
|
||||
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py
|
||||
|
||||
- label: Language Models Test (Standard) # 32min
|
||||
#mirror_hardwares: [amd]
|
||||
@ -503,8 +516,6 @@ steps:
|
||||
- entrypoints/llm/test_collective_rpc.py
|
||||
commands:
|
||||
- pytest -v -s entrypoints/llm/test_collective_rpc.py
|
||||
- VLLM_USE_V1=1 torchrun --nproc-per-node=2 distributed/test_torchrun_example.py
|
||||
- torchrun --nproc-per-node=2 distributed/test_torchrun_example.py
|
||||
- pytest -v -s ./compile/test_basic_correctness.py
|
||||
- pytest -v -s ./compile/test_wrapper.py
|
||||
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
|
||||
@ -517,13 +528,12 @@ steps:
|
||||
# this test fails consistently.
|
||||
# TODO: investigate and fix
|
||||
# - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/disagg_test.py
|
||||
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
||||
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/test_disagg.py
|
||||
|
||||
- label: Plugin Tests (2 GPUs) # 40min
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 2
|
||||
fast_check: true
|
||||
source_file_dependencies:
|
||||
- vllm/plugins/
|
||||
- tests/plugins/
|
||||
|
||||
27
.github/CODEOWNERS
vendored
27
.github/CODEOWNERS
vendored
@ -10,27 +10,32 @@
|
||||
/vllm/worker/worker.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth
|
||||
/vllm/model_executor/guided_decoding @mgoin
|
||||
/vllm/model_executor/guided_decoding @mgoin @russellb
|
||||
/vllm/multimodal @DarkLight1337 @ywang96
|
||||
CMakeLists.txt @tlrmchlsmth
|
||||
|
||||
# vLLM V1
|
||||
/vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat
|
||||
/vllm/v1/structured_output @mgoin @russellb
|
||||
|
||||
# Test ownership
|
||||
/tests/async_engine @njhill @robertgshaw2-redhat @simon-mo
|
||||
/tests/test_inputs.py @DarkLight1337 @ywang96
|
||||
/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @simon-mo
|
||||
/tests/models @DarkLight1337 @ywang96
|
||||
/tests/multimodal @DarkLight1337 @ywang96
|
||||
/tests/prefix_caching @comaniac @KuntaiDu
|
||||
/tests/spec_decode @njhill @LiuXiaoxuanPKU
|
||||
/tests/kernels @tlrmchlsmth @WoosukKwon
|
||||
/tests/quantization @mgoin @robertgshaw2-redhat
|
||||
/.buildkite/lm-eval-harness @mgoin @simon-mo
|
||||
/tests/async_engine @njhill @robertgshaw2-redhat @simon-mo
|
||||
/tests/basic_correctness/test_chunked_prefill @rkooo567 @comaniac
|
||||
/tests/distributed/test_multi_node_assignment.py @youkaichao
|
||||
/tests/distributed/test_pipeline_parallel.py @youkaichao
|
||||
/tests/distributed/test_same_node.py @youkaichao
|
||||
/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @simon-mo
|
||||
/tests/entrypoints/llm/test_guided_generate.py @mgoin @russellb
|
||||
/tests/kernels @tlrmchlsmth @WoosukKwon
|
||||
/tests/model_executor/test_guided_processors.py @mgoin @russellb
|
||||
/tests/models @DarkLight1337 @ywang96
|
||||
/tests/multi_step @alexm-redhat @comaniac
|
||||
/tests/multimodal @DarkLight1337 @ywang96
|
||||
/tests/prefix_caching @comaniac @KuntaiDu
|
||||
/tests/quantization @mgoin @robertgshaw2-redhat
|
||||
/tests/spec_decode @njhill @LiuXiaoxuanPKU
|
||||
/tests/test_inputs.py @DarkLight1337 @ywang96
|
||||
/tests/v1/entrypoints/llm/test_struct_output_generate.py @mgoin @russellb
|
||||
/tests/v1/structured_output @mgoin @russellb
|
||||
/tests/weight_loading @mgoin @youkaichao
|
||||
/tests/basic_correctness/test_chunked_prefill @rkooo567 @comaniac
|
||||
|
||||
28
.github/ISSUE_TEMPLATE/800-misc-discussion.yml
vendored
28
.github/ISSUE_TEMPLATE/800-misc-discussion.yml
vendored
@ -1,28 +0,0 @@
|
||||
name: 🎲 Misc/random discussions that do not fit into the above categories.
|
||||
description: Submit a discussion as you like. Note that developers are heavily overloaded and we mainly rely on community users to answer these issues.
|
||||
title: "[Misc]: "
|
||||
labels: ["misc"]
|
||||
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
#### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue+sort%3Acreated-desc+).
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Anything you want to discuss about vllm.
|
||||
description: >
|
||||
Anything you want to discuss about vllm.
|
||||
validations:
|
||||
required: true
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
Thanks for contributing 🎉!
|
||||
- type: checkboxes
|
||||
id: askllm
|
||||
attributes:
|
||||
label: Before submitting a new issue...
|
||||
options:
|
||||
- label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions.
|
||||
required: true
|
||||
4
.github/ISSUE_TEMPLATE/config.yml
vendored
4
.github/ISSUE_TEMPLATE/config.yml
vendored
@ -1 +1,5 @@
|
||||
blank_issues_enabled: false
|
||||
contact_links:
|
||||
- name: Questions
|
||||
url: https://discuss.vllm.ai
|
||||
about: Ask questions and discuss with other vLLM community members
|
||||
|
||||
15
.github/mergify.yml
vendored
15
.github/mergify.yml
vendored
@ -36,6 +36,21 @@ pull_request_rules:
|
||||
add:
|
||||
- frontend
|
||||
|
||||
- name: label-multi-modality
|
||||
description: Automatically apply multi-modality label
|
||||
conditions:
|
||||
- or:
|
||||
- files~=^vllm/multimodal/
|
||||
- files~=^tests/multimodal/
|
||||
- files~=^tests/models/multimodal/
|
||||
- files~=^tests/models/*/audio_language/
|
||||
- files~=^tests/models/*/vision_language/
|
||||
- files=tests/models/test_vision.py
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
- multi-modality
|
||||
|
||||
- name: label-structured-output
|
||||
description: Automatically apply structured-output label
|
||||
conditions:
|
||||
|
||||
4
.github/workflows/publish.yml
vendored
4
.github/workflows/publish.yml
vendored
@ -39,7 +39,7 @@ jobs:
|
||||
const script = require('.github/workflows/scripts/create_release.js')
|
||||
await script(github, context, core)
|
||||
|
||||
# NOTE(simon): No longer build wheel using Github Actions. See buildkite's release workflow.
|
||||
# NOTE(simon): No longer build wheel using GitHub Actions. See buildkite's release workflow.
|
||||
# wheel:
|
||||
# name: Build Wheel
|
||||
# runs-on: ${{ matrix.os }}
|
||||
@ -50,7 +50,7 @@ jobs:
|
||||
# matrix:
|
||||
# os: ['ubuntu-20.04']
|
||||
# python-version: ['3.9', '3.10', '3.11', '3.12']
|
||||
# pytorch-version: ['2.4.0'] # Must be the most recent version that meets requirements-cuda.txt.
|
||||
# pytorch-version: ['2.4.0'] # Must be the most recent version that meets requirements/cuda.txt.
|
||||
# cuda-version: ['11.8', '12.1']
|
||||
|
||||
# steps:
|
||||
|
||||
2
.github/workflows/scripts/build.sh
vendored
2
.github/workflows/scripts/build.sh
vendored
@ -9,7 +9,7 @@ PATH=${cuda_home}/bin:$PATH
|
||||
LD_LIBRARY_PATH=${cuda_home}/lib64:$LD_LIBRARY_PATH
|
||||
|
||||
# Install requirements
|
||||
$python_executable -m pip install -r requirements-build.txt -r requirements-cuda.txt
|
||||
$python_executable -m pip install -r requirements/build.txt -r requirements/cuda.txt
|
||||
|
||||
# Limit the number of parallel jobs to avoid OOM
|
||||
export MAX_JOBS=1
|
||||
|
||||
2
.github/workflows/scripts/create_release.js
vendored
2
.github/workflows/scripts/create_release.js
vendored
@ -1,4 +1,4 @@
|
||||
// Uses Github's API to create the release and wait for result.
|
||||
// Uses GitHub's API to create the release and wait for result.
|
||||
// We use a JS script since github CLI doesn't provide a way to wait for the release's creation and returns immediately.
|
||||
|
||||
module.exports = async (github, context, core) => {
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@ -197,7 +197,7 @@ _build/
|
||||
hip_compat.h
|
||||
|
||||
# Benchmark dataset
|
||||
benchmarks/*.json
|
||||
benchmarks/**/*.json
|
||||
|
||||
# Linting
|
||||
actionlint
|
||||
|
||||
@ -44,8 +44,8 @@ repos:
|
||||
rev: 0.6.2
|
||||
hooks:
|
||||
- id: pip-compile
|
||||
args: [requirements-test.in, -o, requirements-test.txt]
|
||||
files: ^requirements-test\.(in|txt)$
|
||||
args: [requirements/test.in, -o, requirements/test.txt]
|
||||
files: ^requirements/test\.(in|txt)$
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: mypy-local
|
||||
@ -53,7 +53,7 @@ repos:
|
||||
entry: tools/mypy.sh 0 "local"
|
||||
language: python
|
||||
types: [python]
|
||||
additional_dependencies: &mypy_deps [mypy==1.11.1, types-setuptools, types-PyYAML, types-requests]
|
||||
additional_dependencies: &mypy_deps [mypy==1.11.1, types-cachetools, types-setuptools, types-PyYAML, types-requests]
|
||||
stages: [pre-commit] # Don't run in CI
|
||||
- id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
||||
name: Run mypy for Python 3.9
|
||||
|
||||
@ -18,4 +18,4 @@ formats: []
|
||||
# Optionally declare the Python requirements required to build your docs
|
||||
python:
|
||||
install:
|
||||
- requirements: docs/requirements-docs.txt
|
||||
- requirements: requirements/docs.txt
|
||||
|
||||
101
CMakeLists.txt
101
CMakeLists.txt
@ -46,8 +46,8 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101")
|
||||
# requirements.txt files and should be kept consistent. The ROCm torch
|
||||
# versions are derived from Dockerfile.rocm
|
||||
#
|
||||
set(TORCH_SUPPORTED_VERSION_CUDA "2.5.1")
|
||||
set(TORCH_SUPPORTED_VERSION_ROCM "2.5.1")
|
||||
set(TORCH_SUPPORTED_VERSION_CUDA "2.6.0")
|
||||
set(TORCH_SUPPORTED_VERSION_ROCM "2.6.0")
|
||||
|
||||
#
|
||||
# Try to find python package with an executable that exactly matches
|
||||
@ -319,7 +319,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
# Only build AllSpark kernels if we are building for at least some compatible archs.
|
||||
cuda_archs_loose_intersection(ALLSPARK_ARCHS "8.0;8.6;8.7;8.9" "${CUDA_ARCHS}")
|
||||
if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND ALLSPARK_ARCHS)
|
||||
if (ALLSPARK_ARCHS)
|
||||
set(ALLSPARK_SRCS
|
||||
"csrc/quantization/gptq_allspark/allspark_repack.cu"
|
||||
"csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu")
|
||||
@ -330,39 +330,67 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
message(STATUS "Building AllSpark kernels for archs: ${ALLSPARK_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building AllSpark kernels as no compatible archs found"
|
||||
" in CUDA target architectures, or CUDA not >= 12.0")
|
||||
" in CUDA target architectures")
|
||||
endif()
|
||||
|
||||
|
||||
set(SCALED_MM_3X_ARCHS)
|
||||
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
|
||||
# CUDA 12.0 or later (and only work on Hopper, 9.0a for now).
|
||||
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a;10.0a;10.1a;12.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
|
||||
# CUDA 12.0 or later
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C3X=1")
|
||||
message(STATUS "Building scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM90=1")
|
||||
# Let scaled_mm_c2x know it doesn't need to build these arches
|
||||
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
||||
message(STATUS "Building scaled_mm_c3x_sm90 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c3x as CUDA Compiler version is "
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c3x_sm90 as CUDA Compiler version is "
|
||||
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
|
||||
"later if you intend on running FP8 quantized models on "
|
||||
"Hopper.")
|
||||
else()
|
||||
message(STATUS "Not building scaled_mm_c3x as no compatible archs found "
|
||||
message(STATUS "Not building scaled_mm_c3x_sm90 as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# clear SCALED_MM_3X_ARCHS so the scaled_mm_c2x kernels know we didn't
|
||||
# build any 3x kernels
|
||||
set(SCALED_MM_3X_ARCHS)
|
||||
# The cutlass_scaled_mm kernels for Blackwell (c3x, i.e. CUTLASS 3.x) require
|
||||
# CUDA 12.8 or later
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;12.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu"
|
||||
)
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM100=1")
|
||||
# Let scaled_mm_c2x know it doesn't need to build these arches
|
||||
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
||||
message(STATUS "Building scaled_mm_c3x_sm100 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c3x_sm100 as CUDA Compiler version is "
|
||||
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
|
||||
"later if you intend on running FP8 quantized models on "
|
||||
"Blackwell.")
|
||||
else()
|
||||
message(STATUS "Not building scaled_mm_c3x_100 as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
#
|
||||
@ -394,17 +422,18 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# 2:4 Sparse Kernels
|
||||
|
||||
# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
|
||||
# require CUDA 12.2 or later (and only work on Hopper and Blackwell).
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
|
||||
# require CUDA 12.2 or later (and only work on Hopper).
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SPARSE_SCALED_MM_C3X=1")
|
||||
message(STATUS "Building sparse_scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}")
|
||||
message(STATUS "Building sparse_scaled_mm_c3x for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building sparse_scaled_mm_c3x kernels as CUDA Compiler version is "
|
||||
"not >= 12.2, we recommend upgrading to CUDA 12.2 or later "
|
||||
"if you intend on running FP8 sparse quantized models on Hopper.")
|
||||
@ -432,22 +461,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
set(FP4_ARCHS)
|
||||
endif()
|
||||
|
||||
# FP8 Blackwell Archs
|
||||
cuda_archs_loose_intersection(BLACKWELL_ARCHS "10.0;10.1;12.0" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND BLACKWELL_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu"
|
||||
)
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${BLACKWELL_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
message(STATUS "Building FP8 for archs: ${BLACKWELL_ARCHS}")
|
||||
else()
|
||||
# clear BLACKWELL_ARCHS
|
||||
set(BLACKWELL_ARCHS)
|
||||
endif()
|
||||
|
||||
#
|
||||
# Machete kernels
|
||||
|
||||
@ -548,11 +561,23 @@ set(VLLM_MOE_EXT_SRC
|
||||
"csrc/moe/moe_align_sum_kernels.cu"
|
||||
"csrc/moe/topk_softmax_kernels.cu")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_MOE_EXT_SRC "csrc/moe/moe_wna16.cu")
|
||||
endif()
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${VLLM_MOE_EXT_SRC}"
|
||||
CUDA_ARCHS "${CUDA_ARCHS}")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
set(VLLM_MOE_WNA16_SRC
|
||||
"csrc/moe/moe_wna16.cu")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${VLLM_MOE_WNA16_SRC}"
|
||||
CUDA_ARCHS "${CUDA_ARCHS}")
|
||||
|
||||
list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}")
|
||||
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
|
||||
if (MARLIN_MOE_ARCHS)
|
||||
set(MARLIN_MOE_SRC
|
||||
|
||||
132
Dockerfile
132
Dockerfile
@ -14,22 +14,21 @@ ARG PYTHON_VERSION=3.12
|
||||
ARG TARGETPLATFORM
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Install Python and other dependencies
|
||||
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
|
||||
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y ccache software-properties-common git curl sudo \
|
||||
&& add-apt-repository ppa:deadsnakes/ppa \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
|
||||
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
|
||||
&& update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \
|
||||
&& ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \
|
||||
&& curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \
|
||||
&& python3 --version && python3 -m pip --version
|
||||
# Install uv for faster pip installs
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
python3 -m pip install uv
|
||||
# Install minimal dependencies and uv
|
||||
RUN apt-get update -y \
|
||||
&& apt-get install -y ccache git curl wget sudo \
|
||||
&& curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
# Add uv to PATH
|
||||
ENV PATH="/root/.local/bin:$PATH"
|
||||
# Create venv with specified Python and activate by placing at the front of path
|
||||
ENV VIRTUAL_ENV="/opt/venv"
|
||||
RUN uv venv --python ${PYTHON_VERSION} --seed ${VIRTUAL_ENV}
|
||||
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
|
||||
|
||||
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
|
||||
# Upgrade to GCC 10 to avoid https://gcc.gnu.org/bugzilla/show_bug.cgi?id=92519
|
||||
# as it was causing spam when compiling the CUTLASS kernels
|
||||
@ -47,21 +46,19 @@ RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
# install build and runtime dependencies
|
||||
|
||||
# arm64 (GH200) build follows the practice of "use existing pytorch" build,
|
||||
# we need to install torch and torchvision from the nightly builds first,
|
||||
# pytorch will not appear as a vLLM dependency in all of the following steps
|
||||
# after this step
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
|
||||
uv pip install --system --index-url https://download.pytorch.org/whl/nightly/cu126 "torch==2.7.0.dev20250121+cu126" "torchvision==0.22.0.dev20250121"; \
|
||||
uv pip install --index-url https://download.pytorch.org/whl/nightly/cu126 "torch==2.7.0.dev20250121+cu126" "torchvision==0.22.0.dev20250121"; \
|
||||
fi
|
||||
|
||||
COPY requirements-common.txt requirements-common.txt
|
||||
COPY requirements-cuda.txt requirements-cuda.txt
|
||||
COPY requirements/common.txt requirements/common.txt
|
||||
COPY requirements/cuda.txt requirements/cuda.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements-cuda.txt
|
||||
uv pip install -r requirements/cuda.txt
|
||||
|
||||
# cuda arch list used by torch
|
||||
# can be useful for both `dev` and `test`
|
||||
@ -79,15 +76,19 @@ FROM base AS build
|
||||
ARG TARGETPLATFORM
|
||||
|
||||
# install build dependencies
|
||||
COPY requirements-build.txt requirements-build.txt
|
||||
COPY requirements/build.txt requirements/build.txt
|
||||
|
||||
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements-build.txt
|
||||
uv pip install -r requirements/build.txt
|
||||
|
||||
COPY . .
|
||||
ARG GIT_REPO_CHECK=0
|
||||
RUN --mount=type=bind,source=.git,target=.git \
|
||||
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi
|
||||
if [ "$GIT_REPO_CHECK" != "0" ]; then bash tools/check_repo.sh ; fi
|
||||
|
||||
# max jobs used by Ninja to build extensions
|
||||
ARG max_jobs=2
|
||||
@ -124,6 +125,9 @@ RUN --mount=type=cache,target=/root/.cache/ccache \
|
||||
--mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=.git,target=.git \
|
||||
if [ "$USE_SCCACHE" != "1" ]; then \
|
||||
# Clean any existing CMake artifacts
|
||||
rm -rf .deps && \
|
||||
mkdir -p .deps && \
|
||||
python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38; \
|
||||
fi
|
||||
|
||||
@ -143,11 +147,15 @@ RUN if [ "$RUN_WHEEL_CHECK" = "true" ]; then \
|
||||
#################### DEV IMAGE ####################
|
||||
FROM base as dev
|
||||
|
||||
COPY requirements-lint.txt requirements-lint.txt
|
||||
COPY requirements-test.txt requirements-test.txt
|
||||
COPY requirements-dev.txt requirements-dev.txt
|
||||
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
|
||||
COPY requirements/lint.txt requirements/lint.txt
|
||||
COPY requirements/test.txt requirements/test.txt
|
||||
COPY requirements/dev.txt requirements/dev.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements-dev.txt
|
||||
uv pip install -r requirements/dev.txt
|
||||
#################### DEV IMAGE ####################
|
||||
|
||||
#################### vLLM installation IMAGE ####################
|
||||
@ -163,23 +171,22 @@ ARG TARGETPLATFORM
|
||||
RUN PYTHON_VERSION_STR=$(echo ${PYTHON_VERSION} | sed 's/\.//g') && \
|
||||
echo "export PYTHON_VERSION_STR=${PYTHON_VERSION_STR}" >> /etc/environment
|
||||
|
||||
# Install Python and other dependencies
|
||||
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
|
||||
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y ccache software-properties-common git curl wget sudo vim python3-pip \
|
||||
&& apt-get install -y ffmpeg libsm6 libxext6 libgl1 \
|
||||
&& add-apt-repository ppa:deadsnakes/ppa \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv libibverbs-dev \
|
||||
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
|
||||
&& update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \
|
||||
&& ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \
|
||||
&& curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \
|
||||
&& python3 --version && python3 -m pip --version
|
||||
# Install uv for faster pip installs
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
python3 -m pip install uv
|
||||
# Install minimal dependencies and uv
|
||||
RUN apt-get update -y \
|
||||
&& apt-get install -y ccache git curl wget sudo vim \
|
||||
&& apt-get install -y ffmpeg libsm6 libxext6 libgl1 libibverbs-dev \
|
||||
&& curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
# Add uv to PATH
|
||||
ENV PATH="/root/.local/bin:$PATH"
|
||||
# Create venv with specified Python and activate by placing at the front of path
|
||||
ENV VIRTUAL_ENV="/opt/venv"
|
||||
RUN uv venv --python ${PYTHON_VERSION} --seed ${VIRTUAL_ENV}
|
||||
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
|
||||
|
||||
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
|
||||
# Workaround for https://github.com/openai/triton/issues/2507 and
|
||||
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
|
||||
@ -193,13 +200,13 @@ RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/
|
||||
# after this step
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
|
||||
uv pip install --system --index-url https://download.pytorch.org/whl/nightly/cu124 "torch==2.6.0.dev20241210+cu124" "torchvision==0.22.0.dev20241215"; \
|
||||
uv pip install --index-url https://download.pytorch.org/whl/nightly/cu124 "torch==2.6.0.dev20241210+cu124" "torchvision==0.22.0.dev20241215"; \
|
||||
fi
|
||||
|
||||
# Install vllm wheel first, so that torch etc will be installed.
|
||||
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
|
||||
--mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system dist/*.whl --verbose
|
||||
uv pip install dist/*.whl --verbose
|
||||
|
||||
# If we need to build FlashInfer wheel before its release:
|
||||
# $ export FLASHINFER_ENABLE_AOT=1
|
||||
@ -214,9 +221,8 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
|
||||
# $ # upload the wheel to a public location, e.g. https://wheels.vllm.ai/flashinfer/524304395bd1d8cd7d07db083859523fcaa246a4/flashinfer_python-0.2.1.post1+cu124torch2.5-cp38-abi3-linux_x86_64.whl
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
. /etc/environment && \
|
||||
if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
|
||||
uv pip install --system https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.1.post1/flashinfer_python-0.2.1.post1+cu124torch2.5-cp38-abi3-linux_x86_64.whl ; \
|
||||
uv pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.1.post2/flashinfer_python-0.2.1.post2+cu124torch2.6-cp38-abi3-linux_x86_64.whl ; \
|
||||
fi
|
||||
COPY examples examples
|
||||
|
||||
@ -224,9 +230,9 @@ COPY examples examples
|
||||
# some issues w.r.t. JIT compilation. Therefore we need to
|
||||
# install build dependencies for JIT compilation.
|
||||
# TODO: Remove this once FlashInfer AOT wheel is fixed
|
||||
COPY requirements-build.txt requirements-build.txt
|
||||
COPY requirements/build.txt requirements/build.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements-build.txt
|
||||
uv pip install -r requirements/build.txt
|
||||
|
||||
#################### vLLM installation IMAGE ####################
|
||||
|
||||
@ -237,17 +243,21 @@ FROM vllm-base AS test
|
||||
|
||||
ADD . /vllm-workspace/
|
||||
|
||||
# install development dependencies (for testing)
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements-dev.txt
|
||||
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
|
||||
# install development dependencies (for testing)
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -e tests/vllm_test_utils
|
||||
uv pip install -r requirements/dev.txt
|
||||
|
||||
# install development dependencies (for testing)
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install -e tests/vllm_test_utils
|
||||
|
||||
# enable fast downloads from hf (for testing)
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system hf_transfer
|
||||
uv pip install hf_transfer
|
||||
ENV HF_HUB_ENABLE_HF_TRANSFER 1
|
||||
|
||||
# Copy in the v1 package for testing (it isn't distributed yet)
|
||||
@ -265,12 +275,16 @@ RUN mv vllm test_docs/
|
||||
# base openai image with additional requirements, for any subsequent openai-style images
|
||||
FROM vllm-base AS vllm-openai-base
|
||||
|
||||
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
|
||||
# install additional dependencies for openai api server
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
|
||||
uv pip install --system accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.42.0' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \
|
||||
uv pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.42.0' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \
|
||||
else \
|
||||
uv pip install --system accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.45.0' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \
|
||||
uv pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.45.0' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \
|
||||
fi
|
||||
|
||||
ENV VLLM_USAGE_SOURCE production-docker-image
|
||||
|
||||
@ -26,18 +26,18 @@ WORKDIR /workspace
|
||||
ARG PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu"
|
||||
ENV PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
--mount=type=bind,src=requirements-build.txt,target=requirements-build.txt \
|
||||
--mount=type=bind,src=requirements/build.txt,target=requirements/build.txt \
|
||||
pip install --upgrade pip && \
|
||||
pip install -r requirements-build.txt
|
||||
pip install -r requirements/build.txt
|
||||
|
||||
FROM cpu-test-arm AS build
|
||||
|
||||
WORKDIR /workspace/vllm
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
--mount=type=bind,src=requirements-common.txt,target=requirements-common.txt \
|
||||
--mount=type=bind,src=requirements-cpu.txt,target=requirements-cpu.txt \
|
||||
pip install -v -r requirements-cpu.txt
|
||||
--mount=type=bind,src=requirements/common.txt,target=requirements/common.txt \
|
||||
--mount=type=bind,src=requirements/cpu.txt,target=requirements/cpu.txt \
|
||||
pip install -v -r requirements/cpu.txt
|
||||
|
||||
COPY . .
|
||||
ARG GIT_REPO_CHECK=0
|
||||
|
||||
@ -22,25 +22,25 @@ ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:/usr/local/li
|
||||
|
||||
RUN echo 'ulimit -c 0' >> ~/.bashrc
|
||||
|
||||
RUN pip install intel_extension_for_pytorch==2.5.0
|
||||
RUN pip install intel_extension_for_pytorch==2.6.0
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
ARG PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu"
|
||||
ENV PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
--mount=type=bind,src=requirements-build.txt,target=requirements-build.txt \
|
||||
--mount=type=bind,src=requirements/build.txt,target=requirements/build.txt \
|
||||
pip install --upgrade pip && \
|
||||
pip install -r requirements-build.txt
|
||||
pip install -r requirements/build.txt
|
||||
|
||||
FROM cpu-test-1 AS build
|
||||
|
||||
WORKDIR /workspace/vllm
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
--mount=type=bind,src=requirements-common.txt,target=requirements-common.txt \
|
||||
--mount=type=bind,src=requirements-cpu.txt,target=requirements-cpu.txt \
|
||||
pip install -v -r requirements-cpu.txt
|
||||
--mount=type=bind,src=requirements/common.txt,target=requirements/common.txt \
|
||||
--mount=type=bind,src=requirements/cpu.txt,target=requirements/cpu.txt \
|
||||
pip install -v -r requirements/cpu.txt
|
||||
|
||||
COPY . .
|
||||
ARG GIT_REPO_CHECK=0
|
||||
|
||||
@ -4,7 +4,7 @@ COPY ./ /workspace/vllm
|
||||
|
||||
WORKDIR /workspace/vllm
|
||||
|
||||
RUN pip install -v -r requirements-hpu.txt
|
||||
RUN pip install -v -r requirements/hpu.txt
|
||||
|
||||
ENV no_proxy=localhost,127.0.0.1
|
||||
ENV PT_HPU_ENABLE_LAZY_COLLECTIVES=true
|
||||
|
||||
@ -36,7 +36,7 @@ RUN --mount=type=bind,source=.git,target=.git \
|
||||
|
||||
RUN python3 -m pip install -U \
|
||||
'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \
|
||||
-r requirements-neuron.txt
|
||||
-r requirements/neuron.txt
|
||||
|
||||
ENV VLLM_TARGET_DEVICE neuron
|
||||
RUN --mount=type=bind,source=.git,target=.git \
|
||||
|
||||
@ -1,29 +0,0 @@
|
||||
# The vLLM Dockerfile is used to construct vLLM image that can be directly used
|
||||
# to run the OpenAI compatible server.
|
||||
|
||||
FROM ubuntu:22.04 AS dev
|
||||
|
||||
RUN apt-get update -y && \
|
||||
apt-get install -y \
|
||||
git python3-pip \
|
||||
ffmpeg libsm6 libxext6 libgl1
|
||||
WORKDIR /workspace
|
||||
|
||||
COPY . .
|
||||
ARG GIT_REPO_CHECK=0
|
||||
RUN --mount=type=bind,source=.git,target=.git \
|
||||
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi
|
||||
|
||||
RUN python3 -m pip install -U pip
|
||||
# install build requirements
|
||||
RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" python3 -m pip install -r /workspace/requirements-build.txt
|
||||
# build vLLM with OpenVINO backend
|
||||
RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" VLLM_TARGET_DEVICE="openvino" python3 -m pip install /workspace
|
||||
|
||||
COPY examples/ /workspace/examples
|
||||
COPY benchmarks/ /workspace/benchmarks
|
||||
|
||||
# install development dependencies (for testing)
|
||||
RUN python3 -m pip install -e tests/vllm_test_utils
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
@ -6,7 +6,7 @@ ENV PATH="/usr/local/cargo/bin:$PATH:/opt/conda/bin/"
|
||||
|
||||
RUN apt-get update -y && apt-get install -y git wget kmod curl vim libnuma-dev libsndfile-dev libprotobuf-dev build-essential ffmpeg libsm6 libxext6 libgl1 libssl-dev
|
||||
|
||||
# Some packages in requirements-cpu are installed here
|
||||
# Some packages in requirements/cpu are installed here
|
||||
# IBM provides optimized packages for ppc64le processors in the open-ce project for mamba
|
||||
# Currently these may not be available for venv or pip directly
|
||||
RUN micromamba install -y -n base -c https://ftp.osuosl.org/pub/open-ce/1.11.0-p10/ -c defaults python=3.10 rust && micromamba clean --all --yes
|
||||
@ -21,7 +21,7 @@ RUN --mount=type=bind,source=.git,target=.git \
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
RUSTFLAGS='-L /opt/conda/lib' pip install -v --prefer-binary --extra-index-url https://repo.fury.io/mgiessing \
|
||||
'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \
|
||||
-r requirements-cpu.txt \
|
||||
-r requirements/cpu.txt \
|
||||
xformers uvloop==0.20.0
|
||||
|
||||
RUN --mount=type=bind,source=.git,target=.git \
|
||||
|
||||
@ -38,14 +38,14 @@ FROM fetch_vllm AS build_vllm
|
||||
ARG USE_CYTHON
|
||||
# Build vLLM
|
||||
RUN cd vllm \
|
||||
&& python3 -m pip install -r requirements-rocm.txt \
|
||||
&& python3 -m pip install -r requirements/rocm.txt \
|
||||
&& python3 setup.py clean --all \
|
||||
&& if [ ${USE_CYTHON} -eq "1" ]; then python3 setup_cython.py build_ext --inplace; fi \
|
||||
&& python3 setup.py bdist_wheel --dist-dir=dist
|
||||
FROM scratch AS export_vllm
|
||||
ARG COMMON_WORKDIR
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/dist/*.whl /
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/requirements*.txt /
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/requirements /requirements
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/benchmarks /benchmarks
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/tests /tests
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/examples /examples
|
||||
@ -60,7 +60,8 @@ RUN python3 -m pip install --upgrade pip && rm -rf /var/lib/apt/lists/*
|
||||
# Install vLLM
|
||||
RUN --mount=type=bind,from=export_vllm,src=/,target=/install \
|
||||
cd /install \
|
||||
&& pip install -U -r requirements-rocm.txt \
|
||||
&& pip install -U -r requirements/rocm.txt \
|
||||
&& pip install -U -r requirements/rocm-test.txt \
|
||||
&& pip uninstall -y vllm \
|
||||
&& pip install *.whl
|
||||
|
||||
@ -99,7 +100,7 @@ RUN if [ ${BUILD_RPD} -eq "1" ]; then \
|
||||
# Install vLLM
|
||||
RUN --mount=type=bind,from=export_vllm,src=/,target=/install \
|
||||
cd /install \
|
||||
&& pip install -U -r requirements-rocm.txt \
|
||||
&& pip install -U -r requirements/rocm.txt \
|
||||
&& pip uninstall -y vllm \
|
||||
&& pip install *.whl
|
||||
|
||||
|
||||
@ -12,6 +12,8 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
|
||||
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
|
||||
ARG FA_BRANCH="b7d29fb"
|
||||
ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
|
||||
ARG AITER_BRANCH="21d47a9"
|
||||
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
|
||||
|
||||
FROM ${BASE_IMAGE} AS base
|
||||
|
||||
@ -129,8 +131,18 @@ RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
|
||||
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
|
||||
pip install /install/*.whl
|
||||
|
||||
ARG AITER_REPO
|
||||
ARG AITER_BRANCH
|
||||
RUN git clone --recursive ${AITER_REPO}
|
||||
RUN cd aiter \
|
||||
&& git checkout ${AITER_BRANCH} \
|
||||
&& git submodule update --init --recursive \
|
||||
&& pip install -r requirements.txt \
|
||||
&& PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop && pip show aiter
|
||||
|
||||
ARG BASE_IMAGE
|
||||
ARG HIPBLASLT_BRANCH
|
||||
ARG HIPBLAS_COMMON_BRANCH
|
||||
ARG LEGACY_HIPBLASLT_OPTION
|
||||
ARG RCCL_BRANCH
|
||||
ARG RCCL_REPO
|
||||
@ -155,4 +167,6 @@ RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \
|
||||
&& echo "PYTORCH_REPO: ${PYTORCH_REPO}" >> /app/versions.txt \
|
||||
&& echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \
|
||||
&& echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt
|
||||
&& echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt \
|
||||
&& echo "AITER_BRANCH: ${AITER_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt
|
||||
|
||||
152
Dockerfile.s390x
Normal file
152
Dockerfile.s390x
Normal file
@ -0,0 +1,152 @@
|
||||
# Base UBI image for s390x architecture
|
||||
ARG BASE_UBI_IMAGE_TAG=9.5-1736404155
|
||||
ARG PYTHON_VERSION=3.12
|
||||
FROM registry.access.redhat.com/ubi9/ubi-minimal:${BASE_UBI_IMAGE_TAG} AS base
|
||||
|
||||
# Install basic dependencies
|
||||
ARG PYTHON_VERSION
|
||||
ENV PYTHON_VERSION=${PYTHON_VERSION}
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
ENV LANG=C.UTF-8 \
|
||||
LC_ALL=C.UTF-8
|
||||
|
||||
# Install development utilities
|
||||
RUN microdnf install -y \
|
||||
which procps findutils tar vim git gcc gcc-gfortran g++ make patch zlib-devel \
|
||||
libjpeg-turbo-devel libtiff-devel libpng-devel libwebp-devel freetype-devel harfbuzz-devel \
|
||||
openssl-devel openblas openblas-devel autoconf automake libtool cmake && \
|
||||
microdnf clean all
|
||||
|
||||
# Python Installation
|
||||
FROM base AS python-install
|
||||
ARG PYTHON_VERSION
|
||||
|
||||
ENV VIRTUAL_ENV=/opt/vllm
|
||||
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
|
||||
ENV PYTHON_VERSION=${PYTHON_VERSION}
|
||||
RUN microdnf install -y \
|
||||
python${PYTHON_VERSION}-devel python${PYTHON_VERSION}-pip python${PYTHON_VERSION}-wheel && \
|
||||
python${PYTHON_VERSION} -m venv $VIRTUAL_ENV && pip install --no-cache -U pip wheel uv && microdnf clean all
|
||||
|
||||
FROM python-install AS pyarrow
|
||||
|
||||
# Build Apache Arrow
|
||||
WORKDIR /tmp
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
git clone https://github.com/apache/arrow.git && \
|
||||
cd arrow/cpp && \
|
||||
mkdir release && cd release && \
|
||||
cmake -DCMAKE_BUILD_TYPE=Release \
|
||||
-DCMAKE_INSTALL_PREFIX=/usr/local \
|
||||
-DARROW_PYTHON=ON \
|
||||
-DARROW_PARQUET=ON \
|
||||
-DARROW_ORC=ON \
|
||||
-DARROW_FILESYSTEM=ON \
|
||||
-DARROW_WITH_LZ4=ON \
|
||||
-DARROW_WITH_ZSTD=ON \
|
||||
-DARROW_WITH_SNAPPY=ON \
|
||||
-DARROW_JSON=ON \
|
||||
-DARROW_CSV=ON \
|
||||
-DARROW_DATASET=ON \
|
||||
-DPROTOBUF_PROTOC_EXECUTABLE=/usr/bin/protoc \
|
||||
-DARROW_DEPENDENCY_SOURCE=BUNDLED \
|
||||
.. && \
|
||||
make -j$(nproc) && \
|
||||
make install && \
|
||||
cd ../../python && \
|
||||
export PYARROW_PARALLEL=4 && \
|
||||
export ARROW_BUILD_TYPE=release && \
|
||||
uv pip install -r requirements/build.txt && \
|
||||
python setup.py build_ext --build-type=$ARROW_BUILD_TYPE --bundle-arrow-cpp bdist_wheel
|
||||
|
||||
FROM python-install AS numa-build
|
||||
# Install numactl (needed for numa.h dependency)
|
||||
WORKDIR /tmp
|
||||
RUN curl -LO https://github.com/numactl/numactl/archive/refs/tags/v2.0.16.tar.gz && \
|
||||
tar -xvzf v2.0.16.tar.gz && \
|
||||
cd numactl-2.0.16 && \
|
||||
./autogen.sh && \
|
||||
./configure && \
|
||||
make
|
||||
|
||||
# Set include path
|
||||
ENV C_INCLUDE_PATH="/usr/local/include:$C_INCLUDE_PATH"
|
||||
|
||||
FROM python-install AS rust
|
||||
ENV CARGO_HOME=/root/.cargo
|
||||
ENV RUSTUP_HOME=/root/.rustup
|
||||
ENV PATH="$CARGO_HOME/bin:$RUSTUP_HOME/bin:$PATH"
|
||||
|
||||
RUN curl https://sh.rustup.rs -sSf | sh -s -- -y && \
|
||||
. "$CARGO_HOME/env" && \
|
||||
rustup default stable && \
|
||||
rustup show
|
||||
|
||||
FROM python-install AS torch-vision
|
||||
# Install torchvision
|
||||
ARG TORCH_VERSION=2.7.0.dev20250304
|
||||
ARG TORCH_VISION_VERSION=v0.20.1
|
||||
WORKDIR /tmp
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
git clone https://github.com/pytorch/vision.git && \
|
||||
cd vision && \
|
||||
git checkout $TORCH_VISION_VERSION && \
|
||||
uv pip install -v torch==${TORCH_VERSION} --extra-index-url https://download.pytorch.org/whl/nightly/cpu && \
|
||||
python setup.py bdist_wheel
|
||||
|
||||
# Final build stage
|
||||
FROM python-install AS vllm-cpu
|
||||
ARG PYTHON_VERSION
|
||||
|
||||
# Set correct library path for torch and numactl
|
||||
ENV LD_LIBRARY_PATH="/opt/vllm/lib64/python${PYTHON_VERSION}/site-packages/torch/lib:/usr/local/lib:$LD_LIBRARY_PATH"
|
||||
ENV C_INCLUDE_PATH="/usr/local/include:$C_INCLUDE_PATH"
|
||||
ENV UV_LINK_MODE=copy
|
||||
ENV CARGO_HOME=/root/.cargo
|
||||
ENV RUSTUP_HOME=/root/.rustup
|
||||
ENV PATH="$CARGO_HOME/bin:$RUSTUP_HOME/bin:$PATH"
|
||||
|
||||
COPY . /workspace/vllm
|
||||
WORKDIR /workspace/vllm
|
||||
|
||||
RUN --mount=type=bind,from=numa-build,src=/tmp/numactl-2.0.16,target=/numactl \
|
||||
make -C /numactl install
|
||||
|
||||
# Install dependencies, including PyTorch and Apache Arrow
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,from=rust,source=/root/.cargo,target=/root/.cargo,rw \
|
||||
--mount=type=bind,from=rust,source=/root/.rustup,target=/root/.rustup,rw \
|
||||
--mount=type=bind,from=pyarrow,source=/tmp/arrow/python/dist,target=/tmp/arrow-wheels \
|
||||
--mount=type=bind,from=torch-vision,source=/tmp/vision/dist,target=/tmp/vision-wheels/ \
|
||||
sed -i '/^torch/d' requirements/build.txt && \
|
||||
ARROW_WHL_FILE=$(ls /tmp/arrow-wheels/pyarrow-*.whl | head -n 1) && \
|
||||
VISION_WHL_FILE=$(ls /tmp/vision-wheels/*.whl | head -n 1) && \
|
||||
uv pip install -v \
|
||||
$ARROW_WHL_FILE \
|
||||
$VISION_WHL_FILE \
|
||||
--extra-index-url https://download.pytorch.org/whl/nightly/cpu \
|
||||
--index-strategy unsafe-best-match \
|
||||
-r requirements/build.txt \
|
||||
-r requirements/cpu.txt
|
||||
|
||||
# Build and install vllm
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
VLLM_TARGET_DEVICE=cpu python setup.py bdist_wheel && \
|
||||
uv pip install "$(echo dist/*.whl)[tensorizer]"
|
||||
|
||||
# setup non-root user for vllm
|
||||
RUN umask 002 && \
|
||||
useradd --uid 2000 --gid 0 vllm && \
|
||||
mkdir -p /home/vllm && \
|
||||
chmod g+rwx /home/vllm
|
||||
|
||||
COPY LICENSE /licenses/vllm.md
|
||||
COPY examples/*.jinja /app/data/template/
|
||||
|
||||
USER 2000
|
||||
WORKDIR /home/vllm
|
||||
|
||||
# Set the default entrypoint
|
||||
ENTRYPOINT ["python", "-m", "vllm.entrypoints.openai.api_server"]
|
||||
@ -15,11 +15,14 @@ ARG GIT_REPO_CHECK=0
|
||||
RUN --mount=type=bind,source=.git,target=.git \
|
||||
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi
|
||||
|
||||
# Remove existing versions of dependencies
|
||||
RUN pip uninstall -y torch torch_xla torchvision
|
||||
|
||||
ENV VLLM_TARGET_DEVICE="tpu"
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
--mount=type=bind,source=.git,target=.git \
|
||||
python3 -m pip install \
|
||||
-r requirements-tpu.txt
|
||||
-r requirements/tpu.txt
|
||||
RUN python3 setup.py develop
|
||||
|
||||
# install development dependencies (for testing)
|
||||
|
||||
@ -1,11 +1,7 @@
|
||||
FROM intel/oneapi-basekit:2024.2.1-0-devel-ubuntu22.04 AS vllm-base
|
||||
# oneapi 2025.0.2 docker base image use rolling 2448 package. https://dgpu-docs.intel.com/releases/packages.html?release=Rolling+2448.13&os=Ubuntu+22.04, and we don't need install driver manually.
|
||||
FROM intel/deep-learning-essentials:2025.0.2-0-devel-ubuntu22.04 AS vllm-base
|
||||
|
||||
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/intel-oneapi-archive-keyring.gpg > /dev/null && \
|
||||
echo "deb [signed-by=/usr/share/keyrings/intel-oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main " | tee /etc/apt/sources.list.d/oneAPI.list && \
|
||||
chmod 644 /usr/share/keyrings/intel-oneapi-archive-keyring.gpg && \
|
||||
wget -O- https://repositories.intel.com/graphics/intel-graphics.key | gpg --dearmor | tee /usr/share/keyrings/intel-graphics.gpg > /dev/null && \
|
||||
echo "deb [arch=amd64,i386 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/graphics/ubuntu jammy arc" | tee /etc/apt/sources.list.d/intel.gpu.jammy.list && \
|
||||
chmod 644 /usr/share/keyrings/intel-graphics.gpg
|
||||
RUN rm /etc/apt/sources.list.d/intel-graphics.list
|
||||
|
||||
RUN apt-get update -y && \
|
||||
apt-get install -y --no-install-recommends --fix-missing \
|
||||
@ -21,30 +17,20 @@ RUN apt-get update -y && \
|
||||
python3 \
|
||||
python3-dev \
|
||||
python3-pip \
|
||||
# vim \
|
||||
wget
|
||||
|
||||
WORKDIR /workspace/vllm
|
||||
COPY requirements-xpu.txt /workspace/vllm/requirements-xpu.txt
|
||||
COPY requirements-common.txt /workspace/vllm/requirements-common.txt
|
||||
COPY requirements/xpu.txt /workspace/vllm/requirements/xpu.txt
|
||||
COPY requirements/common.txt /workspace/vllm/requirements/common.txt
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install --no-cache-dir \
|
||||
-r requirements-xpu.txt
|
||||
|
||||
RUN git clone https://github.com/intel/pti-gpu && \
|
||||
cd pti-gpu/sdk && \
|
||||
git checkout 6c491f07a777ed872c2654ca9942f1d0dde0a082 && \
|
||||
mkdir build && \
|
||||
cd build && \
|
||||
cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=../cmake/toolchains/icpx_toolchain.cmake -DBUILD_TESTING=OFF .. && \
|
||||
make -j && \
|
||||
cmake --install . --config Release --prefix "/usr/local"
|
||||
-r requirements/xpu.txt
|
||||
|
||||
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/lib/"
|
||||
|
||||
COPY . .
|
||||
ARG GIT_REPO_CHECK
|
||||
ARG GIT_REPO_CHECK=0
|
||||
RUN --mount=type=bind,source=.git,target=.git \
|
||||
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi
|
||||
|
||||
@ -54,6 +40,12 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
--mount=type=bind,source=.git,target=.git \
|
||||
python3 setup.py install
|
||||
|
||||
# Please refer xpu doc, we need manually install intel-extension-for-pytorch 2.6.10+xpu due to there are some conflict dependencies with torch 2.6.0+xpu
|
||||
# FIXME: This will be fix in ipex 2.7. just leave this here for awareness.
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install intel-extension-for-pytorch==2.6.10+xpu \
|
||||
--extra-index-url=https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
FROM vllm-base AS vllm-openai
|
||||
|
||||
10
MANIFEST.in
10
MANIFEST.in
@ -1,9 +1,9 @@
|
||||
include LICENSE
|
||||
include requirements-common.txt
|
||||
include requirements-cuda.txt
|
||||
include requirements-rocm.txt
|
||||
include requirements-neuron.txt
|
||||
include requirements-cpu.txt
|
||||
include requirements/common.txt
|
||||
include requirements/cuda.txt
|
||||
include requirements/rocm.txt
|
||||
include requirements/neuron.txt
|
||||
include requirements/cpu.txt
|
||||
include CMakeLists.txt
|
||||
|
||||
recursive-include cmake *
|
||||
|
||||
36
README.md
36
README.md
@ -10,36 +10,25 @@ Easy, fast, and cheap LLM serving for everyone
|
||||
</h3>
|
||||
|
||||
<p align="center">
|
||||
| <a href="https://docs.vllm.ai"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://x.com/vllm_project"><b>Twitter/X</b></a> | <a href="https://slack.vllm.ai"><b>Developer Slack</b></a> |
|
||||
| <a href="https://docs.vllm.ai"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://x.com/vllm_project"><b>Twitter/X</b></a> | <a href="https://discuss.vllm.ai"><b>User Forum</b></a> | <a href="https://slack.vllm.ai"><b>Developer Slack</b></a> |
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
We’re excited to invite you to the first **vLLM China Meetup** on **March 16** in **Beijing**!
|
||||
[2025/03] We are collaborating with Ollama to host an [Inference Night](https://lu.ma/vllm-ollama) at Y Combinator in San Francisco on Thursday, March 27, at 6 PM. Discuss all things inference local or data center!
|
||||
|
||||
Join us to connect with the **vLLM team** and explore how vLLM is leveraged in **post-training, fine-tuning, and deployment**, including [verl](https://github.com/volcengine/verl), [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory), and [vllm-ascend](https://github.com/vllm-project/vllm-ascend).
|
||||
|
||||
👉 **[Register Now](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg)** to be part of the discussion!
|
||||
[2025/04] We're hosting our first-ever *vLLM Asia Developer Day* in Singapore on *April 3rd*! This is a full-day event (9 AM - 9 PM SGT) in partnership with SGInnovate, AMD, and Embedded LLM. Meet the vLLM team and learn about LLM inference for RL, MI300X, and more! [Register Now](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day)
|
||||
|
||||
---
|
||||
|
||||
*Latest News* 🔥
|
||||
|
||||
- [2025/03] We hosted [the first vLLM China Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg)! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1REHvfQMKGnvz6p3Fd23HhSO4c8j5WPGZV0bKYLwnHyQ/edit?usp=sharing).
|
||||
- [2025/03] We hosted [the East Coast vLLM Meetup](https://lu.ma/7mu4k4xx)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1NHiv8EUFF1NLd3fEYODm56nDmL26lEeXCaDgyDlTsRs/edit#slide=id.g31441846c39_0_0).
|
||||
- [2025/02] We hosted [the ninth vLLM meetup](https://lu.ma/h7g3kuj9) with Meta! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1jzC_PZVXrVNSFVCW-V4cFXb6pn7zZ2CyP_Flwo05aqg/edit?usp=sharing) and AMD [here](https://drive.google.com/file/d/1Zk5qEJIkTmlQ2eQcXQZlljAx3m9s7nwn/view?usp=sharing). The slides from Meta will not be posted.
|
||||
- [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html).
|
||||
- [2025/01] We hosted [the eighth vLLM meetup](https://lu.ma/zep56hui) with Google Cloud! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1epVkt4Zu8Jz_S5OhEHPc798emsYh2BwYfRuDDVEF7u4/edit?usp=sharing), and Google Cloud team [here](https://drive.google.com/file/d/1h24pHewANyRL11xy5dXUbvRC9F9Kkjix/view?usp=sharing).
|
||||
- [2024/12] vLLM joins [pytorch ecosystem](https://pytorch.org/blog/vllm-joins-pytorch)! Easy, Fast, and Cheap LLM Serving for Everyone!
|
||||
- [2024/11] We hosted [the seventh vLLM meetup](https://lu.ma/h0qvrajz) with Snowflake! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1e3CxQBV3JsfGp30SwyvS3eM_tW-ghOhJ9PAJGK6KR54/edit?usp=sharing), and Snowflake team [here](https://docs.google.com/presentation/d/1qF3RkDAbOULwz9WK5TOltt2fE9t6uIc_hVNLFAaQX6A/edit?usp=sharing).
|
||||
- [2024/10] We have just created a developer slack ([slack.vllm.ai](https://slack.vllm.ai)) focusing on coordinating contributions and discussing features. Please feel free to join us there!
|
||||
- [2024/10] Ray Summit 2024 held a special track for vLLM! Please find the opening talk slides from the vLLM team [here](https://docs.google.com/presentation/d/1B_KQxpHBTRa_mDF-tR6i8rWdOU5QoTZNcEg2MKZxEHM/edit?usp=sharing). Learn more from the [talks](https://www.youtube.com/playlist?list=PLzTswPQNepXl6AQwifuwUImLPFRVpksjR) from other vLLM contributors and users!
|
||||
- [2024/09] We hosted [the sixth vLLM meetup](https://lu.ma/87q3nvnh) with NVIDIA! Please find the meetup slides [here](https://docs.google.com/presentation/d/1wrLGwytQfaOTd5wCGSPNhoaW3nq0E-9wqyP7ny93xRs/edit?usp=sharing).
|
||||
- [2024/07] We hosted [the fifth vLLM meetup](https://lu.ma/lp0gyjqr) with AWS! Please find the meetup slides [here](https://docs.google.com/presentation/d/1RgUD8aCfcHocghoP3zmXzck9vX3RCI9yfUAB2Bbcl4Y/edit?usp=sharing).
|
||||
- [2024/07] In partnership with Meta, vLLM officially supports Llama 3.1 with FP8 quantization and pipeline parallelism! Please check out our blog post [here](https://blog.vllm.ai/2024/07/23/llama31.html).
|
||||
- [2024/06] We hosted [the fourth vLLM meetup](https://lu.ma/agivllm) with Cloudflare and BentoML! Please find the meetup slides [here](https://docs.google.com/presentation/d/1iJ8o7V2bQEi0BFEljLTwc5G1S10_Rhv3beed5oB0NJ4/edit?usp=sharing).
|
||||
- [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing).
|
||||
- [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) with IBM! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing).
|
||||
- [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) with a16z! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing).
|
||||
- [2023/08] We would like to express our sincere gratitude to [Andreessen Horowitz](https://a16z.com/2023/08/30/supporting-the-open-source-ai-community/) (a16z) for providing a generous grant to support the open-source development and research of vLLM.
|
||||
- [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai).
|
||||
- [2024/12] vLLM joins [PyTorch ecosystem](https://pytorch.org/blog/vllm-joins-pytorch)! Easy, Fast, and Cheap LLM Serving for Everyone!
|
||||
|
||||
---
|
||||
|
||||
@ -90,7 +79,7 @@ pip install vllm
|
||||
```
|
||||
|
||||
Visit our [documentation](https://docs.vllm.ai/en/latest/) to learn more.
|
||||
- [Installation](https://docs.vllm.ai/en/latest/getting_started/installation/index.html)
|
||||
- [Installation](https://docs.vllm.ai/en/latest/getting_started/installation.html)
|
||||
- [Quickstart](https://docs.vllm.ai/en/latest/getting_started/quickstart.html)
|
||||
- [List of Supported Models](https://docs.vllm.ai/en/latest/models/supported_models.html)
|
||||
|
||||
@ -150,10 +139,11 @@ If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs
|
||||
|
||||
## Contact Us
|
||||
|
||||
- For technical questions and feature requests, please use Github issues or discussions.
|
||||
- For discussing with fellow users and coordinating contributions and development, please use Slack.
|
||||
- For security disclosures, please use Github's security advisory feature.
|
||||
- For collaborations and partnerships, please contact us at vllm-questions AT lists.berkeley.edu.
|
||||
- For technical questions and feature requests, please use GitHub [Issues](https://github.com/vllm-project/vllm/issues) or [Discussions](https://github.com/vllm-project/vllm/discussions)
|
||||
- For discussing with fellow users, please use the [vLLM Forum](https://discuss.vllm.ai)
|
||||
- coordinating contributions and development, please use [Slack](https://slack.vllm.ai)
|
||||
- For security disclosures, please use GitHub's [Security Advisories](https://github.com/vllm-project/vllm/security/advisories) feature
|
||||
- For collaborations and partnerships, please contact us at [vllm-questions@lists.berkeley.edu](mailto:vllm-questions@lists.berkeley.edu)
|
||||
|
||||
## Media Kit
|
||||
|
||||
|
||||
@ -1,29 +1,268 @@
|
||||
# Benchmarking vLLM
|
||||
|
||||
## Downloading the ShareGPT dataset
|
||||
This README guides you through running benchmark tests with the extensive
|
||||
datasets supported on vLLM. It’s a living document, updated as new features and datasets
|
||||
become available.
|
||||
|
||||
You can download the dataset by running:
|
||||
## Dataset Overview
|
||||
|
||||
<table style="width:100%; border-collapse: collapse;">
|
||||
<thead>
|
||||
<tr>
|
||||
<th style="width:15%; text-align: left;">Dataset</th>
|
||||
<th style="width:10%; text-align: center;">Online</th>
|
||||
<th style="width:10%; text-align: center;">Offline</th>
|
||||
<th style="width:65%; text-align: left;">Data Path</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td><strong>ShareGPT</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td><code>wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>BurstGPT</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td><code>wget https://github.com/HPMLL/BurstGPT/releases/download/v1.1/BurstGPT_without_fails_2.csv</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>Sonnet</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td>Local file: <code>benchmarks/sonnet.txt</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>Random</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td><code>synthetic</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>HuggingFace</strong></td>
|
||||
<td style="text-align: center;">🟡</td>
|
||||
<td style="text-align: center;">🟡</td>
|
||||
<td>Specify your dataset path on HuggingFace</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>VisionArena</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td><code>lmarena-ai/vision-arena-bench-v0.1</code> (a HuggingFace dataset)</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
✅: supported
|
||||
|
||||
🚧: to be supported
|
||||
|
||||
🟡: Partial support. Currently, HuggingFaceDataset only supports dataset formats
|
||||
similar to `lmms-lab/LLaVA-OneVision-Data` and `Aeala/ShareGPT_Vicuna_unfiltered`.
|
||||
If you need support for other dataset formats, please consider contributing.
|
||||
|
||||
**Note**: VisionArena’s `dataset-name` should be set to `hf`
|
||||
|
||||
---
|
||||
## Example - Online Benchmark
|
||||
|
||||
First start serving your model
|
||||
|
||||
```bash
|
||||
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
MODEL_NAME="NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
vllm serve ${MODEL_NAME} --disable-log-requests
|
||||
```
|
||||
|
||||
## Downloading the ShareGPT4V dataset
|
||||
|
||||
The json file refers to several image datasets (coco, llava, etc.). The benchmark scripts
|
||||
will ignore a datapoint if the referred image is missing.
|
||||
Then run the benchmarking script
|
||||
|
||||
```bash
|
||||
wget https://huggingface.co/datasets/Lin-Chen/ShareGPT4V/resolve/main/sharegpt4v_instruct_gpt4-vision_cap100k.json
|
||||
mkdir coco -p
|
||||
wget http://images.cocodataset.org/zips/train2017.zip -O coco/train2017.zip
|
||||
unzip coco/train2017.zip -d coco/
|
||||
# download dataset
|
||||
# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
MODEL_NAME="NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
NUM_PROMPTS=10
|
||||
BACKEND="vllm"
|
||||
DATASET_NAME="sharegpt"
|
||||
DATASET_PATH="<your data path>/ShareGPT_V3_unfiltered_cleaned_split.json"
|
||||
python3 vllm/benchmarks/benchmark_serving.py --backend ${BACKEND} --model ${MODEL_NAME} --endpoint /v1/completions --dataset-name ${DATASET_NAME} --dataset-path ${DATASET_PATH} --num-prompts ${NUM_PROMPTS}
|
||||
```
|
||||
|
||||
# Downloading the BurstGPT dataset
|
||||
If successful, you will see the following output
|
||||
|
||||
You can download the BurstGPT v1.1 dataset by running:
|
||||
```
|
||||
============ Serving Benchmark Result ============
|
||||
Successful requests: 10
|
||||
Benchmark duration (s): 5.78
|
||||
Total input tokens: 1369
|
||||
Total generated tokens: 2212
|
||||
Request throughput (req/s): 1.73
|
||||
Output token throughput (tok/s): 382.89
|
||||
Total Token throughput (tok/s): 619.85
|
||||
---------------Time to First Token----------------
|
||||
Mean TTFT (ms): 71.54
|
||||
Median TTFT (ms): 73.88
|
||||
P99 TTFT (ms): 79.49
|
||||
-----Time per Output Token (excl. 1st token)------
|
||||
Mean TPOT (ms): 7.91
|
||||
Median TPOT (ms): 7.96
|
||||
P99 TPOT (ms): 8.03
|
||||
---------------Inter-token Latency----------------
|
||||
Mean ITL (ms): 7.74
|
||||
Median ITL (ms): 7.70
|
||||
P99 ITL (ms): 8.39
|
||||
==================================================
|
||||
```
|
||||
|
||||
### VisionArena Benchmark for Vision Language Models
|
||||
|
||||
```bash
|
||||
wget https://github.com/HPMLL/BurstGPT/releases/download/v1.1/BurstGPT_without_fails_2.csv
|
||||
# need a model with vision capability here
|
||||
vllm serve Qwen/Qwen2-VL-7B-Instruct --disable-log-requests
|
||||
```
|
||||
|
||||
```bash
|
||||
MODEL_NAME="Qwen/Qwen2-VL-7B-Instruct"
|
||||
NUM_PROMPTS=10
|
||||
BACKEND="openai-chat"
|
||||
DATASET_NAME="hf"
|
||||
DATASET_PATH="lmarena-ai/vision-arena-bench-v0.1"
|
||||
DATASET_SPLIT='train'
|
||||
|
||||
python3 vllm/benchmarks/benchmark_serving.py \
|
||||
--backend "${BACKEND}" \
|
||||
--model "${MODEL_NAME}" \
|
||||
--endpoint "/v1/chat/completions" \
|
||||
--dataset-name "${DATASET_NAME}" \
|
||||
--dataset-path "${DATASET_PATH}" \
|
||||
--hf-split "${DATASET_SPLIT}" \
|
||||
--num-prompts "${NUM_PROMPTS}"
|
||||
```
|
||||
|
||||
### HuggingFaceDataset Examples
|
||||
|
||||
Currently, HuggingFaceDataset only supports dataset formats
|
||||
similar to `lmms-lab/LLaVA-OneVision-Data` and `Aeala/ShareGPT_Vicuna_unfiltered`. If you need support for other dataset
|
||||
formats, please consider contributing.
|
||||
|
||||
```bash
|
||||
# need a model with vision capability here
|
||||
vllm serve Qwen/Qwen2-VL-7B-Instruct --disable-log-requests
|
||||
```
|
||||
|
||||
**`lmms-lab/LLaVA-OneVision-Data`**
|
||||
|
||||
```bash
|
||||
MODEL_NAME="Qwen/Qwen2-VL-7B-Instruct"
|
||||
NUM_PROMPTS=10
|
||||
BACKEND="openai-chat"
|
||||
DATASET_NAME="hf"
|
||||
DATASET_PATH="lmms-lab/LLaVA-OneVision-Data"
|
||||
DATASET_SPLIT='train'
|
||||
DATASET_SUBSET='chart2text(cauldron)'
|
||||
python3 vllm/benchmarks/benchmark_serving.py \
|
||||
--backend "${BACKEND}" \
|
||||
--model "${MODEL_NAME}" \
|
||||
--endpoint "/v1/chat/completions" \
|
||||
--dataset-name "${DATASET_NAME}" \
|
||||
--dataset-path "${DATASET_PATH}" \
|
||||
--hf-split "${DATASET_SPLIT}" \
|
||||
--num-prompts "${NUM_PROMPTS}" \
|
||||
--hf-subset "${DATASET_SUBSET}"
|
||||
```
|
||||
|
||||
**`Aeala/ShareGPT_Vicuna_unfiltered`**
|
||||
|
||||
```bash
|
||||
MODEL_NAME="Qwen/Qwen2-VL-7B-Instruct"
|
||||
NUM_PROMPTS=10
|
||||
BACKEND="openai-chat"
|
||||
DATASET_NAME="hf"
|
||||
DATASET_PATH="Aeala/ShareGPT_Vicuna_unfiltered"
|
||||
DATASET_SPLIT='train'
|
||||
python3 vllm/benchmarks/benchmark_serving.py \
|
||||
--backend "${BACKEND}" \
|
||||
--model "${MODEL_NAME}" \
|
||||
--endpoint "/v1/chat/completions" \
|
||||
--dataset-name "${DATASET_NAME}" \
|
||||
--dataset-path "${DATASET_PATH}" \
|
||||
--hf-split "${DATASET_SPLIT}" \
|
||||
--num-prompts "${NUM_PROMPTS}" \
|
||||
```
|
||||
|
||||
---
|
||||
## Example - Offline Throughput Benchmark
|
||||
|
||||
```bash
|
||||
MODEL_NAME="NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
NUM_PROMPTS=10
|
||||
DATASET_NAME="sonnet"
|
||||
DATASET_PATH="vllm/benchmarks/sonnet.txt"
|
||||
|
||||
python3 vllm/benchmarks/benchmark_throughput.py \
|
||||
--model "${MODEL_NAME}" \
|
||||
--dataset-name "${DATASET_NAME}" \
|
||||
--dataset-path "${DATASET_PATH}" \
|
||||
--num-prompts "${NUM_PROMPTS}"
|
||||
```
|
||||
|
||||
If successful, you will see the following output
|
||||
|
||||
```
|
||||
Throughput: 7.15 requests/s, 4656.00 total tokens/s, 1072.15 output tokens/s
|
||||
Total num prompt tokens: 5014
|
||||
Total num output tokens: 1500
|
||||
```
|
||||
|
||||
### VisionArena Benchmark for Vision Language Models
|
||||
|
||||
``` bash
|
||||
MODEL_NAME="Qwen/Qwen2-VL-7B-Instruct"
|
||||
NUM_PROMPTS=10
|
||||
DATASET_NAME="hf"
|
||||
DATASET_PATH="lmarena-ai/vision-arena-bench-v0.1"
|
||||
DATASET_SPLIT="train"
|
||||
|
||||
python3 vllm/benchmarks/benchmark_throughput.py \
|
||||
--model "${MODEL_NAME}" \
|
||||
--backend "vllm-chat" \
|
||||
--dataset-name "${DATASET_NAME}" \
|
||||
--dataset-path "${DATASET_PATH}" \
|
||||
--num-prompts "${NUM_PROMPTS}" \
|
||||
--hf-split "${DATASET_SPLIT}"
|
||||
```
|
||||
|
||||
The `num prompt tokens` now includes image token counts
|
||||
|
||||
```
|
||||
Throughput: 2.55 requests/s, 4036.92 total tokens/s, 326.90 output tokens/s
|
||||
Total num prompt tokens: 14527
|
||||
Total num output tokens: 1280
|
||||
```
|
||||
|
||||
### Benchmark with LoRA Adapters
|
||||
|
||||
``` bash
|
||||
# download dataset
|
||||
# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
MODEL_NAME="meta-llama/Llama-2-7b-hf"
|
||||
BACKEND="vllm"
|
||||
DATASET_NAME="sharegpt"
|
||||
DATASET_PATH="<your data path>/ShareGPT_V3_unfiltered_cleaned_split.json"
|
||||
NUM_PROMPTS=10
|
||||
MAX_LORAS=2
|
||||
MAX_LORA_RANK=8
|
||||
ENABLE_LORA="--enable-lora"
|
||||
LORA_PATH="yard1/llama-2-7b-sql-lora-test"
|
||||
|
||||
python3 vllm/benchmarks/benchmark_throughput.py \
|
||||
--model "${MODEL_NAME}" \
|
||||
--backend "${BACKEND}" \
|
||||
--dataset_path "${DATASET_PATH}" \
|
||||
--dataset_name "${DATASET_NAME}" \
|
||||
--num-prompts "${NUM_PROMPTS}" \
|
||||
--max-loras "${MAX_LORAS}" \
|
||||
--max-lora-rank "${MAX_LORA_RANK}" \
|
||||
${ENABLE_LORA} \
|
||||
--lora-path "${LORA_PATH}"
|
||||
```
|
||||
|
||||
@ -14,7 +14,8 @@ from tqdm.asyncio import tqdm
|
||||
from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast)
|
||||
|
||||
from vllm.model_executor.model_loader.weight_utils import get_lock
|
||||
# NOTE(simon): do not import vLLM here so the benchmark script
|
||||
# can run without vLLM installed.
|
||||
|
||||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
||||
|
||||
@ -27,7 +28,6 @@ class RequestFuncInput:
|
||||
output_len: int
|
||||
model: str
|
||||
model_name: Optional[str] = None
|
||||
best_of: int = 1
|
||||
logprobs: Optional[int] = None
|
||||
extra_body: Optional[dict] = None
|
||||
multi_modal_content: Optional[dict] = None
|
||||
@ -58,13 +58,12 @@ async def async_request_tgi(
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
params = {
|
||||
"best_of": request_func_input.best_of,
|
||||
"max_new_tokens": request_func_input.output_len,
|
||||
"do_sample": True,
|
||||
"temperature": 0.01, # TGI does not accept 0.0 temperature.
|
||||
"top_p": 0.99, # TGI does not accept 1.0 top_p.
|
||||
"truncate": request_func_input.prompt_len,
|
||||
# TGI does not accept ignore_eos flag.
|
||||
"ignore_eos_token": request_func_input.ignore_eos,
|
||||
}
|
||||
payload = {
|
||||
"inputs": request_func_input.prompt,
|
||||
@ -72,6 +71,10 @@ async def async_request_tgi(
|
||||
}
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
if request_func_input.ignore_eos:
|
||||
output.output_tokens = request_func_input.output_len
|
||||
else:
|
||||
output.output_tokens = None
|
||||
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
@ -130,7 +133,6 @@ async def async_request_trt_llm(
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
assert request_func_input.best_of == 1
|
||||
payload = {
|
||||
"accumulate_tokens": True,
|
||||
"text_input": request_func_input.prompt,
|
||||
@ -195,7 +197,6 @@ async def async_request_deepspeed_mii(
|
||||
) -> RequestFuncOutput:
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
assert request_func_input.best_of == 1
|
||||
|
||||
payload = {
|
||||
"prompt": request_func_input.prompt,
|
||||
@ -249,7 +250,6 @@ async def async_request_openai_completions(
|
||||
if request_func_input.model_name else request_func_input.model,
|
||||
"prompt": request_func_input.prompt,
|
||||
"temperature": 0.0,
|
||||
"best_of": request_func_input.best_of,
|
||||
"max_tokens": request_func_input.output_len,
|
||||
"logprobs": request_func_input.logprobs,
|
||||
"stream": True,
|
||||
@ -338,7 +338,7 @@ async def async_request_openai_chat_completions(
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith(
|
||||
"chat/completions"
|
||||
("chat/completions", "profile")
|
||||
), "OpenAI Chat Completions API URL must end with 'chat/completions'."
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
@ -432,6 +432,8 @@ def get_model(pretrained_model_name_or_path: str) -> str:
|
||||
if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true':
|
||||
from modelscope import snapshot_download
|
||||
|
||||
from vllm.model_executor.model_loader.weight_utils import get_lock
|
||||
|
||||
# Use file lock to prevent multiple processes from
|
||||
# downloading the same model weights at the same time.
|
||||
with get_lock(pretrained_model_name_or_path):
|
||||
|
||||
717
benchmarks/benchmark_dataset.py
Normal file
717
benchmarks/benchmark_dataset.py
Normal file
@ -0,0 +1,717 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
This module defines a framework for sampling benchmark requests from various
|
||||
datasets. Each dataset subclass of BenchmarkDataset must implement sample
|
||||
generation. Supported dataset types include:
|
||||
- ShareGPT
|
||||
- Random (synthetic)
|
||||
- Sonnet
|
||||
- BurstGPT
|
||||
- HuggingFace
|
||||
- VisionArena
|
||||
|
||||
TODO: Implement CustomDataset to parse a JSON file and convert its contents into
|
||||
SampleRequest instances, similar to the approach used in ShareGPT.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from functools import cache
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from datasets import load_dataset
|
||||
from PIL import Image
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.utils import get_adapter_absolute_path
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Data Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class SampleRequest:
|
||||
"""
|
||||
Represents a single inference request for benchmarking.
|
||||
"""
|
||||
|
||||
prompt: Union[str, Any]
|
||||
prompt_len: int
|
||||
expected_output_len: int
|
||||
multi_modal_data: Optional[Union[MultiModalDataDict, dict]] = None
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Benchmark Dataset Base Class
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class BenchmarkDataset(ABC):
|
||||
DEFAULT_SEED = 0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_path: Optional[str] = None,
|
||||
random_seed: int = DEFAULT_SEED,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the BenchmarkDataset with an optional dataset path and random
|
||||
seed. Args:
|
||||
dataset_path (Optional[str]): Path to the dataset. If None, it
|
||||
indicates that a default or random dataset might be used.
|
||||
random_seed (int): Seed value for reproducible shuffling or
|
||||
sampling. Defaults to DEFAULT_SEED.
|
||||
"""
|
||||
self.dataset_path = dataset_path
|
||||
# Set the random seed, ensuring that a None value is replaced with the
|
||||
# default seed.
|
||||
self.random_seed = (random_seed
|
||||
if random_seed is not None else self.DEFAULT_SEED)
|
||||
self.data = None
|
||||
|
||||
def apply_multimodal_chat_transformation(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_content: Optional[MultiModalDataDict] = None) -> list[dict]:
|
||||
"""
|
||||
Transform a prompt and optional multimodal content into a chat format.
|
||||
This method is used for chat models that expect a specific conversation
|
||||
format.
|
||||
"""
|
||||
content = [{"text": prompt, "type": "text"}]
|
||||
if mm_content is not None:
|
||||
content.append(mm_content)
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def load_data(self) -> None:
|
||||
"""
|
||||
Load data from the dataset path into self.data.
|
||||
|
||||
This method must be overridden by subclasses since the method to load
|
||||
data will vary depending on the dataset format and source.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If a subclass does not implement this method.
|
||||
"""
|
||||
# TODO (jenniferzhao): add support for downloading data
|
||||
raise NotImplementedError(
|
||||
"load_data must be implemented in subclasses.")
|
||||
|
||||
def get_random_lora_request(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
max_loras: Optional[int] = None,
|
||||
lora_path: Optional[str] = None,
|
||||
) -> tuple[Optional[LoRARequest], AnyTokenizer]:
|
||||
"""
|
||||
Optionally select a random LoRA request and return its associated
|
||||
tokenizer.
|
||||
|
||||
This method is used when LoRA parameters are provided. It randomly
|
||||
selects a LoRA based on max_loras and retrieves a cached tokenizer for
|
||||
that LoRA if available. Otherwise, it returns the base tokenizer.
|
||||
|
||||
Args:
|
||||
tokenizer (PreTrainedTokenizerBase): The base tokenizer to use if no
|
||||
LoRA is selected. max_loras (Optional[int]): The maximum number of
|
||||
LoRAs available. If None, LoRA is not used. lora_path
|
||||
(Optional[str]): Path to the LoRA parameters on disk. If None, LoRA
|
||||
is not used.
|
||||
|
||||
Returns:
|
||||
tuple[Optional[LoRARequest], AnyTokenizer]: A tuple where the first
|
||||
element is a LoRARequest (or None if not applicable) and the second
|
||||
element is the tokenizer associated with the LoRA request (or the
|
||||
base tokenizer).
|
||||
"""
|
||||
if max_loras is None or lora_path is None:
|
||||
return None, tokenizer
|
||||
|
||||
# Generate a random LoRA ID in the range [1, max_loras].
|
||||
lora_id = random.randint(1, max_loras)
|
||||
lora_request = LoRARequest(
|
||||
lora_name=str(lora_id),
|
||||
lora_int_id=lora_id,
|
||||
lora_path=lora_path_on_disk(lora_path),
|
||||
)
|
||||
if lora_id not in lora_tokenizer_cache:
|
||||
lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request)
|
||||
# Return lora_request and the cached tokenizer if available; otherwise,
|
||||
# return the base tokenizer
|
||||
return lora_request, lora_tokenizer_cache[lora_id] or tokenizer
|
||||
|
||||
@abstractmethod
|
||||
def sample(self, tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int) -> list[SampleRequest]:
|
||||
"""
|
||||
Abstract method to generate sample requests from the dataset.
|
||||
|
||||
Subclasses must override this method to implement dataset-specific logic
|
||||
for generating a list of SampleRequest objects.
|
||||
|
||||
Args:
|
||||
tokenizer (PreTrainedTokenizerBase): The tokenizer to be used
|
||||
for processing the dataset's text.
|
||||
num_requests (int): The number of sample requests to generate.
|
||||
|
||||
Returns:
|
||||
list[SampleRequest]: A list of sample requests generated from the
|
||||
dataset.
|
||||
"""
|
||||
raise NotImplementedError("sample must be implemented in subclasses.")
|
||||
|
||||
def maybe_oversample_requests(self, requests: list[SampleRequest],
|
||||
num_requests: int) -> None:
|
||||
"""
|
||||
Oversamples the list of requests if its size is less than the desired
|
||||
number.
|
||||
|
||||
Args:
|
||||
requests (List[SampleRequest]): The current list of sampled
|
||||
requests. num_requests (int): The target number of requests.
|
||||
"""
|
||||
if len(requests) < num_requests:
|
||||
random.seed(self.random_seed)
|
||||
additional = random.choices(requests,
|
||||
k=num_requests - len(requests))
|
||||
requests.extend(additional)
|
||||
logger.info("Oversampled requests to reach %d total samples.",
|
||||
num_requests)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Utility Functions and Global Caches
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def is_valid_sequence(
|
||||
prompt_len: int,
|
||||
output_len: int,
|
||||
min_len: int = 4,
|
||||
max_prompt_len: int = 1024,
|
||||
max_total_len: int = 2048,
|
||||
skip_min_output_len_check: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Validate a sequence based on prompt and output lengths.
|
||||
|
||||
Default pruning criteria are copied from the original `sample_hf_requests`
|
||||
and `sample_sharegpt_requests` functions in benchmark_serving.py, as well as
|
||||
from `sample_requests` in benchmark_throughput.py.
|
||||
"""
|
||||
# Check for invalid conditions
|
||||
prompt_too_short = prompt_len < min_len
|
||||
output_too_short = (not skip_min_output_len_check) and (output_len
|
||||
< min_len)
|
||||
prompt_too_long = prompt_len > max_prompt_len
|
||||
combined_too_long = (prompt_len + output_len) > max_total_len
|
||||
|
||||
# Return True if none of the invalid conditions are met
|
||||
return not (prompt_too_short or output_too_short or prompt_too_long
|
||||
or combined_too_long)
|
||||
|
||||
|
||||
@cache
|
||||
def lora_path_on_disk(lora_path: str) -> str:
|
||||
return get_adapter_absolute_path(lora_path)
|
||||
|
||||
|
||||
# Global cache for LoRA tokenizers.
|
||||
lora_tokenizer_cache: dict[int, AnyTokenizer] = {}
|
||||
|
||||
|
||||
def process_image(image: Any) -> Mapping[str, Any]:
|
||||
"""
|
||||
Process a single image input and return a multimedia content dictionary.
|
||||
|
||||
For a PIL.Image.Image input:
|
||||
- Converts the image to RGB.
|
||||
- Saves the image as a JPEG in-memory.
|
||||
- Encodes the JPEG data as a base64 string.
|
||||
- Returns a dictionary with the image as a base64 data URL.
|
||||
|
||||
For a string input:
|
||||
- Treats the string as a URL or file path.
|
||||
- Prepends "file://" if the string doesn't start with "http://" or
|
||||
"file://".
|
||||
- Returns a dictionary with the image URL.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input is neither a PIL.Image.Image nor a string.
|
||||
"""
|
||||
if isinstance(image, Image.Image):
|
||||
image = image.convert("RGB")
|
||||
with io.BytesIO() as image_data:
|
||||
image.save(image_data, format="JPEG")
|
||||
image_base64 = base64.b64encode(
|
||||
image_data.getvalue()).decode("utf-8")
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image_base64}"
|
||||
},
|
||||
}
|
||||
|
||||
if isinstance(image, str):
|
||||
image_url = (image if image.startswith(
|
||||
("http://", "file://")) else f"file://{image}")
|
||||
return {"type": "image_url", "image_url": {"url": image_url}}
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid image input {image}. Must be a PIL.Image.Image or str.")
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Random Dataset Implementation (Synthetic Data)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class RandomDataset(BenchmarkDataset):
|
||||
# Default values copied from benchmark_serving.py for the random dataset.
|
||||
DEFAULT_PREFIX_LEN = 0
|
||||
DEFAULT_RANGE_RATIO = 1.0
|
||||
DEFAULT_INPUT_LEN = 1024
|
||||
DEFAULT_OUTPUT_LEN = 128
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
prefix_len: int = DEFAULT_PREFIX_LEN,
|
||||
range_ratio: float = DEFAULT_RANGE_RATIO,
|
||||
input_len: int = DEFAULT_INPUT_LEN,
|
||||
output_len: int = DEFAULT_OUTPUT_LEN,
|
||||
**kwargs,
|
||||
) -> list[SampleRequest]:
|
||||
vocab_size = tokenizer.vocab_size
|
||||
|
||||
prefix_token_ids = (np.random.randint(
|
||||
0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else [])
|
||||
|
||||
input_low = int(input_len * range_ratio)
|
||||
output_low = int(output_len * range_ratio)
|
||||
|
||||
input_lens = np.random.randint(input_low,
|
||||
input_len + 1,
|
||||
size=num_requests)
|
||||
output_lens = np.random.randint(output_low,
|
||||
output_len + 1,
|
||||
size=num_requests)
|
||||
offsets = np.random.randint(0, vocab_size, size=num_requests)
|
||||
|
||||
requests = []
|
||||
for i in range(num_requests):
|
||||
inner_seq = ((offsets[i] + i + np.arange(input_lens[i])) %
|
||||
vocab_size).tolist()
|
||||
token_sequence = prefix_token_ids + inner_seq
|
||||
prompt = tokenizer.decode(token_sequence)
|
||||
total_input_len = prefix_len + int(input_lens[i])
|
||||
requests.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=total_input_len,
|
||||
expected_output_len=int(output_lens[i]),
|
||||
))
|
||||
return requests
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# ShareGPT Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ShareGPTDataset(BenchmarkDataset):
|
||||
"""
|
||||
Implements the ShareGPT dataset. Loads data from a JSON file and generates
|
||||
sample requests based on conversation turns.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.load_data()
|
||||
|
||||
def load_data(self) -> None:
|
||||
if self.dataset_path is None:
|
||||
raise ValueError("dataset_path must be provided for loading data.")
|
||||
|
||||
with open(self.dataset_path, encoding="utf-8") as f:
|
||||
self.data = json.load(f)
|
||||
# Filter entries with at least two conversation turns.
|
||||
self.data = [
|
||||
entry for entry in self.data
|
||||
if "conversations" in entry and len(entry["conversations"]) >= 2
|
||||
]
|
||||
random.seed(self.random_seed)
|
||||
random.shuffle(self.data)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
lora_path: Optional[str] = None,
|
||||
max_loras: Optional[int] = None,
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
**kwargs,
|
||||
) -> list:
|
||||
samples: list = []
|
||||
for entry in self.data:
|
||||
if len(samples) >= num_requests:
|
||||
break
|
||||
prompt, completion = (
|
||||
entry["conversations"][0]["value"],
|
||||
entry["conversations"][1]["value"],
|
||||
)
|
||||
|
||||
lora_request, tokenizer = self.get_random_lora_request(
|
||||
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path)
|
||||
prompt_ids = tokenizer(prompt).input_ids
|
||||
completion_ids = tokenizer(completion).input_ids
|
||||
prompt_len = len(prompt_ids)
|
||||
new_output_len = (len(completion_ids)
|
||||
if output_len is None else output_len)
|
||||
if not is_valid_sequence(prompt_len,
|
||||
new_output_len,
|
||||
skip_min_output_len_check=output_len
|
||||
is not None):
|
||||
continue
|
||||
if enable_multimodal_chat:
|
||||
prompt = self.apply_multimodal_chat_transformation(
|
||||
prompt, None)
|
||||
samples.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=new_output_len,
|
||||
lora_request=lora_request,
|
||||
))
|
||||
self.maybe_oversample_requests(samples, num_requests)
|
||||
return samples
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Sonnet Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SonnetDataset(BenchmarkDataset):
|
||||
"""
|
||||
Simplified implementation of the Sonnet dataset. Loads poem lines from a
|
||||
text file and generates sample requests. Default values here copied from
|
||||
`benchmark_serving.py` for the sonnet dataset.
|
||||
"""
|
||||
|
||||
DEFAULT_PREFIX_LEN = 200
|
||||
DEFAULT_INPUT_LEN = 550
|
||||
DEFAULT_OUTPUT_LEN = 150
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.load_data()
|
||||
|
||||
def load_data(self) -> None:
|
||||
if not self.dataset_path:
|
||||
raise ValueError("dataset_path must be provided.")
|
||||
with open(self.dataset_path, encoding="utf-8") as f:
|
||||
self.data = f.readlines()
|
||||
|
||||
def sample(
|
||||
self,
|
||||
tokenizer,
|
||||
num_requests: int,
|
||||
prefix_len: int = DEFAULT_PREFIX_LEN,
|
||||
input_len: int = DEFAULT_INPUT_LEN,
|
||||
output_len: int = DEFAULT_OUTPUT_LEN,
|
||||
return_prompt_formatted: bool = False,
|
||||
**kwargs,
|
||||
) -> list:
|
||||
# Calculate average token length for a poem line.
|
||||
tokenized_lines = [tokenizer(line).input_ids for line in self.data]
|
||||
avg_len = sum(len(tokens)
|
||||
for tokens in tokenized_lines) / len(tokenized_lines)
|
||||
|
||||
# Build the base prompt.
|
||||
base_prompt = "Pick as many lines as you can from these poem lines:\n"
|
||||
base_msg = [{"role": "user", "content": base_prompt}]
|
||||
base_fmt = tokenizer.apply_chat_template(base_msg,
|
||||
add_generation_prompt=True,
|
||||
tokenize=False)
|
||||
base_offset = len(tokenizer(base_fmt).input_ids)
|
||||
if input_len <= base_offset:
|
||||
raise ValueError(
|
||||
f"'input_len' must be higher than the base prompt length "
|
||||
f"({base_offset}).")
|
||||
|
||||
# Determine how many poem lines to use.
|
||||
num_input_lines = round((input_len - base_offset) / avg_len)
|
||||
num_prefix_lines = round((prefix_len - base_offset) / avg_len)
|
||||
prefix_lines = self.data[:num_prefix_lines]
|
||||
|
||||
samples = []
|
||||
for _ in range(num_requests):
|
||||
extra_lines = random.choices(self.data,
|
||||
k=num_input_lines - num_prefix_lines)
|
||||
prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}"
|
||||
msg = [{"role": "user", "content": prompt}]
|
||||
prompt_formatted = tokenizer.apply_chat_template(
|
||||
msg, add_generation_prompt=True, tokenize=False)
|
||||
prompt_len = len(tokenizer(prompt_formatted).input_ids)
|
||||
samples.append(
|
||||
SampleRequest(
|
||||
prompt=prompt_formatted
|
||||
if return_prompt_formatted else prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
))
|
||||
return samples
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# BurstGPT Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class BurstGPTDataset(BenchmarkDataset):
|
||||
"""
|
||||
Implements the BurstGPT dataset. Loads data from a CSV file and generates
|
||||
sample requests based on synthetic prompt generation. Only rows with Model
|
||||
"GPT-4" and positive response tokens are used.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.load_data()
|
||||
|
||||
def load_data(self, ):
|
||||
if self.dataset_path is None:
|
||||
raise ValueError("dataset_path must be provided for loading data.")
|
||||
|
||||
df = pd.read_csv(self.dataset_path)
|
||||
# Filter to keep only GPT-4 rows.
|
||||
gpt4_df = df[df["Model"] == "GPT-4"]
|
||||
# Remove failed requests (where Response tokens is 0 or less).
|
||||
gpt4_df = gpt4_df[gpt4_df["Response tokens"] > 0]
|
||||
# Sample the desired number of rows.
|
||||
self.data = gpt4_df
|
||||
|
||||
def _sample_loaded_data(self, num_requests: int) -> list:
|
||||
if num_requests <= len(self.data):
|
||||
data = self.data.sample(n=num_requests,
|
||||
random_state=self.random_seed)
|
||||
else:
|
||||
data = self.data.sample(
|
||||
n=num_requests,
|
||||
random_state=self.random_seed,
|
||||
replace=True,
|
||||
)
|
||||
# Convert the dataframe to a list of lists.
|
||||
return data.values.tolist()
|
||||
|
||||
def sample(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
max_loras: Optional[int] = None,
|
||||
lora_path: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> list[SampleRequest]:
|
||||
samples = []
|
||||
data = self._sample_loaded_data(num_requests=num_requests)
|
||||
for i in range(num_requests):
|
||||
input_len = int(data[i][2])
|
||||
output_len = int(data[i][3])
|
||||
lora_req, tokenizer = self.get_random_lora_request(
|
||||
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path)
|
||||
vocab_size = tokenizer.vocab_size
|
||||
# Generate a synthetic prompt: a list of token IDs computed as (i +
|
||||
# j) modulo vocab_size.
|
||||
token_ids = [(i + j) % vocab_size for j in range(input_len)]
|
||||
prompt = tokenizer.decode(token_ids)
|
||||
samples.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=input_len,
|
||||
expected_output_len=output_len,
|
||||
lora_request=lora_req,
|
||||
))
|
||||
return samples
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# HuggingFace Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class HuggingFaceDataset(BenchmarkDataset):
|
||||
"""
|
||||
Dataset class for processing a HuggingFace dataset with conversation data
|
||||
and optional images.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_split: str,
|
||||
dataset_subset: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.dataset_split = dataset_split
|
||||
self.dataset_subset = dataset_subset
|
||||
|
||||
self.load_data()
|
||||
|
||||
def load_data(self) -> None:
|
||||
if not self.dataset_path:
|
||||
raise ValueError("dataset_path must be provided for loading data.")
|
||||
|
||||
self.data = load_dataset(
|
||||
self.dataset_path,
|
||||
name=self.dataset_subset,
|
||||
split=self.dataset_split,
|
||||
streaming=True,
|
||||
)
|
||||
if self.data.features is None or "conversations" \
|
||||
not in self.data.features:
|
||||
raise ValueError(
|
||||
"HuggingFaceDataset currently only supports datasets with "
|
||||
"a 'conversations' column like lmms-lab/LLaVA-OneVision-Data. "
|
||||
"Please consider contributing if you would like to add "
|
||||
"support for additional dataset formats.")
|
||||
# Shuffle and filter examples with at least 2 conversations.
|
||||
self.data = self.data.shuffle(seed=self.random_seed).filter(
|
||||
lambda x: len(x["conversations"]) >= 2)
|
||||
|
||||
def sample(self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
**kwargs) -> list:
|
||||
sampled_requests = []
|
||||
dynamic_output = output_len is None
|
||||
|
||||
for item in self.data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
conv = item["conversations"]
|
||||
prompt, completion = conv[0]["value"], conv[1]["value"]
|
||||
|
||||
prompt_ids = tokenizer(prompt).input_ids
|
||||
completion_ids = tokenizer(completion).input_ids
|
||||
prompt_len = len(prompt_ids)
|
||||
completion_len = len(completion_ids)
|
||||
output_len = completion_len if dynamic_output else output_len
|
||||
assert isinstance(output_len, int) and output_len > 0
|
||||
if dynamic_output and not is_valid_sequence(
|
||||
prompt_len, completion_len):
|
||||
continue
|
||||
mm_content = process_image(
|
||||
item["image"]) if "image" in item else None
|
||||
if enable_multimodal_chat:
|
||||
# Note: when chat is enabled the request prompt_len is no longer
|
||||
# accurate and we will be using request output to count the
|
||||
# actual prompt len and output len
|
||||
prompt = self.apply_multimodal_chat_transformation(
|
||||
prompt, mm_content)
|
||||
sampled_requests.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
multi_modal_data=mm_content,
|
||||
))
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Vision Arena Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class VisionArenaDataset(HuggingFaceDataset):
|
||||
"""
|
||||
Vision Arena Dataset.
|
||||
"""
|
||||
|
||||
DEFAULT_OUTPUT_LEN = 128
|
||||
VISION_ARENA_DATASET_PATH = "lmarena-ai/vision-arena-bench-v0.1"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
if self.dataset_path != self.VISION_ARENA_DATASET_PATH:
|
||||
raise ValueError(f"Only support Vision Arena dataset.\
|
||||
This data path {self.dataset_path} is not valid.")
|
||||
if self.dataset_subset is None and self.dataset_split != "train":
|
||||
raise ValueError("Dataset split must be 'train'.")
|
||||
|
||||
self.load_data()
|
||||
|
||||
def load_data(self) -> None:
|
||||
dataset = load_dataset(
|
||||
self.dataset_path,
|
||||
name=self.dataset_subset,
|
||||
split=self.dataset_split,
|
||||
streaming=True,
|
||||
)
|
||||
self.data = dataset.shuffle(seed=self.random_seed)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
**kwargs,
|
||||
) -> list:
|
||||
output_len = (output_len
|
||||
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
|
||||
sampled_requests = []
|
||||
for item in self.data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
prompt = item["turns"][0][0]["content"]
|
||||
mm_content = process_image(item["images"][0])
|
||||
prompt_len = len(tokenizer(prompt).input_ids)
|
||||
if enable_multimodal_chat:
|
||||
# Note: when chat is enabled the request prompt_len is no longer
|
||||
# accurate and we will be using request output to count the
|
||||
# actual prompt len
|
||||
prompt = self.apply_multimodal_chat_transformation(
|
||||
prompt, mm_content)
|
||||
sampled_requests.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
multi_modal_data=mm_content,
|
||||
))
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
@ -1,507 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Benchmark guided decoding throughput."""
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
|
||||
import datasets
|
||||
import pandas as pd
|
||||
import uvloop
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.entrypoints.openai.api_server import (
|
||||
build_async_engine_client_from_engine_args)
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SampleRequest:
|
||||
"""A class representing a single inference request for benchmarking.
|
||||
|
||||
Attributes:
|
||||
prompt: The input text prompt for the model.
|
||||
multi_modal_data: Optional dictionary containing multi-modal data (e.g.
|
||||
images).
|
||||
prompt_len: The length of the prompt in tokens.
|
||||
expected_output_len: The expected length of the output in tokens.
|
||||
"""
|
||||
prompt: str
|
||||
prompt_len: int
|
||||
expected_output_len: int
|
||||
schema: dict
|
||||
structure_type: str = 'json'
|
||||
completion: str = None
|
||||
|
||||
|
||||
def run_vllm(requests: list[SampleRequest],
|
||||
engine_args: EngineArgs,
|
||||
n: int,
|
||||
guided_decoding_rate: float = 1.0,
|
||||
warmup: bool = False) -> float:
|
||||
from vllm import LLM, SamplingParams
|
||||
llm = LLM(**vars(engine_args))
|
||||
assert all(
|
||||
llm.llm_engine.model_config.max_model_len >= (
|
||||
request.prompt_len + request.expected_output_len)
|
||||
for request in requests), (
|
||||
"Please ensure that max_model_len is greater than the sum of"
|
||||
" prompt_len and expected_output_len for all requests.")
|
||||
|
||||
# Add the requests to the engine.
|
||||
prompts: list[str] = []
|
||||
sampling_params: list[SamplingParams] = []
|
||||
# create a list containing random selected true or false
|
||||
guided_decoding_req_idx = random.sample(
|
||||
range(len(requests)), int(len(requests) * guided_decoding_rate))
|
||||
|
||||
if warmup:
|
||||
print(">>>>> Running warmup prompt, for the first 5")
|
||||
# We setup the first 5 requests to warmup FSM
|
||||
# if using xgrammar dataset, we will skip warmup
|
||||
warmup_requests = requests[:5]
|
||||
for i, request in enumerate(warmup_requests):
|
||||
prompts.append(request.prompt)
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
guided_decoding=GuidedDecodingParams(json=request.schema)
|
||||
if guided_decoding_rate > 0 else None,
|
||||
))
|
||||
llm.generate(prompts, sampling_params, use_tqdm=False)
|
||||
|
||||
print(">>>>> Benchmark started...")
|
||||
prompts = []
|
||||
sampling_params = []
|
||||
for i, request in enumerate(requests):
|
||||
prompts.append(request.prompt)
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
**{request.structure_type: request.schema})
|
||||
if i in guided_decoding_req_idx else None,
|
||||
))
|
||||
|
||||
start = time.perf_counter()
|
||||
outputs = llm.generate(prompts, sampling_params, use_tqdm=False)
|
||||
ret = []
|
||||
for output, request in zip(outputs, requests):
|
||||
generated_text = output.outputs[0].text
|
||||
ret.append({
|
||||
"generated": generated_text,
|
||||
"expected": request.completion
|
||||
})
|
||||
end = time.perf_counter()
|
||||
return end - start, ret
|
||||
|
||||
|
||||
async def run_vllm_async(
|
||||
requests: list[SampleRequest],
|
||||
engine_args: AsyncEngineArgs,
|
||||
n: int,
|
||||
guided_decoding_rate: float = 1.0,
|
||||
warmup: bool = False,
|
||||
disable_frontend_multiprocessing: bool = False) -> float:
|
||||
from vllm import SamplingParams
|
||||
|
||||
async with build_async_engine_client_from_engine_args(
|
||||
engine_args, disable_frontend_multiprocessing) as llm:
|
||||
|
||||
assert all(
|
||||
llm.model_config.max_model_len >= (request.prompt_len +
|
||||
request.expected_output_len)
|
||||
for request in requests), (
|
||||
"Please ensure that max_model_len is greater than the sum of"
|
||||
" prompt_len and expected_output_len for all requests.")
|
||||
|
||||
# Add the requests to the engine.
|
||||
prompts: list[str] = []
|
||||
sampling_params: list[SamplingParams] = []
|
||||
guided_decoding_req_idx = random.sample(
|
||||
range(len(requests)), int(len(requests) * guided_decoding_rate))
|
||||
|
||||
if warmup:
|
||||
print(">>>>>> Running warmup prompt, for the first 5")
|
||||
# We setup the first 5 requests to warmup FSM
|
||||
# if using xgrammar dataset, we will skip warmup
|
||||
warmup_requests = requests[:5]
|
||||
for i, request in enumerate(warmup_requests):
|
||||
prompts.append(request.prompt)
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
json=request.schema)
|
||||
if guided_decoding_rate > 0 else None,
|
||||
))
|
||||
generators = []
|
||||
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
|
||||
generator = llm.generate(prompt, sp, request_id=f"test{i}")
|
||||
generators.append(generator)
|
||||
all_gens = merge_async_iterators(*generators)
|
||||
async for i, res in all_gens:
|
||||
pass
|
||||
|
||||
print(">>>>> Benchmark started...")
|
||||
prompts = []
|
||||
sampling_params = []
|
||||
for i, request in enumerate(requests):
|
||||
prompts.append(request.prompt)
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
guided_decoding=GuidedDecodingParams(json=request.schema)
|
||||
if i in guided_decoding_req_idx else None,
|
||||
))
|
||||
|
||||
generators = []
|
||||
start_time = []
|
||||
latencies = []
|
||||
start = time.perf_counter()
|
||||
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
|
||||
generator = llm.generate(prompt, sp, request_id=f"test{i}")
|
||||
generators.append(generator)
|
||||
start_time.append(time.perf_counter())
|
||||
latencies.append([])
|
||||
all_gens = merge_async_iterators(*generators)
|
||||
generated_texts = [''] * len(requests)
|
||||
async for i, res in all_gens:
|
||||
generated_texts[i] = res.outputs[0].text
|
||||
lat = time.perf_counter() - start_time[i]
|
||||
latencies[i].append(lat)
|
||||
ret = [{
|
||||
'generated': gt,
|
||||
'expected': req.completion
|
||||
} for gt, req in zip(generated_texts, requests)]
|
||||
end = time.perf_counter()
|
||||
first_latency = pd.Series([lat[0] * 1000 for lat in latencies])
|
||||
next_latency = pd.Series([(lat[-1] - lat[0]) / len(lat[1:]) * 1000
|
||||
for lat in latencies])
|
||||
return end - start, ret, (first_latency, next_latency)
|
||||
|
||||
|
||||
def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||
args: argparse.Namespace) -> list[SampleRequest]:
|
||||
if args.dataset == 'json':
|
||||
if args.json_schema_path is None:
|
||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||
args.json_schema_path = os.path.join(dir_path,
|
||||
"structured_schemas",
|
||||
"structured_schema_1.json")
|
||||
with open(args.json_schema_path) as f:
|
||||
schema = json.load(f)
|
||||
prompt = f"Generate an example of a user profile given the following schema: {json.dumps(schema)}" # noqa: E501
|
||||
input_len = len(tokenizer(prompt).input_ids)
|
||||
print(f"Input length of the prompt: {input_len} tokens")
|
||||
requests = [
|
||||
SampleRequest(prompt=prompt,
|
||||
prompt_len=input_len,
|
||||
expected_output_len=args.output_len,
|
||||
schema=schema,
|
||||
structure_type=args.structure_type)
|
||||
for _ in range(args.num_prompts)
|
||||
]
|
||||
|
||||
elif args.dataset == "grammar":
|
||||
schema = """
|
||||
?start: select_statement
|
||||
|
||||
?select_statement: "SELECT " column_list " FROM " table_name
|
||||
|
||||
?column_list: column_name ("," column_name)*
|
||||
|
||||
?table_name: identifier
|
||||
|
||||
?column_name: identifier
|
||||
|
||||
?identifier: /[a-zA-Z_][a-zA-Z0-9_]*/
|
||||
"""
|
||||
prompt = "Generate an SQL query to show the 'username' \
|
||||
and 'email' from the 'users' table."
|
||||
|
||||
input_len = len(tokenizer(prompt).input_ids)
|
||||
print(f"Input length of the prompt: {input_len} tokens")
|
||||
requests = [
|
||||
SampleRequest(prompt=prompt,
|
||||
prompt_len=input_len,
|
||||
expected_output_len=args.output_len,
|
||||
schema=schema,
|
||||
structure_type=args.structure_type)
|
||||
for _ in range(args.num_prompts)
|
||||
]
|
||||
|
||||
elif args.dataset == "regex":
|
||||
regex = r"\w+@\w+\.com\n"
|
||||
args.regex = regex
|
||||
prompt = "Generate an email address for Alan Turing, \
|
||||
who works in Enigma. End in .com and new line. \
|
||||
Example result: alan.turing@enigma.com\n"
|
||||
|
||||
input_len = len(tokenizer(prompt).input_ids)
|
||||
print(f"Input length of the prompt: {input_len} tokens")
|
||||
requests = [
|
||||
SampleRequest(prompt=prompt,
|
||||
prompt_len=input_len,
|
||||
expected_output_len=args.output_len,
|
||||
schema=regex,
|
||||
structure_type=args.structure_type)
|
||||
for _ in range(args.num_prompts)
|
||||
]
|
||||
|
||||
elif args.dataset == "choice":
|
||||
choice = ["Positive", "Negative"]
|
||||
args.choice = choice
|
||||
prompt = "Classify this sentiment: vLLM is wonderful!"
|
||||
input_len = len(tokenizer(prompt).input_ids)
|
||||
print(f"Input length of the prompt: {input_len} tokens")
|
||||
requests = [
|
||||
SampleRequest(prompt=prompt,
|
||||
prompt_len=input_len,
|
||||
expected_output_len=args.output_len,
|
||||
schema=choice,
|
||||
structure_type=args.structure_type)
|
||||
for _ in range(args.num_prompts)
|
||||
]
|
||||
|
||||
elif args.dataset == "xgrammar_bench":
|
||||
args.warmup = False
|
||||
requests: list[SampleRequest] = []
|
||||
dataset = datasets.load_dataset("NousResearch/json-mode-eval",
|
||||
split="train")
|
||||
print(f"dataset has {len(dataset)} entries")
|
||||
len_dataset = len(dataset)
|
||||
for data_point_idx in range(args.num_prompts):
|
||||
idx = data_point_idx
|
||||
while idx >= len_dataset:
|
||||
idx -= len_dataset
|
||||
schema = dataset["schema"][idx]
|
||||
prompt = tokenizer.apply_chat_template(dataset["prompt"][idx],
|
||||
tokenize=False)
|
||||
input_len = len(tokenizer(prompt).input_ids)
|
||||
completion = dataset["completion"][idx]
|
||||
|
||||
requests.append(
|
||||
SampleRequest(prompt=prompt,
|
||||
prompt_len=input_len,
|
||||
expected_output_len=args.output_len,
|
||||
schema=schema,
|
||||
completion=completion))
|
||||
|
||||
return requests
|
||||
|
||||
|
||||
def evaluate(ret, args):
|
||||
|
||||
def _eval_correctness_json(expected, actual):
|
||||
# extract json string from string using regex
|
||||
import re
|
||||
actual = actual.replace('\n', '').replace(' ', '').strip()
|
||||
try:
|
||||
actual = re.search(r'\{.*\}', actual).group()
|
||||
actual = json.loads(actual)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _eval_correctness_choice(expected, actual):
|
||||
return actual in args.choice
|
||||
|
||||
def _eval_correctness_regex(expected, actual):
|
||||
import re
|
||||
return re.match(args.regex, actual) is not None
|
||||
|
||||
def _eval_correctness(expected, actual):
|
||||
if args.structure_type == 'json':
|
||||
return _eval_correctness_json(expected, actual)
|
||||
elif args.structure_type == 'regex':
|
||||
return _eval_correctness_regex(expected, actual)
|
||||
elif args.structure_type == 'choice':
|
||||
return _eval_correctness_choice(expected, actual)
|
||||
else:
|
||||
return None
|
||||
|
||||
scores = []
|
||||
for res in ret:
|
||||
score = _eval_correctness(res['expected'], res['generated'])
|
||||
res['correctness'] = score
|
||||
scores.append(score)
|
||||
|
||||
not_none_scores = [score for score in scores if score is not None]
|
||||
|
||||
return (sum(not_none_scores) / len(not_none_scores) *
|
||||
100) if len(not_none_scores) > 0 else None
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
print(args)
|
||||
random.seed(args.seed)
|
||||
|
||||
# async engine is working for 'regex', 'choice' and 'grammar'
|
||||
if args.dataset == 'grammar':
|
||||
args.structure_type = 'grammar'
|
||||
args.async_engine = False
|
||||
elif args.dataset == 'regex':
|
||||
args.structure_type = 'regex'
|
||||
args.async_engine = False
|
||||
elif args.dataset == 'choice':
|
||||
args.structure_type = 'choice'
|
||||
args.async_engine = False
|
||||
else:
|
||||
args.structure_type = 'json'
|
||||
|
||||
if args.no_guided_decoding:
|
||||
args.guided_decoding_ratio = 0
|
||||
if args.save_results:
|
||||
result_file_name = f'{args.guided_decoding_ratio}guided'
|
||||
result_file_name += f"_{args.model.split('/')[-1]}"
|
||||
result_file_name += f"_{args.dataset}"
|
||||
result_file_name += f"_{args.num_prompts}"
|
||||
result_file_name += f"_out{args.output_len}"
|
||||
result_file_name += f"_async{args.async_engine}"
|
||||
result_file_name += f"_warmup{args.warmup}"
|
||||
result_file_name += f"_chunkedprefill{args.enable_chunked_prefill}"
|
||||
result_file_name += ".txt"
|
||||
else:
|
||||
result_file_name = None
|
||||
|
||||
# Synthesize a prompt with the given input length.
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||
requests = sample_requests(tokenizer, args)
|
||||
|
||||
if args.async_engine:
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
elapsed_time, ret, (first_latency, next_latency) = uvloop.run(
|
||||
run_vllm_async(requests, engine_args, args.n,
|
||||
args.guided_decoding_ratio, args.warmup,
|
||||
args.disable_frontend_multiprocessing))
|
||||
else:
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
elapsed_time, ret = run_vllm(requests, engine_args, args.n,
|
||||
args.guided_decoding_ratio, args.warmup)
|
||||
first_latency, next_latency = None, None
|
||||
|
||||
score = evaluate(ret, args)
|
||||
total_num_tokens = sum(request.prompt_len + request.expected_output_len
|
||||
for request in requests)
|
||||
total_output_tokens = sum(request.expected_output_len
|
||||
for request in requests)
|
||||
if first_latency is not None:
|
||||
latency_breakdown = "\nFirst token latency(msecs):\n"
|
||||
latency_breakdown += f"{first_latency.describe()}"
|
||||
latency_breakdown += "\nNext token latency(msecs):\n"
|
||||
latency_breakdown += f"{next_latency.describe()}"
|
||||
print(
|
||||
f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
||||
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
|
||||
f"{total_output_tokens / elapsed_time:.2f} output tokens/s",
|
||||
f"Correct rate is {score} %",
|
||||
f"{latency_breakdown if first_latency is not None else ''}")
|
||||
|
||||
# Output JSON results if specified
|
||||
if args.output_json or result_file_name:
|
||||
results = {
|
||||
"elapsed_time": elapsed_time,
|
||||
"num_requests": len(requests),
|
||||
"total_num_tokens": total_num_tokens,
|
||||
"total_output_tokens": total_output_tokens,
|
||||
"requests_per_second": len(requests) / elapsed_time,
|
||||
"tokens_per_second": f"{total_num_tokens / elapsed_time:.2f}",
|
||||
"output_tokens_per_second":
|
||||
f"{total_output_tokens / elapsed_time:.2f}",
|
||||
"correct_rate(%)": score
|
||||
}
|
||||
results = {"outputs": ret, **results}
|
||||
if first_latency is not None:
|
||||
results["first_token_latency(msecs)"] = first_latency.describe(
|
||||
).to_dict()
|
||||
results["next_token_latency(msecs)"] = next_latency.describe(
|
||||
).to_dict()
|
||||
if args.output_json:
|
||||
with open(args.output_json, "w") as f:
|
||||
json.dump(results, f, indent=4)
|
||||
elif result_file_name:
|
||||
with open(result_file_name, "w") as f:
|
||||
json.dump(results, f, indent=4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(description="Benchmark guided decoding.")
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
|
||||
parser.add_argument("--output-len",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Output length for each request. Overrides the "
|
||||
"output length from the dataset.")
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
default='json',
|
||||
choices=['json', 'grammar', 'regex', 'choice', 'xgrammar_bench'])
|
||||
parser.add_argument("--json_schema_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to json schema.")
|
||||
parser.add_argument("--n",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of generated sequences per prompt.")
|
||||
parser.add_argument("--num-prompts",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of prompts to process.")
|
||||
parser.add_argument(
|
||||
'--output-json',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Path to save the throughput results in JSON format.')
|
||||
parser.add_argument("--async-engine",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Use vLLM async engine rather than LLM class.")
|
||||
parser.add_argument("--no-guided-decoding",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Whether to disable JSON decoding or not.")
|
||||
parser.add_argument("--guided-decoding-ratio",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Ratio of Guided Decoding requests")
|
||||
parser.add_argument("--disable-frontend-multiprocessing",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Disable decoupled async engine frontend.")
|
||||
parser.add_argument("--warmup",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Run warmup prompts before benchmark.")
|
||||
parser.add_argument("--save-results",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="save output results.")
|
||||
args = parser.parse_args()
|
||||
if args.tokenizer is None:
|
||||
args.tokenizer = args.model
|
||||
main(args)
|
||||
@ -52,6 +52,7 @@ def main(args: argparse.Namespace):
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=args.output_len,
|
||||
detokenize=not args.disable_detokenize,
|
||||
)
|
||||
print(sampling_params)
|
||||
dummy_prompt_token_ids = np.random.randint(10000,
|
||||
@ -173,6 +174,12 @@ if __name__ == "__main__":
|
||||
default=None,
|
||||
help="Path to save the latency results in JSON format.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-detokenize",
|
||||
action="store_true",
|
||||
help=("Do not detokenize responses (i.e. do not include "
|
||||
"detokenization time in the latency measurement)"),
|
||||
)
|
||||
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -194,7 +194,9 @@ def main(args):
|
||||
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
|
||||
sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
|
||||
sampling_params = SamplingParams(temperature=0,
|
||||
max_tokens=args.output_len,
|
||||
detokenize=not args.disable_detokenize)
|
||||
|
||||
print("Testing filtered requests")
|
||||
prompts = repeat_and_sort_requests(filtered_requests,
|
||||
@ -243,6 +245,12 @@ if __name__ == "__main__":
|
||||
"subtract this length when filtering prompts. Only used "
|
||||
"when dataset-path is not provided.",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--disable-detokenize',
|
||||
action='store_true',
|
||||
help=("Do not detokenize responses (i.e. do not include "
|
||||
"detokenization time in the latency measurement)"),
|
||||
)
|
||||
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -23,7 +23,7 @@ def sample_requests(
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
fixed_output_len: Optional[int],
|
||||
) -> list[tuple[str, int, int]]:
|
||||
) -> list[tuple[str, int, int, int]]:
|
||||
if fixed_output_len is not None and fixed_output_len < 4:
|
||||
raise ValueError("output_len too small")
|
||||
|
||||
@ -71,6 +71,7 @@ def run_vllm(
|
||||
requests: list[tuple[str, int, int]],
|
||||
n: int,
|
||||
engine_args: EngineArgs,
|
||||
disable_detokenize: bool = False,
|
||||
) -> float:
|
||||
from vllm import LLM, SamplingParams
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
@ -95,6 +96,7 @@ def run_vllm(
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=output_len,
|
||||
detokenize=not disable_detokenize,
|
||||
))
|
||||
|
||||
start = time.perf_counter()
|
||||
@ -121,7 +123,8 @@ def main(args: argparse.Namespace):
|
||||
|
||||
if args.backend == "vllm":
|
||||
elapsed_time = run_vllm(requests, args.n,
|
||||
EngineArgs.from_cli_args(args))
|
||||
EngineArgs.from_cli_args(args),
|
||||
args.disable_detokenize)
|
||||
else:
|
||||
raise ValueError(f"Unknown backend: {args.backend}")
|
||||
total_num_tokens = sum(prompt_len + output_len
|
||||
@ -174,6 +177,12 @@ if __name__ == "__main__":
|
||||
type=str,
|
||||
default=None,
|
||||
help='Path to save the throughput results in JSON format.')
|
||||
parser.add_argument(
|
||||
'--disable-detokenize',
|
||||
action='store_true',
|
||||
help=("Do not detokenize responses (i.e. do not include "
|
||||
"detokenization time in the latency measurement)"),
|
||||
)
|
||||
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -25,25 +25,20 @@ On the client side, run:
|
||||
"""
|
||||
import argparse
|
||||
import asyncio
|
||||
import base64
|
||||
import gc
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import warnings
|
||||
from collections.abc import AsyncGenerator, Collection
|
||||
from collections.abc import AsyncGenerator, Iterable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
|
||||
RequestFuncOutput)
|
||||
from datasets import load_dataset
|
||||
from PIL.Image import Image
|
||||
from tqdm.asyncio import tqdm
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
@ -57,6 +52,9 @@ try:
|
||||
except ImportError:
|
||||
from argparse import ArgumentParser as FlexibleArgumentParser
|
||||
|
||||
from benchmark_dataset import (BurstGPTDataset, HuggingFaceDataset,
|
||||
RandomDataset, SampleRequest, ShareGPTDataset,
|
||||
SonnetDataset, VisionArenaDataset)
|
||||
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
|
||||
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
|
||||
@ -92,325 +90,18 @@ class BenchmarkMetrics:
|
||||
percentiles_e2el_ms: list[tuple[float, float]]
|
||||
|
||||
|
||||
def sample_sharegpt_requests(
|
||||
dataset_path: str,
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
fixed_output_len: Optional[int] = None,
|
||||
) -> list[tuple[str, int, int, None]]:
|
||||
# Load the dataset.
|
||||
with open(dataset_path, encoding='utf-8') as f:
|
||||
dataset = json.load(f)
|
||||
# Filter out the conversations with less than 2 turns.
|
||||
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
||||
# Only keep the first two turns of each conversation.
|
||||
dataset = [(data["conversations"][0]["value"],
|
||||
data["conversations"][1]["value"]) for data in dataset]
|
||||
|
||||
# Shuffle the dataset.
|
||||
random.shuffle(dataset)
|
||||
|
||||
# Filter out sequences that are too long or too short
|
||||
filtered_dataset: list[tuple[str, int, int]] = []
|
||||
for i in range(len(dataset)):
|
||||
if len(filtered_dataset) == num_requests:
|
||||
break
|
||||
|
||||
# Tokenize the prompts and completions.
|
||||
prompt = dataset[i][0]
|
||||
prompt_token_ids = tokenizer(prompt).input_ids
|
||||
completion = dataset[i][1]
|
||||
completion_token_ids = tokenizer(completion).input_ids
|
||||
prompt_len = len(prompt_token_ids)
|
||||
output_len = len(completion_token_ids
|
||||
) if fixed_output_len is None else fixed_output_len
|
||||
if prompt_len < 4 or (fixed_output_len is None and output_len < 4):
|
||||
# Prune too short sequences.
|
||||
continue
|
||||
if prompt_len > 1024 or prompt_len + output_len > 2048:
|
||||
# Prune too long sequences.
|
||||
continue
|
||||
filtered_dataset.append((prompt, prompt_len, output_len, None))
|
||||
|
||||
return filtered_dataset
|
||||
|
||||
|
||||
def sample_burstgpt_requests(
|
||||
dataset_path: str,
|
||||
num_requests: int,
|
||||
random_seed: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
) -> list[tuple[str, int, int, None]]:
|
||||
df = pd.read_csv(dataset_path)
|
||||
gpt4_df = df[df["Model"] == "GPT-4"]
|
||||
# Remove the failed requests (i.e., response length is 0)
|
||||
gpt4_df = gpt4_df[gpt4_df["Response tokens"] > 0]
|
||||
# Randomly sample num_requests from the dataset
|
||||
if num_requests <= len(gpt4_df):
|
||||
gpt4_df = gpt4_df.sample(n=num_requests, random_state=random_seed)
|
||||
else:
|
||||
gpt4_df = gpt4_df.sample(n=num_requests,
|
||||
random_state=random_seed,
|
||||
replace=True)
|
||||
# Convert the dataframe to a list of tuples
|
||||
dataset = gpt4_df.values.tolist()
|
||||
input_requests = []
|
||||
for i in range(num_requests):
|
||||
input_len = int(dataset[i][2])
|
||||
output_len = int(dataset[i][3])
|
||||
prompt = tokenizer.decode([(i + j) % tokenizer.vocab_size
|
||||
for j in range(input_len)])
|
||||
input_requests.append((prompt, input_len, output_len, None))
|
||||
return input_requests
|
||||
|
||||
|
||||
def sample_sonnet_requests(
|
||||
dataset_path: str,
|
||||
num_requests: int,
|
||||
input_len: int,
|
||||
output_len: int,
|
||||
prefix_len: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
) -> list[tuple[str, str, int, int, None]]:
|
||||
assert (
|
||||
input_len > prefix_len
|
||||
), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'."
|
||||
|
||||
# Load the dataset.
|
||||
with open(dataset_path, encoding='utf-8') as f:
|
||||
poem_lines = f.readlines()
|
||||
|
||||
# Tokenize the poem lines.
|
||||
poem_token_ids = tokenizer(poem_lines).input_ids
|
||||
average_poem_len = sum(
|
||||
len(token_ids) for token_ids in poem_token_ids) / len(poem_token_ids)
|
||||
|
||||
# Base prefix for all requests.
|
||||
base_prompt = "Pick as many lines as you can from these poem lines:\n"
|
||||
base_message = [{
|
||||
"role": "user",
|
||||
"content": base_prompt,
|
||||
}]
|
||||
base_prompt_formatted = tokenizer.apply_chat_template(
|
||||
base_message, add_generation_prompt=True, tokenize=False)
|
||||
base_prompt_offset = len(tokenizer(base_prompt_formatted).input_ids)
|
||||
|
||||
assert (
|
||||
input_len > base_prompt_offset
|
||||
), f"Please set 'args.sonnet-input-len' higher than {base_prompt_offset}."
|
||||
num_input_lines = round(
|
||||
(input_len - base_prompt_offset) / average_poem_len)
|
||||
|
||||
# First approximately `prefix_len` number of tokens in the
|
||||
# prompt are fixed poem lines.
|
||||
assert (
|
||||
prefix_len > base_prompt_offset
|
||||
), f"Please set 'args.sonnet-prefix-len' higher than {base_prompt_offset}."
|
||||
|
||||
num_prefix_lines = round(
|
||||
(prefix_len - base_prompt_offset) / average_poem_len)
|
||||
prefix_lines = poem_lines[:num_prefix_lines]
|
||||
|
||||
# Sample the rest of lines per request.
|
||||
sampled_requests: list[tuple[str, int, int]] = []
|
||||
for _ in range(num_requests):
|
||||
num_lines_needed = num_input_lines - num_prefix_lines
|
||||
sampled_lines = "".join(prefix_lines +
|
||||
random.choices(poem_lines, k=num_lines_needed))
|
||||
|
||||
prompt = f"{base_prompt}{sampled_lines}"
|
||||
message = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
},
|
||||
]
|
||||
prompt_formatted = tokenizer.apply_chat_template(
|
||||
message, add_generation_prompt=True, tokenize=False)
|
||||
prompt_len = len(tokenizer(prompt_formatted).input_ids)
|
||||
sampled_requests.append(
|
||||
(prompt, prompt_formatted, prompt_len, output_len, None))
|
||||
|
||||
return sampled_requests
|
||||
|
||||
|
||||
def sample_vision_arena_requests(
|
||||
dataset,
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
fixed_output_len: Optional[int] = None,
|
||||
) -> list[tuple[str, str, int, Optional[dict[str, Collection[str]]]]]:
|
||||
sampled_requests: list[tuple[str, int, int, dict[str,
|
||||
Collection[str]]]] = []
|
||||
for data in dataset:
|
||||
if len(sampled_requests) == num_requests:
|
||||
break
|
||||
|
||||
prompt = data["turns"][0][0]['content']
|
||||
|
||||
prompt_token_ids = tokenizer(prompt).input_ids
|
||||
if fixed_output_len is None:
|
||||
# Default max output len is set to 128
|
||||
print("--hf-output-len is not provided. Using default value 128.")
|
||||
fixed_output_len = 128
|
||||
|
||||
prompt_len = len(prompt_token_ids)
|
||||
output_len = fixed_output_len
|
||||
|
||||
assert isinstance(
|
||||
data["images"][0],
|
||||
Image), ("Input image format must be `PIL.Image.Image`, "
|
||||
f"given {type(data['image'])}.")
|
||||
image: Image = data["images"][0]
|
||||
image = image.convert("RGB")
|
||||
image_data = io.BytesIO()
|
||||
image.save(image_data, format='JPEG')
|
||||
image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8")
|
||||
mm_content = {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image_base64}"
|
||||
},
|
||||
}
|
||||
|
||||
sampled_requests.append((prompt, prompt_len, output_len, mm_content))
|
||||
|
||||
return sampled_requests
|
||||
|
||||
|
||||
def sample_hf_requests(
|
||||
dataset_path: str,
|
||||
dataset_subset: Optional[str],
|
||||
dataset_split: str,
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
random_seed: int,
|
||||
fixed_output_len: Optional[int] = None,
|
||||
) -> list[tuple[str, str, int, Optional[dict[str, Collection[str]]]]]:
|
||||
|
||||
# Special case for vision_arena dataset
|
||||
if dataset_path == 'lmarena-ai/vision-arena-bench-v0.1' \
|
||||
and dataset_subset is None:
|
||||
assert dataset_split == "train"
|
||||
dataset = load_dataset(dataset_path,
|
||||
name=dataset_subset,
|
||||
split=dataset_split,
|
||||
streaming=True)
|
||||
dataset = dataset.shuffle(seed=random_seed)
|
||||
return sample_vision_arena_requests(dataset, num_requests, tokenizer,
|
||||
fixed_output_len)
|
||||
|
||||
dataset = load_dataset(dataset_path,
|
||||
name=dataset_subset,
|
||||
split=dataset_split,
|
||||
streaming=True)
|
||||
assert "conversations" in dataset.features, (
|
||||
"HF Dataset must have 'conversations' column.")
|
||||
filter_func = lambda x: len(x["conversations"]) >= 2
|
||||
filtered_dataset = dataset.shuffle(seed=random_seed).filter(filter_func)
|
||||
sampled_requests: list[tuple[str, int, int, dict[str,
|
||||
Collection[str]]]] = []
|
||||
for data in filtered_dataset:
|
||||
if len(sampled_requests) == num_requests:
|
||||
break
|
||||
|
||||
# Tokenize the prompts and completions.
|
||||
prompt = data["conversations"][0]["value"]
|
||||
prompt_token_ids = tokenizer(prompt).input_ids
|
||||
completion = data["conversations"][1]["value"]
|
||||
completion_token_ids = tokenizer(completion).input_ids
|
||||
prompt_len = len(prompt_token_ids)
|
||||
output_len = len(completion_token_ids
|
||||
) if fixed_output_len is None else fixed_output_len
|
||||
if fixed_output_len is None and (prompt_len < 4 or output_len < 4):
|
||||
# Prune too short sequences.
|
||||
continue
|
||||
if fixed_output_len is None and \
|
||||
(prompt_len > 1024 or prompt_len + output_len > 2048):
|
||||
# Prune too long sequences.
|
||||
continue
|
||||
|
||||
if "image" in data and isinstance(data["image"], Image):
|
||||
image: Image = data["image"]
|
||||
image = image.convert("RGB")
|
||||
image_data = io.BytesIO()
|
||||
image.save(image_data, format='JPEG')
|
||||
image_base64 = base64.b64encode(
|
||||
image_data.getvalue()).decode("utf-8")
|
||||
mm_content = {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image_base64}"
|
||||
},
|
||||
}
|
||||
elif "image" in data and isinstance(data["image"], str):
|
||||
if (data["image"].startswith("http://") or \
|
||||
data["image"].startswith("file://")):
|
||||
image_url = data["image"]
|
||||
else:
|
||||
image_url = f"file://{data['image']}"
|
||||
|
||||
mm_content = {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
},
|
||||
}
|
||||
else:
|
||||
mm_content = None
|
||||
|
||||
sampled_requests.append((prompt, prompt_len, output_len, mm_content))
|
||||
|
||||
return sampled_requests
|
||||
|
||||
|
||||
def sample_random_requests(
|
||||
prefix_len: int,
|
||||
input_len: int,
|
||||
output_len: int,
|
||||
num_prompts: int,
|
||||
range_ratio: float,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
) -> list[tuple[str, int, int]]:
|
||||
prefix_token_ids = np.random.randint(0,
|
||||
tokenizer.vocab_size,
|
||||
size=prefix_len).tolist()
|
||||
|
||||
input_lens = np.random.randint(
|
||||
int(input_len * range_ratio),
|
||||
input_len + 1,
|
||||
size=num_prompts,
|
||||
)
|
||||
output_lens = np.random.randint(
|
||||
int(output_len * range_ratio),
|
||||
output_len + 1,
|
||||
size=num_prompts,
|
||||
)
|
||||
offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
|
||||
input_requests = []
|
||||
for i in range(num_prompts):
|
||||
prompt = tokenizer.decode(prefix_token_ids +
|
||||
[(offsets[i] + i + j) % tokenizer.vocab_size
|
||||
for j in range(input_lens[i])])
|
||||
|
||||
input_requests.append((prompt, int(prefix_len + input_lens[i]),
|
||||
int(output_lens[i]), None))
|
||||
|
||||
return input_requests
|
||||
|
||||
|
||||
async def get_request(
|
||||
input_requests: list[tuple[str, int, int]],
|
||||
input_requests: list[SampleRequest],
|
||||
request_rate: float,
|
||||
burstiness: float = 1.0,
|
||||
) -> AsyncGenerator[tuple[str, int, int], None]:
|
||||
) -> AsyncGenerator[SampleRequest, None]:
|
||||
"""
|
||||
Asynchronously generates requests at a specified rate
|
||||
with OPTIONAL burstiness.
|
||||
|
||||
Args:
|
||||
input_requests:
|
||||
A list of input requests, each represented as a tuple.
|
||||
A list of input requests, each represented as a SampleRequest.
|
||||
request_rate:
|
||||
The rate at which requests are generated (requests/s).
|
||||
burstiness (optional):
|
||||
@ -422,7 +113,7 @@ async def get_request(
|
||||
in more bursty requests, while a higher burstiness value
|
||||
(burstiness > 1) results in a more uniform arrival of requests.
|
||||
"""
|
||||
input_requests = iter(input_requests)
|
||||
input_requests: Iterable[SampleRequest] = iter(input_requests)
|
||||
|
||||
# Calculate scale parameter theta to maintain the desired request_rate.
|
||||
assert burstiness > 0, (
|
||||
@ -444,7 +135,7 @@ async def get_request(
|
||||
|
||||
|
||||
def calculate_metrics(
|
||||
input_requests: list[tuple[str, int, int]],
|
||||
input_requests: list[SampleRequest],
|
||||
outputs: list[RequestFuncOutput],
|
||||
dur_s: float,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
@ -475,7 +166,7 @@ def calculate_metrics(
|
||||
tokenizer(outputs[i].generated_text,
|
||||
add_special_tokens=False).input_ids)
|
||||
actual_output_lens.append(output_len)
|
||||
total_input += input_requests[i][1]
|
||||
total_input += input_requests[i].prompt_len
|
||||
tpot = 0
|
||||
if output_len > 1:
|
||||
latency_minus_ttft = outputs[i].latency - outputs[i].ttft
|
||||
@ -558,19 +249,18 @@ async def benchmark(
|
||||
model_id: str,
|
||||
model_name: str,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
input_requests: list[tuple[str, int, int]],
|
||||
input_requests: list[SampleRequest],
|
||||
logprobs: Optional[int],
|
||||
best_of: int,
|
||||
request_rate: float,
|
||||
burstiness: float,
|
||||
disable_tqdm: bool,
|
||||
profile: bool,
|
||||
selected_percentile_metrics: list[str],
|
||||
selected_percentiles: list[str],
|
||||
selected_percentiles: list[float],
|
||||
ignore_eos: bool,
|
||||
goodput_config_dict: dict[str, float],
|
||||
max_concurrency: Optional[int],
|
||||
lora_modules: Optional[list[str]],
|
||||
lora_modules: Optional[Iterable[str]],
|
||||
):
|
||||
if backend in ASYNC_REQUEST_FUNCS:
|
||||
request_func = ASYNC_REQUEST_FUNCS[backend]
|
||||
@ -578,12 +268,16 @@ async def benchmark(
|
||||
raise ValueError(f"Unknown backend: {backend}")
|
||||
|
||||
print("Starting initial single prompt test run...")
|
||||
test_prompt, test_prompt_len, test_output_len, test_mm_content = (
|
||||
input_requests[0])
|
||||
test_prompt, test_prompt_len, test_output_len, test_mm_content = \
|
||||
input_requests[0].prompt, input_requests[0].prompt_len, \
|
||||
input_requests[0].expected_output_len, \
|
||||
input_requests[0].multi_modal_data
|
||||
|
||||
if backend != "openai-chat" and test_mm_content is not None:
|
||||
# multi-modal benchmark is only available on OpenAI Chat backend.
|
||||
raise ValueError(
|
||||
"Multi-modal content is only supported on 'openai-chat' backend.")
|
||||
assert test_mm_content is None or isinstance(test_mm_content, dict)
|
||||
test_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
model_name=model_name,
|
||||
@ -592,7 +286,6 @@ async def benchmark(
|
||||
prompt_len=test_prompt_len,
|
||||
output_len=test_output_len,
|
||||
logprobs=logprobs,
|
||||
best_of=best_of,
|
||||
multi_modal_content=test_mm_content,
|
||||
ignore_eos=ignore_eos,
|
||||
)
|
||||
@ -608,7 +301,8 @@ async def benchmark(
|
||||
if lora_modules:
|
||||
# For each input request, choose a LoRA module at random.
|
||||
lora_modules = iter(
|
||||
[random.choice(lora_modules) for _ in range(len(input_requests))])
|
||||
[random.choice(lora_modules) \
|
||||
for _ in range(len(input_requests))])
|
||||
|
||||
if profile:
|
||||
print("Starting profiler...")
|
||||
@ -619,7 +313,6 @@ async def benchmark(
|
||||
prompt_len=test_prompt_len,
|
||||
output_len=test_output_len,
|
||||
logprobs=logprobs,
|
||||
best_of=best_of,
|
||||
multi_modal_content=test_mm_content,
|
||||
ignore_eos=ignore_eos)
|
||||
profile_output = await request_func(request_func_input=profile_input)
|
||||
@ -655,7 +348,9 @@ async def benchmark(
|
||||
benchmark_start_time = time.perf_counter()
|
||||
tasks: list[asyncio.Task] = []
|
||||
async for request in get_request(input_requests, request_rate, burstiness):
|
||||
prompt, prompt_len, output_len, mm_content = request
|
||||
prompt, prompt_len, output_len, mm_content = request.prompt, \
|
||||
request.prompt_len, request.expected_output_len, \
|
||||
request.multi_modal_data
|
||||
req_model_id, req_model_name = model_id, model_name
|
||||
if lora_modules:
|
||||
req_lora_module = next(lora_modules)
|
||||
@ -668,7 +363,6 @@ async def benchmark(
|
||||
prompt_len=prompt_len,
|
||||
output_len=output_len,
|
||||
logprobs=logprobs,
|
||||
best_of=best_of,
|
||||
multi_modal_content=mm_content,
|
||||
ignore_eos=ignore_eos)
|
||||
tasks.append(
|
||||
@ -686,7 +380,6 @@ async def benchmark(
|
||||
prompt_len=test_prompt_len,
|
||||
output_len=test_output_len,
|
||||
logprobs=logprobs,
|
||||
best_of=best_of,
|
||||
)
|
||||
profile_output = await request_func(request_func_input=profile_input)
|
||||
if profile_output.success:
|
||||
@ -872,76 +565,72 @@ def main(args: argparse.Namespace):
|
||||
"Please specify '--dataset-name' and the corresponding "
|
||||
"'--dataset-path' if required.")
|
||||
|
||||
elif args.dataset_name == "sharegpt":
|
||||
input_requests = sample_sharegpt_requests(
|
||||
dataset_path=args.dataset_path,
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
fixed_output_len=args.sharegpt_output_len,
|
||||
)
|
||||
|
||||
elif args.dataset_name == "burstgpt":
|
||||
input_requests = sample_burstgpt_requests(
|
||||
dataset_path=args.dataset_path,
|
||||
num_requests=args.num_prompts,
|
||||
random_seed=args.seed,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
elif args.dataset_name == "sonnet":
|
||||
# Do not format the prompt, pass to message directly
|
||||
if args.dataset_name == "sonnet":
|
||||
dataset = SonnetDataset(dataset_path=args.dataset_path)
|
||||
# For the "sonnet" dataset, formatting depends on the backend.
|
||||
if args.backend == "openai-chat":
|
||||
input_requests = sample_sonnet_requests(
|
||||
dataset_path=args.dataset_path,
|
||||
num_requests=args.num_prompts,
|
||||
input_len=args.sonnet_input_len,
|
||||
output_len=args.sonnet_output_len,
|
||||
prefix_len=args.sonnet_prefix_len,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
input_requests = [(prompt, prompt_len, output_len, None)
|
||||
for prompt, prompt_formatted, prompt_len,
|
||||
output_len, _ in input_requests]
|
||||
input_requests = dataset.sample(num_requests=args.num_prompts,
|
||||
input_len=args.sonnet_input_len,
|
||||
output_len=args.sonnet_output_len,
|
||||
prefix_len=args.sonnet_prefix_len,
|
||||
tokenizer=tokenizer,
|
||||
return_prompt_formatted=False)
|
||||
else:
|
||||
assert (
|
||||
tokenizer.chat_template or tokenizer.default_chat_template
|
||||
), "Tokenizer/model must have chat template for sonnet dataset."
|
||||
input_requests = sample_sonnet_requests(
|
||||
dataset_path=args.dataset_path,
|
||||
num_requests=args.num_prompts,
|
||||
input_len=args.sonnet_input_len,
|
||||
output_len=args.sonnet_output_len,
|
||||
prefix_len=args.sonnet_prefix_len,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
input_requests = [(prompt_formatted, prompt_len, output_len, None)
|
||||
for prompt, prompt_formatted, prompt_len,
|
||||
output_len, _ in input_requests]
|
||||
assert tokenizer.chat_template or tokenizer.default_chat_template, (
|
||||
"Tokenizer/model must have chat template for sonnet dataset.")
|
||||
input_requests = dataset.sample(num_requests=args.num_prompts,
|
||||
input_len=args.sonnet_input_len,
|
||||
output_len=args.sonnet_output_len,
|
||||
prefix_len=args.sonnet_prefix_len,
|
||||
tokenizer=tokenizer,
|
||||
return_prompt_formatted=True)
|
||||
|
||||
elif args.dataset_name == "hf":
|
||||
input_requests = sample_hf_requests(
|
||||
# Choose between VisionArenaDataset
|
||||
# and HuggingFaceDataset based on provided parameters.
|
||||
dataset_class = (VisionArenaDataset if args.dataset_path
|
||||
== VisionArenaDataset.VISION_ARENA_DATASET_PATH
|
||||
and args.hf_subset is None else HuggingFaceDataset)
|
||||
input_requests = dataset_class(
|
||||
dataset_path=args.dataset_path,
|
||||
dataset_subset=args.hf_subset,
|
||||
dataset_split=args.hf_split,
|
||||
).sample(
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
random_seed=args.seed,
|
||||
fixed_output_len=args.hf_output_len,
|
||||
)
|
||||
|
||||
elif args.dataset_name == "random":
|
||||
input_requests = sample_random_requests(
|
||||
prefix_len=args.random_prefix_len,
|
||||
input_len=args.random_input_len,
|
||||
output_len=args.random_output_len,
|
||||
num_prompts=args.num_prompts,
|
||||
range_ratio=args.random_range_ratio,
|
||||
tokenizer=tokenizer,
|
||||
output_len=args.hf_output_len,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset: {args.dataset_name}")
|
||||
# For datasets that follow a similar structure, use a mapping.
|
||||
dataset_mapping = {
|
||||
"sharegpt":
|
||||
lambda: ShareGPTDataset(random_seed=args.seed,
|
||||
dataset_path=args.dataset_path).sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=args.num_prompts,
|
||||
output_len=args.sharegpt_output_len,
|
||||
),
|
||||
"burstgpt":
|
||||
lambda: BurstGPTDataset(random_seed=args.seed,
|
||||
dataset_path=args.dataset_path).
|
||||
sample(tokenizer=tokenizer, num_requests=args.num_prompts),
|
||||
"random":
|
||||
lambda: RandomDataset(dataset_path=args.dataset_path).sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=args.num_prompts,
|
||||
prefix_len=args.random_prefix_len,
|
||||
input_len=args.random_input_len,
|
||||
output_len=args.random_output_len,
|
||||
range_ratio=args.random_range_ratio,
|
||||
)
|
||||
}
|
||||
|
||||
try:
|
||||
input_requests = dataset_mapping[args.dataset_name]()
|
||||
except KeyError as err:
|
||||
raise ValueError(f"Unknown dataset: {args.dataset_name}") from err
|
||||
goodput_config_dict = check_goodput_args(args)
|
||||
|
||||
# Avoid GC processing "static" data - reduce pause times.
|
||||
@ -958,7 +647,6 @@ def main(args: argparse.Namespace):
|
||||
tokenizer=tokenizer,
|
||||
input_requests=input_requests,
|
||||
logprobs=args.logprobs,
|
||||
best_of=args.best_of,
|
||||
request_rate=args.request_rate,
|
||||
burstiness=args.burstiness,
|
||||
disable_tqdm=args.disable_tqdm,
|
||||
@ -983,7 +671,6 @@ def main(args: argparse.Namespace):
|
||||
result_json["backend"] = backend
|
||||
result_json["model_id"] = model_id
|
||||
result_json["tokenizer_id"] = tokenizer_id
|
||||
result_json["best_of"] = args.best_of
|
||||
result_json["num_prompts"] = args.num_prompts
|
||||
|
||||
# Metadata
|
||||
@ -997,6 +684,15 @@ def main(args: argparse.Namespace):
|
||||
"Invalid metadata format. Please use KEY=VALUE format."
|
||||
)
|
||||
|
||||
if not args.save_detailed:
|
||||
# Remove fields with too many data points
|
||||
for field in [
|
||||
"input_lens", "output_lens", "ttfts", "itls",
|
||||
"generated_texts", "errors"
|
||||
]:
|
||||
if field in result_json:
|
||||
del result_json[field]
|
||||
|
||||
# Traffic
|
||||
result_json["request_rate"] = (args.request_rate if args.request_rate
|
||||
< float("inf") else "inf")
|
||||
@ -1081,13 +777,6 @@ if __name__ == "__main__":
|
||||
help=
|
||||
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
|
||||
)
|
||||
parser.add_argument(
|
||||
"--best-of",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Generates `best_of` sequences per prompt and "
|
||||
"returns the best one.",
|
||||
)
|
||||
parser.add_argument("--use-beam-search", action="store_true")
|
||||
parser.add_argument(
|
||||
"--num-prompts",
|
||||
@ -1148,6 +837,12 @@ if __name__ == "__main__":
|
||||
action="store_true",
|
||||
help="Specify to save benchmark results to a json file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-detailed",
|
||||
action="store_true",
|
||||
help="When saving the results, whether to include per request "
|
||||
"information such as response, error, ttfs, tpots, etc.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata",
|
||||
metavar="KEY=VALUE",
|
||||
@ -1312,4 +1007,5 @@ if __name__ == "__main__":
|
||||
"script chooses a LoRA module at random.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
r"""Benchmark online serving throughput with guided decoding.
|
||||
r"""Benchmark online serving throughput with structured outputs.
|
||||
|
||||
On the server side, run one of the following commands:
|
||||
(vLLM OpenAI API server)
|
||||
@ -9,12 +9,12 @@ On the server side, run one of the following commands:
|
||||
./launch_tgi_server.sh <your_model> <max_batch_total_tokens>
|
||||
|
||||
On the client side, run:
|
||||
python benchmarks/benchmark_serving_guided.py \
|
||||
python benchmarks/benchmark_serving_structured_output.py \
|
||||
--backend <backend> \
|
||||
--model <your_model> \
|
||||
--dataset json \
|
||||
--guided-decoding-ratio 1.0 \
|
||||
--guided-decoding-backend xgrammar \
|
||||
--structured-output-ratio 1.0 \
|
||||
--structured-output-backend xgrammar \
|
||||
--request-rate 10 \
|
||||
--num-prompts 1000
|
||||
|
||||
@ -24,11 +24,13 @@ On the client side, run:
|
||||
"""
|
||||
import argparse
|
||||
import asyncio
|
||||
import copy
|
||||
import dataclasses
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
@ -52,6 +54,9 @@ try:
|
||||
except ImportError:
|
||||
from argparse import ArgumentParser as FlexibleArgumentParser
|
||||
|
||||
from vllm.v1.structured_output.utils import (
|
||||
has_xgrammar_unsupported_json_features)
|
||||
|
||||
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
|
||||
|
||||
|
||||
@ -106,24 +111,43 @@ class SampleRequest:
|
||||
|
||||
def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||
args: argparse.Namespace) -> list[SampleRequest]:
|
||||
if args.dataset == 'json':
|
||||
if args.dataset == 'json' or args.dataset == 'json-unique':
|
||||
if args.json_schema_path is None:
|
||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||
args.json_schema_path = os.path.join(dir_path,
|
||||
"structured_schemas",
|
||||
"structured_schema_1.json")
|
||||
json_schemas = []
|
||||
with open(args.json_schema_path) as f:
|
||||
schema = json.load(f)
|
||||
prompt = f"Generate an example of a user profile given the following schema: {json.dumps(schema)}" # noqa: E501
|
||||
input_len = len(tokenizer(prompt).input_ids)
|
||||
print(f"Input length of the prompt: {input_len} tokens")
|
||||
|
||||
if args.dataset == 'json-unique':
|
||||
json_schemas = [
|
||||
copy.deepcopy(schema) for _ in range(args.num_prompts)
|
||||
]
|
||||
for i in range(len(json_schemas)):
|
||||
json_schemas[i]["properties"][
|
||||
f"__optional_field_{uuid.uuid4()}"] = {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"An unique optional field to avoid cached schemas"
|
||||
}
|
||||
|
||||
def gen_prompt(index: int):
|
||||
schema = json_schemas[index % len(json_schemas)]
|
||||
return f"Generate an example of a user profile given the following schema: {json.dumps(schema)}" # noqa: E501
|
||||
|
||||
def get_schema(index: int):
|
||||
return json_schemas[index % len(json_schemas)]
|
||||
|
||||
requests = [
|
||||
SampleRequest(prompt=prompt,
|
||||
prompt_len=input_len,
|
||||
SampleRequest(prompt=gen_prompt(i),
|
||||
prompt_len=len(tokenizer(gen_prompt(i)).input_ids),
|
||||
expected_output_len=args.output_len,
|
||||
schema=schema,
|
||||
schema=get_schema(i),
|
||||
structure_type=args.structure_type)
|
||||
for _ in range(args.num_prompts)
|
||||
for i in range(args.num_prompts)
|
||||
]
|
||||
|
||||
elif args.dataset == "grammar":
|
||||
@ -191,7 +215,17 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||
requests: list[SampleRequest] = []
|
||||
dataset = datasets.load_dataset("NousResearch/json-mode-eval",
|
||||
split="train")
|
||||
print(f"dataset has {len(dataset)} entries")
|
||||
full_dataset_len = len(dataset)
|
||||
|
||||
def _filter_func(item):
|
||||
import json
|
||||
schema = json.loads(item["schema"])
|
||||
return not has_xgrammar_unsupported_json_features(schema)
|
||||
|
||||
dataset = dataset.filter(_filter_func)
|
||||
num_filtered_out = full_dataset_len - len(dataset)
|
||||
print(f"dataset has {len(dataset)} entries after filtering "
|
||||
f"out {num_filtered_out} entries with unsupported features")
|
||||
len_dataset = len(dataset)
|
||||
for data_point_idx in range(args.num_prompts):
|
||||
idx = data_point_idx
|
||||
@ -220,21 +254,21 @@ async def get_request(
|
||||
burstiness: float = 1.0,
|
||||
) -> AsyncGenerator[tuple[int, SampleRequest], None]:
|
||||
"""
|
||||
Asynchronously generates requests at a specified rate
|
||||
Asynchronously generates requests at a specified rate
|
||||
with OPTIONAL burstiness.
|
||||
|
||||
|
||||
Args:
|
||||
input_requests:
|
||||
input_requests:
|
||||
A list of input requests, each represented as a tuple.
|
||||
request_rate:
|
||||
request_rate:
|
||||
The rate at which requests are generated (requests/s).
|
||||
burstiness (optional):
|
||||
The burstiness factor of the request generation.
|
||||
burstiness (optional):
|
||||
The burstiness factor of the request generation.
|
||||
Only takes effect when request_rate is not inf.
|
||||
Default value is 1, which follows a Poisson process.
|
||||
Otherwise, the request intervals follow a gamma distribution.
|
||||
A lower burstiness value (0 < burstiness < 1) results
|
||||
in more bursty requests, while a higher burstiness value
|
||||
A lower burstiness value (0 < burstiness < 1) results
|
||||
in more bursty requests, while a higher burstiness value
|
||||
(burstiness > 1) results in a more uniform arrival of requests.
|
||||
"""
|
||||
input_requests = iter(input_requests)
|
||||
@ -378,8 +412,8 @@ async def benchmark(
|
||||
selected_percentiles: list[str],
|
||||
ignore_eos: bool,
|
||||
max_concurrency: Optional[int],
|
||||
guided_decoding_ratio: float,
|
||||
guided_decoding_backend: str,
|
||||
structured_output_ratio: float,
|
||||
structured_output_backend: str,
|
||||
goodput_config_dict: Optional[dict[str, float]] = None,
|
||||
):
|
||||
if backend in ASYNC_REQUEST_FUNCS:
|
||||
@ -391,16 +425,18 @@ async def benchmark(
|
||||
extra_body = {}
|
||||
# Add the schema to the extra_body
|
||||
extra_body[request.structure_type] = request.schema
|
||||
# Add the specific guided_decoding_backend
|
||||
extra_body["guided_decoding_backend"] = guided_decoding_backend
|
||||
# Add the specific structured_output_backend
|
||||
extra_body["guided_decoding_backend"] = structured_output_backend
|
||||
return extra_body
|
||||
|
||||
print("Starting initial single prompt test run...")
|
||||
guided_decoding_req_idx = random.sample(
|
||||
structured_output_req_idx = random.sample(
|
||||
range(len(input_requests)),
|
||||
int(len(input_requests) * guided_decoding_ratio))
|
||||
int(len(input_requests) * structured_output_ratio))
|
||||
|
||||
test_request = input_requests[0]
|
||||
test_req_extra_body = (prepare_extra_body(test_request)
|
||||
if 0 in structured_output_req_idx else None)
|
||||
test_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
prompt=test_request.prompt,
|
||||
@ -408,7 +444,7 @@ async def benchmark(
|
||||
prompt_len=test_request.prompt_len,
|
||||
output_len=test_request.expected_output_len,
|
||||
ignore_eos=ignore_eos,
|
||||
extra_body=prepare_extra_body(test_request),
|
||||
extra_body=test_req_extra_body,
|
||||
)
|
||||
test_output = await request_func(request_func_input=test_input)
|
||||
if not test_output.success:
|
||||
@ -427,7 +463,7 @@ async def benchmark(
|
||||
prompt_len=test_request.prompt_len,
|
||||
output_len=test_request.expected_output_len,
|
||||
ignore_eos=ignore_eos,
|
||||
extra_body=prepare_extra_body(test_request),
|
||||
extra_body=test_req_extra_body,
|
||||
)
|
||||
profile_output = await request_func(request_func_input=profile_input)
|
||||
if profile_output.success:
|
||||
@ -465,7 +501,7 @@ async def benchmark(
|
||||
async for i, request in get_request(input_requests, request_rate,
|
||||
burstiness):
|
||||
extra_body = prepare_extra_body(
|
||||
request) if i in guided_decoding_req_idx else None
|
||||
request) if i in structured_output_req_idx else None
|
||||
request_func_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
prompt=request.prompt,
|
||||
@ -696,8 +732,11 @@ def main(args: argparse.Namespace):
|
||||
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
|
||||
base_url = f"http://{args.host}:{args.port}"
|
||||
|
||||
tokenizer = get_tokenizer(tokenizer_id,
|
||||
trust_remote_code=args.trust_remote_code)
|
||||
tokenizer = get_tokenizer(
|
||||
tokenizer_id,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
tokenizer_mode=args.tokenizer_mode,
|
||||
)
|
||||
|
||||
if args.dataset == 'grammar':
|
||||
args.structure_type = 'guided_grammar'
|
||||
@ -708,10 +747,10 @@ def main(args: argparse.Namespace):
|
||||
else:
|
||||
args.structure_type = 'guided_json'
|
||||
|
||||
if args.no_guided_decoding:
|
||||
args.guided_decoding_ratio = 0
|
||||
if args.no_structured_output:
|
||||
args.structured_output_ratio = 0
|
||||
if args.save_results:
|
||||
result_file_name = f'{args.guided_decoding_ratio}guided'
|
||||
result_file_name = f'{args.structured_output_ratio}guided'
|
||||
result_file_name += f"_{backend}"
|
||||
result_file_name += f"_{args.request_rate}qps"
|
||||
result_file_name += f"_{args.model.split('/')[-1]}"
|
||||
@ -744,8 +783,8 @@ def main(args: argparse.Namespace):
|
||||
],
|
||||
ignore_eos=args.ignore_eos,
|
||||
max_concurrency=args.max_concurrency,
|
||||
guided_decoding_ratio=args.guided_decoding_ratio,
|
||||
guided_decoding_backend=args.guided_decoding_backend,
|
||||
structured_output_ratio=args.structured_output_ratio,
|
||||
structured_output_backend=args.structured_output_backend,
|
||||
goodput_config_dict=goodput_config_dict,
|
||||
))
|
||||
|
||||
@ -806,10 +845,12 @@ if __name__ == "__main__":
|
||||
default="/v1/completions",
|
||||
help="API endpoint.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
default='json',
|
||||
choices=['json', 'grammar', 'regex', 'choice', 'xgrammar_bench'])
|
||||
parser.add_argument("--dataset",
|
||||
default='json',
|
||||
choices=[
|
||||
'json', 'json-unique', 'grammar', 'regex',
|
||||
'choice', 'xgrammar_bench'
|
||||
])
|
||||
parser.add_argument("--json_schema_path",
|
||||
type=str,
|
||||
default=None,
|
||||
@ -838,6 +879,13 @@ if __name__ == "__main__":
|
||||
help=
|
||||
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer-mode",
|
||||
type=str,
|
||||
default="auto",
|
||||
help=
|
||||
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-prompts",
|
||||
type=int,
|
||||
@ -943,19 +991,20 @@ if __name__ == "__main__":
|
||||
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
|
||||
"and the blog: https://hao-ai-lab.github.io/blogs/distserve")
|
||||
|
||||
parser.add_argument("--no-guided-decoding",
|
||||
parser.add_argument("--no-structured-output",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Whether to disable JSON decoding or not.")
|
||||
parser.add_argument("--guided-decoding-ratio",
|
||||
parser.add_argument("--structured-output-ratio",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Ratio of Guided Decoding requests")
|
||||
parser.add_argument("--guided-decoding-backend",
|
||||
type=str,
|
||||
choices=["outlines", "lm-format-enforcer", "xgrammar"],
|
||||
default="xgrammar",
|
||||
help="Backend to use for guided decoding")
|
||||
help="Ratio of Structured Outputs requests")
|
||||
parser.add_argument(
|
||||
"--structured-output-backend",
|
||||
type=str,
|
||||
choices=["outlines", "lm-format-enforcer", "xgrammar", "guidance"],
|
||||
default="xgrammar",
|
||||
help="Backend to use for structured outputs")
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@ -6,13 +6,15 @@ import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from functools import cache
|
||||
from typing import Any, Optional
|
||||
import warnings
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import uvloop
|
||||
from benchmark_dataset import (BurstGPTDataset, HuggingFaceDataset,
|
||||
RandomDataset, SampleRequest, ShareGPTDataset,
|
||||
SonnetDataset, VisionArenaDataset)
|
||||
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||
PreTrainedTokenizerBase)
|
||||
@ -20,155 +22,19 @@ from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.entrypoints.openai.api_server import (
|
||||
build_async_engine_client_from_engine_args)
|
||||
from vllm.inputs import TextPrompt
|
||||
from vllm.inputs import TextPrompt, TokensPrompt
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.utils import get_adapter_absolute_path
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer
|
||||
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SampleRequest:
|
||||
"""A class representing a single inference request for benchmarking.
|
||||
|
||||
Attributes:
|
||||
prompt: The input text prompt for the model.
|
||||
prompt_len: The length of the prompt in tokens.
|
||||
expected_output_len: The expected length of the output in tokens.
|
||||
multi_modal_data: Optional dictionary containing multi-modal data (e.g.
|
||||
images).
|
||||
lora_request: Optional LoRARequest specifying the LoRA to use.
|
||||
"""
|
||||
prompt: str
|
||||
prompt_len: int
|
||||
expected_output_len: int
|
||||
multi_modal_data: Optional[MultiModalDataDict] = None
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
|
||||
|
||||
def _get_prompt_for_image_model(question: str, *, model: str) -> str:
|
||||
"""Prepend and append special tokens around the question to form a prompt.
|
||||
|
||||
Args:
|
||||
question: The input question text to wrap with special tokens
|
||||
model: The name of the model being used, to determine which special
|
||||
tokens to add
|
||||
|
||||
Returns:
|
||||
The formatted prompt string with appropriate special tokens for the
|
||||
model
|
||||
|
||||
Raises:
|
||||
ValueError: If an unsupported model name is provided
|
||||
"""
|
||||
model = model.lower()
|
||||
if "pixtral" in model:
|
||||
return f"<s>[INST]{question}\n[IMG][/INST]"
|
||||
raise ValueError(f"Unsupported model {model}")
|
||||
|
||||
|
||||
@cache
|
||||
def lora_path_on_disk(lora_path: str) -> str:
|
||||
return get_adapter_absolute_path(lora_path)
|
||||
|
||||
|
||||
lora_tokenizer_cache: dict[int, AnyTokenizer] = {}
|
||||
|
||||
|
||||
def get_random_lora_request(
|
||||
args: argparse.Namespace
|
||||
) -> tuple[LoRARequest, Optional[AnyTokenizer]]:
|
||||
global lora_tokenizer_cache
|
||||
lora_id = random.randint(1, args.max_loras)
|
||||
lora_request = LoRARequest(lora_name=str(lora_id),
|
||||
lora_int_id=lora_id,
|
||||
lora_path=lora_path_on_disk(args.lora_path))
|
||||
if lora_id not in lora_tokenizer_cache:
|
||||
lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request)
|
||||
return lora_request, lora_tokenizer_cache[lora_id]
|
||||
|
||||
|
||||
def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||
args: argparse.Namespace) -> list[SampleRequest]:
|
||||
|
||||
dataset_path: str = args.dataset
|
||||
num_requests: int = args.num_prompts
|
||||
fixed_output_len: Optional[int] = args.output_len
|
||||
model: str = args.model
|
||||
if fixed_output_len is not None and fixed_output_len < 4:
|
||||
raise ValueError("output_len too small")
|
||||
|
||||
# Load the dataset.
|
||||
with open(dataset_path) as f:
|
||||
dataset = json.load(f)
|
||||
# Filter out the conversations with less than 2 turns.
|
||||
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
||||
# Shuffle the dataset.
|
||||
random.shuffle(dataset)
|
||||
|
||||
# Filter out sequences that are too long or too short
|
||||
filtered_dataset: list[SampleRequest] = []
|
||||
for data in tqdm(dataset,
|
||||
total=len(filtered_dataset),
|
||||
desc="sampling requests"):
|
||||
if len(filtered_dataset) == num_requests:
|
||||
break
|
||||
|
||||
# Only keep the first two turns of each conversation.
|
||||
prompt = data["conversations"][0]["value"]
|
||||
completion = data["conversations"][1]["value"]
|
||||
|
||||
multi_modal_data: Optional[MultiModalDataDict] = None
|
||||
if "image" in data:
|
||||
multi_modal_data = multi_modal_data or {}
|
||||
image_path = data["image"]
|
||||
# TODO(vllm-project/vllm/issues/9778): Support multiple images.
|
||||
assert isinstance(image_path,
|
||||
str), "Only support single image input"
|
||||
try:
|
||||
multi_modal_data["image"] = Image.open(image_path).convert(
|
||||
"RGB")
|
||||
except FileNotFoundError:
|
||||
# Ignore datapoint where asset is missing
|
||||
continue
|
||||
prompt = _get_prompt_for_image_model(question=prompt, model=model)
|
||||
|
||||
request_tokenizer = tokenizer
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
if args.enable_lora:
|
||||
lora_request, lora_tokenizer = get_random_lora_request(args)
|
||||
if lora_tokenizer:
|
||||
request_tokenizer = lora_tokenizer
|
||||
|
||||
# Tokenize the prompts and completions.
|
||||
prompt_token_ids = request_tokenizer(prompt).input_ids
|
||||
completion_token_ids = request_tokenizer(completion).input_ids
|
||||
prompt_len = len(prompt_token_ids)
|
||||
output_len = len(completion_token_ids
|
||||
) if fixed_output_len is None else fixed_output_len
|
||||
if prompt_len < 4 or output_len < 4:
|
||||
# Prune too short sequences.
|
||||
continue
|
||||
if prompt_len > 1024 or prompt_len + output_len > 2048:
|
||||
# Prune too long sequences.
|
||||
continue
|
||||
filtered_dataset.append(
|
||||
SampleRequest(prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
multi_modal_data=multi_modal_data,
|
||||
lora_request=lora_request))
|
||||
|
||||
return filtered_dataset
|
||||
|
||||
|
||||
def run_vllm(
|
||||
requests: list[SampleRequest],
|
||||
n: int,
|
||||
engine_args: EngineArgs,
|
||||
) -> float:
|
||||
disable_detokenize: bool = False,
|
||||
) -> tuple[float, Optional[list[RequestOutput]]]:
|
||||
from vllm import LLM, SamplingParams
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
assert all(
|
||||
@ -178,10 +44,13 @@ def run_vllm(
|
||||
"Please ensure that max_model_len is greater than the sum of"
|
||||
" prompt_len and expected_output_len for all requests.")
|
||||
# Add the requests to the engine.
|
||||
prompts: list[TextPrompt] = []
|
||||
prompts: list[Union[TextPrompt, TokensPrompt]] = []
|
||||
sampling_params: list[SamplingParams] = []
|
||||
for request in requests:
|
||||
prompts.append(
|
||||
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
|
||||
multi_modal_data=request.multi_modal_data)
|
||||
if "prompt_token_ids" in request.prompt else \
|
||||
TextPrompt(prompt=request.prompt,
|
||||
multi_modal_data=request.multi_modal_data))
|
||||
sampling_params.append(
|
||||
@ -191,6 +60,7 @@ def run_vllm(
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
detokenize=not disable_detokenize,
|
||||
))
|
||||
lora_requests: Optional[list[LoRARequest]] = None
|
||||
if engine_args.enable_lora:
|
||||
@ -198,12 +68,13 @@ def run_vllm(
|
||||
|
||||
use_beam_search = False
|
||||
|
||||
outputs = None
|
||||
if not use_beam_search:
|
||||
start = time.perf_counter()
|
||||
llm.generate(prompts,
|
||||
sampling_params,
|
||||
lora_request=lora_requests,
|
||||
use_tqdm=True)
|
||||
outputs = llm.generate(prompts,
|
||||
sampling_params,
|
||||
lora_request=lora_requests,
|
||||
use_tqdm=True)
|
||||
end = time.perf_counter()
|
||||
else:
|
||||
assert lora_requests is None, "BeamSearch API does not support LoRA"
|
||||
@ -221,7 +92,46 @@ def run_vllm(
|
||||
ignore_eos=True,
|
||||
))
|
||||
end = time.perf_counter()
|
||||
return end - start
|
||||
return end - start, outputs
|
||||
|
||||
|
||||
def run_vllm_chat(
|
||||
requests: list[SampleRequest],
|
||||
n: int,
|
||||
engine_args: EngineArgs,
|
||||
disable_detokenize: bool = False) -> tuple[float, list[RequestOutput]]:
|
||||
"""
|
||||
Run vLLM chat benchmark. This function is recommended ONLY for benchmarking
|
||||
multimodal models as it properly handles multimodal inputs and chat
|
||||
formatting. For non-multimodal models, use run_vllm() instead.
|
||||
"""
|
||||
from vllm import LLM, SamplingParams
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
|
||||
assert all(
|
||||
llm.llm_engine.model_config.max_model_len >= (
|
||||
request.prompt_len + request.expected_output_len)
|
||||
for request in requests), (
|
||||
"Please ensure that max_model_len is greater than the sum of "
|
||||
"prompt_len and expected_output_len for all requests.")
|
||||
|
||||
prompts = []
|
||||
sampling_params: list[SamplingParams] = []
|
||||
for request in requests:
|
||||
prompts.append(request.prompt)
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
detokenize=not disable_detokenize,
|
||||
))
|
||||
start = time.perf_counter()
|
||||
outputs = llm.chat(prompts, sampling_params, use_tqdm=True)
|
||||
end = time.perf_counter()
|
||||
return end - start, outputs
|
||||
|
||||
|
||||
async def run_vllm_async(
|
||||
@ -229,6 +139,7 @@ async def run_vllm_async(
|
||||
n: int,
|
||||
engine_args: AsyncEngineArgs,
|
||||
disable_frontend_multiprocessing: bool = False,
|
||||
disable_detokenize: bool = False,
|
||||
) -> float:
|
||||
from vllm import SamplingParams
|
||||
|
||||
@ -242,11 +153,14 @@ async def run_vllm_async(
|
||||
" prompt_len and expected_output_len for all requests.")
|
||||
|
||||
# Add the requests to the engine.
|
||||
prompts: list[TextPrompt] = []
|
||||
prompts: list[Union[TextPrompt, TokensPrompt]] = []
|
||||
sampling_params: list[SamplingParams] = []
|
||||
lora_requests: list[Optional[LoRARequest]] = []
|
||||
for request in requests:
|
||||
prompts.append(
|
||||
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
|
||||
multi_modal_data=request.multi_modal_data)
|
||||
if "prompt_token_ids" in request.prompt else \
|
||||
TextPrompt(prompt=request.prompt,
|
||||
multi_modal_data=request.multi_modal_data))
|
||||
sampling_params.append(
|
||||
@ -256,6 +170,7 @@ async def run_vllm_async(
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
detokenize=not disable_detokenize,
|
||||
))
|
||||
lora_requests.append(request.lora_request)
|
||||
|
||||
@ -282,6 +197,7 @@ def run_hf(
|
||||
n: int,
|
||||
max_batch_size: int,
|
||||
trust_remote_code: bool,
|
||||
disable_detokenize: bool = False,
|
||||
) -> float:
|
||||
llm = AutoModelForCausalLM.from_pretrained(
|
||||
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
|
||||
@ -321,8 +237,9 @@ def run_hf(
|
||||
use_cache=True,
|
||||
max_new_tokens=max_output_len,
|
||||
)
|
||||
# Include the decoding time.
|
||||
tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
|
||||
if not disable_detokenize:
|
||||
# Include the decoding time.
|
||||
tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
|
||||
pbar.update(len(batch))
|
||||
|
||||
# Clear the batch.
|
||||
@ -369,58 +286,68 @@ def save_to_pytorch_benchmark_format(args: argparse.Namespace,
|
||||
write_to_json(pt_file, pt_records)
|
||||
|
||||
|
||||
def get_requests(args, tokenizer):
|
||||
# Common parameters for all dataset types.
|
||||
common_kwargs = {
|
||||
"dataset_path": args.dataset_path,
|
||||
"random_seed": args.seed,
|
||||
}
|
||||
sample_kwargs = {
|
||||
"tokenizer": tokenizer,
|
||||
"lora_path": args.lora_path,
|
||||
"max_loras": args.max_loras,
|
||||
"num_requests": args.num_prompts,
|
||||
"input_len": args.input_len,
|
||||
"output_len": args.output_len,
|
||||
}
|
||||
if args.dataset_path is None or args.dataset_name == "random":
|
||||
sample_kwargs["range_ratio"] = args.random_range_ratio
|
||||
sample_kwargs["prefix_len"] = args.prefix_len
|
||||
dataset_cls = RandomDataset
|
||||
elif args.dataset_name == "sharegpt":
|
||||
dataset_cls = ShareGPTDataset
|
||||
if args.backend == "vllm-chat":
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
elif args.dataset_name == "sonnet":
|
||||
assert tokenizer.chat_template or tokenizer.default_chat_template, (
|
||||
"Tokenizer/model must have chat template for sonnet dataset.")
|
||||
dataset_cls = SonnetDataset
|
||||
sample_kwargs["prefix_len"] = args.prefix_len
|
||||
sample_kwargs["return_prompt_formatted"] = True
|
||||
elif args.dataset_name == "burstgpt":
|
||||
dataset_cls = BurstGPTDataset
|
||||
elif args.dataset_name == "hf":
|
||||
if args.backend != "vllm-chat":
|
||||
raise ValueError(
|
||||
"hf datasets only are supported by vllm-chat backend")
|
||||
# Choose between VisionArenaDataset and HuggingFaceDataset based on
|
||||
# provided parameters.
|
||||
dataset_cls = (VisionArenaDataset if args.dataset_path
|
||||
== VisionArenaDataset.VISION_ARENA_DATASET_PATH
|
||||
and args.hf_subset is None else HuggingFaceDataset)
|
||||
common_kwargs['dataset_subset'] = args.hf_subset
|
||||
common_kwargs['dataset_split'] = args.hf_split
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
|
||||
# Remove None values
|
||||
sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None}
|
||||
return dataset_cls(**common_kwargs).sample(**sample_kwargs)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
if args.seed is None:
|
||||
args.seed = 0
|
||||
print(args)
|
||||
random.seed(args.seed)
|
||||
|
||||
# Sample the requests.
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||
if args.dataset is None:
|
||||
vocab_size = tokenizer.vocab_size
|
||||
requests = []
|
||||
for _ in range(args.num_prompts):
|
||||
|
||||
request_tokenizer = tokenizer
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
if args.enable_lora:
|
||||
lora_request, lora_tokenizer = get_random_lora_request(args)
|
||||
if lora_tokenizer:
|
||||
request_tokenizer = lora_tokenizer
|
||||
|
||||
# Synthesize a prompt with the given input length.
|
||||
candidate_ids = [
|
||||
random.randint(0, vocab_size - 1)
|
||||
for _ in range(args.input_len)
|
||||
]
|
||||
# As tokenizer may add additional tokens like BOS, we need to try
|
||||
# different lengths to get the desired input length.
|
||||
for _ in range(5): # Max attempts to correct
|
||||
candidate_prompt = request_tokenizer.decode(candidate_ids)
|
||||
tokenized_len = len(request_tokenizer.encode(candidate_prompt))
|
||||
|
||||
if tokenized_len == args.input_len:
|
||||
break
|
||||
|
||||
# Adjust length based on difference
|
||||
diff = args.input_len - tokenized_len
|
||||
if diff > 0:
|
||||
candidate_ids.extend([
|
||||
random.randint(100, vocab_size - 100)
|
||||
for _ in range(diff)
|
||||
])
|
||||
else:
|
||||
candidate_ids = candidate_ids[:diff]
|
||||
requests.append(
|
||||
SampleRequest(prompt=candidate_prompt,
|
||||
prompt_len=args.input_len,
|
||||
expected_output_len=args.output_len,
|
||||
lora_request=lora_request))
|
||||
else:
|
||||
requests = sample_requests(tokenizer, args)
|
||||
|
||||
requests = get_requests(args, tokenizer)
|
||||
is_multi_modal = any(request.multi_modal_data is not None
|
||||
for request in requests)
|
||||
request_outputs: Optional[list[RequestOutput]] = None
|
||||
if args.backend == "vllm":
|
||||
if args.async_engine:
|
||||
elapsed_time = uvloop.run(
|
||||
@ -429,31 +356,59 @@ def main(args: argparse.Namespace):
|
||||
args.n,
|
||||
AsyncEngineArgs.from_cli_args(args),
|
||||
args.disable_frontend_multiprocessing,
|
||||
args.disable_detokenize,
|
||||
))
|
||||
else:
|
||||
elapsed_time = run_vllm(requests, args.n,
|
||||
EngineArgs.from_cli_args(args))
|
||||
elapsed_time, request_outputs = run_vllm(
|
||||
requests, args.n, EngineArgs.from_cli_args(args),
|
||||
args.disable_detokenize)
|
||||
elif args.backend == "hf":
|
||||
assert args.tensor_parallel_size == 1
|
||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||
args.hf_max_batch_size, args.trust_remote_code)
|
||||
args.hf_max_batch_size, args.trust_remote_code,
|
||||
args.disable_detokenize)
|
||||
elif args.backend == "mii":
|
||||
elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
|
||||
args.output_len)
|
||||
elif args.backend == "vllm-chat":
|
||||
elapsed_time, request_outputs = run_vllm_chat(
|
||||
requests, args.n, EngineArgs.from_cli_args(args),
|
||||
args.disable_detokenize)
|
||||
else:
|
||||
raise ValueError(f"Unknown backend: {args.backend}")
|
||||
total_num_tokens = sum(request.prompt_len + request.expected_output_len
|
||||
for request in requests)
|
||||
total_output_tokens = sum(request.expected_output_len
|
||||
for request in requests)
|
||||
if is_multi_modal:
|
||||
print("\033[91mWARNING\033[0m: Multi-modal request detected. The "
|
||||
|
||||
if request_outputs:
|
||||
# Note: with the vllm and vllm-chat backends,
|
||||
# we have request_outputs, which we use to count tokens.
|
||||
total_prompt_tokens = 0
|
||||
total_output_tokens = 0
|
||||
for ro in request_outputs:
|
||||
if not isinstance(ro, RequestOutput):
|
||||
continue
|
||||
total_prompt_tokens += len(
|
||||
ro.prompt_token_ids) if ro.prompt_token_ids else 0
|
||||
total_output_tokens += sum(
|
||||
len(o.token_ids) for o in ro.outputs if o)
|
||||
total_num_tokens = total_prompt_tokens + total_output_tokens
|
||||
else:
|
||||
total_num_tokens = sum(r.prompt_len + r.expected_output_len
|
||||
for r in requests)
|
||||
total_output_tokens = sum(r.expected_output_len for r in requests)
|
||||
total_prompt_tokens = total_num_tokens - total_output_tokens
|
||||
|
||||
if is_multi_modal and args.backend != "vllm-chat":
|
||||
print("\033[91mWARNING\033[0m: Multi-modal request with "
|
||||
f"{args.backend} backend detected. The "
|
||||
"following metrics are not accurate because image tokens are not"
|
||||
" counted. See vllm-project/vllm/issues/9778 for details.")
|
||||
# TODO(vllm-project/vllm/issues/9778): Count molti-modal token length.
|
||||
# TODO(vllm-project/vllm/issues/9778): Count multi-modal token length.
|
||||
# vllm-chat backend counts the image tokens now
|
||||
|
||||
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
||||
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
|
||||
f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
|
||||
print(f"Total num prompt tokens: {total_prompt_tokens}")
|
||||
print(f"Total num output tokens: {total_output_tokens}")
|
||||
|
||||
# Output JSON results if specified
|
||||
if args.output_json:
|
||||
@ -469,18 +424,112 @@ def main(args: argparse.Namespace):
|
||||
save_to_pytorch_benchmark_format(args, results)
|
||||
|
||||
|
||||
def validate_args(args):
|
||||
"""
|
||||
Validate command-line arguments.
|
||||
"""
|
||||
|
||||
# === Deprecation and Defaulting ===
|
||||
if args.dataset is not None:
|
||||
warnings.warn(
|
||||
"The '--dataset' argument will be deprecated in the next release. "
|
||||
"Please use '--dataset-name' and '--dataset-path' instead.",
|
||||
stacklevel=2)
|
||||
args.dataset_path = args.dataset
|
||||
|
||||
if not getattr(args, "tokenizer", None):
|
||||
args.tokenizer = args.model
|
||||
|
||||
# === Backend Validation ===
|
||||
valid_backends = {"vllm", "hf", "mii", "vllm-chat"}
|
||||
if args.backend not in valid_backends:
|
||||
raise ValueError(f"Unsupported backend: {args.backend}")
|
||||
|
||||
# === Dataset Configuration ===
|
||||
if not args.dataset and not args.dataset_path:
|
||||
print(
|
||||
"When dataset path is not set, it will default to random dataset")
|
||||
args.dataset_name = 'random'
|
||||
if args.input_len is None:
|
||||
raise ValueError("input_len must be provided for a random dataset")
|
||||
|
||||
# === Dataset Name Specific Checks ===
|
||||
# --hf-subset and --hf-split: only used
|
||||
# when dataset_name is 'hf'
|
||||
if args.dataset_name != "hf" and (
|
||||
getattr(args, "hf_subset", None) is not None
|
||||
or getattr(args, "hf_split", None) is not None):
|
||||
warnings.warn("--hf-subset and --hf-split will be ignored \
|
||||
since --dataset-name is not 'hf'.",
|
||||
stacklevel=2)
|
||||
elif args.dataset_name == "hf" and args.backend != "vllm-chat":
|
||||
raise ValueError(
|
||||
"When --dataset-name is 'hf', backend must be 'vllm-chat'")
|
||||
|
||||
# --random-range-ratio: only used when dataset_name is 'random'
|
||||
if args.dataset_name != 'random' and args.random_range_ratio is not None:
|
||||
warnings.warn("--random-range-ratio will be ignored since \
|
||||
--dataset-name is not 'random'.",
|
||||
stacklevel=2)
|
||||
|
||||
# --prefix-len: only used when dataset_name is 'random', 'sonnet', or not
|
||||
# set.
|
||||
if args.dataset_name not in {"random", "sonnet", None
|
||||
} and args.prefix_len is not None:
|
||||
warnings.warn("--prefix-len will be ignored since --dataset-name\
|
||||
is not 'random', 'sonnet', or not set.",
|
||||
stacklevel=2)
|
||||
|
||||
# === LoRA Settings ===
|
||||
if getattr(args, "enable_lora", False) and args.backend != "vllm":
|
||||
raise ValueError(
|
||||
"LoRA benchmarking is only supported for vLLM backend")
|
||||
if getattr(args, "enable_lora", False) and args.lora_path is None:
|
||||
raise ValueError("LoRA path must be provided when enable_lora is True")
|
||||
|
||||
# === Backend-specific Validations ===
|
||||
if args.backend == "hf" and args.hf_max_batch_size is None:
|
||||
raise ValueError("HF max batch size is required for HF backend")
|
||||
if args.backend != "hf" and args.hf_max_batch_size is not None:
|
||||
raise ValueError("HF max batch size is only for HF backend.")
|
||||
|
||||
if args.backend in {"hf", "mii"} and getattr(args, "quantization",
|
||||
None) is not None:
|
||||
raise ValueError("Quantization is only for vLLM backend.")
|
||||
|
||||
if args.backend == "mii" and args.dtype != "auto":
|
||||
raise ValueError("dtype must be auto for MII backend.")
|
||||
if args.backend == "mii" and args.n != 1:
|
||||
raise ValueError("n must be 1 for MII backend.")
|
||||
if args.backend == "mii" and args.tokenizer != args.model:
|
||||
raise ValueError(
|
||||
"Tokenizer must be the same as the model for MII backend.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(description="Benchmark the throughput.")
|
||||
parser.add_argument("--backend",
|
||||
type=str,
|
||||
choices=["vllm", "hf", "mii"],
|
||||
choices=["vllm", "hf", "mii", "vllm-chat"],
|
||||
default="vllm")
|
||||
parser.add_argument("--dataset",
|
||||
parser.add_argument(
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"],
|
||||
help="Name of the dataset to benchmark on.",
|
||||
default="sharegpt")
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the ShareGPT dataset, will be deprecated in\
|
||||
the next release. The dataset is expected to "
|
||||
"be a json in form of list[dict[..., conversations: "
|
||||
"list[dict[..., value: <prompt_or_response>]]]]")
|
||||
parser.add_argument("--dataset-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the dataset. The dataset is expected to "
|
||||
"be a json in form of list[dict[..., conversations: "
|
||||
"list[dict[..., value: <prompt_or_response>]]]]")
|
||||
help="Path to the dataset")
|
||||
parser.add_argument("--input-len",
|
||||
type=int,
|
||||
default=None,
|
||||
@ -515,6 +564,11 @@ if __name__ == "__main__":
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Disable decoupled async engine frontend.")
|
||||
parser.add_argument(
|
||||
"--disable-detokenize",
|
||||
action="store_true",
|
||||
help=("Do not detokenize the response (i.e. do not include "
|
||||
"detokenization time in the measurement)"))
|
||||
# LoRA
|
||||
parser.add_argument(
|
||||
"--lora-path",
|
||||
@ -522,43 +576,33 @@ if __name__ == "__main__":
|
||||
default=None,
|
||||
help="Path to the lora adapters to use. This can be an absolute path, "
|
||||
"a relative path, or a Hugging Face model identifier.")
|
||||
parser.add_argument("--prefix-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of prefix tokens per request."
|
||||
"This is for the RandomDataset and SonnetDataset")
|
||||
# random dataset
|
||||
parser.add_argument(
|
||||
"--random-range-ratio",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Range of sampled ratio of input/output length, "
|
||||
"used only for RandomDataSet.",
|
||||
)
|
||||
|
||||
# hf dtaset
|
||||
parser.add_argument("--hf-subset",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Subset of the HF dataset.")
|
||||
parser.add_argument("--hf-split",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Split of the HF dataset.")
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
if args.tokenizer is None:
|
||||
args.tokenizer = args.model
|
||||
if args.dataset is None:
|
||||
assert args.input_len is not None
|
||||
assert args.output_len is not None
|
||||
else:
|
||||
assert args.input_len is None
|
||||
if args.enable_lora:
|
||||
assert args.lora_path is not None
|
||||
|
||||
if args.backend == "vllm":
|
||||
if args.hf_max_batch_size is not None:
|
||||
raise ValueError("HF max batch size is only for HF backend.")
|
||||
elif args.backend == "hf":
|
||||
if args.hf_max_batch_size is None:
|
||||
raise ValueError("HF max batch size is required for HF backend.")
|
||||
if args.quantization is not None:
|
||||
raise ValueError("Quantization is only for vLLM backend.")
|
||||
if args.enable_lora is not None:
|
||||
raise ValueError("LoRA benchmarking is only supported for vLLM"
|
||||
" backend")
|
||||
elif args.backend == "mii":
|
||||
if args.dtype != "auto":
|
||||
raise ValueError("dtype must be auto for MII backend.")
|
||||
if args.n != 1:
|
||||
raise ValueError("n must be 1 for MII backend.")
|
||||
if args.quantization is not None:
|
||||
raise ValueError("Quantization is only for vLLM backend.")
|
||||
if args.hf_max_batch_size is not None:
|
||||
raise ValueError("HF max batch size is only for HF backend.")
|
||||
if args.tokenizer != args.model:
|
||||
raise ValueError("Tokenizer must be the same as the model for MII "
|
||||
"backend.")
|
||||
if args.enable_lora is not None:
|
||||
raise ValueError("LoRA benchmarking is only supported for vLLM"
|
||||
" backend")
|
||||
validate_args(args)
|
||||
main(args)
|
||||
|
||||
@ -40,7 +40,7 @@ def main(num_tokens: int,
|
||||
|
||||
end_time = time.perf_counter()
|
||||
if profile:
|
||||
torch.cuda.cudart().cudaProfilerStart()
|
||||
torch.cuda.cudart().cudaProfilerStop()
|
||||
return (end_time - start_time) / num_iters
|
||||
|
||||
# Warmup.
|
||||
|
||||
@ -17,11 +17,7 @@ from torch.utils.benchmark import Measurement as TMeasurement
|
||||
from utils import ArgPool, Bench, CudaGraphBenchParams
|
||||
from weight_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm.lora.ops.triton_ops.bgmv_expand import bgmv_expand
|
||||
from vllm.lora.ops.triton_ops.bgmv_expand_slice import bgmv_expand_slice
|
||||
from vllm.lora.ops.triton_ops.bgmv_shrink import bgmv_shrink
|
||||
from vllm.lora.ops.triton_ops.sgmv_expand import sgmv_expand
|
||||
from vllm.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink
|
||||
from vllm.lora.ops.triton_ops import LoRAKernelMeta, lora_expand, lora_shrink
|
||||
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
@ -153,7 +149,6 @@ def ref_group_gemm(ref_out: torch.Tensor, input: torch.Tensor,
|
||||
result = torch.nn.functional.linear(x, w)
|
||||
result *= scaling
|
||||
out_list.append(result)
|
||||
torch.cat(out_list, dim=0)
|
||||
|
||||
cat_result = torch.cat(out_list, dim=0)
|
||||
|
||||
@ -167,52 +162,25 @@ class OpType(Enum):
|
||||
"""
|
||||
LoRA Ops to benchmark and its properties.
|
||||
"""
|
||||
SGMV_SHRINK = auto()
|
||||
BGMV_SHRINK = auto()
|
||||
SGMV_EXPAND = auto()
|
||||
BGMV_EXPAND = auto()
|
||||
BGMV_EXPAND_SLICE = auto()
|
||||
LORA_SHRINK = auto()
|
||||
LORA_EXPAND = auto()
|
||||
|
||||
@staticmethod
|
||||
def from_str(s: str) -> "OpType":
|
||||
if s.lower() == 'sgmv_shrink':
|
||||
return OpType.SGMV_SHRINK
|
||||
if s.lower() == 'sgmv_expand':
|
||||
return OpType.SGMV_EXPAND
|
||||
if s.lower() == 'bgmv_shrink':
|
||||
return OpType.BGMV_SHRINK
|
||||
if s.lower() == 'bgmv_expand':
|
||||
return OpType.BGMV_EXPAND
|
||||
if s.lower() == "bgmv_expand_slice":
|
||||
return OpType.BGMV_EXPAND_SLICE
|
||||
if s.lower() == "lora_shrink":
|
||||
return OpType.LORA_SHRINK
|
||||
if s.lower() == "lora_expand":
|
||||
return OpType.LORA_EXPAND
|
||||
raise ValueError(f"Unrecognized str {s} to convert to OpType")
|
||||
|
||||
def is_shrink_fn(self) -> bool:
|
||||
return self in [OpType.SGMV_SHRINK, OpType.BGMV_SHRINK]
|
||||
return self in [OpType.LORA_SHRINK]
|
||||
|
||||
def is_expand_fn(self) -> bool:
|
||||
return self in [OpType.SGMV_EXPAND, OpType.BGMV_EXPAND]
|
||||
|
||||
def is_prefill_op(self) -> bool:
|
||||
return self in [OpType.SGMV_SHRINK, OpType.SGMV_EXPAND]
|
||||
|
||||
def is_decode_op(self) -> bool:
|
||||
return self in [
|
||||
OpType.BGMV_SHRINK, OpType.BGMV_EXPAND, OpType.BGMV_EXPAND_SLICE
|
||||
]
|
||||
|
||||
def is_expand_slice_fn(self) -> bool:
|
||||
return self in [OpType.BGMV_EXPAND_SLICE]
|
||||
return self in [OpType.LORA_EXPAND]
|
||||
|
||||
def num_slices(self) -> list[int]:
|
||||
if self in [OpType.SGMV_EXPAND, OpType.SGMV_SHRINK]:
|
||||
# SGMV kernels supports slices
|
||||
return [1, 2, 3]
|
||||
if self in [OpType.BGMV_SHRINK, OpType.BGMV_EXPAND]:
|
||||
return [1]
|
||||
if self in [OpType.BGMV_EXPAND_SLICE]:
|
||||
return [2, 3]
|
||||
raise ValueError(f"Unrecognized OpType {self}")
|
||||
return [1, 2, 3]
|
||||
|
||||
def mkn(self, batch_size: int, seq_length: int, hidden_size: int,
|
||||
lora_rank: int) -> tuple[int, int, int]:
|
||||
@ -222,7 +190,7 @@ class OpType(Enum):
|
||||
k = hidden_size
|
||||
n = lora_rank
|
||||
else:
|
||||
assert self.is_expand_fn() or self.is_expand_slice_fn()
|
||||
assert self.is_expand_fn()
|
||||
m = num_tokens
|
||||
k = lora_rank
|
||||
n = hidden_size
|
||||
@ -237,7 +205,7 @@ class OpType(Enum):
|
||||
if self.is_shrink_fn():
|
||||
return op_dtype, op_dtype, torch.float32
|
||||
else:
|
||||
assert self.is_expand_fn() or self.is_expand_slice_fn()
|
||||
assert self.is_expand_fn()
|
||||
return torch.float32, op_dtype, op_dtype
|
||||
|
||||
def matmul_shapes(
|
||||
@ -251,56 +219,39 @@ class OpType(Enum):
|
||||
m, k, n = self.mkn(batch_size, seq_length, hidden_size, lora_rank)
|
||||
|
||||
b_shape = (num_loras, n, k) # col-major
|
||||
if self == OpType.SGMV_SHRINK:
|
||||
# SGMV shrink supports num_slices inherently in the kernel
|
||||
if self in [OpType.LORA_SHRINK]:
|
||||
# LoRA shrink kernels support num_slices inherently in the kernel.
|
||||
return ((m, k), b_shape, (num_slices, m, n))
|
||||
if self == OpType.SGMV_EXPAND:
|
||||
# SGMV expand supports num_slices inherently in the kernel
|
||||
if self in [OpType.LORA_EXPAND]:
|
||||
# LoRA expand kernels support num_slices inherently in the kernel
|
||||
return ((num_slices, m, k), b_shape, (m, n * num_slices))
|
||||
if self == OpType.BGMV_SHRINK:
|
||||
return ((m, k), b_shape, (m, n))
|
||||
if self == OpType.BGMV_EXPAND:
|
||||
return ((m, k), b_shape, (m, n))
|
||||
if self == OpType.BGMV_EXPAND_SLICE:
|
||||
return ((num_slices, m, k), b_shape, (m, n * num_slices))
|
||||
|
||||
raise ValueError(f"Unrecognized op_type {self}")
|
||||
|
||||
def bench_fn(self) -> Callable:
|
||||
if self == OpType.LORA_SHRINK:
|
||||
return lora_shrink
|
||||
if self == OpType.LORA_EXPAND:
|
||||
return lora_expand
|
||||
|
||||
def emulate_bgmv_expand_slice(kwargs_list: list[dict[str, Any]]):
|
||||
for x in kwargs_list:
|
||||
bgmv_expand_slice(**x)
|
||||
|
||||
if self == OpType.SGMV_SHRINK:
|
||||
return sgmv_shrink
|
||||
if self == OpType.SGMV_EXPAND:
|
||||
return sgmv_expand
|
||||
if self == OpType.BGMV_SHRINK:
|
||||
return bgmv_shrink
|
||||
if self == OpType.BGMV_EXPAND:
|
||||
return bgmv_expand
|
||||
if self == OpType.BGMV_EXPAND_SLICE:
|
||||
return emulate_bgmv_expand_slice
|
||||
raise ValueError(f"Unrecognized optype {self}")
|
||||
|
||||
def run_ref_group_gemm(self, output: torch.Tensor, input: torch.Tensor,
|
||||
lora_weights: list[torch.Tensor],
|
||||
**kwargs) -> Callable:
|
||||
"""Each benchmark operation expected the input, lora_weights and outputs
|
||||
"""Each benchmark operation expects the input, lora_weights and outputs
|
||||
in a slightly different format. Refer to self.matmul_shapes().
|
||||
run_ref_group_gemm accounts for those differences in executing a
|
||||
reference group gemm for correctness testing.
|
||||
"""
|
||||
w_dtype = lora_weights[0].dtype
|
||||
num_slices = len(lora_weights)
|
||||
if self == OpType.SGMV_SHRINK:
|
||||
if self in [OpType.LORA_SHRINK]:
|
||||
for slice_idx in range(num_slices):
|
||||
ref_group_gemm(ref_out=output[slice_idx, :],
|
||||
input=input,
|
||||
lora_weights=lora_weights[slice_idx],
|
||||
**kwargs)
|
||||
if self == OpType.SGMV_EXPAND:
|
||||
elif self in [OpType.LORA_EXPAND]:
|
||||
hidden_size = lora_weights[0].shape[1]
|
||||
for slice_idx in range(num_slices):
|
||||
slice_offset = slice_idx * hidden_size
|
||||
@ -309,28 +260,8 @@ class OpType(Enum):
|
||||
input=input[slice_idx].clone().to(dtype=w_dtype),
|
||||
lora_weights=lora_weights[slice_idx],
|
||||
**kwargs)
|
||||
if self == OpType.BGMV_SHRINK:
|
||||
assert num_slices == 1
|
||||
ref_group_gemm(ref_out=output,
|
||||
input=input,
|
||||
lora_weights=lora_weights[0],
|
||||
**kwargs)
|
||||
if self == OpType.BGMV_EXPAND:
|
||||
assert num_slices == 1
|
||||
ref_group_gemm(ref_out=output,
|
||||
input=input.clone().to(dtype=w_dtype),
|
||||
lora_weights=lora_weights[0],
|
||||
**kwargs)
|
||||
if self == OpType.BGMV_EXPAND_SLICE:
|
||||
hidden_size = lora_weights[0].shape[1]
|
||||
for slice_idx in range(num_slices):
|
||||
slice_offset = slice_idx * hidden_size
|
||||
ref_group_gemm(
|
||||
ref_out=output[:, slice_offset:slice_offset + hidden_size],
|
||||
input=input[slice_idx].clone().to(dtype=w_dtype),
|
||||
lora_weights=lora_weights[slice_idx],
|
||||
**kwargs)
|
||||
raise ValueError(f"Unrecognized optype {self}")
|
||||
else:
|
||||
raise ValueError(f"Unrecognized optype {self}")
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -386,11 +317,11 @@ class BenchmarkTensors:
|
||||
input: torch.Tensor
|
||||
lora_weights_lst: list[torch.Tensor]
|
||||
output: torch.Tensor
|
||||
# metadata tensors
|
||||
# LoRA kernel metadata
|
||||
lora_kernel_meta: LoRAKernelMeta
|
||||
# Metadata tensors used in testing correctness
|
||||
seq_lens: torch.Tensor
|
||||
seq_start_loc: torch.Tensor
|
||||
prompt_lora_mapping: torch.Tensor
|
||||
token_lora_mapping: torch.Tensor
|
||||
|
||||
def io_types(self) -> str:
|
||||
return (f"{dtype_to_str(self.input.dtype)}x"
|
||||
@ -417,26 +348,29 @@ class BenchmarkTensors:
|
||||
assert ctx.num_active_loras <= ctx.num_loras
|
||||
total_tokens = ctx.batch_size * ctx.seq_length
|
||||
|
||||
# Make metadata tensors involved in correctness testing.
|
||||
# Prepare seq lens tensor
|
||||
seq_len_tensor = torch.randint(ctx.seq_length, ctx.seq_length + 1,
|
||||
(ctx.batch_size, ))
|
||||
# Prepare seq_start_loc tensor
|
||||
seq_start_loc_tensor = torch.cumsum(torch.tensor(
|
||||
[0] + seq_len_tensor[:-1].tolist(), dtype=torch.long),
|
||||
dim=0)
|
||||
assert total_tokens == seq_len_tensor.sum()
|
||||
# Prepare prompt lora indices tensor
|
||||
prompt_lora_indices_tensor = make_prompt_lora_mapping(
|
||||
ctx.batch_size, ctx.num_active_loras, ctx.sort_by_lora_id, "cpu")
|
||||
# Prepare token lora indices tensor
|
||||
|
||||
# Make LoRAKernelMeta
|
||||
token_lora_indices_tensor = make_token_lora_mapping(
|
||||
total_tokens, ctx.batch_size, prompt_lora_indices_tensor,
|
||||
seq_len_tensor, "cpu")
|
||||
lora_kernel_meta = LoRAKernelMeta.make(
|
||||
max_loras=ctx.num_loras,
|
||||
max_num_tokens=token_lora_indices_tensor.size(0),
|
||||
device="cpu")
|
||||
lora_kernel_meta.prepare_tensors(
|
||||
token_lora_mapping=token_lora_indices_tensor)
|
||||
|
||||
return BenchmarkTensors(input_tensor, lora_weights, output_tensor,
|
||||
seq_len_tensor, seq_start_loc_tensor,
|
||||
prompt_lora_indices_tensor,
|
||||
token_lora_indices_tensor)
|
||||
lora_kernel_meta, seq_len_tensor,
|
||||
prompt_lora_indices_tensor)
|
||||
|
||||
def sanity_check(self) -> None:
|
||||
"""
|
||||
@ -446,9 +380,9 @@ class BenchmarkTensors:
|
||||
# check metadata tensors
|
||||
assert torch.sum(self.seq_lens) == num_tokens
|
||||
num_seqs = self.seq_lens.shape[0]
|
||||
assert self.seq_start_loc.shape[0] == num_seqs
|
||||
#assert self.seq_start_loc.shape[0] == num_seqs
|
||||
assert self.prompt_lora_mapping.shape[0] == num_seqs
|
||||
assert self.token_lora_mapping.shape[0] == num_tokens
|
||||
assert self.lora_kernel_meta.token_lora_mapping.shape[0] == num_tokens
|
||||
|
||||
def to_device(self, device: str):
|
||||
"""
|
||||
@ -463,54 +397,31 @@ class BenchmarkTensors:
|
||||
self.input = to_device(self.input)
|
||||
self.output = to_device(self.output)
|
||||
self.seq_lens = to_device(self.seq_lens)
|
||||
self.seq_start_loc = to_device(self.seq_start_loc)
|
||||
self.prompt_lora_mapping = to_device(self.prompt_lora_mapping)
|
||||
self.token_lora_mapping = to_device(self.token_lora_mapping)
|
||||
for i in range(len(self.lora_weights_lst)):
|
||||
self.lora_weights_lst[i] = to_device(self.lora_weights_lst[i])
|
||||
|
||||
# LoRA meta
|
||||
for field_name in LoRAKernelMeta.__dataclass_fields__:
|
||||
field = getattr(self.lora_kernel_meta, field_name)
|
||||
assert isinstance(field, torch.Tensor)
|
||||
setattr(self.lora_kernel_meta, field_name, to_device(field))
|
||||
|
||||
def metadata(self) -> tuple[int, int, int]:
|
||||
"""
|
||||
Return num_seqs, num_tokens and max_seq_len
|
||||
"""
|
||||
num_seqs = self.seq_lens.shape[0]
|
||||
num_tokens = self.token_lora_mapping.shape[0]
|
||||
num_tokens = self.lora_kernel_meta.token_lora_mapping.shape[0]
|
||||
max_seq_len = torch.max(self.seq_lens).item()
|
||||
num_slices = len(self.lora_weights_lst)
|
||||
return num_seqs, num_tokens, max_seq_len, num_slices
|
||||
|
||||
def convert_to_sgmv_benchmark_tensors(self):
|
||||
"""
|
||||
For sgmv punica kernels, when consecutive sequences have the
|
||||
same LoRA ID, we just merge them together.
|
||||
This happens in punica.py::compute_metadata
|
||||
"""
|
||||
|
||||
# Collapse seq_lens and seq_start_loc
|
||||
_, seq_lens = torch.unique_consecutive(self.token_lora_mapping,
|
||||
return_counts=True)
|
||||
cum_result = torch.cumsum(seq_lens, dim=0)
|
||||
seq_start_loc = torch.zeros_like(seq_lens)
|
||||
seq_start_loc[1:].copy_(cum_result[:-1])
|
||||
|
||||
# Collapse prompt mapping
|
||||
prompt_lora_mapping = torch.unique_consecutive(
|
||||
self.prompt_lora_mapping)
|
||||
|
||||
assert torch.sum(seq_lens) == torch.sum(self.seq_lens), \
|
||||
f"dont match - new {torch.sum(seq_lens)} vs {torch.sum(self.seq_lens)}"
|
||||
|
||||
self.prompt_lora_mapping = prompt_lora_mapping.to(
|
||||
dtype=self.prompt_lora_mapping.dtype)
|
||||
self.seq_lens = seq_lens.to(dtype=self.seq_lens.dtype)
|
||||
self.seq_start_loc = seq_start_loc.to(dtype=self.seq_start_loc.dtype)
|
||||
|
||||
def as_sgmv_shrink_kwargs(self) -> dict[str, Any]:
|
||||
self.convert_to_sgmv_benchmark_tensors()
|
||||
def as_lora_shrink_kwargs(self) -> dict[str, Any]:
|
||||
self.sanity_check()
|
||||
self.to_device(self.input.device)
|
||||
|
||||
num_seqs, num_tokens, max_seq_len, num_slices = self.metadata()
|
||||
_, num_tokens, _, num_slices = self.metadata()
|
||||
|
||||
# Sanity check matrix shapes.
|
||||
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
|
||||
@ -531,22 +442,20 @@ class BenchmarkTensors:
|
||||
'inputs': self.input,
|
||||
'lora_a_weights': self.lora_weights_lst,
|
||||
'output_tensor': self.output,
|
||||
'b_seq_start_loc': self.seq_start_loc,
|
||||
'seq_len_tensor': self.seq_lens,
|
||||
'lora_indices_tensor': self.prompt_lora_mapping,
|
||||
'batches': num_seqs,
|
||||
'max_seq_length': max_seq_len,
|
||||
'token_nums': num_tokens,
|
||||
'token_lora_mapping': self.lora_kernel_meta.token_lora_mapping,
|
||||
'token_indices_sorted_by_lora_ids':
|
||||
self.lora_kernel_meta.token_indices_sorted_by_lora_ids,
|
||||
'num_tokens_per_lora': self.lora_kernel_meta.num_tokens_per_lora,
|
||||
'lora_token_start_loc': self.lora_kernel_meta.lora_token_start_loc,
|
||||
'lora_ids': self.lora_kernel_meta.active_lora_ids,
|
||||
'scaling': 1.0,
|
||||
}
|
||||
|
||||
def as_sgmv_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]:
|
||||
|
||||
self.convert_to_sgmv_benchmark_tensors()
|
||||
def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]:
|
||||
self.sanity_check()
|
||||
self.to_device(self.input.device)
|
||||
|
||||
num_seqs, num_tokens, max_seq_len, num_slices = self.metadata()
|
||||
_, num_tokens, _, num_slices = self.metadata()
|
||||
|
||||
# Sanity check matrix shapes.
|
||||
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
|
||||
@ -568,106 +477,16 @@ class BenchmarkTensors:
|
||||
'inputs': self.input,
|
||||
'lora_b_weights': self.lora_weights_lst,
|
||||
'output_tensor': self.output,
|
||||
'b_seq_start_loc': self.seq_start_loc,
|
||||
'seq_len_tensor': self.seq_lens,
|
||||
'lora_indices_tensor': self.prompt_lora_mapping,
|
||||
'batches': num_seqs,
|
||||
'max_seq_length': max_seq_len,
|
||||
'token_nums': num_tokens,
|
||||
'token_lora_mapping': self.lora_kernel_meta.token_lora_mapping,
|
||||
'token_indices_sorted_by_lora_ids':
|
||||
self.lora_kernel_meta.token_indices_sorted_by_lora_ids,
|
||||
'num_tokens_per_lora': self.lora_kernel_meta.num_tokens_per_lora,
|
||||
'lora_token_start_loc': self.lora_kernel_meta.lora_token_start_loc,
|
||||
'lora_ids': self.lora_kernel_meta.active_lora_ids,
|
||||
'offset_start': 0,
|
||||
'add_inputs': add_inputs,
|
||||
}
|
||||
|
||||
def as_bgmv_shrink_kwargs(self) -> dict[str, Any]:
|
||||
assert len(self.lora_weights_lst) == 1
|
||||
self.to_device(self.input.device)
|
||||
|
||||
_, num_tokens, _, _ = self.metadata()
|
||||
# Sanity check shapes
|
||||
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
|
||||
0].shape, self.output.shape
|
||||
# Expected input shape [num_tokens, hidden_size]
|
||||
assert len(i_shape) == 2
|
||||
assert i_shape[0] == num_tokens
|
||||
hidden_size = i_shape[1]
|
||||
# Expected lora weight shape [num_loras, lora_rank, hidden_size]
|
||||
assert len(lw_shape) == 3
|
||||
assert lw_shape[2] == hidden_size
|
||||
lora_rank = lw_shape[1]
|
||||
# Expected output shape [num_tokens, lora_rank]
|
||||
assert len(o_shape) == 2
|
||||
assert o_shape == (num_tokens, lora_rank)
|
||||
|
||||
return {
|
||||
'inputs': self.input,
|
||||
'lora_a_weights': self.lora_weights_lst[0],
|
||||
'output_tensor': self.output,
|
||||
'lora_indices_tensor': self.token_lora_mapping,
|
||||
'scaling': 1.0
|
||||
}
|
||||
|
||||
def as_bgmv_expand_kwargs(self, add_inputs: bool):
|
||||
assert len(self.lora_weights_lst) == 1
|
||||
self.to_device(self.input.device)
|
||||
|
||||
_, num_tokens, _, _ = self.metadata()
|
||||
# Sanity check shapes
|
||||
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
|
||||
0].shape, self.output.shape
|
||||
# Expected input shape [num_tokens, lora_rank]
|
||||
assert len(i_shape) == 2
|
||||
assert i_shape[0] == num_tokens
|
||||
lora_rank = i_shape[1]
|
||||
# Expected lora weight shape [num_loras, hidden_size, lora_rank]
|
||||
assert len(lw_shape) == 3
|
||||
assert lw_shape[2] == lora_rank
|
||||
hidden_size = lw_shape[1]
|
||||
# Expected output shape [num_tokens, hidden_size]
|
||||
assert len(o_shape) == 2
|
||||
assert o_shape == (num_tokens, hidden_size)
|
||||
|
||||
return {
|
||||
'inputs': self.input,
|
||||
'lora_b_weights': self.lora_weights_lst[0],
|
||||
'output_tensor': self.output,
|
||||
'lora_indices_tensor': self.token_lora_mapping,
|
||||
'add_inputs': add_inputs
|
||||
}
|
||||
|
||||
def as_bgmv_expand_slice_kwargs(self, add_inputs: bool) -> dict[str, Any]:
|
||||
|
||||
_, num_tokens, _, num_slices = self.metadata()
|
||||
# Sanity check shapes
|
||||
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
|
||||
0].shape, self.output.shape
|
||||
# Expected input shape [num_slices, num_tokens, lora_rank]
|
||||
assert len(i_shape) == 3
|
||||
assert i_shape[0] == num_slices
|
||||
assert i_shape[1] == num_tokens
|
||||
lora_rank = i_shape[2]
|
||||
# Expected lora weight shape [num_loras, hidden_size, lora_rank]
|
||||
assert len(lw_shape) == 3
|
||||
assert lw_shape[2] == lora_rank
|
||||
hidden_size = lw_shape[1]
|
||||
# Expected output shape [num_tokens, hidden_size * num_slices]
|
||||
assert len(o_shape) == 2
|
||||
assert o_shape == (num_tokens, hidden_size * num_slices)
|
||||
|
||||
self.to_device(self.input.device)
|
||||
|
||||
kwargs_list = []
|
||||
for i in range(num_slices):
|
||||
kwargs_list.append({
|
||||
'inputs': self.input[i],
|
||||
'lora_b_weights': self.lora_weights_lst[i],
|
||||
'output_tensor': self.output,
|
||||
'lora_indices_tensor': self.token_lora_mapping,
|
||||
'slice_offset': i * hidden_size,
|
||||
'slice_size': hidden_size,
|
||||
'add_inputs': add_inputs,
|
||||
})
|
||||
return {'kwargs_list': kwargs_list}
|
||||
|
||||
def bench_fn_kwargs(self,
|
||||
op_type: OpType,
|
||||
add_inputs: Optional[bool] = None) -> dict[str, Any]:
|
||||
@ -676,16 +495,10 @@ class BenchmarkTensors:
|
||||
else:
|
||||
assert add_inputs is not None
|
||||
|
||||
if op_type == OpType.SGMV_SHRINK:
|
||||
return self.as_sgmv_shrink_kwargs()
|
||||
if op_type == OpType.SGMV_EXPAND:
|
||||
return self.as_sgmv_expand_kwargs(add_inputs)
|
||||
if op_type == OpType.BGMV_SHRINK:
|
||||
return self.as_bgmv_shrink_kwargs()
|
||||
if op_type == OpType.BGMV_EXPAND:
|
||||
return self.as_bgmv_expand_kwargs(add_inputs)
|
||||
if op_type == OpType.BGMV_EXPAND_SLICE:
|
||||
return self.as_bgmv_expand_slice_kwargs(add_inputs)
|
||||
if op_type == OpType.LORA_SHRINK:
|
||||
return self.as_lora_shrink_kwargs()
|
||||
if op_type == OpType.LORA_EXPAND:
|
||||
return self.as_lora_expand_kwargs(add_inputs)
|
||||
raise ValueError(f"Unrecognized optype {self}")
|
||||
|
||||
def test_correctness(self, op_type: OpType,
|
||||
@ -873,14 +686,7 @@ def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]):
|
||||
timers = []
|
||||
for bench_ctx in bench_ctxs:
|
||||
for seq_len in args.seq_lengths:
|
||||
bench_ops: list[OpType] = []
|
||||
if seq_len == 1:
|
||||
# bench all decode ops
|
||||
bench_ops = [op for op in args.op_types if op.is_decode_op()]
|
||||
else:
|
||||
# bench all prefill ops
|
||||
bench_ops = [op for op in args.op_types if op.is_prefill_op()]
|
||||
|
||||
bench_ops: list[OpType] = args.op_types
|
||||
seq_len_timers = []
|
||||
for bench_op in bench_ops:
|
||||
for num_slices in bench_op.num_slices():
|
||||
@ -1090,13 +896,13 @@ Benchmark LoRA kernels:
|
||||
{use_cuda_graph_recommendation()}
|
||||
|
||||
list_bench example:
|
||||
python3 benchmarks/kernels/benchmark_lora.py list_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --hidden-sizes 2048 --lora-ranks 16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32
|
||||
python3 benchmarks/kernels/benchmark_lora.py list_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --hidden-sizes 2048 --lora-ranks 16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32
|
||||
|
||||
model_bench example:
|
||||
python3 benchmarks/kernels/benchmark_lora.py model_bench --models meta-llama/Llama-3-8b --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --lora-ranks 16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32
|
||||
python3 benchmarks/kernels/benchmark_lora.py model_bench --models meta-llama/Llama-3-8b --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --lora-ranks 16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32
|
||||
|
||||
range_bench example:
|
||||
python3 benchmarks/kernels/benchmark_lora.py range_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 --hidden-sizes-start 1024 --hidden-sizes-end 4096 --hidden-sizes-increment 1024 --lora-ranks-start 8 --lora-ranks-end 24 --lora-ranks-increment 8
|
||||
python3 benchmarks/kernels/benchmark_lora.py range_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 --hidden-sizes-start 1024 --hidden-sizes-end 4096 --hidden-sizes-increment 1024 --lora-ranks-start 8 --lora-ranks-end 24 --lora-ranks-increment 8
|
||||
""", # noqa: E501
|
||||
formatter_class=argparse.RawTextHelpFormatter)
|
||||
|
||||
|
||||
@ -45,7 +45,6 @@ def terse_type_name(dt):
|
||||
torch.float16: "fp16",
|
||||
torch.int8: "int8",
|
||||
torch.float8_e4m3fn: "fp8",
|
||||
torch.bfloat16: "bf16",
|
||||
torch.float: "float",
|
||||
torch.int: "int",
|
||||
}[dt]
|
||||
@ -259,7 +258,7 @@ def machete_create_bench_fn(bt: BenchmarkTensors,
|
||||
|
||||
return lambda: ops.machete_mm(
|
||||
a=bt.a,
|
||||
b_q=bt.w_q,
|
||||
b_q=w_q,
|
||||
b_type=bt.wtype,
|
||||
b_group_scales=bt.w_g_s,
|
||||
b_group_zeros=w_g_zp,
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from datetime import datetime
|
||||
@ -17,8 +18,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import *
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
FP8_DTYPE = torch.float8_e4m3fnuz if current_platform.is_rocm(
|
||||
) else torch.float8_e4m3fn
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
class BenchmarkConfig(TypedDict):
|
||||
@ -365,6 +365,7 @@ class BenchmarkWorker:
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
block_quant_shape: List[int] = None,
|
||||
) -> tuple[dict[str, int], float]:
|
||||
current_platform.seed_everything(self.seed)
|
||||
dtype_str = get_config_dtype_str(dtype,
|
||||
@ -385,10 +386,17 @@ class BenchmarkWorker:
|
||||
else:
|
||||
config = op_config[min(op_config.keys(),
|
||||
key=lambda x: abs(x - num_tokens))]
|
||||
kernel_time = benchmark_config(config, num_tokens, num_experts,
|
||||
shard_intermediate_size, hidden_size,
|
||||
topk, dtype, use_fp8_w8a8,
|
||||
use_int8_w8a16)
|
||||
kernel_time = benchmark_config(config,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
num_iters=100,
|
||||
block_quant_shape=block_quant_shape)
|
||||
return config, kernel_time
|
||||
|
||||
def tune(
|
||||
@ -487,6 +495,14 @@ def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,
|
||||
f.write("\n")
|
||||
|
||||
|
||||
def get_weight_block_size_safety(config, default_value=None):
|
||||
|
||||
quantization_config = getattr(config, 'quantization_config', {})
|
||||
if isinstance(quantization_config, dict):
|
||||
return quantization_config.get('weight_block_size', default_value)
|
||||
return default_value
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
print(args)
|
||||
block_quant_shape = None
|
||||
@ -508,7 +524,12 @@ def main(args: argparse.Namespace):
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
block_quant_shape = config.quantization_config['weight_block_size']
|
||||
block_quant_shape = get_weight_block_size_safety(config)
|
||||
elif config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
else:
|
||||
# Default: Mixtral.
|
||||
E = config.num_local_experts
|
||||
|
||||
@ -176,7 +176,7 @@ def main(
|
||||
|
||||
end_time = time.perf_counter()
|
||||
if profile:
|
||||
torch.cuda.cudart().cudaProfilerStart()
|
||||
torch.cuda.cudart().cudaProfilerStop()
|
||||
return (end_time - start_time) / num_iters
|
||||
|
||||
# Warmup.
|
||||
|
||||
@ -40,7 +40,7 @@ def main(num_tokens: int,
|
||||
|
||||
end_time = time.perf_counter()
|
||||
if profile:
|
||||
torch.cuda.cudart().cudaProfilerStart()
|
||||
torch.cuda.cudart().cudaProfilerStop()
|
||||
return (end_time - start_time) / num_iters
|
||||
|
||||
# Warmup.
|
||||
|
||||
@ -139,7 +139,7 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
|
||||
|
||||
print(f"Naive output={output_naive}")
|
||||
print(f"FlashInfer output={output_flashinfer}")
|
||||
print(f"VLLM output={output_vllm}")
|
||||
print(f"vLLM output={output_vllm}")
|
||||
|
||||
if torch.allclose(output_naive, output_flashinfer, atol=1e-2,
|
||||
rtol=1e-2) and torch.allclose(
|
||||
|
||||
129
benchmarks/kernels/deepgemm/README.md
Normal file
129
benchmarks/kernels/deepgemm/README.md
Normal file
@ -0,0 +1,129 @@
|
||||
# DeepSeek DeepGEMM Kernels Benchmark
|
||||
|
||||
This directory includes benchmarks between DeepSeek's DeepGEMM block fp8 kernels against vLLM's existing triton and CUTLASS-based kernels.
|
||||
|
||||
Currently this just includes dense GEMMs and only works on Hopper GPUs.
|
||||
|
||||
## Setup
|
||||
|
||||
You need to install vLLM in your usual fashion, then install DeepGEMM from source in its own directory:
|
||||
|
||||
```
|
||||
git clone --recursive https://github.com/deepseek-ai/DeepGEMM
|
||||
cd DeepGEMM
|
||||
python setup.py install
|
||||
uv pip install -e .
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
```
|
||||
python benchmark_fp8_block_dense_gemm.py
|
||||
INFO 02-26 21:55:13 [__init__.py:207] Automatically detected platform cuda.
|
||||
===== STARTING FP8 GEMM BENCHMARK =====
|
||||
PyTorch version: 2.5.1+cu124
|
||||
CUDA version: 12.4
|
||||
Triton version: 3.1.0
|
||||
Using device: NVIDIA H100 80GB HBM3
|
||||
WARNING 02-26 21:55:15 [fp8_utils.py:458] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json
|
||||
INFO 02-26 21:55:15 [fp8_utils.py:449] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel.
|
||||
WARNING 02-26 21:55:16 [fp8_utils.py:458] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=18432,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json
|
||||
WARNING 02-26 21:55:17 [fp8_utils.py:458] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json
|
||||
INFO 02-26 21:55:17 [fp8_utils.py:449] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel.
|
||||
INFO 02-26 21:55:17 [fp8_utils.py:449] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel.
|
||||
|
||||
===== PERFORMANCE COMPARISON =====
|
||||
|
||||
DeepGEMM Implementation:
|
||||
+------+-------+-------+-----------+--------+--------+
|
||||
| m | n | k | Time (μs) | TFLOPS | GB/s |
|
||||
+------+-------+-------+-----------+--------+--------+
|
||||
| 8 | 4096 | 7168 | 102.9 | 4.6 | 286.4 |
|
||||
| 8 | 7168 | 18432 | 70.8 | 29.8 | 1868.8 |
|
||||
| 8 | 18432 | 7168 | 69.3 | 30.5 | 1911.8 |
|
||||
| 64 | 4096 | 7168 | 69.1 | 54.4 | 439.0 |
|
||||
| 64 | 7168 | 18432 | 69.4 | 243.6 | 1933.6 |
|
||||
| 64 | 18432 | 7168 | 70.4 | 240.3 | 1917.2 |
|
||||
| 64 | 24576 | 1536 | 70.1 | 68.9 | 584.6 |
|
||||
| 64 | 32768 | 512 | 68.4 | 31.4 | 307.1 |
|
||||
| 64 | 7168 | 16384 | 69.5 | 216.3 | 1718.5 |
|
||||
| 128 | 4096 | 7168 | 141.1 | 53.3 | 222.1 |
|
||||
| 128 | 7168 | 18432 | 71.9 | 470.5 | 1896.1 |
|
||||
| 128 | 18432 | 7168 | 69.3 | 488.2 | 1988.2 |
|
||||
| 1024 | 4096 | 7168 | 89.7 | 670.1 | 502.5 |
|
||||
| 1024 | 18432 | 7168 | 279.0 | 969.8 | 635.2 |
|
||||
| 2048 | 4096 | 7168 | 175.1 | 687.0 | 347.4 |
|
||||
| 4096 | 4096 | 7168 | 335.4 | 717.0 | 275.1 |
|
||||
+------+-------+-------+-----------+--------+--------+
|
||||
|
||||
vLLM Triton Implementation:
|
||||
+------+-------+-------+-----------+--------+--------+--------------+
|
||||
| m | n | k | Time (μs) | TFLOPS | GB/s | vs DeepGEMM |
|
||||
+------+-------+-------+-----------+--------+--------+--------------+
|
||||
| 8 | 4096 | 7168 | 74.0 | 6.3 | 398.2 | 1.39x faster |
|
||||
| 8 | 7168 | 18432 | 89.6 | 23.6 | 1478.1 | 0.79x slower |
|
||||
| 8 | 18432 | 7168 | 113.2 | 18.7 | 1170.4 | 0.61x slower |
|
||||
| 64 | 4096 | 7168 | 79.4 | 47.3 | 382.2 | 0.87x slower |
|
||||
| 64 | 7168 | 18432 | 98.5 | 171.7 | 1363.0 | 0.70x slower |
|
||||
| 64 | 18432 | 7168 | 119.5 | 141.5 | 1129.4 | 0.59x slower |
|
||||
| 64 | 24576 | 1536 | 37.6 | 128.4 | 1089.7 | 1.86x faster |
|
||||
| 64 | 32768 | 512 | 38.7 | 55.5 | 542.6 | 1.77x faster |
|
||||
| 64 | 7168 | 16384 | 86.1 | 174.5 | 1386.4 | 0.81x slower |
|
||||
| 128 | 4096 | 7168 | 90.7 | 82.9 | 345.4 | 1.56x faster |
|
||||
| 128 | 7168 | 18432 | 144.0 | 234.9 | 946.9 | 0.50x slower |
|
||||
| 128 | 18432 | 7168 | 229.5 | 147.4 | 600.1 | 0.30x slower |
|
||||
| 1024 | 4096 | 7168 | 242.3 | 248.2 | 186.1 | 0.37x slower |
|
||||
| 1024 | 18432 | 7168 | 897.8 | 301.4 | 197.4 | 0.31x slower |
|
||||
| 2048 | 4096 | 7168 | 463.0 | 259.7 | 131.4 | 0.38x slower |
|
||||
| 4096 | 4096 | 7168 | 901.8 | 266.7 | 102.3 | 0.37x slower |
|
||||
+------+-------+-------+-----------+--------+--------+--------------+
|
||||
|
||||
vLLM CUTLASS Implementation:
|
||||
+------+-------+-------+-----------+--------+--------+--------------+--------------+
|
||||
| m | n | k | Time (μs) | TFLOPS | GB/s | vs DeepGEMM | vs Triton |
|
||||
+------+-------+-------+-----------+--------+--------+--------------+--------------+
|
||||
| 8 | 4096 | 7168 | 34.6 | 13.6 | 852.3 | 2.98x faster | 2.14x faster |
|
||||
| 8 | 7168 | 18432 | 78.9 | 26.8 | 1677.3 | 0.90x slower | 1.13x faster |
|
||||
| 8 | 18432 | 7168 | 81.2 | 26.0 | 1631.1 | 0.85x slower | 1.39x faster |
|
||||
| 64 | 4096 | 7168 | 36.9 | 101.9 | 822.9 | 1.87x faster | 2.15x faster |
|
||||
| 64 | 7168 | 18432 | 87.4 | 193.4 | 1535.2 | 0.79x slower | 1.13x faster |
|
||||
| 64 | 18432 | 7168 | 85.0 | 199.0 | 1587.6 | 0.83x slower | 1.41x faster |
|
||||
| 64 | 24576 | 1536 | 28.0 | 172.8 | 1465.8 | 2.51x faster | 1.35x faster |
|
||||
| 64 | 32768 | 512 | 28.8 | 74.5 | 728.5 | 2.37x faster | 1.34x faster |
|
||||
| 64 | 7168 | 16384 | 77.9 | 193.0 | 1532.8 | 0.89x slower | 1.11x faster |
|
||||
| 128 | 4096 | 7168 | 39.1 | 192.4 | 802.0 | 3.61x faster | 2.32x faster |
|
||||
| 128 | 7168 | 18432 | 93.7 | 360.8 | 1454.2 | 0.77x slower | 1.54x faster |
|
||||
| 128 | 18432 | 7168 | 85.7 | 394.8 | 1608.0 | 0.81x slower | 2.68x faster |
|
||||
| 1024 | 4096 | 7168 | 99.7 | 603.1 | 452.2 | 0.90x slower | 2.43x faster |
|
||||
| 1024 | 18432 | 7168 | 331.3 | 816.7 | 534.9 | 0.84x slower | 2.71x faster |
|
||||
| 2048 | 4096 | 7168 | 198.3 | 606.6 | 306.7 | 0.88x slower | 2.34x faster |
|
||||
| 4096 | 4096 | 7168 | 392.2 | 613.2 | 235.3 | 0.86x slower | 2.30x faster |
|
||||
+------+-------+-------+-----------+--------+--------+--------------+--------------+
|
||||
|
||||
===== AVERAGE PERFORMANCE =====
|
||||
+----------------+------------+----------+---------------+
|
||||
| Implementation | Avg TFLOPS | Avg GB/s | Avg Time (ms) |
|
||||
+----------------+------------+----------+---------------+
|
||||
| DeepGEMM | 310.98 | 1052.10 | 0.11 |
|
||||
| vLLM Triton | 144.30 | 715.60 | 0.23 |
|
||||
| vLLM CUTLASS | 286.78 | 1076.67 | 0.11 |
|
||||
+----------------+------------+----------+---------------+
|
||||
|
||||
===== AVERAGE SPEEDUPS =====
|
||||
+-----------------------------+--------------+
|
||||
| Comparison | Speedup |
|
||||
+-----------------------------+--------------+
|
||||
| DeepGEMM vs vLLM Triton | 1.71x faster |
|
||||
| DeepGEMM vs vLLM CUTLASS | 0.94x slower |
|
||||
| vLLM CUTLASS vs vLLM Triton | 1.84x faster |
|
||||
+-----------------------------+--------------+
|
||||
|
||||
===== ACCURACY COMPARISON =====
|
||||
+----------------+-----------------------+
|
||||
| Implementation | Avg Diff vs Reference |
|
||||
+----------------+-----------------------+
|
||||
| DeepGEMM | 0.000684 |
|
||||
| vLLM Triton | 0.000684 |
|
||||
| vLLM CUTLASS | 0.000684 |
|
||||
+----------------+-----------------------+
|
||||
```
|
||||
464
benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py
Normal file
464
benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py
Normal file
@ -0,0 +1,464 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# fmt: off
|
||||
# ruff: noqa: E501
|
||||
import time
|
||||
|
||||
# Import DeepGEMM functions
|
||||
import deep_gemm
|
||||
import torch
|
||||
import triton
|
||||
from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor
|
||||
|
||||
# Import vLLM functions
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
|
||||
|
||||
|
||||
# Copied from
|
||||
# https://github.com/deepseek-ai/DeepGEMM/blob/78cacf70d41d15d688bd493ebc85845f7f2a3d5d/tests/test_core.py#L9
|
||||
def per_token_cast_to_fp8(
|
||||
x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Convert tensor to FP8 format with per-token scaling."""
|
||||
assert x.dim() == 2 and x.size(1) % 128 == 0
|
||||
m, n = x.shape
|
||||
x_view = x.view(m, -1, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
||||
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(
|
||||
torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)
|
||||
|
||||
|
||||
# Copied from
|
||||
# https://github.com/deepseek-ai/DeepGEMM/blob/78cacf70d41d15d688bd493ebc85845f7f2a3d5d/tests/test_core.py#L17
|
||||
def per_block_cast_to_fp8(
|
||||
x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Convert tensor to FP8 format with per-block scaling."""
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128),
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
||||
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (
|
||||
x_amax / 448.0).view(x_view.size(0), x_view.size(2))
|
||||
|
||||
|
||||
def benchmark_shape(m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
warmup: int = 100,
|
||||
repeat: int = 10000,
|
||||
verbose: bool = False) -> dict:
|
||||
"""Benchmark all implementations for a specific (m, n, k) shape."""
|
||||
if verbose:
|
||||
print(f"\n=== Benchmarking shape: m={m}, n={n}, k={k} ===")
|
||||
|
||||
# Create test tensors
|
||||
A = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
|
||||
B = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
|
||||
|
||||
# Reference result in BF16
|
||||
torch.cuda.synchronize()
|
||||
C_ref = A @ B.t()
|
||||
|
||||
# Pre-quantize B for all implementations
|
||||
# (weights can be pre-quantized offline)
|
||||
B_deepgemm, B_scale_deepgemm = per_block_cast_to_fp8(B)
|
||||
B_vllm, B_scale_vllm = per_block_cast_to_fp8(B)
|
||||
|
||||
# Block size configuration
|
||||
block_size = [128, 128]
|
||||
|
||||
# Pre-quantize A for all implementations
|
||||
A_deepgemm, A_scale_deepgemm = per_token_cast_to_fp8(A)
|
||||
A_scale_deepgemm = get_col_major_tma_aligned_tensor(A_scale_deepgemm)
|
||||
C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
|
||||
A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1])
|
||||
A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8(
|
||||
A, block_size[1], column_major_scales=True)
|
||||
|
||||
# === DeepGEMM Implementation ===
|
||||
def deepgemm_gemm():
|
||||
# A quantization is inside the loop as it depends on activations
|
||||
# A_deepgemm, A_scale_deepgemm = per_token_cast_to_fp8(A)
|
||||
# A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8(
|
||||
# A, block_size[1])
|
||||
# A_scale_aligned = get_col_major_tma_aligned_tensor(A_scale_deepgemm)
|
||||
# C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt((A_deepgemm, A_scale_deepgemm),
|
||||
(B_deepgemm, B_scale_deepgemm),
|
||||
C_deepgemm)
|
||||
return C_deepgemm
|
||||
|
||||
# === vLLM Triton Implementation ===
|
||||
def vllm_triton_gemm():
|
||||
# A quantization is inside the loop as it depends on activations
|
||||
# A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1])
|
||||
return w8a8_block_fp8_matmul(A_vllm,
|
||||
B_vllm,
|
||||
A_scale_vllm,
|
||||
B_scale_vllm,
|
||||
block_size,
|
||||
output_dtype=torch.bfloat16)
|
||||
|
||||
# === vLLM CUTLASS Implementation ===
|
||||
def vllm_cutlass_gemm():
|
||||
# A quantization is inside the loop as it depends on activations
|
||||
# A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8(
|
||||
# A, block_size[1], column_major_scales=True)
|
||||
return ops.cutlass_scaled_mm(A_vllm_cutlass,
|
||||
B_vllm.T,
|
||||
scale_a=A_scale_vllm_cutlass,
|
||||
scale_b=B_scale_vllm.T,
|
||||
out_dtype=torch.bfloat16)
|
||||
|
||||
# Run correctness check first
|
||||
if verbose:
|
||||
print("Running correctness check...")
|
||||
C_deepgemm = deepgemm_gemm()
|
||||
C_vllm_triton = vllm_triton_gemm()
|
||||
C_vllm_cutlass = vllm_cutlass_gemm()
|
||||
|
||||
deepgemm_diff = calc_diff(C_deepgemm, C_ref)
|
||||
vllm_triton_diff = calc_diff(C_vllm_triton, C_ref)
|
||||
vllm_cutlass_diff = calc_diff(C_vllm_cutlass, C_ref)
|
||||
|
||||
if verbose:
|
||||
print(f"DeepGEMM vs Reference difference: {deepgemm_diff:.6f}")
|
||||
print(f"vLLM Triton vs Reference difference: {vllm_triton_diff:.6f}")
|
||||
print(f"vLLM CUTLASS vs Reference difference: {vllm_cutlass_diff:.6f}")
|
||||
print("vLLM Triton vs DeepGEMM difference: "
|
||||
f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}")
|
||||
print("vLLM CUTLASS vs DeepGEMM difference: "
|
||||
f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}")
|
||||
|
||||
# Benchmark implementations
|
||||
implementations = {
|
||||
"DeepGEMM": deepgemm_gemm,
|
||||
"vLLM Triton": vllm_triton_gemm,
|
||||
"vLLM CUTLASS": vllm_cutlass_gemm
|
||||
}
|
||||
|
||||
benchmark_results = {
|
||||
"shape": {
|
||||
"m": m,
|
||||
"n": n,
|
||||
"k": k
|
||||
},
|
||||
"implementations": {}
|
||||
}
|
||||
|
||||
for name, func in implementations.items():
|
||||
# Warmup
|
||||
for _ in range(warmup):
|
||||
func()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Timing loop
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
for _ in range(repeat):
|
||||
func()
|
||||
torch.cuda.synchronize()
|
||||
end = time.time()
|
||||
|
||||
# Calculate timing and TFLOPS
|
||||
avg_time_ms = (end - start) / repeat * 1000
|
||||
avg_time_us = avg_time_ms * 1000
|
||||
tflops = 2 * m * n * k / (avg_time_ms * 1e-3) / 1e12
|
||||
gb_s = (m * k + k * n + m * n * 2) / 1e9 / (avg_time_ms * 1e-3)
|
||||
|
||||
benchmark_results["implementations"][name] = {
|
||||
"time_ms": avg_time_ms,
|
||||
"time_us": avg_time_us,
|
||||
"tflops": tflops,
|
||||
"gb_s": gb_s,
|
||||
"diff": {
|
||||
"DeepGEMM":
|
||||
0.0 if name == "DeepGEMM" else calc_diff(func(), C_deepgemm),
|
||||
"Reference":
|
||||
deepgemm_diff if name == "DeepGEMM" else
|
||||
(vllm_triton_diff
|
||||
if name == "vLLM Triton" else vllm_cutlass_diff)
|
||||
}
|
||||
}
|
||||
|
||||
if verbose:
|
||||
print(
|
||||
f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {gb_s:.2f} GB/s"
|
||||
)
|
||||
|
||||
# Calculate speedups
|
||||
baseline = benchmark_results["implementations"]["DeepGEMM"]["time_ms"]
|
||||
for name, data in benchmark_results["implementations"].items():
|
||||
if name != "DeepGEMM":
|
||||
speedup = baseline / data["time_ms"]
|
||||
benchmark_results["implementations"][name][
|
||||
"speedup_vs_deepgemm"] = speedup
|
||||
if verbose:
|
||||
print(f"DeepGEMM is {1/speedup:.2f}x "
|
||||
f"{'faster' if 1/speedup > 1 else 'slower'} than {name}")
|
||||
|
||||
vllm_triton_time = benchmark_results["implementations"]["vLLM Triton"][
|
||||
"time_ms"]
|
||||
vllm_cutlass_time = benchmark_results["implementations"]["vLLM CUTLASS"][
|
||||
"time_ms"]
|
||||
cutlass_vs_triton = vllm_triton_time / vllm_cutlass_time
|
||||
benchmark_results["implementations"]["vLLM CUTLASS"][
|
||||
"speedup_vs_triton"] = cutlass_vs_triton
|
||||
if verbose:
|
||||
print(
|
||||
f"vLLM CUTLASS is {cutlass_vs_triton:.2f}x "
|
||||
f"{'faster' if cutlass_vs_triton > 1 else 'slower'} than vLLM Triton"
|
||||
)
|
||||
|
||||
return benchmark_results
|
||||
|
||||
|
||||
def format_table_row(values, widths):
|
||||
"""Format a row with specified column widths."""
|
||||
return "| " + " | ".join(f"{val:{w}}"
|
||||
for val, w in zip(values, widths)) + " |"
|
||||
|
||||
|
||||
def print_table(headers, rows, title=None):
|
||||
"""Print a table with headers and rows."""
|
||||
if title:
|
||||
print(f"\n{title}")
|
||||
|
||||
# Calculate column widths based on headers and data
|
||||
widths = [
|
||||
max(len(str(h)), max(len(str(row[i])) for row in rows))
|
||||
for i, h in enumerate(headers)
|
||||
]
|
||||
|
||||
# Create separator line
|
||||
separator = "+-" + "-+-".join("-" * w for w in widths) + "-+"
|
||||
|
||||
# Print table
|
||||
print(separator)
|
||||
print(format_table_row(headers, widths))
|
||||
print(separator)
|
||||
for row in rows:
|
||||
print(format_table_row(row, widths))
|
||||
print(separator)
|
||||
|
||||
|
||||
def format_speedup(value):
|
||||
"""Format speedup value with indicator if it's faster or slower."""
|
||||
return f"{value:.2f}x {'faster' if value > 1.0 else 'slower'}"
|
||||
|
||||
|
||||
def run_benchmarks(verbose: bool = False):
|
||||
"""Run benchmarks for a set of common shapes."""
|
||||
print("===== STARTING FP8 GEMM BENCHMARK =====")
|
||||
|
||||
# Make sure we're using the GPU
|
||||
if not torch.cuda.is_available():
|
||||
print("CUDA not available! Tests require GPU.")
|
||||
return
|
||||
|
||||
# Print system information
|
||||
print(f"PyTorch version: {torch.__version__}")
|
||||
print(f"CUDA version: {torch.version.cuda}")
|
||||
print(f"Triton version: {triton.__version__}")
|
||||
print(f"Using device: {torch.cuda.get_device_name()}")
|
||||
|
||||
# Enable TF32 for better performance
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
# Set seeds for reproducibility
|
||||
torch.manual_seed(42)
|
||||
torch.cuda.manual_seed(42)
|
||||
|
||||
# Define benchmark shapes (m, n, k)
|
||||
shapes = [
|
||||
(8, 4096, 7168),
|
||||
(8, 7168, 18432),
|
||||
(8, 18432, 7168),
|
||||
(64, 4096, 7168),
|
||||
(64, 7168, 18432),
|
||||
(64, 18432, 7168),
|
||||
(64, 24576, 1536),
|
||||
(64, 32768, 512),
|
||||
(64, 7168, 16384),
|
||||
(128, 4096, 7168),
|
||||
(128, 7168, 18432),
|
||||
(128, 18432, 7168),
|
||||
(1024, 4096, 7168),
|
||||
(1024, 18432, 7168),
|
||||
(2048, 4096, 7168),
|
||||
(4096, 4096, 7168),
|
||||
]
|
||||
shapes = [
|
||||
# (64, 2112, 7168),
|
||||
(64, 24576, 1536),
|
||||
(64, 32768, 512),
|
||||
(64, 7168, 16384),
|
||||
(64, 4096, 7168),
|
||||
(64, 7168, 2048),
|
||||
# (128, 2112, 7168),
|
||||
(128, 24576, 1536),
|
||||
(128, 32768, 512),
|
||||
(128, 7168, 16384),
|
||||
(128, 4096, 7168),
|
||||
(128, 7168, 2048),
|
||||
# (4096, 2112, 7168),
|
||||
(4096, 24576, 1536),
|
||||
(4096, 32768, 512),
|
||||
(4096, 7168, 16384),
|
||||
(4096, 4096, 7168),
|
||||
(4096, 7168, 2048),
|
||||
]
|
||||
|
||||
all_results = []
|
||||
for m, n, k in shapes:
|
||||
result = benchmark_shape(m, n, k, verbose=verbose)
|
||||
all_results.append(result)
|
||||
|
||||
# Print results in a nicely formatted table
|
||||
print("\n===== PERFORMANCE COMPARISON =====")
|
||||
|
||||
# Print DeepGEMM table
|
||||
deepgemm_headers = ["m", "n", "k", "Time (μs)", "TFLOPS", "GB/s"]
|
||||
deepgemm_rows = []
|
||||
for result in all_results:
|
||||
shape = result["shape"]
|
||||
impl_data = result["implementations"]["DeepGEMM"]
|
||||
deepgemm_rows.append([
|
||||
shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}",
|
||||
f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}"
|
||||
])
|
||||
|
||||
print_table(deepgemm_headers,
|
||||
deepgemm_rows,
|
||||
title="DeepGEMM Implementation:")
|
||||
|
||||
# Print vLLM Triton table
|
||||
triton_headers = [
|
||||
"m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM"
|
||||
]
|
||||
triton_rows = []
|
||||
for result in all_results:
|
||||
shape = result["shape"]
|
||||
impl_data = result["implementations"]["vLLM Triton"]
|
||||
speedup = impl_data.get("speedup_vs_deepgemm", 1.0)
|
||||
triton_rows.append([
|
||||
shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}",
|
||||
f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}",
|
||||
format_speedup(speedup)
|
||||
])
|
||||
|
||||
print_table(triton_headers,
|
||||
triton_rows,
|
||||
title="vLLM Triton Implementation:")
|
||||
|
||||
# Print vLLM CUTLASS table
|
||||
cutlass_headers = [
|
||||
"m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM",
|
||||
"vs Triton"
|
||||
]
|
||||
cutlass_rows = []
|
||||
for result in all_results:
|
||||
shape = result["shape"]
|
||||
impl_data = result["implementations"]["vLLM CUTLASS"]
|
||||
vs_deepgemm = impl_data.get("speedup_vs_deepgemm", 1.0)
|
||||
vs_triton = impl_data.get("speedup_vs_triton", 1.0)
|
||||
cutlass_rows.append([
|
||||
shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}",
|
||||
f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}",
|
||||
format_speedup(vs_deepgemm),
|
||||
format_speedup(vs_triton)
|
||||
])
|
||||
|
||||
print_table(cutlass_headers,
|
||||
cutlass_rows,
|
||||
title="vLLM CUTLASS Implementation:")
|
||||
|
||||
# Calculate and print averages
|
||||
print("\n===== AVERAGE PERFORMANCE =====")
|
||||
|
||||
implementations = ["DeepGEMM", "vLLM Triton", "vLLM CUTLASS"]
|
||||
avg_metrics = {
|
||||
impl: {
|
||||
"tflops": 0,
|
||||
"gb_s": 0,
|
||||
"time_ms": 0
|
||||
}
|
||||
for impl in implementations
|
||||
}
|
||||
|
||||
for result in all_results:
|
||||
for impl in implementations:
|
||||
impl_data = result["implementations"][impl]
|
||||
avg_metrics[impl]["tflops"] += impl_data["tflops"]
|
||||
avg_metrics[impl]["gb_s"] += impl_data["gb_s"]
|
||||
avg_metrics[impl]["time_ms"] += impl_data["time_ms"]
|
||||
|
||||
num_shapes = len(all_results)
|
||||
avg_headers = ["Implementation", "Avg TFLOPS", "Avg GB/s", "Avg Time (ms)"]
|
||||
avg_rows = []
|
||||
|
||||
for impl in implementations:
|
||||
avg_tflops = avg_metrics[impl]["tflops"] / num_shapes
|
||||
avg_mem_bw = avg_metrics[impl]["gb_s"] / num_shapes
|
||||
avg_time = avg_metrics[impl]["time_ms"] / num_shapes
|
||||
avg_rows.append([
|
||||
impl, f"{avg_tflops:.2f}", f"{avg_mem_bw:.2f}", f"{avg_time:.2f}"
|
||||
])
|
||||
|
||||
print_table(avg_headers, avg_rows)
|
||||
|
||||
# Calculate average speedups
|
||||
avg_speedups = {
|
||||
"DeepGEMM vs vLLM Triton": 0,
|
||||
"DeepGEMM vs vLLM CUTLASS": 0,
|
||||
"vLLM CUTLASS vs vLLM Triton": 0
|
||||
}
|
||||
|
||||
for result in all_results:
|
||||
deepgemm_time = result["implementations"]["DeepGEMM"]["time_ms"]
|
||||
vllm_triton_time = result["implementations"]["vLLM Triton"]["time_ms"]
|
||||
vllm_cutlass_time = result["implementations"]["vLLM CUTLASS"][
|
||||
"time_ms"]
|
||||
|
||||
avg_speedups[
|
||||
"DeepGEMM vs vLLM Triton"] += vllm_triton_time / deepgemm_time
|
||||
avg_speedups[
|
||||
"DeepGEMM vs vLLM CUTLASS"] += vllm_cutlass_time / deepgemm_time
|
||||
avg_speedups[
|
||||
"vLLM CUTLASS vs vLLM Triton"] += vllm_triton_time / vllm_cutlass_time
|
||||
|
||||
print("\n===== AVERAGE SPEEDUPS =====")
|
||||
speedup_headers = ["Comparison", "Speedup"]
|
||||
speedup_rows = []
|
||||
for comparison, total in avg_speedups.items():
|
||||
avg_speedup = total / num_shapes
|
||||
status = "faster" if avg_speedup > 1 else "slower"
|
||||
speedup_rows.append([comparison, f"{avg_speedup:.2f}x {status}"])
|
||||
|
||||
print_table(speedup_headers, speedup_rows)
|
||||
|
||||
# Average accuracy comparison
|
||||
print("\n===== ACCURACY COMPARISON =====")
|
||||
avg_diff = {impl: 0 for impl in implementations}
|
||||
|
||||
for result in all_results:
|
||||
for impl in implementations:
|
||||
avg_diff[impl] += result["implementations"][impl]["diff"][
|
||||
"Reference"]
|
||||
|
||||
diff_headers = ["Implementation", "Avg Diff vs Reference"]
|
||||
diff_rows = []
|
||||
for impl in implementations:
|
||||
diff_rows.append([impl, f"{avg_diff[impl] / num_shapes:.6f}"])
|
||||
|
||||
print_table(diff_headers, diff_rows)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_benchmarks(verbose=False)
|
||||
65
benchmarks/run_structured_output_benchmark.sh
Executable file
65
benchmarks/run_structured_output_benchmark.sh
Executable file
@ -0,0 +1,65 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Define the model to use
|
||||
MODEL=${1:-"Qwen/Qwen2.5-7B-Instruct"}
|
||||
|
||||
# Define the backend to use
|
||||
BACKEND=${2:-"vllm"}
|
||||
|
||||
# Define the dataset to use
|
||||
DATASET=${3:-"xgrammar_bench"}
|
||||
|
||||
# Define the guided decoding backend
|
||||
GUIDED_BACKEND=${4:-"xgrammar"}
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
OUTPUT_DIR=${5:-"$SCRIPT_DIR/structured_output_benchmark_results"}
|
||||
|
||||
GUIDED_RATIO=${6:-0.5}
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
mkdir -p "$OUTPUT_DIR"
|
||||
|
||||
# Define QPS values to test
|
||||
QPS_VALUES=(70 60 50 25 20 15 10)
|
||||
|
||||
# Common parameters
|
||||
COMMON_PARAMS="--backend $BACKEND \
|
||||
--model $MODEL \
|
||||
--dataset $DATASET \
|
||||
--structured-output-backend $GUIDED_BACKEND \
|
||||
--structured-output-ratio $GUIDED_RATIO \
|
||||
--save-results \
|
||||
--result-dir $OUTPUT_DIR"
|
||||
|
||||
echo "Starting structured output benchmark with model: $MODEL"
|
||||
echo "Backend: $BACKEND"
|
||||
echo "Dataset: $DATASET"
|
||||
echo "Structured output backend: $GUIDED_BACKEND"
|
||||
echo "Results will be saved to: $OUTPUT_DIR"
|
||||
echo "----------------------------------------"
|
||||
|
||||
# Run benchmarks with different QPS values
|
||||
for qps in "${QPS_VALUES[@]}"; do
|
||||
echo "Running benchmark with QPS: $qps"
|
||||
|
||||
# Get git hash and branch for the filename
|
||||
GIT_HASH=$(git rev-parse --short HEAD 2>/dev/null || echo "unknown")
|
||||
GIT_BRANCH=$(git rev-parse --abbrev-ref HEAD 2>/dev/null || echo "unknown")
|
||||
|
||||
# Construct filename for this run
|
||||
FILENAME="${GUIDED_BACKEND}_${BACKEND}_${qps}qps_$(basename $MODEL)_${DATASET}_${GIT_HASH}.json"
|
||||
|
||||
# Run the benchmark
|
||||
python "$SCRIPT_DIR/benchmark_serving_structured_output.py" $COMMON_PARAMS \
|
||||
--request-rate $qps \
|
||||
--result-filename "$FILENAME" \
|
||||
--tokenizer-mode ${TOKENIZER_MODE:-"auto"} \
|
||||
--port ${PORT:-8000}
|
||||
|
||||
echo "Completed benchmark with QPS: $qps"
|
||||
echo "----------------------------------------"
|
||||
done
|
||||
|
||||
echo "All benchmarks completed!"
|
||||
echo "Results saved to: $OUTPUT_DIR"
|
||||
@ -1,113 +1,19 @@
|
||||
{
|
||||
"$schema":
|
||||
"https://json-schema.org/draft/2020-12/schema",
|
||||
"title":
|
||||
"User Profile",
|
||||
"type":
|
||||
"object",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"userId": {
|
||||
"type": "string",
|
||||
"description": "Unique identifier for the user."
|
||||
},
|
||||
"personalInfo": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"firstName": {
|
||||
"type": "string",
|
||||
"description": "The user's first name."
|
||||
},
|
||||
"lastName": {
|
||||
"type": "string",
|
||||
"description": "The user's last name."
|
||||
},
|
||||
"age": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"description": "The user's age."
|
||||
},
|
||||
"phoneNumbers": {
|
||||
"type":
|
||||
"array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"enum": ["home", "work", "mobile"],
|
||||
"description": "Type of phone number."
|
||||
},
|
||||
"number": {
|
||||
"type": "string",
|
||||
"pattern": "^\\+?[1-9]\\d{1,14}$",
|
||||
"description": "Phone number in E.164 format."
|
||||
}
|
||||
},
|
||||
"required": ["type", "number"]
|
||||
},
|
||||
"description":
|
||||
"List of phone numbers associated with the user."
|
||||
}
|
||||
},
|
||||
"required": ["firstName", "lastName"]
|
||||
},
|
||||
"address": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"street": {
|
||||
"type": "string",
|
||||
"description": "Street address."
|
||||
},
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "City name."
|
||||
},
|
||||
"state": {
|
||||
"type": "string",
|
||||
"description": "State or province."
|
||||
},
|
||||
"postalCode": {
|
||||
"type": "string",
|
||||
"pattern": "^\\d{5}(-\\d{4})?$",
|
||||
"description": "Postal code."
|
||||
},
|
||||
"country": {
|
||||
"type": "string",
|
||||
"description": "Country name."
|
||||
}
|
||||
},
|
||||
"required": ["street", "city", "state", "postalCode", "country"]
|
||||
},
|
||||
"preferences": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"newsletterSubscribed": {
|
||||
"type":
|
||||
"boolean",
|
||||
"description":
|
||||
"Indicates if the user is subscribed to the newsletter."
|
||||
},
|
||||
"favoriteCategories": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "List of user's favorite categories."
|
||||
}
|
||||
},
|
||||
"required": ["newsletterSubscribed"]
|
||||
},
|
||||
"accountStatus": {
|
||||
"type": "string",
|
||||
"enum": ["active", "inactive", "suspended"],
|
||||
"description": "Current status of the user's account."
|
||||
},
|
||||
"registrationDate": {
|
||||
"type": "string",
|
||||
"format": "date-time",
|
||||
"description": "ISO 8601 formatted date-time of user registration."
|
||||
}
|
||||
"name": { "type": "string" },
|
||||
"email": { "type": "string" },
|
||||
"street": { "type": "string" },
|
||||
"city": { "type": "string" },
|
||||
"state": { "type": "string" },
|
||||
"zip": { "type": "string" },
|
||||
"phone": { "type": "string" },
|
||||
"website": { "type": "string" },
|
||||
"company": { "type": "string" },
|
||||
"age": { "type": "integer" }
|
||||
},
|
||||
"required":
|
||||
["userId", "personalInfo", "address", "accountStatus", "registrationDate"]
|
||||
}
|
||||
"required": [
|
||||
"name",
|
||||
"email"
|
||||
]
|
||||
}
|
||||
|
||||
@ -81,6 +81,7 @@ else()
|
||||
find_isa(${CPUINFO} "POWER9" POWER9_FOUND)
|
||||
find_isa(${CPUINFO} "asimd" ASIMD_FOUND) # Check for ARM NEON support
|
||||
find_isa(${CPUINFO} "bf16" ARM_BF16_FOUND) # Check for ARM BF16 support
|
||||
find_isa(${CPUINFO} "S390" S390_FOUND)
|
||||
endif()
|
||||
|
||||
|
||||
@ -129,8 +130,16 @@ elseif (ASIMD_FOUND)
|
||||
elseif(APPLE_SILICON_FOUND)
|
||||
message(STATUS "Apple Silicon Detected")
|
||||
set(ENABLE_NUMA OFF)
|
||||
elseif (S390_FOUND)
|
||||
message(STATUS "S390 detected")
|
||||
# Check for S390 VXE support
|
||||
list(APPEND CXX_COMPILE_FLAGS
|
||||
"-mvx"
|
||||
"-mzvector"
|
||||
"-march=native"
|
||||
"-mtune=native")
|
||||
else()
|
||||
message(FATAL_ERROR "vLLM CPU backend requires AVX512, AVX2, Power9+ ISA or ARMv8 support.")
|
||||
message(FATAL_ERROR "vLLM CPU backend requires AVX512, AVX2, Power9+ ISA, S390X ISA or ARMv8 support.")
|
||||
endif()
|
||||
|
||||
#
|
||||
@ -140,7 +149,7 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
FetchContent_Declare(
|
||||
oneDNN
|
||||
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
|
||||
GIT_TAG v3.6
|
||||
GIT_TAG v3.7.1
|
||||
GIT_PROGRESS TRUE
|
||||
GIT_SHALLOW TRUE
|
||||
)
|
||||
|
||||
@ -38,7 +38,7 @@ else()
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn
|
||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||
GIT_TAG 720c94869cf2e0ff5a706e9c7f1dce0939686ade
|
||||
GIT_TAG dc9d410b3e2d6534a4c70724c2515f4def670a22
|
||||
GIT_PROGRESS TRUE
|
||||
# Don't share the vllm-flash-attn build between build types
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
@ -64,4 +64,4 @@ install(
|
||||
DESTINATION vllm_flash_attn
|
||||
COMPONENT _vllm_fa3_C
|
||||
FILES_MATCHING PATTERN "*.py"
|
||||
)
|
||||
)
|
||||
|
||||
@ -350,8 +350,8 @@ __global__ void concat_and_cache_mla_kernel(
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
// KV_T is the stored data type of kv-cache.
|
||||
// CACHE_T is the data type of key and value tensors.
|
||||
// KV_T is the data type of key and value tensors.
|
||||
// CACHE_T is the stored data type of kv-cache.
|
||||
// KV_DTYPE is the real data type of kv-cache.
|
||||
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||
@ -393,8 +393,8 @@ void reshape_and_cache(
|
||||
CALL_RESHAPE_AND_CACHE)
|
||||
}
|
||||
|
||||
// KV_T is the stored data type of kv-cache.
|
||||
// CACHE_T is the data type of key and value tensors.
|
||||
// KV_T is the data type of key and value tensors.
|
||||
// CACHE_T is the stored data type of kv-cache.
|
||||
// KV_DTYPE is the real data type of kv-cache.
|
||||
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||
@ -446,8 +446,8 @@ void reshape_and_cache_flash(
|
||||
CALL_RESHAPE_AND_CACHE_FLASH);
|
||||
}
|
||||
|
||||
// KV_T is the stored data type of kv-cache.
|
||||
// CACHE_T is the data type of key and value tensors.
|
||||
// KV_T is the data type of key and value tensors.
|
||||
// CACHE_T is the stored data type of kv-cache.
|
||||
// KV_DTYPE is the real data type of kv-cache.
|
||||
#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||
|
||||
@ -24,8 +24,8 @@ struct KernelVecType<float> {
|
||||
|
||||
template <>
|
||||
struct KernelVecType<c10::Half> {
|
||||
#ifdef __powerpc64__
|
||||
// Power architecture-specific vector types
|
||||
#if defined(__powerpc64__) || defined(__s390x__)
|
||||
// Power and s390x architecture-specific vector types
|
||||
using q_load_vec_type = vec_op::FP32Vec8;
|
||||
using k_load_vec_type = vec_op::FP32Vec16;
|
||||
using v_load_vec_type = vec_op::FP32Vec16;
|
||||
|
||||
@ -3,6 +3,12 @@
|
||||
|
||||
#include "cpu_types.hpp"
|
||||
|
||||
#if defined(__x86_64__)
|
||||
#define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES_WITH_E5M2
|
||||
#else
|
||||
#define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
template <typename scalar_t>
|
||||
void copy_blocks_cpu_impl(std::vector<torch::Tensor> const& key_caches,
|
||||
@ -95,13 +101,12 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
|
||||
}
|
||||
|
||||
const int element_num_per_block = key_caches[0][0].numel();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
|
||||
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, block_mapping,
|
||||
element_num_per_block, num_layers);
|
||||
CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl)
|
||||
});
|
||||
DISPATCH_MACRO(key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
|
||||
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, block_mapping,
|
||||
element_num_per_block, num_layers);
|
||||
CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl)
|
||||
});
|
||||
}
|
||||
|
||||
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||
@ -118,16 +123,15 @@ void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||
int key_stride = key.stride(0);
|
||||
int value_stride = value.stride(0);
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
key.scalar_type(), "reshape_and_cache_cpu_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl)
|
||||
reshape_and_cache_cpu_impl<scalar_t>(
|
||||
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
|
||||
slot_mapping.data_ptr<int64_t>(), num_tokens, key_stride,
|
||||
value_stride, num_heads, head_size, block_size, x);
|
||||
CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl)
|
||||
});
|
||||
DISPATCH_MACRO(key.scalar_type(), "reshape_and_cache_cpu_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl)
|
||||
reshape_and_cache_cpu_impl<scalar_t>(
|
||||
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
|
||||
slot_mapping.data_ptr<int64_t>(), num_tokens, key_stride, value_stride,
|
||||
num_heads, head_size, block_size, x);
|
||||
CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl)
|
||||
});
|
||||
}
|
||||
|
||||
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
||||
|
||||
@ -7,6 +7,9 @@
|
||||
#elif defined(__POWER9_VECTOR__)
|
||||
// ppc implementation
|
||||
#include "cpu_types_vsx.hpp"
|
||||
#elif defined(__s390x__)
|
||||
// s390 implementation
|
||||
#include "cpu_types_vxe.hpp"
|
||||
#elif defined(__aarch64__)
|
||||
// arm implementation
|
||||
#include "cpu_types_arm.hpp"
|
||||
|
||||
480
csrc/cpu/cpu_types_vxe.hpp
Normal file
480
csrc/cpu/cpu_types_vxe.hpp
Normal file
@ -0,0 +1,480 @@
|
||||
|
||||
#ifndef CPU_TYPES_VXE_HPP
|
||||
#define CPU_TYPES_VXE_HPP
|
||||
|
||||
#include <vecintrin.h>
|
||||
#include <cmath>
|
||||
#include <torch/all.h>
|
||||
namespace vec_op {
|
||||
|
||||
#define vec_neg(a) (-(a))
|
||||
#define vec_add(a, b) ((a) + (b))
|
||||
#define vec_sub(a, b) ((a) - (b))
|
||||
#define vec_mul(a, b) ((a) * (b))
|
||||
#define vec_div(a, b) ((a) / (b))
|
||||
#define vec_sr(a, b) ((a) >> (b)) // Vector Shift Right Algebaic
|
||||
#define vec_sl(a, b) ((a) << (b)) // Vector Shift Left
|
||||
|
||||
// FIXME: FP16 is not fully supported in Torch-CPU
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||
|
||||
#ifndef CPU_OP_GUARD
|
||||
#define CPU_KERNEL_GUARD_IN(NAME)
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME)
|
||||
#else
|
||||
#define CPU_KERNEL_GUARD_IN(NAME) \
|
||||
std::cout << #NAME << " invoked." << std::endl;
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME) \
|
||||
std::cout << #NAME << " exit." << std::endl;
|
||||
#endif
|
||||
|
||||
#define FORCE_INLINE __attribute__((always_inline)) inline
|
||||
|
||||
namespace {
|
||||
template <typename T, T... indexes, typename F>
|
||||
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F&& f) {
|
||||
(f(std::integral_constant<T, indexes>{}), ...);
|
||||
}
|
||||
}; // namespace
|
||||
|
||||
template <typename T, T count, typename F,
|
||||
typename = std::enable_if_t<std::is_invocable_v<F, T>>>
|
||||
constexpr void unroll_loop(F&& f) {
|
||||
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct Vec {
|
||||
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }
|
||||
};
|
||||
|
||||
typedef struct ss16x8x2_t {
|
||||
__vector signed short val[2];
|
||||
} ss16x8x2_t;
|
||||
|
||||
typedef struct ss16x8x4_t {
|
||||
__vector signed short val[4];
|
||||
} ss16x8x4_t;
|
||||
|
||||
typedef struct f32x4x2_t {
|
||||
__vector float val[2];
|
||||
} f32x4x2_t;
|
||||
|
||||
typedef struct f32x4x4_t {
|
||||
__vector float val[4];
|
||||
} f32x4x4_t;
|
||||
|
||||
struct FP32Vec8;
|
||||
struct FP32Vec16;
|
||||
|
||||
struct BF16Vec8 : public Vec<BF16Vec8> {
|
||||
constexpr static int VEC_ELEM_NUM = 8;
|
||||
|
||||
__vector signed short reg;
|
||||
|
||||
explicit BF16Vec8(const void* ptr) : reg(*(__vector signed short*)ptr) {}
|
||||
explicit BF16Vec8(const FP32Vec8&);
|
||||
|
||||
void save(void* ptr) const {
|
||||
*reinterpret_cast<__vector signed short*>(ptr) = reg;
|
||||
}
|
||||
};
|
||||
|
||||
struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||
constexpr static int VEC_ELEM_NUM = 16;
|
||||
|
||||
ss16x8x2_t reg;
|
||||
|
||||
explicit BF16Vec16(const void* ptr) {
|
||||
// Load 256 bits in two parts
|
||||
reg.val[0] = (__vector signed short)vec_xl(0, (signed short*)ptr);
|
||||
reg.val[1] = (__vector signed short)vec_xl(16, (signed short*)ptr);
|
||||
}
|
||||
|
||||
explicit BF16Vec16(const FP32Vec16&);
|
||||
|
||||
void save(void* ptr) const {
|
||||
// Save 256 bits in two parts
|
||||
vec_xst(reg.val[0], 0, (signed short*)ptr);
|
||||
vec_xst(reg.val[1], 16, (signed short*)ptr);
|
||||
}
|
||||
};
|
||||
|
||||
const static __vector signed short zero = vec_splats((signed short)0);
|
||||
|
||||
struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
constexpr static int VEC_ELEM_NUM = 32;
|
||||
|
||||
ss16x8x4_t reg;
|
||||
explicit BF16Vec32(const void* ptr)
|
||||
: reg(*reinterpret_cast<const ss16x8x4_t*>(ptr)) {}
|
||||
|
||||
explicit BF16Vec32(ss16x8x4_t data) : reg(data) {}
|
||||
|
||||
explicit BF16Vec32(const BF16Vec8& vec8_data)
|
||||
: reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {}
|
||||
|
||||
void save(void* ptr) const { *reinterpret_cast<ss16x8x4_t*>(ptr) = reg; }
|
||||
};
|
||||
|
||||
struct FP32Vec4 : public Vec<FP32Vec4> {
|
||||
constexpr static int VEC_ELEM_NUM = 4;
|
||||
union AliasReg {
|
||||
__vector float reg;
|
||||
float values[VEC_ELEM_NUM];
|
||||
};
|
||||
|
||||
__vector float reg;
|
||||
|
||||
explicit FP32Vec4(float v) : reg(vec_splats(v)) {}
|
||||
|
||||
explicit FP32Vec4() : reg(vec_splats(0.0f)) {}
|
||||
|
||||
explicit FP32Vec4(const float* ptr) : reg(vec_xl(0, ptr)) {}
|
||||
|
||||
explicit FP32Vec4(__vector float data) : reg(data) {}
|
||||
|
||||
explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {}
|
||||
};
|
||||
|
||||
struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
constexpr static int VEC_ELEM_NUM = 8;
|
||||
union AliasReg {
|
||||
f32x4x2_t reg;
|
||||
float values[VEC_ELEM_NUM];
|
||||
};
|
||||
|
||||
f32x4x2_t reg;
|
||||
|
||||
explicit FP32Vec8(float v) {
|
||||
reg.val[0] = vec_splats(v);
|
||||
reg.val[1] = vec_splats(v);
|
||||
}
|
||||
|
||||
explicit FP32Vec8() {
|
||||
reg.val[0] = vec_splats(0.0f);
|
||||
reg.val[1] = vec_splats(0.0f);
|
||||
}
|
||||
|
||||
explicit FP32Vec8(const float* ptr) {
|
||||
reg.val[0] = vec_xl(0, ptr);
|
||||
reg.val[1] = vec_xl(16, ptr);
|
||||
}
|
||||
|
||||
explicit FP32Vec8(f32x4x2_t data) : reg(data) {}
|
||||
|
||||
explicit FP32Vec8(const FP32Vec8& data) {
|
||||
reg.val[0] = data.reg.val[0];
|
||||
reg.val[1] = data.reg.val[1];
|
||||
}
|
||||
|
||||
explicit FP32Vec8(const BF16Vec8& v) {
|
||||
reg.val[0] = (__vector float)vec_mergeh(zero, v.reg);
|
||||
reg.val[1] = (__vector float)vec_mergel(zero, v.reg);
|
||||
}
|
||||
|
||||
float reduce_sum() const {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
float result = 0;
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&result, &ar](int i) { result += ar.values[i]; });
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
FP32Vec8 exp() const {
|
||||
// TODO: Vectorize this
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
f32x4x4_t ret;
|
||||
ret.val[0][0] = std::exp(ar.values[0]);
|
||||
ret.val[0][1] = std::exp(ar.values[1]);
|
||||
ret.val[0][2] = std::exp(ar.values[2]);
|
||||
ret.val[0][3] = std::exp(ar.values[3]);
|
||||
ret.val[1][0] = std::exp(ar.values[4]);
|
||||
ret.val[1][1] = std::exp(ar.values[5]);
|
||||
ret.val[1][2] = std::exp(ar.values[6]);
|
||||
ret.val[1][3] = std::exp(ar.values[7]);
|
||||
return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
|
||||
}
|
||||
|
||||
FP32Vec8 tanh() const {
|
||||
// TODO: Vectorize this
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
f32x4x4_t ret;
|
||||
ret.val[0][0] = std::tanh(ar.values[0]);
|
||||
ret.val[0][1] = std::tanh(ar.values[1]);
|
||||
ret.val[0][2] = std::tanh(ar.values[2]);
|
||||
ret.val[0][3] = std::tanh(ar.values[3]);
|
||||
ret.val[1][0] = std::tanh(ar.values[4]);
|
||||
ret.val[1][1] = std::tanh(ar.values[5]);
|
||||
ret.val[1][2] = std::tanh(ar.values[6]);
|
||||
ret.val[1][3] = std::tanh(ar.values[7]);
|
||||
return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
|
||||
}
|
||||
|
||||
FP32Vec8 er() const {
|
||||
// TODO: Vectorize this
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
f32x4x4_t ret;
|
||||
ret.val[0][0] = std::erf(ar.values[0]);
|
||||
ret.val[0][1] = std::erf(ar.values[1]);
|
||||
ret.val[0][2] = std::erf(ar.values[2]);
|
||||
ret.val[0][3] = std::erf(ar.values[3]);
|
||||
ret.val[1][0] = std::erf(ar.values[4]);
|
||||
ret.val[1][1] = std::erf(ar.values[5]);
|
||||
ret.val[1][2] = std::erf(ar.values[6]);
|
||||
ret.val[1][3] = std::erf(ar.values[7]);
|
||||
return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
|
||||
}
|
||||
|
||||
FP32Vec8 operator*(const FP32Vec8& b) const {
|
||||
return FP32Vec8(
|
||||
{vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1])});
|
||||
}
|
||||
|
||||
FP32Vec8 operator+(const FP32Vec8& b) const {
|
||||
return FP32Vec8(
|
||||
{vec_add(reg.val[0], b.reg.val[0]), vec_add(reg.val[1], b.reg.val[1])});
|
||||
}
|
||||
|
||||
FP32Vec8 operator-(const FP32Vec8& b) const {
|
||||
return FP32Vec8(
|
||||
{vec_sub(reg.val[0], b.reg.val[0]), vec_sub(reg.val[1], b.reg.val[1])});
|
||||
}
|
||||
|
||||
FP32Vec8 operator/(const FP32Vec8& b) const {
|
||||
return FP32Vec8(
|
||||
{vec_div(reg.val[0], b.reg.val[0]), vec_div(reg.val[1], b.reg.val[1])});
|
||||
}
|
||||
|
||||
void save(float* ptr) const {
|
||||
vec_xst(reg.val[0], 0, ptr);
|
||||
vec_xst(reg.val[1], 16, ptr);
|
||||
}
|
||||
};
|
||||
|
||||
struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
constexpr static int VEC_ELEM_NUM = 16;
|
||||
union AliasReg {
|
||||
f32x4x4_t reg;
|
||||
float values[VEC_ELEM_NUM];
|
||||
};
|
||||
|
||||
f32x4x4_t reg;
|
||||
|
||||
explicit FP32Vec16(float v) {
|
||||
reg.val[0] = vec_splats(v);
|
||||
reg.val[1] = vec_splats(v);
|
||||
reg.val[2] = vec_splats(v);
|
||||
reg.val[3] = vec_splats(v);
|
||||
}
|
||||
|
||||
explicit FP32Vec16() {
|
||||
reg.val[0] = vec_splats(0.0f);
|
||||
reg.val[1] = vec_splats(0.0f);
|
||||
reg.val[2] = vec_splats(0.0f);
|
||||
reg.val[3] = vec_splats(0.0f);
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const float* ptr) {
|
||||
reg.val[0] = vec_xl(0, ptr);
|
||||
reg.val[1] = vec_xl(16, ptr);
|
||||
reg.val[2] = vec_xl(32, ptr);
|
||||
reg.val[3] = vec_xl(48, ptr);
|
||||
}
|
||||
|
||||
explicit FP32Vec16(f32x4x4_t data) : reg(data) {}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec16& data) {
|
||||
reg.val[0] = data.reg.val[0];
|
||||
reg.val[1] = data.reg.val[1];
|
||||
reg.val[2] = data.reg.val[2];
|
||||
reg.val[3] = data.reg.val[3];
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec4& data) {
|
||||
reg.val[0] = data.reg;
|
||||
reg.val[1] = data.reg;
|
||||
reg.val[2] = data.reg;
|
||||
reg.val[3] = data.reg;
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec8& data) {
|
||||
reg.val[0] = data.reg.val[0];
|
||||
reg.val[1] = data.reg.val[1];
|
||||
reg.val[2] = data.reg.val[0];
|
||||
reg.val[3] = data.reg.val[1];
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const BF16Vec16& v) {
|
||||
reg.val[0] = (__vector float)vec_mergeh(zero, v.reg.val[0]);
|
||||
reg.val[1] = (__vector float)vec_mergel(zero, v.reg.val[0]);
|
||||
reg.val[2] = (__vector float)vec_mergeh(zero, v.reg.val[1]);
|
||||
reg.val[3] = (__vector float)vec_mergel(zero, v.reg.val[1]);
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
|
||||
|
||||
FP32Vec16 operator*(const FP32Vec16& b) const {
|
||||
return FP32Vec16(f32x4x4_t({vec_mul(reg.val[0], b.reg.val[0]),
|
||||
vec_mul(reg.val[1], b.reg.val[1]),
|
||||
vec_mul(reg.val[2], b.reg.val[2]),
|
||||
vec_mul(reg.val[3], b.reg.val[3])}));
|
||||
}
|
||||
|
||||
FP32Vec16 operator+(const FP32Vec16& b) const {
|
||||
return FP32Vec16(f32x4x4_t({vec_add(reg.val[0], b.reg.val[0]),
|
||||
vec_add(reg.val[1], b.reg.val[1]),
|
||||
vec_add(reg.val[2], b.reg.val[2]),
|
||||
vec_add(reg.val[3], b.reg.val[3])}));
|
||||
}
|
||||
|
||||
FP32Vec16 operator-(const FP32Vec16& b) const {
|
||||
return FP32Vec16(f32x4x4_t({vec_sub(reg.val[0], b.reg.val[0]),
|
||||
vec_sub(reg.val[1], b.reg.val[1]),
|
||||
vec_sub(reg.val[2], b.reg.val[2]),
|
||||
vec_sub(reg.val[3], b.reg.val[3])}));
|
||||
}
|
||||
|
||||
FP32Vec16 operator/(const FP32Vec16& b) const {
|
||||
return FP32Vec16(f32x4x4_t({vec_div(reg.val[0], b.reg.val[0]),
|
||||
vec_div(reg.val[1], b.reg.val[1]),
|
||||
vec_div(reg.val[2], b.reg.val[2]),
|
||||
vec_div(reg.val[3], b.reg.val[3])}));
|
||||
}
|
||||
|
||||
float reduce_sum() const {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
float result = 0;
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&result, &ar](int i) { result += ar.values[i]; });
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
template <int group_size>
|
||||
float reduce_sub_sum(int idx) {
|
||||
static_assert(VEC_ELEM_NUM % group_size == 0);
|
||||
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
float result = 0;
|
||||
const int start = idx * group_size;
|
||||
unroll_loop<int, group_size>(
|
||||
[&result, &start, ar](int i) { result += ar.values[start + i]; });
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
void save(float* ptr) const {
|
||||
vec_xst(reg.val[0], 0, ptr);
|
||||
vec_xst(reg.val[1], 16, ptr);
|
||||
vec_xst(reg.val[2], 32, ptr);
|
||||
vec_xst(reg.val[3], 48, ptr);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct VecType {
|
||||
using vec_type = void;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using vec_t = typename VecType<T>::vec_type;
|
||||
|
||||
template <>
|
||||
struct VecType<float> {
|
||||
using vec_type = FP32Vec8;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct VecType<c10::BFloat16> {
|
||||
using vec_type = BF16Vec8;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void storeFP32(float v, T* ptr) {
|
||||
*ptr = v;
|
||||
}
|
||||
|
||||
inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
|
||||
acc = acc + a * b;
|
||||
}
|
||||
|
||||
namespace c10 {
|
||||
struct BFloat16 {
|
||||
uint16_t value; // Assume BFloat16 is defined as a struct containing a 16-bit
|
||||
// value.
|
||||
};
|
||||
} // namespace c10
|
||||
|
||||
template <>
|
||||
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
|
||||
c10::BFloat16 __attribute__((__may_alias__))* v_ptr =
|
||||
reinterpret_cast<c10::BFloat16*>(&v);
|
||||
*ptr = *(v_ptr + 1);
|
||||
}
|
||||
|
||||
#ifndef __VEC_CLASS_FP_NAN
|
||||
#define __VEC_CLASS_FP_NAN (1 << 6)
|
||||
#endif
|
||||
|
||||
const static __vector unsigned char omask = {2, 3, 6, 7, 10, 11, 14, 15,
|
||||
18, 19, 22, 23, 26, 27, 30, 31};
|
||||
const static __vector unsigned int bias = {0x00007fff, 0x00007fff, 0x00007fff,
|
||||
0x00007fff};
|
||||
const static __vector unsigned int nan = {0x7fc00000, 0x7fc00000, 0x7fc00000,
|
||||
0x7fc00000};
|
||||
const static __vector unsigned int sh16 = {16, 16, 16, 16};
|
||||
const static __vector unsigned int one = {1, 1, 1, 1};
|
||||
|
||||
inline BF16Vec8::BF16Vec8(const FP32Vec8& v) {
|
||||
__vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]);
|
||||
__vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]);
|
||||
int cc;
|
||||
__vector __bool int sel0 =
|
||||
vec_fp_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN, &cc);
|
||||
__vector __bool int sel1 =
|
||||
vec_fp_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN, &cc);
|
||||
inp0 = vec_sel(inp0, nan, sel0) >> sh16;
|
||||
inp1 = vec_sel(inp1, nan, sel1) >> sh16;
|
||||
reg = (__vector signed short)vec_perm(inp0, inp1, omask);
|
||||
}
|
||||
|
||||
inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
|
||||
__vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]);
|
||||
__vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]);
|
||||
__vector unsigned int inp2 = (__vector unsigned int)(v.reg.val[2]);
|
||||
__vector unsigned int inp3 = (__vector unsigned int)(v.reg.val[3]);
|
||||
int cc;
|
||||
__vector __bool int sel0 =
|
||||
vec_fp_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN, &cc);
|
||||
__vector __bool int sel1 =
|
||||
vec_fp_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN, &cc);
|
||||
__vector __bool int sel2 =
|
||||
vec_fp_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN, &cc);
|
||||
__vector __bool int sel3 =
|
||||
vec_fp_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN, &cc);
|
||||
inp0 = vec_sel(inp0, nan, sel0) >> sh16;
|
||||
inp1 = vec_sel(inp1, nan, sel1) >> sh16;
|
||||
inp2 = vec_sel(inp2, nan, sel2) >> sh16;
|
||||
inp3 = vec_sel(inp3, nan, sel3) >> sh16;
|
||||
reg.val[0] = (__vector signed short)vec_perm(inp0, inp1, omask);
|
||||
reg.val[1] = (__vector signed short)vec_perm(inp2, inp3, omask);
|
||||
}
|
||||
|
||||
inline void prefetch(const void* addr) { void __dcbt(const void* addr); }
|
||||
|
||||
}; // namespace vec_op
|
||||
|
||||
#endif
|
||||
@ -16,9 +16,18 @@ namespace vec_op {
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES_FP8(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES_WITH_E5M2(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, \
|
||||
VLLM_DISPATCH_CASE_FLOATING_TYPES_FP8(__VA_ARGS__))
|
||||
|
||||
#ifndef CPU_OP_GUARD
|
||||
#define CPU_KERNEL_GUARD_IN(NAME)
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME)
|
||||
|
||||
@ -170,7 +170,7 @@ void rotary_embedding_gptj_impl(
|
||||
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
||||
torch::Tensor& key, int64_t head_size,
|
||||
torch::Tensor& cos_sin_cache, bool is_neox) {
|
||||
int num_tokens = query.numel() / query.size(-1);
|
||||
int num_tokens = positions.numel();
|
||||
int rot_dim = cos_sin_cache.size(1);
|
||||
int num_heads = query.size(-1) / head_size;
|
||||
int num_kv_heads = key.size(-1) / head_size;
|
||||
|
||||
@ -25,7 +25,7 @@ struct KernelVecType<c10::BFloat16> {
|
||||
|
||||
template <>
|
||||
struct KernelVecType<c10::Half> {
|
||||
#ifdef __powerpc64__
|
||||
#if defined(__powerpc64__) || defined(__s390x__)
|
||||
// Power architecture-specific vector type
|
||||
using load_vec_type = vec_op::FP32Vec16;
|
||||
#else
|
||||
|
||||
@ -402,7 +402,7 @@ struct CollectiveMma<
|
||||
|
||||
// TODO: test `scale_copy_a` with `ScaleMsPerTile` < 128
|
||||
TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{},
|
||||
Layout<Shape<_32, _1>>{}, Layout<Shape<_4, _1>>{}); // (1,1,1)
|
||||
Layout<Shape<_32>>{}, Layout<Shape<_1>>{}); // (1,1,1)
|
||||
TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{},
|
||||
Layout<Shape<_1>>{}, Layout<Shape<_1>>{}); // (1,1,1)
|
||||
ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x);
|
||||
|
||||
@ -6,6 +6,11 @@
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
// Need a special dispatch case macro since we will nest the FP8 dispatch.
|
||||
// Instead of the usual 'scalar_t', this names the dispatched type 'fp8_t'.
|
||||
#define AT_DISPATCH_FP8_CASE(enum_type, ...) \
|
||||
AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, fp8_t, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||
@ -14,17 +19,32 @@
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||
|
||||
// TODO(luka/varun): use FP8_TYPE macro after refactoring
|
||||
#ifndef USE_ROCM
|
||||
// ROCm devices might use either fn or fnuz, so set up dispatch table for both.
|
||||
// A host-based check at runtime will create a preferred FP8 type for ROCm
|
||||
// such that the correct kernel is dispatched.
|
||||
#ifdef USE_ROCM
|
||||
#define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
|
||||
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
|
||||
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
|
||||
#else
|
||||
#define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
|
||||
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
|
||||
#else
|
||||
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
|
||||
#endif
|
||||
|
||||
// When using this dispatch macro, the type is 'fp8_t' not 'scalar_t'.
|
||||
// See AT_DISPATCH_FP8_CASE above.
|
||||
#define VLLM_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__))
|
||||
|
||||
#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__))
|
||||
|
||||
|
||||
@ -21,9 +21,9 @@
|
||||
namespace vllm {
|
||||
|
||||
// TODO(woosuk): Further optimize this kernel.
|
||||
template <typename scalar_t>
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
__global__ void rms_norm_static_fp8_quant_kernel(
|
||||
FP8_TYPE* __restrict__ out, // [..., hidden_size]
|
||||
fp8_type* __restrict__ out, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float* __restrict__ scale, // [1]
|
||||
@ -52,7 +52,7 @@ __global__ void rms_norm_static_fp8_quant_kernel(
|
||||
float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
|
||||
out[blockIdx.x * hidden_size + idx] =
|
||||
scaled_fp8_conversion<true>(out_norm, scale_inv);
|
||||
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
|
||||
}
|
||||
}
|
||||
|
||||
@ -60,10 +60,10 @@ __global__ void rms_norm_static_fp8_quant_kernel(
|
||||
Additional optimizations we can make in this case are
|
||||
packed and vectorized operations, which help with the
|
||||
memory latency bottleneck. */
|
||||
template <typename scalar_t, int width>
|
||||
template <typename scalar_t, int width, typename fp8_type>
|
||||
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
|
||||
fused_add_rms_norm_static_fp8_quant_kernel(
|
||||
FP8_TYPE* __restrict__ out, // [..., hidden_size]
|
||||
fp8_type* __restrict__ out, // [..., hidden_size]
|
||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
@ -114,7 +114,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; ++i) {
|
||||
out[id * width + i] =
|
||||
scaled_fp8_conversion<true>(float(temp.data[i]), scale_inv);
|
||||
scaled_fp8_conversion<true, fp8_type>(float(temp.data[i]), scale_inv);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -122,10 +122,10 @@ fused_add_rms_norm_static_fp8_quant_kernel(
|
||||
/* Generic fused_add_rms_norm_kernel
|
||||
The width field is not used here but necessary for other specializations.
|
||||
*/
|
||||
template <typename scalar_t, int width>
|
||||
template <typename scalar_t, int width, typename fp8_type>
|
||||
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
|
||||
fused_add_rms_norm_static_fp8_quant_kernel(
|
||||
FP8_TYPE* __restrict__ out, // [..., hidden_size]
|
||||
fp8_type* __restrict__ out, // [..., hidden_size]
|
||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
@ -158,7 +158,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
|
||||
float x = (float)residual[blockIdx.x * hidden_size + idx];
|
||||
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
|
||||
out[blockIdx.x * hidden_size + idx] =
|
||||
scaled_fp8_conversion<true>(out_norm, scale_inv);
|
||||
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
|
||||
}
|
||||
}
|
||||
|
||||
@ -176,25 +176,33 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||
dim3 block(std::min(hidden_size, 1024));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
|
||||
vllm::rms_norm_static_fp8_quant_kernel<scalar_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), epsilon,
|
||||
num_tokens, hidden_size);
|
||||
});
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "rms_norm_kernel_scalar_type", [&] {
|
||||
VLLM_DISPATCH_FP8_TYPES(
|
||||
out.scalar_type(), "rms_norm_kernel_fp8_type", [&] {
|
||||
vllm::rms_norm_static_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(),
|
||||
epsilon, num_tokens, hidden_size);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
|
||||
vllm::fused_add_rms_norm_static_fp8_quant_kernel<scalar_t, width> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(), \
|
||||
residual.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(), \
|
||||
scale.data_ptr<float>(), epsilon, num_tokens, hidden_size); \
|
||||
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
input.scalar_type(), "fused_add_rms_norm_kernel_scalar_type", [&] { \
|
||||
VLLM_DISPATCH_FP8_TYPES( \
|
||||
out.scalar_type(), "fused_add_rms_norm_kernel_fp8_type", [&] { \
|
||||
vllm::fused_add_rms_norm_static_fp8_quant_kernel<scalar_t, \
|
||||
width, fp8_t> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(), \
|
||||
residual.data_ptr<scalar_t>(), \
|
||||
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), \
|
||||
epsilon, num_tokens, hidden_size); \
|
||||
}); \
|
||||
});
|
||||
|
||||
void fused_add_rms_norm_static_fp8_quant(
|
||||
torch::Tensor& out, // [..., hidden_size],
|
||||
torch::Tensor& input, // [..., hidden_size]
|
||||
|
||||
@ -18,3 +18,14 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad);
|
||||
#ifndef USE_ROCM
|
||||
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
|
||||
torch::Tensor b_qweight, torch::Tensor b_scales,
|
||||
std::optional<torch::Tensor> b_qzeros,
|
||||
std::optional<torch::Tensor> topk_weights,
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor expert_ids,
|
||||
torch::Tensor num_tokens_post_pad, int64_t top_k,
|
||||
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
|
||||
int64_t BLOCK_SIZE_K, int64_t bit);
|
||||
#endif
|
||||
346
csrc/moe/moe_wna16.cu
Normal file
346
csrc/moe/moe_wna16.cu
Normal file
@ -0,0 +1,346 @@
|
||||
|
||||
#include <torch/all.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include "moe_wna16_utils.h"
|
||||
|
||||
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
|
||||
|
||||
template <typename scalar_t, int bit, int GROUPS>
|
||||
__global__ void moe_wna16_gemm_kernel(
|
||||
const scalar_t* __restrict__ input, scalar_t* __restrict__ output,
|
||||
|
||||
const uint32_t* __restrict__ qweight, const scalar_t* __restrict__ scales,
|
||||
const uint32_t* __restrict__ qzeros,
|
||||
|
||||
const float* __restrict__ topk_weights,
|
||||
const int32_t* __restrict__ sorted_token_ids,
|
||||
const int32_t* __restrict__ expert_ids,
|
||||
const int32_t* __restrict__ num_tokens_post_pad,
|
||||
|
||||
uint16_t num_experts, uint16_t group_size, uint16_t top_k, uint32_t size_m,
|
||||
uint32_t size_n, uint32_t size_k, uint16_t BLOCK_SIZE_M,
|
||||
uint16_t BLOCK_SIZE_N, uint16_t BLOCK_SIZE_K, bool has_zp,
|
||||
bool mul_topk_weight) {
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800
|
||||
if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
return;
|
||||
} else {
|
||||
#endif
|
||||
|
||||
using Dtype = ScalarType<scalar_t>;
|
||||
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
|
||||
|
||||
if (blockIdx.x * BLOCK_SIZE_M >= num_tokens_post_pad[0]) return;
|
||||
|
||||
const int32_t offset_n = blockIdx.y * BLOCK_SIZE_N + threadIdx.x;
|
||||
const int32_t offset_k = blockIdx.z * BLOCK_SIZE_K;
|
||||
|
||||
const int32_t expert_id = expert_ids[blockIdx.x];
|
||||
|
||||
int32_t num_valid_tokens = 0;
|
||||
extern __shared__ uint16_t block_input_tmp[];
|
||||
scalar_t* block_input = reinterpret_cast<scalar_t*>(block_input_tmp);
|
||||
scalar_t2* block_input_half2 = reinterpret_cast<scalar_t2*>(block_input);
|
||||
|
||||
// load BLOCK_SIZE_M * BLOCK_SIZE_K into shared memory
|
||||
for (int m = 0; m < BLOCK_SIZE_M; m++) {
|
||||
const int32_t offset_m = blockIdx.x * BLOCK_SIZE_M + m;
|
||||
const int32_t token_index = sorted_token_ids[offset_m];
|
||||
if (token_index / top_k >= size_m) break;
|
||||
|
||||
num_valid_tokens = m + 1;
|
||||
if (blockIdx.z == 0 && offset_n < size_n)
|
||||
output[token_index * size_n + offset_n] = Dtype::int2num(0);
|
||||
|
||||
if (expert_id != -1) {
|
||||
int k_per_thread = DIVIDE(BLOCK_SIZE_K, BLOCK_SIZE_N);
|
||||
for (int i = 0; i < k_per_thread; i++) {
|
||||
int k = BLOCK_SIZE_N * i + threadIdx.x;
|
||||
if (k >= BLOCK_SIZE_K) break;
|
||||
if (offset_k + k >= size_k) break;
|
||||
|
||||
// load input to shared memory
|
||||
// use a special layout to fit the layout of dequanted-weight
|
||||
int origin_k;
|
||||
if constexpr (bit == 4) {
|
||||
// [0, 4, 1, 5, 2, 6, 3, 7]
|
||||
int8_t order = (threadIdx.x % 2) * 4 + ((threadIdx.x % 8) / 2);
|
||||
origin_k = BLOCK_SIZE_N * i + threadIdx.x / 8 * 8 + order;
|
||||
} else {
|
||||
// [0, 2, 1, 3]
|
||||
int8_t order = (threadIdx.x % 2) * 2 + ((threadIdx.x % 4) / 2);
|
||||
origin_k = BLOCK_SIZE_N * i + threadIdx.x / 4 * 4 + order;
|
||||
}
|
||||
|
||||
origin_k += token_index / top_k * size_k + blockIdx.z * BLOCK_SIZE_K;
|
||||
block_input[m * BLOCK_SIZE_K + k] = input[origin_k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (expert_id == -1) return;
|
||||
__syncthreads();
|
||||
if (threadIdx.x >= BLOCK_SIZE_N || offset_n >= size_n) return;
|
||||
|
||||
float res[64]; // assume BLOCK_SIZE_M <= 64
|
||||
scalar_t2 res2;
|
||||
scalar_t2 scale_f2;
|
||||
scalar_t2 qzero_f2;
|
||||
|
||||
// note that (size_n * size_k * expert_id) may greater than 2 ** 31
|
||||
constexpr int8_t pack_factor = 32 / bit;
|
||||
const uint64_t expert_offset = ((uint64_t)size_n) * size_k * expert_id;
|
||||
const uint32_t* expert_qweight = qweight + expert_offset / pack_factor;
|
||||
const scalar_t* expert_scales = scales + expert_offset / group_size;
|
||||
const uint32_t* expert_qzeros =
|
||||
qzeros + expert_offset / group_size / pack_factor;
|
||||
|
||||
// load 4*int32 one time: 4 int32 = 128 bit = 1 float4
|
||||
// weight would be loaded in loop
|
||||
uint32_t expert_qweight_tmp[4];
|
||||
float4* expert_qweight_tmp_float4 =
|
||||
reinterpret_cast<float4*>(expert_qweight_tmp);
|
||||
|
||||
// load all required scales one time
|
||||
scalar_t expert_scales_groups[GROUPS];
|
||||
int scales_offset_tmp =
|
||||
(offset_n * size_k + offset_k) / group_size / GROUPS;
|
||||
if constexpr (GROUPS == 1) {
|
||||
*expert_scales_groups = expert_scales[scales_offset_tmp];
|
||||
} else if constexpr (GROUPS == 2) {
|
||||
float* expert_scales_groups_tmp =
|
||||
reinterpret_cast<float*>(expert_scales_groups);
|
||||
*expert_scales_groups_tmp =
|
||||
reinterpret_cast<const float*>(expert_scales)[scales_offset_tmp];
|
||||
} else if constexpr (GROUPS == 4) {
|
||||
float2* expert_scales_groups_tmp =
|
||||
reinterpret_cast<float2*>(expert_scales_groups);
|
||||
*expert_scales_groups_tmp =
|
||||
reinterpret_cast<const float2*>(expert_scales)[scales_offset_tmp];
|
||||
} else if constexpr (GROUPS == 8) {
|
||||
float4* expert_scales_groups_tmp =
|
||||
reinterpret_cast<float4*>(expert_scales_groups);
|
||||
*expert_scales_groups_tmp =
|
||||
reinterpret_cast<const float4*>(expert_scales)[scales_offset_tmp];
|
||||
}
|
||||
|
||||
// load all required qzeros one time
|
||||
uint8_t expert_qzeros_groups[GROUPS];
|
||||
if (!has_zp) {
|
||||
if constexpr (bit == 4) {
|
||||
qzero_f2 = Dtype::num2num2(Dtype::int2num(8));
|
||||
} else {
|
||||
qzero_f2 = Dtype::num2num2(Dtype::int2num(128));
|
||||
}
|
||||
} else {
|
||||
int qzeros_offset_tmp =
|
||||
(offset_n / (8 / bit)) * (size_k / group_size / GROUPS) +
|
||||
offset_k / group_size / GROUPS;
|
||||
if constexpr (GROUPS == 1) {
|
||||
uint8_t* expert_qzeros_groups_tmp =
|
||||
reinterpret_cast<uint8_t*>(expert_qzeros_groups);
|
||||
*expert_qzeros_groups_tmp =
|
||||
reinterpret_cast<const uint8_t*>(expert_qzeros)[qzeros_offset_tmp];
|
||||
} else if constexpr (GROUPS == 2) {
|
||||
uint16_t* expert_qzeros_groups_tmp =
|
||||
reinterpret_cast<uint16_t*>(expert_qzeros_groups);
|
||||
*expert_qzeros_groups_tmp =
|
||||
reinterpret_cast<const uint16_t*>(expert_qzeros)[qzeros_offset_tmp];
|
||||
} else if constexpr (GROUPS == 4) {
|
||||
uint32_t* expert_qzeros_groups_tmp =
|
||||
reinterpret_cast<uint32_t*>(expert_qzeros_groups);
|
||||
*expert_qzeros_groups_tmp =
|
||||
reinterpret_cast<const uint32_t*>(expert_qzeros)[qzeros_offset_tmp];
|
||||
} else if constexpr (GROUPS == 8) {
|
||||
uint64_t* expert_qzeros_groups_tmp =
|
||||
reinterpret_cast<uint64_t*>(expert_qzeros_groups);
|
||||
*expert_qzeros_groups_tmp =
|
||||
reinterpret_cast<const uint64_t*>(expert_qzeros)[qzeros_offset_tmp];
|
||||
}
|
||||
}
|
||||
|
||||
for (int tmp_k = 0; tmp_k < BLOCK_SIZE_K / pack_factor; tmp_k++) {
|
||||
int k = offset_k + tmp_k * pack_factor;
|
||||
if (k >= size_k) break;
|
||||
const int32_t weight_offset = offset_n * size_k + k;
|
||||
|
||||
if (tmp_k % 4 == 0) {
|
||||
*expert_qweight_tmp_float4 = reinterpret_cast<const float4*>(
|
||||
expert_qweight)[weight_offset / pack_factor / 4];
|
||||
}
|
||||
|
||||
if (tmp_k % (group_size / pack_factor) == 0) {
|
||||
scalar_t scale_f =
|
||||
expert_scales_groups[tmp_k / (group_size / pack_factor)];
|
||||
scale_f2 = Dtype::num2num2(scale_f);
|
||||
|
||||
if (has_zp) {
|
||||
uint8_t qzero =
|
||||
expert_qzeros_groups[tmp_k / (group_size / pack_factor)];
|
||||
if constexpr (bit == 4) {
|
||||
qzero = (qzero >> ((threadIdx.x % 2) * 4)) & 0xF;
|
||||
}
|
||||
qzero_f2 = Dtype::num2num2(Dtype::int2num(qzero));
|
||||
}
|
||||
}
|
||||
|
||||
scalar_t2 weight_half2[16 / bit];
|
||||
dequant<scalar_t2, bit>(expert_qweight_tmp[tmp_k % 4], weight_half2);
|
||||
|
||||
for (int m = 0; m < num_valid_tokens; m++) {
|
||||
res2 = {};
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 16 / bit; i++) {
|
||||
int32_t offset_input = m * BLOCK_SIZE_K / 2 + tmp_k * (16 / bit) + i;
|
||||
res2 = __hfma2(__hmul2(__hsub2(weight_half2[i], qzero_f2), scale_f2),
|
||||
block_input_half2[offset_input], res2);
|
||||
}
|
||||
|
||||
if (tmp_k == 0) {
|
||||
res[m] = Dtype::num2float(res2.x) + Dtype::num2float(res2.y);
|
||||
} else {
|
||||
res[m] += Dtype::num2float(res2.x) + Dtype::num2float(res2.y);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int m = 0; m < num_valid_tokens; ++m) {
|
||||
const int32_t token_index =
|
||||
sorted_token_ids[blockIdx.x * BLOCK_SIZE_M + m];
|
||||
if (mul_topk_weight) {
|
||||
res[m] *= topk_weights[token_index];
|
||||
}
|
||||
atomicAdd(&output[token_index * size_n + offset_n],
|
||||
Dtype::float2num(res[m]));
|
||||
}
|
||||
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void run_moe_wna16_gemm(const scalar_t* input, scalar_t* output,
|
||||
const uint32_t* b_qweight, const scalar_t* b_scales,
|
||||
const uint32_t* b_qzeros, const float* topk_weights,
|
||||
const int32_t* sorted_token_ids,
|
||||
const int32_t* expert_ids,
|
||||
const int32_t* num_tokens_post_pad, int num_experts,
|
||||
int group_size, int num_token_blocks, int top_k,
|
||||
int size_m, int size_n, int size_k, int BLOCK_SIZE_M,
|
||||
int BLOCK_SIZE_N, int BLOCK_SIZE_K, int bit,
|
||||
bool has_zp, bool mul_topk_weight) {
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = BLOCK_SIZE_N;
|
||||
blockDim.y = 1;
|
||||
blockDim.z = 1;
|
||||
gridDim.x = num_token_blocks;
|
||||
gridDim.y = DIVIDE(size_n, BLOCK_SIZE_N);
|
||||
gridDim.z = DIVIDE(size_k, BLOCK_SIZE_K);
|
||||
|
||||
auto kernel = moe_wna16_gemm_kernel<scalar_t, 4, 1>;
|
||||
if (bit == 4) {
|
||||
if (BLOCK_SIZE_K / group_size == 2) {
|
||||
kernel = moe_wna16_gemm_kernel<scalar_t, 4, 2>;
|
||||
} else if (BLOCK_SIZE_K / group_size == 4) {
|
||||
kernel = moe_wna16_gemm_kernel<scalar_t, 4, 4>;
|
||||
} else if (BLOCK_SIZE_K / group_size == 8) {
|
||||
kernel = moe_wna16_gemm_kernel<scalar_t, 4, 8>;
|
||||
}
|
||||
} else {
|
||||
if (BLOCK_SIZE_K / group_size == 1) {
|
||||
kernel = moe_wna16_gemm_kernel<scalar_t, 8, 1>;
|
||||
} else if (BLOCK_SIZE_K / group_size == 2) {
|
||||
kernel = moe_wna16_gemm_kernel<scalar_t, 8, 2>;
|
||||
} else if (BLOCK_SIZE_K / group_size == 4) {
|
||||
kernel = moe_wna16_gemm_kernel<scalar_t, 8, 4>;
|
||||
} else if (BLOCK_SIZE_K / group_size == 8) {
|
||||
kernel = moe_wna16_gemm_kernel<scalar_t, 8, 8>;
|
||||
}
|
||||
}
|
||||
|
||||
const int shared_mem_size = BLOCK_SIZE_M * BLOCK_SIZE_K * 2;
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
kernel<<<gridDim, blockDim, shared_mem_size, stream>>>(
|
||||
input, output, b_qweight, b_scales, b_qzeros, topk_weights,
|
||||
sorted_token_ids, expert_ids, num_tokens_post_pad, num_experts,
|
||||
group_size, top_k, size_m, size_n, size_k, BLOCK_SIZE_M, BLOCK_SIZE_N,
|
||||
BLOCK_SIZE_K, has_zp, mul_topk_weight);
|
||||
}
|
||||
|
||||
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
|
||||
torch::Tensor b_qweight, torch::Tensor b_scales,
|
||||
std::optional<torch::Tensor> b_qzeros,
|
||||
std::optional<torch::Tensor> topk_weights,
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor expert_ids,
|
||||
torch::Tensor num_tokens_post_pad, int64_t top_k,
|
||||
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
|
||||
int64_t BLOCK_SIZE_K, int64_t bit) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
auto options =
|
||||
torch::TensorOptions().dtype(input.dtype()).device(input.device());
|
||||
|
||||
const int num_experts = b_qweight.size(0);
|
||||
const int size_m = input.size(0);
|
||||
const int size_n = b_qweight.size(1);
|
||||
const int size_k = input.size(1);
|
||||
const int group_size = size_k / b_scales.size(2);
|
||||
|
||||
int64_t EM = sorted_token_ids.size(0);
|
||||
if (size_m <= BLOCK_SIZE_M) {
|
||||
EM = min(EM, size_m * BLOCK_SIZE_M * top_k);
|
||||
}
|
||||
const int num_token_blocks = (EM + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M;
|
||||
|
||||
const uint32_t* b_qzeros_ptr;
|
||||
if (b_qzeros.has_value())
|
||||
b_qzeros_ptr = (const uint32_t*)b_qzeros.value().data_ptr<uint8_t>();
|
||||
const float* topk_weights_ptr;
|
||||
if (topk_weights.has_value())
|
||||
topk_weights_ptr = (const float*)topk_weights.value().data_ptr();
|
||||
|
||||
int groups_per_block_row = BLOCK_SIZE_K / group_size;
|
||||
TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8");
|
||||
TORCH_CHECK(size_k % BLOCK_SIZE_K == 0,
|
||||
"size_k must divisible by BLOCK_SIZE_K");
|
||||
TORCH_CHECK(BLOCK_SIZE_K % group_size == 0,
|
||||
"BLOCK_SIZE_K must divisible by group_size");
|
||||
TORCH_CHECK(BLOCK_SIZE_M <= 64, "BLOCK_SIZE_M must less or equal to 64");
|
||||
TORCH_CHECK(groups_per_block_row == 1 || groups_per_block_row == 2 ||
|
||||
groups_per_block_row == 4 || groups_per_block_row == 8,
|
||||
"BLOCK_SIZE_K // group_size must be one of [1, 2, 4, 8]");
|
||||
|
||||
if (input.scalar_type() == at::ScalarType::Half) {
|
||||
run_moe_wna16_gemm<half>(
|
||||
(const half*)input.data_ptr<at::Half>(),
|
||||
(half*)output.data_ptr<at::Half>(),
|
||||
(const uint32_t*)b_qweight.data_ptr<uint8_t>(),
|
||||
(const half*)b_scales.data_ptr<at::Half>(), b_qzeros_ptr,
|
||||
topk_weights_ptr, sorted_token_ids.data_ptr<int32_t>(),
|
||||
expert_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(),
|
||||
num_experts, group_size, num_token_blocks, top_k, size_m, size_n,
|
||||
size_k, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, bit,
|
||||
b_qzeros.has_value(), topk_weights.has_value());
|
||||
} else if (input.scalar_type() == at::ScalarType::BFloat16) {
|
||||
run_moe_wna16_gemm<nv_bfloat16>(
|
||||
(const nv_bfloat16*)input.data_ptr<at::BFloat16>(),
|
||||
(nv_bfloat16*)output.data_ptr<at::BFloat16>(),
|
||||
(const uint32_t*)b_qweight.data_ptr<uint8_t>(),
|
||||
(const nv_bfloat16*)b_scales.data_ptr<at::BFloat16>(), b_qzeros_ptr,
|
||||
topk_weights_ptr, sorted_token_ids.data_ptr<int32_t>(),
|
||||
expert_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(),
|
||||
num_experts, group_size, num_token_blocks, top_k, size_m, size_n,
|
||||
size_k, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, bit,
|
||||
b_qzeros.has_value(), topk_weights.has_value());
|
||||
} else {
|
||||
TORCH_CHECK(false, "moe_wna16_gemm only supports bfloat16 and float16");
|
||||
}
|
||||
return output;
|
||||
}
|
||||
200
csrc/moe/moe_wna16_utils.h
Normal file
200
csrc/moe/moe_wna16_utils.h
Normal file
@ -0,0 +1,200 @@
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
template <typename scalar_t>
|
||||
class ScalarType {};
|
||||
|
||||
template <>
|
||||
class ScalarType<half> {
|
||||
public:
|
||||
using scalar_t = half;
|
||||
using scalar_t2 = half2;
|
||||
|
||||
static __device__ float inline num2float(const half x) {
|
||||
return __half2float(x);
|
||||
}
|
||||
|
||||
static __device__ half2 inline num2num2(const half x) {
|
||||
return __half2half2(x);
|
||||
}
|
||||
|
||||
static __device__ half2 inline nums2num2(const half x1, const half x2) {
|
||||
return __halves2half2(x1, x2);
|
||||
}
|
||||
|
||||
static __host__ __device__ half inline float2num(const float x) {
|
||||
return __float2half(x);
|
||||
}
|
||||
|
||||
static __host__ __device__ half inline int2num(const float x) {
|
||||
return __int2half_rn(x);
|
||||
}
|
||||
|
||||
static __host__ __device__ float2 inline num22float2(const half2 x) {
|
||||
return __half22float2(x);
|
||||
}
|
||||
|
||||
static __host__ __device__ half2 inline float22num2(const float2 x) {
|
||||
return __float22half2_rn(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class ScalarType<nv_bfloat16> {
|
||||
public:
|
||||
using scalar_t = nv_bfloat16;
|
||||
using scalar_t2 = nv_bfloat162;
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
static __device__ float inline num2float(const nv_bfloat16 x) {
|
||||
return __bfloat162float(x);
|
||||
}
|
||||
|
||||
static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) {
|
||||
return __bfloat162bfloat162(x);
|
||||
}
|
||||
|
||||
static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1,
|
||||
const nv_bfloat16 x2) {
|
||||
return __halves2bfloat162(x1, x2);
|
||||
}
|
||||
|
||||
static __host__ __device__ nv_bfloat16 inline float2num(const float x) {
|
||||
return __float2bfloat16(x);
|
||||
}
|
||||
|
||||
static __host__ __device__ nv_bfloat16 inline int2num(const float x) {
|
||||
return __int2bfloat16_rn(x);
|
||||
}
|
||||
|
||||
static __host__ __device__ float2 inline num22float2(const nv_bfloat162 x) {
|
||||
return __bfloat1622float2(x);
|
||||
}
|
||||
|
||||
static __host__ __device__ nv_bfloat162 inline float22num2(const float2 x) {
|
||||
return __float22bfloat162_rn(x);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
template <int lut>
|
||||
__device__ inline int lop3(int a, int b, int c) {
|
||||
int res;
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(res)
|
||||
: "r"(a), "r"(b), "r"(c), "n"(lut));
|
||||
return res;
|
||||
}
|
||||
|
||||
template <int start_byte, int mask>
|
||||
__device__ inline uint32_t prmt(uint32_t a) {
|
||||
uint32_t res;
|
||||
asm volatile("prmt.b32 %0, %1, %2, %3;\n"
|
||||
: "=r"(res)
|
||||
: "r"(a), "n"(start_byte), "n"(mask));
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename scalar_t2, int bit>
|
||||
__device__ inline void dequant(int q, scalar_t2* res) {}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, 4>(int q, half2* res) {
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
const int SUB = 0x64006400;
|
||||
const int MUL = 0x2c002c00;
|
||||
const int ADD = 0xd400d400;
|
||||
|
||||
int lo0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
|
||||
int hi0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
|
||||
q >>= 8;
|
||||
int lo1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
|
||||
int hi1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
|
||||
|
||||
res[0] = __hsub2(*reinterpret_cast<half2*>(&lo0),
|
||||
*reinterpret_cast<const half2*>(&SUB));
|
||||
res[1] = __hfma2(*reinterpret_cast<half2*>(&hi0),
|
||||
*reinterpret_cast<const half2*>(&MUL),
|
||||
*reinterpret_cast<const half2*>(&ADD));
|
||||
res[2] = __hsub2(*reinterpret_cast<half2*>(&lo1),
|
||||
*reinterpret_cast<const half2*>(&SUB));
|
||||
res[3] = __hfma2(*reinterpret_cast<half2*>(&hi1),
|
||||
*reinterpret_cast<const half2*>(&MUL),
|
||||
*reinterpret_cast<const half2*>(&ADD));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, 8>(int q, half2* res) {
|
||||
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
||||
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
||||
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||
|
||||
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
||||
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
||||
|
||||
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
|
||||
|
||||
res[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
res[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
|
||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
}
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, 4>(int q, nv_bfloat162* res) {
|
||||
static constexpr uint32_t MASK = 0x000f000f;
|
||||
static constexpr uint32_t EX = 0x43004300;
|
||||
|
||||
int lo0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
|
||||
q >>= 4;
|
||||
int hi0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
|
||||
q >>= 4;
|
||||
int lo1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
|
||||
q >>= 4;
|
||||
int hi1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
|
||||
|
||||
static constexpr uint32_t MUL = 0x3F803F80;
|
||||
static constexpr uint32_t ADD = 0xC300C300;
|
||||
|
||||
res[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo0),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
res[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi0),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
res[2] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo1),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
res[3] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi1),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, 8>(int q, nv_bfloat162* res) {
|
||||
float fp32_intermediates[4];
|
||||
uint32_t* fp32_intermediates_casted =
|
||||
reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||
|
||||
static constexpr uint32_t fp32_base = 0x4B000000;
|
||||
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
|
||||
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
|
||||
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
|
||||
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
|
||||
|
||||
fp32_intermediates[0] -= 8388608.f;
|
||||
fp32_intermediates[1] -= 8388608.f;
|
||||
fp32_intermediates[2] -= 8388608.f;
|
||||
fp32_intermediates[3] -= 8388608.f;
|
||||
|
||||
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(res);
|
||||
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
|
||||
fp32_intermediates_casted[1], 0x7632);
|
||||
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
|
||||
fp32_intermediates_casted[3], 0x7632);
|
||||
}
|
||||
#endif
|
||||
@ -32,6 +32,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
m.impl("sgl_moe_align_block_size", torch::kCUDA, &sgl_moe_align_block_size);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
m.def(
|
||||
"moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, "
|
||||
"Tensor b_scales, Tensor? b_qzeros, "
|
||||
"Tensor? topk_weights, Tensor sorted_token_ids, "
|
||||
"Tensor expert_ids, Tensor num_tokens_post_pad, "
|
||||
"int top_k, int BLOCK_SIZE_M, int BLOCK_SIZE_N, int BLOCK_SIZE_K, "
|
||||
"int bit) -> Tensor");
|
||||
|
||||
m.impl("moe_wna16_gemm", torch::kCUDA, &moe_wna16_gemm);
|
||||
|
||||
m.def(
|
||||
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
|
||||
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
|
||||
@ -42,6 +52,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
"int moe_block_size, bool replicate_input, bool apply_weights)"
|
||||
" -> Tensor");
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
16
csrc/ops.h
16
csrc/ops.h
@ -151,15 +151,25 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
|
||||
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
|
||||
int64_t row);
|
||||
|
||||
torch::Tensor ggml_moe_a8(torch::Tensor X, torch::Tensor W,
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor expert_ids,
|
||||
torch::Tensor num_tokens_post_padded, int64_t type,
|
||||
int64_t row, int64_t top_k, int64_t tokens);
|
||||
|
||||
int64_t ggml_moe_get_block_size(int64_t type);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
|
||||
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability);
|
||||
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
|
||||
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
|
||||
|
||||
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
|
||||
torch::Tensor const& B, torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha);
|
||||
|
||||
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
|
||||
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
|
||||
|
||||
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
|
||||
@ -274,7 +274,7 @@ void advance_step_flashinfer(
|
||||
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
||||
cudaDeviceGetAttribute(&threads, cudaDevAttrMaxThreadsPerBlock, dev);
|
||||
|
||||
int block_tables_stride = block_tables.stride(0);
|
||||
[[maybe_unused]] int block_tables_stride = block_tables.stride(0);
|
||||
TORCH_CHECK((blocks * threads > num_queries),
|
||||
"multi-step: not enough threads to map to num_queries = ",
|
||||
num_queries, " block_tables.stride(0) = ", block_tables.stride(0),
|
||||
|
||||
34
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu
Normal file
34
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu
Normal file
@ -0,0 +1,34 @@
|
||||
#include <cudaTypedefs.h>
|
||||
#include "c3x/scaled_mm_kernels.hpp"
|
||||
|
||||
#include "cuda_utils.h"
|
||||
|
||||
/*
|
||||
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
||||
NVIDIA GPUs with sm100 (Blackwell).
|
||||
*/
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
|
||||
|
||||
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
int M = a.size(0), N = b.size(1), K = a.size(1);
|
||||
TORCH_CHECK(
|
||||
(a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
|
||||
(b_scales.numel() == 1 || b_scales.numel() == b.size(1)),
|
||||
"Currently, block scaled fp8 gemm is not implemented for Blackwell");
|
||||
|
||||
// Standard per-tensor/per-token/per-channel scaling
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn,
|
||||
"Currently, only fp8 gemm is implemented for Blackwell");
|
||||
vllm::cutlass_scaled_mm_sm100_fp8(c, a, b, a_scales, b_scales, bias);
|
||||
}
|
||||
|
||||
#endif
|
||||
@ -5,9 +5,11 @@
|
||||
|
||||
/*
|
||||
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
||||
NVIDIA GPUs with sm90a (Hopper) or later.
|
||||
NVIDIA GPUs with sm90a (Hopper).
|
||||
*/
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
|
||||
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
@ -72,27 +74,4 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
|
||||
azp, bias);
|
||||
}
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12080
|
||||
|
||||
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
int M = a.size(0), N = b.size(1), K = a.size(1);
|
||||
TORCH_CHECK(
|
||||
(a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
|
||||
(b_scales.numel() == 1 || b_scales.numel() == b.size(1)),
|
||||
"Currently, block scaled fp8 gemm is not implemented for Blackwell");
|
||||
|
||||
// Standard per-tensor/per-token/per-channel scaling
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn,
|
||||
"Currently, only fp8 gemm is implemented for Blackwell");
|
||||
vllm::cutlass_scaled_mm_sm100_fp8(c, a, b, a_scales, b_scales, bias);
|
||||
}
|
||||
|
||||
#endif
|
||||
@ -23,12 +23,15 @@ void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
|
||||
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
@ -60,7 +63,7 @@ void cutlass_scaled_mm_azp_sm89(torch::Tensor& c, torch::Tensor const& a,
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
@ -121,26 +124,21 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
|
||||
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
||||
int32_t version_num = get_sm_version_num();
|
||||
// Hopper
|
||||
|
||||
// Guard against compilation issues for sm90 kernels
|
||||
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION < 12080
|
||||
if (version_num >= 90 && version_num < 100) {
|
||||
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
#else
|
||||
if (version_num >= 90 && version_num < 100) {
|
||||
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
} else if (version_num >= 100) {
|
||||
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
|
||||
if (version_num >= 100) {
|
||||
cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// Guard against compilation issues for sm90 kernels
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
if (version_num >= 90 && version_num < 100) {
|
||||
// Hopper
|
||||
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
|
||||
@ -211,7 +209,7 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
||||
|
||||
int32_t version_num = get_sm_version_num();
|
||||
|
||||
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
if (version_num >= 90) {
|
||||
cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
return;
|
||||
|
||||
@ -36,3 +36,9 @@ void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
|
||||
"be compiled using CUDA 12.8 and target "
|
||||
"compute capability 100 or above.");
|
||||
}
|
||||
|
||||
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability) {
|
||||
int runtimeVersion;
|
||||
cudaRuntimeGetVersion(&runtimeVersion);
|
||||
return cuda_device_capability >= 100 && runtimeVersion >= 12080;
|
||||
}
|
||||
@ -201,10 +201,11 @@ void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
#define CHECK_TYPE(x, st, m) \
|
||||
TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type:", m)
|
||||
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
|
||||
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
|
||||
#define CHECK_TH_CUDA(x, m) \
|
||||
TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x, m) \
|
||||
TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
|
||||
TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous")
|
||||
#define CHECK_INPUT(x, st, m) \
|
||||
CHECK_TH_CUDA(x, m); \
|
||||
CHECK_CONTIGUOUS(x, m); \
|
||||
|
||||
@ -13,6 +13,40 @@ namespace vllm {
|
||||
namespace fp8 {
|
||||
#ifdef ENABLE_FP8
|
||||
|
||||
// Use hardware cvt instruction for fp8 on rocm
|
||||
template <typename fp8_type>
|
||||
__device__ __forceinline__ fp8_type cvt_c10(float const r) {
|
||||
return {};
|
||||
}
|
||||
|
||||
// __hip_fp8_e4m3 only exists starting in ROCm 6.3. The macro
|
||||
// HIP_FP8_TYPE_OCP comes from the hip_fp8.h header and also makes
|
||||
// its first appearance in ROCm 6.3. Since VLLM_DISPATCH_FP8_TYPES
|
||||
// on ROCm instantiates both OCP and FNUZ kernels, we need to replace
|
||||
// the new HW cvt with something reasonable that doesn't rely on the
|
||||
// ROCm 6.3 feature. This allows compiling on ROCm 6.2 or newer.
|
||||
template <>
|
||||
__device__ __forceinline__ c10::Float8_e4m3fn cvt_c10(float const r) {
|
||||
#if HIP_FP8_TYPE_OCP
|
||||
return c10::Float8_e4m3fn(
|
||||
__hip_cvt_float_to_fp8(r, __hip_fp8_e4m3::__default_saturation,
|
||||
__hip_fp8_e4m3::__default_interpret),
|
||||
c10::Float8_e4m3fn::from_bits());
|
||||
#else
|
||||
// Cast implemented by pytorch. Uses bit manipulation instead of HW cvt.
|
||||
// HW cvt above is faster when it is available (ROCm 6.3 or newer).
|
||||
return static_cast<c10::Float8_e4m3fn>(r);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ c10::Float8_e4m3fnuz cvt_c10(float const r) {
|
||||
return c10::Float8_e4m3fnuz(
|
||||
__hip_cvt_float_to_fp8(r, __hip_fp8_e4m3_fnuz::__default_saturation,
|
||||
__hip_fp8_e4m3_fnuz::__default_interpret),
|
||||
c10::Float8_e4m3fnuz::from_bits());
|
||||
}
|
||||
|
||||
template <typename Tout, typename Tin>
|
||||
__inline__ __device__ Tout vec_conversion(const Tin& x) {
|
||||
return x;
|
||||
@ -412,7 +446,7 @@ scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) {
|
||||
template <>
|
||||
__inline__ __device__ uint32_t
|
||||
scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale) {
|
||||
__half2_raw h2r =
|
||||
[[maybe_unused]] __half2_raw h2r =
|
||||
__hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
|
||||
union {
|
||||
__half2_raw h2r;
|
||||
|
||||
@ -11,8 +11,8 @@
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
__global__ void scaled_fp8_quant_kernel(fp8_type* __restrict__ out,
|
||||
const scalar_t* __restrict__ input,
|
||||
const float* __restrict__ scale,
|
||||
int64_t num_elems) {
|
||||
@ -25,12 +25,13 @@ __global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
|
||||
out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
__global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
||||
FP8_TYPE* __restrict__ out, float* __restrict__ scale,
|
||||
fp8_type* __restrict__ out, float* __restrict__ scale,
|
||||
scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
|
||||
const int hidden_size) {
|
||||
float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
|
||||
float const min_scaling_factor =
|
||||
1.0f / (fp8_e4m3_adjusted_max_v<fp8_type> * 512.f);
|
||||
|
||||
int const tid = threadIdx.x;
|
||||
int const token_idx = blockIdx.x;
|
||||
@ -38,7 +39,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
||||
// Use int64 to avoid overflowing an int32 when calculating this offset
|
||||
int64_t offset = static_cast<int64_t>(token_idx) * hidden_size;
|
||||
scalar_t const* __restrict__ token_input = &input[offset];
|
||||
FP8_TYPE* __restrict__ token_output = &out[offset];
|
||||
fp8_type* __restrict__ token_output = &out[offset];
|
||||
|
||||
// For vectorization, token_input and token_output pointers need to be
|
||||
// aligned at 8-byte and 4-byte addresses respectively.
|
||||
@ -66,7 +67,8 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
||||
token_scale = block_absmax_val_maybe;
|
||||
}
|
||||
// token scale computation
|
||||
token_scale = max(token_scale / FP8_E4M3_MAX, min_scaling_factor);
|
||||
token_scale = max(token_scale / fp8_e4m3_adjusted_max_v<fp8_type>,
|
||||
min_scaling_factor);
|
||||
scale[token_idx] = token_scale;
|
||||
}
|
||||
__syncthreads();
|
||||
@ -77,7 +79,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
||||
token_output, token_input, token_scale, hidden_size, tid, blockDim.x);
|
||||
} else {
|
||||
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
||||
token_output[i] = scaled_fp8_conversion<false>(
|
||||
token_output[i] = scaled_fp8_conversion<false, fp8_type>(
|
||||
static_cast<float>(token_input[i]), token_scale);
|
||||
}
|
||||
}
|
||||
@ -96,10 +98,14 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
|
||||
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(), num_elems);
|
||||
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
|
||||
VLLM_DISPATCH_FP8_TYPES(
|
||||
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
|
||||
vllm::scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(), num_elems);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@ -114,12 +120,18 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
|
||||
vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
|
||||
scale.data_ptr<float>(), input.data_ptr<scalar_t>(), num_elems);
|
||||
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(), num_elems);
|
||||
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
|
||||
VLLM_DISPATCH_FP8_TYPES(
|
||||
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
|
||||
vllm::segmented_max_reduction<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(scale.data_ptr<float>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
num_elems);
|
||||
vllm::scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(), num_elems);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@ -138,12 +150,18 @@ void dynamic_per_token_scaled_fp8_quant(
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] {
|
||||
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<FP8_TYPE>(), scales.data_ptr<float>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
|
||||
hidden_size);
|
||||
input.scalar_type(),
|
||||
"dynamic_per_token_scaled_fp8_quant_kernel_scalar_type", [&] {
|
||||
VLLM_DISPATCH_FP8_TYPES(
|
||||
out.scalar_type(),
|
||||
"dynamic_per_token_scaled_fp8_quant_kernel_fp8_type", [&] {
|
||||
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<fp8_t>(), scales.data_ptr<float>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
scale_ub.has_value() ? scale_ub->data_ptr<float>()
|
||||
: nullptr,
|
||||
hidden_size);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@ -7,18 +7,52 @@
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
using FP8_TYPE = c10::Float8_e4m3fn;
|
||||
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
|
||||
std::numeric_limits<FP8_TYPE>::max();
|
||||
#define MAYBE_HOST_DEVICE C10_HOST_DEVICE
|
||||
#else
|
||||
#include <ATen/hip/HIPContext.h>
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
#include <c10/util/Float8_e4m3fnuz.h>
|
||||
#include "amd/quant_utils.cuh"
|
||||
using FP8_TYPE = c10::Float8_e4m3fnuz;
|
||||
// Using the default max value from pytorch (240.0) will cause accuracy
|
||||
// issue when running dynamic quantization. Here use 224.0f for rocm.
|
||||
constexpr auto FP8_E4M3_MAX = 224.0f;
|
||||
// ROCm doesn't seem to need C10_HOST_DEVICE for static constexpr
|
||||
#define MAYBE_HOST_DEVICE
|
||||
#endif
|
||||
constexpr static auto kFp8Type = c10::CppTypeToScalarType<FP8_TYPE>::value;
|
||||
|
||||
// Determines the preferred FP8 type for the current platform.
|
||||
// Note that for CUDA this just returns true,
|
||||
// but on ROCm it will check device props.
|
||||
static bool is_fp8_ocp() {
|
||||
#ifndef USE_ROCM
|
||||
return true;
|
||||
#else
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
std::string device_arch = dprops->gcnArchName;
|
||||
size_t substring = device_arch.find("gfx94");
|
||||
return substring == std::string::npos;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct fp8_e4m3_adjusted_max;
|
||||
|
||||
template <>
|
||||
struct fp8_e4m3_adjusted_max<c10::Float8_e4m3fn> {
|
||||
static constexpr c10::Float8_e4m3fn val() {
|
||||
return std::numeric_limits<c10::Float8_e4m3fn>::max();
|
||||
}
|
||||
};
|
||||
|
||||
// Using the default max value from pytorch (240.0 0x7F) will cause accuracy
|
||||
// issues when running dynamic quantization. Here use 224.0 0x7E for rocm.
|
||||
template <>
|
||||
struct fp8_e4m3_adjusted_max<c10::Float8_e4m3fnuz> {
|
||||
static constexpr c10::Float8_e4m3fnuz val() {
|
||||
return c10::Float8_e4m3fnuz(0x7E, c10::Float8_e4m3fnuz::from_bits());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
MAYBE_HOST_DEVICE static constexpr T fp8_e4m3_adjusted_max_v =
|
||||
fp8_e4m3_adjusted_max<T>::val();
|
||||
|
||||
namespace vllm {
|
||||
|
||||
@ -32,8 +66,8 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
||||
return old;
|
||||
}
|
||||
|
||||
template <bool is_scale_inverted>
|
||||
__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
|
||||
template <bool is_scale_inverted, typename fp8_type>
|
||||
__device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val,
|
||||
float const scale) {
|
||||
float x = 0.0f;
|
||||
if constexpr (is_scale_inverted) {
|
||||
@ -42,15 +76,13 @@ __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
|
||||
x = val / scale;
|
||||
}
|
||||
|
||||
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
|
||||
float r = fmax(-fp8_e4m3_adjusted_max_v<fp8_type>,
|
||||
fmin(x, fp8_e4m3_adjusted_max_v<fp8_type>));
|
||||
#ifndef USE_ROCM
|
||||
return static_cast<c10::Float8_e4m3fn>(r);
|
||||
return static_cast<fp8_type>(r);
|
||||
#else
|
||||
// Use hardware cvt instruction for fp8 on rocm
|
||||
return c10::Float8_e4m3fnuz(
|
||||
__hip_cvt_float_to_fp8(r, fp8::fp8_type::__default_saturation,
|
||||
fp8::fp8_type::__default_interpret),
|
||||
c10::Float8_e4m3fnuz::from_bits());
|
||||
return fp8::cvt_c10<fp8_type>(r);
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -60,7 +92,7 @@ __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
|
||||
// So to get the right answer, *scale needs to be initialized to
|
||||
// a value <= 0.0 and we need to wait for all thread blocks to
|
||||
// finish before consuming *scale.
|
||||
template <typename scalar_t>
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
__global__ void segmented_max_reduction(float* __restrict__ scale,
|
||||
const scalar_t* __restrict__ input,
|
||||
int64_t num_elems) {
|
||||
@ -91,7 +123,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
|
||||
// Finally, since cache[0] contains the maximum for this thread block,
|
||||
// atomically write the max to the target location
|
||||
if (threadIdx.x == 0) {
|
||||
atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX);
|
||||
atomicMaxFloat(scale, cache[0] / fp8_e4m3_adjusted_max_v<fp8_type>);
|
||||
}
|
||||
}
|
||||
|
||||
@ -123,13 +155,13 @@ __device__ float thread_max_vec(scalar_t const* __restrict__ input,
|
||||
return absmax_val;
|
||||
}
|
||||
|
||||
template <typename scalar_t, bool is_scale_inverted>
|
||||
__device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
|
||||
template <typename scalar_t, bool is_scale_inverted, typename fp8_type>
|
||||
__device__ void scaled_fp8_conversion_vec(fp8_type* __restrict__ out,
|
||||
scalar_t const* __restrict__ input,
|
||||
float const scale,
|
||||
int64_t const num_elems,
|
||||
int const tid, int const step) {
|
||||
using float8x4_t = q8x4_t<FP8_TYPE>;
|
||||
using float8x4_t = q8x4_t<fp8_type>;
|
||||
// Vectorized input/output to better utilize memory bandwidth.
|
||||
auto const* vectorized_in = reinterpret_cast<vec4_t<scalar_t> const*>(input);
|
||||
auto* vectorized_out = reinterpret_cast<float8x4_t*>(out);
|
||||
@ -141,22 +173,22 @@ __device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
|
||||
vec4_t<scalar_t> in_vec = vectorized_in[i];
|
||||
float8x4_t out_vec;
|
||||
|
||||
out_vec.x = scaled_fp8_conversion<is_scale_inverted>(
|
||||
out_vec.x = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
||||
static_cast<float>(in_vec.x), scale);
|
||||
out_vec.y = scaled_fp8_conversion<is_scale_inverted>(
|
||||
out_vec.y = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
||||
static_cast<float>(in_vec.y), scale);
|
||||
out_vec.z = scaled_fp8_conversion<is_scale_inverted>(
|
||||
out_vec.z = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
||||
static_cast<float>(in_vec.z), scale);
|
||||
out_vec.w = scaled_fp8_conversion<is_scale_inverted>(
|
||||
out_vec.w = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
||||
static_cast<float>(in_vec.w), scale);
|
||||
vectorized_out[i] = out_vec;
|
||||
}
|
||||
|
||||
// Handle the remaining elements if num_elems is not divisible by 4
|
||||
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
|
||||
out[i] = scaled_fp8_conversion<is_scale_inverted>(
|
||||
out[i] = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
||||
static_cast<float>(input[i]), scale);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
} // namespace vllm
|
||||
|
||||
@ -144,6 +144,9 @@ void rms_norm_dynamic_per_token_quant(
|
||||
torch::Tensor& scales, // [num_tokens]
|
||||
double const var_epsilon, // Variance epsilon used in norm calculation
|
||||
std::optional<at::Tensor> scale_ub, std::optional<at::Tensor> residual) {
|
||||
static c10::ScalarType kFp8Type = is_fp8_ocp()
|
||||
? c10::ScalarType::Float8_e4m3fn
|
||||
: c10::ScalarType::Float8_e4m3fnuz;
|
||||
TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(out.is_contiguous() && input.is_contiguous());
|
||||
|
||||
|
||||
@ -24,7 +24,7 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
|
||||
// sum of squares
|
||||
float ss = 0.0f;
|
||||
|
||||
for (int32_t i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
float x = static_cast<float>(input[token_offset + i]);
|
||||
if constexpr (has_residual) {
|
||||
x += static_cast<float>(residual[token_offset + i]);
|
||||
@ -58,7 +58,7 @@ __device__ void compute_dynamic_per_token_scales(
|
||||
constexpr scalar_out_t qmax{std::numeric_limits<scalar_out_t>::max()};
|
||||
|
||||
float block_absmax_val_maybe = 0.0f;
|
||||
for (int32_t i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
float x = static_cast<float>(input[token_offset + i]);
|
||||
if constexpr (has_residual) {
|
||||
x += static_cast<float>(residual[token_offset + i]);
|
||||
@ -103,7 +103,7 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output,
|
||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||
;
|
||||
|
||||
for (int32_t i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
float x = static_cast<float>(input[token_offset + i]);
|
||||
if constexpr (has_residual) {
|
||||
x += static_cast<float>(residual[token_offset + i]);
|
||||
@ -142,7 +142,7 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
|
||||
int32_t const num_vec_elems = hidden_size >> 2;
|
||||
|
||||
#pragma unroll 4
|
||||
for (int32_t i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
|
||||
for (auto i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
|
||||
vec4_t<scalar_t> in = vec_input[i];
|
||||
|
||||
vec4_t<float> x;
|
||||
@ -206,7 +206,7 @@ __device__ void compute_dynamic_per_token_scales(
|
||||
float block_absmax_val_maybe = 0.0f;
|
||||
|
||||
#pragma unroll 4
|
||||
for (int32_t i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
|
||||
for (auto i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
|
||||
vec4_t<scalar_t> in = vec_input[i];
|
||||
vec4_t<scalar_t> const w = vec_weight[i];
|
||||
|
||||
@ -286,7 +286,7 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output,
|
||||
// TODO(luka/varun) extract into type-agnostic vectorized quant function to
|
||||
// replace scaled_fp8_conversion_vec
|
||||
#pragma unroll 4
|
||||
for (int32_t i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
|
||||
for (auto i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
|
||||
vec4_t<scalar_t> const in = vec_input[i];
|
||||
vec4_t<scalar_t> const w = vec_weight[i];
|
||||
|
||||
|
||||
@ -31,9 +31,11 @@ static __device__ __forceinline__ int8_t float_to_int8_rn(float const x) {
|
||||
#endif
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ FP8_TYPE float_to_fp8(float const x) {
|
||||
float const r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
|
||||
return static_cast<FP8_TYPE>(r);
|
||||
template <typename fp8_type>
|
||||
static __device__ __forceinline__ fp8_type float_to_fp8(float const x) {
|
||||
float const r = fmax(-fp8_e4m3_adjusted_max_v<fp8_type>,
|
||||
fmin(x, fp8_e4m3_adjusted_max_v<fp8_type>));
|
||||
return static_cast<fp8_type>(r);
|
||||
}
|
||||
|
||||
template <typename quant_type_t, bool is_scale_inverted, typename enable = void>
|
||||
@ -54,15 +56,16 @@ struct ScaledQuant<
|
||||
};
|
||||
|
||||
template <typename quant_type_t, bool is_scale_inverted>
|
||||
struct ScaledQuant<
|
||||
quant_type_t, is_scale_inverted,
|
||||
typename std::enable_if_t<std::is_same_v<quant_type_t, FP8_TYPE>>> {
|
||||
struct ScaledQuant<quant_type_t, is_scale_inverted,
|
||||
typename std::enable_if_t<
|
||||
std::is_same_v<quant_type_t, c10::Float8_e4m3fn> ||
|
||||
std::is_same_v<quant_type_t, c10::Float8_e4m3fnuz>>> {
|
||||
static __device__ __forceinline__ quant_type_t quant_fn(float const x,
|
||||
float const scale) {
|
||||
if constexpr (is_scale_inverted) {
|
||||
return float_to_fp8(x * scale);
|
||||
return float_to_fp8<quant_type_t>(x * scale);
|
||||
} else {
|
||||
return float_to_fp8(x / scale);
|
||||
return float_to_fp8<quant_type_t>(x / scale);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@ -101,10 +101,10 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const auto i = blockIdx.x;
|
||||
const block_q2_K * x = (const block_q2_K *) vx;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const auto tid = threadIdx.x;
|
||||
const int n = tid/32;
|
||||
const int l = tid - 32*n;
|
||||
const int is = 8*n + l/16;
|
||||
@ -123,10 +123,10 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const auto i = blockIdx.x;
|
||||
const block_q3_K * x = (const block_q3_K *) vx;
|
||||
|
||||
const int r = threadIdx.x/4;
|
||||
const auto r = threadIdx.x/4;
|
||||
const int tid = r/2;
|
||||
const int is0 = r%2;
|
||||
const int l0 = 16*is0 + 4*(threadIdx.x%4);
|
||||
@ -164,10 +164,10 @@ template<typename dst_t>
|
||||
static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
const block_q4_K * x = (const block_q4_K *) vx;
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const auto i = blockIdx.x;
|
||||
|
||||
// assume 32 threads
|
||||
const int tid = threadIdx.x;
|
||||
const auto tid = threadIdx.x;
|
||||
const int il = tid/8;
|
||||
const int ir = tid%8;
|
||||
const int is = 2*il;
|
||||
@ -197,10 +197,10 @@ template<typename dst_t>
|
||||
static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
const block_q5_K * x = (const block_q5_K *) vx;
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const auto i = blockIdx.x;
|
||||
|
||||
// assume 64 threads - this is very slightly better than the one below
|
||||
const int tid = threadIdx.x;
|
||||
const auto tid = threadIdx.x;
|
||||
const int il = tid/16; // il is in 0...3
|
||||
const int ir = tid%16; // ir is in 0...15
|
||||
const int is = 2*il; // is is in 0...6
|
||||
@ -231,10 +231,10 @@ template<typename dst_t>
|
||||
static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
const block_q6_K * x = (const block_q6_K *) vx;
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const auto i = blockIdx.x;
|
||||
|
||||
// assume 64 threads - this is very slightly better than the one below
|
||||
const int tid = threadIdx.x;
|
||||
const auto tid = threadIdx.x;
|
||||
const int ip = tid/32; // ip is 0 or 1
|
||||
const int il = tid - 32*ip; // 0...32
|
||||
const int is = 8*ip + il/16;
|
||||
@ -256,10 +256,10 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const auto i = blockIdx.x;
|
||||
const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const auto tid = threadIdx.x;
|
||||
const int il = tid/8; // 0...3
|
||||
const int ib = tid%8; // 0...7
|
||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||
@ -275,10 +275,10 @@ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, ds
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const auto i = blockIdx.x;
|
||||
const block_iq2_xs * x = (const block_iq2_xs *) vx;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const auto tid = threadIdx.x;
|
||||
const int il = tid/8; // 0...3
|
||||
const int ib = tid%8; // 0...7
|
||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||
@ -293,10 +293,10 @@ static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const auto i = blockIdx.x;
|
||||
const block_iq2_s * x = (const block_iq2_s *) vx;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const auto tid = threadIdx.x;
|
||||
const int il = tid/8; // 0...3
|
||||
const int ib = tid%8; // 0...7
|
||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||
@ -309,10 +309,10 @@ static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const auto i = blockIdx.x;
|
||||
const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const auto tid = threadIdx.x;
|
||||
const int il = tid/8; // 0...3
|
||||
const int ib = tid%8; // 0...7
|
||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||
@ -332,10 +332,10 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const auto i = blockIdx.x;
|
||||
const block_iq3_s * x = (const block_iq3_s *) vx;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const auto tid = threadIdx.x;
|
||||
const int il = tid/8; // 0...3
|
||||
const int ib = tid%8; // 0...7
|
||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||
@ -399,10 +399,10 @@ static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
|
||||
const int i = blockIdx.x;
|
||||
const auto i = blockIdx.x;
|
||||
const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const auto tid = threadIdx.x;
|
||||
const int il = tid/8; // 0...3
|
||||
const int ib = tid%8; // 0...7
|
||||
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
||||
@ -417,10 +417,10 @@ static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst
|
||||
|
||||
template<typename dst_t>
|
||||
static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||
const int i = blockIdx.x;
|
||||
const auto i = blockIdx.x;
|
||||
const block_iq4_xs * x = (const block_iq4_xs *)vx;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const auto tid = threadIdx.x;
|
||||
const int il = tid/8; // 0...3
|
||||
const int ib = tid%8; // 0...7
|
||||
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
||||
@ -565,4 +565,4 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(int64_t type) {
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -5,22 +5,25 @@
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include "ggml-common.h"
|
||||
#include "vecdotq.cuh"
|
||||
#include "dequantize.cuh"
|
||||
#include "mmvq.cuh"
|
||||
#include "mmq.cuh"
|
||||
#include "moe.cuh"
|
||||
|
||||
// Q8 gemv
|
||||
static __global__ void quantize_q8_1(const half* __restrict__ x,
|
||||
template <typename scalar_t>
|
||||
static __global__ void quantize_q8_1(const scalar_t* __restrict__ x,
|
||||
void* __restrict__ vy, const int kx,
|
||||
const int kx_padded) {
|
||||
const int ix = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const auto ix = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (ix >= kx_padded) {
|
||||
return;
|
||||
}
|
||||
const int iy = blockDim.y * blockIdx.y + threadIdx.y;
|
||||
const auto iy = blockDim.y * blockIdx.y + threadIdx.y;
|
||||
const int i_padded = iy * kx_padded + ix;
|
||||
|
||||
block_q8_1* y = (block_q8_1*)vy;
|
||||
@ -28,7 +31,7 @@ static __global__ void quantize_q8_1(const half* __restrict__ x,
|
||||
const int ib = i_padded / QK8_1; // block index
|
||||
const int iqs = i_padded % QK8_1; // quant index
|
||||
|
||||
const float xi = ix < kx ? __half2float(x[iy * kx + ix]) : 0.0f;
|
||||
const float xi = ix < kx ? static_cast<float>(x[iy * kx + ix]) : 0.0f;
|
||||
float amax = fabsf(xi);
|
||||
float sum = xi;
|
||||
|
||||
@ -51,14 +54,20 @@ static __global__ void quantize_q8_1(const half* __restrict__ x,
|
||||
y[ib].ds.y = __float2half(sum);
|
||||
}
|
||||
|
||||
static void quantize_row_q8_1_cuda(const half* x, void* vy, const int kx,
|
||||
template <typename scalar_t>
|
||||
static void quantize_row_q8_1_cuda(const scalar_t* x, void* vy, const int kx,
|
||||
const int ky, cudaStream_t stream) {
|
||||
const int64_t kx_padded = (kx + 512 - 1) / 512 * 512;
|
||||
const int block_num_x =
|
||||
(kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
|
||||
const dim3 num_blocks(block_num_x, ky, 1);
|
||||
const dim3 block_size(CUDA_DEQUANTIZE_BLOCK_SIZE, 1, 1);
|
||||
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
|
||||
constexpr int MAX_BLOCK_SIZE = 65535;
|
||||
for (int off = 0; off < ky; off += MAX_BLOCK_SIZE) {
|
||||
const int num_blocks_y = std::min(ky, off + MAX_BLOCK_SIZE) - off;
|
||||
const dim3 num_blocks(block_num_x, num_blocks_y, 1);
|
||||
const dim3 block_size(CUDA_DEQUANTIZE_BLOCK_SIZE, 1, 1);
|
||||
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(
|
||||
&x[off * kx], (int32_t*)vy + off * (kx_padded / 32 * 9), kx, kx_padded);
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight
|
||||
@ -79,101 +88,112 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight
|
||||
int col = X.sizes()[1];
|
||||
const int padded = (col + 512 - 1) / 512 * 512;
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
|
||||
auto options =
|
||||
torch::TensorOptions().dtype(torch::kFloat16).device(W.device());
|
||||
auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device());
|
||||
at::Tensor Y = torch::empty({1, row}, options);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
options = torch::TensorOptions().dtype(torch::kInt32).device(W.device());
|
||||
at::Tensor quant_X = torch::empty({1, padded / 32 * 9}, options);
|
||||
quantize_row_q8_1_cuda((half*)X.data_ptr(), (void*)quant_X.data_ptr(), col, 1,
|
||||
stream);
|
||||
switch (type) {
|
||||
case 2:
|
||||
mul_mat_vec_q4_0_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 3:
|
||||
mul_mat_vec_q4_1_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 6:
|
||||
mul_mat_vec_q5_0_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 7:
|
||||
mul_mat_vec_q5_1_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 8:
|
||||
mul_mat_vec_q8_0_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 10:
|
||||
mul_mat_vec_q2_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 11:
|
||||
mul_mat_vec_q3_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 12:
|
||||
mul_mat_vec_q4_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 13:
|
||||
mul_mat_vec_q5_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 14:
|
||||
mul_mat_vec_q6_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 16:
|
||||
mul_mat_vec_iq2_xxs_q8_1_cuda((void*)W.data_ptr(),
|
||||
(void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 17:
|
||||
mul_mat_vec_iq2_xs_q8_1_cuda((void*)W.data_ptr(),
|
||||
(void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 18:
|
||||
mul_mat_vec_iq3_xxs_q8_1_cuda((void*)W.data_ptr(),
|
||||
(void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 19:
|
||||
mul_mat_vec_iq1_s_q8_1_cuda((void*)W.data_ptr(),
|
||||
(void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 20:
|
||||
mul_mat_vec_iq4_nl_q8_1_cuda((void*)W.data_ptr(),
|
||||
(void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 21:
|
||||
mul_mat_vec_iq3_s_q8_1_cuda((void*)W.data_ptr(),
|
||||
(void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 22:
|
||||
mul_mat_vec_iq2_s_q8_1_cuda((void*)W.data_ptr(),
|
||||
(void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 23:
|
||||
mul_mat_vec_iq4_xs_q8_1_cuda((void*)W.data_ptr(),
|
||||
(void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 29:
|
||||
mul_mat_vec_iq1_m_q8_1_cuda((void*)W.data_ptr(),
|
||||
(void*)quant_X.data_ptr(),
|
||||
(half*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
}
|
||||
VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_mul_mat_vec_a8", [&] {
|
||||
quantize_row_q8_1_cuda<scalar_t>((scalar_t*)X.data_ptr(),
|
||||
(void*)quant_X.data_ptr(), col, 1, stream);
|
||||
switch (type) {
|
||||
case 2:
|
||||
mul_mat_vec_q4_0_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 3:
|
||||
mul_mat_vec_q4_1_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 6:
|
||||
mul_mat_vec_q5_0_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 7:
|
||||
mul_mat_vec_q5_1_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 8:
|
||||
mul_mat_vec_q8_0_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 10:
|
||||
mul_mat_vec_q2_K_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 11:
|
||||
mul_mat_vec_q3_K_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 12:
|
||||
mul_mat_vec_q4_K_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 13:
|
||||
mul_mat_vec_q5_K_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 14:
|
||||
mul_mat_vec_q6_K_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 16:
|
||||
mul_mat_vec_iq2_xxs_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 17:
|
||||
mul_mat_vec_iq2_xs_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 18:
|
||||
mul_mat_vec_iq3_xxs_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 19:
|
||||
mul_mat_vec_iq1_s_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 20:
|
||||
mul_mat_vec_iq4_nl_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 21:
|
||||
mul_mat_vec_iq3_s_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 22:
|
||||
mul_mat_vec_iq2_s_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 23:
|
||||
mul_mat_vec_iq4_xs_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
case 29:
|
||||
mul_mat_vec_iq1_m_q8_1_cuda<scalar_t>(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, stream);
|
||||
break;
|
||||
}
|
||||
});
|
||||
return Y;
|
||||
}
|
||||
|
||||
@ -184,66 +204,196 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight
|
||||
int padded = (col + 512 - 1) / 512 * 512;
|
||||
int batch = X.sizes()[0];
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
|
||||
auto options =
|
||||
torch::TensorOptions().dtype(torch::kFloat16).device(W.device());
|
||||
auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device());
|
||||
at::Tensor Y = torch::empty({batch, row}, options);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
options = torch::TensorOptions().dtype(torch::kInt32).device(W.device());
|
||||
at::Tensor quant_X = torch::empty({batch, padded / 32 * 9}, options);
|
||||
quantize_row_q8_1_cuda((half*)X.data_ptr(), (void*)quant_X.data_ptr(), col,
|
||||
batch, stream);
|
||||
VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_mul_mat_a8", [&] {
|
||||
quantize_row_q8_1_cuda((scalar_t*)X.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
col, batch, stream);
|
||||
|
||||
switch (type) {
|
||||
case 2:
|
||||
ggml_mul_mat_q4_0_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
|
||||
col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 3:
|
||||
ggml_mul_mat_q4_1_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
|
||||
col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 6:
|
||||
ggml_mul_mat_q5_0_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
|
||||
col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 7:
|
||||
ggml_mul_mat_q5_1_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
|
||||
col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 8:
|
||||
ggml_mul_mat_q8_0_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
|
||||
col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 10:
|
||||
ggml_mul_mat_q2_K_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
|
||||
col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 11:
|
||||
ggml_mul_mat_q3_K_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
|
||||
col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 12:
|
||||
ggml_mul_mat_q4_K_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
|
||||
col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 13:
|
||||
ggml_mul_mat_q5_K_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
|
||||
col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 14:
|
||||
ggml_mul_mat_q6_K_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
|
||||
col, row, batch, padded, row, stream);
|
||||
break;
|
||||
}
|
||||
switch (type) {
|
||||
case 2:
|
||||
ggml_mul_mat_q4_0_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 3:
|
||||
ggml_mul_mat_q4_1_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 6:
|
||||
ggml_mul_mat_q5_0_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 7:
|
||||
ggml_mul_mat_q5_1_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 8:
|
||||
ggml_mul_mat_q8_0_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 10:
|
||||
ggml_mul_mat_q2_K_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 11:
|
||||
ggml_mul_mat_q3_K_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 12:
|
||||
ggml_mul_mat_q4_K_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 13:
|
||||
ggml_mul_mat_q5_K_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream);
|
||||
break;
|
||||
case 14:
|
||||
ggml_mul_mat_q6_K_q8_1_cuda(
|
||||
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), col, row, batch, padded, row, stream);
|
||||
break;
|
||||
}
|
||||
});
|
||||
return Y;
|
||||
}
|
||||
|
||||
torch::Tensor ggml_moe_a8(torch::Tensor X, // input
|
||||
torch::Tensor W, // expert weights
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor expert_ids,
|
||||
torch::Tensor num_tokens_post_padded, int64_t type,
|
||||
int64_t row, int64_t top_k, int64_t tokens) {
|
||||
int col = X.sizes()[1];
|
||||
int padded = (col + 512 - 1) / 512 * 512;
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
|
||||
auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device());
|
||||
at::Tensor Y = torch::empty({tokens * top_k, row}, options);
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
options = torch::TensorOptions().dtype(torch::kInt32).device(W.device());
|
||||
at::Tensor quant_X = torch::empty({tokens, padded / 32 * 9}, options);
|
||||
VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_moe_a8", [&] {
|
||||
quantize_row_q8_1_cuda((scalar_t*)X.data_ptr(), (void*)quant_X.data_ptr(),
|
||||
col, tokens, stream);
|
||||
switch (type) {
|
||||
case 2:
|
||||
ggml_moe_q4_0_q8_1_cuda(
|
||||
(void*)quant_X.data_ptr(), (void*)W.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(),
|
||||
(int*)expert_ids.data_ptr(),
|
||||
(int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row,
|
||||
tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream);
|
||||
break;
|
||||
case 3:
|
||||
ggml_moe_q4_1_q8_1_cuda(
|
||||
(void*)quant_X.data_ptr(), (void*)W.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(),
|
||||
(int*)expert_ids.data_ptr(),
|
||||
(int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row,
|
||||
tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream);
|
||||
break;
|
||||
case 6:
|
||||
ggml_moe_q5_0_q8_1_cuda(
|
||||
(void*)quant_X.data_ptr(), (void*)W.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(),
|
||||
(int*)expert_ids.data_ptr(),
|
||||
(int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row,
|
||||
tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream);
|
||||
break;
|
||||
case 7:
|
||||
ggml_moe_q5_1_q8_1_cuda(
|
||||
(void*)quant_X.data_ptr(), (void*)W.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(),
|
||||
(int*)expert_ids.data_ptr(),
|
||||
(int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row,
|
||||
tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream);
|
||||
break;
|
||||
case 8:
|
||||
ggml_moe_q8_0_q8_1_cuda(
|
||||
(void*)quant_X.data_ptr(), (void*)W.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(),
|
||||
(int*)expert_ids.data_ptr(),
|
||||
(int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row,
|
||||
tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream);
|
||||
break;
|
||||
case 10:
|
||||
ggml_moe_q2_K_q8_1_cuda(
|
||||
(void*)quant_X.data_ptr(), (void*)W.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(),
|
||||
(int*)expert_ids.data_ptr(),
|
||||
(int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row,
|
||||
tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream);
|
||||
break;
|
||||
case 11:
|
||||
ggml_moe_q3_K_q8_1_cuda(
|
||||
(void*)quant_X.data_ptr(), (void*)W.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(),
|
||||
(int*)expert_ids.data_ptr(),
|
||||
(int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row,
|
||||
tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream);
|
||||
break;
|
||||
case 12:
|
||||
ggml_moe_q4_K_q8_1_cuda(
|
||||
(void*)quant_X.data_ptr(), (void*)W.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(),
|
||||
(int*)expert_ids.data_ptr(),
|
||||
(int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row,
|
||||
tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream);
|
||||
break;
|
||||
case 13:
|
||||
ggml_moe_q5_K_q8_1_cuda(
|
||||
(void*)quant_X.data_ptr(), (void*)W.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(),
|
||||
(int*)expert_ids.data_ptr(),
|
||||
(int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row,
|
||||
tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream);
|
||||
break;
|
||||
case 14:
|
||||
ggml_moe_q6_K_q8_1_cuda(
|
||||
(void*)quant_X.data_ptr(), (void*)W.data_ptr(),
|
||||
(scalar_t*)Y.data_ptr(), (int*)sorted_token_ids.data_ptr(),
|
||||
(int*)expert_ids.data_ptr(),
|
||||
(int*)num_tokens_post_padded.data_ptr(), W.stride(0), col, row,
|
||||
tokens, padded, row, top_k, sorted_token_ids.sizes()[0], stream);
|
||||
break;
|
||||
}
|
||||
});
|
||||
return Y;
|
||||
}
|
||||
|
||||
int64_t ggml_moe_get_block_size(int64_t type) {
|
||||
switch (type) {
|
||||
case 2:
|
||||
return MMQ_X_Q4_0;
|
||||
case 3:
|
||||
return MMQ_X_Q4_1;
|
||||
case 6:
|
||||
return MMQ_X_Q5_0;
|
||||
case 7:
|
||||
return MMQ_X_Q5_1;
|
||||
case 8:
|
||||
return MMQ_X_Q8_0;
|
||||
case 10:
|
||||
return MMQ_X_Q2_K;
|
||||
case 11:
|
||||
return MMQ_X_Q3_K;
|
||||
case 12:
|
||||
return MMQ_X_Q4_K;
|
||||
case 13:
|
||||
return MMQ_X_Q5_K;
|
||||
case 14:
|
||||
return MMQ_X_Q6_K;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
// copied from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmq.cu
|
||||
template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
|
||||
template <typename scalar_t, int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
|
||||
allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
|
||||
static __device__ __forceinline__ void mul_mat_q(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
const block_q_t * x = (const block_q_t *) vx;
|
||||
@ -14,10 +14,10 @@ static __device__ __forceinline__ void mul_mat_q(
|
||||
|
||||
const int & ncols_dst = ncols_y;
|
||||
|
||||
const int row_dst_0 = blockIdx.x*mmq_y;
|
||||
const auto row_dst_0 = blockIdx.x*mmq_y;
|
||||
const int & row_x_0 = row_dst_0;
|
||||
|
||||
const int col_dst_0 = blockIdx.y*mmq_x;
|
||||
const auto col_dst_0 = blockIdx.y*mmq_x;
|
||||
const int & col_y_0 = col_dst_0;
|
||||
|
||||
int * tile_x_ql = nullptr;
|
||||
@ -38,8 +38,8 @@ static __device__ __forceinline__ void mul_mat_q(
|
||||
threadIdx.y, nrows_x-row_x_0-1, threadIdx.x, blocks_per_row_x);
|
||||
|
||||
#pragma unroll
|
||||
for (int ir = 0; ir < qr; ++ir) {
|
||||
const int kqs = ir*WARP_SIZE_GGUF + threadIdx.x;
|
||||
for (int ir = 0; ir < qr && ib0 + ir * blocks_per_warp/qr < blocks_per_row_x; ++ir) {
|
||||
const auto kqs = ir*WARP_SIZE_GGUF + threadIdx.x;
|
||||
const int kbxd = kqs / QI8_1;
|
||||
|
||||
#pragma unroll
|
||||
@ -53,7 +53,7 @@ static __device__ __forceinline__ void mul_mat_q(
|
||||
#pragma unroll
|
||||
for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) {
|
||||
const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE_GGUF/QI8_1)) % mmq_x;
|
||||
const int kby = threadIdx.x % (WARP_SIZE_GGUF/QI8_1);
|
||||
const auto kby = threadIdx.x % (WARP_SIZE_GGUF/QI8_1);
|
||||
const int col_y_eff = min(col_y_0 + ids, ncols_y-1);
|
||||
|
||||
// if the sum is not needed it's faster to transform the scale to f32 ahead of time
|
||||
@ -87,18 +87,18 @@ static __device__ __forceinline__ void mul_mat_q(
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < mmq_x; j += nwarps) {
|
||||
const int col_dst = col_dst_0 + j + threadIdx.y;
|
||||
const auto col_dst = col_dst_0 + j + threadIdx.y;
|
||||
if (col_dst >= ncols_dst) {
|
||||
return;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < mmq_y; i += WARP_SIZE_GGUF) {
|
||||
const int row_dst = row_dst_0 + threadIdx.x + i;
|
||||
const auto row_dst = row_dst_0 + threadIdx.x + i;
|
||||
if (row_dst >= nrows_dst) {
|
||||
continue;
|
||||
}
|
||||
dst[col_dst*nrows_dst + row_dst] = __float2half(sum[i/WARP_SIZE_GGUF][j/nwarps]);
|
||||
dst[col_dst*nrows_dst + row_dst] = sum[i/WARP_SIZE_GGUF][j/nwarps];
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -113,24 +113,25 @@ static __device__ __forceinline__ void mul_mat_q(
|
||||
#define NWARPS_Q4_0 4
|
||||
#endif
|
||||
|
||||
template <bool need_check> static __global__ void
|
||||
template<typename scalar_t, bool need_check> static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q4_0, 2)
|
||||
#endif
|
||||
mul_mat_q4_0(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||
const int mmq_x = MMQ_X_Q4_0;
|
||||
const int mmq_y = MMQ_Y_Q4_0;
|
||||
const int nwarps = NWARPS_Q4_0;
|
||||
|
||||
mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, allocate_tiles_q4_0<mmq_y>,
|
||||
mul_mat_q<scalar_t, QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, allocate_tiles_q4_0<mmq_y>,
|
||||
load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
static void ggml_mul_mat_q4_0_q8_1_cuda(
|
||||
const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
|
||||
const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
int mmq_x = MMQ_X_Q4_0;
|
||||
@ -144,11 +145,11 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
const bool need_check = false;
|
||||
mul_mat_q4_0<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
mul_mat_q4_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
} else {
|
||||
const bool need_check = true;
|
||||
mul_mat_q4_0<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
mul_mat_q4_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
}
|
||||
@ -163,24 +164,25 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
|
||||
#define NWARPS_Q4_1 4
|
||||
#endif
|
||||
|
||||
template <bool need_check> static __global__ void
|
||||
template<typename scalar_t, bool need_check> static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q4_1, 2)
|
||||
#endif
|
||||
mul_mat_q4_1(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||
const int mmq_x = MMQ_X_Q4_1;
|
||||
const int mmq_y = MMQ_Y_Q4_1;
|
||||
const int nwarps = NWARPS_Q4_1;
|
||||
|
||||
mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, allocate_tiles_q4_1<mmq_y>,
|
||||
mul_mat_q<scalar_t, QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, allocate_tiles_q4_1<mmq_y>,
|
||||
load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
static void ggml_mul_mat_q4_1_q8_1_cuda(
|
||||
const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
|
||||
const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
int mmq_x = MMQ_X_Q4_1;
|
||||
@ -194,11 +196,11 @@ static void ggml_mul_mat_q4_1_q8_1_cuda(
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
const bool need_check = false;
|
||||
mul_mat_q4_1<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
mul_mat_q4_1<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
} else {
|
||||
const bool need_check = true;
|
||||
mul_mat_q4_1<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
mul_mat_q4_1<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
}
|
||||
@ -213,24 +215,25 @@ static void ggml_mul_mat_q4_1_q8_1_cuda(
|
||||
#define NWARPS_Q5_0 4
|
||||
#endif
|
||||
|
||||
template <bool need_check> static __global__ void
|
||||
template<typename scalar_t, bool need_check> static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q5_0, 2)
|
||||
#endif
|
||||
mul_mat_q5_0(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||
const int mmq_x = MMQ_X_Q5_0;
|
||||
const int mmq_y = MMQ_Y_Q5_0;
|
||||
const int nwarps = NWARPS_Q5_0;
|
||||
|
||||
mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, allocate_tiles_q5_0<mmq_y>,
|
||||
mul_mat_q<scalar_t, QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, allocate_tiles_q5_0<mmq_y>,
|
||||
load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
static void ggml_mul_mat_q5_0_q8_1_cuda(
|
||||
const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
|
||||
const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
const int mmq_x = MMQ_X_Q5_0;
|
||||
@ -244,11 +247,11 @@ static void ggml_mul_mat_q5_0_q8_1_cuda(
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
const bool need_check = false;
|
||||
mul_mat_q5_0<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
mul_mat_q5_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
} else {
|
||||
const bool need_check = true;
|
||||
mul_mat_q5_0<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
mul_mat_q5_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
}
|
||||
@ -263,24 +266,25 @@ static void ggml_mul_mat_q5_0_q8_1_cuda(
|
||||
#define NWARPS_Q5_1 4
|
||||
#endif
|
||||
|
||||
template <bool need_check> static __global__ void
|
||||
template<typename scalar_t, bool need_check> static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q5_1, 2)
|
||||
#endif
|
||||
mul_mat_q5_1(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||
const int mmq_x = MMQ_X_Q5_1;
|
||||
const int mmq_y = MMQ_Y_Q5_1;
|
||||
const int nwarps = NWARPS_Q5_1;
|
||||
|
||||
mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, allocate_tiles_q5_1<mmq_y>,
|
||||
mul_mat_q<scalar_t, QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, allocate_tiles_q5_1<mmq_y>,
|
||||
load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
static void ggml_mul_mat_q5_1_q8_1_cuda(
|
||||
const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
|
||||
const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int mmq_x = MMQ_X_Q5_1;
|
||||
const int mmq_y = MMQ_Y_Q5_1;
|
||||
@ -293,11 +297,11 @@ static void ggml_mul_mat_q5_1_q8_1_cuda(
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
const bool need_check = false;
|
||||
mul_mat_q5_1<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
mul_mat_q5_1<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
} else {
|
||||
const bool need_check = true;
|
||||
mul_mat_q5_1<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
mul_mat_q5_1<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
}
|
||||
@ -312,24 +316,25 @@ static void ggml_mul_mat_q5_1_q8_1_cuda(
|
||||
#define NWARPS_Q8_0 4
|
||||
#endif
|
||||
|
||||
template <bool need_check> static __global__ void
|
||||
template<typename scalar_t, bool need_check> static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q8_0, 2)
|
||||
#endif
|
||||
mul_mat_q8_0(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||
const int mmq_x = MMQ_X_Q8_0;
|
||||
const int mmq_y = MMQ_Y_Q8_0;
|
||||
const int nwarps = NWARPS_Q8_0;
|
||||
|
||||
mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, allocate_tiles_q8_0<mmq_y>,
|
||||
mul_mat_q<scalar_t, QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, allocate_tiles_q8_0<mmq_y>,
|
||||
load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
static void ggml_mul_mat_q8_0_q8_1_cuda(
|
||||
const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
|
||||
const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int mmq_x = MMQ_X_Q8_0;
|
||||
const int mmq_y = MMQ_Y_Q8_0;
|
||||
@ -342,11 +347,11 @@ static void ggml_mul_mat_q8_0_q8_1_cuda(
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
const bool need_check = false;
|
||||
mul_mat_q8_0<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
mul_mat_q8_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
} else {
|
||||
const bool need_check = true;
|
||||
mul_mat_q8_0<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
mul_mat_q8_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
}
|
||||
@ -361,24 +366,25 @@ static void ggml_mul_mat_q8_0_q8_1_cuda(
|
||||
#define NWARPS_Q2_K 4
|
||||
#endif
|
||||
|
||||
template <bool need_check> static __global__ void
|
||||
template<typename scalar_t, bool need_check> static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q2_K, 2)
|
||||
#endif
|
||||
mul_mat_q2_K(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||
const int mmq_x = MMQ_X_Q2_K;
|
||||
const int mmq_y = MMQ_Y_Q2_K;
|
||||
const int nwarps = NWARPS_Q2_K;
|
||||
|
||||
mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, allocate_tiles_q2_K<mmq_y>,
|
||||
mul_mat_q<scalar_t, QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, allocate_tiles_q2_K<mmq_y>,
|
||||
load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
static void ggml_mul_mat_q2_K_q8_1_cuda(
|
||||
const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
|
||||
const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int mmq_x = MMQ_X_Q2_K;
|
||||
const int mmq_y = MMQ_Y_Q2_K;
|
||||
@ -391,11 +397,11 @@ static void ggml_mul_mat_q2_K_q8_1_cuda(
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
const bool need_check = false;
|
||||
mul_mat_q2_K<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
mul_mat_q2_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
} else {
|
||||
const bool need_check = true;
|
||||
mul_mat_q2_K<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
mul_mat_q2_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
}
|
||||
@ -410,25 +416,26 @@ static void ggml_mul_mat_q2_K_q8_1_cuda(
|
||||
#define NWARPS_Q3_K 4
|
||||
#endif
|
||||
|
||||
template <bool need_check> static __global__ void
|
||||
template<typename scalar_t, bool need_check> static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q3_K, 2)
|
||||
#endif
|
||||
mul_mat_q3_K(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||
|
||||
const int mmq_x = MMQ_X_Q3_K;
|
||||
const int mmq_y = MMQ_Y_Q3_K;
|
||||
const int nwarps = NWARPS_Q3_K;
|
||||
|
||||
mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, allocate_tiles_q3_K<mmq_y>,
|
||||
mul_mat_q<scalar_t, QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, allocate_tiles_q3_K<mmq_y>,
|
||||
load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
static void ggml_mul_mat_q3_K_q8_1_cuda(
|
||||
const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
|
||||
const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
const int mmq_x = MMQ_X_Q3_K;
|
||||
@ -442,11 +449,11 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
const bool need_check = false;
|
||||
mul_mat_q3_K<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
mul_mat_q3_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
} else {
|
||||
const bool need_check = true;
|
||||
mul_mat_q3_K<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
mul_mat_q3_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
}
|
||||
@ -461,24 +468,25 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(
|
||||
#define NWARPS_Q4_K 4
|
||||
#endif
|
||||
|
||||
template <bool need_check> static __global__ void
|
||||
template<typename scalar_t, bool need_check> static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q4_K, 2)
|
||||
#endif
|
||||
mul_mat_q4_K(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||
const int mmq_x = MMQ_X_Q4_K;
|
||||
const int mmq_y = MMQ_Y_Q4_K;
|
||||
const int nwarps = NWARPS_Q4_K;
|
||||
|
||||
mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, allocate_tiles_q4_K<mmq_y>,
|
||||
mul_mat_q<scalar_t, QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, allocate_tiles_q4_K<mmq_y>,
|
||||
load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
static void ggml_mul_mat_q4_K_q8_1_cuda(
|
||||
const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
|
||||
const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int mmq_x = MMQ_X_Q4_K;
|
||||
const int mmq_y = MMQ_Y_Q4_K;
|
||||
@ -491,11 +499,11 @@ static void ggml_mul_mat_q4_K_q8_1_cuda(
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
const bool need_check = false;
|
||||
mul_mat_q4_K<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
mul_mat_q4_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
} else {
|
||||
const bool need_check = true;
|
||||
mul_mat_q4_K<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
mul_mat_q4_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
}
|
||||
@ -510,24 +518,25 @@ static void ggml_mul_mat_q4_K_q8_1_cuda(
|
||||
#define NWARPS_Q5_K 4
|
||||
#endif
|
||||
|
||||
template <bool need_check> static __global__ void
|
||||
template<typename scalar_t, bool need_check> static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q5_K, 2)
|
||||
#endif
|
||||
mul_mat_q5_K(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||
const int mmq_x = MMQ_X_Q5_K;
|
||||
const int mmq_y = MMQ_Y_Q5_K;
|
||||
const int nwarps = NWARPS_Q5_K;
|
||||
|
||||
mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, allocate_tiles_q5_K<mmq_y>,
|
||||
mul_mat_q<scalar_t, QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, allocate_tiles_q5_K<mmq_y>,
|
||||
load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
static void ggml_mul_mat_q5_K_q8_1_cuda(
|
||||
const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
|
||||
const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
const int mmq_x = MMQ_X_Q5_K;
|
||||
@ -541,11 +550,11 @@ static void ggml_mul_mat_q5_K_q8_1_cuda(
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
const bool need_check = false;
|
||||
mul_mat_q5_K<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
mul_mat_q5_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
} else {
|
||||
const bool need_check = true;
|
||||
mul_mat_q5_K<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
mul_mat_q5_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
}
|
||||
@ -560,24 +569,25 @@ static void ggml_mul_mat_q5_K_q8_1_cuda(
|
||||
#define NWARPS_Q6_K 4
|
||||
#endif
|
||||
|
||||
template <bool need_check> static __global__ void
|
||||
template<typename scalar_t, bool need_check> static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q6_K, 2)
|
||||
#endif
|
||||
mul_mat_q6_K(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
|
||||
const int mmq_x = MMQ_X_Q6_K;
|
||||
const int mmq_y = MMQ_Y_Q6_K;
|
||||
const int nwarps = NWARPS_Q6_K;
|
||||
|
||||
mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, allocate_tiles_q6_K<mmq_y>,
|
||||
mul_mat_q<scalar_t, QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, allocate_tiles_q6_K<mmq_y>,
|
||||
load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
static void ggml_mul_mat_q6_K_q8_1_cuda(
|
||||
const void * vx, const void * vy, half * dst, const int ncols_x, const int nrows_x,
|
||||
const void * vx, const void * vy, scalar_t * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
const int mmq_x = MMQ_X_Q6_K;
|
||||
const int mmq_y = MMQ_Y_Q6_K;
|
||||
@ -590,11 +600,11 @@ static void ggml_mul_mat_q6_K_q8_1_cuda(
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
const bool need_check = false;
|
||||
mul_mat_q6_K<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
mul_mat_q6_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
} else {
|
||||
const bool need_check = true;
|
||||
mul_mat_q6_K<need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
mul_mat_q6_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>
|
||||
(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
// copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmvq.cu
|
||||
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
|
||||
static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, const int ncols, const int nrows) {
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
template <typename scalar_t, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
|
||||
static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst, const int ncols, const int nrows) {
|
||||
const auto row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
|
||||
if (row >= nrows) {
|
||||
return;
|
||||
@ -16,7 +16,7 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void *
|
||||
const block_q_t * x = (const block_q_t *) vx;
|
||||
const block_q8_1 * y = (const block_q8_1 *) vy;
|
||||
|
||||
for (int i = threadIdx.x / (qi/vdr); i < blocks_per_row; i += blocks_per_warp) {
|
||||
for (auto i = threadIdx.x / (qi/vdr); i < blocks_per_row; i += blocks_per_warp) {
|
||||
const int ibx = row*blocks_per_row + i; // x block index
|
||||
|
||||
const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
|
||||
@ -33,158 +33,177 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void *
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
dst[row] = __float2half(tmp);
|
||||
dst[row] = tmp;
|
||||
}
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
|
||||
mul_mat_vec_q<scalar_t, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
|
||||
mul_mat_vec_q<scalar_t, QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
|
||||
mul_mat_vec_q<scalar_t, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
|
||||
mul_mat_vec_q<scalar_t, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
|
||||
mul_mat_vec_q<scalar_t, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
|
||||
mul_mat_vec_q<scalar_t, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
|
||||
mul_mat_vec_q<scalar_t, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
|
||||
mul_mat_vec_q<scalar_t, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
|
||||
mul_mat_vec_q<scalar_t, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
|
||||
mul_mat_vec_q<scalar_t, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
|
||||
mul_mat_vec_q<scalar_t, QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq2_xs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_iq2_xs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
|
||||
mul_mat_vec_q<scalar_t, QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq2_s_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_iq2_s_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI2_S, block_iq2_s, 1, vec_dot_iq2_s_q8_1>
|
||||
mul_mat_vec_q<scalar_t, QK_K, QI2_S, block_iq2_s, 1, vec_dot_iq2_s_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq3_xxs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_iq3_xxs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
|
||||
mul_mat_vec_q<scalar_t, QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq1_s_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_iq1_s_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI1_S, block_iq1_s, 1, vec_dot_iq1_s_q8_1>
|
||||
mul_mat_vec_q<scalar_t, QK_K, QI1_S, block_iq1_s, 1, vec_dot_iq1_s_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq1_m_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_iq1_m_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI1_M, block_iq1_m, 1, vec_dot_iq1_m_q8_1>
|
||||
mul_mat_vec_q<scalar_t, QK_K, QI1_M, block_iq1_m, 1, vec_dot_iq1_m_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq4_nl_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_iq4_nl_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1>
|
||||
mul_mat_vec_q<scalar_t, QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq4_xs_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_iq4_xs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI4_XS, block_iq4_xs, 1, vec_dot_iq4_xs_q8_1>
|
||||
mul_mat_vec_q<scalar_t, QK_K, QI4_XS, block_iq4_xs, 1, vec_dot_iq4_xs_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
static void mul_mat_vec_iq3_s_q8_1_cuda(const void * vx, const void * vy, half * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
template<typename scalar_t>
|
||||
static void mul_mat_vec_iq3_s_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
mul_mat_vec_q<QK_K, QI3_XS, block_iq3_s, 1, vec_dot_iq3_s_q8_1>
|
||||
mul_mat_vec_q<scalar_t, QK_K, QI3_XS, block_iq3_s, 1, vec_dot_iq3_s_q8_1>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
739
csrc/quantization/gguf/moe.cuh
Normal file
739
csrc/quantization/gguf/moe.cuh
Normal file
@ -0,0 +1,739 @@
|
||||
#include <cstdint>
|
||||
|
||||
/* Adapted from ./csrc/quantization/gguf/mmq.cuh
|
||||
based on ./vllm/model_executor/layers/fused_moe/fused_moe.py */
|
||||
template <typename scalar_t, int qk, int qr, int qi, bool need_sum,
|
||||
typename block_q_t, int mmq_x, int mmq_y, int nwarps,
|
||||
allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles,
|
||||
int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
|
||||
static __device__ __forceinline__ void moe_q(
|
||||
const void* __restrict__ vx, const void* __restrict__ vy,
|
||||
scalar_t* __restrict__ dst, const int* __restrict__ sorted_token_ids,
|
||||
const int* __restrict__ expert_ids,
|
||||
const int* __restrict__ num_tokens_post_padded, const int exp_stride,
|
||||
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y,
|
||||
const int nrows_dst, const int top_k) {
|
||||
const int blocks_per_row_x = ncols_x / qk;
|
||||
const int blocks_per_col_y = nrows_y / QK8_1;
|
||||
const int blocks_per_warp = WARP_SIZE_GGUF / qi;
|
||||
|
||||
const int ncols_dst = ncols_y * top_k;
|
||||
|
||||
const auto row_dst_0 = blockIdx.x * mmq_y;
|
||||
const int& row_x_0 = row_dst_0;
|
||||
|
||||
const auto col_dst_0 = blockIdx.y * mmq_x;
|
||||
|
||||
int token_offs[mmq_x / nwarps];
|
||||
for (int i = 0; i < mmq_x; i += nwarps) {
|
||||
token_offs[i / nwarps] = sorted_token_ids[col_dst_0 + threadIdx.y + i];
|
||||
}
|
||||
|
||||
const int exp_idx = expert_ids[blockIdx.y];
|
||||
if (exp_idx > 255 || exp_idx < 0) return;
|
||||
if (blockIdx.y * mmq_x > num_tokens_post_padded[0]) return;
|
||||
|
||||
const block_q_t* x = (const block_q_t*)((char*)vx + exp_idx * exp_stride);
|
||||
const block_q8_1* y = (const block_q8_1*)(vy);
|
||||
|
||||
int* tile_x_ql = nullptr;
|
||||
half2* tile_x_dm = nullptr;
|
||||
int* tile_x_qh = nullptr;
|
||||
int* tile_x_sc = nullptr;
|
||||
|
||||
allocate_tiles(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc);
|
||||
|
||||
__shared__ int tile_y_qs[mmq_x * WARP_SIZE_GGUF];
|
||||
__shared__ half2 tile_y_ds[mmq_x * WARP_SIZE_GGUF / QI8_1];
|
||||
|
||||
float sum[mmq_y / WARP_SIZE_GGUF][mmq_x / nwarps] = {{0.0f}};
|
||||
|
||||
for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) {
|
||||
load_tiles(x + row_x_0 * blocks_per_row_x + ib0, tile_x_ql, tile_x_dm,
|
||||
tile_x_qh, tile_x_sc, threadIdx.y, nrows_x - row_x_0 - 1,
|
||||
threadIdx.x, blocks_per_row_x);
|
||||
|
||||
const int n_per_r = ((qk * blocks_per_warp) / qr);
|
||||
#pragma unroll
|
||||
for (int ir = 0; ir < qr && ib0 * qk + ir * n_per_r < ncols_x; ++ir) {
|
||||
const auto kqs = ir * WARP_SIZE_GGUF + threadIdx.x;
|
||||
const int kbxd = kqs / QI8_1;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < mmq_x; i += nwarps) {
|
||||
const int col_y_eff = token_offs[i / nwarps] / top_k;
|
||||
const int block_x = ib0 * (qk / QK8_1) + kbxd;
|
||||
if (col_y_eff < ncols_y && block_x < blocks_per_col_y) {
|
||||
const block_q8_1* by0 = &y[col_y_eff * blocks_per_col_y + block_x];
|
||||
const int index_y =
|
||||
(threadIdx.y + i) * WARP_SIZE_GGUF + kqs % WARP_SIZE_GGUF;
|
||||
tile_y_qs[index_y] =
|
||||
get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1);
|
||||
}
|
||||
}
|
||||
|
||||
if (threadIdx.x < n_per_r / QK8_1) {
|
||||
const auto kby = threadIdx.x % (WARP_SIZE_GGUF / QI8_1);
|
||||
const int col_y_eff = token_offs[threadIdx.y] / top_k;
|
||||
const int block_x =
|
||||
ib0 * (qk / QK8_1) + ir * (WARP_SIZE_GGUF / QI8_1) + kby;
|
||||
|
||||
if (col_y_eff < ncols_y && block_x < blocks_per_col_y) {
|
||||
const half2* dsi_src = &y[col_y_eff * blocks_per_col_y + block_x].ds;
|
||||
half2* dsi_dst =
|
||||
&tile_y_ds[threadIdx.y * (WARP_SIZE_GGUF / QI8_1) + kby];
|
||||
|
||||
if (need_sum) {
|
||||
*dsi_dst = *dsi_src;
|
||||
} else {
|
||||
float* dfi_dst = (float*)dsi_dst;
|
||||
*dfi_dst = __low2float(*dsi_src);
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// #pragma unroll // unrolling this loop causes too much register pressure
|
||||
for (int k = ir * WARP_SIZE_GGUF / qr; k < (ir + 1) * WARP_SIZE_GGUF / qr;
|
||||
k += vdr) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < mmq_x; j += nwarps) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < mmq_y; i += WARP_SIZE_GGUF) {
|
||||
sum[i / WARP_SIZE_GGUF][j / nwarps] +=
|
||||
vec_dot(tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs,
|
||||
tile_y_ds, threadIdx.x + i, threadIdx.y + j, k);
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < mmq_x; j += nwarps) {
|
||||
const int col_dst = token_offs[j / nwarps];
|
||||
if (col_dst >= ncols_dst) {
|
||||
return;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < mmq_y; i += WARP_SIZE_GGUF) {
|
||||
const auto row_dst = row_dst_0 + threadIdx.x + i;
|
||||
if (row_dst >= nrows_dst) {
|
||||
continue;
|
||||
}
|
||||
dst[col_dst * nrows_dst + row_dst] = sum[i / WARP_SIZE_GGUF][j / nwarps];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MMQ_X_Q4_0 64
|
||||
#define MMQ_Y_Q4_0 128
|
||||
#define NWARPS_Q4_0 8
|
||||
#else
|
||||
#define MMQ_X_Q4_0 4
|
||||
#define MMQ_Y_Q4_0 32
|
||||
#define NWARPS_Q4_0 4
|
||||
#endif
|
||||
|
||||
template <typename scalar_t, bool need_check>
|
||||
static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q4_0, 2)
|
||||
#endif
|
||||
moe_q4_0(const void* __restrict__ vx, const void* __restrict__ vy,
|
||||
scalar_t* __restrict__ dst, const int* sorted_token_ids,
|
||||
const int* expert_ids, const int* num_tokens_post_padded,
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst,
|
||||
const int top_k) {
|
||||
const int mmq_x = MMQ_X_Q4_0;
|
||||
const int mmq_y = MMQ_Y_Q4_0;
|
||||
const int nwarps = NWARPS_Q4_0;
|
||||
|
||||
moe_q<scalar_t, QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps,
|
||||
allocate_tiles_q4_0<mmq_y>, load_tiles_q4_0<mmq_y, nwarps, need_check>,
|
||||
VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>(
|
||||
vx, vy, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void ggml_moe_q4_0_q8_1_cuda(
|
||||
const void* inp, const void* w, scalar_t* dst, const int* sorted_token_ids,
|
||||
const int* expert_ids, const int* num_tokens_post_padded,
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
||||
const int tokens_post_padded, cudaStream_t stream) {
|
||||
int mmq_x = MMQ_X_Q4_0;
|
||||
int mmq_y = MMQ_Y_Q4_0;
|
||||
int nwarps = NWARPS_Q4_0;
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
const int block_num_y = (tokens_post_padded) / mmq_x;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
constexpr bool need_check = false;
|
||||
moe_q4_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
|
||||
w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
} else {
|
||||
constexpr bool need_check = true;
|
||||
moe_q4_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
|
||||
w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MMQ_X_Q4_1 64
|
||||
#define MMQ_Y_Q4_1 128
|
||||
#define NWARPS_Q4_1 8
|
||||
#else
|
||||
#define MMQ_X_Q4_1 4
|
||||
#define MMQ_Y_Q4_1 32
|
||||
#define NWARPS_Q4_1 4
|
||||
#endif
|
||||
|
||||
template <typename scalar_t, bool need_check>
|
||||
static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q4_1, 2)
|
||||
#endif
|
||||
moe_q4_1(const void* __restrict__ vx, const void* __restrict__ vy,
|
||||
scalar_t* __restrict__ dst, const int* sorted_token_ids,
|
||||
const int* expert_ids, const int* num_tokens_post_padded,
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst,
|
||||
const int top_k) {
|
||||
const int mmq_x = MMQ_X_Q4_1;
|
||||
const int mmq_y = MMQ_Y_Q4_1;
|
||||
const int nwarps = NWARPS_Q4_1;
|
||||
|
||||
moe_q<scalar_t, QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps,
|
||||
allocate_tiles_q4_1<mmq_y>, load_tiles_q4_1<mmq_y, nwarps, need_check>,
|
||||
VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>(
|
||||
vx, vy, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void ggml_moe_q4_1_q8_1_cuda(
|
||||
const void* inp, const void* w, scalar_t* dst, const int* sorted_token_ids,
|
||||
const int* expert_ids, const int* num_tokens_post_padded,
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
||||
const int tokens_post_padded, cudaStream_t stream) {
|
||||
int mmq_x = MMQ_X_Q4_1;
|
||||
int mmq_y = MMQ_Y_Q4_1;
|
||||
int nwarps = NWARPS_Q4_1;
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
const int block_num_y = (tokens_post_padded) / mmq_x;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
constexpr bool need_check = false;
|
||||
moe_q4_1<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
|
||||
w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
} else {
|
||||
constexpr bool need_check = true;
|
||||
moe_q4_1<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
|
||||
w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MMQ_X_Q5_0 64
|
||||
#define MMQ_Y_Q5_0 128
|
||||
#define NWARPS_Q5_0 8
|
||||
#else
|
||||
#define MMQ_X_Q5_0 4
|
||||
#define MMQ_Y_Q5_0 32
|
||||
#define NWARPS_Q5_0 4
|
||||
#endif
|
||||
|
||||
template <typename scalar_t, bool need_check>
|
||||
static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q5_0, 2)
|
||||
#endif
|
||||
moe_q5_0(const void* __restrict__ vx, const void* __restrict__ vy,
|
||||
scalar_t* __restrict__ dst, const int* sorted_token_ids,
|
||||
const int* expert_ids, const int* num_tokens_post_padded,
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst,
|
||||
const int top_k) {
|
||||
const int mmq_x = MMQ_X_Q5_0;
|
||||
const int mmq_y = MMQ_Y_Q5_0;
|
||||
const int nwarps = NWARPS_Q5_0;
|
||||
|
||||
moe_q<scalar_t, QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps,
|
||||
allocate_tiles_q5_0<mmq_y>, load_tiles_q5_0<mmq_y, nwarps, need_check>,
|
||||
VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>(
|
||||
vx, vy, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void ggml_moe_q5_0_q8_1_cuda(
|
||||
const void* inp, const void* w, scalar_t* dst, const int* sorted_token_ids,
|
||||
const int* expert_ids, const int* num_tokens_post_padded,
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
||||
const int tokens_post_padded, cudaStream_t stream) {
|
||||
const int mmq_x = MMQ_X_Q5_0;
|
||||
const int mmq_y = MMQ_Y_Q5_0;
|
||||
const int nwarps = NWARPS_Q5_0;
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
const int block_num_y = (tokens_post_padded) / mmq_x;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
constexpr bool need_check = false;
|
||||
moe_q5_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
|
||||
w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
} else {
|
||||
constexpr bool need_check = true;
|
||||
moe_q5_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
|
||||
w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MMQ_X_Q5_1 64
|
||||
#define MMQ_Y_Q5_1 128
|
||||
#define NWARPS_Q5_1 8
|
||||
#else
|
||||
#define MMQ_X_Q5_1 4
|
||||
#define MMQ_Y_Q5_1 32
|
||||
#define NWARPS_Q5_1 4
|
||||
#endif
|
||||
|
||||
template <typename scalar_t, bool need_check>
|
||||
static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q5_1, 2)
|
||||
#endif
|
||||
moe_q5_1(const void* __restrict__ vx, const void* __restrict__ vy,
|
||||
scalar_t* __restrict__ dst, const int* sorted_token_ids,
|
||||
const int* expert_ids, const int* num_tokens_post_padded,
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst,
|
||||
const int top_k) {
|
||||
const int mmq_x = MMQ_X_Q5_1;
|
||||
const int mmq_y = MMQ_Y_Q5_1;
|
||||
const int nwarps = NWARPS_Q5_1;
|
||||
|
||||
moe_q<scalar_t, QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps,
|
||||
allocate_tiles_q5_1<mmq_y>, load_tiles_q5_1<mmq_y, nwarps, need_check>,
|
||||
VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>(
|
||||
vx, vy, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void ggml_moe_q5_1_q8_1_cuda(
|
||||
const void* inp, const void* w, scalar_t* dst, const int* sorted_token_ids,
|
||||
const int* expert_ids, const int* num_tokens_post_padded,
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
||||
const int tokens_post_padded, cudaStream_t stream) {
|
||||
const int mmq_x = MMQ_X_Q5_1;
|
||||
const int mmq_y = MMQ_Y_Q5_1;
|
||||
const int nwarps = NWARPS_Q5_1;
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
const int block_num_y = (tokens_post_padded) / mmq_x;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
constexpr bool need_check = false;
|
||||
moe_q5_1<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
|
||||
w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
} else {
|
||||
constexpr bool need_check = true;
|
||||
moe_q5_1<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
|
||||
w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MMQ_X_Q8_0 64
|
||||
#define MMQ_Y_Q8_0 128
|
||||
#define NWARPS_Q8_0 8
|
||||
#else
|
||||
#define MMQ_X_Q8_0 4
|
||||
#define MMQ_Y_Q8_0 32
|
||||
#define NWARPS_Q8_0 4
|
||||
#endif
|
||||
|
||||
template <typename scalar_t, bool need_check>
|
||||
static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q8_0, 2)
|
||||
#endif
|
||||
moe_q8_0(const void* __restrict__ vx, const void* __restrict__ vy,
|
||||
scalar_t* __restrict__ dst, const int* sorted_token_ids,
|
||||
const int* expert_ids, const int* num_tokens_post_padded,
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst,
|
||||
const int top_k) {
|
||||
const int mmq_x = MMQ_X_Q8_0;
|
||||
const int mmq_y = MMQ_Y_Q8_0;
|
||||
const int nwarps = NWARPS_Q8_0;
|
||||
|
||||
moe_q<scalar_t, QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps,
|
||||
allocate_tiles_q8_0<mmq_y>, load_tiles_q8_0<mmq_y, nwarps, need_check>,
|
||||
VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>(
|
||||
vx, vy, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void ggml_moe_q8_0_q8_1_cuda(
|
||||
const void* inp, const void* w, scalar_t* dst, const int* sorted_token_ids,
|
||||
const int* expert_ids, const int* num_tokens_post_padded,
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
||||
const int tokens_post_padded, cudaStream_t stream) {
|
||||
const int mmq_x = MMQ_X_Q8_0;
|
||||
const int mmq_y = MMQ_Y_Q8_0;
|
||||
const int nwarps = NWARPS_Q8_0;
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
const int block_num_y = (tokens_post_padded) / mmq_x;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
constexpr bool need_check = false;
|
||||
moe_q8_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
|
||||
w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
} else {
|
||||
constexpr bool need_check = true;
|
||||
moe_q8_0<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
|
||||
w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MMQ_X_Q2_K 64
|
||||
#define MMQ_Y_Q2_K 128
|
||||
#define NWARPS_Q2_K 8
|
||||
#else
|
||||
#define MMQ_X_Q2_K 4
|
||||
#define MMQ_Y_Q2_K 32
|
||||
#define NWARPS_Q2_K 4
|
||||
#endif
|
||||
|
||||
template <typename scalar_t, bool need_check>
|
||||
static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q2_K, 2)
|
||||
#endif
|
||||
moe_q2_K(const void* __restrict__ vx, const void* __restrict__ vy,
|
||||
scalar_t* __restrict__ dst, const int* sorted_token_ids,
|
||||
const int* expert_ids, const int* num_tokens_post_padded,
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst,
|
||||
const int top_k) {
|
||||
const int mmq_x = MMQ_X_Q2_K;
|
||||
const int mmq_y = MMQ_Y_Q2_K;
|
||||
const int nwarps = NWARPS_Q2_K;
|
||||
|
||||
moe_q<scalar_t, QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps,
|
||||
allocate_tiles_q2_K<mmq_y>, load_tiles_q2_K<mmq_y, nwarps, need_check>,
|
||||
VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>(
|
||||
vx, vy, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void ggml_moe_q2_K_q8_1_cuda(
|
||||
const void* inp, const void* w, scalar_t* dst, const int* sorted_token_ids,
|
||||
const int* expert_ids, const int* num_tokens_post_padded,
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
||||
const int tokens_post_padded, cudaStream_t stream) {
|
||||
const int mmq_x = MMQ_X_Q2_K;
|
||||
const int mmq_y = MMQ_Y_Q2_K;
|
||||
const int nwarps = NWARPS_Q2_K;
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
const int block_num_y = (tokens_post_padded) / mmq_x;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
constexpr bool need_check = false;
|
||||
moe_q2_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
|
||||
w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
} else {
|
||||
constexpr bool need_check = true;
|
||||
moe_q2_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
|
||||
w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MMQ_X_Q3_K 64
|
||||
#define MMQ_Y_Q3_K 128
|
||||
#define NWARPS_Q3_K 8
|
||||
#else
|
||||
#define MMQ_X_Q3_K 4
|
||||
#define MMQ_Y_Q3_K 32
|
||||
#define NWARPS_Q3_K 4
|
||||
#endif
|
||||
|
||||
template <typename scalar_t, bool need_check>
|
||||
static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q3_K, 2)
|
||||
#endif
|
||||
moe_q3_K(const void* __restrict__ vx, const void* __restrict__ vy,
|
||||
scalar_t* __restrict__ dst, const int* sorted_token_ids,
|
||||
const int* expert_ids, const int* num_tokens_post_padded,
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst,
|
||||
const int top_k) {
|
||||
|
||||
const int mmq_x = MMQ_X_Q3_K;
|
||||
const int mmq_y = MMQ_Y_Q3_K;
|
||||
const int nwarps = NWARPS_Q3_K;
|
||||
|
||||
moe_q<scalar_t, QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps,
|
||||
allocate_tiles_q3_K<mmq_y>, load_tiles_q3_K<mmq_y, nwarps, need_check>,
|
||||
VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>(
|
||||
vx, vy, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
}
|
||||
template <typename scalar_t>
|
||||
static void ggml_moe_q3_K_q8_1_cuda(
|
||||
const void* inp, const void* w, scalar_t* dst, const int* sorted_token_ids,
|
||||
const int* expert_ids, const int* num_tokens_post_padded,
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
||||
const int tokens_post_padded, cudaStream_t stream) {
|
||||
const int mmq_x = MMQ_X_Q3_K;
|
||||
const int mmq_y = MMQ_Y_Q3_K;
|
||||
const int nwarps = NWARPS_Q3_K;
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
const int block_num_y = (tokens_post_padded) / mmq_x;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
constexpr bool need_check = false;
|
||||
moe_q3_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
|
||||
w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
} else {
|
||||
constexpr bool need_check = true;
|
||||
moe_q3_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
|
||||
w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MMQ_X_Q4_K 64
|
||||
#define MMQ_Y_Q4_K 128
|
||||
#define NWARPS_Q4_K 8
|
||||
#else
|
||||
#define MMQ_X_Q4_K 4
|
||||
#define MMQ_Y_Q4_K 32
|
||||
#define NWARPS_Q4_K 4
|
||||
#endif
|
||||
|
||||
template <typename scalar_t, bool need_check>
|
||||
static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q4_K, 2)
|
||||
#endif
|
||||
moe_q4_K(const void* __restrict__ vx, const void* __restrict__ vy,
|
||||
scalar_t* __restrict__ dst, const int* sorted_token_ids,
|
||||
const int* expert_ids, const int* num_tokens_post_padded,
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst,
|
||||
const int top_k) {
|
||||
const int mmq_x = MMQ_X_Q4_K;
|
||||
const int mmq_y = MMQ_Y_Q4_K;
|
||||
const int nwarps = NWARPS_Q4_K;
|
||||
|
||||
moe_q<scalar_t, QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps,
|
||||
allocate_tiles_q4_K<mmq_y>, load_tiles_q4_K<mmq_y, nwarps, need_check>,
|
||||
VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>(
|
||||
vx, vy, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void ggml_moe_q4_K_q8_1_cuda(
|
||||
const void* inp, const void* w, scalar_t* dst, const int* sorted_token_ids,
|
||||
const int* expert_ids, const int* num_tokens_post_padded,
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
||||
const int tokens_post_padded, cudaStream_t stream) {
|
||||
const int mmq_x = MMQ_X_Q4_K;
|
||||
const int mmq_y = MMQ_Y_Q4_K;
|
||||
const int nwarps = NWARPS_Q4_K;
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
const int block_num_y = (tokens_post_padded) / mmq_x;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
constexpr bool need_check = false;
|
||||
moe_q4_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
|
||||
w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
} else {
|
||||
constexpr bool need_check = true;
|
||||
moe_q4_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
|
||||
w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MMQ_X_Q5_K 64
|
||||
#define MMQ_Y_Q5_K 128
|
||||
#define NWARPS_Q5_K 8
|
||||
#else
|
||||
#define MMQ_X_Q5_K 4
|
||||
#define MMQ_Y_Q5_K 32
|
||||
#define NWARPS_Q5_K 4
|
||||
#endif
|
||||
|
||||
template <typename scalar_t, bool need_check>
|
||||
static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q5_K, 2)
|
||||
#endif
|
||||
moe_q5_K(const void* __restrict__ vx, const void* __restrict__ vy,
|
||||
scalar_t* __restrict__ dst, const int* sorted_token_ids,
|
||||
const int* expert_ids, const int* num_tokens_post_padded,
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst,
|
||||
const int top_k) {
|
||||
const int mmq_x = MMQ_X_Q5_K;
|
||||
const int mmq_y = MMQ_Y_Q5_K;
|
||||
const int nwarps = NWARPS_Q5_K;
|
||||
|
||||
moe_q<scalar_t, QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps,
|
||||
allocate_tiles_q5_K<mmq_y>, load_tiles_q5_K<mmq_y, nwarps, need_check>,
|
||||
VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>(
|
||||
vx, vy, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void ggml_moe_q5_K_q8_1_cuda(
|
||||
const void* inp, const void* w, scalar_t* dst, const int* sorted_token_ids,
|
||||
const int* expert_ids, const int* num_tokens_post_padded,
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
||||
const int tokens_post_padded, cudaStream_t stream) {
|
||||
const int mmq_x = MMQ_X_Q5_K;
|
||||
const int mmq_y = MMQ_Y_Q5_K;
|
||||
const int nwarps = NWARPS_Q5_K;
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
const int block_num_y = (tokens_post_padded) / mmq_x;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
constexpr bool need_check = false;
|
||||
moe_q5_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
|
||||
w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
} else {
|
||||
constexpr bool need_check = true;
|
||||
moe_q5_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
|
||||
w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#define MMQ_X_Q6_K 64
|
||||
#define MMQ_Y_Q6_K 128
|
||||
#define NWARPS_Q6_K 8
|
||||
#else
|
||||
#define MMQ_X_Q6_K 4
|
||||
#define MMQ_Y_Q6_K 32
|
||||
#define NWARPS_Q6_K 4
|
||||
#endif
|
||||
|
||||
template <typename scalar_t, bool need_check>
|
||||
static __global__ void
|
||||
#if defined(USE_ROCM)
|
||||
__launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q6_K, 2)
|
||||
#endif
|
||||
moe_q6_K(const void* __restrict__ vx, const void* __restrict__ vy,
|
||||
scalar_t* __restrict__ dst, const int* sorted_token_ids,
|
||||
const int* expert_ids, const int* num_tokens_post_padded,
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst,
|
||||
const int top_k) {
|
||||
const int mmq_x = MMQ_X_Q6_K;
|
||||
const int mmq_y = MMQ_Y_Q6_K;
|
||||
const int nwarps = NWARPS_Q6_K;
|
||||
|
||||
moe_q<scalar_t, QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps,
|
||||
allocate_tiles_q6_K<mmq_y>, load_tiles_q6_K<mmq_y, nwarps, need_check>,
|
||||
VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>(
|
||||
vx, vy, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
static void ggml_moe_q6_K_q8_1_cuda(
|
||||
const void* inp, const void* w, scalar_t* dst, const int* sorted_token_ids,
|
||||
const int* expert_ids, const int* num_tokens_post_padded,
|
||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
||||
const int tokens_post_padded, cudaStream_t stream) {
|
||||
const int mmq_x = MMQ_X_Q6_K;
|
||||
const int mmq_y = MMQ_Y_Q6_K;
|
||||
const int nwarps = NWARPS_Q6_K;
|
||||
|
||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||
const int block_num_y = (tokens_post_padded) / mmq_x;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);
|
||||
|
||||
if (nrows_x % mmq_y == 0) {
|
||||
constexpr bool need_check = false;
|
||||
moe_q6_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
|
||||
w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
} else {
|
||||
constexpr bool need_check = true;
|
||||
moe_q6_K<scalar_t, need_check><<<block_nums, block_dims, 0, stream>>>(
|
||||
w, inp, dst, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
exp_stride, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, top_k);
|
||||
}
|
||||
}
|
||||
@ -199,15 +199,15 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel(
|
||||
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||
|
||||
int t = threadIdx.x;
|
||||
auto t = threadIdx.x;
|
||||
|
||||
// Block
|
||||
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
||||
int offset_m = blockIdx.y * m_count;
|
||||
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
||||
auto offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
||||
auto offset_m = blockIdx.y * m_count;
|
||||
auto offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
||||
|
||||
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
|
||||
int end_m = min(offset_m + m_count, size_m);
|
||||
[[maybe_unused]] int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
|
||||
[[maybe_unused]] int end_m = min(offset_m + m_count, size_m);
|
||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||
|
||||
int n = offset_n + t * 4;
|
||||
@ -337,15 +337,15 @@ __global__ void gemm_half_q_half_gptq_2bit_kernel(
|
||||
MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||
|
||||
int t = threadIdx.x;
|
||||
auto t = threadIdx.x;
|
||||
|
||||
// Block
|
||||
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
||||
int offset_m = blockIdx.y * m_count;
|
||||
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
||||
auto offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
||||
auto offset_m = blockIdx.y * m_count;
|
||||
auto offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
||||
|
||||
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
|
||||
int end_m = min(offset_m + m_count, size_m);
|
||||
[[maybe_unused]] int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
|
||||
[[maybe_unused]] int end_m = min(offset_m + m_count, size_m);
|
||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||
|
||||
int n = offset_n + t * 4;
|
||||
@ -458,15 +458,15 @@ __global__ void gemm_half_q_half_gptq_3bit_kernel(
|
||||
MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||
|
||||
int t = threadIdx.x;
|
||||
auto t = threadIdx.x;
|
||||
|
||||
// Block
|
||||
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
||||
int offset_m = blockIdx.y * m_count;
|
||||
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
||||
auto offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
||||
auto offset_m = blockIdx.y * m_count;
|
||||
auto offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
||||
|
||||
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
|
||||
int end_m = min(offset_m + m_count, size_m);
|
||||
[[maybe_unused]] int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
|
||||
[[maybe_unused]] int end_m = min(offset_m + m_count, size_m);
|
||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||
|
||||
int n = offset_n + t * 4;
|
||||
@ -586,15 +586,15 @@ __global__ void gemm_half_q_half_gptq_8bit_kernel(
|
||||
MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||
|
||||
int t = threadIdx.x;
|
||||
auto t = threadIdx.x;
|
||||
|
||||
// Block
|
||||
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
||||
int offset_m = blockIdx.y * m_count;
|
||||
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
||||
auto offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
||||
auto offset_m = blockIdx.y * m_count;
|
||||
auto offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
||||
|
||||
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
|
||||
int end_m = min(offset_m + m_count, size_m);
|
||||
[[maybe_unused]] int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
|
||||
[[maybe_unused]] int end_m = min(offset_m + m_count, size_m);
|
||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||
|
||||
int n = offset_n + t * 4;
|
||||
@ -765,14 +765,14 @@ __global__ void reconstruct_exllama_8bit_kernel(
|
||||
MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||
|
||||
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
||||
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
|
||||
auto offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
||||
auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
|
||||
|
||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||
|
||||
// Preload remapping table
|
||||
__shared__ int perm[BLOCK_KN_SIZE];
|
||||
int t = threadIdx.x;
|
||||
auto t = threadIdx.x;
|
||||
|
||||
if (b_q_perm) {
|
||||
if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t];
|
||||
@ -862,14 +862,14 @@ __global__ void reconstruct_exllama_4bit_kernel(
|
||||
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||
|
||||
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
||||
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
|
||||
auto offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
||||
auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
|
||||
|
||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||
|
||||
// Preload remapping table
|
||||
__shared__ int perm[BLOCK_KN_SIZE];
|
||||
int t = threadIdx.x;
|
||||
auto t = threadIdx.x;
|
||||
|
||||
if (b_q_perm) {
|
||||
if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t];
|
||||
@ -967,14 +967,14 @@ __global__ void reconstruct_exllama_3bit_kernel(
|
||||
MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||
|
||||
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
||||
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
|
||||
auto offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
||||
auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
|
||||
|
||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||
|
||||
// Preload remapping table
|
||||
__shared__ int perm[BLOCK_KN_SIZE];
|
||||
int t = threadIdx.x;
|
||||
auto t = threadIdx.x;
|
||||
|
||||
if (b_q_perm) {
|
||||
if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t];
|
||||
@ -1065,14 +1065,14 @@ __global__ void reconstruct_exllama_2bit_kernel(
|
||||
MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||
|
||||
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
||||
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
|
||||
auto offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
||||
auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
|
||||
|
||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||
|
||||
// Preload remapping table
|
||||
__shared__ int perm[BLOCK_KN_SIZE];
|
||||
int t = threadIdx.x;
|
||||
auto t = threadIdx.x;
|
||||
|
||||
if (b_q_perm) {
|
||||
if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t];
|
||||
@ -1181,11 +1181,11 @@ __global__ void gemm_half_q_half_alt_4bit_kernel(
|
||||
int zero_width = width / 8;
|
||||
int vec_height = height * 4;
|
||||
const int blockwidth2 = BLOCK_KN_SIZE / 2;
|
||||
int b = blockIdx.y * BLOCK_M_SIZE_MAX;
|
||||
auto b = blockIdx.y * BLOCK_M_SIZE_MAX;
|
||||
int b_end = min(BLOCK_M_SIZE_MAX, batch - b);
|
||||
int h = BLOCK_KN_SIZE * blockIdx.z / 8;
|
||||
auto h = BLOCK_KN_SIZE * blockIdx.z / 8;
|
||||
int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4;
|
||||
int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
|
||||
auto w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
|
||||
|
||||
__shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2];
|
||||
if (threadIdx.x < h_end) {
|
||||
@ -1197,8 +1197,8 @@ __global__ void gemm_half_q_half_alt_4bit_kernel(
|
||||
}
|
||||
|
||||
__shared__ half2 deq2[256][8];
|
||||
int val = threadIdx.x / 8;
|
||||
int off = threadIdx.x % 8;
|
||||
auto val = threadIdx.x / 8;
|
||||
auto off = threadIdx.x % 8;
|
||||
for (; val < 256; val += BLOCK_KN_SIZE / 8) {
|
||||
deq2[val][off] =
|
||||
__halves2half2(__int2half_rn(val & 0xF), __int2half_rn(val >> 4));
|
||||
@ -1280,11 +1280,11 @@ __global__ void gemm_half_q_half_alt_8bit_kernel(
|
||||
int zero_width = width / 4;
|
||||
int vec_height = height * 2;
|
||||
const int blockwidth2 = BLOCK_KN_SIZE / 2;
|
||||
int b = blockIdx.y * BLOCK_M_SIZE_MAX;
|
||||
auto b = blockIdx.y * BLOCK_M_SIZE_MAX;
|
||||
int b_end = min(BLOCK_M_SIZE_MAX, batch - b);
|
||||
int h = BLOCK_KN_SIZE * blockIdx.z / 4;
|
||||
auto h = BLOCK_KN_SIZE * blockIdx.z / 4;
|
||||
int h_end = min(BLOCK_KN_SIZE / 4, height - h) * 2;
|
||||
int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
|
||||
auto w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
|
||||
|
||||
__shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2];
|
||||
if (threadIdx.x < h_end) {
|
||||
@ -1393,8 +1393,8 @@ __global__ void reconstruct_gptq_kernel(const uint32_t* __restrict__ w,
|
||||
half* __restrict__ out) {
|
||||
// Start of block
|
||||
|
||||
int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
|
||||
int row = blockIdx.y * 32 / bit;
|
||||
auto column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
|
||||
auto row = blockIdx.y * 32 / bit;
|
||||
if (column >= width) return;
|
||||
|
||||
// Views
|
||||
@ -1425,8 +1425,8 @@ __global__ void reconstruct_gptq_3bit_kernel(
|
||||
const int height, const int width, const int group,
|
||||
half* __restrict__ out) {
|
||||
// Start of block
|
||||
int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
|
||||
int row = blockIdx.y * 32;
|
||||
auto column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
|
||||
auto row = blockIdx.y * 32;
|
||||
if (column >= width) return;
|
||||
|
||||
// Views
|
||||
@ -1542,7 +1542,7 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a,
|
||||
|
||||
__global__ void shuffle_4bit_kernel(uint32_t* __restrict__ b_q_weight,
|
||||
const int size_k, const int size_n) {
|
||||
int n = blockIdx.x * THREADS_X + threadIdx.x;
|
||||
auto n = blockIdx.x * THREADS_X + threadIdx.x;
|
||||
if (n >= size_n) return;
|
||||
int k = 0;
|
||||
uint32_t* b_ptr = b_q_weight + n;
|
||||
@ -1555,7 +1555,7 @@ __global__ void shuffle_4bit_kernel(uint32_t* __restrict__ b_q_weight,
|
||||
|
||||
__global__ void shuffle_8bit_kernel(uint32_t* __restrict__ b_q_weight,
|
||||
const int size_k, const int size_n) {
|
||||
int n = blockIdx.x * THREADS_X + threadIdx.x;
|
||||
auto n = blockIdx.x * THREADS_X + threadIdx.x;
|
||||
if (n >= size_n) return;
|
||||
int k = 0;
|
||||
uint32_t* b_ptr = b_q_weight + n;
|
||||
@ -1568,7 +1568,7 @@ __global__ void shuffle_8bit_kernel(uint32_t* __restrict__ b_q_weight,
|
||||
|
||||
__global__ void shuffle_2bit_kernel(uint32_t* __restrict__ b_q_weight,
|
||||
const int size_k, const int size_n) {
|
||||
int n = blockIdx.x * THREADS_X + threadIdx.x;
|
||||
auto n = blockIdx.x * THREADS_X + threadIdx.x;
|
||||
if (n >= size_n) return;
|
||||
int k = 0;
|
||||
uint32_t* b_ptr = b_q_weight + n;
|
||||
@ -1581,7 +1581,7 @@ __global__ void shuffle_2bit_kernel(uint32_t* __restrict__ b_q_weight,
|
||||
|
||||
__global__ void shuffle_3bit_kernel(uint32_t* __restrict__ b_q_weight,
|
||||
const int size_k, const int size_n) {
|
||||
int n = blockIdx.x * THREADS_X + threadIdx.x;
|
||||
auto n = blockIdx.x * THREADS_X + threadIdx.x;
|
||||
if (n >= size_n) return;
|
||||
int k = 0;
|
||||
uint32_t* b_ptr = b_q_weight + n;
|
||||
@ -1599,9 +1599,9 @@ __global__ void make_sequential_4bit_kernel(const uint32_t* __restrict__ w,
|
||||
const uint64_t* w2 = (uint64_t*)w;
|
||||
uint64_t* w_new2 = (uint64_t*)w_new;
|
||||
int w2_stride = w_width >> 1;
|
||||
int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
|
||||
auto w2_column = THREADS_X * blockIdx.x + threadIdx.x;
|
||||
if (w2_column >= w2_stride) return;
|
||||
int w_new2_row = blockIdx.y;
|
||||
auto w_new2_row = blockIdx.y;
|
||||
int q_perm_idx = w_new2_row << 3;
|
||||
uint64_t dst = 0;
|
||||
|
||||
@ -1630,9 +1630,9 @@ __global__ void make_sequential_2bit_kernel(const uint32_t* __restrict__ w,
|
||||
const uint64_t* w2 = (uint64_t*)w;
|
||||
uint64_t* w_new2 = (uint64_t*)w_new;
|
||||
int w2_stride = w_width >> 1;
|
||||
int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
|
||||
auto w2_column = THREADS_X * blockIdx.x + threadIdx.x;
|
||||
if (w2_column >= w2_stride) return;
|
||||
int w_new2_row = blockIdx.y;
|
||||
auto w_new2_row = blockIdx.y;
|
||||
int q_perm_idx = w_new2_row << 4;
|
||||
uint64_t dst = 0;
|
||||
|
||||
@ -1658,10 +1658,10 @@ __global__ void make_sequential_3bit_kernel(const uint32_t* __restrict__ w,
|
||||
uint32_t* __restrict__ w_new,
|
||||
const int* __restrict__ q_perm,
|
||||
const int w_width) {
|
||||
int w_column = THREADS_X * blockIdx.x + threadIdx.x;
|
||||
auto w_column = THREADS_X * blockIdx.x + threadIdx.x;
|
||||
if (w_column >= w_width) return;
|
||||
int w_new_row = blockIdx.y * 3;
|
||||
int q_perm_idx = blockIdx.y << 5;
|
||||
auto w_new_row = blockIdx.y * 3;
|
||||
auto q_perm_idx = blockIdx.y << 5;
|
||||
uint32_t dst[3] = {0, 0, 0};
|
||||
|
||||
#pragma unroll
|
||||
@ -1744,9 +1744,9 @@ __global__ void make_sequential_8bit_kernel(const uint32_t* __restrict__ w,
|
||||
const uint64_t* w2 = (uint64_t*)w;
|
||||
uint64_t* w_new2 = (uint64_t*)w_new;
|
||||
int w2_stride = w_width >> 1;
|
||||
int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
|
||||
auto w2_column = THREADS_X * blockIdx.x + threadIdx.x;
|
||||
if (w2_column >= w2_stride) return;
|
||||
int w_new2_row = blockIdx.y;
|
||||
auto w_new2_row = blockIdx.y;
|
||||
int q_perm_idx = w_new2_row << 2;
|
||||
uint64_t dst = 0;
|
||||
|
||||
|
||||
@ -55,11 +55,11 @@ struct GmemTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
|
||||
this_block_B_base_ptr = params.B_ptr + blockIdx.y * Ntile * params.K +
|
||||
blockIdx.z * params.SplitK * 4;
|
||||
|
||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||
const auto lane_id = threadIdx.x % WARP_SIZE;
|
||||
|
||||
// For matrix A, a block load/store Mtile(row) x 32(col) elements in
|
||||
// multiple iters, 8x4 warp load/store 8(row) x 32(col) elements per iter
|
||||
const int Aldg_row_base_idx = threadIdx.x / 4;
|
||||
const auto Aldg_row_base_idx = threadIdx.x / 4;
|
||||
Aldg_col_idx = (threadIdx.x % 4) * LDG_ELEMENT_CNT_A;
|
||||
const int Aldg_base_offset = Aldg_row_base_idx * params.K + Aldg_col_idx;
|
||||
|
||||
@ -67,7 +67,7 @@ struct GmemTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
|
||||
// elements of N32K16 packing in multiple iters, 4x8 warp load/store 4(row)
|
||||
// * 128(col) per iter
|
||||
Bldg_col_idx = (threadIdx.x % 8) * LDG_ELEMENT_CNT_B;
|
||||
const int Bldg_row_base_idx = threadIdx.x / 8;
|
||||
const auto Bldg_row_base_idx = threadIdx.x / 8;
|
||||
const int Bldg_base_offset =
|
||||
Bldg_row_base_idx * params.K * 4 + Bldg_col_idx;
|
||||
|
||||
@ -89,7 +89,7 @@ struct GmemTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
|
||||
B_ldg_guard = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < (Mtile + M_SIZE_ONE_LOAD - 1) / M_SIZE_ONE_LOAD; ++i) {
|
||||
int m_idx = blockIdx.x * Mtile + Aldg_row_base_idx + i * M_SIZE_ONE_LOAD;
|
||||
auto m_idx = blockIdx.x * Mtile + Aldg_row_base_idx + i * M_SIZE_ONE_LOAD;
|
||||
if (m_idx < params.M) {
|
||||
A_ldg_guard |= (1u << i);
|
||||
}
|
||||
@ -98,8 +98,8 @@ struct GmemTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
|
||||
const int N_padded = (params.N + 31) / 32 * 32;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < (Ntile + N_SIZE_ONE_LOAD - 1) / N_SIZE_ONE_LOAD; ++i) {
|
||||
int n_idx = blockIdx.y * Ntile + (Bldg_row_base_idx / 8) * 32 +
|
||||
i * N_SIZE_ONE_LOAD;
|
||||
auto n_idx = blockIdx.y * Ntile + (Bldg_row_base_idx / 8) * 32 +
|
||||
i * N_SIZE_ONE_LOAD;
|
||||
if (n_idx < N_padded) {
|
||||
B_ldg_guard |= (1u << i);
|
||||
}
|
||||
@ -355,7 +355,7 @@ struct ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
|
||||
__device__ void fused_splitk_reduce() {
|
||||
// need splitk-reduce if enable splitk
|
||||
if (gridDim.z > 1) {
|
||||
int blk_red_idx = blockIdx.x * gridDim.y + blockIdx.y;
|
||||
auto blk_red_idx = blockIdx.x * gridDim.y + blockIdx.y;
|
||||
// Wait for all previous blocks in the splitk direction to accumulate the
|
||||
// results into C_tmp
|
||||
if (threadIdx.x == 0) {
|
||||
@ -371,7 +371,7 @@ struct ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int C_tmp_base_offset = blk_red_idx * Mtile * Ntile + threadIdx.x * 4;
|
||||
auto C_tmp_base_offset = blk_red_idx * Mtile * Ntile + threadIdx.x * 4;
|
||||
if (blockIdx.z != 0) {
|
||||
// expecting that temporary register here reuses the previous A&B frag
|
||||
// register
|
||||
@ -437,9 +437,10 @@ struct ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
|
||||
for (int n_idx = 0; n_idx < WARP_NITER; ++n_idx) {
|
||||
#pragma unroll
|
||||
for (int k_idx = 0; k_idx < 2; ++k_idx) {
|
||||
FType low16 = static_cast<FType>(C_frag[m_idx][n_idx][k_idx * 2]);
|
||||
FType low16 =
|
||||
ScalarType<FType>::float2num(C_frag[m_idx][n_idx][k_idx * 2]);
|
||||
FType high16 =
|
||||
static_cast<FType>(C_frag[m_idx][n_idx][k_idx * 2 + 1]);
|
||||
ScalarType<FType>::float2num(C_frag[m_idx][n_idx][k_idx * 2 + 1]);
|
||||
uint32_t tmp = (reinterpret_cast<uint32_t&>(low16) & 0xffff) |
|
||||
(reinterpret_cast<uint32_t&>(high16) << 16);
|
||||
int sts_offset =
|
||||
@ -455,7 +456,7 @@ struct ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
|
||||
|
||||
FType* C_base_ptr = this_block_C_base_ptr + store_c_base_offset;
|
||||
// C_tile lds and stg
|
||||
int m_base_idx = store_c_row_base_idx + blockIdx.x * Mtile;
|
||||
auto m_base_idx = store_c_row_base_idx + blockIdx.x * Mtile;
|
||||
bool n_guard = (store_c_col_idx + blockIdx.y * Ntile) < params.N;
|
||||
if (WARP_NTILE == 32) {
|
||||
int lds_c_base_offset = warp_id * Mtile * WARP_NTILE +
|
||||
@ -579,9 +580,9 @@ __global__ void __launch_bounds__(BLOCK)
|
||||
int sts_stage_idx = 0;
|
||||
int lds_stage_idx = 0;
|
||||
|
||||
int tb_k_slice = blockIdx.z * params.SplitK + params.SplitK <= params.K
|
||||
? params.SplitK
|
||||
: params.K - blockIdx.z * params.SplitK;
|
||||
auto tb_k_slice = blockIdx.z * params.SplitK + params.SplitK <= params.K
|
||||
? params.SplitK
|
||||
: params.K - blockIdx.z * params.SplitK;
|
||||
int k_tiles = (tb_k_slice + 31) / 32;
|
||||
int first_k_tile = tb_k_slice - (k_tiles - 1) * 32;
|
||||
|
||||
@ -776,13 +777,13 @@ __global__ void restore_N32_K16_dequantize_rhs_w8a16_perc_kernel(
|
||||
const QT* qdata, const FT* scales, const FT* zeros, FT* fdata,
|
||||
const int N_32align, const int N, const int K) {
|
||||
__shared__ FT smem[64 * 32];
|
||||
int warp_id = threadIdx.x / 32;
|
||||
int lane_id = threadIdx.x % 32;
|
||||
const int src_row_idx = blockIdx.x * 8 + lane_id / 4;
|
||||
auto warp_id = threadIdx.x / 32;
|
||||
auto lane_id = threadIdx.x % 32;
|
||||
const auto src_row_idx = blockIdx.x * 8 + lane_id / 4;
|
||||
const int src_col_idx =
|
||||
blockIdx.y * 64 * 4 + warp_id * 16 * 4 + (lane_id % 4) * 16;
|
||||
const int src_offset = src_row_idx * K * 4 + src_col_idx;
|
||||
int params_nidx = blockIdx.x * 32 + (lane_id / 4) * 4;
|
||||
auto params_nidx = blockIdx.x * 32 + (lane_id / 4) * 4;
|
||||
|
||||
QT qval_reg[16];
|
||||
const QT* pdata = qdata + src_offset;
|
||||
@ -793,7 +794,7 @@ __global__ void restore_N32_K16_dequantize_rhs_w8a16_perc_kernel(
|
||||
FT scale_reg[4];
|
||||
*(reinterpret_cast<uint2*>(scale_reg)) =
|
||||
*(reinterpret_cast<const uint2*>(scales + params_nidx));
|
||||
FT zero_reg[4] = {0};
|
||||
FT zero_reg[4];
|
||||
if (zeros != nullptr) {
|
||||
*(reinterpret_cast<uint2*>(zero_reg)) =
|
||||
*(reinterpret_cast<const uint2*>(zeros + params_nidx));
|
||||
@ -809,8 +810,10 @@ __global__ void restore_N32_K16_dequantize_rhs_w8a16_perc_kernel(
|
||||
reinterpret_cast<typename HalfType<FT>::T2*>(&(fval_reg[ni * 4])));
|
||||
#pragma unroll
|
||||
for (int ki = 0; ki < 4; ++ki) {
|
||||
fval_reg[ni * 4 + ki] =
|
||||
(fval_reg[ni * 4 + ki] - zero_reg[ni]) * scale_reg[ni];
|
||||
if (zeros != nullptr) {
|
||||
fval_reg[ni * 4 + ki] = __hsub(fval_reg[ni * 4 + ki], zero_reg[ni]);
|
||||
}
|
||||
fval_reg[ni * 4 + ki] = __hmul(fval_reg[ni * 4 + ki], scale_reg[ni]);
|
||||
int sts_offset = sts_base_offset + ((ki / 2) * 8 + (ki % 2)) * 32 +
|
||||
((ni + lane_id % 4) % 4) * 8;
|
||||
smem[sts_offset] = fval_reg[ni * 4 + ki];
|
||||
@ -826,8 +829,8 @@ __global__ void restore_N32_K16_dequantize_rhs_w8a16_perc_kernel(
|
||||
*reinterpret_cast<uint4*>(smem + lds_base_offset + i * 32 * 32);
|
||||
}
|
||||
|
||||
const int dst_row_base_kidx = blockIdx.y * 64 + threadIdx.x / 4;
|
||||
const int dst_col_nidx = blockIdx.x * 32 + (threadIdx.x % 4) * 8;
|
||||
const auto dst_row_base_kidx = blockIdx.y * 64 + threadIdx.x / 4;
|
||||
const auto dst_col_nidx = blockIdx.x * 32 + (threadIdx.x % 4) * 8;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
int dst_row_kidx = dst_row_base_kidx + i * 32;
|
||||
@ -1005,4 +1008,4 @@ torch::Tensor allspark_w8a16_gemm(
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("allspark_w8a16_gemm", &allspark_w8a16_gemm);
|
||||
}
|
||||
}
|
||||
|
||||
@ -13,8 +13,8 @@ __global__ void __launch_bounds__(128)
|
||||
const uint8_t* B, const FType* B_scale, const FType* B_zero,
|
||||
uint8_t* B_result, FType* B_scale_result, FType* B_zero_result,
|
||||
const int K, const int N, const int N_32align) {
|
||||
const int lane_id = threadIdx.x % 32;
|
||||
const int warp_id = threadIdx.x / 32;
|
||||
const auto lane_id = threadIdx.x % 32;
|
||||
const auto warp_id = threadIdx.x / 32;
|
||||
|
||||
if (blockIdx.x != gridDim.x - 1) {
|
||||
// Load B
|
||||
@ -50,7 +50,7 @@ __global__ void __launch_bounds__(128)
|
||||
}
|
||||
|
||||
// Store B
|
||||
const int dst_row_base_idx = blockIdx.y * (128 / 4) + (lane_id / 8) * 8;
|
||||
const auto dst_row_base_idx = blockIdx.y * (128 / 4) + (lane_id / 8) * 8;
|
||||
const int dst_col_idx =
|
||||
blockIdx.x * (64 * 4) + warp_id * 64 + (lane_id % 8) * 8;
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
@ -65,7 +65,7 @@ __global__ void __launch_bounds__(128)
|
||||
} else {
|
||||
// Load B_scale and B_zero
|
||||
FType b_scale_reg, b_zero_reg;
|
||||
int src_offset = blockIdx.y * 128 + threadIdx.x;
|
||||
auto src_offset = blockIdx.y * 128 + threadIdx.x;
|
||||
ldg16_cg_0(b_scale_reg, B_scale + src_offset, src_offset < N);
|
||||
if (B_zero != nullptr)
|
||||
ldg16_cg_0(b_zero_reg, B_zero + src_offset, src_offset < N);
|
||||
|
||||
@ -7,6 +7,8 @@
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <iostream>
|
||||
#include "../gptq_marlin/marlin_dtypes.cuh"
|
||||
using marlin::ScalarType;
|
||||
|
||||
namespace allspark {
|
||||
|
||||
@ -60,20 +62,20 @@ template <typename FType, int BLOCK, int N_MATRIX>
|
||||
__global__ void f16_gemm_splitk_reduce_kernel(const FType* C_split, FType* C,
|
||||
uint32_t n, uint32_t n_matrix,
|
||||
uint32_t matrix_size) {
|
||||
int idx = blockIdx.x * BLOCK + threadIdx.x;
|
||||
auto idx = blockIdx.x * BLOCK + threadIdx.x;
|
||||
|
||||
if (idx >= matrix_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
FType sum(0);
|
||||
float sum = 0.f;
|
||||
|
||||
int n_mat = N_MATRIX > 0 ? N_MATRIX : (int)n_matrix;
|
||||
for (int i = 0; i < n_mat; ++i) {
|
||||
sum += C_split[idx + i * matrix_size];
|
||||
sum += ScalarType<FType>::num2float(C_split[idx + i * matrix_size]);
|
||||
}
|
||||
|
||||
C[idx] = sum;
|
||||
C[idx] = ScalarType<FType>::float2num(sum);
|
||||
}
|
||||
|
||||
template <typename FType>
|
||||
@ -405,4 +407,4 @@ static __device__ half2 inline num2num2(const half x) {
|
||||
return __half2half2(x);
|
||||
}
|
||||
|
||||
} // namespace allspark
|
||||
} // namespace allspark
|
||||
|
||||
@ -538,6 +538,7 @@ __global__ void Marlin(
|
||||
int prob_n, // output dimension n
|
||||
int prob_k, // reduction dimension k
|
||||
int* locks, // extra global storage for barrier synchronization
|
||||
bool use_atomic_add, // whether to use atomic add to reduce
|
||||
bool use_fp32_reduce // whether to use fp32 global reduce
|
||||
) {
|
||||
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
|
||||
@ -1542,7 +1543,17 @@ __global__ void Marlin(
|
||||
i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
|
||||
i++) {
|
||||
if (c_gl_wr < c_gl_wr_end) {
|
||||
C[c_gl_wr] = sh_red[c_sh_rd];
|
||||
if (use_atomic_add && slice_count > 1) {
|
||||
scalar_t2* C_half2 = reinterpret_cast<scalar_t2*>(&C[c_gl_wr]);
|
||||
scalar_t2* sh_red_half2 =
|
||||
reinterpret_cast<scalar_t2*>(&sh_red[c_sh_rd]);
|
||||
#pragma unroll
|
||||
for (int a = 0; a < 4; a++) {
|
||||
atomicAdd(&C_half2[a], sh_red_half2[a]);
|
||||
}
|
||||
} else {
|
||||
C[c_gl_wr] = sh_red[c_sh_rd];
|
||||
}
|
||||
c_gl_wr += c_gl_wr_delta;
|
||||
c_sh_rd += c_sh_rd_delta;
|
||||
}
|
||||
@ -1644,7 +1655,7 @@ __global__ void Marlin(
|
||||
}
|
||||
cp_async_fence();
|
||||
} else {
|
||||
if (last) {
|
||||
if (last || use_atomic_add) {
|
||||
if (s_sh_wr_pred) {
|
||||
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
|
||||
}
|
||||
@ -1664,7 +1675,7 @@ __global__ void Marlin(
|
||||
}
|
||||
|
||||
} else {
|
||||
if (last) {
|
||||
if (last || use_atomic_add) {
|
||||
cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
||||
@ -1703,8 +1714,8 @@ __global__ void Marlin(
|
||||
}
|
||||
}
|
||||
|
||||
if (slice_count > 1) { // only globally reduce if there is more than one
|
||||
// block in a slice
|
||||
if (slice_count > 1 && !use_atomic_add) {
|
||||
// only globally reduce if there is more than one block in a slice
|
||||
barrier_acquire(&locks[slice_col], slice_idx);
|
||||
if (use_fp32_reduce) {
|
||||
global_reduce_fp32(slice_idx == 0, last);
|
||||
@ -1713,7 +1724,8 @@ __global__ void Marlin(
|
||||
}
|
||||
barrier_release(&locks[slice_col], last);
|
||||
}
|
||||
if (last) // only the last block in a slice actually writes the result
|
||||
if (last || use_atomic_add)
|
||||
// only the last block in a slice actuallywrites the result
|
||||
write_result();
|
||||
slice_row = 0;
|
||||
slice_col_par++;
|
||||
@ -1768,7 +1780,8 @@ __global__ void Marlin(
|
||||
HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT> \
|
||||
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
|
||||
num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \
|
||||
num_groups, prob_m, prob_n, prob_k, locks, use_atomic_add, \
|
||||
use_fp32_reduce); \
|
||||
} \
|
||||
}
|
||||
|
||||
@ -2062,7 +2075,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
vllm::ScalarType const& q_type, bool has_act_order,
|
||||
bool is_k_full, bool has_zp, int num_groups, int group_size,
|
||||
int dev, cudaStream_t stream, int thread_k, int thread_n,
|
||||
int sms, int max_par, bool use_fp32_reduce, bool is_zp_float) {
|
||||
int sms, int max_par, bool use_atomic_add, bool use_fp32_reduce,
|
||||
bool is_zp_float) {
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(
|
||||
q_type == vllm::kU4 || q_type == vllm::kU8,
|
||||
@ -2243,7 +2257,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& workspace,
|
||||
vllm::ScalarTypeId const& b_q_type_id,
|
||||
int64_t size_m, int64_t size_n, int64_t size_k,
|
||||
bool is_k_full, bool has_zp,
|
||||
bool is_k_full, bool has_zp, bool use_atomic_add,
|
||||
bool use_fp32_reduce, bool is_zp_float) {
|
||||
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
|
||||
if (has_zp) {
|
||||
@ -2306,19 +2320,34 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
// Alloc buffers
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
||||
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
|
||||
torch::Tensor c = torch::empty({size_m, size_n}, options);
|
||||
torch::Tensor a_tmp = torch::empty({size_m, size_k}, options);
|
||||
torch::Tensor c;
|
||||
if (use_atomic_add) {
|
||||
c = torch::zeros({size_m, size_n}, options);
|
||||
} else {
|
||||
c = torch::empty({size_m, size_n}, options);
|
||||
}
|
||||
|
||||
torch::Tensor a_tmp;
|
||||
bool has_act_order = g_idx.size(0) != 0;
|
||||
if (has_act_order) {
|
||||
a_tmp = torch::empty({size_m, size_k}, options);
|
||||
} else {
|
||||
a_tmp = torch::empty({0}, options);
|
||||
}
|
||||
|
||||
// Alloc C tmp buffer that is going to be used for the global reduce
|
||||
torch::Tensor c_tmp;
|
||||
int reduce_max_m = marlin::determine_reduce_max_m(size_m, marlin::max_par);
|
||||
int reduce_n = size_n;
|
||||
auto options_fp32 =
|
||||
torch::TensorOptions().dtype(at::kFloat).device(a.device());
|
||||
if (!use_fp32_reduce) {
|
||||
if (use_fp32_reduce) {
|
||||
c_tmp = torch::empty({reduce_max_m, reduce_n}, options_fp32);
|
||||
} else {
|
||||
reduce_max_m = 0;
|
||||
reduce_n = 0;
|
||||
c_tmp = torch::empty({0}, options_fp32);
|
||||
}
|
||||
torch::Tensor c_tmp = torch::empty({reduce_max_m, reduce_n}, options_fp32);
|
||||
|
||||
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
|
||||
// auto -1)
|
||||
@ -2339,7 +2368,6 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
// Detect groupsize and act_order
|
||||
int num_groups = -1;
|
||||
int group_size = -1;
|
||||
bool has_act_order = g_idx.size(0) != 0;
|
||||
|
||||
int rank = b_scales.sizes().size();
|
||||
TORCH_CHECK(rank == 2, "b_scales rank = ", rank, " is not 2");
|
||||
@ -2407,7 +2435,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
|
||||
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
|
||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_zp_float);
|
||||
thread_k, thread_n, sms, marlin::max_par, use_atomic_add,
|
||||
use_fp32_reduce, is_zp_float);
|
||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||
marlin::marlin_mm<nv_bfloat16>(
|
||||
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
||||
@ -2416,7 +2445,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
|
||||
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
|
||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_zp_float);
|
||||
thread_k, thread_n, sms, marlin::max_par, use_atomic_add,
|
||||
use_fp32_reduce, is_zp_float);
|
||||
} else {
|
||||
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
|
||||
}
|
||||
|
||||
@ -4,7 +4,6 @@
|
||||
*/
|
||||
|
||||
// Include both AMD and NVIDIA fp8 types to avoid circular import
|
||||
// TODO(luka/varun) use FP8_TYPE instead after refactoring
|
||||
#include <c10/util/Float8_e4m3fnuz.h>
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
|
||||
|
||||
@ -127,7 +127,7 @@ __device__ __forceinline__ T from_float(const float& inp) {
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ _B16x4 from_floatx4(const floatx4& inp) {
|
||||
union tmpcvt {
|
||||
[[maybe_unused]] union tmpcvt {
|
||||
uint16_t u;
|
||||
_Float16 f;
|
||||
__hip_bfloat16 b;
|
||||
@ -160,7 +160,7 @@ __device__ __forceinline__ _B16x4 from_floatx4(const floatx4& inp) {
|
||||
template <typename T>
|
||||
__device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1,
|
||||
const _B16x4& inp2) {
|
||||
union tmpcvt {
|
||||
[[maybe_unused]] union tmpcvt {
|
||||
uint16_t u;
|
||||
_Float16 f;
|
||||
__hip_bfloat16 b;
|
||||
@ -308,8 +308,8 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
|
||||
constexpr int GQA_RATIO4 = DIVIDE_ROUND_UP(GQA_RATIO, 4);
|
||||
|
||||
__shared__ float shared_qk_max[NWARPS][16 + 1];
|
||||
__shared__ float shared_exp_sum[NWARPS][16 + 1];
|
||||
[[maybe_unused]] __shared__ float shared_qk_max[NWARPS][16 + 1];
|
||||
[[maybe_unused]] __shared__ float shared_exp_sum[NWARPS][16 + 1];
|
||||
// shared_logits is used for multiple purposes
|
||||
__shared__ _B16x4 shared_logits[NWARPS][4][16][4];
|
||||
|
||||
@ -426,7 +426,8 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
const cache_t* k_ptr2 = k_ptr + kblock_number * kv_block_stride;
|
||||
const int klocal_token_idx =
|
||||
TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id;
|
||||
const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx;
|
||||
[[maybe_unused]] const int kglobal_token_idx =
|
||||
partition_start_token_idx + klocal_token_idx;
|
||||
const int kphysical_block_offset = klocal_token_idx % BLOCK_SIZE;
|
||||
const cache_t* k_ptr3 = k_ptr2 + kphysical_block_offset * KX;
|
||||
|
||||
@ -1272,9 +1273,9 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
const int seq_idx = blockIdx.y;
|
||||
const int context_len = context_lens[seq_idx];
|
||||
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
|
||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
[[maybe_unused]] constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
const int warpid = threadIdx.x / WARP_SIZE;
|
||||
const int laneid = threadIdx.x % WARP_SIZE;
|
||||
[[maybe_unused]] const int laneid = threadIdx.x % WARP_SIZE;
|
||||
|
||||
__shared__ float shared_global_exp_sum;
|
||||
// max num partitions supported is warp_size * NPAR_LOOPS
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user