Compare commits
398 Commits
debug
...
codex/remo
| Author | SHA1 | Date | |
|---|---|---|---|
| 85013bf094 | |||
| 07665f8679 | |||
| 9fac6aa30b | |||
| a53ad626d6 | |||
| 1c3dad22ff | |||
| d2a30a2d93 | |||
| 75fb112d80 | |||
| 38db529f66 | |||
| 064cac7bb7 | |||
| e19bce40a1 | |||
| 505805b645 | |||
| bbdc0f2366 | |||
| dc34059360 | |||
| c4cb0af98a | |||
| 1c3b1634aa | |||
| 2ea50e977a | |||
| b419937c78 | |||
| 5f696c33b1 | |||
| 67244c86f0 | |||
| 072d7e53e5 | |||
| 01a583fea4 | |||
| bc19d75985 | |||
| fbd6523ac0 | |||
| 470484a4f5 | |||
| 21da73343a | |||
| 66072b36db | |||
| 3ed1ec4af2 | |||
| 5a33ae9a3f | |||
| c9ff9e6f0c | |||
| eaffe4486c | |||
| 8ed039d527 | |||
| 37970105fe | |||
| cc935fdd7e | |||
| abdfcd4f3d | |||
| 4f02b77de4 | |||
| 29283e8976 | |||
| 05b044e698 | |||
| aa3f105c59 | |||
| ef7eefe17a | |||
| 350c94deb3 | |||
| f4cd80f944 | |||
| 349e0e3462 | |||
| 81b16a2bc9 | |||
| e111d5b0ae | |||
| a904ea78ea | |||
| b7433ca1a4 | |||
| 5c65a72bb1 | |||
| 9d8a2d86d2 | |||
| 3bc18127ff | |||
| bec060fd99 | |||
| 52bc9d5b3e | |||
| dc2979c585 | |||
| 027d37df38 | |||
| b98219670f | |||
| 32baf1d036 | |||
| 3127274d02 | |||
| 4ac510f484 | |||
| 7fb2a5be28 | |||
| 6c036615dc | |||
| 2fc24e94f9 | |||
| 2c3c1bd07a | |||
| 5963b98b46 | |||
| e6585ddb45 | |||
| 2a4d6412e6 | |||
| e67a79db03 | |||
| 9f882d8791 | |||
| 1a456c7c90 | |||
| fedb75fa27 | |||
| bff2e5f1d6 | |||
| 3c068c637b | |||
| f20c3b0951 | |||
| 883131544f | |||
| ee5fd49150 | |||
| 7ae9887542 | |||
| e3db5ebb66 | |||
| 9d442b7c48 | |||
| eb68c2dcd9 | |||
| 8b32464ac1 | |||
| 99cc41ad50 | |||
| d6a518fdde | |||
| 4aa8c7b047 | |||
| 4b946d693e | |||
| 087c6ffc92 | |||
| 4a2d33e371 | |||
| 8f3616f422 | |||
| 47f670b03b | |||
| dd6a910aac | |||
| 1b962e2457 | |||
| bfe9380161 | |||
| 9fccd04e30 | |||
| 252ada5559 | |||
| e120533d7a | |||
| 2b85697031 | |||
| 544fe76b95 | |||
| bb58dc8c20 | |||
| 0fb2551c23 | |||
| 6c47f6bfa4 | |||
| c15309a730 | |||
| 4a9375fe9d | |||
| 03191cd8f0 | |||
| b77bf34e53 | |||
| dd39baf717 | |||
| 43a62c51be | |||
| ca2d1925ef | |||
| 0f7acdd73c | |||
| 5801e49776 | |||
| 58d4c705a8 | |||
| ea3de5ef0d | |||
| 67532a1a68 | |||
| 5672ba90bd | |||
| dd83a157f1 | |||
| 5a411ef6c4 | |||
| eeb135eb87 | |||
| 3059b9cc6b | |||
| 64ad551878 | |||
| cef32104b4 | |||
| 493b10f8bf | |||
| d119fc8614 | |||
| dbebb7f812 | |||
| 3053a22b33 | |||
| 02d4b85454 | |||
| 86daa875fe | |||
| dcf2f3ec06 | |||
| 218454b9b2 | |||
| f4d6eb95cf | |||
| cd1f885bcf | |||
| d593cf28fa | |||
| faa7a5daac | |||
| 567939953b | |||
| 08369289af | |||
| 73cfb3c5ee | |||
| 4e5affeaa1 | |||
| e4f0b4cd96 | |||
| de3e53a75b | |||
| 85e0df1392 | |||
| 0faf3cc3e8 | |||
| 7ea5c73ad7 | |||
| 27fcfe7bcf | |||
| 68dbde5dbb | |||
| 04ad0dc275 | |||
| 238c4c1705 | |||
| 8c54610265 | |||
| 17871983a2 | |||
| 759ef49b15 | |||
| 5206ab20ba | |||
| 0af3ce1355 | |||
| e1279ef00f | |||
| 2942970d44 | |||
| 3c96e7b8a1 | |||
| b42566f440 | |||
| d96e11167d | |||
| 2891603efd | |||
| de2cc3d867 | |||
| e95084308b | |||
| 7f6f2c1182 | |||
| 5bcc153d7b | |||
| 45bfa49cb8 | |||
| fd2f10546c | |||
| e757a629e7 | |||
| aae725af7c | |||
| 73df49ef3a | |||
| 25aba2b6a3 | |||
| 94b03f88dd | |||
| 49bfc538e4 | |||
| a0b26701c9 | |||
| c4afdb69cc | |||
| b834b4cbf1 | |||
| 740f0647b1 | |||
| 01413e0cf5 | |||
| 0e219cd50b | |||
| 72c99f2a75 | |||
| bf214ca226 | |||
| 2e41f5abca | |||
| bc0f6059a2 | |||
| 8de261b04a | |||
| a0d8b9738d | |||
| 59e17dd4a0 | |||
| 4979eb79da | |||
| a8c0f59973 | |||
| f4a948f33f | |||
| 3f3313981c | |||
| 78818dd1b0 | |||
| 8e5cdcda4e | |||
| 90f3f7d73e | |||
| 6dc8da5dc1 | |||
| 79cbcab871 | |||
| ff68035932 | |||
| 1177dd53e9 | |||
| fc2dbcda8b | |||
| fec347dee1 | |||
| cc3173ae98 | |||
| 3e903b6cb4 | |||
| 973c9d01da | |||
| 15b8fef453 | |||
| cfa3234a5b | |||
| 41ae4a1eab | |||
| 4dad72f0d9 | |||
| 59d7ffc17f | |||
| 1da0f1441d | |||
| 98229db244 | |||
| dbeee3844c | |||
| 30498f2a65 | |||
| abc7989adc | |||
| 9a8966bcc2 | |||
| 5febdc8750 | |||
| 99bfef841f | |||
| 89e08d6d18 | |||
| 7f2ea7074e | |||
| 4fdd6f5cbf | |||
| 8226dd56bf | |||
| 5fe643fc26 | |||
| 7ba32aa60b | |||
| c89ed8de43 | |||
| 3beadc2f25 | |||
| bc636f21a6 | |||
| 017354c0ef | |||
| 010acc6e1e | |||
| c8c42597ab | |||
| 9d2a44606d | |||
| f17c075884 | |||
| b0d1213ac3 | |||
| 57f94e88ea | |||
| 684b6870e1 | |||
| a5b84f1cbf | |||
| 9f04d9d55f | |||
| 4d7c1d531b | |||
| 41f17bf290 | |||
| bcb06d7baf | |||
| 0377802c20 | |||
| 72fc8aa412 | |||
| fdb09c77d6 | |||
| 7a1c4025f1 | |||
| 60a0951924 | |||
| 64d90c3e4f | |||
| 59d5d2c736 | |||
| d21a36f5f9 | |||
| 561a0baee0 | |||
| f592b3174b | |||
| 7920de0a2a | |||
| ddcec289c7 | |||
| e090b7b45b | |||
| 6a50eaa0d3 | |||
| 12a8414d81 | |||
| 880c741bb6 | |||
| 40b6c9122b | |||
| 2e6bc46821 | |||
| fcba05c435 | |||
| 7a30fa8708 | |||
| f82f7a8990 | |||
| c3aea10dc8 | |||
| d4fd2768ef | |||
| 7a70a71892 | |||
| 7d4651997a | |||
| 569bf1c9c0 | |||
| 1ec20355f5 | |||
| e42af78b18 | |||
| 074854b24f | |||
| 79ac59f32e | |||
| b971f91504 | |||
| c733bd5e87 | |||
| a892b259b4 | |||
| 127ded0a9e | |||
| bb2b5126da | |||
| 361ae27f8a | |||
| e26fef8397 | |||
| c1eda615ba | |||
| 4aa23892d6 | |||
| 1fdd5c42d7 | |||
| bcbe2a4d9e | |||
| 51d41265ad | |||
| 4984a291d5 | |||
| 404c85ca72 | |||
| 817beef7f3 | |||
| 4f6593b058 | |||
| 94e6b2d55f | |||
| fd1ce98cdd | |||
| d11ec124a0 | |||
| f510715882 | |||
| f946197473 | |||
| 0cd72a7b72 | |||
| 5f5271f1ee | |||
| d6249d0699 | |||
| 25bb9e8c65 | |||
| a1213fae5f | |||
| a8b0361c92 | |||
| ed5ae4aace | |||
| 0fc36463e0 | |||
| d14c4ebf08 | |||
| ba6011027d | |||
| 85df8afdae | |||
| 6aeb1dab4a | |||
| e93f4cc9e3 | |||
| 2048c4e379 | |||
| d13360183a | |||
| 9bd831f501 | |||
| e2b1f863aa | |||
| 41329a0ff9 | |||
| ee0bc5e1b4 | |||
| 3d1393f6fc | |||
| 8a894084d2 | |||
| e2d8c27f68 | |||
| 29799ddacc | |||
| f17a6aa4ec | |||
| 6c8deacd72 | |||
| 55b823ba0f | |||
| 8c5a747246 | |||
| 5931b7e5d9 | |||
| cc99baf14d | |||
| dcb28a332b | |||
| fba7856581 | |||
| b5e383cd8b | |||
| 9a161307f5 | |||
| 37e8182bfe | |||
| 4db4426404 | |||
| a0933c3bd6 | |||
| 09e68bce34 | |||
| 9fb74c27a7 | |||
| 4032949630 | |||
| 08abfa78ec | |||
| 2bef2d1405 | |||
| 36cacd0958 | |||
| bb3eb80d92 | |||
| fcc0a3130a | |||
| 736569da8d | |||
| 2eb9986a2d | |||
| ccee371e86 | |||
| c0bd6a684a | |||
| 3144d90217 | |||
| 2f5e5c18de | |||
| bd98842c8a | |||
| d6069887c6 | |||
| 492196ed0e | |||
| f4f1a8df22 | |||
| 0b9a612fa3 | |||
| 4c04eef706 | |||
| f36355abfd | |||
| 9e3c3a7df2 | |||
| 6cbd41909e | |||
| 72d30108a0 | |||
| 8b83b93739 | |||
| 9dbefd88e9 | |||
| 7c195d43da | |||
| 0ae43dbf8c | |||
| 267c80d31f | |||
| 77f62613f9 | |||
| feaf202e93 | |||
| 91130ae376 | |||
| e40827280b | |||
| 4377b1ae3b | |||
| 009d689b0c | |||
| 0efdb5c3ba | |||
| 53b42f4102 | |||
| 309d7aa401 | |||
| b4a01aaf95 | |||
| 83dd28aae4 | |||
| f88e84016f | |||
| 3c2156b3af | |||
| 7e7db04310 | |||
| 41f160b974 | |||
| dc625ea6b8 | |||
| b23fb78623 | |||
| 561f38dc3c | |||
| 73e688cb79 | |||
| fb1a8f932a | |||
| 0dc9cbb527 | |||
| b5fb3005a8 | |||
| 15de5ff9ea | |||
| b8a93076d3 | |||
| c3f9773b2c | |||
| 3707cb2505 | |||
| 920ed46b09 | |||
| 15cb047e25 | |||
| 9ad0688e43 | |||
| b9a1c4c8a2 | |||
| 1aa427fdc1 | |||
| 1c63a16b65 | |||
| 922d3b401b | |||
| 19332c0479 | |||
| a55cf41a09 | |||
| 6fb2788163 | |||
| 3d2a2de8f7 | |||
| 1116590b16 | |||
| ccb97338af | |||
| 45c9cb5835 | |||
| e283976f3a | |||
| 46876dff32 | |||
| 1823a00d67 | |||
| ed16d0f26f | |||
| 0cdd213641 | |||
| 948dd3443b | |||
| b2f7745774 | |||
| 82dfb12e52 | |||
| bba1042c6f | |||
| b6fbc15634 | |||
| 3e0d4a3475 | |||
| 562663a044 | |||
| ed1623a88a | |||
| 13b89bd823 |
@ -8,7 +8,7 @@ This benchmark aims to:
|
||||
|
||||
Latest results: [results link](https://blog.vllm.ai/2024/09/05/perf-update.html), scroll to the end.
|
||||
|
||||
Latest reproduction guilde: [github issue link](https://github.com/vllm-project/vllm/issues/8176)
|
||||
Latest reproduction guide: [github issue link](https://github.com/vllm-project/vllm/issues/8176)
|
||||
|
||||
## Setup
|
||||
|
||||
|
||||
@ -1,24 +1,22 @@
|
||||
steps:
|
||||
# aarch64 + CUDA builds. PyTorch 2.8 aarch64 + CUDA wheel is only available on CUDA 12.9
|
||||
- label: "Build arm64 wheel - CUDA 12.9"
|
||||
depends_on: ~
|
||||
id: build-wheel-arm64-cuda-12-9
|
||||
agents:
|
||||
queue: arm64_cpu_queue_postmerge
|
||||
commands:
|
||||
# #NOTE: torch_cuda_arch_list is derived from upstream PyTorch build files here:
|
||||
# https://github.com/pytorch/pytorch/blob/main/.ci/aarch64_linux/aarch64_ci_build.sh#L7
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg torch_cuda_arch_list='8.7 9.0 10.0+PTX 12.0' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg VLLM_MAIN_CUDA_VERSION=12.9 --build-arg torch_cuda_arch_list='8.7 9.0 10.0+PTX 12.0' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
|
||||
- "mkdir artifacts"
|
||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
- block: "Build CUDA 12.8 wheel"
|
||||
key: block-build-cu128-wheel
|
||||
|
||||
- label: "Build wheel - CUDA 12.8"
|
||||
depends_on: block-build-cu128-wheel
|
||||
depends_on: ~
|
||||
id: build-wheel-cuda-12-8
|
||||
agents:
|
||||
queue: cpu_queue_postmerge
|
||||
@ -30,12 +28,8 @@ steps:
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
- block: "Build CUDA 12.6 wheel"
|
||||
key: block-build-cu126-wheel
|
||||
depends_on: ~
|
||||
|
||||
- label: "Build wheel - CUDA 12.6"
|
||||
depends_on: block-build-cu126-wheel
|
||||
depends_on: ~
|
||||
id: build-wheel-cuda-12-6
|
||||
agents:
|
||||
queue: cpu_queue_postmerge
|
||||
@ -102,8 +96,6 @@ steps:
|
||||
depends_on:
|
||||
- create-multi-arch-manifest
|
||||
- build-wheel-cuda-12-8
|
||||
- build-wheel-cuda-12-6
|
||||
- build-wheel-cuda-12-9
|
||||
id: annotate-release-workflow
|
||||
agents:
|
||||
queue: cpu_queue_postmerge
|
||||
|
||||
@ -14,18 +14,33 @@ buildkite-agent annotate --style 'info' --context 'release-workflow' << EOF
|
||||
To download the wheel:
|
||||
\`\`\`
|
||||
aws s3 cp s3://vllm-wheels/${RELEASE_VERSION}/vllm-${RELEASE_VERSION}-cp38-abi3-manylinux1_x86_64.whl .
|
||||
aws s3 cp s3://vllm-wheels/${RELEASE_VERSION}/vllm-${RELEASE_VERSION}-cp38-abi3-manylinux2014_aarch64.whl .
|
||||
|
||||
aws s3 cp s3://vllm-wheels/${RELEASE_VERSION}+cu126/vllm-${RELEASE_VERSION}+cu126-cp38-abi3-manylinux1_x86_64.whl .
|
||||
aws s3 cp s3://vllm-wheels/${RELEASE_VERSION}+cu118/vllm-${RELEASE_VERSION}+cu118-cp38-abi3-manylinux1_x86_64.whl .
|
||||
aws s3 cp s3://vllm-wheels/${RELEASE_VERSION}+cu129/vllm-${RELEASE_VERSION}+cu129-cp38-abi3-manylinux1_x86_64.whl .
|
||||
\`\`\`
|
||||
|
||||
To download and upload the image:
|
||||
|
||||
\`\`\`
|
||||
docker pull public.ecr.aws/q9t5s3a7/vllm-release-repo:${BUILDKITE_COMMIT}
|
||||
docker tag public.ecr.aws/q9t5s3a7/vllm-release-repo:${BUILDKITE_COMMIT} vllm/vllm-openai
|
||||
docker tag vllm/vllm-openai vllm/vllm-openai:latest
|
||||
docker tag vllm/vllm-openai vllm/vllm-openai:v${RELEASE_VERSION}
|
||||
docker push vllm/vllm-openai:latest
|
||||
docker push vllm/vllm-openai:v${RELEASE_VERSION}
|
||||
docker pull public.ecr.aws/q9t5s3a7/vllm-release-repo:${BUILDKITE_COMMIT}-x86_64
|
||||
docker pull public.ecr.aws/q9t5s3a7/vllm-release-repo:${BUILDKITE_COMMIT}-aarch64
|
||||
|
||||
docker tag public.ecr.aws/q9t5s3a7/vllm-release-repo:${BUILDKITE_COMMIT}-x86_64 vllm/vllm-openai:x86_64
|
||||
docker tag vllm/vllm-openai:x86_64 vllm/vllm-openai:latest-x86_64
|
||||
docker tag vllm/vllm-openai:x86_64 vllm/vllm-openai:v${RELEASE_VERSION}-x86_64
|
||||
docker push vllm/vllm-openai:latest-x86_64
|
||||
docker push vllm/vllm-openai:v${RELEASE_VERSION}-x86_64
|
||||
|
||||
docker tag public.ecr.aws/q9t5s3a7/vllm-release-repo:${BUILDKITE_COMMIT}-aarch64 vllm/vllm-openai:aarch64
|
||||
docker tag vllm/vllm-openai:aarch64 vllm/vllm-openai:latest-aarch64
|
||||
docker tag vllm/vllm-openai:aarch64 vllm/vllm-openai:v${RELEASE_VERSION}-aarch64
|
||||
docker push vllm/vllm-openai:latest-aarch64
|
||||
docker push vllm/vllm-openai:v${RELEASE_VERSION}-aarch64
|
||||
|
||||
docker manifest create vllm/vllm-openai:latest vllm/vllm-openai:latest-x86_64 vllm/vllm-openai:latest-aarch64 --amend
|
||||
docker manifest create vllm/vllm-openai:v${RELEASE_VERSION} vllm/vllm-openai:v${RELEASE_VERSION}-x86_64 vllm/vllm-openai:v${RELEASE_VERSION}-aarch64 --amend
|
||||
docker manifest push vllm/vllm-openai:latest
|
||||
docker manifest push vllm/vllm-openai:v${RELEASE_VERSION}
|
||||
\`\`\`
|
||||
EOF
|
||||
@ -167,12 +167,6 @@ if [[ $commands == *" entrypoints/llm "* ]]; then
|
||||
--ignore=entrypoints/llm/test_prompt_validation.py "}
|
||||
fi
|
||||
|
||||
#Obsolete currently
|
||||
##ignore certain Entrypoints/llm tests
|
||||
#if [[ $commands == *" && pytest -v -s entrypoints/llm/test_guided_generate.py"* ]]; then
|
||||
# commands=${commands//" && pytest -v -s entrypoints/llm/test_guided_generate.py"/" "}
|
||||
#fi
|
||||
|
||||
# --ignore=entrypoints/openai/test_encoder_decoder.py \
|
||||
# --ignore=entrypoints/openai/test_embedding.py \
|
||||
# --ignore=entrypoints/openai/test_oot_registration.py
|
||||
|
||||
@ -66,7 +66,6 @@ function cpu_tests() {
|
||||
|
||||
pytest -x -v -s tests/models/language/pooling -m cpu_model
|
||||
pytest -x -v -s tests/models/multimodal/generation \
|
||||
--ignore=tests/models/multimodal/generation/test_mllama.py \
|
||||
--ignore=tests/models/multimodal/generation/test_pixtral.py \
|
||||
-m cpu_model"
|
||||
|
||||
|
||||
@ -30,6 +30,7 @@ docker run \
|
||||
bash -c '
|
||||
set -e
|
||||
echo $ZE_AFFINITY_MASK
|
||||
pip install tblib==3.1.0
|
||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
|
||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -O.cudagraph_mode=NONE
|
||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
|
||||
|
||||
@ -46,23 +46,19 @@ steps:
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/mq_llm_engine
|
||||
- tests/async_engine
|
||||
- tests/test_inputs.py
|
||||
- tests/test_outputs.py
|
||||
- tests/multimodal
|
||||
- tests/utils_
|
||||
- tests/worker
|
||||
- tests/standalone_tests/lazy_imports.py
|
||||
- tests/transformers_utils
|
||||
commands:
|
||||
- python3 standalone_tests/lazy_imports.py
|
||||
- pytest -v -s mq_llm_engine # MQLLMEngine
|
||||
- pytest -v -s async_engine # AsyncLLMEngine
|
||||
- pytest -v -s test_inputs.py
|
||||
- pytest -v -s test_outputs.py
|
||||
- pytest -v -s multimodal
|
||||
- pytest -v -s utils_ # Utils
|
||||
- pytest -v -s worker # Worker
|
||||
- pytest -v -s transformers_utils # transformers_utils
|
||||
|
||||
- label: Python-only Installation Test # 10min
|
||||
timeout_in_minutes: 20
|
||||
@ -82,27 +78,25 @@ steps:
|
||||
- vllm/
|
||||
- tests/basic_correctness/test_basic_correctness
|
||||
- tests/basic_correctness/test_cpu_offload
|
||||
- tests/basic_correctness/test_preemption
|
||||
- tests/basic_correctness/test_cumem.py
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s basic_correctness/test_cumem.py
|
||||
- pytest -v -s basic_correctness/test_basic_correctness.py
|
||||
- pytest -v -s basic_correctness/test_cpu_offload.py
|
||||
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
|
||||
|
||||
- label: Core Test # 22min
|
||||
timeout_in_minutes: 35
|
||||
mirror_hardwares: [amdexperimental]
|
||||
- label: Entrypoints Unit Tests # 5min
|
||||
timeout_in_minutes: 10
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
fast_check: true
|
||||
source_file_dependencies:
|
||||
- vllm/core
|
||||
- vllm/distributed
|
||||
- tests/core
|
||||
- vllm/entrypoints
|
||||
- tests/entrypoints/
|
||||
commands:
|
||||
- pytest -v -s core
|
||||
- pytest -v -s entrypoints/openai/tool_parsers
|
||||
- pytest -v -s entrypoints/ --ignore=entrypoints/llm --ignore=entrypoints/openai --ignore=entrypoints/offline_mode --ignore=entrypoints/test_chat_utils.py --ignore=entrypoints/pooling
|
||||
|
||||
- label: Entrypoints Test (LLM) # 30min
|
||||
- label: Entrypoints Integration Test (LLM) # 30min
|
||||
timeout_in_minutes: 40
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
@ -114,12 +108,11 @@ steps:
|
||||
- tests/entrypoints/offline_mode
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_collective_rpc.py
|
||||
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_collective_rpc.py
|
||||
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
|
||||
- VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
||||
|
||||
- label: Entrypoints Test (API Server) # 100min
|
||||
- label: Entrypoints Integration Test (API Server) # 100min
|
||||
timeout_in_minutes: 130
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
@ -132,9 +125,22 @@ steps:
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/openai/test_collective_rpc.py # PYTHONPATH is needed to import custom Worker extension
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_collective_rpc.py
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_collective_rpc.py --ignore=entrypoints/openai/tool_parsers/
|
||||
- pytest -v -s entrypoints/test_chat_utils.py
|
||||
|
||||
- label: Entrypoints Integration Test (Pooling)
|
||||
timeout_in_minutes: 50
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
fast_check: true
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/entrypoints/pooling
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s entrypoints/pooling
|
||||
|
||||
- label: Distributed Tests (4 GPUs) # 35min
|
||||
timeout_in_minutes: 50
|
||||
mirror_hardwares: [amdexperimental]
|
||||
@ -204,16 +210,14 @@ steps:
|
||||
num_gpus: 2
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/metrics
|
||||
- tests/tracing
|
||||
- tests/v1/tracing
|
||||
commands:
|
||||
- pytest -v -s metrics
|
||||
- "pip install \
|
||||
'opentelemetry-sdk>=1.26.0' \
|
||||
'opentelemetry-api>=1.26.0' \
|
||||
'opentelemetry-exporter-otlp>=1.26.0' \
|
||||
'opentelemetry-semantic-conventions-ai>=0.4.1'"
|
||||
- pytest -v -s tracing
|
||||
- pytest -v -s v1/tracing
|
||||
|
||||
##### fast check tests #####
|
||||
##### 1 GPU test #####
|
||||
@ -276,6 +280,7 @@ steps:
|
||||
# split the test to avoid interference
|
||||
- pytest -v -s v1/core
|
||||
- pytest -v -s v1/executor
|
||||
- pytest -v -s v1/kv_offload
|
||||
- pytest -v -s v1/sample
|
||||
- pytest -v -s v1/logits_processors
|
||||
- pytest -v -s v1/worker
|
||||
@ -310,7 +315,6 @@ steps:
|
||||
- python3 offline_inference/vision_language_pooling.py --seed 0
|
||||
- python3 offline_inference/vision_language_multi_image.py --seed 0
|
||||
- VLLM_USE_V1=0 python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
|
||||
- python3 offline_inference/encoder_decoder.py
|
||||
- python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0
|
||||
- python3 offline_inference/basic/classify.py
|
||||
- python3 offline_inference/basic/embed.py
|
||||
@ -369,6 +373,7 @@ steps:
|
||||
- pytest -v -s compile/test_async_tp.py
|
||||
- pytest -v -s compile/test_fusion_all_reduce.py
|
||||
- pytest -v -s compile/test_decorator.py
|
||||
- pytest -v -s compile/test_noop_elimination.py
|
||||
|
||||
- label: PyTorch Fullgraph Smoke Test # 15min
|
||||
timeout_in_minutes: 30
|
||||
@ -379,11 +384,7 @@ steps:
|
||||
- tests/compile
|
||||
commands:
|
||||
- pytest -v -s compile/test_basic_correctness.py
|
||||
# these tests need to be separated, cannot combine
|
||||
- pytest -v -s compile/piecewise/test_simple.py
|
||||
- pytest -v -s compile/piecewise/test_toy_llama.py
|
||||
- pytest -v -s compile/piecewise/test_full_cudagraph.py
|
||||
- pytest -v -s compile/piecewise/test_multiple_graphs.py
|
||||
- pytest -v -s compile/piecewise/
|
||||
|
||||
- label: PyTorch Fullgraph Test # 20min
|
||||
timeout_in_minutes: 30
|
||||
@ -501,6 +502,10 @@ steps:
|
||||
commands:
|
||||
# temporary install here since we need nightly, will move to requirements/test.in
|
||||
# after torchao 0.12 release, and pin a working version of torchao nightly here
|
||||
|
||||
# since torchao nightly is only compatible with torch nightly currently
|
||||
# https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now
|
||||
# we can only upgrade after this is resolved
|
||||
- pip install --pre torchao==0.13.0.dev20250814 --index-url https://download.pytorch.org/whl/nightly/cu128
|
||||
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization
|
||||
|
||||
@ -523,15 +528,6 @@ steps:
|
||||
commands: # LMEval+Transcription WER check
|
||||
- pytest -s entrypoints/openai/correctness/
|
||||
|
||||
- label: Encoder Decoder tests # 12min
|
||||
timeout_in_minutes: 20
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/encoder_decoder
|
||||
commands:
|
||||
- pytest -v -s encoder_decoder
|
||||
|
||||
- label: OpenAI-Compatible Tool Use # 23 min
|
||||
timeout_in_minutes: 35
|
||||
mirror_hardwares: [amdexperimental]
|
||||
@ -546,36 +542,85 @@ steps:
|
||||
|
||||
##### models test #####
|
||||
|
||||
- label: Basic Models Test # 57min
|
||||
timeout_in_minutes: 75
|
||||
- label: Basic Models Tests (Initialization)
|
||||
timeout_in_minutes: 45
|
||||
mirror_hardwares: [amdexperimental]
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models
|
||||
- tests/models/test_initialization.py
|
||||
commands:
|
||||
- pytest -v -s models/test_transformers.py
|
||||
- pytest -v -s models/test_registry.py
|
||||
- pytest -v -s models/test_utils.py
|
||||
- pytest -v -s models/test_vision.py
|
||||
- pytest -v -s models/test_initialization.py
|
||||
# Run a subset of model initialization tests
|
||||
- pytest -v -s models/test_initialization.py::test_can_initialize_small_subset
|
||||
|
||||
- label: Language Models Test (Standard) # 35min
|
||||
- label: Basic Models Tests (Extra Initialization) %N
|
||||
timeout_in_minutes: 45
|
||||
mirror_hardwares: [amdexperimental]
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/model_executor/models/
|
||||
- tests/models/test_initialization.py
|
||||
commands:
|
||||
# Only when vLLM model source is modified - test initialization of a large
|
||||
# subset of supported models (the complement of the small subset in the above
|
||||
# test.) Also run if model initialization test file is modified
|
||||
- pytest -v -s models/test_initialization.py \
|
||||
-k 'not test_can_initialize_small_subset' \
|
||||
--num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT \
|
||||
--shard-id=$$BUILDKITE_PARALLEL_JOB
|
||||
parallelism: 2
|
||||
|
||||
- label: Basic Models Tests (Other)
|
||||
timeout_in_minutes: 45
|
||||
mirror_hardwares: [amdexperimental]
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models/test_transformers.py
|
||||
- tests/models/test_registry.py
|
||||
- tests/models/test_utils.py
|
||||
- tests/models/test_vision.py
|
||||
commands:
|
||||
- pytest -v -s models/test_transformers.py \
|
||||
models/test_registry.py \
|
||||
models/test_utils.py \
|
||||
models/test_vision.py
|
||||
|
||||
- label: Language Models Tests (Standard)
|
||||
timeout_in_minutes: 25
|
||||
mirror_hardwares: [amdexperimental]
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models/language
|
||||
commands:
|
||||
# Test standard language models, excluding a subset of slow tests
|
||||
- pip freeze | grep -E 'torch'
|
||||
- pytest -v -s models/language -m core_model
|
||||
- pytest -v -s models/language -m 'core_model and (not slow_test)'
|
||||
|
||||
- label: Language Models Test (Hybrid) # 35 min
|
||||
- label: Language Models Tests (Extra Standard) %N
|
||||
timeout_in_minutes: 45
|
||||
mirror_hardwares: [amdexperimental]
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/model_executor/models/
|
||||
- tests/models/language/pooling/test_embedding.py
|
||||
- tests/models/language/generation/test_common.py
|
||||
- tests/models/language/pooling/test_classification.py
|
||||
commands:
|
||||
# Shard slow subset of standard language models tests. Only run when model
|
||||
# source is modified, or when specified test files are modified
|
||||
- pip freeze | grep -E 'torch'
|
||||
- pytest -v -s models/language -m 'core_model and slow_test' \
|
||||
--num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT \
|
||||
--shard-id=$$BUILDKITE_PARALLEL_JOB
|
||||
parallelism: 2
|
||||
|
||||
- label: Language Models Tests (Hybrid) %N
|
||||
timeout_in_minutes: 75
|
||||
mirror_hardwares: [amdexperimental]
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models/language/generation
|
||||
commands:
|
||||
@ -583,7 +628,12 @@ steps:
|
||||
# Note: also needed to run plamo2 model in vLLM
|
||||
- uv pip install --system --no-build-isolation 'git+https://github.com/state-spaces/mamba@v2.2.5'
|
||||
- uv pip install --system --no-build-isolation 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.2'
|
||||
- pytest -v -s models/language/generation -m hybrid_model
|
||||
# Shard hybrid language model tests
|
||||
- pytest -v -s models/language/generation \
|
||||
-m hybrid_model \
|
||||
--num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT \
|
||||
--shard-id=$$BUILDKITE_PARALLEL_JOB
|
||||
parallelism: 2
|
||||
|
||||
- label: Language Models Test (Extended Generation) # 80min
|
||||
timeout_in_minutes: 110
|
||||
@ -597,6 +647,16 @@ steps:
|
||||
- pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8'
|
||||
- pytest -v -s models/language/generation -m '(not core_model) and (not hybrid_model)'
|
||||
|
||||
- label: Language Models Test (PPL)
|
||||
timeout_in_minutes: 110
|
||||
mirror_hardwares: [amdexperimental]
|
||||
optional: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models/language/generation_ppl_test
|
||||
commands:
|
||||
- pytest -v -s models/language/generation_ppl_test
|
||||
|
||||
- label: Language Models Test (Extended Pooling) # 36min
|
||||
timeout_in_minutes: 50
|
||||
mirror_hardwares: [amdexperimental]
|
||||
@ -607,6 +667,16 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s models/language/pooling -m 'not core_model'
|
||||
|
||||
- label: Language Models Test (MTEB)
|
||||
timeout_in_minutes: 110
|
||||
mirror_hardwares: [amdexperimental]
|
||||
optional: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models/language/pooling_mteb_test
|
||||
commands:
|
||||
- pytest -v -s models/language/pooling_mteb_test
|
||||
|
||||
- label: Multi-Modal Processor Test # 44min
|
||||
timeout_in_minutes: 60
|
||||
source_file_dependencies:
|
||||
@ -627,7 +697,7 @@ steps:
|
||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
- pip freeze | grep -E 'torch'
|
||||
- pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing
|
||||
- cd .. && pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work
|
||||
- cd .. && VLLM_WORKER_MULTIPROC_METHOD=spawn pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work
|
||||
|
||||
- label: Multi-Modal Models Test (Extended) 1
|
||||
mirror_hardwares: [amdexperimental]
|
||||
@ -713,11 +783,12 @@ steps:
|
||||
# num_heads2 broken by https://github.com/flashinfer-ai/flashinfer/issues/1353
|
||||
- pytest -v -s tests/kernels/attention/test_flashinfer.py -k 'not num_heads2'
|
||||
- pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py
|
||||
- pytest -v -s tests/kernels/test_cutlass_mla_decode.py
|
||||
- pytest -v -s tests/kernels/attention/test_cutlass_mla_decode.py
|
||||
- pytest -v -s tests/kernels/attention/test_flashinfer_mla_decode.py
|
||||
# Quantization
|
||||
- pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8'
|
||||
- pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py
|
||||
- pytest -v -s tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py
|
||||
- pytest -v -s tests/kernels/quantization/test_silu_mul_nvfp4_quant.py
|
||||
- pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py
|
||||
- pytest -v -s tests/kernels/quantization/test_flashinfer_scaled_mm.py
|
||||
- pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
|
||||
@ -729,6 +800,20 @@ steps:
|
||||
- pytest -v -s tests/kernels/moe/test_flashinfer.py
|
||||
- pytest -v -s tests/compile/test_silu_mul_quant_fusion.py
|
||||
|
||||
- label: GPT-OSS Eval (Blackwell)
|
||||
timeout_in_minutes: 60
|
||||
working_dir: "/vllm-workspace/"
|
||||
gpu: b200
|
||||
optional: true # disable while debugging
|
||||
source_file_dependencies:
|
||||
- tests/evals/gpt_oss
|
||||
- vllm/model_executor/models/gpt_oss.py
|
||||
- vllm/model_executor/layers/quantization/mxfp4.py
|
||||
- vllm/v1/attention/backends/flashinfer.py
|
||||
commands:
|
||||
- uv pip install --system 'gpt-oss[eval]==0.0.5'
|
||||
- pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58 --server-args '--tensor-parallel-size 2'
|
||||
|
||||
##### 1 GPU test #####
|
||||
##### multi gpus test #####
|
||||
|
||||
@ -743,6 +828,8 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s distributed/test_comm_ops.py
|
||||
- pytest -v -s distributed/test_shm_broadcast.py
|
||||
- pytest -v -s distributed/test_shm_buffer.py
|
||||
- pytest -v -s distributed/test_shm_storage.py
|
||||
|
||||
- label: 2 Node Tests (4 GPUs in total) # 16min
|
||||
timeout_in_minutes: 30
|
||||
@ -801,7 +888,8 @@ steps:
|
||||
# Avoid importing model tests that cause CUDA reinitialization error
|
||||
- pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest models/language -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest models/multimodal -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py
|
||||
- VLLM_WORKER_MULTIPROC_METHOD=spawn pytest models/multimodal/generation/test_whisper.py -v -s -m 'distributed(num_gpus=2)'
|
||||
# test sequence parallel
|
||||
- pytest -v -s distributed/test_sequence_parallel.py
|
||||
# this test fails consistently.
|
||||
@ -827,7 +915,7 @@ steps:
|
||||
# begin io_processor plugins test, all the code in between uses the prithvi_io_processor plugin
|
||||
- pip install -e ./plugins/prithvi_io_processor_plugin
|
||||
- pytest -v -s plugins_tests/test_io_processor_plugins.py
|
||||
- pip uninstall prithvi_io_processor_plugin -y
|
||||
- pip uninstall prithvi_io_processor_plugin -y
|
||||
# end io_processor plugins test
|
||||
# other tests continue here:
|
||||
- pytest -v -s plugins_tests/test_scheduler_plugins.py
|
||||
@ -851,7 +939,6 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s distributed/test_pp_cudagraph.py
|
||||
- pytest -v -s distributed/test_pipeline_parallel.py
|
||||
# - pytest -v -s distributed/test_context_parallel.py # TODO: enable it on Hopper runners or add triton MLA support
|
||||
|
||||
- label: LoRA TP Test (Distributed) # 17 min
|
||||
timeout_in_minutes: 30
|
||||
@ -875,7 +962,7 @@ steps:
|
||||
timeout_in_minutes: 45
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 2
|
||||
num_gpus: 2
|
||||
optional: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -925,9 +1012,21 @@ steps:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4
|
||||
|
||||
- label: Qwen MoE EP Test # optional
|
||||
##### H200 test #####
|
||||
- label: Distrubted Tests (H200) # optional
|
||||
gpu: h200
|
||||
optional: true
|
||||
working_dir: "/vllm-workspace/"
|
||||
num_gpus: 2
|
||||
commands:
|
||||
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 /vllm-workspace/examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
|
||||
- pytest -v -s tests/distributed/test_context_parallel.py
|
||||
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
|
||||
|
||||
##### B200 test #####
|
||||
- label: Distributed Tests (B200) # optional
|
||||
gpu: b200
|
||||
optional: true
|
||||
working_dir: "/vllm-workspace/"
|
||||
num_gpus: 2
|
||||
commands:
|
||||
- pytest -v -s tests/distributed/test_context_parallel.py
|
||||
|
||||
32
.coveragerc
Normal file
32
.coveragerc
Normal file
@ -0,0 +1,32 @@
|
||||
[run]
|
||||
source = vllm
|
||||
omit =
|
||||
*/tests/*
|
||||
*/test_*
|
||||
*/__pycache__/*
|
||||
*/build/*
|
||||
*/dist/*
|
||||
*/vllm.egg-info/*
|
||||
*/third_party/*
|
||||
*/examples/*
|
||||
*/benchmarks/*
|
||||
*/docs/*
|
||||
|
||||
[report]
|
||||
exclude_lines =
|
||||
pragma: no cover
|
||||
def __repr__
|
||||
if self.debug:
|
||||
if settings.DEBUG
|
||||
raise AssertionError
|
||||
raise NotImplementedError
|
||||
if 0:
|
||||
if __name__ == .__main__.:
|
||||
class .*\bProtocol\):
|
||||
@(abc\.)?abstractmethod
|
||||
|
||||
[html]
|
||||
directory = htmlcov
|
||||
|
||||
[xml]
|
||||
output = coverage.xml
|
||||
24
.github/.bc-linter.yml
vendored
Normal file
24
.github/.bc-linter.yml
vendored
Normal file
@ -0,0 +1,24 @@
|
||||
# doc: https://github.com/pytorch/test-infra/blob/main/tools/stronghold/docs/bc_linter_config.md
|
||||
version: 1
|
||||
paths:
|
||||
# We temporarily disable globally, and will only enable with `annotations.include`
|
||||
# include:
|
||||
# - "vllm/v1/attetion/*.py"
|
||||
# - "vllm/v1/core/*.py"
|
||||
exclude:
|
||||
- "**/*.py"
|
||||
|
||||
scan:
|
||||
functions: true # check free functions and methods
|
||||
classes: true # check classes/dataclasses
|
||||
public_only: true # ignore names starting with "_" at any level
|
||||
|
||||
annotations:
|
||||
include: # decorators that force‑include a symbol
|
||||
- name: "bc_linter_include" # matched by simple name or dotted suffix
|
||||
propagate_to_members: false # for classes, include methods/inner classes
|
||||
exclude: # decorators that force‑exclude a symbol
|
||||
- name: "bc_linter_skip" # matched by simple name or dotted suffix
|
||||
propagate_to_members: true # for classes, exclude methods/inner classes
|
||||
|
||||
excluded_violations: [] # e.g. ["ParameterRenamed", "FieldTypeChanged"]
|
||||
42
.github/CODEOWNERS
vendored
42
.github/CODEOWNERS
vendored
@ -2,23 +2,27 @@
|
||||
# for more info about CODEOWNERS file
|
||||
|
||||
# This lists cover the "core" components of vLLM that require careful review
|
||||
/vllm/attention @LucasWilkinson
|
||||
/vllm/attention/backends/abstract.py @WoosukKwon @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/core @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/engine/llm_engine.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @22quinn
|
||||
/vllm/worker/worker_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @22quinn
|
||||
/vllm/worker/worker.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/model_executor/layers/fused_moe @mgoin
|
||||
/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill @NickLucche
|
||||
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256
|
||||
/vllm/model_executor/layers/mamba @tdoublep
|
||||
/vllm/model_executor/model_loader @22quinn
|
||||
/vllm/multimodal @DarkLight1337 @ywang96
|
||||
/vllm/multimodal @DarkLight1337 @ywang96 @NickLucche
|
||||
/vllm/v1/attention @LucasWilkinson
|
||||
/vllm/v1/sample @22quinn @houseroad
|
||||
/vllm/vllm_flash_attn @LucasWilkinson
|
||||
/vllm/lora @jeejeelee
|
||||
/vllm/reasoning @aarnphm
|
||||
/vllm/entrypoints @aarnphm
|
||||
/vllm/reasoning @aarnphm @chaunceyjiang
|
||||
/vllm/entrypoints @aarnphm @chaunceyjiang
|
||||
/vllm/compilation @zou3519 @youkaichao @ProExpertProg
|
||||
/vllm/distributed/kv_transfer @NickLucche @ApostaC
|
||||
CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||
|
||||
# Any change to the VllmConfig changes can have a large user-facing impact,
|
||||
@ -29,26 +33,37 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||
/vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat
|
||||
/vllm/v1/structured_output @mgoin @russellb @aarnphm @benchislett
|
||||
/vllm/v1/spec_decode @benchislett @luccafong
|
||||
/vllm/v1/attention/backends/flashinfer.py @mgoin
|
||||
/vllm/v1/attention/backends/triton_attn.py @tdoublep
|
||||
/vllm/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat @heheda12345 @ApostaC
|
||||
/vllm/v1/kv_cache_interface.py @heheda12345
|
||||
/vllm/v1/offloading @ApostaC
|
||||
|
||||
# Test ownership
|
||||
/.buildkite/lm-eval-harness @mgoin @simon-mo
|
||||
/tests/async_engine @njhill @robertgshaw2-redhat @simon-mo
|
||||
/tests/distributed/test_multi_node_assignment.py @youkaichao
|
||||
/tests/distributed/test_pipeline_parallel.py @youkaichao
|
||||
/tests/distributed/test_same_node.py @youkaichao
|
||||
/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @simon-mo @aarnphm
|
||||
/tests/kernels @tlrmchlsmth @WoosukKwon @yewentao256
|
||||
/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @simon-mo @aarnphm @NickLucche
|
||||
/tests/evals @mgoin
|
||||
/tests/kernels @mgoin @tlrmchlsmth @WoosukKwon @yewentao256
|
||||
/tests/models @DarkLight1337 @ywang96
|
||||
/tests/multimodal @DarkLight1337 @ywang96
|
||||
/tests/prefix_caching @comaniac @KuntaiDu
|
||||
/tests/multimodal @DarkLight1337 @ywang96 @NickLucche
|
||||
/tests/quantization @mgoin @robertgshaw2-redhat @yewentao256
|
||||
/tests/test_inputs.py @DarkLight1337 @ywang96
|
||||
/tests/v1/entrypoints/llm/test_struct_output_generate.py @mgoin @russellb @aarnphm
|
||||
/tests/v1/structured_output @mgoin @russellb @aarnphm
|
||||
/tests/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat @heheda12345 @ApostaC
|
||||
/tests/weight_loading @mgoin @youkaichao @yewentao256
|
||||
/tests/lora @jeejeelee
|
||||
/tests/models/language/generation/test_hybrid.py @tdoublep
|
||||
/tests/v1/kv_connector/nixl_integration @NickLucche
|
||||
/tests/v1/kv_connector @ApostaC
|
||||
/tests/v1/offloading @ApostaC
|
||||
|
||||
# Transformers backend
|
||||
/vllm/model_executor/models/transformers.py @hmellor
|
||||
/tests/models/test_transformers.py @hmellor
|
||||
|
||||
# Docs
|
||||
/docs @hmellor
|
||||
@ -91,3 +106,12 @@ mkdocs.yaml @hmellor
|
||||
/vllm/v1/attention/backends/mla/rocm*.py @gshtras
|
||||
/vllm/attention/ops/rocm*.py @gshtras
|
||||
/vllm/model_executor/layers/fused_moe/rocm*.py @gshtras
|
||||
|
||||
# TPU
|
||||
/vllm/v1/worker/tpu* @NickLucche
|
||||
/vllm/platforms/tpu.py @NickLucche
|
||||
/vllm/v1/sample/tpu @NickLucche
|
||||
/vllm/tests/v1/tpu @NickLucche
|
||||
|
||||
# KVConnector installation files
|
||||
/requirements/kv_connectors.txt @NickLucche
|
||||
|
||||
26
.github/mergify.yml
vendored
26
.github/mergify.yml
vendored
@ -124,9 +124,16 @@ pull_request_rules:
|
||||
- or:
|
||||
- files~=^examples/.*gpt[-_]?oss.*\.py
|
||||
- files~=^tests/.*gpt[-_]?oss.*\.py
|
||||
- files~=^tests/entrypoints/openai/test_response_api_with_harmony.py
|
||||
- files~=^tests/entrypoints/test_context.py
|
||||
- files~=^vllm/model_executor/models/.*gpt[-_]?oss.*\.py
|
||||
- files~=^vllm/model_executor/layers/.*gpt[-_]?oss.*\.py
|
||||
- files~=^vllm/entrypoints/harmony_utils.py
|
||||
- files~=^vllm/entrypoints/tool_server.py
|
||||
- files~=^vllm/entrypoints/tool.py
|
||||
- files~=^vllm/entrypoints/context.py
|
||||
- title~=(?i)gpt[-_]?oss
|
||||
- title~=(?i)harmony
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
@ -164,7 +171,7 @@ pull_request_rules:
|
||||
- files=examples/online_serving/openai_chat_completion_structured_outputs.py
|
||||
- files=examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py
|
||||
- files~=^tests/v1/structured_output/
|
||||
- files=tests/v1/entrypoints/llm/test_guided_generate.py
|
||||
- files=tests/v1/entrypoints/llm/test_struct_output_generate.py
|
||||
- files~=^vllm/v1/structured_output/
|
||||
actions:
|
||||
label:
|
||||
@ -295,3 +302,20 @@ pull_request_rules:
|
||||
label:
|
||||
remove:
|
||||
- needs-rebase
|
||||
|
||||
- name: label-kv-connector
|
||||
description: Automatically apply kv-connector label
|
||||
conditions:
|
||||
- or:
|
||||
- files~=^examples/online_serving/disaggregated[^/]*/.*
|
||||
- files~=^examples/offline_inference/disaggregated[^/]*/.*
|
||||
- files~=^examples/others/lmcache/
|
||||
- files~=^tests/v1/kv_connector/
|
||||
- files~=^vllm/distributed/kv_transfer/
|
||||
- title~=(?i)\bP/?D\b
|
||||
- title~=(?i)NIXL
|
||||
- title~=(?i)LMCache
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
- kv-connector
|
||||
2
.github/workflows/add_label_automerge.yml
vendored
2
.github/workflows/add_label_automerge.yml
vendored
@ -10,7 +10,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Add label
|
||||
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
with:
|
||||
script: |
|
||||
github.rest.issues.addLabels({
|
||||
|
||||
29
.github/workflows/bc-lint.yml
vendored
Normal file
29
.github/workflows/bc-lint.yml
vendored
Normal file
@ -0,0 +1,29 @@
|
||||
name: BC Lint
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types:
|
||||
- opened
|
||||
- synchronize
|
||||
- reopened
|
||||
- labeled
|
||||
- unlabeled
|
||||
|
||||
jobs:
|
||||
bc_lint:
|
||||
if: github.repository_owner == 'vllm-project'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Run BC Lint Action
|
||||
uses: pytorch/test-infra/.github/actions/bc-lint@main
|
||||
with:
|
||||
repo: ${{ github.event.pull_request.head.repo.full_name }}
|
||||
base_sha: ${{ github.event.pull_request.base.sha }}
|
||||
head_sha: ${{ github.event.pull_request.head.sha }}
|
||||
suppression: ${{ contains(github.event.pull_request.labels.*.name, 'suppress-bc-linter') }}
|
||||
docs_link: 'https://github.com/pytorch/test-infra/wiki/BC-Linter'
|
||||
config_dir: .github
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}
|
||||
cancel-in-progress: true
|
||||
2
.github/workflows/issue_autolabel.yml
vendored
2
.github/workflows/issue_autolabel.yml
vendored
@ -13,7 +13,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Label issues based on keywords
|
||||
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
with:
|
||||
script: |
|
||||
// Configuration: Add new labels and keywords here
|
||||
|
||||
2
.github/workflows/reminder_comment.yml
vendored
2
.github/workflows/reminder_comment.yml
vendored
@ -9,7 +9,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Remind to run full CI on PR
|
||||
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
with:
|
||||
script: |
|
||||
try {
|
||||
|
||||
2
.github/workflows/stale.yml
vendored
2
.github/workflows/stale.yml
vendored
@ -13,7 +13,7 @@ jobs:
|
||||
actions: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/stale@5bef64f19d7facfb25b37b414482c7164d639639 # v9.1.0
|
||||
- uses: actions/stale@3a9db7e6a41a89f618792c92c0e97cc736e1b13f # v10.0.0
|
||||
with:
|
||||
# Increasing this value ensures that changes to this workflow
|
||||
# propagate to all issues and PRs in days rather than months
|
||||
|
||||
12
.gitignore
vendored
12
.gitignore
vendored
@ -4,7 +4,7 @@
|
||||
# vllm-flash-attn built from source
|
||||
vllm/vllm_flash_attn/*
|
||||
|
||||
# triton jit
|
||||
# triton jit
|
||||
.triton
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
@ -177,6 +177,14 @@ cython_debug/
|
||||
# VSCode
|
||||
.vscode/
|
||||
|
||||
# Claude
|
||||
CLAUDE.md
|
||||
.claude/
|
||||
|
||||
# Codex
|
||||
AGENTS.md
|
||||
.codex/
|
||||
|
||||
# DS Store
|
||||
.DS_Store
|
||||
|
||||
@ -209,4 +217,4 @@ shellcheck*/
|
||||
csrc/moe/marlin_moe_wna16/kernel_*
|
||||
|
||||
# Ignore ep_kernels_workspace folder
|
||||
ep_kernels_workspace/
|
||||
ep_kernels_workspace/
|
||||
|
||||
@ -164,9 +164,7 @@ repos:
|
||||
name: Validate configuration has default values and that each field has a docstring
|
||||
entry: python tools/validate_config.py
|
||||
language: python
|
||||
types: [python]
|
||||
pass_filenames: true
|
||||
files: vllm/config.py|tests/test_config.py|vllm/entrypoints/openai/cli_args.py
|
||||
additional_dependencies: [regex]
|
||||
# Keep `suggestion` last
|
||||
- id: suggestion
|
||||
name: Suggestion
|
||||
|
||||
@ -1 +1,2 @@
|
||||
collect_env.py
|
||||
vllm/model_executor/layers/fla/ops/*.py
|
||||
|
||||
@ -13,6 +13,10 @@ cmake_minimum_required(VERSION 3.26)
|
||||
# cmake --install . --component _C
|
||||
project(vllm_extensions LANGUAGES CXX)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
|
||||
# CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py)
|
||||
set(VLLM_TARGET_DEVICE "cuda" CACHE STRING "Target device backend for vLLM")
|
||||
message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
|
||||
@ -171,6 +175,16 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}")
|
||||
endif()
|
||||
|
||||
#
|
||||
# Set CUDA include flags for CXX compiler.
|
||||
#
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -I${CUDA_TOOLKIT_ROOT_DIR}/include")
|
||||
if(CUDA_VERSION VERSION_GREATER_EQUAL 13.0)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -I${CUDA_TOOLKIT_ROOT_DIR}/include/cccl")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
#
|
||||
# Use FetchContent for C++ dependencies that are compiled as part of vLLM's build process.
|
||||
# setup.py will override FETCHCONTENT_BASE_DIR to play nicely with sccache.
|
||||
@ -294,7 +308,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu"
|
||||
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
|
||||
"csrc/cutlass_extensions/common.cpp"
|
||||
"csrc/attention/mla/cutlass_mla_entry.cu"
|
||||
"csrc/quantization/fp8/per_token_group_quant.cu")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
@ -581,7 +594,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
cuda_archs_loose_intersection(MLA_ARCHS "10.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND MLA_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/attention/mla/cutlass_mla_kernels.cu"
|
||||
"csrc/attention/mla/sm100_cutlass_mla_kernel.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
@ -779,6 +791,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Hadacore kernels
|
||||
cuda_archs_loose_intersection(HADACORE_ARCHS "8.0;8.9;9.0" "${CUDA_ARCHS}")
|
||||
if(HADACORE_ARCHS)
|
||||
set(SRCS "csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${HADACORE_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
message(STATUS "Building hadacore")
|
||||
endif()
|
||||
|
||||
# if CUDA endif
|
||||
endif()
|
||||
|
||||
|
||||
@ -14,6 +14,9 @@ Easy, fast, and cheap LLM serving for everyone
|
||||
| <a href="https://docs.vllm.ai"><b>Documentation</b></a> | <a href="https://blog.vllm.ai/"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://x.com/vllm_project"><b>Twitter/X</b></a> | <a href="https://discuss.vllm.ai"><b>User Forum</b></a> | <a href="https://slack.vllm.ai"><b>Developer Slack</b></a> |
|
||||
</p>
|
||||
|
||||
---
|
||||
Join us at the [PyTorch Conference, October 22-23](https://events.linuxfoundation.org/pytorch-conference/) and [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco for our latest updates on vLLM and to meet the vLLM team! Register now for the largest vLLM community events of the year!
|
||||
|
||||
---
|
||||
|
||||
*Latest News* 🔥
|
||||
@ -78,7 +81,7 @@ vLLM is flexible and easy to use with:
|
||||
- Tensor, pipeline, data and expert parallelism support for distributed inference
|
||||
- Streaming outputs
|
||||
- OpenAI-compatible API server
|
||||
- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, TPU, and AWS Neuron
|
||||
- Support for NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, and TPU. Additionally, support for diverse hardware plugins such as Intel Gaudi, IBM Spyre and Huawei Ascend.
|
||||
- Prefix caching support
|
||||
- Multi-LoRA support
|
||||
|
||||
|
||||
@ -1,807 +1,20 @@
|
||||
# Benchmarking vLLM
|
||||
# Benchmarks
|
||||
|
||||
This README guides you through running benchmark tests with the extensive
|
||||
datasets supported on vLLM. It’s a living document, updated as new features and datasets
|
||||
become available.
|
||||
This directory used to contain vLLM's benchmark scripts and utilities for performance testing and evaluation.
|
||||
|
||||
## Dataset Overview
|
||||
## Contents
|
||||
|
||||
<table style="width:100%; border-collapse: collapse;">
|
||||
<thead>
|
||||
<tr>
|
||||
<th style="width:15%; text-align: left;">Dataset</th>
|
||||
<th style="width:10%; text-align: center;">Online</th>
|
||||
<th style="width:10%; text-align: center;">Offline</th>
|
||||
<th style="width:65%; text-align: left;">Data Path</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td><strong>ShareGPT</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td><code>wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>ShareGPT4V (Image)</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td>
|
||||
<code>wget https://huggingface.co/datasets/Lin-Chen/ShareGPT4V/blob/main/sharegpt4v_instruct_gpt4-vision_cap100k.json</code>
|
||||
<br>
|
||||
<div>Note that the images need to be downloaded separately. For example, to download COCO's 2017 Train images:</div>
|
||||
<code>wget http://images.cocodataset.org/zips/train2017.zip</code>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>ShareGPT4Video (Video)</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td>
|
||||
<code>git clone https://huggingface.co/datasets/ShareGPT4Video/ShareGPT4Video</code>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>BurstGPT</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td><code>wget https://github.com/HPMLL/BurstGPT/releases/download/v1.1/BurstGPT_without_fails_2.csv</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>Sonnet (deprecated)</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td>Local file: <code>benchmarks/sonnet.txt</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>Random</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td><code>synthetic</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>RandomMultiModal (Image/Video)</strong></td>
|
||||
<td style="text-align: center;">🟡</td>
|
||||
<td style="text-align: center;">🚧</td>
|
||||
<td><code>synthetic</code> </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>Prefix Repetition</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td><code>synthetic</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>HuggingFace-VisionArena</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td><code>lmarena-ai/VisionArena-Chat</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>HuggingFace-InstructCoder</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td><code>likaixin/InstructCoder</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>HuggingFace-AIMO</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td><code>AI-MO/aimo-validation-aime</code> , <code>AI-MO/NuminaMath-1.5</code>, <code>AI-MO/NuminaMath-CoT</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>HuggingFace-Other</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td><code>lmms-lab/LLaVA-OneVision-Data</code>, <code>Aeala/ShareGPT_Vicuna_unfiltered</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>Custom</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td>Local file: <code>data.jsonl</code></td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
- **Serving benchmarks**: Scripts for testing online inference performance (latency, throughput)
|
||||
- **Throughput benchmarks**: Scripts for testing offline batch inference performance
|
||||
- **Specialized benchmarks**: Tools for testing specific features like structured output, prefix caching, long document QA, request prioritization, and multi-modal inference
|
||||
- **Dataset utilities**: Framework for loading and sampling from various benchmark datasets (ShareGPT, HuggingFace datasets, synthetic data, etc.)
|
||||
|
||||
✅: supported
|
||||
## Usage
|
||||
|
||||
🟡: Partial support
|
||||
For detailed usage instructions, examples, and dataset information, see the [Benchmark CLI documentation](https://docs.vllm.ai/en/latest/contributing/benchmarks.html#benchmark-cli).
|
||||
|
||||
🚧: to be supported
|
||||
For full CLI reference see:
|
||||
|
||||
**Note**: HuggingFace dataset's `dataset-name` should be set to `hf`.
|
||||
For local `dataset-path`, please set `hf-name` to its Hugging Face ID like
|
||||
|
||||
```bash
|
||||
--dataset-path /datasets/VisionArena-Chat/ --hf-name lmarena-ai/VisionArena-Chat
|
||||
```
|
||||
|
||||
## 🚀 Example - Online Benchmark
|
||||
|
||||
<details>
|
||||
<summary>Show more</summary>
|
||||
|
||||
<br/>
|
||||
|
||||
First start serving your model
|
||||
|
||||
```bash
|
||||
vllm serve NousResearch/Hermes-3-Llama-3.1-8B
|
||||
```
|
||||
|
||||
Then run the benchmarking script
|
||||
|
||||
```bash
|
||||
# download dataset
|
||||
# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
vllm bench serve \
|
||||
--backend vllm \
|
||||
--model NousResearch/Hermes-3-Llama-3.1-8B \
|
||||
--endpoint /v1/completions \
|
||||
--dataset-name sharegpt \
|
||||
--dataset-path <your data path>/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||
--num-prompts 10
|
||||
```
|
||||
|
||||
If successful, you will see the following output
|
||||
|
||||
```text
|
||||
============ Serving Benchmark Result ============
|
||||
Successful requests: 10
|
||||
Benchmark duration (s): 5.78
|
||||
Total input tokens: 1369
|
||||
Total generated tokens: 2212
|
||||
Request throughput (req/s): 1.73
|
||||
Output token throughput (tok/s): 382.89
|
||||
Total Token throughput (tok/s): 619.85
|
||||
---------------Time to First Token----------------
|
||||
Mean TTFT (ms): 71.54
|
||||
Median TTFT (ms): 73.88
|
||||
P99 TTFT (ms): 79.49
|
||||
-----Time per Output Token (excl. 1st token)------
|
||||
Mean TPOT (ms): 7.91
|
||||
Median TPOT (ms): 7.96
|
||||
P99 TPOT (ms): 8.03
|
||||
---------------Inter-token Latency----------------
|
||||
Mean ITL (ms): 7.74
|
||||
Median ITL (ms): 7.70
|
||||
P99 ITL (ms): 8.39
|
||||
==================================================
|
||||
```
|
||||
|
||||
### Custom Dataset
|
||||
|
||||
If the dataset you want to benchmark is not supported yet in vLLM, even then you can benchmark on it using `CustomDataset`. Your data needs to be in `.jsonl` format and needs to have "prompt" field per entry, e.g., data.jsonl
|
||||
|
||||
```json
|
||||
{"prompt": "What is the capital of India?"}
|
||||
{"prompt": "What is the capital of Iran?"}
|
||||
{"prompt": "What is the capital of China?"}
|
||||
```
|
||||
|
||||
```bash
|
||||
# start server
|
||||
VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.1-8B-Instruct
|
||||
```
|
||||
|
||||
```bash
|
||||
# run benchmarking script
|
||||
vllm bench serve --port 9001 --save-result --save-detailed \
|
||||
--backend vllm \
|
||||
--model meta-llama/Llama-3.1-8B-Instruct \
|
||||
--endpoint /v1/completions \
|
||||
--dataset-name custom \
|
||||
--dataset-path <path-to-your-data-jsonl> \
|
||||
--custom-skip-chat-template \
|
||||
--num-prompts 80 \
|
||||
--max-concurrency 1 \
|
||||
--temperature=0.3 \
|
||||
--top-p=0.75 \
|
||||
--result-dir "./log/"
|
||||
```
|
||||
|
||||
You can skip applying chat template if your data already has it by using `--custom-skip-chat-template`.
|
||||
|
||||
### VisionArena Benchmark for Vision Language Models
|
||||
|
||||
```bash
|
||||
# need a model with vision capability here
|
||||
vllm serve Qwen/Qwen2-VL-7B-Instruct
|
||||
```
|
||||
|
||||
```bash
|
||||
vllm bench serve \
|
||||
--backend openai-chat \
|
||||
--endpoint-type openai-chat \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--endpoint /v1/chat/completions \
|
||||
--dataset-name hf \
|
||||
--dataset-path lmarena-ai/VisionArena-Chat \
|
||||
--hf-split train \
|
||||
--num-prompts 1000
|
||||
```
|
||||
|
||||
### InstructCoder Benchmark with Speculative Decoding
|
||||
|
||||
``` bash
|
||||
VLLM_USE_V1=1 vllm serve meta-llama/Meta-Llama-3-8B-Instruct \
|
||||
--speculative-config $'{"method": "ngram",
|
||||
"num_speculative_tokens": 5, "prompt_lookup_max": 5,
|
||||
"prompt_lookup_min": 2}'
|
||||
```
|
||||
|
||||
``` bash
|
||||
vllm bench serve \
|
||||
--model meta-llama/Meta-Llama-3-8B-Instruct \
|
||||
--dataset-name hf \
|
||||
--dataset-path likaixin/InstructCoder \
|
||||
--num-prompts 2048
|
||||
```
|
||||
|
||||
### Other HuggingFaceDataset Examples
|
||||
|
||||
```bash
|
||||
vllm serve Qwen/Qwen2-VL-7B-Instruct
|
||||
```
|
||||
|
||||
`lmms-lab/LLaVA-OneVision-Data`:
|
||||
|
||||
```bash
|
||||
vllm bench serve \
|
||||
--backend openai-chat \
|
||||
--endpoint-type openai-chat \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--endpoint /v1/chat/completions \
|
||||
--dataset-name hf \
|
||||
--dataset-path lmms-lab/LLaVA-OneVision-Data \
|
||||
--hf-split train \
|
||||
--hf-subset "chart2text(cauldron)" \
|
||||
--num-prompts 10
|
||||
```
|
||||
|
||||
`Aeala/ShareGPT_Vicuna_unfiltered`:
|
||||
|
||||
```bash
|
||||
vllm bench serve \
|
||||
--backend openai-chat \
|
||||
--endpoint-type openai-chat \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--endpoint /v1/chat/completions \
|
||||
--dataset-name hf \
|
||||
--dataset-path Aeala/ShareGPT_Vicuna_unfiltered \
|
||||
--hf-split train \
|
||||
--num-prompts 10
|
||||
```
|
||||
|
||||
`AI-MO/aimo-validation-aime`:
|
||||
|
||||
``` bash
|
||||
vllm bench serve \
|
||||
--model Qwen/QwQ-32B \
|
||||
--dataset-name hf \
|
||||
--dataset-path AI-MO/aimo-validation-aime \
|
||||
--num-prompts 10 \
|
||||
--seed 42
|
||||
```
|
||||
|
||||
`philschmid/mt-bench`:
|
||||
|
||||
``` bash
|
||||
vllm bench serve \
|
||||
--model Qwen/QwQ-32B \
|
||||
--dataset-name hf \
|
||||
--dataset-path philschmid/mt-bench \
|
||||
--num-prompts 80
|
||||
```
|
||||
|
||||
### Running With Sampling Parameters
|
||||
|
||||
When using OpenAI-compatible backends such as `vllm`, optional sampling
|
||||
parameters can be specified. Example client command:
|
||||
|
||||
```bash
|
||||
vllm bench serve \
|
||||
--backend vllm \
|
||||
--model NousResearch/Hermes-3-Llama-3.1-8B \
|
||||
--endpoint /v1/completions \
|
||||
--dataset-name sharegpt \
|
||||
--dataset-path <your data path>/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||
--top-k 10 \
|
||||
--top-p 0.9 \
|
||||
--temperature 0.5 \
|
||||
--num-prompts 10
|
||||
```
|
||||
|
||||
### Running With Ramp-Up Request Rate
|
||||
|
||||
The benchmark tool also supports ramping up the request rate over the
|
||||
duration of the benchmark run. This can be useful for stress testing the
|
||||
server or finding the maximum throughput that it can handle, given some latency budget.
|
||||
|
||||
Two ramp-up strategies are supported:
|
||||
|
||||
- `linear`: Increases the request rate linearly from a start value to an end value.
|
||||
- `exponential`: Increases the request rate exponentially.
|
||||
|
||||
The following arguments can be used to control the ramp-up:
|
||||
|
||||
- `--ramp-up-strategy`: The ramp-up strategy to use (`linear` or `exponential`).
|
||||
- `--ramp-up-start-rps`: The request rate at the beginning of the benchmark.
|
||||
- `--ramp-up-end-rps`: The request rate at the end of the benchmark.
|
||||
|
||||
</details>
|
||||
|
||||
## 📈 Example - Offline Throughput Benchmark
|
||||
|
||||
<details>
|
||||
<summary>Show more</summary>
|
||||
|
||||
<br/>
|
||||
|
||||
```bash
|
||||
vllm bench throughput \
|
||||
--model NousResearch/Hermes-3-Llama-3.1-8B \
|
||||
--dataset-name sonnet \
|
||||
--dataset-path vllm/benchmarks/sonnet.txt \
|
||||
--num-prompts 10
|
||||
```
|
||||
|
||||
If successful, you will see the following output
|
||||
|
||||
```text
|
||||
Throughput: 7.15 requests/s, 4656.00 total tokens/s, 1072.15 output tokens/s
|
||||
Total num prompt tokens: 5014
|
||||
Total num output tokens: 1500
|
||||
```
|
||||
|
||||
### VisionArena Benchmark for Vision Language Models
|
||||
|
||||
```bash
|
||||
vllm bench throughput \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--backend vllm-chat \
|
||||
--dataset-name hf \
|
||||
--dataset-path lmarena-ai/VisionArena-Chat \
|
||||
--num-prompts 1000 \
|
||||
--hf-split train
|
||||
```
|
||||
|
||||
The `num prompt tokens` now includes image token counts
|
||||
|
||||
```text
|
||||
Throughput: 2.55 requests/s, 4036.92 total tokens/s, 326.90 output tokens/s
|
||||
Total num prompt tokens: 14527
|
||||
Total num output tokens: 1280
|
||||
```
|
||||
|
||||
### InstructCoder Benchmark with Speculative Decoding
|
||||
|
||||
``` bash
|
||||
VLLM_WORKER_MULTIPROC_METHOD=spawn \
|
||||
VLLM_USE_V1=1 \
|
||||
vllm bench throughput \
|
||||
--dataset-name=hf \
|
||||
--dataset-path=likaixin/InstructCoder \
|
||||
--model=meta-llama/Meta-Llama-3-8B-Instruct \
|
||||
--input-len=1000 \
|
||||
--output-len=100 \
|
||||
--num-prompts=2048 \
|
||||
--async-engine \
|
||||
--speculative-config $'{"method": "ngram",
|
||||
"num_speculative_tokens": 5, "prompt_lookup_max": 5,
|
||||
"prompt_lookup_min": 2}'
|
||||
```
|
||||
|
||||
```text
|
||||
Throughput: 104.77 requests/s, 23836.22 total tokens/s, 10477.10 output tokens/s
|
||||
Total num prompt tokens: 261136
|
||||
Total num output tokens: 204800
|
||||
```
|
||||
|
||||
### Other HuggingFaceDataset Examples
|
||||
|
||||
`lmms-lab/LLaVA-OneVision-Data`:
|
||||
|
||||
```bash
|
||||
vllm bench throughput \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--backend vllm-chat \
|
||||
--dataset-name hf \
|
||||
--dataset-path lmms-lab/LLaVA-OneVision-Data \
|
||||
--hf-split train \
|
||||
--hf-subset "chart2text(cauldron)" \
|
||||
--num-prompts 10
|
||||
```
|
||||
|
||||
`Aeala/ShareGPT_Vicuna_unfiltered`:
|
||||
|
||||
```bash
|
||||
vllm bench throughput \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--backend vllm-chat \
|
||||
--dataset-name hf \
|
||||
--dataset-path Aeala/ShareGPT_Vicuna_unfiltered \
|
||||
--hf-split train \
|
||||
--num-prompts 10
|
||||
```
|
||||
|
||||
`AI-MO/aimo-validation-aime`:
|
||||
|
||||
```bash
|
||||
vllm bench throughput \
|
||||
--model Qwen/QwQ-32B \
|
||||
--backend vllm \
|
||||
--dataset-name hf \
|
||||
--dataset-path AI-MO/aimo-validation-aime \
|
||||
--hf-split train \
|
||||
--num-prompts 10
|
||||
```
|
||||
|
||||
Benchmark with LoRA adapters:
|
||||
|
||||
``` bash
|
||||
# download dataset
|
||||
# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
vllm bench throughput \
|
||||
--model meta-llama/Llama-2-7b-hf \
|
||||
--backend vllm \
|
||||
--dataset_path <your data path>/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||
--dataset_name sharegpt \
|
||||
--num-prompts 10 \
|
||||
--max-loras 2 \
|
||||
--max-lora-rank 8 \
|
||||
--enable-lora \
|
||||
--lora-path yard1/llama-2-7b-sql-lora-test
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## 🛠️ Example - Structured Output Benchmark
|
||||
|
||||
<details>
|
||||
<summary>Show more</summary>
|
||||
|
||||
<br/>
|
||||
|
||||
Benchmark the performance of structured output generation (JSON, grammar, regex).
|
||||
|
||||
### Server Setup
|
||||
|
||||
```bash
|
||||
vllm serve NousResearch/Hermes-3-Llama-3.1-8B
|
||||
```
|
||||
|
||||
### JSON Schema Benchmark
|
||||
|
||||
```bash
|
||||
python3 benchmarks/benchmark_serving_structured_output.py \
|
||||
--backend vllm \
|
||||
--model NousResearch/Hermes-3-Llama-3.1-8B \
|
||||
--dataset json \
|
||||
--structured-output-ratio 1.0 \
|
||||
--request-rate 10 \
|
||||
--num-prompts 1000
|
||||
```
|
||||
|
||||
### Grammar-based Generation Benchmark
|
||||
|
||||
```bash
|
||||
python3 benchmarks/benchmark_serving_structured_output.py \
|
||||
--backend vllm \
|
||||
--model NousResearch/Hermes-3-Llama-3.1-8B \
|
||||
--dataset grammar \
|
||||
--structure-type grammar \
|
||||
--request-rate 10 \
|
||||
--num-prompts 1000
|
||||
```
|
||||
|
||||
### Regex-based Generation Benchmark
|
||||
|
||||
```bash
|
||||
python3 benchmarks/benchmark_serving_structured_output.py \
|
||||
--backend vllm \
|
||||
--model NousResearch/Hermes-3-Llama-3.1-8B \
|
||||
--dataset regex \
|
||||
--request-rate 10 \
|
||||
--num-prompts 1000
|
||||
```
|
||||
|
||||
### Choice-based Generation Benchmark
|
||||
|
||||
```bash
|
||||
python3 benchmarks/benchmark_serving_structured_output.py \
|
||||
--backend vllm \
|
||||
--model NousResearch/Hermes-3-Llama-3.1-8B \
|
||||
--dataset choice \
|
||||
--request-rate 10 \
|
||||
--num-prompts 1000
|
||||
```
|
||||
|
||||
### XGrammar Benchmark Dataset
|
||||
|
||||
```bash
|
||||
python3 benchmarks/benchmark_serving_structured_output.py \
|
||||
--backend vllm \
|
||||
--model NousResearch/Hermes-3-Llama-3.1-8B \
|
||||
--dataset xgrammar_bench \
|
||||
--request-rate 10 \
|
||||
--num-prompts 1000
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## 📚 Example - Long Document QA Benchmark
|
||||
|
||||
<details>
|
||||
<summary>Show more</summary>
|
||||
|
||||
<br/>
|
||||
|
||||
Benchmark the performance of long document question-answering with prefix caching.
|
||||
|
||||
### Basic Long Document QA Test
|
||||
|
||||
```bash
|
||||
python3 benchmarks/benchmark_long_document_qa_throughput.py \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--enable-prefix-caching \
|
||||
--num-documents 16 \
|
||||
--document-length 2000 \
|
||||
--output-len 50 \
|
||||
--repeat-count 5
|
||||
```
|
||||
|
||||
### Different Repeat Modes
|
||||
|
||||
```bash
|
||||
# Random mode (default) - shuffle prompts randomly
|
||||
python3 benchmarks/benchmark_long_document_qa_throughput.py \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--enable-prefix-caching \
|
||||
--num-documents 8 \
|
||||
--document-length 3000 \
|
||||
--repeat-count 3 \
|
||||
--repeat-mode random
|
||||
|
||||
# Tile mode - repeat entire prompt list in sequence
|
||||
python3 benchmarks/benchmark_long_document_qa_throughput.py \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--enable-prefix-caching \
|
||||
--num-documents 8 \
|
||||
--document-length 3000 \
|
||||
--repeat-count 3 \
|
||||
--repeat-mode tile
|
||||
|
||||
# Interleave mode - repeat each prompt consecutively
|
||||
python3 benchmarks/benchmark_long_document_qa_throughput.py \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--enable-prefix-caching \
|
||||
--num-documents 8 \
|
||||
--document-length 3000 \
|
||||
--repeat-count 3 \
|
||||
--repeat-mode interleave
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## 🗂️ Example - Prefix Caching Benchmark
|
||||
|
||||
<details>
|
||||
<summary>Show more</summary>
|
||||
|
||||
<br/>
|
||||
|
||||
Benchmark the efficiency of automatic prefix caching.
|
||||
|
||||
### Fixed Prompt with Prefix Caching
|
||||
|
||||
```bash
|
||||
python3 benchmarks/benchmark_prefix_caching.py \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--enable-prefix-caching \
|
||||
--num-prompts 1 \
|
||||
--repeat-count 100 \
|
||||
--input-length-range 128:256
|
||||
```
|
||||
|
||||
### ShareGPT Dataset with Prefix Caching
|
||||
|
||||
```bash
|
||||
# download dataset
|
||||
# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
|
||||
python3 benchmarks/benchmark_prefix_caching.py \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--dataset-path /path/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||
--enable-prefix-caching \
|
||||
--num-prompts 20 \
|
||||
--repeat-count 5 \
|
||||
--input-length-range 128:256
|
||||
```
|
||||
|
||||
### Prefix Repetition Dataset
|
||||
|
||||
```bash
|
||||
vllm bench serve \
|
||||
--backend openai \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--dataset-name prefix_repetition \
|
||||
--num-prompts 100 \
|
||||
--prefix-repetition-prefix-len 512 \
|
||||
--prefix-repetition-suffix-len 128 \
|
||||
--prefix-repetition-num-prefixes 5 \
|
||||
--prefix-repetition-output-len 128
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## ⚡ Example - Request Prioritization Benchmark
|
||||
|
||||
<details>
|
||||
<summary>Show more</summary>
|
||||
|
||||
<br/>
|
||||
|
||||
Benchmark the performance of request prioritization in vLLM.
|
||||
|
||||
### Basic Prioritization Test
|
||||
|
||||
```bash
|
||||
python3 benchmarks/benchmark_prioritization.py \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--input-len 128 \
|
||||
--output-len 64 \
|
||||
--num-prompts 100 \
|
||||
--scheduling-policy priority
|
||||
```
|
||||
|
||||
### Multiple Sequences per Prompt
|
||||
|
||||
```bash
|
||||
python3 benchmarks/benchmark_prioritization.py \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--input-len 128 \
|
||||
--output-len 64 \
|
||||
--num-prompts 100 \
|
||||
--scheduling-policy priority \
|
||||
--n 2
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## 👁️ Example - Multi-Modal Benchmark
|
||||
|
||||
<details>
|
||||
<summary>Show more</summary>
|
||||
|
||||
<br/>
|
||||
|
||||
Benchmark the performance of multi-modal requests in vLLM.
|
||||
|
||||
### Images (ShareGPT4V)
|
||||
|
||||
Start vLLM:
|
||||
|
||||
```bash
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
--model Qwen/Qwen2.5-VL-7B-Instruct \
|
||||
--dtype bfloat16 \
|
||||
--limit-mm-per-prompt '{"image": 1}' \
|
||||
--allowed-local-media-path /path/to/sharegpt4v/images
|
||||
```
|
||||
|
||||
Send requests with images:
|
||||
|
||||
```bash
|
||||
python benchmarks/benchmark_serving.py \
|
||||
--backend openai-chat \
|
||||
--model Qwen/Qwen2.5-VL-7B-Instruct \
|
||||
--dataset-name sharegpt \
|
||||
--dataset-path /path/to/ShareGPT4V/sharegpt4v_instruct_gpt4-vision_cap100k.json \
|
||||
--num-prompts 100 \
|
||||
--save-result \
|
||||
--result-dir ~/vllm_benchmark_results \
|
||||
--save-detailed \
|
||||
--endpoint /v1/chat/completion
|
||||
```
|
||||
|
||||
### Videos (ShareGPT4Video)
|
||||
|
||||
Start vLLM:
|
||||
|
||||
```bash
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
--model Qwen/Qwen2.5-VL-7B-Instruct \
|
||||
--dtype bfloat16 \
|
||||
--limit-mm-per-prompt '{"video": 1}' \
|
||||
--allowed-local-media-path /path/to/sharegpt4video/videos
|
||||
```
|
||||
|
||||
Send requests with videos:
|
||||
|
||||
```bash
|
||||
python benchmarks/benchmark_serving.py \
|
||||
--backend openai-chat \
|
||||
--model Qwen/Qwen2.5-VL-7B-Instruct \
|
||||
--dataset-name sharegpt \
|
||||
--dataset-path /path/to/ShareGPT4Video/llava_v1_5_mix665k_with_video_chatgpt72k_share4video28k.json \
|
||||
--num-prompts 100 \
|
||||
--save-result \
|
||||
--result-dir ~/vllm_benchmark_results \
|
||||
--save-detailed \
|
||||
--endpoint /v1/chat/completion
|
||||
```
|
||||
|
||||
### Synthetic Random Images (random-mm)
|
||||
|
||||
Generate synthetic image inputs alongside random text prompts to stress-test vision models without external datasets.
|
||||
|
||||
Notes:
|
||||
|
||||
- Works only with online benchmark via the OpenAI backend (`--backend openai-chat`) and endpoint `/v1/chat/completions`.
|
||||
- Video sampling is not yet implemented.
|
||||
|
||||
Start the server (example):
|
||||
|
||||
```bash
|
||||
vllm serve Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
--dtype bfloat16 \
|
||||
--max-model-len 16384 \
|
||||
--limit-mm-per-prompt '{"image": 3, "video": 0}' \
|
||||
--mm-processor-kwargs max_pixels=1003520
|
||||
```
|
||||
|
||||
Benchmark. It is recommended to use the flag `--ignore-eos` to simulate real responses. You can set the size of the output via the arg `random-output-len`.
|
||||
|
||||
Ex.1: Fixed number of items and a single image resolution, enforcing generation of approx 40 tokens:
|
||||
|
||||
```bash
|
||||
vllm bench serve \
|
||||
--backend openai-chat \
|
||||
--model Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
--endpoint /v1/chat/completions \
|
||||
--dataset-name random-mm \
|
||||
--num-prompts 100 \
|
||||
--max-concurrency 10 \
|
||||
--random-prefix-len 25 \
|
||||
--random-input-len 300 \
|
||||
--random-output-len 40 \
|
||||
--random-range-ratio 0.2 \
|
||||
--random-mm-base-items-per-request 2 \
|
||||
--random-mm-limit-mm-per-prompt '{"image": 3, "video": 0}' \
|
||||
--random-mm-bucket-config '{(224, 224, 1): 1.0}' \
|
||||
--request-rate inf \
|
||||
--ignore-eos \
|
||||
--seed 42
|
||||
```
|
||||
|
||||
The number of items per request can be controlled by passing multiple image buckets:
|
||||
|
||||
```bash
|
||||
--random-mm-base-items-per-request 2 \
|
||||
--random-mm-num-mm-items-range-ratio 0.5 \
|
||||
--random-mm-limit-mm-per-prompt '{"image": 4, "video": 0}' \
|
||||
--random-mm-bucket-config '{(256, 256, 1): 0.7, (720, 1280, 1): 0.3}' \
|
||||
```
|
||||
|
||||
Flags specific to `random-mm`:
|
||||
|
||||
- `--random-mm-base-items-per-request`: base number of multimodal items per request.
|
||||
- `--random-mm-num-mm-items-range-ratio`: vary item count uniformly in the closed integer range [floor(n·(1−r)), ceil(n·(1+r))]. Set r=0 to keep it fixed; r=1 allows 0 items.
|
||||
- `--random-mm-limit-mm-per-prompt`: per-modality hard caps, e.g. '{"image": 3, "video": 0}'.
|
||||
- `--random-mm-bucket-config`: dict mapping (H, W, T) → probability. Entries with probability 0 are removed; remaining probabilities are renormalized to sum to 1. Use T=1 for images. Set any T>1 for videos (video sampling not yet supported).
|
||||
|
||||
Behavioral notes:
|
||||
|
||||
- If the requested base item count cannot be satisfied under the provided per-prompt limits, the tool raises an error rather than silently clamping.
|
||||
|
||||
How sampling works:
|
||||
|
||||
- Determine per-request item count k by sampling uniformly from the integer range defined by `--random-mm-base-items-per-request` and `--random-mm-num-mm-items-range-ratio`, then clamp k to at most the sum of per-modality limits.
|
||||
- For each of the k items, sample a bucket (H, W, T) according to the normalized probabilities in `--random-mm-bucket-config`, while tracking how many items of each modality have been added.
|
||||
- If a modality (e.g., image) reaches its limit from `--random-mm-limit-mm-per-prompt`, all buckets of that modality are excluded and the remaining bucket probabilities are renormalized before continuing.
|
||||
This should be seen as an edge case, and if this behavior can be avoided by setting `--random-mm-limit-mm-per-prompt` to a large number. Note that this might result in errors due to engine config `--limit-mm-per-prompt`.
|
||||
- The resulting request contains synthetic image data in `multi_modal_data` (OpenAI Chat format). When `random-mm` is used with the OpenAI Chat backend, prompts remain text and MM content is attached via `multi_modal_data`.
|
||||
|
||||
</details>
|
||||
- <https://docs.vllm.ai/en/latest/cli/bench/latency.html>
|
||||
- <https://docs.vllm.ai/en/latest/cli/bench/serve.html>
|
||||
- <https://docs.vllm.ai/en/latest/cli/bench/throughput.html>
|
||||
|
||||
@ -149,3 +149,70 @@ The script follows a systematic process to find the optimal parameters:
|
||||
4. **Track Best Result**: Throughout the process, the script tracks the parameter combination that has yielded the highest valid throughput so far.
|
||||
|
||||
5. **Profile Collection**: For the best-performing run, the script saves the vLLM profiler output, which can be used for deep-dive performance analysis with tools like TensorBoard.
|
||||
|
||||
## Batched `auto_tune`
|
||||
|
||||
The `batch_auto_tune.sh` script allows you to run multiple `auto_tune.sh` experiments sequentially from a single configuration file. It iterates through a list of parameter sets, executes `auto_tune.sh` for each, and records the results back into the input file.
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- **jq**: This script requires `jq` to parse the JSON configuration file.
|
||||
- **gcloud**: If you plan to upload results to Google Cloud Storage, the `gcloud` CLI must be installed and authenticated.
|
||||
|
||||
### How to Run
|
||||
|
||||
1. **Create a JSON configuration file**: Create a file (e.g., `runs_config.json`) containing an array of JSON objects. Each object defines the parameters for a single `auto_tune.sh` run.
|
||||
|
||||
2. **Execute the script**:
|
||||
|
||||
```bash
|
||||
bash batch_auto_tune.sh <path_to_json_file> [gcs_upload_path]
|
||||
```
|
||||
|
||||
- `<path_to_json_file>`: **Required.** Path to your JSON configuration file.
|
||||
- `[gcs_upload_path]`: **Optional.** A GCS path (e.g., `gs://my-bucket/benchmark-results`) where the detailed results and profiles for each run will be uploaded. If this is empty, the results will be available on the local filesystem (see the log for `RESULT_FILE=/path/to/results/file.txt`).
|
||||
|
||||
### Configuration File
|
||||
|
||||
The JSON configuration file should contain an array of objects. Each object's keys correspond to the configuration variables for `auto_tune.sh` (see the [Configuration table above](#configuration)). These keys will be converted to uppercase environment variables for each run.
|
||||
|
||||
Here is an example `runs_config.json` with two benchmark configurations:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"base": "/home/user",
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"system": "TPU", # OR GPU
|
||||
"tp": 8,
|
||||
"input_len": 128,
|
||||
"output_len": 2048,
|
||||
"max_model_len": 2300,
|
||||
"num_seqs_list": "128 256",
|
||||
"num_batched_tokens_list": "8192 16384"
|
||||
},
|
||||
{
|
||||
"base": "/home/user",
|
||||
"model": "meta-llama/Llama-3.1-70B-Instruct",
|
||||
"system": "TPU", # OR GPU
|
||||
"tp": 8,
|
||||
"input_len": 4000,
|
||||
"output_len": 16,
|
||||
"max_model_len": 4096,
|
||||
"num_seqs_list": "64 128",
|
||||
"num_batched_tokens_list": "4096 8192",
|
||||
"max_latency_allowed_ms": 500
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
### Output
|
||||
|
||||
The script modifies the input JSON file in place, adding the results of each run to the corresponding object. The following fields are added:
|
||||
|
||||
- `run_id`: A unique identifier for the run, derived from the timestamp.
|
||||
- `status`: The outcome of the run (`SUCCESS`, `FAILURE`, or `WARNING_NO_RESULT_FILE`).
|
||||
- `results`: The content of the `result.txt` file from the `auto_tune.sh` run.
|
||||
- `gcs_results`: The GCS URL where the run's artifacts are stored (if a GCS path was provided).
|
||||
|
||||
A summary of successful and failed runs is also printed to the console upon completion.
|
||||
|
||||
128
benchmarks/auto_tune/batch_auto_tune.sh
Executable file
128
benchmarks/auto_tune/batch_auto_tune.sh
Executable file
@ -0,0 +1,128 @@
|
||||
#!/bin/bash
|
||||
|
||||
INPUT_JSON="$1"
|
||||
GCS_PATH="$2" # Optional GCS path for uploading results for each run
|
||||
|
||||
SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)
|
||||
AUTOTUNE_SCRIPT="$SCRIPT_DIR/auto_tune.sh"
|
||||
|
||||
if [[ -z "$INPUT_JSON" ]]; then
|
||||
echo "Error: Input JSON file not provided."
|
||||
echo "Usage: $0 <path_to_json_file> [gcs_upload_path]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ ! -f "$INPUT_JSON" ]]; then
|
||||
echo "Error: File not found at '$INPUT_JSON'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! command -v jq &> /dev/null; then
|
||||
echo "Error: 'jq' command not found. Please install jq to process the JSON input."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ -n "$GCS_PATH" ]] && ! command -v gcloud &> /dev/null; then
|
||||
echo "Error: 'gcloud' command not found, but a GCS_PATH was provided."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
SUCCESS_COUNT=0
|
||||
FAILURE_COUNT=0
|
||||
FAILED_RUNS=()
|
||||
SCRIPT_START_TIME=$(date +%s)
|
||||
|
||||
json_content=$(cat "$INPUT_JSON")
|
||||
if ! num_runs=$(echo "$json_content" | jq 'length'); then
|
||||
echo "Error: Invalid JSON in $INPUT_JSON. 'jq' failed to get array length." >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Found $num_runs benchmark configurations in $INPUT_JSON."
|
||||
echo "Starting benchmark runs..."
|
||||
echo "--------------------------------------------------"
|
||||
|
||||
for i in $(seq 0 $(($num_runs - 1))); do
|
||||
run_object=$(echo "$json_content" | jq ".[$i]")
|
||||
|
||||
RUN_START_TIME=$(date +%s)
|
||||
ENV_VARS_ARRAY=()
|
||||
# Dynamically create env vars from the JSON object's keys
|
||||
for key in $(echo "$run_object" | jq -r 'keys_unsorted[]'); do
|
||||
value=$(echo "$run_object" | jq -r ".$key")
|
||||
var_name=$(echo "$key" | tr '[:lower:]' '[:upper:]' | tr -cd 'A-Z0-9_')
|
||||
ENV_VARS_ARRAY+=("${var_name}=${value}")
|
||||
done
|
||||
|
||||
echo "Executing run #$((i+1))/$num_runs with parameters: ${ENV_VARS_ARRAY[*]}"
|
||||
|
||||
# Execute auto_tune.sh and capture output
|
||||
RUN_OUTPUT_FILE=$(mktemp)
|
||||
if env "${ENV_VARS_ARRAY[@]}" bash "$AUTOTUNE_SCRIPT" > >(tee -a "$RUN_OUTPUT_FILE") 2>&1; then
|
||||
STATUS="SUCCESS"
|
||||
((SUCCESS_COUNT++))
|
||||
else
|
||||
STATUS="FAILURE"
|
||||
((FAILURE_COUNT++))
|
||||
FAILED_RUNS+=("Run #$((i+1)): $(echo $run_object | jq -c .)")
|
||||
fi
|
||||
|
||||
RUN_OUTPUT=$(<"$RUN_OUTPUT_FILE")
|
||||
rm "$RUN_OUTPUT_FILE"
|
||||
|
||||
# Parse results and optionally upload them to GCS
|
||||
RUN_ID=""
|
||||
RESULTS=""
|
||||
GCS_RESULTS_URL=""
|
||||
if [[ "$STATUS" == "SUCCESS" ]]; then
|
||||
RESULT_FILE_PATH=$(echo "$RUN_OUTPUT" | grep 'RESULT_FILE=' | tail -n 1 | cut -d'=' -f2 | tr -s '/' || true)
|
||||
|
||||
if [[ -n "$RESULT_FILE_PATH" && -f "$RESULT_FILE_PATH" ]]; then
|
||||
RUN_ID=$(basename "$(dirname "$RESULT_FILE_PATH")")
|
||||
RESULT_DIR=$(dirname "$RESULT_FILE_PATH")
|
||||
RESULTS=$(cat "$RESULT_FILE_PATH")
|
||||
|
||||
if [[ -n "$GCS_PATH" ]]; then
|
||||
GCS_RESULTS_URL="${GCS_PATH}/${RUN_ID}"
|
||||
echo "Uploading results to GCS..."
|
||||
if gcloud storage rsync --recursive "$RESULT_DIR/" "$GCS_RESULTS_URL"; then
|
||||
echo "GCS upload successful."
|
||||
else
|
||||
echo "Warning: GCS upload failed for RUN_ID $RUN_ID."
|
||||
fi
|
||||
fi
|
||||
else
|
||||
echo "Warning: Could not find result file for a successful run."
|
||||
STATUS="WARNING_NO_RESULT_FILE"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Add the results back into the JSON object for this run
|
||||
json_content=$(echo "$json_content" | jq --argjson i "$i" --arg run_id "$RUN_ID" --arg status "$STATUS" --arg results "$RESULTS" --arg gcs_results "$GCS_RESULTS_URL" \
|
||||
'.[$i] += {run_id: $run_id, status: $status, results: $results, gcs_results: $gcs_results}')
|
||||
|
||||
RUN_END_TIME=$(date +%s)
|
||||
echo "Run finished in $((RUN_END_TIME - RUN_START_TIME)) seconds. Status: $STATUS"
|
||||
echo "--------------------------------------------------"
|
||||
|
||||
# Save intermediate progress back to the file
|
||||
echo "$json_content" > "$INPUT_JSON.tmp" && mv "$INPUT_JSON.tmp" "$INPUT_JSON"
|
||||
|
||||
done
|
||||
|
||||
SCRIPT_END_TIME=$(date +%s)
|
||||
echo "All benchmark runs completed in $((SCRIPT_END_TIME - SCRIPT_START_TIME)) seconds."
|
||||
echo
|
||||
echo "====================== SUMMARY ======================"
|
||||
echo "Successful runs: $SUCCESS_COUNT"
|
||||
echo "Failed runs: $FAILURE_COUNT"
|
||||
echo "==================================================="
|
||||
|
||||
if [[ $FAILURE_COUNT -gt 0 ]]; then
|
||||
echo "Details of failed runs (see JSON file for full parameters):"
|
||||
for failed in "${FAILED_RUNS[@]}"; do
|
||||
echo " - $failed"
|
||||
done
|
||||
fi
|
||||
|
||||
echo "Updated results have been saved to '$INPUT_JSON'."
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,191 +1,17 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Benchmark the latency of processing a single batch of requests."""
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import vllm.envs as envs
|
||||
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def save_to_pytorch_benchmark_format(
|
||||
args: argparse.Namespace, results: dict[str, Any]
|
||||
) -> None:
|
||||
pt_records = convert_to_pytorch_benchmark_format(
|
||||
args=args,
|
||||
metrics={"latency": results["latencies"]},
|
||||
extra_info={k: results[k] for k in ["avg_latency", "percentiles"]},
|
||||
)
|
||||
if pt_records:
|
||||
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
|
||||
write_to_json(pt_file, pt_records)
|
||||
|
||||
|
||||
@deprecated(
|
||||
"benchmark_latency.py is deprecated and will be removed in a "
|
||||
"future version. Please use 'vllm bench latency' instead.",
|
||||
)
|
||||
def main(args: argparse.Namespace):
|
||||
print(args)
|
||||
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
|
||||
# NOTE(woosuk): If the request cannot be processed in a single batch,
|
||||
# the engine will automatically process the request in multiple batches.
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
assert llm.llm_engine.model_config.max_model_len >= (
|
||||
args.input_len + args.output_len
|
||||
), (
|
||||
"Please ensure that max_model_len is greater than"
|
||||
" the sum of input_len and output_len."
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
n=args.n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=args.output_len,
|
||||
detokenize=not args.disable_detokenize,
|
||||
)
|
||||
print(sampling_params)
|
||||
dummy_prompt_token_ids = np.random.randint(
|
||||
10000, size=(args.batch_size, args.input_len)
|
||||
)
|
||||
dummy_prompts: list[PromptType] = [
|
||||
{"prompt_token_ids": batch} for batch in dummy_prompt_token_ids.tolist()
|
||||
]
|
||||
|
||||
def llm_generate():
|
||||
if not args.use_beam_search:
|
||||
llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False)
|
||||
else:
|
||||
llm.beam_search(
|
||||
dummy_prompts,
|
||||
BeamSearchParams(
|
||||
beam_width=args.n,
|
||||
max_tokens=args.output_len,
|
||||
ignore_eos=True,
|
||||
),
|
||||
)
|
||||
|
||||
def run_to_completion(profile_dir: Optional[str] = None):
|
||||
if profile_dir:
|
||||
llm.start_profile()
|
||||
llm_generate()
|
||||
llm.stop_profile()
|
||||
else:
|
||||
start_time = time.perf_counter()
|
||||
llm_generate()
|
||||
end_time = time.perf_counter()
|
||||
latency = end_time - start_time
|
||||
return latency
|
||||
|
||||
print("Warming up...")
|
||||
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
|
||||
run_to_completion(profile_dir=None)
|
||||
|
||||
if args.profile:
|
||||
profile_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||
print(f"Profiling (results will be saved to '{profile_dir}')...")
|
||||
run_to_completion(profile_dir=profile_dir)
|
||||
return
|
||||
|
||||
# Benchmark.
|
||||
latencies = []
|
||||
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
|
||||
latencies.append(run_to_completion(profile_dir=None))
|
||||
latencies = np.array(latencies)
|
||||
percentages = [10, 25, 50, 75, 90, 99]
|
||||
percentiles = np.percentile(latencies, percentages)
|
||||
print(f"Avg latency: {np.mean(latencies)} seconds")
|
||||
for percentage, percentile in zip(percentages, percentiles):
|
||||
print(f"{percentage}% percentile latency: {percentile} seconds")
|
||||
|
||||
# Output JSON results if specified
|
||||
if args.output_json:
|
||||
results = {
|
||||
"avg_latency": np.mean(latencies),
|
||||
"latencies": latencies.tolist(),
|
||||
"percentiles": dict(zip(percentages, percentiles.tolist())),
|
||||
}
|
||||
with open(args.output_json, "w") as f:
|
||||
json.dump(results, f, indent=4)
|
||||
save_to_pytorch_benchmark_format(args, results)
|
||||
|
||||
|
||||
def create_argument_parser():
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark the latency of processing a single batch of "
|
||||
"requests till completion."
|
||||
)
|
||||
parser.add_argument("--input-len", type=int, default=32)
|
||||
parser.add_argument("--output-len", type=int, default=128)
|
||||
parser.add_argument("--batch-size", type=int, default=8)
|
||||
parser.add_argument(
|
||||
"--n",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of generated sequences per prompt.",
|
||||
)
|
||||
parser.add_argument("--use-beam-search", action="store_true")
|
||||
parser.add_argument(
|
||||
"--num-iters-warmup",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of iterations to run for warmup.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-iters", type=int, default=30, help="Number of iterations to run."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
action="store_true",
|
||||
help="profile the generation process of a single batch",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-json",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to save the latency results in JSON format.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-detokenize",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Do not detokenize responses (i.e. do not include "
|
||||
"detokenization time in the latency measurement)"
|
||||
),
|
||||
)
|
||||
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# V1 enables prefix caching by default which skews the latency
|
||||
# numbers. We need to disable prefix caching by default.
|
||||
parser.set_defaults(enable_prefix_caching=False)
|
||||
|
||||
return parser
|
||||
|
||||
import sys
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = create_argument_parser()
|
||||
args = parser.parse_args()
|
||||
if args.profile and not envs.VLLM_TORCH_PROFILER_DIR:
|
||||
raise OSError(
|
||||
"The environment variable 'VLLM_TORCH_PROFILER_DIR' is not set. "
|
||||
"Please set it to a valid path to use torch profiler."
|
||||
)
|
||||
main(args)
|
||||
print("""DEPRECATED: This script has been moved to the vLLM CLI.
|
||||
|
||||
Please use the following command instead:
|
||||
vllm bench latency
|
||||
|
||||
For help with the new command, run:
|
||||
vllm bench latency --help
|
||||
|
||||
Alternatively, you can run the new command directly with:
|
||||
python -m vllm.entrypoints.cli.main bench latency --help
|
||||
""")
|
||||
sys.exit(1)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -696,11 +696,11 @@ def evaluate(ret, args):
|
||||
return re.match(args.regex, actual) is not None
|
||||
|
||||
def _eval_correctness(expected, actual):
|
||||
if args.structure_type == "guided_json":
|
||||
if args.structure_type == "json":
|
||||
return _eval_correctness_json(expected, actual)
|
||||
elif args.structure_type == "guided_regex":
|
||||
elif args.structure_type == "regex":
|
||||
return _eval_correctness_regex(expected, actual)
|
||||
elif args.structure_type == "guided_choice":
|
||||
elif args.structure_type == "choice":
|
||||
return _eval_correctness_choice(expected, actual)
|
||||
else:
|
||||
return None
|
||||
@ -780,18 +780,18 @@ def main(args: argparse.Namespace):
|
||||
)
|
||||
|
||||
if args.dataset == "grammar":
|
||||
args.structure_type = "guided_grammar"
|
||||
args.structure_type = "grammar"
|
||||
elif args.dataset == "regex":
|
||||
args.structure_type = "guided_regex"
|
||||
args.structure_type = "regex"
|
||||
elif args.dataset == "choice":
|
||||
args.structure_type = "guided_choice"
|
||||
args.structure_type = "choice"
|
||||
else:
|
||||
args.structure_type = "guided_json"
|
||||
args.structure_type = "json"
|
||||
|
||||
if args.no_structured_output:
|
||||
args.structured_output_ratio = 0
|
||||
if args.save_results:
|
||||
result_file_name = f"{args.structured_output_ratio}guided"
|
||||
result_file_name = f"{args.structured_output_ratio}so"
|
||||
result_file_name += f"_{backend}"
|
||||
result_file_name += f"_{args.request_rate}qps"
|
||||
result_file_name += f"_{args.model.split('/')[-1]}"
|
||||
|
||||
@ -1,741 +1,17 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Benchmark offline inference throughput."""
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import warnings
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import uvloop
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from benchmark_dataset import (
|
||||
AIMODataset,
|
||||
BurstGPTDataset,
|
||||
ConversationDataset,
|
||||
InstructCoderDataset,
|
||||
RandomDataset,
|
||||
SampleRequest,
|
||||
ShareGPTDataset,
|
||||
SonnetDataset,
|
||||
VisionArenaDataset,
|
||||
)
|
||||
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.entrypoints.openai.api_server import (
|
||||
build_async_engine_client_from_engine_args,
|
||||
)
|
||||
from vllm.inputs import TextPrompt, TokensPrompt
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
|
||||
|
||||
|
||||
def run_vllm(
|
||||
requests: list[SampleRequest],
|
||||
n: int,
|
||||
engine_args: EngineArgs,
|
||||
disable_detokenize: bool = False,
|
||||
) -> tuple[float, Optional[list[RequestOutput]]]:
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
assert all(
|
||||
llm.llm_engine.model_config.max_model_len
|
||||
>= (request.prompt_len + request.expected_output_len)
|
||||
for request in requests
|
||||
), (
|
||||
"Please ensure that max_model_len is greater than the sum of"
|
||||
" prompt_len and expected_output_len for all requests."
|
||||
)
|
||||
# Add the requests to the engine.
|
||||
prompts: list[Union[TextPrompt, TokensPrompt]] = []
|
||||
sampling_params: list[SamplingParams] = []
|
||||
for request in requests:
|
||||
prompts.append(
|
||||
TokensPrompt(
|
||||
prompt_token_ids=request.prompt["prompt_token_ids"],
|
||||
multi_modal_data=request.multi_modal_data,
|
||||
)
|
||||
if "prompt_token_ids" in request.prompt
|
||||
else TextPrompt(
|
||||
prompt=request.prompt, multi_modal_data=request.multi_modal_data
|
||||
)
|
||||
)
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
detokenize=not disable_detokenize,
|
||||
)
|
||||
)
|
||||
lora_requests: Optional[list[LoRARequest]] = None
|
||||
if engine_args.enable_lora:
|
||||
lora_requests = [request.lora_request for request in requests]
|
||||
|
||||
use_beam_search = False
|
||||
|
||||
outputs = None
|
||||
if not use_beam_search:
|
||||
start = time.perf_counter()
|
||||
outputs = llm.generate(
|
||||
prompts, sampling_params, lora_request=lora_requests, use_tqdm=True
|
||||
)
|
||||
end = time.perf_counter()
|
||||
else:
|
||||
assert lora_requests is None, "BeamSearch API does not support LoRA"
|
||||
# output_len should be the same for all requests.
|
||||
output_len = requests[0].expected_output_len
|
||||
for request in requests:
|
||||
assert request.expected_output_len == output_len
|
||||
start = time.perf_counter()
|
||||
llm.beam_search(
|
||||
prompts,
|
||||
BeamSearchParams(
|
||||
beam_width=n,
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
),
|
||||
)
|
||||
end = time.perf_counter()
|
||||
return end - start, outputs
|
||||
|
||||
|
||||
def run_vllm_chat(
|
||||
requests: list[SampleRequest],
|
||||
n: int,
|
||||
engine_args: EngineArgs,
|
||||
disable_detokenize: bool = False,
|
||||
) -> tuple[float, list[RequestOutput]]:
|
||||
"""
|
||||
Run vLLM chat benchmark. This function is recommended ONLY for benchmarking
|
||||
multimodal models as it properly handles multimodal inputs and chat
|
||||
formatting. For non-multimodal models, use run_vllm() instead.
|
||||
"""
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
|
||||
assert all(
|
||||
llm.llm_engine.model_config.max_model_len
|
||||
>= (request.prompt_len + request.expected_output_len)
|
||||
for request in requests
|
||||
), (
|
||||
"Please ensure that max_model_len is greater than the sum of "
|
||||
"prompt_len and expected_output_len for all requests."
|
||||
)
|
||||
|
||||
prompts = []
|
||||
sampling_params: list[SamplingParams] = []
|
||||
for request in requests:
|
||||
prompts.append(request.prompt)
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
detokenize=not disable_detokenize,
|
||||
)
|
||||
)
|
||||
start = time.perf_counter()
|
||||
outputs = llm.chat(prompts, sampling_params, use_tqdm=True)
|
||||
end = time.perf_counter()
|
||||
return end - start, outputs
|
||||
|
||||
|
||||
async def run_vllm_async(
|
||||
requests: list[SampleRequest],
|
||||
n: int,
|
||||
engine_args: AsyncEngineArgs,
|
||||
disable_frontend_multiprocessing: bool = False,
|
||||
disable_detokenize: bool = False,
|
||||
) -> float:
|
||||
from vllm import SamplingParams
|
||||
|
||||
async with build_async_engine_client_from_engine_args(
|
||||
engine_args,
|
||||
disable_frontend_multiprocessing=disable_frontend_multiprocessing,
|
||||
) as llm:
|
||||
model_config = await llm.get_model_config()
|
||||
assert all(
|
||||
model_config.max_model_len
|
||||
>= (request.prompt_len + request.expected_output_len)
|
||||
for request in requests
|
||||
), (
|
||||
"Please ensure that max_model_len is greater than the sum of"
|
||||
" prompt_len and expected_output_len for all requests."
|
||||
)
|
||||
|
||||
# Add the requests to the engine.
|
||||
prompts: list[Union[TextPrompt, TokensPrompt]] = []
|
||||
sampling_params: list[SamplingParams] = []
|
||||
lora_requests: list[Optional[LoRARequest]] = []
|
||||
for request in requests:
|
||||
prompts.append(
|
||||
TokensPrompt(
|
||||
prompt_token_ids=request.prompt["prompt_token_ids"],
|
||||
multi_modal_data=request.multi_modal_data,
|
||||
)
|
||||
if "prompt_token_ids" in request.prompt
|
||||
else TextPrompt(
|
||||
prompt=request.prompt, multi_modal_data=request.multi_modal_data
|
||||
)
|
||||
)
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
detokenize=not disable_detokenize,
|
||||
)
|
||||
)
|
||||
lora_requests.append(request.lora_request)
|
||||
|
||||
generators = []
|
||||
start = time.perf_counter()
|
||||
for i, (prompt, sp, lr) in enumerate(
|
||||
zip(prompts, sampling_params, lora_requests)
|
||||
):
|
||||
generator = llm.generate(prompt, sp, lora_request=lr, request_id=f"test{i}")
|
||||
generators.append(generator)
|
||||
all_gens = merge_async_iterators(*generators)
|
||||
async for i, res in all_gens:
|
||||
pass
|
||||
end = time.perf_counter()
|
||||
return end - start
|
||||
|
||||
|
||||
def run_hf(
|
||||
requests: list[SampleRequest],
|
||||
model: str,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
n: int,
|
||||
max_batch_size: int,
|
||||
trust_remote_code: bool,
|
||||
disable_detokenize: bool = False,
|
||||
) -> float:
|
||||
llm = AutoModelForCausalLM.from_pretrained(
|
||||
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code
|
||||
)
|
||||
if llm.config.model_type == "llama":
|
||||
# To enable padding in the HF backend.
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
llm = llm.cuda()
|
||||
|
||||
pbar = tqdm(total=len(requests))
|
||||
start = time.perf_counter()
|
||||
batch: list[str] = []
|
||||
max_prompt_len = 0
|
||||
max_output_len = 0
|
||||
for i in range(len(requests)):
|
||||
prompt = requests[i].prompt
|
||||
prompt_len = requests[i].prompt_len
|
||||
output_len = requests[i].expected_output_len
|
||||
# Add the prompt to the batch.
|
||||
batch.append(prompt)
|
||||
max_prompt_len = max(max_prompt_len, prompt_len)
|
||||
max_output_len = max(max_output_len, output_len)
|
||||
if len(batch) < max_batch_size and i != len(requests) - 1:
|
||||
# Check if we can add more requests to the batch.
|
||||
next_prompt_len = requests[i + 1].prompt_len
|
||||
next_output_len = requests[i + 1].expected_output_len
|
||||
if (
|
||||
max(max_prompt_len, next_prompt_len)
|
||||
+ max(max_output_len, next_output_len)
|
||||
) <= 2048:
|
||||
# We can add more requests to the batch.
|
||||
continue
|
||||
|
||||
# Generate the sequences.
|
||||
input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids
|
||||
llm_outputs = llm.generate(
|
||||
input_ids=input_ids.cuda(),
|
||||
do_sample=True,
|
||||
num_return_sequences=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
use_cache=True,
|
||||
max_new_tokens=max_output_len,
|
||||
)
|
||||
if not disable_detokenize:
|
||||
# Include the decoding time.
|
||||
tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
|
||||
pbar.update(len(batch))
|
||||
|
||||
# Clear the batch.
|
||||
batch = []
|
||||
max_prompt_len = 0
|
||||
max_output_len = 0
|
||||
end = time.perf_counter()
|
||||
return end - start
|
||||
|
||||
|
||||
def run_mii(
|
||||
requests: list[SampleRequest],
|
||||
model: str,
|
||||
tensor_parallel_size: int,
|
||||
output_len: int,
|
||||
) -> float:
|
||||
from mii import client, serve
|
||||
|
||||
llm = serve(model, tensor_parallel=tensor_parallel_size)
|
||||
prompts = [request.prompt for request in requests]
|
||||
|
||||
start = time.perf_counter()
|
||||
llm.generate(prompts, max_new_tokens=output_len)
|
||||
end = time.perf_counter()
|
||||
client = client(model)
|
||||
client.terminate_server()
|
||||
return end - start
|
||||
|
||||
|
||||
def save_to_pytorch_benchmark_format(
|
||||
args: argparse.Namespace, results: dict[str, Any]
|
||||
) -> None:
|
||||
pt_records = convert_to_pytorch_benchmark_format(
|
||||
args=args,
|
||||
metrics={
|
||||
"requests_per_second": [results["requests_per_second"]],
|
||||
"tokens_per_second": [results["tokens_per_second"]],
|
||||
},
|
||||
extra_info={
|
||||
k: results[k] for k in ["elapsed_time", "num_requests", "total_num_tokens"]
|
||||
},
|
||||
)
|
||||
if pt_records:
|
||||
# Don't use json suffix here as we don't want CI to pick it up
|
||||
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
|
||||
write_to_json(pt_file, pt_records)
|
||||
|
||||
|
||||
def get_requests(args, tokenizer):
|
||||
# Common parameters for all dataset types.
|
||||
common_kwargs = {
|
||||
"dataset_path": args.dataset_path,
|
||||
"random_seed": args.seed,
|
||||
}
|
||||
sample_kwargs = {
|
||||
"tokenizer": tokenizer,
|
||||
"lora_path": args.lora_path,
|
||||
"max_loras": args.max_loras,
|
||||
"num_requests": args.num_prompts,
|
||||
"input_len": args.input_len,
|
||||
"output_len": args.output_len,
|
||||
}
|
||||
|
||||
if args.dataset_path is None or args.dataset_name == "random":
|
||||
sample_kwargs["range_ratio"] = args.random_range_ratio
|
||||
sample_kwargs["prefix_len"] = args.prefix_len
|
||||
dataset_cls = RandomDataset
|
||||
elif args.dataset_name == "sharegpt":
|
||||
dataset_cls = ShareGPTDataset
|
||||
if args.backend == "vllm-chat":
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
elif args.dataset_name == "sonnet":
|
||||
assert tokenizer.chat_template or tokenizer.default_chat_template, (
|
||||
"Tokenizer/model must have chat template for sonnet dataset."
|
||||
)
|
||||
dataset_cls = SonnetDataset
|
||||
sample_kwargs["prefix_len"] = args.prefix_len
|
||||
sample_kwargs["return_prompt_formatted"] = True
|
||||
elif args.dataset_name == "burstgpt":
|
||||
dataset_cls = BurstGPTDataset
|
||||
elif args.dataset_name == "hf":
|
||||
common_kwargs["no_stream"] = args.no_stream
|
||||
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = VisionArenaDataset
|
||||
common_kwargs["dataset_subset"] = None
|
||||
common_kwargs["dataset_split"] = "train"
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = InstructCoderDataset
|
||||
common_kwargs["dataset_split"] = "train"
|
||||
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = ConversationDataset
|
||||
common_kwargs["dataset_subset"] = args.hf_subset
|
||||
common_kwargs["dataset_split"] = args.hf_split
|
||||
sample_kwargs["enable_multimodal_chat"] = True
|
||||
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = AIMODataset
|
||||
common_kwargs["dataset_subset"] = None
|
||||
common_kwargs["dataset_split"] = "train"
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
|
||||
# Remove None values
|
||||
sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None}
|
||||
return dataset_cls(**common_kwargs).sample(**sample_kwargs)
|
||||
|
||||
|
||||
@deprecated(
|
||||
"benchmark_throughput.py is deprecated and will be removed in a "
|
||||
"future version. Please use 'vllm bench throughput' instead.",
|
||||
)
|
||||
def main(args: argparse.Namespace):
|
||||
if args.seed is None:
|
||||
args.seed = 0
|
||||
print(args)
|
||||
random.seed(args.seed)
|
||||
# Sample the requests.
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer, trust_remote_code=args.trust_remote_code
|
||||
)
|
||||
requests = get_requests(args, tokenizer)
|
||||
is_multi_modal = any(request.multi_modal_data is not None for request in requests)
|
||||
request_outputs: Optional[list[RequestOutput]] = None
|
||||
if args.backend == "vllm":
|
||||
if args.async_engine:
|
||||
elapsed_time = uvloop.run(
|
||||
run_vllm_async(
|
||||
requests,
|
||||
args.n,
|
||||
AsyncEngineArgs.from_cli_args(args),
|
||||
args.disable_frontend_multiprocessing,
|
||||
args.disable_detokenize,
|
||||
)
|
||||
)
|
||||
else:
|
||||
elapsed_time, request_outputs = run_vllm(
|
||||
requests,
|
||||
args.n,
|
||||
EngineArgs.from_cli_args(args),
|
||||
args.disable_detokenize,
|
||||
)
|
||||
elif args.backend == "hf":
|
||||
assert args.tensor_parallel_size == 1
|
||||
elapsed_time = run_hf(
|
||||
requests,
|
||||
args.model,
|
||||
tokenizer,
|
||||
args.n,
|
||||
args.hf_max_batch_size,
|
||||
args.trust_remote_code,
|
||||
args.disable_detokenize,
|
||||
)
|
||||
elif args.backend == "mii":
|
||||
elapsed_time = run_mii(
|
||||
requests, args.model, args.tensor_parallel_size, args.output_len
|
||||
)
|
||||
elif args.backend == "vllm-chat":
|
||||
elapsed_time, request_outputs = run_vllm_chat(
|
||||
requests, args.n, EngineArgs.from_cli_args(args), args.disable_detokenize
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown backend: {args.backend}")
|
||||
|
||||
if request_outputs:
|
||||
# Note: with the vllm and vllm-chat backends,
|
||||
# we have request_outputs, which we use to count tokens.
|
||||
total_prompt_tokens = 0
|
||||
total_output_tokens = 0
|
||||
for ro in request_outputs:
|
||||
if not isinstance(ro, RequestOutput):
|
||||
continue
|
||||
total_prompt_tokens += (
|
||||
len(ro.prompt_token_ids) if ro.prompt_token_ids else 0
|
||||
)
|
||||
total_output_tokens += sum(len(o.token_ids) for o in ro.outputs if o)
|
||||
total_num_tokens = total_prompt_tokens + total_output_tokens
|
||||
else:
|
||||
total_num_tokens = sum(r.prompt_len + r.expected_output_len for r in requests)
|
||||
total_output_tokens = sum(r.expected_output_len for r in requests)
|
||||
total_prompt_tokens = total_num_tokens - total_output_tokens
|
||||
|
||||
if is_multi_modal and args.backend != "vllm-chat":
|
||||
print(
|
||||
"\033[91mWARNING\033[0m: Multi-modal request with "
|
||||
f"{args.backend} backend detected. The "
|
||||
"following metrics are not accurate because image tokens are not"
|
||||
" counted. See vllm-project/vllm/issues/9778 for details."
|
||||
)
|
||||
# TODO(vllm-project/vllm/issues/9778): Count multi-modal token length.
|
||||
# vllm-chat backend counts the image tokens now
|
||||
|
||||
print(
|
||||
f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
||||
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
|
||||
f"{total_output_tokens / elapsed_time:.2f} output tokens/s"
|
||||
)
|
||||
print(f"Total num prompt tokens: {total_prompt_tokens}")
|
||||
print(f"Total num output tokens: {total_output_tokens}")
|
||||
|
||||
# Output JSON results if specified
|
||||
if args.output_json:
|
||||
results = {
|
||||
"elapsed_time": elapsed_time,
|
||||
"num_requests": len(requests),
|
||||
"total_num_tokens": total_num_tokens,
|
||||
"requests_per_second": len(requests) / elapsed_time,
|
||||
"tokens_per_second": total_num_tokens / elapsed_time,
|
||||
}
|
||||
with open(args.output_json, "w") as f:
|
||||
json.dump(results, f, indent=4)
|
||||
save_to_pytorch_benchmark_format(args, results)
|
||||
|
||||
|
||||
def validate_args(args):
|
||||
"""
|
||||
Validate command-line arguments.
|
||||
"""
|
||||
|
||||
# === Deprecation and Defaulting ===
|
||||
if args.dataset is not None:
|
||||
warnings.warn(
|
||||
"The '--dataset' argument will be deprecated in the next release. "
|
||||
"Please use '--dataset-name' and '--dataset-path' instead.",
|
||||
stacklevel=2,
|
||||
)
|
||||
args.dataset_path = args.dataset
|
||||
|
||||
if not getattr(args, "tokenizer", None):
|
||||
args.tokenizer = args.model
|
||||
|
||||
# === Backend Validation ===
|
||||
valid_backends = {"vllm", "hf", "mii", "vllm-chat"}
|
||||
if args.backend not in valid_backends:
|
||||
raise ValueError(f"Unsupported backend: {args.backend}")
|
||||
|
||||
# === Dataset Configuration ===
|
||||
if not args.dataset and not args.dataset_path:
|
||||
print("When dataset path is not set, it will default to random dataset")
|
||||
args.dataset_name = "random"
|
||||
if args.input_len is None:
|
||||
raise ValueError("input_len must be provided for a random dataset")
|
||||
|
||||
# === Dataset Name Specific Checks ===
|
||||
# --hf-subset and --hf-split: only used
|
||||
# when dataset_name is 'hf'
|
||||
if args.dataset_name != "hf" and (
|
||||
getattr(args, "hf_subset", None) is not None
|
||||
or getattr(args, "hf_split", None) is not None
|
||||
):
|
||||
warnings.warn(
|
||||
"--hf-subset and --hf-split will be ignored \
|
||||
since --dataset-name is not 'hf'.",
|
||||
stacklevel=2,
|
||||
)
|
||||
elif args.dataset_name == "hf":
|
||||
if args.dataset_path in (
|
||||
VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
|
||||
| ConversationDataset.SUPPORTED_DATASET_PATHS
|
||||
):
|
||||
assert args.backend == "vllm-chat", (
|
||||
f"{args.dataset_path} needs to use vllm-chat as the backend."
|
||||
) # noqa: E501
|
||||
elif args.dataset_path in (
|
||||
InstructCoderDataset.SUPPORTED_DATASET_PATHS
|
||||
| AIMODataset.SUPPORTED_DATASET_PATHS
|
||||
):
|
||||
assert args.backend == "vllm", (
|
||||
f"{args.dataset_path} needs to use vllm as the backend."
|
||||
) # noqa: E501
|
||||
else:
|
||||
raise ValueError(f"{args.dataset_path} is not supported by hf dataset.")
|
||||
|
||||
# --random-range-ratio: only used when dataset_name is 'random'
|
||||
if args.dataset_name != "random" and args.random_range_ratio is not None:
|
||||
warnings.warn(
|
||||
"--random-range-ratio will be ignored since \
|
||||
--dataset-name is not 'random'.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# --prefix-len: only used when dataset_name is 'random', 'sonnet', or not
|
||||
# set.
|
||||
if (
|
||||
args.dataset_name not in {"random", "sonnet", None}
|
||||
and args.prefix_len is not None
|
||||
):
|
||||
warnings.warn(
|
||||
"--prefix-len will be ignored since --dataset-name\
|
||||
is not 'random', 'sonnet', or not set.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# === LoRA Settings ===
|
||||
if getattr(args, "enable_lora", False) and args.backend != "vllm":
|
||||
raise ValueError("LoRA benchmarking is only supported for vLLM backend")
|
||||
if getattr(args, "enable_lora", False) and args.lora_path is None:
|
||||
raise ValueError("LoRA path must be provided when enable_lora is True")
|
||||
|
||||
# === Backend-specific Validations ===
|
||||
if args.backend == "hf" and args.hf_max_batch_size is None:
|
||||
raise ValueError("HF max batch size is required for HF backend")
|
||||
if args.backend != "hf" and args.hf_max_batch_size is not None:
|
||||
raise ValueError("HF max batch size is only for HF backend.")
|
||||
|
||||
if (
|
||||
args.backend in {"hf", "mii"}
|
||||
and getattr(args, "quantization", None) is not None
|
||||
):
|
||||
raise ValueError("Quantization is only for vLLM backend.")
|
||||
|
||||
if args.backend == "mii" and args.dtype != "auto":
|
||||
raise ValueError("dtype must be auto for MII backend.")
|
||||
if args.backend == "mii" and args.n != 1:
|
||||
raise ValueError("n must be 1 for MII backend.")
|
||||
if args.backend == "mii" and args.tokenizer != args.model:
|
||||
raise ValueError("Tokenizer must be the same as the model for MII backend.")
|
||||
|
||||
# --data-parallel is not supported currently.
|
||||
# https://github.com/vllm-project/vllm/issues/16222
|
||||
if args.data_parallel_size > 1:
|
||||
raise ValueError(
|
||||
"Data parallel is not supported in offline benchmark, "
|
||||
"please use benchmark serving instead"
|
||||
)
|
||||
|
||||
|
||||
def create_argument_parser():
|
||||
parser = FlexibleArgumentParser(description="Benchmark the throughput.")
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
type=str,
|
||||
choices=["vllm", "hf", "mii", "vllm-chat"],
|
||||
default="vllm",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"],
|
||||
help="Name of the dataset to benchmark on.",
|
||||
default="sharegpt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-stream",
|
||||
action="store_true",
|
||||
help="Do not load the dataset in streaming mode.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the ShareGPT dataset, will be deprecated in\
|
||||
the next release. The dataset is expected to "
|
||||
"be a json in form of list[dict[..., conversations: "
|
||||
"list[dict[..., value: <prompt_or_response>]]]]",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-path", type=str, default=None, help="Path to the dataset"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Input prompt length for each request",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Output length for each request. Overrides the "
|
||||
"output length from the dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n", type=int, default=1, help="Number of generated sequences per prompt."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-prompts", type=int, default=1000, help="Number of prompts to process."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-max-batch-size",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum batch size for HF backend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-json",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to save the throughput results in JSON format.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--async-engine",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use vLLM async engine rather than LLM class.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-frontend-multiprocessing",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Disable decoupled async engine frontend.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-detokenize",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Do not detokenize the response (i.e. do not include "
|
||||
"detokenization time in the measurement)"
|
||||
),
|
||||
)
|
||||
# LoRA
|
||||
parser.add_argument(
|
||||
"--lora-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the LoRA adapters to use. This can be an absolute path, "
|
||||
"a relative path, or a Hugging Face model identifier.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefix-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help=f"Number of prefix tokens to be used in RandomDataset "
|
||||
"and SonnetDataset. For RandomDataset, the total input "
|
||||
"length is the sum of prefix-len (default: "
|
||||
f"{RandomDataset.DEFAULT_PREFIX_LEN}) and a random context length "
|
||||
"sampled from [input_len * (1 - range_ratio), "
|
||||
"input_len * (1 + range_ratio)]. For SonnetDataset, "
|
||||
f"prefix_len (default: {SonnetDataset.DEFAULT_PREFIX_LEN}) "
|
||||
"controls how much of the input is fixed lines versus "
|
||||
"random lines, but the total input length remains approximately "
|
||||
"input_len tokens.",
|
||||
)
|
||||
# random dataset
|
||||
parser.add_argument(
|
||||
"--random-range-ratio",
|
||||
type=float,
|
||||
default=None,
|
||||
help=f"Range ratio (default : {RandomDataset.DEFAULT_RANGE_RATIO}) "
|
||||
"for sampling input/output length, "
|
||||
"used only for RandomDataset. Must be in the range [0, 1) to "
|
||||
"define a symmetric sampling range "
|
||||
"[length * (1 - range_ratio), length * (1 + range_ratio)].",
|
||||
)
|
||||
|
||||
# hf dataset
|
||||
parser.add_argument(
|
||||
"--hf-subset", type=str, default=None, help="Subset of the HF dataset."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-split", type=str, default=None, help="Split of the HF dataset."
|
||||
)
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
|
||||
return parser
|
||||
|
||||
import sys
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = create_argument_parser()
|
||||
args = parser.parse_args()
|
||||
if args.tokenizer is None:
|
||||
args.tokenizer = args.model
|
||||
validate_args(args)
|
||||
main(args)
|
||||
print("""DEPRECATED: This script has been moved to the vLLM CLI.
|
||||
|
||||
Please use the following command instead:
|
||||
vllm bench throughput
|
||||
|
||||
For help with the new command, run:
|
||||
vllm bench throughput --help
|
||||
|
||||
Alternatively, you can run the new command directly with:
|
||||
python -m vllm.entrypoints.cli.main bench throughput --help
|
||||
""")
|
||||
sys.exit(1)
|
||||
|
||||
@ -4,7 +4,10 @@
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
w8a8_block_fp8_matmul,
|
||||
apply_w8a8_block_fp8_linear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
CUTLASS_BLOCK_FP8_SUPPORTED,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton as vllm_triton
|
||||
@ -29,7 +32,7 @@ DEEPSEEK_V3_SHAPES = [
|
||||
]
|
||||
|
||||
|
||||
def build_w8a8_block_fp8_runner(M, N, K, block_size, device):
|
||||
def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass):
|
||||
"""Build runner function for w8a8 block fp8 matmul."""
|
||||
factor_for_scale = 1e-2
|
||||
|
||||
@ -37,37 +40,54 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device):
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
# Create random FP8 tensors
|
||||
A_fp32 = (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
|
||||
A = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
A_ref = (torch.rand(M, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max
|
||||
|
||||
B_fp32 = (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max
|
||||
B = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
B_ref = (torch.rand(N, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max
|
||||
B = B_ref.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
# Create scales
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n_tiles = (N + block_n - 1) // block_n
|
||||
k_tiles = (K + block_k - 1) // block_k
|
||||
|
||||
As = torch.rand(M, k_tiles, dtype=torch.float32, device=device) * factor_for_scale
|
||||
Bs = (
|
||||
torch.rand(n_tiles, k_tiles, dtype=torch.float32, device=device)
|
||||
* factor_for_scale
|
||||
)
|
||||
|
||||
# SM90 CUTLASS requires row-major format for scales
|
||||
if use_cutlass and current_platform.is_device_capability(90):
|
||||
Bs = Bs.T.contiguous()
|
||||
|
||||
def run():
|
||||
return w8a8_block_fp8_matmul(A, B, As, Bs, block_size, torch.bfloat16)
|
||||
if use_cutlass:
|
||||
return apply_w8a8_block_fp8_linear(
|
||||
A_ref, B, block_size, Bs, cutlass_block_fp8_supported=True
|
||||
)
|
||||
else:
|
||||
return apply_w8a8_block_fp8_linear(
|
||||
A_ref, B, block_size, Bs, cutlass_block_fp8_supported=False
|
||||
)
|
||||
|
||||
return run
|
||||
|
||||
|
||||
# Determine available providers
|
||||
available_providers = ["torch-bf16", "w8a8-block-fp8-triton"]
|
||||
plot_title = "BF16 vs W8A8 Block FP8 GEMMs"
|
||||
|
||||
if CUTLASS_BLOCK_FP8_SUPPORTED:
|
||||
available_providers.append("w8a8-block-fp8-cutlass")
|
||||
|
||||
|
||||
@vllm_triton.testing.perf_report(
|
||||
vllm_triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384],
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=["torch-bf16", "w8a8-block-fp8"],
|
||||
line_names=["torch-bf16", "w8a8-block-fp8"],
|
||||
line_vals=available_providers,
|
||||
line_names=available_providers,
|
||||
ylabel="TFLOP/s (larger is better)",
|
||||
plot_name="BF16 vs W8A8 Block FP8 GEMMs",
|
||||
args={},
|
||||
@ -85,11 +105,22 @@ def benchmark_tflops(batch_size, provider, N, K, block_size=(128, 128)):
|
||||
ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph(
|
||||
lambda: torch.nn.functional.linear(a, b), quantiles=quantiles
|
||||
)
|
||||
else: # w8a8-block-fp8
|
||||
run_w8a8 = build_w8a8_block_fp8_runner(M, N, K, block_size, device)
|
||||
ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph(
|
||||
lambda: run_w8a8(), quantiles=quantiles
|
||||
elif provider == "w8a8-block-fp8-triton":
|
||||
run_w8a8_triton = build_w8a8_block_fp8_runner(
|
||||
M, N, K, block_size, device, use_cutlass=False
|
||||
)
|
||||
ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph(
|
||||
lambda: run_w8a8_triton(), quantiles=quantiles
|
||||
)
|
||||
elif provider == "w8a8-block-fp8-cutlass":
|
||||
run_w8a8_cutlass = build_w8a8_block_fp8_runner(
|
||||
M, N, K, block_size, device, use_cutlass=True
|
||||
)
|
||||
ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph(
|
||||
lambda: run_w8a8_cutlass(), quantiles=quantiles
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown provider: {provider}")
|
||||
|
||||
to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3)
|
||||
return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms)
|
||||
|
||||
@ -2,14 +2,25 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import itertools
|
||||
from typing import Callable
|
||||
from unittest.mock import patch
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
|
||||
|
||||
|
||||
def with_triton_mode(fn):
|
||||
"""Temporarily force the Triton fallback path"""
|
||||
|
||||
def wrapped(*args, **kwargs):
|
||||
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
# TODO(luka): use standalone_compile utility
|
||||
@ -21,78 +32,236 @@ def with_dyn_arg(fn: Callable, arg_index: int, dim_index: int):
|
||||
return inner
|
||||
|
||||
|
||||
torch._dynamo.config.recompile_limit = 8888
|
||||
compilation_config = CompilationConfig(custom_ops=["none"])
|
||||
with set_current_vllm_config(VllmConfig(compilation_config=compilation_config)):
|
||||
torch_per_token_quant_fp8 = torch.compile(
|
||||
QuantFP8(False, GroupShape.PER_TOKEN),
|
||||
fullgraph=True,
|
||||
dynamic=False, # recompile for different shapes
|
||||
)
|
||||
def bench_compile(fn: Callable):
|
||||
# recompile for different shapes
|
||||
fwd = torch.compile(fn, fullgraph=True, dynamic=False)
|
||||
|
||||
# First dim is explicitly dynamic to simulate vLLM usage
|
||||
torch_per_token_quant_fp8 = with_dyn_arg(torch_per_token_quant_fp8, 0, 0)
|
||||
return with_dyn_arg(fwd, 0, 0)
|
||||
|
||||
|
||||
def cuda_per_token_quant_fp8(
|
||||
input: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return ops.scaled_fp8_quant(input)
|
||||
torch._dynamo.config.recompile_limit = 8888
|
||||
|
||||
|
||||
def calculate_diff(batch_size: int, seq_len: int):
|
||||
"""Calculate difference between Triton and CUDA implementations."""
|
||||
def calculate_diff(
|
||||
batch_size: int,
|
||||
hidden_size: int,
|
||||
group_shape: GroupShape,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""Calculate the difference between Inductor and CUDA implementations."""
|
||||
device = torch.device("cuda")
|
||||
x = torch.rand((batch_size * seq_len, 4096), dtype=torch.float16, device=device)
|
||||
x = torch.rand((batch_size * hidden_size, 4096), dtype=dtype, device=device)
|
||||
|
||||
torch_out, torch_scale = torch_per_token_quant_fp8(x)
|
||||
cuda_out, cuda_scale = cuda_per_token_quant_fp8(x)
|
||||
quant_fp8 = QuantFP8(False, group_shape, column_major_scales=False)
|
||||
|
||||
if torch.allclose(
|
||||
cuda_out.to(torch.float32), torch_out.to(torch.float32), rtol=1e-3, atol=1e-5
|
||||
) and torch.allclose(cuda_scale, torch_scale, rtol=1e-3, atol=1e-5):
|
||||
torch_out, torch_scale = bench_compile(quant_fp8.forward_native)(x)
|
||||
torch_eager_out, torch_eager_scale = quant_fp8.forward_native(x)
|
||||
cuda_out, cuda_scale = quant_fp8.forward_cuda(x)
|
||||
|
||||
out_allclose = lambda o1, o2: torch.allclose(
|
||||
o1.to(torch.float32),
|
||||
o2.to(torch.float32),
|
||||
rtol=1e-3,
|
||||
atol=1e-5,
|
||||
)
|
||||
scale_allclose = lambda s1, s2: torch.allclose(s1, s2, rtol=1e-3, atol=1e-5)
|
||||
|
||||
if (
|
||||
out_allclose(cuda_out, torch_out)
|
||||
and scale_allclose(cuda_scale, torch_scale)
|
||||
and out_allclose(cuda_out, torch_eager_out)
|
||||
and scale_allclose(cuda_scale, torch_eager_scale)
|
||||
):
|
||||
print("✅ All implementations match")
|
||||
else:
|
||||
print("❌ Implementations differ")
|
||||
|
||||
|
||||
batch_size_range = [1, 16, 32, 64, 128]
|
||||
seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
|
||||
|
||||
configs = list(itertools.product(batch_size_range, seq_len_range))
|
||||
configs = []
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size", "seq_len"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["torch", "cuda"],
|
||||
line_names=["Torch", "CUDA"],
|
||||
styles=[("blue", "-"), ("green", "-")],
|
||||
ylabel="us",
|
||||
plot_name="per-token-dynamic-quant-fp8-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark_quantization(batch_size, seq_len, provider):
|
||||
dtype = torch.float16
|
||||
def benchmark_quantization(
|
||||
batch_size,
|
||||
hidden_size,
|
||||
provider,
|
||||
group_shape: GroupShape,
|
||||
col_major: bool,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
device = torch.device("cuda")
|
||||
|
||||
x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype)
|
||||
x = torch.randn(batch_size * hidden_size, 4096, device=device, dtype=dtype)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
quant_fp8 = QuantFP8(False, group_shape, column_major_scales=col_major)
|
||||
|
||||
if provider == "torch":
|
||||
fn = lambda: torch_per_token_quant_fp8(x.clone())
|
||||
fn = lambda: bench_compile(quant_fp8.forward_native)(x.clone())
|
||||
elif provider == "cuda":
|
||||
fn = lambda: cuda_per_token_quant_fp8(x.clone())
|
||||
fn = lambda: quant_fp8.forward_cuda(x.clone())
|
||||
elif provider == "triton":
|
||||
if not group_shape.is_per_group():
|
||||
# Triton only supported for per-group
|
||||
return 0, 0, 0
|
||||
|
||||
fn = lambda: with_triton_mode(quant_fp8.forward_cuda)(x.clone())
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
# TODO(luka) extract to utils
|
||||
def compute_geomean_speedups(
|
||||
df: pd.DataFrame,
|
||||
baseline_col: str,
|
||||
speedup_cols: list[str],
|
||||
groupby_cols: list[str] | None = None,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Compute geometric mean speedups over a baseline column.
|
||||
|
||||
Args:
|
||||
df: Input dataframe
|
||||
baseline_col: Column to use as baseline
|
||||
speedup_cols: Columns to compute speedups for
|
||||
groupby_cols: Columns to group by. If None, compute over entire df.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame with geometric mean speedups
|
||||
"""
|
||||
from scipy.stats import gmean
|
||||
|
||||
def geo_speedup(group: pd.DataFrame) -> pd.Series:
|
||||
ratios = {
|
||||
col: (group[baseline_col] / group[col]).values for col in speedup_cols
|
||||
}
|
||||
return pd.Series({col: gmean(vals) for col, vals in ratios.items()})
|
||||
|
||||
if groupby_cols is None:
|
||||
result = geo_speedup(df).to_frame().T
|
||||
else:
|
||||
result = (
|
||||
df.groupby(groupby_cols)
|
||||
.apply(geo_speedup, include_groups=False)
|
||||
.reset_index()
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
calculate_diff(batch_size=4, seq_len=4096)
|
||||
benchmark_quantization.run(print_data=True)
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark the various implementations of QuantFP8 (dynamic-only)"
|
||||
)
|
||||
parser.add_argument("-c", "--check", action="store_true")
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, choices=["half", "bfloat16", "float"], default="half"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hidden-sizes",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=None,
|
||||
help="Hidden sizes to benchmark (default: 1,16,64,128,256,512,1024,2048,4096)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-sizes",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=None,
|
||||
help="Batch sizes to benchmark (default: 1,16,32,64,128)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--group-sizes",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=None,
|
||||
help="Group sizes for GroupShape(1,N) to benchmark. "
|
||||
"Use 0 for PER_TENSOR, -1 for PER_TOKEN (default: 0,-1,64,128)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-column-major",
|
||||
action="store_true",
|
||||
help="Disable column-major scales testing",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
assert args
|
||||
|
||||
dtype = STR_DTYPE_TO_TORCH_DTYPE[args.dtype]
|
||||
|
||||
hidden_sizes = args.hidden_sizes or [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
|
||||
batch_sizes = args.batch_sizes or [1, 16, 32, 64, 128]
|
||||
|
||||
if args.group_sizes is not None:
|
||||
group_shapes = []
|
||||
for size in args.group_sizes:
|
||||
if size == 0:
|
||||
group_shapes.append(GroupShape.PER_TENSOR)
|
||||
elif size == -1:
|
||||
group_shapes.append(GroupShape.PER_TOKEN)
|
||||
else:
|
||||
group_shapes.append(GroupShape(1, size))
|
||||
else:
|
||||
group_shapes = [
|
||||
GroupShape.PER_TENSOR,
|
||||
GroupShape.PER_TOKEN,
|
||||
GroupShape(1, 64),
|
||||
GroupShape(1, 128),
|
||||
]
|
||||
|
||||
column_major_scales = [False] if args.no_column_major else [True, False]
|
||||
|
||||
config_gen = itertools.product(
|
||||
group_shapes,
|
||||
column_major_scales,
|
||||
batch_sizes,
|
||||
hidden_sizes,
|
||||
)
|
||||
|
||||
# filter out column-major scales for non-group, reverse order
|
||||
configs.extend(c[::-1] for c in config_gen if (c[0].is_per_group() or not c[1]))
|
||||
|
||||
print(f"Running {len(configs)} configurations:")
|
||||
print(f" Hidden sizes: {hidden_sizes}")
|
||||
print(f" Batch sizes: {batch_sizes}")
|
||||
print(f" Group shapes: {[str(g) for g in group_shapes]}")
|
||||
print(f" Column major scales: {column_major_scales}")
|
||||
print()
|
||||
|
||||
if args.check:
|
||||
for group_shape in group_shapes:
|
||||
group_size = group_shape[1]
|
||||
print(f"{group_size=}")
|
||||
calculate_diff(
|
||||
batch_size=4, hidden_size=4096, group_shape=group_shape, dtype=dtype
|
||||
)
|
||||
|
||||
benchmark = triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["hidden_size", "batch_size", "col_major", "group_shape"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["torch", "cuda", "triton"],
|
||||
line_names=["Torch (Compiled)", "CUDA", "Triton"],
|
||||
styles=[("blue", "-"), ("green", "-"), ("black", "-")],
|
||||
ylabel="us",
|
||||
plot_name="QuantFP8 performance",
|
||||
args={},
|
||||
)
|
||||
)(benchmark_quantization)
|
||||
|
||||
df = benchmark.run(print_data=True, dtype=dtype, return_df=True)
|
||||
|
||||
# Print geomean speedups
|
||||
geo_table_grouped = compute_geomean_speedups(
|
||||
df,
|
||||
baseline_col="Torch (Compiled)",
|
||||
speedup_cols=["CUDA", "Triton"],
|
||||
groupby_cols=["col_major", "group_shape"],
|
||||
)
|
||||
|
||||
print("Speedup over Torch (Compiled)")
|
||||
print(geo_table_grouped.to_string(index=False))
|
||||
|
||||
@ -13,6 +13,10 @@ import torch.utils.benchmark as benchmark
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
fp8_w8a8_moe_quant_config,
|
||||
nvfp4_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
|
||||
from vllm.scalar_type import scalar_types
|
||||
@ -140,6 +144,12 @@ def bench_run(
|
||||
a_fp8_scale: torch.Tensor,
|
||||
num_repeats: int,
|
||||
):
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a_fp8_scale,
|
||||
)
|
||||
|
||||
for _ in range(num_repeats):
|
||||
fused_experts(
|
||||
a,
|
||||
@ -147,10 +157,7 @@ def bench_run(
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a_fp8_scale,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def run_cutlass_moe_fp4(
|
||||
@ -172,25 +179,27 @@ def bench_run(
|
||||
device: torch.device,
|
||||
num_repeats: int,
|
||||
):
|
||||
quant_config = nvfp4_moe_quant_config(
|
||||
a1_gscale=a1_gs,
|
||||
a2_gscale=a2_gs,
|
||||
w1_scale=w1_blockscale,
|
||||
w2_scale=w2_blockscale,
|
||||
g1_alphas=w1_gs,
|
||||
g2_alphas=w2_gs,
|
||||
)
|
||||
for _ in range(num_repeats):
|
||||
with nvtx.annotate("cutlass_moe_fp4", color="green"):
|
||||
cutlass_moe_fp4(
|
||||
a=a,
|
||||
a1_gscale=a1_gs,
|
||||
a2_gscale=a2_gs,
|
||||
w1_fp4=w1_fp4,
|
||||
w1_blockscale=w1_blockscale,
|
||||
w1_alphas=w1_gs,
|
||||
w2_fp4=w2_fp4,
|
||||
w2_blockscale=w2_blockscale,
|
||||
w2_alphas=w2_gs,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
m=m,
|
||||
n=n,
|
||||
k=k,
|
||||
e=num_experts,
|
||||
device=device,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def run_cutlass_from_graph(
|
||||
@ -211,26 +220,29 @@ def bench_run(
|
||||
e: int,
|
||||
device: torch.device,
|
||||
):
|
||||
quant_config = nvfp4_moe_quant_config(
|
||||
a1_gscale=a1_gs,
|
||||
a2_gscale=a2_gs,
|
||||
w1_scale=w1_blockscale,
|
||||
w2_scale=w2_blockscale,
|
||||
g1_alphas=w1_gs,
|
||||
g2_alphas=w2_gs,
|
||||
)
|
||||
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||
):
|
||||
return cutlass_moe_fp4(
|
||||
a=a,
|
||||
a1_gscale=a1_gs,
|
||||
w1_fp4=w1_fp4,
|
||||
w1_blockscale=w1_blockscale,
|
||||
w1_alphas=w1_alphas,
|
||||
a2_gscale=a2_gs,
|
||||
w2_fp4=w2_fp4,
|
||||
w2_blockscale=w2_blockscale,
|
||||
w2_alphas=w2_alphas,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
m=m,
|
||||
n=n,
|
||||
k=k,
|
||||
e=num_experts,
|
||||
device=device,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def run_triton_from_graph(
|
||||
@ -246,16 +258,18 @@ def bench_run(
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||
):
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a_fp8_scale,
|
||||
)
|
||||
return fused_experts(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a_fp8_scale,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def replay_graph(graph, num_repeats):
|
||||
|
||||
486
benchmarks/kernels/benchmark_device_communicators.py
Normal file
486
benchmarks/kernels/benchmark_device_communicators.py
Normal file
@ -0,0 +1,486 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""
|
||||
Benchmark script for device communicators:
|
||||
CustomAllreduce (oneshot, twoshot), PyNcclCommunicator,
|
||||
and SymmMemCommunicator (multimem, two-shot).
|
||||
|
||||
Usage:
|
||||
torchrun --nproc_per_node=<N> benchmark_device_communicators.py [options]
|
||||
|
||||
Example:
|
||||
torchrun --nproc_per_node=2 benchmark_device_communicators.py
|
||||
--sequence-lengths 512 1024 2048 --num-warmup 10 --num-trials 100
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
from vllm.distributed.device_communicators.symm_mem import SymmMemCommunicator
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Default sequence lengths to benchmark
|
||||
DEFAULT_SEQUENCE_LENGTHS = [128, 512, 1024, 2048, 4096, 8192]
|
||||
|
||||
# Fixed hidden size and dtype for all benchmarks
|
||||
HIDDEN_SIZE = 8192
|
||||
BENCHMARK_DTYPE = torch.bfloat16
|
||||
|
||||
# CUDA graph settings
|
||||
CUDA_GRAPH_CAPTURE_CYCLES = 10
|
||||
|
||||
|
||||
class CommunicatorBenchmark:
|
||||
"""Benchmark class for testing device communicators."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
device: torch.device,
|
||||
cpu_group: ProcessGroup,
|
||||
sequence_lengths: list[int],
|
||||
):
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.device = device
|
||||
self.cpu_group = cpu_group
|
||||
|
||||
# Calculate max_size_override based on largest sequence length
|
||||
max_seq_len = max(sequence_lengths)
|
||||
max_tensor_elements = max_seq_len * HIDDEN_SIZE
|
||||
self.max_size_override = max_tensor_elements * BENCHMARK_DTYPE.itemsize + 1
|
||||
|
||||
# Initialize communicators
|
||||
self.custom_allreduce = None
|
||||
self.pynccl_comm = None
|
||||
self.symm_mem_comm = None
|
||||
self.symm_mem_comm_multimem = None
|
||||
self.symm_mem_comm_two_shot = None
|
||||
|
||||
self._init_communicators()
|
||||
|
||||
def _init_communicators(self):
|
||||
"""Initialize all available communicators."""
|
||||
try:
|
||||
self.custom_allreduce = CustomAllreduce(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
max_size=self.max_size_override,
|
||||
)
|
||||
if not self.custom_allreduce.disabled:
|
||||
logger.info("Rank %s: CustomAllreduce initialized", self.rank)
|
||||
else:
|
||||
logger.info("Rank %s: CustomAllreduce disabled", self.rank)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Rank %s: Failed to initialize CustomAllreduce: %s", self.rank, e
|
||||
)
|
||||
self.custom_allreduce = None
|
||||
|
||||
try:
|
||||
self.pynccl_comm = PyNcclCommunicator(
|
||||
group=self.cpu_group, device=self.device
|
||||
)
|
||||
if not self.pynccl_comm.disabled:
|
||||
logger.info("Rank %s: PyNcclCommunicator initialized", self.rank)
|
||||
else:
|
||||
logger.info("Rank %s: PyNcclCommunicator disabled", self.rank)
|
||||
self.pynccl_comm = None
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Rank %s: Failed to initialize PyNcclCommunicator: %s", self.rank, e
|
||||
)
|
||||
self.pynccl_comm = None
|
||||
|
||||
# Initialize variants for SymmMemCommunicator
|
||||
try:
|
||||
self.symm_mem_comm_multimem = SymmMemCommunicator(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
force_multimem=True,
|
||||
max_size_override=self.max_size_override,
|
||||
)
|
||||
if not self.symm_mem_comm_multimem.disabled:
|
||||
logger.info(
|
||||
"Rank %s: SymmMemCommunicator (multimem) initialized", self.rank
|
||||
)
|
||||
else:
|
||||
self.symm_mem_comm_multimem = None
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Rank %s: Failed to initialize SymmMemCommunicator (multimem): %s",
|
||||
self.rank,
|
||||
e,
|
||||
)
|
||||
self.symm_mem_comm_multimem = None
|
||||
|
||||
try:
|
||||
self.symm_mem_comm_two_shot = SymmMemCommunicator(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
force_multimem=False,
|
||||
max_size_override=self.max_size_override,
|
||||
)
|
||||
if not self.symm_mem_comm_two_shot.disabled:
|
||||
logger.info(
|
||||
"Rank %s: SymmMemCommunicator (two_shot) initialized", self.rank
|
||||
)
|
||||
else:
|
||||
self.symm_mem_comm_two_shot = None
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Rank %s: Failed to initialize SymmMemCommunicator (two_shot): %s",
|
||||
self.rank,
|
||||
e,
|
||||
)
|
||||
self.symm_mem_comm_two_shot = None
|
||||
|
||||
def benchmark_allreduce(
|
||||
self, sequence_length: int, num_warmup: int, num_trials: int
|
||||
) -> dict[str, float]:
|
||||
"""Benchmark allreduce operations for all available communicators."""
|
||||
|
||||
results = {}
|
||||
|
||||
# Define communicators with their benchmark functions
|
||||
communicators = []
|
||||
|
||||
if self.custom_allreduce is not None:
|
||||
comm = self.custom_allreduce
|
||||
# CustomAllreduce one-shot
|
||||
communicators.append(
|
||||
(
|
||||
"ca_1stage",
|
||||
lambda t, c=comm: c.custom_all_reduce(t),
|
||||
lambda t, c=comm: c.should_custom_ar(t),
|
||||
comm.capture(),
|
||||
"1stage", # env variable value
|
||||
)
|
||||
)
|
||||
# CustomAllreduce two-shot
|
||||
communicators.append(
|
||||
(
|
||||
"ca_2stage",
|
||||
lambda t, c=comm: c.custom_all_reduce(t),
|
||||
lambda t, c=comm: c.should_custom_ar(t),
|
||||
comm.capture(),
|
||||
"2stage", # env variable value
|
||||
)
|
||||
)
|
||||
|
||||
if self.pynccl_comm is not None:
|
||||
comm = self.pynccl_comm
|
||||
communicators.append(
|
||||
(
|
||||
"pynccl",
|
||||
lambda t, c=comm: c.all_reduce(t),
|
||||
lambda t: True, # Always available if initialized
|
||||
nullcontext(),
|
||||
None, # no env variable needed
|
||||
)
|
||||
)
|
||||
|
||||
if self.symm_mem_comm_multimem is not None:
|
||||
comm = self.symm_mem_comm_multimem
|
||||
communicators.append(
|
||||
(
|
||||
"symm_mem_multimem",
|
||||
lambda t, c=comm: c.all_reduce(t),
|
||||
lambda t, c=comm: c.should_use_symm_mem(t),
|
||||
nullcontext(),
|
||||
None, # no env variable needed
|
||||
)
|
||||
)
|
||||
|
||||
if self.symm_mem_comm_two_shot is not None:
|
||||
comm = self.symm_mem_comm_two_shot
|
||||
communicators.append(
|
||||
(
|
||||
"symm_mem_two_shot",
|
||||
lambda t, c=comm: c.all_reduce(t),
|
||||
lambda t, c=comm: c.should_use_symm_mem(t),
|
||||
nullcontext(),
|
||||
None, # no env variable needed
|
||||
)
|
||||
)
|
||||
|
||||
# Benchmark each communicator
|
||||
for name, allreduce_fn, should_use_fn, context, env_var in communicators:
|
||||
# Set environment variable if needed
|
||||
if env_var is not None:
|
||||
os.environ["VLLM_CUSTOM_ALLREDUCE_ALGO"] = env_var
|
||||
else:
|
||||
# Clear the environment variable to avoid interference
|
||||
os.environ.pop("VLLM_CUSTOM_ALLREDUCE_ALGO", None)
|
||||
|
||||
latency = self.benchmark_allreduce_single(
|
||||
sequence_length,
|
||||
allreduce_fn,
|
||||
should_use_fn,
|
||||
context,
|
||||
num_warmup,
|
||||
num_trials,
|
||||
)
|
||||
if latency is not None:
|
||||
results[name] = latency
|
||||
|
||||
return results
|
||||
|
||||
def benchmark_allreduce_single(
|
||||
self,
|
||||
sequence_length: int,
|
||||
allreduce_fn: Callable[[torch.Tensor], Optional[torch.Tensor]],
|
||||
should_use_fn: Callable[[torch.Tensor], bool],
|
||||
context,
|
||||
num_warmup: int,
|
||||
num_trials: int,
|
||||
) -> Optional[float]:
|
||||
"""Benchmark method with CUDA graph optimization."""
|
||||
try:
|
||||
# Create test tensor (2D: sequence_length x hidden_size)
|
||||
tensor = torch.randn(
|
||||
sequence_length, HIDDEN_SIZE, dtype=BENCHMARK_DTYPE, device=self.device
|
||||
)
|
||||
if not should_use_fn(tensor):
|
||||
return None
|
||||
|
||||
torch.cuda.synchronize()
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
graph_input = tensor.clone()
|
||||
|
||||
# Warmup before capture
|
||||
for _ in range(3):
|
||||
allreduce_fn(graph_input)
|
||||
|
||||
# Capture the graph using context manager
|
||||
with context:
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
for _ in range(CUDA_GRAPH_CAPTURE_CYCLES):
|
||||
allreduce_fn(graph_input)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
for _ in range(num_warmup):
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.perf_counter()
|
||||
|
||||
for _ in range(num_trials):
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
end_time = time.perf_counter()
|
||||
|
||||
# Convert to ms and divide by CUDA_GRAPH_CAPTURE_CYCLES
|
||||
return (
|
||||
(end_time - start_time) / num_trials / CUDA_GRAPH_CAPTURE_CYCLES * 1000
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("CUDA graph benchmark failed: %s", e)
|
||||
raise RuntimeError(
|
||||
f"CUDA graph benchmark failed for communicator: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
def _calculate_speedup_info(comm_results: dict[str, float]) -> str:
|
||||
"""Calculate speedup information for a single tensor size."""
|
||||
if not comm_results:
|
||||
return "N/A"
|
||||
|
||||
# Find the fastest communicator
|
||||
fastest_comm = min(comm_results.keys(), key=lambda k: comm_results[k])
|
||||
fastest_time = comm_results[fastest_comm]
|
||||
|
||||
# Calculate speedup vs PyNccl if available
|
||||
if "pynccl" in comm_results:
|
||||
pynccl_time = comm_results["pynccl"]
|
||||
speedup = pynccl_time / fastest_time
|
||||
return f"{fastest_comm} ({speedup:.2f}x)"
|
||||
else:
|
||||
return f"{fastest_comm} (N/A)"
|
||||
|
||||
|
||||
def print_results(
|
||||
results: dict[str, dict[str, float]], sequence_lengths: list[int], world_size: int
|
||||
):
|
||||
"""Print benchmark results in a formatted table."""
|
||||
|
||||
print(f"\n{'=' * 130}")
|
||||
print("Device Communicator Benchmark Results")
|
||||
print(
|
||||
f"World Size: {world_size}, Data Type: {BENCHMARK_DTYPE}, "
|
||||
f"Hidden Size: {HIDDEN_SIZE}"
|
||||
)
|
||||
print(f"{'=' * 130}")
|
||||
|
||||
# Get all communicator names
|
||||
all_comms = set()
|
||||
for size_results in results.values():
|
||||
all_comms.update(size_results.keys())
|
||||
|
||||
all_comms = sorted(list(all_comms))
|
||||
|
||||
# Print header
|
||||
header = f"{'Tensor Shape':<20}{'Tensor Size':<15}"
|
||||
for comm in all_comms:
|
||||
header += f"{comm:<20}"
|
||||
header += f"{'Best (Speedup vs PyNccl)':<30}"
|
||||
print(header)
|
||||
print("-" * len(header))
|
||||
|
||||
# Print results for each sequence length
|
||||
for seq_len in sequence_lengths:
|
||||
if seq_len in results:
|
||||
# Calculate tensor size in elements and bytes
|
||||
tensor_elements = seq_len * HIDDEN_SIZE
|
||||
tensor_bytes = tensor_elements * BENCHMARK_DTYPE.itemsize
|
||||
|
||||
# Format tensor size (MB)
|
||||
tensor_size_mb = tensor_bytes / (1024 * 1024)
|
||||
tensor_size_str = f"{tensor_size_mb:.2f} MB"
|
||||
|
||||
# Format tensor shape
|
||||
tensor_shape = f"({seq_len}, {HIDDEN_SIZE})"
|
||||
|
||||
row = f"{tensor_shape:<20}{tensor_size_str:<15}"
|
||||
for comm in all_comms:
|
||||
if comm in results[seq_len]:
|
||||
row += f"{results[seq_len][comm]:<20.3f}"
|
||||
else:
|
||||
row += f"{'N/A':<20}"
|
||||
|
||||
# Calculate speedup information
|
||||
speedup_info = _calculate_speedup_info(results[seq_len])
|
||||
row += f"{speedup_info:<30}"
|
||||
|
||||
print(row)
|
||||
|
||||
print(f"{'=' * 130}")
|
||||
print("All times are in milliseconds (ms) per allreduce operation")
|
||||
print("Speedup column shows: fastest_algorithm (speedup_vs_pynccl)")
|
||||
|
||||
|
||||
def main():
|
||||
parser = FlexibleArgumentParser(description="Benchmark device communicators")
|
||||
|
||||
parser.add_argument(
|
||||
"--sequence-lengths",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=DEFAULT_SEQUENCE_LENGTHS,
|
||||
help="Sequence lengths to benchmark (tensor shape: seq_len x hidden_size)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-warmup", type=int, default=5, help="Number of warmup iterations"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-trials", type=int, default=50, help="Number of benchmark trials"
|
||||
)
|
||||
|
||||
parser.add_argument("--output-json", type=str, help="Output results to JSON file")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Initialize distributed
|
||||
if not dist.is_initialized():
|
||||
dist.init_process_group(backend="gloo")
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
# Set device
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
# Get CPU process group
|
||||
cpu_group = dist.new_group(backend="gloo")
|
||||
|
||||
# Disable USE_SYMM_MEM to avoid affecting the max_sizes
|
||||
# in symm_mem and custom_all_reduce for benchmark
|
||||
os.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0"
|
||||
|
||||
# Initialize benchmark
|
||||
benchmark = CommunicatorBenchmark(
|
||||
rank, world_size, device, cpu_group, args.sequence_lengths
|
||||
)
|
||||
|
||||
# Run benchmarks
|
||||
all_results = {}
|
||||
|
||||
for seq_len in args.sequence_lengths:
|
||||
if rank == 0:
|
||||
logger.info(
|
||||
"Benchmarking sequence length: %s (tensor shape: %s x %s)",
|
||||
seq_len,
|
||||
seq_len,
|
||||
HIDDEN_SIZE,
|
||||
)
|
||||
|
||||
results = benchmark.benchmark_allreduce(
|
||||
sequence_length=seq_len,
|
||||
num_warmup=args.num_warmup,
|
||||
num_trials=args.num_trials,
|
||||
)
|
||||
|
||||
all_results[seq_len] = results
|
||||
|
||||
# Synchronize between ranks
|
||||
dist.barrier()
|
||||
|
||||
# Print results (only rank 0)
|
||||
if rank == 0:
|
||||
print_results(all_results, args.sequence_lengths, world_size)
|
||||
|
||||
# Save to JSON if requested
|
||||
if args.output_json:
|
||||
# Add speedup information to results
|
||||
enhanced_results = {}
|
||||
for seq_len, comm_results in all_results.items():
|
||||
enhanced_results[seq_len] = {
|
||||
"timings": comm_results,
|
||||
"speedup_info": _calculate_speedup_info(comm_results),
|
||||
}
|
||||
|
||||
output_data = {
|
||||
"world_size": world_size,
|
||||
"dtype": str(BENCHMARK_DTYPE),
|
||||
"hidden_size": HIDDEN_SIZE,
|
||||
"sequence_lengths": args.sequence_lengths,
|
||||
"num_warmup": args.num_warmup,
|
||||
"num_trials": args.num_trials,
|
||||
"cuda_graph_capture_cycles": CUDA_GRAPH_CAPTURE_CYCLES,
|
||||
"results": enhanced_results,
|
||||
}
|
||||
|
||||
with open(args.output_json, "w") as f:
|
||||
json.dump(output_data, f, indent=2)
|
||||
|
||||
logger.info("Results saved to %s", args.output_json)
|
||||
|
||||
# Cleanup
|
||||
if cpu_group != dist.group.WORLD:
|
||||
dist.destroy_process_group(cpu_group)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -7,6 +7,7 @@ from benchmark_shapes import WEIGHT_SHAPES_MOE
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_experts,
|
||||
@ -96,6 +97,11 @@ def bench_run(
|
||||
a_scale: torch.Tensor,
|
||||
num_repeats: int,
|
||||
):
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a_scale,
|
||||
)
|
||||
for _ in range(num_repeats):
|
||||
fused_experts(
|
||||
a,
|
||||
@ -103,10 +109,7 @@ def bench_run(
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a_scale,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def run_cutlass_moe(
|
||||
@ -125,6 +128,12 @@ def bench_run(
|
||||
per_act_token: bool,
|
||||
num_repeats: int,
|
||||
):
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
per_act_token_quant=per_act_token,
|
||||
)
|
||||
|
||||
for _ in range(num_repeats):
|
||||
cutlass_moe_fp8(
|
||||
a,
|
||||
@ -132,14 +141,11 @@ def bench_run(
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
per_act_token,
|
||||
a1_scale=None,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def run_cutlass_from_graph(
|
||||
@ -156,6 +162,12 @@ def bench_run(
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
):
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
per_act_token_quant=per_act_token,
|
||||
)
|
||||
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||
):
|
||||
@ -165,14 +177,11 @@ def bench_run(
|
||||
w2_q,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
per_act_token,
|
||||
a1_scale=None,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def run_triton_from_graph(
|
||||
@ -185,6 +194,11 @@ def bench_run(
|
||||
w2_scale: torch.Tensor,
|
||||
a_scale: torch.Tensor,
|
||||
):
|
||||
quant_config = fp8_w8a8_moe_quant_config(
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a_scale,
|
||||
)
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||
):
|
||||
@ -194,10 +208,7 @@ def bench_run(
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a_scale,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def replay_graph(graph, num_repeats):
|
||||
|
||||
@ -464,7 +464,11 @@ class BenchmarkTensors:
|
||||
for field_name in LoRAKernelMeta.__dataclass_fields__:
|
||||
field = getattr(self.lora_kernel_meta, field_name)
|
||||
assert isinstance(field, torch.Tensor)
|
||||
setattr(self.lora_kernel_meta, field_name, to_device(field))
|
||||
setattr(
|
||||
self.lora_kernel_meta,
|
||||
field_name,
|
||||
to_device(field) if field_name != "no_lora_flag_cpu" else field,
|
||||
)
|
||||
|
||||
def metadata(self) -> tuple[int, int, int]:
|
||||
"""
|
||||
@ -512,6 +516,7 @@ class BenchmarkTensors:
|
||||
"lora_token_start_loc": self.lora_kernel_meta.lora_token_start_loc,
|
||||
"lora_ids": self.lora_kernel_meta.active_lora_ids,
|
||||
"scaling": 1.0,
|
||||
"no_lora_flag_cpu": self.lora_kernel_meta.no_lora_flag_cpu,
|
||||
}
|
||||
|
||||
def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]:
|
||||
@ -552,6 +557,7 @@ class BenchmarkTensors:
|
||||
"lora_ids": self.lora_kernel_meta.active_lora_ids,
|
||||
"offset_start": 0,
|
||||
"add_inputs": add_inputs,
|
||||
"no_lora_flag_cpu": self.lora_kernel_meta.no_lora_flag_cpu,
|
||||
}
|
||||
|
||||
def bench_fn_kwargs(
|
||||
|
||||
@ -14,6 +14,10 @@ import ray
|
||||
import torch
|
||||
from ray.experimental.tqdm_ray import tqdm
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
_get_config_dtype_str,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import *
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.config import get_config
|
||||
@ -134,43 +138,36 @@ def benchmark_config(
|
||||
def run():
|
||||
from vllm.model_executor.layers.fused_moe import override_config
|
||||
|
||||
if use_fp8_w8a8:
|
||||
quant_dtype = torch.float8_e4m3fn
|
||||
elif use_int8_w8a16:
|
||||
quant_dtype = torch.int8
|
||||
else:
|
||||
quant_dtype = None
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
quant_dtype=quant_dtype,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_quant_shape,
|
||||
)
|
||||
|
||||
with override_config(config):
|
||||
if use_deep_gemm:
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
x, input_gating, topk, False
|
||||
)
|
||||
return fused_experts(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=True,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_quant_shape,
|
||||
allow_deep_gemm=True,
|
||||
)
|
||||
else:
|
||||
fused_moe(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
input_gating,
|
||||
topk,
|
||||
renormalize=True,
|
||||
inplace=True,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_quant_shape,
|
||||
)
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
x, input_gating, topk, renormalize=not use_deep_gemm
|
||||
)
|
||||
return fused_experts(
|
||||
x,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=True,
|
||||
quant_config=quant_config,
|
||||
allow_deep_gemm=use_deep_gemm,
|
||||
)
|
||||
|
||||
# JIT compilation & warmup
|
||||
run()
|
||||
@ -414,7 +411,7 @@ class BenchmarkWorker:
|
||||
use_deep_gemm: bool = False,
|
||||
) -> tuple[dict[str, int], float]:
|
||||
current_platform.seed_everything(self.seed)
|
||||
dtype_str = get_config_dtype_str(
|
||||
dtype_str = _get_config_dtype_str(
|
||||
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
|
||||
)
|
||||
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||
@ -547,7 +544,7 @@ def save_configs(
|
||||
block_quant_shape: list[int],
|
||||
save_dir: str,
|
||||
) -> None:
|
||||
dtype_str = get_config_dtype_str(
|
||||
dtype_str = _get_config_dtype_str(
|
||||
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
|
||||
)
|
||||
|
||||
@ -560,7 +557,7 @@ def save_configs(
|
||||
filename = os.path.join(save_dir, filename)
|
||||
print(f"Writing best config to {filename}...")
|
||||
with open(filename, "w") as f:
|
||||
json.dump(configs, f, indent=4)
|
||||
json.dump({"triton_version": triton.__version__, **configs}, f, indent=4)
|
||||
f.write("\n")
|
||||
|
||||
|
||||
@ -594,7 +591,11 @@ def main(args: argparse.Namespace):
|
||||
E = config.n_routed_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
elif config.architectures[0] in ("Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"):
|
||||
elif config.architectures[0] in (
|
||||
"Qwen2MoeForCausalLM",
|
||||
"Qwen3MoeForCausalLM",
|
||||
"Qwen3NextForCausalLM",
|
||||
):
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
|
||||
155
benchmarks/kernels/benchmark_polynorm.py
Normal file
155
benchmarks/kernels/benchmark_polynorm.py
Normal file
@ -0,0 +1,155 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
|
||||
def polynorm_naive(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
orig_shape = x.shape
|
||||
x = x.view(-1, x.shape[-1])
|
||||
|
||||
def norm(x, eps: float):
|
||||
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
|
||||
|
||||
x = x.float()
|
||||
return (
|
||||
(
|
||||
weight[0] * norm(x**3, eps)
|
||||
+ weight[1] * norm(x**2, eps)
|
||||
+ weight[2] * norm(x, eps)
|
||||
+ bias
|
||||
)
|
||||
.to(weight.dtype)
|
||||
.view(orig_shape)
|
||||
)
|
||||
|
||||
|
||||
def polynorm_vllm(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
orig_shape = x.shape
|
||||
x = x.view(-1, x.shape[-1])
|
||||
|
||||
out = torch.empty_like(x)
|
||||
vllm_ops.poly_norm(out, x, weight, bias, eps)
|
||||
output = out
|
||||
|
||||
output = output.view(orig_shape)
|
||||
return output
|
||||
|
||||
|
||||
def calculate_diff(batch_size, seq_len, hidden_dim):
|
||||
dtype = torch.bfloat16
|
||||
x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda")
|
||||
weight = torch.ones(3, dtype=dtype, device="cuda")
|
||||
bias = torch.ones(1, dtype=dtype, device="cuda")
|
||||
|
||||
output_naive = polynorm_naive(x, weight, bias)
|
||||
output_vllm = polynorm_vllm(x, weight, bias)
|
||||
|
||||
if torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2):
|
||||
print("✅ All implementations match")
|
||||
else:
|
||||
print("❌ Implementations differ")
|
||||
|
||||
|
||||
batch_size_range = [2**i for i in range(0, 7, 2)]
|
||||
seq_length_range = [2**i for i in range(6, 11, 1)]
|
||||
dim_range = [2048, 4096]
|
||||
configs = list(itertools.product(dim_range, batch_size_range, seq_length_range))
|
||||
|
||||
|
||||
def get_benchmark():
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["dim", "batch_size", "seq_len"],
|
||||
x_vals=[list(_) for _ in configs],
|
||||
line_arg="provider",
|
||||
line_vals=["naive", "vllm"],
|
||||
line_names=["Naive", "vLLM"],
|
||||
styles=[("blue", "-"), ("red", "-")],
|
||||
ylabel="us",
|
||||
plot_name="polynorm-perf",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(dim, batch_size, seq_len, provider):
|
||||
dtype = torch.bfloat16
|
||||
hidden_dim = dim * 4
|
||||
|
||||
x = torch.randn(batch_size, seq_len, hidden_dim, dtype=dtype, device="cuda")
|
||||
weight = torch.ones(3, dtype=dtype, device="cuda")
|
||||
bias = torch.ones(1, dtype=dtype, device="cuda")
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "naive":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: polynorm_naive(x, weight, bias),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
else:
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: polynorm_vllm(x, weight, bias),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
return benchmark
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Batch size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seq-len",
|
||||
type=int,
|
||||
default=128,
|
||||
help="Sequence length",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hidden-dim",
|
||||
type=int,
|
||||
default=8192,
|
||||
help="Intermediate size of MLP",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-path",
|
||||
type=str,
|
||||
default="./configs/polnorm/",
|
||||
help="Path to save polnorm benchmark results",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Run correctness test
|
||||
calculate_diff(
|
||||
batch_size=args.batch_size,
|
||||
seq_len=args.seq_len,
|
||||
hidden_dim=args.hidden_dim,
|
||||
)
|
||||
|
||||
benchmark = get_benchmark()
|
||||
# Run performance benchmark
|
||||
benchmark.run(print_data=True, save_path=args.save_path)
|
||||
@ -1,77 +1,675 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
silu_mul_fp8_quant_deep_gemm,
|
||||
silu_mul_fp8_quant_deep_gemm_cuda,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
||||
|
||||
|
||||
def benchmark(E, T, H, G=128, runs=50):
|
||||
current_platform.seed_everything(42)
|
||||
y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda")
|
||||
tokens_per_expert = torch.randint(
|
||||
T // 2, T, size=(E,), dtype=torch.int32, device="cuda"
|
||||
@triton.jit
|
||||
def _silu_mul_fp8_quant_deep_gemm(
|
||||
# Pointers ------------------------------------------------------------
|
||||
input_ptr, # 16-bit activations (E, T, 2*H)
|
||||
y_q_ptr, # fp8 quantized activations (E, T, H)
|
||||
y_s_ptr, # 16-bit scales (E, T, G)
|
||||
counts_ptr, # int32 num tokens per expert (E)
|
||||
# Sizes ---------------------------------------------------------------
|
||||
H: tl.constexpr, # hidden dimension (per output)
|
||||
GROUP_SIZE: tl.constexpr, # elements per group (usually 128)
|
||||
# Strides for input (elements) ---------------------------------------
|
||||
stride_i_e,
|
||||
stride_i_t,
|
||||
stride_i_h,
|
||||
# Strides for y_q (elements) -----------------------------------------
|
||||
stride_yq_e,
|
||||
stride_yq_t,
|
||||
stride_yq_h,
|
||||
# Strides for y_s (elements) -----------------------------------------
|
||||
stride_ys_e,
|
||||
stride_ys_t,
|
||||
stride_ys_g,
|
||||
# Stride for counts (elements)
|
||||
stride_counts_e,
|
||||
# Numeric params ------------------------------------------------------
|
||||
eps: tl.constexpr,
|
||||
fp8_min: tl.constexpr,
|
||||
fp8_max: tl.constexpr,
|
||||
use_ue8m0: tl.constexpr,
|
||||
# Meta ---------------------------------------------------------------
|
||||
BLOCK: tl.constexpr,
|
||||
NUM_STAGES: tl.constexpr,
|
||||
):
|
||||
G = H // GROUP_SIZE
|
||||
|
||||
# map program id -> (e, g)
|
||||
pid = tl.program_id(0)
|
||||
e = pid // G
|
||||
g = pid % G
|
||||
|
||||
e = e.to(tl.int64)
|
||||
g = g.to(tl.int64)
|
||||
|
||||
# number of valid tokens for this expert
|
||||
n_tokens = tl.load(counts_ptr + e * stride_counts_e).to(tl.int64)
|
||||
|
||||
cols = tl.arange(0, BLOCK).to(tl.int64)
|
||||
mask = cols < BLOCK
|
||||
|
||||
base_input_offset = e * stride_i_e + g * GROUP_SIZE * stride_i_h
|
||||
base_gate_offset = base_input_offset + cols * stride_i_h
|
||||
base_up_offset = base_input_offset + H * stride_i_h + cols * stride_i_h
|
||||
base_yq_offset = e * stride_yq_e + g * GROUP_SIZE * stride_yq_h + cols * stride_yq_h
|
||||
base_ys_offset = e * stride_ys_e + g * stride_ys_g
|
||||
|
||||
for t in tl.range(0, n_tokens, num_stages=NUM_STAGES):
|
||||
gate = tl.load(
|
||||
input_ptr + base_gate_offset + t * stride_i_t, mask=mask, other=0.0
|
||||
).to(tl.float32)
|
||||
up = tl.load(input_ptr + base_up_offset + t * stride_i_t, mask=mask, other=0.0)
|
||||
|
||||
gate = gate * (1.0 / (1.0 + tl.exp(-gate)))
|
||||
y = gate * up
|
||||
|
||||
y_s = tl.maximum(tl.max(tl.abs(y)), eps) / fp8_max
|
||||
if use_ue8m0:
|
||||
y_s = tl.exp2(tl.ceil(tl.log2(y_s)))
|
||||
|
||||
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
||||
|
||||
tl.store(y_q_ptr + base_yq_offset + t * stride_yq_t, y_q, mask=mask)
|
||||
tl.store(y_s_ptr + base_ys_offset + t * stride_ys_t, y_s)
|
||||
|
||||
|
||||
def silu_mul_fp8_quant_deep_gemm_triton(
|
||||
y: torch.Tensor, # (E, T, 2*H)
|
||||
tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert
|
||||
num_parallel_tokens,
|
||||
group_size: int = 128,
|
||||
eps: float = 1e-10,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
|
||||
|
||||
y has shape (E, T, 2*H). The first half of the last dimension is
|
||||
silu-activated, multiplied by the second half, then quantized into FP8.
|
||||
|
||||
Returns `(y_q, y_s)` where
|
||||
* `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H]
|
||||
* `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T)
|
||||
"""
|
||||
assert y.ndim == 3, "y must be (E, T, 2*H)"
|
||||
E, T, H2 = y.shape
|
||||
assert H2 % 2 == 0, "last dim of y must be even (2*H)"
|
||||
H = H2 // 2
|
||||
G = (H + group_size - 1) // group_size
|
||||
assert H % group_size == 0, "H must be divisible by group_size"
|
||||
assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E, (
|
||||
"tokens_per_expert must be shape (E,)"
|
||||
)
|
||||
tokens_per_expert = tokens_per_expert.to(device=y.device, dtype=torch.int32)
|
||||
|
||||
# allocate outputs
|
||||
fp8_dtype = torch.float8_e4m3fn
|
||||
y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device)
|
||||
|
||||
# strides (elements)
|
||||
stride_i_e, stride_i_t, stride_i_h = y.stride()
|
||||
stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride()
|
||||
|
||||
# desired scale strides (elements): (T*G, 1, T)
|
||||
stride_ys_e = T * G
|
||||
stride_ys_t = 1
|
||||
stride_ys_g = T
|
||||
y_s = torch.empty_strided(
|
||||
(E, T, G),
|
||||
(stride_ys_e, stride_ys_t, stride_ys_g),
|
||||
dtype=torch.float32,
|
||||
device=y.device,
|
||||
)
|
||||
|
||||
stride_cnt_e = tokens_per_expert.stride()[0]
|
||||
|
||||
# Static grid over experts and H-groups.
|
||||
# A loop inside the kernel handles the token dim
|
||||
grid = (E * G,)
|
||||
|
||||
f_info = torch.finfo(fp8_dtype)
|
||||
fp8_max = f_info.max
|
||||
fp8_min = f_info.min
|
||||
|
||||
_silu_mul_fp8_quant_deep_gemm[grid](
|
||||
y,
|
||||
y_q,
|
||||
y_s,
|
||||
tokens_per_expert,
|
||||
H,
|
||||
group_size,
|
||||
stride_i_e,
|
||||
stride_i_t,
|
||||
stride_i_h,
|
||||
stride_yq_e,
|
||||
stride_yq_t,
|
||||
stride_yq_h,
|
||||
stride_ys_e,
|
||||
stride_ys_t,
|
||||
stride_ys_g,
|
||||
stride_cnt_e,
|
||||
eps,
|
||||
fp8_min,
|
||||
fp8_max,
|
||||
is_deep_gemm_e8m0_used(),
|
||||
BLOCK=group_size,
|
||||
NUM_STAGES=4,
|
||||
num_warps=1,
|
||||
)
|
||||
|
||||
return y_q, y_s
|
||||
|
||||
|
||||
# Parse generation strategies
|
||||
strategies = ["uniform", "max_t", "first_t"]
|
||||
|
||||
|
||||
def benchmark(
|
||||
kernel: Callable,
|
||||
E: int,
|
||||
T: int,
|
||||
H: int,
|
||||
total_tokens: int,
|
||||
num_parallel_tokens: int = 64,
|
||||
G: int = 128,
|
||||
runs: int = 200,
|
||||
num_warmups: int = 20,
|
||||
gen_strategy: str = "default",
|
||||
iterations_per_run: int = 20,
|
||||
):
|
||||
def generate_data(seed_offset=0):
|
||||
"""Generate input data with given seed offset"""
|
||||
current_platform.seed_everything(42 + seed_offset)
|
||||
y = torch.rand((E, T, 2 * H), dtype=torch.bfloat16, device="cuda").contiguous()
|
||||
|
||||
if gen_strategy == "uniform":
|
||||
r = torch.rand(size=(E,), device="cuda")
|
||||
r /= r.sum()
|
||||
r *= total_tokens
|
||||
tokens_per_expert = r.int()
|
||||
tokens_per_expert = torch.minimum(
|
||||
tokens_per_expert,
|
||||
torch.ones((E,), device=r.device, dtype=torch.int) * T,
|
||||
)
|
||||
elif gen_strategy == "max_t":
|
||||
tokens_per_expert = torch.empty(size=(E,), dtype=torch.int32, device="cuda")
|
||||
tokens_per_expert.fill_(total_tokens / E)
|
||||
elif gen_strategy == "first_t":
|
||||
tokens_per_expert = torch.zeros(size=(E,), dtype=torch.int32, device="cuda")
|
||||
tokens_per_expert[0] = min(T, total_tokens)
|
||||
else:
|
||||
raise ValueError(f"Unknown generation strategy: {gen_strategy}")
|
||||
return y, tokens_per_expert
|
||||
|
||||
dataset_count = 4
|
||||
# Pre-generate different input matrices for each iteration to avoid cache effects
|
||||
data_sets = [generate_data(i) for i in range(dataset_count)]
|
||||
|
||||
# Warmup
|
||||
for _ in range(10):
|
||||
silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G)
|
||||
torch.cuda.synchronize()
|
||||
y, tokens_per_expert = data_sets[0]
|
||||
for _ in range(num_warmups):
|
||||
kernel(
|
||||
y, tokens_per_expert, num_parallel_tokens=num_parallel_tokens, group_size=G
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
# Benchmark
|
||||
torch.cuda.synchronize()
|
||||
start = time.perf_counter()
|
||||
latencies: list[float] = []
|
||||
for _ in range(runs):
|
||||
silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G)
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
avg_time = (time.perf_counter() - start) / runs * 1000
|
||||
start_event.record()
|
||||
for i in range(iterations_per_run):
|
||||
y, tokens_per_expert = data_sets[i % dataset_count]
|
||||
kernel(
|
||||
y,
|
||||
tokens_per_expert,
|
||||
num_parallel_tokens=num_parallel_tokens,
|
||||
group_size=G,
|
||||
)
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
|
||||
# Calculate actual work done (only count valid tokens)
|
||||
total_time_ms = start_event.elapsed_time(end_event)
|
||||
per_iter_time_ms = total_time_ms / iterations_per_run
|
||||
latencies.append(per_iter_time_ms)
|
||||
|
||||
# Use median instead of average for better outlier handling
|
||||
median_time_ms = np.median(latencies)
|
||||
median_time_s = median_time_ms / 1000
|
||||
|
||||
# Calculate actual work done (using first dataset for consistency)
|
||||
_, tokens_per_expert = data_sets[0]
|
||||
actual_tokens = tokens_per_expert.sum().item()
|
||||
actual_elements = actual_tokens * H
|
||||
|
||||
# GFLOPS: operations per element = exp + 3 muls + 1 div + quantization ops ≈ 8 ops
|
||||
ops_per_element = 8
|
||||
total_ops = actual_elements * ops_per_element
|
||||
gflops = total_ops / (avg_time / 1000) / 1e9
|
||||
gflops = total_ops / median_time_s / 1e9
|
||||
|
||||
# Memory bandwidth: bfloat16 inputs (2 bytes), fp8 output (1 byte), scales (4 bytes)
|
||||
input_bytes = actual_tokens * 2 * H * 2 # 2*H bfloat16 inputs
|
||||
output_bytes = actual_tokens * H * 1 # H fp8 outputs
|
||||
scale_bytes = actual_tokens * (H // G) * 4 # scales in float32
|
||||
total_bytes = input_bytes + output_bytes + scale_bytes
|
||||
memory_bw = total_bytes / (avg_time / 1000) / 1e9
|
||||
memory_bw = total_bytes / median_time_s / 1e9
|
||||
|
||||
return avg_time, gflops, memory_bw
|
||||
HOPPER_BANDWIDTH_TBPS = 3.35
|
||||
return (
|
||||
median_time_ms,
|
||||
gflops,
|
||||
memory_bw,
|
||||
(memory_bw / (HOPPER_BANDWIDTH_TBPS * 1024)) * 100,
|
||||
)
|
||||
|
||||
|
||||
def create_comparison_plot(
|
||||
ratio, cuda_times, baseline_times, config_labels, strategy_name, id
|
||||
):
|
||||
"""Create a comparison plot for a specific generation strategy"""
|
||||
fig, ax = plt.subplots(1, 1, figsize=(16, 6))
|
||||
|
||||
# Configure x-axis positions
|
||||
x = np.arange(len(config_labels))
|
||||
width = 0.35
|
||||
|
||||
# Execution Time plot (lower is better)
|
||||
ax.bar(
|
||||
x - width / 2, cuda_times, width, label="CUDA Kernel", alpha=0.8, color="blue"
|
||||
)
|
||||
ax.bar(
|
||||
x + width / 2,
|
||||
baseline_times,
|
||||
width,
|
||||
label="Baseline",
|
||||
alpha=0.8,
|
||||
color="orange",
|
||||
)
|
||||
|
||||
# Add speedup labels over each bar pair
|
||||
for i in range(len(x)):
|
||||
speedup = ratio[i]
|
||||
max_height = max(cuda_times[i], baseline_times[i])
|
||||
ax.text(
|
||||
x[i],
|
||||
max_height + max_height * 0.02,
|
||||
f"{speedup:.2f}x",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontweight="bold",
|
||||
fontsize=9,
|
||||
)
|
||||
|
||||
ax.set_xlabel("Configuration")
|
||||
ax.set_ylabel("% Utilization")
|
||||
ax.set_title(
|
||||
f"Memory Bandwidth Utilization (%) - {strategy_name}\n(Higher is Better)"
|
||||
)
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(config_labels, rotation=45, ha="right")
|
||||
ax.legend()
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
return fig, ax
|
||||
|
||||
|
||||
def create_combined_plot(all_results):
|
||||
"""Create a combined plot with all strategies in one PNG"""
|
||||
num_strategies = len(all_results)
|
||||
fig, axes = plt.subplots(num_strategies, 1, figsize=(20, 6 * num_strategies))
|
||||
|
||||
if num_strategies == 1:
|
||||
axes = [axes]
|
||||
|
||||
for idx, (
|
||||
strategy_name,
|
||||
ratio,
|
||||
cuda_times,
|
||||
baseline_times,
|
||||
config_labels,
|
||||
) in enumerate(all_results):
|
||||
ax = axes[idx]
|
||||
|
||||
# Configure x-axis positions
|
||||
x = np.arange(len(config_labels))
|
||||
width = 0.35
|
||||
|
||||
# Execution Time plot (lower is better)
|
||||
ax.bar(
|
||||
x - width / 2,
|
||||
cuda_times,
|
||||
width,
|
||||
label="CUDA Kernel",
|
||||
alpha=0.8,
|
||||
color="blue",
|
||||
)
|
||||
ax.bar(
|
||||
x + width / 2,
|
||||
baseline_times,
|
||||
width,
|
||||
label="Baseline",
|
||||
alpha=0.8,
|
||||
color="orange",
|
||||
)
|
||||
|
||||
# Add speedup labels over each bar pair
|
||||
for i in range(len(x)):
|
||||
speedup = ratio[i]
|
||||
max_height = max(cuda_times[i], baseline_times[i])
|
||||
ax.text(
|
||||
x[i],
|
||||
max_height + max_height * 0.02,
|
||||
f"{speedup:.2f}x",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontweight="bold",
|
||||
fontsize=9,
|
||||
)
|
||||
|
||||
ax.set_xlabel("Configuration")
|
||||
ax.set_ylabel("% Utilization")
|
||||
ax.set_title(
|
||||
f"Memory Bandwidth Utilization (%) - {strategy_name}\n(Higher is Better)"
|
||||
)
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(config_labels, rotation=45, ha="right")
|
||||
ax.legend()
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
filename = "../../silu_bench/silu_benchmark_combined.png"
|
||||
plt.savefig(filename, dpi=300, bbox_inches="tight")
|
||||
plt.show()
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
outer_dim = 7168
|
||||
configs = [
|
||||
(8, 32, 1024),
|
||||
(16, 64, 2048),
|
||||
(32, 128, 4096),
|
||||
# DeepSeekV3 Configs
|
||||
(256, 16, 7168),
|
||||
(256, 32, 7168),
|
||||
(256, 64, 7168),
|
||||
(256, 128, 7168),
|
||||
(256, 256, 7168),
|
||||
(256, 512, 7168),
|
||||
(8, 1024, 7168),
|
||||
# DeepSeekV3 Configs
|
||||
(32, 1024, 7168),
|
||||
# DeepSeekV3 Configs
|
||||
(256, 1024, 7168),
|
||||
]
|
||||
|
||||
print(f"GPU: {torch.cuda.get_device_name()}")
|
||||
print(f"{'Config':<20} {'Time(ms)':<10} {'GFLOPS':<10} {'GB/s':<10}")
|
||||
print("-" * 50)
|
||||
runs = 100
|
||||
num_warmups = 20
|
||||
|
||||
for E, T, H in configs:
|
||||
try:
|
||||
time_ms, gflops, gbps = benchmark(E, T, H)
|
||||
print(f"E={E:3d},T={T:4d},H={H:4d} {time_ms:8.3f} {gflops:8.1f} {gbps:8.1f}")
|
||||
except Exception:
|
||||
print(f"E={E:3d},T={T:4d},H={H:4d} FAILED")
|
||||
strategy_descriptions = {
|
||||
"uniform": "Uniform Random",
|
||||
"max_t": "Even Assignment",
|
||||
"first_t": "experts[0] = T, experts[1:] = 0",
|
||||
}
|
||||
|
||||
print(f"GPU: {torch.cuda.get_device_name()}")
|
||||
print(f"Testing strategies: {', '.join(strategies)}")
|
||||
print(f"Configurations: {len(configs)} configs")
|
||||
|
||||
all_results = []
|
||||
|
||||
# Run benchmarks for each strategy
|
||||
for id, strategy in enumerate(strategies):
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Testing strategy: {strategy_descriptions[strategy]}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
# Collect benchmark data for both algorithms
|
||||
config_labels = []
|
||||
config_x_axis = []
|
||||
all_cuda_results = []
|
||||
all_baseline_results = []
|
||||
all_ratios = []
|
||||
|
||||
for E, T, H in configs:
|
||||
total_tokens_config = [8 * E, 16 * E, 32 * E, 64 * E, 128 * E, 256 * E]
|
||||
config_x_axis.append(total_tokens_config)
|
||||
|
||||
cuda_results = []
|
||||
baseline_results = []
|
||||
ratios = []
|
||||
|
||||
for total_tokens in total_tokens_config:
|
||||
config_label = f"E={E},T={T},H={H},TT={total_tokens}"
|
||||
config_labels.append(config_label)
|
||||
|
||||
# CUDA kernel results
|
||||
time_ms_cuda, gflops, gbps, perc = benchmark(
|
||||
silu_mul_fp8_quant_deep_gemm_cuda,
|
||||
E,
|
||||
T,
|
||||
H,
|
||||
total_tokens,
|
||||
runs=runs,
|
||||
num_warmups=num_warmups,
|
||||
gen_strategy=strategy,
|
||||
)
|
||||
cuda_results.append((time_ms_cuda, gflops, gbps, perc))
|
||||
|
||||
# Baseline results
|
||||
time_ms_triton, gflops, gbps, perc = benchmark(
|
||||
silu_mul_fp8_quant_deep_gemm_triton,
|
||||
E,
|
||||
T,
|
||||
H,
|
||||
total_tokens,
|
||||
runs=runs,
|
||||
num_warmups=num_warmups,
|
||||
gen_strategy=strategy,
|
||||
)
|
||||
baseline_results.append((time_ms_triton, gflops, gbps, perc))
|
||||
ratios.append(time_ms_triton / time_ms_cuda)
|
||||
|
||||
print(f"Completed: {config_label}")
|
||||
all_cuda_results.append(cuda_results)
|
||||
all_baseline_results.append(baseline_results)
|
||||
all_ratios.append(ratios)
|
||||
|
||||
# Store results for combined plotting
|
||||
all_results.append(
|
||||
(
|
||||
strategy_descriptions[strategy],
|
||||
all_ratios,
|
||||
all_cuda_results,
|
||||
all_baseline_results,
|
||||
config_labels,
|
||||
config_x_axis,
|
||||
)
|
||||
)
|
||||
|
||||
# Print summary table for this strategy
|
||||
print(f"\nSummary Table - {strategy_descriptions[strategy]}:")
|
||||
print(f"{'Config':<20} {'CUDA Time(ms)':<12} {'Base Time(ms)':<12} {'Speedup':<8}")
|
||||
print("-" * 60)
|
||||
|
||||
for i, (E, T, H) in enumerate(configs):
|
||||
speedup = baseline_results[i][0] / cuda_results[i][0]
|
||||
config_label = f"E={E:3d},T={T:4d},H={H:4d}"
|
||||
print(
|
||||
f"{config_label:<20} {cuda_results[i][0]:8.5f} "
|
||||
f"{baseline_results[i][0]:8.5f} {speedup:6.2f}x"
|
||||
)
|
||||
|
||||
|
||||
def create_total_tokens_plot(all_results):
|
||||
num_strategies = len(all_results)
|
||||
num_configs = len(configs)
|
||||
|
||||
# Create side-by-side subplots: 2 columns for speedup and bandwidth percentage
|
||||
fig, axs = plt.subplots(
|
||||
num_strategies, num_configs * 2, figsize=(28, 6 * num_strategies)
|
||||
)
|
||||
|
||||
# Add main title to the entire figure
|
||||
fig.suptitle(
|
||||
"Performance Analysis: Speedup vs Bandwidth Utilization (Triton & CUDA)",
|
||||
fontsize=16,
|
||||
fontweight="bold",
|
||||
y=0.98,
|
||||
)
|
||||
|
||||
# Handle single strategy case
|
||||
if num_strategies == 1:
|
||||
axs = axs.reshape(1, -1)
|
||||
|
||||
# Handle single config case
|
||||
if num_configs == 1:
|
||||
axs = axs.reshape(-1, 2)
|
||||
|
||||
for strategy_idx, result in enumerate(all_results):
|
||||
(
|
||||
strategy_name,
|
||||
all_ratios,
|
||||
all_cuda_results,
|
||||
all_baseline_results,
|
||||
config_labels,
|
||||
config_x_axis,
|
||||
) = result
|
||||
|
||||
for config_idx in range(num_configs):
|
||||
# Speedup plot (left column)
|
||||
ax_speedup = axs[strategy_idx, config_idx * 2]
|
||||
# Bandwidth plot (right column)
|
||||
ax_bandwidth = axs[strategy_idx, config_idx * 2 + 1]
|
||||
|
||||
E, T, H = configs[config_idx]
|
||||
ratios = all_ratios[config_idx]
|
||||
total_tokens_values = config_x_axis[config_idx]
|
||||
|
||||
# Extract CUDA and Triton bandwidth percentages
|
||||
cuda_bandwidth_percentages = [
|
||||
result[3] for result in all_cuda_results[config_idx]
|
||||
]
|
||||
triton_bandwidth_percentages = [
|
||||
result[3] for result in all_baseline_results[config_idx]
|
||||
]
|
||||
|
||||
# Plot speedup ratios vs total tokens (left plot)
|
||||
ax_speedup.plot(
|
||||
total_tokens_values, ratios, "bo-", linewidth=3, markersize=8
|
||||
)
|
||||
ax_speedup.set_title(
|
||||
f"{strategy_name}\nSpeedup (CUDA/Triton)\nE={E}, T={T}, H={H}",
|
||||
fontsize=12,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax_speedup.set_xlabel("Total Tokens", fontweight="bold", fontsize=11)
|
||||
ax_speedup.set_ylabel("Speedup Ratio", fontweight="bold", fontsize=11)
|
||||
ax_speedup.grid(True, alpha=0.3)
|
||||
|
||||
ax_bandwidth.plot(
|
||||
total_tokens_values,
|
||||
cuda_bandwidth_percentages,
|
||||
"ro-",
|
||||
linewidth=3,
|
||||
markersize=8,
|
||||
label="CUDA",
|
||||
)
|
||||
ax_bandwidth.plot(
|
||||
total_tokens_values,
|
||||
triton_bandwidth_percentages,
|
||||
"go-",
|
||||
linewidth=3,
|
||||
markersize=8,
|
||||
label="Triton",
|
||||
)
|
||||
ax_bandwidth.set_title(
|
||||
f"{strategy_name}\nBandwidth Utilization (Hopper)\nE={E}, T={T}, H={H}",
|
||||
fontsize=12,
|
||||
fontweight="bold",
|
||||
)
|
||||
ax_bandwidth.set_xlabel("Total Tokens", fontweight="bold", fontsize=11)
|
||||
ax_bandwidth.set_ylabel(
|
||||
"% of Peak Bandwidth", fontweight="bold", fontsize=11
|
||||
)
|
||||
ax_bandwidth.legend(prop={"weight": "bold"})
|
||||
ax_bandwidth.grid(True, alpha=0.3)
|
||||
|
||||
# Format x-axis labels for both plots
|
||||
for ax in [ax_speedup, ax_bandwidth]:
|
||||
ax.set_xticks(total_tokens_values)
|
||||
ax.set_xticklabels(
|
||||
[
|
||||
f"{tt // 1000}K" if tt >= 1000 else str(tt)
|
||||
for tt in total_tokens_values
|
||||
],
|
||||
fontweight="bold",
|
||||
)
|
||||
# Make tick labels bold
|
||||
for label in ax.get_xticklabels() + ax.get_yticklabels():
|
||||
label.set_fontweight("bold")
|
||||
|
||||
# Add value labels on speedup points
|
||||
for x, y in zip(total_tokens_values, ratios):
|
||||
ax_speedup.annotate(
|
||||
f"{y:.2f}x",
|
||||
(x, y),
|
||||
textcoords="offset points",
|
||||
xytext=(0, 12),
|
||||
ha="center",
|
||||
fontsize=10,
|
||||
fontweight="bold",
|
||||
bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.7),
|
||||
)
|
||||
|
||||
# Add value labels on CUDA bandwidth points
|
||||
for x, y in zip(total_tokens_values, cuda_bandwidth_percentages):
|
||||
ax_bandwidth.annotate(
|
||||
f"{y:.1f}%",
|
||||
(x, y),
|
||||
textcoords="offset points",
|
||||
xytext=(0, 12),
|
||||
ha="center",
|
||||
fontsize=9,
|
||||
fontweight="bold",
|
||||
bbox=dict(boxstyle="round,pad=0.2", facecolor="red", alpha=0.3),
|
||||
)
|
||||
|
||||
# Add value labels on Triton bandwidth points
|
||||
for x, y in zip(total_tokens_values, triton_bandwidth_percentages):
|
||||
ax_bandwidth.annotate(
|
||||
f"{y:.1f}%",
|
||||
(x, y),
|
||||
textcoords="offset points",
|
||||
xytext=(0, -15),
|
||||
ha="center",
|
||||
fontsize=9,
|
||||
fontweight="bold",
|
||||
bbox=dict(boxstyle="round,pad=0.2", facecolor="green", alpha=0.3),
|
||||
)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.subplots_adjust(top=0.93) # Make room for main title
|
||||
filename = "silu_benchmark_total_tokens.png"
|
||||
plt.savefig(filename, dpi=300, bbox_inches="tight")
|
||||
plt.show()
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
# Create combined plot with all strategies
|
||||
combined_plot_filename = create_total_tokens_plot(all_results)
|
||||
|
||||
print(f"\n{'=' * 60}")
|
||||
print("Benchmark Complete!")
|
||||
print(f"Generated combined plot: {combined_plot_filename}")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
@ -259,6 +259,7 @@ if __name__ == "__main__":
|
||||
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
|
||||
(None, None, None),
|
||||
(None, FP8_DTYPE, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
|
||||
]
|
||||
|
||||
@ -274,6 +274,7 @@ if __name__ == "__main__":
|
||||
quant_dtypes = [
|
||||
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
|
||||
(None, None, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
|
||||
]
|
||||
|
||||
@ -56,7 +56,7 @@ def w8a8_block_matmul(
|
||||
Bs: The per-block quantization scale for `B`.
|
||||
block_size: The block size for per-block quantization.
|
||||
It should be 2-dim, e.g., [128, 128].
|
||||
output_dytpe: The dtype of the returned tensor.
|
||||
output_dtype: The dtype of the returned tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The result of matmul.
|
||||
|
||||
@ -55,6 +55,107 @@ output_num_chunks 166.0 99.01 11.80 79.00 90.00 98.00 108.75
|
||||
----------------------------------------------------------------------------------------------------
|
||||
```
|
||||
|
||||
### JSON configuration file for synthetic conversations generation
|
||||
|
||||
The input flag `--input-file` is used to determine the input conversations for the benchmark.<br/>
|
||||
When the input is a JSON file with the field `"filetype": "generate_conversations"` the tool will generate synthetic multi-turn (questions and answers) conversations.
|
||||
|
||||
The file `generate_multi_turn.json` is an example file.
|
||||
|
||||
The file must contain the sections `prompt_input` and `prompt_output`.
|
||||
|
||||
The `prompt_input` section must contain `num_turns`, `prefix_num_tokens` and `num_tokens`:
|
||||
|
||||
* `num_turns` - Number of total turns in the conversation (both user & assistant).<br/>
|
||||
The final value will always be rounded to an even number so each user turn has a reply.
|
||||
* `prefix_num_tokens` - Tokens added at the start of only the **first user turn** in a conversation (unique per conversation).
|
||||
* `num_tokens` - Total token length of each **user** message (one turn).
|
||||
|
||||
The `prompt_output` section must contain `num_tokens`:
|
||||
|
||||
* `num_tokens` - Total token length of each **assistant** message (one turn).
|
||||
|
||||
### Random distributions for synthetic conversations generation
|
||||
|
||||
When creating an input JSON file (such as `generate_multi_turn.json`),<br/>
|
||||
every numeric field (such as `num_turns` or `num_tokens`) requires a distribution.<br/>
|
||||
The distribution determines how to randomly sample values for the field.
|
||||
|
||||
The available distributions are listed below.
|
||||
|
||||
**Note:** The optional `max` field (for lognormal, zipf, and poisson) can be used to cap sampled values at an upper bound.</br>
|
||||
Can be used to make sure that the total number of tokens in every request does not exceed `--max-model-len`.
|
||||
|
||||
#### constant
|
||||
|
||||
```json
|
||||
{
|
||||
"distribution": "constant",
|
||||
"value": 500
|
||||
}
|
||||
```
|
||||
|
||||
* `value` - the fixed integer value (always returns the same number).
|
||||
|
||||
#### uniform
|
||||
|
||||
```json
|
||||
{
|
||||
"distribution": "uniform",
|
||||
"min": 12,
|
||||
"max": 18
|
||||
}
|
||||
```
|
||||
|
||||
* `min` - minimum value (inclusive).
|
||||
* `max` - maximum value (inclusive), should be equal or larger than min.
|
||||
|
||||
#### lognormal
|
||||
|
||||
```json
|
||||
{
|
||||
"distribution": "lognormal",
|
||||
"average": 1000,
|
||||
"max": 5000
|
||||
}
|
||||
```
|
||||
|
||||
You can parameterize the lognormal distribution in one of two ways:
|
||||
|
||||
Using the average and optional median ratio:
|
||||
|
||||
* `average` - target average value of the distribution.
|
||||
* `median_ratio` - the ratio of the median to the average; controls the skewness. Must be in the range (0, 1).
|
||||
|
||||
Using the parameters of the underlying normal distribution:
|
||||
|
||||
* `mean` - mean of the underlying normal distribution.
|
||||
* `sigma` - standard deviation of the underlying normal distribution.
|
||||
|
||||
#### zipf
|
||||
|
||||
```json
|
||||
{
|
||||
"distribution": "zipf",
|
||||
"alpha": 1.2,
|
||||
"max": 100
|
||||
}
|
||||
```
|
||||
|
||||
* `alpha` - skew parameter (> 1). Larger values produce stronger skew toward smaller integers.
|
||||
|
||||
#### poisson
|
||||
|
||||
```json
|
||||
{
|
||||
"distribution": "poisson",
|
||||
"alpha": 10,
|
||||
"max": 50
|
||||
}
|
||||
```
|
||||
|
||||
* `alpha` - expected value (λ). Also the variance of the distribution.
|
||||
|
||||
## ShareGPT Conversations
|
||||
|
||||
To run with the ShareGPT data, download the following ShareGPT dataset:
|
||||
|
||||
@ -99,21 +99,105 @@ class PoissonDistribution(Distribution):
|
||||
|
||||
class LognormalDistribution(Distribution):
|
||||
def __init__(
|
||||
self, mean: float, sigma: float, max_val: Optional[int] = None
|
||||
self,
|
||||
mean: Optional[float] = None,
|
||||
sigma: Optional[float] = None,
|
||||
average: Optional[int] = None,
|
||||
median_ratio: Optional[float] = None,
|
||||
max_val: Optional[int] = None,
|
||||
) -> None:
|
||||
self.average = average
|
||||
self.median_ratio = median_ratio
|
||||
self.max_val = max_val
|
||||
|
||||
if average is not None:
|
||||
if average < 1:
|
||||
raise ValueError("Lognormal average must be positive")
|
||||
|
||||
if mean or sigma:
|
||||
raise ValueError(
|
||||
"When using lognormal average, you can't provide mean/sigma"
|
||||
)
|
||||
|
||||
if self.median_ratio is None:
|
||||
# Default value that provides relatively wide range of values
|
||||
self.median_ratio = 0.85
|
||||
|
||||
# Calculate mean/sigma of np.random.lognormal based on the average
|
||||
mean, sigma = self._generate_lognormal_by_median(
|
||||
target_average=self.average, median_ratio=self.median_ratio
|
||||
)
|
||||
else:
|
||||
if mean is None or sigma is None:
|
||||
raise ValueError(
|
||||
"Must provide both mean and sigma if average is not used"
|
||||
)
|
||||
|
||||
if mean <= 0 or sigma < 0:
|
||||
raise ValueError(
|
||||
"Lognormal mean must be positive and sigma must be non-negative"
|
||||
)
|
||||
|
||||
# Mean and standard deviation of the underlying normal distribution
|
||||
# Based on numpy.random.lognormal
|
||||
self.mean = mean
|
||||
self.sigma = sigma
|
||||
self.max_val = max_val
|
||||
|
||||
@staticmethod
|
||||
def _generate_lognormal_by_median(
|
||||
target_average: int, median_ratio: float
|
||||
) -> tuple[float, float]:
|
||||
"""
|
||||
Compute (mu, sigma) for a lognormal distribution given:
|
||||
- a target average (mean of the distribution)
|
||||
- a ratio of median / mean (controls skewness), assume mean > median
|
||||
|
||||
Background:
|
||||
If Z ~ Normal(mu, sigma^2), then X = exp(Z) ~ LogNormal(mu, sigma).
|
||||
* mean(X) = exp(mu + sigma^2 / 2)
|
||||
* median(X) = exp(mu)
|
||||
|
||||
So:
|
||||
median / mean = exp(mu) / exp(mu + sigma^2 / 2)
|
||||
= exp(-sigma^2 / 2)
|
||||
|
||||
Rearranging:
|
||||
sigma^2 = 2 * ln(mean / median)
|
||||
mu = ln(median)
|
||||
|
||||
This gives a unique (mu, sigma) for any valid mean and median.
|
||||
"""
|
||||
# Check input validity: median must be smaller than mean
|
||||
if median_ratio <= 0 or median_ratio >= 1:
|
||||
raise ValueError("median_ratio must be in range (0, 1)")
|
||||
|
||||
target_median = target_average * median_ratio
|
||||
|
||||
# Solve sigma^2 = 2 * ln(mean / median)
|
||||
sigma = np.sqrt(2 * np.log(target_average / target_median))
|
||||
mu = np.log(target_median)
|
||||
|
||||
return mu, sigma
|
||||
|
||||
def sample(self, size: int = 1) -> np.ndarray:
|
||||
samples = np.random.lognormal(mean=self.mean, sigma=self.sigma, size=size)
|
||||
|
||||
if self.average is not None:
|
||||
# Scale to average
|
||||
samples *= self.average / samples.mean()
|
||||
|
||||
if self.max_val:
|
||||
samples = np.minimum(samples, self.max_val)
|
||||
|
||||
return np.round(samples).astype(int)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"LognormalDistribution[{self.mean}, {self.sigma}]"
|
||||
if self.average:
|
||||
return (
|
||||
f"LognormalDistribution[{self.average}, "
|
||||
f"{self.median_ratio}, {self.max_val}]"
|
||||
)
|
||||
return f"LognormalDistribution[{self.mean}, {self.sigma}, {self.max_val}]"
|
||||
|
||||
|
||||
class GenConvArgs(NamedTuple):
|
||||
@ -173,10 +257,21 @@ def get_random_distribution(
|
||||
return PoissonDistribution(conf["alpha"], max_val=max_val)
|
||||
|
||||
elif distribution == "lognormal":
|
||||
max_val = conf.get("max", None)
|
||||
|
||||
if "average" in conf:
|
||||
# Infer lognormal mean/sigma (numpy) from input average
|
||||
median_ratio = conf.get("median_ratio", None)
|
||||
return LognormalDistribution(
|
||||
average=conf["average"], median_ratio=median_ratio, max_val=max_val
|
||||
)
|
||||
|
||||
# Use mean/sigma directly (for full control over the distribution)
|
||||
verify_field_exists(conf, "mean", section, subsection)
|
||||
verify_field_exists(conf, "sigma", section, subsection)
|
||||
max_val = conf.get("max", None)
|
||||
return LognormalDistribution(conf["mean"], conf["sigma"], max_val=max_val)
|
||||
return LognormalDistribution(
|
||||
mean=conf["mean"], sigma=conf["sigma"], max_val=max_val
|
||||
)
|
||||
|
||||
elif distribution == "uniform":
|
||||
verify_field_exists(conf, "min", section, subsection)
|
||||
|
||||
@ -15,9 +15,8 @@
|
||||
},
|
||||
"prefix_num_tokens": {
|
||||
"distribution": "lognormal",
|
||||
"mean": 6,
|
||||
"sigma": 4,
|
||||
"max": 1500
|
||||
"average": 1000,
|
||||
"max": 5000
|
||||
},
|
||||
"num_tokens": {
|
||||
"distribution": "uniform",
|
||||
|
||||
@ -480,7 +480,6 @@ function (define_gpu_extension_target GPU_MOD_NAME)
|
||||
${GPU_LANGUAGE}_ARCHITECTURES "${GPU_ARCHITECTURES}")
|
||||
endif()
|
||||
|
||||
set_property(TARGET ${GPU_MOD_NAME} PROPERTY CXX_STANDARD 17)
|
||||
|
||||
target_compile_options(${GPU_MOD_NAME} PRIVATE
|
||||
$<$<COMPILE_LANGUAGE:${GPU_LANGUAGE}>:${GPU_COMPILE_FLAGS}>)
|
||||
|
||||
@ -1,38 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#if defined ENABLE_CUTLASS_MLA && ENABLE_CUTLASS_MLA
|
||||
void cutlass_mla_decode_sm100a(torch::Tensor const& out,
|
||||
torch::Tensor const& q_nope,
|
||||
torch::Tensor const& q_pe,
|
||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table, double scale);
|
||||
#endif
|
||||
|
||||
void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
|
||||
torch::Tensor const& q_pe,
|
||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table, double scale) {
|
||||
#if defined ENABLE_CUTLASS_MLA && ENABLE_CUTLASS_MLA
|
||||
return cutlass_mla_decode_sm100a(out, q_nope, q_pe, kv_c_and_k_pe_cache,
|
||||
seq_lens, page_table, scale);
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled cutlass MLA");
|
||||
}
|
||||
@ -1,225 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/kernel_hardware_info.h"
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
#include "device/sm100_mla.hpp"
|
||||
#include "kernel/sm100_mla_tile_scheduler.hpp"
|
||||
|
||||
using namespace cute;
|
||||
using namespace cutlass::fmha::kernel;
|
||||
|
||||
template <typename T, bool PersistenceOption = true>
|
||||
struct MlaSm100 {
|
||||
using Element = T;
|
||||
using ElementAcc = float;
|
||||
using ElementOut = T;
|
||||
|
||||
using TileShape = Shape<_128, _128, Shape<_512, _64>>;
|
||||
using TileShapeH = cute::tuple_element_t<0, TileShape>;
|
||||
using TileShapeD = cute::tuple_element_t<2, TileShape>;
|
||||
|
||||
// H K (D_latent D_rope) B
|
||||
using ProblemShape = cute::tuple<TileShapeH, int, TileShapeD, int>;
|
||||
|
||||
using StrideQ = cute::tuple<int64_t, _1, int64_t>; // H D B
|
||||
using StrideK = cute::tuple<int64_t, _1, int64_t>; // K D B
|
||||
using StrideO = StrideK; // H D B
|
||||
using StrideLSE = cute::tuple<_1, int>; // H B
|
||||
|
||||
using TileScheduler =
|
||||
std::conditional_t<PersistenceOption, Sm100MlaPersistentTileScheduler,
|
||||
Sm100MlaIndividualTileScheduler>;
|
||||
|
||||
using FmhaKernel =
|
||||
cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized<
|
||||
TileShape, Element, ElementAcc, ElementOut, ElementAcc, TileScheduler,
|
||||
/*kIsCpAsync=*/true>;
|
||||
using Fmha = cutlass::fmha::device::MLA<FmhaKernel>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
typename T::Fmha::Arguments args_from_options(
|
||||
at::Tensor const& out, at::Tensor const& q_nope, at::Tensor const& q_pe,
|
||||
at::Tensor const& kv_c_and_k_pe_cache, at::Tensor const& seq_lens,
|
||||
at::Tensor const& page_table, double scale) {
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = q_nope.device().index();
|
||||
hw_info.sm_count =
|
||||
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
|
||||
hw_info.device_id);
|
||||
|
||||
int batches = q_nope.sizes()[0];
|
||||
int page_count_per_seq = page_table.sizes()[1];
|
||||
int page_count_total = kv_c_and_k_pe_cache.sizes()[0];
|
||||
int page_size = kv_c_and_k_pe_cache.sizes()[1];
|
||||
int max_seq_len = page_size * page_count_per_seq;
|
||||
using TileShapeH = typename T::TileShapeH;
|
||||
using TileShapeD = typename T::TileShapeD;
|
||||
auto problem_shape =
|
||||
cute::make_tuple(TileShapeH{}, max_seq_len, TileShapeD{}, batches);
|
||||
|
||||
auto [H, K, D, B] = problem_shape;
|
||||
auto [D_latent, D_rope] = D;
|
||||
|
||||
using StrideQ = typename T::StrideQ;
|
||||
using StrideK = typename T::StrideK;
|
||||
using StrideO = typename T::StrideO;
|
||||
using StrideLSE = typename T::StrideLSE;
|
||||
|
||||
StrideQ stride_Q_latent = cute::make_tuple(
|
||||
static_cast<int64_t>(D_latent), _1{}, static_cast<int64_t>(H * D_latent));
|
||||
StrideQ stride_Q_rope = cute::make_tuple(static_cast<int64_t>(D_rope), _1{},
|
||||
static_cast<int64_t>(H * D_rope));
|
||||
StrideK stride_C =
|
||||
cute::make_tuple(static_cast<int64_t>(D_latent + D_rope), _1{},
|
||||
static_cast<int64_t>(page_size * (D_latent + D_rope)));
|
||||
StrideLSE stride_PT = cute::make_stride(_1{}, page_count_per_seq);
|
||||
StrideLSE stride_LSE = cute::make_tuple(_1{}, static_cast<int>(H));
|
||||
StrideO stride_O = cute::make_tuple(static_cast<int64_t>(D_latent), _1{},
|
||||
static_cast<int64_t>(H * D_latent));
|
||||
|
||||
using Element = typename T::Element;
|
||||
using ElementOut = typename T::ElementOut;
|
||||
using ElementAcc = typename T::ElementAcc;
|
||||
auto Q_latent_ptr = static_cast<Element*>(q_nope.data_ptr());
|
||||
auto Q_rope_ptr = static_cast<Element*>(q_pe.data_ptr());
|
||||
auto C_ptr = static_cast<Element*>(kv_c_and_k_pe_cache.data_ptr());
|
||||
auto scale_f = static_cast<float>(scale);
|
||||
typename T::Fmha::Arguments arguments{
|
||||
problem_shape,
|
||||
{scale_f, Q_latent_ptr, stride_Q_latent, Q_rope_ptr, stride_Q_rope, C_ptr,
|
||||
stride_C, C_ptr + D_latent, stride_C,
|
||||
static_cast<int*>(seq_lens.data_ptr()),
|
||||
static_cast<int*>(page_table.data_ptr()), stride_PT, page_count_total,
|
||||
page_size},
|
||||
{static_cast<ElementOut*>(out.data_ptr()), stride_O,
|
||||
static_cast<ElementAcc*>(nullptr), stride_LSE},
|
||||
hw_info,
|
||||
1, // split_kv
|
||||
nullptr, // is_var_split_kv
|
||||
};
|
||||
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
|
||||
// split_kv automatically based on batch size and sequence length to balance
|
||||
// workload across available SMs. Consider using var_split_kv for manual
|
||||
// control if needed.
|
||||
T::Fmha::set_split_kv(arguments);
|
||||
return arguments;
|
||||
}
|
||||
|
||||
template <typename Element>
|
||||
void runMla(at::Tensor const& out, at::Tensor const& q_nope,
|
||||
at::Tensor const& q_pe, at::Tensor const& kv_c_and_k_pe_cache,
|
||||
at::Tensor const& seq_lens, at::Tensor const& page_table,
|
||||
float scale, cudaStream_t stream) {
|
||||
using MlaSm100Type = MlaSm100<Element>;
|
||||
typename MlaSm100Type::Fmha fmha;
|
||||
auto arguments = args_from_options<MlaSm100Type>(
|
||||
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, scale);
|
||||
size_t workspace_size = MlaSm100Type::Fmha::get_workspace_size(arguments);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(q_nope.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
CUTLASS_CHECK(fmha.can_implement(arguments));
|
||||
|
||||
CUTLASS_CHECK(fmha.initialize(arguments, workspace.data_ptr(), stream));
|
||||
|
||||
CUTLASS_CHECK(fmha.run(arguments, workspace.data_ptr(), stream));
|
||||
}
|
||||
|
||||
void cutlass_mla_decode_sm100a(torch::Tensor const& out,
|
||||
torch::Tensor const& q_nope,
|
||||
torch::Tensor const& q_pe,
|
||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table, double scale) {
|
||||
TORCH_CHECK(q_nope.device().is_cuda(), "q_nope must be on CUDA");
|
||||
TORCH_CHECK(q_nope.dim() == 3, "q_nope must be a 3D tensor");
|
||||
TORCH_CHECK(q_pe.dim() == 3, "q_pe must be a 3D tensor");
|
||||
TORCH_CHECK(kv_c_and_k_pe_cache.dim() == 3,
|
||||
"kv_c_and_k_pe_cache must be a 3D tensor");
|
||||
TORCH_CHECK(seq_lens.dim() == 1, "seq_lens must be a 1D tensor");
|
||||
TORCH_CHECK(page_table.dim() == 2, "page_table must be a 2D tensor");
|
||||
TORCH_CHECK(out.dim() == 3, "out must be a 3D tensor");
|
||||
|
||||
auto B_q_nope = q_nope.size(0);
|
||||
auto H_q_nope = q_nope.size(1);
|
||||
auto D_q_nope = q_nope.size(2);
|
||||
auto B_q_pe = q_pe.size(0);
|
||||
auto H_q_pe = q_pe.size(1);
|
||||
auto D_q_pe = q_pe.size(2);
|
||||
auto B_pt = page_table.size(0);
|
||||
auto PAGE_NUM = page_table.size(1);
|
||||
auto PAGE_SIZE = kv_c_and_k_pe_cache.size(1);
|
||||
auto D_ckv = kv_c_and_k_pe_cache.size(2);
|
||||
auto B_o = out.size(0);
|
||||
auto H_o = out.size(1);
|
||||
auto D_o = out.size(2);
|
||||
|
||||
TORCH_CHECK(D_q_nope == 512, "D_q_nope must be equal to 512");
|
||||
TORCH_CHECK(D_q_pe == 64, "D_q_pe must be equal to 64");
|
||||
TORCH_CHECK(D_ckv == 576, "D_ckv must be equal to 576");
|
||||
TORCH_CHECK(H_q_nope == H_q_pe && H_q_nope == H_o && H_o == 128,
|
||||
"H_q_nope, H_q_pe, and H_o must be equal to 128");
|
||||
TORCH_CHECK(PAGE_SIZE > 0 && (PAGE_SIZE & (PAGE_SIZE - 1)) == 0,
|
||||
"PAGE_SIZE must be a power of 2");
|
||||
TORCH_CHECK(
|
||||
B_q_nope == B_q_pe && B_q_nope == B_pt && B_q_nope == B_o,
|
||||
"Batch dims must be same for page_table, q_nope and q_pe, and out");
|
||||
TORCH_CHECK(PAGE_NUM % (128 / PAGE_SIZE) == 0,
|
||||
"PAGE_NUM must be divisible by 128 / PAGE_SIZE");
|
||||
TORCH_CHECK(D_o == 512, "D_o must be equal to 512");
|
||||
|
||||
TORCH_CHECK(q_nope.dtype() == at::ScalarType::Half ||
|
||||
q_nope.dtype() == at::ScalarType::BFloat16 ||
|
||||
q_nope.dtype() == at::ScalarType::Float8_e4m3fn,
|
||||
"q_nope must be a half, bfloat16, or float8_e4m3fn tensor");
|
||||
TORCH_CHECK(kv_c_and_k_pe_cache.dtype() == q_nope.dtype() &&
|
||||
q_nope.dtype() == q_pe.dtype(),
|
||||
"kv_c_and_k_pe_cache, q_nope, and q_pe must be the same type");
|
||||
TORCH_CHECK(seq_lens.dtype() == torch::kInt32,
|
||||
"seq_lens must be a 32-bit integer tensor");
|
||||
TORCH_CHECK(page_table.dtype() == torch::kInt32,
|
||||
"page_table must be a 32-bit integer tensor");
|
||||
|
||||
auto in_dtype = q_nope.dtype();
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_nope));
|
||||
const cudaStream_t stream =
|
||||
at::cuda::getCurrentCUDAStream(q_nope.get_device());
|
||||
if (in_dtype == at::ScalarType::Half) {
|
||||
runMla<cutlass::half_t>(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens,
|
||||
page_table, scale, stream);
|
||||
} else if (in_dtype == at::ScalarType::BFloat16) {
|
||||
runMla<cutlass::bfloat16_t>(out, q_nope, q_pe, kv_c_and_k_pe_cache,
|
||||
seq_lens, page_table, scale, stream);
|
||||
} else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
|
||||
runMla<cutlass::float_e4m3_t>(out, q_nope, q_pe, kv_c_and_k_pe_cache,
|
||||
seq_lens, page_table, scale, stream);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported input data type of MLA");
|
||||
}
|
||||
}
|
||||
@ -133,6 +133,14 @@ public:
|
||||
// printf(" sm_count = %d\n", sm_count);
|
||||
int max_splits = ceil_div(K, 128);
|
||||
max_splits = min(16, max_splits);
|
||||
|
||||
// TODO: This avoids a hang when the batch size larger than 1 and
|
||||
// there is more than 4 kv_splits.
|
||||
// Discuss with NVIDIA how this can be fixed.
|
||||
if (B > 1) {
|
||||
max_splits = min(2, max_splits);
|
||||
}
|
||||
|
||||
// printf(" max_splits = %d\n", max_splits);
|
||||
int sms_per_batch = max(1, sm_count / B);
|
||||
// printf(" sms_per_batch = %d\n", sms_per_batch);
|
||||
|
||||
@ -43,6 +43,7 @@ void sm100_cutlass_mla_decode(
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table,
|
||||
torch::Tensor const& workspace,
|
||||
double sm_scale,
|
||||
int64_t num_kv_splits) {
|
||||
TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_decode");
|
||||
}
|
||||
|
||||
@ -17,4 +17,8 @@
|
||||
#warning "unsupported vLLM cpu implementation"
|
||||
#endif
|
||||
|
||||
#ifdef _OPENMP
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
#endif
|
||||
@ -12,7 +12,7 @@ namespace vec_op {
|
||||
#define vec_sub(a, b) ((a) - (b))
|
||||
#define vec_mul(a, b) ((a) * (b))
|
||||
#define vec_div(a, b) ((a) / (b))
|
||||
#define vec_sr(a, b) ((a) >> (b)) // Vector Shift Right Algebaic
|
||||
#define vec_sr(a, b) ((a) >> (b)) // Vector Shift Right Algebraic
|
||||
#define vec_sl(a, b) ((a) << (b)) // Vector Shift Left
|
||||
|
||||
// FIXME: FP16 is not fully supported in Torch-CPU
|
||||
|
||||
@ -523,7 +523,7 @@ void onednn_mm(torch::Tensor& c, // [M, OC], row-major
|
||||
CPU_KERNEL_GUARD_IN(onednn_mm)
|
||||
TORCH_CHECK(a.dim() == 2);
|
||||
TORCH_CHECK(a.stride(-1) == 1);
|
||||
TORCH_CHECK(c.is_contiguous());
|
||||
TORCH_CHECK(c.stride(-1) == 1);
|
||||
MatMulPrimitiveHandler* ptr =
|
||||
reinterpret_cast<MatMulPrimitiveHandler*>(handler);
|
||||
|
||||
|
||||
@ -215,7 +215,7 @@ int moe_align_block_size(
|
||||
offsets[mb + 1] = sorted_id_size(sorted_ids + mb * BLOCK_M);
|
||||
}
|
||||
});
|
||||
// TODO: do we need to vecterize this ?
|
||||
// TODO: do we need to vectorize this ?
|
||||
for (int mb = 0; mb < num_token_blocks; ++mb) {
|
||||
offsets[mb + 1] += offsets[mb];
|
||||
}
|
||||
|
||||
17
csrc/cub_helpers.h
Normal file
17
csrc/cub_helpers.h
Normal file
@ -0,0 +1,17 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cub/cub.cuh>
|
||||
#if CUB_VERSION >= 200800
|
||||
#include <cuda/std/functional>
|
||||
using CubAddOp = cuda::std::plus<>;
|
||||
using CubMaxOp = cuda::maximum<>;
|
||||
#else // if CUB_VERSION < 200800
|
||||
using CubAddOp = cub::Sum;
|
||||
using CubMaxOp = cub::Max;
|
||||
#endif // CUB_VERSION
|
||||
#else
|
||||
#include <hipcub/hipcub.hpp>
|
||||
using CubAddOp = cub::Sum;
|
||||
using CubMaxOp = cub::Max;
|
||||
#endif // USE_ROCM
|
||||
@ -15,6 +15,8 @@ typedef __hip_bfloat16 nv_bfloat16;
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
|
||||
namespace vllm {
|
||||
#define CUDACHECK(cmd) \
|
||||
@ -555,22 +557,47 @@ class CustomAllreduce {
|
||||
size /= d;
|
||||
auto bytes = size * sizeof(typename packed_t<T>::P);
|
||||
int blocks = std::min(block_limit, (size + threads - 1) / threads);
|
||||
|
||||
// Check environment variable once
|
||||
const char* env_algo = std::getenv("VLLM_CUSTOM_ALLREDUCE_ALGO");
|
||||
bool force_1stage = false;
|
||||
bool force_2stage = false;
|
||||
if (env_algo != nullptr) {
|
||||
if (std::strcmp(env_algo, "1stage") == 0 ||
|
||||
std::strcmp(env_algo, "oneshot") == 0) {
|
||||
force_1stage = true;
|
||||
} else if (std::strcmp(env_algo, "2stage") == 0 ||
|
||||
std::strcmp(env_algo, "twoshot") == 0) {
|
||||
force_2stage = true;
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"Invalid VLLM_CUSTOM_ALLREDUCE_ALGO: " + std::string(env_algo) +
|
||||
". Valid values: 1stage, oneshot, 2stage, twoshot");
|
||||
}
|
||||
}
|
||||
|
||||
#define KL(ngpus, name) \
|
||||
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
|
||||
rank_, size);
|
||||
#define REDUCE_CASE(ngpus) \
|
||||
case ngpus: { \
|
||||
if (world_size_ == 2) { \
|
||||
KL(ngpus, cross_device_reduce_1stage); \
|
||||
} else if (fully_connected_) { \
|
||||
if ((world_size_ <= 4 && bytes < 512 * 1024) || \
|
||||
(world_size_ <= 8 && bytes < 256 * 1024)) { \
|
||||
KL(ngpus, cross_device_reduce_1stage); \
|
||||
} else { \
|
||||
KL(ngpus, cross_device_reduce_2stage); \
|
||||
} \
|
||||
} \
|
||||
break; \
|
||||
#define REDUCE_CASE(ngpus) \
|
||||
case ngpus: { \
|
||||
if (force_1stage) { \
|
||||
KL(ngpus, cross_device_reduce_1stage); \
|
||||
} else if (force_2stage) { \
|
||||
KL(ngpus, cross_device_reduce_2stage); \
|
||||
} else { \
|
||||
if (world_size_ == 2) { \
|
||||
KL(ngpus, cross_device_reduce_1stage); \
|
||||
} else if (fully_connected_) { \
|
||||
if ((world_size_ <= 4 && bytes < 512 * 1024) || \
|
||||
(world_size_ <= 8 && bytes < 256 * 1024)) { \
|
||||
KL(ngpus, cross_device_reduce_1stage); \
|
||||
} else { \
|
||||
KL(ngpus, cross_device_reduce_2stage); \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
|
||||
switch (world_size_) {
|
||||
|
||||
@ -1,123 +0,0 @@
|
||||
// Modified from: cutlass/gemm/collective/builders/sm90_gmma_builder.inl
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/collective/builders/sm90_gmma_builder.inl"
|
||||
|
||||
#include "cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp"
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// GMMA_TMA_WS_SS (BlockScaled Builders)
|
||||
template <
|
||||
class ElementA,
|
||||
class GmemLayoutATag,
|
||||
int AlignmentA,
|
||||
class ElementB,
|
||||
class GmemLayoutBTag,
|
||||
int AlignmentB,
|
||||
class ElementAccumulator,
|
||||
class TileShape_MNK,
|
||||
class ClusterShape_MNK,
|
||||
class StageCountType,
|
||||
int ScaleGranularityM
|
||||
>
|
||||
struct CollectiveBuilder<
|
||||
arch::Sm90,
|
||||
arch::OpClassTensorOp,
|
||||
ElementA,
|
||||
GmemLayoutATag,
|
||||
AlignmentA,
|
||||
ElementB,
|
||||
GmemLayoutBTag,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
TileShape_MNK,
|
||||
ClusterShape_MNK,
|
||||
StageCountType,
|
||||
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>,
|
||||
cute::enable_if_t<
|
||||
not detail::is_use_rmem_A<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>()>
|
||||
> {
|
||||
using KernelScheduleType = KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>;
|
||||
|
||||
static_assert(is_static<TileShape_MNK>::value);
|
||||
static_assert(is_static<ClusterShape_MNK>::value);
|
||||
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
|
||||
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
|
||||
#endif
|
||||
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
|
||||
"Should meet TMA alignment requirement\n");
|
||||
|
||||
static constexpr bool IsArrayOfPointersGemm = (cute::is_any_of_v<KernelScheduleType,
|
||||
KernelPtrArrayTmaWarpSpecializedCooperative,
|
||||
KernelPtrArrayTmaWarpSpecializedPingpong>);
|
||||
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
|
||||
static_assert((!IsFP8Input || !IsArrayOfPointersGemm),
|
||||
"KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum is only compatible with FP8 Blocked Scaled version right now.");
|
||||
|
||||
// For fp32 types, map to tf32 MMA value type
|
||||
using ElementAMma = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
|
||||
using ElementBMma = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;
|
||||
|
||||
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementAMma, GmemLayoutATag>();
|
||||
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementBMma, GmemLayoutBTag>();
|
||||
|
||||
static constexpr bool IsCooperative = cute::is_any_of_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedCooperative,
|
||||
KernelPtrArrayTmaWarpSpecializedCooperative,
|
||||
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>>;
|
||||
using AtomLayoutMNK = cute::conditional_t<IsCooperative,
|
||||
Layout<Shape<_2,_1,_1>>, Layout<Shape<_1,_1,_1>>>;
|
||||
|
||||
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector<
|
||||
ElementAMma, ElementBMma, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{}));
|
||||
|
||||
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
|
||||
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
|
||||
|
||||
using SmemLayoutAtomA = decltype(detail::ss_smem_selector<
|
||||
GmmaMajorA, ElementAMma, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
using SmemLayoutAtomB = decltype(detail::ss_smem_selector<
|
||||
GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
|
||||
|
||||
static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0;
|
||||
static constexpr int KernelSmemCarveout = static_cast<int>(TensorMapStorage);
|
||||
|
||||
static constexpr int PipelineStages = detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes - KernelSmemCarveout,
|
||||
ElementAMma, ElementBMma, TileShape_MNK>(StageCountType{});
|
||||
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType, ScaleGranularityM>;
|
||||
|
||||
using SmemCopyAtomA = void;
|
||||
using SmemCopyAtomB = void;
|
||||
|
||||
using CollectiveOp = CollectiveMma<
|
||||
DispatchPolicy,
|
||||
TileShape_MNK,
|
||||
ElementA,
|
||||
TagToStrideA_t<GmemLayoutATag>,
|
||||
ElementB,
|
||||
TagToStrideB_t<GmemLayoutBTag>,
|
||||
TiledMma,
|
||||
GmemTiledCopyA,
|
||||
SmemLayoutAtomA,
|
||||
SmemCopyAtomA,
|
||||
cute::identity,
|
||||
GmemTiledCopyB,
|
||||
SmemLayoutAtomB,
|
||||
SmemCopyAtomB,
|
||||
cute::identity
|
||||
>;
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,183 +0,0 @@
|
||||
// clang-format off
|
||||
// adapted from: https://github.com/soundOfDestiny/cutlass/blob/a4208aa6958864923505cade9c63eb2a6daf16e5/include/cutlass/gemm/collective/fp8_accumulation.hpp
|
||||
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cute/algorithm/clear.hpp"
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
///////////////////////////////////FP8 Accumulation///////////////////////////
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
/// This class provides API to promote (add) or scale (multiply_add) the results
|
||||
/// from the tensor core accumulators to the main accumulators when the number
|
||||
/// of MMAs reaches the max number of MMA interval specified by user, after that
|
||||
/// the tensor core accumulators are zeroed.
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
|
||||
template <
|
||||
class EngineAccum,
|
||||
class LayoutAccum>
|
||||
struct GmmaFP8AccumulationWithScale {
|
||||
using TensorAccum = cute::Tensor<EngineAccum, LayoutAccum>;
|
||||
using ElementAccumulator = typename EngineAccum::value_type;
|
||||
|
||||
static_assert(is_static<LayoutAccum>::value, "Accumulator Layout should be static");
|
||||
static_assert(is_rmem<TensorAccum>::value , "Accumulator tensor must be rmem resident.");
|
||||
|
||||
private:
|
||||
TensorAccum& accum_;
|
||||
TensorAccum accum_temp_;
|
||||
|
||||
uint32_t accum_promotion_interval_; // defines the max num of executed MMAs after which accum should be promoted.
|
||||
uint32_t mma_count_per_mainloop_iteration_; // num of MMAs per k_tile of mainloop
|
||||
uint32_t mma_count_; // current executed MMAs
|
||||
uint32_t reset_accum_flag_; // accum needs to be zeroed or not.
|
||||
|
||||
// promote or `add` the partial accumulators to main accumulator (FADD).
|
||||
CUTLASS_DEVICE
|
||||
void promote_core() {
|
||||
warpgroup_wait<0>();
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(accum_); ++i) {
|
||||
accum_(i) += accum_temp_(i);
|
||||
}
|
||||
}
|
||||
|
||||
// `multiply` scale the partial accumulators and `add` to main accumulator (FFMA).
|
||||
template <
|
||||
class EngineScale,
|
||||
class LayoutScale>
|
||||
CUTLASS_DEVICE
|
||||
void scale_core(const cute::Tensor<EngineScale, LayoutScale> &scale) {
|
||||
using TensorScale = cute::Tensor<EngineScale, LayoutScale>;
|
||||
|
||||
static_assert(is_static<LayoutScale>::value, "Scale Layout should be static");
|
||||
static_assert(is_rmem<TensorScale>::value , "Scale tensor must be rmem resident.");
|
||||
|
||||
static_assert(LayoutAccum{}.shape() == LayoutScale{}.shape(), "Accumulator and scale must have same shape.");
|
||||
|
||||
warpgroup_wait<0>();
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(accum_); ++i) {
|
||||
accum_(i) += accum_temp_(i) * scale(i);
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
GmmaFP8AccumulationWithScale(
|
||||
TensorAccum &accum,
|
||||
uint32_t accum_promotion_interval,
|
||||
uint32_t mma_count_per_mainloop_iteration)
|
||||
: accum_(accum),
|
||||
accum_promotion_interval_(accum_promotion_interval),
|
||||
mma_count_per_mainloop_iteration_(mma_count_per_mainloop_iteration),
|
||||
mma_count_(0),
|
||||
reset_accum_flag_(0)
|
||||
{
|
||||
accum_temp_ = cute::make_fragment_like(accum);
|
||||
}
|
||||
|
||||
//
|
||||
// Methods (Common)
|
||||
//
|
||||
|
||||
CUTLASS_DEVICE
|
||||
TensorAccum& operator()() {
|
||||
return accum_temp_;
|
||||
}
|
||||
|
||||
/// prepare the MMA accumulators when initialization or zeroing is required.
|
||||
CUTLASS_DEVICE
|
||||
bool prepare_if_needed() {
|
||||
return reset_accum_flag_;
|
||||
}
|
||||
|
||||
//
|
||||
// Methods (for FADD version)
|
||||
//
|
||||
|
||||
/// promote (add) the results from the MMA accumulators to main accumulator if needed.
|
||||
CUTLASS_DEVICE
|
||||
void promote_if_needed() {
|
||||
mma_count_ += mma_count_per_mainloop_iteration_;
|
||||
reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0);
|
||||
if (reset_accum_flag_) {
|
||||
promote_core();
|
||||
mma_count_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// promote (add) the residue results from the MMA accumulators to main accumulator if needed.
|
||||
CUTLASS_DEVICE
|
||||
void promote_residue_if_needed() {
|
||||
if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) {
|
||||
promote_core();
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Methods (for FFMA version)
|
||||
//
|
||||
|
||||
/// scale (multiply_add) the results from the MMA accumulators to main accumulator if needed.
|
||||
template <
|
||||
class EngineScale,
|
||||
class LayoutScale>
|
||||
CUTLASS_DEVICE
|
||||
void scale_if_needed(const cute::Tensor<EngineScale, LayoutScale> &scale) {
|
||||
mma_count_ += mma_count_per_mainloop_iteration_;
|
||||
reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0);
|
||||
if (reset_accum_flag_) {
|
||||
scale_core(scale);
|
||||
mma_count_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// scale (multiply_add) the residue results from the MMA accumulators to main accumulator if needed.
|
||||
template <
|
||||
class EngineScale,
|
||||
class LayoutScale>
|
||||
CUTLASS_DEVICE
|
||||
void scale_residue_if_needed(const cute::Tensor<EngineScale, LayoutScale> &scale) {
|
||||
if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) {
|
||||
scale_core(scale);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
@ -1,729 +0,0 @@
|
||||
// clang-format off
|
||||
// Adapted (Heavily) from: https://github.com/soundOfDestiny/cutlass/blob/9d997ce0dea4c5fa1a617db6b7ff29aa9235822c/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
|
||||
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cute/arch/cluster_sm90.hpp"
|
||||
#include "cute/arch/copy_sm80.hpp"
|
||||
#include "cute/arch/copy_sm90.hpp"
|
||||
#include "cute/algorithm/functional.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cute/algorithm/gemm.hpp"
|
||||
#include "cute/numeric/arithmetic_tuple.hpp"
|
||||
|
||||
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass_extensions/gemm/collective/fp8_accumulation.hpp"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
using namespace cute;
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// WarpSpecialized Mainloop
|
||||
template <
|
||||
int Stages,
|
||||
class ClusterShape,
|
||||
class KernelSchedule,
|
||||
int ScaleGranularityM_,
|
||||
class TileShape_,
|
||||
class ElementA_,
|
||||
class StrideA_,
|
||||
class ElementB_,
|
||||
class StrideB_,
|
||||
class TiledMma_,
|
||||
class GmemTiledCopyA_,
|
||||
class SmemLayoutAtomA_,
|
||||
class SmemCopyAtomA_,
|
||||
class TransformA_,
|
||||
class GmemTiledCopyB_,
|
||||
class SmemLayoutAtomB_,
|
||||
class SmemCopyAtomB_,
|
||||
class TransformB_>
|
||||
struct CollectiveMma<
|
||||
MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8<Stages, ClusterShape, KernelSchedule, ScaleGranularityM_>,
|
||||
TileShape_,
|
||||
ElementA_,
|
||||
StrideA_,
|
||||
ElementB_,
|
||||
StrideB_,
|
||||
TiledMma_,
|
||||
GmemTiledCopyA_,
|
||||
SmemLayoutAtomA_,
|
||||
SmemCopyAtomA_,
|
||||
TransformA_,
|
||||
GmemTiledCopyB_,
|
||||
SmemLayoutAtomB_,
|
||||
SmemCopyAtomB_,
|
||||
TransformB_>
|
||||
{
|
||||
//
|
||||
// Type Aliases
|
||||
//
|
||||
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8<Stages, ClusterShape, KernelSchedule, ScaleGranularityM_>;
|
||||
using TileShape = TileShape_;
|
||||
using ElementA = ElementA_;
|
||||
using StrideA = StrideA_;
|
||||
using ElementB = ElementB_;
|
||||
using StrideB = StrideB_;
|
||||
using TiledMma = TiledMma_;
|
||||
using ElementAccumulator = typename TiledMma::ValTypeC;
|
||||
using ElementBlockScale = ElementAccumulator;
|
||||
using GmemTiledCopyA = GmemTiledCopyA_;
|
||||
using GmemTiledCopyB = GmemTiledCopyB_;
|
||||
using SmemLayoutAtomA = SmemLayoutAtomA_;
|
||||
using SmemLayoutAtomB = SmemLayoutAtomB_;
|
||||
using SmemCopyAtomA = SmemCopyAtomA_;
|
||||
using SmemCopyAtomB = SmemCopyAtomB_;
|
||||
using TransformA = TransformA_;
|
||||
using TransformB = TransformB_;
|
||||
using ArchTag = typename DispatchPolicy::ArchTag;
|
||||
|
||||
using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{}));
|
||||
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
|
||||
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
|
||||
using PipelineParams = typename MainloopPipeline::Params;
|
||||
|
||||
// Two threads per CTA are producers (1 for operand tile and 32 for scales)
|
||||
static constexpr int NumProducerThreadEvents = 33;
|
||||
|
||||
static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape{}) : ScaleGranularityM_;
|
||||
static constexpr int ScaleMsPerTile = size<0>(TileShape{}) / ScaleGranularityM;
|
||||
|
||||
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape.");
|
||||
|
||||
static_assert((size<0>(TileShape{}) % ScaleGranularityM) == 0, "FP8 scaling granularity must evenly divide tile shape along M.");
|
||||
|
||||
// Tile along modes in a way that maximizes the TMA box size.
|
||||
using SmemLayoutA = decltype(tile_to_shape(
|
||||
SmemLayoutAtomA{},
|
||||
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
|
||||
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
|
||||
using SmemLayoutB = decltype(tile_to_shape(
|
||||
SmemLayoutAtomB{},
|
||||
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int<DispatchPolicy::Stages>{}),
|
||||
cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}));
|
||||
|
||||
// Block scaling gmem-to-smem copy atom
|
||||
using SmemBlockScalingCopyAtomA = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
|
||||
using SmemBlockScalingCopyAtomB = Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementBlockScale>, ElementBlockScale>;
|
||||
|
||||
// Block scaling smem layout
|
||||
using SmemLayoutScaleA = Layout<Shape<Int<ScaleMsPerTile>, Int<DispatchPolicy::Stages>>>;
|
||||
using SmemLayoutScaleB = Layout<Shape<Int<DispatchPolicy::Stages>>, Stride<_1>>; // `ScaleNsPerTile` is always 1.
|
||||
|
||||
static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more.");
|
||||
static_assert(cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value &&
|
||||
cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeB>::value,
|
||||
"MMA atom must source both A and B operand from smem_desc for this mainloop.");
|
||||
static_assert(cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
|
||||
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
||||
static_assert(cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> || cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
|
||||
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
|
||||
static_assert(cute::is_same_v<ElementAccumulator, ElementBlockScale>,
|
||||
"ElementAccumulator and ElementBlockScale should be same datatype");
|
||||
|
||||
struct SharedStorage
|
||||
{
|
||||
struct TensorStorage : cute::aligned_struct<128> {
|
||||
cute::array_aligned<typename TiledMma::ValTypeA, cute::cosize_v<SmemLayoutA>> smem_A; // mxk
|
||||
cute::array_aligned<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B; // nxk
|
||||
cute::array_aligned<ElementBlockScale, cute::cosize_v<SmemLayoutScaleA>> smem_scale_A; // ScaleMsPerTile x k
|
||||
cute::array_aligned<ElementBlockScale, cute::cosize_v<SmemLayoutScaleB>> smem_scale_B; // 1xk
|
||||
} tensors;
|
||||
|
||||
using PipelineStorage = typename MainloopPipeline::SharedStorage;
|
||||
PipelineStorage pipeline;
|
||||
};
|
||||
using TensorStorage = typename SharedStorage::TensorStorage;
|
||||
using PipelineStorage = typename SharedStorage::PipelineStorage;
|
||||
|
||||
// Host side kernel arguments
|
||||
struct Arguments {
|
||||
ElementA const* ptr_A;
|
||||
StrideA dA;
|
||||
ElementB const* ptr_B;
|
||||
StrideB dB;
|
||||
ElementBlockScale const* ptr_scale_A;
|
||||
ElementBlockScale const* ptr_scale_B;
|
||||
};
|
||||
|
||||
// Device side kernel params
|
||||
struct Params {
|
||||
// Assumption: StrideA is congruent with Problem_MK
|
||||
using TMA_A = decltype(make_tma_copy_A_sm90(
|
||||
GmemTiledCopyA{},
|
||||
make_tensor(static_cast<ElementA const*>(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}),
|
||||
SmemLayoutA{}(_,_,0),
|
||||
TileShape{},
|
||||
ClusterShape{}));
|
||||
// Assumption: StrideB is congruent with Problem_NK
|
||||
using TMA_B = decltype(make_tma_copy_B_sm90(
|
||||
GmemTiledCopyB{},
|
||||
make_tensor(static_cast<ElementB const*>(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}),
|
||||
SmemLayoutB{}(_,_,0),
|
||||
TileShape{},
|
||||
ClusterShape{}));
|
||||
TMA_A tma_load_a;
|
||||
TMA_B tma_load_b;
|
||||
uint32_t tma_transaction_bytes = TmaTransactionBytes;
|
||||
uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK;
|
||||
uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK;
|
||||
// Block scaling factors for A and B
|
||||
ElementBlockScale const* ptr_scale_A;
|
||||
ElementBlockScale const* ptr_scale_B;
|
||||
};
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
template <class ProblemShape>
|
||||
static constexpr Params
|
||||
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||
(void) workspace;
|
||||
|
||||
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK)
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
|
||||
auto ptr_A = reinterpret_cast<ElementA const*>(args.ptr_A);
|
||||
auto ptr_B = reinterpret_cast<ElementB const*>(args.ptr_B);
|
||||
|
||||
Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA));
|
||||
Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB));
|
||||
typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90(
|
||||
GmemTiledCopyA{},
|
||||
tensor_a,
|
||||
SmemLayoutA{}(_,_,cute::Int<0>{}),
|
||||
TileShape{},
|
||||
ClusterShape{});
|
||||
typename Params::TMA_B tma_load_b = make_tma_copy_B_sm90(
|
||||
GmemTiledCopyB{},
|
||||
tensor_b,
|
||||
SmemLayoutB{}(_,_,cute::Int<0>{}),
|
||||
TileShape{},
|
||||
ClusterShape{});
|
||||
uint32_t transaction_bytes_mk = TmaTransactionBytesMK;
|
||||
uint32_t transaction_bytes_nk = TmaTransactionBytesNK;
|
||||
uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk;
|
||||
|
||||
return {
|
||||
tma_load_a,
|
||||
tma_load_b,
|
||||
transaction_bytes,
|
||||
transaction_bytes_mk,
|
||||
transaction_bytes_nk,
|
||||
args.ptr_scale_A,
|
||||
args.ptr_scale_B
|
||||
};
|
||||
}
|
||||
|
||||
template<class ProblemShape>
|
||||
static bool
|
||||
can_implement(
|
||||
ProblemShape const& problem_shape,
|
||||
[[maybe_unused]] Arguments const& args) {
|
||||
constexpr int tma_alignment_bits = 128;
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
|
||||
bool implementable = true;
|
||||
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
|
||||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
|
||||
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
|
||||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), StrideB{});
|
||||
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
|
||||
}
|
||||
return implementable;
|
||||
}
|
||||
|
||||
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
|
||||
static constexpr int K_PIPE_MMAS = 1;
|
||||
static constexpr uint32_t TmaTransactionBytesMK =
|
||||
cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(sizeof_bits<ElementA>::value));
|
||||
static constexpr uint32_t TmaTransactionBytesNK =
|
||||
cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(sizeof_bits<ElementB>::value));
|
||||
static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK;
|
||||
|
||||
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
||||
CUTLASS_DEVICE
|
||||
static void prefetch_tma_descriptors(Params const& mainloop_params)
|
||||
{
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor());
|
||||
}
|
||||
|
||||
/// Set up the data needed by this collective for load and mma.
|
||||
/// Returns a tuple of tensors. The collective and the kernel layer have the contract
|
||||
/// Returned tuple must contain at least two elements, with the first two elements being:
|
||||
/// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l)
|
||||
/// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l)
|
||||
template <class ProblemShape_MNKL>
|
||||
CUTLASS_DEVICE auto
|
||||
load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const {
|
||||
using X = Underscore;
|
||||
// Separate out problem shape for convenience
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
|
||||
// TMA requires special handling of strides to deal with coord codomain mapping
|
||||
// Represent the full tensors -- get these from TMA
|
||||
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l)
|
||||
Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l)
|
||||
|
||||
// Make tiled views, defer the slice
|
||||
Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l)
|
||||
Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l)
|
||||
|
||||
constexpr auto scales_m = Int<ScaleMsPerTile>{};
|
||||
auto tM = get<2>(gA_mkl.shape());
|
||||
auto tN = get<2>(gB_nkl.shape());
|
||||
auto tK = get<3>(gA_mkl.shape());
|
||||
|
||||
// Make the tiled views of scale tensors
|
||||
auto scaleA_shape = make_shape(M / ScaleGranularityM, tK, L); // (scale_m,k,l)
|
||||
auto scaleA_layout = make_ordered_layout(scaleA_shape, Step<_0, _1, _2>{});
|
||||
auto scaleB_shape = make_shape(tN, tK, L); // (n,k,l)
|
||||
auto scaleB_layout = make_ordered_layout(scaleB_shape, Step<_1, _0, _2>{});
|
||||
|
||||
// Note that mScaleA_mkl and mScaleB_nkl are already blocked tiled in the `m` host and
|
||||
// gScaleA_mkl and gScaleB_nkl in `g` global memory are same as mScaleA_mkl and mScaleB_nkl.
|
||||
Tensor mScaleA_mkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_A), scaleA_layout); // (scale_m,k,l)
|
||||
Tensor mScaleB_nkl = make_tensor(make_gmem_ptr(mainloop_params.ptr_scale_B), scaleB_layout); // (n,k,l)
|
||||
|
||||
return cute::make_tuple(gA_mkl, gB_nkl, mScaleA_mkl, mScaleB_nkl);
|
||||
}
|
||||
|
||||
/// Perform a collective-scoped matrix multiply-accumulate
|
||||
/// Producer Perspective
|
||||
template <
|
||||
class TensorA, class TensorB,
|
||||
class TensorScaleA, class TensorScaleB,
|
||||
class KTileIterator, class BlockCoord
|
||||
>
|
||||
CUTLASS_DEVICE void
|
||||
load(
|
||||
Params const& mainloop_params,
|
||||
MainloopPipeline pipeline,
|
||||
PipelineState smem_pipe_write,
|
||||
cute::tuple<TensorA, TensorB, TensorScaleA, TensorScaleB> const& load_inputs,
|
||||
BlockCoord const& blk_coord,
|
||||
KTileIterator k_tile_iter, int k_tile_count,
|
||||
int thread_idx,
|
||||
uint32_t block_rank_in_cluster,
|
||||
TensorStorage& shared_tensors) {
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
// Blockscaling: Tma loads for load_input and CpAsync for load_scale
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
Tensor sScaleA = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()), SmemLayoutScaleA{}); // (ScaleMsPerTile,k)
|
||||
Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k)
|
||||
|
||||
//
|
||||
// Prepare the TMA loads for A and B
|
||||
//
|
||||
|
||||
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
|
||||
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
|
||||
|
||||
Tensor gA_mkl = get<0>(load_inputs);
|
||||
Tensor gB_nkl = get<1>(load_inputs);
|
||||
|
||||
auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
|
||||
auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
|
||||
|
||||
// Partition the inputs based on the current block coordinates.
|
||||
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
|
||||
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
|
||||
Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k)
|
||||
|
||||
|
||||
// Block scaling: load_scale has scaling tensors in global memory which are not tiled
|
||||
Tensor mScaleA_mkl = get<2>(load_inputs);
|
||||
Tensor mScaleB_nkl = get<3>(load_inputs);
|
||||
auto scales_m = get<0>(mScaleA_mkl.shape());
|
||||
|
||||
Tensor cScaleA_mkl = make_identity_tensor(mScaleA_mkl.shape());
|
||||
|
||||
Tensor gScaleA = local_tile(
|
||||
mScaleA_mkl, make_tile(Int<ScaleMsPerTile>{}),
|
||||
make_coord(m_coord,_,l_coord)); // (ScaleMsPerTile,k,1)
|
||||
Tensor cScaleA = local_tile(
|
||||
cScaleA_mkl, make_tile(Int<ScaleMsPerTile>{}),
|
||||
make_coord(m_coord,_,l_coord));
|
||||
Tensor gScaleB = mScaleB_nkl(n_coord,_,l_coord); // (1,k,1)
|
||||
|
||||
// TODO: test `scale_copy_a` with `ScaleMsPerTile` < 128
|
||||
TiledCopy scale_copy_a = make_tiled_copy(SmemBlockScalingCopyAtomA{},
|
||||
Layout<Shape<_32>>{}, Layout<Shape<_1>>{}); // (1,1,1)
|
||||
TiledCopy scale_copy_b = make_tiled_copy(SmemBlockScalingCopyAtomB{},
|
||||
Layout<Shape<_1>>{}, Layout<Shape<_1>>{}); // (1,1,1)
|
||||
ThrCopy thr_scale_copy_a = scale_copy_a.get_slice(threadIdx.x);
|
||||
ThrCopy thr_scale_copy_b = scale_copy_b.get_slice(threadIdx.x);
|
||||
|
||||
Tensor tAgA_ScaleA = thr_scale_copy_a.partition_S(gScaleA);
|
||||
Tensor tAcA_ScaleA = thr_scale_copy_a.partition_S(cScaleA);
|
||||
Tensor tAsA_ScaleA = thr_scale_copy_a.partition_D(sScaleA);
|
||||
|
||||
Tensor tBgB_ScaleB = thr_scale_copy_b.partition_S(gScaleB);
|
||||
Tensor tBsB_ScaleB = thr_scale_copy_b.partition_D(sScaleB);
|
||||
|
||||
// Applies the mapping from block_tma_a
|
||||
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
|
||||
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
|
||||
|
||||
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
|
||||
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
|
||||
|
||||
uint16_t mcast_mask_a = 0;
|
||||
uint16_t mcast_mask_b = 0;
|
||||
|
||||
// Issue TmaLoads for GEMM operands A/B and CpAsync for scale tensors
|
||||
// Maps the tile -> block, value
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
|
||||
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
||||
for (int n = 0; n < size<1>(block_layout); ++n) {
|
||||
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
|
||||
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
|
||||
for (int m = 0; m < size<0>(block_layout); ++m) {
|
||||
mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{}));
|
||||
}
|
||||
}
|
||||
|
||||
// Allocate predicate tensors for a_scales (since we can't guarantee that
|
||||
// all scales are valid, since we could have a partial tiles along M)
|
||||
Tensor tApA_ScaleA = make_tensor<bool>(shape(tAsA_ScaleA(_,_,0)));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(tApA_ScaleA); ++i) {
|
||||
tApA_ScaleA(i) = get<0>(tAcA_ScaleA(i)) < scales_m;
|
||||
}
|
||||
|
||||
// Mainloop
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for ( ; k_tile_count > 0; --k_tile_count) {
|
||||
// LOCK smem_pipe_write for _writing_
|
||||
pipeline.producer_acquire(smem_pipe_write);
|
||||
|
||||
//
|
||||
// Copy gmem to smem for *k_tile_iter
|
||||
//
|
||||
int write_stage = smem_pipe_write.index();
|
||||
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
|
||||
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
|
||||
|
||||
// Copy operands A and B from global memory to shared memory
|
||||
if (lane_predicate) copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
|
||||
if (lane_predicate) copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
|
||||
|
||||
// Copy scale tensors from global memory to shared memory
|
||||
copy_if(scale_copy_a, tApA_ScaleA, tAgA_ScaleA(_,_,*k_tile_iter), tAsA_ScaleA(_,_,write_stage));
|
||||
copy(scale_copy_b, tBgB_ScaleB(_,*k_tile_iter), tBsB_ScaleB(_,write_stage));
|
||||
pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc);
|
||||
|
||||
++k_tile_iter;
|
||||
|
||||
// Advance smem_pipe_write
|
||||
++smem_pipe_write;
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
|
||||
CUTLASS_DEVICE void
|
||||
load_tail(
|
||||
MainloopPipeline pipeline,
|
||||
PipelineState smem_pipe_write) {
|
||||
int lane_predicate = cute::elect_one_sync();
|
||||
|
||||
// Issue the epilogue waits
|
||||
if (lane_predicate) {
|
||||
/* This helps avoid early exit of blocks in Cluster
|
||||
* Waits for all stages to either be released (all
|
||||
* Consumer UNLOCKs), or if the stage was never used
|
||||
* then would just be acquired since the phase was
|
||||
* still inverted from make_producer_start_state
|
||||
*/
|
||||
pipeline.producer_tail(smem_pipe_write);
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a collective-scoped matrix multiply-accumulate
|
||||
/// Consumer Perspective
|
||||
template <
|
||||
class FrgTensorC
|
||||
>
|
||||
CUTLASS_DEVICE void
|
||||
mma(MainloopPipeline pipeline,
|
||||
PipelineState smem_pipe_read,
|
||||
FrgTensorC& accum,
|
||||
int k_tile_count,
|
||||
int thread_idx,
|
||||
TensorStorage& shared_tensors,
|
||||
Params const& mainloop_params) {
|
||||
|
||||
|
||||
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");
|
||||
static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomA>,
|
||||
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
|
||||
static_assert(cute::is_void_v<SmemCopyAtomB>,
|
||||
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions.");
|
||||
|
||||
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
|
||||
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
|
||||
|
||||
// Block scaling
|
||||
Tensor sScaleAViewAsC = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_A.data()),
|
||||
Layout<
|
||||
Shape<Shape<Int<ScaleGranularityM>, Int<ScaleMsPerTile>>, cute::tuple_element_t<1, TileShape>, Int<DispatchPolicy::Stages>>,
|
||||
Stride<Stride<_0, _1>, _0, Int<ScaleMsPerTile>>
|
||||
>{}); // ((ScaleGranularityM,ScaleMsPerTile),n,k)
|
||||
Tensor sScaleB = make_tensor(cute::make_smem_ptr(shared_tensors.smem_scale_B.data()), SmemLayoutScaleB{}); // (k)
|
||||
|
||||
//
|
||||
// Define C accumulators and A/B partitioning
|
||||
//
|
||||
|
||||
// Layout of warp group to thread mapping
|
||||
|
||||
static_assert(stride<0>(typename TiledMma::ALayout{}) == 0 and
|
||||
stride<0>(typename TiledMma::BLayout{}) == 0 and
|
||||
size<0>(typename TiledMma::ALayout{}) == NumThreadsPerWarpGroup and
|
||||
size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup,
|
||||
"Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup");
|
||||
|
||||
constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup;
|
||||
Layout warp_group_thread_layout = make_layout(Int<MmaWarpGroups>{},
|
||||
Int<NumThreadsPerWarpGroup>{});
|
||||
|
||||
int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);
|
||||
|
||||
TiledMma tiled_mma;
|
||||
auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx));
|
||||
|
||||
Tensor tCsScaleAViewAsC = tiled_mma.get_slice(thread_idx).partition_C(sScaleAViewAsC); // (MMA,MMA_M,MMA_N,PIPE), `thread_mma` above is correct when partitioning A and B, but it is not correct when partitioning C.
|
||||
|
||||
Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
|
||||
// Allocate "fragments/descriptors"
|
||||
Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
|
||||
Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
|
||||
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
|
||||
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sA)); // PIPE
|
||||
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
|
||||
|
||||
//
|
||||
// PIPELINED MAIN LOOP
|
||||
//
|
||||
static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX),
|
||||
"ERROR : Incorrect number of MMAs in flight");
|
||||
|
||||
// We release buffers to producer warps(dma load) with some mmas in flight
|
||||
PipelineState smem_pipe_release = smem_pipe_read;
|
||||
|
||||
// Per block scale values for operand A and B
|
||||
|
||||
using RegLayoutScaleAViewAsC = decltype(make_layout_like(tCsScaleAViewAsC(_, _, _, 0).layout())); // `make_layout_like` makes a compact layout.
|
||||
using RegLayoutScaleAEssential = decltype(filter_zeros(RegLayoutScaleAViewAsC{}.stride(), RegLayoutScaleAViewAsC{}.shape())); // an interface to traverse the underlying storage for the compact layout mentioned above
|
||||
|
||||
Tensor tCrScaleAViewAsC = make_tensor<ElementBlockScale>(RegLayoutScaleAViewAsC{}); // (MMA,MMA_M,MMA_N)
|
||||
ElementBlockScale scale_b;
|
||||
|
||||
// Prologue GMMAs
|
||||
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
||||
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
|
||||
GmmaFP8AccumulationWithScale accumulation(accum, size<2>(TileShape{}) / size<2>(typename TiledMma::AtomShape_MNK{}), size<2>(tCrA));
|
||||
warpgroup_fence_operand(accumulation());
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue)
|
||||
{
|
||||
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
|
||||
if (accumulation.prepare_if_needed()) {
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
}
|
||||
|
||||
int read_stage = smem_pipe_read.index();
|
||||
|
||||
// Load per block scale values from shared memory to registers.
|
||||
scale_b = sScaleB[read_stage];
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
|
||||
tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{}));
|
||||
}
|
||||
if constexpr (ScaleMsPerTile == 1) {
|
||||
static_assert(size(RegLayoutScaleAEssential{}) == 1);
|
||||
tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`.
|
||||
} else {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
|
||||
tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b;
|
||||
}
|
||||
}
|
||||
|
||||
warpgroup_arrive();
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
// (V,M,K) x (V,N,K) => (V,M,N)
|
||||
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation());
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
|
||||
// Block scale the accumulators with reg tensor `tCrScaleAViewAsC`
|
||||
accumulation.scale_if_needed(tCrScaleAViewAsC);
|
||||
|
||||
++smem_pipe_read;
|
||||
}
|
||||
|
||||
warpgroup_fence_operand(accumulation());
|
||||
// Mainloop GMMAs
|
||||
k_tile_count -= prologue_mma_count;
|
||||
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for ( ; k_tile_count > 0; --k_tile_count)
|
||||
{
|
||||
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
|
||||
auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
|
||||
pipeline.consumer_wait(smem_pipe_read, barrier_token);
|
||||
|
||||
//
|
||||
// Compute on k_tile
|
||||
//
|
||||
|
||||
int read_stage = smem_pipe_read.index();
|
||||
|
||||
// Load per block scale values from shared memory to registers (at most twice per block along M and exactly once per block along N)
|
||||
scale_b = sScaleB[read_stage];
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
|
||||
tCrScaleAViewAsC.data()[i] = tCsScaleAViewAsC(_, _, _, read_stage)(idx2crd(i, RegLayoutScaleAEssential{}));
|
||||
}
|
||||
if constexpr (ScaleMsPerTile == 1) {
|
||||
static_assert(size(RegLayoutScaleAEssential{}) == 1);
|
||||
tCrScaleAViewAsC.data()[0] = __shfl_sync(0xffffffff, tCrScaleAViewAsC.data()[0] * scale_b, 0); // `tCrScaleAViewAsC.data()[0]` are all same in a warp group when `ScaleMsPerTile == 1`.
|
||||
} else {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(RegLayoutScaleAEssential{}); i++) {
|
||||
tCrScaleAViewAsC.data()[i] = tCrScaleAViewAsC.data()[i] * scale_b;
|
||||
}
|
||||
}
|
||||
|
||||
if (accumulation.prepare_if_needed()) {
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||
}
|
||||
|
||||
warpgroup_fence_operand(accumulation());
|
||||
warpgroup_arrive();
|
||||
// Unroll the K mode manually to set scale D to 1
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
|
||||
// (V,M,K) x (V,N,K) => (V,M,N)
|
||||
cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation());
|
||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
|
||||
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed
|
||||
warpgroup_wait<K_PIPE_MMAS>();
|
||||
warpgroup_fence_operand(accumulation());
|
||||
|
||||
// Block scale the accumulators with reg tensor `tCrScaleAViewAsC`
|
||||
accumulation.scale_if_needed(tCrScaleAViewAsC);
|
||||
|
||||
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
|
||||
|
||||
// Advance smem_pipe_read and smem_pipe_release
|
||||
++smem_pipe_read;
|
||||
++smem_pipe_release;
|
||||
}
|
||||
|
||||
accumulation.scale_residue_if_needed(tCrScaleAViewAsC);
|
||||
|
||||
warpgroup_fence_operand(accumulation());
|
||||
}
|
||||
|
||||
/// Perform a Consumer Epilogue to release all buffers
|
||||
CUTLASS_DEVICE void
|
||||
mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) {
|
||||
// Prologue GMMAs
|
||||
int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count);
|
||||
k_tile_count -= prologue_mma_count;
|
||||
|
||||
smem_pipe_release.advance(k_tile_count);
|
||||
|
||||
// Wait on all GMMAs to complete
|
||||
warpgroup_wait<0>();
|
||||
|
||||
for (int count = 0; count < prologue_mma_count; ++count) {
|
||||
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
|
||||
++smem_pipe_release;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm::collective
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,39 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
|
||||
namespace cutlass::gemm {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// FP8 related policies (including Blocked Scaled Accumulation)
|
||||
// `ScaleGranularityM` specifies scaling granularity along M, while zero-value
|
||||
// `ScaleGranularityM` indicates that scaling granularity is
|
||||
// `size<0>(TileShape_MNK{})` along M.
|
||||
template <int ScaleGranularityM = 0>
|
||||
struct KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum
|
||||
: KernelTmaWarpSpecializedCooperative {};
|
||||
|
||||
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp
|
||||
// specialized dynamic schedule For FP8 kernels with Block Scaling
|
||||
template <int Stages_, class ClusterShape_ = Shape<_1, _1, _1>,
|
||||
class KernelSchedule = KernelTmaWarpSpecialized,
|
||||
int ScaleGranularityM =
|
||||
0 // `ScaleGranularityM` specifies scaling granularity along M,
|
||||
// while zero-value `ScaleGranularityM` indicates that scaling
|
||||
// granularity is `size<0>(TileShape_MNK{})` along M.
|
||||
>
|
||||
struct MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8
|
||||
: MainloopSm90TmaGmmaWarpSpecialized<Stages_, ClusterShape_,
|
||||
KernelSchedule> {
|
||||
static_assert(
|
||||
cute::is_same_v<
|
||||
KernelSchedule,
|
||||
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<
|
||||
ScaleGranularityM>>,
|
||||
"KernelSchedule must be one of the warp specialized policies");
|
||||
};
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::gemm
|
||||
@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
using namespace cute;
|
||||
|
||||
@ -1,15 +1,10 @@
|
||||
#include "type_convert.cuh"
|
||||
#include "dispatch_utils.h"
|
||||
#include "cub_helpers.h"
|
||||
|
||||
#include <torch/cuda.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cub/cub.cuh>
|
||||
#else
|
||||
#include <hipcub/hipcub.hpp>
|
||||
#endif
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// TODO(woosuk): Further optimize this kernel.
|
||||
@ -30,7 +25,7 @@ __global__ void rms_norm_kernel(
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
|
||||
variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||
@ -85,7 +80,7 @@ fused_add_rms_norm_kernel(
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
|
||||
variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||
@ -126,7 +121,7 @@ fused_add_rms_norm_kernel(
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
|
||||
variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||
@ -140,6 +135,211 @@ fused_add_rms_norm_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
/* Function specialization in the case of FP16/BF16 tensors.
|
||||
Additional optimizations we can make in this case are
|
||||
packed and vectorized operations, which help with the
|
||||
memory latency bottleneck.
|
||||
|
||||
_f16VecPN struct extends _f16Vec to add operations specifically required for
|
||||
polynomial normalization (poly norm).
|
||||
The original _f16Vec does not include the sum-of-powers computation or
|
||||
in-place polynomial normalization logic. */
|
||||
template <typename scalar_t, int width>
|
||||
struct alignas(16) _f16VecPN : _f16Vec<scalar_t, width> {
|
||||
using Base = _f16Vec<scalar_t, width>;
|
||||
using Converter = typename Base::Converter;
|
||||
using T1 = typename Base::T1;
|
||||
using T2 = typename Base::T2;
|
||||
using Base::data;
|
||||
|
||||
__device__ auto sum_pows() const {
|
||||
float s2 = 0.0f, s4 = 0.0f, s6 = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; i += 2) {
|
||||
float2 z = Converter::convert(T2{data[i], data[i + 1]});
|
||||
float x2 = z.x * z.x;
|
||||
float x4 = x2 * x2;
|
||||
float x6 = x4 * x2;
|
||||
|
||||
float y2 = z.y * z.y;
|
||||
float y4 = y2 * y2;
|
||||
float y6 = y4 * y2;
|
||||
|
||||
s2 += x2 + y2;
|
||||
s4 += x4 + y4;
|
||||
s6 += x6 + y6;
|
||||
}
|
||||
return std::make_tuple(s2, s4, s6);
|
||||
}
|
||||
|
||||
__device__ void poly_norm_inplace(const float w2_inv_std,
|
||||
const float w1_inv_std2,
|
||||
const float w0_inv_std3, const float bias) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; i += 2) {
|
||||
float2 z = Converter::convert(T2{data[i], data[i + 1]});
|
||||
|
||||
float x2 = z.x * z.x;
|
||||
float x3 = x2 * z.x;
|
||||
z.x = w2_inv_std * z.x + w1_inv_std2 * x2 + w0_inv_std3 * x3 + bias;
|
||||
|
||||
float y2 = z.y * z.y;
|
||||
float y3 = y2 * z.y;
|
||||
z.y = w2_inv_std * z.y + w1_inv_std2 * y2 + w0_inv_std3 * y3 + bias;
|
||||
|
||||
auto out = Converter::convert(z);
|
||||
data[i] = out.x;
|
||||
data[i + 1] = out.y;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename scalar_t, int width>
|
||||
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
|
||||
poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [3]
|
||||
const scalar_t* __restrict__ bias, // [1]
|
||||
const float epsilon, const int hidden_size) {
|
||||
// Sanity checks on our vector struct and type-punned pointer arithmetic
|
||||
static_assert(std::is_pod_v<_f16VecPN<scalar_t, width>>);
|
||||
static_assert(sizeof(_f16VecPN<scalar_t, width>) == sizeof(scalar_t) * width);
|
||||
|
||||
/* These and the argument pointers are all declared `restrict` as they are
|
||||
not aliased in practice. Argument pointers should not be dereferenced
|
||||
in this kernel as that would be undefined behavior */
|
||||
auto* __restrict__ input_v =
|
||||
reinterpret_cast<const _f16VecPN<scalar_t, width>*>(input);
|
||||
const int vec_hidden_size = hidden_size / width;
|
||||
float variance = 0.0f;
|
||||
float variance2 = 0.0f;
|
||||
float variance3 = 0.0f;
|
||||
|
||||
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
|
||||
int id = blockIdx.x * vec_hidden_size + idx;
|
||||
_f16VecPN<scalar_t, width> temp = input_v[id];
|
||||
auto [x2, x4, x6] = temp.sum_pows();
|
||||
|
||||
variance += x2;
|
||||
variance2 += x4;
|
||||
variance3 += x6;
|
||||
}
|
||||
|
||||
float3 thread_variances = make_float3(variance, variance2, variance3);
|
||||
|
||||
struct SumOp {
|
||||
__device__ float3 operator()(const float3& a, const float3& b) const {
|
||||
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
|
||||
}
|
||||
};
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float3, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
float3 block_variances =
|
||||
BlockReduce(reduceStore).Reduce(thread_variances, SumOp{}, blockDim.x);
|
||||
|
||||
variance = block_variances.x;
|
||||
variance2 = block_variances.y;
|
||||
variance3 = block_variances.z;
|
||||
|
||||
__shared__ float s_w2_inv_std;
|
||||
__shared__ float s_w1_inv_std2;
|
||||
__shared__ float s_w0_inv_std3;
|
||||
__shared__ float s_bias;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
float w0 = (float)weight[0];
|
||||
float w1 = (float)weight[1];
|
||||
float w2 = (float)weight[2];
|
||||
s_bias = (float)bias[0];
|
||||
|
||||
s_w2_inv_std = w2 * rsqrtf(variance / hidden_size + epsilon);
|
||||
s_w1_inv_std2 = w1 * rsqrtf(variance2 / hidden_size + epsilon);
|
||||
s_w0_inv_std3 = w0 * rsqrtf(variance3 / hidden_size + epsilon);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
auto* __restrict__ out_v = reinterpret_cast<_f16VecPN<scalar_t, width>*>(out);
|
||||
|
||||
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
|
||||
int id = blockIdx.x * vec_hidden_size + idx;
|
||||
_f16VecPN<scalar_t, width> temp = input_v[id];
|
||||
temp.poly_norm_inplace(s_w2_inv_std, s_w1_inv_std2, s_w0_inv_std3, s_bias);
|
||||
out_v[id] = temp;
|
||||
}
|
||||
}
|
||||
|
||||
/* Generic poly_norm_kernel
|
||||
The width field is not used here but necessary for other specializations.
|
||||
*/
|
||||
template <typename scalar_t, int width>
|
||||
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
|
||||
poly_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [3]
|
||||
const scalar_t* __restrict__ bias, // [1]
|
||||
const float epsilon, const int hidden_size) {
|
||||
float variance = 0.0f;
|
||||
float variance2 = 0.0f;
|
||||
float variance3 = 0.0f;
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||
float x2 = x * x;
|
||||
float x4 = x2 * x2;
|
||||
float x6 = x4 * x2;
|
||||
|
||||
variance += x2;
|
||||
variance2 += x4;
|
||||
variance3 += x6;
|
||||
}
|
||||
|
||||
float3 thread_variances = make_float3(variance, variance2, variance3);
|
||||
|
||||
struct SumOp {
|
||||
__device__ float3 operator()(const float3& a, const float3& b) const {
|
||||
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
|
||||
}
|
||||
};
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float3, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
float3 block_variances =
|
||||
BlockReduce(reduceStore).Reduce(thread_variances, SumOp{}, blockDim.x);
|
||||
|
||||
variance = block_variances.x;
|
||||
variance2 = block_variances.y;
|
||||
variance3 = block_variances.z;
|
||||
|
||||
__shared__ float s_w2_inv_std;
|
||||
__shared__ float s_w1_inv_std2;
|
||||
__shared__ float s_w0_inv_std3;
|
||||
__shared__ float s_bias;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
float w0 = (float)weight[0];
|
||||
float w1 = (float)weight[1];
|
||||
float w2 = (float)weight[2];
|
||||
s_bias = (float)bias[0];
|
||||
|
||||
s_w2_inv_std = w2 * rsqrtf(variance / hidden_size + epsilon);
|
||||
s_w1_inv_std2 = w1 * rsqrtf(variance2 / hidden_size + epsilon);
|
||||
s_w0_inv_std3 = w0 * rsqrtf(variance3 / hidden_size + epsilon);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||
float x2 = x * x;
|
||||
float x3 = x2 * x;
|
||||
|
||||
out[blockIdx.x * hidden_size + idx] =
|
||||
(scalar_t)(x * s_w2_inv_std + x2 * s_w1_inv_std2 + x3 * s_w0_inv_std3 +
|
||||
s_bias);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void rms_norm(torch::Tensor& out, // [..., hidden_size]
|
||||
@ -219,3 +419,49 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
|
||||
LAUNCH_FUSED_ADD_RMS_NORM(0);
|
||||
}
|
||||
}
|
||||
|
||||
#define LAUNCH_FUSED_POLY_NORM(width) \
|
||||
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "poly_norm_kernel", [&] { \
|
||||
vllm::poly_norm_kernel<scalar_t, width><<<grid, block, 0, stream>>>( \
|
||||
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), \
|
||||
weight.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(), epsilon, \
|
||||
hidden_size); \
|
||||
});
|
||||
|
||||
void poly_norm(torch::Tensor& out, // [..., hidden_size]
|
||||
torch::Tensor& input, // [..., hidden_size]
|
||||
torch::Tensor& weight, // [3]
|
||||
torch::Tensor& bias, // [1]
|
||||
double epsilon) {
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(out.data_ptr() != input.data_ptr());
|
||||
|
||||
int hidden_size = input.size(-1);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
/* This kernel is memory-latency bound in many scenarios.
|
||||
When num_tokens is large, a smaller block size allows
|
||||
for increased block occupancy on CUs and better latency
|
||||
hiding on global mem ops. */
|
||||
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
|
||||
dim3 block(std::min(hidden_size, max_block_size));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
/*If the tensor types are FP16/BF16, try to use the optimized kernel
|
||||
with packed + vectorized ops.
|
||||
Max optimization is achieved with a width-8 vector of FP16/BF16s
|
||||
since we can load at most 128 bits at once in a global memory op.
|
||||
However, this requires each tensor's data to be aligned to 16
|
||||
bytes.
|
||||
*/
|
||||
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
|
||||
auto out_ptr = reinterpret_cast<std::uintptr_t>(out.data_ptr());
|
||||
bool ptrs_are_aligned = inp_ptr % 16 == 0 && out_ptr % 16 == 0;
|
||||
if (ptrs_are_aligned && hidden_size % 8 == 0) {
|
||||
LAUNCH_FUSED_POLY_NORM(8);
|
||||
} else {
|
||||
LAUNCH_FUSED_POLY_NORM(0);
|
||||
}
|
||||
}
|
||||
|
||||
@ -8,16 +8,11 @@
|
||||
#include "type_convert.cuh"
|
||||
#include "quantization/fp8/common.cuh"
|
||||
#include "dispatch_utils.h"
|
||||
#include "cub_helpers.h"
|
||||
|
||||
#include <torch/cuda.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cub/cub.cuh>
|
||||
#else
|
||||
#include <hipcub/hipcub.hpp>
|
||||
#endif
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// TODO(woosuk): Further optimize this kernel.
|
||||
@ -39,7 +34,7 @@ __global__ void rms_norm_static_fp8_quant_kernel(
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
|
||||
variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||
@ -100,7 +95,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
|
||||
variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||
@ -149,7 +144,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
|
||||
variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||
|
||||
@ -21,6 +21,7 @@
|
||||
#include <torch/all.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda/std/limits>
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
namespace cg = cooperative_groups;
|
||||
@ -28,7 +29,6 @@ namespace cg = cooperative_groups;
|
||||
namespace vllm {
|
||||
namespace moe {
|
||||
|
||||
constexpr float kNegInfinity = INFINITY * -1;
|
||||
constexpr unsigned FULL_WARP_MASK = 0xffffffff;
|
||||
constexpr int32_t WARP_SIZE = 32;
|
||||
constexpr int32_t BLOCK_SIZE = 512;
|
||||
@ -411,14 +411,21 @@ __device__ inline float cuda_cast<float, __nv_bfloat16>(__nv_bfloat16 val) {
|
||||
return __bfloat162float(val);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ inline T neg_inf() {
|
||||
// cuda::std::numeric_limits<T>::infinity() returns `0` for [T=bf16 or fp16]
|
||||
// so we need to cast from fp32
|
||||
return cuda_cast<T, float>(-cuda::std::numeric_limits<float>::infinity());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void topk_with_k2(T* output, T const* input,
|
||||
cg::thread_block_tile<32> const& tile,
|
||||
int32_t const lane_id,
|
||||
int const num_experts_per_group) {
|
||||
// Get the top2 per thread
|
||||
T largest = -INFINITY;
|
||||
T second_largest = -INFINITY;
|
||||
T largest = neg_inf<T>();
|
||||
T second_largest = neg_inf<T>();
|
||||
|
||||
if (num_experts_per_group > WARP_SIZE) {
|
||||
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
|
||||
@ -513,8 +520,8 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
warp_id * topk;
|
||||
s_topk_idx += warp_id * topk;
|
||||
|
||||
T value = kNegInfinity;
|
||||
T topk_group_value = kNegInfinity;
|
||||
T value = neg_inf<T>();
|
||||
T topk_group_value = neg_inf<T>();
|
||||
int32_t num_equalto_topkth_group;
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
@ -525,11 +532,8 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
if (case_id < num_tokens) {
|
||||
// calculate group_idx
|
||||
int32_t target_num_min = WARP_SIZE - n_group + topk_group;
|
||||
if (lane_id < n_group &&
|
||||
(isfinite(cuda_cast<float, T>(
|
||||
group_scores[lane_id])))) // The check is necessary to avoid
|
||||
// abnormal input
|
||||
{
|
||||
// The check is necessary to avoid abnormal input
|
||||
if (lane_id < n_group && cuda::std::isfinite(group_scores[lane_id])) {
|
||||
value = group_scores[lane_id];
|
||||
}
|
||||
|
||||
@ -540,11 +544,11 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
__syncwarp(); // Ensure all threads have valid data before reduction
|
||||
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
|
||||
if (value == topk_group_value) {
|
||||
value = kNegInfinity;
|
||||
value = neg_inf<T>();
|
||||
}
|
||||
pre_count_equal_to_top_value = count_equal_to_top_value;
|
||||
count_equal_to_top_value = __popc(__ballot_sync(
|
||||
FULL_WARP_MASK, (value == cuda_cast<T, float>(kNegInfinity))));
|
||||
count_equal_to_top_value =
|
||||
__popc(__ballot_sync(FULL_WARP_MASK, (value == neg_inf<T>())));
|
||||
}
|
||||
num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
|
||||
}
|
||||
@ -552,11 +556,10 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
|
||||
warp_topk::WarpSelect</*capability*/ WARP_SIZE, /*greater*/ true, T, int32_t,
|
||||
/* is_stable */ true>
|
||||
queue((int32_t)topk, -INFINITY);
|
||||
queue((int32_t)topk, neg_inf<T>());
|
||||
|
||||
int count_equalto_topkth_group = 0;
|
||||
bool if_proceed_next_topk =
|
||||
(topk_group_value != cuda_cast<T, float>(kNegInfinity));
|
||||
bool if_proceed_next_topk = topk_group_value != neg_inf<T>();
|
||||
if (case_id < num_tokens && if_proceed_next_topk) {
|
||||
for (int i_group = 0; i_group < n_group; i_group++) {
|
||||
if ((group_scores[i_group] > topk_group_value) ||
|
||||
@ -566,10 +569,10 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
for (int32_t i = lane_id; i < align_num_experts_per_group;
|
||||
i += WARP_SIZE) {
|
||||
T candidates =
|
||||
(i < num_experts_per_group) && isfinite(cuda_cast<float, T>(
|
||||
scores_with_bias[offset + i]))
|
||||
(i < num_experts_per_group) &&
|
||||
cuda::std::isfinite(scores_with_bias[offset + i])
|
||||
? scores_with_bias[offset + i]
|
||||
: cuda_cast<T, float>(kNegInfinity);
|
||||
: neg_inf<T>();
|
||||
queue.add(candidates, offset + i);
|
||||
}
|
||||
if (group_scores[i_group] == topk_group_value) {
|
||||
@ -598,7 +601,8 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
if (i < topk) {
|
||||
s_topk_value[i] = value;
|
||||
}
|
||||
topk_sum += reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
|
||||
topk_sum +=
|
||||
cg::reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -20,17 +20,7 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include "../cuda_compat.h"
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cub/util_type.cuh>
|
||||
#include <cub/cub.cuh>
|
||||
#include <cuda/std/functional>
|
||||
using AddOp = cuda::std::plus<float>;
|
||||
#else
|
||||
#include <hipcub/util_type.hpp>
|
||||
#include <hipcub/hipcub.hpp>
|
||||
using AddOp = cub::Sum;
|
||||
#endif
|
||||
#include "../cub_helpers.h"
|
||||
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
@ -79,7 +69,7 @@ __launch_bounds__(TPB) __global__
|
||||
threadData = max(static_cast<float>(input[idx]), threadData);
|
||||
}
|
||||
|
||||
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
|
||||
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, CubMaxOp());
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
float_max = maxElem;
|
||||
@ -94,7 +84,7 @@ __launch_bounds__(TPB) __global__
|
||||
threadData += exp((static_cast<float>(input[idx]) - float_max));
|
||||
}
|
||||
|
||||
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, AddOp());
|
||||
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, CubAddOp());
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
|
||||
19
csrc/ops.h
19
csrc/ops.h
@ -92,6 +92,9 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
||||
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
|
||||
torch::Tensor& weight, double epsilon);
|
||||
|
||||
void poly_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
||||
torch::Tensor& bias, double epsilon);
|
||||
|
||||
void apply_repetition_penalties_(torch::Tensor& logits,
|
||||
const torch::Tensor& prompt_mask,
|
||||
const torch::Tensor& output_mask,
|
||||
@ -119,12 +122,6 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
||||
std::optional<torch::Tensor> key, int64_t head_size,
|
||||
torch::Tensor& cos_sin_cache, bool is_neox);
|
||||
|
||||
void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
||||
std::optional<torch::Tensor> key,
|
||||
int64_t head_size, torch::Tensor& cos_sin_cache,
|
||||
bool is_neox, int64_t rot_dim,
|
||||
torch::Tensor& cos_sin_cache_offsets);
|
||||
|
||||
void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
|
||||
@ -136,6 +133,12 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& out,
|
||||
torch::Tensor& input,
|
||||
torch::Tensor& input_global_scale);
|
||||
#endif
|
||||
void silu_mul_fp8_quant_deep_gemm_cuda(
|
||||
const at::Tensor& input, // (E, T, 2*H)
|
||||
const at::Tensor& counts, // (E)
|
||||
at::Tensor& y_q, // (E, T, H) [OUT]
|
||||
at::Tensor& y_s, // (E, T, H//group_size) [OUT]
|
||||
int64_t group_size, bool use_ue8m0, int64_t num_parallel_tokens);
|
||||
|
||||
void mul_and_silu(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
@ -344,6 +347,8 @@ std::tuple<int64_t, torch::Tensor> allocate_shared_buffer_and_handle(
|
||||
int64_t open_mem_handle(torch::Tensor& mem_handle);
|
||||
void free_shared_buffer(int64_t buffer);
|
||||
|
||||
torch::Tensor hadacore_transform(torch::Tensor& x, bool inplace);
|
||||
|
||||
#ifdef USE_ROCM
|
||||
fptr_t init_custom_qr(int64_t rank, int64_t world_size,
|
||||
std::optional<int64_t> qr_max_size = std::nullopt);
|
||||
@ -353,4 +358,4 @@ void qr_open_handles(fptr_t _fa, const std::vector<torch::Tensor>& handles);
|
||||
void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
|
||||
int64_t quant_level, bool cast_bf2half = false);
|
||||
int64_t qr_max_size();
|
||||
#endif
|
||||
#endif
|
||||
|
||||
@ -99,35 +99,6 @@ __global__ void rotary_embedding_kernel(
|
||||
token_idx, query_stride, key_stride, head_stride);
|
||||
}
|
||||
|
||||
template <typename scalar_t, bool IS_NEOX>
|
||||
__global__ void batched_rotary_embedding_kernel(
|
||||
const int64_t* __restrict__ positions, // [batch_size, seq_len] or
|
||||
// [num_tokens]
|
||||
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
|
||||
// head_size] or [num_tokens, num_heads,
|
||||
// head_size]
|
||||
scalar_t* __restrict__ key, // nullptr or
|
||||
// [batch_size, seq_len, num_kv_heads,
|
||||
// head_size] or [num_tokens, num_kv_heads,
|
||||
// head_size]
|
||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
|
||||
// 2]
|
||||
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len]
|
||||
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
|
||||
const int64_t head_stride, const int num_heads, const int num_kv_heads,
|
||||
const int head_size) {
|
||||
// Each thread block is responsible for one token.
|
||||
const int token_idx = blockIdx.x;
|
||||
int64_t pos = positions[token_idx];
|
||||
int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx];
|
||||
const scalar_t* cache_ptr =
|
||||
cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;
|
||||
|
||||
apply_rotary_embedding<scalar_t, IS_NEOX>(
|
||||
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
|
||||
token_idx, query_stride, key_stride, head_stride);
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void rotary_embedding(
|
||||
@ -211,96 +182,3 @@ void rotary_embedding(
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/*
|
||||
Batched version of rotary embedding, pack multiple LoRAs together
|
||||
and process in batched manner.
|
||||
*/
|
||||
void batched_rotary_embedding(
|
||||
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
|
||||
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
|
||||
// [num_tokens, num_heads * head_size] or
|
||||
// [batch_size, seq_len, num_heads, head_size] or
|
||||
// [num_tokens, num_heads, head_size]
|
||||
std::optional<torch::Tensor>
|
||||
key, // null or
|
||||
// [batch_size, seq_len, num_kv_heads * head_size] or
|
||||
// [num_tokens, num_kv_heads * head_size] or
|
||||
// [batch_size, seq_len, num_heads, head_size] or
|
||||
// [num_tokens, num_heads, head_size]
|
||||
int64_t head_size,
|
||||
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||
bool is_neox, int64_t rot_dim,
|
||||
torch::Tensor& cos_sin_cache_offsets // [num_tokens] or [batch_size]
|
||||
) {
|
||||
// num_tokens = batch_size * seq_len
|
||||
int64_t num_tokens = cos_sin_cache_offsets.size(0);
|
||||
TORCH_CHECK(
|
||||
positions.size(0) == num_tokens || positions.numel() == num_tokens,
|
||||
"positions must have the same num_tokens or batch_size as "
|
||||
"cos_sin_cache_offsets");
|
||||
|
||||
int positions_ndim = positions.dim();
|
||||
// Make sure num_tokens dim is consistent across positions, query, and key
|
||||
TORCH_CHECK(
|
||||
positions_ndim == 1 || positions_ndim == 2,
|
||||
"positions must have shape [num_tokens] or [batch_size, seq_len]");
|
||||
if (positions_ndim == 1) {
|
||||
TORCH_CHECK(query.size(0) == positions.size(0) &&
|
||||
(!key.has_value() || key->size(0) == positions.size(0)),
|
||||
"query, key and positions must have the same number of tokens");
|
||||
}
|
||||
if (positions_ndim == 2) {
|
||||
TORCH_CHECK(
|
||||
query.size(0) == positions.size(0) &&
|
||||
(!key.has_value() || key->size(0) == positions.size(0)) &&
|
||||
query.size(1) == positions.size(1) &&
|
||||
(!key.has_value() || key->size(1) == positions.size(1)),
|
||||
"query, key and positions must have the same batch_size and seq_len");
|
||||
}
|
||||
|
||||
// Make sure head_size is valid for query and key
|
||||
int query_hidden_size = query.numel() / num_tokens;
|
||||
int key_hidden_size = key.has_value() ? key->numel() / num_tokens : 0;
|
||||
TORCH_CHECK(query_hidden_size % head_size == 0);
|
||||
TORCH_CHECK(key_hidden_size % head_size == 0);
|
||||
|
||||
// Make sure query and key have concistent number of heads
|
||||
int num_heads = query_hidden_size / head_size;
|
||||
int num_kv_heads = key.has_value() ? key_hidden_size / head_size : num_heads;
|
||||
TORCH_CHECK(num_heads % num_kv_heads == 0);
|
||||
|
||||
int seq_dim_idx = positions_ndim - 1;
|
||||
int64_t query_stride = query.stride(seq_dim_idx);
|
||||
int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
|
||||
// Determine head stride: for [*, heads, head_size] use stride of last dim;
|
||||
// for flat [*, heads*head_size], heads blocks are contiguous of size
|
||||
// head_size
|
||||
int query_ndim = query.dim();
|
||||
int64_t head_stride =
|
||||
(query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
|
||||
if (is_neox) {
|
||||
vllm::batched_rotary_embedding_kernel<scalar_t, true>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
|
||||
key_stride, head_stride, num_heads, num_kv_heads, head_size);
|
||||
} else {
|
||||
vllm::batched_rotary_embedding_kernel<scalar_t, false>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
|
||||
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
|
||||
key_stride, head_stride, num_heads, num_kv_heads, head_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@ -9,6 +9,26 @@
|
||||
|
||||
#include "quantization/fp8/common.cuh"
|
||||
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_fp8.h>
|
||||
#else
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hip/hip_fp8.h>
|
||||
|
||||
typedef __hip_bfloat162 __nv_bfloat162;
|
||||
typedef __hip_bfloat16 __nv_bfloat16;
|
||||
typedef __hip_bfloat16_raw __nv_bfloat16_raw;
|
||||
|
||||
typedef __hip_fp8_e4m3 __nv_fp8_e4m3;
|
||||
typedef __hip_fp8x4_e4m3 __nv_fp8x4_e4m3;
|
||||
#endif
|
||||
|
||||
#include "core/registration.h"
|
||||
namespace vllm {
|
||||
|
||||
template <typename T>
|
||||
@ -87,6 +107,336 @@ __global__ void act_and_mul_quant_kernel(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float silu(float x) {
|
||||
return (__fdividef(x, (1.f + expf(-x))));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float2 silu2(float2 x) {
|
||||
return make_float2(silu(x.x), silu(x.y));
|
||||
}
|
||||
|
||||
#ifndef USE_ROCM
|
||||
__device__ __forceinline__ float warp_max(float v) {
|
||||
static constexpr unsigned FULL_MASK = 0xffffffffu;
|
||||
for (int offset = 1; offset < WARP_SIZE; offset *= 2) {
|
||||
v = fmaxf(v, __shfl_xor_sync(FULL_MASK, v, offset));
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ __nv_bfloat16 warp_max(__nv_bfloat16 v) {
|
||||
static constexpr unsigned FULL_MASK = 0xffffffffu;
|
||||
for (int offset = 1; offset < WARP_SIZE; offset *= 2) {
|
||||
v = __hmax(v, __shfl_xor_sync(FULL_MASK, v, offset));
|
||||
}
|
||||
return v;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename T, typename U>
|
||||
__device__ __forceinline__ void cp_async4(T* _smem_ptr, const U* _glob_ptr) {
|
||||
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
|
||||
auto smem_ptr = reinterpret_cast<void*>(_smem_ptr);
|
||||
auto glob_ptr = reinterpret_cast<const void*>(_glob_ptr);
|
||||
const int BYTES = 16;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
||||
"}\n" ::"r"(smem),
|
||||
"l"(glob_ptr), "n"(BYTES));
|
||||
#else
|
||||
_smem_ptr[0] = _glob_ptr[0];
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void cp_async_fence() {
|
||||
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
|
||||
asm volatile("cp.async.commit_group;\n" ::);
|
||||
#else
|
||||
#endif
|
||||
}
|
||||
|
||||
template <int N>
|
||||
__device__ __forceinline__ void cp_async_wait() {
|
||||
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
|
||||
#else
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ void cp_async_wait<0>() {
|
||||
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
|
||||
asm volatile("cp.async.wait_all;\n" ::);
|
||||
#else
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float clip(float v, float mmin, float mmax) {
|
||||
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
|
||||
return fminf(mmax, fmaxf(v, mmin));
|
||||
#else
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ __forceinline__ __nv_bfloat16 clip(__nv_bfloat16 v,
|
||||
__nv_bfloat16 mmin,
|
||||
__nv_bfloat16 mmax) {
|
||||
return __hmin(mmax, __hmax(v, mmin));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ __nv_bfloat162 clip(__nv_bfloat162 v,
|
||||
__nv_bfloat162 mmin,
|
||||
__nv_bfloat162 mmax) {
|
||||
return __hmin2(mmax, __hmax2(v, mmin));
|
||||
}
|
||||
|
||||
// We use the following values for fp8 min/max:
|
||||
// __nv_fp8_e4m3 = (-448, +448)
|
||||
// __nv_fp8_e4m3uz = (-240.0, +240.0)
|
||||
// It is currently assumed that only
|
||||
template <class T>
|
||||
constexpr __nv_bfloat16 get_fp8_max() {
|
||||
static_assert(std::is_same_v<T, c10::Float8_e4m3fn> ||
|
||||
std::is_same_v<T, c10::Float8_e4m3fnuz>);
|
||||
if constexpr (std::is_same_v<T, c10::Float8_e4m3fn>) {
|
||||
return __nv_bfloat16(__nv_bfloat16_raw{.x = 17376});
|
||||
} else {
|
||||
return __nv_bfloat16(__nv_bfloat16_raw{.x = 17264});
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
constexpr __nv_bfloat16 get_fp8_min() {
|
||||
static_assert(std::is_same_v<T, c10::Float8_e4m3fn> ||
|
||||
std::is_same_v<T, c10::Float8_e4m3fnuz>);
|
||||
if constexpr (std::is_same_v<T, c10::Float8_e4m3fn>) {
|
||||
return __nv_bfloat16(__nv_bfloat16_raw{.x = 50144});
|
||||
} else {
|
||||
return __nv_bfloat16(__nv_bfloat16_raw{.x = 50032});
|
||||
}
|
||||
}
|
||||
#ifndef USE_ROCM
|
||||
template <typename fp8_type, int32_t NUM_WARPS, typename Idx_t,
|
||||
int NUM_PARALLEL_TOKENS, bool USE_UE8M0, int GROUP_SIZE = 128,
|
||||
int NUM_STAGES = 3>
|
||||
__global__ void silu_mul_fp8_quant_deep_gemm_kernel(
|
||||
const __nv_bfloat16* __restrict__ _input, fp8_type* __restrict__ _y_q,
|
||||
float* __restrict__ _y_s, const int32_t* __restrict__ counts,
|
||||
|
||||
// sizes
|
||||
int H, int G,
|
||||
|
||||
// strides (in elements)
|
||||
Idx_t stride_i_e, Idx_t stride_i_t, Idx_t stride_i_h, Idx_t stride_yq_e,
|
||||
Idx_t stride_yq_t, Idx_t stride_yq_h, Idx_t stride_ys_e, Idx_t stride_ys_t,
|
||||
Idx_t stride_ys_g, Idx_t stride_counts_e) {
|
||||
static constexpr __nv_bfloat16 fp8_min = get_fp8_min<fp8_type>();
|
||||
static constexpr __nv_bfloat16 fp8_max = get_fp8_max<fp8_type>();
|
||||
// We assign EPS with its 16-bit unsigned counterpart to allow constexpr.
|
||||
static constexpr __nv_bfloat16 EPS = (__nv_bfloat16_raw{.x = 11996});
|
||||
|
||||
// We pack 8 16-bit bfloat16 values into a 128-bit __int128_t.
|
||||
static constexpr int32_t BFLOAT16_PER_GROUP = 8;
|
||||
|
||||
// We split the shared memory in half, corresponding to gate and up matrices:
|
||||
// [...gate_i, ...up_i] where 0 <= i < stages.
|
||||
static constexpr int32_t S_NUM_128 =
|
||||
2u * (GROUP_SIZE / BFLOAT16_PER_GROUP) * NUM_WARPS * NUM_STAGES;
|
||||
static constexpr auto THREAD_COUNT = NUM_WARPS * WARP_SIZE;
|
||||
static constexpr int HALF_THREAD_COUNT = THREAD_COUNT / 2;
|
||||
static constexpr int32_t S_NUM_64 = S_NUM_128 * 2;
|
||||
__shared__ __int128_t __align__(16) s_buff_128[S_NUM_128];
|
||||
|
||||
const int32_t tid = threadIdx.x;
|
||||
const int32_t warp_id = tid / WARP_SIZE;
|
||||
const int32_t lane_id = tid % WARP_SIZE;
|
||||
|
||||
auto s_buff_compute_32 = reinterpret_cast<__nv_bfloat162*>(s_buff_128);
|
||||
|
||||
// block handles one (expert e, group g)
|
||||
int32_t pid = blockIdx.x;
|
||||
int32_t e = pid / G;
|
||||
int32_t g = pid % G;
|
||||
|
||||
const int32_t n_tokens = counts[e * stride_counts_e];
|
||||
|
||||
if (!n_tokens) {
|
||||
return; // Exit ASAP.
|
||||
}
|
||||
|
||||
const Idx_t stride_i_t_128 = stride_i_t / 8u;
|
||||
|
||||
int32_t n_tokens_lower, n_tokens_upper;
|
||||
|
||||
// Each block i iterates over tokens of a slice of n_tokens =
|
||||
// expert_counts[i], with the size of chunk being
|
||||
// (n_tokens / NUM_PARALLEL_TOKENS) + residual, instead of
|
||||
// updiv(n_tokens, NUM_PARALLEL_TOKENS) for better scheduling.
|
||||
if (n_tokens < NUM_PARALLEL_TOKENS && blockIdx.y < n_tokens) {
|
||||
// Specialize this, but can be likely fused.
|
||||
if (blockIdx.y >= NUM_PARALLEL_TOKENS) {
|
||||
return;
|
||||
}
|
||||
n_tokens_lower = blockIdx.y;
|
||||
n_tokens_upper = blockIdx.y + 1;
|
||||
} else {
|
||||
auto chunk_size = n_tokens / NUM_PARALLEL_TOKENS;
|
||||
auto residual = n_tokens - chunk_size * NUM_PARALLEL_TOKENS;
|
||||
auto calc_id = [&](int32_t id) {
|
||||
if (id < residual) {
|
||||
return min(n_tokens, id * (chunk_size + 1));
|
||||
} else {
|
||||
return min(n_tokens, id * chunk_size + residual);
|
||||
}
|
||||
};
|
||||
n_tokens_lower = calc_id(blockIdx.y);
|
||||
n_tokens_upper = calc_id(blockIdx.y + 1);
|
||||
}
|
||||
|
||||
if (n_tokens_lower >= n_tokens_upper) {
|
||||
return;
|
||||
}
|
||||
|
||||
// We do calculations here, using constexpr wherever possible.
|
||||
const Idx_t base_i = e * stride_i_e + NUM_WARPS * g * GROUP_SIZE * stride_i_h;
|
||||
const Idx_t base_ys = e * stride_ys_e + NUM_WARPS * g * stride_ys_g;
|
||||
const Idx_t base_yq =
|
||||
e * stride_yq_e + NUM_WARPS * g * GROUP_SIZE * stride_yq_h;
|
||||
Idx_t gate_off_128 = (base_i / static_cast<Idx_t>(8u));
|
||||
auto input_128_ptr = reinterpret_cast<const __int128_t*>(_input);
|
||||
auto gate_128_ptr = input_128_ptr + gate_off_128 + (tid % HALF_THREAD_COUNT) +
|
||||
stride_i_t_128 * n_tokens_lower;
|
||||
auto up_128_ptr = gate_128_ptr + (H * stride_i_h) / 8u;
|
||||
auto y_s_ptr =
|
||||
_y_s + base_ys + warp_id * stride_ys_g + n_tokens_lower * stride_ys_t;
|
||||
auto y_q_ptr = _y_q + base_yq + warp_id * GROUP_SIZE +
|
||||
stride_yq_t * n_tokens_lower + 4 * lane_id;
|
||||
int32_t t_load = n_tokens_lower, load_stage_id = 0;
|
||||
auto s_buff_gate_load_128 = s_buff_128 + (tid % HALF_THREAD_COUNT);
|
||||
auto s_buff_up_load_128 = s_buff_gate_load_128 + S_NUM_128 / 2u;
|
||||
int32_t stage_offset{};
|
||||
|
||||
static constexpr int32_t LOAD_STAGE_SIZE = (NUM_WARPS * WARP_SIZE / 2);
|
||||
static constexpr int32_t LOAD_STAGE_MOD =
|
||||
NUM_STAGES * (NUM_WARPS * WARP_SIZE / 2);
|
||||
|
||||
// Two halves of all threads in a block conduct global loads for gate and up,
|
||||
// repsectively.
|
||||
auto load_and_advance_y_pred = [&] {
|
||||
if (t_load < n_tokens_upper) {
|
||||
auto s_gate_stage_128_staged_ptr = s_buff_gate_load_128 + stage_offset;
|
||||
auto s_up_stage_128_staged_ptr = s_buff_up_load_128 + stage_offset;
|
||||
|
||||
// It is very important that LOAD_STAGE_SIZE is constexpr to avoid
|
||||
// unnecessary ALU ops.
|
||||
stage_offset += LOAD_STAGE_SIZE;
|
||||
stage_offset %= LOAD_STAGE_MOD;
|
||||
|
||||
if (tid < HALF_THREAD_COUNT) {
|
||||
cp_async4(s_gate_stage_128_staged_ptr, gate_128_ptr);
|
||||
gate_128_ptr += stride_i_t_128;
|
||||
} else {
|
||||
cp_async4(s_up_stage_128_staged_ptr, up_128_ptr);
|
||||
up_128_ptr += stride_i_t_128;
|
||||
}
|
||||
++t_load;
|
||||
++load_stage_id;
|
||||
}
|
||||
// We fence even if there is nothing to load to simplify pipelining.
|
||||
cp_async_fence();
|
||||
};
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_STAGES - 1; i++) {
|
||||
load_and_advance_y_pred();
|
||||
}
|
||||
|
||||
__int64_t* s_gate_ptr = reinterpret_cast<__int64_t*>(
|
||||
s_buff_compute_32 + warp_id * (GROUP_SIZE / 2)) +
|
||||
lane_id;
|
||||
__int64_t* s_up_ptr = s_gate_ptr + S_NUM_64 / 2;
|
||||
|
||||
static constexpr int32_t STAGE_SIZE = (GROUP_SIZE * NUM_WARPS) / 4u;
|
||||
static constexpr int32_t STAGE_MOD = STAGE_SIZE * NUM_STAGES;
|
||||
|
||||
int32_t compute_pipeline_offset_64 = 0;
|
||||
|
||||
for (int32_t t = n_tokens_lower; t < n_tokens_upper; ++t) {
|
||||
__nv_bfloat162 results_bf162[2];
|
||||
|
||||
cp_async_wait<NUM_STAGES - 2>();
|
||||
__syncthreads();
|
||||
|
||||
// We double-buffer pipelined loads so that the next load will
|
||||
// concurrently run with compute without overwrites.
|
||||
load_and_advance_y_pred();
|
||||
|
||||
auto s_gate_compute_64 = s_gate_ptr + compute_pipeline_offset_64;
|
||||
auto s_up_compute_64 = s_up_ptr + compute_pipeline_offset_64;
|
||||
|
||||
// STAGE_SIZE must also be constexpr!
|
||||
compute_pipeline_offset_64 += STAGE_SIZE;
|
||||
compute_pipeline_offset_64 %= STAGE_MOD;
|
||||
|
||||
// Each thread loads (gate/up) 2X 4X bfloat16 values into registers.
|
||||
__int64_t gate64 = *s_gate_compute_64;
|
||||
__nv_bfloat162* s_gate_compute_32 =
|
||||
reinterpret_cast<__nv_bfloat162*>(&gate64);
|
||||
|
||||
__int64_t up64 = *s_up_compute_64;
|
||||
__nv_bfloat162* s_up_compute_32 = reinterpret_cast<__nv_bfloat162*>(&up64);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2; i++) {
|
||||
// For silu, we make sure that div is emitted.
|
||||
float2 gate = silu2(__bfloat1622float2(s_gate_compute_32[i]));
|
||||
results_bf162[i] = __float22bfloat162_rn(gate);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2; i++) {
|
||||
results_bf162[i] = __hmul2(results_bf162[i], s_up_compute_32[i]);
|
||||
}
|
||||
|
||||
auto _y_max2 =
|
||||
__hmax2(__habs2(results_bf162[0]), __habs2(results_bf162[1]));
|
||||
|
||||
__nv_bfloat16 y_max_bf16 = __hmax(EPS, __hmax(_y_max2.x, _y_max2.y));
|
||||
|
||||
// An entire group is assigned to a single warp, so a simple warp reduce
|
||||
// is used.
|
||||
__nv_bfloat16 y_s = warp_max(y_max_bf16) / fp8_max;
|
||||
|
||||
if constexpr (USE_UE8M0) {
|
||||
y_s = hexp2(hceil(hlog2(y_s)));
|
||||
}
|
||||
|
||||
auto inv_y = __float2bfloat16_rn(1.f) / y_s;
|
||||
|
||||
auto y_s2 = make_bfloat162(inv_y, inv_y);
|
||||
|
||||
#pragma unroll
|
||||
for (int32_t i = 0; i < 2; ++i) {
|
||||
results_bf162[i] =
|
||||
clip(__hmul2(results_bf162[i], y_s2), __bfloat162bfloat162(fp8_min),
|
||||
__bfloat162bfloat162(fp8_max));
|
||||
}
|
||||
|
||||
auto fp8x4 = __nv_fp8x4_e4m3(results_bf162[0], results_bf162[1]);
|
||||
*reinterpret_cast<__nv_fp8x4_e4m3*>(y_q_ptr) = fp8x4;
|
||||
y_q_ptr += stride_yq_t;
|
||||
|
||||
if (lane_id == 0) {
|
||||
*y_s_ptr = y_s;
|
||||
y_s_ptr += stride_ys_t;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
// Launch activation, gating, and quantize kernel.
|
||||
@ -119,3 +469,117 @@ void silu_and_mul_quant(torch::Tensor& out, // [..., d]
|
||||
TORCH_CHECK(input.size(-1) % 2 == 0);
|
||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
|
||||
}
|
||||
|
||||
void silu_mul_fp8_quant_deep_gemm_cuda(
|
||||
const at::Tensor& input, // (E, T, 2*H)
|
||||
const at::Tensor& counts, // (E)
|
||||
at::Tensor& y_q, // (E, T, H) [OUT]
|
||||
at::Tensor& y_s, // (E, T, H//group_size) [OUT]
|
||||
int64_t group_size, bool use_ue8m0, int64_t num_parallel_tokens) {
|
||||
#ifndef USE_ROCM
|
||||
// This kernel relies heavily on cp.async and fp8 support.
|
||||
// This kernel currently only supports H % 128 == 0 and assumes a
|
||||
// fixed GROUP_SIZE of 128.
|
||||
TORCH_CHECK(input.dtype() == torch::kBFloat16);
|
||||
TORCH_CHECK(y_q.dtype() == torch::kFloat8_e4m3fn ||
|
||||
y_q.dtype() == torch::kFloat8_e4m3fnuz);
|
||||
TORCH_CHECK(y_s.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(input.size(-1) % 256 == 0);
|
||||
|
||||
// Check that num_parallel_tokens is of power of 2 and between 1 and 64.
|
||||
TORCH_CHECK(1 <= num_parallel_tokens && num_parallel_tokens <= 64);
|
||||
TORCH_CHECK(!(num_parallel_tokens & (num_parallel_tokens - 1)));
|
||||
|
||||
using Idx_t = int64_t;
|
||||
|
||||
Idx_t E = input.size(0);
|
||||
Idx_t T = input.size(1);
|
||||
Idx_t H = input.size(2) / 2;
|
||||
Idx_t stride_i_e = input.stride(0);
|
||||
Idx_t stride_i_t = input.stride(1);
|
||||
Idx_t stride_i_h = input.stride(2);
|
||||
Idx_t stride_yq_e = y_q.stride(0);
|
||||
Idx_t stride_yq_t = y_q.stride(1);
|
||||
Idx_t stride_yq_h = y_q.stride(2);
|
||||
Idx_t stride_ys_e = y_s.stride(0);
|
||||
Idx_t stride_ys_t = y_s.stride(1);
|
||||
Idx_t stride_ys_g = y_s.stride(2);
|
||||
|
||||
Idx_t stride_counts_e = counts.stride(0);
|
||||
|
||||
static constexpr int GROUP_SIZE = 128;
|
||||
|
||||
#define KERNEL_FN \
|
||||
if (use_ue8m0) { \
|
||||
vllm::silu_mul_fp8_quant_deep_gemm_kernel<fp8_t, NUM_WARPS, Idx_t, \
|
||||
NUM_PARALLEL_TOKENS, true> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \
|
||||
(fp8_t*)y_q.data_ptr(), y_s.data_ptr<float>(), \
|
||||
reinterpret_cast<int32_t*>(counts.data_ptr<int>()), H, G, \
|
||||
stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t, \
|
||||
stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, \
|
||||
stride_counts_e); \
|
||||
} else { \
|
||||
vllm::silu_mul_fp8_quant_deep_gemm_kernel<fp8_t, NUM_WARPS, Idx_t, \
|
||||
NUM_PARALLEL_TOKENS, false> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \
|
||||
(fp8_t*)y_q.data_ptr(), y_s.data_ptr<float>(), \
|
||||
reinterpret_cast<int32_t*>(counts.data_ptr<int>()), H, G, \
|
||||
stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t, \
|
||||
stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, \
|
||||
stride_counts_e); \
|
||||
}
|
||||
|
||||
#define KERNEL_CALL_H \
|
||||
if (H % (4 * GROUP_SIZE) == 0) { \
|
||||
static constexpr int NUM_WARPS = 4; \
|
||||
populate_launch_params(NUM_WARPS, NUM_PARALLEL_TOKENS); \
|
||||
KERNEL_FN \
|
||||
} else { \
|
||||
static constexpr int NUM_WARPS = 1; \
|
||||
populate_launch_params(NUM_WARPS, NUM_PARALLEL_TOKENS); \
|
||||
KERNEL_FN \
|
||||
}
|
||||
|
||||
#define KERNEL_CALL_TOP_LEVEL \
|
||||
if (num_parallel_tokens == 1) { \
|
||||
static constexpr int NUM_PARALLEL_TOKENS = 1; \
|
||||
KERNEL_CALL_H \
|
||||
} else if (num_parallel_tokens == 2) { \
|
||||
static constexpr int NUM_PARALLEL_TOKENS = 2; \
|
||||
KERNEL_CALL_H \
|
||||
} else if (num_parallel_tokens == 4) { \
|
||||
static constexpr int NUM_PARALLEL_TOKENS = 4; \
|
||||
KERNEL_CALL_H \
|
||||
} else if (num_parallel_tokens == 8) { \
|
||||
static constexpr int NUM_PARALLEL_TOKENS = 8; \
|
||||
KERNEL_CALL_H \
|
||||
} else if (num_parallel_tokens == 16) { \
|
||||
static constexpr int NUM_PARALLEL_TOKENS = 16; \
|
||||
KERNEL_CALL_H \
|
||||
} else if (num_parallel_tokens == 32) { \
|
||||
static constexpr int NUM_PARALLEL_TOKENS = 32; \
|
||||
KERNEL_CALL_H \
|
||||
} else if (num_parallel_tokens == 64) { \
|
||||
static constexpr int NUM_PARALLEL_TOKENS = 64; \
|
||||
KERNEL_CALL_H \
|
||||
}
|
||||
|
||||
Idx_t G;
|
||||
dim3 block, grid;
|
||||
auto populate_launch_params = [&](int num_warps, int _num_parallel_tokens) {
|
||||
G = H / Idx_t(group_size * num_warps);
|
||||
grid = dim3(E * G, _num_parallel_tokens);
|
||||
block = dim3(num_warps * WARP_SIZE);
|
||||
};
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
VLLM_DISPATCH_FP8_TYPES(y_q.scalar_type(),
|
||||
"silu_mul_fp8_quant_deep_gemm_kernel",
|
||||
[&] { KERNEL_CALL_TOP_LEVEL });
|
||||
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -7,17 +7,10 @@
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "../../cub_helpers.h"
|
||||
#include "../../dispatch_utils.h"
|
||||
#include "../vectorization_utils.cuh"
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cub/cub.cuh>
|
||||
#include <cub/util_type.cuh>
|
||||
#else
|
||||
#include <hipcub/hipcub.hpp>
|
||||
#include <hipcub/util_type.hpp>
|
||||
#endif
|
||||
|
||||
static inline __device__ int8_t float_to_int8_rn(float x) {
|
||||
#ifdef USE_ROCM
|
||||
static constexpr auto i8_min =
|
||||
@ -173,7 +166,7 @@ __global__ void dynamic_scaled_int8_quant_kernel(
|
||||
});
|
||||
using BlockReduce = cub::BlockReduce<float, 256>;
|
||||
__shared__ typename BlockReduce::TempStorage tmp;
|
||||
float block_max = BlockReduce(tmp).Reduce(thread_max, cub::Max{}, blockDim.x);
|
||||
float block_max = BlockReduce(tmp).Reduce(thread_max, CubMaxOp{}, blockDim.x);
|
||||
__shared__ float absmax;
|
||||
if (tid == 0) {
|
||||
absmax = block_max;
|
||||
|
||||
@ -25,6 +25,8 @@
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace vllm::cutlass_w4a8 {
|
||||
|
||||
using namespace cute;
|
||||
@ -393,6 +395,71 @@ torch::Tensor pack_scale_fp8(torch::Tensor const& scales) {
|
||||
return packed_scales;
|
||||
}
|
||||
|
||||
/*
|
||||
GPU-accelerated implementation of cutlass::unified_encode_int4b.
|
||||
Constructs a lookup table in constant memory to map 8 bits
|
||||
(two 4-bit values) at a time. Assumes memory is contiguous
|
||||
and pointers are 16-byte aligned.
|
||||
*/
|
||||
__constant__ uint8_t kNibbleLUT[256];
|
||||
|
||||
__global__ void unified_encode_int4b_device(const uint8_t* in, uint8_t* out,
|
||||
size_t nbytes) {
|
||||
constexpr size_t V = sizeof(uint4); // 16 bytes
|
||||
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const size_t nthreads = size_t(gridDim.x) * blockDim.x;
|
||||
const size_t nvec = nbytes / V;
|
||||
|
||||
// 1-D grid-stride loop over 16-byte chunks
|
||||
for (size_t vec = tid; vec < nvec; vec += nthreads) {
|
||||
uint4 v = reinterpret_cast<const uint4*>(in)[vec];
|
||||
uint8_t* b = reinterpret_cast<uint8_t*>(&v);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < int(V); ++i) b[i] = kNibbleLUT[b[i]];
|
||||
reinterpret_cast<uint4*>(out)[vec] = v;
|
||||
}
|
||||
}
|
||||
|
||||
static bool upload_lut() {
|
||||
std::array<uint8_t, 256> lut{};
|
||||
auto map_nib = [](uint8_t v) -> uint8_t {
|
||||
// 1..7 -> (8 - v); keep 0 and 8..15
|
||||
return (v == 0 || (v & 0x8)) ? v : uint8_t(8 - v);
|
||||
};
|
||||
for (int b = 0; b < 256; ++b) {
|
||||
uint8_t lo = b & 0xF;
|
||||
uint8_t hi = (b >> 4) & 0xF;
|
||||
lut[b] = uint8_t((map_nib(hi) << 4) | map_nib(lo));
|
||||
}
|
||||
cudaError_t e = cudaMemcpyToSymbol(kNibbleLUT, lut.data(), lut.size(),
|
||||
/*offset=*/0, cudaMemcpyHostToDevice);
|
||||
|
||||
return (e == cudaSuccess);
|
||||
}
|
||||
|
||||
static bool unified_encode_int4b(cutlass::int4b_t const* in,
|
||||
cutlass::int4b_t* out, size_t num_int4_elems) {
|
||||
// Build/upload LUT
|
||||
if (!upload_lut()) return false;
|
||||
|
||||
static_assert(sizeof(typename cutlass::int4b_t::Storage) == 1,
|
||||
"int4 storage must be 1 byte");
|
||||
const size_t nbytes = num_int4_elems >> 1;
|
||||
|
||||
auto* in_bytes = reinterpret_cast<uint8_t const*>(in);
|
||||
auto* out_bytes = reinterpret_cast<uint8_t*>(out);
|
||||
|
||||
// kernel launch params
|
||||
constexpr int block = 256;
|
||||
const size_t nvec = nbytes / sizeof(uint4); // # of 16B vectors
|
||||
int grid = int((nvec + block - 1) / block);
|
||||
if (grid == 0) grid = 1; // ensure we still cover the tail in the kernel
|
||||
|
||||
unified_encode_int4b_device<<<grid, block>>>(in_bytes, out_bytes, nbytes);
|
||||
cudaError_t err = cudaGetLastError();
|
||||
return (err == cudaSuccess);
|
||||
}
|
||||
|
||||
torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) {
|
||||
TORCH_CHECK(B.dtype() == torch::kInt32);
|
||||
TORCH_CHECK(B.dim() == 2);
|
||||
@ -401,6 +468,7 @@ torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) {
|
||||
|
||||
int k = B.size(0) * PackFactor; // logical k
|
||||
int n = B.size(1);
|
||||
TORCH_CHECK((n * k) % 32 == 0, "need multiples of 32 int4s for 16B chunks");
|
||||
|
||||
auto B_ptr = static_cast<QuantType const*>(B.const_data_ptr());
|
||||
auto B_packed_ptr = static_cast<QuantType*>(B_packed.data_ptr());
|
||||
@ -409,7 +477,9 @@ torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) {
|
||||
LayoutB_Reordered layout_B_reordered =
|
||||
cute::tile_to_shape(LayoutAtomQuant{}, shape_B);
|
||||
|
||||
cutlass::unified_encode_int4b(B_ptr, B_packed_ptr, n * k);
|
||||
bool ok =
|
||||
vllm::cutlass_w4a8::unified_encode_int4b(B_ptr, B_packed_ptr, n * k);
|
||||
TORCH_CHECK(ok, "unified_encode_int4b failed");
|
||||
cutlass::reorder_tensor(B_packed_ptr, layout_B, layout_B_reordered);
|
||||
|
||||
return B_packed;
|
||||
|
||||
@ -14,9 +14,6 @@
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass_gemm_caller.cuh"
|
||||
|
||||
namespace vllm {
|
||||
@ -149,6 +146,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
||||
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
using ElementBlockScale = typename Gemm::ElementBlockScale;
|
||||
|
||||
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
|
||||
|
||||
@ -169,26 +167,29 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
||||
ScaleConfig::tile_atom_to_shape_SFB(make_shape(n, m, k, 1)) :
|
||||
ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
|
||||
|
||||
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
||||
auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr());
|
||||
auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr());
|
||||
auto a_ptr = static_cast<ElementAB const*>(a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementAB const*>(b.data_ptr());
|
||||
auto a_scales_ptr = static_cast<ElementBlockScale const*>(a_scales.data_ptr());
|
||||
auto b_scales_ptr = static_cast<ElementBlockScale const*>(b_scales.data_ptr());
|
||||
|
||||
auto mainloop_args = [&](){
|
||||
// layout_SFA and layout_SFB cannot be swapped since they are deduced.
|
||||
if (swap_ab) {
|
||||
return typename GemmKernel::MainloopArguments{
|
||||
b_ptr, b_stride, a_ptr, a_stride,
|
||||
b_scales_ptr, layout_SFA, a_scales_ptr, layout_SFB
|
||||
};
|
||||
}
|
||||
else {
|
||||
return typename GemmKernel::MainloopArguments{
|
||||
a_ptr, a_stride, b_ptr, b_stride,
|
||||
a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB
|
||||
};
|
||||
}
|
||||
}();
|
||||
typename GemmKernel::MainloopArguments mainloop_args{};
|
||||
mainloop_args.layout_SFA = layout_SFA;
|
||||
mainloop_args.layout_SFB = layout_SFB;
|
||||
if (swap_ab) {
|
||||
mainloop_args.ptr_A = b_ptr;
|
||||
mainloop_args.dA = b_stride;
|
||||
mainloop_args.ptr_B = a_ptr;
|
||||
mainloop_args.dB = a_stride;
|
||||
mainloop_args.ptr_SFA = b_scales_ptr;
|
||||
mainloop_args.ptr_SFB = a_scales_ptr;
|
||||
} else {
|
||||
mainloop_args.ptr_A = a_ptr;
|
||||
mainloop_args.dA = a_stride;
|
||||
mainloop_args.ptr_B = b_ptr;
|
||||
mainloop_args.dB = b_stride;
|
||||
mainloop_args.ptr_SFA = a_scales_ptr;
|
||||
mainloop_args.ptr_SFB = b_scales_ptr;
|
||||
}
|
||||
auto prob_shape = swap_ab ? cute::make_shape(n, m, k, 1) : cute::make_shape(m, n, k, 1);
|
||||
|
||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||
|
||||
@ -14,9 +14,6 @@
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass_gemm_caller.cuh"
|
||||
|
||||
namespace vllm {
|
||||
@ -128,6 +125,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
||||
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
using ElementBlockScale = typename Gemm::ElementBlockScale;
|
||||
|
||||
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
|
||||
|
||||
@ -146,17 +144,20 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
||||
LayoutSFB layout_SFB =
|
||||
ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
|
||||
|
||||
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
||||
auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr());
|
||||
auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr());
|
||||
auto a_ptr = static_cast<ElementAB const*>(a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementAB const*>(b.data_ptr());
|
||||
auto a_scales_ptr = static_cast<ElementBlockScale const*>(a_scales.data_ptr());
|
||||
auto b_scales_ptr = static_cast<ElementBlockScale const*>(b_scales.data_ptr());
|
||||
|
||||
auto mainloop_args = [&](){
|
||||
return typename GemmKernel::MainloopArguments{
|
||||
a_ptr, a_stride, b_ptr, b_stride,
|
||||
a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB
|
||||
};
|
||||
}();
|
||||
typename GemmKernel::MainloopArguments mainloop_args{};
|
||||
mainloop_args.ptr_A = a_ptr;
|
||||
mainloop_args.dA = a_stride;
|
||||
mainloop_args.ptr_B = b_ptr;
|
||||
mainloop_args.dB = b_stride;
|
||||
mainloop_args.ptr_SFA = a_scales_ptr;
|
||||
mainloop_args.layout_SFA = layout_SFA;
|
||||
mainloop_args.ptr_SFB = b_scales_ptr;
|
||||
mainloop_args.layout_SFB = layout_SFB;
|
||||
auto prob_shape = cute::make_shape(m, n, k, 1);
|
||||
|
||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||
|
||||
@ -13,27 +13,18 @@
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass_gemm_caller.cuh"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename SchedulerType, typename OutType, int GroupSizeM_,
|
||||
int GroupSizeN_, int GroupSizeK_, int TileSizeM_ = 128,
|
||||
class ClusterShape = Shape<_1, _2, _1>>
|
||||
// clang-format off
|
||||
template <class OutType, int ScaleGranularityM,
|
||||
int ScaleGranularityN, int ScaleGranularityK,
|
||||
class MmaTileShape, class ClusterShape,
|
||||
class EpilogueScheduler, class MainloopScheduler>
|
||||
struct cutlass_3x_gemm_fp8_blockwise {
|
||||
using GroupSizeM = Int<GroupSizeM_>;
|
||||
using GroupSizeN = Int<GroupSizeN_>;
|
||||
using GroupSizeK = Int<GroupSizeK_>;
|
||||
using TileSizeM = Int<TileSizeM_>;
|
||||
|
||||
static_assert(TileSizeM_ % GroupSizeM_ == 0,
|
||||
"TileSizeM must be a multiple of GroupSizeM");
|
||||
|
||||
using ElementAB = cutlass::float_e4m3_t;
|
||||
|
||||
using ElementA = ElementAB;
|
||||
@ -45,52 +36,67 @@ struct cutlass_3x_gemm_fp8_blockwise {
|
||||
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
||||
|
||||
using ElementD = OutType;
|
||||
using StrideD = Stride<int64_t, Int<1>, Int<0>>;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
using ElementC = void;
|
||||
using StrideC = StrideD;
|
||||
using ElementC = void; // TODO: support bias
|
||||
using LayoutC = LayoutD;
|
||||
static constexpr int AlignmentC = AlignmentD;
|
||||
|
||||
using ElementAccumulator = float;
|
||||
using ElementBlockScale = float;
|
||||
using ElementCompute = float;
|
||||
using ElementBlockScale = float;
|
||||
|
||||
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<
|
||||
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK>;
|
||||
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
using TileShape = Shape<TileSizeM, GroupSizeN, GroupSizeK>;
|
||||
|
||||
using KernelSchedule = cutlass::gemm::
|
||||
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<
|
||||
GroupSizeM_>;
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
using ElementScalar = float;
|
||||
using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementScalar, RoundStyle>;
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
AlignmentC,
|
||||
ElementD,
|
||||
LayoutD,
|
||||
AlignmentD,
|
||||
EpilogueScheduler,
|
||||
DefaultOperation
|
||||
>::CollectiveOp;
|
||||
|
||||
using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT<
|
||||
cutlass::epilogue::fusion::Sm90AccFetch>;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType,
|
||||
ElementAccumulator, ElementCompute, ElementC, StrideC, AlignmentC,
|
||||
ElementD, StrideD, AlignmentD, EpilogueSchedule,
|
||||
StoreEpilogueCompute>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB,
|
||||
LayoutB, AlignmentB, ElementAccumulator, TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
ElementA,
|
||||
cute::tuple<LayoutA, LayoutSFA>,
|
||||
AlignmentA,
|
||||
ElementB,
|
||||
cute::tuple<LayoutB, LayoutSFB>,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
MainloopScheduler
|
||||
>::CollectiveOp;
|
||||
|
||||
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
|
||||
SchedulerType>>;
|
||||
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue>>;
|
||||
|
||||
struct GemmKernel : public KernelType {};
|
||||
|
||||
using StrideA = typename GemmKernel::StrideA;
|
||||
using StrideB = typename GemmKernel::StrideB;
|
||||
};
|
||||
|
||||
template <typename Gemm>
|
||||
@ -99,76 +105,58 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using LayoutSFA = typename Gemm::LayoutSFA;
|
||||
using LayoutSFB = typename Gemm::LayoutSFB;
|
||||
using ScaleConfig = typename Gemm::ScaleConfig;
|
||||
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
using ElementBlockScale = typename Gemm::ElementBlockScale;
|
||||
|
||||
auto prob_shape = c3x::get_problem_shape(a, b);
|
||||
int32_t m = get<0>(prob_shape), n = get<1>(prob_shape),
|
||||
k = get<2>(prob_shape);
|
||||
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
|
||||
|
||||
int64_t lda = a.stride(0);
|
||||
int64_t ldb = b.stride(1);
|
||||
int64_t ldc = out.stride(0);
|
||||
TORCH_CHECK(m % 4 == 0, "m must be divisible by 4");
|
||||
|
||||
using StrideA = Stride<int64_t, Int<1>, int64_t>;
|
||||
using StrideB = Stride<int64_t, Int<1>, int64_t>;
|
||||
using StrideC = typename Gemm::StrideC;
|
||||
StrideA a_stride;
|
||||
StrideB b_stride;
|
||||
StrideC c_stride;
|
||||
a_stride =
|
||||
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
|
||||
b_stride =
|
||||
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
|
||||
c_stride =
|
||||
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1));
|
||||
|
||||
StrideA a_stride{lda, Int<1>{}, 0};
|
||||
StrideB b_stride{ldb, Int<1>{}, 0};
|
||||
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
|
||||
LayoutSFA layout_SFA =
|
||||
ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
|
||||
LayoutSFB layout_SFB =
|
||||
ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
|
||||
|
||||
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
||||
auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr());
|
||||
auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr());
|
||||
auto a_ptr = static_cast<ElementAB const*>(a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementAB const*>(b.data_ptr());
|
||||
auto a_scales_ptr = static_cast<ElementBlockScale const*>(a_scales.data_ptr());
|
||||
auto b_scales_ptr = static_cast<ElementBlockScale const*>(b_scales.data_ptr());
|
||||
|
||||
// Check is the t is contiguous and is 1D or 2D with one of the dimensions
|
||||
// being 1 (i.e. a row or column vector)
|
||||
auto is_contiguous_vector = [](const torch::Tensor& t) {
|
||||
auto t_sizes = t.sizes();
|
||||
return t.is_contiguous() &&
|
||||
(t.dim() == 1 ||
|
||||
(t.dim() == 2 &&
|
||||
*std::min_element(t_sizes.begin(), t_sizes.end()) == 1));
|
||||
};
|
||||
|
||||
// TODO(lucas): lets clean-up the kernel so that we pass in Strides so
|
||||
// we don't have to deal with enforcing implicit layouts
|
||||
TORCH_CHECK(a_scales.size(0) == m / Gemm::GroupSizeM::value);
|
||||
TORCH_CHECK(a_scales.size(1) == k / Gemm::GroupSizeK::value);
|
||||
TORCH_CHECK(a_scales.stride(0) == 1 || is_contiguous_vector(a_scales),
|
||||
"a_scales must be M major");
|
||||
TORCH_CHECK(b_scales.size(0) == k / Gemm::GroupSizeK::value);
|
||||
TORCH_CHECK(b_scales.size(1) == n / Gemm::GroupSizeN::value);
|
||||
TORCH_CHECK(b_scales.stride(0) == 1 || is_contiguous_vector(b_scales),
|
||||
"b_scales must be K major");
|
||||
typename GemmKernel::MainloopArguments mainloop_args{
|
||||
a_ptr, a_stride, b_ptr, b_stride, a_scales_ptr, b_scales_ptr};
|
||||
typename GemmKernel::MainloopArguments mainloop_args{};
|
||||
mainloop_args.ptr_A = a_ptr;
|
||||
mainloop_args.dA = a_stride;
|
||||
mainloop_args.ptr_B = b_ptr;
|
||||
mainloop_args.dB = b_stride;
|
||||
mainloop_args.ptr_SFA = a_scales_ptr;
|
||||
mainloop_args.layout_SFA = layout_SFA;
|
||||
mainloop_args.ptr_SFB = b_scales_ptr;
|
||||
mainloop_args.layout_SFB = layout_SFB;
|
||||
auto prob_shape = cute::make_shape(m, n, k, 1);
|
||||
|
||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
{}, c_ptr, c_stride, c_ptr, c_stride};
|
||||
|
||||
typename GemmKernel::TileSchedulerArguments scheduler;
|
||||
|
||||
static constexpr bool UsesStreamKScheduler =
|
||||
cute::is_same_v<typename GemmKernel::TileSchedulerTag,
|
||||
cutlass::gemm::StreamKScheduler>;
|
||||
|
||||
if constexpr (UsesStreamKScheduler) {
|
||||
using DecompositionMode = typename cutlass::gemm::kernel::detail::
|
||||
PersistentTileSchedulerSm90StreamKParams::DecompositionMode;
|
||||
using ReductionMode = typename cutlass::gemm::kernel::detail::
|
||||
PersistentTileSchedulerSm90StreamKParams::ReductionMode;
|
||||
|
||||
scheduler.decomposition_mode = DecompositionMode::StreamK;
|
||||
scheduler.reduction_mode = ReductionMode::Nondeterministic;
|
||||
}
|
||||
|
||||
c3x::cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
|
||||
epilogue_args, scheduler);
|
||||
epilogue_args);
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
@ -177,18 +165,12 @@ void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
auto k = a.size(1);
|
||||
auto n = b.size(1);
|
||||
|
||||
if (k > 3 * n) {
|
||||
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
||||
cutlass::gemm::StreamKScheduler, OutType, 1, 128, 128>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
} else {
|
||||
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
||||
cutlass::gemm::PersistentScheduler, OutType, 1, 128, 128>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
// TODO: better heuristics
|
||||
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
||||
OutType, 1, 128, 128, Shape<_128, _128, _128>,
|
||||
Shape<_1, _2, _1>, cutlass::epilogue::TmaWarpSpecializedCooperative,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@ -32,7 +32,7 @@ void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor.");
|
||||
TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor.");
|
||||
int32_t version_num = get_sm_version_num();
|
||||
if (version_num >= 100) {
|
||||
if (version_num >= 90) {
|
||||
TORCH_CHECK(
|
||||
a.size(0) == a_scales.size(0) &&
|
||||
cuda_utils::ceil_div(a.size(1), int64_t(128)) == a_scales.size(1),
|
||||
@ -41,32 +41,6 @@ void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
cuda_utils::ceil_div(b.size(0), int64_t(128)) == b_scales.size(0) &&
|
||||
cuda_utils::ceil_div(b.size(1), int64_t(128)) == b_scales.size(1),
|
||||
"b_scale_group_shape must be [128, 128].");
|
||||
} else {
|
||||
// TODO: Remove this after using cutlass sm90 blockwise scaling gemm
|
||||
// kernel, or introducing ceil_div to the load_init() of mainloop.
|
||||
using GroupShape = std::array<int64_t, 2>;
|
||||
auto make_group_shape = [](torch::Tensor const& x,
|
||||
torch::Tensor const& s) -> GroupShape {
|
||||
TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
|
||||
return {cuda_utils::ceil_div(x.size(0), s.size(0)),
|
||||
cuda_utils::ceil_div(x.size(1), s.size(1))};
|
||||
};
|
||||
|
||||
GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
|
||||
GroupShape b_scale_group_shape = make_group_shape(b, b_scales);
|
||||
|
||||
// 1x128 per-token group scales for activations
|
||||
// 128x128 blockwise scales for weights
|
||||
TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} &&
|
||||
b_scale_group_shape == GroupShape{128, 128} &&
|
||||
a.dtype() == torch::kFloat8_e4m3fn &&
|
||||
b.dtype() == torch::kFloat8_e4m3fn),
|
||||
"cutlass_scaled_mm only supports datatype float8_e4m3fn.\n"
|
||||
"a_scale_group_shape must be [1, 128]. Got: [",
|
||||
a_scale_group_shape[0], ", ", a_scale_group_shape[1],
|
||||
"]\n"
|
||||
"b_scale_group_shape must be [128, 128]. Got: [",
|
||||
b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
|
||||
}
|
||||
|
||||
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
|
||||
|
||||
@ -30,109 +30,41 @@
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// silu in float32
|
||||
__device__ __forceinline__ float silu(float x) {
|
||||
return __fdividef(x, (1.f + __expf(-x)));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float2 silu2(float2 x) {
|
||||
return make_float2(silu(x.x), silu(x.y));
|
||||
}
|
||||
|
||||
template <class Type>
|
||||
__inline__ __device__ PackedVec<Type> compute_silu(PackedVec<Type>& vec,
|
||||
PackedVec<Type>& vec2) {
|
||||
__inline__ __device__ PackedVec<Type> compute_silu_mul(PackedVec<Type>& vec,
|
||||
PackedVec<Type>& vec2) {
|
||||
PackedVec<Type> result;
|
||||
using packed_type = typename TypeConverter<Type>::Type;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) {
|
||||
// silu_mul in float32
|
||||
if constexpr (std::is_same_v<Type, half>) {
|
||||
half2 val(0.5f, 0.5f);
|
||||
half2 t0 = __hmul2(vec.elts[i], val);
|
||||
half2 t1 = __hfma2(h2tanh(t0), val, val);
|
||||
half2 t2 = __hmul2(vec.elts[i], t1);
|
||||
result.elts[i] = __hmul2(t2, vec2.elts[i]);
|
||||
float2 silu_vec = silu2(__half22float2(vec.elts[i]));
|
||||
result.elts[i] =
|
||||
__float22half2_rn(__fmul2_rn(silu_vec, __half22float2(vec2.elts[i])));
|
||||
} else {
|
||||
__nv_bfloat162 val(0.5f, 0.5f);
|
||||
__nv_bfloat162 t0 = __hmul2(vec.elts[i], val);
|
||||
__nv_bfloat162 t1 = __hfma2(h2tanh(t0), val, val);
|
||||
__nv_bfloat162 t2 = __hmul2(vec.elts[i], t1);
|
||||
result.elts[i] = __hmul2(t2, vec2.elts[i]);
|
||||
float2 silu_vec = silu2(__bfloat1622float2(vec.elts[i]));
|
||||
result.elts[i] = __float22bfloat162_rn(
|
||||
__fmul2_rn(silu_vec, __bfloat1622float2(vec2.elts[i])));
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Quantizes the provided PackedVec into the uint32_t output
|
||||
template <class Type, bool UE8M0_SF = false>
|
||||
__device__ uint32_t silu_and_cvt_warp_fp16_to_fp4(PackedVec<Type>& vec,
|
||||
PackedVec<Type>& vec2,
|
||||
float SFScaleVal,
|
||||
uint8_t* SFout) {
|
||||
PackedVec<Type> out_silu = compute_silu(vec, vec2);
|
||||
// Get absolute maximum values among the local 8 values.
|
||||
auto localMax = __habs2(out_silu.elts[0]);
|
||||
|
||||
// Local maximum value.
|
||||
#pragma unroll
|
||||
for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
|
||||
localMax = __hmax2(localMax, __habs2(out_silu.elts[i]));
|
||||
}
|
||||
|
||||
// Get the absolute maximum among all 16 values (two threads).
|
||||
localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);
|
||||
// Get the final absolute maximum values.
|
||||
float vecMax = float(__hmax(localMax.x, localMax.y));
|
||||
|
||||
// Get the SF (max value of the vector / max value of e2m1).
|
||||
// maximum value of e2m1 = 6.0.
|
||||
// TODO: use half as compute data type.
|
||||
float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f));
|
||||
// 8 bits representation of the SF.
|
||||
uint8_t fp8SFVal;
|
||||
// Write the SF to global memory (STG.8).
|
||||
if constexpr (UE8M0_SF) {
|
||||
// Extract the 8 exponent bits from float32.
|
||||
// float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits.
|
||||
uint32_t tmp = reinterpret_cast<uint32_t&>(SFValue) >> 23;
|
||||
fp8SFVal = tmp & 0xff;
|
||||
// Convert back to fp32.
|
||||
reinterpret_cast<uint32_t&>(SFValue) = tmp << 23;
|
||||
} else {
|
||||
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
|
||||
__nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue);
|
||||
reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp;
|
||||
// Convert back to fp32.
|
||||
SFValue = float(tmp);
|
||||
}
|
||||
// Get the output scale.
|
||||
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
|
||||
// reciprocal(SFScaleVal))
|
||||
float outputScale =
|
||||
SFValue != 0 ? reciprocal_approximate_ftz(
|
||||
SFValue * reciprocal_approximate_ftz(SFScaleVal))
|
||||
: 0.0f;
|
||||
|
||||
if (SFout) {
|
||||
// Write the SF to global memory (STG.8).
|
||||
*SFout = fp8SFVal;
|
||||
}
|
||||
|
||||
// Convert the input to float.
|
||||
float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
|
||||
if constexpr (std::is_same_v<Type, half>) {
|
||||
fp2Vals[i] = __half22float2(out_silu.elts[i]);
|
||||
} else {
|
||||
fp2Vals[i] = __bfloat1622float2(out_silu.elts[i]);
|
||||
}
|
||||
fp2Vals[i].x *= outputScale;
|
||||
fp2Vals[i].y *= outputScale;
|
||||
}
|
||||
|
||||
// Convert to e2m1 values.
|
||||
uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals);
|
||||
|
||||
// Write the e2m1 values to global memory.
|
||||
return e2m1Vec;
|
||||
}
|
||||
|
||||
// Use UE4M3 by default.
|
||||
template <class Type, bool UE8M0_SF = false>
|
||||
__global__ void __launch_bounds__(1024, 4)
|
||||
silu_and_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
|
||||
silu_mul_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
|
||||
float const* SFScale, uint32_t* out,
|
||||
uint32_t* SFout) {
|
||||
using PackedVec = PackedVec<Type>;
|
||||
@ -160,16 +92,18 @@ __global__ void __launch_bounds__(1024, 4)
|
||||
// Get the output tensor offset.
|
||||
// Same as inOffset because 8 elements are packed into one uint32_t.
|
||||
int64_t outOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;
|
||||
;
|
||||
auto& out_pos = out[outOffset];
|
||||
|
||||
// Compute silu and mul
|
||||
PackedVec out_silu_mul = compute_silu_mul(in_vec, in_vec2);
|
||||
|
||||
auto sf_out =
|
||||
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
|
||||
CVT_FP4_NUM_THREADS_PER_SF>(
|
||||
rowIdx, colIdx, numCols, SFout);
|
||||
|
||||
out_pos = silu_and_cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(
|
||||
in_vec, in_vec2, SFScaleVal, sf_out);
|
||||
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(out_silu_mul, SFScaleVal,
|
||||
sf_out);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -204,7 +138,7 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
|
||||
input.scalar_type(), "silu_and_mul_nvfp4_quant_kernel", [&] {
|
||||
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
|
||||
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
|
||||
vllm::silu_and_cvt_fp16_to_fp4<cuda_type><<<grid, block, 0, stream>>>(
|
||||
vllm::silu_mul_cvt_fp16_to_fp4<cuda_type><<<grid, block, 0, stream>>>(
|
||||
m, n, input_ptr, input_sf_ptr,
|
||||
reinterpret_cast<uint32_t*>(output_ptr),
|
||||
reinterpret_cast<uint32_t*>(sf_out));
|
||||
|
||||
@ -1,15 +1,10 @@
|
||||
#include "common.cuh"
|
||||
#include "dispatch_utils.h"
|
||||
#include "../../cub_helpers.h"
|
||||
#include "../vectorization_utils.cuh"
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cub/cub.cuh>
|
||||
#else
|
||||
#include <hipcub/hipcub.hpp>
|
||||
#endif
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
@ -116,7 +111,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel_strided(
|
||||
using BlockReduce = cub::BlockReduce<float, 256>;
|
||||
__shared__ typename BlockReduce::TempStorage tmp;
|
||||
const float block_max =
|
||||
BlockReduce(tmp).Reduce(absmax_val, cub::Max{}, blockDim.x);
|
||||
BlockReduce(tmp).Reduce(absmax_val, CubMaxOp{}, blockDim.x);
|
||||
|
||||
__shared__ float token_scale;
|
||||
if (tid == 0) {
|
||||
|
||||
@ -5,7 +5,9 @@
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#ifndef USE_ROCM
|
||||
#include "nvidia/quant_utils.cuh"
|
||||
#else
|
||||
#include "amd/quant_utils.cuh"
|
||||
#endif
|
||||
|
||||
@ -48,7 +50,9 @@ __device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val,
|
||||
float r =
|
||||
fmaxf(-quant_type_max_v<fp8_type>, fminf(x, quant_type_max_v<fp8_type>));
|
||||
#ifndef USE_ROCM
|
||||
return static_cast<fp8_type>(r);
|
||||
// Use hardware cvt instruction for fp8 on nvidia
|
||||
// Currently only support fp8_type = c10::Float8_e4m3fn
|
||||
return fp8::vec_conversion<fp8_type, float>(r);
|
||||
#else
|
||||
// Use hardware cvt instruction for fp8 on rocm
|
||||
return fp8::cvt_c10<fp8_type>(r);
|
||||
|
||||
@ -12,13 +12,26 @@ namespace vllm {
|
||||
namespace fp8 {
|
||||
#ifdef ENABLE_FP8
|
||||
|
||||
#if 0 // Disable the following code to reduce the binary size.
|
||||
template <typename Tout, typename Tin>
|
||||
__inline__ __device__ Tout
|
||||
vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) {
|
||||
__inline__ __device__ Tout vec_conversion(
|
||||
const Tin& x, const __nv_fp8_interpretation_t fp8_type = __NV_E4M3) {
|
||||
return x;
|
||||
}
|
||||
|
||||
// float -> c10::Float8_e4m3fn
|
||||
template <>
|
||||
__inline__ __device__ c10::Float8_e4m3fn
|
||||
vec_conversion<c10::Float8_e4m3fn, float>(
|
||||
const float& a, const __nv_fp8_interpretation_t fp8_type) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
return static_cast<c10::Float8_e4m3fn>(a);
|
||||
#else
|
||||
return c10::Float8_e4m3fn(__nv_cvt_float_to_fp8(a, __NV_SATFINITE, fp8_type),
|
||||
c10::Float8_e4m3fn::from_bits());
|
||||
#endif
|
||||
}
|
||||
|
||||
#if 0 // Disable the following code to reduce the binary size.
|
||||
// fp8 -> half
|
||||
template <>
|
||||
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(
|
||||
|
||||
@ -8,11 +8,7 @@
|
||||
#include "quantization/utils.cuh"
|
||||
#include "quant_conversions.cuh"
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cub/cub.cuh>
|
||||
#else
|
||||
#include <hipcub/hipcub.hpp>
|
||||
#endif
|
||||
#include "../../cub_helpers.h"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
@ -36,7 +32,7 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
ss = BlockReduce(reduceStore).Reduce(ss, cub::Sum{}, blockDim.x);
|
||||
ss = BlockReduce(reduceStore).Reduce(ss, CubAddOp{}, blockDim.x);
|
||||
|
||||
__shared__ float s_rms;
|
||||
if (threadIdx.x == 0) {
|
||||
@ -73,7 +69,7 @@ __device__ void compute_dynamic_per_token_scales(
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
block_absmax_val_maybe =
|
||||
BlockReduce(reduceStore)
|
||||
.Reduce(block_absmax_val_maybe, cub::Max{}, blockDim.x);
|
||||
.Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x);
|
||||
|
||||
__shared__ float s_token_scale;
|
||||
if (threadIdx.x == 0) {
|
||||
@ -169,7 +165,7 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
ss = BlockReduce(reduceStore).Reduce(ss, cub::Sum{}, blockDim.x);
|
||||
ss = BlockReduce(reduceStore).Reduce(ss, CubAddOp{}, blockDim.x);
|
||||
|
||||
__shared__ float s_rms;
|
||||
if (threadIdx.x == 0) {
|
||||
@ -240,7 +236,7 @@ __device__ void compute_dynamic_per_token_scales(
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
block_absmax_val_maybe =
|
||||
BlockReduce(reduceStore)
|
||||
.Reduce(block_absmax_val_maybe, cub::Max{}, blockDim.x);
|
||||
.Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x);
|
||||
|
||||
__shared__ float s_token_scale;
|
||||
if (threadIdx.x == 0) {
|
||||
|
||||
817
csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu
Normal file
817
csrc/quantization/hadamard/hadacore/hadamard_transform_cuda.cu
Normal file
@ -0,0 +1,817 @@
|
||||
// clang-format off
|
||||
// Adapted from: https://github.com/meta-pytorch/applied-ai/blob/main/kernels/cuda/inference/hadamard_transform/hadamard_transform_cuda.cu
|
||||
|
||||
/***********
|
||||
Copyright 2024 Meta
|
||||
|
||||
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
||||
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
||||
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
||||
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
***********/
|
||||
|
||||
#include <torch/all.h>
|
||||
#include <stdint.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <mma.h>
|
||||
#include <cuda/annotated_ptr>
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "core/registration.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
namespace hadacore {
|
||||
|
||||
#ifndef __CUDACC__
|
||||
#define __launch_bounds__(x,y)
|
||||
#endif
|
||||
|
||||
#define MAX_WARPS_PER_SM 48
|
||||
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
|
||||
using b16 = uint16_t;
|
||||
using b32 = uint32_t;
|
||||
|
||||
constexpr int launch_configs_big[7][3] = {
|
||||
// default
|
||||
{2, 1, 24},
|
||||
{2, 2, 16},
|
||||
{2, 4, 8},
|
||||
{2, 8, 4},
|
||||
{2, 16, 3},
|
||||
{4, 16, 2},
|
||||
{8, 16, 1}
|
||||
// // extra coalescing
|
||||
// {2, 1, 24},
|
||||
// {2, 2, 16},
|
||||
// {2, 4, 8},
|
||||
// {2, 8, 4},
|
||||
// {4, 8, 3},
|
||||
// {8, 8, 2},
|
||||
// {16, 8, 1}
|
||||
// // less coalescing
|
||||
// {2, 1, 24},
|
||||
// {2, 2, 16},
|
||||
// {2, 4, 8},
|
||||
// {2, 8, 4},
|
||||
// {1, 32, 1},
|
||||
// {2, 32, 1},
|
||||
// {4, 32, 1}
|
||||
};
|
||||
|
||||
// a 4x2, b 2x2, c 2x2
|
||||
template <torch::ScalarType dtype>
|
||||
__device__ __forceinline__ void mma_m16_n8_k16_b16_b16_b16_noacc(b32 a0, b32 a1, b32 a2, b32 a3, b32 b0, b32 b1, b32& c0, b32& c1){
|
||||
static_assert(dtype == torch::ScalarType::Half || dtype == torch::ScalarType::BFloat16);
|
||||
// d, a, b, c
|
||||
b32 zero = 0;
|
||||
if constexpr(dtype == torch::ScalarType::Half) {
|
||||
asm (
|
||||
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
|
||||
"{%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\n\t"
|
||||
: "=r"(c0), "=r"(c1) : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), "r"(zero), "r"(zero)
|
||||
);
|
||||
} else {
|
||||
b32 temp0, temp1, temp2, temp3;
|
||||
asm (
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n\t"
|
||||
: "=r"(temp0), "=r"(temp1), "=r"(temp2), "=r"(temp3) : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), "r"(zero), "r"(zero), "r"(zero), "r"(zero)
|
||||
);
|
||||
asm ("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c0) : "r"(temp1), "r"(temp0));
|
||||
asm ("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c1) : "r"(temp3), "r"(temp2));
|
||||
}
|
||||
}
|
||||
|
||||
// a 4x2, b 4x2, c 4x2
|
||||
template <torch::ScalarType dtype>
|
||||
__device__ __forceinline__ void mma_m16_n16_k16_b16_b16_b16_noacc(b32 a0, b32 a1, b32 a2, b32 a3, b32 b0, b32 b1, b32 b2, b32 b3, b32& c0, b32& c1, b32& c2, b32& c3){
|
||||
mma_m16_n8_k16_b16_b16_b16_noacc<dtype>(a0, a1, a2, a3, b0, b1, c0, c1);
|
||||
mma_m16_n8_k16_b16_b16_b16_noacc<dtype>(a0, a1, a2, a3, b2, b3, c2, c3);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void matrix_transpose_m8_n8_b16_inplace(b32& a0) {
|
||||
asm (
|
||||
"movmatrix.sync.aligned.m8n8.trans.b16 "
|
||||
"%0, %1;\n\t"
|
||||
: "=r"(a0) : "r"(a0)
|
||||
);
|
||||
}
|
||||
|
||||
#define p_p(i) ((val_1p[i] & 0x0000FFFF) | val_1p[i] << 16)
|
||||
#define p_n(i) ((val_1p[i] & 0x0000FFFF) | val_1n[i] << 16)
|
||||
#define n_p(i) ((val_1n[i] & 0x0000FFFF) | val_1p[i] << 16)
|
||||
#define n_n(i) ((val_1n[i] & 0x0000FFFF) | val_1n[i] << 16)
|
||||
|
||||
template<int64_t num_chunks, int64_t warps_per_block, int64_t log_had_size, int64_t blocks_per_sm, bool enable_mask, torch::ScalarType dtype>
|
||||
__global__ void __launch_bounds__(32 * warps_per_block, blocks_per_sm)
|
||||
// a is column major, b is row major
|
||||
hadamard_transform_kernel(b16* a, b16* out, int total_num_chunks) {
|
||||
static_assert(dtype == torch::ScalarType::Half || dtype == torch::ScalarType::BFloat16, "Only fp16 and bf16 supported currently");
|
||||
|
||||
b32 b_frag_all[num_chunks][4]; // for all chunks, holds matrix fragment (which takes 4 regs of b16x2 * 32 threads)
|
||||
|
||||
int64_t blockid = blockIdx.x * warps_per_block + threadIdx.x / 32;
|
||||
int64_t threadid = threadIdx.x % 32;
|
||||
extern __shared__ b32 bfrag_arr[]; // num_chunks * warps_per_block * 128
|
||||
int64_t real_num_chunks = ((blockid + 1) * num_chunks) > total_num_chunks ? (total_num_chunks - (blockid * num_chunks)) : num_chunks;
|
||||
int64_t diff_num_chunks = real_num_chunks - num_chunks;
|
||||
|
||||
b32* a_start_ptr = (b32*) (a + blockid * num_chunks * 256); // offset a to where this warp starts
|
||||
b32* out_start_ptr = (b32*) (out + blockid * num_chunks * 256);
|
||||
b32* a_ptr = a_start_ptr + threadid * 4;
|
||||
b32* b_frag_ptr = bfrag_arr + (blockid % warps_per_block) * num_chunks * 128 + threadid * 4;
|
||||
|
||||
#if (__CUDA_ARCH__ < 900) // SM80, SM89
|
||||
uint64_t cache_policy;
|
||||
asm volatile(
|
||||
"createpolicy.fractional.L2::evict_first.b64 %0, 1.0;\n"
|
||||
: "=l"(cache_policy)
|
||||
);
|
||||
#endif
|
||||
|
||||
#pragma unroll
|
||||
for (int64_t k = 0; k < num_chunks; k++) {
|
||||
size_t shared_ptr = __cvta_generic_to_shared(b_frag_ptr);
|
||||
#if (__CUDA_ARCH__ >= 900) // SM90
|
||||
asm volatile(
|
||||
"cp.async.cg.shared.global [%0], [%1], 16;\n"
|
||||
"cp.async.commit_group;\n"
|
||||
:: "l"(shared_ptr), "l"(a_ptr)
|
||||
);
|
||||
#else // SM80, SM89
|
||||
asm volatile(
|
||||
"cp.async.cg.shared.global.L2::cache_hint.L2::256B [%0], [%1], 16, %2;\n"
|
||||
"cp.async.commit_group;\n"
|
||||
:: "l"(shared_ptr), "l"(a_ptr), "l"(cache_policy)
|
||||
);
|
||||
#endif
|
||||
|
||||
a_ptr += 128;
|
||||
b_frag_ptr += 128;
|
||||
}
|
||||
|
||||
// generate hadamard 16x16 (up to 2 of them)
|
||||
constexpr b16 fp16_1p[4] = {0b0011100110101000, 0b0011100000000000, 0b0011010110101000, 0b0011010000000000};
|
||||
constexpr b16 fp16_1n[4] = {0b1011100110101000, 0b1011100000000000, 0b1011010110101000, 0b1011010000000000};
|
||||
constexpr b16 bf16_1p[4] = {0b0011111100110101, 0b0011111100000000, 0b0011111010110101, 0b0011111010000000};
|
||||
constexpr b16 bf16_1n[4] = {0b1011111100110101, 0b1011111100000000, 0b1011111010110101, 0b1011111010000000};
|
||||
|
||||
#define val_type_1p(i) (((dtype) == torch::ScalarType::Half) ? (fp16_1p[i]) : (bf16_1p[i]))
|
||||
#define val_type_1n(i) (((dtype) == torch::ScalarType::Half) ? (fp16_1n[i]) : (bf16_1n[i]))
|
||||
constexpr b16 val_1p[4] = {val_type_1p(0), val_type_1p(1), val_type_1p(2), val_type_1p(3)};
|
||||
constexpr b16 val_1n[4] = {val_type_1n(0), val_type_1n(1), val_type_1n(2), val_type_1n(3)};
|
||||
|
||||
constexpr b32 p_p[4] = {p_p(0), p_p(1), p_p(2), p_p(3)};
|
||||
constexpr b32 p_n[4] = {p_n(0), p_n(1), p_n(2), p_n(3)};
|
||||
constexpr b32 n_p[4] = {n_p(0), n_p(1), n_p(2), n_p(3)};
|
||||
constexpr b32 n_n[4] = {n_n(0), n_n(1), n_n(2), n_n(3)};
|
||||
const b32 had_16_p1[4][4] = {
|
||||
{
|
||||
0b10001000010001000010001000010001,
|
||||
0b00000000000000000000000000000000,
|
||||
0b00000000000000000000000000000000,
|
||||
0b10001000010001000010001000010001
|
||||
},
|
||||
{
|
||||
0b11001100100010000011001100100010,
|
||||
0b00000000000000000000000000000000,
|
||||
0b00000000000000000000000000000000,
|
||||
0b11001100100010000011001100100010
|
||||
},
|
||||
{
|
||||
0b11111111101010101100110010011001,
|
||||
0b00000000000000000000000000000000,
|
||||
0b00000000000000000000000000000000,
|
||||
0b11111111101010101100110010011001
|
||||
},
|
||||
{
|
||||
0b11111111101010101100110010011001,
|
||||
0b11111111101010101100110010011001,
|
||||
0b11111111101010101100110010011001,
|
||||
0b00000000010101010011001101100110
|
||||
}
|
||||
};
|
||||
const b32 had_16_p2[4][4] = {
|
||||
{
|
||||
0b10000000010000000010000000010000,
|
||||
0b00000000000000000000000000000000,
|
||||
0b00000000000000000000000000000000,
|
||||
0b10000000010000000010000000010000
|
||||
},
|
||||
{
|
||||
0b11000000100001000011000000100001,
|
||||
0b00000000000000000000000000000000,
|
||||
0b00000000000000000000000000000000,
|
||||
0b11000000100001000011000000100001
|
||||
},
|
||||
{
|
||||
0b11110000101001011100001110010110,
|
||||
0b00000000000000000000000000000000,
|
||||
0b00000000000000000000000000000000,
|
||||
0b11110000101001011100001110010110
|
||||
},
|
||||
{
|
||||
0b11110000101001011100001110010110,
|
||||
0b11110000101001011100001110010110,
|
||||
0b11110000101001011100001110010110,
|
||||
0b00001111010110100011110001101001
|
||||
}
|
||||
};
|
||||
const b32 had_16_mask[3][4] = {
|
||||
{
|
||||
0b10001000010001000010001000010001,
|
||||
0b00000000000000000000000000000000,
|
||||
0b00000000000000000000000000000000,
|
||||
0b10001000010001000010001000010001
|
||||
},
|
||||
{
|
||||
0b11001100110011000011001100110011,
|
||||
0b00000000000000000000000000000000,
|
||||
0b00000000000000000000000000000000,
|
||||
0b11001100110011000011001100110011
|
||||
},
|
||||
{
|
||||
0b11111111111111111111111111111111,
|
||||
0b00000000000000000000000000000000,
|
||||
0b00000000000000000000000000000000,
|
||||
0b11111111111111111111111111111111
|
||||
}
|
||||
};
|
||||
b32 had_frag[8];
|
||||
#pragma unroll
|
||||
for (int64_t i = 0; i < 2; i++) {
|
||||
int64_t c_log_h = (i == 0) ? MIN(4, log_had_size) : log_had_size % 4;
|
||||
#pragma unroll
|
||||
for (int64_t j = 0; j < 4; j++) {
|
||||
if (c_log_h < 4) {
|
||||
bool mask = had_16_mask[c_log_h - 1][j] & (1 << (31 - threadid));
|
||||
if (!mask) {
|
||||
had_frag[i * 4 + j] = 0;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
bool pred1 = had_16_p1[c_log_h - 1][j] & (1 << (31 - threadid));
|
||||
bool pred2 = had_16_p2[c_log_h - 1][j] & (1 << (31 - threadid));
|
||||
b32 val = pred1 ? (pred2 ? p_p[c_log_h - 1] : p_n[c_log_h - 1]) : (pred2 ? n_p[c_log_h - 1] : n_n[c_log_h - 1]);
|
||||
had_frag[i * 4 + j] = val;
|
||||
}
|
||||
if constexpr(log_had_size <= 4 || log_had_size % 4 == 0) break;
|
||||
}
|
||||
|
||||
// log had size above 8, only used for above 2^8 = 256 size
|
||||
constexpr int64_t part8_log_had_size = log_had_size - 8;
|
||||
|
||||
b32* a_chunk_ptr = a_start_ptr; // first chunk starts at this warp's data starts
|
||||
b32* out_chunk_ptr = out_start_ptr;
|
||||
|
||||
#pragma unroll
|
||||
for (int64_t l = 0; l < 2; l++) {
|
||||
if constexpr(log_had_size <= 8) { // l == 0 guaranteed, redundant simplified version of else body, to help compiler warnings
|
||||
b_frag_ptr = bfrag_arr + (blockid % warps_per_block) * num_chunks * 128;
|
||||
} else {
|
||||
b_frag_ptr = bfrag_arr + (blockid % warps_per_block) * num_chunks * (l == 0 ? 128 : (128 >> part8_log_had_size));
|
||||
}
|
||||
|
||||
if (l == 1) {
|
||||
if constexpr(log_had_size > 8) {
|
||||
__syncthreads(); // sync between first and second iterations if above size 256
|
||||
|
||||
if constexpr(log_had_size >= 12) {
|
||||
// sizes 4k and above
|
||||
|
||||
// a + threadblock offset + warp offset
|
||||
// can then index into all chunks owned by this warp
|
||||
b32* store = bfrag_arr + (128 >> part8_log_had_size) * (num_chunks * (blockid % warps_per_block));
|
||||
|
||||
#pragma unroll
|
||||
for (int64_t j = 0; j < 4; j++) {
|
||||
#pragma unroll
|
||||
for (int64_t k = 0; k < num_chunks; k++) {
|
||||
// here, j represents register, and k represents 8-offset/chunk
|
||||
uint64_t real_chunk_num = (num_chunks - (threadid % num_chunks) + k) % num_chunks; // chunk at which you have target thread #'s data
|
||||
|
||||
int64_t real_thread_id = (threadid / num_chunks) * num_chunks + k; // target thread #
|
||||
int64_t chunk_idx = 128 * real_chunk_num; // index due to fetching from another chunk (chunk in which this thread has the target thread's original data)
|
||||
int64_t thread_group_idx = (real_thread_id / 4) * 16; // index due to fetching from another group of num_chunk threads (since shuffle is between num_chunk threads)
|
||||
int64_t thread_idx = (real_thread_id % 4) * 2; // index due to original thread's position within the group of num_chunk threads
|
||||
int64_t reg_idx = (j / 2) * 8 + (j % 2); // index due to target register
|
||||
int64_t idx = chunk_idx + thread_group_idx + thread_idx + reg_idx; // final index
|
||||
|
||||
// fix idx for majorness
|
||||
int64_t rowidx = idx % (1 << part8_log_had_size);
|
||||
int64_t colidx = idx >> part8_log_had_size;
|
||||
|
||||
// store[rowidx * 128 + colidx] = data;
|
||||
b32 data = store[rowidx * 128 + colidx];
|
||||
|
||||
// compiler generates excessive instructions, so we manually do the if statement
|
||||
#pragma unroll
|
||||
for (uint64_t i = 0; i < num_chunks; i++) {
|
||||
asm volatile (
|
||||
"{\n\t"
|
||||
" .reg .pred p0;\n\t"
|
||||
" setp.eq.s64 p0, %1, %2;\n\t"
|
||||
" @p0 mov.b32 %0, %3;\n\t"
|
||||
"}\n\t"
|
||||
: "+r"(b_frag_all[i][j]) // Output operand %0
|
||||
: "l"(real_chunk_num), "l"(i), "r"(data) // Input operands %1, %2, %3
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int64_t j = 0; j < 4; j++) {
|
||||
#pragma unroll
|
||||
for (int64_t k = 1; k < num_chunks; k++) {
|
||||
int64_t threadid_contig = threadid % num_chunks;
|
||||
int64_t threadid_mul = threadid / num_chunks;
|
||||
int64_t threadid2 = (threadid_contig + num_chunks - k) % num_chunks + threadid_mul * num_chunks; // thread to give your data to
|
||||
b_frag_all[k][j] = __shfl_sync(0xFFFFFFFF, b_frag_all[k][j], threadid2);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int64_t k = 0; k < num_chunks; k++) {
|
||||
if constexpr(enable_mask) {
|
||||
if (k >= real_num_chunks)
|
||||
break;
|
||||
}
|
||||
if (l == 0) {
|
||||
// bad fix for k not being recognized as a constexpr by compiler
|
||||
// asm("cp.async.wait_group %0;\n" :: "n"(num_chunks - k - 1));
|
||||
#define SWITCH_WAIT_ASYNC_LOAD_GROUP(i) case i: asm volatile("cp.async.wait_group %0;\n" :: "n"(num_chunks - i - 1)); break;
|
||||
if constexpr(enable_mask) {
|
||||
switch(k + diff_num_chunks) {
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(0)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(1)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(2)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(3)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(4)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(5)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(6)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(7)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(8)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(9)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(10)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(11)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(12)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(13)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(14)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(15)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(16)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(17)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(18)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(19)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(20)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(21)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(22)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(23)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(24)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(25)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(26)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(27)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(28)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(29)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(30)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(31)
|
||||
}
|
||||
} else {
|
||||
switch(k) {
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(0)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(1)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(2)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(3)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(4)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(5)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(6)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(7)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(8)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(9)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(10)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(11)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(12)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(13)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(14)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(15)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(16)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(17)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(18)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(19)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(20)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(21)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(22)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(23)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(24)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(25)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(26)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(27)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(28)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(29)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(30)
|
||||
SWITCH_WAIT_ASYNC_LOAD_GROUP(31)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (l == 0) {
|
||||
// loading for the first iteration
|
||||
|
||||
// thread 0 loads [t0r0, t16r1, t0r2, t16r3]
|
||||
// thread 16 loads [t0r1, t16r0, t0r3, t16r2]
|
||||
// allows full coalescing, same for t1/t17, t2/t18, etc.
|
||||
#pragma unroll
|
||||
for (int64_t j = 0; j < 4; j++) {
|
||||
int64_t reg = ((threadid & 16) == 0) ? j : (j / 2 * 2 + (1 - j % 2));
|
||||
int64_t real_thread_id = (reg == 0 || reg == 2) ? threadid : (threadid ^ 16);
|
||||
int64_t real_row = real_thread_id % 4;
|
||||
int64_t real_col = real_thread_id / 4;
|
||||
b_frag_all[k][j] = b_frag_ptr[(real_row + (reg % 2) * 4) + (real_col + (j / 2) * 8) * 8];
|
||||
}
|
||||
|
||||
// for t16 swap r0/r1 and r2/r3 to have [t16r0, t0r1, t16r2, t0r3]
|
||||
// so registers are in right order, same for t17, t18, etc.
|
||||
if ((threadid & 16) != 0) {
|
||||
b32 temp = b_frag_all[k][0];
|
||||
b_frag_all[k][0] = b_frag_all[k][1];
|
||||
b_frag_all[k][1] = temp;
|
||||
|
||||
temp = b_frag_all[k][2];
|
||||
b_frag_all[k][2] = b_frag_all[k][3];
|
||||
b_frag_all[k][3] = temp;
|
||||
}
|
||||
|
||||
// t0 and t16 swap r1 and r3 to have their own data,
|
||||
// same for t1/t17, t2/18, etc.
|
||||
#pragma unroll
|
||||
for (int64_t j = 1; j < 4; j += 2) {
|
||||
b_frag_all[k][j] = __shfl_xor_sync(0xFFFFFFFF, b_frag_all[k][j], 16);
|
||||
}
|
||||
} else if constexpr(log_had_size > 8) { // condition is redundant to help compiler warnings
|
||||
if constexpr(log_had_size < 12) {
|
||||
// sizes 512, 1k, and 2k
|
||||
|
||||
// for 512:
|
||||
// thread 0 loads [t0r0, t0r1, t16r2, t16r3]
|
||||
// thread 16 loads [t0r2, t0r3, t16r0, t16r1]
|
||||
// same for t1/t17, t2/t18, etc.
|
||||
// for 1k and 2k:
|
||||
// thread 0 loads [t0r0, t0r1, t1r2, t1r3]
|
||||
// thread 1 loads [t0r2, t0r3, t1r0, t1r1]
|
||||
// same for t2/t3, t4/t5, etc.
|
||||
// allows full coalescing for 512 and 1k, 16x coalescing for 2k
|
||||
constexpr int64_t xor_val = log_had_size == 9 ? 16 : 1;
|
||||
|
||||
#pragma unroll
|
||||
for (int64_t j = 0; j < 4; j++) {
|
||||
int64_t reg = ((threadid & xor_val) == 0) ? j : (j + 2) % 4;
|
||||
int64_t real_thread_id = reg < 2 ? threadid : (threadid ^ xor_val);
|
||||
int64_t idx = (real_thread_id / 4 * 16) + (real_thread_id % 4 * 2) + (reg / 2 * 8) + (reg % 2);
|
||||
int64_t rowidx = idx % (1 << part8_log_had_size);
|
||||
int64_t colidx = idx >> part8_log_had_size;
|
||||
b_frag_all[k][j] = b_frag_ptr[rowidx * 128 + colidx];
|
||||
}
|
||||
|
||||
if ((threadid & xor_val) != 0) {
|
||||
b32 temp = b_frag_all[k][0];
|
||||
b_frag_all[k][0] = b_frag_all[k][2];
|
||||
b_frag_all[k][2] = temp;
|
||||
|
||||
temp = b_frag_all[k][1];
|
||||
b_frag_all[k][1] = b_frag_all[k][3];
|
||||
b_frag_all[k][3] = temp;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int64_t j = 2; j < 4; j++) {
|
||||
b_frag_all[k][j] = __shfl_xor_sync(0xFFFFFFFF, b_frag_all[k][j], xor_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (l == 1) {
|
||||
// for second iteration, we load 2 consecutive b16s (1 b32) per register,
|
||||
// but tensor core register layout requires 2 b16s that are in the
|
||||
// same column/consecutive rows to be in the same register, so do the swap
|
||||
b32 f0 = ((b_frag_all[k][1] & 0xFFFF) << 16) | (b_frag_all[k][0] & 0xFFFF);
|
||||
b32 f1 = ((b_frag_all[k][3] & 0xFFFF) << 16) | (b_frag_all[k][2] & 0xFFFF);
|
||||
b32 f2 = (b_frag_all[k][1] & 0xFFFF0000) | (b_frag_all[k][0] >> 16);
|
||||
b32 f3 = (b_frag_all[k][3] & 0xFFFF0000) | (b_frag_all[k][2] >> 16);
|
||||
b_frag_all[k][0] = f0;
|
||||
b_frag_all[k][1] = f1;
|
||||
b_frag_all[k][2] = f2;
|
||||
b_frag_all[k][3] = f3;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for(int64_t i = 0, remaining_log_had_size = log_had_size - l * 8; i < 2 && remaining_log_had_size > 0; i++) {
|
||||
int64_t had_off = ((remaining_log_had_size < 4) && !(log_had_size <= 4 || log_had_size % 4 == 0)) ? 4 : 0;
|
||||
mma_m16_n16_k16_b16_b16_b16_noacc<dtype>(had_frag[had_off + 0], had_frag[had_off + 1], had_frag[had_off + 2], had_frag[had_off + 3], b_frag_all[k][0], b_frag_all[k][1], b_frag_all[k][2], b_frag_all[k][3], b_frag_all[k][0], b_frag_all[k][1], b_frag_all[k][2], b_frag_all[k][3]);
|
||||
|
||||
remaining_log_had_size -= 4;
|
||||
if (remaining_log_had_size <= 0 && i == 0) {
|
||||
// TODO: consider different storing so no need for transpose
|
||||
matrix_transpose_m8_n8_b16_inplace(b_frag_all[k][0]);
|
||||
matrix_transpose_m8_n8_b16_inplace(b_frag_all[k][1]);
|
||||
matrix_transpose_m8_n8_b16_inplace(b_frag_all[k][2]);
|
||||
matrix_transpose_m8_n8_b16_inplace(b_frag_all[k][3]);
|
||||
} else {
|
||||
// swap and use output directly as b_frag for next iteration as an actually free transpose
|
||||
b32 temp = b_frag_all[k][1];
|
||||
b_frag_all[k][1] = b_frag_all[k][2];
|
||||
b_frag_all[k][2] = temp;
|
||||
}
|
||||
}
|
||||
|
||||
if (l == 1) {
|
||||
// invert swap from above for second iteration
|
||||
b32 f0 = ((b_frag_all[k][2] & 0xFFFF) << 16) | (b_frag_all[k][0] & 0xFFFF);
|
||||
b32 f1 = (b_frag_all[k][2] & 0xFFFF0000) | (b_frag_all[k][0] >> 16);
|
||||
b32 f2 = ((b_frag_all[k][3] & 0xFFFF) << 16) | (b_frag_all[k][1] & 0xFFFF);
|
||||
b32 f3 = (b_frag_all[k][3] & 0xFFFF0000) | (b_frag_all[k][1] >> 16);
|
||||
b_frag_all[k][0] = f0;
|
||||
b_frag_all[k][1] = f1;
|
||||
b_frag_all[k][2] = f2;
|
||||
b_frag_all[k][3] = f3;
|
||||
}
|
||||
|
||||
if (l == 0) {
|
||||
// inverse of coalesced load for first iteration to store result
|
||||
#pragma unroll
|
||||
for (int64_t j = 1; j < 4; j += 2) {
|
||||
b_frag_all[k][j] = __shfl_xor_sync(0xFFFFFFFF, b_frag_all[k][j], 16);
|
||||
}
|
||||
|
||||
if ((threadid & 16) != 0) {
|
||||
b32 temp = b_frag_all[k][0];
|
||||
b_frag_all[k][0] = b_frag_all[k][1];
|
||||
b_frag_all[k][1] = temp;
|
||||
|
||||
temp = b_frag_all[k][2];
|
||||
b_frag_all[k][2] = b_frag_all[k][3];
|
||||
b_frag_all[k][3] = temp;
|
||||
}
|
||||
|
||||
// if only going up to 256 size, store directly back to global memory,
|
||||
// otherwise store back to shared memory for next iteration
|
||||
b32* store = (log_had_size <= 8) ? out_chunk_ptr : b_frag_ptr;
|
||||
|
||||
#pragma unroll
|
||||
for (int64_t j = 0; j < 4; j++) {
|
||||
int64_t reg = ((threadid & 16) == 0) ? j : (j / 2 * 2 + (1 - j % 2));
|
||||
int64_t real_thread_id = (reg == 0 || reg == 2) ? threadid : (threadid ^ 16);
|
||||
int64_t real_row = real_thread_id % 4;
|
||||
int64_t real_col = real_thread_id / 4;
|
||||
store[(real_row + (reg % 2) * 4) + (real_col + (reg / 2) * 8) * 8] = b_frag_all[k][j];
|
||||
}
|
||||
} else if constexpr(log_had_size > 8) { // condition is redundant to help compiler warnings
|
||||
if (log_had_size < 12) {
|
||||
// inverse of coalesced load for sizes 512, 1k and 2k to store result
|
||||
constexpr int xor_val = log_had_size == 9 ? 16 : 1;
|
||||
#pragma unroll
|
||||
for (int64_t j = 2; j < 4; j++) {
|
||||
b_frag_all[k][j] = __shfl_xor_sync(0xFFFFFFFF, b_frag_all[k][j], xor_val);
|
||||
}
|
||||
|
||||
if ((threadid & xor_val) != 0) {
|
||||
b32 temp = b_frag_all[k][0];
|
||||
b_frag_all[k][0] = b_frag_all[k][2];
|
||||
b_frag_all[k][2] = temp;
|
||||
|
||||
temp = b_frag_all[k][1];
|
||||
b_frag_all[k][1] = b_frag_all[k][3];
|
||||
b_frag_all[k][3] = temp;
|
||||
}
|
||||
|
||||
b32* store = (b32*)(out + (blockid / warps_per_block) * (num_chunks * warps_per_block) * 256 + (256 >> part8_log_had_size) * (num_chunks * (blockid % warps_per_block) + k));
|
||||
#pragma unroll
|
||||
for (int64_t j = 0; j < 4; j++) {
|
||||
int64_t reg = ((threadid & xor_val) == 0) ? j : (j + 2) % 4;
|
||||
b32 data = b_frag_all[k][j];
|
||||
int64_t real_thread_id = reg < 2 ? threadid : (threadid ^ xor_val);
|
||||
int64_t idx = (real_thread_id / 4 * 16) + (real_thread_id % 4 * 2) + (reg / 2 * 8) + (reg % 2);
|
||||
int64_t rowidx = idx % (1 << part8_log_had_size);
|
||||
int64_t colidx = idx >> part8_log_had_size;
|
||||
store[rowidx * 128 + colidx] = data;
|
||||
}
|
||||
}
|
||||
// for size 4k and above, wait to process all chunks so a final store can be performed coalesced
|
||||
}
|
||||
|
||||
a_chunk_ptr += 128; // (only affects first 256 size) move on to next chunk by skipping 256 elements in b16 (= 128 in b32)
|
||||
out_chunk_ptr += 128;
|
||||
if constexpr(log_had_size > 8) {
|
||||
b_frag_ptr += (l == 0 ? 128 : (128 >> part8_log_had_size));
|
||||
} else { // else is redundant, simplified version of if body, to help compiler warnings
|
||||
b_frag_ptr += 128;
|
||||
}
|
||||
}
|
||||
if (log_had_size <= 8)
|
||||
break;
|
||||
}
|
||||
|
||||
if constexpr(log_had_size >= 12) {
|
||||
// for sizes 4k and above, perform final coalesced store after processing all chunks
|
||||
#pragma unroll
|
||||
for (int64_t j = 0; j < 4; j++) {
|
||||
#pragma unroll
|
||||
for (int64_t k = 1; k < num_chunks; k++) {
|
||||
int64_t threadid_contig = threadid % num_chunks;
|
||||
int64_t threadid_mul = threadid / num_chunks;
|
||||
int64_t threadid2 = (threadid_contig + k) % num_chunks + threadid_mul * num_chunks; // thread to give your data to
|
||||
b_frag_all[k][j] = __shfl_sync(0xFFFFFFFF, b_frag_all[k][j], threadid2);
|
||||
}
|
||||
}
|
||||
|
||||
// a + threadblock offset + warp offset
|
||||
// can then index into all chunks owned by this warp
|
||||
b32* store = bfrag_arr + (128 >> part8_log_had_size) * (num_chunks * (blockid % warps_per_block));
|
||||
|
||||
#pragma unroll
|
||||
for (int64_t j = 0; j < 4; j++) {
|
||||
#pragma unroll
|
||||
for (int64_t k = 0; k < num_chunks; k++) {
|
||||
// here, j represents register, and k represents 8-offset/chunk
|
||||
int64_t real_chunk_num = (num_chunks - (threadid % num_chunks) + k) % num_chunks; // chunk at which you have target thread #'s data
|
||||
|
||||
// b32 data = b_frag_all[real_chunk_num][j]; // target thread data
|
||||
b32 data;
|
||||
#pragma unroll
|
||||
for (int64_t i = 0; i < num_chunks; i++) {
|
||||
if (real_chunk_num == i) data = b_frag_all[i][j];
|
||||
}
|
||||
|
||||
int64_t real_thread_id = (threadid / num_chunks) * num_chunks + k; // target thread #
|
||||
int64_t chunk_idx = 128 * real_chunk_num; // index due to fetching from another chunk (chunk in which this thread has the target thread's original data)
|
||||
int64_t thread_group_idx = (real_thread_id / 4) * 16; // index due to fetching from another group of num_chunk threads (since shuffle is between num_chunk threads)
|
||||
int64_t thread_idx = (real_thread_id % 4) * 2; // index due to original thread's position within the group of num_chunk threads
|
||||
int64_t reg_idx = (j / 2) * 8 + (j % 2); // index due to target register
|
||||
int64_t idx = chunk_idx + thread_group_idx + thread_idx + reg_idx; // final index
|
||||
|
||||
// fix idx for majorness
|
||||
int64_t rowidx = idx % (1 << part8_log_had_size);
|
||||
int64_t colidx = idx >> part8_log_had_size;
|
||||
|
||||
store[rowidx * 128 + colidx] = data;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
store = ((b32*) out) + (blockid / warps_per_block) * (num_chunks * warps_per_block) * 128;
|
||||
int4* store4 = (int4*) store;
|
||||
int4* bfrag_arr4 = (int4*) bfrag_arr;
|
||||
// flush smem, simply linearly write to store
|
||||
// always divisible by 128*32b, so (32*4)*32b is ok
|
||||
#pragma unroll
|
||||
for (int64_t warp_off = 0; warp_off < (num_chunks * warps_per_block * 128 / 4); warp_off += 32 * warps_per_block) {
|
||||
int64_t total_off = warp_off + threadid + (blockid % warps_per_block) * 32;
|
||||
store4[total_off] = bfrag_arr4[total_off];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
constexpr int64_t ceil_div(int64_t a, int64_t b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
template <torch::ScalarType dtype, int64_t chunks_per_warp, int64_t warps_per_block, int64_t log_had_size, int64_t blocks_per_sm, bool check_masking = false>
|
||||
void __forceinline__ run_kernel(b16* a_mat, b16* out, int64_t num_chunks, cudaStream_t stream) {
|
||||
int64_t shared_size = chunks_per_warp * warps_per_block * 128 * 4;
|
||||
dim3 block_size = 32 * warps_per_block;
|
||||
|
||||
#define CHECK_SHARED_LIM() { \
|
||||
if (shared_size > 48 * 1024) { \
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536)); \
|
||||
} \
|
||||
} \
|
||||
|
||||
if constexpr(check_masking) {
|
||||
if (num_chunks % (chunks_per_warp * warps_per_block) != 0) {
|
||||
dim3 grid_size = ceil_div(ceil_div(num_chunks, chunks_per_warp), warps_per_block);
|
||||
auto kernel = hadamard_transform_kernel<chunks_per_warp, warps_per_block, log_had_size, blocks_per_sm, true, dtype>;
|
||||
CHECK_SHARED_LIM();
|
||||
kernel<<<dim3(grid_size), dim3(block_size), shared_size, stream>>>(a_mat, out, num_chunks);
|
||||
} else {
|
||||
dim3 grid_size = num_chunks / chunks_per_warp / warps_per_block;
|
||||
auto kernel = hadamard_transform_kernel<chunks_per_warp, warps_per_block, log_had_size, blocks_per_sm, false, dtype>;
|
||||
CHECK_SHARED_LIM();
|
||||
kernel<<<dim3(grid_size), dim3(block_size), shared_size, stream>>>(a_mat, out, num_chunks);
|
||||
}
|
||||
} else {
|
||||
dim3 grid_size = num_chunks / chunks_per_warp / warps_per_block;
|
||||
auto kernel = hadamard_transform_kernel<chunks_per_warp, warps_per_block, log_had_size, blocks_per_sm, false, dtype>;
|
||||
CHECK_SHARED_LIM();
|
||||
kernel<<<dim3(grid_size), dim3(block_size), shared_size, stream>>>(a_mat, out, num_chunks);
|
||||
}
|
||||
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
||||
template <torch::ScalarType dtype>
|
||||
void run_fht(void* a_mat_ptr, void* out_ptr, int64_t numel, int64_t had_size, cudaStream_t stream) {
|
||||
int64_t num_chunks = numel / 256; // caller required to ensure divisible by 256
|
||||
// for size 256, use (2, 1)
|
||||
// for size 32k use (8, 16)
|
||||
constexpr int64_t chunks_per_warp_small = 1;// 8;
|
||||
constexpr int64_t warps_per_block_small = 1;//2;//16;
|
||||
constexpr int64_t blocks_per_sm_small = 24;
|
||||
constexpr int64_t chunks_per_warp_large = 2;
|
||||
constexpr int64_t warps_per_block_large = 1;
|
||||
constexpr int64_t blocks_per_sm_large = 24;
|
||||
|
||||
b16* a_mat = (b16*) a_mat_ptr;
|
||||
b16* out = (b16*) out_ptr;
|
||||
|
||||
if (numel <= 256) {
|
||||
switch (had_size) {
|
||||
case (1<<1): run_kernel<dtype, chunks_per_warp_small, warps_per_block_small, 1, blocks_per_sm_small>(a_mat, out, num_chunks, stream); break;
|
||||
case (1<<2): run_kernel<dtype, chunks_per_warp_small, warps_per_block_small, 2, blocks_per_sm_small>(a_mat, out, num_chunks, stream); break;
|
||||
case (1<<3): run_kernel<dtype, chunks_per_warp_small, warps_per_block_small, 3, blocks_per_sm_small>(a_mat, out, num_chunks, stream); break;
|
||||
case (1<<4): run_kernel<dtype, chunks_per_warp_small, warps_per_block_small, 4, blocks_per_sm_small>(a_mat, out, num_chunks, stream); break;
|
||||
case (1<<5): run_kernel<dtype, chunks_per_warp_small, warps_per_block_small, 5, blocks_per_sm_small>(a_mat, out, num_chunks, stream); break;
|
||||
case (1<<6): run_kernel<dtype, chunks_per_warp_small, warps_per_block_small, 6, blocks_per_sm_small>(a_mat, out, num_chunks, stream); break;
|
||||
case (1<<7): run_kernel<dtype, chunks_per_warp_small, warps_per_block_small, 7, blocks_per_sm_small>(a_mat, out, num_chunks, stream); break;
|
||||
case (1<<8): run_kernel<dtype, chunks_per_warp_small, warps_per_block_small, 8, blocks_per_sm_small>(a_mat, out, num_chunks, stream); break;
|
||||
}
|
||||
} else {
|
||||
switch (had_size) {
|
||||
case (1<<1): run_kernel<dtype, chunks_per_warp_large, warps_per_block_large, 1, blocks_per_sm_large, true>(a_mat, out, num_chunks, stream); break;
|
||||
case (1<<2): run_kernel<dtype, chunks_per_warp_large, warps_per_block_large, 2, blocks_per_sm_large, true>(a_mat, out, num_chunks, stream); break;
|
||||
case (1<<3): run_kernel<dtype, chunks_per_warp_large, warps_per_block_large, 3, blocks_per_sm_large, true>(a_mat, out, num_chunks, stream); break;
|
||||
case (1<<4): run_kernel<dtype, chunks_per_warp_large, warps_per_block_large, 4, blocks_per_sm_large, true>(a_mat, out, num_chunks, stream); break;
|
||||
case (1<<5): run_kernel<dtype, chunks_per_warp_large, warps_per_block_large, 5, blocks_per_sm_large, true>(a_mat, out, num_chunks, stream); break;
|
||||
case (1<<6): run_kernel<dtype, chunks_per_warp_large, warps_per_block_large, 6, blocks_per_sm_large, true>(a_mat, out, num_chunks, stream); break;
|
||||
case (1<<7): run_kernel<dtype, chunks_per_warp_large, warps_per_block_large, 7, blocks_per_sm_large, true>(a_mat, out, num_chunks, stream); break;
|
||||
case (1<<8): run_kernel<dtype, chunks_per_warp_large, warps_per_block_large, 8, blocks_per_sm_large, true>(a_mat, out, num_chunks, stream); break;
|
||||
case (1<<9): run_kernel<dtype, launch_configs_big[0][0], launch_configs_big[0][1], 9 , launch_configs_big[0][2]>(a_mat, out, num_chunks, stream); break;
|
||||
case (1<<10): run_kernel<dtype, launch_configs_big[1][0], launch_configs_big[1][1], 10, launch_configs_big[1][2]>(a_mat, out, num_chunks, stream); break;
|
||||
case (1<<11): run_kernel<dtype, launch_configs_big[2][0], launch_configs_big[2][1], 11, launch_configs_big[2][2]>(a_mat, out, num_chunks, stream); break;
|
||||
case (1<<12): run_kernel<dtype, launch_configs_big[3][0], launch_configs_big[3][1], 12, launch_configs_big[3][2]>(a_mat, out, num_chunks, stream); break;
|
||||
case (1<<13): run_kernel<dtype, launch_configs_big[4][0], launch_configs_big[4][1], 13, launch_configs_big[4][2]>(a_mat, out, num_chunks, stream); break;
|
||||
case (1<<14): run_kernel<dtype, launch_configs_big[5][0], launch_configs_big[5][1], 14, launch_configs_big[5][2]>(a_mat, out, num_chunks, stream); break;
|
||||
case (1<<15): run_kernel<dtype, launch_configs_big[6][0], launch_configs_big[6][1], 15, launch_configs_big[6][2]>(a_mat, out, num_chunks, stream); break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template void run_fht<torch::ScalarType::Half>(void* a_mat_ptr, void* out_ptr, int64_t numel, int64_t had_size, cudaStream_t stream);
|
||||
template void run_fht<torch::ScalarType::BFloat16>(void* a_mat_ptr, void* out_ptr, int64_t numel, int64_t had_size, cudaStream_t stream);
|
||||
|
||||
} // namespace hadacore
|
||||
|
||||
constexpr bool is_power_of_two(int x) { return x && !(x & (x - 1)); }
|
||||
|
||||
torch::Tensor hadacore_transform(torch::Tensor& x, bool inplace) {
|
||||
auto dtype = x.scalar_type();
|
||||
TORCH_CHECK(dtype == torch::ScalarType::Half || dtype == torch::ScalarType::BFloat16, "Only fp16 and bf16 supported currently");
|
||||
TORCH_CHECK(x.is_cuda());
|
||||
|
||||
const int had_size = x.size(-1);
|
||||
TORCH_CHECK(is_power_of_two(had_size) && (had_size <= (1U << 15)),
|
||||
"Only power of two Hadamard sizes up to 2^15 are supported, got ", had_size);
|
||||
|
||||
const auto res_shape = x.sizes();
|
||||
x = x.reshape({-1, had_size});
|
||||
|
||||
auto numel = x.numel();
|
||||
if (numel % 256 != 0) {
|
||||
x = torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, 0, 0, (256 - numel % 256) / had_size}));
|
||||
}
|
||||
|
||||
if (x.stride(-1) != 1) {
|
||||
x = x.contiguous();
|
||||
}
|
||||
torch::Tensor out = inplace ? x : torch::empty_like(x);
|
||||
|
||||
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
VLLM_DISPATCH_HALF_TYPES(x.scalar_type(), "hadacore_transform_runfht", [&] {
|
||||
auto constexpr SCALAR_TYPE = c10::CppTypeToScalarType<scalar_t>::value;
|
||||
hadacore::run_fht<SCALAR_TYPE>(x.data_ptr(), x.data_ptr(), x.numel(), had_size, stream);
|
||||
});
|
||||
|
||||
if (numel % 256 != 0) {
|
||||
out = out.index({torch::indexing::Slice(0, numel / had_size)});
|
||||
}
|
||||
|
||||
if (inplace && out.data_ptr() != x.data_ptr()) {
|
||||
x.copy_(out.view(res_shape));
|
||||
return x;
|
||||
}
|
||||
return out.reshape(res_shape);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("hadacore_transform", &hadacore_transform);
|
||||
}
|
||||
@ -30,6 +30,10 @@
|
||||
#define __HIP__GFX9__
|
||||
#endif
|
||||
|
||||
#if defined(__HIPCC__) && (defined(__gfx942__) || defined(__gfx950__))
|
||||
#define __HIP__FP8MFMA__
|
||||
#endif
|
||||
|
||||
#if defined(__HIPCC__) && (defined(__gfx1100__) || defined(__gfx1101__))
|
||||
#define __HIP__GFX11__
|
||||
#endif
|
||||
@ -51,6 +55,12 @@
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
||||
|
||||
enum class MFMAType {
|
||||
F16 = 0,
|
||||
Fp8 = 1,
|
||||
Fp4 = 2,
|
||||
};
|
||||
|
||||
#if defined(__HIP__GFX9__)
|
||||
|
||||
#define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32
|
||||
@ -112,6 +122,21 @@ __device__ __forceinline__ floatx4 gcn_mfma16x16x16_instr(const _B16x4& inpA,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int absz, int cbid, int blgp>
|
||||
__device__ __forceinline__ floatx4 gcn_mfma16x16x32_instr(const long& inpA,
|
||||
const long& inpB,
|
||||
const floatx4& inpC) {
|
||||
if constexpr (std::is_same<T, __hip_fp8_e4m3>::value) {
|
||||
return __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(inpA, inpB, inpC, absz,
|
||||
cbid, blgp);
|
||||
} else if constexpr (std::is_same<T, __hip_fp8_e5m2>::value) {
|
||||
return __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(inpA, inpB, inpC, absz,
|
||||
cbid, blgp);
|
||||
} else {
|
||||
static_assert(false, "unsupported 8b dtype");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ float to_float(const T& inp) {
|
||||
if constexpr (std::is_same<T, _Float16>::value) {
|
||||
@ -256,12 +281,44 @@ __device__ __forceinline__ _B16x8 convert_b8x8_custom(const _B8x8 input) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
typedef union u64_cvt {
|
||||
half f16x4[4];
|
||||
int16_t b16x4[4];
|
||||
_B8x8 b8x8;
|
||||
_B16x4 b64;
|
||||
int64_t i64;
|
||||
} _T8x8;
|
||||
|
||||
__device__ __forceinline__ _B8x8 convert_b16x8(const _B16x8& input,
|
||||
_T8x8& Mtemp) {
|
||||
_T8x8 Qtmp8x8;
|
||||
|
||||
for (int i = 0; i < 2; i++) {
|
||||
floatx4 q_out = {0, 0, 0, 0};
|
||||
q_out = gcn_mfma16x16x16_instr<_Float16, 0, 0, 0>(Mtemp.b64, input.xy[i],
|
||||
q_out);
|
||||
Qtmp8x8.b16x4[i * 2] =
|
||||
__builtin_amdgcn_cvt_pk_fp8_f32(q_out[0], q_out[1], 0, false);
|
||||
Qtmp8x8.b16x4[i * 2 + 1] =
|
||||
__builtin_amdgcn_cvt_pk_fp8_f32(q_out[2], q_out[3], 0, false);
|
||||
}
|
||||
return Qtmp8x8.b8x8;
|
||||
}
|
||||
|
||||
__device__ float warpReduceMax(float val) {
|
||||
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
|
||||
val = max(
|
||||
val, __shfl_down(val, offset, WARP_SIZE)); // Using max() for reduction
|
||||
}
|
||||
return val;
|
||||
}
|
||||
|
||||
// grid (num_seqs, num_partitions,num_kv_heads)
|
||||
// block (256)
|
||||
// clang-format off
|
||||
template <typename scalar_t, typename cache_t,
|
||||
vllm::Fp8KVCacheDataType KV_DTYPE, typename OUTT, int BLOCK_SIZE,
|
||||
int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, int GQA_RATIO>
|
||||
int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, int GQA_RATIO, MFMAType MFMA_TYPE>
|
||||
__global__
|
||||
__launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
@ -367,6 +424,10 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq;
|
||||
|
||||
int kphysical_block_number[TLOOP];
|
||||
#if defined(__HIP__FP8MFMA__)
|
||||
float q_max = 0;
|
||||
float q_scale = 1.0;
|
||||
#endif
|
||||
|
||||
// fetch k physical block numbers
|
||||
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
|
||||
@ -416,6 +477,15 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
Qlocal[qkhe_depth][qkratio].xy[i] =
|
||||
shared_logits[qkhe_depth][rowid][lane16id % GQA_RATIO]
|
||||
[2 * qkratio + i];
|
||||
#if defined(__HIP__FP8MFMA__)
|
||||
if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto &&
|
||||
MFMA_TYPE == MFMAType::Fp8) {
|
||||
scalar_t* qptr =
|
||||
reinterpret_cast<scalar_t*>(&Qlocal[qkhe_depth][qkratio].xy[i]);
|
||||
for (int k = 0; k < 4; k++)
|
||||
q_max = fmax(fabs(to_float<scalar_t>(qptr[k])), q_max);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -515,6 +585,14 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) {
|
||||
// multiply by k_scale if fp8 kv cache
|
||||
scale2 *= *k_scale;
|
||||
#if defined(__HIP__FP8MFMA__)
|
||||
q_max = warpReduceMax(q_max);
|
||||
constexpr float FP8_E4M3_SCALE_TARGET = 224.0f;
|
||||
if constexpr (MFMA_TYPE == MFMAType::Fp8) {
|
||||
q_scale = q_max > 0 ? FP8_E4M3_SCALE_TARGET / q_max : 1.0f;
|
||||
scale2 /= q_scale;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
floatx4 d_out[TLOOP];
|
||||
@ -534,12 +612,41 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
auto Ktmp = Klocal[token_depth][qkhe_depth];
|
||||
_B8x16 Ktmp8x16 = *reinterpret_cast<_B8x16*>(&Ktmp);
|
||||
for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) {
|
||||
_B8x8 Ktmp8x8 = Ktmp8x16.xy[qkratio];
|
||||
_B16x8 Klocaltmp = convert_b8x8_custom<scalar_t>(Ktmp8x8);
|
||||
for (int i = 0; i < 2; i++) {
|
||||
d_out[token_depth] = gcn_mfma16x16x16_instr<scalar_t, 0, 0, 0>(
|
||||
Klocaltmp.xy[i], Qlocal[qkhe_depth][qkratio].xy[i],
|
||||
d_out[token_depth]);
|
||||
if constexpr (MFMA_TYPE == MFMAType::F16) {
|
||||
_B8x8 Ktmp8x8 = Ktmp8x16.xy[qkratio];
|
||||
_B16x8 Klocaltmp = convert_b8x8_custom<scalar_t>(Ktmp8x8);
|
||||
for (int i = 0; i < 2; i++) {
|
||||
d_out[token_depth] = gcn_mfma16x16x16_instr<scalar_t, 0, 0, 0>(
|
||||
Klocaltmp.xy[i], Qlocal[qkhe_depth][qkratio].xy[i],
|
||||
d_out[token_depth]);
|
||||
}
|
||||
} else {
|
||||
#if defined(__HIP__FP8MFMA__)
|
||||
_T8x8 Ktmp8x8, Qtmp8x8;
|
||||
Ktmp8x8.b8x8 = Ktmp8x16.xy[qkratio];
|
||||
|
||||
for (int n = 0; n < 2; n++) {
|
||||
scalar_t* qptr = reinterpret_cast<scalar_t*>(
|
||||
&Qlocal[qkhe_depth][qkratio].xy[n]);
|
||||
|
||||
Qtmp8x8.b16x4[n * 2] =
|
||||
vllm::fp8::scaled_vec_conversion<uint16_t, float2>(
|
||||
make_float2(to_float<scalar_t>(qptr[0]),
|
||||
to_float<scalar_t>(qptr[1])),
|
||||
q_scale);
|
||||
Qtmp8x8.b16x4[n * 2 + 1] =
|
||||
vllm::fp8::scaled_vec_conversion<uint16_t, float2>(
|
||||
make_float2(to_float<scalar_t>(qptr[2]),
|
||||
to_float<scalar_t>(qptr[3])),
|
||||
q_scale);
|
||||
}
|
||||
|
||||
d_out[token_depth] =
|
||||
gcn_mfma16x16x32_instr<__hip_fp8_e4m3, 0, 0, 0>(
|
||||
Ktmp8x8.i64, Qtmp8x8.i64, d_out[token_depth]);
|
||||
#else
|
||||
UNREACHABLE_CODE
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -629,17 +736,36 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
// disable rtz conversion due to its impact on accuracy.
|
||||
constexpr bool LOGITS_RTZ_CONVERSION = false;
|
||||
|
||||
#if defined(__HIP__FP8MFMA__)
|
||||
int rowid_8x8 = rowid / 2;
|
||||
int offset = rowid % 2;
|
||||
#endif
|
||||
|
||||
// write logits to shared mem
|
||||
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
|
||||
d_out[token_depth] *= inv_sum_scale;
|
||||
if constexpr (LOGITS_RTZ_CONVERSION) {
|
||||
// use rtz conversion for better performance, with negligible impact on
|
||||
// accuracy
|
||||
shared_logits[warpid][token_depth][lane16id][rowid] =
|
||||
from_floatx4_rtz<scalar_t>(d_out[token_depth]);
|
||||
if constexpr (MFMA_TYPE != MFMAType::Fp8) {
|
||||
if constexpr (LOGITS_RTZ_CONVERSION) {
|
||||
// use rtz conversion for better performance, with negligible impact on
|
||||
// accuracy
|
||||
shared_logits[warpid][token_depth][lane16id][rowid] =
|
||||
from_floatx4_rtz<scalar_t>(d_out[token_depth]);
|
||||
} else {
|
||||
shared_logits[warpid][token_depth][lane16id][rowid] =
|
||||
from_floatx4<scalar_t>(d_out[token_depth]);
|
||||
}
|
||||
} else {
|
||||
shared_logits[warpid][token_depth][lane16id][rowid] =
|
||||
from_floatx4<scalar_t>(d_out[token_depth]);
|
||||
#if defined(__HIP__FP8MFMA__)
|
||||
// cast _B16x4* to _B8x8*
|
||||
_T8x8& logits_8x8 = *reinterpret_cast<_T8x8*>(
|
||||
&shared_logits[warpid][token_depth][lane16id][rowid_8x8]);
|
||||
logits_8x8.b16x4[offset * 2] = __builtin_amdgcn_cvt_pk_fp8_f32(
|
||||
d_out[token_depth][0], d_out[token_depth][1], 0, false);
|
||||
logits_8x8.b16x4[offset * 2 + 1] = __builtin_amdgcn_cvt_pk_fp8_f32(
|
||||
d_out[token_depth][2], d_out[token_depth][3], 0, false);
|
||||
#else
|
||||
UNREACHABLE_CODE
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@ -692,19 +818,42 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
_B8x16 Vtmp8x16 = *reinterpret_cast<_B8x16*>(&Vtmp);
|
||||
for (int j = 0; j < ELEMS16_ELEMS8_RATIO; j++) {
|
||||
_B8x8 Vtmp8x8 = Vtmp8x16.xy[j];
|
||||
_B16x8 Vlocaltmp = convert_b8x8_custom<scalar_t>(Vtmp8x8);
|
||||
for (int i = 0; i < ELEMS8_ELEMS4_RATIO; i++) {
|
||||
const int offset =
|
||||
rowid * ELEMS16_ELEMS8_RATIO * ELEMS8_ELEMS4_RATIO +
|
||||
j * ELEMS8_ELEMS4_RATIO + i;
|
||||
const int offset1 = offset % ROWS_PER_WARP;
|
||||
const int offset2 = offset / ROWS_PER_WARP;
|
||||
// output format is 16 qheads across 16 lanes, 16 head elems
|
||||
// spread across 4 rows
|
||||
tmp_out = gcn_mfma16x16x16_instr<scalar_t, 0, 0, 0>(
|
||||
Vlocaltmp.xy[i],
|
||||
shared_logits[vtoken_depth][offset2][lane16id][offset1],
|
||||
tmp_out);
|
||||
if constexpr (MFMA_TYPE == MFMAType::F16) {
|
||||
_B16x8 Vlocaltmp = convert_b8x8_custom<scalar_t>(Vtmp8x8);
|
||||
for (int i = 0; i < ELEMS8_ELEMS4_RATIO; i++) {
|
||||
const int offset =
|
||||
rowid * ELEMS16_ELEMS8_RATIO * ELEMS8_ELEMS4_RATIO +
|
||||
j * ELEMS8_ELEMS4_RATIO + i;
|
||||
const int offset1 = offset % ROWS_PER_WARP;
|
||||
const int offset2 = offset / ROWS_PER_WARP;
|
||||
// output format is 16 qheads across 16 lanes, 16 head elems
|
||||
// spread across 4 rows
|
||||
tmp_out = gcn_mfma16x16x16_instr<scalar_t, 0, 0, 0>(
|
||||
Vlocaltmp.xy[i],
|
||||
shared_logits[vtoken_depth][offset2][lane16id][offset1],
|
||||
tmp_out);
|
||||
}
|
||||
} else {
|
||||
#if defined(__HIP__FP8MFMA__)
|
||||
for (int i = 0; i < ELEMS8_ELEMS4_RATIO / 2; i++) {
|
||||
const int offset =
|
||||
rowid * ELEMS16_ELEMS8_RATIO * ELEMS8_ELEMS4_RATIO +
|
||||
j * ELEMS8_ELEMS4_RATIO + i;
|
||||
const int offset1 = (offset % ROWS_PER_WARP) / 2;
|
||||
const int offset2 = offset / ROWS_PER_WARP;
|
||||
// output format is 16 qheads across 16 lanes, 16 head elems
|
||||
// spread across 4 rows
|
||||
tmp_out = gcn_mfma16x16x32_instr<__hip_fp8_e4m3, 0, 0, 0>(
|
||||
reinterpret_cast<_T8x8*>(&Vtmp8x8)->i64,
|
||||
reinterpret_cast<_T8x8*>(
|
||||
&shared_logits[vtoken_depth][offset2][lane16id]
|
||||
[offset1])
|
||||
->i64,
|
||||
tmp_out);
|
||||
}
|
||||
#else
|
||||
UNREACHABLE_CODE
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1570,7 +1719,8 @@ __device__ __forceinline__ _B16x8 from_floatx8(const floatx8& inp) {
|
||||
// clang-format off
|
||||
template <typename scalar_t, typename cache_t,
|
||||
vllm::Fp8KVCacheDataType KV_DTYPE, typename OUTT, int BLOCK_SIZE,
|
||||
int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, int GQA_RATIO>
|
||||
int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, int GQA_RATIO,
|
||||
MFMAType MFMA_TYPE>
|
||||
__global__
|
||||
__launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
@ -2337,7 +2487,8 @@ __device__ __forceinline__ _B16x8 from_floatx8(const floatx8& inp) {
|
||||
// clang-format off
|
||||
template <typename scalar_t, typename cache_t,
|
||||
vllm::Fp8KVCacheDataType KV_DTYPE, typename OUTT, int BLOCK_SIZE,
|
||||
int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, int GQA_RATIO>
|
||||
int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED, int GQA_RATIO,
|
||||
MFMAType MFMA_TYPE>
|
||||
__global__
|
||||
__launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
@ -2969,7 +3120,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
template <typename scalar_t, typename cache_t,
|
||||
vllm::Fp8KVCacheDataType KV_DTYPE, typename OUTT, int BLOCK_SIZE,
|
||||
int HEAD_SIZE, int NUM_THREADS, bool ALIBI_ENABLED,
|
||||
int GQA_RATIO>
|
||||
int GQA_RATIO, MFMAType MFMA_TYPE>
|
||||
__global__
|
||||
__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
@ -3041,7 +3192,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
#define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \
|
||||
paged_attention_ll4mi_QKV_mfma16_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \
|
||||
HEAD_SIZE, NTHR, ALIBI_ENABLED, \
|
||||
GQA_RATIO> \
|
||||
GQA_RATIO, MFMA_TYPE> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
|
||||
block_tables_ptr, seq_lens_ptr, query_start_loc_ptr, \
|
||||
@ -3069,7 +3220,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
|
||||
template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
|
||||
int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD,
|
||||
bool ALIBI_ENABLED>
|
||||
bool ALIBI_ENABLED, MFMAType MFMA_TYPE>
|
||||
void paged_attention_custom_launcher(
|
||||
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
@ -3225,7 +3376,7 @@ void paged_attention_custom_launcher(
|
||||
|
||||
template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
|
||||
int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD,
|
||||
bool ALIBI_ENABLED>
|
||||
bool ALIBI_ENABLED, MFMAType MFMA_TYPE>
|
||||
void paged_attention_custom_launcher_navi(
|
||||
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
@ -3397,74 +3548,77 @@ void paged_attention_custom_launcher_navi(
|
||||
}
|
||||
|
||||
#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \
|
||||
PSIZE, ALIBI_ENABLED) \
|
||||
PSIZE, ALIBI_ENABLED, MFMA_TYPE) \
|
||||
if (!is_navi) { \
|
||||
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
|
||||
OUTT, PSIZE, ALIBI_ENABLED>( \
|
||||
OUTT, PSIZE, ALIBI_ENABLED, MFMA_TYPE>( \
|
||||
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
||||
num_kv_heads, scale, block_tables, seq_lens, query_start_loc, \
|
||||
max_seq_len, alibi_slopes, k_scale, v_scale, fp8_out_scale); \
|
||||
} else { \
|
||||
paged_attention_custom_launcher_navi< \
|
||||
T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, ALIBI_ENABLED>( \
|
||||
paged_attention_custom_launcher_navi<T, KVT, KV_DTYPE, BLK_SIZE, \
|
||||
HEAD_SIZE, OUTT, PSIZE, \
|
||||
ALIBI_ENABLED, MFMA_TYPE>( \
|
||||
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
||||
num_kv_heads, scale, block_tables, seq_lens, query_start_loc, \
|
||||
max_seq_len, alibi_slopes, k_scale, v_scale); \
|
||||
}
|
||||
|
||||
#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
|
||||
OUTT, PSIZE) \
|
||||
OUTT, PSIZE, MFMA_TYPE) \
|
||||
if (alibi_slopes) { \
|
||||
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \
|
||||
true); \
|
||||
true, MFMA_TYPE); \
|
||||
} else { \
|
||||
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \
|
||||
false); \
|
||||
false, MFMA_TYPE); \
|
||||
}
|
||||
|
||||
#if defined(__HIPCC__) && defined(__gfx90a__)
|
||||
#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \
|
||||
#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
|
||||
MFMA_TYPE) \
|
||||
if (fp8_out_scale) { \
|
||||
TORCH_CHECK(false, "fp8 out scale unsupported for gfx90a"); \
|
||||
} else { \
|
||||
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
|
||||
256); \
|
||||
256, MFMA_TYPE); \
|
||||
}
|
||||
#else
|
||||
#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \
|
||||
#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
|
||||
MFMA_TYPE) \
|
||||
if (fp8_out_scale) { \
|
||||
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
|
||||
uint8_t, 256); \
|
||||
uint8_t, 256, MFMA_TYPE); \
|
||||
} else { \
|
||||
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
|
||||
256); \
|
||||
256, MFMA_TYPE); \
|
||||
}
|
||||
#endif
|
||||
|
||||
#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \
|
||||
switch (block_size) { \
|
||||
case 16: \
|
||||
CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \
|
||||
break; \
|
||||
case 32: \
|
||||
CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \
|
||||
break; \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||
break; \
|
||||
#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE, MFMA_TYPE) \
|
||||
switch (block_size) { \
|
||||
case 16: \
|
||||
CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 16, HEAD_SIZE, MFMA_TYPE); \
|
||||
break; \
|
||||
case 32: \
|
||||
CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 32, HEAD_SIZE, MFMA_TYPE); \
|
||||
break; \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||
break; \
|
||||
}
|
||||
|
||||
#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE) \
|
||||
switch (head_size) { \
|
||||
case 64: \
|
||||
CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 64); \
|
||||
break; \
|
||||
case 128: \
|
||||
CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 128); \
|
||||
break; \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported head size: ", head_size); \
|
||||
break; \
|
||||
#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE, MFMA_TYPE) \
|
||||
switch (head_size) { \
|
||||
case 64: \
|
||||
CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 64, MFMA_TYPE); \
|
||||
break; \
|
||||
case 128: \
|
||||
CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 128, MFMA_TYPE); \
|
||||
break; \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported head size: ", head_size); \
|
||||
break; \
|
||||
}
|
||||
|
||||
bool is_navi_gpu() {
|
||||
@ -3503,28 +3657,43 @@ void paged_attention(
|
||||
const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale,
|
||||
const std::optional<torch::Tensor>& fp8_out_scale) {
|
||||
const std::optional<torch::Tensor>& fp8_out_scale,
|
||||
const std::string& mfma_type) {
|
||||
// clang-format on
|
||||
bool is_navi = is_navi_gpu();
|
||||
|
||||
const int head_size = query.size(2);
|
||||
if (kv_cache_dtype == "auto") {
|
||||
if (query.dtype() == at::ScalarType::Half) {
|
||||
CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, _Float16,
|
||||
vllm::Fp8KVCacheDataType::kAuto);
|
||||
CALL_CUSTOM_LAUNCHER_BLK_HEAD(
|
||||
_Float16, _Float16, vllm::Fp8KVCacheDataType::kAuto, MFMAType::F16);
|
||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, __hip_bfloat16,
|
||||
vllm::Fp8KVCacheDataType::kAuto);
|
||||
vllm::Fp8KVCacheDataType::kAuto,
|
||||
MFMAType::F16);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||
}
|
||||
} else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") {
|
||||
if (query.dtype() == at::ScalarType::Half) {
|
||||
CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, uint8_t,
|
||||
vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||
if (mfma_type == "fp8") {
|
||||
CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, uint8_t,
|
||||
vllm::Fp8KVCacheDataType::kFp8E4M3,
|
||||
MFMAType::Fp8);
|
||||
} else {
|
||||
CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, uint8_t,
|
||||
vllm::Fp8KVCacheDataType::kFp8E4M3,
|
||||
MFMAType::F16);
|
||||
}
|
||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, uint8_t,
|
||||
vllm::Fp8KVCacheDataType::kFp8E4M3);
|
||||
if (mfma_type == "fp8") {
|
||||
CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, uint8_t,
|
||||
vllm::Fp8KVCacheDataType::kFp8E4M3,
|
||||
MFMAType::Fp8);
|
||||
} else {
|
||||
CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, uint8_t,
|
||||
vllm::Fp8KVCacheDataType::kFp8E4M3,
|
||||
MFMAType::F16);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||
}
|
||||
|
||||
@ -19,4 +19,5 @@ void paged_attention(
|
||||
const std::optional<torch::Tensor>& query_start_loc, int64_t block_size,
|
||||
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale, const std::optional<torch::Tensor>& fp8_out_scale);
|
||||
torch::Tensor& v_scale, const std::optional<torch::Tensor>& fp8_out_scale,
|
||||
const std::string& mfma_type);
|
||||
|
||||
@ -48,7 +48,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
|
||||
" Tensor? alibi_slopes,"
|
||||
" str kv_cache_dtype,"
|
||||
" Tensor k_scale, Tensor v_scale,"
|
||||
" Tensor? fp8_out_scale) -> ()");
|
||||
" Tensor? fp8_out_scale,"
|
||||
" str mfma_type) -> ()");
|
||||
rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention);
|
||||
}
|
||||
|
||||
|
||||
@ -32,6 +32,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
#define stride_tag
|
||||
#endif
|
||||
|
||||
ops.def(
|
||||
"silu_mul_fp8_quant_deep_gemm_cuda(Tensor input, Tensor counts, Tensor! "
|
||||
"y_q, Tensor! y_s, int group_size, "
|
||||
"bool use_ue8m0, int num_parallel_tokens) -> ()");
|
||||
ops.impl("silu_mul_fp8_quant_deep_gemm_cuda", torch::kCUDA,
|
||||
&silu_mul_fp8_quant_deep_gemm_cuda);
|
||||
|
||||
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
|
||||
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);
|
||||
|
||||
@ -168,6 +175,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"float epsilon) -> ()");
|
||||
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
|
||||
|
||||
// Polynomial Normalization.
|
||||
ops.def(
|
||||
"poly_norm(Tensor! out, Tensor input, Tensor weight, Tensor bias, float "
|
||||
"epsilon) -> ()");
|
||||
ops.impl("poly_norm", torch::kCUDA, &poly_norm);
|
||||
|
||||
// Apply repetition penalties to logits in-place
|
||||
ops.def(
|
||||
"apply_repetition_penalties_(Tensor! logits, Tensor prompt_mask, "
|
||||
@ -208,16 +221,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor cos_sin_cache, bool is_neox) -> ()");
|
||||
ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
|
||||
|
||||
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key
|
||||
// (supports multiple loras).
|
||||
ops.def(
|
||||
"batched_rotary_embedding(Tensor positions, Tensor! query,"
|
||||
" Tensor!? key, int head_size,"
|
||||
" Tensor cos_sin_cache, bool is_neox,"
|
||||
" int rot_dim,"
|
||||
" Tensor cos_sin_cache_offsets) -> ()");
|
||||
ops.impl("batched_rotary_embedding", torch::kCUDA, &batched_rotary_embedding);
|
||||
|
||||
// Quantization ops
|
||||
#ifndef USE_ROCM
|
||||
// Quantized GEMM for AWQ.
|
||||
@ -507,13 +510,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("cutlass_sparse_compress(Tensor a) -> Tensor[]");
|
||||
ops.impl("cutlass_sparse_compress", &cutlass_sparse_compress);
|
||||
|
||||
// CUTLASS MLA decode
|
||||
ops.def(
|
||||
"cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe,"
|
||||
" Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
|
||||
" Tensor page_table, float scale) -> ()");
|
||||
ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
|
||||
|
||||
// SM100 CUTLASS MLA decode
|
||||
ops.def(
|
||||
"sm100_cutlass_mla_decode(Tensor! out, Tensor! lse, Tensor q_nope,"
|
||||
@ -610,6 +606,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"int pad_slot_id) -> ()");
|
||||
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
|
||||
|
||||
// Hadamard transforms
|
||||
ops.def("hadacore_transform(Tensor! x, bool inplace) -> Tensor");
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// Compute per-token-group FP8 quantized tensor and scaling factor.
|
||||
ops.def(
|
||||
|
||||
@ -196,6 +196,7 @@ ARG SCCACHE_S3_NO_CREDENTIALS=0
|
||||
|
||||
# Flag to control whether to use pre-built vLLM wheels
|
||||
ARG VLLM_USE_PRECOMPILED=""
|
||||
ARG VLLM_MAIN_CUDA_VERSION=""
|
||||
|
||||
# if USE_SCCACHE is set, use sccache to speed up compilation
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
@ -213,6 +214,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
&& export SCCACHE_IDLE_TIMEOUT=0 \
|
||||
&& export CMAKE_BUILD_TYPE=Release \
|
||||
&& export VLLM_USE_PRECOMPILED="${VLLM_USE_PRECOMPILED}" \
|
||||
&& export VLLM_MAIN_CUDA_VERSION="${VLLM_MAIN_CUDA_VERSION}" \
|
||||
&& export VLLM_DOCKER_BUILD_CONTEXT=1 \
|
||||
&& sccache --show-stats \
|
||||
&& python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38 \
|
||||
@ -281,6 +283,10 @@ WORKDIR /vllm-workspace
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ARG TARGETPLATFORM
|
||||
|
||||
ARG GDRCOPY_CUDA_VERSION=12.8
|
||||
# Keep in line with FINAL_BASE_IMAGE
|
||||
ARG GDRCOPY_OS_VERSION=Ubuntu22_04
|
||||
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
|
||||
ARG DEADSNAKES_MIRROR_URL
|
||||
@ -375,7 +381,7 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
|
||||
# Install FlashInfer from source
|
||||
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
|
||||
# Keep this in sync with "flashinfer" extra in setup.py
|
||||
ARG FLASHINFER_GIT_REF="v0.3.0"
|
||||
ARG FLASHINFER_GIT_REF="v0.3.1"
|
||||
# Flag to control whether to compile FlashInfer AOT kernels
|
||||
# Set to "true" to enable AOT compilation:
|
||||
# docker build --build-arg FLASHINFER_AOT_COMPILE=true ...
|
||||
@ -439,13 +445,21 @@ COPY tools/install_deepgemm.sh /tmp/install_deepgemm.sh
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
VLLM_DOCKER_BUILD_CONTEXT=1 /tmp/install_deepgemm.sh --cuda-version "${CUDA_VERSION}" ${DEEPGEMM_GIT_REF:+--ref "$DEEPGEMM_GIT_REF"}
|
||||
|
||||
# Install EP kernels(pplx-kernels and DeepEP), NixL
|
||||
COPY tools/install_gdrcopy.sh install_gdrcopy.sh
|
||||
RUN set -eux; \
|
||||
case "${TARGETPLATFORM}" in \
|
||||
linux/arm64) UUARCH="aarch64" ;; \
|
||||
linux/amd64) UUARCH="x64" ;; \
|
||||
*) echo "Unsupported TARGETPLATFORM: ${TARGETPLATFORM}" >&2; exit 1 ;; \
|
||||
esac; \
|
||||
./install_gdrcopy.sh "${GDRCOPY_OS_VERSION}" "${GDRCOPY_CUDA_VERSION}" "${UUARCH}"; \
|
||||
rm ./install_gdrcopy.sh
|
||||
|
||||
# Install EP kernels(pplx-kernels and DeepEP)
|
||||
COPY tools/ep_kernels/install_python_libraries.sh install_python_libraries.sh
|
||||
COPY tools/install_nixl.sh install_nixl.sh
|
||||
ENV CUDA_HOME=/usr/local/cuda
|
||||
RUN export TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST:-9.0a+PTX}" \
|
||||
&& bash install_python_libraries.sh \
|
||||
&& bash install_nixl.sh --force
|
||||
&& bash install_python_libraries.sh
|
||||
|
||||
#################### vLLM installation IMAGE ####################
|
||||
|
||||
@ -519,7 +533,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
else \
|
||||
BITSANDBYTES_VERSION="0.46.1"; \
|
||||
fi; \
|
||||
uv pip install --system accelerate hf_transfer modelscope "bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]
|
||||
uv pip install --system accelerate hf_transfer modelscope "bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm>=1.0.17' boto3 runai-model-streamer runai-model-streamer[s3]
|
||||
|
||||
ENV VLLM_USAGE_SOURCE production-docker-image
|
||||
|
||||
|
||||
@ -246,7 +246,7 @@ RUN pip install setuptools==75.6.0 packaging==23.2 ninja==1.11.1.3 build==1.2.2.
|
||||
|
||||
|
||||
# build flashinfer for torch nightly from source around 10 mins
|
||||
# release version: v0.2.2.post1
|
||||
# release version: v0.3.1
|
||||
# todo(elainewy): cache flashinfer build result for faster build
|
||||
ENV CCACHE_DIR=/root/.cache/ccache
|
||||
RUN --mount=type=cache,target=/root/.cache/ccache \
|
||||
@ -254,7 +254,7 @@ RUN --mount=type=cache,target=/root/.cache/ccache \
|
||||
echo "git clone flashinfer..." \
|
||||
&& git clone --recursive https://github.com/flashinfer-ai/flashinfer.git \
|
||||
&& cd flashinfer \
|
||||
&& git checkout v0.2.2.post1 \
|
||||
&& git checkout v0.3.1 \
|
||||
&& git submodule update --init --recursive \
|
||||
&& echo "finish git clone flashinfer..." \
|
||||
&& rm -rf build \
|
||||
|
||||
@ -29,7 +29,10 @@ ARG VLLM_BRANCH="main"
|
||||
ONBUILD RUN git clone ${VLLM_REPO} \
|
||||
&& cd vllm \
|
||||
&& git fetch -v --prune -- origin ${VLLM_BRANCH} \
|
||||
&& git checkout FETCH_HEAD
|
||||
&& git checkout FETCH_HEAD \
|
||||
&& if [ ${VLLM_REPO} != "https://github.com/vllm-project/vllm.git" ] ; then \
|
||||
git remote add upstream "https://github.com/vllm-project/vllm.git" \
|
||||
&& git fetch upstream ; fi
|
||||
FROM fetch_vllm_${REMOTE_VLLM} AS fetch_vllm
|
||||
|
||||
# -----------------------
|
||||
@ -47,6 +50,7 @@ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/requirements /requirements
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/benchmarks /benchmarks
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/tests /tests
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/examples /examples
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/docker/Dockerfile.rocm /docker/
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/.buildkite /.buildkite
|
||||
|
||||
# -----------------------
|
||||
@ -71,7 +75,7 @@ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm /vllm-workspace
|
||||
RUN cd /vllm-workspace \
|
||||
&& rm -rf vllm \
|
||||
&& python3 -m pip install -e tests/vllm_test_utils \
|
||||
&& python3 -m pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] \
|
||||
&& python3 -m pip install lm-eval[api]==0.4.4 \
|
||||
&& python3 -m pip install pytest-shard
|
||||
|
||||
# -----------------------
|
||||
@ -100,8 +104,10 @@ ARG COMMON_WORKDIR
|
||||
# Copy over the benchmark scripts as well
|
||||
COPY --from=export_vllm /benchmarks ${COMMON_WORKDIR}/vllm/benchmarks
|
||||
COPY --from=export_vllm /examples ${COMMON_WORKDIR}/vllm/examples
|
||||
COPY --from=export_vllm /docker ${COMMON_WORKDIR}/vllm/docker
|
||||
|
||||
ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1
|
||||
ENV RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES=1
|
||||
ENV TOKENIZERS_PARALLELISM=false
|
||||
|
||||
# ENV that can improve safe tensor loading, and end-to-end time
|
||||
|
||||
@ -1,27 +1,23 @@
|
||||
ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:6.3.1-complete
|
||||
ARG HIPBLASLT_BRANCH="db8e93b4"
|
||||
ARG HIPBLAS_COMMON_BRANCH="7c1566b"
|
||||
ARG LEGACY_HIPBLASLT_OPTION=
|
||||
ARG RCCL_BRANCH="648a58d"
|
||||
ARG RCCL_REPO="https://github.com/ROCm/rccl"
|
||||
ARG TRITON_BRANCH="e5be006"
|
||||
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
|
||||
ARG PYTORCH_BRANCH="295f2ed4"
|
||||
ARG PYTORCH_VISION_BRANCH="v0.21.0"
|
||||
ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
|
||||
ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:7.0-complete
|
||||
ARG TRITON_BRANCH="f9e5bf54"
|
||||
ARG TRITON_REPO="https://github.com/ROCm/triton.git"
|
||||
ARG PYTORCH_BRANCH="b2fb6885"
|
||||
ARG PYTORCH_VISION_BRANCH="v0.23.0"
|
||||
ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git"
|
||||
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
|
||||
ARG FA_BRANCH="1a7f4dfa"
|
||||
ARG FA_BRANCH="0e60e394"
|
||||
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
|
||||
ARG AITER_BRANCH="916bf3c"
|
||||
ARG AITER_BRANCH="2ab9f4cd"
|
||||
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
|
||||
|
||||
FROM ${BASE_IMAGE} AS base
|
||||
|
||||
ENV PATH=/opt/rocm/llvm/bin:$PATH
|
||||
ENV PATH=/opt/rocm/llvm/bin:/opt/rocm/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
|
||||
ENV ROCM_PATH=/opt/rocm
|
||||
ENV LD_LIBRARY_PATH=/opt/rocm/lib:/usr/local/lib:
|
||||
ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942;gfx1100;gfx1101;gfx1200;gfx1201
|
||||
ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942;gfx950;gfx1100;gfx1101;gfx1200;gfx1201
|
||||
ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
|
||||
ENV AITER_ROCM_ARCH=gfx942;gfx950
|
||||
|
||||
ARG PYTHON_VERSION=3.12
|
||||
|
||||
@ -45,38 +41,7 @@ RUN apt-get update -y \
|
||||
&& curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \
|
||||
&& python3 --version && python3 -m pip --version
|
||||
|
||||
RUN pip install -U packaging 'cmake<4' ninja wheel setuptools pybind11 Cython
|
||||
|
||||
FROM base AS build_hipblaslt
|
||||
ARG HIPBLASLT_BRANCH
|
||||
ARG HIPBLAS_COMMON_BRANCH
|
||||
# Set to "--legacy_hipblas_direct" for ROCm<=6.2
|
||||
ARG LEGACY_HIPBLASLT_OPTION
|
||||
RUN git clone https://github.com/ROCm/hipBLAS-common.git
|
||||
RUN cd hipBLAS-common \
|
||||
&& git checkout ${HIPBLAS_COMMON_BRANCH} \
|
||||
&& mkdir build \
|
||||
&& cd build \
|
||||
&& cmake .. \
|
||||
&& make package \
|
||||
&& dpkg -i ./*.deb
|
||||
RUN git clone https://github.com/ROCm/hipBLASLt
|
||||
RUN cd hipBLASLt \
|
||||
&& git checkout ${HIPBLASLT_BRANCH} \
|
||||
&& apt-get install -y llvm-dev \
|
||||
&& ./install.sh -dc --architecture ${PYTORCH_ROCM_ARCH} ${LEGACY_HIPBLASLT_OPTION} \
|
||||
&& cd build/release \
|
||||
&& make package
|
||||
RUN mkdir -p /app/install && cp /app/hipBLASLt/build/release/*.deb /app/hipBLAS-common/build/*.deb /app/install
|
||||
|
||||
FROM base AS build_rccl
|
||||
ARG RCCL_BRANCH
|
||||
ARG RCCL_REPO
|
||||
RUN git clone ${RCCL_REPO}
|
||||
RUN cd rccl \
|
||||
&& git checkout ${RCCL_BRANCH} \
|
||||
&& ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH}
|
||||
RUN mkdir -p /app/install && cp /app/rccl/build/release/*.deb /app/install
|
||||
RUN pip install -U packaging 'cmake<4' ninja wheel 'setuptools<80' pybind11 Cython
|
||||
|
||||
FROM base AS build_triton
|
||||
ARG TRITON_BRANCH
|
||||
@ -84,9 +49,11 @@ ARG TRITON_REPO
|
||||
RUN git clone ${TRITON_REPO}
|
||||
RUN cd triton \
|
||||
&& git checkout ${TRITON_BRANCH} \
|
||||
&& cd python \
|
||||
&& python3 setup.py bdist_wheel --dist-dir=dist
|
||||
RUN mkdir -p /app/install && cp /app/triton/python/dist/*.whl /app/install
|
||||
&& if [ ! -f setup.py ]; then cd python; fi \
|
||||
&& python3 setup.py bdist_wheel --dist-dir=dist \
|
||||
&& mkdir -p /app/install && cp dist/*.whl /app/install
|
||||
RUN if [ -d triton/python/triton_kernels ]; then pip install build && cd triton/python/triton_kernels \
|
||||
&& python3 -m build --wheel && cp dist/*.whl /app/install; fi
|
||||
|
||||
FROM base AS build_amdsmi
|
||||
RUN cd /opt/rocm/share/amd_smi \
|
||||
@ -129,18 +96,21 @@ RUN cd aiter \
|
||||
&& git checkout ${AITER_BRANCH} \
|
||||
&& git submodule update --init --recursive \
|
||||
&& pip install -r requirements.txt
|
||||
RUN pip install pyyaml && cd aiter && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py bdist_wheel --dist-dir=dist && ls /app/aiter/dist/*.whl
|
||||
RUN pip install pyyaml && cd aiter && PREBUILD_KERNELS=1 GPU_ARCHS=${AITER_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist && ls /app/aiter/dist/*.whl
|
||||
RUN mkdir -p /app/install && cp /app/aiter/dist/*.whl /app/install
|
||||
|
||||
FROM base AS debs
|
||||
RUN mkdir /app/debs
|
||||
RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \
|
||||
cp /install/*.whl /app/debs
|
||||
RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
|
||||
cp /install/*.whl /app/debs
|
||||
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
|
||||
cp /install/*.whl /app/debs
|
||||
RUN --mount=type=bind,from=build_aiter,src=/app/install/,target=/install \
|
||||
cp /install/*.whl /app/debs
|
||||
|
||||
FROM base AS final
|
||||
RUN --mount=type=bind,from=build_hipblaslt,src=/app/install/,target=/install \
|
||||
dpkg -i /install/*deb \
|
||||
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
|
||||
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status
|
||||
RUN --mount=type=bind,from=build_rccl,src=/app/install/,target=/install \
|
||||
dpkg -i /install/*deb \
|
||||
&& sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \
|
||||
&& sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status
|
||||
RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \
|
||||
pip install /install/*.whl
|
||||
RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
|
||||
@ -151,11 +121,6 @@ RUN --mount=type=bind,from=build_aiter,src=/app/install/,target=/install \
|
||||
pip install /install/*.whl
|
||||
|
||||
ARG BASE_IMAGE
|
||||
ARG HIPBLAS_COMMON_BRANCH
|
||||
ARG HIPBLASLT_BRANCH
|
||||
ARG LEGACY_HIPBLASLT_OPTION
|
||||
ARG RCCL_BRANCH
|
||||
ARG RCCL_REPO
|
||||
ARG TRITON_BRANCH
|
||||
ARG TRITON_REPO
|
||||
ARG PYTORCH_BRANCH
|
||||
@ -167,11 +132,6 @@ ARG FA_REPO
|
||||
ARG AITER_BRANCH
|
||||
ARG AITER_REPO
|
||||
RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \
|
||||
&& echo "HIPBLAS_COMMON_BRANCH: ${HIPBLAS_COMMON_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "HIPBLASLT_BRANCH: ${HIPBLASLT_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "LEGACY_HIPBLASLT_OPTION: ${LEGACY_HIPBLASLT_OPTION}" >> /app/versions.txt \
|
||||
&& echo "RCCL_BRANCH: ${RCCL_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "RCCL_REPO: ${RCCL_REPO}" >> /app/versions.txt \
|
||||
&& echo "TRITON_BRANCH: ${TRITON_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "TRITON_REPO: ${TRITON_REPO}" >> /app/versions.txt \
|
||||
&& echo "PYTORCH_BRANCH: ${PYTORCH_BRANCH}" >> /app/versions.txt \
|
||||
@ -179,5 +139,6 @@ RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \
|
||||
&& echo "PYTORCH_REPO: ${PYTORCH_REPO}" >> /app/versions.txt \
|
||||
&& echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \
|
||||
&& echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt \
|
||||
&& echo "AITER_BRANCH: ${AITER_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt
|
||||
&& echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt
|
||||
@ -44,11 +44,12 @@ nav:
|
||||
- contributing/model/registration.md
|
||||
- contributing/model/tests.md
|
||||
- contributing/model/multimodal.md
|
||||
- contributing/model/transcription.md
|
||||
- CI: contributing/ci
|
||||
- Design Documents: design
|
||||
- API Reference:
|
||||
- api/README.md
|
||||
- api/vllm/*
|
||||
- api/vllm
|
||||
- CLI Reference: cli
|
||||
- Community:
|
||||
- community/*
|
||||
|
||||
@ -56,7 +56,7 @@ vLLM is flexible and easy to use with:
|
||||
- Tensor, pipeline, data and expert parallelism support for distributed inference
|
||||
- Streaming outputs
|
||||
- OpenAI-compatible API server
|
||||
- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs, Gaudi® accelerators and GPUs, IBM Power CPUs, TPU, and AWS Trainium and Inferentia Accelerators.
|
||||
- Support for NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, and TPU. Additionally, support for diverse hardware plugins such as Intel Gaudi, IBM Spyre and Huawei Ascend.
|
||||
- Prefix caching support
|
||||
- Multi-LoRA support
|
||||
|
||||
|
||||
@ -14,7 +14,7 @@ API documentation for vLLM's configuration classes.
|
||||
- [vllm.config.LoRAConfig][]
|
||||
- [vllm.config.MultiModalConfig][]
|
||||
- [vllm.config.PoolerConfig][]
|
||||
- [vllm.config.DecodingConfig][]
|
||||
- [vllm.config.StructuredOutputsConfig][]
|
||||
- [vllm.config.ObservabilityConfig][]
|
||||
- [vllm.config.KVTransferConfig][]
|
||||
- [vllm.config.CompilationConfig][]
|
||||
@ -46,7 +46,6 @@ Engine classes for offline and online inference.
|
||||
Inference parameters for vLLM APIs.
|
||||
|
||||
[](){ #sampling-params }
|
||||
[](){ #pooling-params }
|
||||
|
||||
- [vllm.SamplingParams][]
|
||||
- [vllm.PoolingParams][]
|
||||
|
||||
@ -175,6 +175,7 @@ Regardless, you need to set `mm_encoder_tp_mode="data"` in engine arguments to u
|
||||
Known supported models:
|
||||
|
||||
- GLM-4.5V GLM-4.1V (<gh-pr:23168>)
|
||||
- InternVL (<gh-pr:23909>)
|
||||
- Kimi-VL (<gh-pr:23817>)
|
||||
- Llama4 (<gh-pr:18368>)
|
||||
- MiniCPM-V-2.5 or above (<gh-pr:23327>, <gh-pr:23948>)
|
||||
@ -230,6 +231,20 @@ Multi-modal IPC caching is automatically enabled when
|
||||
there is a one-to-one correspondence between API (`P0`) and engine core (`P1`) processes,
|
||||
to avoid repeatedly transferring the same multi-modal inputs between them.
|
||||
|
||||
#### Key-Replicated Cache
|
||||
|
||||
By default, IPC caching uses a **key-replicated cache**, where cache keys exist
|
||||
in both the API (`P0`) and engine core (`P1`) processes, but the actual cache
|
||||
data resides only in `P1`.
|
||||
|
||||
#### Shared Memory Cache
|
||||
|
||||
When multiple worker processes are involved (e.g., when TP > 1), a
|
||||
**shared-memory cache** is more efficient. This can be enabled by setting
|
||||
`mm_processor_cache_type="shm"`. In this mode, cache keys are stored
|
||||
on `P0`, while the cache data itself lives in shared memory accessible by all
|
||||
processes.
|
||||
|
||||
### Configuration
|
||||
|
||||
You can adjust the size of the cache by setting the value of `mm_processor_cache_gb` (default 4 GiB).
|
||||
@ -244,6 +259,12 @@ Examples:
|
||||
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
mm_processor_cache_gb=8)
|
||||
|
||||
# Use a shared-memory based IPC cache
|
||||
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
tensor_parallel_size=2,
|
||||
mm_processor_cache_type="shm",
|
||||
mm_processor_cache_gb=8)
|
||||
|
||||
# Disable the cache
|
||||
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
mm_processor_cache_gb=0)
|
||||
@ -253,11 +274,12 @@ llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
|
||||
Based on the configuration, the content of the multi-modal caches on `P0` and `P1` are as follows:
|
||||
|
||||
| Processor Caching | IPC Caching | `P0` Cache | `P1` Cache | Max. Memory |
|
||||
|-------------------|-------------|------------|------------|-------------|
|
||||
| ✅ | ✅ | K | K + V | `mm_processor_cache_gb * data_parallel_size` |
|
||||
| ✅ | ❌ | K + V | N/A | `mm_processor_cache_gb * api_server_count` |
|
||||
| ❌ | ❌ | N/A | N/A | `0` |
|
||||
| mm_processor_cache_type | Cache Type | `P0` Cache | `P1` Engine Cache | `P1` Worker Cache | Max. Memory |
|
||||
|-------------------|-------------|------------|------------|-------------|-------------|
|
||||
| lru | Processor Caching | K + V | N/A | N/A | `mm_processor_cache_gb * data_parallel_size` |
|
||||
| lru | Key-Replicated Caching | K | K + V | N/A | `mm_processor_cache_gb * api_server_count` |
|
||||
| shm | Shared Memory Caching | K | N/A | V | `mm_processor_cache_gb * api_server_count` |
|
||||
| N/A | Disabled | N/A | N/A | N/A | `0` |
|
||||
|
||||
K: Stores the hashes of multi-modal items
|
||||
V: Stores the processed tensor data of multi-modal items
|
||||
|
||||
@ -26,113 +26,123 @@ See <gh-file:LICENSE>.
|
||||
|
||||
## Developing
|
||||
|
||||
--8<-- "docs/getting_started/installation/python_env_setup.inc.md"
|
||||
|
||||
Depending on the kind of development you'd like to do (e.g. Python, CUDA), you can choose to build vLLM with or without compilation.
|
||||
Check out the [building from source][build-from-source] documentation for details.
|
||||
|
||||
For an optimized workflow when iterating on C++/CUDA kernels, see the [Incremental Compilation Workflow](./incremental_build.md) for recommendations.
|
||||
|
||||
### Building the docs with MkDocs
|
||||
|
||||
#### Introduction to MkDocs
|
||||
|
||||
[MkDocs](https://github.com/mkdocs/mkdocs) is a fast, simple and downright gorgeous static site generator that's geared towards building project documentation. Documentation source files are written in Markdown, and configured with a single YAML configuration file.
|
||||
|
||||
#### Install MkDocs and Plugins
|
||||
|
||||
Install MkDocs along with the [plugins](https://github.com/vllm-project/vllm/blob/main/mkdocs.yaml) used in the vLLM documentation, as well as required dependencies:
|
||||
|
||||
```bash
|
||||
uv pip install -r requirements/docs.txt
|
||||
```
|
||||
|
||||
!!! note
|
||||
Ensure that your Python version is compatible with the plugins (e.g., `mkdocs-awesome-nav` requires Python 3.10+)
|
||||
|
||||
#### Verify Installation
|
||||
|
||||
Confirm that MkDocs is correctly installed:
|
||||
|
||||
```bash
|
||||
mkdocs --version
|
||||
```
|
||||
|
||||
Example output:
|
||||
|
||||
```console
|
||||
mkdocs, version 1.6.1 from /opt/miniconda3/envs/mkdoc/lib/python3.10/site-packages/mkdocs (Python 3.10)
|
||||
```
|
||||
|
||||
#### Clone the `vLLM` repository
|
||||
The first step of contributing to vLLM is to clone the GitHub repository:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/vllm-project/vllm.git
|
||||
cd vllm
|
||||
```
|
||||
|
||||
#### Start the Development Server
|
||||
Then, configure your Python virtual environment.
|
||||
|
||||
MkDocs comes with a built-in dev-server that lets you preview your documentation as you work on it. Make sure you're in the same directory as the `mkdocs.yml` configuration file, and then start the server by running the `mkdocs serve` command:
|
||||
--8<-- "docs/getting_started/installation/python_env_setup.inc.md"
|
||||
|
||||
If you are only developing vLLM's Python code, install vLLM using:
|
||||
|
||||
```bash
|
||||
mkdocs serve
|
||||
VLLM_USE_PRECOMPILED=1 uv pip install -e .
|
||||
```
|
||||
|
||||
Example output:
|
||||
If you are developing vLLM's Python and CUDA/C++ code, install vLLM using:
|
||||
|
||||
```console
|
||||
INFO - Documentation built in 106.83 seconds
|
||||
INFO - [22:02:02] Watching paths for changes: 'docs', 'mkdocs.yaml'
|
||||
INFO - [22:02:02] Serving on http://127.0.0.1:8000/
|
||||
```bash
|
||||
uv pip install -e .
|
||||
```
|
||||
|
||||
#### View in Your Browser
|
||||
For more details about installing from source and installing for other hardware, check out the [installation instructions](../getting_started/installation/README.md) for your hardware and head to the "Build wheel from source" section.
|
||||
|
||||
Open up [http://127.0.0.1:8000/](http://127.0.0.1:8000/) in your browser to see a live preview:.
|
||||
|
||||
#### Learn More
|
||||
|
||||
For additional features and advanced configurations, refer to the official [MkDocs Documentation](https://www.mkdocs.org/).
|
||||
|
||||
## Testing
|
||||
|
||||
??? console "Commands"
|
||||
|
||||
```bash
|
||||
# These commands are only for Nvidia CUDA platforms.
|
||||
uv pip install -r requirements/common.txt -r requirements/dev.txt --torch-backend=auto
|
||||
|
||||
# Linting, formatting and static type checking
|
||||
pre-commit install
|
||||
|
||||
# You can manually run pre-commit with
|
||||
pre-commit run --all-files --show-diff-on-failure
|
||||
|
||||
# To manually run something from CI that does not run
|
||||
# locally by default, you can run:
|
||||
pre-commit run mypy-3.9 --hook-stage manual --all-files
|
||||
|
||||
# Unit tests
|
||||
pytest tests/
|
||||
|
||||
# Run tests for a single test file with detailed output
|
||||
pytest -s -v tests/test_logger.py
|
||||
```
|
||||
For an optimized workflow when iterating on C++/CUDA kernels, see the [Incremental Compilation Workflow](./incremental_build.md) for recommendations.
|
||||
|
||||
!!! tip
|
||||
Since the <gh-file:docker/Dockerfile> ships with Python 3.12, all tests in CI (except `mypy`) are run with Python 3.12.
|
||||
vLLM is compatible with Python versions 3.9 to 3.12. However, vLLM's default [Dockerfile](gh-file:docker/Dockerfile) ships with Python 3.12 and tests in CI (except `mypy`) are run with Python 3.12.
|
||||
|
||||
Therefore, we recommend developing with Python 3.12 to minimise the chance of your local environment clashing with our CI environment.
|
||||
|
||||
!!! note "Install python3-dev if Python.h is missing"
|
||||
### Linting
|
||||
|
||||
vLLM uses `pre-commit` to lint and format the codebase. See <https://pre-commit.com/#usage> if `pre-commit` is new to you. Setting up `pre-commit` is as easy as:
|
||||
|
||||
```bash
|
||||
uv pip install pre-commit
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
vLLM's `pre-commit` hooks will now run automatically every time you commit.
|
||||
|
||||
!!! tip "Tips"
|
||||
You can manually run the `pre-commit` hooks using:
|
||||
|
||||
```bash
|
||||
pre-commit run # runs on staged files
|
||||
pre-commit run -a # runs on all files (short for --all-files)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
Some `pre-commit` hooks only run in CI. If you need to, you can run them locally with:
|
||||
|
||||
```bash
|
||||
pre-commit run --hook-stage manual markdownlint
|
||||
pre-commit run --hook-stage manual mypy-3.9
|
||||
```
|
||||
|
||||
### Documentation
|
||||
|
||||
MkDocs is a fast, simple and downright gorgeous static site generator that's geared towards building project documentation. Documentation source files are written in Markdown, and configured with a single YAML configuration file, <gh-file:mkdocs.yaml>.
|
||||
|
||||
Get started with:
|
||||
|
||||
```bash
|
||||
uv pip install -r requirements/docs.txt
|
||||
```
|
||||
|
||||
!!! tip
|
||||
Ensure that your Python version is compatible with the plugins
|
||||
(e.g., `mkdocs-awesome-nav` requires Python 3.10+)
|
||||
|
||||
MkDocs comes with a built-in dev-server that lets you preview your documentation as you work on it.
|
||||
From the root of the repository, run:
|
||||
|
||||
```bash
|
||||
mkdocs serve # with API ref (~10 minutes)
|
||||
API_AUTONAV_EXCLUDE=vllm mkdocs serve # API ref off (~15 seconds)
|
||||
```
|
||||
|
||||
Once you see `Serving on http://127.0.0.1:8000/` in the logs, the live preview is ready!
|
||||
Open <http://127.0.0.1:8000/> in your browser to see it.
|
||||
|
||||
For additional features and advanced configurations, refer to the:
|
||||
|
||||
- [MkDocs documentation](https://www.mkdocs.org/)
|
||||
- [Material for MkDocs documentation](https://squidfunk.github.io/mkdocs-material/) (the MkDocs theme we use)
|
||||
|
||||
### Testing
|
||||
|
||||
vLLM uses `pytest` to test the codebase.
|
||||
|
||||
```bash
|
||||
# Install the test dependencies used in CI (CUDA only)
|
||||
uv pip install -r requirements/common.txt -r requirements/dev.txt --torch-backend=auto
|
||||
|
||||
# Install some common test dependencies (hardware agnostic)
|
||||
uv pip install pytest pytest-asyncio
|
||||
|
||||
# Run all tests
|
||||
pytest tests/
|
||||
|
||||
# Run tests for a single test file with detailed output
|
||||
pytest -s -v tests/test_logger.py
|
||||
```
|
||||
|
||||
!!! tip "Install python3-dev if Python.h is missing"
|
||||
If any of the above commands fails with `Python.h: No such file or directory`, install
|
||||
`python3-dev` with `sudo apt install python3-dev`.
|
||||
|
||||
!!! note
|
||||
!!! warning "Warnings"
|
||||
Currently, the repository is not fully checked by `mypy`.
|
||||
|
||||
!!! note
|
||||
---
|
||||
|
||||
Currently, not all unit tests pass when run on CPU platforms. If you don't have access to a GPU
|
||||
platform to run unit tests locally, rely on the continuous integration system to run the tests for
|
||||
now.
|
||||
@ -194,8 +204,7 @@ appropriately to indicate the type of change. Please use one of the following:
|
||||
The PR needs to meet the following code quality standards:
|
||||
|
||||
- We adhere to [Google Python style guide](https://google.github.io/styleguide/pyguide.html) and [Google C++ style guide](https://google.github.io/styleguide/cppguide.html).
|
||||
- Pass all linter checks. Please use `pre-commit` to format your code. See
|
||||
<https://pre-commit.com/#usage> if `pre-commit` is new to you.
|
||||
- Pass all linter checks.
|
||||
- The code needs to be well-documented to ensure future contributors can easily
|
||||
understand the code.
|
||||
- Include sufficient tests to ensure the project stays correct and robust. This
|
||||
|
||||
@ -1,9 +1,787 @@
|
||||
---
|
||||
toc_depth: 4
|
||||
---
|
||||
|
||||
# Benchmark Suites
|
||||
|
||||
vLLM contains two sets of benchmarks:
|
||||
vLLM provides comprehensive benchmarking tools for performance testing and evaluation:
|
||||
|
||||
- [Performance benchmarks][performance-benchmarks]
|
||||
- [Nightly benchmarks][nightly-benchmarks]
|
||||
- **[Benchmark CLI]**: `vllm bench` CLI tools and specialized benchmark scripts for interactive performance testing
|
||||
- **[Performance benchmarks][performance-benchmarks]**: Automated CI benchmarks for development
|
||||
- **[Nightly benchmarks][nightly-benchmarks]**: Comparative benchmarks against alternatives
|
||||
|
||||
[Benchmark CLI]: #benchmark-cli
|
||||
|
||||
## Benchmark CLI
|
||||
|
||||
This section guides you through running benchmark tests with the extensive
|
||||
datasets supported on vLLM. It's a living document, updated as new features and datasets
|
||||
become available.
|
||||
|
||||
### Dataset Overview
|
||||
|
||||
<style>
|
||||
th {
|
||||
min-width: 0 !important;
|
||||
}
|
||||
</style>
|
||||
|
||||
| Dataset | Online | Offline | Data Path |
|
||||
|---------|--------|---------|-----------|
|
||||
| ShareGPT | ✅ | ✅ | `wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json` |
|
||||
| ShareGPT4V (Image) | ✅ | ✅ | `wget https://huggingface.co/datasets/Lin-Chen/ShareGPT4V/blob/main/sharegpt4v_instruct_gpt4-vision_cap100k.json`<br>Note that the images need to be downloaded separately. For example, to download COCO's 2017 Train images:<br>`wget http://images.cocodataset.org/zips/train2017.zip` |
|
||||
| ShareGPT4Video (Video) | ✅ | ✅ | `git clone https://huggingface.co/datasets/ShareGPT4Video/ShareGPT4Video` |
|
||||
| BurstGPT | ✅ | ✅ | `wget https://github.com/HPMLL/BurstGPT/releases/download/v1.1/BurstGPT_without_fails_2.csv` |
|
||||
| Sonnet (deprecated) | ✅ | ✅ | Local file: `benchmarks/sonnet.txt` |
|
||||
| Random | ✅ | ✅ | `synthetic` |
|
||||
| RandomMultiModal (Image/Video) | 🟡 | 🚧 | `synthetic` |
|
||||
| Prefix Repetition | ✅ | ✅ | `synthetic` |
|
||||
| HuggingFace-VisionArena | ✅ | ✅ | `lmarena-ai/VisionArena-Chat` |
|
||||
| HuggingFace-MMVU | ✅ | ✅ | `yale-nlp/MMVU` |
|
||||
| HuggingFace-InstructCoder | ✅ | ✅ | `likaixin/InstructCoder` |
|
||||
| HuggingFace-AIMO | ✅ | ✅ | `AI-MO/aimo-validation-aime`, `AI-MO/NuminaMath-1.5`, `AI-MO/NuminaMath-CoT` |
|
||||
| HuggingFace-Other | ✅ | ✅ | `lmms-lab/LLaVA-OneVision-Data`, `Aeala/ShareGPT_Vicuna_unfiltered` |
|
||||
| HuggingFace-MTBench | ✅ | ✅ | `philschmid/mt-bench` |
|
||||
| HuggingFace-Blazedit | ✅ | ✅ | `vdaita/edit_5k_char`, `vdaita/edit_10k_char` |
|
||||
| Spec Bench | ✅ | ✅ | `wget https://raw.githubusercontent.com/hemingkx/Spec-Bench/refs/heads/main/data/spec_bench/question.jsonl` |
|
||||
| Custom | ✅ | ✅ | Local file: `data.jsonl` |
|
||||
|
||||
Legend:
|
||||
|
||||
- ✅ - supported
|
||||
- 🟡 - Partial support
|
||||
- 🚧 - to be supported
|
||||
|
||||
!!! note
|
||||
HuggingFace dataset's `dataset-name` should be set to `hf`.
|
||||
For local `dataset-path`, please set `hf-name` to its Hugging Face ID like
|
||||
|
||||
```bash
|
||||
--dataset-path /datasets/VisionArena-Chat/ --hf-name lmarena-ai/VisionArena-Chat
|
||||
```
|
||||
|
||||
### Examples
|
||||
|
||||
#### 🚀 Online Benchmark
|
||||
|
||||
<details class="admonition abstract" markdown="1">
|
||||
<summary>Show more</summary>
|
||||
|
||||
First start serving your model
|
||||
|
||||
```bash
|
||||
vllm serve NousResearch/Hermes-3-Llama-3.1-8B
|
||||
```
|
||||
|
||||
Then run the benchmarking script
|
||||
|
||||
```bash
|
||||
# download dataset
|
||||
# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
vllm bench serve \
|
||||
--backend vllm \
|
||||
--model NousResearch/Hermes-3-Llama-3.1-8B \
|
||||
--endpoint /v1/completions \
|
||||
--dataset-name sharegpt \
|
||||
--dataset-path <your data path>/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||
--num-prompts 10
|
||||
```
|
||||
|
||||
If successful, you will see the following output
|
||||
|
||||
```text
|
||||
============ Serving Benchmark Result ============
|
||||
Successful requests: 10
|
||||
Benchmark duration (s): 5.78
|
||||
Total input tokens: 1369
|
||||
Total generated tokens: 2212
|
||||
Request throughput (req/s): 1.73
|
||||
Output token throughput (tok/s): 382.89
|
||||
Total Token throughput (tok/s): 619.85
|
||||
---------------Time to First Token----------------
|
||||
Mean TTFT (ms): 71.54
|
||||
Median TTFT (ms): 73.88
|
||||
P99 TTFT (ms): 79.49
|
||||
-----Time per Output Token (excl. 1st token)------
|
||||
Mean TPOT (ms): 7.91
|
||||
Median TPOT (ms): 7.96
|
||||
P99 TPOT (ms): 8.03
|
||||
---------------Inter-token Latency----------------
|
||||
Mean ITL (ms): 7.74
|
||||
Median ITL (ms): 7.70
|
||||
P99 ITL (ms): 8.39
|
||||
==================================================
|
||||
```
|
||||
|
||||
##### Custom Dataset
|
||||
|
||||
If the dataset you want to benchmark is not supported yet in vLLM, even then you can benchmark on it using `CustomDataset`. Your data needs to be in `.jsonl` format and needs to have "prompt" field per entry, e.g., data.jsonl
|
||||
|
||||
```json
|
||||
{"prompt": "What is the capital of India?"}
|
||||
{"prompt": "What is the capital of Iran?"}
|
||||
{"prompt": "What is the capital of China?"}
|
||||
```
|
||||
|
||||
```bash
|
||||
# start server
|
||||
VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.1-8B-Instruct
|
||||
```
|
||||
|
||||
```bash
|
||||
# run benchmarking script
|
||||
vllm bench serve --port 9001 --save-result --save-detailed \
|
||||
--backend vllm \
|
||||
--model meta-llama/Llama-3.1-8B-Instruct \
|
||||
--endpoint /v1/completions \
|
||||
--dataset-name custom \
|
||||
--dataset-path <path-to-your-data-jsonl> \
|
||||
--custom-skip-chat-template \
|
||||
--num-prompts 80 \
|
||||
--max-concurrency 1 \
|
||||
--temperature=0.3 \
|
||||
--top-p=0.75 \
|
||||
--result-dir "./log/"
|
||||
```
|
||||
|
||||
You can skip applying chat template if your data already has it by using `--custom-skip-chat-template`.
|
||||
|
||||
##### VisionArena Benchmark for Vision Language Models
|
||||
|
||||
```bash
|
||||
# need a model with vision capability here
|
||||
vllm serve Qwen/Qwen2-VL-7B-Instruct
|
||||
```
|
||||
|
||||
```bash
|
||||
vllm bench serve \
|
||||
--backend openai-chat \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--endpoint /v1/chat/completions \
|
||||
--dataset-name hf \
|
||||
--dataset-path lmarena-ai/VisionArena-Chat \
|
||||
--hf-split train \
|
||||
--num-prompts 1000
|
||||
```
|
||||
|
||||
##### InstructCoder Benchmark with Speculative Decoding
|
||||
|
||||
``` bash
|
||||
VLLM_USE_V1=1 vllm serve meta-llama/Meta-Llama-3-8B-Instruct \
|
||||
--speculative-config $'{"method": "ngram",
|
||||
"num_speculative_tokens": 5, "prompt_lookup_max": 5,
|
||||
"prompt_lookup_min": 2}'
|
||||
```
|
||||
|
||||
``` bash
|
||||
vllm bench serve \
|
||||
--model meta-llama/Meta-Llama-3-8B-Instruct \
|
||||
--dataset-name hf \
|
||||
--dataset-path likaixin/InstructCoder \
|
||||
--num-prompts 2048
|
||||
```
|
||||
|
||||
##### Spec Bench Benchmark with Speculative Decoding
|
||||
|
||||
``` bash
|
||||
VLLM_USE_V1=1 vllm serve meta-llama/Meta-Llama-3-8B-Instruct \
|
||||
--speculative-config $'{"method": "ngram",
|
||||
"num_speculative_tokens": 5, "prompt_lookup_max": 5,
|
||||
"prompt_lookup_min": 2}'
|
||||
```
|
||||
|
||||
[SpecBench dataset](https://github.com/hemingkx/Spec-Bench)
|
||||
|
||||
Run all categories:
|
||||
|
||||
``` bash
|
||||
# Download the dataset using:
|
||||
# wget https://raw.githubusercontent.com/hemingkx/Spec-Bench/refs/heads/main/data/spec_bench/question.jsonl
|
||||
|
||||
vllm bench serve \
|
||||
--model meta-llama/Meta-Llama-3-8B-Instruct \
|
||||
--dataset-name spec_bench \
|
||||
--dataset-path "<YOUR_DOWNLOADED_PATH>/data/spec_bench/question.jsonl" \
|
||||
--num-prompts -1
|
||||
```
|
||||
|
||||
Available categories include `[writing, roleplay, reasoning, math, coding, extraction, stem, humanities, translation, summarization, qa, math_reasoning, rag]`.
|
||||
|
||||
Run only a specific category like "summarization":
|
||||
|
||||
``` bash
|
||||
vllm bench serve \
|
||||
--model meta-llama/Meta-Llama-3-8B-Instruct \
|
||||
--dataset-name spec_bench \
|
||||
--dataset-path "<YOUR_DOWNLOADED_PATH>/data/spec_bench/question.jsonl" \
|
||||
--num-prompts -1
|
||||
--spec-bench-category "summarization"
|
||||
```
|
||||
|
||||
##### Other HuggingFaceDataset Examples
|
||||
|
||||
```bash
|
||||
vllm serve Qwen/Qwen2-VL-7B-Instruct
|
||||
```
|
||||
|
||||
`lmms-lab/LLaVA-OneVision-Data`:
|
||||
|
||||
```bash
|
||||
vllm bench serve \
|
||||
--backend openai-chat \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--endpoint /v1/chat/completions \
|
||||
--dataset-name hf \
|
||||
--dataset-path lmms-lab/LLaVA-OneVision-Data \
|
||||
--hf-split train \
|
||||
--hf-subset "chart2text(cauldron)" \
|
||||
--num-prompts 10
|
||||
```
|
||||
|
||||
`Aeala/ShareGPT_Vicuna_unfiltered`:
|
||||
|
||||
```bash
|
||||
vllm bench serve \
|
||||
--backend openai-chat \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--endpoint /v1/chat/completions \
|
||||
--dataset-name hf \
|
||||
--dataset-path Aeala/ShareGPT_Vicuna_unfiltered \
|
||||
--hf-split train \
|
||||
--num-prompts 10
|
||||
```
|
||||
|
||||
`AI-MO/aimo-validation-aime`:
|
||||
|
||||
``` bash
|
||||
vllm bench serve \
|
||||
--model Qwen/QwQ-32B \
|
||||
--dataset-name hf \
|
||||
--dataset-path AI-MO/aimo-validation-aime \
|
||||
--num-prompts 10 \
|
||||
--seed 42
|
||||
```
|
||||
|
||||
`philschmid/mt-bench`:
|
||||
|
||||
``` bash
|
||||
vllm bench serve \
|
||||
--model Qwen/QwQ-32B \
|
||||
--dataset-name hf \
|
||||
--dataset-path philschmid/mt-bench \
|
||||
--num-prompts 80
|
||||
```
|
||||
|
||||
`vdaita/edit_5k_char` or `vdaita/edit_10k_char`:
|
||||
|
||||
``` bash
|
||||
vllm bench serve \
|
||||
--model Qwen/QwQ-32B \
|
||||
--dataset-name hf \
|
||||
--dataset-path vdaita/edit_5k_char \
|
||||
--num-prompts 90 \
|
||||
--blazedit-min-distance 0.01 \
|
||||
--blazedit-max-distance 0.99
|
||||
```
|
||||
|
||||
##### Running With Sampling Parameters
|
||||
|
||||
When using OpenAI-compatible backends such as `vllm`, optional sampling
|
||||
parameters can be specified. Example client command:
|
||||
|
||||
```bash
|
||||
vllm bench serve \
|
||||
--backend vllm \
|
||||
--model NousResearch/Hermes-3-Llama-3.1-8B \
|
||||
--endpoint /v1/completions \
|
||||
--dataset-name sharegpt \
|
||||
--dataset-path <your data path>/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||
--top-k 10 \
|
||||
--top-p 0.9 \
|
||||
--temperature 0.5 \
|
||||
--num-prompts 10
|
||||
```
|
||||
|
||||
##### Running With Ramp-Up Request Rate
|
||||
|
||||
The benchmark tool also supports ramping up the request rate over the
|
||||
duration of the benchmark run. This can be useful for stress testing the
|
||||
server or finding the maximum throughput that it can handle, given some latency budget.
|
||||
|
||||
Two ramp-up strategies are supported:
|
||||
|
||||
- `linear`: Increases the request rate linearly from a start value to an end value.
|
||||
- `exponential`: Increases the request rate exponentially.
|
||||
|
||||
The following arguments can be used to control the ramp-up:
|
||||
|
||||
- `--ramp-up-strategy`: The ramp-up strategy to use (`linear` or `exponential`).
|
||||
- `--ramp-up-start-rps`: The request rate at the beginning of the benchmark.
|
||||
- `--ramp-up-end-rps`: The request rate at the end of the benchmark.
|
||||
|
||||
</details>
|
||||
|
||||
#### 📈 Offline Throughput Benchmark
|
||||
|
||||
<details class="admonition abstract" markdown="1">
|
||||
<summary>Show more</summary>
|
||||
|
||||
```bash
|
||||
vllm bench throughput \
|
||||
--model NousResearch/Hermes-3-Llama-3.1-8B \
|
||||
--dataset-name sonnet \
|
||||
--dataset-path vllm/benchmarks/sonnet.txt \
|
||||
--num-prompts 10
|
||||
```
|
||||
|
||||
If successful, you will see the following output
|
||||
|
||||
```text
|
||||
Throughput: 7.15 requests/s, 4656.00 total tokens/s, 1072.15 output tokens/s
|
||||
Total num prompt tokens: 5014
|
||||
Total num output tokens: 1500
|
||||
```
|
||||
|
||||
##### VisionArena Benchmark for Vision Language Models
|
||||
|
||||
```bash
|
||||
vllm bench throughput \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--backend vllm-chat \
|
||||
--dataset-name hf \
|
||||
--dataset-path lmarena-ai/VisionArena-Chat \
|
||||
--num-prompts 1000 \
|
||||
--hf-split train
|
||||
```
|
||||
|
||||
The `num prompt tokens` now includes image token counts
|
||||
|
||||
```text
|
||||
Throughput: 2.55 requests/s, 4036.92 total tokens/s, 326.90 output tokens/s
|
||||
Total num prompt tokens: 14527
|
||||
Total num output tokens: 1280
|
||||
```
|
||||
|
||||
##### InstructCoder Benchmark with Speculative Decoding
|
||||
|
||||
``` bash
|
||||
VLLM_WORKER_MULTIPROC_METHOD=spawn \
|
||||
VLLM_USE_V1=1 \
|
||||
vllm bench throughput \
|
||||
--dataset-name=hf \
|
||||
--dataset-path=likaixin/InstructCoder \
|
||||
--model=meta-llama/Meta-Llama-3-8B-Instruct \
|
||||
--input-len=1000 \
|
||||
--output-len=100 \
|
||||
--num-prompts=2048 \
|
||||
--async-engine \
|
||||
--speculative-config $'{"method": "ngram",
|
||||
"num_speculative_tokens": 5, "prompt_lookup_max": 5,
|
||||
"prompt_lookup_min": 2}'
|
||||
```
|
||||
|
||||
```text
|
||||
Throughput: 104.77 requests/s, 23836.22 total tokens/s, 10477.10 output tokens/s
|
||||
Total num prompt tokens: 261136
|
||||
Total num output tokens: 204800
|
||||
```
|
||||
|
||||
##### Other HuggingFaceDataset Examples
|
||||
|
||||
`lmms-lab/LLaVA-OneVision-Data`:
|
||||
|
||||
```bash
|
||||
vllm bench throughput \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--backend vllm-chat \
|
||||
--dataset-name hf \
|
||||
--dataset-path lmms-lab/LLaVA-OneVision-Data \
|
||||
--hf-split train \
|
||||
--hf-subset "chart2text(cauldron)" \
|
||||
--num-prompts 10
|
||||
```
|
||||
|
||||
`Aeala/ShareGPT_Vicuna_unfiltered`:
|
||||
|
||||
```bash
|
||||
vllm bench throughput \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--backend vllm-chat \
|
||||
--dataset-name hf \
|
||||
--dataset-path Aeala/ShareGPT_Vicuna_unfiltered \
|
||||
--hf-split train \
|
||||
--num-prompts 10
|
||||
```
|
||||
|
||||
`AI-MO/aimo-validation-aime`:
|
||||
|
||||
```bash
|
||||
vllm bench throughput \
|
||||
--model Qwen/QwQ-32B \
|
||||
--backend vllm \
|
||||
--dataset-name hf \
|
||||
--dataset-path AI-MO/aimo-validation-aime \
|
||||
--hf-split train \
|
||||
--num-prompts 10
|
||||
```
|
||||
|
||||
Benchmark with LoRA adapters:
|
||||
|
||||
``` bash
|
||||
# download dataset
|
||||
# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
vllm bench throughput \
|
||||
--model meta-llama/Llama-2-7b-hf \
|
||||
--backend vllm \
|
||||
--dataset_path <your data path>/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||
--dataset_name sharegpt \
|
||||
--num-prompts 10 \
|
||||
--max-loras 2 \
|
||||
--max-lora-rank 8 \
|
||||
--enable-lora \
|
||||
--lora-path yard1/llama-2-7b-sql-lora-test
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
#### 🛠️ Structured Output Benchmark
|
||||
|
||||
<details class="admonition abstract" markdown="1">
|
||||
<summary>Show more</summary>
|
||||
|
||||
Benchmark the performance of structured output generation (JSON, grammar, regex).
|
||||
|
||||
##### Server Setup
|
||||
|
||||
```bash
|
||||
vllm serve NousResearch/Hermes-3-Llama-3.1-8B
|
||||
```
|
||||
|
||||
##### JSON Schema Benchmark
|
||||
|
||||
```bash
|
||||
python3 benchmarks/benchmark_serving_structured_output.py \
|
||||
--backend vllm \
|
||||
--model NousResearch/Hermes-3-Llama-3.1-8B \
|
||||
--dataset json \
|
||||
--structured-output-ratio 1.0 \
|
||||
--request-rate 10 \
|
||||
--num-prompts 1000
|
||||
```
|
||||
|
||||
##### Grammar-based Generation Benchmark
|
||||
|
||||
```bash
|
||||
python3 benchmarks/benchmark_serving_structured_output.py \
|
||||
--backend vllm \
|
||||
--model NousResearch/Hermes-3-Llama-3.1-8B \
|
||||
--dataset grammar \
|
||||
--structure-type grammar \
|
||||
--request-rate 10 \
|
||||
--num-prompts 1000
|
||||
```
|
||||
|
||||
##### Regex-based Generation Benchmark
|
||||
|
||||
```bash
|
||||
python3 benchmarks/benchmark_serving_structured_output.py \
|
||||
--backend vllm \
|
||||
--model NousResearch/Hermes-3-Llama-3.1-8B \
|
||||
--dataset regex \
|
||||
--request-rate 10 \
|
||||
--num-prompts 1000
|
||||
```
|
||||
|
||||
##### Choice-based Generation Benchmark
|
||||
|
||||
```bash
|
||||
python3 benchmarks/benchmark_serving_structured_output.py \
|
||||
--backend vllm \
|
||||
--model NousResearch/Hermes-3-Llama-3.1-8B \
|
||||
--dataset choice \
|
||||
--request-rate 10 \
|
||||
--num-prompts 1000
|
||||
```
|
||||
|
||||
##### XGrammar Benchmark Dataset
|
||||
|
||||
```bash
|
||||
python3 benchmarks/benchmark_serving_structured_output.py \
|
||||
--backend vllm \
|
||||
--model NousResearch/Hermes-3-Llama-3.1-8B \
|
||||
--dataset xgrammar_bench \
|
||||
--request-rate 10 \
|
||||
--num-prompts 1000
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
#### 📚 Long Document QA Benchmark
|
||||
|
||||
<details class="admonition abstract" markdown="1">
|
||||
<summary>Show more</summary>
|
||||
|
||||
Benchmark the performance of long document question-answering with prefix caching.
|
||||
|
||||
##### Basic Long Document QA Test
|
||||
|
||||
```bash
|
||||
python3 benchmarks/benchmark_long_document_qa_throughput.py \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--enable-prefix-caching \
|
||||
--num-documents 16 \
|
||||
--document-length 2000 \
|
||||
--output-len 50 \
|
||||
--repeat-count 5
|
||||
```
|
||||
|
||||
##### Different Repeat Modes
|
||||
|
||||
```bash
|
||||
# Random mode (default) - shuffle prompts randomly
|
||||
python3 benchmarks/benchmark_long_document_qa_throughput.py \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--enable-prefix-caching \
|
||||
--num-documents 8 \
|
||||
--document-length 3000 \
|
||||
--repeat-count 3 \
|
||||
--repeat-mode random
|
||||
|
||||
# Tile mode - repeat entire prompt list in sequence
|
||||
python3 benchmarks/benchmark_long_document_qa_throughput.py \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--enable-prefix-caching \
|
||||
--num-documents 8 \
|
||||
--document-length 3000 \
|
||||
--repeat-count 3 \
|
||||
--repeat-mode tile
|
||||
|
||||
# Interleave mode - repeat each prompt consecutively
|
||||
python3 benchmarks/benchmark_long_document_qa_throughput.py \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--enable-prefix-caching \
|
||||
--num-documents 8 \
|
||||
--document-length 3000 \
|
||||
--repeat-count 3 \
|
||||
--repeat-mode interleave
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
#### 🗂️ Prefix Caching Benchmark
|
||||
|
||||
<details class="admonition abstract" markdown="1">
|
||||
<summary>Show more</summary>
|
||||
|
||||
Benchmark the efficiency of automatic prefix caching.
|
||||
|
||||
##### Fixed Prompt with Prefix Caching
|
||||
|
||||
```bash
|
||||
python3 benchmarks/benchmark_prefix_caching.py \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--enable-prefix-caching \
|
||||
--num-prompts 1 \
|
||||
--repeat-count 100 \
|
||||
--input-length-range 128:256
|
||||
```
|
||||
|
||||
##### ShareGPT Dataset with Prefix Caching
|
||||
|
||||
```bash
|
||||
# download dataset
|
||||
# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
|
||||
python3 benchmarks/benchmark_prefix_caching.py \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--dataset-path /path/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||
--enable-prefix-caching \
|
||||
--num-prompts 20 \
|
||||
--repeat-count 5 \
|
||||
--input-length-range 128:256
|
||||
```
|
||||
|
||||
##### Prefix Repetition Dataset
|
||||
|
||||
```bash
|
||||
vllm bench serve \
|
||||
--backend openai \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--dataset-name prefix_repetition \
|
||||
--num-prompts 100 \
|
||||
--prefix-repetition-prefix-len 512 \
|
||||
--prefix-repetition-suffix-len 128 \
|
||||
--prefix-repetition-num-prefixes 5 \
|
||||
--prefix-repetition-output-len 128
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
#### ⚡ Request Prioritization Benchmark
|
||||
|
||||
<details class="admonition abstract" markdown="1">
|
||||
<summary>Show more</summary>
|
||||
|
||||
Benchmark the performance of request prioritization in vLLM.
|
||||
|
||||
##### Basic Prioritization Test
|
||||
|
||||
```bash
|
||||
python3 benchmarks/benchmark_prioritization.py \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--input-len 128 \
|
||||
--output-len 64 \
|
||||
--num-prompts 100 \
|
||||
--scheduling-policy priority
|
||||
```
|
||||
|
||||
##### Multiple Sequences per Prompt
|
||||
|
||||
```bash
|
||||
python3 benchmarks/benchmark_prioritization.py \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--input-len 128 \
|
||||
--output-len 64 \
|
||||
--num-prompts 100 \
|
||||
--scheduling-policy priority \
|
||||
--n 2
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
#### 👁️ Multi-Modal Benchmark
|
||||
|
||||
<details class="admonition abstract" markdown="1">
|
||||
<summary>Show more</summary>
|
||||
|
||||
Benchmark the performance of multi-modal requests in vLLM.
|
||||
|
||||
##### Images (ShareGPT4V)
|
||||
|
||||
Start vLLM:
|
||||
|
||||
```bash
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
--model Qwen/Qwen2.5-VL-7B-Instruct \
|
||||
--dtype bfloat16 \
|
||||
--limit-mm-per-prompt '{"image": 1}' \
|
||||
--allowed-local-media-path /path/to/sharegpt4v/images
|
||||
```
|
||||
|
||||
Send requests with images:
|
||||
|
||||
```bash
|
||||
vllm bench serve \
|
||||
--backend openai-chat \
|
||||
--model Qwen/Qwen2.5-VL-7B-Instruct \
|
||||
--dataset-name sharegpt \
|
||||
--dataset-path /path/to/ShareGPT4V/sharegpt4v_instruct_gpt4-vision_cap100k.json \
|
||||
--num-prompts 100 \
|
||||
--save-result \
|
||||
--result-dir ~/vllm_benchmark_results \
|
||||
--save-detailed \
|
||||
--endpoint /v1/chat/completion
|
||||
```
|
||||
|
||||
##### Videos (ShareGPT4Video)
|
||||
|
||||
Start vLLM:
|
||||
|
||||
```bash
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
--model Qwen/Qwen2.5-VL-7B-Instruct \
|
||||
--dtype bfloat16 \
|
||||
--limit-mm-per-prompt '{"video": 1}' \
|
||||
--allowed-local-media-path /path/to/sharegpt4video/videos
|
||||
```
|
||||
|
||||
Send requests with videos:
|
||||
|
||||
```bash
|
||||
vllm bench serve \
|
||||
--backend openai-chat \
|
||||
--model Qwen/Qwen2.5-VL-7B-Instruct \
|
||||
--dataset-name sharegpt \
|
||||
--dataset-path /path/to/ShareGPT4Video/llava_v1_5_mix665k_with_video_chatgpt72k_share4video28k.json \
|
||||
--num-prompts 100 \
|
||||
--save-result \
|
||||
--result-dir ~/vllm_benchmark_results \
|
||||
--save-detailed \
|
||||
--endpoint /v1/chat/completion
|
||||
```
|
||||
|
||||
##### Synthetic Random Images (random-mm)
|
||||
|
||||
Generate synthetic image inputs alongside random text prompts to stress-test vision models without external datasets.
|
||||
|
||||
Notes:
|
||||
|
||||
- Works only with online benchmark via the OpenAI backend (`--backend openai-chat`) and endpoint `/v1/chat/completions`.
|
||||
- Video sampling is not yet implemented.
|
||||
|
||||
Start the server (example):
|
||||
|
||||
```bash
|
||||
vllm serve Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
--dtype bfloat16 \
|
||||
--max-model-len 16384 \
|
||||
--limit-mm-per-prompt '{"image": 3, "video": 0}' \
|
||||
--mm-processor-kwargs max_pixels=1003520
|
||||
```
|
||||
|
||||
Benchmark. It is recommended to use the flag `--ignore-eos` to simulate real responses. You can set the size of the output via the arg `random-output-len`.
|
||||
|
||||
Ex.1: Fixed number of items and a single image resolution, enforcing generation of approx 40 tokens:
|
||||
|
||||
```bash
|
||||
vllm bench serve \
|
||||
--backend openai-chat \
|
||||
--model Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
--endpoint /v1/chat/completions \
|
||||
--dataset-name random-mm \
|
||||
--num-prompts 100 \
|
||||
--max-concurrency 10 \
|
||||
--random-prefix-len 25 \
|
||||
--random-input-len 300 \
|
||||
--random-output-len 40 \
|
||||
--random-range-ratio 0.2 \
|
||||
--random-mm-base-items-per-request 2 \
|
||||
--random-mm-limit-mm-per-prompt '{"image": 3, "video": 0}' \
|
||||
--random-mm-bucket-config '{(224, 224, 1): 1.0}' \
|
||||
--request-rate inf \
|
||||
--ignore-eos \
|
||||
--seed 42
|
||||
```
|
||||
|
||||
The number of items per request can be controlled by passing multiple image buckets:
|
||||
|
||||
```bash
|
||||
--random-mm-base-items-per-request 2 \
|
||||
--random-mm-num-mm-items-range-ratio 0.5 \
|
||||
--random-mm-limit-mm-per-prompt '{"image": 4, "video": 0}' \
|
||||
--random-mm-bucket-config '{(256, 256, 1): 0.7, (720, 1280, 1): 0.3}' \
|
||||
```
|
||||
|
||||
Flags specific to `random-mm`:
|
||||
|
||||
- `--random-mm-base-items-per-request`: base number of multimodal items per request.
|
||||
- `--random-mm-num-mm-items-range-ratio`: vary item count uniformly in the closed integer range [floor(n·(1−r)), ceil(n·(1+r))]. Set r=0 to keep it fixed; r=1 allows 0 items.
|
||||
- `--random-mm-limit-mm-per-prompt`: per-modality hard caps, e.g. '{"image": 3, "video": 0}'.
|
||||
- `--random-mm-bucket-config`: dict mapping (H, W, T) → probability. Entries with probability 0 are removed; remaining probabilities are renormalized to sum to 1. Use T=1 for images. Set any T>1 for videos (video sampling not yet supported).
|
||||
|
||||
Behavioral notes:
|
||||
|
||||
- If the requested base item count cannot be satisfied under the provided per-prompt limits, the tool raises an error rather than silently clamping.
|
||||
|
||||
How sampling works:
|
||||
|
||||
- Determine per-request item count k by sampling uniformly from the integer range defined by `--random-mm-base-items-per-request` and `--random-mm-num-mm-items-range-ratio`, then clamp k to at most the sum of per-modality limits.
|
||||
- For each of the k items, sample a bucket (H, W, T) according to the normalized probabilities in `--random-mm-bucket-config`, while tracking how many items of each modality have been added.
|
||||
- If a modality (e.g., image) reaches its limit from `--random-mm-limit-mm-per-prompt`, all buckets of that modality are excluded and the remaining bucket probabilities are renormalized before continuing.
|
||||
This should be seen as an edge case, and if this behavior can be avoided by setting `--random-mm-limit-mm-per-prompt` to a large number. Note that this might result in errors due to engine config `--limit-mm-per-prompt`.
|
||||
- The resulting request contains synthetic image data in `multi_modal_data` (OpenAI Chat format). When `random-mm` is used with the OpenAI Chat backend, prompts remain text and MM content is attached via `multi_modal_data`.
|
||||
|
||||
</details>
|
||||
|
||||
[](){ #performance-benchmarks }
|
||||
|
||||
@ -13,22 +791,22 @@ The performance benchmarks are used for development to confirm whether new chang
|
||||
|
||||
### Manually Trigger the benchmark
|
||||
|
||||
Use [vllm-ci-test-repo images](https://gallery.ecr.aws/q9t5s3a7/vllm-ci-test-repo) with vLLM benchmark suite.
|
||||
Use [vllm-ci-test-repo images](https://gallery.ecr.aws/q9t5s3a7/vllm-ci-test-repo) with vLLM benchmark suite.
|
||||
For CPU environment, please use the image with "-cpu" postfix.
|
||||
|
||||
Here is an example for docker run command for CPU.
|
||||
Here is an example for docker run command for CPU.
|
||||
|
||||
```bash
|
||||
docker run -it --entrypoint /bin/bash -v /data/huggingface:/root/.cache/huggingface -e HF_TOKEN='' --shm-size=16g --name vllm-cpu-ci public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:1da94e673c257373280026f75ceb4effac80e892-cpu
|
||||
```
|
||||
|
||||
Then, run below command inside the docker instance.
|
||||
Then, run below command inside the docker instance.
|
||||
|
||||
```bash
|
||||
bash .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh
|
||||
```
|
||||
|
||||
When run, benchmark script generates results under **benchmark/results** folder, along with the benchmark_results.md and benchmark_results.json.
|
||||
When run, benchmark script generates results under **benchmark/results** folder, along with the benchmark_results.md and benchmark_results.json.
|
||||
|
||||
#### Runtime environment variables
|
||||
|
||||
|
||||
@ -40,6 +40,16 @@ python tools/generate_cmake_presets.py
|
||||
|
||||
The script will prompt you if it cannot automatically determine certain paths (e.g., `nvcc` or a specific Python executable for your vLLM development environment). Follow the on-screen prompts. If an existing `CMakeUserPresets.json` is found, the script will ask for confirmation before overwriting it.
|
||||
|
||||
**Force overwrite existing file:**
|
||||
|
||||
To automatically overwrite an existing `CMakeUserPresets.json` without prompting, use the `--force-overwrite` flag:
|
||||
|
||||
```console
|
||||
python tools/generate_cmake_presets.py --force-overwrite
|
||||
```
|
||||
|
||||
This is particularly useful in automated scripts or CI/CD environments where interactive prompts are not desired.
|
||||
|
||||
After running the script, a `CMakeUserPresets.json` file will be created in the root of your vLLM repository.
|
||||
|
||||
### Example `CMakeUserPresets.json`
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
!!! important
|
||||
Many decoder language models can now be automatically loaded using the [Transformers backend][transformers-backend] without having to implement them in vLLM. See if `vllm serve <model>` works first!
|
||||
|
||||
vLLM models are specialized [PyTorch](https://pytorch.org/) models that take advantage of various [features](../../features/compatibility_matrix.md) to optimize their performance.
|
||||
vLLM models are specialized [PyTorch](https://pytorch.org/) models that take advantage of various [features](../../features/README.md#compatibility-matrix) to optimize their performance.
|
||||
|
||||
The complexity of integrating a model into vLLM depends heavily on the model's architecture.
|
||||
The process is considerably straightforward if the model shares a similar architecture with an existing model in vLLM.
|
||||
@ -15,6 +15,7 @@ Read through these pages for a step-by-step guide:
|
||||
- [Registering a Model](registration.md)
|
||||
- [Unit Testing](tests.md)
|
||||
- [Multi-Modal Support](multimodal.md)
|
||||
- [Speech-to-Text Support](transcription.md)
|
||||
|
||||
!!! tip
|
||||
If you are encountering issues while integrating your model into vLLM, feel free to open a [GitHub issue](https://github.com/vllm-project/vllm/issues)
|
||||
|
||||
@ -840,7 +840,6 @@ Some HF processors directly insert feature tokens without replacing anything in
|
||||
Examples:
|
||||
|
||||
- BLIP-2 (insert at start of prompt): <gh-file:vllm/model_executor/models/blip2.py>
|
||||
- Florence2 (insert at start of prompt): <gh-file:vllm/model_executor/models/florence2.py>
|
||||
- Molmo (insert after `<|endoftext|>` token): <gh-file:vllm/model_executor/models/molmo.py>
|
||||
|
||||
### Handling prompt updates unrelated to multi-modal data
|
||||
|
||||
276
docs/contributing/model/transcription.md
Normal file
276
docs/contributing/model/transcription.md
Normal file
@ -0,0 +1,276 @@
|
||||
# Speech-to-Text (Transcription/Translation) Support
|
||||
|
||||
This document walks you through the steps to add support for speech-to-text (ASR) models to vLLM’s transcription and translation APIs by implementing [SupportsTranscription][vllm.model_executor.models.interfaces.SupportsTranscription].
|
||||
Please refer to the [supported models](../../models/supported_models.md#transcription) for further guidance.
|
||||
|
||||
## Update the base vLLM model
|
||||
|
||||
It is assumed you have already implemented your model in vLLM according to the basic model guide. Extend your model with the [SupportsTranscription][vllm.model_executor.models.interfaces.SupportsTranscription] interface and implement the following class attributes and methods.
|
||||
|
||||
### `supported_languages` and `supports_transcription_only`
|
||||
|
||||
Declare supported languages and capabilities:
|
||||
|
||||
- The `supported_languages` mapping is validated at init time.
|
||||
- Set `supports_transcription_only=True` if the model should not serve text generation (eg Whisper).
|
||||
|
||||
??? code "supported_languages and supports_transcription_only"
|
||||
```python
|
||||
from typing import ClassVar, Mapping, Optional, Literal
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import ModelConfig, SpeechToTextConfig
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.model_executor.models.interfaces import SupportsTranscription
|
||||
|
||||
class YourASRModel(nn.Module, SupportsTranscription):
|
||||
# Map of ISO 639-1 language codes to language names
|
||||
supported_languages: ClassVar[Mapping[str, str]] = {
|
||||
"en": "English",
|
||||
"it": "Italian",
|
||||
# ... add more as needed
|
||||
}
|
||||
|
||||
# If your model only supports audio-conditioned generation
|
||||
# (no text-only generation), enable this flag.
|
||||
supports_transcription_only: ClassVar[bool] = True
|
||||
```
|
||||
|
||||
Provide an ASR configuration via [get_speech_to_text_config][vllm.model_executor.models.interfaces.SupportsTranscription.get_speech_to_text_config].
|
||||
|
||||
This is for controlling general behavior of the API when serving your model:
|
||||
|
||||
??? code "get_speech_to_text_config()"
|
||||
```python
|
||||
class YourASRModel(nn.Module, SupportsTranscription):
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def get_speech_to_text_config(
|
||||
cls,
|
||||
model_config: ModelConfig,
|
||||
task_type: Literal["transcribe", "translate"],
|
||||
) -> SpeechToTextConfig:
|
||||
return SpeechToTextConfig(
|
||||
sample_rate=16_000,
|
||||
max_audio_clip_s=30,
|
||||
# Set to None to disable server-side chunking if your
|
||||
# model/processor handles it already
|
||||
min_energy_split_window_size=None,
|
||||
)
|
||||
```
|
||||
|
||||
See [Audio preprocessing and chunking](#audio-preprocessing-and-chunking) for what each field controls.
|
||||
|
||||
Implement the prompt construction via [get_generation_prompt][vllm.model_executor.models.interfaces.SupportsTranscription.get_generation_prompt]. The server passes you the resampled waveform and task parameters; you return a valid [PromptType][vllm.inputs.data.PromptType]. There are two common patterns:
|
||||
|
||||
#### Multimodal LLM with audio embeddings (e.g., Voxtral, Gemma3n)
|
||||
|
||||
Return a dict containing `multi_modal_data` with the audio, and either a `prompt` string or `prompt_token_ids`:
|
||||
|
||||
??? code "get_generation_prompt()"
|
||||
```python
|
||||
class YourASRModel(nn.Module, SupportsTranscription):
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def get_generation_prompt(
|
||||
cls,
|
||||
audio: np.ndarray,
|
||||
stt_config: SpeechToTextConfig,
|
||||
model_config: ModelConfig,
|
||||
language: Optional[str],
|
||||
task_type: Literal["transcribe", "translate"],
|
||||
request_prompt: str,
|
||||
to_language: Optional[str],
|
||||
) -> PromptType:
|
||||
# Example with a free-form instruction prompt
|
||||
task_word = "Transcribe" if task_type == "transcribe" else "Translate"
|
||||
prompt = (
|
||||
"<start_of_turn>user\n"
|
||||
f"{task_word} this audio: <audio_soft_token>"
|
||||
"<end_of_turn>\n<start_of_turn>model\n"
|
||||
)
|
||||
|
||||
return {
|
||||
"multi_modal_data": {"audio": (audio, stt_config.sample_rate)},
|
||||
"prompt": prompt,
|
||||
}
|
||||
```
|
||||
|
||||
For further clarification on multi modal inputs, please refer to [Multi-Modal Inputs](../../features/multimodal_inputs.md).
|
||||
|
||||
#### Encoder–decoder audio-only (e.g., Whisper)
|
||||
|
||||
Return a dict with separate `encoder_prompt` and `decoder_prompt` entries:
|
||||
|
||||
??? code "get_generation_prompt()"
|
||||
```python
|
||||
class YourASRModel(nn.Module, SupportsTranscription):
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def get_generation_prompt(
|
||||
cls,
|
||||
audio: np.ndarray,
|
||||
stt_config: SpeechToTextConfig,
|
||||
model_config: ModelConfig,
|
||||
language: Optional[str],
|
||||
task_type: Literal["transcribe", "translate"],
|
||||
request_prompt: str,
|
||||
to_language: Optional[str],
|
||||
) -> PromptType:
|
||||
if language is None:
|
||||
raise ValueError("Language must be specified")
|
||||
|
||||
prompt = {
|
||||
"encoder_prompt": {
|
||||
"prompt": "",
|
||||
"multi_modal_data": {
|
||||
"audio": (audio, stt_config.sample_rate),
|
||||
},
|
||||
},
|
||||
"decoder_prompt": (
|
||||
(f"<|prev|>{request_prompt}" if request_prompt else "")
|
||||
+ f"<|startoftranscript|><|{language}|>"
|
||||
+ f"<|{task_type}|><|notimestamps|>"
|
||||
),
|
||||
}
|
||||
return cast(PromptType, prompt)
|
||||
```
|
||||
|
||||
### `validate_language` (optional)
|
||||
|
||||
Language validation via [validate_language][vllm.model_executor.models.interfaces.SupportsTranscription.validate_language]
|
||||
|
||||
If your model requires a language and you want a default, override this method (see Whisper):
|
||||
|
||||
??? code "validate_language()"
|
||||
```python
|
||||
@classmethod
|
||||
def validate_language(cls, language: Optional[str]) -> Optional[str]:
|
||||
if language is None:
|
||||
logger.warning(
|
||||
"Defaulting to language='en'. If you wish to transcribe audio in a different language, pass the `language` field.")
|
||||
language = "en"
|
||||
return super().validate_language(language)
|
||||
```
|
||||
|
||||
### `get_num_audio_tokens` (optional)
|
||||
|
||||
Token accounting for streaming via [get_num_audio_tokens][vllm.model_executor.models.interfaces.SupportsTranscription.get_num_audio_tokens]
|
||||
|
||||
Provide a fast duration→token estimate to improve streaming usage statistics:
|
||||
|
||||
??? code "get_num_audio_tokens()"
|
||||
```python
|
||||
class YourASRModel(nn.Module, SupportsTranscription):
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def get_num_audio_tokens(
|
||||
cls,
|
||||
audio_duration_s: float,
|
||||
stt_config: SpeechToTextConfig,
|
||||
model_config: ModelConfig,
|
||||
) -> Optional[int]:
|
||||
# Return None if unknown; otherwise return an estimate.
|
||||
return int(audio_duration_s * stt_config.sample_rate // 320) # example
|
||||
```
|
||||
|
||||
## Audio preprocessing and chunking
|
||||
|
||||
The API server takes care of basic audio I/O and optional chunking before building prompts:
|
||||
|
||||
- Resampling: Input audio is resampled to `SpeechToTextConfig.sample_rate` using `librosa`.
|
||||
- Chunking: If `SpeechToTextConfig.allow_audio_chunking` is True and the duration exceeds `max_audio_clip_s`, the server splits the audio into overlapping chunks and generates a prompt per chunk. Overlap is controlled by `overlap_chunk_second`.
|
||||
- Energy-aware splitting: When `min_energy_split_window_size` is set, the server finds low-energy regions to minimize cutting within words.
|
||||
|
||||
Relevant server logic:
|
||||
|
||||
??? code "_preprocess_speech_to_text()"
|
||||
```python
|
||||
# vllm/entrypoints/openai/speech_to_text.py
|
||||
async def _preprocess_speech_to_text(...):
|
||||
language = self.model_cls.validate_language(request.language)
|
||||
...
|
||||
y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate)
|
||||
duration = librosa.get_duration(y=y, sr=sr)
|
||||
do_split_audio = (self.asr_config.allow_audio_chunking
|
||||
and duration > self.asr_config.max_audio_clip_s)
|
||||
chunks = [y] if not do_split_audio else self._split_audio(y, int(sr))
|
||||
prompts = []
|
||||
for chunk in chunks:
|
||||
prompt = self.model_cls.get_generation_prompt(
|
||||
audio=chunk,
|
||||
stt_config=self.asr_config,
|
||||
model_config=self.model_config,
|
||||
language=language,
|
||||
task_type=self.task_type,
|
||||
request_prompt=request.prompt,
|
||||
to_language=to_language,
|
||||
)
|
||||
prompts.append(prompt)
|
||||
return prompts, duration
|
||||
```
|
||||
|
||||
## Exposing tasks automatically
|
||||
|
||||
vLLM automatically advertises transcription support if your model implements the interface:
|
||||
|
||||
```python
|
||||
if supports_transcription(model):
|
||||
if model.supports_transcription_only:
|
||||
return ["transcription"]
|
||||
supported_tasks.append("transcription")
|
||||
```
|
||||
|
||||
When enabled, the server initializes the transcription and translation handlers:
|
||||
|
||||
```python
|
||||
state.openai_serving_transcription = OpenAIServingTranscription(...) if "transcription" in supported_tasks else None
|
||||
state.openai_serving_translation = OpenAIServingTranslation(...) if "transcription" in supported_tasks else None
|
||||
```
|
||||
|
||||
No extra registration is required beyond having your model class available via the model registry and implementing `SupportsTranscription`.
|
||||
|
||||
## Examples in-tree
|
||||
|
||||
- Whisper encoder–decoder (audio-only): <gh-file:vllm/model_executor/models/whisper.py>
|
||||
- Voxtral decoder-only (audio embeddings + LLM): <gh-file:vllm/model_executor/models/voxtral.py>
|
||||
- Gemma3n decoder-only with fixed instruction prompt: <gh-file:vllm/model_executor/models/gemma3n_mm.py>
|
||||
|
||||
## Test with the API
|
||||
|
||||
Once your model implements `SupportsTranscription`, you can test the endpoints (API mimics OpenAI):
|
||||
|
||||
- Transcription (ASR):
|
||||
|
||||
```bash
|
||||
curl -s -X POST \
|
||||
-H "Authorization: Bearer $VLLM_API_KEY" \
|
||||
-H "Content-Type: multipart/form-data" \
|
||||
-F "file=@/path/to/audio.wav" \
|
||||
-F "model=$MODEL_ID" \
|
||||
http://localhost:8000/v1/audio/transcriptions
|
||||
```
|
||||
|
||||
- Translation (source → English unless otherwise supported):
|
||||
|
||||
```bash
|
||||
curl -s -X POST \
|
||||
-H "Authorization: Bearer $VLLM_API_KEY" \
|
||||
-H "Content-Type: multipart/form-data" \
|
||||
-F "file=@/path/to/audio.wav" \
|
||||
-F "model=$MODEL_ID" \
|
||||
http://localhost:8000/v1/audio/translations
|
||||
```
|
||||
|
||||
Or check out more examples in <gh-file:examples/online_serving>.
|
||||
|
||||
!!! note
|
||||
- If your model handles chunking internally (e.g., via its processor or encoder), set `min_energy_split_window_size=None` in the returned `SpeechToTextConfig` to disable server-side chunking.
|
||||
- Implementing `get_num_audio_tokens` improves accuracy of streaming usage metrics (`prompt_tokens`) without an extra forward pass.
|
||||
- For multilingual behavior, keep `supported_languages` aligned with actual model capabilities.
|
||||
@ -1,41 +1,53 @@
|
||||
# Anything LLM
|
||||
# AnythingLLM
|
||||
|
||||
[Anything LLM](https://github.com/Mintplex-Labs/anything-llm) is a full-stack application that enables you to turn any document, resource, or piece of content into context that any LLM can use as references during chatting.
|
||||
[AnythingLLM](https://github.com/Mintplex-Labs/anything-llm) is a full-stack application that enables you to turn any document, resource, or piece of content into context that any LLM can use as references during chatting.
|
||||
|
||||
It allows you to deploy a large language model (LLM) server with vLLM as the backend, which exposes OpenAI-compatible endpoints.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Setup vLLM environment
|
||||
Set up the vLLM environment:
|
||||
|
||||
```bash
|
||||
pip install vllm
|
||||
```
|
||||
|
||||
## Deploy
|
||||
|
||||
- Start the vLLM server with the supported chat completion model, e.g.
|
||||
1. Start the vLLM server with a supported chat-completion model, for example:
|
||||
|
||||
```bash
|
||||
vllm serve Qwen/Qwen1.5-32B-Chat-AWQ --max-model-len 4096
|
||||
```
|
||||
```bash
|
||||
vllm serve Qwen/Qwen1.5-32B-Chat-AWQ --max-model-len 4096
|
||||
```
|
||||
|
||||
- Download and install [Anything LLM desktop](https://anythingllm.com/desktop).
|
||||
1. Download and install [AnythingLLM Desktop](https://anythingllm.com/desktop).
|
||||
|
||||
- On the bottom left of open settings, AI Providers --> LLM:
|
||||
- LLM Provider: Generic OpenAI
|
||||
- Base URL: http://{vllm server host}:{vllm server port}/v1
|
||||
- Chat Model Name: `Qwen/Qwen1.5-32B-Chat-AWQ`
|
||||
1. Configure the AI provider:
|
||||
|
||||

|
||||
- At the bottom, click the 🔧 wrench icon -> **Open settings** -> **AI Providers** -> **LLM**.
|
||||
- Enter the following values:
|
||||
- LLM Provider: Generic OpenAI
|
||||
- Base URL: `http://{vllm server host}:{vllm server port}/v1`
|
||||
- Chat Model Name: `Qwen/Qwen1.5-32B-Chat-AWQ`
|
||||
|
||||
- Back to home page, New Workspace --> create `vllm` workspace, and start to chat:
|
||||

|
||||
|
||||

|
||||
1. Create a workspace:
|
||||
|
||||
- Click the upload button:
|
||||
- upload the doc
|
||||
- select the doc and move to the workspace
|
||||
- save and embed
|
||||
1. At the bottom, click the ↺ back icon and back to workspaces.
|
||||
1. Create a workspace (e.g., `vllm`) and start chatting.
|
||||
|
||||

|
||||

|
||||
|
||||
- Chat again:
|
||||
1. Add a document.
|
||||
|
||||

|
||||
1. Click the 📎 attachment icon.
|
||||
1. Upload a document.
|
||||
1. Select and move the document into your workspace.
|
||||
1. Save and embed it.
|
||||
|
||||

|
||||
|
||||
1. Chat using your document as context.
|
||||
|
||||

|
||||
|
||||
@ -4,9 +4,7 @@
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Setup vLLM environment
|
||||
|
||||
- Setup [AutoGen](https://microsoft.github.io/autogen/0.2/docs/installation/) environment
|
||||
Set up the vLLM and [AutoGen](https://microsoft.github.io/autogen/0.2/docs/installation/) environment:
|
||||
|
||||
```bash
|
||||
pip install vllm
|
||||
@ -18,14 +16,14 @@ pip install -U "autogen-agentchat" "autogen-ext[openai]"
|
||||
|
||||
## Deploy
|
||||
|
||||
- Start the vLLM server with the supported chat completion model, e.g.
|
||||
1. Start the vLLM server with the supported chat completion model, e.g.
|
||||
|
||||
```bash
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
--model mistralai/Mistral-7B-Instruct-v0.2
|
||||
```
|
||||
```bash
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
--model mistralai/Mistral-7B-Instruct-v0.2
|
||||
```
|
||||
|
||||
- Call it with AutoGen:
|
||||
1. Call it with AutoGen:
|
||||
|
||||
??? code
|
||||
|
||||
|
||||
@ -6,27 +6,31 @@ It allows you to deploy a large language model (LLM) server with vLLM as the bac
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Setup vLLM environment
|
||||
Set up the vLLM environment:
|
||||
|
||||
```bash
|
||||
pip install vllm
|
||||
```
|
||||
|
||||
## Deploy
|
||||
|
||||
- Start the vLLM server with the supported chat completion model, e.g.
|
||||
1. Start the vLLM server with the supported chat completion model, e.g.
|
||||
|
||||
```bash
|
||||
vllm serve qwen/Qwen1.5-0.5B-Chat
|
||||
```
|
||||
```bash
|
||||
vllm serve qwen/Qwen1.5-0.5B-Chat
|
||||
```
|
||||
|
||||
- Download and install [Chatbox desktop](https://chatboxai.app/en#download).
|
||||
1. Download and install [Chatbox desktop](https://chatboxai.app/en#download).
|
||||
|
||||
- On the bottom left of settings, Add Custom Provider
|
||||
1. On the bottom left of settings, Add Custom Provider
|
||||
- API Mode: `OpenAI API Compatible`
|
||||
- Name: vllm
|
||||
- API Host: `http://{vllm server host}:{vllm server port}/v1`
|
||||
- API Path: `/chat/completions`
|
||||
- Model: `qwen/Qwen1.5-0.5B-Chat`
|
||||
|
||||

|
||||

|
||||
|
||||
- Go to `Just chat`, and start to chat:
|
||||
1. Go to `Just chat`, and start to chat:
|
||||
|
||||

|
||||

|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user