Compare commits
517 Commits
bind_kv_ca
...
sampler-en
| Author | SHA1 | Date | |
|---|---|---|---|
| 4c42267293 | |||
| 24f68342b4 | |||
| c5d963835b | |||
| b313220727 | |||
| 15dac210f0 | |||
| 112b3e5b3b | |||
| 32d669275b | |||
| 4098b72210 | |||
| 46450b8d33 | |||
| 13ac9cab21 | |||
| 66aa4c0bf4 | |||
| 247181536f | |||
| 07bf813fb5 | |||
| 8958217ad5 | |||
| ac5bc615b0 | |||
| 8063dfc61a | |||
| 6278bc829e | |||
| 3f532cb6a6 | |||
| e6c9053f9e | |||
| 43ed4143c4 | |||
| f4c98b4d4c | |||
| e1e0fd7543 | |||
| df8d3d1287 | |||
| 619d3de8bd | |||
| ecff8309a3 | |||
| dcf2a590f5 | |||
| 54aa619459 | |||
| fb22be5817 | |||
| 7f301dd8ef | |||
| 8095341a01 | |||
| 69db16a46a | |||
| ce78f9af4e | |||
| 9239bf718e | |||
| 7a6d45bc8a | |||
| e74ff409e0 | |||
| 7a888271f5 | |||
| 9d119a86ae | |||
| b2e85e26f4 | |||
| dd8a29da99 | |||
| 27df5199d9 | |||
| 35fad35a48 | |||
| 733e7c9e95 | |||
| 0af4d764d6 | |||
| e64afa455c | |||
| 1711b929b6 | |||
| c091c0a588 | |||
| 1aa162e030 | |||
| cf5c8f1686 | |||
| 4ec2cee000 | |||
| 99f536f830 | |||
| 5ebf66748b | |||
| 781d056280 | |||
| 5aefd6ac31 | |||
| 6c663dfd5e | |||
| 33437bc6e7 | |||
| 23114d3364 | |||
| 997c8811d6 | |||
| e42389f9d7 | |||
| ff38f0a32c | |||
| a5cfbab3c8 | |||
| ac3cd6e83c | |||
| 082ab86f5f | |||
| 6aa196c8dc | |||
| a0dd7dcd49 | |||
| e977c11111 | |||
| 5f063a80bd | |||
| 5d8e1c9279 | |||
| 0a049c7d86 | |||
| d0cfec7ab9 | |||
| a608160027 | |||
| 3f04a7fbf2 | |||
| 5994430b84 | |||
| a9e879b316 | |||
| 3e2f37a69a | |||
| 4f044b1d67 | |||
| 4157f563b4 | |||
| 051da7efe3 | |||
| 25f560a62c | |||
| a09ad90a72 | |||
| 10b34e36b9 | |||
| b5269db959 | |||
| 6db94571d7 | |||
| 97cfa65df7 | |||
| 911c8eb000 | |||
| ebcebeeb6b | |||
| f533b5837f | |||
| 8279201ce6 | |||
| 23fdab00a8 | |||
| 623e2ed29f | |||
| 9d72daf4ce | |||
| 6dd55af6c9 | |||
| 3eb08ed9b1 | |||
| 5eeadc2642 | |||
| 3aee6573dc | |||
| 9cc645141d | |||
| 0893567db9 | |||
| 8abe69b499 | |||
| 761702fd19 | |||
| 9606d572ed | |||
| cbcdf2c609 | |||
| 038de04d7b | |||
| 6b3cc75be0 | |||
| 7ffcccfa5c | |||
| cc8accfd53 | |||
| 948ab03e7e | |||
| 5797fb97e9 | |||
| 3892e58ad7 | |||
| d20e261199 | |||
| f622dbcf39 | |||
| dccf535f8e | |||
| 9c5c81b0da | |||
| d6cd59f122 | |||
| bc8ed3c4ba | |||
| b9bd76ca14 | |||
| 6ebaf9ac71 | |||
| f90d34b498 | |||
| f68cce8e64 | |||
| 09b6a95551 | |||
| 50c9636d87 | |||
| 0661cfef7a | |||
| a827aa815d | |||
| b877031d80 | |||
| dd861b992f | |||
| eb63ea1e18 | |||
| 2f4bd358f1 | |||
| 8a8b30eac1 | |||
| 2fa0e1396b | |||
| 1c2bec0f82 | |||
| ec870fba9a | |||
| df1430265c | |||
| 4c69e228b3 | |||
| 790b79750b | |||
| cfbb8c930f | |||
| baec0d4de9 | |||
| c21b99b912 | |||
| 93a00d7dde | |||
| 61e8c18350 | |||
| 8afcd0f633 | |||
| 91ca929dc7 | |||
| 84e00adc8a | |||
| 47c7126213 | |||
| a989ca2bf6 | |||
| 0fa3970deb | |||
| da6ea29f7a | |||
| 7297941b38 | |||
| f8a08cb90d | |||
| b15fd2be2a | |||
| e588ac237c | |||
| 5df2da5b97 | |||
| 11b986b3fb | |||
| 296f927f24 | |||
| 0032903a5b | |||
| 47195057e9 | |||
| 6edbfa924d | |||
| 1e508343e1 | |||
| 2e0b4cfde0 | |||
| 10f55fe6c5 | |||
| d3ccbd6350 | |||
| 0cfe7d386d | |||
| 0c6f5023c3 | |||
| 06dd08256f | |||
| 2b22290ce0 | |||
| d8e82bc06d | |||
| 086b56824c | |||
| 5a0905ba2a | |||
| a8f12a63fd | |||
| 69ae2380c6 | |||
| 27261e40a6 | |||
| e3f813c33b | |||
| c607a2652b | |||
| 3d45e3d749 | |||
| 742369d35a | |||
| bfe2fe0af4 | |||
| a8652f4f0f | |||
| 2f726b241e | |||
| a597a57595 | |||
| ae65f3e237 | |||
| 34868b106a | |||
| 1f16b7fe74 | |||
| b88be22165 | |||
| d8c6d7d6b5 | |||
| 40828ce5fe | |||
| ffa443afed | |||
| 70e500cad9 | |||
| 4cb1c05c9e | |||
| c47aafa37c | |||
| cfbca8a2f2 | |||
| 0fe5609874 | |||
| 22d33baca2 | |||
| b0e96aaebb | |||
| 8310e0b59b | |||
| 26dd972adb | |||
| 61c7a1b856 | |||
| 374ee287d8 | |||
| a4d83661d7 | |||
| 8363cd093d | |||
| 6c5a3195db | |||
| 073d1ed354 | |||
| 3d446433ec | |||
| 1fe0fd12d3 | |||
| dafb4e504a | |||
| 68cf1601d3 | |||
| 61f412187d | |||
| 05ccd0aa35 | |||
| f690372b68 | |||
| 8b3e94a357 | |||
| 437f9162d0 | |||
| 4f065f12f5 | |||
| 228b768db6 | |||
| 027827cc1d | |||
| 72a8639b68 | |||
| 99abb8b650 | |||
| 3a1e648158 | |||
| 46c759c165 | |||
| 179a619c21 | |||
| 452e8fd968 | |||
| 8b793f7ec6 | |||
| af35d3a3cc | |||
| 3b457143d2 | |||
| ab656f2c2f | |||
| 64fc2193dc | |||
| dd732028f5 | |||
| 414919138b | |||
| db7c8ca910 | |||
| f863ffc965 | |||
| 400d483e87 | |||
| d1695758b2 | |||
| 53a0cf8b95 | |||
| 5eeabc2a44 | |||
| 18551e820c | |||
| e41e160263 | |||
| b89fb2a4a1 | |||
| 5340b0e221 | |||
| 37e3806132 | |||
| c0efdd655b | |||
| aaaec52ad9 | |||
| e1eb45d397 | |||
| 89fca671fb | |||
| d20b0c139c | |||
| 166a168b0f | |||
| 2bb0e1a799 | |||
| 6eaf1e5c52 | |||
| 868a8c5b2c | |||
| b4ad56c1bd | |||
| 69698f257e | |||
| cd0cd85102 | |||
| 0a74bfce9c | |||
| dd3b865854 | |||
| 9b87a579aa | |||
| b539222d4e | |||
| 8d6cf89526 | |||
| 583a9778e0 | |||
| a73e183e36 | |||
| 1e799b7ec1 | |||
| 7f6c5ee06c | |||
| faa0275730 | |||
| 8a5a9b70d7 | |||
| bb3aeddfaf | |||
| aecc780dba | |||
| 90df7f23aa | |||
| b9b5bdfc7d | |||
| 31060b2757 | |||
| fc1f67715d | |||
| f6137adbcb | |||
| e53b1350f2 | |||
| d30aa7e9e6 | |||
| d1ad2a57af | |||
| b82662d952 | |||
| 71c1e07107 | |||
| b30c75dda4 | |||
| def232e122 | |||
| 3453b964a3 | |||
| 61c6a5a796 | |||
| 74bc397b0a | |||
| f58aea002c | |||
| 3556a41434 | |||
| 9ed6ee92d6 | |||
| ee3778d5fc | |||
| aaacf17324 | |||
| 4c7629cae9 | |||
| e0fdfa1608 | |||
| 5952d8ab61 | |||
| a2ae496589 | |||
| 877e352262 | |||
| d4d93db2c5 | |||
| 8c0d15d5c5 | |||
| 97ac781c62 | |||
| 776dcec8fe | |||
| ccf02fcbae | |||
| acaea3bb07 | |||
| 9f37422779 | |||
| dd344e0342 | |||
| 54a8804455 | |||
| bbd94a19fc | |||
| 233ffce1eb | |||
| 40677783aa | |||
| 14f301b541 | |||
| 46f98893dd | |||
| fe66b34728 | |||
| 270a5da495 | |||
| 7097b4cc1c | |||
| 977a16772c | |||
| 73deea2fdb | |||
| 9d2b4a70f4 | |||
| 0b0d6421b2 | |||
| 1140991a7b | |||
| 613c5bb945 | |||
| fd8e055ffb | |||
| ab93f1360f | |||
| 40253bab44 | |||
| c77620d22d | |||
| 989ecd2007 | |||
| 54cc46f3eb | |||
| 601bd3268e | |||
| 09269b3127 | |||
| 27b50f1fe6 | |||
| 9532c49836 | |||
| 0c2af17c76 | |||
| a6e0d096dd | |||
| d3d4956261 | |||
| 4059adc31b | |||
| f1f632d9ec | |||
| 95d680b862 | |||
| fb4c7f8ef0 | |||
| 0b1cfa6180 | |||
| 32ef4983cd | |||
| ad19c8a003 | |||
| 2a602b055a | |||
| 7888e1d0a3 | |||
| 60c872d4b6 | |||
| 3fb17d26c8 | |||
| d47807ba08 | |||
| 02fcaa3d0a | |||
| 8a4a2efc6f | |||
| 8e9ffd37d6 | |||
| 01b3fd0af7 | |||
| f53a0586b9 | |||
| b1cc4dfef5 | |||
| 382403921f | |||
| a73122de96 | |||
| bd44b812cb | |||
| 55211b01e8 | |||
| 5d043c1685 | |||
| 36d1ccb286 | |||
| 1bc3b739c4 | |||
| 1bd32bc8dd | |||
| 128bf75283 | |||
| a94a699c3f | |||
| ab426ec9c0 | |||
| 165290d357 | |||
| ce20124671 | |||
| 53be4a8634 | |||
| f5d3acd474 | |||
| 916836bbfb | |||
| d9f83d6206 | |||
| 4a754fcf15 | |||
| c0c25e25fa | |||
| 45f3f3f59e | |||
| ff47aab056 | |||
| debd6bbf09 | |||
| 5c538c37b2 | |||
| e22ee1e7a2 | |||
| e392d85831 | |||
| 77a318bd01 | |||
| 80e78d02ac | |||
| 4a42b9f5d6 | |||
| 47532cd9f4 | |||
| 36e0c8f7da | |||
| 9f583e360c | |||
| b706d898af | |||
| 863d315c86 | |||
| d374f04a33 | |||
| 61a01b27a7 | |||
| 53056731fd | |||
| 4cbf286794 | |||
| c6e14a61ab | |||
| 07b4b7a37f | |||
| 07964e2f30 | |||
| 4bf82d4b90 | |||
| 9ab326713f | |||
| af295e9b01 | |||
| a1c8f3796c | |||
| 08a1a1121d | |||
| 1477ffc381 | |||
| 70b808fe1a | |||
| 63d635d179 | |||
| 1fc973c0b5 | |||
| c982ac5722 | |||
| 4290b704ff | |||
| c91b64f749 | |||
| d6123170d5 | |||
| 485afdd3cb | |||
| 90e88ab756 | |||
| 04421dff8a | |||
| 432d6dad15 | |||
| 5ff0d32580 | |||
| 0967110e42 | |||
| fb0acb6c72 | |||
| 92b0ce2ac7 | |||
| bc2d4473bf | |||
| 3b352a2f92 | |||
| dea985aef0 | |||
| 39be30351f | |||
| 001a9c7b0d | |||
| 89cdaa83e7 | |||
| b0746fae3d | |||
| 60a98b2de5 | |||
| 460f553a6d | |||
| 1253b15774 | |||
| dc74613fa2 | |||
| a21076ed3a | |||
| 212007b168 | |||
| fb16eea48b | |||
| 73ae0b44e9 | |||
| 6d7f037748 | |||
| 10f7552789 | |||
| b0d541947a | |||
| 5f0b53c6ea | |||
| eb8b5eb183 | |||
| 9513290032 | |||
| 0d5e73d30e | |||
| 609ef61fea | |||
| db84f5eb3b | |||
| 206e2577fa | |||
| e02883c400 | |||
| 9085aabd62 | |||
| 8d5aa466fb | |||
| 0b7f06b447 | |||
| 03fe18ae0f | |||
| cb8bdfade2 | |||
| 33f227e16b | |||
| cfd0ae8234 | |||
| 7caff01a7b | |||
| be0b399d74 | |||
| b8b0ccbd2d | |||
| c908a07f57 | |||
| 7b6fd6e486 | |||
| 47512b3200 | |||
| 3b9c6c6947 | |||
| 4aae667668 | |||
| 9f3bc0f58c | |||
| 980385f8c1 | |||
| ca7a2d5f28 | |||
| 333681408f | |||
| ef64044079 | |||
| 66e16a038e | |||
| e1f0835ae0 | |||
| 8ed5421aaa | |||
| c6359e8ca6 | |||
| 952a074980 | |||
| d0feea31c7 | |||
| 58abe35455 | |||
| f7ebad2307 | |||
| 80e9afb5bc | |||
| 1e3598edeb | |||
| f7a6bd0fa1 | |||
| 0ca3b8e01c | |||
| cc10281498 | |||
| 05fb6718f0 | |||
| 12c29a881f | |||
| 70da0c0748 | |||
| c1588a2c94 | |||
| 8ca7a71df7 | |||
| 63137cd922 | |||
| ddd1ef66ec | |||
| e5e03c2c1b | |||
| e1744502c2 | |||
| dae6896977 | |||
| c34eeec58d | |||
| ad60bbb2b2 | |||
| 0578e5a462 | |||
| 04222984f8 | |||
| 6832707e90 | |||
| 6b2ef5cd17 | |||
| 958adce478 | |||
| 99b0915d3b | |||
| 8ca2b21c98 | |||
| d9292786e1 | |||
| cc2f9b32c8 | |||
| cd579352bf | |||
| 9f1710f1ac | |||
| e642ec962c | |||
| ada19210a3 | |||
| bf0560bda9 | |||
| 151b08e0fe | |||
| 81b2f4a45f | |||
| 82551ad616 | |||
| caac5c2e59 | |||
| 6bd1dd9d26 | |||
| 4f27044aab | |||
| 0ddc991f5c | |||
| fa82b93853 | |||
| 69ff99fdcd | |||
| 5d802522a7 | |||
| 1769928079 | |||
| ed6ea06577 | |||
| 5ee10e990d | |||
| 3dbd2d813a | |||
| f5f7f00cd9 | |||
| abcc61e0af | |||
| f6bb18fd9a | |||
| 71eaf8969b | |||
| ca100c90fe | |||
| ffad94397d | |||
| 4dacaa4a83 | |||
| a7ea35aa67 | |||
| 1e3e76b6cc | |||
| 53ea6ad830 | |||
| 1b7624bf5c | |||
| ac60dc7fe1 | |||
| a4f1ee35d6 | |||
| a32c8669ca | |||
| ca2ca8de57 | |||
| f71b00a19e | |||
| 8f808cf86e | |||
| 7bab4bb048 | |||
| e17e4488bd |
@ -4,8 +4,8 @@ tasks:
|
||||
- name: "gsm8k"
|
||||
metrics:
|
||||
- name: "exact_match,strict-match"
|
||||
value: 0.233
|
||||
value: 0.231
|
||||
- name: "exact_match,flexible-extract"
|
||||
value: 0.236
|
||||
value: 0.22
|
||||
limit: 1000
|
||||
num_fewshot: 5
|
||||
|
||||
@ -13,6 +13,7 @@ from pathlib import Path
|
||||
|
||||
import lm_eval
|
||||
import numpy
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
RTOL = 0.05
|
||||
@ -46,6 +47,10 @@ def test_lm_eval_correctness():
|
||||
eval_config = yaml.safe_load(
|
||||
Path(TEST_DATA_FILE).read_text(encoding="utf-8"))
|
||||
|
||||
if eval_config[
|
||||
"model_name"] == "nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform": #noqa: E501
|
||||
pytest.skip("FBGEMM is currently failing on main.")
|
||||
|
||||
# Launch eval requests.
|
||||
results = launch_lm_eval(eval_config)
|
||||
|
||||
|
||||
@ -426,7 +426,7 @@ main() {
|
||||
|
||||
pip install -U transformers
|
||||
|
||||
pip install -r requirements-dev.txt
|
||||
pip install -r requirements/dev.txt
|
||||
which genai-perf
|
||||
|
||||
# check storage
|
||||
|
||||
@ -361,7 +361,7 @@ main() {
|
||||
# get the current IP address, required by benchmark_serving.py
|
||||
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
|
||||
# turn of the reporting of the status of each request, to clean up the terminal output
|
||||
export VLLM_LOG_LEVEL="WARNING"
|
||||
export VLLM_LOGGING_LEVEL="WARNING"
|
||||
|
||||
# prepare for benchmarking
|
||||
cd benchmarks || exit 1
|
||||
|
||||
@ -82,7 +82,7 @@ steps:
|
||||
queue: cpu_queue_postmerge
|
||||
commands:
|
||||
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --progress plain -f Dockerfile.cpu ."
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest --progress plain -f Dockerfile.cpu ."
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
@ -101,16 +101,30 @@ if [[ $commands == *" kernels "* ]]; then
|
||||
--ignore=kernels/test_permute_cols.py"
|
||||
fi
|
||||
|
||||
#ignore certain Entrypoints tests
|
||||
#ignore certain Entrypoints/openai tests
|
||||
if [[ $commands == *" entrypoints/openai "* ]]; then
|
||||
commands=${commands//" entrypoints/openai "/" entrypoints/openai \
|
||||
--ignore=entrypoints/openai/test_accuracy.py \
|
||||
--ignore=entrypoints/openai/test_audio.py \
|
||||
--ignore=entrypoints/openai/test_encoder_decoder.py \
|
||||
--ignore=entrypoints/openai/test_embedding.py \
|
||||
--ignore=entrypoints/openai/test_oot_registration.py "}
|
||||
--ignore=entrypoints/openai/test_chat.py \
|
||||
--ignore=entrypoints/openai/test_shutdown.py \
|
||||
--ignore=entrypoints/openai/test_completion.py \
|
||||
--ignore=entrypoints/openai/test_sleep.py \
|
||||
--ignore=entrypoints/openai/test_models.py \
|
||||
--ignore=entrypoints/openai/test_prompt_validation.py "}
|
||||
fi
|
||||
|
||||
#ignore certain Entrypoints/llm tests
|
||||
if [[ $commands == *" && pytest -v -s entrypoints/llm/test_guided_generate.py"* ]]; then
|
||||
commands=${commands//" && pytest -v -s entrypoints/llm/test_guided_generate.py"/" "}
|
||||
fi
|
||||
|
||||
# --ignore=entrypoints/openai/test_encoder_decoder.py \
|
||||
# --ignore=entrypoints/openai/test_embedding.py \
|
||||
# --ignore=entrypoints/openai/test_oot_registration.py
|
||||
# --ignore=entrypoints/openai/test_accuracy.py \
|
||||
# --ignore=entrypoints/openai/test_models.py <= Fails on MI250 but passes on MI300 as of 2025-03-13
|
||||
|
||||
|
||||
PARALLEL_JOB_COUNT=8
|
||||
# check if the command contains shard flag, we will run all shards in parallel because the host have 8 GPUs.
|
||||
if [[ $commands == *"--shard-id="* ]]; then
|
||||
@ -120,9 +134,10 @@ if [[ $commands == *"--shard-id="* ]]; then
|
||||
# assign shard-id for each shard
|
||||
commands_gpu=${commands//"--shard-id= "/"--shard-id=${GPU} "}
|
||||
echo "Shard ${GPU} commands:$commands_gpu"
|
||||
echo "Render devices: $BUILDKITE_AGENT_META_DATA_RENDER_DEVICES"
|
||||
docker run \
|
||||
--device /dev/kfd --device /dev/dri \
|
||||
--network host \
|
||||
--device /dev/kfd $BUILDKITE_AGENT_META_DATA_RENDER_DEVICES \
|
||||
--network=host \
|
||||
--shm-size=16gb \
|
||||
--rm \
|
||||
-e HIP_VISIBLE_DEVICES="${GPU}" \
|
||||
@ -149,9 +164,10 @@ if [[ $commands == *"--shard-id="* ]]; then
|
||||
fi
|
||||
done
|
||||
else
|
||||
echo "Render devices: $BUILDKITE_AGENT_META_DATA_RENDER_DEVICES"
|
||||
docker run \
|
||||
--device /dev/kfd --device /dev/dri \
|
||||
--network host \
|
||||
--device /dev/kfd $BUILDKITE_AGENT_META_DATA_RENDER_DEVICES \
|
||||
--network=host \
|
||||
--shm-size=16gb \
|
||||
--rm \
|
||||
-e HIP_VISIBLE_DEVICES=0 \
|
||||
|
||||
@ -19,13 +19,14 @@ remove_docker_container
|
||||
|
||||
# Run the image, setting --shm-size=4g for tensor parallel.
|
||||
docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus="$CORE_RANGE" \
|
||||
--cpuset-mems="$NUMA_NODE" --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" cpu-test-"$BUILDKITE_BUILD_NUMBER"
|
||||
--cpuset-mems="$NUMA_NODE" --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" cpu-test-"$BUILDKITE_BUILD_NUMBER"
|
||||
docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus="$CORE_RANGE" \
|
||||
--cpuset-mems="$NUMA_NODE" --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2-"$NUMA_NODE" cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2
|
||||
--cpuset-mems="$NUMA_NODE" --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2-"$NUMA_NODE" cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2
|
||||
|
||||
function cpu_tests() {
|
||||
set -e
|
||||
export NUMA_NODE=$2
|
||||
export BUILDKITE_BUILD_NUMBER=$3
|
||||
|
||||
# offline inference
|
||||
docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2-"$NUMA_NODE" bash -c "
|
||||
@ -35,7 +36,10 @@ function cpu_tests() {
|
||||
# Run basic model test
|
||||
docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
pip install -r vllm/requirements-test.txt
|
||||
pip install -r vllm/requirements/test.txt
|
||||
pip install -r vllm/requirements/cpu.txt
|
||||
pytest -v -s tests/kernels/test_cache.py -m cpu_model
|
||||
pytest -v -s tests/kernels/test_mla_decode_cpu.py -m cpu_model
|
||||
pytest -v -s tests/models/decoder_only/language -m cpu_model
|
||||
pytest -v -s tests/models/embedding/language -m cpu_model
|
||||
pytest -v -s tests/models/encoder_decoder/language -m cpu_model
|
||||
@ -85,4 +89,4 @@ function cpu_tests() {
|
||||
|
||||
# All of CPU tests are expected to be finished less than 40 mins.
|
||||
export -f cpu_tests
|
||||
timeout 40m bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
|
||||
timeout 40m bash -c "cpu_tests $CORE_RANGE $NUMA_NODE $BUILDKITE_BUILD_NUMBER"
|
||||
|
||||
@ -14,6 +14,7 @@ DOCKER_BUILDKIT=1 docker build . \
|
||||
-t gh200-test \
|
||||
--build-arg max_jobs=66 \
|
||||
--build-arg nvcc_threads=2 \
|
||||
--build-arg RUN_WHEEL_CHECK=false \
|
||||
--build-arg torch_cuda_arch_list="9.0+PTX" \
|
||||
--build-arg vllm_fa_cmake_gpu_arches="90-real"
|
||||
|
||||
@ -23,6 +24,6 @@ trap remove_docker_container EXIT
|
||||
remove_docker_container
|
||||
|
||||
# Run the image and test offline inference
|
||||
docker run -e HF_TOKEN -v /root/.cache/huggingface:/root/.cache/huggingface --name gh200-test --gpus=all --entrypoint="" gh200-test bash -c '
|
||||
docker run -e HF_TOKEN -e VLLM_WORKER_MULTIPROC_METHOD=spawn -v /root/.cache/huggingface:/root/.cache/huggingface --name gh200-test --gpus=all --entrypoint="" gh200-test bash -c '
|
||||
python3 examples/offline_inference/basic/generate.py --model meta-llama/Llama-3.2-1B
|
||||
'
|
||||
|
||||
@ -44,11 +44,11 @@ remove_docker_container() {
|
||||
trap remove_docker_container EXIT
|
||||
|
||||
# Run the image
|
||||
docker run --rm -it --device=/dev/neuron0 --device=/dev/neuron1 --network host \
|
||||
docker run --rm -it --device=/dev/neuron0 --network bridge \
|
||||
-v "${HF_CACHE}:${HF_MOUNT}" \
|
||||
-e "HF_HOME=${HF_MOUNT}" \
|
||||
-v "${NEURON_COMPILE_CACHE_URL}:${NEURON_COMPILE_CACHE_MOUNT}" \
|
||||
-e "NEURON_COMPILE_CACHE_URL=${NEURON_COMPILE_CACHE_MOUNT}" \
|
||||
--name "${container_name}" \
|
||||
${image_name} \
|
||||
/bin/bash -c "python3 /workspace/vllm/examples/offline_inference/neuron.py && python3 -m pytest /workspace/vllm/tests/neuron/ -v --capture=tee-sys"
|
||||
/bin/bash -c "python3 /workspace/vllm/examples/offline_inference/neuron.py && python3 -m pytest /workspace/vllm/tests/neuron/1_core/ -v --capture=tee-sys && python3 -m pytest /workspace/vllm/tests/neuron/2_core/ -v --capture=tee-sys"
|
||||
|
||||
@ -1,16 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# This script build the OpenVINO docker image and run the offline inference inside the container.
|
||||
# It serves a sanity check for compilation and basic model usage.
|
||||
set -ex
|
||||
|
||||
# Try building the docker image
|
||||
docker build -t openvino-test -f Dockerfile.openvino .
|
||||
|
||||
# Setup cleanup
|
||||
remove_docker_container() { docker rm -f openvino-test || true; }
|
||||
trap remove_docker_container EXIT
|
||||
remove_docker_container
|
||||
|
||||
# Run the image and launch offline inference
|
||||
docker run --network host --env VLLM_OPENVINO_KVCACHE_SPACE=1 --name openvino-test openvino-test python3 /workspace/examples/offline_inference/basic/generate.py --model facebook/opt-125m
|
||||
@ -1,26 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
# Build the docker image.
|
||||
docker build -f Dockerfile.tpu -t vllm-tpu .
|
||||
|
||||
# Set up cleanup.
|
||||
remove_docker_container() { docker rm -f tpu-test || true; }
|
||||
trap remove_docker_container EXIT
|
||||
# Remove the container that might not be cleaned up in the previous run.
|
||||
remove_docker_container
|
||||
|
||||
# For HF_TOKEN.
|
||||
source /etc/environment
|
||||
# Run a simple end-to-end example.
|
||||
docker run --privileged --net host --shm-size=16G -it \
|
||||
-e "HF_TOKEN=$HF_TOKEN" --name tpu-test \
|
||||
vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git \
|
||||
&& python3 -m pip install pytest \
|
||||
&& python3 -m pip install lm_eval[api]==0.4.4 \
|
||||
&& pytest -v -s /workspace/vllm/tests/entrypoints/openai/test_accuracy.py \
|
||||
&& pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \
|
||||
&& python3 /workspace/vllm/tests/tpu/test_compilation.py \
|
||||
&& python3 /workspace/vllm/tests/tpu/test_quantization_accuracy.py \
|
||||
&& python3 /workspace/vllm/examples/offline_inference/tpu.py"
|
||||
42
.buildkite/run-tpu-v1-test.sh
Executable file
42
.buildkite/run-tpu-v1-test.sh
Executable file
@ -0,0 +1,42 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
# Build the docker image.
|
||||
docker build -f Dockerfile.tpu -t vllm-tpu .
|
||||
|
||||
# Set up cleanup.
|
||||
remove_docker_container() { docker rm -f tpu-test || true; }
|
||||
trap remove_docker_container EXIT
|
||||
# Remove the container that might not be cleaned up in the previous run.
|
||||
remove_docker_container
|
||||
|
||||
# For HF_TOKEN.
|
||||
source /etc/environment
|
||||
# Run a simple end-to-end example.
|
||||
docker run --privileged --net host --shm-size=16G -it \
|
||||
-e "HF_TOKEN=$HF_TOKEN" --name tpu-test \
|
||||
vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git \
|
||||
&& python3 -m pip install pytest \
|
||||
&& python3 -m pip install lm_eval[api]==0.4.4 \
|
||||
&& export VLLM_USE_V1=1 \
|
||||
&& export VLLM_XLA_CHECK_RECOMPILATION=1 \
|
||||
&& echo TEST_1 \
|
||||
&& pytest -v -s /workspace/vllm/tests/tpu/test_compilation.py \
|
||||
&& echo TEST_2 \
|
||||
&& pytest -v -s /workspace/vllm/tests/v1/tpu/test_basic.py \
|
||||
&& echo TEST_3 \
|
||||
&& pytest -v -s /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine \
|
||||
&& echo TEST_4 \
|
||||
&& pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py \
|
||||
&& echo TEST_5 \
|
||||
&& python3 /workspace/vllm/examples/offline_inference/tpu.py \
|
||||
&& echo TEST_6 \
|
||||
&& pytest -s -v /workspace/vllm/tests/tpu/worker/test_tpu_model_runner.py \
|
||||
&& echo TEST_7 \
|
||||
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py" \
|
||||
|
||||
|
||||
# TODO: This test fails because it uses RANDOM_SEED sampling
|
||||
# && VLLM_USE_V1=1 pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \
|
||||
|
||||
@ -4,16 +4,28 @@
|
||||
# It serves a sanity check for compilation and basic model usage.
|
||||
set -ex
|
||||
|
||||
image_name="xpu/vllm-ci:${BUILDKITE_COMMIT}"
|
||||
container_name="xpu_${BUILDKITE_COMMIT}_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)"
|
||||
|
||||
# Try building the docker image
|
||||
docker build -t xpu-test -f Dockerfile.xpu .
|
||||
docker build -t ${image_name} -f Dockerfile.xpu .
|
||||
|
||||
# Setup cleanup
|
||||
remove_docker_container() { docker rm -f xpu-test || true; }
|
||||
remove_docker_container() {
|
||||
docker rm -f "${container_name}" || true;
|
||||
docker image rm -f "${image_name}" || true;
|
||||
docker system prune -f || true;
|
||||
}
|
||||
trap remove_docker_container EXIT
|
||||
remove_docker_container
|
||||
|
||||
# Run the image and test offline inference/tensor parallel
|
||||
docker run --name xpu-test --device /dev/dri -v /dev/dri/by-path:/dev/dri/by-path --entrypoint="" xpu-test sh -c '
|
||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m
|
||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m -tp 2
|
||||
docker run \
|
||||
--device /dev/dri \
|
||||
-v /dev/dri/by-path:/dev/dri/by-path \
|
||||
--entrypoint="" \
|
||||
--name "${container_name}" \
|
||||
"${image_name}" \
|
||||
sh -c '
|
||||
VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m
|
||||
VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m -tp 2
|
||||
'
|
||||
|
||||
@ -35,13 +35,12 @@ steps:
|
||||
fast_check: true
|
||||
no_gpu: True
|
||||
commands:
|
||||
- pip install -r requirements-docs.txt
|
||||
- pip install -r ../../requirements/docs.txt
|
||||
- SPHINXOPTS=\"-W\" make html
|
||||
# Check API reference (if it fails, you may have missing mock imports)
|
||||
- grep \"sig sig-object py\" build/html/api/inference_params.html
|
||||
|
||||
- label: Async Engine, Inputs, Utils, Worker Test # 24min
|
||||
fast_check: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/mq_llm_engine
|
||||
@ -78,6 +77,7 @@ steps:
|
||||
- tests/basic_correctness/test_preemption
|
||||
- tests/basic_correctness/test_cumem.py
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s basic_correctness/test_cumem.py
|
||||
- pytest -v -s basic_correctness/test_basic_correctness.py
|
||||
- pytest -v -s basic_correctness/test_cpu_offload.py
|
||||
@ -112,19 +112,19 @@ steps:
|
||||
- tests/entrypoints/test_chat_utils
|
||||
- tests/entrypoints/offline_mode
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py --ignore=entrypoints/llm/test_collective_rpc.py
|
||||
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/correctness/
|
||||
- VLLM_USE_V1=0 pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/correctness/
|
||||
- pytest -v -s entrypoints/test_chat_utils.py
|
||||
- pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
||||
- VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
||||
|
||||
- label: Distributed Tests (4 GPUs) # 10min
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 4
|
||||
fast_check: true
|
||||
source_file_dependencies:
|
||||
- vllm/distributed/
|
||||
- vllm/core/
|
||||
@ -135,20 +135,27 @@ steps:
|
||||
- examples/offline_inference/rlhf.py
|
||||
- examples/offline_inference/rlhf_colocate.py
|
||||
- tests/examples/offline_inference/data_parallel.py
|
||||
- tests/v1/test_async_llm_dp.py
|
||||
commands:
|
||||
- VLLM_USE_V1=1 python3 ../examples/offline_inference/data_parallel.py
|
||||
# test with tp=2 and external_dp=2
|
||||
- VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
|
||||
- torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
|
||||
# test with internal dp
|
||||
- python3 ../examples/offline_inference/data_parallel.py
|
||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
||||
- pytest -v -s distributed/test_utils.py
|
||||
- pytest -v -s compile/test_basic_correctness.py
|
||||
- pytest -v -s distributed/test_pynccl.py
|
||||
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
|
||||
# TODO: create a dedicated test section for multi-GPU example tests
|
||||
# when we have multiple distributed example tests
|
||||
- python3 ../examples/offline_inference/rlhf.py
|
||||
- RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/rlhf_colocate.py
|
||||
- pushd ../examples/offline_inference
|
||||
- VLLM_ENABLE_V1_MULTIPROCESSING=0 python3 rlhf.py
|
||||
- VLLM_ENABLE_V1_MULTIPROCESSING=0 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
|
||||
- popd
|
||||
|
||||
- label: Metrics, Tracing Test # 10min
|
||||
num_gpus: 2
|
||||
fast_check: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/metrics
|
||||
@ -196,15 +203,19 @@ steps:
|
||||
- tests/v1
|
||||
commands:
|
||||
# split the test to avoid interference
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/core
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/engine
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/sample
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/worker
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/test_stats.py
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/test_utils.py
|
||||
- pytest -v -s v1/core
|
||||
- pytest -v -s v1/entrypoints
|
||||
- pytest -v -s v1/engine
|
||||
- pytest -v -s v1/entrypoints
|
||||
- pytest -v -s v1/sample
|
||||
- pytest -v -s v1/worker
|
||||
- pytest -v -s v1/structured_output
|
||||
- pytest -v -s v1/test_stats.py
|
||||
- pytest -v -s v1/test_utils.py
|
||||
- pytest -v -s v1/test_oracle.py
|
||||
# TODO: accuracy does not match, whether setting
|
||||
# VLLM_USE_FLASHINFER_SAMPLER or not on H100.
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/e2e
|
||||
- pytest -v -s v1/e2e
|
||||
# Integration test for streaming correctness (requires special branch).
|
||||
- pip install -U git+https://github.com/robertgshaw2-neuralmagic/lm-evaluation-harness.git@streaming-api
|
||||
- pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine
|
||||
@ -222,14 +233,17 @@ steps:
|
||||
- python3 offline_inference/basic/chat.py
|
||||
- python3 offline_inference/prefix_caching.py
|
||||
- python3 offline_inference/llm_engine_example.py
|
||||
- python3 offline_inference/vision_language.py
|
||||
- python3 offline_inference/vision_language_multi_image.py
|
||||
- python3 other/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 other/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
|
||||
- python3 offline_inference/audio_language.py --seed 0
|
||||
- python3 offline_inference/vision_language.py --seed 0
|
||||
- python3 offline_inference/vision_language_embedding.py --seed 0
|
||||
- python3 offline_inference/vision_language_multi_image.py --seed 0
|
||||
- VLLM_USE_V1=0 python3 other/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 other/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
|
||||
- python3 offline_inference/encoder_decoder.py
|
||||
- python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0
|
||||
- python3 offline_inference/basic/classify.py
|
||||
- python3 offline_inference/basic/embed.py
|
||||
- python3 offline_inference/basic/score.py
|
||||
- python3 offline_inference/profiling.py --model facebook/opt-125m run_num_steps --num-steps 2
|
||||
- VLLM_USE_V1=0 python3 offline_inference/profiling.py --model facebook/opt-125m run_num_steps --num-steps 2
|
||||
|
||||
- label: Prefix Caching Test # 9min
|
||||
mirror_hardwares: [amd]
|
||||
@ -275,11 +289,10 @@ steps:
|
||||
source_file_dependencies:
|
||||
- vllm/lora
|
||||
- tests/lora
|
||||
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_minicpmv_tp.py --ignore=lora/test_transfomers_model.py
|
||||
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_minicpmv_tp.py --ignore=lora/test_transfomers_model.py
|
||||
parallelism: 4
|
||||
|
||||
- label: PyTorch Fullgraph Smoke Test # 9min
|
||||
fast_check: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/compile
|
||||
@ -288,6 +301,7 @@ steps:
|
||||
# these tests need to be separated, cannot combine
|
||||
- pytest -v -s compile/piecewise/test_simple.py
|
||||
- pytest -v -s compile/piecewise/test_toy_llama.py
|
||||
- pytest -v -s compile/test_pass_manager.py
|
||||
|
||||
- label: PyTorch Fullgraph Test # 18min
|
||||
source_file_dependencies:
|
||||
@ -374,7 +388,8 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s models/test_transformers.py
|
||||
- pytest -v -s models/test_registry.py
|
||||
- pytest -v -s models/test_initialization.py
|
||||
# V1 Test: https://github.com/vllm-project/vllm/issues/14531
|
||||
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py
|
||||
|
||||
- label: Language Models Test (Standard) # 32min
|
||||
#mirror_hardwares: [amd]
|
||||
@ -501,10 +516,11 @@ steps:
|
||||
- vllm/worker/worker.py
|
||||
- vllm/worker/model_runner.py
|
||||
- entrypoints/llm/test_collective_rpc.py
|
||||
- tests/v1/test_async_llm_dp.py
|
||||
- vllm/v1/engine/
|
||||
commands:
|
||||
- pytest -v -s entrypoints/llm/test_collective_rpc.py
|
||||
- VLLM_USE_V1=1 torchrun --nproc-per-node=2 distributed/test_torchrun_example.py
|
||||
- torchrun --nproc-per-node=2 distributed/test_torchrun_example.py
|
||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
||||
- VLLM_ENABLE_V1_MULTIPROCESSING=0 pytest -v -s entrypoints/llm/test_collective_rpc.py
|
||||
- pytest -v -s ./compile/test_basic_correctness.py
|
||||
- pytest -v -s ./compile/test_wrapper.py
|
||||
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
|
||||
@ -517,13 +533,12 @@ steps:
|
||||
# this test fails consistently.
|
||||
# TODO: investigate and fix
|
||||
# - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/disagg_test.py
|
||||
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
||||
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/test_disagg.py
|
||||
|
||||
- label: Plugin Tests (2 GPUs) # 40min
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 2
|
||||
fast_check: true
|
||||
source_file_dependencies:
|
||||
- vllm/plugins/
|
||||
- tests/plugins/
|
||||
@ -582,8 +597,6 @@ steps:
|
||||
# FIXIT: find out which code initialize cuda before running the test
|
||||
# before the fix, we need to use spawn to test it
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
# This test runs llama 13B, so it is required to run on 4 GPUs.
|
||||
- pytest -v -s -x lora/test_long_context.py
|
||||
# There is some Tensor Parallelism related processing logic in LoRA that
|
||||
# requires multi-GPU testing for validation.
|
||||
- pytest -v -s -x lora/test_chatglm3_tp.py
|
||||
|
||||
27
.github/CODEOWNERS
vendored
27
.github/CODEOWNERS
vendored
@ -10,27 +10,32 @@
|
||||
/vllm/worker/worker.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth
|
||||
/vllm/model_executor/guided_decoding @mgoin
|
||||
/vllm/model_executor/guided_decoding @mgoin @russellb
|
||||
/vllm/multimodal @DarkLight1337 @ywang96
|
||||
CMakeLists.txt @tlrmchlsmth
|
||||
|
||||
# vLLM V1
|
||||
/vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat
|
||||
/vllm/v1/structured_output @mgoin @russellb
|
||||
|
||||
# Test ownership
|
||||
/tests/async_engine @njhill @robertgshaw2-redhat @simon-mo
|
||||
/tests/test_inputs.py @DarkLight1337 @ywang96
|
||||
/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @simon-mo
|
||||
/tests/models @DarkLight1337 @ywang96
|
||||
/tests/multimodal @DarkLight1337 @ywang96
|
||||
/tests/prefix_caching @comaniac @KuntaiDu
|
||||
/tests/spec_decode @njhill @LiuXiaoxuanPKU
|
||||
/tests/kernels @tlrmchlsmth @WoosukKwon
|
||||
/tests/quantization @mgoin @robertgshaw2-redhat
|
||||
/.buildkite/lm-eval-harness @mgoin @simon-mo
|
||||
/tests/async_engine @njhill @robertgshaw2-redhat @simon-mo
|
||||
/tests/basic_correctness/test_chunked_prefill @rkooo567 @comaniac
|
||||
/tests/distributed/test_multi_node_assignment.py @youkaichao
|
||||
/tests/distributed/test_pipeline_parallel.py @youkaichao
|
||||
/tests/distributed/test_same_node.py @youkaichao
|
||||
/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @simon-mo
|
||||
/tests/entrypoints/llm/test_guided_generate.py @mgoin @russellb
|
||||
/tests/kernels @tlrmchlsmth @WoosukKwon
|
||||
/tests/model_executor/test_guided_processors.py @mgoin @russellb
|
||||
/tests/models @DarkLight1337 @ywang96
|
||||
/tests/multi_step @alexm-redhat @comaniac
|
||||
/tests/multimodal @DarkLight1337 @ywang96
|
||||
/tests/prefix_caching @comaniac @KuntaiDu
|
||||
/tests/quantization @mgoin @robertgshaw2-redhat
|
||||
/tests/spec_decode @njhill @LiuXiaoxuanPKU
|
||||
/tests/test_inputs.py @DarkLight1337 @ywang96
|
||||
/tests/v1/entrypoints/llm/test_struct_output_generate.py @mgoin @russellb
|
||||
/tests/v1/structured_output @mgoin @russellb
|
||||
/tests/weight_loading @mgoin @youkaichao
|
||||
/tests/basic_correctness/test_chunked_prefill @rkooo567 @comaniac
|
||||
|
||||
28
.github/ISSUE_TEMPLATE/800-misc-discussion.yml
vendored
28
.github/ISSUE_TEMPLATE/800-misc-discussion.yml
vendored
@ -1,28 +0,0 @@
|
||||
name: 🎲 Misc/random discussions that do not fit into the above categories.
|
||||
description: Submit a discussion as you like. Note that developers are heavily overloaded and we mainly rely on community users to answer these issues.
|
||||
title: "[Misc]: "
|
||||
labels: ["misc"]
|
||||
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
#### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue+sort%3Acreated-desc+).
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Anything you want to discuss about vllm.
|
||||
description: >
|
||||
Anything you want to discuss about vllm.
|
||||
validations:
|
||||
required: true
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
Thanks for contributing 🎉!
|
||||
- type: checkboxes
|
||||
id: askllm
|
||||
attributes:
|
||||
label: Before submitting a new issue...
|
||||
options:
|
||||
- label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions.
|
||||
required: true
|
||||
4
.github/ISSUE_TEMPLATE/config.yml
vendored
4
.github/ISSUE_TEMPLATE/config.yml
vendored
@ -1 +1,5 @@
|
||||
blank_issues_enabled: false
|
||||
contact_links:
|
||||
- name: Questions
|
||||
url: https://discuss.vllm.ai
|
||||
about: Ask questions and discuss with other vLLM community members
|
||||
|
||||
45
.github/mergify.yml
vendored
45
.github/mergify.yml
vendored
@ -36,6 +36,21 @@ pull_request_rules:
|
||||
add:
|
||||
- frontend
|
||||
|
||||
- name: label-multi-modality
|
||||
description: Automatically apply multi-modality label
|
||||
conditions:
|
||||
- or:
|
||||
- files~=^vllm/multimodal/
|
||||
- files~=^tests/multimodal/
|
||||
- files~=^tests/models/multimodal/
|
||||
- files~=^tests/models/*/audio_language/
|
||||
- files~=^tests/models/*/vision_language/
|
||||
- files=tests/models/test_vision.py
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
- multi-modality
|
||||
|
||||
- name: label-structured-output
|
||||
description: Automatically apply structured-output label
|
||||
conditions:
|
||||
@ -73,6 +88,36 @@ pull_request_rules:
|
||||
add:
|
||||
- v1
|
||||
|
||||
- name: label-tpu
|
||||
description: Automatically apply tpu label
|
||||
# Keep this list in sync with `label-tpu-remove` conditions
|
||||
conditions:
|
||||
- or:
|
||||
- files~=tpu.py
|
||||
- files~=_tpu
|
||||
- files~=tpu_
|
||||
- files~=/tpu/
|
||||
- files~=pallas
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
- tpu
|
||||
|
||||
- name: label-tpu-remove
|
||||
description: Automatically remove tpu label
|
||||
# Keep this list in sync with `label-tpu` conditions
|
||||
conditions:
|
||||
- and:
|
||||
- -files~=tpu.py
|
||||
- -files~=_tpu
|
||||
- -files~=tpu_
|
||||
- -files~=/tpu/
|
||||
- -files~=pallas
|
||||
actions:
|
||||
label:
|
||||
remove:
|
||||
- tpu
|
||||
|
||||
- name: ping author on conflicts and add 'needs-rebase' label
|
||||
conditions:
|
||||
- conflict
|
||||
|
||||
4
.github/workflows/publish.yml
vendored
4
.github/workflows/publish.yml
vendored
@ -39,7 +39,7 @@ jobs:
|
||||
const script = require('.github/workflows/scripts/create_release.js')
|
||||
await script(github, context, core)
|
||||
|
||||
# NOTE(simon): No longer build wheel using Github Actions. See buildkite's release workflow.
|
||||
# NOTE(simon): No longer build wheel using GitHub Actions. See buildkite's release workflow.
|
||||
# wheel:
|
||||
# name: Build Wheel
|
||||
# runs-on: ${{ matrix.os }}
|
||||
@ -50,7 +50,7 @@ jobs:
|
||||
# matrix:
|
||||
# os: ['ubuntu-20.04']
|
||||
# python-version: ['3.9', '3.10', '3.11', '3.12']
|
||||
# pytorch-version: ['2.4.0'] # Must be the most recent version that meets requirements-cuda.txt.
|
||||
# pytorch-version: ['2.4.0'] # Must be the most recent version that meets requirements/cuda.txt.
|
||||
# cuda-version: ['11.8', '12.1']
|
||||
|
||||
# steps:
|
||||
|
||||
2
.github/workflows/scripts/build.sh
vendored
2
.github/workflows/scripts/build.sh
vendored
@ -9,7 +9,7 @@ PATH=${cuda_home}/bin:$PATH
|
||||
LD_LIBRARY_PATH=${cuda_home}/lib64:$LD_LIBRARY_PATH
|
||||
|
||||
# Install requirements
|
||||
$python_executable -m pip install -r requirements-build.txt -r requirements-cuda.txt
|
||||
$python_executable -m pip install -r requirements/build.txt -r requirements/cuda.txt
|
||||
|
||||
# Limit the number of parallel jobs to avoid OOM
|
||||
export MAX_JOBS=1
|
||||
|
||||
2
.github/workflows/scripts/create_release.js
vendored
2
.github/workflows/scripts/create_release.js
vendored
@ -1,4 +1,4 @@
|
||||
// Uses Github's API to create the release and wait for result.
|
||||
// Uses GitHub's API to create the release and wait for result.
|
||||
// We use a JS script since github CLI doesn't provide a way to wait for the release's creation and returns immediately.
|
||||
|
||||
module.exports = async (github, context, core) => {
|
||||
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@ -2,7 +2,8 @@
|
||||
/vllm/_version.py
|
||||
|
||||
# vllm-flash-attn built from source
|
||||
vllm/vllm_flash_attn/
|
||||
vllm/vllm_flash_attn/*
|
||||
!vllm/vllm_flash_attn/fa_utils.py
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
@ -197,7 +198,7 @@ _build/
|
||||
hip_compat.h
|
||||
|
||||
# Benchmark dataset
|
||||
benchmarks/*.json
|
||||
benchmarks/**/*.json
|
||||
|
||||
# Linting
|
||||
actionlint
|
||||
|
||||
@ -44,8 +44,8 @@ repos:
|
||||
rev: 0.6.2
|
||||
hooks:
|
||||
- id: pip-compile
|
||||
args: [requirements-test.in, -o, requirements-test.txt]
|
||||
files: ^requirements-test\.(in|txt)$
|
||||
args: [requirements/test.in, -o, requirements/test.txt]
|
||||
files: ^requirements/test\.(in|txt)$
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: mypy-local
|
||||
@ -53,7 +53,7 @@ repos:
|
||||
entry: tools/mypy.sh 0 "local"
|
||||
language: python
|
||||
types: [python]
|
||||
additional_dependencies: &mypy_deps [mypy==1.11.1, types-setuptools, types-PyYAML, types-requests]
|
||||
additional_dependencies: &mypy_deps [mypy==1.11.1, types-cachetools, types-setuptools, types-PyYAML, types-requests]
|
||||
stages: [pre-commit] # Don't run in CI
|
||||
- id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
||||
name: Run mypy for Python 3.9
|
||||
|
||||
@ -18,4 +18,4 @@ formats: []
|
||||
# Optionally declare the Python requirements required to build your docs
|
||||
python:
|
||||
install:
|
||||
- requirements: docs/requirements-docs.txt
|
||||
- requirements: requirements/docs.txt
|
||||
|
||||
118
CMakeLists.txt
118
CMakeLists.txt
@ -46,8 +46,8 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101")
|
||||
# requirements.txt files and should be kept consistent. The ROCm torch
|
||||
# versions are derived from Dockerfile.rocm
|
||||
#
|
||||
set(TORCH_SUPPORTED_VERSION_CUDA "2.5.1")
|
||||
set(TORCH_SUPPORTED_VERSION_ROCM "2.5.1")
|
||||
set(TORCH_SUPPORTED_VERSION_CUDA "2.6.0")
|
||||
set(TORCH_SUPPORTED_VERSION_ROCM "2.6.0")
|
||||
|
||||
#
|
||||
# Try to find python package with an executable that exactly matches
|
||||
@ -319,7 +319,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
# Only build AllSpark kernels if we are building for at least some compatible archs.
|
||||
cuda_archs_loose_intersection(ALLSPARK_ARCHS "8.0;8.6;8.7;8.9" "${CUDA_ARCHS}")
|
||||
if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND ALLSPARK_ARCHS)
|
||||
if (ALLSPARK_ARCHS)
|
||||
set(ALLSPARK_SRCS
|
||||
"csrc/quantization/gptq_allspark/allspark_repack.cu"
|
||||
"csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu")
|
||||
@ -330,39 +330,67 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
message(STATUS "Building AllSpark kernels for archs: ${ALLSPARK_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building AllSpark kernels as no compatible archs found"
|
||||
" in CUDA target architectures, or CUDA not >= 12.0")
|
||||
" in CUDA target architectures")
|
||||
endif()
|
||||
|
||||
|
||||
set(SCALED_MM_3X_ARCHS)
|
||||
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
|
||||
# CUDA 12.0 or later (and only work on Hopper, 9.0a for now).
|
||||
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a;10.0a;10.1a;12.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
|
||||
# CUDA 12.0 or later
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C3X=1")
|
||||
message(STATUS "Building scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM90=1")
|
||||
# Let scaled_mm_c2x know it doesn't need to build these arches
|
||||
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
||||
message(STATUS "Building scaled_mm_c3x_sm90 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c3x as CUDA Compiler version is "
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c3x_sm90 as CUDA Compiler version is "
|
||||
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
|
||||
"later if you intend on running FP8 quantized models on "
|
||||
"Hopper.")
|
||||
else()
|
||||
message(STATUS "Not building scaled_mm_c3x as no compatible archs found "
|
||||
message(STATUS "Not building scaled_mm_c3x_sm90 as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# clear SCALED_MM_3X_ARCHS so the scaled_mm_c2x kernels know we didn't
|
||||
# build any 3x kernels
|
||||
set(SCALED_MM_3X_ARCHS)
|
||||
# The cutlass_scaled_mm kernels for Blackwell (c3x, i.e. CUTLASS 3.x) require
|
||||
# CUDA 12.8 or later
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;12.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu"
|
||||
)
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM100=1")
|
||||
# Let scaled_mm_c2x know it doesn't need to build these arches
|
||||
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
||||
message(STATUS "Building scaled_mm_c3x_sm100 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c3x_sm100 as CUDA Compiler version is "
|
||||
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
|
||||
"later if you intend on running FP8 quantized models on "
|
||||
"Blackwell.")
|
||||
else()
|
||||
message(STATUS "Not building scaled_mm_c3x_100 as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
#
|
||||
@ -394,17 +422,18 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# 2:4 Sparse Kernels
|
||||
|
||||
# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
|
||||
# require CUDA 12.2 or later (and only work on Hopper and Blackwell).
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
|
||||
# require CUDA 12.2 or later (and only work on Hopper).
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SPARSE_SCALED_MM_C3X=1")
|
||||
message(STATUS "Building sparse_scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}")
|
||||
message(STATUS "Building sparse_scaled_mm_c3x for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building sparse_scaled_mm_c3x kernels as CUDA Compiler version is "
|
||||
"not >= 12.2, we recommend upgrading to CUDA 12.2 or later "
|
||||
"if you intend on running FP8 sparse quantized models on Hopper.")
|
||||
@ -432,22 +461,33 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
set(FP4_ARCHS)
|
||||
endif()
|
||||
|
||||
# FP8 Blackwell Archs
|
||||
cuda_archs_loose_intersection(BLACKWELL_ARCHS "10.0;10.1;12.0" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND BLACKWELL_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu"
|
||||
)
|
||||
#
|
||||
# CUTLASS MoE kernels
|
||||
|
||||
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works
|
||||
# on Hopper). get_cutlass_moe_mm_data should only be compiled if it's possible
|
||||
# to compile MoE kernels that use its output.
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu"
|
||||
"csrc/quantization/cutlass_w8a8/moe/moe_data.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${BLACKWELL_ARCHS}")
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
message(STATUS "Building FP8 for archs: ${BLACKWELL_ARCHS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM90=1")
|
||||
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
# clear BLACKWELL_ARCHS
|
||||
set(BLACKWELL_ARCHS)
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
|
||||
"not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
|
||||
"if you intend on running FP8 quantized MoE models on Hopper.")
|
||||
else()
|
||||
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
#
|
||||
# Machete kernels
|
||||
|
||||
@ -548,11 +588,23 @@ set(VLLM_MOE_EXT_SRC
|
||||
"csrc/moe/moe_align_sum_kernels.cu"
|
||||
"csrc/moe/topk_softmax_kernels.cu")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_MOE_EXT_SRC "csrc/moe/moe_wna16.cu")
|
||||
endif()
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${VLLM_MOE_EXT_SRC}"
|
||||
CUDA_ARCHS "${CUDA_ARCHS}")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
set(VLLM_MOE_WNA16_SRC
|
||||
"csrc/moe/moe_wna16.cu")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${VLLM_MOE_WNA16_SRC}"
|
||||
CUDA_ARCHS "${CUDA_ARCHS}")
|
||||
|
||||
list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}")
|
||||
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
|
||||
if (MARLIN_MOE_ARCHS)
|
||||
set(MARLIN_MOE_SRC
|
||||
|
||||
63
Dockerfile
63
Dockerfile
@ -31,6 +31,10 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
python3 -m pip install uv
|
||||
|
||||
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
|
||||
# Upgrade to GCC 10 to avoid https://gcc.gnu.org/bugzilla/show_bug.cgi?id=92519
|
||||
# as it was causing spam when compiling the CUTLASS kernels
|
||||
RUN apt-get install -y gcc-10 g++-10
|
||||
@ -55,13 +59,14 @@ WORKDIR /workspace
|
||||
# after this step
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
|
||||
uv pip install --system --index-url https://download.pytorch.org/whl/nightly/cu126 "torch==2.7.0.dev20250121+cu126" "torchvision==0.22.0.dev20250121"; \
|
||||
uv pip install --system --index-url https://download.pytorch.org/whl/nightly/cu128 "torch==2.8.0.dev20250318+cu128" "torchvision==0.22.0.dev20250319"; \
|
||||
uv pip install --system --index-url https://download.pytorch.org/whl/nightly/cu128 --pre pytorch_triton==3.3.0+gitab727c40; \
|
||||
fi
|
||||
|
||||
COPY requirements-common.txt requirements-common.txt
|
||||
COPY requirements-cuda.txt requirements-cuda.txt
|
||||
COPY requirements/common.txt requirements/common.txt
|
||||
COPY requirements/cuda.txt requirements/cuda.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements-cuda.txt
|
||||
uv pip install --system -r requirements/cuda.txt
|
||||
|
||||
# cuda arch list used by torch
|
||||
# can be useful for both `dev` and `test`
|
||||
@ -79,15 +84,19 @@ FROM base AS build
|
||||
ARG TARGETPLATFORM
|
||||
|
||||
# install build dependencies
|
||||
COPY requirements-build.txt requirements-build.txt
|
||||
COPY requirements/build.txt requirements/build.txt
|
||||
|
||||
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements-build.txt
|
||||
uv pip install --system -r requirements/build.txt
|
||||
|
||||
COPY . .
|
||||
ARG GIT_REPO_CHECK=0
|
||||
RUN --mount=type=bind,source=.git,target=.git \
|
||||
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi
|
||||
if [ "$GIT_REPO_CHECK" != "0" ]; then bash tools/check_repo.sh ; fi
|
||||
|
||||
# max jobs used by Ninja to build extensions
|
||||
ARG max_jobs=2
|
||||
@ -124,6 +133,9 @@ RUN --mount=type=cache,target=/root/.cache/ccache \
|
||||
--mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=.git,target=.git \
|
||||
if [ "$USE_SCCACHE" != "1" ]; then \
|
||||
# Clean any existing CMake artifacts
|
||||
rm -rf .deps && \
|
||||
mkdir -p .deps && \
|
||||
python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38; \
|
||||
fi
|
||||
|
||||
@ -143,11 +155,15 @@ RUN if [ "$RUN_WHEEL_CHECK" = "true" ]; then \
|
||||
#################### DEV IMAGE ####################
|
||||
FROM base as dev
|
||||
|
||||
COPY requirements-lint.txt requirements-lint.txt
|
||||
COPY requirements-test.txt requirements-test.txt
|
||||
COPY requirements-dev.txt requirements-dev.txt
|
||||
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
|
||||
COPY requirements/lint.txt requirements/lint.txt
|
||||
COPY requirements/test.txt requirements/test.txt
|
||||
COPY requirements/dev.txt requirements/dev.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements-dev.txt
|
||||
uv pip install --system -r requirements/dev.txt
|
||||
#################### DEV IMAGE ####################
|
||||
|
||||
#################### vLLM installation IMAGE ####################
|
||||
@ -181,6 +197,10 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
python3 -m pip install uv
|
||||
|
||||
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
|
||||
# Workaround for https://github.com/openai/triton/issues/2507 and
|
||||
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
|
||||
# this won't be needed for future versions of this docker image
|
||||
@ -193,7 +213,8 @@ RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/
|
||||
# after this step
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
|
||||
uv pip install --system --index-url https://download.pytorch.org/whl/nightly/cu124 "torch==2.6.0.dev20241210+cu124" "torchvision==0.22.0.dev20241215"; \
|
||||
uv pip install --system --index-url https://download.pytorch.org/whl/nightly/cu128 "torch==2.8.0.dev20250318+cu128" "torchvision==0.22.0.dev20250319"; \
|
||||
uv pip install --system --index-url https://download.pytorch.org/whl/nightly/cu128 --pre pytorch_triton==3.3.0+gitab727c40; \
|
||||
fi
|
||||
|
||||
# Install vllm wheel first, so that torch etc will be installed.
|
||||
@ -216,7 +237,7 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
. /etc/environment && \
|
||||
if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
|
||||
uv pip install --system https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.1.post1/flashinfer_python-0.2.1.post1+cu124torch2.5-cp38-abi3-linux_x86_64.whl ; \
|
||||
uv pip install --system https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.1.post2/flashinfer_python-0.2.1.post2+cu124torch2.6-cp38-abi3-linux_x86_64.whl ; \
|
||||
fi
|
||||
COPY examples examples
|
||||
|
||||
@ -224,9 +245,9 @@ COPY examples examples
|
||||
# some issues w.r.t. JIT compilation. Therefore we need to
|
||||
# install build dependencies for JIT compilation.
|
||||
# TODO: Remove this once FlashInfer AOT wheel is fixed
|
||||
COPY requirements-build.txt requirements-build.txt
|
||||
COPY requirements/build.txt requirements/build.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements-build.txt
|
||||
uv pip install --system -r requirements/build.txt
|
||||
|
||||
#################### vLLM installation IMAGE ####################
|
||||
|
||||
@ -237,9 +258,13 @@ FROM vllm-base AS test
|
||||
|
||||
ADD . /vllm-workspace/
|
||||
|
||||
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
|
||||
# install development dependencies (for testing)
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements-dev.txt
|
||||
uv pip install --system -r requirements/dev.txt
|
||||
|
||||
# install development dependencies (for testing)
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
@ -265,12 +290,16 @@ RUN mv vllm test_docs/
|
||||
# base openai image with additional requirements, for any subsequent openai-style images
|
||||
FROM vllm-base AS vllm-openai-base
|
||||
|
||||
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
|
||||
# install additional dependencies for openai api server
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
|
||||
uv pip install --system accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.42.0' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \
|
||||
else \
|
||||
uv pip install --system accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.45.0' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \
|
||||
uv pip install --system accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.45.3' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \
|
||||
fi
|
||||
|
||||
ENV VLLM_USAGE_SOURCE production-docker-image
|
||||
|
||||
@ -26,18 +26,18 @@ WORKDIR /workspace
|
||||
ARG PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu"
|
||||
ENV PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
--mount=type=bind,src=requirements-build.txt,target=requirements-build.txt \
|
||||
--mount=type=bind,src=requirements/build.txt,target=requirements/build.txt \
|
||||
pip install --upgrade pip && \
|
||||
pip install -r requirements-build.txt
|
||||
pip install -r requirements/build.txt
|
||||
|
||||
FROM cpu-test-arm AS build
|
||||
|
||||
WORKDIR /workspace/vllm
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
--mount=type=bind,src=requirements-common.txt,target=requirements-common.txt \
|
||||
--mount=type=bind,src=requirements-cpu.txt,target=requirements-cpu.txt \
|
||||
pip install -v -r requirements-cpu.txt
|
||||
--mount=type=bind,src=requirements/common.txt,target=requirements/common.txt \
|
||||
--mount=type=bind,src=requirements/cpu.txt,target=requirements/cpu.txt \
|
||||
pip install -v -r requirements/cpu.txt
|
||||
|
||||
COPY . .
|
||||
ARG GIT_REPO_CHECK=0
|
||||
|
||||
@ -22,25 +22,25 @@ ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:/usr/local/li
|
||||
|
||||
RUN echo 'ulimit -c 0' >> ~/.bashrc
|
||||
|
||||
RUN pip install intel_extension_for_pytorch==2.5.0
|
||||
RUN pip install intel_extension_for_pytorch==2.6.0
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
ARG PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu"
|
||||
ENV PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
--mount=type=bind,src=requirements-build.txt,target=requirements-build.txt \
|
||||
--mount=type=bind,src=requirements/build.txt,target=requirements/build.txt \
|
||||
pip install --upgrade pip && \
|
||||
pip install -r requirements-build.txt
|
||||
pip install -r requirements/build.txt
|
||||
|
||||
FROM cpu-test-1 AS build
|
||||
|
||||
WORKDIR /workspace/vllm
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
--mount=type=bind,src=requirements-common.txt,target=requirements-common.txt \
|
||||
--mount=type=bind,src=requirements-cpu.txt,target=requirements-cpu.txt \
|
||||
pip install -v -r requirements-cpu.txt
|
||||
--mount=type=bind,src=requirements/common.txt,target=requirements/common.txt \
|
||||
--mount=type=bind,src=requirements/cpu.txt,target=requirements/cpu.txt \
|
||||
pip install -v -r requirements/cpu.txt
|
||||
|
||||
COPY . .
|
||||
ARG GIT_REPO_CHECK=0
|
||||
|
||||
@ -4,7 +4,7 @@ COPY ./ /workspace/vllm
|
||||
|
||||
WORKDIR /workspace/vllm
|
||||
|
||||
RUN pip install -v -r requirements-hpu.txt
|
||||
RUN pip install -v -r requirements/hpu.txt
|
||||
|
||||
ENV no_proxy=localhost,127.0.0.1
|
||||
ENV PT_HPU_ENABLE_LAZY_COLLECTIVES=true
|
||||
|
||||
@ -36,7 +36,7 @@ RUN --mount=type=bind,source=.git,target=.git \
|
||||
|
||||
RUN python3 -m pip install -U \
|
||||
'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \
|
||||
-r requirements-neuron.txt
|
||||
-r requirements/neuron.txt
|
||||
|
||||
ENV VLLM_TARGET_DEVICE neuron
|
||||
RUN --mount=type=bind,source=.git,target=.git \
|
||||
|
||||
@ -1,29 +0,0 @@
|
||||
# The vLLM Dockerfile is used to construct vLLM image that can be directly used
|
||||
# to run the OpenAI compatible server.
|
||||
|
||||
FROM ubuntu:22.04 AS dev
|
||||
|
||||
RUN apt-get update -y && \
|
||||
apt-get install -y \
|
||||
git python3-pip \
|
||||
ffmpeg libsm6 libxext6 libgl1
|
||||
WORKDIR /workspace
|
||||
|
||||
COPY . .
|
||||
ARG GIT_REPO_CHECK=0
|
||||
RUN --mount=type=bind,source=.git,target=.git \
|
||||
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi
|
||||
|
||||
RUN python3 -m pip install -U pip
|
||||
# install build requirements
|
||||
RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" python3 -m pip install -r /workspace/requirements-build.txt
|
||||
# build vLLM with OpenVINO backend
|
||||
RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" VLLM_TARGET_DEVICE="openvino" python3 -m pip install /workspace
|
||||
|
||||
COPY examples/ /workspace/examples
|
||||
COPY benchmarks/ /workspace/benchmarks
|
||||
|
||||
# install development dependencies (for testing)
|
||||
RUN python3 -m pip install -e tests/vllm_test_utils
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
@ -1,37 +1,267 @@
|
||||
FROM mambaorg/micromamba
|
||||
ARG MAMBA_DOCKERFILE_ACTIVATE=1
|
||||
USER root
|
||||
ARG BASE_UBI_IMAGE_TAG=9.5-1741850109
|
||||
|
||||
ENV PATH="/usr/local/cargo/bin:$PATH:/opt/conda/bin/"
|
||||
###############################################################
|
||||
# base stage with basic dependencies
|
||||
###############################################################
|
||||
|
||||
RUN apt-get update -y && apt-get install -y git wget kmod curl vim libnuma-dev libsndfile-dev libprotobuf-dev build-essential ffmpeg libsm6 libxext6 libgl1 libssl-dev
|
||||
FROM registry.access.redhat.com/ubi9/ubi-minimal:${BASE_UBI_IMAGE_TAG} AS base-builder
|
||||
|
||||
# Some packages in requirements-cpu are installed here
|
||||
# IBM provides optimized packages for ppc64le processors in the open-ce project for mamba
|
||||
# Currently these may not be available for venv or pip directly
|
||||
RUN micromamba install -y -n base -c https://ftp.osuosl.org/pub/open-ce/1.11.0-p10/ -c defaults python=3.10 rust && micromamba clean --all --yes
|
||||
ARG PYTHON_VERSION=3.12
|
||||
ARG OPENBLAS_VERSION=0.3.29
|
||||
|
||||
# Set Environment Variables for venv, cargo & openblas
|
||||
ENV VIRTUAL_ENV=/opt/vllm
|
||||
ENV PATH=${VIRTUAL_ENV}/bin:/root/.cargo/bin:$PATH
|
||||
ENV PKG_CONFIG_PATH=/usr/local/lib/pkgconfig/
|
||||
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib64:/usr/local/lib:/usr/lib64:/usr/lib
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
# install gcc-13, python, rust, openblas
|
||||
# Note: A symlink for libatomic.so is created for gcc-13 (linker fails to find libatomic otherwise - reqd. for sentencepiece)
|
||||
# Note: A dummy file 'control' is created in /tmp/ to artificially create dependencies between stages when building stages in parallel
|
||||
# when `--jobs=<N>` is passed with podman build command
|
||||
RUN microdnf install -y openssl-devel dnf \
|
||||
&& dnf install -y https://mirror.stream.centos.org/9-stream/BaseOS/`arch`/os/Packages/centos-gpg-keys-9.0-24.el9.noarch.rpm \
|
||||
https://mirror.stream.centos.org/9-stream/BaseOS/`arch`/os/Packages/centos-stream-repos-9.0-24.el9.noarch.rpm \
|
||||
https://dl.fedoraproject.org/pub/epel/epel-release-latest-9.noarch.rpm \
|
||||
&& dnf config-manager --add-repo https://mirror.stream.centos.org/9-stream/BaseOS/`arch`/os \
|
||||
&& dnf config-manager --add-repo https://mirror.stream.centos.org/9-stream/AppStream/`arch`/os \
|
||||
&& dnf config-manager --set-enabled crb \
|
||||
&& dnf install -y \
|
||||
git tar gcc-toolset-13 automake libtool numactl-devel lapack-devel \
|
||||
pkgconfig xsimd zeromq-devel kmod findutils protobuf* \
|
||||
libtiff-devel libjpeg-devel openjpeg2-devel zlib-devel \
|
||||
freetype-devel lcms2-devel libwebp-devel tcl-devel tk-devel \
|
||||
harfbuzz-devel fribidi-devel libraqm-devel libimagequant-devel libxcb-devel \
|
||||
python${PYTHON_VERSION}-devel python${PYTHON_VERSION}-pip \
|
||||
&& dnf clean all \
|
||||
&& ln -sf /usr/lib64/libatomic.so.1 /usr/lib64/libatomic.so \
|
||||
&& python${PYTHON_VERSION} -m venv ${VIRTUAL_ENV} \
|
||||
&& python -m pip install -U pip uv \
|
||||
&& uv pip install wheel build "setuptools<70" setuptools_scm setuptools_rust meson-python cmake ninja cython scikit_build_core scikit_build \
|
||||
&& curl -sL https://ftp2.osuosl.org/pub/ppc64el/openblas/latest/Openblas_${OPENBLAS_VERSION}_ppc64le.tar.gz | tar xvf - -C /usr/local \
|
||||
&& curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y \
|
||||
&& cd /tmp && touch control
|
||||
|
||||
###############################################################
|
||||
# Stage to build torch family
|
||||
###############################################################
|
||||
|
||||
FROM base-builder AS torch-builder
|
||||
|
||||
ARG MAX_JOBS
|
||||
ARG TORCH_VERSION=2.6.0
|
||||
ARG _GLIBCXX_USE_CXX11_ABI=1
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
source /opt/rh/gcc-toolset-13/enable && \
|
||||
git clone --recursive https://github.com/pytorch/pytorch.git -b v${TORCH_VERSION} && \
|
||||
cd pytorch && \
|
||||
uv pip install -r requirements.txt && \
|
||||
python setup.py develop && \
|
||||
rm -f dist/torch*+git*whl && \
|
||||
MAX_JOBS=${MAX_JOBS:-$(nproc)} \
|
||||
PYTORCH_BUILD_VERSION=${TORCH_VERSION} PYTORCH_BUILD_NUMBER=1 uv build --wheel --out-dir /torchwheels/
|
||||
|
||||
ARG TORCHVISION_VERSION=0.21.0
|
||||
ARG TORCHVISION_USE_NVJPEG=0
|
||||
ARG TORCHVISION_USE_FFMPEG=0
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
source /opt/rh/gcc-toolset-13/enable && \
|
||||
git clone --recursive https://github.com/pytorch/vision.git -b v${TORCHVISION_VERSION} && \
|
||||
cd vision && \
|
||||
MAX_JOBS=${MAX_JOBS:-$(nproc)} \
|
||||
BUILD_VERSION=${TORCHVISION_VERSION} \
|
||||
uv build --wheel --out-dir /torchwheels/ --no-build-isolation
|
||||
|
||||
ARG TORCHAUDIO_VERSION=2.6.0
|
||||
ARG BUILD_SOX=1
|
||||
ARG BUILD_KALDI=1
|
||||
ARG BUILD_RNNT=1
|
||||
ARG USE_FFMPEG=0
|
||||
ARG USE_ROCM=0
|
||||
ARG USE_CUDA=0
|
||||
ARG TORCHAUDIO_TEST_ALLOW_SKIP_IF_NO_FFMPEG=1
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
source /opt/rh/gcc-toolset-13/enable && \
|
||||
git clone --recursive https://github.com/pytorch/audio.git -b v${TORCHAUDIO_VERSION} && \
|
||||
cd audio && \
|
||||
MAX_JOBS=${MAX_JOBS:-$(nproc)} \
|
||||
BUILD_VERSION=${TORCHAUDIO_VERSION} \
|
||||
uv build --wheel --out-dir /torchwheels/ --no-build-isolation
|
||||
|
||||
###############################################################
|
||||
# Stage to build pyarrow
|
||||
###############################################################
|
||||
|
||||
FROM base-builder AS arrow-builder
|
||||
|
||||
ARG MAX_JOBS
|
||||
ARG PYARROW_PARALLEL
|
||||
ARG PYARROW_VERSION=19.0.1
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
source /opt/rh/gcc-toolset-13/enable && \
|
||||
git clone --recursive https://github.com/apache/arrow.git -b apache-arrow-${PYARROW_VERSION} && \
|
||||
cd arrow/cpp && \
|
||||
mkdir build && cd build && \
|
||||
cmake -DCMAKE_BUILD_TYPE=release \
|
||||
-DCMAKE_INSTALL_PREFIX=/usr/local \
|
||||
-DARROW_PYTHON=ON \
|
||||
-DARROW_BUILD_TESTS=OFF \
|
||||
-DARROW_JEMALLOC=ON \
|
||||
-DARROW_BUILD_STATIC="OFF" \
|
||||
-DARROW_PARQUET=ON \
|
||||
.. && \
|
||||
make install -j ${MAX_JOBS:-$(nproc)} && \
|
||||
cd ../../python/ && \
|
||||
uv pip install -v -r requirements-wheel-build.txt && \
|
||||
PYARROW_PARALLEL=${PYARROW_PARALLEL:-$(nproc)} \
|
||||
python setup.py build_ext \
|
||||
--build-type=release --bundle-arrow-cpp \
|
||||
bdist_wheel --dist-dir /arrowwheels/
|
||||
|
||||
###############################################################
|
||||
# Stage to build opencv
|
||||
###############################################################
|
||||
|
||||
FROM base-builder AS cv-builder
|
||||
|
||||
ARG MAX_JOBS
|
||||
ARG OPENCV_VERSION=84
|
||||
ARG ENABLE_HEADLESS=1
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
source /opt/rh/gcc-toolset-13/enable && \
|
||||
git clone --recursive https://github.com/opencv/opencv-python.git -b ${OPENCV_VERSION} && \
|
||||
cd opencv-python && \
|
||||
sed -i 's/"setuptools==59.2.0",/"setuptools<70.0",/g' pyproject.toml && \
|
||||
python -m build --wheel --installer=uv --outdir /opencvwheels/
|
||||
|
||||
###############################################################
|
||||
# Stage to build vllm - this stage builds and installs
|
||||
# vllm, tensorizer and vllm-tgis-adapter and builds uv cache
|
||||
# for transitive dependencies - eg. grpcio
|
||||
###############################################################
|
||||
|
||||
FROM base-builder AS vllmcache-builder
|
||||
|
||||
COPY --from=torch-builder /tmp/control /dev/null
|
||||
COPY --from=arrow-builder /tmp/control /dev/null
|
||||
COPY --from=cv-builder /tmp/control /dev/null
|
||||
|
||||
ARG VLLM_TARGET_DEVICE=cpu
|
||||
|
||||
# this step installs vllm and populates uv cache
|
||||
# with all the transitive dependencies
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,from=torch-builder,source=/torchwheels/,target=/torchwheels/,ro \
|
||||
--mount=type=bind,from=arrow-builder,source=/arrowwheels/,target=/arrowwheels/,ro \
|
||||
--mount=type=bind,from=cv-builder,source=/opencvwheels/,target=/opencvwheels/,ro \
|
||||
--mount=type=bind,src=.,dst=/src/,rw \
|
||||
source /opt/rh/gcc-toolset-13/enable && \
|
||||
uv pip install /opencvwheels/*.whl /arrowwheels/*.whl /torchwheels/*.whl && \
|
||||
sed -i -e 's/.*torch.*//g' /src/pyproject.toml /src/requirements/*.txt && \
|
||||
uv pip install pandas pythran pybind11 && \
|
||||
# sentencepiece.pc is in some pkgconfig inside uv cache
|
||||
export PKG_CONFIG_PATH=$(find / -type d -name "pkgconfig" 2>/dev/null | tr '\n' ':') && \
|
||||
uv pip install -r /src/requirements/common.txt -r /src/requirements/cpu.txt -r /src/requirements/build.txt --no-build-isolation && \
|
||||
cd /src/ && \
|
||||
uv build --wheel --out-dir /vllmwheel/ --no-build-isolation && \
|
||||
uv pip install /vllmwheel/*.whl
|
||||
|
||||
|
||||
###############################################################
|
||||
# Stage to build numactl
|
||||
###############################################################
|
||||
|
||||
FROM base-builder AS numa-builder
|
||||
|
||||
# Note: Building numactl with gcc-11. Compiling with gcc-13 in this builder stage will
|
||||
# trigger recompilation with gcc-11 (and require libtool) in the final stage where we do not have gcc-13
|
||||
ARG MAX_JOBS
|
||||
ARG NUMACTL_VERSION=2.0.19
|
||||
RUN git clone --recursive https://github.com/numactl/numactl.git -b v${NUMACTL_VERSION} \
|
||||
&& cd numactl \
|
||||
&& autoreconf -i && ./configure \
|
||||
&& make -j ${MAX_JOBS:-$(nproc)}
|
||||
|
||||
###############################################################
|
||||
# Stage to build lapack
|
||||
###############################################################
|
||||
|
||||
FROM base-builder AS lapack-builder
|
||||
|
||||
ARG MAX_JOBS
|
||||
ARG LAPACK_VERSION=3.12.1
|
||||
RUN git clone --recursive https://github.com/Reference-LAPACK/lapack.git -b v${LAPACK_VERSION} \
|
||||
&& cd lapack && source /opt/rh/gcc-toolset-13/enable \
|
||||
&& cmake -B build -S . \
|
||||
&& cmake --build build -j ${MAX_JOBS:-$(nproc)}
|
||||
|
||||
|
||||
###############################################################
|
||||
# FINAL VLLM IMAGE STAGE #
|
||||
###############################################################
|
||||
|
||||
FROM registry.access.redhat.com/ubi9/ubi-minimal:${BASE_UBI_IMAGE_TAG} AS vllm-openai
|
||||
|
||||
ARG PYTHON_VERSION=3.12
|
||||
ARG OPENBLAS_VERSION=0.3.29
|
||||
|
||||
# Set Environment Variables for venv & openblas
|
||||
ENV VIRTUAL_ENV=/opt/vllm
|
||||
ENV PATH=${VIRTUAL_ENV}/bin:$PATH
|
||||
ENV PKG_CONFIG_PATH=/usr/local/lib/pkgconfig/
|
||||
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib64:/usr/local/lib:/usr/lib64:/usr/lib
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
# create artificial dependencies between stages for independent stages to build in parallel
|
||||
COPY --from=torch-builder /tmp/control /dev/null
|
||||
COPY --from=arrow-builder /tmp/control /dev/null
|
||||
COPY --from=cv-builder /tmp/control /dev/null
|
||||
COPY --from=vllmcache-builder /tmp/control /dev/null
|
||||
COPY --from=numa-builder /tmp/control /dev/null
|
||||
COPY --from=lapack-builder /tmp/control /dev/null
|
||||
|
||||
# install gcc-11, python, openblas, numactl, lapack
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,from=numa-builder,source=/numactl/,target=/numactl/,rw \
|
||||
--mount=type=bind,from=lapack-builder,source=/lapack/,target=/lapack/,rw \
|
||||
rpm -ivh https://dl.fedoraproject.org/pub/epel/epel-release-latest-9.noarch.rpm && \
|
||||
microdnf install --nodocs -y \
|
||||
tar findutils openssl \
|
||||
pkgconfig xsimd g++ gcc-fortran libsndfile \
|
||||
libtiff libjpeg openjpeg2 zlib zeromq \
|
||||
freetype lcms2 libwebp tcl tk utf8proc \
|
||||
harfbuzz fribidi libraqm libimagequant libxcb \
|
||||
python${PYTHON_VERSION}-devel python${PYTHON_VERSION}-pip \
|
||||
&& microdnf clean all \
|
||||
&& python${PYTHON_VERSION} -m venv ${VIRTUAL_ENV} \
|
||||
&& python -m pip install -U pip uv --no-cache \
|
||||
&& curl -sL https://ftp2.osuosl.org/pub/ppc64el/openblas/latest/Openblas_${OPENBLAS_VERSION}_ppc64le.tar.gz | tar xvf - -C /usr/local \
|
||||
&& make -C /numactl install \
|
||||
&& uv pip install cmake \
|
||||
&& cmake --install /lapack/build \
|
||||
&& uv pip uninstall cmake
|
||||
|
||||
# consume previously built wheels (including vllm)
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,from=torch-builder,source=/torchwheels/,target=/torchwheels/,ro \
|
||||
--mount=type=bind,from=arrow-builder,source=/arrowwheels/,target=/arrowwheels/,ro \
|
||||
--mount=type=bind,from=cv-builder,source=/opencvwheels/,target=/opencvwheels/,ro \
|
||||
--mount=type=bind,from=vllmcache-builder,source=/vllmwheel/,target=/vllmwheel/,ro \
|
||||
HOME=/root uv pip install /opencvwheels/*.whl /arrowwheels/*.whl /torchwheels/*.whl /vllmwheel/*.whl
|
||||
|
||||
COPY ./ /workspace/vllm
|
||||
|
||||
WORKDIR /workspace/vllm
|
||||
ARG GIT_REPO_CHECK=0
|
||||
RUN --mount=type=bind,source=.git,target=.git \
|
||||
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
RUSTFLAGS='-L /opt/conda/lib' pip install -v --prefer-binary --extra-index-url https://repo.fury.io/mgiessing \
|
||||
'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \
|
||||
-r requirements-cpu.txt \
|
||||
xformers uvloop==0.20.0
|
||||
|
||||
RUN --mount=type=bind,source=.git,target=.git \
|
||||
VLLM_TARGET_DEVICE=cpu python3 setup.py install
|
||||
|
||||
# install development dependencies (for testing)
|
||||
RUN python3 -m pip install -e tests/vllm_test_utils
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install -e tests/vllm_test_utils
|
||||
|
||||
WORKDIR /workspace/
|
||||
|
||||
RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks
|
||||
|
||||
ENTRYPOINT ["/opt/conda/bin/python3", "-m", "vllm.entrypoints.openai.api_server"]
|
||||
ENTRYPOINT ["python", "-m", "vllm.entrypoints.openai.api_server"]
|
||||
|
||||
@ -12,7 +12,8 @@ ENV PYTORCH_ROCM_ARCH=${ARG_PYTORCH_ROCM_ARCH:-${PYTORCH_ROCM_ARCH}}
|
||||
|
||||
# Install some basic utilities
|
||||
RUN apt-get update -q -y && apt-get install -q -y \
|
||||
sqlite3 libsqlite3-dev libfmt-dev libmsgpack-dev libsuitesparse-dev
|
||||
sqlite3 libsqlite3-dev libfmt-dev libmsgpack-dev libsuitesparse-dev \
|
||||
apt-transport-https ca-certificates wget curl
|
||||
# Remove sccache
|
||||
RUN python3 -m pip install --upgrade pip && pip install setuptools_scm
|
||||
RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(which sccache)"
|
||||
@ -38,14 +39,14 @@ FROM fetch_vllm AS build_vllm
|
||||
ARG USE_CYTHON
|
||||
# Build vLLM
|
||||
RUN cd vllm \
|
||||
&& python3 -m pip install -r requirements-rocm.txt \
|
||||
&& python3 -m pip install -r requirements/rocm.txt \
|
||||
&& python3 setup.py clean --all \
|
||||
&& if [ ${USE_CYTHON} -eq "1" ]; then python3 setup_cython.py build_ext --inplace; fi \
|
||||
&& if [ ${USE_CYTHON} -eq "1" ]; then python3 tests/build_cython.py build_ext --inplace; fi \
|
||||
&& python3 setup.py bdist_wheel --dist-dir=dist
|
||||
FROM scratch AS export_vllm
|
||||
ARG COMMON_WORKDIR
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/dist/*.whl /
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/requirements*.txt /
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/requirements /requirements
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/benchmarks /benchmarks
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/tests /tests
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/examples /examples
|
||||
@ -60,7 +61,8 @@ RUN python3 -m pip install --upgrade pip && rm -rf /var/lib/apt/lists/*
|
||||
# Install vLLM
|
||||
RUN --mount=type=bind,from=export_vllm,src=/,target=/install \
|
||||
cd /install \
|
||||
&& pip install -U -r requirements-rocm.txt \
|
||||
&& pip install -U -r requirements/rocm.txt \
|
||||
&& pip install -U -r requirements/rocm-test.txt \
|
||||
&& pip uninstall -y vllm \
|
||||
&& pip install *.whl
|
||||
|
||||
@ -99,7 +101,7 @@ RUN if [ ${BUILD_RPD} -eq "1" ]; then \
|
||||
# Install vLLM
|
||||
RUN --mount=type=bind,from=export_vllm,src=/,target=/install \
|
||||
cd /install \
|
||||
&& pip install -U -r requirements-rocm.txt \
|
||||
&& pip install -U -r requirements/rocm.txt \
|
||||
&& pip uninstall -y vllm \
|
||||
&& pip install *.whl
|
||||
|
||||
|
||||
@ -12,6 +12,8 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
|
||||
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
|
||||
ARG FA_BRANCH="b7d29fb"
|
||||
ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
|
||||
ARG AITER_BRANCH="21d47a9"
|
||||
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
|
||||
|
||||
FROM ${BASE_IMAGE} AS base
|
||||
|
||||
@ -129,8 +131,18 @@ RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
|
||||
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
|
||||
pip install /install/*.whl
|
||||
|
||||
ARG AITER_REPO
|
||||
ARG AITER_BRANCH
|
||||
RUN git clone --recursive ${AITER_REPO}
|
||||
RUN cd aiter \
|
||||
&& git checkout ${AITER_BRANCH} \
|
||||
&& git submodule update --init --recursive \
|
||||
&& pip install -r requirements.txt \
|
||||
&& PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop && pip show aiter
|
||||
|
||||
ARG BASE_IMAGE
|
||||
ARG HIPBLASLT_BRANCH
|
||||
ARG HIPBLAS_COMMON_BRANCH
|
||||
ARG LEGACY_HIPBLASLT_OPTION
|
||||
ARG RCCL_BRANCH
|
||||
ARG RCCL_REPO
|
||||
@ -155,4 +167,6 @@ RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \
|
||||
&& echo "PYTORCH_REPO: ${PYTORCH_REPO}" >> /app/versions.txt \
|
||||
&& echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \
|
||||
&& echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt
|
||||
&& echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt \
|
||||
&& echo "AITER_BRANCH: ${AITER_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt
|
||||
|
||||
152
Dockerfile.s390x
Normal file
152
Dockerfile.s390x
Normal file
@ -0,0 +1,152 @@
|
||||
# Base UBI image for s390x architecture
|
||||
ARG BASE_UBI_IMAGE_TAG=9.5-1736404155
|
||||
ARG PYTHON_VERSION=3.12
|
||||
FROM registry.access.redhat.com/ubi9/ubi-minimal:${BASE_UBI_IMAGE_TAG} AS base
|
||||
|
||||
# Install basic dependencies
|
||||
ARG PYTHON_VERSION
|
||||
ENV PYTHON_VERSION=${PYTHON_VERSION}
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
ENV LANG=C.UTF-8 \
|
||||
LC_ALL=C.UTF-8
|
||||
|
||||
# Install development utilities
|
||||
RUN microdnf install -y \
|
||||
which procps findutils tar vim git gcc gcc-gfortran g++ make patch zlib-devel \
|
||||
libjpeg-turbo-devel libtiff-devel libpng-devel libwebp-devel freetype-devel harfbuzz-devel \
|
||||
openssl-devel openblas openblas-devel autoconf automake libtool cmake && \
|
||||
microdnf clean all
|
||||
|
||||
# Python Installation
|
||||
FROM base AS python-install
|
||||
ARG PYTHON_VERSION
|
||||
|
||||
ENV VIRTUAL_ENV=/opt/vllm
|
||||
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
|
||||
ENV PYTHON_VERSION=${PYTHON_VERSION}
|
||||
RUN microdnf install -y \
|
||||
python${PYTHON_VERSION}-devel python${PYTHON_VERSION}-pip python${PYTHON_VERSION}-wheel && \
|
||||
python${PYTHON_VERSION} -m venv $VIRTUAL_ENV && pip install --no-cache -U pip wheel uv && microdnf clean all
|
||||
|
||||
FROM python-install AS pyarrow
|
||||
|
||||
# Build Apache Arrow
|
||||
WORKDIR /tmp
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
git clone https://github.com/apache/arrow.git && \
|
||||
cd arrow/cpp && \
|
||||
mkdir release && cd release && \
|
||||
cmake -DCMAKE_BUILD_TYPE=Release \
|
||||
-DCMAKE_INSTALL_PREFIX=/usr/local \
|
||||
-DARROW_PYTHON=ON \
|
||||
-DARROW_PARQUET=ON \
|
||||
-DARROW_ORC=ON \
|
||||
-DARROW_FILESYSTEM=ON \
|
||||
-DARROW_WITH_LZ4=ON \
|
||||
-DARROW_WITH_ZSTD=ON \
|
||||
-DARROW_WITH_SNAPPY=ON \
|
||||
-DARROW_JSON=ON \
|
||||
-DARROW_CSV=ON \
|
||||
-DARROW_DATASET=ON \
|
||||
-DPROTOBUF_PROTOC_EXECUTABLE=/usr/bin/protoc \
|
||||
-DARROW_DEPENDENCY_SOURCE=BUNDLED \
|
||||
.. && \
|
||||
make -j$(nproc) && \
|
||||
make install && \
|
||||
cd ../../python && \
|
||||
export PYARROW_PARALLEL=4 && \
|
||||
export ARROW_BUILD_TYPE=release && \
|
||||
uv pip install -r requirements/build.txt && \
|
||||
python setup.py build_ext --build-type=$ARROW_BUILD_TYPE --bundle-arrow-cpp bdist_wheel
|
||||
|
||||
FROM python-install AS numa-build
|
||||
# Install numactl (needed for numa.h dependency)
|
||||
WORKDIR /tmp
|
||||
RUN curl -LO https://github.com/numactl/numactl/archive/refs/tags/v2.0.16.tar.gz && \
|
||||
tar -xvzf v2.0.16.tar.gz && \
|
||||
cd numactl-2.0.16 && \
|
||||
./autogen.sh && \
|
||||
./configure && \
|
||||
make
|
||||
|
||||
# Set include path
|
||||
ENV C_INCLUDE_PATH="/usr/local/include:$C_INCLUDE_PATH"
|
||||
|
||||
FROM python-install AS rust
|
||||
ENV CARGO_HOME=/root/.cargo
|
||||
ENV RUSTUP_HOME=/root/.rustup
|
||||
ENV PATH="$CARGO_HOME/bin:$RUSTUP_HOME/bin:$PATH"
|
||||
|
||||
RUN curl https://sh.rustup.rs -sSf | sh -s -- -y && \
|
||||
. "$CARGO_HOME/env" && \
|
||||
rustup default stable && \
|
||||
rustup show
|
||||
|
||||
FROM python-install AS torch-vision
|
||||
# Install torchvision
|
||||
ARG TORCH_VERSION=2.7.0.dev20250304
|
||||
ARG TORCH_VISION_VERSION=v0.20.1
|
||||
WORKDIR /tmp
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
git clone https://github.com/pytorch/vision.git && \
|
||||
cd vision && \
|
||||
git checkout $TORCH_VISION_VERSION && \
|
||||
uv pip install -v torch==${TORCH_VERSION} --extra-index-url https://download.pytorch.org/whl/nightly/cpu && \
|
||||
python setup.py bdist_wheel
|
||||
|
||||
# Final build stage
|
||||
FROM python-install AS vllm-cpu
|
||||
ARG PYTHON_VERSION
|
||||
|
||||
# Set correct library path for torch and numactl
|
||||
ENV LD_LIBRARY_PATH="/opt/vllm/lib64/python${PYTHON_VERSION}/site-packages/torch/lib:/usr/local/lib:$LD_LIBRARY_PATH"
|
||||
ENV C_INCLUDE_PATH="/usr/local/include:$C_INCLUDE_PATH"
|
||||
ENV UV_LINK_MODE=copy
|
||||
ENV CARGO_HOME=/root/.cargo
|
||||
ENV RUSTUP_HOME=/root/.rustup
|
||||
ENV PATH="$CARGO_HOME/bin:$RUSTUP_HOME/bin:$PATH"
|
||||
|
||||
COPY . /workspace/vllm
|
||||
WORKDIR /workspace/vllm
|
||||
|
||||
RUN --mount=type=bind,from=numa-build,src=/tmp/numactl-2.0.16,target=/numactl \
|
||||
make -C /numactl install
|
||||
|
||||
# Install dependencies, including PyTorch and Apache Arrow
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,from=rust,source=/root/.cargo,target=/root/.cargo,rw \
|
||||
--mount=type=bind,from=rust,source=/root/.rustup,target=/root/.rustup,rw \
|
||||
--mount=type=bind,from=pyarrow,source=/tmp/arrow/python/dist,target=/tmp/arrow-wheels \
|
||||
--mount=type=bind,from=torch-vision,source=/tmp/vision/dist,target=/tmp/vision-wheels/ \
|
||||
sed -i '/^torch/d' requirements/build.txt && \
|
||||
ARROW_WHL_FILE=$(ls /tmp/arrow-wheels/pyarrow-*.whl | head -n 1) && \
|
||||
VISION_WHL_FILE=$(ls /tmp/vision-wheels/*.whl | head -n 1) && \
|
||||
uv pip install -v \
|
||||
$ARROW_WHL_FILE \
|
||||
$VISION_WHL_FILE \
|
||||
--extra-index-url https://download.pytorch.org/whl/nightly/cpu \
|
||||
--index-strategy unsafe-best-match \
|
||||
-r requirements/build.txt \
|
||||
-r requirements/cpu.txt
|
||||
|
||||
# Build and install vllm
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
VLLM_TARGET_DEVICE=cpu python setup.py bdist_wheel && \
|
||||
uv pip install "$(echo dist/*.whl)[tensorizer]"
|
||||
|
||||
# setup non-root user for vllm
|
||||
RUN umask 002 && \
|
||||
useradd --uid 2000 --gid 0 vllm && \
|
||||
mkdir -p /home/vllm && \
|
||||
chmod g+rwx /home/vllm
|
||||
|
||||
COPY LICENSE /licenses/vllm.md
|
||||
COPY examples/*.jinja /app/data/template/
|
||||
|
||||
USER 2000
|
||||
WORKDIR /home/vllm
|
||||
|
||||
# Set the default entrypoint
|
||||
ENTRYPOINT ["python", "-m", "vllm.entrypoints.openai.api_server"]
|
||||
@ -15,11 +15,14 @@ ARG GIT_REPO_CHECK=0
|
||||
RUN --mount=type=bind,source=.git,target=.git \
|
||||
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi
|
||||
|
||||
# Remove existing versions of dependencies
|
||||
RUN pip uninstall -y torch torch_xla torchvision
|
||||
|
||||
ENV VLLM_TARGET_DEVICE="tpu"
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
--mount=type=bind,source=.git,target=.git \
|
||||
python3 -m pip install \
|
||||
-r requirements-tpu.txt
|
||||
-r requirements/tpu.txt
|
||||
RUN python3 setup.py develop
|
||||
|
||||
# install development dependencies (for testing)
|
||||
|
||||
@ -1,11 +1,7 @@
|
||||
FROM intel/oneapi-basekit:2024.2.1-0-devel-ubuntu22.04 AS vllm-base
|
||||
# oneapi 2025.0.2 docker base image use rolling 2448 package. https://dgpu-docs.intel.com/releases/packages.html?release=Rolling+2448.13&os=Ubuntu+22.04, and we don't need install driver manually.
|
||||
FROM intel/deep-learning-essentials:2025.0.2-0-devel-ubuntu22.04 AS vllm-base
|
||||
|
||||
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/intel-oneapi-archive-keyring.gpg > /dev/null && \
|
||||
echo "deb [signed-by=/usr/share/keyrings/intel-oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main " | tee /etc/apt/sources.list.d/oneAPI.list && \
|
||||
chmod 644 /usr/share/keyrings/intel-oneapi-archive-keyring.gpg && \
|
||||
wget -O- https://repositories.intel.com/graphics/intel-graphics.key | gpg --dearmor | tee /usr/share/keyrings/intel-graphics.gpg > /dev/null && \
|
||||
echo "deb [arch=amd64,i386 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/graphics/ubuntu jammy arc" | tee /etc/apt/sources.list.d/intel.gpu.jammy.list && \
|
||||
chmod 644 /usr/share/keyrings/intel-graphics.gpg
|
||||
RUN rm /etc/apt/sources.list.d/intel-graphics.list
|
||||
|
||||
RUN apt-get update -y && \
|
||||
apt-get install -y --no-install-recommends --fix-missing \
|
||||
@ -21,30 +17,20 @@ RUN apt-get update -y && \
|
||||
python3 \
|
||||
python3-dev \
|
||||
python3-pip \
|
||||
# vim \
|
||||
wget
|
||||
|
||||
WORKDIR /workspace/vllm
|
||||
COPY requirements-xpu.txt /workspace/vllm/requirements-xpu.txt
|
||||
COPY requirements-common.txt /workspace/vllm/requirements-common.txt
|
||||
COPY requirements/xpu.txt /workspace/vllm/requirements/xpu.txt
|
||||
COPY requirements/common.txt /workspace/vllm/requirements/common.txt
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install --no-cache-dir \
|
||||
-r requirements-xpu.txt
|
||||
|
||||
RUN git clone https://github.com/intel/pti-gpu && \
|
||||
cd pti-gpu/sdk && \
|
||||
git checkout 6c491f07a777ed872c2654ca9942f1d0dde0a082 && \
|
||||
mkdir build && \
|
||||
cd build && \
|
||||
cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=../cmake/toolchains/icpx_toolchain.cmake -DBUILD_TESTING=OFF .. && \
|
||||
make -j && \
|
||||
cmake --install . --config Release --prefix "/usr/local"
|
||||
-r requirements/xpu.txt
|
||||
|
||||
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/lib/"
|
||||
|
||||
COPY . .
|
||||
ARG GIT_REPO_CHECK
|
||||
ARG GIT_REPO_CHECK=0
|
||||
RUN --mount=type=bind,source=.git,target=.git \
|
||||
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi
|
||||
|
||||
@ -54,6 +40,12 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
--mount=type=bind,source=.git,target=.git \
|
||||
python3 setup.py install
|
||||
|
||||
# Please refer xpu doc, we need manually install intel-extension-for-pytorch 2.6.10+xpu due to there are some conflict dependencies with torch 2.6.0+xpu
|
||||
# FIXME: This will be fix in ipex 2.7. just leave this here for awareness.
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install intel-extension-for-pytorch==2.6.10+xpu \
|
||||
--extra-index-url=https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
FROM vllm-base AS vllm-openai
|
||||
|
||||
10
MANIFEST.in
10
MANIFEST.in
@ -1,9 +1,9 @@
|
||||
include LICENSE
|
||||
include requirements-common.txt
|
||||
include requirements-cuda.txt
|
||||
include requirements-rocm.txt
|
||||
include requirements-neuron.txt
|
||||
include requirements-cpu.txt
|
||||
include requirements/common.txt
|
||||
include requirements/cuda.txt
|
||||
include requirements/rocm.txt
|
||||
include requirements/neuron.txt
|
||||
include requirements/cpu.txt
|
||||
include CMakeLists.txt
|
||||
|
||||
recursive-include cmake *
|
||||
|
||||
29
README.md
29
README.md
@ -10,24 +10,29 @@ Easy, fast, and cheap LLM serving for everyone
|
||||
</h3>
|
||||
|
||||
<p align="center">
|
||||
| <a href="https://docs.vllm.ai"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://x.com/vllm_project"><b>Twitter/X</b></a> | <a href="https://slack.vllm.ai"><b>Developer Slack</b></a> |
|
||||
| <a href="https://docs.vllm.ai"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://x.com/vllm_project"><b>Twitter/X</b></a> | <a href="https://discuss.vllm.ai"><b>User Forum</b></a> | <a href="https://slack.vllm.ai"><b>Developer Slack</b></a> |
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
We’re excited to invite you to the first **vLLM China Meetup** on **March 16** in **Beijing**!
|
||||
[2025/03] We are collaborating with Ollama to host an [Inference Night](https://lu.ma/vllm-ollama) at Y Combinator in San Francisco on Thursday, March 27, at 6 PM. Discuss all things inference local or data center!
|
||||
|
||||
Join us to connect with the **vLLM team** and explore how vLLM is leveraged in **post-training, fine-tuning, and deployment**, including [verl](https://github.com/volcengine/verl), [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory), and [vllm-ascend](https://github.com/vllm-project/vllm-ascend).
|
||||
|
||||
👉 **[Register Now](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg)** to be part of the discussion!
|
||||
[2025/04] We're hosting our first-ever *vLLM Asia Developer Day* in Singapore on *April 3rd*! This is a full-day event (9 AM - 9 PM SGT) in partnership with SGInnovate, AMD, and Embedded LLM. Meet the vLLM team and learn about LLM inference for RL, MI300X, and more! [Register Now](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day)
|
||||
|
||||
---
|
||||
|
||||
*Latest News* 🔥
|
||||
|
||||
- [2025/03] We hosted [the first vLLM China Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg)! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1REHvfQMKGnvz6p3Fd23HhSO4c8j5WPGZV0bKYLwnHyQ/edit?usp=sharing).
|
||||
- [2025/03] We hosted [the East Coast vLLM Meetup](https://lu.ma/7mu4k4xx)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1NHiv8EUFF1NLd3fEYODm56nDmL26lEeXCaDgyDlTsRs/edit#slide=id.g31441846c39_0_0).
|
||||
- [2025/02] We hosted [the ninth vLLM meetup](https://lu.ma/h7g3kuj9) with Meta! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1jzC_PZVXrVNSFVCW-V4cFXb6pn7zZ2CyP_Flwo05aqg/edit?usp=sharing) and AMD [here](https://drive.google.com/file/d/1Zk5qEJIkTmlQ2eQcXQZlljAx3m9s7nwn/view?usp=sharing). The slides from Meta will not be posted.
|
||||
- [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html).
|
||||
- [2025/01] We hosted [the eighth vLLM meetup](https://lu.ma/zep56hui) with Google Cloud! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1epVkt4Zu8Jz_S5OhEHPc798emsYh2BwYfRuDDVEF7u4/edit?usp=sharing), and Google Cloud team [here](https://drive.google.com/file/d/1h24pHewANyRL11xy5dXUbvRC9F9Kkjix/view?usp=sharing).
|
||||
- [2024/12] vLLM joins [pytorch ecosystem](https://pytorch.org/blog/vllm-joins-pytorch)! Easy, Fast, and Cheap LLM Serving for Everyone!
|
||||
|
||||
<details>
|
||||
<summary>Previous News</summary>
|
||||
|
||||
- [2024/11] We hosted [the seventh vLLM meetup](https://lu.ma/h0qvrajz) with Snowflake! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1e3CxQBV3JsfGp30SwyvS3eM_tW-ghOhJ9PAJGK6KR54/edit?usp=sharing), and Snowflake team [here](https://docs.google.com/presentation/d/1qF3RkDAbOULwz9WK5TOltt2fE9t6uIc_hVNLFAaQX6A/edit?usp=sharing).
|
||||
- [2024/10] We have just created a developer slack ([slack.vllm.ai](https://slack.vllm.ai)) focusing on coordinating contributions and discussing features. Please feel free to join us there!
|
||||
- [2024/10] Ray Summit 2024 held a special track for vLLM! Please find the opening talk slides from the vLLM team [here](https://docs.google.com/presentation/d/1B_KQxpHBTRa_mDF-tR6i8rWdOU5QoTZNcEg2MKZxEHM/edit?usp=sharing). Learn more from the [talks](https://www.youtube.com/playlist?list=PLzTswPQNepXl6AQwifuwUImLPFRVpksjR) from other vLLM contributors and users!
|
||||
@ -41,8 +46,9 @@ Join us to connect with the **vLLM team** and explore how vLLM is leveraged in *
|
||||
- [2023/08] We would like to express our sincere gratitude to [Andreessen Horowitz](https://a16z.com/2023/08/30/supporting-the-open-source-ai-community/) (a16z) for providing a generous grant to support the open-source development and research of vLLM.
|
||||
- [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai).
|
||||
|
||||
---
|
||||
</details>
|
||||
|
||||
---
|
||||
## About
|
||||
|
||||
vLLM is a fast and easy-to-use library for LLM inference and serving.
|
||||
@ -90,7 +96,7 @@ pip install vllm
|
||||
```
|
||||
|
||||
Visit our [documentation](https://docs.vllm.ai/en/latest/) to learn more.
|
||||
- [Installation](https://docs.vllm.ai/en/latest/getting_started/installation/index.html)
|
||||
- [Installation](https://docs.vllm.ai/en/latest/getting_started/installation.html)
|
||||
- [Quickstart](https://docs.vllm.ai/en/latest/getting_started/quickstart.html)
|
||||
- [List of Supported Models](https://docs.vllm.ai/en/latest/models/supported_models.html)
|
||||
|
||||
@ -150,10 +156,11 @@ If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs
|
||||
|
||||
## Contact Us
|
||||
|
||||
- For technical questions and feature requests, please use Github issues or discussions.
|
||||
- For discussing with fellow users and coordinating contributions and development, please use Slack.
|
||||
- For security disclosures, please use Github's security advisory feature.
|
||||
- For collaborations and partnerships, please contact us at vllm-questions AT lists.berkeley.edu.
|
||||
- For technical questions and feature requests, please use GitHub [Issues](https://github.com/vllm-project/vllm/issues) or [Discussions](https://github.com/vllm-project/vllm/discussions)
|
||||
- For discussing with fellow users, please use the [vLLM Forum](https://discuss.vllm.ai)
|
||||
- coordinating contributions and development, please use [Slack](https://slack.vllm.ai)
|
||||
- For security disclosures, please use GitHub's [Security Advisories](https://github.com/vllm-project/vllm/security/advisories) feature
|
||||
- For collaborations and partnerships, please contact us at [vllm-questions@lists.berkeley.edu](mailto:vllm-questions@lists.berkeley.edu)
|
||||
|
||||
## Media Kit
|
||||
|
||||
|
||||
@ -1,29 +1,268 @@
|
||||
# Benchmarking vLLM
|
||||
|
||||
## Downloading the ShareGPT dataset
|
||||
This README guides you through running benchmark tests with the extensive
|
||||
datasets supported on vLLM. It’s a living document, updated as new features and datasets
|
||||
become available.
|
||||
|
||||
You can download the dataset by running:
|
||||
## Dataset Overview
|
||||
|
||||
<table style="width:100%; border-collapse: collapse;">
|
||||
<thead>
|
||||
<tr>
|
||||
<th style="width:15%; text-align: left;">Dataset</th>
|
||||
<th style="width:10%; text-align: center;">Online</th>
|
||||
<th style="width:10%; text-align: center;">Offline</th>
|
||||
<th style="width:65%; text-align: left;">Data Path</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td><strong>ShareGPT</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td><code>wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>BurstGPT</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td><code>wget https://github.com/HPMLL/BurstGPT/releases/download/v1.1/BurstGPT_without_fails_2.csv</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>Sonnet</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td>Local file: <code>benchmarks/sonnet.txt</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>Random</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td><code>synthetic</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>HuggingFace</strong></td>
|
||||
<td style="text-align: center;">🟡</td>
|
||||
<td style="text-align: center;">🟡</td>
|
||||
<td>Specify your dataset path on HuggingFace</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>VisionArena</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td><code>lmarena-ai/vision-arena-bench-v0.1</code> (a HuggingFace dataset)</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
✅: supported
|
||||
|
||||
🚧: to be supported
|
||||
|
||||
🟡: Partial support. Currently, HuggingFaceDataset only supports dataset formats
|
||||
similar to `lmms-lab/LLaVA-OneVision-Data` and `Aeala/ShareGPT_Vicuna_unfiltered`.
|
||||
If you need support for other dataset formats, please consider contributing.
|
||||
|
||||
**Note**: VisionArena’s `dataset-name` should be set to `hf`
|
||||
|
||||
---
|
||||
## Example - Online Benchmark
|
||||
|
||||
First start serving your model
|
||||
|
||||
```bash
|
||||
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
MODEL_NAME="NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
vllm serve ${MODEL_NAME} --disable-log-requests
|
||||
```
|
||||
|
||||
## Downloading the ShareGPT4V dataset
|
||||
|
||||
The json file refers to several image datasets (coco, llava, etc.). The benchmark scripts
|
||||
will ignore a datapoint if the referred image is missing.
|
||||
Then run the benchmarking script
|
||||
|
||||
```bash
|
||||
wget https://huggingface.co/datasets/Lin-Chen/ShareGPT4V/resolve/main/sharegpt4v_instruct_gpt4-vision_cap100k.json
|
||||
mkdir coco -p
|
||||
wget http://images.cocodataset.org/zips/train2017.zip -O coco/train2017.zip
|
||||
unzip coco/train2017.zip -d coco/
|
||||
# download dataset
|
||||
# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
MODEL_NAME="NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
NUM_PROMPTS=10
|
||||
BACKEND="vllm"
|
||||
DATASET_NAME="sharegpt"
|
||||
DATASET_PATH="<your data path>/ShareGPT_V3_unfiltered_cleaned_split.json"
|
||||
python3 vllm/benchmarks/benchmark_serving.py --backend ${BACKEND} --model ${MODEL_NAME} --endpoint /v1/completions --dataset-name ${DATASET_NAME} --dataset-path ${DATASET_PATH} --num-prompts ${NUM_PROMPTS}
|
||||
```
|
||||
|
||||
# Downloading the BurstGPT dataset
|
||||
If successful, you will see the following output
|
||||
|
||||
You can download the BurstGPT v1.1 dataset by running:
|
||||
```
|
||||
============ Serving Benchmark Result ============
|
||||
Successful requests: 10
|
||||
Benchmark duration (s): 5.78
|
||||
Total input tokens: 1369
|
||||
Total generated tokens: 2212
|
||||
Request throughput (req/s): 1.73
|
||||
Output token throughput (tok/s): 382.89
|
||||
Total Token throughput (tok/s): 619.85
|
||||
---------------Time to First Token----------------
|
||||
Mean TTFT (ms): 71.54
|
||||
Median TTFT (ms): 73.88
|
||||
P99 TTFT (ms): 79.49
|
||||
-----Time per Output Token (excl. 1st token)------
|
||||
Mean TPOT (ms): 7.91
|
||||
Median TPOT (ms): 7.96
|
||||
P99 TPOT (ms): 8.03
|
||||
---------------Inter-token Latency----------------
|
||||
Mean ITL (ms): 7.74
|
||||
Median ITL (ms): 7.70
|
||||
P99 ITL (ms): 8.39
|
||||
==================================================
|
||||
```
|
||||
|
||||
### VisionArena Benchmark for Vision Language Models
|
||||
|
||||
```bash
|
||||
wget https://github.com/HPMLL/BurstGPT/releases/download/v1.1/BurstGPT_without_fails_2.csv
|
||||
# need a model with vision capability here
|
||||
vllm serve Qwen/Qwen2-VL-7B-Instruct --disable-log-requests
|
||||
```
|
||||
|
||||
```bash
|
||||
MODEL_NAME="Qwen/Qwen2-VL-7B-Instruct"
|
||||
NUM_PROMPTS=10
|
||||
BACKEND="openai-chat"
|
||||
DATASET_NAME="hf"
|
||||
DATASET_PATH="lmarena-ai/vision-arena-bench-v0.1"
|
||||
DATASET_SPLIT='train'
|
||||
|
||||
python3 vllm/benchmarks/benchmark_serving.py \
|
||||
--backend "${BACKEND}" \
|
||||
--model "${MODEL_NAME}" \
|
||||
--endpoint "/v1/chat/completions" \
|
||||
--dataset-name "${DATASET_NAME}" \
|
||||
--dataset-path "${DATASET_PATH}" \
|
||||
--hf-split "${DATASET_SPLIT}" \
|
||||
--num-prompts "${NUM_PROMPTS}"
|
||||
```
|
||||
|
||||
### HuggingFaceDataset Examples
|
||||
|
||||
Currently, HuggingFaceDataset only supports dataset formats
|
||||
similar to `lmms-lab/LLaVA-OneVision-Data` and `Aeala/ShareGPT_Vicuna_unfiltered`. If you need support for other dataset
|
||||
formats, please consider contributing.
|
||||
|
||||
```bash
|
||||
# need a model with vision capability here
|
||||
vllm serve Qwen/Qwen2-VL-7B-Instruct --disable-log-requests
|
||||
```
|
||||
|
||||
**`lmms-lab/LLaVA-OneVision-Data`**
|
||||
|
||||
```bash
|
||||
MODEL_NAME="Qwen/Qwen2-VL-7B-Instruct"
|
||||
NUM_PROMPTS=10
|
||||
BACKEND="openai-chat"
|
||||
DATASET_NAME="hf"
|
||||
DATASET_PATH="lmms-lab/LLaVA-OneVision-Data"
|
||||
DATASET_SPLIT='train'
|
||||
DATASET_SUBSET='chart2text(cauldron)'
|
||||
python3 vllm/benchmarks/benchmark_serving.py \
|
||||
--backend "${BACKEND}" \
|
||||
--model "${MODEL_NAME}" \
|
||||
--endpoint "/v1/chat/completions" \
|
||||
--dataset-name "${DATASET_NAME}" \
|
||||
--dataset-path "${DATASET_PATH}" \
|
||||
--hf-split "${DATASET_SPLIT}" \
|
||||
--num-prompts "${NUM_PROMPTS}" \
|
||||
--hf-subset "${DATASET_SUBSET}"
|
||||
```
|
||||
|
||||
**`Aeala/ShareGPT_Vicuna_unfiltered`**
|
||||
|
||||
```bash
|
||||
MODEL_NAME="Qwen/Qwen2-VL-7B-Instruct"
|
||||
NUM_PROMPTS=10
|
||||
BACKEND="openai-chat"
|
||||
DATASET_NAME="hf"
|
||||
DATASET_PATH="Aeala/ShareGPT_Vicuna_unfiltered"
|
||||
DATASET_SPLIT='train'
|
||||
python3 vllm/benchmarks/benchmark_serving.py \
|
||||
--backend "${BACKEND}" \
|
||||
--model "${MODEL_NAME}" \
|
||||
--endpoint "/v1/chat/completions" \
|
||||
--dataset-name "${DATASET_NAME}" \
|
||||
--dataset-path "${DATASET_PATH}" \
|
||||
--hf-split "${DATASET_SPLIT}" \
|
||||
--num-prompts "${NUM_PROMPTS}" \
|
||||
```
|
||||
|
||||
---
|
||||
## Example - Offline Throughput Benchmark
|
||||
|
||||
```bash
|
||||
MODEL_NAME="NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
NUM_PROMPTS=10
|
||||
DATASET_NAME="sonnet"
|
||||
DATASET_PATH="vllm/benchmarks/sonnet.txt"
|
||||
|
||||
python3 vllm/benchmarks/benchmark_throughput.py \
|
||||
--model "${MODEL_NAME}" \
|
||||
--dataset-name "${DATASET_NAME}" \
|
||||
--dataset-path "${DATASET_PATH}" \
|
||||
--num-prompts "${NUM_PROMPTS}"
|
||||
```
|
||||
|
||||
If successful, you will see the following output
|
||||
|
||||
```
|
||||
Throughput: 7.15 requests/s, 4656.00 total tokens/s, 1072.15 output tokens/s
|
||||
Total num prompt tokens: 5014
|
||||
Total num output tokens: 1500
|
||||
```
|
||||
|
||||
### VisionArena Benchmark for Vision Language Models
|
||||
|
||||
``` bash
|
||||
MODEL_NAME="Qwen/Qwen2-VL-7B-Instruct"
|
||||
NUM_PROMPTS=10
|
||||
DATASET_NAME="hf"
|
||||
DATASET_PATH="lmarena-ai/vision-arena-bench-v0.1"
|
||||
DATASET_SPLIT="train"
|
||||
|
||||
python3 vllm/benchmarks/benchmark_throughput.py \
|
||||
--model "${MODEL_NAME}" \
|
||||
--backend "vllm-chat" \
|
||||
--dataset-name "${DATASET_NAME}" \
|
||||
--dataset-path "${DATASET_PATH}" \
|
||||
--num-prompts "${NUM_PROMPTS}" \
|
||||
--hf-split "${DATASET_SPLIT}"
|
||||
```
|
||||
|
||||
The `num prompt tokens` now includes image token counts
|
||||
|
||||
```
|
||||
Throughput: 2.55 requests/s, 4036.92 total tokens/s, 326.90 output tokens/s
|
||||
Total num prompt tokens: 14527
|
||||
Total num output tokens: 1280
|
||||
```
|
||||
|
||||
### Benchmark with LoRA Adapters
|
||||
|
||||
``` bash
|
||||
# download dataset
|
||||
# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
MODEL_NAME="meta-llama/Llama-2-7b-hf"
|
||||
BACKEND="vllm"
|
||||
DATASET_NAME="sharegpt"
|
||||
DATASET_PATH="<your data path>/ShareGPT_V3_unfiltered_cleaned_split.json"
|
||||
NUM_PROMPTS=10
|
||||
MAX_LORAS=2
|
||||
MAX_LORA_RANK=8
|
||||
ENABLE_LORA="--enable-lora"
|
||||
LORA_PATH="yard1/llama-2-7b-sql-lora-test"
|
||||
|
||||
python3 vllm/benchmarks/benchmark_throughput.py \
|
||||
--model "${MODEL_NAME}" \
|
||||
--backend "${BACKEND}" \
|
||||
--dataset_path "${DATASET_PATH}" \
|
||||
--dataset_name "${DATASET_NAME}" \
|
||||
--num-prompts "${NUM_PROMPTS}" \
|
||||
--max-loras "${MAX_LORAS}" \
|
||||
--max-lora-rank "${MAX_LORA_RANK}" \
|
||||
${ENABLE_LORA} \
|
||||
--lora-path "${LORA_PATH}"
|
||||
```
|
||||
|
||||
@ -14,7 +14,8 @@ from tqdm.asyncio import tqdm
|
||||
from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast)
|
||||
|
||||
from vllm.model_executor.model_loader.weight_utils import get_lock
|
||||
# NOTE(simon): do not import vLLM here so the benchmark script
|
||||
# can run without vLLM installed.
|
||||
|
||||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
||||
|
||||
@ -27,7 +28,6 @@ class RequestFuncInput:
|
||||
output_len: int
|
||||
model: str
|
||||
model_name: Optional[str] = None
|
||||
best_of: int = 1
|
||||
logprobs: Optional[int] = None
|
||||
extra_body: Optional[dict] = None
|
||||
multi_modal_content: Optional[dict] = None
|
||||
@ -58,13 +58,12 @@ async def async_request_tgi(
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
params = {
|
||||
"best_of": request_func_input.best_of,
|
||||
"max_new_tokens": request_func_input.output_len,
|
||||
"do_sample": True,
|
||||
"temperature": 0.01, # TGI does not accept 0.0 temperature.
|
||||
"top_p": 0.99, # TGI does not accept 1.0 top_p.
|
||||
"truncate": request_func_input.prompt_len,
|
||||
# TGI does not accept ignore_eos flag.
|
||||
"ignore_eos_token": request_func_input.ignore_eos,
|
||||
}
|
||||
payload = {
|
||||
"inputs": request_func_input.prompt,
|
||||
@ -72,6 +71,10 @@ async def async_request_tgi(
|
||||
}
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
if request_func_input.ignore_eos:
|
||||
output.output_tokens = request_func_input.output_len
|
||||
else:
|
||||
output.output_tokens = None
|
||||
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
@ -130,7 +133,6 @@ async def async_request_trt_llm(
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
assert request_func_input.best_of == 1
|
||||
payload = {
|
||||
"accumulate_tokens": True,
|
||||
"text_input": request_func_input.prompt,
|
||||
@ -195,7 +197,6 @@ async def async_request_deepspeed_mii(
|
||||
) -> RequestFuncOutput:
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
assert request_func_input.best_of == 1
|
||||
|
||||
payload = {
|
||||
"prompt": request_func_input.prompt,
|
||||
@ -249,7 +250,6 @@ async def async_request_openai_completions(
|
||||
if request_func_input.model_name else request_func_input.model,
|
||||
"prompt": request_func_input.prompt,
|
||||
"temperature": 0.0,
|
||||
"best_of": request_func_input.best_of,
|
||||
"max_tokens": request_func_input.output_len,
|
||||
"logprobs": request_func_input.logprobs,
|
||||
"stream": True,
|
||||
@ -338,7 +338,7 @@ async def async_request_openai_chat_completions(
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith(
|
||||
"chat/completions"
|
||||
("chat/completions", "profile")
|
||||
), "OpenAI Chat Completions API URL must end with 'chat/completions'."
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
@ -432,6 +432,8 @@ def get_model(pretrained_model_name_or_path: str) -> str:
|
||||
if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true':
|
||||
from modelscope import snapshot_download
|
||||
|
||||
from vllm.model_executor.model_loader.weight_utils import get_lock
|
||||
|
||||
# Use file lock to prevent multiple processes from
|
||||
# downloading the same model weights at the same time.
|
||||
with get_lock(pretrained_model_name_or_path):
|
||||
|
||||
717
benchmarks/benchmark_dataset.py
Normal file
717
benchmarks/benchmark_dataset.py
Normal file
@ -0,0 +1,717 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
This module defines a framework for sampling benchmark requests from various
|
||||
datasets. Each dataset subclass of BenchmarkDataset must implement sample
|
||||
generation. Supported dataset types include:
|
||||
- ShareGPT
|
||||
- Random (synthetic)
|
||||
- Sonnet
|
||||
- BurstGPT
|
||||
- HuggingFace
|
||||
- VisionArena
|
||||
|
||||
TODO: Implement CustomDataset to parse a JSON file and convert its contents into
|
||||
SampleRequest instances, similar to the approach used in ShareGPT.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from functools import cache
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from datasets import load_dataset
|
||||
from PIL import Image
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.utils import get_adapter_absolute_path
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Data Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class SampleRequest:
|
||||
"""
|
||||
Represents a single inference request for benchmarking.
|
||||
"""
|
||||
|
||||
prompt: Union[str, Any]
|
||||
prompt_len: int
|
||||
expected_output_len: int
|
||||
multi_modal_data: Optional[Union[MultiModalDataDict, dict]] = None
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Benchmark Dataset Base Class
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class BenchmarkDataset(ABC):
|
||||
DEFAULT_SEED = 0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_path: Optional[str] = None,
|
||||
random_seed: int = DEFAULT_SEED,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the BenchmarkDataset with an optional dataset path and random
|
||||
seed. Args:
|
||||
dataset_path (Optional[str]): Path to the dataset. If None, it
|
||||
indicates that a default or random dataset might be used.
|
||||
random_seed (int): Seed value for reproducible shuffling or
|
||||
sampling. Defaults to DEFAULT_SEED.
|
||||
"""
|
||||
self.dataset_path = dataset_path
|
||||
# Set the random seed, ensuring that a None value is replaced with the
|
||||
# default seed.
|
||||
self.random_seed = (random_seed
|
||||
if random_seed is not None else self.DEFAULT_SEED)
|
||||
self.data = None
|
||||
|
||||
def apply_multimodal_chat_transformation(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_content: Optional[MultiModalDataDict] = None) -> list[dict]:
|
||||
"""
|
||||
Transform a prompt and optional multimodal content into a chat format.
|
||||
This method is used for chat models that expect a specific conversation
|
||||
format.
|
||||
"""
|
||||
content = [{"text": prompt, "type": "text"}]
|
||||
if mm_content is not None:
|
||||
content.append(mm_content)
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def load_data(self) -> None:
|
||||
"""
|
||||
Load data from the dataset path into self.data.
|
||||
|
||||
This method must be overridden by subclasses since the method to load
|
||||
data will vary depending on the dataset format and source.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If a subclass does not implement this method.
|
||||
"""
|
||||
# TODO (jenniferzhao): add support for downloading data
|
||||
raise NotImplementedError(
|
||||
"load_data must be implemented in subclasses.")
|
||||
|
||||
def get_random_lora_request(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
max_loras: Optional[int] = None,
|
||||
lora_path: Optional[str] = None,
|
||||
) -> tuple[Optional[LoRARequest], AnyTokenizer]:
|
||||
"""
|
||||
Optionally select a random LoRA request and return its associated
|
||||
tokenizer.
|
||||
|
||||
This method is used when LoRA parameters are provided. It randomly
|
||||
selects a LoRA based on max_loras and retrieves a cached tokenizer for
|
||||
that LoRA if available. Otherwise, it returns the base tokenizer.
|
||||
|
||||
Args:
|
||||
tokenizer (PreTrainedTokenizerBase): The base tokenizer to use if no
|
||||
LoRA is selected. max_loras (Optional[int]): The maximum number of
|
||||
LoRAs available. If None, LoRA is not used. lora_path
|
||||
(Optional[str]): Path to the LoRA parameters on disk. If None, LoRA
|
||||
is not used.
|
||||
|
||||
Returns:
|
||||
tuple[Optional[LoRARequest], AnyTokenizer]: A tuple where the first
|
||||
element is a LoRARequest (or None if not applicable) and the second
|
||||
element is the tokenizer associated with the LoRA request (or the
|
||||
base tokenizer).
|
||||
"""
|
||||
if max_loras is None or lora_path is None:
|
||||
return None, tokenizer
|
||||
|
||||
# Generate a random LoRA ID in the range [1, max_loras].
|
||||
lora_id = random.randint(1, max_loras)
|
||||
lora_request = LoRARequest(
|
||||
lora_name=str(lora_id),
|
||||
lora_int_id=lora_id,
|
||||
lora_path=lora_path_on_disk(lora_path),
|
||||
)
|
||||
if lora_id not in lora_tokenizer_cache:
|
||||
lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request)
|
||||
# Return lora_request and the cached tokenizer if available; otherwise,
|
||||
# return the base tokenizer
|
||||
return lora_request, lora_tokenizer_cache[lora_id] or tokenizer
|
||||
|
||||
@abstractmethod
|
||||
def sample(self, tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int) -> list[SampleRequest]:
|
||||
"""
|
||||
Abstract method to generate sample requests from the dataset.
|
||||
|
||||
Subclasses must override this method to implement dataset-specific logic
|
||||
for generating a list of SampleRequest objects.
|
||||
|
||||
Args:
|
||||
tokenizer (PreTrainedTokenizerBase): The tokenizer to be used
|
||||
for processing the dataset's text.
|
||||
num_requests (int): The number of sample requests to generate.
|
||||
|
||||
Returns:
|
||||
list[SampleRequest]: A list of sample requests generated from the
|
||||
dataset.
|
||||
"""
|
||||
raise NotImplementedError("sample must be implemented in subclasses.")
|
||||
|
||||
def maybe_oversample_requests(self, requests: list[SampleRequest],
|
||||
num_requests: int) -> None:
|
||||
"""
|
||||
Oversamples the list of requests if its size is less than the desired
|
||||
number.
|
||||
|
||||
Args:
|
||||
requests (List[SampleRequest]): The current list of sampled
|
||||
requests. num_requests (int): The target number of requests.
|
||||
"""
|
||||
if len(requests) < num_requests:
|
||||
random.seed(self.random_seed)
|
||||
additional = random.choices(requests,
|
||||
k=num_requests - len(requests))
|
||||
requests.extend(additional)
|
||||
logger.info("Oversampled requests to reach %d total samples.",
|
||||
num_requests)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Utility Functions and Global Caches
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def is_valid_sequence(
|
||||
prompt_len: int,
|
||||
output_len: int,
|
||||
min_len: int = 4,
|
||||
max_prompt_len: int = 1024,
|
||||
max_total_len: int = 2048,
|
||||
skip_min_output_len_check: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Validate a sequence based on prompt and output lengths.
|
||||
|
||||
Default pruning criteria are copied from the original `sample_hf_requests`
|
||||
and `sample_sharegpt_requests` functions in benchmark_serving.py, as well as
|
||||
from `sample_requests` in benchmark_throughput.py.
|
||||
"""
|
||||
# Check for invalid conditions
|
||||
prompt_too_short = prompt_len < min_len
|
||||
output_too_short = (not skip_min_output_len_check) and (output_len
|
||||
< min_len)
|
||||
prompt_too_long = prompt_len > max_prompt_len
|
||||
combined_too_long = (prompt_len + output_len) > max_total_len
|
||||
|
||||
# Return True if none of the invalid conditions are met
|
||||
return not (prompt_too_short or output_too_short or prompt_too_long
|
||||
or combined_too_long)
|
||||
|
||||
|
||||
@cache
|
||||
def lora_path_on_disk(lora_path: str) -> str:
|
||||
return get_adapter_absolute_path(lora_path)
|
||||
|
||||
|
||||
# Global cache for LoRA tokenizers.
|
||||
lora_tokenizer_cache: dict[int, AnyTokenizer] = {}
|
||||
|
||||
|
||||
def process_image(image: Any) -> Mapping[str, Any]:
|
||||
"""
|
||||
Process a single image input and return a multimedia content dictionary.
|
||||
|
||||
For a PIL.Image.Image input:
|
||||
- Converts the image to RGB.
|
||||
- Saves the image as a JPEG in-memory.
|
||||
- Encodes the JPEG data as a base64 string.
|
||||
- Returns a dictionary with the image as a base64 data URL.
|
||||
|
||||
For a string input:
|
||||
- Treats the string as a URL or file path.
|
||||
- Prepends "file://" if the string doesn't start with "http://" or
|
||||
"file://".
|
||||
- Returns a dictionary with the image URL.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input is neither a PIL.Image.Image nor a string.
|
||||
"""
|
||||
if isinstance(image, Image.Image):
|
||||
image = image.convert("RGB")
|
||||
with io.BytesIO() as image_data:
|
||||
image.save(image_data, format="JPEG")
|
||||
image_base64 = base64.b64encode(
|
||||
image_data.getvalue()).decode("utf-8")
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image_base64}"
|
||||
},
|
||||
}
|
||||
|
||||
if isinstance(image, str):
|
||||
image_url = (image if image.startswith(
|
||||
("http://", "file://")) else f"file://{image}")
|
||||
return {"type": "image_url", "image_url": {"url": image_url}}
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid image input {image}. Must be a PIL.Image.Image or str.")
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Random Dataset Implementation (Synthetic Data)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class RandomDataset(BenchmarkDataset):
|
||||
# Default values copied from benchmark_serving.py for the random dataset.
|
||||
DEFAULT_PREFIX_LEN = 0
|
||||
DEFAULT_RANGE_RATIO = 1.0
|
||||
DEFAULT_INPUT_LEN = 1024
|
||||
DEFAULT_OUTPUT_LEN = 128
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
prefix_len: int = DEFAULT_PREFIX_LEN,
|
||||
range_ratio: float = DEFAULT_RANGE_RATIO,
|
||||
input_len: int = DEFAULT_INPUT_LEN,
|
||||
output_len: int = DEFAULT_OUTPUT_LEN,
|
||||
**kwargs,
|
||||
) -> list[SampleRequest]:
|
||||
vocab_size = tokenizer.vocab_size
|
||||
|
||||
prefix_token_ids = (np.random.randint(
|
||||
0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else [])
|
||||
|
||||
input_low = int(input_len * range_ratio)
|
||||
output_low = int(output_len * range_ratio)
|
||||
|
||||
input_lens = np.random.randint(input_low,
|
||||
input_len + 1,
|
||||
size=num_requests)
|
||||
output_lens = np.random.randint(output_low,
|
||||
output_len + 1,
|
||||
size=num_requests)
|
||||
offsets = np.random.randint(0, vocab_size, size=num_requests)
|
||||
|
||||
requests = []
|
||||
for i in range(num_requests):
|
||||
inner_seq = ((offsets[i] + i + np.arange(input_lens[i])) %
|
||||
vocab_size).tolist()
|
||||
token_sequence = prefix_token_ids + inner_seq
|
||||
prompt = tokenizer.decode(token_sequence)
|
||||
total_input_len = prefix_len + int(input_lens[i])
|
||||
requests.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=total_input_len,
|
||||
expected_output_len=int(output_lens[i]),
|
||||
))
|
||||
return requests
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# ShareGPT Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ShareGPTDataset(BenchmarkDataset):
|
||||
"""
|
||||
Implements the ShareGPT dataset. Loads data from a JSON file and generates
|
||||
sample requests based on conversation turns.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.load_data()
|
||||
|
||||
def load_data(self) -> None:
|
||||
if self.dataset_path is None:
|
||||
raise ValueError("dataset_path must be provided for loading data.")
|
||||
|
||||
with open(self.dataset_path, encoding="utf-8") as f:
|
||||
self.data = json.load(f)
|
||||
# Filter entries with at least two conversation turns.
|
||||
self.data = [
|
||||
entry for entry in self.data
|
||||
if "conversations" in entry and len(entry["conversations"]) >= 2
|
||||
]
|
||||
random.seed(self.random_seed)
|
||||
random.shuffle(self.data)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
lora_path: Optional[str] = None,
|
||||
max_loras: Optional[int] = None,
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
**kwargs,
|
||||
) -> list:
|
||||
samples: list = []
|
||||
for entry in self.data:
|
||||
if len(samples) >= num_requests:
|
||||
break
|
||||
prompt, completion = (
|
||||
entry["conversations"][0]["value"],
|
||||
entry["conversations"][1]["value"],
|
||||
)
|
||||
|
||||
lora_request, tokenizer = self.get_random_lora_request(
|
||||
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path)
|
||||
prompt_ids = tokenizer(prompt).input_ids
|
||||
completion_ids = tokenizer(completion).input_ids
|
||||
prompt_len = len(prompt_ids)
|
||||
new_output_len = (len(completion_ids)
|
||||
if output_len is None else output_len)
|
||||
if not is_valid_sequence(prompt_len,
|
||||
new_output_len,
|
||||
skip_min_output_len_check=output_len
|
||||
is not None):
|
||||
continue
|
||||
if enable_multimodal_chat:
|
||||
prompt = self.apply_multimodal_chat_transformation(
|
||||
prompt, None)
|
||||
samples.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=new_output_len,
|
||||
lora_request=lora_request,
|
||||
))
|
||||
self.maybe_oversample_requests(samples, num_requests)
|
||||
return samples
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Sonnet Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SonnetDataset(BenchmarkDataset):
|
||||
"""
|
||||
Simplified implementation of the Sonnet dataset. Loads poem lines from a
|
||||
text file and generates sample requests. Default values here copied from
|
||||
`benchmark_serving.py` for the sonnet dataset.
|
||||
"""
|
||||
|
||||
DEFAULT_PREFIX_LEN = 200
|
||||
DEFAULT_INPUT_LEN = 550
|
||||
DEFAULT_OUTPUT_LEN = 150
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.load_data()
|
||||
|
||||
def load_data(self) -> None:
|
||||
if not self.dataset_path:
|
||||
raise ValueError("dataset_path must be provided.")
|
||||
with open(self.dataset_path, encoding="utf-8") as f:
|
||||
self.data = f.readlines()
|
||||
|
||||
def sample(
|
||||
self,
|
||||
tokenizer,
|
||||
num_requests: int,
|
||||
prefix_len: int = DEFAULT_PREFIX_LEN,
|
||||
input_len: int = DEFAULT_INPUT_LEN,
|
||||
output_len: int = DEFAULT_OUTPUT_LEN,
|
||||
return_prompt_formatted: bool = False,
|
||||
**kwargs,
|
||||
) -> list:
|
||||
# Calculate average token length for a poem line.
|
||||
tokenized_lines = [tokenizer(line).input_ids for line in self.data]
|
||||
avg_len = sum(len(tokens)
|
||||
for tokens in tokenized_lines) / len(tokenized_lines)
|
||||
|
||||
# Build the base prompt.
|
||||
base_prompt = "Pick as many lines as you can from these poem lines:\n"
|
||||
base_msg = [{"role": "user", "content": base_prompt}]
|
||||
base_fmt = tokenizer.apply_chat_template(base_msg,
|
||||
add_generation_prompt=True,
|
||||
tokenize=False)
|
||||
base_offset = len(tokenizer(base_fmt).input_ids)
|
||||
if input_len <= base_offset:
|
||||
raise ValueError(
|
||||
f"'input_len' must be higher than the base prompt length "
|
||||
f"({base_offset}).")
|
||||
|
||||
# Determine how many poem lines to use.
|
||||
num_input_lines = round((input_len - base_offset) / avg_len)
|
||||
num_prefix_lines = round((prefix_len - base_offset) / avg_len)
|
||||
prefix_lines = self.data[:num_prefix_lines]
|
||||
|
||||
samples = []
|
||||
for _ in range(num_requests):
|
||||
extra_lines = random.choices(self.data,
|
||||
k=num_input_lines - num_prefix_lines)
|
||||
prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}"
|
||||
msg = [{"role": "user", "content": prompt}]
|
||||
prompt_formatted = tokenizer.apply_chat_template(
|
||||
msg, add_generation_prompt=True, tokenize=False)
|
||||
prompt_len = len(tokenizer(prompt_formatted).input_ids)
|
||||
samples.append(
|
||||
SampleRequest(
|
||||
prompt=prompt_formatted
|
||||
if return_prompt_formatted else prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
))
|
||||
return samples
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# BurstGPT Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class BurstGPTDataset(BenchmarkDataset):
|
||||
"""
|
||||
Implements the BurstGPT dataset. Loads data from a CSV file and generates
|
||||
sample requests based on synthetic prompt generation. Only rows with Model
|
||||
"GPT-4" and positive response tokens are used.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.load_data()
|
||||
|
||||
def load_data(self, ):
|
||||
if self.dataset_path is None:
|
||||
raise ValueError("dataset_path must be provided for loading data.")
|
||||
|
||||
df = pd.read_csv(self.dataset_path)
|
||||
# Filter to keep only GPT-4 rows.
|
||||
gpt4_df = df[df["Model"] == "GPT-4"]
|
||||
# Remove failed requests (where Response tokens is 0 or less).
|
||||
gpt4_df = gpt4_df[gpt4_df["Response tokens"] > 0]
|
||||
# Sample the desired number of rows.
|
||||
self.data = gpt4_df
|
||||
|
||||
def _sample_loaded_data(self, num_requests: int) -> list:
|
||||
if num_requests <= len(self.data):
|
||||
data = self.data.sample(n=num_requests,
|
||||
random_state=self.random_seed)
|
||||
else:
|
||||
data = self.data.sample(
|
||||
n=num_requests,
|
||||
random_state=self.random_seed,
|
||||
replace=True,
|
||||
)
|
||||
# Convert the dataframe to a list of lists.
|
||||
return data.values.tolist()
|
||||
|
||||
def sample(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
max_loras: Optional[int] = None,
|
||||
lora_path: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> list[SampleRequest]:
|
||||
samples = []
|
||||
data = self._sample_loaded_data(num_requests=num_requests)
|
||||
for i in range(num_requests):
|
||||
input_len = int(data[i][2])
|
||||
output_len = int(data[i][3])
|
||||
lora_req, tokenizer = self.get_random_lora_request(
|
||||
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path)
|
||||
vocab_size = tokenizer.vocab_size
|
||||
# Generate a synthetic prompt: a list of token IDs computed as (i +
|
||||
# j) modulo vocab_size.
|
||||
token_ids = [(i + j) % vocab_size for j in range(input_len)]
|
||||
prompt = tokenizer.decode(token_ids)
|
||||
samples.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=input_len,
|
||||
expected_output_len=output_len,
|
||||
lora_request=lora_req,
|
||||
))
|
||||
return samples
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# HuggingFace Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class HuggingFaceDataset(BenchmarkDataset):
|
||||
"""
|
||||
Dataset class for processing a HuggingFace dataset with conversation data
|
||||
and optional images.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_split: str,
|
||||
dataset_subset: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.dataset_split = dataset_split
|
||||
self.dataset_subset = dataset_subset
|
||||
|
||||
self.load_data()
|
||||
|
||||
def load_data(self) -> None:
|
||||
if not self.dataset_path:
|
||||
raise ValueError("dataset_path must be provided for loading data.")
|
||||
|
||||
self.data = load_dataset(
|
||||
self.dataset_path,
|
||||
name=self.dataset_subset,
|
||||
split=self.dataset_split,
|
||||
streaming=True,
|
||||
)
|
||||
if self.data.features is None or "conversations" \
|
||||
not in self.data.features:
|
||||
raise ValueError(
|
||||
"HuggingFaceDataset currently only supports datasets with "
|
||||
"a 'conversations' column like lmms-lab/LLaVA-OneVision-Data. "
|
||||
"Please consider contributing if you would like to add "
|
||||
"support for additional dataset formats.")
|
||||
# Shuffle and filter examples with at least 2 conversations.
|
||||
self.data = self.data.shuffle(seed=self.random_seed).filter(
|
||||
lambda x: len(x["conversations"]) >= 2)
|
||||
|
||||
def sample(self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
**kwargs) -> list:
|
||||
sampled_requests = []
|
||||
dynamic_output = output_len is None
|
||||
|
||||
for item in self.data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
conv = item["conversations"]
|
||||
prompt, completion = conv[0]["value"], conv[1]["value"]
|
||||
|
||||
prompt_ids = tokenizer(prompt).input_ids
|
||||
completion_ids = tokenizer(completion).input_ids
|
||||
prompt_len = len(prompt_ids)
|
||||
completion_len = len(completion_ids)
|
||||
output_len = completion_len if dynamic_output else output_len
|
||||
assert isinstance(output_len, int) and output_len > 0
|
||||
if dynamic_output and not is_valid_sequence(
|
||||
prompt_len, completion_len):
|
||||
continue
|
||||
mm_content = process_image(
|
||||
item["image"]) if "image" in item else None
|
||||
if enable_multimodal_chat:
|
||||
# Note: when chat is enabled the request prompt_len is no longer
|
||||
# accurate and we will be using request output to count the
|
||||
# actual prompt len and output len
|
||||
prompt = self.apply_multimodal_chat_transformation(
|
||||
prompt, mm_content)
|
||||
sampled_requests.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
multi_modal_data=mm_content,
|
||||
))
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Vision Arena Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class VisionArenaDataset(HuggingFaceDataset):
|
||||
"""
|
||||
Vision Arena Dataset.
|
||||
"""
|
||||
|
||||
DEFAULT_OUTPUT_LEN = 128
|
||||
VISION_ARENA_DATASET_PATH = "lmarena-ai/vision-arena-bench-v0.1"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
if self.dataset_path != self.VISION_ARENA_DATASET_PATH:
|
||||
raise ValueError(f"Only support Vision Arena dataset.\
|
||||
This data path {self.dataset_path} is not valid.")
|
||||
if self.dataset_subset is None and self.dataset_split != "train":
|
||||
raise ValueError("Dataset split must be 'train'.")
|
||||
|
||||
self.load_data()
|
||||
|
||||
def load_data(self) -> None:
|
||||
dataset = load_dataset(
|
||||
self.dataset_path,
|
||||
name=self.dataset_subset,
|
||||
split=self.dataset_split,
|
||||
streaming=True,
|
||||
)
|
||||
self.data = dataset.shuffle(seed=self.random_seed)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
**kwargs,
|
||||
) -> list:
|
||||
output_len = (output_len
|
||||
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
|
||||
sampled_requests = []
|
||||
for item in self.data:
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
prompt = item["turns"][0][0]["content"]
|
||||
mm_content = process_image(item["images"][0])
|
||||
prompt_len = len(tokenizer(prompt).input_ids)
|
||||
if enable_multimodal_chat:
|
||||
# Note: when chat is enabled the request prompt_len is no longer
|
||||
# accurate and we will be using request output to count the
|
||||
# actual prompt len
|
||||
prompt = self.apply_multimodal_chat_transformation(
|
||||
prompt, mm_content)
|
||||
sampled_requests.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
multi_modal_data=mm_content,
|
||||
))
|
||||
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||
return sampled_requests
|
||||
@ -1,507 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Benchmark guided decoding throughput."""
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
|
||||
import datasets
|
||||
import pandas as pd
|
||||
import uvloop
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.entrypoints.openai.api_server import (
|
||||
build_async_engine_client_from_engine_args)
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SampleRequest:
|
||||
"""A class representing a single inference request for benchmarking.
|
||||
|
||||
Attributes:
|
||||
prompt: The input text prompt for the model.
|
||||
multi_modal_data: Optional dictionary containing multi-modal data (e.g.
|
||||
images).
|
||||
prompt_len: The length of the prompt in tokens.
|
||||
expected_output_len: The expected length of the output in tokens.
|
||||
"""
|
||||
prompt: str
|
||||
prompt_len: int
|
||||
expected_output_len: int
|
||||
schema: dict
|
||||
structure_type: str = 'json'
|
||||
completion: str = None
|
||||
|
||||
|
||||
def run_vllm(requests: list[SampleRequest],
|
||||
engine_args: EngineArgs,
|
||||
n: int,
|
||||
guided_decoding_rate: float = 1.0,
|
||||
warmup: bool = False) -> float:
|
||||
from vllm import LLM, SamplingParams
|
||||
llm = LLM(**vars(engine_args))
|
||||
assert all(
|
||||
llm.llm_engine.model_config.max_model_len >= (
|
||||
request.prompt_len + request.expected_output_len)
|
||||
for request in requests), (
|
||||
"Please ensure that max_model_len is greater than the sum of"
|
||||
" prompt_len and expected_output_len for all requests.")
|
||||
|
||||
# Add the requests to the engine.
|
||||
prompts: list[str] = []
|
||||
sampling_params: list[SamplingParams] = []
|
||||
# create a list containing random selected true or false
|
||||
guided_decoding_req_idx = random.sample(
|
||||
range(len(requests)), int(len(requests) * guided_decoding_rate))
|
||||
|
||||
if warmup:
|
||||
print(">>>>> Running warmup prompt, for the first 5")
|
||||
# We setup the first 5 requests to warmup FSM
|
||||
# if using xgrammar dataset, we will skip warmup
|
||||
warmup_requests = requests[:5]
|
||||
for i, request in enumerate(warmup_requests):
|
||||
prompts.append(request.prompt)
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
guided_decoding=GuidedDecodingParams(json=request.schema)
|
||||
if guided_decoding_rate > 0 else None,
|
||||
))
|
||||
llm.generate(prompts, sampling_params, use_tqdm=False)
|
||||
|
||||
print(">>>>> Benchmark started...")
|
||||
prompts = []
|
||||
sampling_params = []
|
||||
for i, request in enumerate(requests):
|
||||
prompts.append(request.prompt)
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
**{request.structure_type: request.schema})
|
||||
if i in guided_decoding_req_idx else None,
|
||||
))
|
||||
|
||||
start = time.perf_counter()
|
||||
outputs = llm.generate(prompts, sampling_params, use_tqdm=False)
|
||||
ret = []
|
||||
for output, request in zip(outputs, requests):
|
||||
generated_text = output.outputs[0].text
|
||||
ret.append({
|
||||
"generated": generated_text,
|
||||
"expected": request.completion
|
||||
})
|
||||
end = time.perf_counter()
|
||||
return end - start, ret
|
||||
|
||||
|
||||
async def run_vllm_async(
|
||||
requests: list[SampleRequest],
|
||||
engine_args: AsyncEngineArgs,
|
||||
n: int,
|
||||
guided_decoding_rate: float = 1.0,
|
||||
warmup: bool = False,
|
||||
disable_frontend_multiprocessing: bool = False) -> float:
|
||||
from vllm import SamplingParams
|
||||
|
||||
async with build_async_engine_client_from_engine_args(
|
||||
engine_args, disable_frontend_multiprocessing) as llm:
|
||||
|
||||
assert all(
|
||||
llm.model_config.max_model_len >= (request.prompt_len +
|
||||
request.expected_output_len)
|
||||
for request in requests), (
|
||||
"Please ensure that max_model_len is greater than the sum of"
|
||||
" prompt_len and expected_output_len for all requests.")
|
||||
|
||||
# Add the requests to the engine.
|
||||
prompts: list[str] = []
|
||||
sampling_params: list[SamplingParams] = []
|
||||
guided_decoding_req_idx = random.sample(
|
||||
range(len(requests)), int(len(requests) * guided_decoding_rate))
|
||||
|
||||
if warmup:
|
||||
print(">>>>>> Running warmup prompt, for the first 5")
|
||||
# We setup the first 5 requests to warmup FSM
|
||||
# if using xgrammar dataset, we will skip warmup
|
||||
warmup_requests = requests[:5]
|
||||
for i, request in enumerate(warmup_requests):
|
||||
prompts.append(request.prompt)
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
json=request.schema)
|
||||
if guided_decoding_rate > 0 else None,
|
||||
))
|
||||
generators = []
|
||||
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
|
||||
generator = llm.generate(prompt, sp, request_id=f"test{i}")
|
||||
generators.append(generator)
|
||||
all_gens = merge_async_iterators(*generators)
|
||||
async for i, res in all_gens:
|
||||
pass
|
||||
|
||||
print(">>>>> Benchmark started...")
|
||||
prompts = []
|
||||
sampling_params = []
|
||||
for i, request in enumerate(requests):
|
||||
prompts.append(request.prompt)
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
guided_decoding=GuidedDecodingParams(json=request.schema)
|
||||
if i in guided_decoding_req_idx else None,
|
||||
))
|
||||
|
||||
generators = []
|
||||
start_time = []
|
||||
latencies = []
|
||||
start = time.perf_counter()
|
||||
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
|
||||
generator = llm.generate(prompt, sp, request_id=f"test{i}")
|
||||
generators.append(generator)
|
||||
start_time.append(time.perf_counter())
|
||||
latencies.append([])
|
||||
all_gens = merge_async_iterators(*generators)
|
||||
generated_texts = [''] * len(requests)
|
||||
async for i, res in all_gens:
|
||||
generated_texts[i] = res.outputs[0].text
|
||||
lat = time.perf_counter() - start_time[i]
|
||||
latencies[i].append(lat)
|
||||
ret = [{
|
||||
'generated': gt,
|
||||
'expected': req.completion
|
||||
} for gt, req in zip(generated_texts, requests)]
|
||||
end = time.perf_counter()
|
||||
first_latency = pd.Series([lat[0] * 1000 for lat in latencies])
|
||||
next_latency = pd.Series([(lat[-1] - lat[0]) / len(lat[1:]) * 1000
|
||||
for lat in latencies])
|
||||
return end - start, ret, (first_latency, next_latency)
|
||||
|
||||
|
||||
def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||
args: argparse.Namespace) -> list[SampleRequest]:
|
||||
if args.dataset == 'json':
|
||||
if args.json_schema_path is None:
|
||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||
args.json_schema_path = os.path.join(dir_path,
|
||||
"structured_schemas",
|
||||
"structured_schema_1.json")
|
||||
with open(args.json_schema_path) as f:
|
||||
schema = json.load(f)
|
||||
prompt = f"Generate an example of a user profile given the following schema: {json.dumps(schema)}" # noqa: E501
|
||||
input_len = len(tokenizer(prompt).input_ids)
|
||||
print(f"Input length of the prompt: {input_len} tokens")
|
||||
requests = [
|
||||
SampleRequest(prompt=prompt,
|
||||
prompt_len=input_len,
|
||||
expected_output_len=args.output_len,
|
||||
schema=schema,
|
||||
structure_type=args.structure_type)
|
||||
for _ in range(args.num_prompts)
|
||||
]
|
||||
|
||||
elif args.dataset == "grammar":
|
||||
schema = """
|
||||
?start: select_statement
|
||||
|
||||
?select_statement: "SELECT " column_list " FROM " table_name
|
||||
|
||||
?column_list: column_name ("," column_name)*
|
||||
|
||||
?table_name: identifier
|
||||
|
||||
?column_name: identifier
|
||||
|
||||
?identifier: /[a-zA-Z_][a-zA-Z0-9_]*/
|
||||
"""
|
||||
prompt = "Generate an SQL query to show the 'username' \
|
||||
and 'email' from the 'users' table."
|
||||
|
||||
input_len = len(tokenizer(prompt).input_ids)
|
||||
print(f"Input length of the prompt: {input_len} tokens")
|
||||
requests = [
|
||||
SampleRequest(prompt=prompt,
|
||||
prompt_len=input_len,
|
||||
expected_output_len=args.output_len,
|
||||
schema=schema,
|
||||
structure_type=args.structure_type)
|
||||
for _ in range(args.num_prompts)
|
||||
]
|
||||
|
||||
elif args.dataset == "regex":
|
||||
regex = r"\w+@\w+\.com\n"
|
||||
args.regex = regex
|
||||
prompt = "Generate an email address for Alan Turing, \
|
||||
who works in Enigma. End in .com and new line. \
|
||||
Example result: alan.turing@enigma.com\n"
|
||||
|
||||
input_len = len(tokenizer(prompt).input_ids)
|
||||
print(f"Input length of the prompt: {input_len} tokens")
|
||||
requests = [
|
||||
SampleRequest(prompt=prompt,
|
||||
prompt_len=input_len,
|
||||
expected_output_len=args.output_len,
|
||||
schema=regex,
|
||||
structure_type=args.structure_type)
|
||||
for _ in range(args.num_prompts)
|
||||
]
|
||||
|
||||
elif args.dataset == "choice":
|
||||
choice = ["Positive", "Negative"]
|
||||
args.choice = choice
|
||||
prompt = "Classify this sentiment: vLLM is wonderful!"
|
||||
input_len = len(tokenizer(prompt).input_ids)
|
||||
print(f"Input length of the prompt: {input_len} tokens")
|
||||
requests = [
|
||||
SampleRequest(prompt=prompt,
|
||||
prompt_len=input_len,
|
||||
expected_output_len=args.output_len,
|
||||
schema=choice,
|
||||
structure_type=args.structure_type)
|
||||
for _ in range(args.num_prompts)
|
||||
]
|
||||
|
||||
elif args.dataset == "xgrammar_bench":
|
||||
args.warmup = False
|
||||
requests: list[SampleRequest] = []
|
||||
dataset = datasets.load_dataset("NousResearch/json-mode-eval",
|
||||
split="train")
|
||||
print(f"dataset has {len(dataset)} entries")
|
||||
len_dataset = len(dataset)
|
||||
for data_point_idx in range(args.num_prompts):
|
||||
idx = data_point_idx
|
||||
while idx >= len_dataset:
|
||||
idx -= len_dataset
|
||||
schema = dataset["schema"][idx]
|
||||
prompt = tokenizer.apply_chat_template(dataset["prompt"][idx],
|
||||
tokenize=False)
|
||||
input_len = len(tokenizer(prompt).input_ids)
|
||||
completion = dataset["completion"][idx]
|
||||
|
||||
requests.append(
|
||||
SampleRequest(prompt=prompt,
|
||||
prompt_len=input_len,
|
||||
expected_output_len=args.output_len,
|
||||
schema=schema,
|
||||
completion=completion))
|
||||
|
||||
return requests
|
||||
|
||||
|
||||
def evaluate(ret, args):
|
||||
|
||||
def _eval_correctness_json(expected, actual):
|
||||
# extract json string from string using regex
|
||||
import re
|
||||
actual = actual.replace('\n', '').replace(' ', '').strip()
|
||||
try:
|
||||
actual = re.search(r'\{.*\}', actual).group()
|
||||
actual = json.loads(actual)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _eval_correctness_choice(expected, actual):
|
||||
return actual in args.choice
|
||||
|
||||
def _eval_correctness_regex(expected, actual):
|
||||
import re
|
||||
return re.match(args.regex, actual) is not None
|
||||
|
||||
def _eval_correctness(expected, actual):
|
||||
if args.structure_type == 'json':
|
||||
return _eval_correctness_json(expected, actual)
|
||||
elif args.structure_type == 'regex':
|
||||
return _eval_correctness_regex(expected, actual)
|
||||
elif args.structure_type == 'choice':
|
||||
return _eval_correctness_choice(expected, actual)
|
||||
else:
|
||||
return None
|
||||
|
||||
scores = []
|
||||
for res in ret:
|
||||
score = _eval_correctness(res['expected'], res['generated'])
|
||||
res['correctness'] = score
|
||||
scores.append(score)
|
||||
|
||||
not_none_scores = [score for score in scores if score is not None]
|
||||
|
||||
return (sum(not_none_scores) / len(not_none_scores) *
|
||||
100) if len(not_none_scores) > 0 else None
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
print(args)
|
||||
random.seed(args.seed)
|
||||
|
||||
# async engine is working for 'regex', 'choice' and 'grammar'
|
||||
if args.dataset == 'grammar':
|
||||
args.structure_type = 'grammar'
|
||||
args.async_engine = False
|
||||
elif args.dataset == 'regex':
|
||||
args.structure_type = 'regex'
|
||||
args.async_engine = False
|
||||
elif args.dataset == 'choice':
|
||||
args.structure_type = 'choice'
|
||||
args.async_engine = False
|
||||
else:
|
||||
args.structure_type = 'json'
|
||||
|
||||
if args.no_guided_decoding:
|
||||
args.guided_decoding_ratio = 0
|
||||
if args.save_results:
|
||||
result_file_name = f'{args.guided_decoding_ratio}guided'
|
||||
result_file_name += f"_{args.model.split('/')[-1]}"
|
||||
result_file_name += f"_{args.dataset}"
|
||||
result_file_name += f"_{args.num_prompts}"
|
||||
result_file_name += f"_out{args.output_len}"
|
||||
result_file_name += f"_async{args.async_engine}"
|
||||
result_file_name += f"_warmup{args.warmup}"
|
||||
result_file_name += f"_chunkedprefill{args.enable_chunked_prefill}"
|
||||
result_file_name += ".txt"
|
||||
else:
|
||||
result_file_name = None
|
||||
|
||||
# Synthesize a prompt with the given input length.
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||
requests = sample_requests(tokenizer, args)
|
||||
|
||||
if args.async_engine:
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
elapsed_time, ret, (first_latency, next_latency) = uvloop.run(
|
||||
run_vllm_async(requests, engine_args, args.n,
|
||||
args.guided_decoding_ratio, args.warmup,
|
||||
args.disable_frontend_multiprocessing))
|
||||
else:
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
elapsed_time, ret = run_vllm(requests, engine_args, args.n,
|
||||
args.guided_decoding_ratio, args.warmup)
|
||||
first_latency, next_latency = None, None
|
||||
|
||||
score = evaluate(ret, args)
|
||||
total_num_tokens = sum(request.prompt_len + request.expected_output_len
|
||||
for request in requests)
|
||||
total_output_tokens = sum(request.expected_output_len
|
||||
for request in requests)
|
||||
if first_latency is not None:
|
||||
latency_breakdown = "\nFirst token latency(msecs):\n"
|
||||
latency_breakdown += f"{first_latency.describe()}"
|
||||
latency_breakdown += "\nNext token latency(msecs):\n"
|
||||
latency_breakdown += f"{next_latency.describe()}"
|
||||
print(
|
||||
f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
||||
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
|
||||
f"{total_output_tokens / elapsed_time:.2f} output tokens/s",
|
||||
f"Correct rate is {score} %",
|
||||
f"{latency_breakdown if first_latency is not None else ''}")
|
||||
|
||||
# Output JSON results if specified
|
||||
if args.output_json or result_file_name:
|
||||
results = {
|
||||
"elapsed_time": elapsed_time,
|
||||
"num_requests": len(requests),
|
||||
"total_num_tokens": total_num_tokens,
|
||||
"total_output_tokens": total_output_tokens,
|
||||
"requests_per_second": len(requests) / elapsed_time,
|
||||
"tokens_per_second": f"{total_num_tokens / elapsed_time:.2f}",
|
||||
"output_tokens_per_second":
|
||||
f"{total_output_tokens / elapsed_time:.2f}",
|
||||
"correct_rate(%)": score
|
||||
}
|
||||
results = {"outputs": ret, **results}
|
||||
if first_latency is not None:
|
||||
results["first_token_latency(msecs)"] = first_latency.describe(
|
||||
).to_dict()
|
||||
results["next_token_latency(msecs)"] = next_latency.describe(
|
||||
).to_dict()
|
||||
if args.output_json:
|
||||
with open(args.output_json, "w") as f:
|
||||
json.dump(results, f, indent=4)
|
||||
elif result_file_name:
|
||||
with open(result_file_name, "w") as f:
|
||||
json.dump(results, f, indent=4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(description="Benchmark guided decoding.")
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
|
||||
parser.add_argument("--output-len",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Output length for each request. Overrides the "
|
||||
"output length from the dataset.")
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
default='json',
|
||||
choices=['json', 'grammar', 'regex', 'choice', 'xgrammar_bench'])
|
||||
parser.add_argument("--json_schema_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to json schema.")
|
||||
parser.add_argument("--n",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of generated sequences per prompt.")
|
||||
parser.add_argument("--num-prompts",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of prompts to process.")
|
||||
parser.add_argument(
|
||||
'--output-json',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Path to save the throughput results in JSON format.')
|
||||
parser.add_argument("--async-engine",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Use vLLM async engine rather than LLM class.")
|
||||
parser.add_argument("--no-guided-decoding",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Whether to disable JSON decoding or not.")
|
||||
parser.add_argument("--guided-decoding-ratio",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Ratio of Guided Decoding requests")
|
||||
parser.add_argument("--disable-frontend-multiprocessing",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Disable decoupled async engine frontend.")
|
||||
parser.add_argument("--warmup",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Run warmup prompts before benchmark.")
|
||||
parser.add_argument("--save-results",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="save output results.")
|
||||
args = parser.parse_args()
|
||||
if args.tokenizer is None:
|
||||
args.tokenizer = args.model
|
||||
main(args)
|
||||
@ -52,6 +52,7 @@ def main(args: argparse.Namespace):
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=args.output_len,
|
||||
detokenize=not args.disable_detokenize,
|
||||
)
|
||||
print(sampling_params)
|
||||
dummy_prompt_token_ids = np.random.randint(10000,
|
||||
@ -173,6 +174,12 @@ if __name__ == "__main__":
|
||||
default=None,
|
||||
help="Path to save the latency results in JSON format.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-detokenize",
|
||||
action="store_true",
|
||||
help=("Do not detokenize responses (i.e. do not include "
|
||||
"detokenization time in the latency measurement)"),
|
||||
)
|
||||
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -194,7 +194,9 @@ def main(args):
|
||||
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
|
||||
sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
|
||||
sampling_params = SamplingParams(temperature=0,
|
||||
max_tokens=args.output_len,
|
||||
detokenize=not args.disable_detokenize)
|
||||
|
||||
print("Testing filtered requests")
|
||||
prompts = repeat_and_sort_requests(filtered_requests,
|
||||
@ -243,6 +245,12 @@ if __name__ == "__main__":
|
||||
"subtract this length when filtering prompts. Only used "
|
||||
"when dataset-path is not provided.",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--disable-detokenize',
|
||||
action='store_true',
|
||||
help=("Do not detokenize responses (i.e. do not include "
|
||||
"detokenization time in the latency measurement)"),
|
||||
)
|
||||
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -23,7 +23,7 @@ def sample_requests(
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
fixed_output_len: Optional[int],
|
||||
) -> list[tuple[str, int, int]]:
|
||||
) -> list[tuple[str, int, int, int]]:
|
||||
if fixed_output_len is not None and fixed_output_len < 4:
|
||||
raise ValueError("output_len too small")
|
||||
|
||||
@ -71,6 +71,7 @@ def run_vllm(
|
||||
requests: list[tuple[str, int, int]],
|
||||
n: int,
|
||||
engine_args: EngineArgs,
|
||||
disable_detokenize: bool = False,
|
||||
) -> float:
|
||||
from vllm import LLM, SamplingParams
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
@ -95,6 +96,7 @@ def run_vllm(
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=output_len,
|
||||
detokenize=not disable_detokenize,
|
||||
))
|
||||
|
||||
start = time.perf_counter()
|
||||
@ -121,7 +123,8 @@ def main(args: argparse.Namespace):
|
||||
|
||||
if args.backend == "vllm":
|
||||
elapsed_time = run_vllm(requests, args.n,
|
||||
EngineArgs.from_cli_args(args))
|
||||
EngineArgs.from_cli_args(args),
|
||||
args.disable_detokenize)
|
||||
else:
|
||||
raise ValueError(f"Unknown backend: {args.backend}")
|
||||
total_num_tokens = sum(prompt_len + output_len
|
||||
@ -174,6 +177,12 @@ if __name__ == "__main__":
|
||||
type=str,
|
||||
default=None,
|
||||
help='Path to save the throughput results in JSON format.')
|
||||
parser.add_argument(
|
||||
'--disable-detokenize',
|
||||
action='store_true',
|
||||
help=("Do not detokenize responses (i.e. do not include "
|
||||
"detokenization time in the latency measurement)"),
|
||||
)
|
||||
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -25,25 +25,20 @@ On the client side, run:
|
||||
"""
|
||||
import argparse
|
||||
import asyncio
|
||||
import base64
|
||||
import gc
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import warnings
|
||||
from collections.abc import AsyncGenerator, Collection
|
||||
from collections.abc import AsyncGenerator, Iterable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
|
||||
RequestFuncOutput)
|
||||
from datasets import load_dataset
|
||||
from PIL.Image import Image
|
||||
from tqdm.asyncio import tqdm
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
@ -57,6 +52,9 @@ try:
|
||||
except ImportError:
|
||||
from argparse import ArgumentParser as FlexibleArgumentParser
|
||||
|
||||
from benchmark_dataset import (BurstGPTDataset, HuggingFaceDataset,
|
||||
RandomDataset, SampleRequest, ShareGPTDataset,
|
||||
SonnetDataset, VisionArenaDataset)
|
||||
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
|
||||
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
|
||||
@ -92,325 +90,18 @@ class BenchmarkMetrics:
|
||||
percentiles_e2el_ms: list[tuple[float, float]]
|
||||
|
||||
|
||||
def sample_sharegpt_requests(
|
||||
dataset_path: str,
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
fixed_output_len: Optional[int] = None,
|
||||
) -> list[tuple[str, int, int, None]]:
|
||||
# Load the dataset.
|
||||
with open(dataset_path, encoding='utf-8') as f:
|
||||
dataset = json.load(f)
|
||||
# Filter out the conversations with less than 2 turns.
|
||||
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
||||
# Only keep the first two turns of each conversation.
|
||||
dataset = [(data["conversations"][0]["value"],
|
||||
data["conversations"][1]["value"]) for data in dataset]
|
||||
|
||||
# Shuffle the dataset.
|
||||
random.shuffle(dataset)
|
||||
|
||||
# Filter out sequences that are too long or too short
|
||||
filtered_dataset: list[tuple[str, int, int]] = []
|
||||
for i in range(len(dataset)):
|
||||
if len(filtered_dataset) == num_requests:
|
||||
break
|
||||
|
||||
# Tokenize the prompts and completions.
|
||||
prompt = dataset[i][0]
|
||||
prompt_token_ids = tokenizer(prompt).input_ids
|
||||
completion = dataset[i][1]
|
||||
completion_token_ids = tokenizer(completion).input_ids
|
||||
prompt_len = len(prompt_token_ids)
|
||||
output_len = len(completion_token_ids
|
||||
) if fixed_output_len is None else fixed_output_len
|
||||
if prompt_len < 4 or (fixed_output_len is None and output_len < 4):
|
||||
# Prune too short sequences.
|
||||
continue
|
||||
if prompt_len > 1024 or prompt_len + output_len > 2048:
|
||||
# Prune too long sequences.
|
||||
continue
|
||||
filtered_dataset.append((prompt, prompt_len, output_len, None))
|
||||
|
||||
return filtered_dataset
|
||||
|
||||
|
||||
def sample_burstgpt_requests(
|
||||
dataset_path: str,
|
||||
num_requests: int,
|
||||
random_seed: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
) -> list[tuple[str, int, int, None]]:
|
||||
df = pd.read_csv(dataset_path)
|
||||
gpt4_df = df[df["Model"] == "GPT-4"]
|
||||
# Remove the failed requests (i.e., response length is 0)
|
||||
gpt4_df = gpt4_df[gpt4_df["Response tokens"] > 0]
|
||||
# Randomly sample num_requests from the dataset
|
||||
if num_requests <= len(gpt4_df):
|
||||
gpt4_df = gpt4_df.sample(n=num_requests, random_state=random_seed)
|
||||
else:
|
||||
gpt4_df = gpt4_df.sample(n=num_requests,
|
||||
random_state=random_seed,
|
||||
replace=True)
|
||||
# Convert the dataframe to a list of tuples
|
||||
dataset = gpt4_df.values.tolist()
|
||||
input_requests = []
|
||||
for i in range(num_requests):
|
||||
input_len = int(dataset[i][2])
|
||||
output_len = int(dataset[i][3])
|
||||
prompt = tokenizer.decode([(i + j) % tokenizer.vocab_size
|
||||
for j in range(input_len)])
|
||||
input_requests.append((prompt, input_len, output_len, None))
|
||||
return input_requests
|
||||
|
||||
|
||||
def sample_sonnet_requests(
|
||||
dataset_path: str,
|
||||
num_requests: int,
|
||||
input_len: int,
|
||||
output_len: int,
|
||||
prefix_len: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
) -> list[tuple[str, str, int, int, None]]:
|
||||
assert (
|
||||
input_len > prefix_len
|
||||
), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'."
|
||||
|
||||
# Load the dataset.
|
||||
with open(dataset_path, encoding='utf-8') as f:
|
||||
poem_lines = f.readlines()
|
||||
|
||||
# Tokenize the poem lines.
|
||||
poem_token_ids = tokenizer(poem_lines).input_ids
|
||||
average_poem_len = sum(
|
||||
len(token_ids) for token_ids in poem_token_ids) / len(poem_token_ids)
|
||||
|
||||
# Base prefix for all requests.
|
||||
base_prompt = "Pick as many lines as you can from these poem lines:\n"
|
||||
base_message = [{
|
||||
"role": "user",
|
||||
"content": base_prompt,
|
||||
}]
|
||||
base_prompt_formatted = tokenizer.apply_chat_template(
|
||||
base_message, add_generation_prompt=True, tokenize=False)
|
||||
base_prompt_offset = len(tokenizer(base_prompt_formatted).input_ids)
|
||||
|
||||
assert (
|
||||
input_len > base_prompt_offset
|
||||
), f"Please set 'args.sonnet-input-len' higher than {base_prompt_offset}."
|
||||
num_input_lines = round(
|
||||
(input_len - base_prompt_offset) / average_poem_len)
|
||||
|
||||
# First approximately `prefix_len` number of tokens in the
|
||||
# prompt are fixed poem lines.
|
||||
assert (
|
||||
prefix_len > base_prompt_offset
|
||||
), f"Please set 'args.sonnet-prefix-len' higher than {base_prompt_offset}."
|
||||
|
||||
num_prefix_lines = round(
|
||||
(prefix_len - base_prompt_offset) / average_poem_len)
|
||||
prefix_lines = poem_lines[:num_prefix_lines]
|
||||
|
||||
# Sample the rest of lines per request.
|
||||
sampled_requests: list[tuple[str, int, int]] = []
|
||||
for _ in range(num_requests):
|
||||
num_lines_needed = num_input_lines - num_prefix_lines
|
||||
sampled_lines = "".join(prefix_lines +
|
||||
random.choices(poem_lines, k=num_lines_needed))
|
||||
|
||||
prompt = f"{base_prompt}{sampled_lines}"
|
||||
message = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
},
|
||||
]
|
||||
prompt_formatted = tokenizer.apply_chat_template(
|
||||
message, add_generation_prompt=True, tokenize=False)
|
||||
prompt_len = len(tokenizer(prompt_formatted).input_ids)
|
||||
sampled_requests.append(
|
||||
(prompt, prompt_formatted, prompt_len, output_len, None))
|
||||
|
||||
return sampled_requests
|
||||
|
||||
|
||||
def sample_vision_arena_requests(
|
||||
dataset,
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
fixed_output_len: Optional[int] = None,
|
||||
) -> list[tuple[str, str, int, Optional[dict[str, Collection[str]]]]]:
|
||||
sampled_requests: list[tuple[str, int, int, dict[str,
|
||||
Collection[str]]]] = []
|
||||
for data in dataset:
|
||||
if len(sampled_requests) == num_requests:
|
||||
break
|
||||
|
||||
prompt = data["turns"][0][0]['content']
|
||||
|
||||
prompt_token_ids = tokenizer(prompt).input_ids
|
||||
if fixed_output_len is None:
|
||||
# Default max output len is set to 128
|
||||
print("--hf-output-len is not provided. Using default value 128.")
|
||||
fixed_output_len = 128
|
||||
|
||||
prompt_len = len(prompt_token_ids)
|
||||
output_len = fixed_output_len
|
||||
|
||||
assert isinstance(
|
||||
data["images"][0],
|
||||
Image), ("Input image format must be `PIL.Image.Image`, "
|
||||
f"given {type(data['image'])}.")
|
||||
image: Image = data["images"][0]
|
||||
image = image.convert("RGB")
|
||||
image_data = io.BytesIO()
|
||||
image.save(image_data, format='JPEG')
|
||||
image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8")
|
||||
mm_content = {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image_base64}"
|
||||
},
|
||||
}
|
||||
|
||||
sampled_requests.append((prompt, prompt_len, output_len, mm_content))
|
||||
|
||||
return sampled_requests
|
||||
|
||||
|
||||
def sample_hf_requests(
|
||||
dataset_path: str,
|
||||
dataset_subset: Optional[str],
|
||||
dataset_split: str,
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
random_seed: int,
|
||||
fixed_output_len: Optional[int] = None,
|
||||
) -> list[tuple[str, str, int, Optional[dict[str, Collection[str]]]]]:
|
||||
|
||||
# Special case for vision_arena dataset
|
||||
if dataset_path == 'lmarena-ai/vision-arena-bench-v0.1' \
|
||||
and dataset_subset is None:
|
||||
assert dataset_split == "train"
|
||||
dataset = load_dataset(dataset_path,
|
||||
name=dataset_subset,
|
||||
split=dataset_split,
|
||||
streaming=True)
|
||||
dataset = dataset.shuffle(seed=random_seed)
|
||||
return sample_vision_arena_requests(dataset, num_requests, tokenizer,
|
||||
fixed_output_len)
|
||||
|
||||
dataset = load_dataset(dataset_path,
|
||||
name=dataset_subset,
|
||||
split=dataset_split,
|
||||
streaming=True)
|
||||
assert "conversations" in dataset.features, (
|
||||
"HF Dataset must have 'conversations' column.")
|
||||
filter_func = lambda x: len(x["conversations"]) >= 2
|
||||
filtered_dataset = dataset.shuffle(seed=random_seed).filter(filter_func)
|
||||
sampled_requests: list[tuple[str, int, int, dict[str,
|
||||
Collection[str]]]] = []
|
||||
for data in filtered_dataset:
|
||||
if len(sampled_requests) == num_requests:
|
||||
break
|
||||
|
||||
# Tokenize the prompts and completions.
|
||||
prompt = data["conversations"][0]["value"]
|
||||
prompt_token_ids = tokenizer(prompt).input_ids
|
||||
completion = data["conversations"][1]["value"]
|
||||
completion_token_ids = tokenizer(completion).input_ids
|
||||
prompt_len = len(prompt_token_ids)
|
||||
output_len = len(completion_token_ids
|
||||
) if fixed_output_len is None else fixed_output_len
|
||||
if fixed_output_len is None and (prompt_len < 4 or output_len < 4):
|
||||
# Prune too short sequences.
|
||||
continue
|
||||
if fixed_output_len is None and \
|
||||
(prompt_len > 1024 or prompt_len + output_len > 2048):
|
||||
# Prune too long sequences.
|
||||
continue
|
||||
|
||||
if "image" in data and isinstance(data["image"], Image):
|
||||
image: Image = data["image"]
|
||||
image = image.convert("RGB")
|
||||
image_data = io.BytesIO()
|
||||
image.save(image_data, format='JPEG')
|
||||
image_base64 = base64.b64encode(
|
||||
image_data.getvalue()).decode("utf-8")
|
||||
mm_content = {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image_base64}"
|
||||
},
|
||||
}
|
||||
elif "image" in data and isinstance(data["image"], str):
|
||||
if (data["image"].startswith("http://") or \
|
||||
data["image"].startswith("file://")):
|
||||
image_url = data["image"]
|
||||
else:
|
||||
image_url = f"file://{data['image']}"
|
||||
|
||||
mm_content = {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
},
|
||||
}
|
||||
else:
|
||||
mm_content = None
|
||||
|
||||
sampled_requests.append((prompt, prompt_len, output_len, mm_content))
|
||||
|
||||
return sampled_requests
|
||||
|
||||
|
||||
def sample_random_requests(
|
||||
prefix_len: int,
|
||||
input_len: int,
|
||||
output_len: int,
|
||||
num_prompts: int,
|
||||
range_ratio: float,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
) -> list[tuple[str, int, int]]:
|
||||
prefix_token_ids = np.random.randint(0,
|
||||
tokenizer.vocab_size,
|
||||
size=prefix_len).tolist()
|
||||
|
||||
input_lens = np.random.randint(
|
||||
int(input_len * range_ratio),
|
||||
input_len + 1,
|
||||
size=num_prompts,
|
||||
)
|
||||
output_lens = np.random.randint(
|
||||
int(output_len * range_ratio),
|
||||
output_len + 1,
|
||||
size=num_prompts,
|
||||
)
|
||||
offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
|
||||
input_requests = []
|
||||
for i in range(num_prompts):
|
||||
prompt = tokenizer.decode(prefix_token_ids +
|
||||
[(offsets[i] + i + j) % tokenizer.vocab_size
|
||||
for j in range(input_lens[i])])
|
||||
|
||||
input_requests.append((prompt, int(prefix_len + input_lens[i]),
|
||||
int(output_lens[i]), None))
|
||||
|
||||
return input_requests
|
||||
|
||||
|
||||
async def get_request(
|
||||
input_requests: list[tuple[str, int, int]],
|
||||
input_requests: list[SampleRequest],
|
||||
request_rate: float,
|
||||
burstiness: float = 1.0,
|
||||
) -> AsyncGenerator[tuple[str, int, int], None]:
|
||||
) -> AsyncGenerator[SampleRequest, None]:
|
||||
"""
|
||||
Asynchronously generates requests at a specified rate
|
||||
with OPTIONAL burstiness.
|
||||
|
||||
Args:
|
||||
input_requests:
|
||||
A list of input requests, each represented as a tuple.
|
||||
A list of input requests, each represented as a SampleRequest.
|
||||
request_rate:
|
||||
The rate at which requests are generated (requests/s).
|
||||
burstiness (optional):
|
||||
@ -422,7 +113,7 @@ async def get_request(
|
||||
in more bursty requests, while a higher burstiness value
|
||||
(burstiness > 1) results in a more uniform arrival of requests.
|
||||
"""
|
||||
input_requests = iter(input_requests)
|
||||
input_requests: Iterable[SampleRequest] = iter(input_requests)
|
||||
|
||||
# Calculate scale parameter theta to maintain the desired request_rate.
|
||||
assert burstiness > 0, (
|
||||
@ -444,7 +135,7 @@ async def get_request(
|
||||
|
||||
|
||||
def calculate_metrics(
|
||||
input_requests: list[tuple[str, int, int]],
|
||||
input_requests: list[SampleRequest],
|
||||
outputs: list[RequestFuncOutput],
|
||||
dur_s: float,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
@ -475,7 +166,7 @@ def calculate_metrics(
|
||||
tokenizer(outputs[i].generated_text,
|
||||
add_special_tokens=False).input_ids)
|
||||
actual_output_lens.append(output_len)
|
||||
total_input += input_requests[i][1]
|
||||
total_input += input_requests[i].prompt_len
|
||||
tpot = 0
|
||||
if output_len > 1:
|
||||
latency_minus_ttft = outputs[i].latency - outputs[i].ttft
|
||||
@ -558,19 +249,18 @@ async def benchmark(
|
||||
model_id: str,
|
||||
model_name: str,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
input_requests: list[tuple[str, int, int]],
|
||||
input_requests: list[SampleRequest],
|
||||
logprobs: Optional[int],
|
||||
best_of: int,
|
||||
request_rate: float,
|
||||
burstiness: float,
|
||||
disable_tqdm: bool,
|
||||
profile: bool,
|
||||
selected_percentile_metrics: list[str],
|
||||
selected_percentiles: list[str],
|
||||
selected_percentiles: list[float],
|
||||
ignore_eos: bool,
|
||||
goodput_config_dict: dict[str, float],
|
||||
max_concurrency: Optional[int],
|
||||
lora_modules: Optional[list[str]],
|
||||
lora_modules: Optional[Iterable[str]],
|
||||
):
|
||||
if backend in ASYNC_REQUEST_FUNCS:
|
||||
request_func = ASYNC_REQUEST_FUNCS[backend]
|
||||
@ -578,12 +268,16 @@ async def benchmark(
|
||||
raise ValueError(f"Unknown backend: {backend}")
|
||||
|
||||
print("Starting initial single prompt test run...")
|
||||
test_prompt, test_prompt_len, test_output_len, test_mm_content = (
|
||||
input_requests[0])
|
||||
test_prompt, test_prompt_len, test_output_len, test_mm_content = \
|
||||
input_requests[0].prompt, input_requests[0].prompt_len, \
|
||||
input_requests[0].expected_output_len, \
|
||||
input_requests[0].multi_modal_data
|
||||
|
||||
if backend != "openai-chat" and test_mm_content is not None:
|
||||
# multi-modal benchmark is only available on OpenAI Chat backend.
|
||||
raise ValueError(
|
||||
"Multi-modal content is only supported on 'openai-chat' backend.")
|
||||
assert test_mm_content is None or isinstance(test_mm_content, dict)
|
||||
test_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
model_name=model_name,
|
||||
@ -592,7 +286,6 @@ async def benchmark(
|
||||
prompt_len=test_prompt_len,
|
||||
output_len=test_output_len,
|
||||
logprobs=logprobs,
|
||||
best_of=best_of,
|
||||
multi_modal_content=test_mm_content,
|
||||
ignore_eos=ignore_eos,
|
||||
)
|
||||
@ -608,7 +301,8 @@ async def benchmark(
|
||||
if lora_modules:
|
||||
# For each input request, choose a LoRA module at random.
|
||||
lora_modules = iter(
|
||||
[random.choice(lora_modules) for _ in range(len(input_requests))])
|
||||
[random.choice(lora_modules) \
|
||||
for _ in range(len(input_requests))])
|
||||
|
||||
if profile:
|
||||
print("Starting profiler...")
|
||||
@ -619,7 +313,6 @@ async def benchmark(
|
||||
prompt_len=test_prompt_len,
|
||||
output_len=test_output_len,
|
||||
logprobs=logprobs,
|
||||
best_of=best_of,
|
||||
multi_modal_content=test_mm_content,
|
||||
ignore_eos=ignore_eos)
|
||||
profile_output = await request_func(request_func_input=profile_input)
|
||||
@ -655,7 +348,9 @@ async def benchmark(
|
||||
benchmark_start_time = time.perf_counter()
|
||||
tasks: list[asyncio.Task] = []
|
||||
async for request in get_request(input_requests, request_rate, burstiness):
|
||||
prompt, prompt_len, output_len, mm_content = request
|
||||
prompt, prompt_len, output_len, mm_content = request.prompt, \
|
||||
request.prompt_len, request.expected_output_len, \
|
||||
request.multi_modal_data
|
||||
req_model_id, req_model_name = model_id, model_name
|
||||
if lora_modules:
|
||||
req_lora_module = next(lora_modules)
|
||||
@ -668,7 +363,6 @@ async def benchmark(
|
||||
prompt_len=prompt_len,
|
||||
output_len=output_len,
|
||||
logprobs=logprobs,
|
||||
best_of=best_of,
|
||||
multi_modal_content=mm_content,
|
||||
ignore_eos=ignore_eos)
|
||||
tasks.append(
|
||||
@ -686,7 +380,6 @@ async def benchmark(
|
||||
prompt_len=test_prompt_len,
|
||||
output_len=test_output_len,
|
||||
logprobs=logprobs,
|
||||
best_of=best_of,
|
||||
)
|
||||
profile_output = await request_func(request_func_input=profile_input)
|
||||
if profile_output.success:
|
||||
@ -872,76 +565,72 @@ def main(args: argparse.Namespace):
|
||||
"Please specify '--dataset-name' and the corresponding "
|
||||
"'--dataset-path' if required.")
|
||||
|
||||
elif args.dataset_name == "sharegpt":
|
||||
input_requests = sample_sharegpt_requests(
|
||||
dataset_path=args.dataset_path,
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
fixed_output_len=args.sharegpt_output_len,
|
||||
)
|
||||
|
||||
elif args.dataset_name == "burstgpt":
|
||||
input_requests = sample_burstgpt_requests(
|
||||
dataset_path=args.dataset_path,
|
||||
num_requests=args.num_prompts,
|
||||
random_seed=args.seed,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
elif args.dataset_name == "sonnet":
|
||||
# Do not format the prompt, pass to message directly
|
||||
if args.dataset_name == "sonnet":
|
||||
dataset = SonnetDataset(dataset_path=args.dataset_path)
|
||||
# For the "sonnet" dataset, formatting depends on the backend.
|
||||
if args.backend == "openai-chat":
|
||||
input_requests = sample_sonnet_requests(
|
||||
dataset_path=args.dataset_path,
|
||||
num_requests=args.num_prompts,
|
||||
input_len=args.sonnet_input_len,
|
||||
output_len=args.sonnet_output_len,
|
||||
prefix_len=args.sonnet_prefix_len,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
input_requests = [(prompt, prompt_len, output_len, None)
|
||||
for prompt, prompt_formatted, prompt_len,
|
||||
output_len, _ in input_requests]
|
||||
input_requests = dataset.sample(num_requests=args.num_prompts,
|
||||
input_len=args.sonnet_input_len,
|
||||
output_len=args.sonnet_output_len,
|
||||
prefix_len=args.sonnet_prefix_len,
|
||||
tokenizer=tokenizer,
|
||||
return_prompt_formatted=False)
|
||||
else:
|
||||
assert (
|
||||
tokenizer.chat_template or tokenizer.default_chat_template
|
||||
), "Tokenizer/model must have chat template for sonnet dataset."
|
||||
input_requests = sample_sonnet_requests(
|
||||
dataset_path=args.dataset_path,
|
||||
num_requests=args.num_prompts,
|
||||
input_len=args.sonnet_input_len,
|
||||
output_len=args.sonnet_output_len,
|
||||
prefix_len=args.sonnet_prefix_len,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
input_requests = [(prompt_formatted, prompt_len, output_len, None)
|
||||
for prompt, prompt_formatted, prompt_len,
|
||||
output_len, _ in input_requests]
|
||||
assert tokenizer.chat_template or tokenizer.default_chat_template, (
|
||||
"Tokenizer/model must have chat template for sonnet dataset.")
|
||||
input_requests = dataset.sample(num_requests=args.num_prompts,
|
||||
input_len=args.sonnet_input_len,
|
||||
output_len=args.sonnet_output_len,
|
||||
prefix_len=args.sonnet_prefix_len,
|
||||
tokenizer=tokenizer,
|
||||
return_prompt_formatted=True)
|
||||
|
||||
elif args.dataset_name == "hf":
|
||||
input_requests = sample_hf_requests(
|
||||
# Choose between VisionArenaDataset
|
||||
# and HuggingFaceDataset based on provided parameters.
|
||||
dataset_class = (VisionArenaDataset if args.dataset_path
|
||||
== VisionArenaDataset.VISION_ARENA_DATASET_PATH
|
||||
and args.hf_subset is None else HuggingFaceDataset)
|
||||
input_requests = dataset_class(
|
||||
dataset_path=args.dataset_path,
|
||||
dataset_subset=args.hf_subset,
|
||||
dataset_split=args.hf_split,
|
||||
).sample(
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
random_seed=args.seed,
|
||||
fixed_output_len=args.hf_output_len,
|
||||
)
|
||||
|
||||
elif args.dataset_name == "random":
|
||||
input_requests = sample_random_requests(
|
||||
prefix_len=args.random_prefix_len,
|
||||
input_len=args.random_input_len,
|
||||
output_len=args.random_output_len,
|
||||
num_prompts=args.num_prompts,
|
||||
range_ratio=args.random_range_ratio,
|
||||
tokenizer=tokenizer,
|
||||
output_len=args.hf_output_len,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset: {args.dataset_name}")
|
||||
# For datasets that follow a similar structure, use a mapping.
|
||||
dataset_mapping = {
|
||||
"sharegpt":
|
||||
lambda: ShareGPTDataset(random_seed=args.seed,
|
||||
dataset_path=args.dataset_path).sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=args.num_prompts,
|
||||
output_len=args.sharegpt_output_len,
|
||||
),
|
||||
"burstgpt":
|
||||
lambda: BurstGPTDataset(random_seed=args.seed,
|
||||
dataset_path=args.dataset_path).
|
||||
sample(tokenizer=tokenizer, num_requests=args.num_prompts),
|
||||
"random":
|
||||
lambda: RandomDataset(dataset_path=args.dataset_path).sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=args.num_prompts,
|
||||
prefix_len=args.random_prefix_len,
|
||||
input_len=args.random_input_len,
|
||||
output_len=args.random_output_len,
|
||||
range_ratio=args.random_range_ratio,
|
||||
)
|
||||
}
|
||||
|
||||
try:
|
||||
input_requests = dataset_mapping[args.dataset_name]()
|
||||
except KeyError as err:
|
||||
raise ValueError(f"Unknown dataset: {args.dataset_name}") from err
|
||||
goodput_config_dict = check_goodput_args(args)
|
||||
|
||||
# Avoid GC processing "static" data - reduce pause times.
|
||||
@ -958,7 +647,6 @@ def main(args: argparse.Namespace):
|
||||
tokenizer=tokenizer,
|
||||
input_requests=input_requests,
|
||||
logprobs=args.logprobs,
|
||||
best_of=args.best_of,
|
||||
request_rate=args.request_rate,
|
||||
burstiness=args.burstiness,
|
||||
disable_tqdm=args.disable_tqdm,
|
||||
@ -983,7 +671,6 @@ def main(args: argparse.Namespace):
|
||||
result_json["backend"] = backend
|
||||
result_json["model_id"] = model_id
|
||||
result_json["tokenizer_id"] = tokenizer_id
|
||||
result_json["best_of"] = args.best_of
|
||||
result_json["num_prompts"] = args.num_prompts
|
||||
|
||||
# Metadata
|
||||
@ -997,6 +684,15 @@ def main(args: argparse.Namespace):
|
||||
"Invalid metadata format. Please use KEY=VALUE format."
|
||||
)
|
||||
|
||||
if not args.save_detailed:
|
||||
# Remove fields with too many data points
|
||||
for field in [
|
||||
"input_lens", "output_lens", "ttfts", "itls",
|
||||
"generated_texts", "errors"
|
||||
]:
|
||||
if field in result_json:
|
||||
del result_json[field]
|
||||
|
||||
# Traffic
|
||||
result_json["request_rate"] = (args.request_rate if args.request_rate
|
||||
< float("inf") else "inf")
|
||||
@ -1081,13 +777,6 @@ if __name__ == "__main__":
|
||||
help=
|
||||
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
|
||||
)
|
||||
parser.add_argument(
|
||||
"--best-of",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Generates `best_of` sequences per prompt and "
|
||||
"returns the best one.",
|
||||
)
|
||||
parser.add_argument("--use-beam-search", action="store_true")
|
||||
parser.add_argument(
|
||||
"--num-prompts",
|
||||
@ -1148,6 +837,12 @@ if __name__ == "__main__":
|
||||
action="store_true",
|
||||
help="Specify to save benchmark results to a json file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-detailed",
|
||||
action="store_true",
|
||||
help="When saving the results, whether to include per request "
|
||||
"information such as response, error, ttfs, tpots, etc.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata",
|
||||
metavar="KEY=VALUE",
|
||||
@ -1312,4 +1007,5 @@ if __name__ == "__main__":
|
||||
"script chooses a LoRA module at random.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
r"""Benchmark online serving throughput with guided decoding.
|
||||
r"""Benchmark online serving throughput with structured outputs.
|
||||
|
||||
On the server side, run one of the following commands:
|
||||
(vLLM OpenAI API server)
|
||||
@ -9,12 +9,12 @@ On the server side, run one of the following commands:
|
||||
./launch_tgi_server.sh <your_model> <max_batch_total_tokens>
|
||||
|
||||
On the client side, run:
|
||||
python benchmarks/benchmark_serving_guided.py \
|
||||
python benchmarks/benchmark_serving_structured_output.py \
|
||||
--backend <backend> \
|
||||
--model <your_model> \
|
||||
--dataset json \
|
||||
--guided-decoding-ratio 1.0 \
|
||||
--guided-decoding-backend xgrammar \
|
||||
--structured-output-ratio 1.0 \
|
||||
--structured-output-backend xgrammar \
|
||||
--request-rate 10 \
|
||||
--num-prompts 1000
|
||||
|
||||
@ -24,11 +24,13 @@ On the client side, run:
|
||||
"""
|
||||
import argparse
|
||||
import asyncio
|
||||
import copy
|
||||
import dataclasses
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
@ -52,6 +54,9 @@ try:
|
||||
except ImportError:
|
||||
from argparse import ArgumentParser as FlexibleArgumentParser
|
||||
|
||||
from vllm.v1.structured_output.utils import (
|
||||
has_xgrammar_unsupported_json_features)
|
||||
|
||||
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
|
||||
|
||||
|
||||
@ -106,24 +111,43 @@ class SampleRequest:
|
||||
|
||||
def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||
args: argparse.Namespace) -> list[SampleRequest]:
|
||||
if args.dataset == 'json':
|
||||
if args.dataset == 'json' or args.dataset == 'json-unique':
|
||||
if args.json_schema_path is None:
|
||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||
args.json_schema_path = os.path.join(dir_path,
|
||||
"structured_schemas",
|
||||
"structured_schema_1.json")
|
||||
json_schemas = []
|
||||
with open(args.json_schema_path) as f:
|
||||
schema = json.load(f)
|
||||
prompt = f"Generate an example of a user profile given the following schema: {json.dumps(schema)}" # noqa: E501
|
||||
input_len = len(tokenizer(prompt).input_ids)
|
||||
print(f"Input length of the prompt: {input_len} tokens")
|
||||
|
||||
if args.dataset == 'json-unique':
|
||||
json_schemas = [
|
||||
copy.deepcopy(schema) for _ in range(args.num_prompts)
|
||||
]
|
||||
for i in range(len(json_schemas)):
|
||||
json_schemas[i]["properties"][
|
||||
f"__optional_field_{uuid.uuid4()}"] = {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"An unique optional field to avoid cached schemas"
|
||||
}
|
||||
|
||||
def gen_prompt(index: int):
|
||||
schema = json_schemas[index % len(json_schemas)]
|
||||
return f"Generate an example of a user profile given the following schema: {json.dumps(schema)}" # noqa: E501
|
||||
|
||||
def get_schema(index: int):
|
||||
return json_schemas[index % len(json_schemas)]
|
||||
|
||||
requests = [
|
||||
SampleRequest(prompt=prompt,
|
||||
prompt_len=input_len,
|
||||
SampleRequest(prompt=gen_prompt(i),
|
||||
prompt_len=len(tokenizer(gen_prompt(i)).input_ids),
|
||||
expected_output_len=args.output_len,
|
||||
schema=schema,
|
||||
schema=get_schema(i),
|
||||
structure_type=args.structure_type)
|
||||
for _ in range(args.num_prompts)
|
||||
for i in range(args.num_prompts)
|
||||
]
|
||||
|
||||
elif args.dataset == "grammar":
|
||||
@ -191,7 +215,17 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||
requests: list[SampleRequest] = []
|
||||
dataset = datasets.load_dataset("NousResearch/json-mode-eval",
|
||||
split="train")
|
||||
print(f"dataset has {len(dataset)} entries")
|
||||
full_dataset_len = len(dataset)
|
||||
|
||||
def _filter_func(item):
|
||||
import json
|
||||
schema = json.loads(item["schema"])
|
||||
return not has_xgrammar_unsupported_json_features(schema)
|
||||
|
||||
dataset = dataset.filter(_filter_func)
|
||||
num_filtered_out = full_dataset_len - len(dataset)
|
||||
print(f"dataset has {len(dataset)} entries after filtering "
|
||||
f"out {num_filtered_out} entries with unsupported features")
|
||||
len_dataset = len(dataset)
|
||||
for data_point_idx in range(args.num_prompts):
|
||||
idx = data_point_idx
|
||||
@ -220,21 +254,21 @@ async def get_request(
|
||||
burstiness: float = 1.0,
|
||||
) -> AsyncGenerator[tuple[int, SampleRequest], None]:
|
||||
"""
|
||||
Asynchronously generates requests at a specified rate
|
||||
Asynchronously generates requests at a specified rate
|
||||
with OPTIONAL burstiness.
|
||||
|
||||
|
||||
Args:
|
||||
input_requests:
|
||||
input_requests:
|
||||
A list of input requests, each represented as a tuple.
|
||||
request_rate:
|
||||
request_rate:
|
||||
The rate at which requests are generated (requests/s).
|
||||
burstiness (optional):
|
||||
The burstiness factor of the request generation.
|
||||
burstiness (optional):
|
||||
The burstiness factor of the request generation.
|
||||
Only takes effect when request_rate is not inf.
|
||||
Default value is 1, which follows a Poisson process.
|
||||
Otherwise, the request intervals follow a gamma distribution.
|
||||
A lower burstiness value (0 < burstiness < 1) results
|
||||
in more bursty requests, while a higher burstiness value
|
||||
A lower burstiness value (0 < burstiness < 1) results
|
||||
in more bursty requests, while a higher burstiness value
|
||||
(burstiness > 1) results in a more uniform arrival of requests.
|
||||
"""
|
||||
input_requests = iter(input_requests)
|
||||
@ -378,8 +412,8 @@ async def benchmark(
|
||||
selected_percentiles: list[str],
|
||||
ignore_eos: bool,
|
||||
max_concurrency: Optional[int],
|
||||
guided_decoding_ratio: float,
|
||||
guided_decoding_backend: str,
|
||||
structured_output_ratio: float,
|
||||
structured_output_backend: str,
|
||||
goodput_config_dict: Optional[dict[str, float]] = None,
|
||||
):
|
||||
if backend in ASYNC_REQUEST_FUNCS:
|
||||
@ -391,16 +425,18 @@ async def benchmark(
|
||||
extra_body = {}
|
||||
# Add the schema to the extra_body
|
||||
extra_body[request.structure_type] = request.schema
|
||||
# Add the specific guided_decoding_backend
|
||||
extra_body["guided_decoding_backend"] = guided_decoding_backend
|
||||
# Add the specific structured_output_backend
|
||||
extra_body["guided_decoding_backend"] = structured_output_backend
|
||||
return extra_body
|
||||
|
||||
print("Starting initial single prompt test run...")
|
||||
guided_decoding_req_idx = random.sample(
|
||||
structured_output_req_idx = random.sample(
|
||||
range(len(input_requests)),
|
||||
int(len(input_requests) * guided_decoding_ratio))
|
||||
int(len(input_requests) * structured_output_ratio))
|
||||
|
||||
test_request = input_requests[0]
|
||||
test_req_extra_body = (prepare_extra_body(test_request)
|
||||
if 0 in structured_output_req_idx else None)
|
||||
test_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
prompt=test_request.prompt,
|
||||
@ -408,7 +444,7 @@ async def benchmark(
|
||||
prompt_len=test_request.prompt_len,
|
||||
output_len=test_request.expected_output_len,
|
||||
ignore_eos=ignore_eos,
|
||||
extra_body=prepare_extra_body(test_request),
|
||||
extra_body=test_req_extra_body,
|
||||
)
|
||||
test_output = await request_func(request_func_input=test_input)
|
||||
if not test_output.success:
|
||||
@ -427,7 +463,7 @@ async def benchmark(
|
||||
prompt_len=test_request.prompt_len,
|
||||
output_len=test_request.expected_output_len,
|
||||
ignore_eos=ignore_eos,
|
||||
extra_body=prepare_extra_body(test_request),
|
||||
extra_body=test_req_extra_body,
|
||||
)
|
||||
profile_output = await request_func(request_func_input=profile_input)
|
||||
if profile_output.success:
|
||||
@ -465,7 +501,7 @@ async def benchmark(
|
||||
async for i, request in get_request(input_requests, request_rate,
|
||||
burstiness):
|
||||
extra_body = prepare_extra_body(
|
||||
request) if i in guided_decoding_req_idx else None
|
||||
request) if i in structured_output_req_idx else None
|
||||
request_func_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
prompt=request.prompt,
|
||||
@ -696,8 +732,11 @@ def main(args: argparse.Namespace):
|
||||
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
|
||||
base_url = f"http://{args.host}:{args.port}"
|
||||
|
||||
tokenizer = get_tokenizer(tokenizer_id,
|
||||
trust_remote_code=args.trust_remote_code)
|
||||
tokenizer = get_tokenizer(
|
||||
tokenizer_id,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
tokenizer_mode=args.tokenizer_mode,
|
||||
)
|
||||
|
||||
if args.dataset == 'grammar':
|
||||
args.structure_type = 'guided_grammar'
|
||||
@ -708,10 +747,10 @@ def main(args: argparse.Namespace):
|
||||
else:
|
||||
args.structure_type = 'guided_json'
|
||||
|
||||
if args.no_guided_decoding:
|
||||
args.guided_decoding_ratio = 0
|
||||
if args.no_structured_output:
|
||||
args.structured_output_ratio = 0
|
||||
if args.save_results:
|
||||
result_file_name = f'{args.guided_decoding_ratio}guided'
|
||||
result_file_name = f'{args.structured_output_ratio}guided'
|
||||
result_file_name += f"_{backend}"
|
||||
result_file_name += f"_{args.request_rate}qps"
|
||||
result_file_name += f"_{args.model.split('/')[-1]}"
|
||||
@ -744,8 +783,8 @@ def main(args: argparse.Namespace):
|
||||
],
|
||||
ignore_eos=args.ignore_eos,
|
||||
max_concurrency=args.max_concurrency,
|
||||
guided_decoding_ratio=args.guided_decoding_ratio,
|
||||
guided_decoding_backend=args.guided_decoding_backend,
|
||||
structured_output_ratio=args.structured_output_ratio,
|
||||
structured_output_backend=args.structured_output_backend,
|
||||
goodput_config_dict=goodput_config_dict,
|
||||
))
|
||||
|
||||
@ -806,10 +845,12 @@ if __name__ == "__main__":
|
||||
default="/v1/completions",
|
||||
help="API endpoint.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
default='json',
|
||||
choices=['json', 'grammar', 'regex', 'choice', 'xgrammar_bench'])
|
||||
parser.add_argument("--dataset",
|
||||
default='json',
|
||||
choices=[
|
||||
'json', 'json-unique', 'grammar', 'regex',
|
||||
'choice', 'xgrammar_bench'
|
||||
])
|
||||
parser.add_argument("--json_schema_path",
|
||||
type=str,
|
||||
default=None,
|
||||
@ -838,6 +879,13 @@ if __name__ == "__main__":
|
||||
help=
|
||||
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer-mode",
|
||||
type=str,
|
||||
default="auto",
|
||||
help=
|
||||
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-prompts",
|
||||
type=int,
|
||||
@ -943,19 +991,20 @@ if __name__ == "__main__":
|
||||
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
|
||||
"and the blog: https://hao-ai-lab.github.io/blogs/distserve")
|
||||
|
||||
parser.add_argument("--no-guided-decoding",
|
||||
parser.add_argument("--no-structured-output",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Whether to disable JSON decoding or not.")
|
||||
parser.add_argument("--guided-decoding-ratio",
|
||||
parser.add_argument("--structured-output-ratio",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Ratio of Guided Decoding requests")
|
||||
parser.add_argument("--guided-decoding-backend",
|
||||
type=str,
|
||||
choices=["outlines", "lm-format-enforcer", "xgrammar"],
|
||||
default="xgrammar",
|
||||
help="Backend to use for guided decoding")
|
||||
help="Ratio of Structured Outputs requests")
|
||||
parser.add_argument(
|
||||
"--structured-output-backend",
|
||||
type=str,
|
||||
choices=["outlines", "lm-format-enforcer", "xgrammar", "guidance"],
|
||||
default="xgrammar",
|
||||
help="Backend to use for structured outputs")
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@ -6,13 +6,15 @@ import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from functools import cache
|
||||
from typing import Any, Optional
|
||||
import warnings
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import uvloop
|
||||
from benchmark_dataset import (BurstGPTDataset, HuggingFaceDataset,
|
||||
RandomDataset, SampleRequest, ShareGPTDataset,
|
||||
SonnetDataset, VisionArenaDataset)
|
||||
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||
PreTrainedTokenizerBase)
|
||||
@ -20,155 +22,19 @@ from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.entrypoints.openai.api_server import (
|
||||
build_async_engine_client_from_engine_args)
|
||||
from vllm.inputs import TextPrompt
|
||||
from vllm.inputs import TextPrompt, TokensPrompt
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.utils import get_adapter_absolute_path
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer
|
||||
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SampleRequest:
|
||||
"""A class representing a single inference request for benchmarking.
|
||||
|
||||
Attributes:
|
||||
prompt: The input text prompt for the model.
|
||||
prompt_len: The length of the prompt in tokens.
|
||||
expected_output_len: The expected length of the output in tokens.
|
||||
multi_modal_data: Optional dictionary containing multi-modal data (e.g.
|
||||
images).
|
||||
lora_request: Optional LoRARequest specifying the LoRA to use.
|
||||
"""
|
||||
prompt: str
|
||||
prompt_len: int
|
||||
expected_output_len: int
|
||||
multi_modal_data: Optional[MultiModalDataDict] = None
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
|
||||
|
||||
def _get_prompt_for_image_model(question: str, *, model: str) -> str:
|
||||
"""Prepend and append special tokens around the question to form a prompt.
|
||||
|
||||
Args:
|
||||
question: The input question text to wrap with special tokens
|
||||
model: The name of the model being used, to determine which special
|
||||
tokens to add
|
||||
|
||||
Returns:
|
||||
The formatted prompt string with appropriate special tokens for the
|
||||
model
|
||||
|
||||
Raises:
|
||||
ValueError: If an unsupported model name is provided
|
||||
"""
|
||||
model = model.lower()
|
||||
if "pixtral" in model:
|
||||
return f"<s>[INST]{question}\n[IMG][/INST]"
|
||||
raise ValueError(f"Unsupported model {model}")
|
||||
|
||||
|
||||
@cache
|
||||
def lora_path_on_disk(lora_path: str) -> str:
|
||||
return get_adapter_absolute_path(lora_path)
|
||||
|
||||
|
||||
lora_tokenizer_cache: dict[int, AnyTokenizer] = {}
|
||||
|
||||
|
||||
def get_random_lora_request(
|
||||
args: argparse.Namespace
|
||||
) -> tuple[LoRARequest, Optional[AnyTokenizer]]:
|
||||
global lora_tokenizer_cache
|
||||
lora_id = random.randint(1, args.max_loras)
|
||||
lora_request = LoRARequest(lora_name=str(lora_id),
|
||||
lora_int_id=lora_id,
|
||||
lora_path=lora_path_on_disk(args.lora_path))
|
||||
if lora_id not in lora_tokenizer_cache:
|
||||
lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request)
|
||||
return lora_request, lora_tokenizer_cache[lora_id]
|
||||
|
||||
|
||||
def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||
args: argparse.Namespace) -> list[SampleRequest]:
|
||||
|
||||
dataset_path: str = args.dataset
|
||||
num_requests: int = args.num_prompts
|
||||
fixed_output_len: Optional[int] = args.output_len
|
||||
model: str = args.model
|
||||
if fixed_output_len is not None and fixed_output_len < 4:
|
||||
raise ValueError("output_len too small")
|
||||
|
||||
# Load the dataset.
|
||||
with open(dataset_path) as f:
|
||||
dataset = json.load(f)
|
||||
# Filter out the conversations with less than 2 turns.
|
||||
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
||||
# Shuffle the dataset.
|
||||
random.shuffle(dataset)
|
||||
|
||||
# Filter out sequences that are too long or too short
|
||||
filtered_dataset: list[SampleRequest] = []
|
||||
for data in tqdm(dataset,
|
||||
total=len(filtered_dataset),
|
||||
desc="sampling requests"):
|
||||
if len(filtered_dataset) == num_requests:
|
||||
break
|
||||
|
||||
# Only keep the first two turns of each conversation.
|
||||
prompt = data["conversations"][0]["value"]
|
||||
completion = data["conversations"][1]["value"]
|
||||
|
||||
multi_modal_data: Optional[MultiModalDataDict] = None
|
||||
if "image" in data:
|
||||
multi_modal_data = multi_modal_data or {}
|
||||
image_path = data["image"]
|
||||
# TODO(vllm-project/vllm/issues/9778): Support multiple images.
|
||||
assert isinstance(image_path,
|
||||
str), "Only support single image input"
|
||||
try:
|
||||
multi_modal_data["image"] = Image.open(image_path).convert(
|
||||
"RGB")
|
||||
except FileNotFoundError:
|
||||
# Ignore datapoint where asset is missing
|
||||
continue
|
||||
prompt = _get_prompt_for_image_model(question=prompt, model=model)
|
||||
|
||||
request_tokenizer = tokenizer
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
if args.enable_lora:
|
||||
lora_request, lora_tokenizer = get_random_lora_request(args)
|
||||
if lora_tokenizer:
|
||||
request_tokenizer = lora_tokenizer
|
||||
|
||||
# Tokenize the prompts and completions.
|
||||
prompt_token_ids = request_tokenizer(prompt).input_ids
|
||||
completion_token_ids = request_tokenizer(completion).input_ids
|
||||
prompt_len = len(prompt_token_ids)
|
||||
output_len = len(completion_token_ids
|
||||
) if fixed_output_len is None else fixed_output_len
|
||||
if prompt_len < 4 or output_len < 4:
|
||||
# Prune too short sequences.
|
||||
continue
|
||||
if prompt_len > 1024 or prompt_len + output_len > 2048:
|
||||
# Prune too long sequences.
|
||||
continue
|
||||
filtered_dataset.append(
|
||||
SampleRequest(prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
multi_modal_data=multi_modal_data,
|
||||
lora_request=lora_request))
|
||||
|
||||
return filtered_dataset
|
||||
|
||||
|
||||
def run_vllm(
|
||||
requests: list[SampleRequest],
|
||||
n: int,
|
||||
engine_args: EngineArgs,
|
||||
) -> float:
|
||||
disable_detokenize: bool = False,
|
||||
) -> tuple[float, Optional[list[RequestOutput]]]:
|
||||
from vllm import LLM, SamplingParams
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
assert all(
|
||||
@ -178,10 +44,13 @@ def run_vllm(
|
||||
"Please ensure that max_model_len is greater than the sum of"
|
||||
" prompt_len and expected_output_len for all requests.")
|
||||
# Add the requests to the engine.
|
||||
prompts: list[TextPrompt] = []
|
||||
prompts: list[Union[TextPrompt, TokensPrompt]] = []
|
||||
sampling_params: list[SamplingParams] = []
|
||||
for request in requests:
|
||||
prompts.append(
|
||||
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
|
||||
multi_modal_data=request.multi_modal_data)
|
||||
if "prompt_token_ids" in request.prompt else \
|
||||
TextPrompt(prompt=request.prompt,
|
||||
multi_modal_data=request.multi_modal_data))
|
||||
sampling_params.append(
|
||||
@ -191,6 +60,7 @@ def run_vllm(
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
detokenize=not disable_detokenize,
|
||||
))
|
||||
lora_requests: Optional[list[LoRARequest]] = None
|
||||
if engine_args.enable_lora:
|
||||
@ -198,12 +68,13 @@ def run_vllm(
|
||||
|
||||
use_beam_search = False
|
||||
|
||||
outputs = None
|
||||
if not use_beam_search:
|
||||
start = time.perf_counter()
|
||||
llm.generate(prompts,
|
||||
sampling_params,
|
||||
lora_request=lora_requests,
|
||||
use_tqdm=True)
|
||||
outputs = llm.generate(prompts,
|
||||
sampling_params,
|
||||
lora_request=lora_requests,
|
||||
use_tqdm=True)
|
||||
end = time.perf_counter()
|
||||
else:
|
||||
assert lora_requests is None, "BeamSearch API does not support LoRA"
|
||||
@ -221,7 +92,46 @@ def run_vllm(
|
||||
ignore_eos=True,
|
||||
))
|
||||
end = time.perf_counter()
|
||||
return end - start
|
||||
return end - start, outputs
|
||||
|
||||
|
||||
def run_vllm_chat(
|
||||
requests: list[SampleRequest],
|
||||
n: int,
|
||||
engine_args: EngineArgs,
|
||||
disable_detokenize: bool = False) -> tuple[float, list[RequestOutput]]:
|
||||
"""
|
||||
Run vLLM chat benchmark. This function is recommended ONLY for benchmarking
|
||||
multimodal models as it properly handles multimodal inputs and chat
|
||||
formatting. For non-multimodal models, use run_vllm() instead.
|
||||
"""
|
||||
from vllm import LLM, SamplingParams
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
|
||||
assert all(
|
||||
llm.llm_engine.model_config.max_model_len >= (
|
||||
request.prompt_len + request.expected_output_len)
|
||||
for request in requests), (
|
||||
"Please ensure that max_model_len is greater than the sum of "
|
||||
"prompt_len and expected_output_len for all requests.")
|
||||
|
||||
prompts = []
|
||||
sampling_params: list[SamplingParams] = []
|
||||
for request in requests:
|
||||
prompts.append(request.prompt)
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
detokenize=not disable_detokenize,
|
||||
))
|
||||
start = time.perf_counter()
|
||||
outputs = llm.chat(prompts, sampling_params, use_tqdm=True)
|
||||
end = time.perf_counter()
|
||||
return end - start, outputs
|
||||
|
||||
|
||||
async def run_vllm_async(
|
||||
@ -229,6 +139,7 @@ async def run_vllm_async(
|
||||
n: int,
|
||||
engine_args: AsyncEngineArgs,
|
||||
disable_frontend_multiprocessing: bool = False,
|
||||
disable_detokenize: bool = False,
|
||||
) -> float:
|
||||
from vllm import SamplingParams
|
||||
|
||||
@ -242,11 +153,14 @@ async def run_vllm_async(
|
||||
" prompt_len and expected_output_len for all requests.")
|
||||
|
||||
# Add the requests to the engine.
|
||||
prompts: list[TextPrompt] = []
|
||||
prompts: list[Union[TextPrompt, TokensPrompt]] = []
|
||||
sampling_params: list[SamplingParams] = []
|
||||
lora_requests: list[Optional[LoRARequest]] = []
|
||||
for request in requests:
|
||||
prompts.append(
|
||||
TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
|
||||
multi_modal_data=request.multi_modal_data)
|
||||
if "prompt_token_ids" in request.prompt else \
|
||||
TextPrompt(prompt=request.prompt,
|
||||
multi_modal_data=request.multi_modal_data))
|
||||
sampling_params.append(
|
||||
@ -256,6 +170,7 @@ async def run_vllm_async(
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
detokenize=not disable_detokenize,
|
||||
))
|
||||
lora_requests.append(request.lora_request)
|
||||
|
||||
@ -282,6 +197,7 @@ def run_hf(
|
||||
n: int,
|
||||
max_batch_size: int,
|
||||
trust_remote_code: bool,
|
||||
disable_detokenize: bool = False,
|
||||
) -> float:
|
||||
llm = AutoModelForCausalLM.from_pretrained(
|
||||
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
|
||||
@ -321,8 +237,9 @@ def run_hf(
|
||||
use_cache=True,
|
||||
max_new_tokens=max_output_len,
|
||||
)
|
||||
# Include the decoding time.
|
||||
tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
|
||||
if not disable_detokenize:
|
||||
# Include the decoding time.
|
||||
tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
|
||||
pbar.update(len(batch))
|
||||
|
||||
# Clear the batch.
|
||||
@ -369,58 +286,68 @@ def save_to_pytorch_benchmark_format(args: argparse.Namespace,
|
||||
write_to_json(pt_file, pt_records)
|
||||
|
||||
|
||||
def get_requests(args, tokenizer):
|
||||
# Common parameters for all dataset types.
|
||||
common_kwargs = {
|
||||
"dataset_path": args.dataset_path,
|
||||
"random_seed": args.seed,
|
||||
}
|
||||
sample_kwargs = {
|
||||
"tokenizer": tokenizer,
|
||||
"lora_path": args.lora_path,
|
||||
"max_loras": args.max_loras,
|
||||
"num_requests": args.num_prompts,
|
||||
"input_len": args.input_len,
|
||||
"output_len": args.output_len,
|
||||
}
|
||||
if args.dataset_path is None or args.dataset_name == "random":
|
||||
sample_kwargs["range_ratio"] = args.random_range_ratio
|
||||
sample_kwargs["prefix_len"] = args.prefix_len
|
||||
dataset_cls = RandomDataset
|
||||
elif args.dataset_name == "sharegpt":
|
||||
dataset_cls = ShareGPTDataset
|
||||
if args.backend == "vllm-chat":
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
elif args.dataset_name == "sonnet":
|
||||
assert tokenizer.chat_template or tokenizer.default_chat_template, (
|
||||
"Tokenizer/model must have chat template for sonnet dataset.")
|
||||
dataset_cls = SonnetDataset
|
||||
sample_kwargs["prefix_len"] = args.prefix_len
|
||||
sample_kwargs["return_prompt_formatted"] = True
|
||||
elif args.dataset_name == "burstgpt":
|
||||
dataset_cls = BurstGPTDataset
|
||||
elif args.dataset_name == "hf":
|
||||
if args.backend != "vllm-chat":
|
||||
raise ValueError(
|
||||
"hf datasets only are supported by vllm-chat backend")
|
||||
# Choose between VisionArenaDataset and HuggingFaceDataset based on
|
||||
# provided parameters.
|
||||
dataset_cls = (VisionArenaDataset if args.dataset_path
|
||||
== VisionArenaDataset.VISION_ARENA_DATASET_PATH
|
||||
and args.hf_subset is None else HuggingFaceDataset)
|
||||
common_kwargs['dataset_subset'] = args.hf_subset
|
||||
common_kwargs['dataset_split'] = args.hf_split
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
|
||||
# Remove None values
|
||||
sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None}
|
||||
return dataset_cls(**common_kwargs).sample(**sample_kwargs)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
if args.seed is None:
|
||||
args.seed = 0
|
||||
print(args)
|
||||
random.seed(args.seed)
|
||||
|
||||
# Sample the requests.
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||
if args.dataset is None:
|
||||
vocab_size = tokenizer.vocab_size
|
||||
requests = []
|
||||
for _ in range(args.num_prompts):
|
||||
|
||||
request_tokenizer = tokenizer
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
if args.enable_lora:
|
||||
lora_request, lora_tokenizer = get_random_lora_request(args)
|
||||
if lora_tokenizer:
|
||||
request_tokenizer = lora_tokenizer
|
||||
|
||||
# Synthesize a prompt with the given input length.
|
||||
candidate_ids = [
|
||||
random.randint(0, vocab_size - 1)
|
||||
for _ in range(args.input_len)
|
||||
]
|
||||
# As tokenizer may add additional tokens like BOS, we need to try
|
||||
# different lengths to get the desired input length.
|
||||
for _ in range(5): # Max attempts to correct
|
||||
candidate_prompt = request_tokenizer.decode(candidate_ids)
|
||||
tokenized_len = len(request_tokenizer.encode(candidate_prompt))
|
||||
|
||||
if tokenized_len == args.input_len:
|
||||
break
|
||||
|
||||
# Adjust length based on difference
|
||||
diff = args.input_len - tokenized_len
|
||||
if diff > 0:
|
||||
candidate_ids.extend([
|
||||
random.randint(100, vocab_size - 100)
|
||||
for _ in range(diff)
|
||||
])
|
||||
else:
|
||||
candidate_ids = candidate_ids[:diff]
|
||||
requests.append(
|
||||
SampleRequest(prompt=candidate_prompt,
|
||||
prompt_len=args.input_len,
|
||||
expected_output_len=args.output_len,
|
||||
lora_request=lora_request))
|
||||
else:
|
||||
requests = sample_requests(tokenizer, args)
|
||||
|
||||
requests = get_requests(args, tokenizer)
|
||||
is_multi_modal = any(request.multi_modal_data is not None
|
||||
for request in requests)
|
||||
request_outputs: Optional[list[RequestOutput]] = None
|
||||
if args.backend == "vllm":
|
||||
if args.async_engine:
|
||||
elapsed_time = uvloop.run(
|
||||
@ -429,31 +356,59 @@ def main(args: argparse.Namespace):
|
||||
args.n,
|
||||
AsyncEngineArgs.from_cli_args(args),
|
||||
args.disable_frontend_multiprocessing,
|
||||
args.disable_detokenize,
|
||||
))
|
||||
else:
|
||||
elapsed_time = run_vllm(requests, args.n,
|
||||
EngineArgs.from_cli_args(args))
|
||||
elapsed_time, request_outputs = run_vllm(
|
||||
requests, args.n, EngineArgs.from_cli_args(args),
|
||||
args.disable_detokenize)
|
||||
elif args.backend == "hf":
|
||||
assert args.tensor_parallel_size == 1
|
||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||
args.hf_max_batch_size, args.trust_remote_code)
|
||||
args.hf_max_batch_size, args.trust_remote_code,
|
||||
args.disable_detokenize)
|
||||
elif args.backend == "mii":
|
||||
elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
|
||||
args.output_len)
|
||||
elif args.backend == "vllm-chat":
|
||||
elapsed_time, request_outputs = run_vllm_chat(
|
||||
requests, args.n, EngineArgs.from_cli_args(args),
|
||||
args.disable_detokenize)
|
||||
else:
|
||||
raise ValueError(f"Unknown backend: {args.backend}")
|
||||
total_num_tokens = sum(request.prompt_len + request.expected_output_len
|
||||
for request in requests)
|
||||
total_output_tokens = sum(request.expected_output_len
|
||||
for request in requests)
|
||||
if is_multi_modal:
|
||||
print("\033[91mWARNING\033[0m: Multi-modal request detected. The "
|
||||
|
||||
if request_outputs:
|
||||
# Note: with the vllm and vllm-chat backends,
|
||||
# we have request_outputs, which we use to count tokens.
|
||||
total_prompt_tokens = 0
|
||||
total_output_tokens = 0
|
||||
for ro in request_outputs:
|
||||
if not isinstance(ro, RequestOutput):
|
||||
continue
|
||||
total_prompt_tokens += len(
|
||||
ro.prompt_token_ids) if ro.prompt_token_ids else 0
|
||||
total_output_tokens += sum(
|
||||
len(o.token_ids) for o in ro.outputs if o)
|
||||
total_num_tokens = total_prompt_tokens + total_output_tokens
|
||||
else:
|
||||
total_num_tokens = sum(r.prompt_len + r.expected_output_len
|
||||
for r in requests)
|
||||
total_output_tokens = sum(r.expected_output_len for r in requests)
|
||||
total_prompt_tokens = total_num_tokens - total_output_tokens
|
||||
|
||||
if is_multi_modal and args.backend != "vllm-chat":
|
||||
print("\033[91mWARNING\033[0m: Multi-modal request with "
|
||||
f"{args.backend} backend detected. The "
|
||||
"following metrics are not accurate because image tokens are not"
|
||||
" counted. See vllm-project/vllm/issues/9778 for details.")
|
||||
# TODO(vllm-project/vllm/issues/9778): Count molti-modal token length.
|
||||
# TODO(vllm-project/vllm/issues/9778): Count multi-modal token length.
|
||||
# vllm-chat backend counts the image tokens now
|
||||
|
||||
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
||||
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
|
||||
f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
|
||||
print(f"Total num prompt tokens: {total_prompt_tokens}")
|
||||
print(f"Total num output tokens: {total_output_tokens}")
|
||||
|
||||
# Output JSON results if specified
|
||||
if args.output_json:
|
||||
@ -469,18 +424,112 @@ def main(args: argparse.Namespace):
|
||||
save_to_pytorch_benchmark_format(args, results)
|
||||
|
||||
|
||||
def validate_args(args):
|
||||
"""
|
||||
Validate command-line arguments.
|
||||
"""
|
||||
|
||||
# === Deprecation and Defaulting ===
|
||||
if args.dataset is not None:
|
||||
warnings.warn(
|
||||
"The '--dataset' argument will be deprecated in the next release. "
|
||||
"Please use '--dataset-name' and '--dataset-path' instead.",
|
||||
stacklevel=2)
|
||||
args.dataset_path = args.dataset
|
||||
|
||||
if not getattr(args, "tokenizer", None):
|
||||
args.tokenizer = args.model
|
||||
|
||||
# === Backend Validation ===
|
||||
valid_backends = {"vllm", "hf", "mii", "vllm-chat"}
|
||||
if args.backend not in valid_backends:
|
||||
raise ValueError(f"Unsupported backend: {args.backend}")
|
||||
|
||||
# === Dataset Configuration ===
|
||||
if not args.dataset and not args.dataset_path:
|
||||
print(
|
||||
"When dataset path is not set, it will default to random dataset")
|
||||
args.dataset_name = 'random'
|
||||
if args.input_len is None:
|
||||
raise ValueError("input_len must be provided for a random dataset")
|
||||
|
||||
# === Dataset Name Specific Checks ===
|
||||
# --hf-subset and --hf-split: only used
|
||||
# when dataset_name is 'hf'
|
||||
if args.dataset_name != "hf" and (
|
||||
getattr(args, "hf_subset", None) is not None
|
||||
or getattr(args, "hf_split", None) is not None):
|
||||
warnings.warn("--hf-subset and --hf-split will be ignored \
|
||||
since --dataset-name is not 'hf'.",
|
||||
stacklevel=2)
|
||||
elif args.dataset_name == "hf" and args.backend != "vllm-chat":
|
||||
raise ValueError(
|
||||
"When --dataset-name is 'hf', backend must be 'vllm-chat'")
|
||||
|
||||
# --random-range-ratio: only used when dataset_name is 'random'
|
||||
if args.dataset_name != 'random' and args.random_range_ratio is not None:
|
||||
warnings.warn("--random-range-ratio will be ignored since \
|
||||
--dataset-name is not 'random'.",
|
||||
stacklevel=2)
|
||||
|
||||
# --prefix-len: only used when dataset_name is 'random', 'sonnet', or not
|
||||
# set.
|
||||
if args.dataset_name not in {"random", "sonnet", None
|
||||
} and args.prefix_len is not None:
|
||||
warnings.warn("--prefix-len will be ignored since --dataset-name\
|
||||
is not 'random', 'sonnet', or not set.",
|
||||
stacklevel=2)
|
||||
|
||||
# === LoRA Settings ===
|
||||
if getattr(args, "enable_lora", False) and args.backend != "vllm":
|
||||
raise ValueError(
|
||||
"LoRA benchmarking is only supported for vLLM backend")
|
||||
if getattr(args, "enable_lora", False) and args.lora_path is None:
|
||||
raise ValueError("LoRA path must be provided when enable_lora is True")
|
||||
|
||||
# === Backend-specific Validations ===
|
||||
if args.backend == "hf" and args.hf_max_batch_size is None:
|
||||
raise ValueError("HF max batch size is required for HF backend")
|
||||
if args.backend != "hf" and args.hf_max_batch_size is not None:
|
||||
raise ValueError("HF max batch size is only for HF backend.")
|
||||
|
||||
if args.backend in {"hf", "mii"} and getattr(args, "quantization",
|
||||
None) is not None:
|
||||
raise ValueError("Quantization is only for vLLM backend.")
|
||||
|
||||
if args.backend == "mii" and args.dtype != "auto":
|
||||
raise ValueError("dtype must be auto for MII backend.")
|
||||
if args.backend == "mii" and args.n != 1:
|
||||
raise ValueError("n must be 1 for MII backend.")
|
||||
if args.backend == "mii" and args.tokenizer != args.model:
|
||||
raise ValueError(
|
||||
"Tokenizer must be the same as the model for MII backend.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(description="Benchmark the throughput.")
|
||||
parser.add_argument("--backend",
|
||||
type=str,
|
||||
choices=["vllm", "hf", "mii"],
|
||||
choices=["vllm", "hf", "mii", "vllm-chat"],
|
||||
default="vllm")
|
||||
parser.add_argument("--dataset",
|
||||
parser.add_argument(
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"],
|
||||
help="Name of the dataset to benchmark on.",
|
||||
default="sharegpt")
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the ShareGPT dataset, will be deprecated in\
|
||||
the next release. The dataset is expected to "
|
||||
"be a json in form of list[dict[..., conversations: "
|
||||
"list[dict[..., value: <prompt_or_response>]]]]")
|
||||
parser.add_argument("--dataset-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the dataset. The dataset is expected to "
|
||||
"be a json in form of list[dict[..., conversations: "
|
||||
"list[dict[..., value: <prompt_or_response>]]]]")
|
||||
help="Path to the dataset")
|
||||
parser.add_argument("--input-len",
|
||||
type=int,
|
||||
default=None,
|
||||
@ -515,6 +564,11 @@ if __name__ == "__main__":
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Disable decoupled async engine frontend.")
|
||||
parser.add_argument(
|
||||
"--disable-detokenize",
|
||||
action="store_true",
|
||||
help=("Do not detokenize the response (i.e. do not include "
|
||||
"detokenization time in the measurement)"))
|
||||
# LoRA
|
||||
parser.add_argument(
|
||||
"--lora-path",
|
||||
@ -522,43 +576,33 @@ if __name__ == "__main__":
|
||||
default=None,
|
||||
help="Path to the lora adapters to use. This can be an absolute path, "
|
||||
"a relative path, or a Hugging Face model identifier.")
|
||||
parser.add_argument("--prefix-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of prefix tokens per request."
|
||||
"This is for the RandomDataset and SonnetDataset")
|
||||
# random dataset
|
||||
parser.add_argument(
|
||||
"--random-range-ratio",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Range of sampled ratio of input/output length, "
|
||||
"used only for RandomDataSet.",
|
||||
)
|
||||
|
||||
# hf dtaset
|
||||
parser.add_argument("--hf-subset",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Subset of the HF dataset.")
|
||||
parser.add_argument("--hf-split",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Split of the HF dataset.")
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
if args.tokenizer is None:
|
||||
args.tokenizer = args.model
|
||||
if args.dataset is None:
|
||||
assert args.input_len is not None
|
||||
assert args.output_len is not None
|
||||
else:
|
||||
assert args.input_len is None
|
||||
if args.enable_lora:
|
||||
assert args.lora_path is not None
|
||||
|
||||
if args.backend == "vllm":
|
||||
if args.hf_max_batch_size is not None:
|
||||
raise ValueError("HF max batch size is only for HF backend.")
|
||||
elif args.backend == "hf":
|
||||
if args.hf_max_batch_size is None:
|
||||
raise ValueError("HF max batch size is required for HF backend.")
|
||||
if args.quantization is not None:
|
||||
raise ValueError("Quantization is only for vLLM backend.")
|
||||
if args.enable_lora is not None:
|
||||
raise ValueError("LoRA benchmarking is only supported for vLLM"
|
||||
" backend")
|
||||
elif args.backend == "mii":
|
||||
if args.dtype != "auto":
|
||||
raise ValueError("dtype must be auto for MII backend.")
|
||||
if args.n != 1:
|
||||
raise ValueError("n must be 1 for MII backend.")
|
||||
if args.quantization is not None:
|
||||
raise ValueError("Quantization is only for vLLM backend.")
|
||||
if args.hf_max_batch_size is not None:
|
||||
raise ValueError("HF max batch size is only for HF backend.")
|
||||
if args.tokenizer != args.model:
|
||||
raise ValueError("Tokenizer must be the same as the model for MII "
|
||||
"backend.")
|
||||
if args.enable_lora is not None:
|
||||
raise ValueError("LoRA benchmarking is only supported for vLLM"
|
||||
" backend")
|
||||
validate_args(args)
|
||||
main(args)
|
||||
|
||||
340
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
Normal file
340
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
Normal file
@ -0,0 +1,340 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import torch
|
||||
import torch.utils.benchmark as benchmark
|
||||
from benchmark_shapes import WEIGHT_SHAPES_MOE
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe_fp8,
|
||||
fused_experts,
|
||||
fused_topk)
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
DEFAULT_MODELS = [
|
||||
"nm-testing/Mixtral-8x7B-Instruct-v0.1", "nm-testing/deepseekv2-lite",
|
||||
"ibm-granite/granite-3.0-1b-a400m", "ibm-granite/granite-3.0-3b-a800m"
|
||||
]
|
||||
DEFAULT_BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128, 256, 512]
|
||||
DEFAULT_TP_SIZES = [1]
|
||||
|
||||
PER_ACT_TOKEN_OPTS = [False]
|
||||
PER_OUT_CH_OPTS = [False]
|
||||
|
||||
|
||||
def to_fp8(tensor: torch.Tensor):
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
return torch.round(tensor.clamp(
|
||||
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
|
||||
|
||||
|
||||
def bench_run(results: list[benchmark.Measurement], model: str,
|
||||
num_experts: int, topk: int, per_act_token: bool,
|
||||
per_out_ch: bool, mkn: tuple[int, int, int]):
|
||||
label = "Quant Matmul"
|
||||
|
||||
sub_label = (
|
||||
"{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, "
|
||||
"MKN=({})".format(model, num_experts, topk, per_act_token, per_out_ch,
|
||||
mkn))
|
||||
|
||||
print(f"Testing: {sub_label}")
|
||||
|
||||
(m, k, n) = mkn
|
||||
|
||||
dtype = torch.half
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((num_experts, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((num_experts, k, n), device="cuda", dtype=dtype) / 10
|
||||
|
||||
_, a_scale = ops.scaled_fp8_quant(a)
|
||||
|
||||
w1_q = torch.empty((num_experts, 2 * n, k),
|
||||
device="cuda",
|
||||
dtype=torch.float8_e4m3fn)
|
||||
w2_q = torch.empty((num_experts, k, n),
|
||||
device="cuda",
|
||||
dtype=torch.float8_e4m3fn)
|
||||
w1_scale = torch.empty((num_experts, 1, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
w2_scale = torch.empty((num_experts, 1, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
|
||||
ab_strides1 = torch.full((num_experts, ),
|
||||
k,
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
c_strides1 = torch.full((num_experts, ),
|
||||
2 * n,
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
ab_strides2 = torch.full((num_experts, ),
|
||||
n,
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
c_strides2 = torch.full((num_experts, ),
|
||||
k,
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
|
||||
for expert in range(num_experts):
|
||||
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert])
|
||||
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert])
|
||||
w1_q_notransp = w1_q.clone()
|
||||
w2_q_notransp = w2_q.clone()
|
||||
w1_q = w1_q.transpose(1, 2)
|
||||
w2_q = w2_q.transpose(1, 2)
|
||||
|
||||
score = torch.randn((m, num_experts), device="cuda", dtype=dtype)
|
||||
|
||||
topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False)
|
||||
|
||||
def run_triton_moe(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
|
||||
a_scale: torch.Tensor, num_repeats: int):
|
||||
for _ in range(num_repeats):
|
||||
fused_experts(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a_scale)
|
||||
|
||||
def run_cutlass_moe(a: torch.Tensor, a_scale: torch.Tensor,
|
||||
w1: torch.Tensor, w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
ab_strides1: torch.Tensor, c_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor, c_strides2: torch.Tensor,
|
||||
num_repeats: int):
|
||||
for _ in range(num_repeats):
|
||||
cutlass_moe_fp8(a,
|
||||
w1,
|
||||
w2,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
ab_strides1,
|
||||
c_strides1,
|
||||
ab_strides2,
|
||||
c_strides2,
|
||||
a1_scale=a_scale)
|
||||
|
||||
def run_cutlass_from_graph(
|
||||
a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor,
|
||||
w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
ab_strides1: torch.Tensor, c_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor, c_strides2: torch.Tensor):
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(
|
||||
pipeline_parallel_size=1))):
|
||||
return cutlass_moe_fp8(a,
|
||||
w1_q,
|
||||
w2_q,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
ab_strides1,
|
||||
c_strides1,
|
||||
ab_strides2,
|
||||
c_strides2,
|
||||
a1_scale=a_scale)
|
||||
|
||||
def run_triton_from_graph(a: torch.Tensor, w1: torch.Tensor,
|
||||
w2: torch.Tensor, topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor, w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor, a_scale: torch.Tensor):
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(
|
||||
pipeline_parallel_size=1))):
|
||||
return fused_experts(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a_scale)
|
||||
|
||||
def replay_graph(graph, num_repeats):
|
||||
for _ in range(num_repeats):
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
cutlass_stream = torch.cuda.Stream()
|
||||
cutlass_graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(cutlass_graph, stream=cutlass_stream):
|
||||
run_cutlass_from_graph(a, a_scale, w1_q, w2_q, w1_scale, w2_scale,
|
||||
topk_weights, topk_ids, ab_strides1, c_strides1,
|
||||
ab_strides2, c_strides2)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
triton_stream = torch.cuda.Stream()
|
||||
triton_graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(triton_graph, stream=triton_stream):
|
||||
run_triton_from_graph(a, w1_q_notransp, w2_q_notransp, topk_weights,
|
||||
topk_ids, w1_scale, w2_scale, a_scale)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
min_run_time = 5
|
||||
num_warmup = 5
|
||||
num_runs = 25
|
||||
|
||||
globals = {
|
||||
# Baseline params
|
||||
"w1": w1,
|
||||
"w2": w2,
|
||||
"score": score,
|
||||
"topk": topk,
|
||||
"w1_q_notransp": w1_q_notransp,
|
||||
"w2_q_notransp": w2_q_notransp,
|
||||
# Cutlass params
|
||||
"a_scale": a_scale,
|
||||
"w1_q": w1_q,
|
||||
"w2_q": w2_q,
|
||||
"w1_scale": w1_scale,
|
||||
"w2_scale": w2_scale,
|
||||
"ab_strides1": ab_strides1,
|
||||
"c_strides1": c_strides1,
|
||||
"ab_strides2": ab_strides2,
|
||||
"c_strides2": c_strides2,
|
||||
# cuda graph params
|
||||
"cutlass_graph": cutlass_graph,
|
||||
"triton_graph": triton_graph,
|
||||
# Gen params
|
||||
"a": a,
|
||||
"topk_weights": topk_weights,
|
||||
"topk_ids": topk_ids,
|
||||
"num_runs": num_runs,
|
||||
# Kernels
|
||||
"run_triton_moe": run_triton_moe,
|
||||
"run_cutlass_moe": run_cutlass_moe,
|
||||
"replay_graph": replay_graph,
|
||||
}
|
||||
|
||||
# Warmup
|
||||
run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids,
|
||||
w1_scale, w2_scale, a_scale, num_warmup)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt=
|
||||
"run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="triton_moe",
|
||||
).blocked_autorange(min_run_time=min_run_time))
|
||||
|
||||
# Warmup
|
||||
replay_graph(triton_graph, num_warmup)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="replay_graph(triton_graph, num_runs)",
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="triton_moe_cuda_graphs",
|
||||
).blocked_autorange(min_run_time=min_run_time))
|
||||
|
||||
# Warmup
|
||||
run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights,
|
||||
topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2,
|
||||
num_warmup)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt=
|
||||
"run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="grouped_gemm_moe",
|
||||
).blocked_autorange(min_run_time=min_run_time))
|
||||
|
||||
# Warmup
|
||||
replay_graph(cutlass_graph, num_warmup)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="replay_graph(cutlass_graph, num_runs)",
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="grouped_gemm_moe_cuda_graphs",
|
||||
).blocked_autorange(min_run_time=min_run_time))
|
||||
|
||||
|
||||
def main(args):
|
||||
print("Benchmarking models:")
|
||||
for i, model in enumerate(args.models):
|
||||
print(f"[{i}] {model}")
|
||||
|
||||
results: list[benchmark.Measurement] = []
|
||||
|
||||
for model in args.models:
|
||||
for tp in args.tp_sizes:
|
||||
for layer in WEIGHT_SHAPES_MOE[model]:
|
||||
num_experts = layer[0]
|
||||
topk = layer[1]
|
||||
size_k = layer[2]
|
||||
size_n = layer[3] // tp
|
||||
|
||||
if len(args.limit_k) > 0 and size_k not in args.limit_k:
|
||||
continue
|
||||
|
||||
if len(args.limit_n) > 0 and size_n not in args.limit_n:
|
||||
continue
|
||||
|
||||
for per_act_token in PER_ACT_TOKEN_OPTS:
|
||||
for per_out_ch in PER_OUT_CH_OPTS:
|
||||
for size_m in DEFAULT_BATCH_SIZES:
|
||||
mkn = (size_m, size_k, size_n)
|
||||
bench_run(results, model, num_experts, topk,
|
||||
per_act_token, per_out_ch, mkn)
|
||||
|
||||
compare = benchmark.Compare(results)
|
||||
compare.print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark Marlin across specified models/shapes/batches")
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=DEFAULT_MODELS,
|
||||
choices=WEIGHT_SHAPES_MOE.keys(),
|
||||
)
|
||||
parser.add_argument("--tp-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=DEFAULT_TP_SIZES)
|
||||
parser.add_argument("--batch-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=DEFAULT_BATCH_SIZES)
|
||||
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[])
|
||||
parser.add_argument("--limit-per-act-token",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[])
|
||||
parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[])
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@ -40,7 +40,7 @@ def main(num_tokens: int,
|
||||
|
||||
end_time = time.perf_counter()
|
||||
if profile:
|
||||
torch.cuda.cudart().cudaProfilerStart()
|
||||
torch.cuda.cudart().cudaProfilerStop()
|
||||
return (end_time - start_time) / num_iters
|
||||
|
||||
# Warmup.
|
||||
|
||||
@ -17,11 +17,7 @@ from torch.utils.benchmark import Measurement as TMeasurement
|
||||
from utils import ArgPool, Bench, CudaGraphBenchParams
|
||||
from weight_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm.lora.ops.triton_ops.bgmv_expand import bgmv_expand
|
||||
from vllm.lora.ops.triton_ops.bgmv_expand_slice import bgmv_expand_slice
|
||||
from vllm.lora.ops.triton_ops.bgmv_shrink import bgmv_shrink
|
||||
from vllm.lora.ops.triton_ops.sgmv_expand import sgmv_expand
|
||||
from vllm.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink
|
||||
from vllm.lora.ops.triton_ops import LoRAKernelMeta, lora_expand, lora_shrink
|
||||
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
@ -153,7 +149,6 @@ def ref_group_gemm(ref_out: torch.Tensor, input: torch.Tensor,
|
||||
result = torch.nn.functional.linear(x, w)
|
||||
result *= scaling
|
||||
out_list.append(result)
|
||||
torch.cat(out_list, dim=0)
|
||||
|
||||
cat_result = torch.cat(out_list, dim=0)
|
||||
|
||||
@ -167,52 +162,25 @@ class OpType(Enum):
|
||||
"""
|
||||
LoRA Ops to benchmark and its properties.
|
||||
"""
|
||||
SGMV_SHRINK = auto()
|
||||
BGMV_SHRINK = auto()
|
||||
SGMV_EXPAND = auto()
|
||||
BGMV_EXPAND = auto()
|
||||
BGMV_EXPAND_SLICE = auto()
|
||||
LORA_SHRINK = auto()
|
||||
LORA_EXPAND = auto()
|
||||
|
||||
@staticmethod
|
||||
def from_str(s: str) -> "OpType":
|
||||
if s.lower() == 'sgmv_shrink':
|
||||
return OpType.SGMV_SHRINK
|
||||
if s.lower() == 'sgmv_expand':
|
||||
return OpType.SGMV_EXPAND
|
||||
if s.lower() == 'bgmv_shrink':
|
||||
return OpType.BGMV_SHRINK
|
||||
if s.lower() == 'bgmv_expand':
|
||||
return OpType.BGMV_EXPAND
|
||||
if s.lower() == "bgmv_expand_slice":
|
||||
return OpType.BGMV_EXPAND_SLICE
|
||||
if s.lower() == "lora_shrink":
|
||||
return OpType.LORA_SHRINK
|
||||
if s.lower() == "lora_expand":
|
||||
return OpType.LORA_EXPAND
|
||||
raise ValueError(f"Unrecognized str {s} to convert to OpType")
|
||||
|
||||
def is_shrink_fn(self) -> bool:
|
||||
return self in [OpType.SGMV_SHRINK, OpType.BGMV_SHRINK]
|
||||
return self in [OpType.LORA_SHRINK]
|
||||
|
||||
def is_expand_fn(self) -> bool:
|
||||
return self in [OpType.SGMV_EXPAND, OpType.BGMV_EXPAND]
|
||||
|
||||
def is_prefill_op(self) -> bool:
|
||||
return self in [OpType.SGMV_SHRINK, OpType.SGMV_EXPAND]
|
||||
|
||||
def is_decode_op(self) -> bool:
|
||||
return self in [
|
||||
OpType.BGMV_SHRINK, OpType.BGMV_EXPAND, OpType.BGMV_EXPAND_SLICE
|
||||
]
|
||||
|
||||
def is_expand_slice_fn(self) -> bool:
|
||||
return self in [OpType.BGMV_EXPAND_SLICE]
|
||||
return self in [OpType.LORA_EXPAND]
|
||||
|
||||
def num_slices(self) -> list[int]:
|
||||
if self in [OpType.SGMV_EXPAND, OpType.SGMV_SHRINK]:
|
||||
# SGMV kernels supports slices
|
||||
return [1, 2, 3]
|
||||
if self in [OpType.BGMV_SHRINK, OpType.BGMV_EXPAND]:
|
||||
return [1]
|
||||
if self in [OpType.BGMV_EXPAND_SLICE]:
|
||||
return [2, 3]
|
||||
raise ValueError(f"Unrecognized OpType {self}")
|
||||
return [1, 2, 3]
|
||||
|
||||
def mkn(self, batch_size: int, seq_length: int, hidden_size: int,
|
||||
lora_rank: int) -> tuple[int, int, int]:
|
||||
@ -222,7 +190,7 @@ class OpType(Enum):
|
||||
k = hidden_size
|
||||
n = lora_rank
|
||||
else:
|
||||
assert self.is_expand_fn() or self.is_expand_slice_fn()
|
||||
assert self.is_expand_fn()
|
||||
m = num_tokens
|
||||
k = lora_rank
|
||||
n = hidden_size
|
||||
@ -237,7 +205,7 @@ class OpType(Enum):
|
||||
if self.is_shrink_fn():
|
||||
return op_dtype, op_dtype, torch.float32
|
||||
else:
|
||||
assert self.is_expand_fn() or self.is_expand_slice_fn()
|
||||
assert self.is_expand_fn()
|
||||
return torch.float32, op_dtype, op_dtype
|
||||
|
||||
def matmul_shapes(
|
||||
@ -251,56 +219,39 @@ class OpType(Enum):
|
||||
m, k, n = self.mkn(batch_size, seq_length, hidden_size, lora_rank)
|
||||
|
||||
b_shape = (num_loras, n, k) # col-major
|
||||
if self == OpType.SGMV_SHRINK:
|
||||
# SGMV shrink supports num_slices inherently in the kernel
|
||||
if self in [OpType.LORA_SHRINK]:
|
||||
# LoRA shrink kernels support num_slices inherently in the kernel.
|
||||
return ((m, k), b_shape, (num_slices, m, n))
|
||||
if self == OpType.SGMV_EXPAND:
|
||||
# SGMV expand supports num_slices inherently in the kernel
|
||||
if self in [OpType.LORA_EXPAND]:
|
||||
# LoRA expand kernels support num_slices inherently in the kernel
|
||||
return ((num_slices, m, k), b_shape, (m, n * num_slices))
|
||||
if self == OpType.BGMV_SHRINK:
|
||||
return ((m, k), b_shape, (m, n))
|
||||
if self == OpType.BGMV_EXPAND:
|
||||
return ((m, k), b_shape, (m, n))
|
||||
if self == OpType.BGMV_EXPAND_SLICE:
|
||||
return ((num_slices, m, k), b_shape, (m, n * num_slices))
|
||||
|
||||
raise ValueError(f"Unrecognized op_type {self}")
|
||||
|
||||
def bench_fn(self) -> Callable:
|
||||
if self == OpType.LORA_SHRINK:
|
||||
return lora_shrink
|
||||
if self == OpType.LORA_EXPAND:
|
||||
return lora_expand
|
||||
|
||||
def emulate_bgmv_expand_slice(kwargs_list: list[dict[str, Any]]):
|
||||
for x in kwargs_list:
|
||||
bgmv_expand_slice(**x)
|
||||
|
||||
if self == OpType.SGMV_SHRINK:
|
||||
return sgmv_shrink
|
||||
if self == OpType.SGMV_EXPAND:
|
||||
return sgmv_expand
|
||||
if self == OpType.BGMV_SHRINK:
|
||||
return bgmv_shrink
|
||||
if self == OpType.BGMV_EXPAND:
|
||||
return bgmv_expand
|
||||
if self == OpType.BGMV_EXPAND_SLICE:
|
||||
return emulate_bgmv_expand_slice
|
||||
raise ValueError(f"Unrecognized optype {self}")
|
||||
|
||||
def run_ref_group_gemm(self, output: torch.Tensor, input: torch.Tensor,
|
||||
lora_weights: list[torch.Tensor],
|
||||
**kwargs) -> Callable:
|
||||
"""Each benchmark operation expected the input, lora_weights and outputs
|
||||
"""Each benchmark operation expects the input, lora_weights and outputs
|
||||
in a slightly different format. Refer to self.matmul_shapes().
|
||||
run_ref_group_gemm accounts for those differences in executing a
|
||||
reference group gemm for correctness testing.
|
||||
"""
|
||||
w_dtype = lora_weights[0].dtype
|
||||
num_slices = len(lora_weights)
|
||||
if self == OpType.SGMV_SHRINK:
|
||||
if self in [OpType.LORA_SHRINK]:
|
||||
for slice_idx in range(num_slices):
|
||||
ref_group_gemm(ref_out=output[slice_idx, :],
|
||||
input=input,
|
||||
lora_weights=lora_weights[slice_idx],
|
||||
**kwargs)
|
||||
if self == OpType.SGMV_EXPAND:
|
||||
elif self in [OpType.LORA_EXPAND]:
|
||||
hidden_size = lora_weights[0].shape[1]
|
||||
for slice_idx in range(num_slices):
|
||||
slice_offset = slice_idx * hidden_size
|
||||
@ -309,28 +260,8 @@ class OpType(Enum):
|
||||
input=input[slice_idx].clone().to(dtype=w_dtype),
|
||||
lora_weights=lora_weights[slice_idx],
|
||||
**kwargs)
|
||||
if self == OpType.BGMV_SHRINK:
|
||||
assert num_slices == 1
|
||||
ref_group_gemm(ref_out=output,
|
||||
input=input,
|
||||
lora_weights=lora_weights[0],
|
||||
**kwargs)
|
||||
if self == OpType.BGMV_EXPAND:
|
||||
assert num_slices == 1
|
||||
ref_group_gemm(ref_out=output,
|
||||
input=input.clone().to(dtype=w_dtype),
|
||||
lora_weights=lora_weights[0],
|
||||
**kwargs)
|
||||
if self == OpType.BGMV_EXPAND_SLICE:
|
||||
hidden_size = lora_weights[0].shape[1]
|
||||
for slice_idx in range(num_slices):
|
||||
slice_offset = slice_idx * hidden_size
|
||||
ref_group_gemm(
|
||||
ref_out=output[:, slice_offset:slice_offset + hidden_size],
|
||||
input=input[slice_idx].clone().to(dtype=w_dtype),
|
||||
lora_weights=lora_weights[slice_idx],
|
||||
**kwargs)
|
||||
raise ValueError(f"Unrecognized optype {self}")
|
||||
else:
|
||||
raise ValueError(f"Unrecognized optype {self}")
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -386,11 +317,11 @@ class BenchmarkTensors:
|
||||
input: torch.Tensor
|
||||
lora_weights_lst: list[torch.Tensor]
|
||||
output: torch.Tensor
|
||||
# metadata tensors
|
||||
# LoRA kernel metadata
|
||||
lora_kernel_meta: LoRAKernelMeta
|
||||
# Metadata tensors used in testing correctness
|
||||
seq_lens: torch.Tensor
|
||||
seq_start_loc: torch.Tensor
|
||||
prompt_lora_mapping: torch.Tensor
|
||||
token_lora_mapping: torch.Tensor
|
||||
|
||||
def io_types(self) -> str:
|
||||
return (f"{dtype_to_str(self.input.dtype)}x"
|
||||
@ -417,26 +348,29 @@ class BenchmarkTensors:
|
||||
assert ctx.num_active_loras <= ctx.num_loras
|
||||
total_tokens = ctx.batch_size * ctx.seq_length
|
||||
|
||||
# Make metadata tensors involved in correctness testing.
|
||||
# Prepare seq lens tensor
|
||||
seq_len_tensor = torch.randint(ctx.seq_length, ctx.seq_length + 1,
|
||||
(ctx.batch_size, ))
|
||||
# Prepare seq_start_loc tensor
|
||||
seq_start_loc_tensor = torch.cumsum(torch.tensor(
|
||||
[0] + seq_len_tensor[:-1].tolist(), dtype=torch.long),
|
||||
dim=0)
|
||||
assert total_tokens == seq_len_tensor.sum()
|
||||
# Prepare prompt lora indices tensor
|
||||
prompt_lora_indices_tensor = make_prompt_lora_mapping(
|
||||
ctx.batch_size, ctx.num_active_loras, ctx.sort_by_lora_id, "cpu")
|
||||
# Prepare token lora indices tensor
|
||||
|
||||
# Make LoRAKernelMeta
|
||||
token_lora_indices_tensor = make_token_lora_mapping(
|
||||
total_tokens, ctx.batch_size, prompt_lora_indices_tensor,
|
||||
seq_len_tensor, "cpu")
|
||||
lora_kernel_meta = LoRAKernelMeta.make(
|
||||
max_loras=ctx.num_loras,
|
||||
max_num_tokens=token_lora_indices_tensor.size(0),
|
||||
device="cpu")
|
||||
lora_kernel_meta.prepare_tensors(
|
||||
token_lora_mapping=token_lora_indices_tensor)
|
||||
|
||||
return BenchmarkTensors(input_tensor, lora_weights, output_tensor,
|
||||
seq_len_tensor, seq_start_loc_tensor,
|
||||
prompt_lora_indices_tensor,
|
||||
token_lora_indices_tensor)
|
||||
lora_kernel_meta, seq_len_tensor,
|
||||
prompt_lora_indices_tensor)
|
||||
|
||||
def sanity_check(self) -> None:
|
||||
"""
|
||||
@ -446,9 +380,9 @@ class BenchmarkTensors:
|
||||
# check metadata tensors
|
||||
assert torch.sum(self.seq_lens) == num_tokens
|
||||
num_seqs = self.seq_lens.shape[0]
|
||||
assert self.seq_start_loc.shape[0] == num_seqs
|
||||
#assert self.seq_start_loc.shape[0] == num_seqs
|
||||
assert self.prompt_lora_mapping.shape[0] == num_seqs
|
||||
assert self.token_lora_mapping.shape[0] == num_tokens
|
||||
assert self.lora_kernel_meta.token_lora_mapping.shape[0] == num_tokens
|
||||
|
||||
def to_device(self, device: str):
|
||||
"""
|
||||
@ -463,54 +397,31 @@ class BenchmarkTensors:
|
||||
self.input = to_device(self.input)
|
||||
self.output = to_device(self.output)
|
||||
self.seq_lens = to_device(self.seq_lens)
|
||||
self.seq_start_loc = to_device(self.seq_start_loc)
|
||||
self.prompt_lora_mapping = to_device(self.prompt_lora_mapping)
|
||||
self.token_lora_mapping = to_device(self.token_lora_mapping)
|
||||
for i in range(len(self.lora_weights_lst)):
|
||||
self.lora_weights_lst[i] = to_device(self.lora_weights_lst[i])
|
||||
|
||||
# LoRA meta
|
||||
for field_name in LoRAKernelMeta.__dataclass_fields__:
|
||||
field = getattr(self.lora_kernel_meta, field_name)
|
||||
assert isinstance(field, torch.Tensor)
|
||||
setattr(self.lora_kernel_meta, field_name, to_device(field))
|
||||
|
||||
def metadata(self) -> tuple[int, int, int]:
|
||||
"""
|
||||
Return num_seqs, num_tokens and max_seq_len
|
||||
"""
|
||||
num_seqs = self.seq_lens.shape[0]
|
||||
num_tokens = self.token_lora_mapping.shape[0]
|
||||
num_tokens = self.lora_kernel_meta.token_lora_mapping.shape[0]
|
||||
max_seq_len = torch.max(self.seq_lens).item()
|
||||
num_slices = len(self.lora_weights_lst)
|
||||
return num_seqs, num_tokens, max_seq_len, num_slices
|
||||
|
||||
def convert_to_sgmv_benchmark_tensors(self):
|
||||
"""
|
||||
For sgmv punica kernels, when consecutive sequences have the
|
||||
same LoRA ID, we just merge them together.
|
||||
This happens in punica.py::compute_metadata
|
||||
"""
|
||||
|
||||
# Collapse seq_lens and seq_start_loc
|
||||
_, seq_lens = torch.unique_consecutive(self.token_lora_mapping,
|
||||
return_counts=True)
|
||||
cum_result = torch.cumsum(seq_lens, dim=0)
|
||||
seq_start_loc = torch.zeros_like(seq_lens)
|
||||
seq_start_loc[1:].copy_(cum_result[:-1])
|
||||
|
||||
# Collapse prompt mapping
|
||||
prompt_lora_mapping = torch.unique_consecutive(
|
||||
self.prompt_lora_mapping)
|
||||
|
||||
assert torch.sum(seq_lens) == torch.sum(self.seq_lens), \
|
||||
f"dont match - new {torch.sum(seq_lens)} vs {torch.sum(self.seq_lens)}"
|
||||
|
||||
self.prompt_lora_mapping = prompt_lora_mapping.to(
|
||||
dtype=self.prompt_lora_mapping.dtype)
|
||||
self.seq_lens = seq_lens.to(dtype=self.seq_lens.dtype)
|
||||
self.seq_start_loc = seq_start_loc.to(dtype=self.seq_start_loc.dtype)
|
||||
|
||||
def as_sgmv_shrink_kwargs(self) -> dict[str, Any]:
|
||||
self.convert_to_sgmv_benchmark_tensors()
|
||||
def as_lora_shrink_kwargs(self) -> dict[str, Any]:
|
||||
self.sanity_check()
|
||||
self.to_device(self.input.device)
|
||||
|
||||
num_seqs, num_tokens, max_seq_len, num_slices = self.metadata()
|
||||
_, num_tokens, _, num_slices = self.metadata()
|
||||
|
||||
# Sanity check matrix shapes.
|
||||
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
|
||||
@ -531,22 +442,20 @@ class BenchmarkTensors:
|
||||
'inputs': self.input,
|
||||
'lora_a_weights': self.lora_weights_lst,
|
||||
'output_tensor': self.output,
|
||||
'b_seq_start_loc': self.seq_start_loc,
|
||||
'seq_len_tensor': self.seq_lens,
|
||||
'lora_indices_tensor': self.prompt_lora_mapping,
|
||||
'batches': num_seqs,
|
||||
'max_seq_length': max_seq_len,
|
||||
'token_nums': num_tokens,
|
||||
'token_lora_mapping': self.lora_kernel_meta.token_lora_mapping,
|
||||
'token_indices_sorted_by_lora_ids':
|
||||
self.lora_kernel_meta.token_indices_sorted_by_lora_ids,
|
||||
'num_tokens_per_lora': self.lora_kernel_meta.num_tokens_per_lora,
|
||||
'lora_token_start_loc': self.lora_kernel_meta.lora_token_start_loc,
|
||||
'lora_ids': self.lora_kernel_meta.active_lora_ids,
|
||||
'scaling': 1.0,
|
||||
}
|
||||
|
||||
def as_sgmv_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]:
|
||||
|
||||
self.convert_to_sgmv_benchmark_tensors()
|
||||
def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]:
|
||||
self.sanity_check()
|
||||
self.to_device(self.input.device)
|
||||
|
||||
num_seqs, num_tokens, max_seq_len, num_slices = self.metadata()
|
||||
_, num_tokens, _, num_slices = self.metadata()
|
||||
|
||||
# Sanity check matrix shapes.
|
||||
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
|
||||
@ -568,106 +477,16 @@ class BenchmarkTensors:
|
||||
'inputs': self.input,
|
||||
'lora_b_weights': self.lora_weights_lst,
|
||||
'output_tensor': self.output,
|
||||
'b_seq_start_loc': self.seq_start_loc,
|
||||
'seq_len_tensor': self.seq_lens,
|
||||
'lora_indices_tensor': self.prompt_lora_mapping,
|
||||
'batches': num_seqs,
|
||||
'max_seq_length': max_seq_len,
|
||||
'token_nums': num_tokens,
|
||||
'token_lora_mapping': self.lora_kernel_meta.token_lora_mapping,
|
||||
'token_indices_sorted_by_lora_ids':
|
||||
self.lora_kernel_meta.token_indices_sorted_by_lora_ids,
|
||||
'num_tokens_per_lora': self.lora_kernel_meta.num_tokens_per_lora,
|
||||
'lora_token_start_loc': self.lora_kernel_meta.lora_token_start_loc,
|
||||
'lora_ids': self.lora_kernel_meta.active_lora_ids,
|
||||
'offset_start': 0,
|
||||
'add_inputs': add_inputs,
|
||||
}
|
||||
|
||||
def as_bgmv_shrink_kwargs(self) -> dict[str, Any]:
|
||||
assert len(self.lora_weights_lst) == 1
|
||||
self.to_device(self.input.device)
|
||||
|
||||
_, num_tokens, _, _ = self.metadata()
|
||||
# Sanity check shapes
|
||||
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
|
||||
0].shape, self.output.shape
|
||||
# Expected input shape [num_tokens, hidden_size]
|
||||
assert len(i_shape) == 2
|
||||
assert i_shape[0] == num_tokens
|
||||
hidden_size = i_shape[1]
|
||||
# Expected lora weight shape [num_loras, lora_rank, hidden_size]
|
||||
assert len(lw_shape) == 3
|
||||
assert lw_shape[2] == hidden_size
|
||||
lora_rank = lw_shape[1]
|
||||
# Expected output shape [num_tokens, lora_rank]
|
||||
assert len(o_shape) == 2
|
||||
assert o_shape == (num_tokens, lora_rank)
|
||||
|
||||
return {
|
||||
'inputs': self.input,
|
||||
'lora_a_weights': self.lora_weights_lst[0],
|
||||
'output_tensor': self.output,
|
||||
'lora_indices_tensor': self.token_lora_mapping,
|
||||
'scaling': 1.0
|
||||
}
|
||||
|
||||
def as_bgmv_expand_kwargs(self, add_inputs: bool):
|
||||
assert len(self.lora_weights_lst) == 1
|
||||
self.to_device(self.input.device)
|
||||
|
||||
_, num_tokens, _, _ = self.metadata()
|
||||
# Sanity check shapes
|
||||
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
|
||||
0].shape, self.output.shape
|
||||
# Expected input shape [num_tokens, lora_rank]
|
||||
assert len(i_shape) == 2
|
||||
assert i_shape[0] == num_tokens
|
||||
lora_rank = i_shape[1]
|
||||
# Expected lora weight shape [num_loras, hidden_size, lora_rank]
|
||||
assert len(lw_shape) == 3
|
||||
assert lw_shape[2] == lora_rank
|
||||
hidden_size = lw_shape[1]
|
||||
# Expected output shape [num_tokens, hidden_size]
|
||||
assert len(o_shape) == 2
|
||||
assert o_shape == (num_tokens, hidden_size)
|
||||
|
||||
return {
|
||||
'inputs': self.input,
|
||||
'lora_b_weights': self.lora_weights_lst[0],
|
||||
'output_tensor': self.output,
|
||||
'lora_indices_tensor': self.token_lora_mapping,
|
||||
'add_inputs': add_inputs
|
||||
}
|
||||
|
||||
def as_bgmv_expand_slice_kwargs(self, add_inputs: bool) -> dict[str, Any]:
|
||||
|
||||
_, num_tokens, _, num_slices = self.metadata()
|
||||
# Sanity check shapes
|
||||
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
|
||||
0].shape, self.output.shape
|
||||
# Expected input shape [num_slices, num_tokens, lora_rank]
|
||||
assert len(i_shape) == 3
|
||||
assert i_shape[0] == num_slices
|
||||
assert i_shape[1] == num_tokens
|
||||
lora_rank = i_shape[2]
|
||||
# Expected lora weight shape [num_loras, hidden_size, lora_rank]
|
||||
assert len(lw_shape) == 3
|
||||
assert lw_shape[2] == lora_rank
|
||||
hidden_size = lw_shape[1]
|
||||
# Expected output shape [num_tokens, hidden_size * num_slices]
|
||||
assert len(o_shape) == 2
|
||||
assert o_shape == (num_tokens, hidden_size * num_slices)
|
||||
|
||||
self.to_device(self.input.device)
|
||||
|
||||
kwargs_list = []
|
||||
for i in range(num_slices):
|
||||
kwargs_list.append({
|
||||
'inputs': self.input[i],
|
||||
'lora_b_weights': self.lora_weights_lst[i],
|
||||
'output_tensor': self.output,
|
||||
'lora_indices_tensor': self.token_lora_mapping,
|
||||
'slice_offset': i * hidden_size,
|
||||
'slice_size': hidden_size,
|
||||
'add_inputs': add_inputs,
|
||||
})
|
||||
return {'kwargs_list': kwargs_list}
|
||||
|
||||
def bench_fn_kwargs(self,
|
||||
op_type: OpType,
|
||||
add_inputs: Optional[bool] = None) -> dict[str, Any]:
|
||||
@ -676,16 +495,10 @@ class BenchmarkTensors:
|
||||
else:
|
||||
assert add_inputs is not None
|
||||
|
||||
if op_type == OpType.SGMV_SHRINK:
|
||||
return self.as_sgmv_shrink_kwargs()
|
||||
if op_type == OpType.SGMV_EXPAND:
|
||||
return self.as_sgmv_expand_kwargs(add_inputs)
|
||||
if op_type == OpType.BGMV_SHRINK:
|
||||
return self.as_bgmv_shrink_kwargs()
|
||||
if op_type == OpType.BGMV_EXPAND:
|
||||
return self.as_bgmv_expand_kwargs(add_inputs)
|
||||
if op_type == OpType.BGMV_EXPAND_SLICE:
|
||||
return self.as_bgmv_expand_slice_kwargs(add_inputs)
|
||||
if op_type == OpType.LORA_SHRINK:
|
||||
return self.as_lora_shrink_kwargs()
|
||||
if op_type == OpType.LORA_EXPAND:
|
||||
return self.as_lora_expand_kwargs(add_inputs)
|
||||
raise ValueError(f"Unrecognized optype {self}")
|
||||
|
||||
def test_correctness(self, op_type: OpType,
|
||||
@ -873,14 +686,7 @@ def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]):
|
||||
timers = []
|
||||
for bench_ctx in bench_ctxs:
|
||||
for seq_len in args.seq_lengths:
|
||||
bench_ops: list[OpType] = []
|
||||
if seq_len == 1:
|
||||
# bench all decode ops
|
||||
bench_ops = [op for op in args.op_types if op.is_decode_op()]
|
||||
else:
|
||||
# bench all prefill ops
|
||||
bench_ops = [op for op in args.op_types if op.is_prefill_op()]
|
||||
|
||||
bench_ops: list[OpType] = args.op_types
|
||||
seq_len_timers = []
|
||||
for bench_op in bench_ops:
|
||||
for num_slices in bench_op.num_slices():
|
||||
@ -1090,13 +896,13 @@ Benchmark LoRA kernels:
|
||||
{use_cuda_graph_recommendation()}
|
||||
|
||||
list_bench example:
|
||||
python3 benchmarks/kernels/benchmark_lora.py list_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --hidden-sizes 2048 --lora-ranks 16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32
|
||||
python3 benchmarks/kernels/benchmark_lora.py list_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --hidden-sizes 2048 --lora-ranks 16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32
|
||||
|
||||
model_bench example:
|
||||
python3 benchmarks/kernels/benchmark_lora.py model_bench --models meta-llama/Llama-3-8b --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --lora-ranks 16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32
|
||||
python3 benchmarks/kernels/benchmark_lora.py model_bench --models meta-llama/Llama-3-8b --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --lora-ranks 16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32
|
||||
|
||||
range_bench example:
|
||||
python3 benchmarks/kernels/benchmark_lora.py range_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 --hidden-sizes-start 1024 --hidden-sizes-end 4096 --hidden-sizes-increment 1024 --lora-ranks-start 8 --lora-ranks-end 24 --lora-ranks-increment 8
|
||||
python3 benchmarks/kernels/benchmark_lora.py range_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 --hidden-sizes-start 1024 --hidden-sizes-end 4096 --hidden-sizes-increment 1024 --lora-ranks-start 8 --lora-ranks-end 24 --lora-ranks-increment 8
|
||||
""", # noqa: E501
|
||||
formatter_class=argparse.RawTextHelpFormatter)
|
||||
|
||||
|
||||
@ -45,7 +45,6 @@ def terse_type_name(dt):
|
||||
torch.float16: "fp16",
|
||||
torch.int8: "int8",
|
||||
torch.float8_e4m3fn: "fp8",
|
||||
torch.bfloat16: "bf16",
|
||||
torch.float: "float",
|
||||
torch.int: "int",
|
||||
}[dt]
|
||||
@ -259,7 +258,7 @@ def machete_create_bench_fn(bt: BenchmarkTensors,
|
||||
|
||||
return lambda: ops.machete_mm(
|
||||
a=bt.a,
|
||||
b_q=bt.w_q,
|
||||
b_q=w_q,
|
||||
b_type=bt.wtype,
|
||||
b_group_scales=bt.w_g_s,
|
||||
b_group_zeros=w_g_zp,
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from datetime import datetime
|
||||
@ -17,8 +18,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import *
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
FP8_DTYPE = torch.float8_e4m3fnuz if current_platform.is_rocm(
|
||||
) else torch.float8_e4m3fn
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
class BenchmarkConfig(TypedDict):
|
||||
@ -365,6 +365,7 @@ class BenchmarkWorker:
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
block_quant_shape: List[int] = None,
|
||||
) -> tuple[dict[str, int], float]:
|
||||
current_platform.seed_everything(self.seed)
|
||||
dtype_str = get_config_dtype_str(dtype,
|
||||
@ -385,10 +386,17 @@ class BenchmarkWorker:
|
||||
else:
|
||||
config = op_config[min(op_config.keys(),
|
||||
key=lambda x: abs(x - num_tokens))]
|
||||
kernel_time = benchmark_config(config, num_tokens, num_experts,
|
||||
shard_intermediate_size, hidden_size,
|
||||
topk, dtype, use_fp8_w8a8,
|
||||
use_int8_w8a16)
|
||||
kernel_time = benchmark_config(config,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
num_iters=100,
|
||||
block_quant_shape=block_quant_shape)
|
||||
return config, kernel_time
|
||||
|
||||
def tune(
|
||||
@ -487,6 +495,14 @@ def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,
|
||||
f.write("\n")
|
||||
|
||||
|
||||
def get_weight_block_size_safety(config, default_value=None):
|
||||
|
||||
quantization_config = getattr(config, 'quantization_config', {})
|
||||
if isinstance(quantization_config, dict):
|
||||
return quantization_config.get('weight_block_size', default_value)
|
||||
return default_value
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
print(args)
|
||||
block_quant_shape = None
|
||||
@ -508,7 +524,12 @@ def main(args: argparse.Namespace):
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
block_quant_shape = config.quantization_config['weight_block_size']
|
||||
block_quant_shape = get_weight_block_size_safety(config)
|
||||
elif config.architectures[0] == "Qwen2MoeForCausalLM":
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
else:
|
||||
# Default: Mixtral.
|
||||
E = config.num_local_experts
|
||||
|
||||
@ -7,10 +7,13 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
|
||||
create_kv_caches_with_random)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
NUM_BLOCKS = 128 * 1024
|
||||
PARTITION_SIZE = 512
|
||||
PARTITION_SIZE_ROCM = 256
|
||||
@ -176,7 +179,7 @@ def main(
|
||||
|
||||
end_time = time.perf_counter()
|
||||
if profile:
|
||||
torch.cuda.cudart().cudaProfilerStart()
|
||||
torch.cuda.cudart().cudaProfilerStop()
|
||||
return (end_time - start_time) / num_iters
|
||||
|
||||
# Warmup.
|
||||
@ -193,6 +196,9 @@ def main(
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logger.warning("This script benchmarks the paged attention kernel. "
|
||||
"By default this is no longer used in vLLM inference.")
|
||||
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark the paged attention kernel.")
|
||||
parser.add_argument("--version",
|
||||
|
||||
@ -40,7 +40,7 @@ def main(num_tokens: int,
|
||||
|
||||
end_time = time.perf_counter()
|
||||
if profile:
|
||||
torch.cuda.cudart().cudaProfilerStart()
|
||||
torch.cuda.cudart().cudaProfilerStop()
|
||||
return (end_time - start_time) / num_iters
|
||||
|
||||
# Warmup.
|
||||
|
||||
@ -139,7 +139,7 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
|
||||
|
||||
print(f"Naive output={output_naive}")
|
||||
print(f"FlashInfer output={output_flashinfer}")
|
||||
print(f"VLLM output={output_vllm}")
|
||||
print(f"vLLM output={output_vllm}")
|
||||
|
||||
if torch.allclose(output_naive, output_flashinfer, atol=1e-2,
|
||||
rtol=1e-2) and torch.allclose(
|
||||
|
||||
@ -75,3 +75,19 @@ WEIGHT_SHAPES = {
|
||||
[7168, 8192],
|
||||
],
|
||||
}
|
||||
|
||||
WEIGHT_SHAPES_MOE = {
|
||||
"nm-testing/Mixtral-8x7B-Instruct-v0.1": [
|
||||
[8, 2, 4096, 28672],
|
||||
[8, 2, 14336, 4096],
|
||||
],
|
||||
"nm-testing/deepseekv2-lite": [
|
||||
[64, 6, 2048, 1408],
|
||||
],
|
||||
"ibm-granite/granite-3.0-1b-a400m": [
|
||||
[32, 8, 1024, 1024],
|
||||
],
|
||||
"ibm-granite/granite-3.0-3b-a800m": [
|
||||
[40, 8, 1024, 1536],
|
||||
],
|
||||
}
|
||||
|
||||
420
benchmarks/kernels/benchmark_w8a8_block_fp8.py
Normal file
420
benchmarks/kernels/benchmark_w8a8_block_fp8.py
Normal file
@ -0,0 +1,420 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Adapted from sglang quantization/tuning_block_wise_kernel.py
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
import triton
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
_w8a8_block_fp8_matmul)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
assert current_platform.is_cuda(
|
||||
), "Only support tune w8a8 block fp8 kernel on CUDA device."
|
||||
|
||||
DTYPE_MAP = {
|
||||
"float32": torch.float32,
|
||||
"float16": torch.float16,
|
||||
"half": torch.half,
|
||||
"bfloat16": torch.bfloat16,
|
||||
}
|
||||
|
||||
|
||||
def w8a8_block_matmul(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
block_size: list[int],
|
||||
config: dict[str, Any],
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
"""This function performs matrix multiplication with
|
||||
block-wise quantization.
|
||||
|
||||
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
|
||||
The output is returned in the specified `output_dtype`.
|
||||
|
||||
Args:
|
||||
A: The input tensor, e.g., activation.
|
||||
B: The input tensor, e.g., weight.
|
||||
As: The per-token-group quantization scale for `A`.
|
||||
Bs: The per-block quantization scale for `B`.
|
||||
block_size: The block size for per-block quantization.
|
||||
It should be 2-dim, e.g., [128, 128].
|
||||
output_dytpe: The dtype of the returned tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The result of matmul.
|
||||
"""
|
||||
assert len(block_size) == 2
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
|
||||
assert A.shape[-1] == B.shape[-1]
|
||||
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
|
||||
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
|
||||
M = A.numel() // A.shape[-1]
|
||||
|
||||
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
||||
N, K = B.shape
|
||||
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
||||
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
||||
|
||||
C_shape = A.shape[:-1] + (N, )
|
||||
C = A.new_empty(C_shape, dtype=output_dtype)
|
||||
|
||||
def grid(META):
|
||||
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) *
|
||||
triton.cdiv(N, META["BLOCK_SIZE_N"]), )
|
||||
|
||||
if A.dtype == torch.float8_e4m3fn:
|
||||
kernel = _w8a8_block_fp8_matmul
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Currently, only support tune w8a8 block fp8 kernel.")
|
||||
|
||||
kernel[grid](
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
As,
|
||||
Bs,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
block_n,
|
||||
block_k,
|
||||
A.stride(-2),
|
||||
A.stride(-1),
|
||||
B.stride(1),
|
||||
B.stride(0),
|
||||
C.stride(-2),
|
||||
C.stride(-1),
|
||||
As.stride(-2),
|
||||
As.stride(-1),
|
||||
Bs.stride(1),
|
||||
Bs.stride(0),
|
||||
**config,
|
||||
)
|
||||
|
||||
return C
|
||||
|
||||
|
||||
def get_configs_compute_bound():
|
||||
configs = []
|
||||
for num_stages in [2, 3, 4, 5]:
|
||||
for block_m in [16, 32, 64, 128, 256]:
|
||||
for block_k in [64, 128]:
|
||||
for block_n in [32, 64, 128, 256]:
|
||||
for num_warps in [4, 8]:
|
||||
for group_size in [1, 16, 32, 64]:
|
||||
configs.append({
|
||||
"BLOCK_SIZE_M": block_m,
|
||||
"BLOCK_SIZE_N": block_n,
|
||||
"BLOCK_SIZE_K": block_k,
|
||||
"GROUP_SIZE_M": group_size,
|
||||
"num_warps": num_warps,
|
||||
"num_stages": num_stages,
|
||||
})
|
||||
return configs
|
||||
|
||||
|
||||
def get_weight_shapes(tp_size):
|
||||
# NOTE(HandH1998): The weight shapes only works for DeepSeek-V3.
|
||||
# Modify them, if you tune for another different model.
|
||||
# cannot TP
|
||||
total = [
|
||||
(512 + 64, 7168),
|
||||
((128 + 64) * 128, 7168),
|
||||
(128 * (128 + 128), 512),
|
||||
(7168, 16384),
|
||||
(7168, 18432),
|
||||
]
|
||||
# N can TP
|
||||
n_tp = [
|
||||
(18432 * 2, 7168),
|
||||
((128 + 64) * 128, 7168),
|
||||
(128 * (128 + 128), 512),
|
||||
(24576, 1536),
|
||||
(12288, 7168),
|
||||
(4096, 7168),
|
||||
]
|
||||
# K can TP
|
||||
k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]
|
||||
|
||||
weight_shapes = []
|
||||
for t in total:
|
||||
weight_shapes.append(t)
|
||||
for n_t in n_tp:
|
||||
new_t = (n_t[0] // tp_size, n_t[1])
|
||||
weight_shapes.append(new_t)
|
||||
for k_t in k_tp:
|
||||
new_t = (k_t[0], k_t[1] // tp_size)
|
||||
weight_shapes.append(new_t)
|
||||
return weight_shapes
|
||||
|
||||
|
||||
def benchmark_config(A,
|
||||
B,
|
||||
As,
|
||||
Bs,
|
||||
block_size,
|
||||
config,
|
||||
out_dtype=torch.float16,
|
||||
num_iters=10):
|
||||
|
||||
def run():
|
||||
w8a8_block_matmul(A, B, As, Bs, block_size, config, out_dtype)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
# JIT complication & warmup
|
||||
for _ in range(5):
|
||||
run()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
latencies: list[float] = []
|
||||
for i in range(num_iters):
|
||||
torch.cuda.synchronize()
|
||||
start_event.record()
|
||||
run()
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
latencies.append(start_event.elapsed_time(end_event))
|
||||
avg = sum(latencies) / (num_iters * 10) * 1000 # us
|
||||
return avg
|
||||
|
||||
|
||||
def tune(M, N, K, block_size, out_dtype, search_space, input_type):
|
||||
factor_for_scale = 1e-2
|
||||
|
||||
if input_type == "fp8":
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
A_fp32 = (
|
||||
(torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 *
|
||||
fp8_max)
|
||||
A = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
B_fp32 = (
|
||||
(torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 *
|
||||
fp8_max)
|
||||
B = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Currently, only support tune w8a8 block fp8 kernel.")
|
||||
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n_tiles = (N + block_n - 1) // block_n
|
||||
k_tiles = (K + block_k - 1) // block_k
|
||||
|
||||
As = torch.rand(M, k_tiles, dtype=torch.float32,
|
||||
device="cuda") * factor_for_scale
|
||||
Bs = (torch.rand(n_tiles, k_tiles, dtype=torch.float32, device="cuda") *
|
||||
factor_for_scale)
|
||||
|
||||
best_config = None
|
||||
best_time = float("inf")
|
||||
for config in tqdm(search_space):
|
||||
try:
|
||||
kernel_time = benchmark_config(
|
||||
A,
|
||||
B,
|
||||
As,
|
||||
Bs,
|
||||
block_size,
|
||||
config,
|
||||
out_dtype,
|
||||
num_iters=10,
|
||||
)
|
||||
except triton.runtime.autotuner.OutOfResources:
|
||||
# Some configurations may be invalid and fail to compile.
|
||||
continue
|
||||
|
||||
if kernel_time < best_time:
|
||||
best_time = kernel_time
|
||||
best_config = config
|
||||
now = datetime.now()
|
||||
print(f"{now.ctime()}] Completed tuning for batch_size={M}")
|
||||
assert best_config is not None
|
||||
return best_config
|
||||
|
||||
|
||||
def save_configs(
|
||||
N,
|
||||
K,
|
||||
block_n,
|
||||
block_k,
|
||||
configs,
|
||||
save_path,
|
||||
input_type="fp8",
|
||||
) -> None:
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
device_name = current_platform.get_device_name().replace(" ", "_")
|
||||
json_file_name = (
|
||||
f"N={N},K={K},device_name={device_name},dtype={input_type}_w8a8,"
|
||||
f"block_shape=[{block_n},{block_k}].json")
|
||||
|
||||
config_file_path = os.path.join(save_path, json_file_name)
|
||||
print(f"Writing best config to {config_file_path}...")
|
||||
|
||||
with open(config_file_path, "w") as f:
|
||||
json.dump(configs, f, indent=4)
|
||||
f.write("\n")
|
||||
|
||||
|
||||
def tune_on_gpu(args_dict):
|
||||
"""Run tuning on a specific GPU."""
|
||||
gpu_id = args_dict["gpu_id"]
|
||||
batch_sizes = args_dict["batch_sizes"]
|
||||
weight_shapes = args_dict["weight_shapes"]
|
||||
args = args_dict["args"]
|
||||
|
||||
torch.cuda.set_device(gpu_id)
|
||||
print(f"Starting tuning on GPU {gpu_id} with batch sizes {batch_sizes}")
|
||||
|
||||
block_n = args.block_n
|
||||
block_k = args.block_k
|
||||
out_dtype = DTYPE_MAP[args.out_dtype]
|
||||
save_path = args.save_path
|
||||
input_type = args.input_type
|
||||
|
||||
search_space = get_configs_compute_bound()
|
||||
search_space = [
|
||||
config for config in search_space
|
||||
if block_k % config["BLOCK_SIZE_K"] == 0
|
||||
]
|
||||
|
||||
start = time.time()
|
||||
for shape in tqdm(weight_shapes, desc=f"GPU {gpu_id} - Shapes"):
|
||||
N, K = shape[0], shape[1]
|
||||
print(f"[GPU {gpu_id}] Tune for weight shape of `N: {N}, K: {K}`")
|
||||
benchmark_results = [
|
||||
tune(
|
||||
batch_size,
|
||||
N,
|
||||
K,
|
||||
[block_n, block_k],
|
||||
out_dtype,
|
||||
search_space,
|
||||
input_type,
|
||||
) for batch_size in tqdm(batch_sizes,
|
||||
desc=f"GPU {gpu_id} - Batch sizes")
|
||||
]
|
||||
best_configs = {
|
||||
M: config
|
||||
for M, config in zip(batch_sizes, benchmark_results)
|
||||
}
|
||||
save_configs(N, K, block_n, block_k, best_configs, save_path,
|
||||
input_type)
|
||||
|
||||
end = time.time()
|
||||
print(f"Tuning on GPU {gpu_id} took {end - start:.2f} seconds")
|
||||
|
||||
|
||||
def distribute_batch_sizes(batch_sizes, num_gpus):
|
||||
"""Distribute batch sizes across available GPUs."""
|
||||
batches_per_gpu = []
|
||||
for i in range(num_gpus):
|
||||
start_idx = i * len(batch_sizes) // num_gpus
|
||||
end_idx = (i + 1) * len(batch_sizes) // num_gpus
|
||||
batches_per_gpu.append(batch_sizes[start_idx:end_idx])
|
||||
return batches_per_gpu
|
||||
|
||||
|
||||
def main(args):
|
||||
print(args)
|
||||
num_gpus = torch.cuda.device_count()
|
||||
if num_gpus == 0:
|
||||
raise RuntimeError("No GPU available for tuning")
|
||||
print(f"Found {num_gpus} GPUs for parallel tuning")
|
||||
|
||||
torch.cuda.init()
|
||||
|
||||
if args.batch_size is None:
|
||||
batch_sizes = [
|
||||
1,
|
||||
2,
|
||||
4,
|
||||
8,
|
||||
16,
|
||||
24,
|
||||
32,
|
||||
48,
|
||||
64,
|
||||
96,
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
1536,
|
||||
2048,
|
||||
3072,
|
||||
4096,
|
||||
]
|
||||
else:
|
||||
batch_sizes = [args.batch_size]
|
||||
num_gpus = 1 # If only one batch size, use only one GPU
|
||||
|
||||
weight_shapes = get_weight_shapes(args.tp_size)
|
||||
|
||||
batches_per_gpu = distribute_batch_sizes(batch_sizes, num_gpus)
|
||||
|
||||
process_args = []
|
||||
for gpu_id in range(num_gpus):
|
||||
process_args.append({
|
||||
"gpu_id": gpu_id,
|
||||
"batch_sizes": batches_per_gpu[gpu_id],
|
||||
"weight_shapes":
|
||||
weight_shapes, # Each GPU processes all weight shapes
|
||||
"args": args,
|
||||
})
|
||||
|
||||
ctx = mp.get_context("spawn")
|
||||
with ctx.Pool(num_gpus) as pool:
|
||||
pool.map(tune_on_gpu, process_args)
|
||||
|
||||
print("Multi-GPU tuning completed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description="""
|
||||
Tune triton w8a8 block fp8 for DeepSeek-V3/DeepSeek-R1:
|
||||
python3 benchmark_w8a8_block_fp8.py --tp-size 8 --input-type fp8
|
||||
Then copy to model_executor/layers/quantization/utils/configs
|
||||
""",
|
||||
formatter_class=argparse.RawTextHelpFormatter)
|
||||
|
||||
parser.add_argument("--tp-size", "-tp", type=int, default=8)
|
||||
parser.add_argument("--input-type",
|
||||
type=str,
|
||||
choices=["fp8"],
|
||||
default="fp8")
|
||||
parser.add_argument(
|
||||
"--out-dtype",
|
||||
type=str,
|
||||
choices=["float32", "float16", "bfloat16", "half"],
|
||||
default="float16",
|
||||
)
|
||||
parser.add_argument("--block-n", type=int, default=128)
|
||||
parser.add_argument("--block-k", type=int, default=128)
|
||||
parser.add_argument("--batch-size", type=int, required=False)
|
||||
parser.add_argument("--save-path", type=str, default="./")
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
129
benchmarks/kernels/deepgemm/README.md
Normal file
129
benchmarks/kernels/deepgemm/README.md
Normal file
@ -0,0 +1,129 @@
|
||||
# DeepSeek DeepGEMM Kernels Benchmark
|
||||
|
||||
This directory includes benchmarks between DeepSeek's DeepGEMM block fp8 kernels against vLLM's existing triton and CUTLASS-based kernels.
|
||||
|
||||
Currently this just includes dense GEMMs and only works on Hopper GPUs.
|
||||
|
||||
## Setup
|
||||
|
||||
You need to install vLLM in your usual fashion, then install DeepGEMM from source in its own directory:
|
||||
|
||||
```
|
||||
git clone --recursive https://github.com/deepseek-ai/DeepGEMM
|
||||
cd DeepGEMM
|
||||
python setup.py install
|
||||
uv pip install -e .
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
```
|
||||
python benchmark_fp8_block_dense_gemm.py
|
||||
INFO 02-26 21:55:13 [__init__.py:207] Automatically detected platform cuda.
|
||||
===== STARTING FP8 GEMM BENCHMARK =====
|
||||
PyTorch version: 2.5.1+cu124
|
||||
CUDA version: 12.4
|
||||
Triton version: 3.1.0
|
||||
Using device: NVIDIA H100 80GB HBM3
|
||||
WARNING 02-26 21:55:15 [fp8_utils.py:458] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json
|
||||
INFO 02-26 21:55:15 [fp8_utils.py:449] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel.
|
||||
WARNING 02-26 21:55:16 [fp8_utils.py:458] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=18432,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json
|
||||
WARNING 02-26 21:55:17 [fp8_utils.py:458] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json
|
||||
INFO 02-26 21:55:17 [fp8_utils.py:449] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel.
|
||||
INFO 02-26 21:55:17 [fp8_utils.py:449] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel.
|
||||
|
||||
===== PERFORMANCE COMPARISON =====
|
||||
|
||||
DeepGEMM Implementation:
|
||||
+------+-------+-------+-----------+--------+--------+
|
||||
| m | n | k | Time (μs) | TFLOPS | GB/s |
|
||||
+------+-------+-------+-----------+--------+--------+
|
||||
| 8 | 4096 | 7168 | 102.9 | 4.6 | 286.4 |
|
||||
| 8 | 7168 | 18432 | 70.8 | 29.8 | 1868.8 |
|
||||
| 8 | 18432 | 7168 | 69.3 | 30.5 | 1911.8 |
|
||||
| 64 | 4096 | 7168 | 69.1 | 54.4 | 439.0 |
|
||||
| 64 | 7168 | 18432 | 69.4 | 243.6 | 1933.6 |
|
||||
| 64 | 18432 | 7168 | 70.4 | 240.3 | 1917.2 |
|
||||
| 64 | 24576 | 1536 | 70.1 | 68.9 | 584.6 |
|
||||
| 64 | 32768 | 512 | 68.4 | 31.4 | 307.1 |
|
||||
| 64 | 7168 | 16384 | 69.5 | 216.3 | 1718.5 |
|
||||
| 128 | 4096 | 7168 | 141.1 | 53.3 | 222.1 |
|
||||
| 128 | 7168 | 18432 | 71.9 | 470.5 | 1896.1 |
|
||||
| 128 | 18432 | 7168 | 69.3 | 488.2 | 1988.2 |
|
||||
| 1024 | 4096 | 7168 | 89.7 | 670.1 | 502.5 |
|
||||
| 1024 | 18432 | 7168 | 279.0 | 969.8 | 635.2 |
|
||||
| 2048 | 4096 | 7168 | 175.1 | 687.0 | 347.4 |
|
||||
| 4096 | 4096 | 7168 | 335.4 | 717.0 | 275.1 |
|
||||
+------+-------+-------+-----------+--------+--------+
|
||||
|
||||
vLLM Triton Implementation:
|
||||
+------+-------+-------+-----------+--------+--------+--------------+
|
||||
| m | n | k | Time (μs) | TFLOPS | GB/s | vs DeepGEMM |
|
||||
+------+-------+-------+-----------+--------+--------+--------------+
|
||||
| 8 | 4096 | 7168 | 74.0 | 6.3 | 398.2 | 1.39x faster |
|
||||
| 8 | 7168 | 18432 | 89.6 | 23.6 | 1478.1 | 0.79x slower |
|
||||
| 8 | 18432 | 7168 | 113.2 | 18.7 | 1170.4 | 0.61x slower |
|
||||
| 64 | 4096 | 7168 | 79.4 | 47.3 | 382.2 | 0.87x slower |
|
||||
| 64 | 7168 | 18432 | 98.5 | 171.7 | 1363.0 | 0.70x slower |
|
||||
| 64 | 18432 | 7168 | 119.5 | 141.5 | 1129.4 | 0.59x slower |
|
||||
| 64 | 24576 | 1536 | 37.6 | 128.4 | 1089.7 | 1.86x faster |
|
||||
| 64 | 32768 | 512 | 38.7 | 55.5 | 542.6 | 1.77x faster |
|
||||
| 64 | 7168 | 16384 | 86.1 | 174.5 | 1386.4 | 0.81x slower |
|
||||
| 128 | 4096 | 7168 | 90.7 | 82.9 | 345.4 | 1.56x faster |
|
||||
| 128 | 7168 | 18432 | 144.0 | 234.9 | 946.9 | 0.50x slower |
|
||||
| 128 | 18432 | 7168 | 229.5 | 147.4 | 600.1 | 0.30x slower |
|
||||
| 1024 | 4096 | 7168 | 242.3 | 248.2 | 186.1 | 0.37x slower |
|
||||
| 1024 | 18432 | 7168 | 897.8 | 301.4 | 197.4 | 0.31x slower |
|
||||
| 2048 | 4096 | 7168 | 463.0 | 259.7 | 131.4 | 0.38x slower |
|
||||
| 4096 | 4096 | 7168 | 901.8 | 266.7 | 102.3 | 0.37x slower |
|
||||
+------+-------+-------+-----------+--------+--------+--------------+
|
||||
|
||||
vLLM CUTLASS Implementation:
|
||||
+------+-------+-------+-----------+--------+--------+--------------+--------------+
|
||||
| m | n | k | Time (μs) | TFLOPS | GB/s | vs DeepGEMM | vs Triton |
|
||||
+------+-------+-------+-----------+--------+--------+--------------+--------------+
|
||||
| 8 | 4096 | 7168 | 34.6 | 13.6 | 852.3 | 2.98x faster | 2.14x faster |
|
||||
| 8 | 7168 | 18432 | 78.9 | 26.8 | 1677.3 | 0.90x slower | 1.13x faster |
|
||||
| 8 | 18432 | 7168 | 81.2 | 26.0 | 1631.1 | 0.85x slower | 1.39x faster |
|
||||
| 64 | 4096 | 7168 | 36.9 | 101.9 | 822.9 | 1.87x faster | 2.15x faster |
|
||||
| 64 | 7168 | 18432 | 87.4 | 193.4 | 1535.2 | 0.79x slower | 1.13x faster |
|
||||
| 64 | 18432 | 7168 | 85.0 | 199.0 | 1587.6 | 0.83x slower | 1.41x faster |
|
||||
| 64 | 24576 | 1536 | 28.0 | 172.8 | 1465.8 | 2.51x faster | 1.35x faster |
|
||||
| 64 | 32768 | 512 | 28.8 | 74.5 | 728.5 | 2.37x faster | 1.34x faster |
|
||||
| 64 | 7168 | 16384 | 77.9 | 193.0 | 1532.8 | 0.89x slower | 1.11x faster |
|
||||
| 128 | 4096 | 7168 | 39.1 | 192.4 | 802.0 | 3.61x faster | 2.32x faster |
|
||||
| 128 | 7168 | 18432 | 93.7 | 360.8 | 1454.2 | 0.77x slower | 1.54x faster |
|
||||
| 128 | 18432 | 7168 | 85.7 | 394.8 | 1608.0 | 0.81x slower | 2.68x faster |
|
||||
| 1024 | 4096 | 7168 | 99.7 | 603.1 | 452.2 | 0.90x slower | 2.43x faster |
|
||||
| 1024 | 18432 | 7168 | 331.3 | 816.7 | 534.9 | 0.84x slower | 2.71x faster |
|
||||
| 2048 | 4096 | 7168 | 198.3 | 606.6 | 306.7 | 0.88x slower | 2.34x faster |
|
||||
| 4096 | 4096 | 7168 | 392.2 | 613.2 | 235.3 | 0.86x slower | 2.30x faster |
|
||||
+------+-------+-------+-----------+--------+--------+--------------+--------------+
|
||||
|
||||
===== AVERAGE PERFORMANCE =====
|
||||
+----------------+------------+----------+---------------+
|
||||
| Implementation | Avg TFLOPS | Avg GB/s | Avg Time (ms) |
|
||||
+----------------+------------+----------+---------------+
|
||||
| DeepGEMM | 310.98 | 1052.10 | 0.11 |
|
||||
| vLLM Triton | 144.30 | 715.60 | 0.23 |
|
||||
| vLLM CUTLASS | 286.78 | 1076.67 | 0.11 |
|
||||
+----------------+------------+----------+---------------+
|
||||
|
||||
===== AVERAGE SPEEDUPS =====
|
||||
+-----------------------------+--------------+
|
||||
| Comparison | Speedup |
|
||||
+-----------------------------+--------------+
|
||||
| DeepGEMM vs vLLM Triton | 1.71x faster |
|
||||
| DeepGEMM vs vLLM CUTLASS | 0.94x slower |
|
||||
| vLLM CUTLASS vs vLLM Triton | 1.84x faster |
|
||||
+-----------------------------+--------------+
|
||||
|
||||
===== ACCURACY COMPARISON =====
|
||||
+----------------+-----------------------+
|
||||
| Implementation | Avg Diff vs Reference |
|
||||
+----------------+-----------------------+
|
||||
| DeepGEMM | 0.000684 |
|
||||
| vLLM Triton | 0.000684 |
|
||||
| vLLM CUTLASS | 0.000684 |
|
||||
+----------------+-----------------------+
|
||||
```
|
||||
464
benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py
Normal file
464
benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py
Normal file
@ -0,0 +1,464 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# fmt: off
|
||||
# ruff: noqa: E501
|
||||
import time
|
||||
|
||||
# Import DeepGEMM functions
|
||||
import deep_gemm
|
||||
import torch
|
||||
import triton
|
||||
from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor
|
||||
|
||||
# Import vLLM functions
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
|
||||
|
||||
|
||||
# Copied from
|
||||
# https://github.com/deepseek-ai/DeepGEMM/blob/78cacf70d41d15d688bd493ebc85845f7f2a3d5d/tests/test_core.py#L9
|
||||
def per_token_cast_to_fp8(
|
||||
x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Convert tensor to FP8 format with per-token scaling."""
|
||||
assert x.dim() == 2 and x.size(1) % 128 == 0
|
||||
m, n = x.shape
|
||||
x_view = x.view(m, -1, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
||||
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(
|
||||
torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)
|
||||
|
||||
|
||||
# Copied from
|
||||
# https://github.com/deepseek-ai/DeepGEMM/blob/78cacf70d41d15d688bd493ebc85845f7f2a3d5d/tests/test_core.py#L17
|
||||
def per_block_cast_to_fp8(
|
||||
x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Convert tensor to FP8 format with per-block scaling."""
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128),
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
||||
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (
|
||||
x_amax / 448.0).view(x_view.size(0), x_view.size(2))
|
||||
|
||||
|
||||
def benchmark_shape(m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
warmup: int = 100,
|
||||
repeat: int = 10000,
|
||||
verbose: bool = False) -> dict:
|
||||
"""Benchmark all implementations for a specific (m, n, k) shape."""
|
||||
if verbose:
|
||||
print(f"\n=== Benchmarking shape: m={m}, n={n}, k={k} ===")
|
||||
|
||||
# Create test tensors
|
||||
A = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
|
||||
B = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
|
||||
|
||||
# Reference result in BF16
|
||||
torch.cuda.synchronize()
|
||||
C_ref = A @ B.t()
|
||||
|
||||
# Pre-quantize B for all implementations
|
||||
# (weights can be pre-quantized offline)
|
||||
B_deepgemm, B_scale_deepgemm = per_block_cast_to_fp8(B)
|
||||
B_vllm, B_scale_vllm = per_block_cast_to_fp8(B)
|
||||
|
||||
# Block size configuration
|
||||
block_size = [128, 128]
|
||||
|
||||
# Pre-quantize A for all implementations
|
||||
A_deepgemm, A_scale_deepgemm = per_token_cast_to_fp8(A)
|
||||
A_scale_deepgemm = get_col_major_tma_aligned_tensor(A_scale_deepgemm)
|
||||
C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
|
||||
A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1])
|
||||
A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8(
|
||||
A, block_size[1], column_major_scales=True)
|
||||
|
||||
# === DeepGEMM Implementation ===
|
||||
def deepgemm_gemm():
|
||||
# A quantization is inside the loop as it depends on activations
|
||||
# A_deepgemm, A_scale_deepgemm = per_token_cast_to_fp8(A)
|
||||
# A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8(
|
||||
# A, block_size[1])
|
||||
# A_scale_aligned = get_col_major_tma_aligned_tensor(A_scale_deepgemm)
|
||||
# C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt((A_deepgemm, A_scale_deepgemm),
|
||||
(B_deepgemm, B_scale_deepgemm),
|
||||
C_deepgemm)
|
||||
return C_deepgemm
|
||||
|
||||
# === vLLM Triton Implementation ===
|
||||
def vllm_triton_gemm():
|
||||
# A quantization is inside the loop as it depends on activations
|
||||
# A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1])
|
||||
return w8a8_block_fp8_matmul(A_vllm,
|
||||
B_vllm,
|
||||
A_scale_vllm,
|
||||
B_scale_vllm,
|
||||
block_size,
|
||||
output_dtype=torch.bfloat16)
|
||||
|
||||
# === vLLM CUTLASS Implementation ===
|
||||
def vllm_cutlass_gemm():
|
||||
# A quantization is inside the loop as it depends on activations
|
||||
# A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8(
|
||||
# A, block_size[1], column_major_scales=True)
|
||||
return ops.cutlass_scaled_mm(A_vllm_cutlass,
|
||||
B_vllm.T,
|
||||
scale_a=A_scale_vllm_cutlass,
|
||||
scale_b=B_scale_vllm.T,
|
||||
out_dtype=torch.bfloat16)
|
||||
|
||||
# Run correctness check first
|
||||
if verbose:
|
||||
print("Running correctness check...")
|
||||
C_deepgemm = deepgemm_gemm()
|
||||
C_vllm_triton = vllm_triton_gemm()
|
||||
C_vllm_cutlass = vllm_cutlass_gemm()
|
||||
|
||||
deepgemm_diff = calc_diff(C_deepgemm, C_ref)
|
||||
vllm_triton_diff = calc_diff(C_vllm_triton, C_ref)
|
||||
vllm_cutlass_diff = calc_diff(C_vllm_cutlass, C_ref)
|
||||
|
||||
if verbose:
|
||||
print(f"DeepGEMM vs Reference difference: {deepgemm_diff:.6f}")
|
||||
print(f"vLLM Triton vs Reference difference: {vllm_triton_diff:.6f}")
|
||||
print(f"vLLM CUTLASS vs Reference difference: {vllm_cutlass_diff:.6f}")
|
||||
print("vLLM Triton vs DeepGEMM difference: "
|
||||
f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}")
|
||||
print("vLLM CUTLASS vs DeepGEMM difference: "
|
||||
f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}")
|
||||
|
||||
# Benchmark implementations
|
||||
implementations = {
|
||||
"DeepGEMM": deepgemm_gemm,
|
||||
"vLLM Triton": vllm_triton_gemm,
|
||||
"vLLM CUTLASS": vllm_cutlass_gemm
|
||||
}
|
||||
|
||||
benchmark_results = {
|
||||
"shape": {
|
||||
"m": m,
|
||||
"n": n,
|
||||
"k": k
|
||||
},
|
||||
"implementations": {}
|
||||
}
|
||||
|
||||
for name, func in implementations.items():
|
||||
# Warmup
|
||||
for _ in range(warmup):
|
||||
func()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Timing loop
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
for _ in range(repeat):
|
||||
func()
|
||||
torch.cuda.synchronize()
|
||||
end = time.time()
|
||||
|
||||
# Calculate timing and TFLOPS
|
||||
avg_time_ms = (end - start) / repeat * 1000
|
||||
avg_time_us = avg_time_ms * 1000
|
||||
tflops = 2 * m * n * k / (avg_time_ms * 1e-3) / 1e12
|
||||
gb_s = (m * k + k * n + m * n * 2) / 1e9 / (avg_time_ms * 1e-3)
|
||||
|
||||
benchmark_results["implementations"][name] = {
|
||||
"time_ms": avg_time_ms,
|
||||
"time_us": avg_time_us,
|
||||
"tflops": tflops,
|
||||
"gb_s": gb_s,
|
||||
"diff": {
|
||||
"DeepGEMM":
|
||||
0.0 if name == "DeepGEMM" else calc_diff(func(), C_deepgemm),
|
||||
"Reference":
|
||||
deepgemm_diff if name == "DeepGEMM" else
|
||||
(vllm_triton_diff
|
||||
if name == "vLLM Triton" else vllm_cutlass_diff)
|
||||
}
|
||||
}
|
||||
|
||||
if verbose:
|
||||
print(
|
||||
f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {gb_s:.2f} GB/s"
|
||||
)
|
||||
|
||||
# Calculate speedups
|
||||
baseline = benchmark_results["implementations"]["DeepGEMM"]["time_ms"]
|
||||
for name, data in benchmark_results["implementations"].items():
|
||||
if name != "DeepGEMM":
|
||||
speedup = baseline / data["time_ms"]
|
||||
benchmark_results["implementations"][name][
|
||||
"speedup_vs_deepgemm"] = speedup
|
||||
if verbose:
|
||||
print(f"DeepGEMM is {1/speedup:.2f}x "
|
||||
f"{'faster' if 1/speedup > 1 else 'slower'} than {name}")
|
||||
|
||||
vllm_triton_time = benchmark_results["implementations"]["vLLM Triton"][
|
||||
"time_ms"]
|
||||
vllm_cutlass_time = benchmark_results["implementations"]["vLLM CUTLASS"][
|
||||
"time_ms"]
|
||||
cutlass_vs_triton = vllm_triton_time / vllm_cutlass_time
|
||||
benchmark_results["implementations"]["vLLM CUTLASS"][
|
||||
"speedup_vs_triton"] = cutlass_vs_triton
|
||||
if verbose:
|
||||
print(
|
||||
f"vLLM CUTLASS is {cutlass_vs_triton:.2f}x "
|
||||
f"{'faster' if cutlass_vs_triton > 1 else 'slower'} than vLLM Triton"
|
||||
)
|
||||
|
||||
return benchmark_results
|
||||
|
||||
|
||||
def format_table_row(values, widths):
|
||||
"""Format a row with specified column widths."""
|
||||
return "| " + " | ".join(f"{val:{w}}"
|
||||
for val, w in zip(values, widths)) + " |"
|
||||
|
||||
|
||||
def print_table(headers, rows, title=None):
|
||||
"""Print a table with headers and rows."""
|
||||
if title:
|
||||
print(f"\n{title}")
|
||||
|
||||
# Calculate column widths based on headers and data
|
||||
widths = [
|
||||
max(len(str(h)), max(len(str(row[i])) for row in rows))
|
||||
for i, h in enumerate(headers)
|
||||
]
|
||||
|
||||
# Create separator line
|
||||
separator = "+-" + "-+-".join("-" * w for w in widths) + "-+"
|
||||
|
||||
# Print table
|
||||
print(separator)
|
||||
print(format_table_row(headers, widths))
|
||||
print(separator)
|
||||
for row in rows:
|
||||
print(format_table_row(row, widths))
|
||||
print(separator)
|
||||
|
||||
|
||||
def format_speedup(value):
|
||||
"""Format speedup value with indicator if it's faster or slower."""
|
||||
return f"{value:.2f}x {'faster' if value > 1.0 else 'slower'}"
|
||||
|
||||
|
||||
def run_benchmarks(verbose: bool = False):
|
||||
"""Run benchmarks for a set of common shapes."""
|
||||
print("===== STARTING FP8 GEMM BENCHMARK =====")
|
||||
|
||||
# Make sure we're using the GPU
|
||||
if not torch.cuda.is_available():
|
||||
print("CUDA not available! Tests require GPU.")
|
||||
return
|
||||
|
||||
# Print system information
|
||||
print(f"PyTorch version: {torch.__version__}")
|
||||
print(f"CUDA version: {torch.version.cuda}")
|
||||
print(f"Triton version: {triton.__version__}")
|
||||
print(f"Using device: {torch.cuda.get_device_name()}")
|
||||
|
||||
# Enable TF32 for better performance
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
# Set seeds for reproducibility
|
||||
torch.manual_seed(42)
|
||||
torch.cuda.manual_seed(42)
|
||||
|
||||
# Define benchmark shapes (m, n, k)
|
||||
shapes = [
|
||||
(8, 4096, 7168),
|
||||
(8, 7168, 18432),
|
||||
(8, 18432, 7168),
|
||||
(64, 4096, 7168),
|
||||
(64, 7168, 18432),
|
||||
(64, 18432, 7168),
|
||||
(64, 24576, 1536),
|
||||
(64, 32768, 512),
|
||||
(64, 7168, 16384),
|
||||
(128, 4096, 7168),
|
||||
(128, 7168, 18432),
|
||||
(128, 18432, 7168),
|
||||
(1024, 4096, 7168),
|
||||
(1024, 18432, 7168),
|
||||
(2048, 4096, 7168),
|
||||
(4096, 4096, 7168),
|
||||
]
|
||||
shapes = [
|
||||
# (64, 2112, 7168),
|
||||
(64, 24576, 1536),
|
||||
(64, 32768, 512),
|
||||
(64, 7168, 16384),
|
||||
(64, 4096, 7168),
|
||||
(64, 7168, 2048),
|
||||
# (128, 2112, 7168),
|
||||
(128, 24576, 1536),
|
||||
(128, 32768, 512),
|
||||
(128, 7168, 16384),
|
||||
(128, 4096, 7168),
|
||||
(128, 7168, 2048),
|
||||
# (4096, 2112, 7168),
|
||||
(4096, 24576, 1536),
|
||||
(4096, 32768, 512),
|
||||
(4096, 7168, 16384),
|
||||
(4096, 4096, 7168),
|
||||
(4096, 7168, 2048),
|
||||
]
|
||||
|
||||
all_results = []
|
||||
for m, n, k in shapes:
|
||||
result = benchmark_shape(m, n, k, verbose=verbose)
|
||||
all_results.append(result)
|
||||
|
||||
# Print results in a nicely formatted table
|
||||
print("\n===== PERFORMANCE COMPARISON =====")
|
||||
|
||||
# Print DeepGEMM table
|
||||
deepgemm_headers = ["m", "n", "k", "Time (μs)", "TFLOPS", "GB/s"]
|
||||
deepgemm_rows = []
|
||||
for result in all_results:
|
||||
shape = result["shape"]
|
||||
impl_data = result["implementations"]["DeepGEMM"]
|
||||
deepgemm_rows.append([
|
||||
shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}",
|
||||
f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}"
|
||||
])
|
||||
|
||||
print_table(deepgemm_headers,
|
||||
deepgemm_rows,
|
||||
title="DeepGEMM Implementation:")
|
||||
|
||||
# Print vLLM Triton table
|
||||
triton_headers = [
|
||||
"m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM"
|
||||
]
|
||||
triton_rows = []
|
||||
for result in all_results:
|
||||
shape = result["shape"]
|
||||
impl_data = result["implementations"]["vLLM Triton"]
|
||||
speedup = impl_data.get("speedup_vs_deepgemm", 1.0)
|
||||
triton_rows.append([
|
||||
shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}",
|
||||
f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}",
|
||||
format_speedup(speedup)
|
||||
])
|
||||
|
||||
print_table(triton_headers,
|
||||
triton_rows,
|
||||
title="vLLM Triton Implementation:")
|
||||
|
||||
# Print vLLM CUTLASS table
|
||||
cutlass_headers = [
|
||||
"m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM",
|
||||
"vs Triton"
|
||||
]
|
||||
cutlass_rows = []
|
||||
for result in all_results:
|
||||
shape = result["shape"]
|
||||
impl_data = result["implementations"]["vLLM CUTLASS"]
|
||||
vs_deepgemm = impl_data.get("speedup_vs_deepgemm", 1.0)
|
||||
vs_triton = impl_data.get("speedup_vs_triton", 1.0)
|
||||
cutlass_rows.append([
|
||||
shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}",
|
||||
f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}",
|
||||
format_speedup(vs_deepgemm),
|
||||
format_speedup(vs_triton)
|
||||
])
|
||||
|
||||
print_table(cutlass_headers,
|
||||
cutlass_rows,
|
||||
title="vLLM CUTLASS Implementation:")
|
||||
|
||||
# Calculate and print averages
|
||||
print("\n===== AVERAGE PERFORMANCE =====")
|
||||
|
||||
implementations = ["DeepGEMM", "vLLM Triton", "vLLM CUTLASS"]
|
||||
avg_metrics = {
|
||||
impl: {
|
||||
"tflops": 0,
|
||||
"gb_s": 0,
|
||||
"time_ms": 0
|
||||
}
|
||||
for impl in implementations
|
||||
}
|
||||
|
||||
for result in all_results:
|
||||
for impl in implementations:
|
||||
impl_data = result["implementations"][impl]
|
||||
avg_metrics[impl]["tflops"] += impl_data["tflops"]
|
||||
avg_metrics[impl]["gb_s"] += impl_data["gb_s"]
|
||||
avg_metrics[impl]["time_ms"] += impl_data["time_ms"]
|
||||
|
||||
num_shapes = len(all_results)
|
||||
avg_headers = ["Implementation", "Avg TFLOPS", "Avg GB/s", "Avg Time (ms)"]
|
||||
avg_rows = []
|
||||
|
||||
for impl in implementations:
|
||||
avg_tflops = avg_metrics[impl]["tflops"] / num_shapes
|
||||
avg_mem_bw = avg_metrics[impl]["gb_s"] / num_shapes
|
||||
avg_time = avg_metrics[impl]["time_ms"] / num_shapes
|
||||
avg_rows.append([
|
||||
impl, f"{avg_tflops:.2f}", f"{avg_mem_bw:.2f}", f"{avg_time:.2f}"
|
||||
])
|
||||
|
||||
print_table(avg_headers, avg_rows)
|
||||
|
||||
# Calculate average speedups
|
||||
avg_speedups = {
|
||||
"DeepGEMM vs vLLM Triton": 0,
|
||||
"DeepGEMM vs vLLM CUTLASS": 0,
|
||||
"vLLM CUTLASS vs vLLM Triton": 0
|
||||
}
|
||||
|
||||
for result in all_results:
|
||||
deepgemm_time = result["implementations"]["DeepGEMM"]["time_ms"]
|
||||
vllm_triton_time = result["implementations"]["vLLM Triton"]["time_ms"]
|
||||
vllm_cutlass_time = result["implementations"]["vLLM CUTLASS"][
|
||||
"time_ms"]
|
||||
|
||||
avg_speedups[
|
||||
"DeepGEMM vs vLLM Triton"] += vllm_triton_time / deepgemm_time
|
||||
avg_speedups[
|
||||
"DeepGEMM vs vLLM CUTLASS"] += vllm_cutlass_time / deepgemm_time
|
||||
avg_speedups[
|
||||
"vLLM CUTLASS vs vLLM Triton"] += vllm_triton_time / vllm_cutlass_time
|
||||
|
||||
print("\n===== AVERAGE SPEEDUPS =====")
|
||||
speedup_headers = ["Comparison", "Speedup"]
|
||||
speedup_rows = []
|
||||
for comparison, total in avg_speedups.items():
|
||||
avg_speedup = total / num_shapes
|
||||
status = "faster" if avg_speedup > 1 else "slower"
|
||||
speedup_rows.append([comparison, f"{avg_speedup:.2f}x {status}"])
|
||||
|
||||
print_table(speedup_headers, speedup_rows)
|
||||
|
||||
# Average accuracy comparison
|
||||
print("\n===== ACCURACY COMPARISON =====")
|
||||
avg_diff = {impl: 0 for impl in implementations}
|
||||
|
||||
for result in all_results:
|
||||
for impl in implementations:
|
||||
avg_diff[impl] += result["implementations"][impl]["diff"][
|
||||
"Reference"]
|
||||
|
||||
diff_headers = ["Implementation", "Avg Diff vs Reference"]
|
||||
diff_rows = []
|
||||
for impl in implementations:
|
||||
diff_rows.append([impl, f"{avg_diff[impl] / num_shapes:.6f}"])
|
||||
|
||||
print_table(diff_headers, diff_rows)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_benchmarks(verbose=False)
|
||||
65
benchmarks/run_structured_output_benchmark.sh
Executable file
65
benchmarks/run_structured_output_benchmark.sh
Executable file
@ -0,0 +1,65 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Define the model to use
|
||||
MODEL=${1:-"Qwen/Qwen2.5-7B-Instruct"}
|
||||
|
||||
# Define the backend to use
|
||||
BACKEND=${2:-"vllm"}
|
||||
|
||||
# Define the dataset to use
|
||||
DATASET=${3:-"xgrammar_bench"}
|
||||
|
||||
# Define the guided decoding backend
|
||||
GUIDED_BACKEND=${4:-"xgrammar"}
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
OUTPUT_DIR=${5:-"$SCRIPT_DIR/structured_output_benchmark_results"}
|
||||
|
||||
GUIDED_RATIO=${6:-0.5}
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
mkdir -p "$OUTPUT_DIR"
|
||||
|
||||
# Define QPS values to test
|
||||
QPS_VALUES=(70 60 50 25 20 15 10)
|
||||
|
||||
# Common parameters
|
||||
COMMON_PARAMS="--backend $BACKEND \
|
||||
--model $MODEL \
|
||||
--dataset $DATASET \
|
||||
--structured-output-backend $GUIDED_BACKEND \
|
||||
--structured-output-ratio $GUIDED_RATIO \
|
||||
--save-results \
|
||||
--result-dir $OUTPUT_DIR"
|
||||
|
||||
echo "Starting structured output benchmark with model: $MODEL"
|
||||
echo "Backend: $BACKEND"
|
||||
echo "Dataset: $DATASET"
|
||||
echo "Structured output backend: $GUIDED_BACKEND"
|
||||
echo "Results will be saved to: $OUTPUT_DIR"
|
||||
echo "----------------------------------------"
|
||||
|
||||
# Run benchmarks with different QPS values
|
||||
for qps in "${QPS_VALUES[@]}"; do
|
||||
echo "Running benchmark with QPS: $qps"
|
||||
|
||||
# Get git hash and branch for the filename
|
||||
GIT_HASH=$(git rev-parse --short HEAD 2>/dev/null || echo "unknown")
|
||||
GIT_BRANCH=$(git rev-parse --abbrev-ref HEAD 2>/dev/null || echo "unknown")
|
||||
|
||||
# Construct filename for this run
|
||||
FILENAME="${GUIDED_BACKEND}_${BACKEND}_${qps}qps_$(basename $MODEL)_${DATASET}_${GIT_HASH}.json"
|
||||
|
||||
# Run the benchmark
|
||||
python "$SCRIPT_DIR/benchmark_serving_structured_output.py" $COMMON_PARAMS \
|
||||
--request-rate $qps \
|
||||
--result-filename "$FILENAME" \
|
||||
--tokenizer-mode ${TOKENIZER_MODE:-"auto"} \
|
||||
--port ${PORT:-8000}
|
||||
|
||||
echo "Completed benchmark with QPS: $qps"
|
||||
echo "----------------------------------------"
|
||||
done
|
||||
|
||||
echo "All benchmarks completed!"
|
||||
echo "Results saved to: $OUTPUT_DIR"
|
||||
@ -1,113 +1,19 @@
|
||||
{
|
||||
"$schema":
|
||||
"https://json-schema.org/draft/2020-12/schema",
|
||||
"title":
|
||||
"User Profile",
|
||||
"type":
|
||||
"object",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"userId": {
|
||||
"type": "string",
|
||||
"description": "Unique identifier for the user."
|
||||
},
|
||||
"personalInfo": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"firstName": {
|
||||
"type": "string",
|
||||
"description": "The user's first name."
|
||||
},
|
||||
"lastName": {
|
||||
"type": "string",
|
||||
"description": "The user's last name."
|
||||
},
|
||||
"age": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"description": "The user's age."
|
||||
},
|
||||
"phoneNumbers": {
|
||||
"type":
|
||||
"array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"enum": ["home", "work", "mobile"],
|
||||
"description": "Type of phone number."
|
||||
},
|
||||
"number": {
|
||||
"type": "string",
|
||||
"pattern": "^\\+?[1-9]\\d{1,14}$",
|
||||
"description": "Phone number in E.164 format."
|
||||
}
|
||||
},
|
||||
"required": ["type", "number"]
|
||||
},
|
||||
"description":
|
||||
"List of phone numbers associated with the user."
|
||||
}
|
||||
},
|
||||
"required": ["firstName", "lastName"]
|
||||
},
|
||||
"address": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"street": {
|
||||
"type": "string",
|
||||
"description": "Street address."
|
||||
},
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "City name."
|
||||
},
|
||||
"state": {
|
||||
"type": "string",
|
||||
"description": "State or province."
|
||||
},
|
||||
"postalCode": {
|
||||
"type": "string",
|
||||
"pattern": "^\\d{5}(-\\d{4})?$",
|
||||
"description": "Postal code."
|
||||
},
|
||||
"country": {
|
||||
"type": "string",
|
||||
"description": "Country name."
|
||||
}
|
||||
},
|
||||
"required": ["street", "city", "state", "postalCode", "country"]
|
||||
},
|
||||
"preferences": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"newsletterSubscribed": {
|
||||
"type":
|
||||
"boolean",
|
||||
"description":
|
||||
"Indicates if the user is subscribed to the newsletter."
|
||||
},
|
||||
"favoriteCategories": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"description": "List of user's favorite categories."
|
||||
}
|
||||
},
|
||||
"required": ["newsletterSubscribed"]
|
||||
},
|
||||
"accountStatus": {
|
||||
"type": "string",
|
||||
"enum": ["active", "inactive", "suspended"],
|
||||
"description": "Current status of the user's account."
|
||||
},
|
||||
"registrationDate": {
|
||||
"type": "string",
|
||||
"format": "date-time",
|
||||
"description": "ISO 8601 formatted date-time of user registration."
|
||||
}
|
||||
"name": { "type": "string" },
|
||||
"email": { "type": "string" },
|
||||
"street": { "type": "string" },
|
||||
"city": { "type": "string" },
|
||||
"state": { "type": "string" },
|
||||
"zip": { "type": "string" },
|
||||
"phone": { "type": "string" },
|
||||
"website": { "type": "string" },
|
||||
"company": { "type": "string" },
|
||||
"age": { "type": "integer" }
|
||||
},
|
||||
"required":
|
||||
["userId", "personalInfo", "address", "accountStatus", "registrationDate"]
|
||||
}
|
||||
"required": [
|
||||
"name",
|
||||
"email"
|
||||
]
|
||||
}
|
||||
|
||||
@ -81,6 +81,7 @@ else()
|
||||
find_isa(${CPUINFO} "POWER9" POWER9_FOUND)
|
||||
find_isa(${CPUINFO} "asimd" ASIMD_FOUND) # Check for ARM NEON support
|
||||
find_isa(${CPUINFO} "bf16" ARM_BF16_FOUND) # Check for ARM BF16 support
|
||||
find_isa(${CPUINFO} "S390" S390_FOUND)
|
||||
endif()
|
||||
|
||||
|
||||
@ -129,8 +130,16 @@ elseif (ASIMD_FOUND)
|
||||
elseif(APPLE_SILICON_FOUND)
|
||||
message(STATUS "Apple Silicon Detected")
|
||||
set(ENABLE_NUMA OFF)
|
||||
elseif (S390_FOUND)
|
||||
message(STATUS "S390 detected")
|
||||
# Check for S390 VXE support
|
||||
list(APPEND CXX_COMPILE_FLAGS
|
||||
"-mvx"
|
||||
"-mzvector"
|
||||
"-march=native"
|
||||
"-mtune=native")
|
||||
else()
|
||||
message(FATAL_ERROR "vLLM CPU backend requires AVX512, AVX2, Power9+ ISA or ARMv8 support.")
|
||||
message(FATAL_ERROR "vLLM CPU backend requires AVX512, AVX2, Power9+ ISA, S390X ISA or ARMv8 support.")
|
||||
endif()
|
||||
|
||||
#
|
||||
@ -140,7 +149,7 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
FetchContent_Declare(
|
||||
oneDNN
|
||||
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
|
||||
GIT_TAG v3.6
|
||||
GIT_TAG v3.7.1
|
||||
GIT_PROGRESS TRUE
|
||||
GIT_SHALLOW TRUE
|
||||
)
|
||||
@ -181,6 +190,7 @@ set(VLLM_EXT_SRC
|
||||
"csrc/cpu/cache.cpp"
|
||||
"csrc/cpu/utils.cpp"
|
||||
"csrc/cpu/layernorm.cpp"
|
||||
"csrc/cpu/mla_decode.cpp"
|
||||
"csrc/cpu/pos_encoding.cpp"
|
||||
"csrc/cpu/torch_bindings.cpp")
|
||||
|
||||
|
||||
@ -38,7 +38,7 @@ else()
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn
|
||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||
GIT_TAG 720c94869cf2e0ff5a706e9c7f1dce0939686ade
|
||||
GIT_TAG dc9d410b3e2d6534a4c70724c2515f4def670a22
|
||||
GIT_PROGRESS TRUE
|
||||
# Don't share the vllm-flash-attn build between build types
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
@ -64,4 +64,4 @@ install(
|
||||
DESTINATION vllm_flash_attn
|
||||
COMPONENT _vllm_fa3_C
|
||||
FILES_MATCHING PATTERN "*.py"
|
||||
)
|
||||
)
|
||||
|
||||
@ -350,8 +350,8 @@ __global__ void concat_and_cache_mla_kernel(
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
// KV_T is the stored data type of kv-cache.
|
||||
// CACHE_T is the data type of key and value tensors.
|
||||
// KV_T is the data type of key and value tensors.
|
||||
// CACHE_T is the stored data type of kv-cache.
|
||||
// KV_DTYPE is the real data type of kv-cache.
|
||||
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||
@ -393,8 +393,8 @@ void reshape_and_cache(
|
||||
CALL_RESHAPE_AND_CACHE)
|
||||
}
|
||||
|
||||
// KV_T is the stored data type of kv-cache.
|
||||
// CACHE_T is the data type of key and value tensors.
|
||||
// KV_T is the data type of key and value tensors.
|
||||
// CACHE_T is the stored data type of kv-cache.
|
||||
// KV_DTYPE is the real data type of kv-cache.
|
||||
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||
@ -446,8 +446,8 @@ void reshape_and_cache_flash(
|
||||
CALL_RESHAPE_AND_CACHE_FLASH);
|
||||
}
|
||||
|
||||
// KV_T is the stored data type of kv-cache.
|
||||
// CACHE_T is the data type of key and value tensors.
|
||||
// KV_T is the data type of key and value tensors.
|
||||
// CACHE_T is the stored data type of kv-cache.
|
||||
// KV_DTYPE is the real data type of kv-cache.
|
||||
#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||
|
||||
@ -24,8 +24,8 @@ struct KernelVecType<float> {
|
||||
|
||||
template <>
|
||||
struct KernelVecType<c10::Half> {
|
||||
#ifdef __powerpc64__
|
||||
// Power architecture-specific vector types
|
||||
#if defined(__powerpc64__) || defined(__s390x__)
|
||||
// Power and s390x architecture-specific vector types
|
||||
using q_load_vec_type = vec_op::FP32Vec8;
|
||||
using k_load_vec_type = vec_op::FP32Vec16;
|
||||
using v_load_vec_type = vec_op::FP32Vec16;
|
||||
|
||||
@ -3,6 +3,12 @@
|
||||
|
||||
#include "cpu_types.hpp"
|
||||
|
||||
#if defined(__x86_64__)
|
||||
#define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES_WITH_E5M2
|
||||
#else
|
||||
#define DISPATCH_MACRO VLLM_DISPATCH_FLOATING_TYPES
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
template <typename scalar_t>
|
||||
void copy_blocks_cpu_impl(std::vector<torch::Tensor> const& key_caches,
|
||||
@ -82,6 +88,48 @@ void reshape_and_cache_cpu_impl(
|
||||
}
|
||||
}; // namespace
|
||||
|
||||
template <typename scalar_t>
|
||||
void concat_and_cache_mla_cpu_impl(
|
||||
const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
|
||||
const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim]
|
||||
scalar_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
|
||||
// + pe_dim)]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int num_tokens, //
|
||||
const int block_stride, //
|
||||
const int entry_stride, //
|
||||
const int kv_c_stride, //
|
||||
const int k_pe_stride, //
|
||||
const int kv_lora_rank, //
|
||||
const int pe_dim, //
|
||||
const int block_size //
|
||||
) {
|
||||
#pragma omp parallel for
|
||||
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
// NOTE: slot_idx can be -1 if the token is padded
|
||||
if (slot_idx < 0) {
|
||||
continue;
|
||||
}
|
||||
const int64_t block_idx = slot_idx / block_size;
|
||||
const int64_t block_offset = slot_idx % block_size;
|
||||
|
||||
auto copy = [&](const scalar_t* __restrict__ src,
|
||||
scalar_t* __restrict__ dst, int src_stride, int dst_stride,
|
||||
int size, int offset) {
|
||||
for (int i = 0; i < size; i++) {
|
||||
const int64_t src_idx = token_idx * src_stride + i;
|
||||
const int64_t dst_idx =
|
||||
block_idx * block_stride + block_offset * entry_stride + i + offset;
|
||||
dst[dst_idx] = src[src_idx];
|
||||
}
|
||||
};
|
||||
|
||||
copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0);
|
||||
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
|
||||
}
|
||||
}
|
||||
|
||||
// Note: the key_caches and value_caches vectors are constant but
|
||||
// not the Tensors they contain. The vectors need to be const refs
|
||||
// in order to satisfy pytorch's C++ operator registration code.
|
||||
@ -95,13 +143,12 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
|
||||
}
|
||||
|
||||
const int element_num_per_block = key_caches[0][0].numel();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
|
||||
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, block_mapping,
|
||||
element_num_per_block, num_layers);
|
||||
CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl)
|
||||
});
|
||||
DISPATCH_MACRO(key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
|
||||
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, block_mapping,
|
||||
element_num_per_block, num_layers);
|
||||
CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl)
|
||||
});
|
||||
}
|
||||
|
||||
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||
@ -118,15 +165,46 @@ void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||
int key_stride = key.stride(0);
|
||||
int value_stride = value.stride(0);
|
||||
|
||||
DISPATCH_MACRO(key.scalar_type(), "reshape_and_cache_cpu_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl)
|
||||
reshape_and_cache_cpu_impl<scalar_t>(
|
||||
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
|
||||
slot_mapping.data_ptr<int64_t>(), num_tokens, key_stride, value_stride,
|
||||
num_heads, head_size, block_size, x);
|
||||
CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl)
|
||||
});
|
||||
}
|
||||
|
||||
void concat_and_cache_mla(
|
||||
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
|
||||
torch::Tensor& k_pe, // [num_tokens, pe_dim]
|
||||
torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank +
|
||||
// pe_dim)]
|
||||
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
|
||||
const std::string& kv_cache_dtype, torch::Tensor& scale) {
|
||||
int num_tokens = slot_mapping.size(0);
|
||||
int kv_lora_rank = kv_c.size(1);
|
||||
int pe_dim = k_pe.size(1);
|
||||
int block_size = kv_cache.size(1);
|
||||
|
||||
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
|
||||
TORCH_CHECK(kv_cache_dtype != "fp8");
|
||||
|
||||
int kv_c_stride = kv_c.stride(0);
|
||||
int k_pe_stride = k_pe.stride(0);
|
||||
int block_stride = kv_cache.stride(0);
|
||||
int entry_stride = kv_cache.stride(1);
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
key.scalar_type(), "reshape_and_cache_cpu_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl)
|
||||
reshape_and_cache_cpu_impl<scalar_t>(
|
||||
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
|
||||
slot_mapping.data_ptr<int64_t>(), num_tokens, key_stride,
|
||||
value_stride, num_heads, head_size, block_size, x);
|
||||
CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl)
|
||||
kv_c.scalar_type(), "concat_and_cache_mla_cpu_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(concat_and_cache_mla_cpu_impl)
|
||||
concat_and_cache_mla_cpu_impl<scalar_t>(
|
||||
kv_c.data_ptr<scalar_t>(), k_pe.data_ptr<scalar_t>(),
|
||||
kv_cache.data_ptr<scalar_t>(), slot_mapping.data_ptr<int64_t>(),
|
||||
num_tokens, block_stride, entry_stride, kv_c_stride, k_pe_stride,
|
||||
kv_lora_rank, pe_dim, block_size);
|
||||
CPU_KERNEL_GUARD_OUT(concat_and_cache_mla_cpu_impl)
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@ -7,6 +7,9 @@
|
||||
#elif defined(__POWER9_VECTOR__)
|
||||
// ppc implementation
|
||||
#include "cpu_types_vsx.hpp"
|
||||
#elif defined(__s390x__)
|
||||
// s390 implementation
|
||||
#include "cpu_types_vxe.hpp"
|
||||
#elif defined(__aarch64__)
|
||||
// arm implementation
|
||||
#include "cpu_types_arm.hpp"
|
||||
|
||||
480
csrc/cpu/cpu_types_vxe.hpp
Normal file
480
csrc/cpu/cpu_types_vxe.hpp
Normal file
@ -0,0 +1,480 @@
|
||||
|
||||
#ifndef CPU_TYPES_VXE_HPP
|
||||
#define CPU_TYPES_VXE_HPP
|
||||
|
||||
#include <vecintrin.h>
|
||||
#include <cmath>
|
||||
#include <torch/all.h>
|
||||
namespace vec_op {
|
||||
|
||||
#define vec_neg(a) (-(a))
|
||||
#define vec_add(a, b) ((a) + (b))
|
||||
#define vec_sub(a, b) ((a) - (b))
|
||||
#define vec_mul(a, b) ((a) * (b))
|
||||
#define vec_div(a, b) ((a) / (b))
|
||||
#define vec_sr(a, b) ((a) >> (b)) // Vector Shift Right Algebaic
|
||||
#define vec_sl(a, b) ((a) << (b)) // Vector Shift Left
|
||||
|
||||
// FIXME: FP16 is not fully supported in Torch-CPU
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||
|
||||
#ifndef CPU_OP_GUARD
|
||||
#define CPU_KERNEL_GUARD_IN(NAME)
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME)
|
||||
#else
|
||||
#define CPU_KERNEL_GUARD_IN(NAME) \
|
||||
std::cout << #NAME << " invoked." << std::endl;
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME) \
|
||||
std::cout << #NAME << " exit." << std::endl;
|
||||
#endif
|
||||
|
||||
#define FORCE_INLINE __attribute__((always_inline)) inline
|
||||
|
||||
namespace {
|
||||
template <typename T, T... indexes, typename F>
|
||||
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F&& f) {
|
||||
(f(std::integral_constant<T, indexes>{}), ...);
|
||||
}
|
||||
}; // namespace
|
||||
|
||||
template <typename T, T count, typename F,
|
||||
typename = std::enable_if_t<std::is_invocable_v<F, T>>>
|
||||
constexpr void unroll_loop(F&& f) {
|
||||
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct Vec {
|
||||
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }
|
||||
};
|
||||
|
||||
typedef struct ss16x8x2_t {
|
||||
__vector signed short val[2];
|
||||
} ss16x8x2_t;
|
||||
|
||||
typedef struct ss16x8x4_t {
|
||||
__vector signed short val[4];
|
||||
} ss16x8x4_t;
|
||||
|
||||
typedef struct f32x4x2_t {
|
||||
__vector float val[2];
|
||||
} f32x4x2_t;
|
||||
|
||||
typedef struct f32x4x4_t {
|
||||
__vector float val[4];
|
||||
} f32x4x4_t;
|
||||
|
||||
struct FP32Vec8;
|
||||
struct FP32Vec16;
|
||||
|
||||
struct BF16Vec8 : public Vec<BF16Vec8> {
|
||||
constexpr static int VEC_ELEM_NUM = 8;
|
||||
|
||||
__vector signed short reg;
|
||||
|
||||
explicit BF16Vec8(const void* ptr) : reg(*(__vector signed short*)ptr) {}
|
||||
explicit BF16Vec8(const FP32Vec8&);
|
||||
|
||||
void save(void* ptr) const {
|
||||
*reinterpret_cast<__vector signed short*>(ptr) = reg;
|
||||
}
|
||||
};
|
||||
|
||||
struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||
constexpr static int VEC_ELEM_NUM = 16;
|
||||
|
||||
ss16x8x2_t reg;
|
||||
|
||||
explicit BF16Vec16(const void* ptr) {
|
||||
// Load 256 bits in two parts
|
||||
reg.val[0] = (__vector signed short)vec_xl(0, (signed short*)ptr);
|
||||
reg.val[1] = (__vector signed short)vec_xl(16, (signed short*)ptr);
|
||||
}
|
||||
|
||||
explicit BF16Vec16(const FP32Vec16&);
|
||||
|
||||
void save(void* ptr) const {
|
||||
// Save 256 bits in two parts
|
||||
vec_xst(reg.val[0], 0, (signed short*)ptr);
|
||||
vec_xst(reg.val[1], 16, (signed short*)ptr);
|
||||
}
|
||||
};
|
||||
|
||||
const static __vector signed short zero = vec_splats((signed short)0);
|
||||
|
||||
struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
constexpr static int VEC_ELEM_NUM = 32;
|
||||
|
||||
ss16x8x4_t reg;
|
||||
explicit BF16Vec32(const void* ptr)
|
||||
: reg(*reinterpret_cast<const ss16x8x4_t*>(ptr)) {}
|
||||
|
||||
explicit BF16Vec32(ss16x8x4_t data) : reg(data) {}
|
||||
|
||||
explicit BF16Vec32(const BF16Vec8& vec8_data)
|
||||
: reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {}
|
||||
|
||||
void save(void* ptr) const { *reinterpret_cast<ss16x8x4_t*>(ptr) = reg; }
|
||||
};
|
||||
|
||||
struct FP32Vec4 : public Vec<FP32Vec4> {
|
||||
constexpr static int VEC_ELEM_NUM = 4;
|
||||
union AliasReg {
|
||||
__vector float reg;
|
||||
float values[VEC_ELEM_NUM];
|
||||
};
|
||||
|
||||
__vector float reg;
|
||||
|
||||
explicit FP32Vec4(float v) : reg(vec_splats(v)) {}
|
||||
|
||||
explicit FP32Vec4() : reg(vec_splats(0.0f)) {}
|
||||
|
||||
explicit FP32Vec4(const float* ptr) : reg(vec_xl(0, ptr)) {}
|
||||
|
||||
explicit FP32Vec4(__vector float data) : reg(data) {}
|
||||
|
||||
explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {}
|
||||
};
|
||||
|
||||
struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
constexpr static int VEC_ELEM_NUM = 8;
|
||||
union AliasReg {
|
||||
f32x4x2_t reg;
|
||||
float values[VEC_ELEM_NUM];
|
||||
};
|
||||
|
||||
f32x4x2_t reg;
|
||||
|
||||
explicit FP32Vec8(float v) {
|
||||
reg.val[0] = vec_splats(v);
|
||||
reg.val[1] = vec_splats(v);
|
||||
}
|
||||
|
||||
explicit FP32Vec8() {
|
||||
reg.val[0] = vec_splats(0.0f);
|
||||
reg.val[1] = vec_splats(0.0f);
|
||||
}
|
||||
|
||||
explicit FP32Vec8(const float* ptr) {
|
||||
reg.val[0] = vec_xl(0, ptr);
|
||||
reg.val[1] = vec_xl(16, ptr);
|
||||
}
|
||||
|
||||
explicit FP32Vec8(f32x4x2_t data) : reg(data) {}
|
||||
|
||||
explicit FP32Vec8(const FP32Vec8& data) {
|
||||
reg.val[0] = data.reg.val[0];
|
||||
reg.val[1] = data.reg.val[1];
|
||||
}
|
||||
|
||||
explicit FP32Vec8(const BF16Vec8& v) {
|
||||
reg.val[0] = (__vector float)vec_mergeh(zero, v.reg);
|
||||
reg.val[1] = (__vector float)vec_mergel(zero, v.reg);
|
||||
}
|
||||
|
||||
float reduce_sum() const {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
float result = 0;
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&result, &ar](int i) { result += ar.values[i]; });
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
FP32Vec8 exp() const {
|
||||
// TODO: Vectorize this
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
f32x4x4_t ret;
|
||||
ret.val[0][0] = std::exp(ar.values[0]);
|
||||
ret.val[0][1] = std::exp(ar.values[1]);
|
||||
ret.val[0][2] = std::exp(ar.values[2]);
|
||||
ret.val[0][3] = std::exp(ar.values[3]);
|
||||
ret.val[1][0] = std::exp(ar.values[4]);
|
||||
ret.val[1][1] = std::exp(ar.values[5]);
|
||||
ret.val[1][2] = std::exp(ar.values[6]);
|
||||
ret.val[1][3] = std::exp(ar.values[7]);
|
||||
return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
|
||||
}
|
||||
|
||||
FP32Vec8 tanh() const {
|
||||
// TODO: Vectorize this
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
f32x4x4_t ret;
|
||||
ret.val[0][0] = std::tanh(ar.values[0]);
|
||||
ret.val[0][1] = std::tanh(ar.values[1]);
|
||||
ret.val[0][2] = std::tanh(ar.values[2]);
|
||||
ret.val[0][3] = std::tanh(ar.values[3]);
|
||||
ret.val[1][0] = std::tanh(ar.values[4]);
|
||||
ret.val[1][1] = std::tanh(ar.values[5]);
|
||||
ret.val[1][2] = std::tanh(ar.values[6]);
|
||||
ret.val[1][3] = std::tanh(ar.values[7]);
|
||||
return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
|
||||
}
|
||||
|
||||
FP32Vec8 er() const {
|
||||
// TODO: Vectorize this
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
f32x4x4_t ret;
|
||||
ret.val[0][0] = std::erf(ar.values[0]);
|
||||
ret.val[0][1] = std::erf(ar.values[1]);
|
||||
ret.val[0][2] = std::erf(ar.values[2]);
|
||||
ret.val[0][3] = std::erf(ar.values[3]);
|
||||
ret.val[1][0] = std::erf(ar.values[4]);
|
||||
ret.val[1][1] = std::erf(ar.values[5]);
|
||||
ret.val[1][2] = std::erf(ar.values[6]);
|
||||
ret.val[1][3] = std::erf(ar.values[7]);
|
||||
return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
|
||||
}
|
||||
|
||||
FP32Vec8 operator*(const FP32Vec8& b) const {
|
||||
return FP32Vec8(
|
||||
{vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1])});
|
||||
}
|
||||
|
||||
FP32Vec8 operator+(const FP32Vec8& b) const {
|
||||
return FP32Vec8(
|
||||
{vec_add(reg.val[0], b.reg.val[0]), vec_add(reg.val[1], b.reg.val[1])});
|
||||
}
|
||||
|
||||
FP32Vec8 operator-(const FP32Vec8& b) const {
|
||||
return FP32Vec8(
|
||||
{vec_sub(reg.val[0], b.reg.val[0]), vec_sub(reg.val[1], b.reg.val[1])});
|
||||
}
|
||||
|
||||
FP32Vec8 operator/(const FP32Vec8& b) const {
|
||||
return FP32Vec8(
|
||||
{vec_div(reg.val[0], b.reg.val[0]), vec_div(reg.val[1], b.reg.val[1])});
|
||||
}
|
||||
|
||||
void save(float* ptr) const {
|
||||
vec_xst(reg.val[0], 0, ptr);
|
||||
vec_xst(reg.val[1], 16, ptr);
|
||||
}
|
||||
};
|
||||
|
||||
struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
constexpr static int VEC_ELEM_NUM = 16;
|
||||
union AliasReg {
|
||||
f32x4x4_t reg;
|
||||
float values[VEC_ELEM_NUM];
|
||||
};
|
||||
|
||||
f32x4x4_t reg;
|
||||
|
||||
explicit FP32Vec16(float v) {
|
||||
reg.val[0] = vec_splats(v);
|
||||
reg.val[1] = vec_splats(v);
|
||||
reg.val[2] = vec_splats(v);
|
||||
reg.val[3] = vec_splats(v);
|
||||
}
|
||||
|
||||
explicit FP32Vec16() {
|
||||
reg.val[0] = vec_splats(0.0f);
|
||||
reg.val[1] = vec_splats(0.0f);
|
||||
reg.val[2] = vec_splats(0.0f);
|
||||
reg.val[3] = vec_splats(0.0f);
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const float* ptr) {
|
||||
reg.val[0] = vec_xl(0, ptr);
|
||||
reg.val[1] = vec_xl(16, ptr);
|
||||
reg.val[2] = vec_xl(32, ptr);
|
||||
reg.val[3] = vec_xl(48, ptr);
|
||||
}
|
||||
|
||||
explicit FP32Vec16(f32x4x4_t data) : reg(data) {}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec16& data) {
|
||||
reg.val[0] = data.reg.val[0];
|
||||
reg.val[1] = data.reg.val[1];
|
||||
reg.val[2] = data.reg.val[2];
|
||||
reg.val[3] = data.reg.val[3];
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec4& data) {
|
||||
reg.val[0] = data.reg;
|
||||
reg.val[1] = data.reg;
|
||||
reg.val[2] = data.reg;
|
||||
reg.val[3] = data.reg;
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec8& data) {
|
||||
reg.val[0] = data.reg.val[0];
|
||||
reg.val[1] = data.reg.val[1];
|
||||
reg.val[2] = data.reg.val[0];
|
||||
reg.val[3] = data.reg.val[1];
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const BF16Vec16& v) {
|
||||
reg.val[0] = (__vector float)vec_mergeh(zero, v.reg.val[0]);
|
||||
reg.val[1] = (__vector float)vec_mergel(zero, v.reg.val[0]);
|
||||
reg.val[2] = (__vector float)vec_mergeh(zero, v.reg.val[1]);
|
||||
reg.val[3] = (__vector float)vec_mergel(zero, v.reg.val[1]);
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
|
||||
|
||||
FP32Vec16 operator*(const FP32Vec16& b) const {
|
||||
return FP32Vec16(f32x4x4_t({vec_mul(reg.val[0], b.reg.val[0]),
|
||||
vec_mul(reg.val[1], b.reg.val[1]),
|
||||
vec_mul(reg.val[2], b.reg.val[2]),
|
||||
vec_mul(reg.val[3], b.reg.val[3])}));
|
||||
}
|
||||
|
||||
FP32Vec16 operator+(const FP32Vec16& b) const {
|
||||
return FP32Vec16(f32x4x4_t({vec_add(reg.val[0], b.reg.val[0]),
|
||||
vec_add(reg.val[1], b.reg.val[1]),
|
||||
vec_add(reg.val[2], b.reg.val[2]),
|
||||
vec_add(reg.val[3], b.reg.val[3])}));
|
||||
}
|
||||
|
||||
FP32Vec16 operator-(const FP32Vec16& b) const {
|
||||
return FP32Vec16(f32x4x4_t({vec_sub(reg.val[0], b.reg.val[0]),
|
||||
vec_sub(reg.val[1], b.reg.val[1]),
|
||||
vec_sub(reg.val[2], b.reg.val[2]),
|
||||
vec_sub(reg.val[3], b.reg.val[3])}));
|
||||
}
|
||||
|
||||
FP32Vec16 operator/(const FP32Vec16& b) const {
|
||||
return FP32Vec16(f32x4x4_t({vec_div(reg.val[0], b.reg.val[0]),
|
||||
vec_div(reg.val[1], b.reg.val[1]),
|
||||
vec_div(reg.val[2], b.reg.val[2]),
|
||||
vec_div(reg.val[3], b.reg.val[3])}));
|
||||
}
|
||||
|
||||
float reduce_sum() const {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
float result = 0;
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&result, &ar](int i) { result += ar.values[i]; });
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
template <int group_size>
|
||||
float reduce_sub_sum(int idx) {
|
||||
static_assert(VEC_ELEM_NUM % group_size == 0);
|
||||
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
float result = 0;
|
||||
const int start = idx * group_size;
|
||||
unroll_loop<int, group_size>(
|
||||
[&result, &start, ar](int i) { result += ar.values[start + i]; });
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
void save(float* ptr) const {
|
||||
vec_xst(reg.val[0], 0, ptr);
|
||||
vec_xst(reg.val[1], 16, ptr);
|
||||
vec_xst(reg.val[2], 32, ptr);
|
||||
vec_xst(reg.val[3], 48, ptr);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct VecType {
|
||||
using vec_type = void;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using vec_t = typename VecType<T>::vec_type;
|
||||
|
||||
template <>
|
||||
struct VecType<float> {
|
||||
using vec_type = FP32Vec8;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct VecType<c10::BFloat16> {
|
||||
using vec_type = BF16Vec8;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void storeFP32(float v, T* ptr) {
|
||||
*ptr = v;
|
||||
}
|
||||
|
||||
inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
|
||||
acc = acc + a * b;
|
||||
}
|
||||
|
||||
namespace c10 {
|
||||
struct BFloat16 {
|
||||
uint16_t value; // Assume BFloat16 is defined as a struct containing a 16-bit
|
||||
// value.
|
||||
};
|
||||
} // namespace c10
|
||||
|
||||
template <>
|
||||
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
|
||||
c10::BFloat16 __attribute__((__may_alias__))* v_ptr =
|
||||
reinterpret_cast<c10::BFloat16*>(&v);
|
||||
*ptr = *(v_ptr + 1);
|
||||
}
|
||||
|
||||
#ifndef __VEC_CLASS_FP_NAN
|
||||
#define __VEC_CLASS_FP_NAN (1 << 6)
|
||||
#endif
|
||||
|
||||
const static __vector unsigned char omask = {2, 3, 6, 7, 10, 11, 14, 15,
|
||||
18, 19, 22, 23, 26, 27, 30, 31};
|
||||
const static __vector unsigned int bias = {0x00007fff, 0x00007fff, 0x00007fff,
|
||||
0x00007fff};
|
||||
const static __vector unsigned int nan = {0x7fc00000, 0x7fc00000, 0x7fc00000,
|
||||
0x7fc00000};
|
||||
const static __vector unsigned int sh16 = {16, 16, 16, 16};
|
||||
const static __vector unsigned int one = {1, 1, 1, 1};
|
||||
|
||||
inline BF16Vec8::BF16Vec8(const FP32Vec8& v) {
|
||||
__vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]);
|
||||
__vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]);
|
||||
int cc;
|
||||
__vector __bool int sel0 =
|
||||
vec_fp_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN, &cc);
|
||||
__vector __bool int sel1 =
|
||||
vec_fp_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN, &cc);
|
||||
inp0 = vec_sel(inp0, nan, sel0) >> sh16;
|
||||
inp1 = vec_sel(inp1, nan, sel1) >> sh16;
|
||||
reg = (__vector signed short)vec_perm(inp0, inp1, omask);
|
||||
}
|
||||
|
||||
inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
|
||||
__vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]);
|
||||
__vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]);
|
||||
__vector unsigned int inp2 = (__vector unsigned int)(v.reg.val[2]);
|
||||
__vector unsigned int inp3 = (__vector unsigned int)(v.reg.val[3]);
|
||||
int cc;
|
||||
__vector __bool int sel0 =
|
||||
vec_fp_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN, &cc);
|
||||
__vector __bool int sel1 =
|
||||
vec_fp_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN, &cc);
|
||||
__vector __bool int sel2 =
|
||||
vec_fp_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN, &cc);
|
||||
__vector __bool int sel3 =
|
||||
vec_fp_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN, &cc);
|
||||
inp0 = vec_sel(inp0, nan, sel0) >> sh16;
|
||||
inp1 = vec_sel(inp1, nan, sel1) >> sh16;
|
||||
inp2 = vec_sel(inp2, nan, sel2) >> sh16;
|
||||
inp3 = vec_sel(inp3, nan, sel3) >> sh16;
|
||||
reg.val[0] = (__vector signed short)vec_perm(inp0, inp1, omask);
|
||||
reg.val[1] = (__vector signed short)vec_perm(inp2, inp3, omask);
|
||||
}
|
||||
|
||||
inline void prefetch(const void* addr) { void __dcbt(const void* addr); }
|
||||
|
||||
}; // namespace vec_op
|
||||
|
||||
#endif
|
||||
@ -16,9 +16,18 @@ namespace vec_op {
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES_FP8(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES_WITH_E5M2(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, \
|
||||
VLLM_DISPATCH_CASE_FLOATING_TYPES_FP8(__VA_ARGS__))
|
||||
|
||||
#ifndef CPU_OP_GUARD
|
||||
#define CPU_KERNEL_GUARD_IN(NAME)
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME)
|
||||
@ -121,6 +130,8 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
|
||||
__m512i reg;
|
||||
|
||||
explicit BF16Vec32() : reg(_mm512_setzero_si512()) {}
|
||||
|
||||
explicit BF16Vec32(const void* ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {}
|
||||
|
||||
explicit BF16Vec32(__m512i data) : reg(data) {}
|
||||
|
||||
393
csrc/cpu/mla_decode.cpp
Normal file
393
csrc/cpu/mla_decode.cpp
Normal file
@ -0,0 +1,393 @@
|
||||
#include "cpu_types.hpp"
|
||||
#include <float.h>
|
||||
|
||||
namespace {
|
||||
template <typename scalar_t>
|
||||
struct KernelVecType {
|
||||
using qk_load_vec_type = void;
|
||||
using qk_vec_type = void;
|
||||
using v_load_vec_type = void;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KernelVecType<float> {
|
||||
using qk_load_vec_type = vec_op::FP32Vec16;
|
||||
using qk_vec_type = vec_op::FP32Vec16;
|
||||
using v_load_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KernelVecType<c10::Half> {
|
||||
#if defined(__powerpc64__) || defined(__s390x__)
|
||||
// Power and s390x architecture-specific vector types
|
||||
using qk_load_vec_type = vec_op::FP32Vec16;
|
||||
using qk_vec_type = vec_op::FP32Vec16;
|
||||
using v_load_vec_type = vec_op::FP32Vec16;
|
||||
#else
|
||||
// Fallback for other architectures, including x86
|
||||
using qk_load_vec_type = vec_op::FP16Vec16;
|
||||
using qk_vec_type = vec_op::FP32Vec16;
|
||||
using v_load_vec_type = vec_op::FP16Vec16;
|
||||
#endif
|
||||
};
|
||||
|
||||
#ifdef __AVX512BF16__
|
||||
template <>
|
||||
struct KernelVecType<c10::BFloat16> {
|
||||
using qk_load_vec_type = vec_op::BF16Vec32;
|
||||
using qk_vec_type = vec_op::BF16Vec32;
|
||||
using v_load_vec_type = vec_op::BF16Vec16;
|
||||
};
|
||||
#elif defined(__aarch64__) && !defined(ARM_BF16_SUPPORT)
|
||||
// pass
|
||||
#else
|
||||
template <>
|
||||
struct KernelVecType<c10::BFloat16> {
|
||||
using qk_load_vec_type = vec_op::BF16Vec16;
|
||||
using qk_vec_type = vec_op::FP32Vec16;
|
||||
using v_load_vec_type = vec_op::BF16Vec16;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <int HEAD_DIM, int V_HEAD_DIM, int BLOCK_SIZE, int HEAD_UNROLL,
|
||||
typename qk_vec_type>
|
||||
void mla_decode_block_head(
|
||||
const qk_vec_type* __restrict__ q_vecs, // [HEAD_UNROLL, head_dim]
|
||||
const qk_vec_type* __restrict__ k_vecs, // [block_size, head_dim]
|
||||
const vec_op::FP32Vec16* __restrict v_vecs_f32, // [block_size, v_head_dim]
|
||||
float* __restrict__ acc_out, // [HEAD_UNROLL, v_head_dim]
|
||||
float* __restrict__ acc_lse, // [HEAD_UNROLL]
|
||||
const float scale, const int num_tokens) {
|
||||
using f32_vec_type = vec_op::FP32Vec16;
|
||||
constexpr int QK_NUM_ELEM = qk_vec_type::VEC_ELEM_NUM;
|
||||
constexpr int V_NUM_ELEM = f32_vec_type::VEC_ELEM_NUM;
|
||||
|
||||
float logits[BLOCK_SIZE][HEAD_UNROLL] = {}; // initialize to zeros
|
||||
float max_val[HEAD_UNROLL];
|
||||
std::fill(max_val, max_val + HEAD_UNROLL, -FLT_MAX);
|
||||
|
||||
f32_vec_type acc_vec[BLOCK_SIZE][HEAD_UNROLL];
|
||||
for (int i = 0; i < HEAD_DIM; i += QK_NUM_ELEM) {
|
||||
// load to registers
|
||||
qk_vec_type q_vec[HEAD_UNROLL];
|
||||
|
||||
#pragma unroll
|
||||
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll)
|
||||
q_vec[unroll] =
|
||||
qk_vec_type{q_vecs[(i + unroll * HEAD_DIM) / QK_NUM_ELEM]};
|
||||
|
||||
for (int block_offset = 0; block_offset < num_tokens; ++block_offset) {
|
||||
qk_vec_type k_vec(k_vecs[(block_offset * HEAD_DIM + i) / QK_NUM_ELEM]);
|
||||
|
||||
#pragma unroll
|
||||
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll)
|
||||
vec_op::fma(acc_vec[block_offset][unroll], q_vec[unroll], k_vec);
|
||||
}
|
||||
}
|
||||
|
||||
for (int block_offset = 0; block_offset < num_tokens; ++block_offset) {
|
||||
#pragma unroll
|
||||
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) {
|
||||
const float acc = acc_vec[block_offset][unroll].reduce_sum() * scale;
|
||||
logits[block_offset][unroll] = acc;
|
||||
max_val[unroll] = std::max(max_val[unroll], acc);
|
||||
}
|
||||
}
|
||||
|
||||
float sum_exp[HEAD_UNROLL] = {};
|
||||
for (int block_offset = 0; block_offset < num_tokens; ++block_offset) {
|
||||
#pragma unroll
|
||||
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) {
|
||||
const float val =
|
||||
std::exp(logits[block_offset][unroll] - max_val[unroll]);
|
||||
logits[block_offset][unroll] = val;
|
||||
sum_exp[unroll] += val;
|
||||
}
|
||||
}
|
||||
|
||||
f32_vec_type this_out[V_HEAD_DIM / V_NUM_ELEM][HEAD_UNROLL];
|
||||
|
||||
for (int block_offset = 0; block_offset < num_tokens; ++block_offset) {
|
||||
// load to registers
|
||||
f32_vec_type scale_[HEAD_UNROLL];
|
||||
|
||||
#pragma unroll
|
||||
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll)
|
||||
scale_[unroll] =
|
||||
f32_vec_type{logits[block_offset][unroll] / sum_exp[unroll]};
|
||||
|
||||
for (int i = 0; i < V_HEAD_DIM; i += V_NUM_ELEM) {
|
||||
f32_vec_type v_vec(
|
||||
v_vecs_f32[(block_offset * HEAD_DIM + i) / V_NUM_ELEM]);
|
||||
|
||||
#pragma unroll
|
||||
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll)
|
||||
vec_op::fma(this_out[i / V_NUM_ELEM][unroll], v_vec, scale_[unroll]);
|
||||
}
|
||||
}
|
||||
|
||||
// merge attention state
|
||||
// section 2.2 in https://arxiv.org/pdf/2501.01005
|
||||
f32_vec_type prev_scale[HEAD_UNROLL];
|
||||
f32_vec_type curr_scale[HEAD_UNROLL];
|
||||
|
||||
#pragma unroll
|
||||
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) {
|
||||
const float prev_lse = acc_lse[unroll];
|
||||
const float curr_lse = std::log(sum_exp[unroll]) +
|
||||
max_val[unroll]; // add back max_val to get true lse
|
||||
// softmax trick
|
||||
const float max_lse = std::max(prev_lse, curr_lse);
|
||||
const float prev_sum_exp = std::exp(prev_lse - max_lse);
|
||||
const float curr_sum_exp = std::exp(curr_lse - max_lse);
|
||||
|
||||
const float new_sum_exp = prev_sum_exp + curr_sum_exp;
|
||||
acc_lse[unroll] = std::log(new_sum_exp) + max_lse;
|
||||
|
||||
prev_scale[unroll] = f32_vec_type{prev_sum_exp / new_sum_exp};
|
||||
curr_scale[unroll] = f32_vec_type{curr_sum_exp / new_sum_exp};
|
||||
}
|
||||
|
||||
for (int i = 0; i < V_HEAD_DIM; i += V_NUM_ELEM) {
|
||||
#pragma unroll
|
||||
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) {
|
||||
f32_vec_type o_vec(acc_out + i + V_HEAD_DIM * unroll);
|
||||
o_vec = o_vec * prev_scale[unroll] +
|
||||
this_out[i / V_NUM_ELEM][unroll] * curr_scale[unroll];
|
||||
o_vec.save(acc_out + i + V_HEAD_DIM * unroll);
|
||||
}
|
||||
}
|
||||
|
||||
q_vecs += HEAD_DIM / QK_NUM_ELEM * HEAD_UNROLL;
|
||||
acc_out += V_HEAD_DIM * HEAD_UNROLL;
|
||||
}
|
||||
|
||||
template <typename scalar_t, int HEAD_DIM, int V_HEAD_DIM, int BLOCK_SIZE,
|
||||
typename qk_vec_type>
|
||||
void mla_decode_block(
|
||||
const qk_vec_type* __restrict__ q_vecs, // [num_heads, head_dim]
|
||||
const scalar_t* __restrict__ kv_cache, // [block_size, head_dim]
|
||||
float* __restrict__ acc_out, // [num_heads, v_head_dim]
|
||||
float* __restrict__ acc_lse, // [num_heads]
|
||||
const int num_heads, const float scale, const int num_tokens) {
|
||||
using qk_load_vec_type = typename KernelVecType<scalar_t>::qk_load_vec_type;
|
||||
static_assert(
|
||||
std::is_same<qk_vec_type,
|
||||
typename KernelVecType<scalar_t>::qk_vec_type>::value);
|
||||
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
|
||||
using f32_vec_type = vec_op::FP32Vec16;
|
||||
static_assert(qk_load_vec_type::VEC_ELEM_NUM == qk_vec_type::VEC_ELEM_NUM);
|
||||
static_assert(v_load_vec_type::VEC_ELEM_NUM == f32_vec_type::VEC_ELEM_NUM);
|
||||
constexpr int QK_NUM_ELEM = qk_vec_type::VEC_ELEM_NUM;
|
||||
constexpr int V_NUM_ELEM = v_load_vec_type::VEC_ELEM_NUM;
|
||||
|
||||
const qk_vec_type* k_vecs;
|
||||
const f32_vec_type* v_vecs_f32;
|
||||
float* kv_cache_f32 = nullptr;
|
||||
|
||||
if constexpr (!std::is_same<scalar_t, float>::value) {
|
||||
// convert KV cache block to FP32 to reuse it across query heads and
|
||||
// attn @ V computation, since FP16/BF16->FP32 is expensive.
|
||||
// TODO: move malloc outside of this fn to reuse across iterations.
|
||||
const int nbytes = BLOCK_SIZE * HEAD_DIM * sizeof(float);
|
||||
kv_cache_f32 = static_cast<float*>(std::aligned_alloc(64, nbytes));
|
||||
|
||||
for (int block_offset = 0; block_offset < num_tokens; ++block_offset)
|
||||
for (int i = 0; i < HEAD_DIM; i += V_NUM_ELEM) {
|
||||
v_load_vec_type kv_load_vec(kv_cache + block_offset * HEAD_DIM + i);
|
||||
f32_vec_type kv_vec_f32(kv_load_vec);
|
||||
kv_vec_f32.save(kv_cache_f32 + block_offset * HEAD_DIM + i);
|
||||
}
|
||||
|
||||
if constexpr (std::is_same<qk_load_vec_type, qk_vec_type>::value) {
|
||||
// for AVX512_BF16, Q @ K.T uses BF16 for K (no conversion)
|
||||
// NOTE: in this case, we only need to convert the V section to FP32.
|
||||
// But for simplicity, we will convert the whole KV block to FP32.
|
||||
k_vecs = reinterpret_cast<const qk_vec_type*>(kv_cache);
|
||||
} else {
|
||||
k_vecs = reinterpret_cast<const qk_vec_type*>(kv_cache_f32);
|
||||
}
|
||||
|
||||
// attn @ V always use FP32 for V, since attn is FP32.
|
||||
v_vecs_f32 = reinterpret_cast<const f32_vec_type*>(kv_cache_f32);
|
||||
|
||||
} else {
|
||||
// KV cache is FP32. don't need to do anything.
|
||||
k_vecs = reinterpret_cast<const qk_vec_type*>(kv_cache);
|
||||
v_vecs_f32 = reinterpret_cast<const f32_vec_type*>(kv_cache);
|
||||
}
|
||||
|
||||
// compute 2 heads at the same time to improve ILP and
|
||||
// take advantage of register cache for K and V.
|
||||
constexpr int HEAD_UNROLL = 2;
|
||||
for (int iter = 0; iter < num_heads / HEAD_UNROLL; ++iter) {
|
||||
mla_decode_block_head<HEAD_DIM, V_HEAD_DIM, BLOCK_SIZE, HEAD_UNROLL>(
|
||||
q_vecs, k_vecs, v_vecs_f32, acc_out, acc_lse, scale, num_tokens);
|
||||
|
||||
q_vecs += HEAD_UNROLL * HEAD_DIM / QK_NUM_ELEM;
|
||||
acc_out += HEAD_UNROLL * V_HEAD_DIM;
|
||||
acc_lse += HEAD_UNROLL;
|
||||
}
|
||||
|
||||
// take care of the remaining heads
|
||||
for (int iter = 0; iter < num_heads % HEAD_UNROLL; ++iter) {
|
||||
mla_decode_block_head<HEAD_DIM, V_HEAD_DIM, BLOCK_SIZE, 1>(
|
||||
q_vecs, k_vecs, v_vecs_f32, acc_out, acc_lse, scale, num_tokens);
|
||||
|
||||
q_vecs += HEAD_DIM / QK_NUM_ELEM;
|
||||
acc_out += V_HEAD_DIM;
|
||||
acc_lse += 1;
|
||||
}
|
||||
|
||||
if (kv_cache_f32 != nullptr) {
|
||||
std::free(kv_cache_f32);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <typename scalar_t, int HEAD_DIM, int V_HEAD_DIM, int BLOCK_SIZE>
|
||||
void mla_decode_kvcache_cpu_impl(
|
||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, v_head_dim]
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_dim]
|
||||
const scalar_t* __restrict__ kv_cache, // [num_blocks, block_size,
|
||||
// head_dim]
|
||||
const int num_heads, const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int max_num_blocks_per_seq, const int o_stride, const int q_stride,
|
||||
const int kv_stride, const int num_seqs) {
|
||||
using qk_load_vec_type = typename KernelVecType<scalar_t>::qk_load_vec_type;
|
||||
using qk_vec_type = typename KernelVecType<scalar_t>::qk_vec_type;
|
||||
constexpr int QK_NUM_ELEM = qk_vec_type::VEC_ELEM_NUM;
|
||||
|
||||
// shared across threads
|
||||
const int max_threads = omp_get_max_threads();
|
||||
const int acc_out_nbytes =
|
||||
max_threads * num_heads * V_HEAD_DIM * sizeof(float);
|
||||
float* acc_out = static_cast<float*>(std::aligned_alloc(64, acc_out_nbytes));
|
||||
std::vector<float> acc_lse(max_threads * num_heads);
|
||||
|
||||
// allocate memory to pre-convert query to FP32 later
|
||||
float* q_f32;
|
||||
constexpr bool PRE_CONVERT_QUERY =
|
||||
!std::is_same<scalar_t, float>::value &&
|
||||
std::is_same<qk_vec_type, vec_op::FP32Vec16>::value;
|
||||
if constexpr (PRE_CONVERT_QUERY) {
|
||||
const int q_f32_nbytes = num_heads * HEAD_DIM * sizeof(float);
|
||||
q_f32 = static_cast<float*>(std::aligned_alloc(64, q_f32_nbytes));
|
||||
}
|
||||
|
||||
#pragma omp parallel
|
||||
{
|
||||
const int num_threads = omp_get_num_threads();
|
||||
const int thread_id = omp_get_thread_num();
|
||||
float* __restrict__ acc_out_thread =
|
||||
acc_out + thread_id * num_heads * V_HEAD_DIM;
|
||||
float* __restrict__ acc_lse_thread = acc_lse.data() + thread_id * num_heads;
|
||||
|
||||
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
|
||||
// reset accumulator
|
||||
std::fill(acc_out_thread, acc_out_thread + num_heads * V_HEAD_DIM, 0.0f);
|
||||
std::fill(acc_lse_thread, acc_lse_thread + num_heads, -FLT_MAX);
|
||||
|
||||
const int seq_len = seq_lens[seq_idx];
|
||||
const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
const int last_block_size = seq_len - (block_num - 1) * BLOCK_SIZE;
|
||||
|
||||
const qk_vec_type* q_vecs;
|
||||
if constexpr (PRE_CONVERT_QUERY) {
|
||||
// pre-convert query to FP32 since FP16/BF16->FP32 is slow.
|
||||
#pragma omp for
|
||||
for (int i = 0; i < num_heads * HEAD_DIM; i += QK_NUM_ELEM) {
|
||||
qk_load_vec_type q_load_vec(q + seq_idx * q_stride + i);
|
||||
qk_vec_type q_vec(q_load_vec);
|
||||
q_vec.save(q_f32 + i);
|
||||
}
|
||||
q_vecs = reinterpret_cast<const qk_vec_type*>(q_f32);
|
||||
} else {
|
||||
q_vecs = reinterpret_cast<const qk_vec_type*>(q + seq_idx * q_stride);
|
||||
}
|
||||
|
||||
#pragma omp for
|
||||
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
|
||||
const int physical_block_idx =
|
||||
block_tables[seq_idx * max_num_blocks_per_seq + block_idx];
|
||||
const int num_tokens =
|
||||
block_idx < block_num - 1 ? BLOCK_SIZE : last_block_size;
|
||||
|
||||
mla_decode_block<scalar_t, HEAD_DIM, V_HEAD_DIM, BLOCK_SIZE>(
|
||||
q_vecs, kv_cache + physical_block_idx * kv_stride, acc_out_thread,
|
||||
acc_lse_thread, num_heads, scale, num_tokens);
|
||||
}
|
||||
|
||||
// merge attention states across threads
|
||||
// section 2.2 in https://arxiv.org/pdf/2501.01005
|
||||
// each thread is responsible for 1 head
|
||||
#pragma omp for
|
||||
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
|
||||
float* acc_lse_head = acc_lse.data() + head_idx;
|
||||
float* acc_out_head = acc_out + head_idx * V_HEAD_DIM;
|
||||
|
||||
float max_val = -FLT_MAX;
|
||||
for (int thread_id_ = 0; thread_id_ < num_threads; ++thread_id_) {
|
||||
max_val = std::max(max_val, acc_lse_head[thread_id_ * num_heads]);
|
||||
}
|
||||
|
||||
float sum_exp = 0.0f;
|
||||
for (int thread_id_ = 0; thread_id_ < num_threads; ++thread_id_) {
|
||||
float val = std::exp(acc_lse_head[thread_id_ * num_heads] - max_val);
|
||||
acc_lse_head[thread_id_ * num_heads] = val;
|
||||
sum_exp += val;
|
||||
}
|
||||
|
||||
float inv_sum = 1.0f / sum_exp;
|
||||
float out_head[V_HEAD_DIM] = {};
|
||||
for (int thread_id_ = 0; thread_id_ < num_threads; ++thread_id_) {
|
||||
float scale_ = acc_lse_head[thread_id_ * num_heads] * inv_sum;
|
||||
for (int i = 0; i < V_HEAD_DIM; ++i) {
|
||||
out_head[i] +=
|
||||
acc_out_head[thread_id_ * num_heads * V_HEAD_DIM + i] * scale_;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < V_HEAD_DIM; ++i) {
|
||||
vec_op::storeFP32(out_head[i], out + seq_idx * o_stride +
|
||||
head_idx * V_HEAD_DIM + i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (PRE_CONVERT_QUERY) {
|
||||
std::free(q_f32);
|
||||
}
|
||||
std::free(acc_out);
|
||||
}
|
||||
|
||||
void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
|
||||
torch::Tensor& kv_cache, double scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens) {
|
||||
const int num_seqs = query.size(0);
|
||||
const int num_heads = query.size(1);
|
||||
const int head_dim = query.size(2);
|
||||
const int block_size = kv_cache.size(1);
|
||||
const int v_head_dim = out.size(2);
|
||||
|
||||
const int max_num_blocks_per_seq = block_tables.size(1);
|
||||
const int o_stride = out.stride(0);
|
||||
const int q_stride = query.stride(0);
|
||||
const int kv_stride = kv_cache.stride(0);
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
query.scalar_type(), "mla_decode_kvcache_cpu_impl", [&] {
|
||||
CPU_KERNEL_GUARD_IN(mla_decode_kvcache_cpu_impl)
|
||||
if (head_dim == 576 && v_head_dim == 512 && block_size == 16)
|
||||
mla_decode_kvcache_cpu_impl<scalar_t, 576, 512, 16>(
|
||||
out.data_ptr<scalar_t>(), query.data_ptr<scalar_t>(),
|
||||
kv_cache.data_ptr<scalar_t>(), num_heads, scale,
|
||||
block_tables.data_ptr<int>(), seq_lens.data_ptr<int>(),
|
||||
max_num_blocks_per_seq, o_stride, q_stride, kv_stride, num_seqs);
|
||||
else
|
||||
TORCH_CHECK(false, "Unsupported block size: ", block_size);
|
||||
CPU_KERNEL_GUARD_OUT(mla_decode_kvcache_cpu_impl)
|
||||
});
|
||||
}
|
||||
@ -170,7 +170,7 @@ void rotary_embedding_gptj_impl(
|
||||
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
||||
torch::Tensor& key, int64_t head_size,
|
||||
torch::Tensor& cos_sin_cache, bool is_neox) {
|
||||
int num_tokens = query.numel() / query.size(-1);
|
||||
int num_tokens = positions.numel();
|
||||
int rot_dim = cos_sin_cache.size(1);
|
||||
int num_heads = query.size(-1) / head_size;
|
||||
int num_kv_heads = key.size(-1) / head_size;
|
||||
|
||||
@ -25,7 +25,7 @@ struct KernelVecType<c10::BFloat16> {
|
||||
|
||||
template <>
|
||||
struct KernelVecType<c10::Half> {
|
||||
#ifdef __powerpc64__
|
||||
#if defined(__powerpc64__) || defined(__s390x__)
|
||||
// Power architecture-specific vector type
|
||||
using load_vec_type = vec_op::FP32Vec16;
|
||||
#else
|
||||
|
||||
@ -18,6 +18,10 @@ void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a,
|
||||
const std::optional<torch::Tensor>& azp,
|
||||
const std::optional<torch::Tensor>& bias);
|
||||
|
||||
void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
|
||||
torch::Tensor& kv_cache, double scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens);
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// vLLM custom ops
|
||||
|
||||
@ -150,6 +154,14 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
" str kv_cache_dtype,"
|
||||
" Tensor k_scale, Tensor v_scale) -> ()");
|
||||
cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache);
|
||||
|
||||
cache_ops.def(
|
||||
"concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
|
||||
" Tensor! kv_cache,"
|
||||
" Tensor slot_mapping,"
|
||||
" str kv_cache_dtype,"
|
||||
" Tensor scale) -> ()");
|
||||
cache_ops.impl("concat_and_cache_mla", torch::kCPU, &concat_and_cache_mla);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
|
||||
@ -157,4 +169,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
|
||||
utils.def("init_cpu_threads_env(str cpu_ids) -> str", &init_cpu_threads_env);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cpu), cpu_ops) {
|
||||
cpu_ops.def(
|
||||
"mla_decode_kvcache("
|
||||
" Tensor! out, Tensor query, Tensor kv_cache,"
|
||||
" float scale, Tensor block_tables, Tensor seq_lens) -> ()");
|
||||
cpu_ops.impl("mla_decode_kvcache", torch::kCPU, &mla_decode_kvcache);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
||||
|
||||
@ -48,4 +48,14 @@ struct enable_sm90_or_later : Kernel {
|
||||
Kernel::operator()(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
template <typename Kernel>
|
||||
struct enable_sm90_only : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE void operator()(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 900
|
||||
Kernel::operator()(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@ -0,0 +1,457 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
//
|
||||
// This file is a modified excerpt of
|
||||
// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp
|
||||
// from https://github.com/NVIDIA/cutlass v3.5.0
|
||||
// It has been modified to support either row/column or scalar broadcasting
|
||||
// where the tensor being loaded from is always passed in via a device pointer.
|
||||
// This lets one compiled kernel handle all cases of per-tensor or
|
||||
// per-channel/per-token quantization.
|
||||
//
|
||||
// This interface also allows the scales to be passed in as tensors that
|
||||
// consistently reside on the device, which avoids an issue with a previous
|
||||
// implementation where scalars needed to be on the CPU since they
|
||||
// were passed in via float values. This created a potential performance hazard
|
||||
// if scales were initially on the device, and caused torch.compile graphs
|
||||
// breaks when moving scales to the CPU.
|
||||
//
|
||||
#pragma once
|
||||
|
||||
// Turn off clang-format for the entire file to keep it close to upstream
|
||||
// clang-format off
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/barrier.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
|
||||
|
||||
namespace cutlass::epilogue::fusion {
|
||||
|
||||
using namespace cute;
|
||||
using namespace detail;
|
||||
|
||||
// Row vector broadcast
|
||||
template<
|
||||
int Stages,
|
||||
class CtaTileShapeMNK,
|
||||
class Element,
|
||||
class StrideMNL = Stride<_0,_1,_0>,
|
||||
int Alignment = 128 / sizeof_bits_v<Element>
|
||||
>
|
||||
struct Sm90RowOrScalarBroadcastArray {
|
||||
static_assert(Stages == 0, "Row broadcast doesn't support smem usage");
|
||||
static_assert(is_static_v<decltype(take<0,2>(StrideMNL{}))>); // batch stride can be dynamic or static
|
||||
static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{});
|
||||
|
||||
struct SharedStorage {
|
||||
array_aligned<Element, size<1>(CtaTileShapeMNK{})> smem;
|
||||
};
|
||||
|
||||
// This struct has been modified to have a bool indicating that ptr_row is a
|
||||
// scalar that must be broadcast, instead of containing a scalar that is
|
||||
// valid if ptr_row is null.
|
||||
struct Arguments {
|
||||
const Element* const* ptr_row_array = nullptr;
|
||||
bool row_broadcast = true;
|
||||
StrideMNL dRow = {};
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static bool
|
||||
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static cutlass::Status
|
||||
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
|
||||
CudaHostAdapter* cuda_adapter = nullptr) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90RowOrScalarBroadcastArray() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90RowOrScalarBroadcastArray(Params const& params, SharedStorage const& shared_storage)
|
||||
: params(params)
|
||||
, smem(const_cast<Element*>(shared_storage.smem.data())) { }
|
||||
|
||||
Params params;
|
||||
Element *smem = nullptr;
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_producer_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_C_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_zero() const {
|
||||
return (!params.row_broadcast && *(params.ptr_row_array[group]) == Element(0));
|
||||
}
|
||||
|
||||
template <class... Args>
|
||||
CUTLASS_DEVICE auto
|
||||
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
|
||||
return EmptyProducerLoadCallbacks{};
|
||||
}
|
||||
|
||||
template <class GS_GTensor, class GS_STensor, class GS_CTensor, class Tiled_G2S, class SR_STensor, class SR_RTensor, class CTensor, class ThrResidue, class ThrNum>
|
||||
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
ConsumerStoreCallbacks(
|
||||
GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_,
|
||||
GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_,
|
||||
SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_,
|
||||
CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_,
|
||||
int group, Params const& params_)
|
||||
: tGS_gRow(tGS_gRow_)
|
||||
, tGS_sRow(tGS_sRow_)
|
||||
, tGS_cRow(tGS_cRow_)
|
||||
, tiled_G2S(tiled_g2s_)
|
||||
, tSR_sRow(tSR_sRow_)
|
||||
, tSR_rRow(tSR_rRow_)
|
||||
, tCcRow(tCcRow_)
|
||||
, residue_tCcRow(residue_tCcRow_)
|
||||
, group(group)
|
||||
, params(params_) {}
|
||||
|
||||
GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N)
|
||||
GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N)
|
||||
GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N)
|
||||
Tiled_G2S tiled_G2S;
|
||||
|
||||
SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
|
||||
CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
ThrResidue residue_tCcRow; // (m, n)
|
||||
ThrNum thr_num;
|
||||
int group;
|
||||
Params const& params;
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
begin() {
|
||||
if (!params.row_broadcast) {
|
||||
fill(tSR_rRow, *(params.ptr_row_array[group]));
|
||||
return;
|
||||
}
|
||||
|
||||
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
|
||||
Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
|
||||
Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
|
||||
Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
|
||||
|
||||
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
|
||||
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
|
||||
continue; // OOB of SMEM,
|
||||
}
|
||||
if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) {
|
||||
tGS_sRow_flt(i) = tGS_gRow_flt(i);
|
||||
}
|
||||
else {
|
||||
tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds.
|
||||
}
|
||||
}
|
||||
synchronize();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
begin_loop(int epi_m, int epi_n) {
|
||||
if (epi_m == 0) { // Assumes M-major subtile loop
|
||||
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
|
||||
Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
|
||||
Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
|
||||
copy(tSR_sRow_flt, tSR_rRow_flt);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ElementAccumulator, int FragmentSize>
|
||||
CUTLASS_DEVICE Array<Element, FragmentSize>
|
||||
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
||||
Array<Element, FragmentSize> frg_row;
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < FragmentSize; ++i) {
|
||||
frg_row[i] = tSR_rRow(epi_v * FragmentSize + i);
|
||||
}
|
||||
|
||||
return frg_row;
|
||||
}
|
||||
};
|
||||
|
||||
template <
|
||||
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
|
||||
class... Args
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||
auto [m, n, k, l] = args.tile_coord_mnkl;
|
||||
using ThreadCount = decltype(size(args.tiled_copy));
|
||||
|
||||
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow);
|
||||
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
|
||||
Tensor sRow = make_tensor(make_smem_ptr(smem),
|
||||
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
|
||||
//// G2S: Gmem to Smem
|
||||
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
||||
Layout< Shape<_1, ThreadCount>,
|
||||
Stride<_0, _1>>{},
|
||||
Layout<_1>{});
|
||||
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
|
||||
Tensor tGS_gRow = thr_g2s.partition_S(gRow);
|
||||
Tensor tGS_sRow = thr_g2s.partition_D(sRow);
|
||||
|
||||
//// G2S: Coord
|
||||
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
|
||||
Tensor tGS_cRow = thr_g2s.partition_S(cRow);
|
||||
|
||||
//// S2R: Smem to Reg
|
||||
Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
|
||||
|
||||
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
|
||||
tGS_gRow,
|
||||
tGS_sRow,
|
||||
tGS_cRow, tiled_g2s,
|
||||
tSR_sRow,
|
||||
tSR_rRow,
|
||||
args.tCcD,
|
||||
args.residue_cD,
|
||||
ThreadCount{},
|
||||
l,
|
||||
params);
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Column vector broadcast
|
||||
template<
|
||||
int Stages,
|
||||
class CtaTileShapeMNK,
|
||||
class Element,
|
||||
class StrideMNL = Stride<_1,_0,_0>,
|
||||
int Alignment = 128 / sizeof_bits_v<Element>
|
||||
>
|
||||
struct Sm90ColOrScalarBroadcastArray {
|
||||
static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet");
|
||||
static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");
|
||||
static_assert(
|
||||
(cute::is_same_v<StrideMNL, Stride<_1,_0, _0>>) || // col vector broadcast, e.g. per-row alpha/bias
|
||||
(cute::is_same_v<StrideMNL, Stride<_1,_0,int>>)); // batched col vector broadcast, e.g. batched per-row bias
|
||||
|
||||
// Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem
|
||||
struct SharedStorage { };
|
||||
|
||||
// This struct has been modified to have a bool indicating that ptr_col is a
|
||||
// scalar that must be broadcast, instead of containing a scalar that is
|
||||
// valid if ptr_col is null.
|
||||
struct Arguments {
|
||||
const Element* const* ptr_col_array = nullptr;
|
||||
bool col_broadcast = true;
|
||||
StrideMNL dCol = {};
|
||||
};
|
||||
|
||||
using Params = Arguments;
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static bool
|
||||
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static cutlass::Status
|
||||
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
|
||||
CudaHostAdapter* cuda_adapter = nullptr) {
|
||||
return cutlass::Status::kSuccess;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_producer_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_C_load_needed() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_zero() const {
|
||||
return (!params.col_broadcast && *(params.ptr_col_array[group]) == Element(0));
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90ColOrScalarBroadcastArray() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90ColOrScalarBroadcastArray(Params const& params, SharedStorage const& shared_storage)
|
||||
: params(params) { }
|
||||
|
||||
Params params;
|
||||
|
||||
template <class... Args>
|
||||
CUTLASS_DEVICE auto
|
||||
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
|
||||
return EmptyProducerLoadCallbacks{};
|
||||
}
|
||||
|
||||
template<class GTensor, class RTensor, class CTensor, class ProblemShape>
|
||||
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
ConsumerStoreCallbacks(
|
||||
GTensor&& tCgCol,
|
||||
RTensor&& tCrCol,
|
||||
CTensor&& tCcCol,
|
||||
ProblemShape problem_shape,
|
||||
int group,
|
||||
Params const& params
|
||||
):
|
||||
tCgCol(cute::forward<GTensor>(tCgCol)),
|
||||
tCrCol(cute::forward<RTensor>(tCrCol)),
|
||||
tCcCol(cute::forward<CTensor>(tCcCol)),
|
||||
m(get<0>(problem_shape)),
|
||||
group(group),
|
||||
params(params) {}
|
||||
|
||||
GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
RTensor tCrCol;
|
||||
CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
Params const& params;
|
||||
int m;
|
||||
int group;
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
begin() {
|
||||
Tensor pred = make_tensor<bool>(shape(tCgCol));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(pred); ++i) {
|
||||
pred(i) = get<0>(tCcCol(i)) < m;
|
||||
}
|
||||
|
||||
if (!params.col_broadcast) {
|
||||
fill(tCrCol, *(params.ptr_col_array[group]));
|
||||
return;
|
||||
}
|
||||
|
||||
// Filter so we don't issue redundant copies over stride-0 modes
|
||||
// (only works if 0-strides are in same location, which is by construction)
|
||||
copy_if(pred, filter(tCgCol), filter(tCrCol));
|
||||
}
|
||||
|
||||
template <typename ElementAccumulator, int FragmentSize>
|
||||
CUTLASS_DEVICE Array<Element, FragmentSize>
|
||||
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
||||
Array<Element, FragmentSize> frg_col;
|
||||
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < FragmentSize; ++i) {
|
||||
frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i);
|
||||
}
|
||||
|
||||
return frg_col;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template <
|
||||
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
|
||||
class... Args
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||
|
||||
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||
auto [m, n, k, l] = args.tile_coord_mnkl;
|
||||
|
||||
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol);
|
||||
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
|
||||
// Generate an identity tensor matching the shape of the global tensor and
|
||||
// partition the same way, this will be used to generate the predicate
|
||||
// tensor for loading
|
||||
Tensor cCol = make_identity_tensor(mCol.shape());
|
||||
Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
|
||||
return ConsumerStoreCallbacks(
|
||||
cute::move(tCgCol),
|
||||
cute::move(tCrCol),
|
||||
cute::move(tCcCol),
|
||||
args.problem_shape_mnkl,
|
||||
l,
|
||||
params
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
|
||||
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp"
|
||||
|
||||
/*
|
||||
This file defines custom epilogues for fusing channel scales, token scales,
|
||||
@ -69,6 +70,16 @@ struct ScaledEpilogueBase {
|
||||
0 /*Stages*/, TileShape, T, T, Stride<Int<0>, Int<1>, Int<0>>,
|
||||
128 / sizeof_bits_v<T>, EnableNullPtr>;
|
||||
|
||||
template <typename T>
|
||||
using ColOrScalarLoadArray =
|
||||
cutlass::epilogue::fusion::Sm90ColOrScalarBroadcastArray<
|
||||
0 /*Stages*/, TileShape, T, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using RowOrScalarLoadArray =
|
||||
cutlass::epilogue::fusion::Sm90RowOrScalarBroadcastArray<
|
||||
0 /*Stages*/, TileShape, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
// This utility function constructs the arguments for the load descriptors
|
||||
// from a tensor. It can handle both row and column, as well as row/column or
|
||||
// scalar cases.
|
||||
@ -96,6 +107,14 @@ struct ScaledEpilogueBase {
|
||||
std::is_same_v<Descriptor, RowLoad<T, true>>);
|
||||
return Arguments{data_ptr};
|
||||
}
|
||||
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(const T* const* data_ptr, bool do_broadcast) {
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
static_assert(std::is_same_v<Descriptor, ColOrScalarLoadArray<T>> ||
|
||||
std::is_same_v<Descriptor, RowOrScalarLoadArray<T>>);
|
||||
return Arguments{data_ptr, do_broadcast};
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
@ -381,4 +400,51 @@ struct ScaledEpilogueBiasAzpToken
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
This epilogue works like ScaledEpilogue, but ScaleA and ScaleB are pointers
|
||||
to arrays containing different scales used in group gemm. The number of
|
||||
pointers in ScaleA and the number of pointers in ScaleB are equal to the
|
||||
group size.
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||
struct ScaledEpilogueArray
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoadArray<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoadArray<float>;
|
||||
|
||||
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
using ScaleAArray = typename SUPER::template ColOrScalarLoadArray<float>;
|
||||
using ScaleBArray = typename SUPER::template RowOrScalarLoadArray<float>;
|
||||
|
||||
static ArgumentType prepare_args(float const* const* a_scales_ptr,
|
||||
float const* const* b_scales_ptr,
|
||||
bool a_col_broadcast, bool b_row_broadcast) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleAArray, float>(
|
||||
a_scales_ptr, a_col_broadcast);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleBArray, float>(
|
||||
b_scales_ptr, b_row_broadcast);
|
||||
|
||||
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
|
||||
return ArgumentType{a_args, evt0_args, {}};
|
||||
}
|
||||
};
|
||||
|
||||
}; // namespace vllm::c3x
|
||||
|
||||
@ -402,7 +402,7 @@ struct CollectiveMma<
|
||||
|
||||
// TODO: test `scale_copy_a` with `ScaleMsPerTile` < 128
|
||||
TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{},
|
||||
Layout<Shape<_32, _1>>{}, Layout<Shape<_4, _1>>{}); // (1,1,1)
|
||||
Layout<Shape<_32>>{}, Layout<Shape<_1>>{}); // (1,1,1)
|
||||
TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{},
|
||||
Layout<Shape<_1>>{}, Layout<Shape<_1>>{}); // (1,1,1)
|
||||
ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x);
|
||||
|
||||
@ -6,6 +6,11 @@
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
// Need a special dispatch case macro since we will nest the FP8 dispatch.
|
||||
// Instead of the usual 'scalar_t', this names the dispatched type 'fp8_t'.
|
||||
#define AT_DISPATCH_FP8_CASE(enum_type, ...) \
|
||||
AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, fp8_t, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||
@ -14,17 +19,32 @@
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||
|
||||
// TODO(luka/varun): use FP8_TYPE macro after refactoring
|
||||
#ifndef USE_ROCM
|
||||
// ROCm devices might use either fn or fnuz, so set up dispatch table for both.
|
||||
// A host-based check at runtime will create a preferred FP8 type for ROCm
|
||||
// such that the correct kernel is dispatched.
|
||||
#ifdef USE_ROCM
|
||||
#define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
|
||||
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
|
||||
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
|
||||
#else
|
||||
#define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
|
||||
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
|
||||
#else
|
||||
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
|
||||
#endif
|
||||
|
||||
// When using this dispatch macro, the type is 'fp8_t' not 'scalar_t'.
|
||||
// See AT_DISPATCH_FP8_CASE above.
|
||||
#define VLLM_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__))
|
||||
|
||||
#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__))
|
||||
|
||||
|
||||
@ -21,9 +21,9 @@
|
||||
namespace vllm {
|
||||
|
||||
// TODO(woosuk): Further optimize this kernel.
|
||||
template <typename scalar_t>
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
__global__ void rms_norm_static_fp8_quant_kernel(
|
||||
FP8_TYPE* __restrict__ out, // [..., hidden_size]
|
||||
fp8_type* __restrict__ out, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float* __restrict__ scale, // [1]
|
||||
@ -52,7 +52,7 @@ __global__ void rms_norm_static_fp8_quant_kernel(
|
||||
float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
|
||||
out[blockIdx.x * hidden_size + idx] =
|
||||
scaled_fp8_conversion<true>(out_norm, scale_inv);
|
||||
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
|
||||
}
|
||||
}
|
||||
|
||||
@ -60,10 +60,10 @@ __global__ void rms_norm_static_fp8_quant_kernel(
|
||||
Additional optimizations we can make in this case are
|
||||
packed and vectorized operations, which help with the
|
||||
memory latency bottleneck. */
|
||||
template <typename scalar_t, int width>
|
||||
template <typename scalar_t, int width, typename fp8_type>
|
||||
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
|
||||
fused_add_rms_norm_static_fp8_quant_kernel(
|
||||
FP8_TYPE* __restrict__ out, // [..., hidden_size]
|
||||
fp8_type* __restrict__ out, // [..., hidden_size]
|
||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
@ -114,7 +114,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; ++i) {
|
||||
out[id * width + i] =
|
||||
scaled_fp8_conversion<true>(float(temp.data[i]), scale_inv);
|
||||
scaled_fp8_conversion<true, fp8_type>(float(temp.data[i]), scale_inv);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -122,10 +122,10 @@ fused_add_rms_norm_static_fp8_quant_kernel(
|
||||
/* Generic fused_add_rms_norm_kernel
|
||||
The width field is not used here but necessary for other specializations.
|
||||
*/
|
||||
template <typename scalar_t, int width>
|
||||
template <typename scalar_t, int width, typename fp8_type>
|
||||
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
|
||||
fused_add_rms_norm_static_fp8_quant_kernel(
|
||||
FP8_TYPE* __restrict__ out, // [..., hidden_size]
|
||||
fp8_type* __restrict__ out, // [..., hidden_size]
|
||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
@ -158,7 +158,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
|
||||
float x = (float)residual[blockIdx.x * hidden_size + idx];
|
||||
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
|
||||
out[blockIdx.x * hidden_size + idx] =
|
||||
scaled_fp8_conversion<true>(out_norm, scale_inv);
|
||||
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
|
||||
}
|
||||
}
|
||||
|
||||
@ -176,25 +176,33 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||
dim3 block(std::min(hidden_size, 1024));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
|
||||
vllm::rms_norm_static_fp8_quant_kernel<scalar_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), epsilon,
|
||||
num_tokens, hidden_size);
|
||||
});
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "rms_norm_kernel_scalar_type", [&] {
|
||||
VLLM_DISPATCH_FP8_TYPES(
|
||||
out.scalar_type(), "rms_norm_kernel_fp8_type", [&] {
|
||||
vllm::rms_norm_static_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(),
|
||||
epsilon, num_tokens, hidden_size);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
|
||||
vllm::fused_add_rms_norm_static_fp8_quant_kernel<scalar_t, width> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(), \
|
||||
residual.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(), \
|
||||
scale.data_ptr<float>(), epsilon, num_tokens, hidden_size); \
|
||||
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
input.scalar_type(), "fused_add_rms_norm_kernel_scalar_type", [&] { \
|
||||
VLLM_DISPATCH_FP8_TYPES( \
|
||||
out.scalar_type(), "fused_add_rms_norm_kernel_fp8_type", [&] { \
|
||||
vllm::fused_add_rms_norm_static_fp8_quant_kernel<scalar_t, \
|
||||
width, fp8_t> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(), \
|
||||
residual.data_ptr<scalar_t>(), \
|
||||
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), \
|
||||
epsilon, num_tokens, hidden_size); \
|
||||
}); \
|
||||
});
|
||||
|
||||
void fused_add_rms_norm_static_fp8_quant(
|
||||
torch::Tensor& out, // [..., hidden_size],
|
||||
torch::Tensor& input, // [..., hidden_size]
|
||||
|
||||
@ -18,3 +18,14 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad);
|
||||
#ifndef USE_ROCM
|
||||
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
|
||||
torch::Tensor b_qweight, torch::Tensor b_scales,
|
||||
std::optional<torch::Tensor> b_qzeros,
|
||||
std::optional<torch::Tensor> topk_weights,
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor expert_ids,
|
||||
torch::Tensor num_tokens_post_pad, int64_t top_k,
|
||||
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
|
||||
int64_t BLOCK_SIZE_K, int64_t bit);
|
||||
#endif
|
||||
346
csrc/moe/moe_wna16.cu
Normal file
346
csrc/moe/moe_wna16.cu
Normal file
@ -0,0 +1,346 @@
|
||||
|
||||
#include <torch/all.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include "moe_wna16_utils.h"
|
||||
|
||||
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
|
||||
|
||||
template <typename scalar_t, int bit, int GROUPS>
|
||||
__global__ void moe_wna16_gemm_kernel(
|
||||
const scalar_t* __restrict__ input, scalar_t* __restrict__ output,
|
||||
|
||||
const uint32_t* __restrict__ qweight, const scalar_t* __restrict__ scales,
|
||||
const uint32_t* __restrict__ qzeros,
|
||||
|
||||
const float* __restrict__ topk_weights,
|
||||
const int32_t* __restrict__ sorted_token_ids,
|
||||
const int32_t* __restrict__ expert_ids,
|
||||
const int32_t* __restrict__ num_tokens_post_pad,
|
||||
|
||||
uint16_t num_experts, uint16_t group_size, uint16_t top_k, uint32_t size_m,
|
||||
uint32_t size_n, uint32_t size_k, uint16_t BLOCK_SIZE_M,
|
||||
uint16_t BLOCK_SIZE_N, uint16_t BLOCK_SIZE_K, bool has_zp,
|
||||
bool mul_topk_weight) {
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800
|
||||
if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
|
||||
return;
|
||||
} else {
|
||||
#endif
|
||||
|
||||
using Dtype = ScalarType<scalar_t>;
|
||||
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
|
||||
|
||||
if (blockIdx.x * BLOCK_SIZE_M >= num_tokens_post_pad[0]) return;
|
||||
|
||||
const int32_t offset_n = blockIdx.y * BLOCK_SIZE_N + threadIdx.x;
|
||||
const int32_t offset_k = blockIdx.z * BLOCK_SIZE_K;
|
||||
|
||||
const int32_t expert_id = expert_ids[blockIdx.x];
|
||||
|
||||
int32_t num_valid_tokens = 0;
|
||||
extern __shared__ uint16_t block_input_tmp[];
|
||||
scalar_t* block_input = reinterpret_cast<scalar_t*>(block_input_tmp);
|
||||
scalar_t2* block_input_half2 = reinterpret_cast<scalar_t2*>(block_input);
|
||||
|
||||
// load BLOCK_SIZE_M * BLOCK_SIZE_K into shared memory
|
||||
for (int m = 0; m < BLOCK_SIZE_M; m++) {
|
||||
const int32_t offset_m = blockIdx.x * BLOCK_SIZE_M + m;
|
||||
const int32_t token_index = sorted_token_ids[offset_m];
|
||||
if (token_index / top_k >= size_m) break;
|
||||
|
||||
num_valid_tokens = m + 1;
|
||||
if (blockIdx.z == 0 && offset_n < size_n)
|
||||
output[token_index * size_n + offset_n] = Dtype::int2num(0);
|
||||
|
||||
if (expert_id != -1) {
|
||||
int k_per_thread = DIVIDE(BLOCK_SIZE_K, BLOCK_SIZE_N);
|
||||
for (int i = 0; i < k_per_thread; i++) {
|
||||
int k = BLOCK_SIZE_N * i + threadIdx.x;
|
||||
if (k >= BLOCK_SIZE_K) break;
|
||||
if (offset_k + k >= size_k) break;
|
||||
|
||||
// load input to shared memory
|
||||
// use a special layout to fit the layout of dequanted-weight
|
||||
int origin_k;
|
||||
if constexpr (bit == 4) {
|
||||
// [0, 4, 1, 5, 2, 6, 3, 7]
|
||||
int8_t order = (threadIdx.x % 2) * 4 + ((threadIdx.x % 8) / 2);
|
||||
origin_k = BLOCK_SIZE_N * i + threadIdx.x / 8 * 8 + order;
|
||||
} else {
|
||||
// [0, 2, 1, 3]
|
||||
int8_t order = (threadIdx.x % 2) * 2 + ((threadIdx.x % 4) / 2);
|
||||
origin_k = BLOCK_SIZE_N * i + threadIdx.x / 4 * 4 + order;
|
||||
}
|
||||
|
||||
origin_k += token_index / top_k * size_k + blockIdx.z * BLOCK_SIZE_K;
|
||||
block_input[m * BLOCK_SIZE_K + k] = input[origin_k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (expert_id == -1) return;
|
||||
__syncthreads();
|
||||
if (threadIdx.x >= BLOCK_SIZE_N || offset_n >= size_n) return;
|
||||
|
||||
float res[64]; // assume BLOCK_SIZE_M <= 64
|
||||
scalar_t2 res2;
|
||||
scalar_t2 scale_f2;
|
||||
scalar_t2 qzero_f2;
|
||||
|
||||
// note that (size_n * size_k * expert_id) may greater than 2 ** 31
|
||||
constexpr int8_t pack_factor = 32 / bit;
|
||||
const uint64_t expert_offset = ((uint64_t)size_n) * size_k * expert_id;
|
||||
const uint32_t* expert_qweight = qweight + expert_offset / pack_factor;
|
||||
const scalar_t* expert_scales = scales + expert_offset / group_size;
|
||||
const uint32_t* expert_qzeros =
|
||||
qzeros + expert_offset / group_size / pack_factor;
|
||||
|
||||
// load 4*int32 one time: 4 int32 = 128 bit = 1 float4
|
||||
// weight would be loaded in loop
|
||||
uint32_t expert_qweight_tmp[4];
|
||||
float4* expert_qweight_tmp_float4 =
|
||||
reinterpret_cast<float4*>(expert_qweight_tmp);
|
||||
|
||||
// load all required scales one time
|
||||
scalar_t expert_scales_groups[GROUPS];
|
||||
int scales_offset_tmp =
|
||||
(offset_n * size_k + offset_k) / group_size / GROUPS;
|
||||
if constexpr (GROUPS == 1) {
|
||||
*expert_scales_groups = expert_scales[scales_offset_tmp];
|
||||
} else if constexpr (GROUPS == 2) {
|
||||
float* expert_scales_groups_tmp =
|
||||
reinterpret_cast<float*>(expert_scales_groups);
|
||||
*expert_scales_groups_tmp =
|
||||
reinterpret_cast<const float*>(expert_scales)[scales_offset_tmp];
|
||||
} else if constexpr (GROUPS == 4) {
|
||||
float2* expert_scales_groups_tmp =
|
||||
reinterpret_cast<float2*>(expert_scales_groups);
|
||||
*expert_scales_groups_tmp =
|
||||
reinterpret_cast<const float2*>(expert_scales)[scales_offset_tmp];
|
||||
} else if constexpr (GROUPS == 8) {
|
||||
float4* expert_scales_groups_tmp =
|
||||
reinterpret_cast<float4*>(expert_scales_groups);
|
||||
*expert_scales_groups_tmp =
|
||||
reinterpret_cast<const float4*>(expert_scales)[scales_offset_tmp];
|
||||
}
|
||||
|
||||
// load all required qzeros one time
|
||||
uint8_t expert_qzeros_groups[GROUPS];
|
||||
if (!has_zp) {
|
||||
if constexpr (bit == 4) {
|
||||
qzero_f2 = Dtype::num2num2(Dtype::int2num(8));
|
||||
} else {
|
||||
qzero_f2 = Dtype::num2num2(Dtype::int2num(128));
|
||||
}
|
||||
} else {
|
||||
int qzeros_offset_tmp =
|
||||
(offset_n / (8 / bit)) * (size_k / group_size / GROUPS) +
|
||||
offset_k / group_size / GROUPS;
|
||||
if constexpr (GROUPS == 1) {
|
||||
uint8_t* expert_qzeros_groups_tmp =
|
||||
reinterpret_cast<uint8_t*>(expert_qzeros_groups);
|
||||
*expert_qzeros_groups_tmp =
|
||||
reinterpret_cast<const uint8_t*>(expert_qzeros)[qzeros_offset_tmp];
|
||||
} else if constexpr (GROUPS == 2) {
|
||||
uint16_t* expert_qzeros_groups_tmp =
|
||||
reinterpret_cast<uint16_t*>(expert_qzeros_groups);
|
||||
*expert_qzeros_groups_tmp =
|
||||
reinterpret_cast<const uint16_t*>(expert_qzeros)[qzeros_offset_tmp];
|
||||
} else if constexpr (GROUPS == 4) {
|
||||
uint32_t* expert_qzeros_groups_tmp =
|
||||
reinterpret_cast<uint32_t*>(expert_qzeros_groups);
|
||||
*expert_qzeros_groups_tmp =
|
||||
reinterpret_cast<const uint32_t*>(expert_qzeros)[qzeros_offset_tmp];
|
||||
} else if constexpr (GROUPS == 8) {
|
||||
uint64_t* expert_qzeros_groups_tmp =
|
||||
reinterpret_cast<uint64_t*>(expert_qzeros_groups);
|
||||
*expert_qzeros_groups_tmp =
|
||||
reinterpret_cast<const uint64_t*>(expert_qzeros)[qzeros_offset_tmp];
|
||||
}
|
||||
}
|
||||
|
||||
for (int tmp_k = 0; tmp_k < BLOCK_SIZE_K / pack_factor; tmp_k++) {
|
||||
int k = offset_k + tmp_k * pack_factor;
|
||||
if (k >= size_k) break;
|
||||
const int32_t weight_offset = offset_n * size_k + k;
|
||||
|
||||
if (tmp_k % 4 == 0) {
|
||||
*expert_qweight_tmp_float4 = reinterpret_cast<const float4*>(
|
||||
expert_qweight)[weight_offset / pack_factor / 4];
|
||||
}
|
||||
|
||||
if (tmp_k % (group_size / pack_factor) == 0) {
|
||||
scalar_t scale_f =
|
||||
expert_scales_groups[tmp_k / (group_size / pack_factor)];
|
||||
scale_f2 = Dtype::num2num2(scale_f);
|
||||
|
||||
if (has_zp) {
|
||||
uint8_t qzero =
|
||||
expert_qzeros_groups[tmp_k / (group_size / pack_factor)];
|
||||
if constexpr (bit == 4) {
|
||||
qzero = (qzero >> ((threadIdx.x % 2) * 4)) & 0xF;
|
||||
}
|
||||
qzero_f2 = Dtype::num2num2(Dtype::int2num(qzero));
|
||||
}
|
||||
}
|
||||
|
||||
scalar_t2 weight_half2[16 / bit];
|
||||
dequant<scalar_t2, bit>(expert_qweight_tmp[tmp_k % 4], weight_half2);
|
||||
|
||||
for (int m = 0; m < num_valid_tokens; m++) {
|
||||
res2 = {};
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 16 / bit; i++) {
|
||||
int32_t offset_input = m * BLOCK_SIZE_K / 2 + tmp_k * (16 / bit) + i;
|
||||
res2 = __hfma2(__hmul2(__hsub2(weight_half2[i], qzero_f2), scale_f2),
|
||||
block_input_half2[offset_input], res2);
|
||||
}
|
||||
|
||||
if (tmp_k == 0) {
|
||||
res[m] = Dtype::num2float(res2.x) + Dtype::num2float(res2.y);
|
||||
} else {
|
||||
res[m] += Dtype::num2float(res2.x) + Dtype::num2float(res2.y);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int m = 0; m < num_valid_tokens; ++m) {
|
||||
const int32_t token_index =
|
||||
sorted_token_ids[blockIdx.x * BLOCK_SIZE_M + m];
|
||||
if (mul_topk_weight) {
|
||||
res[m] *= topk_weights[token_index];
|
||||
}
|
||||
atomicAdd(&output[token_index * size_n + offset_n],
|
||||
Dtype::float2num(res[m]));
|
||||
}
|
||||
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void run_moe_wna16_gemm(const scalar_t* input, scalar_t* output,
|
||||
const uint32_t* b_qweight, const scalar_t* b_scales,
|
||||
const uint32_t* b_qzeros, const float* topk_weights,
|
||||
const int32_t* sorted_token_ids,
|
||||
const int32_t* expert_ids,
|
||||
const int32_t* num_tokens_post_pad, int num_experts,
|
||||
int group_size, int num_token_blocks, int top_k,
|
||||
int size_m, int size_n, int size_k, int BLOCK_SIZE_M,
|
||||
int BLOCK_SIZE_N, int BLOCK_SIZE_K, int bit,
|
||||
bool has_zp, bool mul_topk_weight) {
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = BLOCK_SIZE_N;
|
||||
blockDim.y = 1;
|
||||
blockDim.z = 1;
|
||||
gridDim.x = num_token_blocks;
|
||||
gridDim.y = DIVIDE(size_n, BLOCK_SIZE_N);
|
||||
gridDim.z = DIVIDE(size_k, BLOCK_SIZE_K);
|
||||
|
||||
auto kernel = moe_wna16_gemm_kernel<scalar_t, 4, 1>;
|
||||
if (bit == 4) {
|
||||
if (BLOCK_SIZE_K / group_size == 2) {
|
||||
kernel = moe_wna16_gemm_kernel<scalar_t, 4, 2>;
|
||||
} else if (BLOCK_SIZE_K / group_size == 4) {
|
||||
kernel = moe_wna16_gemm_kernel<scalar_t, 4, 4>;
|
||||
} else if (BLOCK_SIZE_K / group_size == 8) {
|
||||
kernel = moe_wna16_gemm_kernel<scalar_t, 4, 8>;
|
||||
}
|
||||
} else {
|
||||
if (BLOCK_SIZE_K / group_size == 1) {
|
||||
kernel = moe_wna16_gemm_kernel<scalar_t, 8, 1>;
|
||||
} else if (BLOCK_SIZE_K / group_size == 2) {
|
||||
kernel = moe_wna16_gemm_kernel<scalar_t, 8, 2>;
|
||||
} else if (BLOCK_SIZE_K / group_size == 4) {
|
||||
kernel = moe_wna16_gemm_kernel<scalar_t, 8, 4>;
|
||||
} else if (BLOCK_SIZE_K / group_size == 8) {
|
||||
kernel = moe_wna16_gemm_kernel<scalar_t, 8, 8>;
|
||||
}
|
||||
}
|
||||
|
||||
const int shared_mem_size = BLOCK_SIZE_M * BLOCK_SIZE_K * 2;
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
kernel<<<gridDim, blockDim, shared_mem_size, stream>>>(
|
||||
input, output, b_qweight, b_scales, b_qzeros, topk_weights,
|
||||
sorted_token_ids, expert_ids, num_tokens_post_pad, num_experts,
|
||||
group_size, top_k, size_m, size_n, size_k, BLOCK_SIZE_M, BLOCK_SIZE_N,
|
||||
BLOCK_SIZE_K, has_zp, mul_topk_weight);
|
||||
}
|
||||
|
||||
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
|
||||
torch::Tensor b_qweight, torch::Tensor b_scales,
|
||||
std::optional<torch::Tensor> b_qzeros,
|
||||
std::optional<torch::Tensor> topk_weights,
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor expert_ids,
|
||||
torch::Tensor num_tokens_post_pad, int64_t top_k,
|
||||
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
|
||||
int64_t BLOCK_SIZE_K, int64_t bit) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
auto options =
|
||||
torch::TensorOptions().dtype(input.dtype()).device(input.device());
|
||||
|
||||
const int num_experts = b_qweight.size(0);
|
||||
const int size_m = input.size(0);
|
||||
const int size_n = b_qweight.size(1);
|
||||
const int size_k = input.size(1);
|
||||
const int group_size = size_k / b_scales.size(2);
|
||||
|
||||
int64_t EM = sorted_token_ids.size(0);
|
||||
if (size_m <= BLOCK_SIZE_M) {
|
||||
EM = min(EM, size_m * BLOCK_SIZE_M * top_k);
|
||||
}
|
||||
const int num_token_blocks = (EM + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M;
|
||||
|
||||
const uint32_t* b_qzeros_ptr;
|
||||
if (b_qzeros.has_value())
|
||||
b_qzeros_ptr = (const uint32_t*)b_qzeros.value().data_ptr<uint8_t>();
|
||||
const float* topk_weights_ptr;
|
||||
if (topk_weights.has_value())
|
||||
topk_weights_ptr = (const float*)topk_weights.value().data_ptr();
|
||||
|
||||
int groups_per_block_row = BLOCK_SIZE_K / group_size;
|
||||
TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8");
|
||||
TORCH_CHECK(size_k % BLOCK_SIZE_K == 0,
|
||||
"size_k must divisible by BLOCK_SIZE_K");
|
||||
TORCH_CHECK(BLOCK_SIZE_K % group_size == 0,
|
||||
"BLOCK_SIZE_K must divisible by group_size");
|
||||
TORCH_CHECK(BLOCK_SIZE_M <= 64, "BLOCK_SIZE_M must less or equal to 64");
|
||||
TORCH_CHECK(groups_per_block_row == 1 || groups_per_block_row == 2 ||
|
||||
groups_per_block_row == 4 || groups_per_block_row == 8,
|
||||
"BLOCK_SIZE_K // group_size must be one of [1, 2, 4, 8]");
|
||||
|
||||
if (input.scalar_type() == at::ScalarType::Half) {
|
||||
run_moe_wna16_gemm<half>(
|
||||
(const half*)input.data_ptr<at::Half>(),
|
||||
(half*)output.data_ptr<at::Half>(),
|
||||
(const uint32_t*)b_qweight.data_ptr<uint8_t>(),
|
||||
(const half*)b_scales.data_ptr<at::Half>(), b_qzeros_ptr,
|
||||
topk_weights_ptr, sorted_token_ids.data_ptr<int32_t>(),
|
||||
expert_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(),
|
||||
num_experts, group_size, num_token_blocks, top_k, size_m, size_n,
|
||||
size_k, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, bit,
|
||||
b_qzeros.has_value(), topk_weights.has_value());
|
||||
} else if (input.scalar_type() == at::ScalarType::BFloat16) {
|
||||
run_moe_wna16_gemm<nv_bfloat16>(
|
||||
(const nv_bfloat16*)input.data_ptr<at::BFloat16>(),
|
||||
(nv_bfloat16*)output.data_ptr<at::BFloat16>(),
|
||||
(const uint32_t*)b_qweight.data_ptr<uint8_t>(),
|
||||
(const nv_bfloat16*)b_scales.data_ptr<at::BFloat16>(), b_qzeros_ptr,
|
||||
topk_weights_ptr, sorted_token_ids.data_ptr<int32_t>(),
|
||||
expert_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(),
|
||||
num_experts, group_size, num_token_blocks, top_k, size_m, size_n,
|
||||
size_k, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, bit,
|
||||
b_qzeros.has_value(), topk_weights.has_value());
|
||||
} else {
|
||||
TORCH_CHECK(false, "moe_wna16_gemm only supports bfloat16 and float16");
|
||||
}
|
||||
return output;
|
||||
}
|
||||
200
csrc/moe/moe_wna16_utils.h
Normal file
200
csrc/moe/moe_wna16_utils.h
Normal file
@ -0,0 +1,200 @@
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
template <typename scalar_t>
|
||||
class ScalarType {};
|
||||
|
||||
template <>
|
||||
class ScalarType<half> {
|
||||
public:
|
||||
using scalar_t = half;
|
||||
using scalar_t2 = half2;
|
||||
|
||||
static __device__ float inline num2float(const half x) {
|
||||
return __half2float(x);
|
||||
}
|
||||
|
||||
static __device__ half2 inline num2num2(const half x) {
|
||||
return __half2half2(x);
|
||||
}
|
||||
|
||||
static __device__ half2 inline nums2num2(const half x1, const half x2) {
|
||||
return __halves2half2(x1, x2);
|
||||
}
|
||||
|
||||
static __host__ __device__ half inline float2num(const float x) {
|
||||
return __float2half(x);
|
||||
}
|
||||
|
||||
static __host__ __device__ half inline int2num(const float x) {
|
||||
return __int2half_rn(x);
|
||||
}
|
||||
|
||||
static __host__ __device__ float2 inline num22float2(const half2 x) {
|
||||
return __half22float2(x);
|
||||
}
|
||||
|
||||
static __host__ __device__ half2 inline float22num2(const float2 x) {
|
||||
return __float22half2_rn(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class ScalarType<nv_bfloat16> {
|
||||
public:
|
||||
using scalar_t = nv_bfloat16;
|
||||
using scalar_t2 = nv_bfloat162;
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
static __device__ float inline num2float(const nv_bfloat16 x) {
|
||||
return __bfloat162float(x);
|
||||
}
|
||||
|
||||
static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) {
|
||||
return __bfloat162bfloat162(x);
|
||||
}
|
||||
|
||||
static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1,
|
||||
const nv_bfloat16 x2) {
|
||||
return __halves2bfloat162(x1, x2);
|
||||
}
|
||||
|
||||
static __host__ __device__ nv_bfloat16 inline float2num(const float x) {
|
||||
return __float2bfloat16(x);
|
||||
}
|
||||
|
||||
static __host__ __device__ nv_bfloat16 inline int2num(const float x) {
|
||||
return __int2bfloat16_rn(x);
|
||||
}
|
||||
|
||||
static __host__ __device__ float2 inline num22float2(const nv_bfloat162 x) {
|
||||
return __bfloat1622float2(x);
|
||||
}
|
||||
|
||||
static __host__ __device__ nv_bfloat162 inline float22num2(const float2 x) {
|
||||
return __float22bfloat162_rn(x);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
template <int lut>
|
||||
__device__ inline int lop3(int a, int b, int c) {
|
||||
int res;
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(res)
|
||||
: "r"(a), "r"(b), "r"(c), "n"(lut));
|
||||
return res;
|
||||
}
|
||||
|
||||
template <int start_byte, int mask>
|
||||
__device__ inline uint32_t prmt(uint32_t a) {
|
||||
uint32_t res;
|
||||
asm volatile("prmt.b32 %0, %1, %2, %3;\n"
|
||||
: "=r"(res)
|
||||
: "r"(a), "n"(start_byte), "n"(mask));
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename scalar_t2, int bit>
|
||||
__device__ inline void dequant(int q, scalar_t2* res) {}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, 4>(int q, half2* res) {
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
const int SUB = 0x64006400;
|
||||
const int MUL = 0x2c002c00;
|
||||
const int ADD = 0xd400d400;
|
||||
|
||||
int lo0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
|
||||
int hi0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
|
||||
q >>= 8;
|
||||
int lo1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
|
||||
int hi1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
|
||||
|
||||
res[0] = __hsub2(*reinterpret_cast<half2*>(&lo0),
|
||||
*reinterpret_cast<const half2*>(&SUB));
|
||||
res[1] = __hfma2(*reinterpret_cast<half2*>(&hi0),
|
||||
*reinterpret_cast<const half2*>(&MUL),
|
||||
*reinterpret_cast<const half2*>(&ADD));
|
||||
res[2] = __hsub2(*reinterpret_cast<half2*>(&lo1),
|
||||
*reinterpret_cast<const half2*>(&SUB));
|
||||
res[3] = __hfma2(*reinterpret_cast<half2*>(&hi1),
|
||||
*reinterpret_cast<const half2*>(&MUL),
|
||||
*reinterpret_cast<const half2*>(&ADD));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<half2, 8>(int q, half2* res) {
|
||||
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
||||
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
||||
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||
|
||||
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
||||
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
||||
|
||||
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
|
||||
|
||||
res[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
res[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
|
||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
}
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, 4>(int q, nv_bfloat162* res) {
|
||||
static constexpr uint32_t MASK = 0x000f000f;
|
||||
static constexpr uint32_t EX = 0x43004300;
|
||||
|
||||
int lo0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
|
||||
q >>= 4;
|
||||
int hi0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
|
||||
q >>= 4;
|
||||
int lo1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
|
||||
q >>= 4;
|
||||
int hi1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
|
||||
|
||||
static constexpr uint32_t MUL = 0x3F803F80;
|
||||
static constexpr uint32_t ADD = 0xC300C300;
|
||||
|
||||
res[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo0),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
res[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi0),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
res[2] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo1),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
res[3] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi1),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<nv_bfloat162, 8>(int q, nv_bfloat162* res) {
|
||||
float fp32_intermediates[4];
|
||||
uint32_t* fp32_intermediates_casted =
|
||||
reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||
|
||||
static constexpr uint32_t fp32_base = 0x4B000000;
|
||||
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
|
||||
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
|
||||
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
|
||||
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
|
||||
|
||||
fp32_intermediates[0] -= 8388608.f;
|
||||
fp32_intermediates[1] -= 8388608.f;
|
||||
fp32_intermediates[2] -= 8388608.f;
|
||||
fp32_intermediates[3] -= 8388608.f;
|
||||
|
||||
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(res);
|
||||
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
|
||||
fp32_intermediates_casted[1], 0x7632);
|
||||
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
|
||||
fp32_intermediates_casted[3], 0x7632);
|
||||
}
|
||||
#endif
|
||||
@ -32,6 +32,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
m.impl("sgl_moe_align_block_size", torch::kCUDA, &sgl_moe_align_block_size);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
m.def(
|
||||
"moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, "
|
||||
"Tensor b_scales, Tensor? b_qzeros, "
|
||||
"Tensor? topk_weights, Tensor sorted_token_ids, "
|
||||
"Tensor expert_ids, Tensor num_tokens_post_pad, "
|
||||
"int top_k, int BLOCK_SIZE_M, int BLOCK_SIZE_N, int BLOCK_SIZE_K, "
|
||||
"int bit) -> Tensor");
|
||||
|
||||
m.impl("moe_wna16_gemm", torch::kCUDA, &moe_wna16_gemm);
|
||||
|
||||
m.def(
|
||||
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
|
||||
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
|
||||
@ -42,6 +52,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
"int moe_block_size, bool replicate_input, bool apply_weights)"
|
||||
" -> Tensor");
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
30
csrc/ops.h
30
csrc/ops.h
@ -151,20 +151,44 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
|
||||
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
|
||||
int64_t row);
|
||||
|
||||
torch::Tensor ggml_moe_a8(torch::Tensor X, torch::Tensor W,
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor expert_ids,
|
||||
torch::Tensor num_tokens_post_padded, int64_t type,
|
||||
int64_t row, int64_t top_k, int64_t tokens);
|
||||
|
||||
int64_t ggml_moe_get_block_size(int64_t type);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
|
||||
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability);
|
||||
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
|
||||
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
|
||||
bool cutlass_group_gemm_supported(int64_t cuda_device_capability);
|
||||
|
||||
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
|
||||
torch::Tensor const& B, torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha);
|
||||
|
||||
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
|
||||
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
|
||||
|
||||
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_moe_mm(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides);
|
||||
|
||||
void get_cutlass_moe_mm_data(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
||||
const int64_t num_experts, const int64_t n, const int64_t k);
|
||||
|
||||
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
|
||||
@ -274,7 +274,7 @@ void advance_step_flashinfer(
|
||||
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
||||
cudaDeviceGetAttribute(&threads, cudaDevAttrMaxThreadsPerBlock, dev);
|
||||
|
||||
int block_tables_stride = block_tables.stride(0);
|
||||
[[maybe_unused]] int block_tables_stride = block_tables.stride(0);
|
||||
TORCH_CHECK((blocks * threads > num_queries),
|
||||
"multi-step: not enough threads to map to num_queries = ",
|
||||
num_queries, " block_tables.stride(0) = ", block_tables.stride(0),
|
||||
|
||||
80
csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh
Normal file
80
csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh
Normal file
@ -0,0 +1,80 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <torch/all.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
|
||||
#include "core/scalar_type.hpp"
|
||||
#include "cutlass/bfloat16.h"
|
||||
#include "cutlass/float8.h"
|
||||
|
||||
template <typename ElementAB, typename ElementC, typename ElementAccumulator>
|
||||
__global__ void get_group_gemm_starts(
|
||||
int32_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets,
|
||||
ElementC** out_offsets, ElementAccumulator** a_scales_offsets,
|
||||
ElementAccumulator** b_scales_offsets, ElementAB* a_base_as_int,
|
||||
ElementAB* b_base_as_int, ElementC* out_base_as_int,
|
||||
ElementAccumulator* a_scales_base_as_int,
|
||||
ElementAccumulator* b_scales_base_as_int, int64_t n, int64_t k,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
int expert_id = threadIdx.x;
|
||||
|
||||
int64_t expert_offset = expert_offsets[expert_id];
|
||||
|
||||
a_offsets[expert_id] = a_base_as_int + expert_offset * k;
|
||||
b_offsets[expert_id] = b_base_as_int + expert_id * k * n;
|
||||
out_offsets[expert_id] = out_base_as_int + expert_offset * n;
|
||||
a_scales_offsets[expert_id] =
|
||||
a_scales_base_as_int + (per_act_token ? expert_offset : 0);
|
||||
b_scales_offsets[expert_id] =
|
||||
b_scales_base_as_int + (per_out_ch ? n * expert_id : expert_id);
|
||||
}
|
||||
|
||||
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
|
||||
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
|
||||
get_group_gemm_starts<cutlass::float_e4m3_t, C_TYPE, float> \
|
||||
<<<1, num_experts, 0, stream>>>( \
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()), \
|
||||
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
|
||||
static_cast<float**>(a_scales_ptrs.data_ptr()), \
|
||||
static_cast<float**>(b_scales_ptrs.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t*>(b_tensors.data_ptr()), \
|
||||
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
|
||||
static_cast<float*>(a_scales.data_ptr()), \
|
||||
static_cast<float*>(b_scales.data_ptr()), out_tensors.size(1), \
|
||||
a_tensors.size(1), per_act_token, per_out_ch); \
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
void run_get_group_gemm_starts(
|
||||
torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs,
|
||||
torch::Tensor& b_ptrs, torch::Tensor& out_ptrs,
|
||||
torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs,
|
||||
torch::Tensor const& a_tensors, torch::Tensor const& b_tensors,
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||
bool per_act_token = a_scales.numel() != 1;
|
||||
bool per_out_ch = b_scales.numel() != num_experts;
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
||||
|
||||
if (false) {
|
||||
}
|
||||
__CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t)
|
||||
__CALL_GET_STARTS_KERNEL(torch::kFloat16, half)
|
||||
else {
|
||||
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
160
csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu
Normal file
160
csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu
Normal file
@ -0,0 +1,160 @@
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "grouped_mm_c3x.cuh"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_default {
|
||||
// M in (16, inf)
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using TileShape = cute::Shape<cute::_64, cute::_256, cute::_128>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_2, cute::_1>;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_M16 {
|
||||
// M in [1, 16]
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using TileShape = cute::Shape<cute::_64, cute::_64, cute::_128>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_4, cute::_1>;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_K8192 {
|
||||
// K in [8192, inf)
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using TileShape = cute::Shape<cute::_128, cute::_128, cute::_128>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_8, cute::_1>;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_N8192 {
|
||||
// N in [8192, inf)
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using TileShape = cute::Shape<cute::_64, cute::_128, cute::_256>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_8, cute::_1>;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType>
|
||||
void run_cutlass_moe_mm_sm90(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
|
||||
TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
|
||||
TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
|
||||
TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
|
||||
|
||||
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn,
|
||||
"A tensors must be of type float8_e4m3fn.");
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn,
|
||||
"B tensors must be of type float8_e4m3fn.");
|
||||
|
||||
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
using Cutlass3xGemmN8192 = typename sm90_fp8_config_N8192<
|
||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmK8192 = typename sm90_fp8_config_K8192<
|
||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM16 = typename sm90_fp8_config_M16<
|
||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmDefault = typename sm90_fp8_config_default<
|
||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||
|
||||
uint32_t const m = a_tensors.size(0);
|
||||
uint32_t const n = out_tensors.size(1);
|
||||
uint32_t const k = a_tensors.size(1);
|
||||
|
||||
if (n >= 8192) {
|
||||
cutlass_group_gemm_caller<Cutlass3xGemmN8192>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides);
|
||||
} else if (k >= 8192) {
|
||||
cutlass_group_gemm_caller<Cutlass3xGemmK8192>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides);
|
||||
} else if (m <= 16) {
|
||||
cutlass_group_gemm_caller<Cutlass3xGemmM16>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides);
|
||||
} else {
|
||||
cutlass_group_gemm_caller<Cutlass3xGemmDefault>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides);
|
||||
}
|
||||
}
|
||||
|
||||
void dispatch_moe_mm_sm90(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
|
||||
if (out_tensors.dtype() == torch::kBFloat16) {
|
||||
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides);
|
||||
} else {
|
||||
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::half_t>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void cutlass_moe_mm_sm90(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
|
||||
dispatch_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides);
|
||||
}
|
||||
149
csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh
Normal file
149
csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh
Normal file
@ -0,0 +1,149 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "get_group_starts.cuh"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
namespace {
|
||||
|
||||
using ProblemShape =
|
||||
cutlass::gemm::GroupProblemShape<cute::Shape<int, int, int>>;
|
||||
|
||||
using ElementAccumulator = float;
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
|
||||
template <typename ElementAB_, typename ElementC_,
|
||||
template <typename, typename, typename> typename Epilogue_,
|
||||
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
||||
typename EpilogueSchedule>
|
||||
struct cutlass_3x_group_gemm {
|
||||
using ElementAB = ElementAB_;
|
||||
using ElementC = void;
|
||||
using ElementD = ElementC_;
|
||||
using ElementAccumulator = float;
|
||||
|
||||
using Epilogue = Epilogue_<ElementAccumulator, ElementD, TileShape>;
|
||||
|
||||
using StrideC =
|
||||
cute::remove_pointer_t<cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>>;
|
||||
|
||||
static constexpr int AlignmentAB =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
using EVTCompute = typename Epilogue::EVTCompute;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
|
||||
ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementD,
|
||||
LayoutC*, AlignmentC, EpilogueSchedule, EVTCompute>::CollectiveOp;
|
||||
|
||||
static constexpr size_t CEStorageSize =
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage);
|
||||
using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(CEStorageSize)>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, ElementAB, LayoutA*, AlignmentAB, ElementAB,
|
||||
LayoutB*, AlignmentAB, ElementAccumulator, TileShape, ClusterShape,
|
||||
Stages, KernelSchedule>::CollectiveOp;
|
||||
|
||||
using KernelType = enable_sm90_only<cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape, CollectiveMainloop, CollectiveEpilogue>>;
|
||||
|
||||
struct GemmKernel : public KernelType {};
|
||||
};
|
||||
|
||||
template <typename Gemm>
|
||||
void cutlass_group_gemm_caller(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
|
||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||
int k_size = a_tensors.size(1);
|
||||
int n_size = out_tensors.size(1);
|
||||
|
||||
bool per_act_token = a_scales.numel() != 1;
|
||||
bool per_out_ch = b_scales.numel() != num_experts;
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
||||
|
||||
auto options_int =
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a_tensors.device());
|
||||
|
||||
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
|
||||
|
||||
run_get_group_gemm_starts(expert_offsets, a_ptrs, b_ptrs, out_ptrs,
|
||||
a_scales_ptrs, b_scales_ptrs, a_tensors, b_tensors,
|
||||
out_tensors, a_scales, b_scales);
|
||||
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
using StrideA = Stride<int64_t, Int<1>, Int<0>>;
|
||||
using StrideB = Stride<int64_t, Int<1>, Int<0>>;
|
||||
using StrideC = typename GemmKernel::InternalStrideC;
|
||||
|
||||
ProblemShape::UnderlyingProblemShape* problem_sizes_as_shapes =
|
||||
static_cast<ProblemShape::UnderlyingProblemShape*>(
|
||||
problem_sizes.data_ptr());
|
||||
ProblemShape prob_shape{num_experts, problem_sizes_as_shapes, nullptr};
|
||||
|
||||
typename GemmKernel::MainloopArguments mainloop_args{
|
||||
static_cast<const ElementAB**>(a_ptrs.data_ptr()),
|
||||
static_cast<StrideA*>(a_strides.data_ptr()),
|
||||
static_cast<const ElementAB**>(b_ptrs.data_ptr()),
|
||||
static_cast<StrideB*>(b_strides.data_ptr())};
|
||||
|
||||
// Currently, we are only able to do broadcast on either all or none a_scales
|
||||
// and on either all or none b_scales
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
Gemm::Epilogue::prepare_args(
|
||||
static_cast<const ElementAccumulator**>(a_scales_ptrs.data_ptr()),
|
||||
static_cast<const ElementAccumulator**>(b_scales_ptrs.data_ptr()),
|
||||
per_act_token, per_out_ch),
|
||||
nullptr, static_cast<StrideC*>(c_strides.data_ptr()),
|
||||
static_cast<ElementD**>(out_ptrs.data_ptr()),
|
||||
static_cast<StrideC*>(c_strides.data_ptr())};
|
||||
|
||||
typename GemmKernel::Arguments args{
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape, mainloop_args,
|
||||
epilogue_args};
|
||||
|
||||
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
GemmOp gemm_op;
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(a_tensors.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
90
csrc/quantization/cutlass_w8a8/moe/moe_data.cu
Normal file
90
csrc/quantization/cutlass_w8a8/moe/moe_data.cu
Normal file
@ -0,0 +1,90 @@
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
constexpr uint64_t THREADS_PER_EXPERT = 512;
|
||||
|
||||
__global__ void compute_problem_sizes(const int* __restrict__ topk_ids,
|
||||
int32_t* problem_sizes1,
|
||||
int32_t* problem_sizes2,
|
||||
int32_t* atomic_buffer,
|
||||
const int topk_length, const int n,
|
||||
const int k) {
|
||||
int expert_id = blockIdx.x;
|
||||
|
||||
int occurrences = 0;
|
||||
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
|
||||
occurrences += (topk_ids[i] == expert_id);
|
||||
}
|
||||
atomicAdd(&atomic_buffer[expert_id], occurrences);
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
int final_occurrences = atomic_buffer[expert_id];
|
||||
problem_sizes1[expert_id * 3] = final_occurrences;
|
||||
problem_sizes1[expert_id * 3 + 1] = 2 * n;
|
||||
problem_sizes1[expert_id * 3 + 2] = k;
|
||||
problem_sizes2[expert_id * 3] = final_occurrences;
|
||||
problem_sizes2[expert_id * 3 + 1] = k;
|
||||
problem_sizes2[expert_id * 3 + 2] = n;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void compute_expert_offsets(
|
||||
const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets,
|
||||
int32_t* atomic_buffer, const int num_experts) {
|
||||
int32_t tot_offset = 0;
|
||||
expert_offsets[0] = 0;
|
||||
for (int i = 0; i < num_experts; ++i) {
|
||||
atomic_buffer[i] = tot_offset;
|
||||
tot_offset += problem_sizes1[i * 3];
|
||||
expert_offsets[i + 1] = tot_offset;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void compute_arg_sorts(const int* __restrict__ topk_ids,
|
||||
int32_t* input_permutation,
|
||||
int32_t* output_permutation,
|
||||
int32_t* atomic_buffer, const int topk_length,
|
||||
const int topk) {
|
||||
int expert_id = blockIdx.x;
|
||||
|
||||
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
|
||||
if (topk_ids[i] == expert_id) {
|
||||
int start = atomicAdd(&atomic_buffer[expert_id], 1);
|
||||
input_permutation[start] = i / topk;
|
||||
output_permutation[i] = start;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_data_caller(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
||||
const int64_t num_experts, const int64_t n, const int64_t k) {
|
||||
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
|
||||
auto options_int32 =
|
||||
torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
|
||||
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
|
||||
|
||||
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
||||
compute_problem_sizes<<<num_experts, num_threads, 0, stream>>>(
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n, k);
|
||||
compute_expert_offsets<<<1, 1, 0, stream>>>(
|
||||
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
|
||||
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<int32_t*>(input_permutation.data_ptr()),
|
||||
static_cast<int32_t*>(output_permutation.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(),
|
||||
topk_ids.size(1));
|
||||
}
|
||||
34
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu
Normal file
34
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu
Normal file
@ -0,0 +1,34 @@
|
||||
#include <cudaTypedefs.h>
|
||||
#include "c3x/scaled_mm_kernels.hpp"
|
||||
|
||||
#include "cuda_utils.h"
|
||||
|
||||
/*
|
||||
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
||||
NVIDIA GPUs with sm100 (Blackwell).
|
||||
*/
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
|
||||
|
||||
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
int M = a.size(0), N = b.size(1), K = a.size(1);
|
||||
TORCH_CHECK(
|
||||
(a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
|
||||
(b_scales.numel() == 1 || b_scales.numel() == b.size(1)),
|
||||
"Currently, block scaled fp8 gemm is not implemented for Blackwell");
|
||||
|
||||
// Standard per-tensor/per-token/per-channel scaling
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn,
|
||||
"Currently, only fp8 gemm is implemented for Blackwell");
|
||||
vllm::cutlass_scaled_mm_sm100_fp8(c, a, b, a_scales, b_scales, bias);
|
||||
}
|
||||
|
||||
#endif
|
||||
@ -5,9 +5,11 @@
|
||||
|
||||
/*
|
||||
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
||||
NVIDIA GPUs with sm90a (Hopper) or later.
|
||||
NVIDIA GPUs with sm90a (Hopper).
|
||||
*/
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
|
||||
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
@ -72,27 +74,4 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
|
||||
azp, bias);
|
||||
}
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12080
|
||||
|
||||
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
int M = a.size(0), N = b.size(1), K = a.size(1);
|
||||
TORCH_CHECK(
|
||||
(a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
|
||||
(b_scales.numel() == 1 || b_scales.numel() == b.size(1)),
|
||||
"Currently, block scaled fp8 gemm is not implemented for Blackwell");
|
||||
|
||||
// Standard per-tensor/per-token/per-channel scaling
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn,
|
||||
"Currently, only fp8 gemm is implemented for Blackwell");
|
||||
vllm::cutlass_scaled_mm_sm100_fp8(c, a, b, a_scales, b_scales, bias);
|
||||
}
|
||||
|
||||
#endif
|
||||
@ -23,12 +23,29 @@ void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_moe_mm_sm90(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides);
|
||||
|
||||
void get_cutlass_moe_mm_data_caller(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
||||
const int64_t num_experts, const int64_t n, const int64_t k);
|
||||
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
|
||||
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
@ -60,7 +77,7 @@ void cutlass_scaled_mm_azp_sm89(torch::Tensor& c, torch::Tensor const& a,
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
@ -99,6 +116,19 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool cutlass_group_gemm_supported(int64_t cuda_device_capability) {
|
||||
// CUTLASS groped FP8 kernels need at least CUDA 12.3
|
||||
// and SM90 (Hopper)
|
||||
|
||||
#if defined CUDA_VERSION
|
||||
if (cuda_device_capability == 90) {
|
||||
return CUDA_VERSION >= 12030;
|
||||
}
|
||||
#endif
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
@ -121,26 +151,21 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
|
||||
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
||||
int32_t version_num = get_sm_version_num();
|
||||
// Hopper
|
||||
|
||||
// Guard against compilation issues for sm90 kernels
|
||||
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION < 12080
|
||||
if (version_num >= 90 && version_num < 100) {
|
||||
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
#else
|
||||
if (version_num >= 90 && version_num < 100) {
|
||||
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
} else if (version_num >= 100) {
|
||||
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
|
||||
if (version_num >= 100) {
|
||||
cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// Guard against compilation issues for sm90 kernels
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
if (version_num >= 90 && version_num < 100) {
|
||||
// Hopper
|
||||
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
|
||||
@ -170,6 +195,46 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
version_num);
|
||||
}
|
||||
|
||||
void cutlass_moe_mm(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
|
||||
cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides);
|
||||
return;
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled cutlass_scaled_mm for CUDA device capability: ", version_num,
|
||||
". Required capability: 90");
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_data(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
||||
const int64_t num_experts, const int64_t n, const int64_t k) {
|
||||
// This function currently gets compiled only if we have a valid cutlass moe
|
||||
// mm to run it for.
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
|
||||
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
|
||||
problem_sizes2, input_permutation,
|
||||
output_permutation, num_experts, n, k);
|
||||
return;
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for "
|
||||
"CUDA device capability: ",
|
||||
version_num, ". Required capability: 90");
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
@ -211,7 +276,7 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
||||
|
||||
int32_t version_num = get_sm_version_num();
|
||||
|
||||
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
|
||||
#if defined ENABLE_SCALED_MM_SM90 && ENABLE_SCALED_MM_SM90
|
||||
if (version_num >= 90) {
|
||||
cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
||||
return;
|
||||
|
||||
@ -36,3 +36,9 @@ void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
|
||||
"be compiled using CUDA 12.8 and target "
|
||||
"compute capability 100 or above.");
|
||||
}
|
||||
|
||||
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability) {
|
||||
int runtimeVersion;
|
||||
cudaRuntimeGetVersion(&runtimeVersion);
|
||||
return cuda_device_capability >= 100 && runtimeVersion >= 12080;
|
||||
}
|
||||
@ -201,10 +201,11 @@ void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
#define CHECK_TYPE(x, st, m) \
|
||||
TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type:", m)
|
||||
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
|
||||
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
|
||||
#define CHECK_TH_CUDA(x, m) \
|
||||
TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x, m) \
|
||||
TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
|
||||
TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous")
|
||||
#define CHECK_INPUT(x, st, m) \
|
||||
CHECK_TH_CUDA(x, m); \
|
||||
CHECK_CONTIGUOUS(x, m); \
|
||||
|
||||
@ -13,6 +13,40 @@ namespace vllm {
|
||||
namespace fp8 {
|
||||
#ifdef ENABLE_FP8
|
||||
|
||||
// Use hardware cvt instruction for fp8 on rocm
|
||||
template <typename fp8_type>
|
||||
__device__ __forceinline__ fp8_type cvt_c10(float const r) {
|
||||
return {};
|
||||
}
|
||||
|
||||
// __hip_fp8_e4m3 only exists starting in ROCm 6.3. The macro
|
||||
// HIP_FP8_TYPE_OCP comes from the hip_fp8.h header and also makes
|
||||
// its first appearance in ROCm 6.3. Since VLLM_DISPATCH_FP8_TYPES
|
||||
// on ROCm instantiates both OCP and FNUZ kernels, we need to replace
|
||||
// the new HW cvt with something reasonable that doesn't rely on the
|
||||
// ROCm 6.3 feature. This allows compiling on ROCm 6.2 or newer.
|
||||
template <>
|
||||
__device__ __forceinline__ c10::Float8_e4m3fn cvt_c10(float const r) {
|
||||
#if HIP_FP8_TYPE_OCP
|
||||
return c10::Float8_e4m3fn(
|
||||
__hip_cvt_float_to_fp8(r, __hip_fp8_e4m3::__default_saturation,
|
||||
__hip_fp8_e4m3::__default_interpret),
|
||||
c10::Float8_e4m3fn::from_bits());
|
||||
#else
|
||||
// Cast implemented by pytorch. Uses bit manipulation instead of HW cvt.
|
||||
// HW cvt above is faster when it is available (ROCm 6.3 or newer).
|
||||
return static_cast<c10::Float8_e4m3fn>(r);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ c10::Float8_e4m3fnuz cvt_c10(float const r) {
|
||||
return c10::Float8_e4m3fnuz(
|
||||
__hip_cvt_float_to_fp8(r, __hip_fp8_e4m3_fnuz::__default_saturation,
|
||||
__hip_fp8_e4m3_fnuz::__default_interpret),
|
||||
c10::Float8_e4m3fnuz::from_bits());
|
||||
}
|
||||
|
||||
template <typename Tout, typename Tin>
|
||||
__inline__ __device__ Tout vec_conversion(const Tin& x) {
|
||||
return x;
|
||||
@ -412,7 +446,7 @@ scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) {
|
||||
template <>
|
||||
__inline__ __device__ uint32_t
|
||||
scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale) {
|
||||
__half2_raw h2r =
|
||||
[[maybe_unused]] __half2_raw h2r =
|
||||
__hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
|
||||
union {
|
||||
__half2_raw h2r;
|
||||
|
||||
@ -11,8 +11,8 @@
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
__global__ void scaled_fp8_quant_kernel(fp8_type* __restrict__ out,
|
||||
const scalar_t* __restrict__ input,
|
||||
const float* __restrict__ scale,
|
||||
int64_t num_elems) {
|
||||
@ -25,12 +25,13 @@ __global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
|
||||
out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
__global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
||||
FP8_TYPE* __restrict__ out, float* __restrict__ scale,
|
||||
fp8_type* __restrict__ out, float* __restrict__ scale,
|
||||
scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
|
||||
const int hidden_size) {
|
||||
float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
|
||||
float const min_scaling_factor =
|
||||
1.0f / (fp8_e4m3_adjusted_max_v<fp8_type> * 512.f);
|
||||
|
||||
int const tid = threadIdx.x;
|
||||
int const token_idx = blockIdx.x;
|
||||
@ -38,7 +39,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
||||
// Use int64 to avoid overflowing an int32 when calculating this offset
|
||||
int64_t offset = static_cast<int64_t>(token_idx) * hidden_size;
|
||||
scalar_t const* __restrict__ token_input = &input[offset];
|
||||
FP8_TYPE* __restrict__ token_output = &out[offset];
|
||||
fp8_type* __restrict__ token_output = &out[offset];
|
||||
|
||||
// For vectorization, token_input and token_output pointers need to be
|
||||
// aligned at 8-byte and 4-byte addresses respectively.
|
||||
@ -66,7 +67,8 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
||||
token_scale = block_absmax_val_maybe;
|
||||
}
|
||||
// token scale computation
|
||||
token_scale = max(token_scale / FP8_E4M3_MAX, min_scaling_factor);
|
||||
token_scale = max(token_scale / fp8_e4m3_adjusted_max_v<fp8_type>,
|
||||
min_scaling_factor);
|
||||
scale[token_idx] = token_scale;
|
||||
}
|
||||
__syncthreads();
|
||||
@ -77,7 +79,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
||||
token_output, token_input, token_scale, hidden_size, tid, blockDim.x);
|
||||
} else {
|
||||
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
||||
token_output[i] = scaled_fp8_conversion<false>(
|
||||
token_output[i] = scaled_fp8_conversion<false, fp8_type>(
|
||||
static_cast<float>(token_input[i]), token_scale);
|
||||
}
|
||||
}
|
||||
@ -96,10 +98,14 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
|
||||
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(), num_elems);
|
||||
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
|
||||
VLLM_DISPATCH_FP8_TYPES(
|
||||
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
|
||||
vllm::scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(), num_elems);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@ -114,12 +120,18 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
|
||||
vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
|
||||
scale.data_ptr<float>(), input.data_ptr<scalar_t>(), num_elems);
|
||||
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(), num_elems);
|
||||
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
|
||||
VLLM_DISPATCH_FP8_TYPES(
|
||||
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
|
||||
vllm::segmented_max_reduction<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(scale.data_ptr<float>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
num_elems);
|
||||
vllm::scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(), num_elems);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@ -138,12 +150,18 @@ void dynamic_per_token_scaled_fp8_quant(
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] {
|
||||
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<FP8_TYPE>(), scales.data_ptr<float>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
|
||||
hidden_size);
|
||||
input.scalar_type(),
|
||||
"dynamic_per_token_scaled_fp8_quant_kernel_scalar_type", [&] {
|
||||
VLLM_DISPATCH_FP8_TYPES(
|
||||
out.scalar_type(),
|
||||
"dynamic_per_token_scaled_fp8_quant_kernel_fp8_type", [&] {
|
||||
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<fp8_t>(), scales.data_ptr<float>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
scale_ub.has_value() ? scale_ub->data_ptr<float>()
|
||||
: nullptr,
|
||||
hidden_size);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@ -7,18 +7,52 @@
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
using FP8_TYPE = c10::Float8_e4m3fn;
|
||||
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
|
||||
std::numeric_limits<FP8_TYPE>::max();
|
||||
#define MAYBE_HOST_DEVICE C10_HOST_DEVICE
|
||||
#else
|
||||
#include <ATen/hip/HIPContext.h>
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
#include <c10/util/Float8_e4m3fnuz.h>
|
||||
#include "amd/quant_utils.cuh"
|
||||
using FP8_TYPE = c10::Float8_e4m3fnuz;
|
||||
// Using the default max value from pytorch (240.0) will cause accuracy
|
||||
// issue when running dynamic quantization. Here use 224.0f for rocm.
|
||||
constexpr auto FP8_E4M3_MAX = 224.0f;
|
||||
// ROCm doesn't seem to need C10_HOST_DEVICE for static constexpr
|
||||
#define MAYBE_HOST_DEVICE
|
||||
#endif
|
||||
constexpr static auto kFp8Type = c10::CppTypeToScalarType<FP8_TYPE>::value;
|
||||
|
||||
// Determines the preferred FP8 type for the current platform.
|
||||
// Note that for CUDA this just returns true,
|
||||
// but on ROCm it will check device props.
|
||||
static bool is_fp8_ocp() {
|
||||
#ifndef USE_ROCM
|
||||
return true;
|
||||
#else
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
std::string device_arch = dprops->gcnArchName;
|
||||
size_t substring = device_arch.find("gfx94");
|
||||
return substring == std::string::npos;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct fp8_e4m3_adjusted_max;
|
||||
|
||||
template <>
|
||||
struct fp8_e4m3_adjusted_max<c10::Float8_e4m3fn> {
|
||||
static constexpr c10::Float8_e4m3fn val() {
|
||||
return std::numeric_limits<c10::Float8_e4m3fn>::max();
|
||||
}
|
||||
};
|
||||
|
||||
// Using the default max value from pytorch (240.0 0x7F) will cause accuracy
|
||||
// issues when running dynamic quantization. Here use 224.0 0x7E for rocm.
|
||||
template <>
|
||||
struct fp8_e4m3_adjusted_max<c10::Float8_e4m3fnuz> {
|
||||
static constexpr c10::Float8_e4m3fnuz val() {
|
||||
return c10::Float8_e4m3fnuz(0x7E, c10::Float8_e4m3fnuz::from_bits());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
MAYBE_HOST_DEVICE static constexpr T fp8_e4m3_adjusted_max_v =
|
||||
fp8_e4m3_adjusted_max<T>::val();
|
||||
|
||||
namespace vllm {
|
||||
|
||||
@ -32,8 +66,8 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
||||
return old;
|
||||
}
|
||||
|
||||
template <bool is_scale_inverted>
|
||||
__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
|
||||
template <bool is_scale_inverted, typename fp8_type>
|
||||
__device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val,
|
||||
float const scale) {
|
||||
float x = 0.0f;
|
||||
if constexpr (is_scale_inverted) {
|
||||
@ -42,15 +76,13 @@ __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
|
||||
x = val / scale;
|
||||
}
|
||||
|
||||
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
|
||||
float r = fmax(-fp8_e4m3_adjusted_max_v<fp8_type>,
|
||||
fmin(x, fp8_e4m3_adjusted_max_v<fp8_type>));
|
||||
#ifndef USE_ROCM
|
||||
return static_cast<c10::Float8_e4m3fn>(r);
|
||||
return static_cast<fp8_type>(r);
|
||||
#else
|
||||
// Use hardware cvt instruction for fp8 on rocm
|
||||
return c10::Float8_e4m3fnuz(
|
||||
__hip_cvt_float_to_fp8(r, fp8::fp8_type::__default_saturation,
|
||||
fp8::fp8_type::__default_interpret),
|
||||
c10::Float8_e4m3fnuz::from_bits());
|
||||
return fp8::cvt_c10<fp8_type>(r);
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -60,7 +92,7 @@ __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
|
||||
// So to get the right answer, *scale needs to be initialized to
|
||||
// a value <= 0.0 and we need to wait for all thread blocks to
|
||||
// finish before consuming *scale.
|
||||
template <typename scalar_t>
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
__global__ void segmented_max_reduction(float* __restrict__ scale,
|
||||
const scalar_t* __restrict__ input,
|
||||
int64_t num_elems) {
|
||||
@ -91,7 +123,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
|
||||
// Finally, since cache[0] contains the maximum for this thread block,
|
||||
// atomically write the max to the target location
|
||||
if (threadIdx.x == 0) {
|
||||
atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX);
|
||||
atomicMaxFloat(scale, cache[0] / fp8_e4m3_adjusted_max_v<fp8_type>);
|
||||
}
|
||||
}
|
||||
|
||||
@ -123,13 +155,13 @@ __device__ float thread_max_vec(scalar_t const* __restrict__ input,
|
||||
return absmax_val;
|
||||
}
|
||||
|
||||
template <typename scalar_t, bool is_scale_inverted>
|
||||
__device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
|
||||
template <typename scalar_t, bool is_scale_inverted, typename fp8_type>
|
||||
__device__ void scaled_fp8_conversion_vec(fp8_type* __restrict__ out,
|
||||
scalar_t const* __restrict__ input,
|
||||
float const scale,
|
||||
int64_t const num_elems,
|
||||
int const tid, int const step) {
|
||||
using float8x4_t = q8x4_t<FP8_TYPE>;
|
||||
using float8x4_t = q8x4_t<fp8_type>;
|
||||
// Vectorized input/output to better utilize memory bandwidth.
|
||||
auto const* vectorized_in = reinterpret_cast<vec4_t<scalar_t> const*>(input);
|
||||
auto* vectorized_out = reinterpret_cast<float8x4_t*>(out);
|
||||
@ -141,22 +173,22 @@ __device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
|
||||
vec4_t<scalar_t> in_vec = vectorized_in[i];
|
||||
float8x4_t out_vec;
|
||||
|
||||
out_vec.x = scaled_fp8_conversion<is_scale_inverted>(
|
||||
out_vec.x = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
||||
static_cast<float>(in_vec.x), scale);
|
||||
out_vec.y = scaled_fp8_conversion<is_scale_inverted>(
|
||||
out_vec.y = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
||||
static_cast<float>(in_vec.y), scale);
|
||||
out_vec.z = scaled_fp8_conversion<is_scale_inverted>(
|
||||
out_vec.z = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
||||
static_cast<float>(in_vec.z), scale);
|
||||
out_vec.w = scaled_fp8_conversion<is_scale_inverted>(
|
||||
out_vec.w = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
||||
static_cast<float>(in_vec.w), scale);
|
||||
vectorized_out[i] = out_vec;
|
||||
}
|
||||
|
||||
// Handle the remaining elements if num_elems is not divisible by 4
|
||||
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
|
||||
out[i] = scaled_fp8_conversion<is_scale_inverted>(
|
||||
out[i] = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
||||
static_cast<float>(input[i]), scale);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
} // namespace vllm
|
||||
|
||||
@ -144,6 +144,9 @@ void rms_norm_dynamic_per_token_quant(
|
||||
torch::Tensor& scales, // [num_tokens]
|
||||
double const var_epsilon, // Variance epsilon used in norm calculation
|
||||
std::optional<at::Tensor> scale_ub, std::optional<at::Tensor> residual) {
|
||||
static c10::ScalarType kFp8Type = is_fp8_ocp()
|
||||
? c10::ScalarType::Float8_e4m3fn
|
||||
: c10::ScalarType::Float8_e4m3fnuz;
|
||||
TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(out.is_contiguous() && input.is_contiguous());
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user