Compare commits
233 Commits
v0.7.3
...
bind_kv_ca
| Author | SHA1 | Date | |
|---|---|---|---|
| bfff9bcd1d | |||
| 257e200a25 | |||
| 47d4a7e004 | |||
| 7f89a594dd | |||
| 961644e6a8 | |||
| 8d6cd32b7b | |||
| ec79b67c77 | |||
| 32985bed7c | |||
| dae9ec464c | |||
| 6eaf93020d | |||
| 72c62eae5f | |||
| 0a995d5434 | |||
| ade3f7d988 | |||
| 0df25101d6 | |||
| e123aafdf0 | |||
| 5b143d33be | |||
| eb59b5a6cb | |||
| fbfc3ee37e | |||
| 3e1d223626 | |||
| 4f5b059f14 | |||
| 288ca110f6 | |||
| c2bd2196fc | |||
| 550c7ba3dc | |||
| e5b2f1601a | |||
| 9badee53de | |||
| beebf4742a | |||
| f89978ad7c | |||
| b3cf368d79 | |||
| c8525f06fc | |||
| 5db6b2c961 | |||
| 6247bae6c6 | |||
| 3610fb4930 | |||
| 71c4b40562 | |||
| ac65bc92df | |||
| f78c0be80a | |||
| 66233af7b6 | |||
| bf13d40972 | |||
| 989f4f430c | |||
| bb5b640359 | |||
| c060b71408 | |||
| 79e4937c65 | |||
| cd1d3c3df8 | |||
| 19d98e0c7d | |||
| 2b04c209ee | |||
| ae122b1cbd | |||
| 872db2be0e | |||
| 2dfdfed8a0 | |||
| c41d27156b | |||
| 91373a0d15 | |||
| 848a6438ae | |||
| 98175b2816 | |||
| 4167252eaf | |||
| f35f8e2242 | |||
| b87c21fc89 | |||
| e584b85afd | |||
| 09e56f9262 | |||
| cf069aa8aa | |||
| bf33700ecd | |||
| bc6ccb9878 | |||
| 82fbeae92b | |||
| cc5e8f6db8 | |||
| d54990da47 | |||
| b9f1d4294e | |||
| b28246f6ff | |||
| 3b5567a209 | |||
| fdcc405346 | |||
| 8994dabc22 | |||
| 02296f420d | |||
| 6a92ff93e1 | |||
| 6a84164add | |||
| f64ffa8c25 | |||
| bd56c983d6 | |||
| 084bbac8cc | |||
| 28943d36ce | |||
| b526ca6726 | |||
| e7bd944e08 | |||
| c3b6559a10 | |||
| 4be4b26cb7 | |||
| 2aed2c9fa7 | |||
| 9b61dd41e7 | |||
| f7bee5c815 | |||
| e0734387fb | |||
| f58f8b5c96 | |||
| b3f7aaccd0 | |||
| b91660ddb8 | |||
| 76c89fcadd | |||
| b9e41734c5 | |||
| 1088f06242 | |||
| 73e0225ee9 | |||
| 6c85da3a18 | |||
| 67fc426845 | |||
| 9804145cac | |||
| 2e94b9cfbb | |||
| 8294773e48 | |||
| cd813c6d4d | |||
| 38acae6e97 | |||
| a2dd48c386 | |||
| 126f6beeb4 | |||
| 58d1b2aa77 | |||
| f1579b229d | |||
| 7864875879 | |||
| 1dd422b64a | |||
| 06c8f8d885 | |||
| 5677c9bb3e | |||
| 512d77d582 | |||
| 7f0be2aa24 | |||
| edf309ebbe | |||
| 788f284b53 | |||
| 4b1d141f49 | |||
| 10c3b8c1cf | |||
| a7f37314b7 | |||
| cd711c48b2 | |||
| 378b3ef6f8 | |||
| c9944acbf9 | |||
| ca377cf1b9 | |||
| a31614e386 | |||
| f95903909f | |||
| b382a7f28f | |||
| 4cb6fa0a9c | |||
| d08b285adf | |||
| b27122acc2 | |||
| 934bb99c71 | |||
| 3f808cc044 | |||
| ec8a5e5386 | |||
| 215bf150a6 | |||
| 0ecdd98031 | |||
| 7b700ec8c8 | |||
| 7ca1da020f | |||
| 5157338ed9 | |||
| e206b54331 | |||
| 1d35662e6d | |||
| e656f638de | |||
| 145944cb94 | |||
| 094b7d9496 | |||
| e1fe7591f2 | |||
| 5629f26df7 | |||
| 9ba28043b5 | |||
| 24679788ed | |||
| 07c4353057 | |||
| 34e3494e70 | |||
| f75aa72732 | |||
| 340e39e387 | |||
| f4133ce4e5 | |||
| 6522d55b6f | |||
| 6ff518626c | |||
| fa82074167 | |||
| 75e9d49796 | |||
| 32c3b6bfd1 | |||
| 37b6cb4985 | |||
| aabeb2688f | |||
| 2f42a4888c | |||
| 3173c3b34e | |||
| 2d87d7d1ac | |||
| aab392774b | |||
| 6724e79164 | |||
| 03f48b3db6 | |||
| 4d251ad00e | |||
| 18e505930d | |||
| 4a8cfc7551 | |||
| bc32bc73aa | |||
| ab1091d5f2 | |||
| 1e15aaef56 | |||
| 51010a1807 | |||
| 7196a3b1db | |||
| cdc1fa12eb | |||
| f61528d46d | |||
| 1f0ae3ed0a | |||
| db986c19ea | |||
| 227578480d | |||
| befc402d34 | |||
| 444b0f0f62 | |||
| ccc00515fd | |||
| 781096e385 | |||
| 7940d8a6a7 | |||
| c0e3ecd6d2 | |||
| 23eca9cf68 | |||
| 437b76ff59 | |||
| f90a375593 | |||
| e7ef74e26e | |||
| cbae7af552 | |||
| eb24dc4a45 | |||
| 9bebc9512f | |||
| 5a2ba16f5c | |||
| ba5106e519 | |||
| d5ca2110f1 | |||
| 2c5e637b57 | |||
| 322d2a27d6 | |||
| 82e0d601fc | |||
| 78ac0f591d | |||
| b56155e7f3 | |||
| 382f66fb08 | |||
| 8354f6640c | |||
| c904fdddf6 | |||
| 558db8083c | |||
| e109e598c7 | |||
| 8db1b9d0a1 | |||
| 2382ad29d1 | |||
| 3e472d882a | |||
| 7f6bae561c | |||
| 105b8ce4c0 | |||
| 2cb8c1540e | |||
| 1cd981da4f | |||
| fca20841c2 | |||
| da31b5333e | |||
| bb78fb318e | |||
| 8aca27fa11 | |||
| 95c617e04b | |||
| 9a1f1da5d1 | |||
| 68d630a0c7 | |||
| 68d535ef44 | |||
| c6ed93860f | |||
| 0ffdf8ce0c | |||
| 8c0dd3d4df | |||
| ada7c780d5 | |||
| 288cc6c234 | |||
| 900edbfa48 | |||
| b2c3fc5d65 | |||
| 839b27c6cc | |||
| 34ad27fe83 | |||
| 1c3c975766 | |||
| 1cdc88614a | |||
| 31aa045c11 | |||
| a30c093502 | |||
| c7b07a95a6 | |||
| 27a09dc52c | |||
| 981f3c831e | |||
| 44c33f01f3 | |||
| 33170081f1 | |||
| 71face8540 | |||
| bfbc0b32c6 | |||
| 6a417b8600 | |||
| d3ea50113c | |||
| 34aad515c8 |
@ -84,8 +84,13 @@ if __name__ == "__main__":
|
||||
# this result is generated via `benchmark_serving.py`
|
||||
|
||||
# attach the benchmarking command to raw_result
|
||||
with open(test_file.with_suffix(".commands")) as f:
|
||||
command = json.loads(f.read())
|
||||
try:
|
||||
with open(test_file.with_suffix(".commands")) as f:
|
||||
command = json.loads(f.read())
|
||||
except OSError as e:
|
||||
print(e)
|
||||
continue
|
||||
|
||||
raw_result.update(command)
|
||||
|
||||
# update the test name of this result
|
||||
@ -99,8 +104,13 @@ if __name__ == "__main__":
|
||||
# this result is generated via `benchmark_latency.py`
|
||||
|
||||
# attach the benchmarking command to raw_result
|
||||
with open(test_file.with_suffix(".commands")) as f:
|
||||
command = json.loads(f.read())
|
||||
try:
|
||||
with open(test_file.with_suffix(".commands")) as f:
|
||||
command = json.loads(f.read())
|
||||
except OSError as e:
|
||||
print(e)
|
||||
continue
|
||||
|
||||
raw_result.update(command)
|
||||
|
||||
# update the test name of this result
|
||||
@ -121,8 +131,13 @@ if __name__ == "__main__":
|
||||
# this result is generated via `benchmark_throughput.py`
|
||||
|
||||
# attach the benchmarking command to raw_result
|
||||
with open(test_file.with_suffix(".commands")) as f:
|
||||
command = json.loads(f.read())
|
||||
try:
|
||||
with open(test_file.with_suffix(".commands")) as f:
|
||||
command = json.loads(f.read())
|
||||
except OSError as e:
|
||||
print(e)
|
||||
continue
|
||||
|
||||
raw_result.update(command)
|
||||
|
||||
# update the test name of this result
|
||||
|
||||
@ -309,11 +309,14 @@ run_serving_tests() {
|
||||
|
||||
new_test_name=$test_name"_qps_"$qps
|
||||
|
||||
# pass the tensor parallel size to the client so that it can be displayed
|
||||
# on the benchmark dashboard
|
||||
client_command="python3 benchmark_serving.py \
|
||||
--save-result \
|
||||
--result-dir $RESULTS_FOLDER \
|
||||
--result-filename ${new_test_name}.json \
|
||||
--request-rate $qps \
|
||||
--metadata "tensor_parallel_size=$tp" \
|
||||
$client_args"
|
||||
|
||||
echo "Running test case $test_name with qps $qps"
|
||||
|
||||
@ -32,4 +32,4 @@
|
||||
"backend": "vllm"
|
||||
}
|
||||
}
|
||||
]
|
||||
]
|
||||
|
||||
@ -1,4 +1,15 @@
|
||||
steps:
|
||||
- label: "Build wheel - CUDA 12.4"
|
||||
agents:
|
||||
queue: cpu_queue_postmerge
|
||||
commands:
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.4.0 --tag vllm-ci:build-image --target build --progress plain ."
|
||||
- "mkdir artifacts"
|
||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||
- "bash .buildkite/upload-wheels.sh"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
- label: "Build wheel - CUDA 12.1"
|
||||
agents:
|
||||
queue: cpu_queue_postmerge
|
||||
@ -37,7 +48,7 @@ steps:
|
||||
queue: cpu_queue_postmerge
|
||||
commands:
|
||||
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.1.0 --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain ."
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.4.0 --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain ."
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT"
|
||||
|
||||
- label: "Build and publish TPU release image"
|
||||
|
||||
@ -77,7 +77,6 @@ echo "Commands:$commands"
|
||||
#ignore certain kernels tests
|
||||
if [[ $commands == *" kernels "* ]]; then
|
||||
commands="${commands} \
|
||||
--ignore=kernels/test_attention.py \
|
||||
--ignore=kernels/test_attention_selector.py \
|
||||
--ignore=kernels/test_blocksparse_attention.py \
|
||||
--ignore=kernels/test_causal_conv1d.py \
|
||||
@ -92,7 +91,14 @@ if [[ $commands == *" kernels "* ]]; then
|
||||
--ignore=kernels/test_moe.py \
|
||||
--ignore=kernels/test_prefix_prefill.py \
|
||||
--ignore=kernels/test_rand.py \
|
||||
--ignore=kernels/test_sampler.py"
|
||||
--ignore=kernels/test_sampler.py \
|
||||
--ignore=kernels/test_cascade_flash_attn.py \
|
||||
--ignore=kernels/test_mamba_mixer2.py \
|
||||
--ignore=kernels/test_aqlm.py \
|
||||
--ignore=kernels/test_machete_mm.py \
|
||||
--ignore=kernels/test_mha_attn.py \
|
||||
--ignore=kernels/test_block_fp8.py \
|
||||
--ignore=kernels/test_permute_cols.py"
|
||||
fi
|
||||
|
||||
#ignore certain Entrypoints tests
|
||||
|
||||
@ -134,7 +134,9 @@ steps:
|
||||
- tests/compile/test_basic_correctness
|
||||
- examples/offline_inference/rlhf.py
|
||||
- examples/offline_inference/rlhf_colocate.py
|
||||
- tests/examples/offline_inference/data_parallel.py
|
||||
commands:
|
||||
- VLLM_USE_V1=1 python3 ../examples/offline_inference/data_parallel.py
|
||||
- pytest -v -s distributed/test_utils.py
|
||||
- pytest -v -s compile/test_basic_correctness.py
|
||||
- pytest -v -s distributed/test_pynccl.py
|
||||
@ -273,10 +275,10 @@ steps:
|
||||
source_file_dependencies:
|
||||
- vllm/lora
|
||||
- tests/lora
|
||||
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_minicpmv_tp.py
|
||||
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_minicpmv_tp.py --ignore=lora/test_transfomers_model.py
|
||||
parallelism: 4
|
||||
|
||||
- label: "PyTorch Fullgraph Smoke Test" # 9min
|
||||
- label: PyTorch Fullgraph Smoke Test # 9min
|
||||
fast_check: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -287,7 +289,7 @@ steps:
|
||||
- pytest -v -s compile/piecewise/test_simple.py
|
||||
- pytest -v -s compile/piecewise/test_toy_llama.py
|
||||
|
||||
- label: "PyTorch Fullgraph Test" # 18min
|
||||
- label: PyTorch Fullgraph Test # 18min
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/compile
|
||||
@ -501,6 +503,7 @@ steps:
|
||||
- entrypoints/llm/test_collective_rpc.py
|
||||
commands:
|
||||
- pytest -v -s entrypoints/llm/test_collective_rpc.py
|
||||
- VLLM_USE_V1=1 torchrun --nproc-per-node=2 distributed/test_torchrun_example.py
|
||||
- torchrun --nproc-per-node=2 distributed/test_torchrun_example.py
|
||||
- pytest -v -s ./compile/test_basic_correctness.py
|
||||
- pytest -v -s ./compile/test_wrapper.py
|
||||
@ -586,6 +589,7 @@ steps:
|
||||
- pytest -v -s -x lora/test_chatglm3_tp.py
|
||||
- pytest -v -s -x lora/test_llama_tp.py
|
||||
- pytest -v -s -x lora/test_minicpmv_tp.py
|
||||
- pytest -v -s -x lora/test_transfomers_model.py
|
||||
|
||||
|
||||
- label: Weight Loading Multiple GPU Test # 33min
|
||||
|
||||
@ -50,8 +50,11 @@ aws s3 cp "$normal_wheel" "s3://vllm-wheels/$BUILDKITE_COMMIT/"
|
||||
if [[ $normal_wheel == *"cu118"* ]]; then
|
||||
# if $normal_wheel matches cu118, do not upload the index.html
|
||||
echo "Skipping index files for cu118 wheels"
|
||||
elif [[ $normal_wheel == *"cu121"* ]]; then
|
||||
# if $normal_wheel matches cu121, do not upload the index.html
|
||||
echo "Skipping index files for cu121 wheels"
|
||||
else
|
||||
# only upload index.html for cu12 wheels (default wheels)
|
||||
# only upload index.html for cu124 wheels (default wheels)
|
||||
aws s3 cp index.html "s3://vllm-wheels/$BUILDKITE_COMMIT/vllm/index.html"
|
||||
aws s3 cp "s3://vllm-wheels/nightly/index.html" "s3://vllm-wheels/$BUILDKITE_COMMIT/index.html"
|
||||
fi
|
||||
@ -63,8 +66,11 @@ aws s3 cp "$normal_wheel" "s3://vllm-wheels/nightly/"
|
||||
if [[ $normal_wheel == *"cu118"* ]]; then
|
||||
# if $normal_wheel matches cu118, do not upload the index.html
|
||||
echo "Skipping index files for cu118 wheels"
|
||||
elif [[ $normal_wheel == *"cu121"* ]]; then
|
||||
# if $normal_wheel matches cu121, do not upload the index.html
|
||||
echo "Skipping index files for cu121 wheels"
|
||||
else
|
||||
# only upload index.html for cu12 wheels (default wheels)
|
||||
# only upload index.html for cu124 wheels (default wheels)
|
||||
aws s3 cp index.html "s3://vllm-wheels/nightly/vllm/index.html"
|
||||
fi
|
||||
|
||||
|
||||
2
.github/dependabot.yml
vendored
2
.github/dependabot.yml
vendored
@ -23,7 +23,7 @@ updates:
|
||||
- dependency-name: "lm-format-enforcer"
|
||||
- dependency-name: "gguf"
|
||||
- dependency-name: "compressed-tensors"
|
||||
- dependency-name: "ray[adag]"
|
||||
- dependency-name: "ray[cgraph]" # Ray Compiled Graph
|
||||
- dependency-name: "lm-eval"
|
||||
groups:
|
||||
minor-update:
|
||||
|
||||
1
.github/mergify.yml
vendored
1
.github/mergify.yml
vendored
@ -5,6 +5,7 @@ pull_request_rules:
|
||||
- or:
|
||||
- files~=^[^/]+\.md$
|
||||
- files~=^docs/
|
||||
- files~=^examples/
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
|
||||
2
.github/workflows/lint-and-deploy.yaml
vendored
2
.github/workflows/lint-and-deploy.yaml
vendored
@ -12,7 +12,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@fe7b79cd5ee1e45176fcad797de68ecaf3ca4814 # v4.2.0
|
||||
uses: azure/setup-helm@b9e51907a09c216f16ebe8536097933489208112 # v4.3.0
|
||||
with:
|
||||
version: v3.14.4
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
default_stages:
|
||||
- pre-commit # Run locally
|
||||
- manual # Run in CI
|
||||
exclude: 'vllm/third_party/.*'
|
||||
repos:
|
||||
- repo: https://github.com/google/yapf
|
||||
rev: v0.43.0
|
||||
@ -8,13 +9,11 @@ repos:
|
||||
- id: yapf
|
||||
args: [--in-place, --verbose]
|
||||
additional_dependencies: [toml] # TODO: Remove when yapf is upgraded
|
||||
exclude: 'vllm/third_party/.*'
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.9.3
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--output-format, github, --fix]
|
||||
exclude: 'vllm/third_party/.*'
|
||||
- repo: https://github.com/codespell-project/codespell
|
||||
rev: v2.4.0
|
||||
hooks:
|
||||
@ -22,10 +21,9 @@ repos:
|
||||
additional_dependencies: ['tomli']
|
||||
args: ['--toml', 'pyproject.toml']
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
rev: 5.13.2
|
||||
rev: 0a0b7a830386ba6a31c2ec8316849ae4d1b8240d # 6.0.0
|
||||
hooks:
|
||||
- id: isort
|
||||
exclude: 'vllm/third_party/.*'
|
||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||
rev: v19.1.7
|
||||
hooks:
|
||||
@ -38,12 +36,16 @@ repos:
|
||||
hooks:
|
||||
- id: pymarkdown
|
||||
args: [fix]
|
||||
exclude: 'vllm/third_party/.*'
|
||||
- repo: https://github.com/rhysd/actionlint
|
||||
rev: v1.7.7
|
||||
hooks:
|
||||
- id: actionlint
|
||||
exclude: 'vllm/third_party/.*'
|
||||
- repo: https://github.com/astral-sh/uv-pre-commit
|
||||
rev: 0.6.2
|
||||
hooks:
|
||||
- id: pip-compile
|
||||
args: [requirements-test.in, -o, requirements-test.txt]
|
||||
files: ^requirements-test\.(in|txt)$
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: mypy-local
|
||||
@ -53,7 +55,6 @@ repos:
|
||||
types: [python]
|
||||
additional_dependencies: &mypy_deps [mypy==1.11.1, types-setuptools, types-PyYAML, types-requests]
|
||||
stages: [pre-commit] # Don't run in CI
|
||||
exclude: 'vllm/third_party/.*'
|
||||
- id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
||||
name: Run mypy for Python 3.9
|
||||
entry: tools/mypy.sh 1 "3.9"
|
||||
@ -61,7 +62,6 @@ repos:
|
||||
types: [python]
|
||||
additional_dependencies: *mypy_deps
|
||||
stages: [manual] # Only run in CI
|
||||
exclude: 'vllm/third_party/.*'
|
||||
- id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
||||
name: Run mypy for Python 3.10
|
||||
entry: tools/mypy.sh 1 "3.10"
|
||||
@ -69,7 +69,6 @@ repos:
|
||||
types: [python]
|
||||
additional_dependencies: *mypy_deps
|
||||
stages: [manual] # Only run in CI
|
||||
exclude: 'vllm/third_party/.*'
|
||||
- id: mypy-3.11 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
||||
name: Run mypy for Python 3.11
|
||||
entry: tools/mypy.sh 1 "3.11"
|
||||
@ -77,7 +76,6 @@ repos:
|
||||
types: [python]
|
||||
additional_dependencies: *mypy_deps
|
||||
stages: [manual] # Only run in CI
|
||||
exclude: 'vllm/third_party/.*'
|
||||
- id: mypy-3.12 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
||||
name: Run mypy for Python 3.12
|
||||
entry: tools/mypy.sh 1 "3.12"
|
||||
@ -85,19 +83,16 @@ repos:
|
||||
types: [python]
|
||||
additional_dependencies: *mypy_deps
|
||||
stages: [manual] # Only run in CI
|
||||
exclude: 'vllm/third_party/.*'
|
||||
- id: shellcheck
|
||||
name: Lint shell scripts
|
||||
entry: tools/shellcheck.sh
|
||||
language: script
|
||||
types: [shell]
|
||||
exclude: 'vllm/third_party/.*'
|
||||
- id: png-lint
|
||||
name: Lint PNG exports from excalidraw
|
||||
entry: tools/png-lint.sh
|
||||
language: script
|
||||
types: [png]
|
||||
exclude: 'vllm/third_party/.*'
|
||||
- id: signoff-commit
|
||||
name: Sign-off Commit
|
||||
entry: bash
|
||||
@ -110,13 +105,11 @@ repos:
|
||||
language: system
|
||||
verbose: true
|
||||
stages: [commit-msg]
|
||||
exclude: 'vllm/third_party/.*'
|
||||
- id: check-spdx-header
|
||||
name: Check SPDX headers
|
||||
entry: python tools/check_spdx_header.py
|
||||
language: python
|
||||
types: [python]
|
||||
exclude: 'vllm/third_party/.*'
|
||||
- id: check-filenames
|
||||
name: Check for spaces in all filenames
|
||||
entry: bash
|
||||
@ -126,7 +119,6 @@ repos:
|
||||
language: system
|
||||
always_run: true
|
||||
pass_filenames: false
|
||||
exclude: 'vllm/third_party/.*'
|
||||
# Keep `suggestion` last
|
||||
- id: suggestion
|
||||
name: Suggestion
|
||||
@ -134,5 +126,4 @@ repos:
|
||||
language: system
|
||||
verbose: true
|
||||
pass_filenames: false
|
||||
exclude: 'vllm/third_party/.*'
|
||||
# Insert new entries above the `suggestion` entry
|
||||
|
||||
152
CMakeLists.txt
Executable file → Normal file
152
CMakeLists.txt
Executable file → Normal file
@ -31,7 +31,7 @@ set(ignoreMe "${VLLM_PYTHON_PATH}")
|
||||
set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12")
|
||||
|
||||
# Supported NVIDIA architectures.
|
||||
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0")
|
||||
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0")
|
||||
|
||||
# Supported AMD GPU architectures.
|
||||
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101")
|
||||
@ -174,6 +174,25 @@ include(FetchContent)
|
||||
file(MAKE_DIRECTORY ${FETCHCONTENT_BASE_DIR}) # Ensure the directory exists
|
||||
message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}")
|
||||
|
||||
#
|
||||
# Set rocm version dev int.
|
||||
#
|
||||
if(VLLM_GPU_LANG STREQUAL "HIP")
|
||||
#
|
||||
# Overriding the default -O set up by cmake, adding ggdb3 for the most verbose devug info
|
||||
#
|
||||
set(CMAKE_${VLLM_GPU_LANG}_FLAGS_DEBUG "${CMAKE_${VLLM_GPU_LANG}_FLAGS_DEBUG} -O0 -ggdb3")
|
||||
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -O0 -ggdb3")
|
||||
|
||||
|
||||
#
|
||||
# Certain HIP functions are marked as [[nodiscard]], yet vllm ignores the result which generates
|
||||
# a lot of warnings that always mask real issues. Suppressing until this is properly addressed.
|
||||
#
|
||||
set(CMAKE_${VLLM_GPU_LANG}_FLAGS "${CMAKE_${VLLM_GPU_LANG}_FLAGS} -Wno-unused-result")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-result")
|
||||
endif()
|
||||
|
||||
#
|
||||
# Define other extension targets
|
||||
#
|
||||
@ -229,7 +248,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
# Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case.
|
||||
# Please keep this in sync with FetchContent_Declare line below.
|
||||
set(CUTLASS_REVISION "v3.7.0" CACHE STRING "CUTLASS revision to use")
|
||||
set(CUTLASS_REVISION "v3.8.0" CACHE STRING "CUTLASS revision to use")
|
||||
|
||||
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
|
||||
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
|
||||
@ -247,7 +266,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
cutlass
|
||||
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
|
||||
# Please keep this in sync with CUTLASS_REVISION line above.
|
||||
GIT_TAG v3.7.0
|
||||
GIT_TAG v3.8.0
|
||||
GIT_PROGRESS TRUE
|
||||
|
||||
# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
|
||||
@ -267,6 +286,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"csrc/permute_cols.cu"
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
|
||||
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
|
||||
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
|
||||
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
|
||||
"csrc/cutlass_extensions/common.cpp")
|
||||
|
||||
@ -277,7 +297,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# Only build Marlin kernels if we are building for at least some compatible archs.
|
||||
# Keep building Marlin for 9.0 as there are some group sizes and shapes that
|
||||
# are not supported by Machete yet.
|
||||
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}")
|
||||
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
|
||||
if (MARLIN_ARCHS)
|
||||
set(MARLIN_SRCS
|
||||
"csrc/quantization/fp8/fp8_marlin.cu"
|
||||
@ -297,11 +317,27 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
" in CUDA target architectures")
|
||||
endif()
|
||||
|
||||
# Only build AllSpark kernels if we are building for at least some compatible archs.
|
||||
cuda_archs_loose_intersection(ALLSPARK_ARCHS "8.0;8.6;8.7;8.9" "${CUDA_ARCHS}")
|
||||
if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND ALLSPARK_ARCHS)
|
||||
set(ALLSPARK_SRCS
|
||||
"csrc/quantization/gptq_allspark/allspark_repack.cu"
|
||||
"csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${ALLSPARK_SRCS}"
|
||||
CUDA_ARCHS "${ALLSPARK_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${ALLSPARK_SRCS}")
|
||||
message(STATUS "Building AllSpark kernels for archs: ${ALLSPARK_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building AllSpark kernels as no compatible archs found"
|
||||
" in CUDA target architectures, or CUDA not >= 12.0")
|
||||
endif()
|
||||
|
||||
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
|
||||
# CUDA 12.0 or later (and only work on Hopper, 9.0a for now).
|
||||
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a;10.0a;10.1a;12.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
|
||||
set(SRCS
|
||||
set(SRCS
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu"
|
||||
@ -333,7 +369,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
|
||||
# kernels for the remaining archs that are not already built for 3x.
|
||||
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
|
||||
"7.5;8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}")
|
||||
"7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
|
||||
# subtract out the archs that are already built for 3x
|
||||
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
|
||||
if (SCALED_MM_2X_ARCHS)
|
||||
@ -358,7 +394,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# 2:4 Sparse Kernels
|
||||
|
||||
# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
|
||||
# require CUDA 12.2 or later (and only work on Hopper, 9.0a for now).
|
||||
# require CUDA 12.2 or later (and only work on Hopper and Blackwell).
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
|
||||
set(SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
@ -381,9 +417,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# FP4 Archs and flags
|
||||
cuda_archs_loose_intersection(FP4_ARCHS "10.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND FP4_ARCHS)
|
||||
set(SRCS
|
||||
set(SRCS
|
||||
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
|
||||
)
|
||||
"csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${FP4_ARCHS}")
|
||||
@ -396,6 +432,22 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
set(FP4_ARCHS)
|
||||
endif()
|
||||
|
||||
# FP8 Blackwell Archs
|
||||
cuda_archs_loose_intersection(BLACKWELL_ARCHS "10.0;10.1;12.0" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND BLACKWELL_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu"
|
||||
)
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${BLACKWELL_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
message(STATUS "Building FP8 for archs: ${BLACKWELL_ARCHS}")
|
||||
else()
|
||||
# clear BLACKWELL_ARCHS
|
||||
set(BLACKWELL_ARCHS)
|
||||
endif()
|
||||
|
||||
#
|
||||
# Machete kernels
|
||||
|
||||
@ -477,6 +529,7 @@ define_gpu_extension_target(
|
||||
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
||||
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
||||
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
|
||||
INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
|
||||
USE_SABI 3
|
||||
WITH_SOABI)
|
||||
|
||||
@ -500,7 +553,7 @@ set_gencode_flags_for_srcs(
|
||||
CUDA_ARCHS "${CUDA_ARCHS}")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}")
|
||||
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
|
||||
if (MARLIN_MOE_ARCHS)
|
||||
set(MARLIN_MOE_SRC
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel.h"
|
||||
@ -554,77 +607,8 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
|
||||
WITH_SOABI)
|
||||
endif()
|
||||
|
||||
# vllm-flash-attn currently only supported on CUDA
|
||||
if (NOT VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
return()
|
||||
# For CUDA we also build and ship some external projects.
|
||||
if (VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
include(cmake/external_projects/flashmla.cmake)
|
||||
include(cmake/external_projects/vllm_flash_attn.cmake)
|
||||
endif ()
|
||||
|
||||
# vLLM flash attention requires VLLM_GPU_ARCHES to contain the set of target
|
||||
# arches in the CMake syntax (75-real, 89-virtual, etc), since we clear the
|
||||
# arches in the CUDA case (and instead set the gencodes on a per file basis)
|
||||
# we need to manually set VLLM_GPU_ARCHES here.
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
foreach(_ARCH ${CUDA_ARCHS})
|
||||
string(REPLACE "." "" _ARCH "${_ARCH}")
|
||||
list(APPEND VLLM_GPU_ARCHES "${_ARCH}-real")
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
#
|
||||
# Build vLLM flash attention from source
|
||||
#
|
||||
# IMPORTANT: This has to be the last thing we do, because vllm-flash-attn uses the same macros/functions as vLLM.
|
||||
# Because functions all belong to the global scope, vllm-flash-attn's functions overwrite vLLMs.
|
||||
# They should be identical but if they aren't, this is a massive footgun.
|
||||
#
|
||||
# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place.
|
||||
# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2) or --component _vllm_fa3_C (for FA3).
|
||||
# If no component is specified, vllm-flash-attn is still installed.
|
||||
|
||||
# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading.
|
||||
# This is to enable local development of vllm-flash-attn within vLLM.
|
||||
# It can be set as an environment variable or passed as a cmake argument.
|
||||
# The environment variable takes precedence.
|
||||
if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR})
|
||||
set(VLLM_FLASH_ATTN_SRC_DIR $ENV{VLLM_FLASH_ATTN_SRC_DIR})
|
||||
endif()
|
||||
|
||||
if(VLLM_FLASH_ATTN_SRC_DIR)
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn SOURCE_DIR
|
||||
${VLLM_FLASH_ATTN_SRC_DIR}
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
)
|
||||
else()
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn
|
||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||
GIT_TAG 720c94869cf2e0ff5a706e9c7f1dce0939686ade
|
||||
GIT_PROGRESS TRUE
|
||||
# Don't share the vllm-flash-attn build between build types
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
)
|
||||
endif()
|
||||
|
||||
|
||||
# Fetch the vllm-flash-attn library
|
||||
FetchContent_MakeAvailable(vllm-flash-attn)
|
||||
message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}")
|
||||
|
||||
# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in
|
||||
# case only one is built, in the case both are built redundant work is done)
|
||||
install(
|
||||
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
|
||||
DESTINATION vllm_flash_attn
|
||||
COMPONENT _vllm_fa2_C
|
||||
FILES_MATCHING PATTERN "*.py"
|
||||
)
|
||||
|
||||
install(
|
||||
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
|
||||
DESTINATION vllm_flash_attn
|
||||
COMPONENT _vllm_fa3_C
|
||||
FILES_MATCHING PATTERN "*.py"
|
||||
)
|
||||
|
||||
# Nothing after vllm-flash-attn, see comment about macros above
|
||||
|
||||
32
Dockerfile
32
Dockerfile
@ -28,7 +28,7 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
|
||||
&& curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \
|
||||
&& python3 --version && python3 -m pip --version
|
||||
# Install uv for faster pip installs
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
python3 -m pip install uv
|
||||
|
||||
# Upgrade to GCC 10 to avoid https://gcc.gnu.org/bugzilla/show_bug.cgi?id=92519
|
||||
@ -53,14 +53,14 @@ WORKDIR /workspace
|
||||
# we need to install torch and torchvision from the nightly builds first,
|
||||
# pytorch will not appear as a vLLM dependency in all of the following steps
|
||||
# after this step
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
|
||||
uv pip install --system --index-url https://download.pytorch.org/whl/nightly/cu126 "torch==2.7.0.dev20250121+cu126" "torchvision==0.22.0.dev20250121"; \
|
||||
fi
|
||||
|
||||
COPY requirements-common.txt requirements-common.txt
|
||||
COPY requirements-cuda.txt requirements-cuda.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements-cuda.txt
|
||||
|
||||
# cuda arch list used by torch
|
||||
@ -81,7 +81,7 @@ ARG TARGETPLATFORM
|
||||
# install build dependencies
|
||||
COPY requirements-build.txt requirements-build.txt
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements-build.txt
|
||||
|
||||
COPY . .
|
||||
@ -101,7 +101,7 @@ ARG SCCACHE_BUCKET_NAME=vllm-build-sccache
|
||||
ARG SCCACHE_REGION_NAME=us-west-2
|
||||
ARG SCCACHE_S3_NO_CREDENTIALS=0
|
||||
# if USE_SCCACHE is set, use sccache to speed up compilation
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=.git,target=.git \
|
||||
if [ "$USE_SCCACHE" = "1" ]; then \
|
||||
echo "Installing sccache..." \
|
||||
@ -121,7 +121,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
|
||||
ENV CCACHE_DIR=/root/.cache/ccache
|
||||
RUN --mount=type=cache,target=/root/.cache/ccache \
|
||||
--mount=type=cache,target=/root/.cache/pip \
|
||||
--mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=.git,target=.git \
|
||||
if [ "$USE_SCCACHE" != "1" ]; then \
|
||||
python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38; \
|
||||
@ -146,7 +146,7 @@ FROM base as dev
|
||||
COPY requirements-lint.txt requirements-lint.txt
|
||||
COPY requirements-test.txt requirements-test.txt
|
||||
COPY requirements-dev.txt requirements-dev.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements-dev.txt
|
||||
#################### DEV IMAGE ####################
|
||||
|
||||
@ -178,7 +178,7 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
|
||||
&& curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \
|
||||
&& python3 --version && python3 -m pip --version
|
||||
# Install uv for faster pip installs
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
python3 -m pip install uv
|
||||
|
||||
# Workaround for https://github.com/openai/triton/issues/2507 and
|
||||
@ -191,14 +191,14 @@ RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/
|
||||
# we need to install torch and torchvision from the nightly builds first,
|
||||
# pytorch will not appear as a vLLM dependency in all of the following steps
|
||||
# after this step
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
|
||||
uv pip install --system --index-url https://download.pytorch.org/whl/nightly/cu124 "torch==2.6.0.dev20241210+cu124" "torchvision==0.22.0.dev20241215"; \
|
||||
fi
|
||||
|
||||
# Install vllm wheel first, so that torch etc will be installed.
|
||||
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
|
||||
--mount=type=cache,target=/root/.cache/pip \
|
||||
--mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system dist/*.whl --verbose
|
||||
|
||||
# If we need to build FlashInfer wheel before its release:
|
||||
@ -213,7 +213,7 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
|
||||
# $ ls dist
|
||||
# $ # upload the wheel to a public location, e.g. https://wheels.vllm.ai/flashinfer/524304395bd1d8cd7d07db083859523fcaa246a4/flashinfer_python-0.2.1.post1+cu124torch2.5-cp38-abi3-linux_x86_64.whl
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
. /etc/environment && \
|
||||
if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
|
||||
uv pip install --system https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.1.post1/flashinfer_python-0.2.1.post1+cu124torch2.5-cp38-abi3-linux_x86_64.whl ; \
|
||||
@ -225,7 +225,7 @@ COPY examples examples
|
||||
# install build dependencies for JIT compilation.
|
||||
# TODO: Remove this once FlashInfer AOT wheel is fixed
|
||||
COPY requirements-build.txt requirements-build.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements-build.txt
|
||||
|
||||
#################### vLLM installation IMAGE ####################
|
||||
@ -238,15 +238,15 @@ FROM vllm-base AS test
|
||||
ADD . /vllm-workspace/
|
||||
|
||||
# install development dependencies (for testing)
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements-dev.txt
|
||||
|
||||
# install development dependencies (for testing)
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -e tests/vllm_test_utils
|
||||
|
||||
# enable fast downloads from hf (for testing)
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system hf_transfer
|
||||
ENV HF_HUB_ENABLE_HF_TRANSFER 1
|
||||
|
||||
@ -266,7 +266,7 @@ RUN mv vllm test_docs/
|
||||
FROM vllm-base AS vllm-openai-base
|
||||
|
||||
# install additional dependencies for openai api server
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
|
||||
uv pip install --system accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.42.0' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \
|
||||
else \
|
||||
|
||||
@ -15,7 +15,11 @@ Easy, fast, and cheap LLM serving for everyone
|
||||
|
||||
---
|
||||
|
||||
We are excited to invite you to our Menlo Park meetup with Meta, evening of Thursday, February 27! Meta engineers will discuss the improvements on top of vLLM, and vLLM contributors will share updates from the v0.7.x series of releases. [Register Now](https://lu.ma/h7g3kuj9)
|
||||
We’re excited to invite you to the first **vLLM China Meetup** on **March 16** in **Beijing**!
|
||||
|
||||
Join us to connect with the **vLLM team** and explore how vLLM is leveraged in **post-training, fine-tuning, and deployment**, including [verl](https://github.com/volcengine/verl), [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory), and [vllm-ascend](https://github.com/vllm-project/vllm-ascend).
|
||||
|
||||
👉 **[Register Now](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg)** to be part of the discussion!
|
||||
|
||||
---
|
||||
|
||||
|
||||
54
RELEASE.md
Normal file
54
RELEASE.md
Normal file
@ -0,0 +1,54 @@
|
||||
# Releasing vLLM
|
||||
|
||||
vLLM releases offer a reliable version of the code base, packaged into a binary format that can be conveniently accessed via PyPI. These releases also serve as key milestones for the development team to communicate with the community about newly available features, improvements, and upcoming changes that could affect users, including potential breaking changes.
|
||||
|
||||
## Release Versioning
|
||||
|
||||
vLLM uses a “right-shifted” versioning scheme where a new patch release is out every 2 weeks. And patch releases contain features and bug fixes (as opposed to semver where patch release contains only backwards-compatible bug fixes). When critical fixes need to be made, special release post1 is released.
|
||||
|
||||
* _major_ major architectural milestone and when incompatible API changes are made, similar to PyTorch 2.0.
|
||||
* _minor_ major features
|
||||
* _patch_ features and backwards-compatible bug fixes
|
||||
* _post1_ or _patch-1_ backwards-compatible bug fixes, either explicit or implicit post release
|
||||
|
||||
## Release Cadence
|
||||
|
||||
Patch release is released on bi-weekly basis. Post release 1-3 days after patch release and uses same branch as patch release.
|
||||
Following is the release cadence for year 2025. All future release dates below are tentative. Please note: Post releases are optional.
|
||||
|
||||
| Release Date | Patch release versions | Post Release versions |
|
||||
| --- | --- | --- |
|
||||
| Jan 2025 | 0.7.0 | --- |
|
||||
| Feb 2025 | 0.7.1, 0.7.2, 0.7.3 | --- |
|
||||
| Mar 2025 | 0.7.4, 0.7.5 | --- |
|
||||
| Apr 2025 | 0.7.6, 0.7.7 | --- |
|
||||
| May 2025 | 0.7.8, 0.7.9 | --- |
|
||||
| Jun 2025 | 0.7.10, 0.7.11 | --- |
|
||||
| Jul 2025 | 0.7.12, 0.7.13 | --- |
|
||||
| Aug 2025 | 0.7.14, 0.7.15 | --- |
|
||||
| Sep 2025 | 0.7.16, 0.7.17 | --- |
|
||||
| Oct 2025 | 0.7.18, 0.7.19 | --- |
|
||||
| Nov 2025 | 0.7.20, 0.7.21 | --- |
|
||||
| Dec 2025 | 0.7.22, 0.7.23 | --- |
|
||||
|
||||
## Release branch
|
||||
|
||||
Each release is built from a dedicated release branch.
|
||||
|
||||
* For _major_, _minor_, _patch_ releases, the release branch cut is performed 1-2 days before release is live.
|
||||
* For post releases, previously cut release branch is reused
|
||||
* Release builds are triggered via push to RC tag like vX.Y.Z-rc1 . This enables us to build and test multiple RCs for each release.
|
||||
* Final tag : vX.Y.Z does not trigger the build but used for Release notes and assets.
|
||||
* After branch cut is created we monitor the main branch for any reverts and apply these reverts to a release branch.
|
||||
|
||||
## Release Cherry-Pick Criteria
|
||||
|
||||
After branch cut, we approach finalizing the release branch with clear criteria on what cherry picks are allowed in. Note: a cherry pick is a process to land a PR in the release branch after branch cut. These are typically limited to ensure that the team has sufficient time to complete a thorough round of testing on a stable code base.
|
||||
|
||||
* Regression fixes - that address functional/performance regression against the most recent release (e.g. 0.7.0 for 0.7.1 release)
|
||||
* Critical fixes - critical fixes for severe issue such as silent incorrectness, backwards compatibility, crashes, deadlocks, (large) memory leaks
|
||||
* Fixes to new features introduced in the most recent release (e.g. 0.7.0 for 0.7.1 release)
|
||||
* Documentation improvements
|
||||
* Release branch specific changes (e.g. change version identifiers or CI fixes)
|
||||
|
||||
Please note: **No feature work allowed for cherry picks**. All PRs that are considered for cherry-picks need to be merged on trunk, the only exception are Release branch specific changes.
|
||||
@ -6,7 +6,7 @@ import sys
|
||||
import time
|
||||
import traceback
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import aiohttp
|
||||
import huggingface_hub.constants
|
||||
@ -14,6 +14,8 @@ from tqdm.asyncio import tqdm
|
||||
from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast)
|
||||
|
||||
from vllm.model_executor.model_loader.weight_utils import get_lock
|
||||
|
||||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
||||
|
||||
|
||||
@ -39,8 +41,8 @@ class RequestFuncOutput:
|
||||
latency: float = 0.0
|
||||
output_tokens: int = 0
|
||||
ttft: float = 0.0 # Time to first token
|
||||
itl: List[float] = field(
|
||||
default_factory=list) # List of inter-token latencies
|
||||
itl: list[float] = field(
|
||||
default_factory=list) # list of inter-token latencies
|
||||
tpot: float = 0.0 # avg next-token latencies
|
||||
prompt_len: int = 0
|
||||
error: str = ""
|
||||
@ -430,12 +432,15 @@ def get_model(pretrained_model_name_or_path: str) -> str:
|
||||
if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true':
|
||||
from modelscope import snapshot_download
|
||||
|
||||
model_path = snapshot_download(
|
||||
model_id=pretrained_model_name_or_path,
|
||||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||
ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
|
||||
# Use file lock to prevent multiple processes from
|
||||
# downloading the same model weights at the same time.
|
||||
with get_lock(pretrained_model_name_or_path):
|
||||
model_path = snapshot_download(
|
||||
model_id=pretrained_model_name_or_path,
|
||||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||
ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
|
||||
|
||||
return model_path
|
||||
return model_path
|
||||
return pretrained_model_name_or_path
|
||||
|
||||
|
||||
|
||||
@ -6,7 +6,6 @@ import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
import datasets
|
||||
import pandas as pd
|
||||
@ -39,17 +38,23 @@ class SampleRequest:
|
||||
completion: str = None
|
||||
|
||||
|
||||
def run_vllm(requests: List[SampleRequest],
|
||||
def run_vllm(requests: list[SampleRequest],
|
||||
engine_args: EngineArgs,
|
||||
n: int,
|
||||
guided_decoding_rate: float = 1.0,
|
||||
warmup: bool = False) -> float:
|
||||
from vllm import LLM, SamplingParams
|
||||
llm = LLM(**vars(engine_args))
|
||||
assert all(
|
||||
llm.llm_engine.model_config.max_model_len >= (
|
||||
request.prompt_len + request.expected_output_len)
|
||||
for request in requests), (
|
||||
"Please ensure that max_model_len is greater than the sum of"
|
||||
" prompt_len and expected_output_len for all requests.")
|
||||
|
||||
# Add the requests to the engine.
|
||||
prompts: List[str] = []
|
||||
sampling_params: List[SamplingParams] = []
|
||||
prompts: list[str] = []
|
||||
sampling_params: list[SamplingParams] = []
|
||||
# create a list containing random selected true or false
|
||||
guided_decoding_req_idx = random.sample(
|
||||
range(len(requests)), int(len(requests) * guided_decoding_rate))
|
||||
@ -104,7 +109,7 @@ def run_vllm(requests: List[SampleRequest],
|
||||
|
||||
|
||||
async def run_vllm_async(
|
||||
requests: List[SampleRequest],
|
||||
requests: list[SampleRequest],
|
||||
engine_args: AsyncEngineArgs,
|
||||
n: int,
|
||||
guided_decoding_rate: float = 1.0,
|
||||
@ -115,9 +120,16 @@ async def run_vllm_async(
|
||||
async with build_async_engine_client_from_engine_args(
|
||||
engine_args, disable_frontend_multiprocessing) as llm:
|
||||
|
||||
assert all(
|
||||
llm.model_config.max_model_len >= (request.prompt_len +
|
||||
request.expected_output_len)
|
||||
for request in requests), (
|
||||
"Please ensure that max_model_len is greater than the sum of"
|
||||
" prompt_len and expected_output_len for all requests.")
|
||||
|
||||
# Add the requests to the engine.
|
||||
prompts: List[str] = []
|
||||
sampling_params: List[SamplingParams] = []
|
||||
prompts: list[str] = []
|
||||
sampling_params: list[SamplingParams] = []
|
||||
guided_decoding_req_idx = random.sample(
|
||||
range(len(requests)), int(len(requests) * guided_decoding_rate))
|
||||
|
||||
@ -190,7 +202,7 @@ async def run_vllm_async(
|
||||
|
||||
|
||||
def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||
args: argparse.Namespace) -> List[SampleRequest]:
|
||||
args: argparse.Namespace) -> list[SampleRequest]:
|
||||
if args.dataset == 'json':
|
||||
if args.json_schema_path is None:
|
||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||
@ -274,7 +286,7 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||
|
||||
elif args.dataset == "xgrammar_bench":
|
||||
args.warmup = False
|
||||
requests: List[SampleRequest] = []
|
||||
requests: list[SampleRequest] = []
|
||||
dataset = datasets.load_dataset("NousResearch/json-mode-eval",
|
||||
split="train")
|
||||
print(f"dataset has {len(dataset)} entries")
|
||||
|
||||
@ -7,11 +7,11 @@ import json
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from benchmark_utils import convert_to_pytorch_benchmark_format
|
||||
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
@ -22,7 +22,7 @@ from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def save_to_pytorch_benchmark_format(args: argparse.Namespace,
|
||||
results: Dict[str, Any]) -> None:
|
||||
results: dict[str, Any]) -> None:
|
||||
pt_records = convert_to_pytorch_benchmark_format(
|
||||
args=args,
|
||||
metrics={"latency": results["latencies"]},
|
||||
@ -30,8 +30,7 @@ def save_to_pytorch_benchmark_format(args: argparse.Namespace,
|
||||
for k in ["avg_latency", "percentiles"]})
|
||||
if pt_records:
|
||||
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
|
||||
with open(pt_file, "w") as f:
|
||||
json.dump(pt_records, f)
|
||||
write_to_json(pt_file, pt_records)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
@ -42,6 +41,10 @@ def main(args: argparse.Namespace):
|
||||
# NOTE(woosuk): If the request cannot be processed in a single batch,
|
||||
# the engine will automatically process the request in multiple batches.
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
assert llm.llm_engine.model_config.max_model_len >= (
|
||||
args.input_len +
|
||||
args.output_len), ("Please ensure that max_model_len is greater than"
|
||||
" the sum of input_len and output_len.")
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
n=args.n,
|
||||
@ -54,7 +57,7 @@ def main(args: argparse.Namespace):
|
||||
dummy_prompt_token_ids = np.random.randint(10000,
|
||||
size=(args.batch_size,
|
||||
args.input_len))
|
||||
dummy_prompts: List[PromptType] = [{
|
||||
dummy_prompts: list[PromptType] = [{
|
||||
"prompt_token_ids": batch
|
||||
} for batch in dummy_prompt_token_ids.tolist()]
|
||||
|
||||
|
||||
@ -31,7 +31,7 @@ import dataclasses
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
@ -77,9 +77,9 @@ def sample_requests_from_dataset(
|
||||
dataset_path: str,
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
input_length_range: Tuple[int, int],
|
||||
input_length_range: tuple[int, int],
|
||||
fixed_output_len: Optional[int],
|
||||
) -> List[Request]:
|
||||
) -> list[Request]:
|
||||
if fixed_output_len is not None and fixed_output_len < 4:
|
||||
raise ValueError("output_len too small")
|
||||
|
||||
@ -99,7 +99,7 @@ def sample_requests_from_dataset(
|
||||
assert min_len >= 0 and max_len >= min_len, "input_length_range too small"
|
||||
|
||||
# Filter out sequences that are too long or too short
|
||||
filtered_requests: List[Request] = []
|
||||
filtered_requests: list[Request] = []
|
||||
|
||||
for i in range(len(dataset)):
|
||||
if len(filtered_requests) == num_requests:
|
||||
@ -122,10 +122,10 @@ def sample_requests_from_dataset(
|
||||
def sample_requests_from_random(
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
input_length_range: Tuple[int, int],
|
||||
input_length_range: tuple[int, int],
|
||||
fixed_output_len: Optional[int],
|
||||
prefix_len: int,
|
||||
) -> List[Request]:
|
||||
) -> list[Request]:
|
||||
|
||||
requests = []
|
||||
prefix_token_ids = sample_tokens(tokenizer, prefix_len)
|
||||
@ -144,9 +144,9 @@ def sample_requests_from_random(
|
||||
return requests
|
||||
|
||||
|
||||
def repeat_and_sort_requests(requests: List[Request],
|
||||
def repeat_and_sort_requests(requests: list[Request],
|
||||
repeat_count: int,
|
||||
sort: bool = False) -> List[str]:
|
||||
sort: bool = False) -> list[str]:
|
||||
repeated_requests = requests * repeat_count
|
||||
if sort:
|
||||
repeated_requests.sort(key=lambda x: x[1])
|
||||
|
||||
@ -5,7 +5,7 @@ import dataclasses
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||
|
||||
@ -13,12 +13,17 @@ from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
#Select a equi-probable random priority
|
||||
def get_random_flag():
|
||||
return 0 if random.random() < 0.5 else 1
|
||||
|
||||
|
||||
def sample_requests(
|
||||
dataset_path: str,
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
fixed_output_len: Optional[int],
|
||||
) -> List[Tuple[str, int, int]]:
|
||||
) -> list[tuple[str, int, int]]:
|
||||
if fixed_output_len is not None and fixed_output_len < 4:
|
||||
raise ValueError("output_len too small")
|
||||
|
||||
@ -35,7 +40,7 @@ def sample_requests(
|
||||
random.shuffle(dataset)
|
||||
|
||||
# Filter out sequences that are too long or too short
|
||||
filtered_dataset: List[Tuple[str, int, int]] = []
|
||||
filtered_dataset: list[tuple[str, int, int]] = []
|
||||
for i in range(len(dataset)):
|
||||
if len(filtered_dataset) == num_requests:
|
||||
break
|
||||
@ -55,8 +60,7 @@ def sample_requests(
|
||||
# Prune too long sequences.
|
||||
continue
|
||||
|
||||
#Select a equi-probable random priority
|
||||
priority = 0 if random.random() < 0.5 else 1
|
||||
priority = get_random_flag()
|
||||
|
||||
filtered_dataset.append((prompt, prompt_len, output_len, priority))
|
||||
|
||||
@ -64,13 +68,19 @@ def sample_requests(
|
||||
|
||||
|
||||
def run_vllm(
|
||||
requests: List[Tuple[str, int, int]],
|
||||
requests: list[tuple[str, int, int]],
|
||||
n: int,
|
||||
engine_args: EngineArgs,
|
||||
) -> float:
|
||||
from vllm import LLM, SamplingParams
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
|
||||
assert all(
|
||||
llm.llm_engine.model_config.max_model_len >= (request[1] + request[2])
|
||||
for request in requests), (
|
||||
"Please ensure that max_model_len is greater than the sum of"
|
||||
" input_len and output_len for all requests.")
|
||||
|
||||
# Add the requests to the engine.
|
||||
prompts = []
|
||||
sampling_params = []
|
||||
@ -103,8 +113,8 @@ def main(args: argparse.Namespace):
|
||||
if args.dataset is None:
|
||||
# Synthesize a prompt with the given input length.
|
||||
prompt = "hi" * (args.input_len - 1)
|
||||
requests = [(prompt, args.input_len, args.output_len)
|
||||
for _ in range(args.num_prompts)]
|
||||
requests = [(prompt, args.input_len, args.output_len,
|
||||
get_random_flag()) for _ in range(args.num_prompts)]
|
||||
else:
|
||||
requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
|
||||
args.output_len)
|
||||
|
||||
@ -33,9 +33,10 @@ import os
|
||||
import random
|
||||
import time
|
||||
import warnings
|
||||
from collections.abc import AsyncGenerator, Collection
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, AsyncGenerator, Collection, Dict, List, Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@ -56,7 +57,7 @@ try:
|
||||
except ImportError:
|
||||
from argparse import ArgumentParser as FlexibleArgumentParser
|
||||
|
||||
from benchmark_utils import convert_to_pytorch_benchmark_format
|
||||
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
|
||||
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
|
||||
|
||||
@ -73,22 +74,22 @@ class BenchmarkMetrics:
|
||||
mean_ttft_ms: float
|
||||
median_ttft_ms: float
|
||||
std_ttft_ms: float
|
||||
percentiles_ttft_ms: List[Tuple[float, float]]
|
||||
percentiles_ttft_ms: list[tuple[float, float]]
|
||||
mean_tpot_ms: float
|
||||
median_tpot_ms: float
|
||||
std_tpot_ms: float
|
||||
percentiles_tpot_ms: List[Tuple[float, float]]
|
||||
percentiles_tpot_ms: list[tuple[float, float]]
|
||||
mean_itl_ms: float
|
||||
median_itl_ms: float
|
||||
std_itl_ms: float
|
||||
percentiles_itl_ms: List[Tuple[float, float]]
|
||||
percentiles_itl_ms: list[tuple[float, float]]
|
||||
# E2EL stands for end-to-end latency per request.
|
||||
# It is the time taken on the client side from sending
|
||||
# a request to receiving a complete response.
|
||||
mean_e2el_ms: float
|
||||
median_e2el_ms: float
|
||||
std_e2el_ms: float
|
||||
percentiles_e2el_ms: List[Tuple[float, float]]
|
||||
percentiles_e2el_ms: list[tuple[float, float]]
|
||||
|
||||
|
||||
def sample_sharegpt_requests(
|
||||
@ -96,7 +97,7 @@ def sample_sharegpt_requests(
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
fixed_output_len: Optional[int] = None,
|
||||
) -> List[Tuple[str, int, int, None]]:
|
||||
) -> list[tuple[str, int, int, None]]:
|
||||
# Load the dataset.
|
||||
with open(dataset_path, encoding='utf-8') as f:
|
||||
dataset = json.load(f)
|
||||
@ -110,7 +111,7 @@ def sample_sharegpt_requests(
|
||||
random.shuffle(dataset)
|
||||
|
||||
# Filter out sequences that are too long or too short
|
||||
filtered_dataset: List[Tuple[str, int, int]] = []
|
||||
filtered_dataset: list[tuple[str, int, int]] = []
|
||||
for i in range(len(dataset)):
|
||||
if len(filtered_dataset) == num_requests:
|
||||
break
|
||||
@ -139,7 +140,7 @@ def sample_burstgpt_requests(
|
||||
num_requests: int,
|
||||
random_seed: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
) -> List[Tuple[str, int, int, None]]:
|
||||
) -> list[tuple[str, int, int, None]]:
|
||||
df = pd.read_csv(dataset_path)
|
||||
gpt4_df = df[df["Model"] == "GPT-4"]
|
||||
# Remove the failed requests (i.e., response length is 0)
|
||||
@ -170,7 +171,7 @@ def sample_sonnet_requests(
|
||||
output_len: int,
|
||||
prefix_len: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
) -> List[Tuple[str, str, int, int, None]]:
|
||||
) -> list[tuple[str, str, int, int, None]]:
|
||||
assert (
|
||||
input_len > prefix_len
|
||||
), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'."
|
||||
@ -211,7 +212,7 @@ def sample_sonnet_requests(
|
||||
prefix_lines = poem_lines[:num_prefix_lines]
|
||||
|
||||
# Sample the rest of lines per request.
|
||||
sampled_requests: List[Tuple[str, int, int]] = []
|
||||
sampled_requests: list[tuple[str, int, int]] = []
|
||||
for _ in range(num_requests):
|
||||
num_lines_needed = num_input_lines - num_prefix_lines
|
||||
sampled_lines = "".join(prefix_lines +
|
||||
@ -238,8 +239,8 @@ def sample_vision_arena_requests(
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
fixed_output_len: Optional[int] = None,
|
||||
) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]:
|
||||
sampled_requests: List[Tuple[str, int, int, Dict[str,
|
||||
) -> list[tuple[str, str, int, Optional[dict[str, Collection[str]]]]]:
|
||||
sampled_requests: list[tuple[str, int, int, dict[str,
|
||||
Collection[str]]]] = []
|
||||
for data in dataset:
|
||||
if len(sampled_requests) == num_requests:
|
||||
@ -285,7 +286,7 @@ def sample_hf_requests(
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
random_seed: int,
|
||||
fixed_output_len: Optional[int] = None,
|
||||
) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]:
|
||||
) -> list[tuple[str, str, int, Optional[dict[str, Collection[str]]]]]:
|
||||
|
||||
# Special case for vision_arena dataset
|
||||
if dataset_path == 'lmarena-ai/vision-arena-bench-v0.1' \
|
||||
@ -307,7 +308,7 @@ def sample_hf_requests(
|
||||
"HF Dataset must have 'conversations' column.")
|
||||
filter_func = lambda x: len(x["conversations"]) >= 2
|
||||
filtered_dataset = dataset.shuffle(seed=random_seed).filter(filter_func)
|
||||
sampled_requests: List[Tuple[str, int, int, Dict[str,
|
||||
sampled_requests: list[tuple[str, int, int, dict[str,
|
||||
Collection[str]]]] = []
|
||||
for data in filtered_dataset:
|
||||
if len(sampled_requests) == num_requests:
|
||||
@ -370,7 +371,7 @@ def sample_random_requests(
|
||||
num_prompts: int,
|
||||
range_ratio: float,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
) -> List[Tuple[str, int, int]]:
|
||||
) -> list[tuple[str, int, int]]:
|
||||
prefix_token_ids = np.random.randint(0,
|
||||
tokenizer.vocab_size,
|
||||
size=prefix_len).tolist()
|
||||
@ -399,10 +400,10 @@ def sample_random_requests(
|
||||
|
||||
|
||||
async def get_request(
|
||||
input_requests: List[Tuple[str, int, int]],
|
||||
input_requests: list[tuple[str, int, int]],
|
||||
request_rate: float,
|
||||
burstiness: float = 1.0,
|
||||
) -> AsyncGenerator[Tuple[str, int, int], None]:
|
||||
) -> AsyncGenerator[tuple[str, int, int], None]:
|
||||
"""
|
||||
Asynchronously generates requests at a specified rate
|
||||
with OPTIONAL burstiness.
|
||||
@ -443,23 +444,23 @@ async def get_request(
|
||||
|
||||
|
||||
def calculate_metrics(
|
||||
input_requests: List[Tuple[str, int, int]],
|
||||
outputs: List[RequestFuncOutput],
|
||||
input_requests: list[tuple[str, int, int]],
|
||||
outputs: list[RequestFuncOutput],
|
||||
dur_s: float,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
selected_percentile_metrics: List[str],
|
||||
selected_percentiles: List[float],
|
||||
goodput_config_dict: Dict[str, float],
|
||||
) -> Tuple[BenchmarkMetrics, List[int]]:
|
||||
actual_output_lens: List[int] = []
|
||||
selected_percentile_metrics: list[str],
|
||||
selected_percentiles: list[float],
|
||||
goodput_config_dict: dict[str, float],
|
||||
) -> tuple[BenchmarkMetrics, list[int]]:
|
||||
actual_output_lens: list[int] = []
|
||||
total_input = 0
|
||||
completed = 0
|
||||
good_completed = 0
|
||||
itls: List[float] = []
|
||||
tpots: List[float] = []
|
||||
all_tpots: List[float] = []
|
||||
ttfts: List[float] = []
|
||||
e2els: List[float] = []
|
||||
itls: list[float] = []
|
||||
tpots: list[float] = []
|
||||
all_tpots: list[float] = []
|
||||
ttfts: list[float] = []
|
||||
e2els: list[float] = []
|
||||
for i in range(len(outputs)):
|
||||
if outputs[i].success:
|
||||
output_len = outputs[i].output_tokens
|
||||
@ -557,19 +558,19 @@ async def benchmark(
|
||||
model_id: str,
|
||||
model_name: str,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
input_requests: List[Tuple[str, int, int]],
|
||||
input_requests: list[tuple[str, int, int]],
|
||||
logprobs: Optional[int],
|
||||
best_of: int,
|
||||
request_rate: float,
|
||||
burstiness: float,
|
||||
disable_tqdm: bool,
|
||||
profile: bool,
|
||||
selected_percentile_metrics: List[str],
|
||||
selected_percentiles: List[str],
|
||||
selected_percentile_metrics: list[str],
|
||||
selected_percentiles: list[str],
|
||||
ignore_eos: bool,
|
||||
goodput_config_dict: Dict[str, float],
|
||||
goodput_config_dict: dict[str, float],
|
||||
max_concurrency: Optional[int],
|
||||
lora_modules: Optional[List[str]],
|
||||
lora_modules: Optional[list[str]],
|
||||
):
|
||||
if backend in ASYNC_REQUEST_FUNCS:
|
||||
request_func = ASYNC_REQUEST_FUNCS[backend]
|
||||
@ -652,7 +653,7 @@ async def benchmark(
|
||||
pbar=pbar)
|
||||
|
||||
benchmark_start_time = time.perf_counter()
|
||||
tasks: List[asyncio.Task] = []
|
||||
tasks: list[asyncio.Task] = []
|
||||
async for request in get_request(input_requests, request_rate, burstiness):
|
||||
prompt, prompt_len, output_len, mm_content = request
|
||||
req_model_id, req_model_name = model_id, model_name
|
||||
@ -674,7 +675,7 @@ async def benchmark(
|
||||
asyncio.create_task(
|
||||
limited_request_func(request_func_input=request_func_input,
|
||||
pbar=pbar)))
|
||||
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
|
||||
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
|
||||
|
||||
if profile:
|
||||
print("Stopping profiler...")
|
||||
@ -820,7 +821,7 @@ def parse_goodput(slo_pairs):
|
||||
|
||||
|
||||
def save_to_pytorch_benchmark_format(args: argparse.Namespace,
|
||||
results: Dict[str, Any],
|
||||
results: dict[str, Any],
|
||||
file_name: str) -> None:
|
||||
metrics = [
|
||||
"median_ttft_ms", "mean_ttft_ms", "std_ttft_ms", "p99_ttft_ms",
|
||||
@ -841,8 +842,7 @@ def save_to_pytorch_benchmark_format(args: argparse.Namespace,
|
||||
if pt_records:
|
||||
# Don't use json suffix here as we don't want CI to pick it up
|
||||
pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json"
|
||||
with open(pt_file, "w") as f:
|
||||
json.dump(pt_records, f)
|
||||
write_to_json(pt_file, pt_records)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
@ -867,18 +867,10 @@ def main(args: argparse.Namespace):
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
trust_remote_code=args.trust_remote_code)
|
||||
|
||||
if args.dataset is not None:
|
||||
warnings.warn(
|
||||
"The '--dataset' argument will be deprecated in the next "
|
||||
"release. Please use '--dataset-name' and "
|
||||
"'--dataset-path' in the future runs.",
|
||||
stacklevel=2)
|
||||
input_requests = sample_sharegpt_requests(
|
||||
dataset_path=args.dataset,
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
fixed_output_len=args.sharegpt_output_len,
|
||||
)
|
||||
if args.dataset_name is None:
|
||||
raise ValueError(
|
||||
"Please specify '--dataset-name' and the corresponding "
|
||||
"'--dataset-path' if required.")
|
||||
|
||||
elif args.dataset_name == "sharegpt":
|
||||
input_requests = sample_sharegpt_requests(
|
||||
@ -983,7 +975,7 @@ def main(args: argparse.Namespace):
|
||||
|
||||
# Save config and results to json
|
||||
if args.save_result:
|
||||
result_json: Dict[str, Any] = {}
|
||||
result_json: dict[str, Any] = {}
|
||||
|
||||
# Setup
|
||||
current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
@ -1052,13 +1044,6 @@ if __name__ == "__main__":
|
||||
default="/v1/completions",
|
||||
help="API endpoint.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the ShareGPT dataset, will be deprecated in the "
|
||||
"next release.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
|
||||
@ -9,7 +9,7 @@ On the server side, run one of the following commands:
|
||||
./launch_tgi_server.sh <your_model> <max_batch_total_tokens>
|
||||
|
||||
On the client side, run:
|
||||
python benchmarks/benchmark_serving.py \
|
||||
python benchmarks/benchmark_serving_guided.py \
|
||||
--backend <backend> \
|
||||
--model <your_model> \
|
||||
--dataset json \
|
||||
@ -30,8 +30,9 @@ import os
|
||||
import random
|
||||
import time
|
||||
import warnings
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncGenerator, List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
@ -66,22 +67,22 @@ class BenchmarkMetrics:
|
||||
mean_ttft_ms: float
|
||||
median_ttft_ms: float
|
||||
std_ttft_ms: float
|
||||
percentiles_ttft_ms: List[Tuple[float, float]]
|
||||
percentiles_ttft_ms: list[tuple[float, float]]
|
||||
mean_tpot_ms: float
|
||||
median_tpot_ms: float
|
||||
std_tpot_ms: float
|
||||
percentiles_tpot_ms: List[Tuple[float, float]]
|
||||
percentiles_tpot_ms: list[tuple[float, float]]
|
||||
mean_itl_ms: float
|
||||
median_itl_ms: float
|
||||
std_itl_ms: float
|
||||
percentiles_itl_ms: List[Tuple[float, float]]
|
||||
percentiles_itl_ms: list[tuple[float, float]]
|
||||
# E2EL stands for end-to-end latency per request.
|
||||
# It is the time taken on the client side from sending
|
||||
# a request to receiving a complete response.
|
||||
mean_e2el_ms: float
|
||||
median_e2el_ms: float
|
||||
std_e2el_ms: float
|
||||
percentiles_e2el_ms: List[Tuple[float, float]]
|
||||
percentiles_e2el_ms: list[tuple[float, float]]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -104,7 +105,7 @@ class SampleRequest:
|
||||
|
||||
|
||||
def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||
args: argparse.Namespace) -> List[SampleRequest]:
|
||||
args: argparse.Namespace) -> list[SampleRequest]:
|
||||
if args.dataset == 'json':
|
||||
if args.json_schema_path is None:
|
||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||
@ -187,7 +188,7 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||
]
|
||||
|
||||
elif args.dataset == "xgrammar_bench":
|
||||
requests: List[SampleRequest] = []
|
||||
requests: list[SampleRequest] = []
|
||||
dataset = datasets.load_dataset("NousResearch/json-mode-eval",
|
||||
split="train")
|
||||
print(f"dataset has {len(dataset)} entries")
|
||||
@ -214,10 +215,10 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||
|
||||
|
||||
async def get_request(
|
||||
input_requests: List[SampleRequest],
|
||||
input_requests: list[SampleRequest],
|
||||
request_rate: float,
|
||||
burstiness: float = 1.0,
|
||||
) -> AsyncGenerator[Tuple[int, SampleRequest], None]:
|
||||
) -> AsyncGenerator[tuple[int, SampleRequest], None]:
|
||||
"""
|
||||
Asynchronously generates requests at a specified rate
|
||||
with OPTIONAL burstiness.
|
||||
@ -258,22 +259,23 @@ async def get_request(
|
||||
|
||||
|
||||
def calculate_metrics(
|
||||
input_requests: List[Tuple[str, int, int]],
|
||||
outputs: List[RequestFuncOutput],
|
||||
input_requests: list[tuple[str, int, int]],
|
||||
outputs: list[RequestFuncOutput],
|
||||
dur_s: float,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
selected_percentile_metrics: List[str],
|
||||
selected_percentiles: List[float],
|
||||
) -> Tuple[BenchmarkMetrics, List[int]]:
|
||||
actual_output_lens: List[int] = []
|
||||
selected_percentile_metrics: list[str],
|
||||
selected_percentiles: list[float],
|
||||
goodput_config_dict: Optional[dict[str, float]] = None,
|
||||
) -> tuple[BenchmarkMetrics, list[int]]:
|
||||
actual_output_lens: list[int] = []
|
||||
total_input = 0
|
||||
completed = 0
|
||||
good_completed = 0
|
||||
itls: List[float] = []
|
||||
tpots: List[float] = []
|
||||
all_tpots: List[float] = []
|
||||
ttfts: List[float] = []
|
||||
e2els: List[float] = []
|
||||
itls: list[float] = []
|
||||
tpots: list[float] = []
|
||||
all_tpots: list[float] = []
|
||||
ttfts: list[float] = []
|
||||
e2els: list[float] = []
|
||||
for i in range(len(outputs)):
|
||||
if outputs[i].success:
|
||||
# We use the tokenizer to count the number of output tokens for all
|
||||
@ -287,10 +289,10 @@ def calculate_metrics(
|
||||
total_input += input_requests[i].prompt_len
|
||||
tpot = 0
|
||||
if output_len > 1:
|
||||
tpot = (outputs[i].latency - outputs[i].ttft) / (output_len -
|
||||
1)
|
||||
latency_minus_ttft = outputs[i].latency - outputs[i].ttft
|
||||
tpot = latency_minus_ttft / (output_len - 1)
|
||||
tpots.append(tpot)
|
||||
outputs[i].tpot = sum(tpots) / len(tpots) if len(tpots) else 0
|
||||
outputs[i].tpot = tpot
|
||||
# Note: if output_len <= 1, we regard tpot as 0 for goodput
|
||||
all_tpots.append(tpot)
|
||||
itls += outputs[i].itl
|
||||
@ -300,6 +302,28 @@ def calculate_metrics(
|
||||
else:
|
||||
actual_output_lens.append(0)
|
||||
|
||||
if goodput_config_dict:
|
||||
valid_metrics = []
|
||||
slo_values = []
|
||||
|
||||
if "ttft" in goodput_config_dict:
|
||||
valid_metrics.append(ttfts)
|
||||
slo_values.append(goodput_config_dict["ttft"] /
|
||||
MILLISECONDS_TO_SECONDS_CONVERSION)
|
||||
if "tpot" in goodput_config_dict:
|
||||
valid_metrics.append(all_tpots)
|
||||
slo_values.append(goodput_config_dict["tpot"] /
|
||||
MILLISECONDS_TO_SECONDS_CONVERSION)
|
||||
if "e2el" in goodput_config_dict:
|
||||
valid_metrics.append(e2els)
|
||||
slo_values.append(goodput_config_dict["e2el"] /
|
||||
MILLISECONDS_TO_SECONDS_CONVERSION)
|
||||
|
||||
for req_metric in zip(*valid_metrics):
|
||||
is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)])
|
||||
if is_good_req:
|
||||
good_completed += 1
|
||||
|
||||
if completed == 0:
|
||||
warnings.warn(
|
||||
"All requests failed. This is likely due to a misconfiguration "
|
||||
@ -345,17 +369,18 @@ async def benchmark(
|
||||
base_url: str,
|
||||
model_id: str,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
input_requests: List[SampleRequest],
|
||||
input_requests: list[SampleRequest],
|
||||
request_rate: float,
|
||||
burstiness: float,
|
||||
disable_tqdm: bool,
|
||||
profile: bool,
|
||||
selected_percentile_metrics: List[str],
|
||||
selected_percentiles: List[str],
|
||||
selected_percentile_metrics: list[str],
|
||||
selected_percentiles: list[str],
|
||||
ignore_eos: bool,
|
||||
max_concurrency: Optional[int],
|
||||
guided_decoding_ratio: float,
|
||||
guided_decoding_backend: str,
|
||||
goodput_config_dict: Optional[dict[str, float]] = None,
|
||||
):
|
||||
if backend in ASYNC_REQUEST_FUNCS:
|
||||
request_func = ASYNC_REQUEST_FUNCS[backend]
|
||||
@ -435,8 +460,8 @@ async def benchmark(
|
||||
pbar=pbar)
|
||||
|
||||
benchmark_start_time = time.perf_counter()
|
||||
tasks: List[asyncio.Task] = []
|
||||
expected: List[str] = []
|
||||
tasks: list[asyncio.Task] = []
|
||||
expected: list[str] = []
|
||||
async for i, request in get_request(input_requests, request_rate,
|
||||
burstiness):
|
||||
extra_body = prepare_extra_body(
|
||||
@ -455,7 +480,7 @@ async def benchmark(
|
||||
asyncio.create_task(
|
||||
limited_request_func(request_func_input=request_func_input,
|
||||
pbar=pbar)))
|
||||
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
|
||||
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
|
||||
|
||||
if profile:
|
||||
print("Stopping profiler...")
|
||||
@ -483,6 +508,7 @@ async def benchmark(
|
||||
tokenizer=tokenizer,
|
||||
selected_percentile_metrics=selected_percentile_metrics,
|
||||
selected_percentiles=selected_percentiles,
|
||||
goodput_config_dict=goodput_config_dict,
|
||||
)
|
||||
|
||||
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
|
||||
@ -494,6 +520,9 @@ async def benchmark(
|
||||
metrics.total_output))
|
||||
print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
|
||||
metrics.request_throughput))
|
||||
if goodput_config_dict:
|
||||
print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
|
||||
metrics.request_goodput))
|
||||
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
|
||||
metrics.output_throughput))
|
||||
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
|
||||
@ -617,6 +646,40 @@ def evaluate(ret, args):
|
||||
100) if len(not_none_scores) > 0 else None
|
||||
|
||||
|
||||
def parse_goodput(slo_pairs):
|
||||
goodput_config_dict = {}
|
||||
try:
|
||||
for slo_pair in slo_pairs:
|
||||
slo_name, slo_val = slo_pair.split(":")
|
||||
goodput_config_dict[slo_name] = float(slo_val)
|
||||
except ValueError as err:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"Invalid format found for service level objectives. "
|
||||
"Specify service level objectives for goodput as \"KEY:VALUE\" "
|
||||
"pairs, where the key is a metric name, and the value is a "
|
||||
"number in milliseconds.") from err
|
||||
return goodput_config_dict
|
||||
|
||||
|
||||
def check_goodput_args(args):
|
||||
goodput_config_dict = {}
|
||||
VALID_NAMES = ["ttft", "tpot", "e2el"]
|
||||
if args.goodput:
|
||||
goodput_config_dict = parse_goodput(args.goodput)
|
||||
for slo_name, slo_val in goodput_config_dict.items():
|
||||
if slo_name not in VALID_NAMES:
|
||||
raise ValueError(
|
||||
f"Invalid metric name found, {slo_name}: {slo_val}. "
|
||||
"The service level objective name should be one of "
|
||||
f"{str(VALID_NAMES)}. ")
|
||||
if slo_val < 0:
|
||||
raise ValueError(
|
||||
f"Invalid value found, {slo_name}: {slo_val}. "
|
||||
"The service level objective value should be "
|
||||
"non-negative.")
|
||||
return goodput_config_dict
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
print(args)
|
||||
random.seed(args.seed)
|
||||
@ -661,6 +724,8 @@ def main(args: argparse.Namespace):
|
||||
|
||||
input_requests = sample_requests(tokenizer, args)
|
||||
|
||||
goodput_config_dict = check_goodput_args(args)
|
||||
|
||||
benchmark_result, ret = asyncio.run(
|
||||
benchmark(
|
||||
backend=backend,
|
||||
@ -681,6 +746,7 @@ def main(args: argparse.Namespace):
|
||||
max_concurrency=args.max_concurrency,
|
||||
guided_decoding_ratio=args.guided_decoding_ratio,
|
||||
guided_decoding_backend=args.guided_decoding_backend,
|
||||
goodput_config_dict=goodput_config_dict,
|
||||
))
|
||||
|
||||
# Save config and results to json
|
||||
@ -865,6 +931,18 @@ if __name__ == "__main__":
|
||||
"Default value is \"99\". "
|
||||
"Use \"--percentile-metrics\" to select metrics.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--goodput",
|
||||
nargs="+",
|
||||
required=False,
|
||||
help="Specify service level objectives for goodput as \"KEY:VALUE\" "
|
||||
"pairs, where the key is a metric name, and the value is in "
|
||||
"milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, "
|
||||
"separated by spaces. Allowed request level metric names are "
|
||||
"\"ttft\", \"tpot\", \"e2el\". For more context on the definition of "
|
||||
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
|
||||
"and the blog: https://hao-ai-lab.github.io/blogs/distserve")
|
||||
|
||||
parser.add_argument("--no-guided-decoding",
|
||||
action='store_true',
|
||||
default=False,
|
||||
|
||||
@ -7,11 +7,11 @@ import os
|
||||
import random
|
||||
import time
|
||||
from functools import cache
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import uvloop
|
||||
from benchmark_utils import convert_to_pytorch_benchmark_format
|
||||
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||
@ -74,12 +74,12 @@ def lora_path_on_disk(lora_path: str) -> str:
|
||||
return get_adapter_absolute_path(lora_path)
|
||||
|
||||
|
||||
lora_tokenizer_cache: Dict[int, AnyTokenizer] = {}
|
||||
lora_tokenizer_cache: dict[int, AnyTokenizer] = {}
|
||||
|
||||
|
||||
def get_random_lora_request(
|
||||
args: argparse.Namespace
|
||||
) -> Tuple[LoRARequest, Optional[AnyTokenizer]]:
|
||||
) -> tuple[LoRARequest, Optional[AnyTokenizer]]:
|
||||
global lora_tokenizer_cache
|
||||
lora_id = random.randint(1, args.max_loras)
|
||||
lora_request = LoRARequest(lora_name=str(lora_id),
|
||||
@ -91,7 +91,7 @@ def get_random_lora_request(
|
||||
|
||||
|
||||
def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||
args: argparse.Namespace) -> List[SampleRequest]:
|
||||
args: argparse.Namespace) -> list[SampleRequest]:
|
||||
|
||||
dataset_path: str = args.dataset
|
||||
num_requests: int = args.num_prompts
|
||||
@ -109,7 +109,7 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||
random.shuffle(dataset)
|
||||
|
||||
# Filter out sequences that are too long or too short
|
||||
filtered_dataset: List[SampleRequest] = []
|
||||
filtered_dataset: list[SampleRequest] = []
|
||||
for data in tqdm(dataset,
|
||||
total=len(filtered_dataset),
|
||||
desc="sampling requests"):
|
||||
@ -165,16 +165,21 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||
|
||||
|
||||
def run_vllm(
|
||||
requests: List[SampleRequest],
|
||||
requests: list[SampleRequest],
|
||||
n: int,
|
||||
engine_args: EngineArgs,
|
||||
) -> float:
|
||||
from vllm import LLM, SamplingParams
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
|
||||
assert all(
|
||||
llm.llm_engine.model_config.max_model_len >= (
|
||||
request.prompt_len + request.expected_output_len)
|
||||
for request in requests), (
|
||||
"Please ensure that max_model_len is greater than the sum of"
|
||||
" prompt_len and expected_output_len for all requests.")
|
||||
# Add the requests to the engine.
|
||||
prompts: List[TextPrompt] = []
|
||||
sampling_params: List[SamplingParams] = []
|
||||
prompts: list[TextPrompt] = []
|
||||
sampling_params: list[SamplingParams] = []
|
||||
for request in requests:
|
||||
prompts.append(
|
||||
TextPrompt(prompt=request.prompt,
|
||||
@ -187,7 +192,7 @@ def run_vllm(
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
))
|
||||
lora_requests: Optional[List[LoRARequest]] = None
|
||||
lora_requests: Optional[list[LoRARequest]] = None
|
||||
if engine_args.enable_lora:
|
||||
lora_requests = [request.lora_request for request in requests]
|
||||
|
||||
@ -220,7 +225,7 @@ def run_vllm(
|
||||
|
||||
|
||||
async def run_vllm_async(
|
||||
requests: List[SampleRequest],
|
||||
requests: list[SampleRequest],
|
||||
n: int,
|
||||
engine_args: AsyncEngineArgs,
|
||||
disable_frontend_multiprocessing: bool = False,
|
||||
@ -229,11 +234,17 @@ async def run_vllm_async(
|
||||
|
||||
async with build_async_engine_client_from_engine_args(
|
||||
engine_args, disable_frontend_multiprocessing) as llm:
|
||||
assert all(
|
||||
llm.model_config.max_model_len >= (request.prompt_len +
|
||||
request.expected_output_len)
|
||||
for request in requests), (
|
||||
"Please ensure that max_model_len is greater than the sum of"
|
||||
" prompt_len and expected_output_len for all requests.")
|
||||
|
||||
# Add the requests to the engine.
|
||||
prompts: List[TextPrompt] = []
|
||||
sampling_params: List[SamplingParams] = []
|
||||
lora_requests: List[Optional[LoRARequest]] = []
|
||||
prompts: list[TextPrompt] = []
|
||||
sampling_params: list[SamplingParams] = []
|
||||
lora_requests: list[Optional[LoRARequest]] = []
|
||||
for request in requests:
|
||||
prompts.append(
|
||||
TextPrompt(prompt=request.prompt,
|
||||
@ -265,7 +276,7 @@ async def run_vllm_async(
|
||||
|
||||
|
||||
def run_hf(
|
||||
requests: List[SampleRequest],
|
||||
requests: list[SampleRequest],
|
||||
model: str,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
n: int,
|
||||
@ -281,7 +292,7 @@ def run_hf(
|
||||
|
||||
pbar = tqdm(total=len(requests))
|
||||
start = time.perf_counter()
|
||||
batch: List[str] = []
|
||||
batch: list[str] = []
|
||||
max_prompt_len = 0
|
||||
max_output_len = 0
|
||||
for i in range(len(requests)):
|
||||
@ -323,7 +334,7 @@ def run_hf(
|
||||
|
||||
|
||||
def run_mii(
|
||||
requests: List[SampleRequest],
|
||||
requests: list[SampleRequest],
|
||||
model: str,
|
||||
tensor_parallel_size: int,
|
||||
output_len: int,
|
||||
@ -341,7 +352,7 @@ def run_mii(
|
||||
|
||||
|
||||
def save_to_pytorch_benchmark_format(args: argparse.Namespace,
|
||||
results: Dict[str, Any]) -> None:
|
||||
results: dict[str, Any]) -> None:
|
||||
pt_records = convert_to_pytorch_benchmark_format(
|
||||
args=args,
|
||||
metrics={
|
||||
@ -355,8 +366,7 @@ def save_to_pytorch_benchmark_format(args: argparse.Namespace,
|
||||
if pt_records:
|
||||
# Don't use json suffix here as we don't want CI to pick it up
|
||||
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
|
||||
with open(pt_file, "w") as f:
|
||||
json.dump(pt_records, f)
|
||||
write_to_json(pt_file, pt_records)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
@ -469,8 +479,8 @@ if __name__ == "__main__":
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the dataset. The dataset is expected to "
|
||||
"be a json in form of List[Dict[..., conversations: "
|
||||
"List[Dict[..., value: <prompt_or_response>]]]]")
|
||||
"be a json in form of list[dict[..., conversations: "
|
||||
"list[dict[..., value: <prompt_or_response>]]]]")
|
||||
parser.add_argument("--input-len",
|
||||
type=int,
|
||||
default=None,
|
||||
|
||||
@ -1,13 +1,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
|
||||
def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
|
||||
metrics: Dict[str, List],
|
||||
extra_info: Dict[str, Any]) -> List:
|
||||
metrics: dict[str, list],
|
||||
extra_info: dict[str, Any]) -> list:
|
||||
"""
|
||||
Save the benchmark results in the format used by PyTorch OSS benchmark with
|
||||
on metric per record
|
||||
@ -34,6 +36,34 @@ def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
|
||||
"extra_info": extra_info,
|
||||
},
|
||||
}
|
||||
|
||||
tp = record["benchmark"]["extra_info"]["args"].get(
|
||||
"tensor_parallel_size")
|
||||
# Save tensor_parallel_size parameter if it's part of the metadata
|
||||
if not tp and "tensor_parallel_size" in extra_info:
|
||||
record["benchmark"]["extra_info"]["args"][
|
||||
"tensor_parallel_size"] = extra_info["tensor_parallel_size"]
|
||||
|
||||
records.append(record)
|
||||
|
||||
return records
|
||||
|
||||
|
||||
class InfEncoder(json.JSONEncoder):
|
||||
|
||||
def clear_inf(self, o: Any):
|
||||
if isinstance(o, dict):
|
||||
return {k: self.clear_inf(v) for k, v in o.items()}
|
||||
elif isinstance(o, list):
|
||||
return [self.clear_inf(v) for v in o]
|
||||
elif isinstance(o, float) and math.isinf(o):
|
||||
return "inf"
|
||||
return o
|
||||
|
||||
def iterencode(self, o: Any, *args, **kwargs) -> Any:
|
||||
return super().iterencode(self.clear_inf(o), *args, **kwargs)
|
||||
|
||||
|
||||
def write_to_json(filename: str, records: list) -> None:
|
||||
with open(filename, "w") as f:
|
||||
json.dump(records, f, cls=InfEncoder)
|
||||
|
||||
@ -5,7 +5,8 @@ import copy
|
||||
import itertools
|
||||
import pickle as pkl
|
||||
import time
|
||||
from typing import Callable, Iterable, List, Tuple
|
||||
from collections.abc import Iterable
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.utils.benchmark as TBenchmark
|
||||
@ -228,7 +229,7 @@ def print_timers(timers: Iterable[TMeasurement]):
|
||||
|
||||
|
||||
def run(dtype: torch.dtype,
|
||||
MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
|
||||
MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]:
|
||||
results = []
|
||||
for m, k, n in MKNs:
|
||||
timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm",
|
||||
@ -241,7 +242,7 @@ def run(dtype: torch.dtype,
|
||||
|
||||
# output makers
|
||||
def make_output(data: Iterable[TMeasurement],
|
||||
MKNs: Iterable[Tuple[int, int, int]],
|
||||
MKNs: Iterable[tuple[int, int, int]],
|
||||
base_description: str,
|
||||
timestamp=None):
|
||||
print(f"== All Results {base_description} ====")
|
||||
@ -282,7 +283,7 @@ def run_model_bench(args):
|
||||
for i, model in enumerate(args.models):
|
||||
print(f"[{i}] {model}")
|
||||
|
||||
def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]:
|
||||
def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]:
|
||||
KNs = []
|
||||
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
|
||||
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Cutlass bench utils
|
||||
from typing import Iterable, Tuple
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
|
||||
@ -27,7 +27,7 @@ def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
|
||||
def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
|
||||
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
k: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
a = torch.randn((m, k), device='cuda') * 5
|
||||
b = torch.randn((n, k), device='cuda').t() * 5
|
||||
|
||||
@ -63,7 +63,7 @@ def prune_to_2_4(tensor):
|
||||
|
||||
|
||||
def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int,
|
||||
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
k: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
a = torch.randn((m, k), device='cuda') * 5
|
||||
b = torch.randn((n, k), device='cuda').t() * 5
|
||||
|
||||
@ -88,7 +88,7 @@ def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int,
|
||||
|
||||
def make_n_rand_sparse_tensors(num_tensors: int, dtype: torch.dtype,
|
||||
m: int, n: int, k: int) -> \
|
||||
Tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]:
|
||||
tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]:
|
||||
ABs = []
|
||||
for _ in range(num_tensors):
|
||||
b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)
|
||||
|
||||
@ -5,7 +5,8 @@ import copy
|
||||
import itertools
|
||||
import pickle as pkl
|
||||
import time
|
||||
from typing import Callable, Iterable, List, Optional, Tuple
|
||||
from collections.abc import Iterable
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.utils.benchmark as TBenchmark
|
||||
@ -49,7 +50,7 @@ def bench_int8(
|
||||
n: int,
|
||||
label: str,
|
||||
sub_label: str,
|
||||
bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]:
|
||||
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]:
|
||||
"""Benchmark INT8-based kernels."""
|
||||
assert dtype == torch.int8
|
||||
a, b = make_rand_tensors(torch.int8, m, n, k)
|
||||
@ -101,7 +102,7 @@ def bench_fp8(
|
||||
n: int,
|
||||
label: str,
|
||||
sub_label: str,
|
||||
bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]:
|
||||
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]:
|
||||
"""Benchmark FP8-based kernels."""
|
||||
assert dtype == torch.float8_e4m3fn
|
||||
a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
|
||||
@ -180,7 +181,7 @@ def bench(dtype: torch.dtype,
|
||||
n: int,
|
||||
label: str,
|
||||
sub_label: str,
|
||||
bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]:
|
||||
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]:
|
||||
if dtype == torch.int8:
|
||||
return bench_int8(dtype, m, k, n, label, sub_label, bench_kernels)
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
@ -195,8 +196,8 @@ def print_timers(timers: Iterable[TMeasurement]):
|
||||
|
||||
|
||||
def run(dtype: torch.dtype,
|
||||
MKNs: Iterable[Tuple[int, int, int]],
|
||||
bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]:
|
||||
MKNs: Iterable[tuple[int, int, int]],
|
||||
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]:
|
||||
results = []
|
||||
for m, k, n in MKNs:
|
||||
timers = bench(dtype,
|
||||
@ -212,7 +213,7 @@ def run(dtype: torch.dtype,
|
||||
|
||||
|
||||
def make_output(data: Iterable[TMeasurement],
|
||||
MKNs: Iterable[Tuple[int, int, int]],
|
||||
MKNs: Iterable[tuple[int, int, int]],
|
||||
base_description: str,
|
||||
timestamp=None):
|
||||
print(f"== All Results {base_description} ====")
|
||||
@ -248,7 +249,7 @@ def run_model_bench(args):
|
||||
for i, model in enumerate(args.models):
|
||||
print(f"[{i}] {model}")
|
||||
|
||||
def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]:
|
||||
def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]:
|
||||
KNs = []
|
||||
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
|
||||
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
|
||||
|
||||
@ -2,9 +2,10 @@
|
||||
|
||||
import pickle as pkl
|
||||
import time
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from itertools import product
|
||||
from typing import Callable, Iterable, List, Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.utils.benchmark as TBenchmark
|
||||
@ -29,7 +30,7 @@ class bench_params_t:
|
||||
f'x DT {self.dtype}')
|
||||
|
||||
|
||||
def get_bench_params() -> List[bench_params_t]:
|
||||
def get_bench_params() -> list[bench_params_t]:
|
||||
## Test Fixtures
|
||||
NUM_TOKENS = [2**x for x in range(11)]
|
||||
HIDDEN_SIZES = list(range(1024, 8129, 1024))
|
||||
|
||||
@ -9,7 +9,7 @@ from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from itertools import product
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.utils.benchmark as TBenchmark
|
||||
@ -61,15 +61,15 @@ def make_rand_lora_weight_tensor(k: int,
|
||||
|
||||
|
||||
def make_rand_tensors(
|
||||
a_shape: Tuple[int],
|
||||
b_shape: Tuple[int],
|
||||
c_shape: Tuple[int],
|
||||
a_shape: tuple[int],
|
||||
b_shape: tuple[int],
|
||||
c_shape: tuple[int],
|
||||
a_dtype: torch.dtype,
|
||||
b_dtype: torch.dtype,
|
||||
c_dtype: torch.dtype,
|
||||
num_slices: int,
|
||||
device: str = "cuda",
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, list[torch.Tensor], torch.Tensor]:
|
||||
"""
|
||||
Make LoRA input/output matrices.
|
||||
"""
|
||||
@ -89,7 +89,7 @@ def make_prompt_lora_mapping(num_prompts: int, num_active_loras: int,
|
||||
sort_by_lora_id: bool,
|
||||
device: str) -> torch.Tensor:
|
||||
"""
|
||||
All prompts are mapped to a Lora ID in range [0, num_active_loras).
|
||||
All prompts are mapped to a LoRA ID in range [0, num_active_loras).
|
||||
where 0 refers to first lora, 1 refers to second lora and so on.
|
||||
"""
|
||||
assert num_active_loras > 0
|
||||
@ -135,7 +135,7 @@ def make_token_lora_mapping(num_tokens: int, num_prompts: int,
|
||||
|
||||
|
||||
def ref_group_gemm(ref_out: torch.Tensor, input: torch.Tensor,
|
||||
lora_weights: List[torch.Tensor],
|
||||
lora_weights: list[torch.Tensor],
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
prompt_lora_mapping_cpu: torch.Tensor, scaling: float,
|
||||
add_inputs: Optional[bool]):
|
||||
@ -204,7 +204,7 @@ class OpType(Enum):
|
||||
def is_expand_slice_fn(self) -> bool:
|
||||
return self in [OpType.BGMV_EXPAND_SLICE]
|
||||
|
||||
def num_slices(self) -> List[int]:
|
||||
def num_slices(self) -> list[int]:
|
||||
if self in [OpType.SGMV_EXPAND, OpType.SGMV_SHRINK]:
|
||||
# SGMV kernels supports slices
|
||||
return [1, 2, 3]
|
||||
@ -215,7 +215,7 @@ class OpType(Enum):
|
||||
raise ValueError(f"Unrecognized OpType {self}")
|
||||
|
||||
def mkn(self, batch_size: int, seq_length: int, hidden_size: int,
|
||||
lora_rank: int) -> Tuple[int, int, int]:
|
||||
lora_rank: int) -> tuple[int, int, int]:
|
||||
num_tokens = batch_size * seq_length
|
||||
if self.is_shrink_fn():
|
||||
m = num_tokens
|
||||
@ -230,7 +230,7 @@ class OpType(Enum):
|
||||
|
||||
def matmul_dtypes(
|
||||
self, op_dtype: torch.dtype
|
||||
) -> Tuple[torch.dtype, torch.dtype, torch.dtype]:
|
||||
) -> tuple[torch.dtype, torch.dtype, torch.dtype]:
|
||||
"""
|
||||
return a type, b type and c type for A x B = C
|
||||
"""
|
||||
@ -243,7 +243,7 @@ class OpType(Enum):
|
||||
def matmul_shapes(
|
||||
self, batch_size: int, seq_length: int, hidden_size: int,
|
||||
lora_rank: int, num_loras: int,
|
||||
num_slices: int) -> Tuple[Tuple[int], Tuple[int], Tuple[int]]:
|
||||
num_slices: int) -> tuple[tuple[int], tuple[int], tuple[int]]:
|
||||
"""
|
||||
Given num_slices, return the shapes of the A, B, and C matrices
|
||||
in A x B = C, for the op_type
|
||||
@ -268,7 +268,7 @@ class OpType(Enum):
|
||||
|
||||
def bench_fn(self) -> Callable:
|
||||
|
||||
def emulate_bgmv_expand_slice(kwargs_list: List[Dict[str, Any]]):
|
||||
def emulate_bgmv_expand_slice(kwargs_list: list[dict[str, Any]]):
|
||||
for x in kwargs_list:
|
||||
bgmv_expand_slice(**x)
|
||||
|
||||
@ -285,7 +285,7 @@ class OpType(Enum):
|
||||
raise ValueError(f"Unrecognized optype {self}")
|
||||
|
||||
def run_ref_group_gemm(self, output: torch.Tensor, input: torch.Tensor,
|
||||
lora_weights: List[torch.Tensor],
|
||||
lora_weights: list[torch.Tensor],
|
||||
**kwargs) -> Callable:
|
||||
"""Each benchmark operation expected the input, lora_weights and outputs
|
||||
in a slightly different format. Refer to self.matmul_shapes().
|
||||
@ -384,7 +384,7 @@ class BenchmarkTensors:
|
||||
"""
|
||||
# matmul tensors
|
||||
input: torch.Tensor
|
||||
lora_weights_lst: List[torch.Tensor]
|
||||
lora_weights_lst: list[torch.Tensor]
|
||||
output: torch.Tensor
|
||||
# metadata tensors
|
||||
seq_lens: torch.Tensor
|
||||
@ -469,7 +469,7 @@ class BenchmarkTensors:
|
||||
for i in range(len(self.lora_weights_lst)):
|
||||
self.lora_weights_lst[i] = to_device(self.lora_weights_lst[i])
|
||||
|
||||
def metadata(self) -> Tuple[int, int, int]:
|
||||
def metadata(self) -> tuple[int, int, int]:
|
||||
"""
|
||||
Return num_seqs, num_tokens and max_seq_len
|
||||
"""
|
||||
@ -505,7 +505,7 @@ class BenchmarkTensors:
|
||||
self.seq_lens = seq_lens.to(dtype=self.seq_lens.dtype)
|
||||
self.seq_start_loc = seq_start_loc.to(dtype=self.seq_start_loc.dtype)
|
||||
|
||||
def as_sgmv_shrink_kwargs(self) -> Dict[str, Any]:
|
||||
def as_sgmv_shrink_kwargs(self) -> dict[str, Any]:
|
||||
self.convert_to_sgmv_benchmark_tensors()
|
||||
self.sanity_check()
|
||||
self.to_device(self.input.device)
|
||||
@ -540,7 +540,7 @@ class BenchmarkTensors:
|
||||
'scaling': 1.0,
|
||||
}
|
||||
|
||||
def as_sgmv_expand_kwargs(self, add_inputs: bool) -> Dict[str, Any]:
|
||||
def as_sgmv_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]:
|
||||
|
||||
self.convert_to_sgmv_benchmark_tensors()
|
||||
self.sanity_check()
|
||||
@ -578,7 +578,7 @@ class BenchmarkTensors:
|
||||
'add_inputs': add_inputs,
|
||||
}
|
||||
|
||||
def as_bgmv_shrink_kwargs(self) -> Dict[str, Any]:
|
||||
def as_bgmv_shrink_kwargs(self) -> dict[str, Any]:
|
||||
assert len(self.lora_weights_lst) == 1
|
||||
self.to_device(self.input.device)
|
||||
|
||||
@ -634,7 +634,7 @@ class BenchmarkTensors:
|
||||
'add_inputs': add_inputs
|
||||
}
|
||||
|
||||
def as_bgmv_expand_slice_kwargs(self, add_inputs: bool) -> Dict[str, Any]:
|
||||
def as_bgmv_expand_slice_kwargs(self, add_inputs: bool) -> dict[str, Any]:
|
||||
|
||||
_, num_tokens, _, num_slices = self.metadata()
|
||||
# Sanity check shapes
|
||||
@ -670,7 +670,7 @@ class BenchmarkTensors:
|
||||
|
||||
def bench_fn_kwargs(self,
|
||||
op_type: OpType,
|
||||
add_inputs: Optional[bool] = None) -> Dict[str, Any]:
|
||||
add_inputs: Optional[bool] = None) -> dict[str, Any]:
|
||||
if op_type.is_shrink_fn():
|
||||
assert add_inputs is None
|
||||
else:
|
||||
@ -734,7 +734,7 @@ def bench_optype(ctx: BenchmarkContext,
|
||||
assert expand_fn_add_inputs is not None
|
||||
|
||||
# BenchmarkContext -> BenchmarkTensors
|
||||
bench_tensors : List[BenchmarkTensors] = \
|
||||
bench_tensors : list[BenchmarkTensors] = \
|
||||
[BenchmarkTensors.make(ctx, op_type) for _ in range(arg_pool_size)]
|
||||
for bt in bench_tensors:
|
||||
bt.sanity_check()
|
||||
@ -746,7 +746,7 @@ def bench_optype(ctx: BenchmarkContext,
|
||||
for bt in bench_tensors
|
||||
])
|
||||
|
||||
# BenchmarkTensors -> Dict (kwargs)
|
||||
# BenchmarkTensors -> dict (kwargs)
|
||||
kwargs_list = [
|
||||
bt.bench_fn_kwargs(op_type, add_inputs=expand_fn_add_inputs)
|
||||
for bt in bench_tensors
|
||||
@ -841,7 +841,7 @@ def use_cuda_graph_recommendation() -> str:
|
||||
"""
|
||||
|
||||
|
||||
def print_timers(timers: List[TMeasurement],
|
||||
def print_timers(timers: list[TMeasurement],
|
||||
args: Optional[argparse.Namespace] = None):
|
||||
compare = TBenchmark.Compare(timers)
|
||||
compare.print()
|
||||
@ -861,7 +861,7 @@ def print_timers(timers: List[TMeasurement],
|
||||
"small num_loras the goal should be to match the torch.mm numbers.")
|
||||
|
||||
|
||||
def run(args: argparse.Namespace, bench_ctxs: List[BenchmarkContext]):
|
||||
def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]):
|
||||
|
||||
if args.cuda_graph_nops is not None:
|
||||
assert args.cuda_graph_nops > 0
|
||||
@ -873,7 +873,7 @@ def run(args: argparse.Namespace, bench_ctxs: List[BenchmarkContext]):
|
||||
timers = []
|
||||
for bench_ctx in bench_ctxs:
|
||||
for seq_len in args.seq_lengths:
|
||||
bench_ops: List[OpType] = []
|
||||
bench_ops: list[OpType] = []
|
||||
if seq_len == 1:
|
||||
# bench all decode ops
|
||||
bench_ops = [op for op in args.op_types if op.is_decode_op()]
|
||||
@ -921,10 +921,10 @@ def run(args: argparse.Namespace, bench_ctxs: List[BenchmarkContext]):
|
||||
pickle.dump(timers, f)
|
||||
|
||||
|
||||
def as_benchmark_contexts(hidden_sizes: List[int], lora_ranks: List[int],
|
||||
args: argparse.Namespace) -> List[BenchmarkContext]:
|
||||
def as_benchmark_contexts(hidden_sizes: list[int], lora_ranks: list[int],
|
||||
args: argparse.Namespace) -> list[BenchmarkContext]:
|
||||
|
||||
ctxs: List[BenchmarkContext] = []
|
||||
ctxs: list[BenchmarkContext] = []
|
||||
for batch_size, hidden_size, lora_rank, num_loras, sort_by_lora_id in product( # noqa
|
||||
args.batch_sizes, list(hidden_sizes), lora_ranks, args.num_loras,
|
||||
args.sort_by_lora_id):
|
||||
@ -954,7 +954,7 @@ def run_list_bench(args: argparse.Namespace):
|
||||
f" LoRA Ranks {args.lora_ranks}")
|
||||
|
||||
# Get all benchmarking contexts
|
||||
bench_contexts: List[BenchmarkContext] = as_benchmark_contexts(
|
||||
bench_contexts: list[BenchmarkContext] = as_benchmark_contexts(
|
||||
hidden_sizes=args.hidden_sizes, lora_ranks=args.lora_ranks, args=args)
|
||||
|
||||
run(args, bench_contexts)
|
||||
@ -975,7 +975,7 @@ def run_range_bench(args: argparse.Namespace):
|
||||
f" LoRA Ranks {lora_ranks}")
|
||||
|
||||
# Get all benchmarking contexts
|
||||
bench_contexts: List[BenchmarkContext] = as_benchmark_contexts(
|
||||
bench_contexts: list[BenchmarkContext] = as_benchmark_contexts(
|
||||
hidden_sizes=hidden_sizes, lora_ranks=lora_ranks, args=args)
|
||||
|
||||
run(args, bench_contexts)
|
||||
@ -1002,7 +1002,7 @@ def run_model_bench(args: argparse.Namespace):
|
||||
f" LoRA Ranks {args.lora_ranks}")
|
||||
|
||||
# Get all benchmarking contexts
|
||||
bench_contexts: List[BenchmarkContext] = as_benchmark_contexts(
|
||||
bench_contexts: list[BenchmarkContext] = as_benchmark_contexts(
|
||||
hidden_sizes=hidden_sizes, lora_ranks=args.lora_ranks, args=args)
|
||||
|
||||
run(args, bench_contexts)
|
||||
|
||||
@ -7,9 +7,10 @@ import math
|
||||
import os
|
||||
import pickle as pkl
|
||||
import time
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from itertools import product
|
||||
from typing import Callable, Iterable, List, Optional, Tuple
|
||||
from typing import Callable, Optional
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
@ -102,8 +103,8 @@ def quantize_and_pack(atype: torch.dtype,
|
||||
return w_ref, w_q, w_s, w_zp
|
||||
|
||||
|
||||
def create_bench_tensors(shape: Tuple[int, int, int], types: TypeConfig,
|
||||
group_size: Optional[int]) -> List[BenchmarkTensors]:
|
||||
def create_bench_tensors(shape: tuple[int, int, int], types: TypeConfig,
|
||||
group_size: Optional[int]) -> list[BenchmarkTensors]:
|
||||
m, n, k = shape
|
||||
|
||||
# we want to make sure that weights don't fit into L2 cache between runs so
|
||||
@ -114,7 +115,7 @@ def create_bench_tensors(shape: Tuple[int, int, int], types: TypeConfig,
|
||||
|
||||
a = rand_data((m, k), types.act_type, scale=5)
|
||||
|
||||
benchmark_tensors: List[BenchmarkTensors] = []
|
||||
benchmark_tensors: list[BenchmarkTensors] = []
|
||||
for _ in range(num_weights):
|
||||
w = rand_data((k, n), types.act_type, scale=5)
|
||||
|
||||
@ -276,7 +277,7 @@ def machete_create_bench_fn(bt: BenchmarkTensors,
|
||||
|
||||
|
||||
def bench_fns(label: str, sub_label: str, description: str,
|
||||
fns: List[Callable]):
|
||||
fns: list[Callable]):
|
||||
|
||||
min_run_time = 1 if not NVTX_PROFILE else 0.1
|
||||
res = TBenchmark.Timer(
|
||||
@ -311,7 +312,7 @@ def bench(types: TypeConfig,
|
||||
n: int,
|
||||
label: str,
|
||||
sub_label: str,
|
||||
sweep_schedules: bool = True) -> List[TMeasurement]:
|
||||
sweep_schedules: bool = True) -> list[TMeasurement]:
|
||||
benchmark_tensors = create_bench_tensors((m, n, k), types, group_size)
|
||||
sub_label += f", L={len(benchmark_tensors)}"
|
||||
|
||||
@ -414,12 +415,12 @@ def bench(types: TypeConfig,
|
||||
|
||||
|
||||
# runner
|
||||
def print_timers(timers: List[TMeasurement]):
|
||||
def print_timers(timers: list[TMeasurement]):
|
||||
compare = TBenchmark.Compare(timers)
|
||||
compare.print()
|
||||
|
||||
|
||||
def run(args, MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
|
||||
def run(args, MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]:
|
||||
types = TypeConfig(
|
||||
act_type=args.act_type,
|
||||
weight_type=scalar_types.uint4b8 if args.group_zero_type is None \
|
||||
@ -431,7 +432,7 @@ def run(args, MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
|
||||
token_scale_type=args.token_scale_type,
|
||||
)
|
||||
|
||||
results: List[TMeasurement] = []
|
||||
results: list[TMeasurement] = []
|
||||
for m, k, n in MKNs:
|
||||
timers = bench(types,
|
||||
args.group_size,
|
||||
@ -449,8 +450,8 @@ def run(args, MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
|
||||
|
||||
# output makers
|
||||
def make_output(
|
||||
data: List[TMeasurement],
|
||||
MKNs: Iterable[Tuple[int, int, int]],
|
||||
data: list[TMeasurement],
|
||||
MKNs: Iterable[tuple[int, int, int]],
|
||||
base_description: str,
|
||||
timestamp=None,
|
||||
):
|
||||
@ -497,7 +498,7 @@ def run_model_bench(args):
|
||||
for i, model in enumerate(args.models):
|
||||
print(f"[{i}] {model}")
|
||||
|
||||
def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]:
|
||||
def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]:
|
||||
KNs = []
|
||||
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
|
||||
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.utils.benchmark as benchmark
|
||||
from benchmark_shapes import WEIGHT_SHAPES
|
||||
@ -10,6 +8,8 @@ from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
|
||||
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
|
||||
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, ALLSPARK_SUPPORTED_QUANT_TYPES)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
MARLIN_SUPPORTED_GROUP_SIZES, query_marlin_supported_quant_types)
|
||||
@ -18,18 +18,18 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
|
||||
marlin_24_quantize)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
gptq_pack, gptq_quantize_weights, sort_weights)
|
||||
gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights)
|
||||
from vllm.scalar_type import ScalarType
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
|
||||
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
|
||||
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
|
||||
|
||||
ACT_ORDER_OPTS = [False, True]
|
||||
K_FULL_OPTS = [False, True]
|
||||
|
||||
|
||||
def bench_run(results: List[benchmark.Measurement], model: str,
|
||||
def bench_run(results: list[benchmark.Measurement], model: str,
|
||||
act_order: bool, is_k_full: bool, quant_type: ScalarType,
|
||||
group_size: int, size_m: int, size_k: int, size_n: int):
|
||||
label = "Quant Matmul"
|
||||
@ -81,6 +81,27 @@ def bench_run(results: List[benchmark.Measurement], model: str,
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL)
|
||||
marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int)
|
||||
|
||||
# AllSpark W8A16 quant
|
||||
as_supported_case = (quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES
|
||||
and group_size == -1 and not act_order and is_k_full)
|
||||
if as_supported_case:
|
||||
properties = torch.cuda.get_device_properties(b.device.index)
|
||||
sm_count = properties.multi_processor_count
|
||||
sm_version = properties.major * 10 + properties.minor
|
||||
|
||||
supported_arch = (sm_version >= 80 and sm_version < 90)
|
||||
as_supported_case = as_supported_case and supported_arch
|
||||
if supported_arch:
|
||||
has_zp = False
|
||||
w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size,
|
||||
has_zp)
|
||||
qw = qw.to(torch.uint8)
|
||||
|
||||
qw_reorder, s_reorder, zp_reorder = \
|
||||
ops.allspark_repack_weight(
|
||||
qw, s, zp, has_zp)
|
||||
CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
|
||||
|
||||
globals = {
|
||||
# Gen params
|
||||
"quant_type": quant_type,
|
||||
@ -109,10 +130,19 @@ def bench_run(results: List[benchmark.Measurement], model: str,
|
||||
# GPTQ params
|
||||
"q_w_gptq": q_w_gptq,
|
||||
"repack_sort_indices": repack_sort_indices,
|
||||
# AllSpark W8A16 params
|
||||
"qw_reorder": qw_reorder if as_supported_case else None,
|
||||
"s_reorder": s_reorder if as_supported_case else None,
|
||||
"zp_reorder": zp_reorder if as_supported_case else None,
|
||||
"sm_count": sm_count if as_supported_case else None,
|
||||
"sm_version": sm_version if as_supported_case else None,
|
||||
"CUBLAS_M_THRESHOLD":
|
||||
CUBLAS_M_THRESHOLD if as_supported_case else None,
|
||||
# Kernels
|
||||
"gptq_marlin_gemm": ops.gptq_marlin_gemm,
|
||||
"gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
|
||||
"gptq_marlin_repack": ops.gptq_marlin_repack,
|
||||
"allspark_w8a16_gemm": ops.allspark_w8a16_gemm,
|
||||
}
|
||||
|
||||
min_run_time = 1
|
||||
@ -172,13 +202,24 @@ def bench_run(results: List[benchmark.Measurement], model: str,
|
||||
description="gptq_marlin_repack",
|
||||
).blocked_autorange(min_run_time=min_run_time))
|
||||
|
||||
if as_supported_case:
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt=
|
||||
"output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="allspark_w8a16_gemm_fp32",
|
||||
).blocked_autorange(min_run_time=min_run_time))
|
||||
|
||||
|
||||
def main(args):
|
||||
print("Benchmarking models:")
|
||||
for i, model in enumerate(args.models):
|
||||
print(f"[{i}] {model}")
|
||||
|
||||
results: List[benchmark.Measurement] = []
|
||||
results: list[benchmark.Measurement] = []
|
||||
|
||||
for model in args.models:
|
||||
for layer in WEIGHT_SHAPES[model]:
|
||||
|
||||
@ -2,9 +2,10 @@
|
||||
|
||||
import argparse
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from datetime import datetime
|
||||
from itertools import product
|
||||
from typing import Any, Dict, List, Tuple, TypedDict
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import ray
|
||||
import torch
|
||||
@ -40,6 +41,7 @@ def benchmark_config(
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
num_iters: int = 100,
|
||||
block_quant_shape: List[int] = None,
|
||||
) -> float:
|
||||
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
@ -81,8 +83,24 @@ def benchmark_config(
|
||||
dtype=torch.float32)
|
||||
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
|
||||
if use_fp8_w8a8:
|
||||
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||
if block_quant_shape:
|
||||
block_n, block_k = block_quant_shape[0], block_quant_shape[1]
|
||||
E = num_experts
|
||||
N = shard_intermediate_size // 2
|
||||
K = hidden_size
|
||||
factor_for_scale = 1e-2
|
||||
n_tiles_w1 = (2 * N + block_n - 1) // block_n
|
||||
n_tiles_w2 = (K + block_n - 1) // block_n
|
||||
k_tiles_w1 = (K + block_k - 1) // block_k
|
||||
k_tiles_w2 = (N + block_k - 1) // block_k
|
||||
w1_scale = torch.rand((E, n_tiles_w1, k_tiles_w1),
|
||||
dtype=torch.float32) * factor_for_scale
|
||||
w2_scale = torch.rand((E, n_tiles_w2, k_tiles_w2),
|
||||
dtype=torch.float32) * factor_for_scale
|
||||
else:
|
||||
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||
|
||||
a1_scale = torch.randn(1, dtype=torch.float32)
|
||||
a2_scale = torch.randn(1, dtype=torch.float32)
|
||||
|
||||
@ -111,6 +129,7 @@ def benchmark_config(
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_quant_shape,
|
||||
)
|
||||
|
||||
# JIT compilation & warmup
|
||||
@ -132,7 +151,7 @@ def benchmark_config(
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
latencies: List[float] = []
|
||||
latencies: list[float] = []
|
||||
for i in range(num_iters):
|
||||
prepare(i)
|
||||
torch.cuda.synchronize()
|
||||
@ -175,8 +194,9 @@ def get_rocm_tuning_space(use_fp16):
|
||||
return param_ranges
|
||||
|
||||
|
||||
def get_configs_compute_bound(use_fp16) -> List[Dict[str, int]]:
|
||||
configs: List[BenchmarkConfig] = []
|
||||
def get_configs_compute_bound(use_fp16,
|
||||
block_quant_shape) -> list[dict[str, int]]:
|
||||
configs: list[BenchmarkConfig] = []
|
||||
|
||||
if current_platform.is_rocm():
|
||||
param_ranges = get_rocm_tuning_space(use_fp16)
|
||||
@ -204,17 +224,27 @@ def get_configs_compute_bound(use_fp16) -> List[Dict[str, int]]:
|
||||
for config_values in product(*values):
|
||||
config = dict(zip(keys, config_values))
|
||||
configs.append(config)
|
||||
|
||||
# Remove configs that are not compatible with fp8 block quantization
|
||||
# BLOCK_SIZE_K must be a multiple of block_k
|
||||
# BLOCK_SIZE_N must be a multiple of block_n
|
||||
if block_quant_shape is not None and not use_fp16:
|
||||
block_n, block_k = block_quant_shape[0], block_quant_shape[1]
|
||||
for config in configs[:]:
|
||||
if config["BLOCK_SIZE_K"] % block_k != 0 or config[
|
||||
"BLOCK_SIZE_N"] % block_n != 0:
|
||||
configs.remove(config)
|
||||
return configs
|
||||
|
||||
|
||||
def prune_rocm_search_space(num_tokens, shard_intermediate_size, hidden_size,
|
||||
search_space, is_fp16):
|
||||
search_space, is_fp16, topk):
|
||||
N1, K1 = shard_intermediate_size, hidden_size
|
||||
N2, K2 = hidden_size, shard_intermediate_size // 2
|
||||
pruned_space_1 = prune_rocm_configs(num_tokens * 2, N1, K1, search_space,
|
||||
is_fp16)
|
||||
pruned_space_2 = prune_rocm_configs(num_tokens * 2, N2, K2, search_space,
|
||||
is_fp16)
|
||||
pruned_space_1 = prune_rocm_configs(num_tokens * topk, N1, K1,
|
||||
search_space, is_fp16)
|
||||
pruned_space_2 = prune_rocm_configs(num_tokens * topk, N2, K2,
|
||||
search_space, is_fp16)
|
||||
search_space = merge_unique_dicts(pruned_space_1, pruned_space_2)
|
||||
return search_space
|
||||
|
||||
@ -335,7 +365,7 @@ class BenchmarkWorker:
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
) -> Tuple[Dict[str, int], float]:
|
||||
) -> tuple[dict[str, int], float]:
|
||||
current_platform.seed_everything(self.seed)
|
||||
dtype_str = get_config_dtype_str(dtype,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
@ -371,8 +401,9 @@ class BenchmarkWorker:
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
search_space: List[Dict[str, int]],
|
||||
) -> Dict[str, int]:
|
||||
search_space: list[dict[str, int]],
|
||||
block_quant_shape: list[int],
|
||||
) -> dict[str, int]:
|
||||
best_config = None
|
||||
best_time = float("inf")
|
||||
if current_platform.is_rocm():
|
||||
@ -380,21 +411,24 @@ class BenchmarkWorker:
|
||||
search_space = prune_rocm_search_space(num_tokens,
|
||||
shard_intermediate_size,
|
||||
hidden_size, search_space,
|
||||
is_fp16)
|
||||
is_fp16, topk)
|
||||
|
||||
with torch.cuda.device(self.device_id):
|
||||
with torch.cuda.device(self.device_id) if current_platform.is_rocm(
|
||||
) else nullcontext():
|
||||
for config in tqdm(search_space):
|
||||
try:
|
||||
kernel_time = benchmark_config(config,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
num_iters=20)
|
||||
kernel_time = benchmark_config(
|
||||
config,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
num_iters=20,
|
||||
block_quant_shape=block_quant_shape)
|
||||
except triton.runtime.autotuner.OutOfResources:
|
||||
# Some configurations may be invalid and fail to compile.
|
||||
continue
|
||||
@ -434,10 +468,10 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
|
||||
}
|
||||
|
||||
|
||||
def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int,
|
||||
def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,
|
||||
shard_intermediate_size: int, hidden_size: int, topk: int,
|
||||
dtype: torch.dtype, use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool) -> None:
|
||||
dtype: torch.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool,
|
||||
block_quant_shape: List[int]) -> None:
|
||||
dtype_str = get_config_dtype_str(dtype,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_fp8_w8a8=use_fp8_w8a8)
|
||||
@ -445,7 +479,7 @@ def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int,
|
||||
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||
# is the intermediate size after silu_and_mul.
|
||||
filename = get_config_file_name(num_experts, shard_intermediate_size // 2,
|
||||
dtype_str)
|
||||
dtype_str, block_quant_shape)
|
||||
|
||||
print(f"Writing best config to {filename}...")
|
||||
with open(filename, "w") as f:
|
||||
@ -455,7 +489,7 @@ def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int,
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
print(args)
|
||||
|
||||
block_quant_shape = None
|
||||
config = AutoConfig.from_pretrained(
|
||||
args.model, trust_remote_code=args.trust_remote_code)
|
||||
if config.architectures[0] == "DbrxForCausalLM":
|
||||
@ -468,11 +502,13 @@ def main(args: argparse.Namespace):
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
elif config.architectures[0] == "DeepseekV3ForCausalLM":
|
||||
elif (config.architectures[0] == "DeepseekV3ForCausalLM"
|
||||
or config.architectures[0] == "DeepseekV2ForCausalLM"):
|
||||
E = config.n_routed_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
block_quant_shape = config.quantization_config['weight_block_size']
|
||||
else:
|
||||
# Default: Mixtral.
|
||||
E = config.num_local_experts
|
||||
@ -497,7 +533,7 @@ def main(args: argparse.Namespace):
|
||||
num_gpus = int(ray.available_resources()["GPU"])
|
||||
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
|
||||
|
||||
def _distribute(method: str, inputs: List[Any]) -> List[Any]:
|
||||
def _distribute(method: str, inputs: list[Any]) -> list[Any]:
|
||||
outputs = []
|
||||
worker_idx = 0
|
||||
for input_args in inputs:
|
||||
@ -510,27 +546,30 @@ def main(args: argparse.Namespace):
|
||||
|
||||
if args.tune:
|
||||
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
|
||||
search_space = get_configs_compute_bound(is_fp16)
|
||||
search_space = get_configs_compute_bound(is_fp16, block_quant_shape)
|
||||
print(f"Start tuning over {len(search_space)} configurations...")
|
||||
|
||||
start = time.time()
|
||||
configs = _distribute(
|
||||
"tune", [(batch_size, E, shard_intermediate_size, hidden_size,
|
||||
topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space)
|
||||
for batch_size in batch_sizes])
|
||||
"tune",
|
||||
[(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype,
|
||||
use_fp8_w8a8, use_int8_w8a16, search_space, block_quant_shape)
|
||||
for batch_size in batch_sizes])
|
||||
best_configs = {
|
||||
M: sort_config(config)
|
||||
for M, config in zip(batch_sizes, configs)
|
||||
}
|
||||
save_configs(best_configs, E, shard_intermediate_size, hidden_size,
|
||||
topk, dtype, use_fp8_w8a8, use_int8_w8a16)
|
||||
topk, dtype, use_fp8_w8a8, use_int8_w8a16,
|
||||
block_quant_shape)
|
||||
end = time.time()
|
||||
print(f"Tuning took {end - start:.2f} seconds")
|
||||
else:
|
||||
outputs = _distribute(
|
||||
"benchmark", [(batch_size, E, shard_intermediate_size, hidden_size,
|
||||
topk, dtype, use_fp8_w8a8, use_int8_w8a16)
|
||||
for batch_size in batch_sizes])
|
||||
"benchmark",
|
||||
[(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype,
|
||||
use_fp8_w8a8, use_int8_w8a16, block_quant_shape)
|
||||
for batch_size in batch_sizes])
|
||||
|
||||
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
|
||||
print(f"Batch size: {batch_size}, config: {config}")
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
import random
|
||||
import time
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -11,8 +11,9 @@ from vllm.platforms import current_platform
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
|
||||
create_kv_caches_with_random)
|
||||
|
||||
NUM_BLOCKS = 1024
|
||||
NUM_BLOCKS = 128 * 1024
|
||||
PARTITION_SIZE = 512
|
||||
PARTITION_SIZE_ROCM = 256
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@ -54,7 +55,7 @@ def main(
|
||||
|
||||
# Create the block tables.
|
||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||
block_tables_lst: List[List[int]] = []
|
||||
block_tables_lst: list[list[int]] = []
|
||||
for _ in range(num_seqs):
|
||||
block_table = [
|
||||
random.randint(0, NUM_BLOCKS - 1)
|
||||
@ -80,6 +81,12 @@ def main(
|
||||
# Prepare for the paged attention kernel.
|
||||
output = torch.empty_like(query)
|
||||
if version == "v2":
|
||||
if current_platform.is_rocm():
|
||||
global PARTITION_SIZE
|
||||
if not args.custom_paged_attn:
|
||||
PARTITION_SIZE = 1024
|
||||
else:
|
||||
PARTITION_SIZE = PARTITION_SIZE_ROCM
|
||||
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_query_heads, num_partitions, head_size),
|
||||
@ -123,25 +130,46 @@ def main(
|
||||
v_scale,
|
||||
)
|
||||
elif version == "v2":
|
||||
ops.paged_attention_v2(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
if not args.custom_paged_attn:
|
||||
ops.paged_attention_v2(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
else:
|
||||
ops.paged_attention_rocm(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid version: {version}")
|
||||
torch.cuda.synchronize()
|
||||
@ -195,6 +223,9 @@ if __name__ == '__main__':
|
||||
help="Data type for kv cache storage. If 'auto', will use model "
|
||||
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
|
||||
"ROCm (AMD GPU) supports fp8 (=fp8_e4m3)")
|
||||
parser.add_argument("--custom-paged-attn",
|
||||
action="store_true",
|
||||
help="Use custom paged attention")
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import itertools
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import triton
|
||||
@ -22,7 +22,7 @@ class HuggingFaceRMSNorm(nn.Module):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
orig_dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
if residual is not None:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from itertools import accumulate
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import nvtx
|
||||
import torch
|
||||
@ -39,7 +39,7 @@ def benchmark_rope_kernels_multi_lora(
|
||||
})
|
||||
# non-batched RoPE takes only one scaling factor, we create multiple
|
||||
# instances to simulate the same behavior
|
||||
non_batched_ropes: List[RotaryEmbedding] = []
|
||||
non_batched_ropes: list[RotaryEmbedding] = []
|
||||
for scaling_factor in scaling_factors:
|
||||
non_batched_ropes.append(
|
||||
get_rope(head_size, rotary_dim, max_position, base, is_neox_style,
|
||||
|
||||
@ -4,7 +4,6 @@ import math
|
||||
import pickle
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from typing import List
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
@ -23,7 +22,7 @@ if __name__ == "__main__":
|
||||
|
||||
with open(args.filename, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
raw_results: List[TMeasurement] = data["results"]
|
||||
raw_results: list[TMeasurement] = data["results"]
|
||||
|
||||
results = defaultdict(lambda: list())
|
||||
for v in raw_results:
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, Callable, Iterable, Optional
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.utils.benchmark as TBenchmark
|
||||
|
||||
66
cmake/external_projects/flashmla.cmake
Normal file
66
cmake/external_projects/flashmla.cmake
Normal file
@ -0,0 +1,66 @@
|
||||
include(FetchContent)
|
||||
|
||||
# If FLASH_MLA_SRC_DIR is set, flash-mla is installed from that directory
|
||||
# instead of downloading.
|
||||
# It can be set as an environment variable or passed as a cmake argument.
|
||||
# The environment variable takes precedence.
|
||||
if (DEFINED ENV{FLASH_MLA_SRC_DIR})
|
||||
set(FLASH_MLA_SRC_DIR $ENV{FLASH_MLA_SRC_DIR})
|
||||
endif()
|
||||
|
||||
if(FLASH_MLA_SRC_DIR)
|
||||
FetchContent_Declare(
|
||||
flashmla
|
||||
SOURCE_DIR ${FLASH_MLA_SRC_DIR}
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
)
|
||||
else()
|
||||
FetchContent_Declare(
|
||||
flashmla
|
||||
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git
|
||||
GIT_TAG 575f7724b9762f265bbee5889df9c7d630801845
|
||||
GIT_PROGRESS TRUE
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
)
|
||||
endif()
|
||||
|
||||
|
||||
FetchContent_MakeAvailable(flashmla)
|
||||
message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}")
|
||||
|
||||
# The FlashMLA kernels only work on hopper and require CUDA 12.3 or later.
|
||||
# Only build FlashMLA kernels if we are building for something compatible with
|
||||
# sm90a
|
||||
cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
|
||||
set(FlashMLA_SOURCES
|
||||
${flashmla_SOURCE_DIR}/csrc/flash_api.cpp
|
||||
${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_bf16_sm90.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_fp16_sm90.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_metadata.cu)
|
||||
|
||||
set(FlashMLA_INCLUDES
|
||||
${flashmla_SOURCE_DIR}/csrc/cutlass/include
|
||||
${flashmla_SOURCE_DIR}/csrc/include)
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${FlashMLA_SOURCES}"
|
||||
CUDA_ARCHS "${FLASH_MLA_ARCHS}")
|
||||
|
||||
define_gpu_extension_target(
|
||||
_flashmla_C
|
||||
DESTINATION vllm
|
||||
LANGUAGE ${VLLM_GPU_LANG}
|
||||
SOURCES ${FlashMLA_SOURCES}
|
||||
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
||||
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
||||
INCLUDE_DIRECTORIES ${FlashMLA_INCLUDES}
|
||||
USE_SABI 3
|
||||
WITH_SOABI)
|
||||
else()
|
||||
# Create an empty target for setup.py when not targeting sm90a systems
|
||||
add_custom_target(_flashmla_C)
|
||||
endif()
|
||||
|
||||
67
cmake/external_projects/vllm_flash_attn.cmake
Normal file
67
cmake/external_projects/vllm_flash_attn.cmake
Normal file
@ -0,0 +1,67 @@
|
||||
# vLLM flash attention requires VLLM_GPU_ARCHES to contain the set of target
|
||||
# arches in the CMake syntax (75-real, 89-virtual, etc), since we clear the
|
||||
# arches in the CUDA case (and instead set the gencodes on a per file basis)
|
||||
# we need to manually set VLLM_GPU_ARCHES here.
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
foreach(_ARCH ${CUDA_ARCHS})
|
||||
string(REPLACE "." "" _ARCH "${_ARCH}")
|
||||
list(APPEND VLLM_GPU_ARCHES "${_ARCH}-real")
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
#
|
||||
# Build vLLM flash attention from source
|
||||
#
|
||||
# IMPORTANT: This has to be the last thing we do, because vllm-flash-attn uses the same macros/functions as vLLM.
|
||||
# Because functions all belong to the global scope, vllm-flash-attn's functions overwrite vLLMs.
|
||||
# They should be identical but if they aren't, this is a massive footgun.
|
||||
#
|
||||
# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place.
|
||||
# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2) or --component _vllm_fa3_C (for FA3).
|
||||
# If no component is specified, vllm-flash-attn is still installed.
|
||||
|
||||
# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading.
|
||||
# This is to enable local development of vllm-flash-attn within vLLM.
|
||||
# It can be set as an environment variable or passed as a cmake argument.
|
||||
# The environment variable takes precedence.
|
||||
if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR})
|
||||
set(VLLM_FLASH_ATTN_SRC_DIR $ENV{VLLM_FLASH_ATTN_SRC_DIR})
|
||||
endif()
|
||||
|
||||
if(VLLM_FLASH_ATTN_SRC_DIR)
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn SOURCE_DIR
|
||||
${VLLM_FLASH_ATTN_SRC_DIR}
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
)
|
||||
else()
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn
|
||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||
GIT_TAG 720c94869cf2e0ff5a706e9c7f1dce0939686ade
|
||||
GIT_PROGRESS TRUE
|
||||
# Don't share the vllm-flash-attn build between build types
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
)
|
||||
endif()
|
||||
|
||||
|
||||
# Fetch the vllm-flash-attn library
|
||||
FetchContent_MakeAvailable(vllm-flash-attn)
|
||||
message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}")
|
||||
|
||||
# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in
|
||||
# case only one is built, in the case both are built redundant work is done)
|
||||
install(
|
||||
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
|
||||
DESTINATION vllm_flash_attn
|
||||
COMPONENT _vllm_fa2_C
|
||||
FILES_MATCHING PATTERN "*.py"
|
||||
)
|
||||
|
||||
install(
|
||||
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
|
||||
DESTINATION vllm_flash_attn
|
||||
COMPONENT _vllm_fa3_C
|
||||
FILES_MATCHING PATTERN "*.py"
|
||||
)
|
||||
@ -39,3 +39,10 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
|
||||
// Just for unittest
|
||||
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||
const double scale, const std::string& kv_cache_dtype);
|
||||
|
||||
void gather_cache(
|
||||
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
|
||||
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
|
||||
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
|
||||
torch::Tensor const& cu_seq_lens, // [BATCH+1]
|
||||
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);
|
||||
@ -2,6 +2,7 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "cuda_utils.h"
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
@ -374,7 +375,7 @@ void reshape_and_cache(
|
||||
torch::Tensor& slot_mapping, // [num_tokens]
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale) {
|
||||
int num_tokens = key.size(0);
|
||||
int num_tokens = slot_mapping.size(0);
|
||||
int num_heads = key.size(1);
|
||||
int head_size = key.size(2);
|
||||
int block_size = key_cache.size(3);
|
||||
@ -570,3 +571,161 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||
TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
|
||||
}
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// grid is launched with dimensions (batch, num_splits)
|
||||
template <typename scalar_t>
|
||||
__global__ void gather_cache(
|
||||
const scalar_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
|
||||
// ENTRIES...]
|
||||
scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...]
|
||||
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
|
||||
const int32_t* __restrict__ cu_seq_lens, // [BATCH+1]
|
||||
const int32_t block_size, const int32_t entry_size,
|
||||
const int64_t block_table_stride, const int64_t cache_block_stride,
|
||||
const int64_t cache_entry_stride, const int64_t dst_entry_stride,
|
||||
const int32_t* __restrict__ seq_starts) { // Optional: starting offsets per
|
||||
// batch
|
||||
|
||||
const int64_t bid = blockIdx.x; // Batch ID
|
||||
const int32_t num_splits = gridDim.y;
|
||||
const int32_t split = blockIdx.y;
|
||||
const int32_t seq_start = cu_seq_lens[bid];
|
||||
const int32_t seq_end = cu_seq_lens[bid + 1];
|
||||
const int32_t seq_len = seq_end - seq_start;
|
||||
const int32_t tot_blocks = cuda_utils::ceil_div(seq_len, block_size);
|
||||
const int32_t split_blocks = cuda_utils::ceil_div(tot_blocks, num_splits);
|
||||
|
||||
const int32_t split_start = split * split_blocks;
|
||||
const int32_t split_end = min((split + 1) * split_blocks, tot_blocks);
|
||||
|
||||
const bool is_active_split = (split_start < tot_blocks);
|
||||
const bool is_last_split = (split_end == tot_blocks);
|
||||
|
||||
if (!is_active_split) return;
|
||||
|
||||
int32_t full_blocks_end = split_end;
|
||||
int32_t partial_block_size = 0;
|
||||
|
||||
// Adjust the pointer for the block_table for this batch.
|
||||
// If seq_starts is provided, compute an offset based on (seq_starts[bid] /
|
||||
// page_size)
|
||||
const int32_t batch_offset = bid * block_table_stride;
|
||||
int32_t offset = 0;
|
||||
if (seq_starts != nullptr) {
|
||||
offset = seq_starts[bid] / block_size;
|
||||
}
|
||||
const int32_t* batch_block_table = block_table + batch_offset + offset;
|
||||
|
||||
// Adjust dst pointer based on the cumulative sequence lengths.
|
||||
dst += seq_start * dst_entry_stride;
|
||||
|
||||
if (is_last_split) {
|
||||
partial_block_size = seq_len % block_size;
|
||||
if (partial_block_size) full_blocks_end -= 1;
|
||||
}
|
||||
|
||||
auto copy_entry = [&](const scalar_t* __restrict__ _src,
|
||||
scalar_t* __restrict__ _dst) {
|
||||
for (int i = threadIdx.x; i < entry_size; i += blockDim.x)
|
||||
_dst[i] = _src[i];
|
||||
};
|
||||
|
||||
for (int pid = split_start; pid < full_blocks_end; ++pid) {
|
||||
auto block_id = batch_block_table[pid];
|
||||
auto block_start_ptr = src_cache + block_id * cache_block_stride;
|
||||
auto block_dst_ptr = dst + pid * block_size * dst_entry_stride;
|
||||
for (int eid = 0; eid < block_size; ++eid) {
|
||||
copy_entry(block_start_ptr + eid * cache_entry_stride,
|
||||
block_dst_ptr + eid * dst_entry_stride);
|
||||
}
|
||||
}
|
||||
|
||||
if (partial_block_size) {
|
||||
auto block_id = batch_block_table[full_blocks_end];
|
||||
auto block_start_ptr = src_cache + block_id * cache_block_stride;
|
||||
auto block_dst_ptr = dst + full_blocks_end * block_size * dst_entry_stride;
|
||||
for (int eid = 0; eid < partial_block_size; ++eid) {
|
||||
copy_entry(block_start_ptr + eid * cache_entry_stride,
|
||||
block_dst_ptr + eid * dst_entry_stride);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
// Macro to dispatch the kernel based on the data type.
|
||||
#define CALL_GATHER_CACHE(CPY_DTYPE) \
|
||||
vllm::gather_cache<CPY_DTYPE><<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<CPY_DTYPE*>(src_cache.data_ptr()), \
|
||||
reinterpret_cast<CPY_DTYPE*>(dst.data_ptr()), \
|
||||
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
|
||||
block_size, entry_size, block_table_stride, cache_block_stride, \
|
||||
cache_entry_stride, dst_entry_stride, seq_starts_ptr);
|
||||
|
||||
// Gather sequences from the cache into the destination tensor.
|
||||
// - cu_seq_lens contains the cumulative sequence lengths for each batch
|
||||
// - block_table contains the cache block indices for each sequence
|
||||
// - Optionally, seq_starts (if provided) offsets the starting block index by
|
||||
// (seq_starts[bid] / page_size)
|
||||
void gather_cache(
|
||||
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
|
||||
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
|
||||
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
|
||||
torch::Tensor const& cu_seq_lens, // [BATCH+1]
|
||||
int64_t batch_size,
|
||||
std::optional<torch::Tensor> seq_starts = std::nullopt) {
|
||||
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
int32_t block_size = src_cache.size(1);
|
||||
int32_t entry_size = src_cache.flatten(2, -1).size(2);
|
||||
|
||||
TORCH_CHECK(block_table.dtype() == torch::kInt32,
|
||||
"block_table must be int32");
|
||||
TORCH_CHECK(cu_seq_lens.dtype() == torch::kInt32,
|
||||
"cu_seq_lens must be int32");
|
||||
if (seq_starts.has_value()) {
|
||||
TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32,
|
||||
"seq_starts must be int32");
|
||||
}
|
||||
|
||||
TORCH_CHECK(src_cache.device() == dst.device(),
|
||||
"src_cache and dst must be on the same device");
|
||||
TORCH_CHECK(src_cache.device() == block_table.device(),
|
||||
"src_cache and block_table must be on the same device");
|
||||
TORCH_CHECK(src_cache.device() == cu_seq_lens.device(),
|
||||
"src_cache and cu_seq_lens must be on the same device");
|
||||
if (seq_starts.has_value()) {
|
||||
TORCH_CHECK(src_cache.device() == seq_starts.value().device(),
|
||||
"src_cache and seq_starts must be on the same device");
|
||||
}
|
||||
|
||||
int64_t block_table_stride = block_table.stride(0);
|
||||
int64_t cache_block_stride = src_cache.stride(0);
|
||||
int64_t cache_entry_stride = src_cache.stride(1);
|
||||
int64_t dst_entry_stride = dst.stride(0);
|
||||
|
||||
// Decide on the number of splits based on the batch size.
|
||||
int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16;
|
||||
dim3 grid(batch_size, num_splits);
|
||||
dim3 block(1024);
|
||||
|
||||
TORCH_CHECK(src_cache.dtype() == dst.dtype(),
|
||||
"src_cache and dst must have the same dtype");
|
||||
|
||||
const int dtype_bits = src_cache.element_size() * 8;
|
||||
const int32_t* seq_starts_ptr =
|
||||
seq_starts.has_value() ? seq_starts.value().data_ptr<int32_t>() : nullptr;
|
||||
|
||||
if (dtype_bits == 32) {
|
||||
CALL_GATHER_CACHE(uint32_t);
|
||||
} else if (dtype_bits == 16) {
|
||||
CALL_GATHER_CACHE(uint16_t);
|
||||
} else if (dtype_bits == 8) {
|
||||
CALL_GATHER_CACHE(uint8_t);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits);
|
||||
}
|
||||
}
|
||||
|
||||
@ -7,8 +7,3 @@ inline constexpr uint32_t next_pow_2(uint32_t const num) {
|
||||
if (num <= 1) return num;
|
||||
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline constexpr std::enable_if_t<std::is_integral_v<T>, T> ceil_div(T a, T b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
@ -2,6 +2,10 @@
|
||||
#include <torch/all.h>
|
||||
#include <cmath>
|
||||
|
||||
#if defined(__APPLE__)
|
||||
#include "omp.h"
|
||||
#endif
|
||||
|
||||
namespace vec_op {
|
||||
|
||||
#ifdef ARM_BF16_SUPPORT
|
||||
|
||||
@ -2,10 +2,14 @@
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#if defined(__CUDACC__) || defined(_NVHPC_CUDA)
|
||||
#define HOST_DEVICE_INLINE __forceinline__ __host__ __device__
|
||||
#define DEVICE_INLINE __forceinline__ __device__
|
||||
#define HOST_INLINE __forceinline__ __host__
|
||||
#if defined(__HIPCC__)
|
||||
#define HOST_DEVICE_INLINE __host__ __device__
|
||||
#define DEVICE_INLINE __device__
|
||||
#define HOST_INLINE __host__
|
||||
#elif defined(__CUDACC__) || defined(_NVHPC_CUDA)
|
||||
#define HOST_DEVICE_INLINE __host__ __device__ __forceinline__
|
||||
#define DEVICE_INLINE __device__ __forceinline__
|
||||
#define HOST_INLINE __host__ __forceinline__
|
||||
#else
|
||||
#define HOST_DEVICE_INLINE inline
|
||||
#define DEVICE_INLINE inline
|
||||
@ -25,3 +29,13 @@
|
||||
int64_t get_device_attribute(int64_t attribute, int64_t device_id);
|
||||
|
||||
int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id);
|
||||
|
||||
namespace cuda_utils {
|
||||
|
||||
template <typename T>
|
||||
HOST_DEVICE_INLINE constexpr std::enable_if_t<std::is_integral_v<T>, T>
|
||||
ceil_div(T a, T b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
}; // namespace cuda_utils
|
||||
@ -122,8 +122,8 @@ struct ScaledEpilogue
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
|
||||
typename EVTCompute0::Arguments evt0_args{b_args};
|
||||
return ArgumentType{a_args, evt0_args};
|
||||
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
|
||||
return ArgumentType{a_args, evt0_args, {}};
|
||||
}
|
||||
};
|
||||
|
||||
@ -167,8 +167,8 @@ struct ScaledEpilogueBias
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
|
||||
typename EVTCompute0::Arguments evt0_args{b_args};
|
||||
return ArgumentType{a_args, evt0_args, bias_args};
|
||||
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
|
||||
return ArgumentType{a_args, evt0_args, bias_args, {}};
|
||||
}
|
||||
};
|
||||
|
||||
@ -230,9 +230,10 @@ struct ScaledEpilogueBiasAzp
|
||||
auto azp_adj_args =
|
||||
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
|
||||
|
||||
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
|
||||
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
|
||||
return ArgumentType{a_args, evt_scale_b_args, bias_args};
|
||||
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args, {}};
|
||||
typename EVTComputeScaleB::Arguments evt_scale_b_args{
|
||||
b_args, evt_azp_args, {}};
|
||||
return ArgumentType{a_args, evt_scale_b_args, bias_args, {}};
|
||||
}
|
||||
};
|
||||
|
||||
@ -309,11 +310,12 @@ struct ScaledEpilogueBiasAzpToken
|
||||
auto azp_adj_args =
|
||||
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
|
||||
|
||||
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
|
||||
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
|
||||
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
|
||||
return ArgumentType{a_args, evt_scale_b_args, bias_args};
|
||||
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args, {}};
|
||||
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args, {}};
|
||||
typename EVTComputeScaleB::Arguments evt_scale_b_args{
|
||||
b_args, evt_acc_args, {}};
|
||||
return ArgumentType{a_args, evt_scale_b_args, bias_args, {}};
|
||||
}
|
||||
};
|
||||
|
||||
}; // namespace vllm::c2x
|
||||
}; // namespace vllm::c2x
|
||||
|
||||
@ -22,7 +22,7 @@ struct identity {
|
||||
T operator()(T lhs) const { return lhs; }
|
||||
};
|
||||
|
||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||
template <typename ElementAcc, typename ElementD, typename TileShape>
|
||||
struct TrivialEpilogue {
|
||||
private:
|
||||
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
||||
@ -44,32 +44,30 @@ struct TrivialEpilogue {
|
||||
* This class provides the common load descriptors for the
|
||||
* ScaledEpilogue[...] classes
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||
template <typename ElementAcc, typename ElementD, typename TileShape>
|
||||
struct ScaledEpilogueBase {
|
||||
protected:
|
||||
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
||||
|
||||
template <typename T>
|
||||
using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
||||
Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
0 /*Stages*/, TileShape, T, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
template <typename T>
|
||||
using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
||||
Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
0 /*Stages*/, TileShape, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
// Don't want to support nullptr by default
|
||||
template <typename T, bool EnableNullPtr = false>
|
||||
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T,
|
||||
Stride<Int<1>, Int<0>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
|
||||
0 /*Stages*/, TileShape, T, T, Stride<Int<1>, Int<0>, Int<0>>,
|
||||
128 / sizeof_bits_v<T>, EnableNullPtr>;
|
||||
|
||||
// Don't want to support nullptr by default
|
||||
template <typename T, bool EnableNullPtr = false>
|
||||
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T,
|
||||
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
|
||||
0 /*Stages*/, TileShape, T, T, Stride<Int<0>, Int<1>, Int<0>>,
|
||||
128 / sizeof_bits_v<T>, EnableNullPtr>;
|
||||
|
||||
// This utility function constructs the arguments for the load descriptors
|
||||
// from a tensor. It can handle both row and column, as well as row/column or
|
||||
@ -116,11 +114,11 @@ struct ScaledEpilogueBase {
|
||||
the A and B operands respectively. These scales may be either per-tensor or
|
||||
per row or column.
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||
template <typename ElementAcc, typename ElementD, typename TileShape>
|
||||
struct ScaledEpilogue
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
@ -146,8 +144,8 @@ struct ScaledEpilogue
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
|
||||
typename EVTCompute0::Arguments evt0_args{b_args};
|
||||
return ArgumentType{a_args, evt0_args};
|
||||
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
|
||||
return ArgumentType{a_args, evt0_args, {}};
|
||||
}
|
||||
};
|
||||
|
||||
@ -160,11 +158,11 @@ struct ScaledEpilogue
|
||||
* The bias tensor must be per-output channel.
|
||||
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||
template <typename ElementAcc, typename ElementD, typename TileShape>
|
||||
struct ScaledEpilogueBias
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
@ -193,8 +191,8 @@ struct ScaledEpilogueBias
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
|
||||
typename EVTCompute0::Arguments evt0_args{b_args};
|
||||
return ArgumentType{a_args, evt0_args, bias_args};
|
||||
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
|
||||
return ArgumentType{a_args, evt0_args, bias_args, {}};
|
||||
}
|
||||
};
|
||||
|
||||
@ -203,11 +201,11 @@ struct ScaledEpilogueBias
|
||||
* bias is a column vector instead of a row vector. Useful e.g. if we are
|
||||
* computing a GEMM via C^T += B^T A^T. This happens in the 2:4 sparse kernels.
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||
template <typename ElementAcc, typename ElementD, typename TileShape>
|
||||
struct ScaledEpilogueColumnBias
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
@ -236,8 +234,8 @@ struct ScaledEpilogueColumnBias
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
|
||||
typename EVTCompute0::Arguments evt0_args{b_args};
|
||||
return ArgumentType{a_args, evt0_args, bias_args};
|
||||
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
|
||||
return ArgumentType{a_args, evt0_args, bias_args, {}};
|
||||
}
|
||||
};
|
||||
|
||||
@ -249,11 +247,11 @@ struct ScaledEpilogueColumnBias
|
||||
*
|
||||
* This epilogue also supports bias, which remains per-channel.
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||
template <typename ElementAcc, typename ElementD, typename TileShape>
|
||||
struct ScaledEpilogueBiasAzp
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
@ -297,9 +295,10 @@ struct ScaledEpilogueBiasAzp
|
||||
auto azp_adj_args =
|
||||
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
|
||||
|
||||
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
|
||||
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
|
||||
return ArgumentType{a_args, evt_scale_b_args, bias_args};
|
||||
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args, {}};
|
||||
typename EVTComputeScaleB::Arguments evt_scale_b_args{
|
||||
b_args, evt_azp_args, {}};
|
||||
return ArgumentType{a_args, evt_scale_b_args, bias_args, {}};
|
||||
}
|
||||
};
|
||||
|
||||
@ -313,11 +312,11 @@ struct ScaledEpilogueBiasAzp
|
||||
*
|
||||
* This epilogue also supports bias, which remains per-channel.
|
||||
*/
|
||||
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||
template <typename ElementAcc, typename ElementD, typename TileShape>
|
||||
struct ScaledEpilogueBiasAzpToken
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
||||
: private ScaledEpilogueBase<ElementAcc, ElementD, TileShape> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, TileShape>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
||||
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
||||
@ -374,10 +373,11 @@ struct ScaledEpilogueBiasAzpToken
|
||||
auto azp_adj_args =
|
||||
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
|
||||
|
||||
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
|
||||
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
|
||||
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
|
||||
return ArgumentType{a_args, evt_scale_b_args, bias_args};
|
||||
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args, {}};
|
||||
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args, {}};
|
||||
typename EVTComputeScaleB::Arguments evt_scale_b_args{
|
||||
b_args, evt_acc_args, {}};
|
||||
return ArgumentType{a_args, evt_scale_b_args, bias_args, {}};
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import enum
|
||||
from typing import Dict, Union
|
||||
from typing import Union
|
||||
|
||||
from cutlass_library import *
|
||||
|
||||
@ -21,7 +21,7 @@ class MixedInputKernelScheduleType(enum.Enum):
|
||||
TmaWarpSpecializedCooperative = enum_auto()
|
||||
|
||||
|
||||
VLLMDataTypeNames: Dict[Union[VLLMDataType, DataType], str] = {
|
||||
VLLMDataTypeNames: dict[Union[VLLMDataType, DataType], str] = {
|
||||
**DataTypeNames, # type: ignore
|
||||
**{
|
||||
VLLMDataType.u4b8: "u4b8",
|
||||
@ -29,7 +29,7 @@ VLLMDataTypeNames: Dict[Union[VLLMDataType, DataType], str] = {
|
||||
}
|
||||
}
|
||||
|
||||
VLLMDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
|
||||
VLLMDataTypeTag: dict[Union[VLLMDataType, DataType], str] = {
|
||||
**DataTypeTag, # type: ignore
|
||||
**{
|
||||
VLLMDataType.u4b8: "cutlass::vllm_uint4b8_t",
|
||||
@ -37,7 +37,7 @@ VLLMDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
|
||||
}
|
||||
}
|
||||
|
||||
VLLMDataTypeSize: Dict[Union[VLLMDataType, DataType], int] = {
|
||||
VLLMDataTypeSize: dict[Union[VLLMDataType, DataType], int] = {
|
||||
**DataTypeSize, # type: ignore
|
||||
**{
|
||||
VLLMDataType.u4b8: 4,
|
||||
@ -45,7 +45,7 @@ VLLMDataTypeSize: Dict[Union[VLLMDataType, DataType], int] = {
|
||||
}
|
||||
}
|
||||
|
||||
VLLMDataTypeVLLMScalarTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
|
||||
VLLMDataTypeVLLMScalarTypeTag: dict[Union[VLLMDataType, DataType], str] = {
|
||||
VLLMDataType.u4b8: "vllm::kU4B8",
|
||||
VLLMDataType.u8b128: "vllm::kU8B128",
|
||||
DataType.u4: "vllm::kU4",
|
||||
@ -56,7 +56,7 @@ VLLMDataTypeVLLMScalarTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
|
||||
DataType.bf16: "vllm::kBfloat16",
|
||||
}
|
||||
|
||||
VLLMDataTypeTorchDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
|
||||
VLLMDataTypeTorchDataTypeTag: dict[Union[VLLMDataType, DataType], str] = {
|
||||
DataType.u8: "at::ScalarType::Byte",
|
||||
DataType.s8: "at::ScalarType::Char",
|
||||
DataType.e4m3: "at::ScalarType::Float8_e4m3fn",
|
||||
@ -66,7 +66,7 @@ VLLMDataTypeTorchDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
|
||||
DataType.f32: "at::ScalarType::Float",
|
||||
}
|
||||
|
||||
VLLMKernelScheduleTag: Dict[Union[
|
||||
VLLMKernelScheduleTag: dict[Union[
|
||||
MixedInputKernelScheduleType, KernelScheduleType], str] = {
|
||||
**KernelScheduleTag, # type: ignore
|
||||
**{
|
||||
|
||||
@ -152,6 +152,11 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
|
||||
int64_t row);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
|
||||
torch::Tensor const& B, torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha);
|
||||
|
||||
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
|
||||
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
|
||||
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
|
||||
#include "core/math.hpp"
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
@ -30,12 +31,18 @@ static inline cute::Shape<int, int, int, int> get_problem_shape(
|
||||
}
|
||||
|
||||
template <typename GemmKernel>
|
||||
void cutlass_gemm_caller(torch::Device device,
|
||||
cute::Shape<int, int, int, int> prob_shape,
|
||||
typename GemmKernel::MainloopArguments mainloop_args,
|
||||
typename GemmKernel::EpilogueArguments epilogue_args) {
|
||||
void cutlass_gemm_caller(
|
||||
torch::Device device, cute::Shape<int, int, int, int> prob_shape,
|
||||
typename GemmKernel::MainloopArguments mainloop_args,
|
||||
typename GemmKernel::EpilogueArguments epilogue_args,
|
||||
typename GemmKernel::TileSchedulerArguments scheduler = {}) {
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
prob_shape, mainloop_args, epilogue_args};
|
||||
prob_shape,
|
||||
mainloop_args,
|
||||
epilogue_args,
|
||||
hw_info,
|
||||
scheduler};
|
||||
|
||||
// Launch the CUTLASS GEMM kernel.
|
||||
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
@ -58,22 +65,28 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_params) {
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementC = typename Gemm::ElementC;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
|
||||
int64_t lda = a.stride(0);
|
||||
int64_t ldb = b.stride(1);
|
||||
int64_t ldc = out.stride(0);
|
||||
|
||||
using StrideA = cute::Stride<int64_t, cute::Int<1>, int64_t>;
|
||||
using StrideB = cute::Stride<int64_t, cute::Int<1>, int64_t>;
|
||||
using StrideC = typename Gemm::StrideC;
|
||||
|
||||
StrideA a_stride{lda, cute::Int<1>{}, 0};
|
||||
StrideB b_stride{ldb, cute::Int<1>{}, 0};
|
||||
StrideC c_stride{ldc, cute::Int<1>{}, cute::Int<0>{}};
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using StrideD = StrideC;
|
||||
using StrideAux = StrideC;
|
||||
|
||||
typename GemmKernel::ProblemShape prob_shape = get_problem_shape(a, b);
|
||||
auto [M, N, K, L] = prob_shape;
|
||||
|
||||
StrideA a_stride =
|
||||
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
|
||||
StrideB b_stride =
|
||||
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
|
||||
StrideC c_stride =
|
||||
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
|
||||
StrideD d_stride =
|
||||
cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
|
||||
StrideAux aux_stride = d_stride;
|
||||
|
||||
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
||||
@ -81,10 +94,11 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||
b_stride};
|
||||
|
||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||
// auto d_ptr = static_cast<ElementC*>(out.data_ptr());
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
Gemm::Epilogue::prepare_args(
|
||||
std::forward<EpilogueArgs>(epilogue_params)...),
|
||||
c_ptr, c_stride, c_ptr, c_stride};
|
||||
c_ptr, c_stride, c_ptr, d_stride};
|
||||
|
||||
cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
|
||||
epilogue_args);
|
||||
|
||||
@ -40,12 +40,7 @@ struct cutlass_3x_gemm {
|
||||
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
||||
float>::type;
|
||||
|
||||
using EpilogueDescriptor =
|
||||
cutlass::epilogue::collective::detail::EpilogueDescriptor<
|
||||
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
|
||||
ElementD, EpilogueSchedule>;
|
||||
|
||||
using Epilogue = Epilogue_<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||
using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>;
|
||||
|
||||
using StrideD = Stride<int64_t, Int<1>, Int<0>>;
|
||||
using ElementC = void;
|
||||
@ -88,4 +83,65 @@ struct cutlass_3x_gemm {
|
||||
struct GemmKernel : public KernelType {};
|
||||
};
|
||||
|
||||
template <typename ElementAB_, typename ElementD_,
|
||||
template <typename, typename, typename> typename Epilogue_,
|
||||
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
||||
typename EpilogueSchedule>
|
||||
struct cutlass_3x_gemm_sm100 {
|
||||
using ElementAB = ElementAB_;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentA =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
static constexpr int AlignmentB =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
|
||||
using ElementC = void;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentC =
|
||||
128 / cutlass::sizeof_bits<ElementD_>::value;
|
||||
|
||||
using ElementD = ElementD_;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentD = AlignmentC;
|
||||
|
||||
using ElementAcc =
|
||||
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
||||
float>::type;
|
||||
using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>;
|
||||
|
||||
// MMA type
|
||||
using ElementAccumulator = float;
|
||||
|
||||
// Epilogue types
|
||||
using ElementBias = cutlass::half_t;
|
||||
using ElementCompute = float;
|
||||
using ElementAux = ElementD;
|
||||
using LayoutAux = LayoutD;
|
||||
using ElementAmax = float;
|
||||
|
||||
using EVTCompute = typename Epilogue::EVTCompute;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, TileShape,
|
||||
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC,
|
||||
ElementD, LayoutD, AlignmentD, EpilogueSchedule,
|
||||
EVTCompute>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementAB,
|
||||
LayoutA, AlignmentA, ElementAB, LayoutB, AlignmentB,
|
||||
ElementAccumulator, TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
|
||||
};
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
@ -22,8 +22,9 @@ namespace vllm {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename OutType, int GroupSizeM_, int GroupSizeN_, int GroupSizeK_,
|
||||
int TileSizeM_ = 128, class ClusterShape = Shape<_1, _2, _1>>
|
||||
template <typename SchedulerType, typename OutType, int GroupSizeM_,
|
||||
int GroupSizeN_, int GroupSizeK_, int TileSizeM_ = 128,
|
||||
class ClusterShape = Shape<_1, _2, _1>>
|
||||
struct cutlass_3x_gemm_fp8_blockwise {
|
||||
using GroupSizeM = Int<GroupSizeM_>;
|
||||
using GroupSizeN = Int<GroupSizeN_>;
|
||||
@ -84,7 +85,7 @@ struct cutlass_3x_gemm_fp8_blockwise {
|
||||
|
||||
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
|
||||
cutlass::gemm::PersistentScheduler>>;
|
||||
SchedulerType>>;
|
||||
|
||||
struct GemmKernel : public KernelType {};
|
||||
|
||||
@ -150,8 +151,24 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
{}, c_ptr, c_stride, c_ptr, c_stride};
|
||||
|
||||
typename GemmKernel::TileSchedulerArguments scheduler;
|
||||
|
||||
static constexpr bool UsesStreamKScheduler =
|
||||
cute::is_same_v<typename GemmKernel::TileSchedulerTag,
|
||||
cutlass::gemm::StreamKScheduler>;
|
||||
|
||||
if constexpr (UsesStreamKScheduler) {
|
||||
using DecompositionMode = typename cutlass::gemm::kernel::detail::
|
||||
PersistentTileSchedulerSm90StreamKParams::DecompositionMode;
|
||||
using ReductionMode = typename cutlass::gemm::kernel::detail::
|
||||
PersistentTileSchedulerSm90StreamKParams::ReductionMode;
|
||||
|
||||
scheduler.decomposition_mode = DecompositionMode::StreamK;
|
||||
scheduler.reduction_mode = ReductionMode::Nondeterministic;
|
||||
}
|
||||
|
||||
c3x::cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
|
||||
epilogue_args);
|
||||
epilogue_args, scheduler);
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
@ -160,9 +177,18 @@ void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
cutlass_gemm_caller_blockwise<
|
||||
cutlass_3x_gemm_fp8_blockwise<OutType, 1, 128, 128>>(out, a, b, a_scales,
|
||||
b_scales);
|
||||
auto k = a.size(1);
|
||||
auto n = b.size(1);
|
||||
|
||||
if (k > 3 * n) {
|
||||
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
||||
cutlass::gemm::StreamKScheduler, OutType, 1, 128, 128>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
} else {
|
||||
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
||||
cutlass::gemm::PersistentScheduler, OutType, 1, 128, 128>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@ -30,4 +30,10 @@ void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales);
|
||||
|
||||
void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
24
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu
Normal file
24
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu
Normal file
@ -0,0 +1,24 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_sm100_fp8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm100_fp8_epilogue<c3x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm100_fp8_epilogue<c3x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@ -0,0 +1,67 @@
|
||||
#pragma once
|
||||
|
||||
#include "scaled_mm.cuh"
|
||||
#include "cutlass_gemm_caller.cuh"
|
||||
|
||||
/**
|
||||
* This file defines Gemm kernel configurations for SM100 (fp8) based on the
|
||||
* Gemm shape.
|
||||
*/
|
||||
|
||||
namespace vllm {
|
||||
|
||||
using c3x::cutlass_gemm_caller;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm100_fp8_config_default {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
using TileShape = Shape<_256, _128, _64>;
|
||||
using ClusterShape = Shape<_2, _2, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm100_fp8_config_default<InType, OutType,
|
||||
Epilogue>::Cutlass3xGemm;
|
||||
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
|
||||
template <template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm100_fp8_epilogue(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_sm100_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@ -1,7 +1,7 @@
|
||||
#include <cudaTypedefs.h>
|
||||
#include "c3x/scaled_mm_kernels.hpp"
|
||||
|
||||
#include "core/math.hpp"
|
||||
#include "cuda_utils.h"
|
||||
|
||||
/*
|
||||
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
||||
@ -33,7 +33,8 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
auto make_group_shape = [](torch::Tensor const& x,
|
||||
torch::Tensor const& s) -> GroupShape {
|
||||
TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
|
||||
return {ceil_div(x.size(0), s.size(0)), ceil_div(x.size(1), s.size(1))};
|
||||
return {cuda_utils::ceil_div(x.size(0), s.size(0)),
|
||||
cuda_utils::ceil_div(x.size(1), s.size(1))};
|
||||
};
|
||||
|
||||
GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
|
||||
@ -70,3 +71,28 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
|
||||
vllm::cutlass_scaled_mm_azp_sm90_int8(out, a, b, a_scales, b_scales, azp_adj,
|
||||
azp, bias);
|
||||
}
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12080
|
||||
|
||||
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
int M = a.size(0), N = b.size(1), K = a.size(1);
|
||||
TORCH_CHECK(
|
||||
(a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
|
||||
(b_scales.numel() == 1 || b_scales.numel() == b.size(1)),
|
||||
"Currently, block scaled fp8 gemm is not implemented for Blackwell");
|
||||
|
||||
// Standard per-tensor/per-token/per-channel scaling
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn,
|
||||
"Currently, only fp8 gemm is implemented for Blackwell");
|
||||
vllm::cutlass_scaled_mm_sm100_fp8(c, a, b, a_scales, b_scales, bias);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@ -29,6 +29,11 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
#endif
|
||||
|
||||
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
|
||||
@ -86,7 +91,7 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
|
||||
// and at least SM90 (Hopper)
|
||||
|
||||
#if defined CUDA_VERSION
|
||||
if (cuda_device_capability >= 90) {
|
||||
if (cuda_device_capability >= 90 && cuda_device_capability < 100) {
|
||||
return CUDA_VERSION >= 12000;
|
||||
}
|
||||
#endif
|
||||
@ -120,10 +125,22 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
|
||||
// Guard against compilation issues for sm90 kernels
|
||||
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
|
||||
if (version_num >= 90) {
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION < 12080
|
||||
if (version_num >= 90 && version_num < 100) {
|
||||
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
#else
|
||||
if (version_num >= 90 && version_num < 100) {
|
||||
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
} else if (version_num >= 100) {
|
||||
cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
|
||||
|
||||
@ -348,10 +348,7 @@ void scaled_fp4_quant_sm100a(torch::Tensor const& output,
|
||||
auto sf_out = static_cast<int32_t*>(output_sf.data_ptr());
|
||||
auto output_ptr = static_cast<int64_t*>(output.data_ptr());
|
||||
at::cuda::CUDAGuard device_guard{(char)input.get_device()};
|
||||
auto stream = at::cuda::getStreamFromPool(false, input.get_device());
|
||||
if (stream == nullptr) {
|
||||
std::cerr << "Warning: Null CUDA stream" << std::endl;
|
||||
}
|
||||
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
|
||||
// We don't support e8m0 scales at this moment.
|
||||
bool useUE8M0 = false;
|
||||
|
||||
38
csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu
Normal file
38
csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu
Normal file
@ -0,0 +1,38 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
|
||||
void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
|
||||
torch::Tensor const& B,
|
||||
torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha);
|
||||
#endif
|
||||
|
||||
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
|
||||
torch::Tensor const& B, torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha) {
|
||||
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
|
||||
return cutlass_scaled_fp4_mm_sm100a(D, A, B, A_sf, B_sf, alpha);
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
"No compiled nvfp4 mm kernel, vLLM should "
|
||||
"be compiled using CUDA 12.8 and target "
|
||||
"compute capability 100 or above.");
|
||||
}
|
||||
281
csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu
Normal file
281
csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu
Normal file
@ -0,0 +1,281 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
// Kernel Perf config
|
||||
template <typename T>
|
||||
struct KernelTraits;
|
||||
|
||||
template <>
|
||||
struct KernelTraits<float> {
|
||||
using MmaTileShape = Shape<_128, _128, _256>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
using PerSmTileShape_MNK = Shape<_128, _128, _256>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KernelTraits<cutlass::half_t> {
|
||||
using MmaTileShape = Shape<_256, _256, _256>;
|
||||
using ClusterShape = Shape<_4, _4, _1>;
|
||||
using PerSmTileShape_MNK = Shape<_128, _256, _256>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KernelTraits<cutlass::bfloat16_t> {
|
||||
using MmaTileShape = Shape<_256, _256, _256>;
|
||||
using ClusterShape = Shape<_4, _4, _1>;
|
||||
using PerSmTileShape_MNK = Shape<_128, _256, _256>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Fp4GemmSm100 {
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
|
||||
using LayoutATag = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentA = 32;
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
|
||||
using LayoutBTag = cutlass::layout::ColumnMajor;
|
||||
static constexpr int AlignmentB = 32;
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementD = T;
|
||||
using ElementC = T;
|
||||
using LayoutCTag = cutlass::layout::RowMajor;
|
||||
using LayoutDTag = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
// Kernel functional config
|
||||
using ElementAccumulator = float;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp;
|
||||
|
||||
// Kernel Perf config
|
||||
using MmaTileShape = typename KernelTraits<T>::MmaTileShape;
|
||||
using ClusterShape = typename KernelTraits<T>::ClusterShape;
|
||||
using PerSmTileShape_MNK = typename KernelTraits<T>::PerSmTileShape_MNK;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, PerSmTileShape_MNK, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
|
||||
ElementAccumulator, ElementC, LayoutCTag, AlignmentC, ElementD,
|
||||
LayoutDTag, AlignmentD,
|
||||
cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, ElementA, LayoutATag, AlignmentA, ElementB,
|
||||
LayoutBTag, AlignmentB, ElementAccumulator, MmaTileShape,
|
||||
ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using LayoutA = decltype(cute::make_layout(make_shape(0, 0, 0), StrideA{}));
|
||||
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using LayoutB = decltype(cute::make_layout(make_shape(0, 0, 0), StrideB{}));
|
||||
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using LayoutC = decltype(cute::make_layout(make_shape(0, 0, 0), StrideC{}));
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
using LayoutD = decltype(cute::make_layout(make_shape(0, 0, 0), StrideD{}));
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
typename T::Gemm::Arguments args_from_options(
|
||||
at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
||||
at::Tensor const& A_sf, at::Tensor const& B_sf, at::Tensor const& alpha,
|
||||
int64_t M, int64_t N, int64_t K) {
|
||||
using ElementA = typename T::Gemm::ElementA;
|
||||
using ElementB = typename T::Gemm::ElementB;
|
||||
using ElementSFA = cutlass::float_ue4m3_t;
|
||||
using ElementSFB = cutlass::float_ue4m3_t;
|
||||
using ElementD = typename T::Gemm::ElementD;
|
||||
using ElementCompute = float;
|
||||
using StrideA = typename T::StrideA;
|
||||
using StrideB = typename T::StrideB;
|
||||
using StrideD = typename T::StrideD;
|
||||
using Sm100BlkScaledConfig =
|
||||
typename T::Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;
|
||||
|
||||
int m = static_cast<int>(M);
|
||||
int n = static_cast<int>(N);
|
||||
int k = static_cast<int>(K);
|
||||
auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {m, k, 1});
|
||||
auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1});
|
||||
auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {m, n, 1});
|
||||
|
||||
auto layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(
|
||||
cute::make_shape(m, n, k, 1));
|
||||
auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(
|
||||
cute::make_shape(m, n, k, 1));
|
||||
|
||||
typename T::Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{m, n, k, 1},
|
||||
{// Mainloop arguments
|
||||
static_cast<ElementA const*>(A.data_ptr()), stride_A,
|
||||
static_cast<ElementB const*>(B.data_ptr()), stride_B,
|
||||
static_cast<ElementSFA const*>(A_sf.data_ptr()), layout_SFA,
|
||||
static_cast<ElementSFB const*>(B_sf.data_ptr()), layout_SFB},
|
||||
{ // Epilogue arguments
|
||||
{}, // epilogue.thread
|
||||
static_cast<ElementD const*>(D.data_ptr()),
|
||||
stride_D,
|
||||
static_cast<ElementD*>(D.data_ptr()),
|
||||
stride_D}};
|
||||
auto& fusion_args = arguments.epilogue.thread;
|
||||
fusion_args.alpha_ptr = static_cast<ElementCompute const*>(alpha.data_ptr());
|
||||
return arguments;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
||||
at::Tensor const& A_sf, at::Tensor const& B_sf,
|
||||
at::Tensor const& alpha, int64_t m, int64_t n, int64_t k,
|
||||
cudaStream_t stream) {
|
||||
typename Fp4GemmSm100<T>::Gemm gemm;
|
||||
|
||||
auto arguments =
|
||||
args_from_options<Fp4GemmSm100<T>>(D, A, B, A_sf, B_sf, alpha, m, n, k);
|
||||
|
||||
size_t workspace_size = Fp4GemmSm100<T>::Gemm::get_workspace_size(arguments);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.data_ptr(), stream));
|
||||
|
||||
CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream));
|
||||
}
|
||||
#else
|
||||
template <typename T>
|
||||
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
||||
at::Tensor const& A_sf, at::Tensor const& B_sf,
|
||||
at::Tensor const& alpha, int64_t m, int64_t n, int64_t k,
|
||||
cudaStream_t stream) {
|
||||
TORCH_CHECK(false,
|
||||
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
|
||||
"a CUTLASS 3.8 source directory to enable support.");
|
||||
}
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
|
||||
#define CHECK_TYPE(x, st, m) \
|
||||
TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type:", m)
|
||||
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x, m) \
|
||||
TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
|
||||
#define CHECK_INPUT(x, st, m) \
|
||||
CHECK_TH_CUDA(x, m); \
|
||||
CHECK_CONTIGUOUS(x, m); \
|
||||
CHECK_TYPE(x, st, m)
|
||||
|
||||
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
|
||||
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
|
||||
|
||||
void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
|
||||
torch::Tensor const& B,
|
||||
torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha) {
|
||||
CHECK_INPUT(A, FLOAT4_E2M1X2, "a");
|
||||
CHECK_INPUT(B, FLOAT4_E2M1X2, "b");
|
||||
|
||||
CHECK_INPUT(A_sf, SF_DTYPE, "scale_a");
|
||||
CHECK_INPUT(B_sf, SF_DTYPE, "scale_b");
|
||||
|
||||
CHECK_INPUT(alpha, at::ScalarType::Float, "alpha");
|
||||
|
||||
TORCH_CHECK(A.dim() == 2, "a must be a matrix");
|
||||
TORCH_CHECK(B.dim() == 2, "b must be a matrix");
|
||||
TORCH_CHECK(A.sizes()[1] == B.sizes()[1],
|
||||
"a and b shapes cannot be multiplied (", A.sizes()[0], "x",
|
||||
A.sizes()[1], " and ", B.sizes()[0], "x", B.sizes()[1], ")");
|
||||
|
||||
auto const m = A.sizes()[0];
|
||||
auto const n = B.sizes()[0];
|
||||
auto const k = A.sizes()[1] * 2;
|
||||
|
||||
constexpr int alignment = 32;
|
||||
TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ", alignment,
|
||||
", but got a shape: (", A.sizes()[0], "x", A.sizes()[1],
|
||||
"), k: ", k, ".");
|
||||
TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ", alignment,
|
||||
", but got b shape: (", B.sizes()[0], "x", B.sizes()[1], ").");
|
||||
|
||||
auto round_up = [](int x, int y) { return (x + y - 1) / y * y; };
|
||||
int rounded_m = round_up(m, 128);
|
||||
int rounded_n = round_up(n, 128);
|
||||
// Since k is divisible by 32 (alignment), k / 16 is guaranteed to be an
|
||||
// integer.
|
||||
int rounded_k = round_up(k / 16, 4);
|
||||
|
||||
TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
|
||||
TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
|
||||
TORCH_CHECK(A_sf.sizes()[1] == B_sf.sizes()[1],
|
||||
"scale_a and scale_b shapes cannot be multiplied (",
|
||||
A_sf.sizes()[0], "x", A_sf.sizes()[1], " and ", B_sf.sizes()[0],
|
||||
"x", B_sf.sizes()[1], ")");
|
||||
TORCH_CHECK(A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k,
|
||||
"scale_a must be padded and swizzled to a shape (", rounded_m,
|
||||
"x", rounded_k, "), but got a shape (", A_sf.sizes()[0], "x",
|
||||
A_sf.sizes()[1], ")");
|
||||
TORCH_CHECK(B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k,
|
||||
"scale_b must be padded and swizzled to a shape (", rounded_n,
|
||||
"x", rounded_k, "), but got a shape (", B_sf.sizes()[0], "x",
|
||||
B_sf.sizes()[1], ")");
|
||||
|
||||
auto out_dtype = D.dtype();
|
||||
at::cuda::CUDAGuard device_guard{(char)A.get_device()};
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
|
||||
|
||||
if (out_dtype == at::ScalarType::Half) {
|
||||
runGemm<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
} else if (out_dtype == at::ScalarType::BFloat16) {
|
||||
runGemm<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
} else if (out_dtype == at::ScalarType::Float) {
|
||||
runGemm<float>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm");
|
||||
}
|
||||
}
|
||||
@ -1,137 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifdef __HIPCC__
|
||||
#include <hip/hip_runtime.h>
|
||||
#else
|
||||
#include <type_traits>
|
||||
#include <stdint.h>
|
||||
#include <math.h>
|
||||
#include <iostream>
|
||||
#endif
|
||||
|
||||
#include "hip_float8_impl.h"
|
||||
|
||||
struct alignas(1) hip_fp8 {
|
||||
struct from_bits_t {};
|
||||
HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() {
|
||||
return from_bits_t();
|
||||
}
|
||||
uint8_t data;
|
||||
|
||||
hip_fp8() = default;
|
||||
HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default;
|
||||
HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete;
|
||||
explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t)
|
||||
: data(v) {}
|
||||
|
||||
#ifdef __HIP__MI300__
|
||||
// NOTE: ON-DEVICE... always optimal bias
|
||||
explicit HIP_FP8_DEVICE hip_fp8(float v)
|
||||
: data(hip_fp8_impl::to_fp8_from_fp32(v)) {}
|
||||
|
||||
explicit HIP_FP8_DEVICE hip_fp8(_Float16 v)
|
||||
: hip_fp8(static_cast<float>(v)) {}
|
||||
|
||||
// Host only implementation using s/w simulation
|
||||
explicit HIP_FP8_HOST
|
||||
#else // __HIP__MI300__
|
||||
// both Host and DEVICE for non-MI300 using s/w simulation
|
||||
explicit HIP_FP8_HOST_DEVICE
|
||||
#endif // __HIP__MI300__
|
||||
hip_fp8(float v) {
|
||||
data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/,
|
||||
true /*clip*/>(v);
|
||||
}
|
||||
|
||||
explicit HIP_FP8_HOST_DEVICE hip_fp8(double v)
|
||||
: hip_fp8(static_cast<float>(v)) {}
|
||||
|
||||
#ifdef __HIP__MI300__
|
||||
// upcast using device specific intrinsic
|
||||
explicit inline HIP_FP8_DEVICE operator float() const {
|
||||
float fval;
|
||||
uint32_t i32val = static_cast<uint32_t>(data);
|
||||
|
||||
// upcast
|
||||
asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0"
|
||||
: "=v"(fval)
|
||||
: "v"(i32val));
|
||||
|
||||
return fval;
|
||||
}
|
||||
|
||||
explicit inline HIP_FP8_HOST operator float() const
|
||||
#else // __HIP__MI300__
|
||||
explicit inline HIP_FP8_HOST_DEVICE operator float() const
|
||||
#endif // __HIP__MI300__
|
||||
{
|
||||
return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>(
|
||||
data);
|
||||
}
|
||||
};
|
||||
|
||||
namespace std {
|
||||
inline hip_fp8 sin(hip_fp8 a) { return hip_fp8(sinf(float(a))); }
|
||||
inline hip_fp8 cos(hip_fp8 a) { return hip_fp8(cosf(float(a))); }
|
||||
HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a) { return a; }
|
||||
} // namespace std
|
||||
|
||||
// Special operator overloading
|
||||
inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8) {
|
||||
return os << float(f8);
|
||||
}
|
||||
|
||||
// all + operator overloading with mixed types
|
||||
// mixed types, always converts to f32, does computation in f32, and returns
|
||||
// float
|
||||
inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b) {
|
||||
return (fa + float(b));
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb) {
|
||||
return (float(a) + fb);
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b) {
|
||||
return hip_fp8(float(a) + float(b));
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b) {
|
||||
return a = hip_fp8(float(a) + float(b));
|
||||
}
|
||||
|
||||
// overloading multiplication, always returns float,
|
||||
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b) {
|
||||
return float(a) * float(b);
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b) {
|
||||
return (a * float(b));
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b) {
|
||||
return (float(a) * b);
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b) {
|
||||
return ((float)a * float(b));
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b) {
|
||||
return ((float)a * float(b));
|
||||
}
|
||||
|
||||
// overloading for compare
|
||||
inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b) {
|
||||
return (a.data == b.data);
|
||||
}
|
||||
inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b) {
|
||||
return (a.data != b.data);
|
||||
}
|
||||
|
||||
inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b) {
|
||||
return static_cast<float>(a) >= static_cast<float>(b);
|
||||
}
|
||||
inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b) {
|
||||
return static_cast<float>(a) > static_cast<float>(b);
|
||||
}
|
||||
@ -1,315 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#if defined(__HIPCC__) && defined(__gfx942__)
|
||||
#define __HIP__MI300__
|
||||
#endif
|
||||
|
||||
#ifdef __HIPCC__
|
||||
#define HIP_FP8_HOST_DEVICE __host__ __device__
|
||||
#define HIP_FP8_HOST __host__
|
||||
#define HIP_FP8_DEVICE __device__
|
||||
#else
|
||||
#define HIP_FP8_HOST_DEVICE
|
||||
#define HIP_FP8_HOST
|
||||
#define HIP_FP8_DEVICE
|
||||
#endif
|
||||
|
||||
namespace hip_fp8_impl {
|
||||
|
||||
#ifdef __HIP__MI300__
|
||||
HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v) {
|
||||
uint8_t i8data;
|
||||
union {
|
||||
float fval;
|
||||
uint32_t i32val;
|
||||
uint8_t i8val[4]; // NOTE: not endian independent
|
||||
} val;
|
||||
|
||||
uint32_t ival = 0;
|
||||
val.fval = v;
|
||||
|
||||
if ((val.i32val & 0x7F800000) !=
|
||||
0x7F800000) { /// propagate NAN/INF, no clipping
|
||||
val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
|
||||
}
|
||||
|
||||
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival,
|
||||
false); // false -> WORD0
|
||||
val.i32val = ival;
|
||||
i8data = val.i8val[0];
|
||||
|
||||
return i8data;
|
||||
}
|
||||
#endif // __HIP__MI300__
|
||||
|
||||
HIP_FP8_HOST inline int clz(uint32_t x) { return __builtin_clz(x); }
|
||||
#if defined(__HIPCC__) || defined(__CUDA_ARCH__)
|
||||
HIP_FP8_DEVICE inline int clz(uint32_t x) { return __clz(x); }
|
||||
#endif
|
||||
|
||||
template <int we, int wm, typename T, bool negative_zero_nan, bool clip>
|
||||
HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false,
|
||||
uint32_t rng = 0) {
|
||||
#ifdef __HIPCC__
|
||||
constexpr bool is_half = std::is_same<T, _Float16>::value;
|
||||
#else
|
||||
constexpr bool is_half = false;
|
||||
#endif
|
||||
constexpr bool is_float = std::is_same<T, float>::value;
|
||||
static_assert(wm + we == 7, "wm+we==7");
|
||||
static_assert(is_half || is_float, "Only half and float can be cast to f8");
|
||||
|
||||
const int mfmt = (sizeof(T) == 4) ? 23 : 10;
|
||||
uint32_t x;
|
||||
if (sizeof(T) == 4) {
|
||||
x = reinterpret_cast<uint32_t&>(_x);
|
||||
} else {
|
||||
x = reinterpret_cast<uint16_t&>(_x);
|
||||
}
|
||||
|
||||
uint32_t head, mantissa;
|
||||
int exponent, bias;
|
||||
uint32_t sign;
|
||||
|
||||
if (sizeof(T) == 4) {
|
||||
head = x & 0xFF800000;
|
||||
mantissa = x & 0x7FFFFF;
|
||||
exponent = (head >> 23) & 0xFF;
|
||||
sign = head >> 31;
|
||||
bias = 127;
|
||||
} else {
|
||||
head = x & 0xFC00;
|
||||
mantissa = x & 0x3FF;
|
||||
exponent = (head >> 10) & 0x1F;
|
||||
sign = head >> 15;
|
||||
bias = 15;
|
||||
}
|
||||
|
||||
uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm);
|
||||
|
||||
// Deal with inf and NaNs
|
||||
if (negative_zero_nan) {
|
||||
if (sizeof(T) == 4) {
|
||||
if ((x & 0x7F800000) == 0x7F800000) {
|
||||
return 0x80;
|
||||
}
|
||||
} else {
|
||||
// if(__hisinf(x) || __hisnan(x))
|
||||
if ((x & 0x7C00) == 0x7C00) {
|
||||
return 0x80;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (sizeof(T) == 4) {
|
||||
if ((x & 0x7F800000) == 0x7F800000) {
|
||||
return signed_inf + (mantissa != 0 ? 1 : 0);
|
||||
}
|
||||
} else {
|
||||
if ((x & 0x7C00) == 0x7C00) {
|
||||
return signed_inf + (mantissa != 0 ? 1 : 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (x == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// First need to check if it is normal or denorm as there is a difference of
|
||||
// implicit 1 Then need to adjust the exponent to align with the F8 exponent,
|
||||
// in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
|
||||
// to mantissa and truncate. And for RNE, no need to add rng. Then probably
|
||||
// need to check whether there is carry and adjust exponent and mantissa again
|
||||
|
||||
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
|
||||
// bits
|
||||
const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0);
|
||||
const int f8_denormal_act_exponent =
|
||||
1 - f8_bias; // actual exponent of f8 denormal
|
||||
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
|
||||
// f8_exponent is the converted f8 exponent with bias encoding
|
||||
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
|
||||
// the difference needs to be adjusted and mantissa shifted
|
||||
int act_exponent, f8_exponent, exponent_diff;
|
||||
|
||||
if (exponent == 0) { // fp32/fp16 is in denormal.
|
||||
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we
|
||||
mostly concern fp16 here. In this case, f8 is usually in denormal. But there
|
||||
could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
|
||||
exponent bias 16. It means that there are some numbers in fp16 denormal but they
|
||||
are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
|
||||
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
|
||||
(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
|
||||
act_exponent = exponent - bias + 1;
|
||||
exponent_diff =
|
||||
f8_denormal_act_exponent -
|
||||
act_exponent; // actual exponent is exponent-bias+1 as it is denormal
|
||||
} else { // fp32/fp16 is normal with implicit 1
|
||||
act_exponent = exponent - bias;
|
||||
if (act_exponent <= f8_denormal_act_exponent) {
|
||||
/* This is the case where fp32/fp16 is normal but it is in f8 denormal
|
||||
range. For example fp8 nanoo mode, denormal exponent is -7, but if the
|
||||
fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1,
|
||||
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
|
||||
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
|
||||
exponent_diff = f8_denormal_act_exponent - act_exponent;
|
||||
} else { // both fp32/fp16 and f8 are in normal range
|
||||
exponent_diff = 0; // exponent_diff=0 does not mean there is no
|
||||
// difference for this case, act_exponent could be
|
||||
// larger. Just that it does not need shift mantissa
|
||||
}
|
||||
mantissa += (1 << mfmt); // Add the implicit 1 into mantissa
|
||||
}
|
||||
|
||||
bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) ==
|
||||
static_cast<uint32_t>(1 << (mfmt - wm + exponent_diff - 1));
|
||||
/* This part is a bit tricky. The judgment of whether it is a tie needs to be
|
||||
done before we shift right as shift right could rip off some residual part
|
||||
and make something not midpoint look like midpoint. For example, the fp16
|
||||
number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after
|
||||
shift right by 4 bits, it would look like midpoint.
|
||||
*/
|
||||
|
||||
if (exponent_diff > 0) {
|
||||
mantissa >>= exponent_diff;
|
||||
} else if (exponent_diff == -1) {
|
||||
mantissa <<= -exponent_diff;
|
||||
}
|
||||
bool implicit_one = mantissa & (1 << mfmt);
|
||||
// if there is no implicit 1, it means the f8 is denormal and need to adjust
|
||||
// to denorm exponent
|
||||
f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ +
|
||||
f8_bias - (implicit_one ? 0 : 1);
|
||||
|
||||
// Now we have the exponent and mantissa adjusted
|
||||
uint32_t drop_mask = (1 << (mfmt - wm)) - 1;
|
||||
bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit
|
||||
// that is not truncated is 1
|
||||
mantissa +=
|
||||
(stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) &
|
||||
drop_mask;
|
||||
|
||||
// Now we deal with overflow
|
||||
if (f8_exponent == 0) {
|
||||
if ((1 << mfmt) & mantissa) {
|
||||
f8_exponent = 1; // denormal overflow to become normal, promote exponent
|
||||
}
|
||||
} else {
|
||||
if ((1 << (mfmt + 1)) & mantissa) {
|
||||
mantissa >>= 1;
|
||||
f8_exponent++;
|
||||
}
|
||||
}
|
||||
|
||||
mantissa >>= (mfmt - wm);
|
||||
|
||||
// above range: quantize to maximum possible float of the same sign
|
||||
const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2);
|
||||
if (f8_exponent > max_exp) {
|
||||
if (clip) {
|
||||
mantissa = (1 << wm) - 1;
|
||||
f8_exponent = max_exp;
|
||||
} else {
|
||||
return signed_inf;
|
||||
}
|
||||
}
|
||||
|
||||
if (f8_exponent == 0 && mantissa == 0) {
|
||||
return negative_zero_nan ? 0 : (sign << 7);
|
||||
}
|
||||
mantissa &= (1 << wm) - 1;
|
||||
return (sign << 7) | (f8_exponent << wm) | mantissa;
|
||||
}
|
||||
|
||||
template <int we, int wm, typename T = float, bool negative_zero_nan = true>
|
||||
inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x) {
|
||||
#ifdef __HIPCC__
|
||||
constexpr bool is_half = std::is_same<T, _Float16>::value;
|
||||
#else
|
||||
constexpr bool is_half = false;
|
||||
#endif
|
||||
constexpr bool is_float = std::is_same<T, float>::value;
|
||||
static_assert(is_half || is_float, "only half and float are supported");
|
||||
|
||||
constexpr int weo = is_half ? 5 : 8;
|
||||
constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7);
|
||||
|
||||
T fInf, fNegInf, fNaN, fNeg0;
|
||||
|
||||
#ifdef __HIPCC__
|
||||
if (is_half) {
|
||||
const uint16_t ihInf = 0x7C00;
|
||||
const uint16_t ihNegInf = 0xFC00;
|
||||
const uint16_t ihNaN = 0x7C01;
|
||||
const uint16_t ihNeg0 = 0x8000;
|
||||
fInf = reinterpret_cast<const _Float16&>(ihInf);
|
||||
fNegInf = reinterpret_cast<const _Float16&>(ihNegInf);
|
||||
fNaN = reinterpret_cast<const _Float16&>(ihNaN);
|
||||
fNeg0 = reinterpret_cast<const _Float16&>(ihNeg0);
|
||||
} else
|
||||
#endif
|
||||
if (is_float) {
|
||||
const uint32_t ifInf = 0x7F800000;
|
||||
const uint32_t ifNegInf = 0xFF800000;
|
||||
const uint32_t ifNaN = 0x7F800001;
|
||||
const uint32_t ifNeg0 = 0x80000000;
|
||||
fInf = reinterpret_cast<const float&>(ifInf);
|
||||
fNegInf = reinterpret_cast<const float&>(ifNegInf);
|
||||
fNaN = reinterpret_cast<const float&>(ifNaN);
|
||||
fNeg0 = reinterpret_cast<const float&>(ifNeg0);
|
||||
}
|
||||
|
||||
if (x == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t sign = x >> 7;
|
||||
uint32_t mantissa = x & ((1 << wm) - 1);
|
||||
int exponent = (x & 0x7F) >> wm;
|
||||
if (negative_zero_nan) {
|
||||
if (x == 0x80) {
|
||||
return fNaN;
|
||||
}
|
||||
} else {
|
||||
if (x == 0x80) {
|
||||
return fNeg0;
|
||||
}
|
||||
if (exponent == ((1 << we) - 1)) {
|
||||
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
|
||||
}
|
||||
}
|
||||
typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval;
|
||||
if (we == 5 && is_half && !negative_zero_nan) {
|
||||
retval = x << 8;
|
||||
return reinterpret_cast<const T&>(retval);
|
||||
}
|
||||
|
||||
const int exp_low_cutoff =
|
||||
(1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
|
||||
|
||||
// subnormal input
|
||||
if (exponent == 0) {
|
||||
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
|
||||
int sh = 1 + clz(mantissa) - (32 - wm);
|
||||
mantissa <<= sh;
|
||||
exponent += 1 - sh;
|
||||
mantissa &= ((1 << wm) - 1);
|
||||
}
|
||||
exponent += exp_low_cutoff - 1;
|
||||
mantissa <<= wmo - wm;
|
||||
|
||||
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
|
||||
if (exponent <= 0) {
|
||||
mantissa |= 1 << wmo;
|
||||
mantissa >>= 1 - exponent;
|
||||
exponent = 0;
|
||||
}
|
||||
|
||||
if (sizeof(T) == 2) {
|
||||
retval = (sign << 15) | (exponent << 10) | mantissa;
|
||||
} else {
|
||||
retval = (sign << 31) | (exponent << 23) | mantissa;
|
||||
}
|
||||
return reinterpret_cast<const T&>(retval);
|
||||
}
|
||||
|
||||
} // namespace hip_fp8_impl
|
||||
@ -1,13 +1,11 @@
|
||||
#pragma once
|
||||
#include "hip_float8.h"
|
||||
#include <hip/hip_fp8.h>
|
||||
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_bfloat16.h>
|
||||
|
||||
#include "../../../attention/dtype_fp8.cuh"
|
||||
#include "../../../attention/dtype_float32.cuh"
|
||||
#include "../../../attention/dtype_bfloat16.cuh"
|
||||
#include "../../../attention/attention_dtypes.h"
|
||||
|
||||
namespace vllm {
|
||||
#ifdef USE_ROCM
|
||||
@ -26,40 +24,31 @@ __inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
|
||||
return x;
|
||||
}
|
||||
|
||||
#if HIP_FP8_TYPE_OCP
|
||||
using fp8_type = __hip_fp8_e4m3;
|
||||
using fp8x2_type = __hip_fp8x2_e4m3;
|
||||
#else
|
||||
using fp8_type = __hip_fp8_e4m3_fnuz;
|
||||
using fp8x2_type = __hip_fp8x2_e4m3_fnuz;
|
||||
#endif
|
||||
|
||||
// fp8 -> half
|
||||
template <>
|
||||
__inline__ __device__ uint16_t
|
||||
vec_conversion<uint16_t, uint8_t>(const uint8_t& a) {
|
||||
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||
__half_raw res;
|
||||
res.data = static_cast<float>(f8);
|
||||
return res.x;
|
||||
return __hip_cvt_fp8_to_halfraw(a, fp8_type::__default_interpret).x;
|
||||
}
|
||||
|
||||
// fp8x2 -> half2
|
||||
template <>
|
||||
__inline__ __device__ uint32_t
|
||||
vec_conversion<uint32_t, uint16_t>(const uint16_t& a) {
|
||||
#if defined(__HIP__MI300__) && \
|
||||
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||
union {
|
||||
__half2_raw h2r;
|
||||
uint32_t ui32;
|
||||
} tmp;
|
||||
tmp.h2r.x.data = f2[0];
|
||||
tmp.h2r.y.data = f2[1];
|
||||
tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
|
||||
return tmp.ui32;
|
||||
#else
|
||||
union {
|
||||
uint16_t u16[2];
|
||||
uint32_t u32;
|
||||
} tmp;
|
||||
|
||||
tmp.u16[0] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a));
|
||||
tmp.u16[1] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U));
|
||||
return tmp.u32;
|
||||
#endif
|
||||
}
|
||||
|
||||
// fp8x4 -> half2x2
|
||||
@ -92,9 +81,9 @@ using __nv_bfloat16 = __hip_bfloat16;
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat16
|
||||
vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) {
|
||||
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||
float f{f8};
|
||||
return __float2bfloat16(f);
|
||||
fp8_type f8;
|
||||
f8.__x = a;
|
||||
return __float2bfloat16(static_cast<float>(f8));
|
||||
}
|
||||
|
||||
using __nv_bfloat162 = __hip_bfloat162;
|
||||
@ -136,27 +125,18 @@ __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) {
|
||||
// fp8 -> float
|
||||
template <>
|
||||
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) {
|
||||
hip_fp8 fp8{a, hip_fp8::from_bits()};
|
||||
return static_cast<float>(fp8);
|
||||
fp8_type f8;
|
||||
f8.__x = a;
|
||||
return static_cast<float>(f8);
|
||||
}
|
||||
|
||||
// fp8x2 -> float2
|
||||
template <>
|
||||
__inline__ __device__ float2
|
||||
vec_conversion<float2, uint16_t>(const uint16_t& a) {
|
||||
#if defined(__HIP__MI300__) && \
|
||||
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||
float2 res;
|
||||
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||
res.x = f2[0];
|
||||
res.y = f2[1];
|
||||
return res;
|
||||
#else
|
||||
float2 res;
|
||||
res.x = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a));
|
||||
res.y = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U));
|
||||
return res;
|
||||
#endif
|
||||
fp8x2_type f8x2;
|
||||
f8x2.__x = a;
|
||||
return static_cast<float2>(f8x2);
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
@ -169,6 +149,15 @@ vec_conversion<Float4_, uint32_t>(const uint32_t& a) {
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ float4
|
||||
vec_conversion<float4, uint32_t>(const uint32_t& a) {
|
||||
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
|
||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> float8
|
||||
template <>
|
||||
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) {
|
||||
@ -189,33 +178,36 @@ __inline__ __device__ uint8_t
|
||||
vec_conversion<uint8_t, uint16_t>(const uint16_t& a) {
|
||||
__half_raw tmp;
|
||||
tmp.x = a;
|
||||
return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
|
||||
fp8_type::__default_interpret);
|
||||
}
|
||||
|
||||
hip_fp8 f8{static_cast<float>(tmp.data)};
|
||||
return f8.data;
|
||||
template <>
|
||||
__inline__ __device__ uint16_t
|
||||
vec_conversion<uint16_t, uint32_t>(const uint32_t& a) {
|
||||
union {
|
||||
uint32_t ui32;
|
||||
__half2_raw h2r;
|
||||
} tmp;
|
||||
tmp.ui32 = a;
|
||||
return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
|
||||
fp8_type::__default_interpret);
|
||||
}
|
||||
|
||||
// bf16 -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t
|
||||
vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) {
|
||||
hip_fp8 res{__bfloat162float(a)};
|
||||
return res.data;
|
||||
return __hip_cvt_float_to_fp8(__bfloat162float(a),
|
||||
fp8_type::__default_saturation,
|
||||
fp8_type::__default_interpret);
|
||||
}
|
||||
|
||||
// float -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) {
|
||||
hip_fp8 f8(a);
|
||||
return f8.data;
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ float4
|
||||
vec_conversion<float4, uint32_t>(const uint32_t& a) {
|
||||
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
|
||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||
return res;
|
||||
return __hip_cvt_float_to_fp8(a, fp8_type::__default_saturation,
|
||||
fp8_type::__default_interpret);
|
||||
}
|
||||
|
||||
// float2 -> half2
|
||||
@ -307,90 +299,22 @@ vec_conversion<bf16_8_t, Float8_>(const Float8_& a) {
|
||||
|
||||
*/
|
||||
|
||||
// fp8 -> half
|
||||
template <>
|
||||
__inline__ __device__ uint16_t
|
||||
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, const float scale) {
|
||||
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||
__half_raw res;
|
||||
res.data = static_cast<float>(f8) * scale;
|
||||
return res.x;
|
||||
}
|
||||
|
||||
// fp8x2 -> half2
|
||||
template <>
|
||||
__inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
|
||||
const uint16_t& a, const float scale) {
|
||||
#if defined(__HIP__MI300__) && \
|
||||
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||
union {
|
||||
__half2_raw h2r;
|
||||
uint32_t ui32;
|
||||
} tmp;
|
||||
tmp.h2r.x.data = f2[0] * scale;
|
||||
tmp.h2r.y.data = f2[1] * scale;
|
||||
return tmp.ui32;
|
||||
#else
|
||||
union {
|
||||
uint16_t u16[2];
|
||||
uint32_t u32;
|
||||
} tmp;
|
||||
|
||||
tmp.u16[0] =
|
||||
scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a), scale);
|
||||
tmp.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>(
|
||||
static_cast<uint8_t>(a >> 8U), scale);
|
||||
return tmp.u32;
|
||||
#endif
|
||||
}
|
||||
|
||||
// fp8x4 -> half2x2
|
||||
template <>
|
||||
__inline__ __device__ uint2
|
||||
scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, const float scale) {
|
||||
union {
|
||||
uint2 u32x2;
|
||||
uint32_t u32[2];
|
||||
} tmp;
|
||||
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
|
||||
tmp.u32[1] =
|
||||
scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
|
||||
return tmp.u32x2;
|
||||
}
|
||||
|
||||
// fp8x8 -> half2x4
|
||||
template <>
|
||||
__inline__ __device__ uint4
|
||||
scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale) {
|
||||
union {
|
||||
uint4 u64x2;
|
||||
uint2 u64[2];
|
||||
} tmp;
|
||||
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
|
||||
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
|
||||
return tmp.u64x2;
|
||||
}
|
||||
|
||||
using __nv_bfloat16 = __hip_bfloat16;
|
||||
|
||||
// fp8 -> __nv_bfloat16
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat16
|
||||
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a,
|
||||
const float scale) {
|
||||
hip_fp8 f8{a, hip_fp8::from_bits()};
|
||||
float f{f8};
|
||||
return __float2bfloat16(f * scale);
|
||||
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) {
|
||||
fp8_type f8;
|
||||
f8.__x = a;
|
||||
return __float2bfloat16(static_cast<float>(f8) * scale);
|
||||
}
|
||||
|
||||
using __nv_bfloat162 = __hip_bfloat162;
|
||||
|
||||
// fp8x2 -> __nv_bfloat162
|
||||
template <>
|
||||
__inline__ __device__ __nv_bfloat162
|
||||
scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
|
||||
const float scale) {
|
||||
float scale) {
|
||||
__nv_bfloat162 res;
|
||||
res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
|
||||
res.y =
|
||||
@ -400,8 +324,8 @@ scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
|
||||
|
||||
// fp8x4 -> bf16_4_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
|
||||
const uint32_t& a, const float scale) {
|
||||
__inline__ __device__ bf16_4_t
|
||||
scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, float scale) {
|
||||
bf16_4_t res;
|
||||
res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
|
||||
res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
|
||||
@ -412,7 +336,7 @@ __inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
|
||||
// fp8x8 -> bf16_8_t
|
||||
template <>
|
||||
__inline__ __device__ bf16_8_t
|
||||
scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale) {
|
||||
scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale) {
|
||||
bf16_4_t tmp1, tmp2;
|
||||
tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
|
||||
tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
|
||||
@ -427,29 +351,19 @@ scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale) {
|
||||
// fp8 -> float
|
||||
template <>
|
||||
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
|
||||
const uint8_t& a, const float scale) {
|
||||
hip_fp8 fp8{a, hip_fp8::from_bits()};
|
||||
return static_cast<float>(fp8) * scale;
|
||||
const uint8_t& a, float scale) {
|
||||
fp8_type f8;
|
||||
f8.__x = a;
|
||||
return static_cast<float>(f8) * scale;
|
||||
}
|
||||
|
||||
// fp8x2 -> float2
|
||||
template <>
|
||||
__inline__ __device__ float2
|
||||
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, const float scale) {
|
||||
#if defined(__HIP__MI300__) && \
|
||||
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
|
||||
float2 res;
|
||||
const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
|
||||
res.x = f2[0] * scale;
|
||||
res.y = f2[1] * scale;
|
||||
return res;
|
||||
#else
|
||||
float2 res;
|
||||
res.x = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a), scale);
|
||||
res.y = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U),
|
||||
scale);
|
||||
return res;
|
||||
#endif
|
||||
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale) {
|
||||
fp8x2_type f8x2;
|
||||
f8x2.__x = a;
|
||||
return static_cast<float2>(f8x2) * scale;
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
@ -462,10 +376,18 @@ scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) {
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template <>
|
||||
__inline__ __device__ float4
|
||||
scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, float scale) {
|
||||
Float4_ res = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
|
||||
return {res.x.x, res.x.y, res.y.x, res.y.y};
|
||||
}
|
||||
|
||||
// fp8x8 -> float8
|
||||
template <>
|
||||
__inline__ __device__ Float8_
|
||||
scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale) {
|
||||
scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale) {
|
||||
Float4_ tmp1, tmp2;
|
||||
tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
|
||||
tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
|
||||
@ -477,44 +399,184 @@ scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale) {
|
||||
return res;
|
||||
}
|
||||
|
||||
/* Quantize(HP / scale) => FP8 */
|
||||
// fp8 -> half
|
||||
template <>
|
||||
__inline__ __device__ uint16_t
|
||||
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) {
|
||||
__half_raw res;
|
||||
res.data = scaled_vec_conversion<float, uint8_t>(a, scale);
|
||||
return res.x;
|
||||
}
|
||||
|
||||
// TODO(Hai): vectorized to add
|
||||
// fp8x2 -> half2
|
||||
template <>
|
||||
__inline__ __device__ uint32_t
|
||||
scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale) {
|
||||
__half2_raw h2r =
|
||||
__hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
|
||||
union {
|
||||
__half2_raw h2r;
|
||||
uint32_t ui32;
|
||||
} tmp;
|
||||
tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
|
||||
tmp.h2r.x.data *= scale;
|
||||
tmp.h2r.y.data *= scale;
|
||||
return tmp.ui32;
|
||||
}
|
||||
|
||||
// fp8x4 -> half2x2
|
||||
template <>
|
||||
__inline__ __device__ uint2
|
||||
scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, float scale) {
|
||||
union {
|
||||
uint2 u32x2;
|
||||
uint32_t u32[2];
|
||||
} tmp;
|
||||
tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
|
||||
tmp.u32[1] =
|
||||
scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
|
||||
return tmp.u32x2;
|
||||
}
|
||||
|
||||
// fp8x8 -> half2x4
|
||||
template <>
|
||||
__inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a,
|
||||
float scale) {
|
||||
union {
|
||||
uint4 u64x2;
|
||||
uint2 u64[2];
|
||||
} tmp;
|
||||
tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
|
||||
tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
|
||||
return tmp.u64x2;
|
||||
}
|
||||
|
||||
// half -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t
|
||||
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, const float scale) {
|
||||
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, float scale) {
|
||||
__half_raw tmp;
|
||||
tmp.x = a;
|
||||
tmp.data /= scale;
|
||||
return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
|
||||
fp8_type::__default_interpret);
|
||||
}
|
||||
|
||||
hip_fp8 f8{static_cast<float>(tmp.data) / scale};
|
||||
return f8.data;
|
||||
// halfx2 -> fp8x2
|
||||
template <>
|
||||
__inline__ __device__ uint16_t
|
||||
scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) {
|
||||
union {
|
||||
uint32_t ui32;
|
||||
__half2_raw h2r;
|
||||
} tmp;
|
||||
tmp.ui32 = a;
|
||||
tmp.h2r.x.data /= scale;
|
||||
tmp.h2r.y.data /= scale;
|
||||
return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
|
||||
fp8_type::__default_interpret);
|
||||
}
|
||||
|
||||
// half2x2 -> fp8x4
|
||||
template <>
|
||||
__inline__ __device__ uint32_t
|
||||
scaled_vec_conversion<uint32_t, uint2>(const uint2& a, float scale) {
|
||||
union {
|
||||
uint16_t ui16[2];
|
||||
uint32_t ui32;
|
||||
} tmp;
|
||||
tmp.ui16[0] = scaled_vec_conversion<uint16_t, uint32_t>(a.x, scale);
|
||||
tmp.ui16[1] = scaled_vec_conversion<uint16_t, uint32_t>(a.y, scale);
|
||||
return tmp.ui32;
|
||||
}
|
||||
|
||||
// half2x4 -> fp8x8
|
||||
template <>
|
||||
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint4>(const uint4& a,
|
||||
float scale) {
|
||||
union {
|
||||
uint2 ui2[2];
|
||||
uint4 ui4;
|
||||
} tmp;
|
||||
tmp.ui4 = a;
|
||||
uint2 res;
|
||||
res.x = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[0], scale);
|
||||
res.y = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[1], scale);
|
||||
return res;
|
||||
}
|
||||
|
||||
// bf16 -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
|
||||
const __nv_bfloat16& a, const float scale) {
|
||||
hip_fp8 res{__bfloat162float(a) / scale};
|
||||
return res.data;
|
||||
const __nv_bfloat16& a, float scale) {
|
||||
return __hip_cvt_float_to_fp8(__bfloat162float(a) / scale,
|
||||
fp8_type::__default_saturation,
|
||||
fp8_type::__default_interpret);
|
||||
}
|
||||
|
||||
// bf16x2 -> fp8x2
|
||||
template <>
|
||||
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, __nv_bfloat162>(
|
||||
const __nv_bfloat162& a, float scale) {
|
||||
union {
|
||||
uint8_t ui8[2];
|
||||
uint16_t ui16;
|
||||
} tmp;
|
||||
tmp.ui8[0] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.x, scale);
|
||||
tmp.ui8[1] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.y, scale);
|
||||
return tmp.ui16;
|
||||
}
|
||||
|
||||
// bf16x4 -> fp8x4
|
||||
template <>
|
||||
__inline__ __device__ uint32_t
|
||||
scaled_vec_conversion<uint32_t, bf16_4_t>(const bf16_4_t& a, float scale) {
|
||||
union {
|
||||
uint16_t ui16[2];
|
||||
uint32_t ui32;
|
||||
} tmp;
|
||||
tmp.ui16[0] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.x, scale);
|
||||
tmp.ui16[1] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.y, scale);
|
||||
return tmp.ui32;
|
||||
}
|
||||
|
||||
// bf16x8 -> fp8x8
|
||||
template <>
|
||||
__inline__ __device__ uint2
|
||||
scaled_vec_conversion<uint2, bf16_8_t>(const bf16_8_t& a, float scale) {
|
||||
uint2 res;
|
||||
res.x = scaled_vec_conversion<uint32_t, bf16_4_t>({a.x, a.y}, scale);
|
||||
res.y = scaled_vec_conversion<uint32_t, bf16_4_t>({a.z, a.w}, scale);
|
||||
return res;
|
||||
}
|
||||
|
||||
// float -> fp8
|
||||
template <>
|
||||
__inline__ __device__ uint8_t
|
||||
scaled_vec_conversion<uint8_t, float>(const float& a, const float scale) {
|
||||
hip_fp8 f8(a / scale);
|
||||
return f8.data;
|
||||
scaled_vec_conversion<uint8_t, float>(const float& a, float scale) {
|
||||
return __hip_cvt_float_to_fp8(a / scale, fp8_type::__default_saturation,
|
||||
fp8_type::__default_interpret);
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
// floatx2 -> fp8x2
|
||||
template <>
|
||||
__inline__ __device__ float4
|
||||
scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, const float scale) {
|
||||
Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
|
||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||
return res;
|
||||
__inline__ __device__ uint16_t
|
||||
scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale) {
|
||||
return __hip_cvt_float2_to_fp8x2(a / scale, fp8_type::__default_saturation,
|
||||
fp8_type::__default_interpret);
|
||||
}
|
||||
|
||||
// floatx4 -> fp8x4
|
||||
template <>
|
||||
__inline__ __device__ uint32_t
|
||||
scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale) {
|
||||
union {
|
||||
uint16_t ui16[2];
|
||||
uint32_t ui32;
|
||||
} tmp;
|
||||
tmp.ui16[0] = scaled_vec_conversion<uint16_t, float2>({a.x, a.y}, scale);
|
||||
tmp.ui16[1] = scaled_vec_conversion<uint16_t, float2>({a.z, a.w}, scale);
|
||||
return tmp.ui32;
|
||||
}
|
||||
#endif // ENABLE_FP8
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@ C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
|
||||
std::numeric_limits<FP8_TYPE>::max();
|
||||
#else
|
||||
#include <c10/util/Float8_e4m3fnuz.h>
|
||||
#include "amd/hip_float8.h"
|
||||
#include "amd/quant_utils.cuh"
|
||||
using FP8_TYPE = c10::Float8_e4m3fnuz;
|
||||
// Using the default max value from pytorch (240.0) will cause accuracy
|
||||
// issue when running dynamic quantization. Here use 224.0f for rocm.
|
||||
@ -47,8 +47,10 @@ __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
|
||||
return static_cast<c10::Float8_e4m3fn>(r);
|
||||
#else
|
||||
// Use hardware cvt instruction for fp8 on rocm
|
||||
return c10::Float8_e4m3fnuz(hip_fp8(r).data,
|
||||
c10::Float8_e4m3fnuz::from_bits());
|
||||
return c10::Float8_e4m3fnuz(
|
||||
__hip_cvt_float_to_fp8(r, fp8::fp8_type::__default_saturation,
|
||||
fp8::fp8_type::__default_interpret),
|
||||
c10::Float8_e4m3fnuz::from_bits());
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@ -37,6 +37,8 @@ static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t *
|
||||
return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
|
||||
}
|
||||
|
||||
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
|
||||
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
|
||||
|
||||
#define VDR_Q4_0_Q8_1_MMVQ 2
|
||||
#define VDR_Q4_0_Q8_1_MMQ 4
|
||||
|
||||
1008
csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu
Normal file
1008
csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu
Normal file
File diff suppressed because it is too large
Load Diff
163
csrc/quantization/gptq_allspark/allspark_repack.cu
Normal file
163
csrc/quantization/gptq_allspark/allspark_repack.cu
Normal file
@ -0,0 +1,163 @@
|
||||
#include "allspark_utils.cuh"
|
||||
#include <torch/all.h>
|
||||
#include "core/registration.h"
|
||||
|
||||
namespace allspark {
|
||||
|
||||
// Rearrange B to facilitate Ampere Tensor Core load data
|
||||
// reorder B from (K, N) to (N_32align / 4, K * 4)
|
||||
// K % 16 == 0, N % 16 == 0, N_32align % 32 == 0
|
||||
template <typename FType>
|
||||
__global__ void __launch_bounds__(128)
|
||||
rearrange_kn_weight_as_n32k16_order_ldg16_kernel(
|
||||
const uint8_t* B, const FType* B_scale, const FType* B_zero,
|
||||
uint8_t* B_result, FType* B_scale_result, FType* B_zero_result,
|
||||
const int K, const int N, const int N_32align) {
|
||||
const int lane_id = threadIdx.x % 32;
|
||||
const int warp_id = threadIdx.x / 32;
|
||||
|
||||
if (blockIdx.x != gridDim.x - 1) {
|
||||
// Load B
|
||||
// per block process 64(k) * 128(n) B elements
|
||||
// per warp process 16(k) * 128 B elements
|
||||
const int src_row_base_idx =
|
||||
blockIdx.x * 64 + warp_id * 16 + ((lane_id % 8) / 2) * 2;
|
||||
const int src_col_idx =
|
||||
blockIdx.y * 128 + (lane_id / 8) * 32 + (lane_id % 2) * 16;
|
||||
uint8_t B_frag[4][16];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
int src_row_idx = src_row_base_idx + (i / 2) * 8 + (i % 2);
|
||||
int src_offset = src_row_idx * N + src_col_idx;
|
||||
bool guard = src_row_idx < K && src_col_idx < N;
|
||||
ldg128_cg_0(*reinterpret_cast<uint32_t*>(B_frag[i]),
|
||||
*(reinterpret_cast<uint32_t*>(B_frag[i]) + 1),
|
||||
*(reinterpret_cast<uint32_t*>(B_frag[i]) + 2),
|
||||
*(reinterpret_cast<uint32_t*>(B_frag[i]) + 3), B + src_offset,
|
||||
guard);
|
||||
}
|
||||
|
||||
// reorder B
|
||||
uint8_t B_reorder_frag[8][8];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
int dst_i = j % 8;
|
||||
int dst_j = i + (j / 8) * 4;
|
||||
B_reorder_frag[dst_i][dst_j] = B_frag[i][j];
|
||||
}
|
||||
}
|
||||
|
||||
// Store B
|
||||
const int dst_row_base_idx = blockIdx.y * (128 / 4) + (lane_id / 8) * 8;
|
||||
const int dst_col_idx =
|
||||
blockIdx.x * (64 * 4) + warp_id * 64 + (lane_id % 8) * 8;
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
int dst_row_idx = dst_row_base_idx + i;
|
||||
int dst_offset = dst_row_idx * K * 4 + dst_col_idx;
|
||||
bool guard = (dst_row_base_idx < N_32align / 4) && (dst_col_idx < K * 4);
|
||||
if (guard) {
|
||||
*reinterpret_cast<int2*>(B_result + dst_offset) =
|
||||
*reinterpret_cast<int2*>(B_reorder_frag[i]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Load B_scale and B_zero
|
||||
FType b_scale_reg, b_zero_reg;
|
||||
int src_offset = blockIdx.y * 128 + threadIdx.x;
|
||||
ldg16_cg_0(b_scale_reg, B_scale + src_offset, src_offset < N);
|
||||
if (B_zero != nullptr)
|
||||
ldg16_cg_0(b_zero_reg, B_zero + src_offset, src_offset < N);
|
||||
int dst_offset =
|
||||
blockIdx.y * 128 + warp_id * 32 + (lane_id % 8) * 4 + lane_id / 8;
|
||||
if (dst_offset < N_32align) {
|
||||
B_scale_result[dst_offset] = b_scale_reg;
|
||||
if (B_zero != nullptr) B_zero_result[dst_offset] = b_zero_reg;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename FType>
|
||||
void rearrange_kn_weight_as_n32k16_order_ldg16(
|
||||
const uint8_t* B, const FType* B_scale, const FType* B_zero,
|
||||
uint8_t* B_result, FType* B_scale_result, FType* B_zero_result,
|
||||
const int64_t K, const int64_t N, const int64_t N_32align,
|
||||
cudaStream_t stream) {
|
||||
if (N % 16 != 0 || K % 16 != 0) {
|
||||
std::cerr << "Now only support N and K is multiples of 16" << std::endl;
|
||||
}
|
||||
const int BLOCK = 128;
|
||||
int grid_x = (K + 64 - 1) / 64 + 1;
|
||||
int grid_y = (N + 128 - 1) / 128;
|
||||
dim3 grid(grid_x, grid_y);
|
||||
|
||||
rearrange_kn_weight_as_n32k16_order_ldg16_kernel<FType>
|
||||
<<<grid, BLOCK, 0, stream>>>(B, B_scale, B_zero, B_result, B_scale_result,
|
||||
B_zero_result, K, N, N_32align);
|
||||
}
|
||||
} // namespace allspark
|
||||
|
||||
void rearrange_kn_weight_as_n32k16_order(
|
||||
torch::Tensor const& b_qweight, torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& b_zeros, bool has_zp,
|
||||
torch::Tensor& b_qweight_reorder, torch::Tensor& b_scales_reorder,
|
||||
c10::optional<torch::Tensor> const& b_zeros_reorder, const int64_t K,
|
||||
const int64_t N, const int64_t N_32align) {
|
||||
// Verify device and strides
|
||||
TORCH_CHECK(b_qweight.device().is_cuda(), "b_qweight is not on GPU");
|
||||
TORCH_CHECK(b_qweight.is_contiguous(), "b_qweight is not contiguous");
|
||||
|
||||
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
|
||||
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
|
||||
|
||||
TORCH_CHECK(b_qweight_reorder.device().is_cuda(),
|
||||
"b_qweight_reorder is not on GPU");
|
||||
TORCH_CHECK(b_qweight_reorder.is_contiguous(),
|
||||
"b_qweight_reorder is not contiguous");
|
||||
|
||||
TORCH_CHECK(b_scales_reorder.device().is_cuda(),
|
||||
"b_scales_reorder is not on GPU");
|
||||
TORCH_CHECK(b_scales_reorder.is_contiguous(),
|
||||
"b_scales_reorder is not contiguous");
|
||||
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(b_zeros.value().device().is_cuda(), "b_zeros is not on GPU");
|
||||
TORCH_CHECK(b_zeros.value().is_contiguous(), "b_zeros is not contiguous");
|
||||
|
||||
TORCH_CHECK(b_zeros_reorder.value().device().is_cuda(),
|
||||
"b_zeros_reorder is not on GPU");
|
||||
TORCH_CHECK(b_zeros_reorder.value().is_contiguous(),
|
||||
"b_zeros_reorder is not contiguous");
|
||||
}
|
||||
|
||||
const uint8_t* matB = reinterpret_cast<const uint8_t*>(b_qweight.data_ptr());
|
||||
const void* b_scale = b_scales.data_ptr();
|
||||
const void* b_zero = has_zp ? b_zeros.value().data_ptr() : nullptr;
|
||||
|
||||
uint8_t* matB_reorder =
|
||||
reinterpret_cast<uint8_t*>(b_qweight_reorder.data_ptr());
|
||||
void* b_scale_reorder = b_scales_reorder.data_ptr();
|
||||
void* b_zero_reorder = has_zp ? b_zeros_reorder.value().data_ptr() : nullptr;
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
if (b_scales.dtype() == at::ScalarType::Half) {
|
||||
allspark::rearrange_kn_weight_as_n32k16_order_ldg16<__half>(
|
||||
matB, reinterpret_cast<const __half*>(b_scale),
|
||||
reinterpret_cast<const __half*>(b_zero), matB_reorder,
|
||||
reinterpret_cast<__half*>(b_scale_reorder),
|
||||
reinterpret_cast<__half*>(b_zero_reorder), K, N, N_32align, stream);
|
||||
} else if (b_scales.dtype() == at::ScalarType::BFloat16) {
|
||||
allspark::rearrange_kn_weight_as_n32k16_order_ldg16<__nv_bfloat16>(
|
||||
matB, reinterpret_cast<const __nv_bfloat16*>(b_scale),
|
||||
reinterpret_cast<const __nv_bfloat16*>(b_zero), matB_reorder,
|
||||
reinterpret_cast<__nv_bfloat16*>(b_scale_reorder),
|
||||
reinterpret_cast<__nv_bfloat16*>(b_zero_reorder), K, N, N_32align,
|
||||
stream);
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("rearrange_kn_weight_as_n32k16_order",
|
||||
&rearrange_kn_weight_as_n32k16_order);
|
||||
}
|
||||
408
csrc/quantization/gptq_allspark/allspark_utils.cuh
Normal file
408
csrc/quantization/gptq_allspark/allspark_utils.cuh
Normal file
@ -0,0 +1,408 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/all.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace allspark {
|
||||
|
||||
#define CHECK_CUDA(cmd) \
|
||||
do { \
|
||||
cudaError_t cuda_status = cmd; \
|
||||
if (cuda_status != cudaSuccess) { \
|
||||
std::string err_str = cudaGetErrorString(cuda_status); \
|
||||
std::cerr << "Failed: " << __FILE__ << ":" << __LINE__ << " " \
|
||||
<< err_str; \
|
||||
exit(-1); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define CHECK_CUBLAS(cmd) \
|
||||
do { \
|
||||
cublasStatus_t cublas_status = cmd; \
|
||||
if (cublas_status != CUBLAS_STATUS_SUCCESS) { \
|
||||
std::cerr << "Failed: " << __FILE__ << ":" << __LINE__ << " " \
|
||||
<< cublas_status << std::endl; \
|
||||
exit(-1); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
template <typename FType, typename QType>
|
||||
struct SM8x_GEMM_W8A16_Splitk_Params {
|
||||
const FType* A_ptr;
|
||||
const QType* B_ptr;
|
||||
const FType* B_scale_ptr;
|
||||
const FType* B_zero_ptr;
|
||||
FType* C_ptr;
|
||||
int M;
|
||||
int N;
|
||||
int K;
|
||||
int SplitK;
|
||||
int GroupCnt;
|
||||
int GroupSize;
|
||||
FType* C_split_ptr; // for non-fused splitk reduce
|
||||
float* C_tmp_ptr; // for fused splitk reduce
|
||||
uint32_t* red_count_ptr; // for fused splitk reduce
|
||||
};
|
||||
|
||||
struct alignas(16) BlockTileSplitkParams {
|
||||
int Mtile;
|
||||
int Ntile;
|
||||
int SplitK;
|
||||
bool EnableFuse;
|
||||
};
|
||||
|
||||
template <typename FType, int BLOCK, int N_MATRIX>
|
||||
__global__ void f16_gemm_splitk_reduce_kernel(const FType* C_split, FType* C,
|
||||
uint32_t n, uint32_t n_matrix,
|
||||
uint32_t matrix_size) {
|
||||
int idx = blockIdx.x * BLOCK + threadIdx.x;
|
||||
|
||||
if (idx >= matrix_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
FType sum(0);
|
||||
|
||||
int n_mat = N_MATRIX > 0 ? N_MATRIX : (int)n_matrix;
|
||||
for (int i = 0; i < n_mat; ++i) {
|
||||
sum += C_split[idx + i * matrix_size];
|
||||
}
|
||||
|
||||
C[idx] = sum;
|
||||
}
|
||||
|
||||
template <typename FType>
|
||||
void f16_gemm_splitk_reduce(const FType* C_split, FType* C, const uint32_t m,
|
||||
const uint32_t n, const uint32_t n_matrix,
|
||||
cudaStream_t stream) {
|
||||
const int BLOCK = 128;
|
||||
uint32_t matrix_size = m * n;
|
||||
int grid = (matrix_size + BLOCK - 1) / BLOCK;
|
||||
|
||||
void (*kernel)(const FType*, FType*, uint32_t, uint32_t, uint32_t) = nullptr;
|
||||
|
||||
switch (n_matrix) {
|
||||
case 4:
|
||||
kernel = f16_gemm_splitk_reduce_kernel<FType, BLOCK, 4>;
|
||||
break;
|
||||
case 5:
|
||||
kernel = f16_gemm_splitk_reduce_kernel<FType, BLOCK, 5>;
|
||||
break;
|
||||
case 6:
|
||||
kernel = f16_gemm_splitk_reduce_kernel<FType, BLOCK, 6>;
|
||||
break;
|
||||
case 7:
|
||||
kernel = f16_gemm_splitk_reduce_kernel<FType, BLOCK, 7>;
|
||||
break;
|
||||
case 8:
|
||||
kernel = f16_gemm_splitk_reduce_kernel<FType, BLOCK, 8>;
|
||||
break;
|
||||
case 9:
|
||||
kernel = f16_gemm_splitk_reduce_kernel<FType, BLOCK, 9>;
|
||||
break;
|
||||
case 10:
|
||||
kernel = f16_gemm_splitk_reduce_kernel<FType, BLOCK, 10>;
|
||||
break;
|
||||
case 11:
|
||||
kernel = f16_gemm_splitk_reduce_kernel<FType, BLOCK, 11>;
|
||||
break;
|
||||
case 12:
|
||||
kernel = f16_gemm_splitk_reduce_kernel<FType, BLOCK, 12>;
|
||||
break;
|
||||
default:
|
||||
kernel = f16_gemm_splitk_reduce_kernel<FType, BLOCK, -1>;
|
||||
break;
|
||||
}
|
||||
|
||||
kernel<<<grid, BLOCK, 0, stream>>>(C_split, C, n, n_matrix, matrix_size);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct HalfType;
|
||||
template <>
|
||||
struct HalfType<half> {
|
||||
using T1 = __half;
|
||||
using T2 = __half2;
|
||||
};
|
||||
template <>
|
||||
struct HalfType<__nv_bfloat16> {
|
||||
using T1 = __nv_bfloat16;
|
||||
using T2 = __nv_bfloat162;
|
||||
};
|
||||
|
||||
// convert 64-bit pointer to 32-bit smem addr
|
||||
__device__ __forceinline__ uint32_t smem_u32addr(const void* smem_ptr) {
|
||||
uint32_t addr;
|
||||
asm("{.reg .u64 u64addr;\n"
|
||||
" cvta.to.shared.u64 u64addr, %1;\n"
|
||||
" cvt.u32.u64 %0, u64addr;}\n"
|
||||
: "=r"(addr)
|
||||
: "l"(smem_ptr));
|
||||
|
||||
return addr;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void ldg16_cg_0(T& r0, const void* ptr, bool guard) {
|
||||
static_assert(sizeof(T) == 2, "ldg16_cg_0: invalid T");
|
||||
|
||||
asm volatile(
|
||||
"{.reg .pred p;\n"
|
||||
" setp.ne.b32 p, %2, 0;\n"
|
||||
" @!p mov.b16 %0, 0;\n"
|
||||
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDACC_VER_MINOR__ >= 4 && \
|
||||
__CUDA_ARCH__ >= 750
|
||||
" @p ld.global.cg.L2::128B.b16 {%0}, [%1];}\n"
|
||||
#else
|
||||
" @p ld.global.ca.b16 {%0}, [%1];}\n"
|
||||
#endif
|
||||
: "=h"(reinterpret_cast<uint16_t&>(r0))
|
||||
: "l"(ptr), "r"((int)guard));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void ldg64_ca(T& r0, T& r1, const void* ptr,
|
||||
bool guard) {
|
||||
static_assert(sizeof(T) == 4, "ldg64_ca: invalid T");
|
||||
|
||||
asm volatile(
|
||||
"{.reg .pred p;\n"
|
||||
" setp.ne.b32 p, %3, 0;\n"
|
||||
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDACC_VER_MINOR__ >= 4 && \
|
||||
__CUDA_ARCH__ >= 750
|
||||
" @p ld.global.ca.L2::128B.v2.b32 {%0, %1}, [%2];}\n"
|
||||
#else
|
||||
" @p ld.global.ca.v2.b32 {%0, %1}, [%2];}\n"
|
||||
#endif
|
||||
: "=r"(reinterpret_cast<uint32_t&>(r0)),
|
||||
"=r"(reinterpret_cast<uint32_t&>(r1))
|
||||
: "l"(ptr), "r"((int)guard));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void ldg128_cg_0(T& r0, T& r1, T& r2, T& r3,
|
||||
const void* ptr, bool guard) {
|
||||
static_assert(sizeof(T) == 4, "ldg128_cg_0: invalid T");
|
||||
|
||||
asm volatile(
|
||||
"{.reg .pred p;\n"
|
||||
" setp.ne.b32 p, %5, 0;\n"
|
||||
" @!p mov.b32 %0, 0;\n"
|
||||
" @!p mov.b32 %1, 0;\n"
|
||||
" @!p mov.b32 %2, 0;\n"
|
||||
" @!p mov.b32 %3, 0;\n"
|
||||
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDACC_VER_MINOR__ >= 4 && \
|
||||
__CUDA_ARCH__ >= 750
|
||||
" @p ld.global.cg.L2::128B.v4.b32 {%0, %1, %2, %3}, [%4];}\n"
|
||||
#else
|
||||
" @p ld.global.cg.v4.b32 {%0, %1, %2, %3}, [%4];}\n"
|
||||
#endif
|
||||
: "=r"(reinterpret_cast<uint32_t&>(r0)),
|
||||
"=r"(reinterpret_cast<uint32_t&>(r1)),
|
||||
"=r"(reinterpret_cast<uint32_t&>(r2)),
|
||||
"=r"(reinterpret_cast<uint32_t&>(r3))
|
||||
: "l"(ptr), "r"((int)guard));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void lds128(T& reg0, T& reg1, T& reg2, T& reg3,
|
||||
const uint32_t addr) {
|
||||
static_assert(sizeof(T) == 4, "lds128: invalid T");
|
||||
|
||||
asm volatile("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(reinterpret_cast<uint32_t&>(reg0)),
|
||||
"=r"(reinterpret_cast<uint32_t&>(reg1)),
|
||||
"=r"(reinterpret_cast<uint32_t&>(reg2)),
|
||||
"=r"(reinterpret_cast<uint32_t&>(reg3))
|
||||
: "r"(addr));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void stg128(const T& r0, const T& r1, const T& r2,
|
||||
const T& r3, const void* ptr,
|
||||
bool guard) {
|
||||
static_assert(sizeof(T) == 4, "stg128: invalid T");
|
||||
|
||||
asm volatile(
|
||||
"{.reg .pred p;\n"
|
||||
" setp.ne.b32 p, %1, 0;\n"
|
||||
" @p st.global.v4.b32 [%0], {%2, %3, %4, %5};}\n"
|
||||
:
|
||||
: "l"(ptr), "r"((int)guard), "r"(reinterpret_cast<const uint32_t&>(r0)),
|
||||
"r"(reinterpret_cast<const uint32_t&>(r1)),
|
||||
"r"(reinterpret_cast<const uint32_t&>(r2)),
|
||||
"r"(reinterpret_cast<const uint32_t&>(r3)));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void ldsm_4(T& r0, T& r1, T& r2, T& r3,
|
||||
const uint32_t& addr) {
|
||||
static_assert(sizeof(T) == 4, "ldsm_4: invalid T");
|
||||
#if (__CUDA_ARCH__ >= 750) && (__CUDACC_VER_MAJOR__ >= 11)
|
||||
asm volatile(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(reinterpret_cast<uint32_t&>(r0)),
|
||||
"=r"(reinterpret_cast<uint32_t&>(r1)),
|
||||
"=r"(reinterpret_cast<uint32_t&>(r2)),
|
||||
"=r"(reinterpret_cast<uint32_t&>(r3))
|
||||
: "r"(addr));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename FType>
|
||||
__device__ __forceinline__ void hmma16816_f32(float (&d)[4],
|
||||
const uint32_t (&a)[4],
|
||||
const uint32_t (&b)[2]);
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ void hmma16816_f32<__half>(float (&d)[4],
|
||||
const uint32_t (&a)[4],
|
||||
const uint32_t (&b)[2]) {
|
||||
#if (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11)
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, "
|
||||
"{%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};\n"
|
||||
: "+f"(d[0]), "+f"(d[1]), "+f"(d[2]), "+f"(d[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ void hmma16816_f32<__nv_bfloat16>(
|
||||
float (&d)[4], const uint32_t (&a)[4], const uint32_t (&b)[2]) {
|
||||
#if (__CUDA_ARCH__ >= 800) && (__CUDACC_VER_MAJOR__ >= 11)
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, "
|
||||
"{%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};\n"
|
||||
: "+f"(d[0]), "+f"(d[1]), "+f"(d[2]), "+f"(d[3])
|
||||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <int SIZE_IN_BYTES>
|
||||
__device__ __forceinline__ void cp_async(const uint32_t smem_addr,
|
||||
const void* gmem_ptr,
|
||||
const int src_in_bytes, bool guard) {
|
||||
static_assert(
|
||||
(SIZE_IN_BYTES == 4 || SIZE_IN_BYTES == 8 || SIZE_IN_BYTES == 16),
|
||||
"Size is not supported");
|
||||
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
|
||||
asm volatile(
|
||||
"{.reg.pred p;\n"
|
||||
" setp.ne.b32 p, %4, 0;\n"
|
||||
#if __CUDACC_VER_MINOR__ >= 4
|
||||
" @p cp.async.cg.shared.global.L2::256B [%0], [%1], %2, %3;}\n"
|
||||
#else
|
||||
" @p cp.async.cg.shared.global [%0], [%1], %2, %3;}\n"
|
||||
#endif
|
||||
::"r"(smem_addr),
|
||||
"l"(gmem_ptr), "n"(SIZE_IN_BYTES), "r"(src_in_bytes), "r"((int)guard));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <int SIZE_IN_BYTES>
|
||||
__device__ __forceinline__ void cp_async_ca(const uint32_t smem_addr,
|
||||
const void* gmem_ptr,
|
||||
const int src_in_bytes,
|
||||
bool guard) {
|
||||
static_assert(
|
||||
(SIZE_IN_BYTES == 4 || SIZE_IN_BYTES == 8 || SIZE_IN_BYTES == 16),
|
||||
"Size is not supported");
|
||||
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
|
||||
asm volatile(
|
||||
"{.reg.pred p;\n"
|
||||
" setp.ne.b32 p, %4, 0;\n"
|
||||
#if __CUDACC_VER_MINOR__ >= 4
|
||||
" @p cp.async.ca.shared.global.L2::256B [%0], [%1], %2, %3;}\n"
|
||||
#else
|
||||
" @p cp.async.ca.shared.global [%0], [%1], %2, %3;}\n"
|
||||
#endif
|
||||
::"r"(smem_addr),
|
||||
"l"(gmem_ptr), "n"(SIZE_IN_BYTES), "r"(src_in_bytes), "r"((int)guard));
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void cp_async_commit_group() {
|
||||
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
|
||||
asm volatile("cp.async.commit_group;\n");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <int N>
|
||||
__device__ __forceinline__ void cp_asyc_wait_group() {
|
||||
#if __CUDACC_VER_MAJOR__ >= 11 && __CUDA_ARCH__ >= 800
|
||||
asm volatile("cp.async.wait_group %0;\n" : : "n"(N));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ void cvt_8bx4_to_16bx4_bias128(const uint32_t& idata,
|
||||
T* fdata);
|
||||
|
||||
template <>
|
||||
// fast conversion: 4xuint8 to 4xhalf, subtracting bias = 128
|
||||
__device__ __forceinline__ void cvt_8bx4_to_16bx4_bias128<__half2>(
|
||||
const uint32_t& idata, __half2* fdata) {
|
||||
uint32_t i10, i32;
|
||||
asm volatile(
|
||||
"prmt.b32 %0, %2, 0x64, 0x4140;"
|
||||
"prmt.b32 %1, %2, 0x64, 0x4342;"
|
||||
: "=r"(i10), "=r"(i32)
|
||||
: "r"(idata));
|
||||
|
||||
static constexpr uint32_t MAGIC_NUM = 0x64806480;
|
||||
fdata[0] = __hsub2(reinterpret_cast<const __half2&>(i10),
|
||||
reinterpret_cast<const __half2&>(MAGIC_NUM));
|
||||
fdata[1] = __hsub2(reinterpret_cast<const __half2&>(i32),
|
||||
reinterpret_cast<const __half2&>(MAGIC_NUM));
|
||||
}
|
||||
|
||||
template <>
|
||||
// fast conversion: 4xuint8 to 4xbfloat16, subtracting bias = 128
|
||||
// reference from marlin fast implementation
|
||||
__device__ __forceinline__ void cvt_8bx4_to_16bx4_bias128<__nv_bfloat162>(
|
||||
const uint32_t& idata, __nv_bfloat162* fdata) {
|
||||
float fp32_imd[4];
|
||||
uint32_t* fp32_imd_casted = reinterpret_cast<uint32_t*>(fp32_imd);
|
||||
asm volatile(
|
||||
"prmt.b32 %0, %4, 0x4B000000, 0x7650;"
|
||||
"prmt.b32 %1, %4, 0x4B000000, 0x7651;"
|
||||
"prmt.b32 %2, %4, 0x4B000000, 0x7652;"
|
||||
"prmt.b32 %3, %4, 0x4B000000, 0x7653;"
|
||||
: "=r"(fp32_imd_casted[0]), "=r"(fp32_imd_casted[1]),
|
||||
"=r"(fp32_imd_casted[2]), "=r"(fp32_imd_casted[3])
|
||||
: "r"(idata));
|
||||
|
||||
fp32_imd[0] -= 8388736.f;
|
||||
fp32_imd[1] -= 8388736.f;
|
||||
fp32_imd[2] -= 8388736.f;
|
||||
fp32_imd[3] -= 8388736.f;
|
||||
|
||||
uint32_t* bf16_res = reinterpret_cast<uint32_t*>(fdata);
|
||||
asm volatile(
|
||||
"prmt.b32 %0, %2, %3, 0x7632;"
|
||||
"prmt.b32 %1, %4, %5, 0x7632;"
|
||||
: "=r"(bf16_res[0]), "=r"(bf16_res[1])
|
||||
: "r"(fp32_imd_casted[0]), "r"(fp32_imd_casted[1]),
|
||||
"r"(fp32_imd_casted[2]), "r"(fp32_imd_casted[3]));
|
||||
}
|
||||
|
||||
static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
return __bfloat162bfloat162(x);
|
||||
#endif
|
||||
__builtin_unreachable(); // Suppress missing return statement warning
|
||||
}
|
||||
|
||||
static __device__ half2 inline num2num2(const half x) {
|
||||
return __half2half2(x);
|
||||
}
|
||||
|
||||
} // namespace allspark
|
||||
@ -8,7 +8,7 @@ from collections.abc import Iterable
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, fields
|
||||
from functools import reduce
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import jinja2
|
||||
# yapf conflicts with isort for this block
|
||||
@ -247,8 +247,8 @@ TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleConfig:
|
||||
tile_shape_mn: Tuple[int, int]
|
||||
cluster_shape_mnk: Tuple[int, int, int]
|
||||
tile_shape_mn: tuple[int, int]
|
||||
cluster_shape_mnk: tuple[int, int, int]
|
||||
kernel_schedule: MixedInputKernelScheduleType
|
||||
epilogue_schedule: EpilogueScheduleType
|
||||
tile_scheduler: TileSchedulerType
|
||||
@ -277,8 +277,8 @@ class PrepackTypeConfig:
|
||||
@dataclass
|
||||
class ImplConfig:
|
||||
types: TypeConfig
|
||||
schedules: List[ScheduleConfig]
|
||||
heuristic: List[Tuple[Optional[str], ScheduleConfig]]
|
||||
schedules: list[ScheduleConfig]
|
||||
heuristic: list[tuple[Optional[str], ScheduleConfig]]
|
||||
|
||||
|
||||
def generate_sch_sig(schedule_config: ScheduleConfig) -> str:
|
||||
@ -333,7 +333,7 @@ def is_power_of_two(n):
|
||||
return (n != 0) and (n & (n - 1) == 0)
|
||||
|
||||
|
||||
def to_cute_constant(value: List[int]):
|
||||
def to_cute_constant(value: list[int]):
|
||||
|
||||
def _to_cute_constant(value: int):
|
||||
if is_power_of_two(value):
|
||||
@ -347,7 +347,7 @@ def to_cute_constant(value: List[int]):
|
||||
return _to_cute_constant(value)
|
||||
|
||||
|
||||
def unique_schedules(impl_configs: List[ImplConfig]):
|
||||
def unique_schedules(impl_configs: list[ImplConfig]):
|
||||
return list(
|
||||
set(sch for impl_config in impl_configs
|
||||
for sch in impl_config.schedules))
|
||||
@ -391,7 +391,7 @@ mm_impl_template = create_template(IMPL_TEMPLATE)
|
||||
prepack_dispatch_template = create_template(PREPACK_TEMPLATE)
|
||||
|
||||
|
||||
def create_sources(impl_configs: List[ImplConfig], num_impl_files=8):
|
||||
def create_sources(impl_configs: list[ImplConfig], num_impl_files=8):
|
||||
sources = []
|
||||
|
||||
sources.append((
|
||||
@ -435,7 +435,7 @@ def create_sources(impl_configs: List[ImplConfig], num_impl_files=8):
|
||||
num_impls = reduce(lambda x, y: x + len(y.schedules), impl_configs, 0)
|
||||
num_impls_per_file = math.ceil(num_impls / num_impl_files)
|
||||
|
||||
files_impls: List[List[ImplConfig]] = [[]]
|
||||
files_impls: list[list[ImplConfig]] = [[]]
|
||||
|
||||
curr_num_impls_assigned = 0
|
||||
curr_impl_in_file = 0
|
||||
@ -515,7 +515,7 @@ def generate():
|
||||
for cond, tile_config in default_tile_heuristic_config.items()
|
||||
]
|
||||
|
||||
def get_unique_schedules(heuristic: Dict[str, ScheduleConfig]):
|
||||
def get_unique_schedules(heuristic: dict[str, ScheduleConfig]):
|
||||
# Do not use schedules = list(set(...)) because we need to make sure
|
||||
# the output list is deterministic; otherwise the generated kernel file
|
||||
# will be non-deterministic and causes ccache miss.
|
||||
|
||||
@ -126,15 +126,10 @@ struct MacheteKernelTemplate {
|
||||
std::is_same_v<ElementSChannel, ElementSToken>),
|
||||
"Currently token and channel scales (if present) must be the same type");
|
||||
|
||||
using EpilogueDescriptor =
|
||||
cutlass::epilogue::collective::detail::EpilogueDescriptor<
|
||||
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
|
||||
ElementD, EpilogueSchedule>;
|
||||
|
||||
// Currently only supports float scales
|
||||
using ChTokScalesEpilogue =
|
||||
typename vllm::c3x::ScaledEpilogue<ElementAccumulator, ElementD,
|
||||
EpilogueDescriptor>;
|
||||
TileShape>;
|
||||
static_assert((with_channel_scales || with_token_scales) ||
|
||||
(std::is_same_v<ElementSChannel, float> &&
|
||||
std::is_same_v<ElementSToken, float>),
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -65,12 +65,7 @@ struct cutlass_sparse_3x_gemm {
|
||||
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
||||
float>::type;
|
||||
|
||||
using EpilogueDescriptor =
|
||||
cutlass::epilogue::collective::detail::EpilogueDescriptor<
|
||||
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
|
||||
ElementD, EpilogueSchedule>;
|
||||
|
||||
using Epilogue = Epilogue_<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||
using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>;
|
||||
|
||||
using ElementC = void;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
|
||||
@ -302,6 +302,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"SymInt size_k) -> Tensor");
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
// CUTLASS nvfp4 block scaled GEMM
|
||||
ops.def(
|
||||
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
|
||||
" Tensor block_scale_a, Tensor block_scale_b,"
|
||||
" Tensor alpha) -> ()");
|
||||
ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);
|
||||
|
||||
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
||||
// quantization, as well as bias
|
||||
ops.def(
|
||||
@ -440,6 +447,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"Tensor!? azp) -> ()");
|
||||
ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
|
||||
&dynamic_scaled_int8_quant);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
|
||||
ops.def(
|
||||
"rearrange_kn_weight_as_n32k16_order(Tensor b_qweight, Tensor b_scales, "
|
||||
"Tensor? b_zeros, "
|
||||
"bool has_zp, Tensor! b_qweight_reorder, Tensor! b_scales_reorder, "
|
||||
"Tensor!? b_zeros_reorder, "
|
||||
"int K, int N, int N_32align) -> ()");
|
||||
// conditionally compiled so impl in source file
|
||||
|
||||
// AllSpark quantization ops
|
||||
ops.def(
|
||||
"allspark_w8a16_gemm(Tensor a, Tensor b_qweight, Tensor b_scales, "
|
||||
"Tensor? b_qzeros, "
|
||||
"SymInt n, SymInt group_size, SymInt sm_count, SymInt sm_version, SymInt "
|
||||
"CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) -> Tensor");
|
||||
// conditionally compiled so impl in source file
|
||||
#endif
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
@ -493,6 +519,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
|
||||
"str kv_cache_dtype) -> ()");
|
||||
cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
|
||||
|
||||
// Gather cache blocks from src_cache to dst.
|
||||
cache_ops.def(
|
||||
"gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, "
|
||||
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()");
|
||||
cache_ops.impl("gather_cache", torch::kCUDA, &gather_cache);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 115 KiB After Width: | Height: | Size: 118 KiB |
BIN
docs/source/assets/design/v1/metrics/intervals-1.png
Normal file
BIN
docs/source/assets/design/v1/metrics/intervals-1.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 185 KiB |
BIN
docs/source/assets/design/v1/metrics/intervals-2.png
Normal file
BIN
docs/source/assets/design/v1/metrics/intervals-2.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 162 KiB |
BIN
docs/source/assets/design/v1/metrics/intervals-3.png
Normal file
BIN
docs/source/assets/design/v1/metrics/intervals-3.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 161 KiB |
@ -17,7 +17,6 @@ import inspect
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
import requests
|
||||
from sphinx.ext import autodoc
|
||||
@ -58,7 +57,7 @@ templates_path = ['_templates']
|
||||
# List of patterns, relative to source directory, that match files and
|
||||
# directories to ignore when looking for source files.
|
||||
# This pattern also affects html_static_path and html_extra_path.
|
||||
exclude_patterns: List[str] = ["**/*.template.md", "**/*.inc.md"]
|
||||
exclude_patterns: list[str] = ["**/*.template.md", "**/*.inc.md"]
|
||||
|
||||
# Exclude the prompt "$" when copying code
|
||||
copybutton_prompt_text = r"\$ "
|
||||
|
||||
@ -74,8 +74,6 @@ def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
```
|
||||
|
||||
@ -16,8 +16,6 @@ Further update the model as follows:
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
+ pixel_values: torch.Tensor,
|
||||
) -> SamplerOutput:
|
||||
```
|
||||
@ -722,13 +720,13 @@ def _get_mm_fields_config(
|
||||
|
||||
:::::
|
||||
|
||||
### Prompt replacements
|
||||
### Prompt updates
|
||||
|
||||
Override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements` to
|
||||
return a list of {class}`~vllm.multimodal.processing.PromptReplacement` instances.
|
||||
Override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates` to
|
||||
return a list of {class}`~vllm.multimodal.processing.PromptUpdate` instances.
|
||||
|
||||
Each {class}`~vllm.multimodal.processing.PromptReplacement` instance specifies a find-and-replace
|
||||
operation performed by the HF processor.
|
||||
Each {class}`~vllm.multimodal.processing.PromptUpdate` instance specifies an update operation
|
||||
(e.g.: insertion, replacement) performed by the HF processor.
|
||||
|
||||
::::{tab-set}
|
||||
:::{tab-item} Basic example: LLaVA
|
||||
@ -745,15 +743,15 @@ for sample in text:
|
||||
```
|
||||
|
||||
It simply repeats each input `image_token` a number of times equal to the number of placeholder feature tokens (`num_image_tokens`).
|
||||
Based on this, we override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements` as follows:
|
||||
Based on this, we override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates` as follows:
|
||||
|
||||
```python
|
||||
def _get_prompt_replacements(
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_config = self.info.get_hf_config()
|
||||
image_token_id = hf_config.image_token_index
|
||||
|
||||
@ -861,7 +859,7 @@ prompt_tokens, prompts_length = _tokenize_prompts_with_image_and_batch(
|
||||
)
|
||||
```
|
||||
|
||||
To accommodate this, instead of a string you can return an instance of `PromptReplacementDetails`
|
||||
To accommodate this, instead of a string you can return an instance of {class}`~vllm.multimodal.processing.PromptUpdateDetails`
|
||||
with different `full` and `feature` attributes:
|
||||
|
||||
```python
|
||||
@ -880,7 +878,7 @@ def get_replacement_fuyu(item_idx: int):
|
||||
image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
|
||||
[_NEWLINE_TOKEN_ID]) * nrows
|
||||
|
||||
return PromptReplacementDetails(
|
||||
return PromptUpdateDetails(
|
||||
full=image_tokens + [bos_token_id],
|
||||
features=image_tokens,
|
||||
)
|
||||
@ -890,12 +888,12 @@ Finally, noticing that the HF processor removes the `|ENDOFTEXT|` token from the
|
||||
we can search for it to conduct the replacement at the start of the string:
|
||||
|
||||
```python
|
||||
def _get_prompt_replacements(
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_config = self.info.get_hf_config()
|
||||
bos_token_id = hf_config.bos_token_id
|
||||
assert isinstance(bos_token_id, int)
|
||||
@ -915,7 +913,7 @@ def _get_prompt_replacements(
|
||||
image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
|
||||
[_NEWLINE_TOKEN_ID]) * nrows
|
||||
|
||||
return PromptReplacementDetails(
|
||||
return PromptUpdateDetails(
|
||||
full=image_tokens + [bos_token_id],
|
||||
features=image_tokens,
|
||||
)
|
||||
@ -950,3 +948,35 @@ to register them to the multi-modal registry:
|
||||
+ dummy_inputs=YourDummyInputsBuilder)
|
||||
class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
### Inserting feature tokens without replacement
|
||||
|
||||
Some HF processors directly insert feature tokens without replacing anything in the original prompt. In that case, you can use {class}`~vllm.multimodal.processing.PromptInsertion` instead of {class}`~vllm.multimodal.processing.PromptReplacement` inside {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates`.
|
||||
|
||||
Examples:
|
||||
|
||||
- BLIP-2 (insert at start of prompt): <gh-file:vllm/model_executor/models/blip2.py>
|
||||
- Florence2 (insert at start of prompt): <gh-file:vllm/model_executor/models/florence2.py>
|
||||
- Molmo (insert after `<|endoftext|>` token): <gh-file:vllm/model_executor/models/molmo.py>
|
||||
|
||||
### Handling prompt updates unrelated to multi-modal data
|
||||
|
||||
{meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates` assumes that each application of prompt update corresponds to one multi-modal item. If the HF processor performs additional processing regardless of how many multi-modal items there are, you should override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._apply_hf_processor_tokens_only` so that the processed token inputs are consistent with the result of applying the HF processor on text inputs. This is because token inputs bypass the HF processor according to [our design](#mm-processing).
|
||||
|
||||
Examples:
|
||||
|
||||
- Chameleon (appends `sep_token`): <gh-file:vllm/model_executor/models/chameleon.py>
|
||||
- Fuyu (appends `boa_token`): <gh-file:vllm/model_executor/models/fuyu.py>
|
||||
- Molmo (applies chat template which is not defined elsewhere): <gh-file:vllm/model_executor/models/molmo.py>
|
||||
|
||||
### Custom HF processor
|
||||
|
||||
Some models don't define a HF processor class on HF Hub. In that case, you can define a custom HF processor that has the same call signature as HF processors and pass it to {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._call_hf_processor`.
|
||||
|
||||
Examples:
|
||||
|
||||
- DeepSeek-VL2: <gh-file:vllm/model_executor/models/deepseek_vl2.py>
|
||||
- InternVL: <gh-file:vllm/model_executor/models/internvl.py>
|
||||
- Qwen-VL: <gh-file:vllm/model_executor/models/qwen_vl.py>
|
||||
|
||||
@ -145,6 +145,9 @@ review process:
|
||||
- Please respond to all comments within a reasonable time frame. If a comment
|
||||
isn't clear or you disagree with a suggestion, feel free to ask for
|
||||
clarification or discuss the suggestion.
|
||||
- Note that not all CI checks will be executed due to limited computational
|
||||
resources. The reviewer will add `ready` label to the PR when the PR is
|
||||
ready to merge or a full CI run is needed.
|
||||
|
||||
## Thank You
|
||||
|
||||
|
||||
@ -27,6 +27,36 @@ container to access the host's shared memory. vLLM uses PyTorch, which uses shar
|
||||
memory to share data between processes under the hood, particularly for tensor parallel inference.
|
||||
:::
|
||||
|
||||
:::{note}
|
||||
Optional dependencies are not included in order to avoid licensing issues (e.g. <gh-issue:8030>).
|
||||
|
||||
If you need to use those dependencies (having accepted the license terms),
|
||||
create a custom Dockerfile on top of the base image with an extra layer that installs them:
|
||||
|
||||
```Dockerfile
|
||||
FROM vllm/vllm-openai:v0.7.3
|
||||
|
||||
# e.g. install the `audio` and `video` optional dependencies
|
||||
# NOTE: Make sure the version of vLLM matches the base image!
|
||||
RUN uv pip install --system vllm[audio,video]==0.7.3
|
||||
```
|
||||
|
||||
:::
|
||||
|
||||
:::{tip}
|
||||
Some new models may only be available on the main branch of [HF Transformers](https://github.com/huggingface/transformers).
|
||||
|
||||
To use the development version of `transformers`, create a custom Dockerfile on top of the base image
|
||||
with an extra layer that installs their code from source:
|
||||
|
||||
```Dockerfile
|
||||
FROM vllm/vllm-openai:latest
|
||||
|
||||
RUN uv pip install --system git+https://github.com/huggingface/transformers.git
|
||||
```
|
||||
|
||||
:::
|
||||
|
||||
(deployment-docker-build-image-from-source)=
|
||||
|
||||
## Building vLLM's Docker Image from Source
|
||||
|
||||
@ -6,4 +6,6 @@
|
||||
kserve
|
||||
kubeai
|
||||
llamastack
|
||||
llmaz
|
||||
production-stack
|
||||
:::
|
||||
|
||||
7
docs/source/deployment/integrations/llmaz.md
Normal file
7
docs/source/deployment/integrations/llmaz.md
Normal file
@ -0,0 +1,7 @@
|
||||
(deployment-llmaz)=
|
||||
|
||||
# llmaz
|
||||
|
||||
[llmaz](https://github.com/InftyAI/llmaz) is an easy-to-use and advanced inference platform for large language models on Kubernetes, aimed for production use. It uses vLLM as the default model serving backend.
|
||||
|
||||
Please refer to the [Quick Start](https://github.com/InftyAI/llmaz?tab=readme-ov-file#quick-start) for more details.
|
||||
154
docs/source/deployment/integrations/production-stack.md
Normal file
154
docs/source/deployment/integrations/production-stack.md
Normal file
@ -0,0 +1,154 @@
|
||||
(deployment-production-stack)=
|
||||
|
||||
# Production stack
|
||||
|
||||
Deploying vLLM on Kubernetes is a scalable and efficient way to serve machine learning models. This guide walks you through deploying vLLM using the [vLLM production stack](https://github.com/vllm-project/production-stack). Born out of a Berkeley-UChicago collaboration, [vLLM production stack](https://github.com/vllm-project/production-stack) is an officially released, production-optimized codebase under the [vLLM project](https://github.com/vllm-project), designed for LLM deployment with:
|
||||
|
||||
* **Upstream vLLM compatibility** – It wraps around upstream vLLM without modifying its code.
|
||||
* **Ease of use** – Simplified deployment via Helm charts and observability through Grafana dashboards.
|
||||
* **High performance** – Optimized for LLM workloads with features like multi-model support, model-aware and prefix-aware routing, fast vLLM bootstrapping, and KV cache offloading with [LMCache](https://github.com/LMCache/LMCache), among others.
|
||||
|
||||
If you are new to Kubernetes, don't worry: in the vLLM production stack [repo](https://github.com/vllm-project/production-stack), we provide a step-by-step [guide](https://github.com/vllm-project/production-stack/blob/main/tutorials/00-install-kubernetes-env.md) and a [short video](https://www.youtube.com/watch?v=EsTJbQtzj0g) to set up everything and get started in **4 minutes**!
|
||||
|
||||
## Pre-requisite
|
||||
|
||||
Ensure that you have a running Kubernetes environment with GPU (you can follow [this tutorial](https://github.com/vllm-project/production-stack/blob/main/tutorials/00-install-kubernetes-env.md) to install a Kubernetes environment on a bare-medal GPU machine).
|
||||
|
||||
## Deployment using vLLM production stack
|
||||
|
||||
The standard vLLM production stack install uses a Helm chart. You can run this [bash script](https://github.com/vllm-project/production-stack/blob/main/tutorials/install-helm.sh) to install Helm on your GPU server.
|
||||
|
||||
To install the vLLM production stack, run the following commands on your desktop:
|
||||
|
||||
```bash
|
||||
sudo helm repo add vllm https://vllm-project.github.io/production-stack
|
||||
sudo helm install vllm vllm/vllm-stack -f tutorials/assets/values-01-minimal-example.yaml
|
||||
```
|
||||
|
||||
This will instantiate a vLLM-production-stack-based deployment named `vllm` that runs a small LLM (Facebook opt-125M model).
|
||||
|
||||
### Validate Installation
|
||||
|
||||
Monitor the deployment status using:
|
||||
|
||||
```bash
|
||||
sudo kubectl get pods
|
||||
```
|
||||
|
||||
And you will see that pods for the `vllm` deployment will transit to `Running` state.
|
||||
|
||||
```text
|
||||
NAME READY STATUS RESTARTS AGE
|
||||
vllm-deployment-router-859d8fb668-2x2b7 1/1 Running 0 2m38s
|
||||
vllm-opt125m-deployment-vllm-84dfc9bd7-vb9bs 1/1 Running 0 2m38s
|
||||
```
|
||||
|
||||
**NOTE**: It may take some time for the containers to download the Docker images and LLM weights.
|
||||
|
||||
### Send a Query to the Stack
|
||||
|
||||
Forward the `vllm-router-service` port to the host machine:
|
||||
|
||||
```bash
|
||||
sudo kubectl port-forward svc/vllm-router-service 30080:80
|
||||
```
|
||||
|
||||
And then you can send out a query to the OpenAI-compatible API to check the available models:
|
||||
|
||||
```bash
|
||||
curl -o- http://localhost:30080/models
|
||||
```
|
||||
|
||||
Expected output:
|
||||
|
||||
```json
|
||||
{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": "facebook/opt-125m",
|
||||
"object": "model",
|
||||
"created": 1737428424,
|
||||
"owned_by": "vllm",
|
||||
"root": null
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
To send an actual chatting request, you can issue a curl request to the OpenAI `/completion` endpoint:
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:30080/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "facebook/opt-125m",
|
||||
"prompt": "Once upon a time,",
|
||||
"max_tokens": 10
|
||||
}'
|
||||
```
|
||||
|
||||
Expected output:
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "completion-id",
|
||||
"object": "text_completion",
|
||||
"created": 1737428424,
|
||||
"model": "facebook/opt-125m",
|
||||
"choices": [
|
||||
{
|
||||
"text": " there was a brave knight who...",
|
||||
"index": 0,
|
||||
"finish_reason": "length"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Uninstall
|
||||
|
||||
To remove the deployment, run:
|
||||
|
||||
```bash
|
||||
sudo helm uninstall vllm
|
||||
```
|
||||
|
||||
------
|
||||
|
||||
### (Advanced) Configuring vLLM production stack
|
||||
|
||||
The core vLLM production stack configuration is managed with YAML. Here is the example configuration used in the installation above:
|
||||
|
||||
```yaml
|
||||
servingEngineSpec:
|
||||
runtimeClassName: ""
|
||||
modelSpec:
|
||||
- name: "opt125m"
|
||||
repository: "vllm/vllm-openai"
|
||||
tag: "latest"
|
||||
modelURL: "facebook/opt-125m"
|
||||
|
||||
replicaCount: 1
|
||||
|
||||
requestCPU: 6
|
||||
requestMemory: "16Gi"
|
||||
requestGPU: 1
|
||||
|
||||
pvcStorage: "10Gi"
|
||||
```
|
||||
|
||||
In this YAML configuration:
|
||||
* **`modelSpec`** includes:
|
||||
* `name`: A nickname that you prefer to call the model.
|
||||
* `repository`: Docker repository of vLLM.
|
||||
* `tag`: Docker image tag.
|
||||
* `modelURL`: The LLM model that you want to use.
|
||||
* **`replicaCount`**: Number of replicas.
|
||||
* **`requestCPU` and `requestMemory`**: Specifies the CPU and memory resource requests for the pod.
|
||||
* **`requestGPU`**: Specifies the number of GPUs required.
|
||||
* **`pvcStorage`**: Allocates persistent storage for the model.
|
||||
|
||||
**NOTE:** If you intend to set up two pods, please refer to this [YAML file](https://github.com/vllm-project/production-stack/blob/main/tutorials/assets/values-01-2pods-minimal-example.yaml).
|
||||
|
||||
**NOTE:** vLLM production stack offers many more features (*e.g.* CPU offloading and a wide range of routing algorithms). Please check out these [examples and tutorials](https://github.com/vllm-project/production-stack/tree/main/tutorials) and our [repo](https://github.com/vllm-project/production-stack) for more details!
|
||||
@ -2,17 +2,21 @@
|
||||
|
||||
# Using Kubernetes
|
||||
|
||||
Using Kubernetes to deploy vLLM is a scalable and efficient way to serve machine learning models. This guide will walk you through the process of deploying vLLM with Kubernetes, including the necessary prerequisites, steps for deployment, and testing.
|
||||
Deploying vLLM on Kubernetes is a scalable and efficient way to serve machine learning models. This guide walks you through deploying vLLM using native Kubernetes.
|
||||
|
||||
## Prerequisites
|
||||
--------
|
||||
|
||||
Before you begin, ensure that you have the following:
|
||||
Alternatively, you can also deploy Kubernetes using [helm chart](https://docs.vllm.ai/en/latest/deployment/frameworks/helm.html). There are also open-source projects available to make your deployment even smoother.
|
||||
|
||||
- A running Kubernetes cluster
|
||||
- NVIDIA Kubernetes Device Plugin (`k8s-device-plugin`): This can be found at `https://github.com/NVIDIA/k8s-device-plugin/`
|
||||
- Available GPU resources in your cluster
|
||||
* [vLLM production-stack](https://github.com/vllm-project/production-stack): Born out of a Berkeley-UChicago collaboration, vLLM production stack is a project that contains latest research and community effort, while still delivering production-level stability and performance. Checkout the [documentation page](https://docs.vllm.ai/en/latest/deployment/integrations/production-stack.html) for more details and examples.
|
||||
|
||||
## Deployment Steps
|
||||
--------
|
||||
|
||||
## Pre-requisite
|
||||
|
||||
Ensure that you have a running Kubernetes environment with GPU (you can follow [this tutorial](https://github.com/vllm-project/production-stack/blob/main/tutorials/00-install-kubernetes-env.md) to install a Kubernetes environment on a bare-medal GPU machine).
|
||||
|
||||
## Deployment using native K8s
|
||||
|
||||
1. Create a PVC, Secret and Deployment for vLLM
|
||||
|
||||
|
||||
@ -95,14 +95,14 @@ Notes:
|
||||
|
||||
- If you have your HuggingFace models cached somewhere else, update `hf_cache_dir` below.
|
||||
- If you don't have an existing HuggingFace cache you will want to start `vllm0` and wait for the model to complete downloading and the server to be ready. This will ensure that `vllm1` can leverage the model you just downloaded and it won't have to be downloaded again.
|
||||
- The below example assumes GPU backend used. If you are using CPU backend, remove `--gpus all`, add `VLLM_CPU_KVCACHE_SPACE` and `VLLM_CPU_OMP_THREADS_BIND` environment variables to the docker run command.
|
||||
- The below example assumes GPU backend used. If you are using CPU backend, remove `--gpus device=ID`, add `VLLM_CPU_KVCACHE_SPACE` and `VLLM_CPU_OMP_THREADS_BIND` environment variables to the docker run command.
|
||||
- Adjust the model name that you want to use in your vLLM servers if you don't want to use `Llama-2-7b-chat-hf`.
|
||||
|
||||
```console
|
||||
mkdir -p ~/.cache/huggingface/hub/
|
||||
hf_cache_dir=~/.cache/huggingface/
|
||||
docker run -itd --ipc host --privileged --network vllm_nginx --gpus all --shm-size=10.24gb -v $hf_cache_dir:/root/.cache/huggingface/ -p 8081:8000 --name vllm0 vllm --model meta-llama/Llama-2-7b-chat-hf
|
||||
docker run -itd --ipc host --privileged --network vllm_nginx --gpus all --shm-size=10.24gb -v $hf_cache_dir:/root/.cache/huggingface/ -p 8082:8000 --name vllm1 vllm --model meta-llama/Llama-2-7b-chat-hf
|
||||
docker run -itd --ipc host --network vllm_nginx --gpus device=0 --shm-size=10.24gb -v $hf_cache_dir:/root/.cache/huggingface/ -p 8081:8000 --name vllm0 vllm --model meta-llama/Llama-2-7b-chat-hf
|
||||
docker run -itd --ipc host --network vllm_nginx --gpus device=1 --shm-size=10.24gb -v $hf_cache_dir:/root/.cache/huggingface/ -p 8082:8000 --name vllm1 vllm --model meta-llama/Llama-2-7b-chat-hf
|
||||
```
|
||||
|
||||
:::{note}
|
||||
|
||||
@ -6,11 +6,16 @@ To enable various optimizations in vLLM such as [chunked prefill](#chunked-prefi
|
||||
|
||||
Here are the main features of {class}`~vllm.multimodal.processing.BaseMultiModalProcessor`:
|
||||
|
||||
## Prompt Replacement Detection
|
||||
## Prompt Update Detection
|
||||
|
||||
One of the main responsibilies of HF processor is to replace input placeholder tokens (e.g. `<image>` for a single image) with feature placeholder tokens (e.g. `<image><image>...<image>`, the number of which equals to the feature size). The information about which tokens have been replaced is key to finding the correspondence between placeholder feature tokens and multi-modal inputs.
|
||||
One of the main responsibilies of HF processor is to update the prompt with placeholder tokens. For example:
|
||||
|
||||
In vLLM, this information is specified using {class}`~vllm.multimodal.processing.PromptReplacement` in {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements`. Given this specification, we can automatically detect whether HF has replaced the input placeholder tokens by checking whether the feature placeholder tokens exist in the prompt.
|
||||
- Insert feature placeholder tokens (e.g. `<image><image>...<image>`, the number of which equals to the feature size) at the start of the string.
|
||||
- Replace existing input placeholder tokens (e.g. `<image>` for a single image) with feature placeholder tokens (e.g. `<image><image>...<image>`, the number of which equals to the feature size).
|
||||
|
||||
The information about which tokens have been updated is key to finding the correspondence between placeholder feature tokens and multi-modal inputs.
|
||||
|
||||
In vLLM, this information is specified using {class}`~vllm.multimodal.processing.PromptUpdate` in {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates`. We can automatically detect whether HF has updated the prompt by checking the existence of the updated tokens.
|
||||
|
||||
## Tokenized Prompt Inputs
|
||||
|
||||
@ -22,7 +27,7 @@ Consider that HF processors follow these main steps:
|
||||
|
||||
1. Tokenize the text
|
||||
2. Process multi-modal inputs
|
||||
3. Perform prompt replacement
|
||||
3. Perform prompt updates
|
||||
|
||||
And we require that:
|
||||
|
||||
@ -44,16 +49,16 @@ Moreover, since the tokenized text has not passed through the HF processor, we h
|
||||
|
||||
We work around the first issue by requiring each model to define how to generate dummy text based on the number of multi-modal inputs, via {meth}`~vllm.multimodal.profiling.BaseDummyInputsBuilder.get_dummy_processor_inputs`. This lets us generate dummy text corresponding to the multi-modal inputs and input them together to obtain the processed multi-modal data.
|
||||
|
||||
(mm-automatic-prompt-replacement)=
|
||||
(mm-automatic-prompt-updating)=
|
||||
|
||||
### Automatic prompt replacement
|
||||
### Automatic prompt updating
|
||||
|
||||
We address the second issue by implementing model-agnostic code in
|
||||
{meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._apply_prompt_replacements` to automatically replace input placeholder tokens with feature placeholder tokens based on the specification outputted by {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements`.
|
||||
{meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._apply_prompt_updates` to automatically update the prompt with feature placeholder tokens based on the specification outputted by {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates`.
|
||||
|
||||
### Summary
|
||||
|
||||
With the help of dummy text and automatic prompt replacement, our multi-modal processor can finally accept both text and token prompts with multi-modal data. The detailed logic is shown in {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._apply_hf_processor_main`.
|
||||
With the help of dummy text and automatic prompt updating, our multi-modal processor can finally accept both text and token prompts with multi-modal data. The detailed logic is shown in {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._apply_hf_processor_main`.
|
||||
|
||||
## Processor Output Caching
|
||||
|
||||
@ -61,4 +66,4 @@ Some HF processors, such as the one for Qwen2-VL, are [very slow](gh-issue:9238)
|
||||
|
||||
When new data is passed in, we first check which items are in the cache, and which ones are missing. The missing items are passed into the HF processor in a single batch and cached, before being merged with the existing items in the cache.
|
||||
|
||||
Since we only process the missing multi-modal data items, the number of input placeholder tokens no longer corresponds to the number of the multi-modal inputs, so they can't be passed alongside the text prompt to HF processor. Therefore, we process the text and multi-modal inputs separately, using [dummy text](#mm-dummy-text) to avoid HF errors. Since this skips HF's prompt replacement code, we apply [automatic prompt replacement](#mm-automatic-prompt-replacement) afterwards to keep the output tokens and multi-modal data consistent with each other.
|
||||
Since we only process the missing multi-modal data items, the number of input placeholder tokens no longer corresponds to the number of the multi-modal inputs, so they can't be passed alongside the text prompt to HF processor. Therefore, we process the text and multi-modal inputs separately, using [dummy text](#mm-dummy-text) to avoid HF errors. Since this skips HF's prompt updating code, we apply [automatic prompt updating](#mm-automatic-prompt-updating) afterwards to keep the output tokens and multi-modal data consistent with each other.
|
||||
|
||||
712
docs/source/design/v1/metrics.md
Normal file
712
docs/source/design/v1/metrics.md
Normal file
@ -0,0 +1,712 @@
|
||||
# Metrics
|
||||
|
||||
Ensure the v1 LLM Engine exposes a superset of the metrics available in v0.
|
||||
|
||||
## Objectives
|
||||
|
||||
- Achieve parity of metrics between v0 and v1.
|
||||
- The priority use case is accessing these metrics via Prometheus as this is what we expect to be used in production environments.
|
||||
- Logging support - i.e. printing metrics to the info log - is provided for more ad-hoc testing, debugging, development, and exploratory use cases.
|
||||
|
||||
## Background
|
||||
|
||||
Metrics in vLLM can be categorized as follows:
|
||||
|
||||
1. Server-level metrics: these are global metrics that track the state and performance of the LLM engine. These are typically exposed as Gauges or Counters in Prometheus.
|
||||
2. Request-level metrics: these are metrics that track the characteristics - e.g. size and timing - of individual requests. These are typically exposed as Histrograms in Prometheus, and are often the SLO that an SRE monitoring vLLM will be tracking.
|
||||
|
||||
The mental model is that the "Server-level Metrics" explain why the "Request-level Metrics" are what they are.
|
||||
|
||||
### v0 Metrics
|
||||
|
||||
In v0, the following metrics are exposed via a Prometheus-compatible `/metrics` endpoint using the `vllm:` prefix:
|
||||
|
||||
- `vllm:num_requests_running` (Gauge)
|
||||
- `vllm:num_requests_swapped` (Gauge)
|
||||
- `vllm:num_requests_waiting` (Gauge)
|
||||
- `vllm:gpu_cache_usage_perc` (Gauge)
|
||||
- `vllm:cpu_cache_usage_perc` (Gauge)
|
||||
- `vllm:gpu_prefix_cache_hit_rate` (Gauge)
|
||||
- `vllm:cpu_prefix_cache_hit_rate` (Gauge)
|
||||
- `vllm:prompt_tokens_total` (Counter)
|
||||
- `vllm:generation_tokens_total` (Counter)
|
||||
- `vllm:request_success_total` (Counter)
|
||||
- `vllm:request_prompt_tokens` (Histogram)
|
||||
- `vllm:request_generation_tokens` (Histogram)
|
||||
- `vllm:time_to_first_token_seconds` (Histogram)
|
||||
- `vllm:time_per_output_token_seconds` (Histogram)
|
||||
- `vllm:e2e_request_latency_seconds` (Histogram)
|
||||
- `vllm:request_queue_time_seconds` (Histogram)
|
||||
- `vllm:request_inference_time_seconds` (Histogram)
|
||||
- `vllm:request_prefill_time_seconds` (Histogram)
|
||||
- `vllm:request_decode_time_seconds` (Histogram)
|
||||
- `vllm:request_max_num_generation_tokens` (Histogram)
|
||||
- `vllm:num_preemptions_total` (Counter)
|
||||
- `vllm:cache_config_info` (Gauge)
|
||||
- `vllm:lora_requests_info` (Gauge)
|
||||
- `vllm:tokens_total` (Counter)
|
||||
- `vllm:iteration_tokens_total` (Histogram)
|
||||
- `vllm:time_in_queue_requests` (Histogram)
|
||||
- `vllm:model_forward_time_milliseconds` (Histogram
|
||||
- `vllm:model_execute_time_milliseconds` (Histogram)
|
||||
- `vllm:request_params_n` (Histogram)
|
||||
- `vllm:request_params_max_tokens` (Histogram)
|
||||
- `vllm:spec_decode_draft_acceptance_rate` (Gauge)
|
||||
- `vllm:spec_decode_efficiency` (Gauge)
|
||||
- `vllm:spec_decode_num_accepted_tokens_total` (Counter)
|
||||
- `vllm:spec_decode_num_draft_tokens_total` (Counter)
|
||||
- `vllm:spec_decode_num_emitted_tokens_total` (Counter)
|
||||
|
||||
These are documented under [Inferencing and Serving -> Production Metrics](project:../../serving/metrics.md).
|
||||
|
||||
### Grafana Dashboard
|
||||
|
||||
vLLM also provides [a reference example](https://docs.vllm.ai/en/latest/getting_started/examples/prometheus_grafana.html) for how to collect and store these metrics using Prometheus and visualize them using a Grafana dashboard.
|
||||
|
||||
The subset of metrics exposed in the Grafana dashboard gives us an indication of which metrics are especially important:
|
||||
|
||||
- `vllm:e2e_request_latency_seconds_bucket` - End to end request latency measured in seconds
|
||||
- `vllm:prompt_tokens_total` - Prompt Tokens/Sec
|
||||
- `vllm:generation_tokens_total` - Generation Tokens/Sec
|
||||
- `vllm:time_per_output_token_seconds` - Inter token latency (Time Per Output Token, TPOT) in second.
|
||||
- `vllm:time_to_first_token_seconds` - Time to First Token (TTFT) latency in seconds.
|
||||
- `vllm:num_requests_running` (also, `_swapped` and `_waiting`) - Number of requests in RUNNING, WAITING, and SWAPPED state
|
||||
- `vllm:gpu_cache_usage_perc` - Percentage of used cache blocks by vLLM.
|
||||
- `vllm:request_prompt_tokens` - Request prompt length
|
||||
- `vllm:request_generation_tokens` - request generation length
|
||||
- `vllm:request_success_total` - Number of finished requests by their finish reason: either an EOS token was generated or the max sequence length was reached
|
||||
- `vllm:request_queue_time_seconds` - Queue Time
|
||||
- `vllm:request_prefill_time_seconds` - Requests Prefill Time
|
||||
- `vllm:request_decode_time_seconds` - Requests Decode Time
|
||||
- `vllm:request_max_num_generation_tokens` - Max Generation Token in Sequence Group
|
||||
|
||||
See [the PR which added this Dashboard](gh-pr:2316) for interesting and useful background on the choices made here.
|
||||
|
||||
### Prometheus Client Library
|
||||
|
||||
Prometheus support was initially added [using the aioprometheus library](gh-pr:1890), but a switch was made quickly to [prometheus_client](gh-pr:2730). The rationale is discussed in both linked PRs.
|
||||
|
||||
### Multi-process Mode
|
||||
|
||||
In v0, metrics are collected in the engine core process and we use multi-process mode to make them available in the API server process. See <gh-pr:7279>.
|
||||
|
||||
### Built in Python/Process Metrics
|
||||
|
||||
The following metrics are supported by default by `prometheus_client`, but the are not exposed with multiprocess mode is used:
|
||||
|
||||
- `python_gc_objects_collected_total`
|
||||
- `python_gc_objects_uncollectable_total`
|
||||
- `python_gc_collections_total`
|
||||
- `python_info`
|
||||
- `process_virtual_memory_bytes`
|
||||
- `process_resident_memory_bytes`
|
||||
- `process_start_time_seconds`
|
||||
- `process_cpu_seconds_total`
|
||||
- `process_open_fds`
|
||||
- `process_max_fds`
|
||||
|
||||
This is relevant because if we move away from multiprocess mode in v1,
|
||||
we get these back. However, it's questionable how relevant these are
|
||||
if they don't aggregate these stats for all processes that make up a
|
||||
vLLM instance.
|
||||
|
||||
### v0 PRs and Issues
|
||||
|
||||
For background, these are some of the relevant PRs which added the v0 metrics:
|
||||
|
||||
- <gh-pr:1890>
|
||||
- <gh-pr:2316>
|
||||
- <gh-pr:2730>
|
||||
- <gh-pr:4464>
|
||||
- <gh-pr:7279>
|
||||
|
||||
Also note the ["Even Better Observability"](gh-issue:3616) feature where e.g. [a detailed roadmap was laid out](gh-issue:3616#issuecomment-2030858781).
|
||||
|
||||
## v1 Design
|
||||
|
||||
### v1 PRs
|
||||
|
||||
For background, here are the relevant v1 PRs relating to the v1
|
||||
metrics issue <gh-issue:10582>:
|
||||
|
||||
- <gh-pr:11962>
|
||||
- <gh-pr:11973>
|
||||
- <gh-pr:10907>
|
||||
- <gh-pr:12416>
|
||||
- <gh-pr:12478>
|
||||
- <gh-pr:12516>
|
||||
- <gh-pr:12530>
|
||||
- <gh-pr:12561>
|
||||
- <gh-pr:12579>
|
||||
- <gh-pr:12592>
|
||||
- <gh-pr:12644>
|
||||
|
||||
### Metrics Collection
|
||||
|
||||
In v1, we wish to move computation and overhead out of the engine core
|
||||
process to minimize the time between each forward pass.
|
||||
|
||||
The overall idea of V1 EngineCore design is:
|
||||
- EngineCore is the inner loop. Performance is most critical here
|
||||
- AsyncLLM is the outer loop. This is overlapped with GPU execution
|
||||
(ideally), so this is where any "overheads" should be if
|
||||
possible. So AsyncLLM.output_handler_loop is the ideal place for the
|
||||
metrics bookkeeping if possible.
|
||||
|
||||
We will achieve this by collecting metrics in the frontend API server,
|
||||
and base these metrics on information we can glean from the
|
||||
`EngineCoreOutputs` returned by the engine core process to the
|
||||
frontend.
|
||||
|
||||
### Interval Calculations
|
||||
|
||||
Many of our metrics are the time interval between various events in
|
||||
the processing of a request. It is best practice to use timestamps
|
||||
based on "monotonic time" (`time.monotonic()`) rather than "wall-clock
|
||||
time" (`time.time()`) to calculate intervals as the former is
|
||||
unaffected by system clock changes (e.g. from NTP).
|
||||
|
||||
It's also important to note that monotonic clocks differ between
|
||||
processes - each process has its own reference. point. So it is
|
||||
meaningless to compare monotonic timestamps from different processes.
|
||||
|
||||
Therefore, in order to calculate an interval, we must compare two
|
||||
monotonic timestamps from the same process.
|
||||
|
||||
### Scheduler Stats
|
||||
|
||||
The engine core process will collect some key statistics from the
|
||||
scheduler - e.g. the number of requests that were scheduled or waiting
|
||||
after the last scheduler pass - and include those statistics in
|
||||
`EngineCoreOutputs`.
|
||||
|
||||
### Engine Core Events
|
||||
|
||||
The engine core will also record the timestamp of certain per-request
|
||||
events so that the frontend can calculate the interval between these
|
||||
events.
|
||||
|
||||
The events are:
|
||||
|
||||
- `QUEUED` - when the request was received by the engine core and
|
||||
added to the scheduler queue.
|
||||
- `SCHEDULED` - when the request was first scheduled for execution.
|
||||
- `PREEMPTED` - the request has been put back in the waiting queue
|
||||
in order to make room for other requests to complete. It will be
|
||||
re-scheduled in future and re-start its prefill phase.
|
||||
- `NEW_TOKENS` - when the output included in `EngineCoreOutput` was
|
||||
generated. Since this is common to all requests in a given
|
||||
iteration, we use a single timestamp on `EngineCoreOutputs` to
|
||||
record this event.
|
||||
|
||||
And the calculated intervals are:
|
||||
|
||||
- Queue interval - between `QUEUED` and most recent `SCHEDULED`.
|
||||
- Prefill interval - between most recent `SCHEDULED` and the subsequent
|
||||
first `NEW_TOKENS`.
|
||||
- Decode interval - between first (after the most recent `SCHEDULED`) and
|
||||
last `NEW_TOKENS`.
|
||||
- Inference interval - between most recent `SCHEDULED` and last `NEW_TOKENS`.
|
||||
- Inter-token interval - between successive `NEW_TOKENS`.
|
||||
|
||||
Put another way:
|
||||
|
||||
:::{image} /assets/design/v1/metrics/intervals-1.png
|
||||
:alt: Interval calculations - common case
|
||||
:::
|
||||
|
||||
We explored the possibility of having the frontend calculate these
|
||||
intervals using the timing of events visible by the frontend. However,
|
||||
the frontend does not have visibility into the timing of the `QUEUED`
|
||||
and `SCHEDULED` events and, since we need to calculate intervals based
|
||||
on monotonic timestamps from the same process ... we need the engine
|
||||
core to record timestamps for all of these events.
|
||||
|
||||
#### Interval Calculations vs Preemptions
|
||||
|
||||
When a preemption occurs during decode, since any already generated
|
||||
tokens are reused, we consider the preemption as affecting the
|
||||
inter-token, decode, and inference intervals.
|
||||
|
||||
:::{image} /assets/design/v1/metrics/intervals-2.png
|
||||
:alt: Interval calculations - preempted decode
|
||||
:::
|
||||
|
||||
When a preemption occurs during prefill (assuming such an event
|
||||
is possible), we consider the preemption as affecting the
|
||||
time-to-first-token and prefill intervals.
|
||||
|
||||
:::{image} /assets/design/v1/metrics/intervals-3.png
|
||||
:alt: Interval calculations - preempted prefill
|
||||
:::
|
||||
|
||||
### Frontend Stats Collection
|
||||
|
||||
As the frontend processes a single `EngineCoreOutputs` - i.e. the
|
||||
output from a single engine core iteration - it collects various
|
||||
statistics relating to that iteration:
|
||||
|
||||
- The total number of new tokens generated in this iteration.
|
||||
- The total number of prompt tokens processed by the prefills that
|
||||
completed in this iteration.
|
||||
- The queue intervals for any requests that were scheduled in this
|
||||
iteration.
|
||||
- The prefill intervals for any requests that completed prefill in
|
||||
this iteration.
|
||||
- The inter-token intervals (Time Per Output Token, TPOT), for all
|
||||
requests included in this iteration.
|
||||
- The Time-To-First-Token (TTFT) for any requests that completed
|
||||
prefill in this iteration. However, we calculate this interval
|
||||
relative to when the request was first received by the frontend
|
||||
(`arrival_time`) in order to account for input processing time.
|
||||
|
||||
For any requests that were completed in a given iteration, we also
|
||||
record:
|
||||
|
||||
- The inference and decode intervals - relative to the scheduled and
|
||||
first token events, as described above.
|
||||
- End-to-end latency - the interval between frontend `arrival_time`
|
||||
and the frontend receiving the final token.
|
||||
|
||||
### Metrics Publishing - Logging
|
||||
|
||||
The `LoggingStatLogger` metrics publisher outputs a log `INFO` message
|
||||
every 5 seconds with some key metrics:
|
||||
|
||||
- The current number of running/waiting requests
|
||||
- The current GPU cache usage
|
||||
- The number of prompt tokens processed per second over the past 5
|
||||
seconds
|
||||
- The number of new tokens generated per second over the past 5
|
||||
seconds
|
||||
- The prefix cache hit rate over the most recent 1k kv-cache block queries
|
||||
|
||||
### Metrics Publishing - Prometheus
|
||||
|
||||
The `PrometheusStatLogger` metrics publisher makes the metrics
|
||||
available via a `/metrics` HTTP endpoint in a Prometheus-compatible
|
||||
format. A Prometheus instance can then be configured to poll this
|
||||
endpoint (e.g. every second) and record the values in its time-series
|
||||
database. Prometheus is often used via Grafana, allowing these metrics
|
||||
to be graphed over time.
|
||||
|
||||
Prometheus supports the following metric types:
|
||||
|
||||
- Counter: a value that will increase over time, never reducing, and
|
||||
generally reset to zero when the vLLM instance restarts. For
|
||||
example, the number of tokens generated over the lifetime of the
|
||||
instance.
|
||||
- Gauge: a value that goes up and down, for example the number of
|
||||
requests currently scheduled for execution.
|
||||
- Histogram: a count of metric samples, recorded in buckets. For
|
||||
example, the number of requests whose TTFT was <1ms, <5ms, <10ms,
|
||||
<20ms, and so on.
|
||||
|
||||
Prometheus metrics can also be labelled, allowing metrics to be
|
||||
combined according to matching labels. In vLLM, we add a `model_name`
|
||||
label to every metric which includes the name of the model served by
|
||||
that instance.
|
||||
|
||||
Example output:
|
||||
|
||||
```bash
|
||||
$ curl http://0.0.0.0:8000/metrics
|
||||
# HELP vllm:num_requests_running Number of requests in model execution batches.
|
||||
# TYPE vllm:num_requests_running gauge
|
||||
vllm:num_requests_running{model_name="meta-llama/Llama-3.1-8B-Instruct"} 8.0
|
||||
...
|
||||
# HELP vllm:generation_tokens_total Number of generation tokens processed.
|
||||
# TYPE vllm:generation_tokens_total counter
|
||||
vllm:generation_tokens_total{model_name="meta-llama/Llama-3.1-8B-Instruct"} 27453.0
|
||||
...
|
||||
# HELP vllm:request_success_total Count of successfully processed requests.
|
||||
# TYPE vllm:request_success_total counter
|
||||
vllm:request_success_total{finished_reason="stop",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0
|
||||
vllm:request_success_total{finished_reason="length",model_name="meta-llama/Llama-3.1-8B-Instruct"} 131.0
|
||||
vllm:request_success_total{finished_reason="abort",model_name="meta-llama/Llama-3.1-8B-Instruct"} 0.0
|
||||
...
|
||||
# HELP vllm:time_to_first_token_seconds Histogram of time to first token in seconds.
|
||||
# TYPE vllm:time_to_first_token_seconds histogram
|
||||
vllm:time_to_first_token_seconds_bucket{le="0.001",model_name="meta-llama/Llama-3.1-8B-Instruct"} 0.0
|
||||
vllm:time_to_first_token_seconds_bucket{le="0.005",model_name="meta-llama/Llama-3.1-8B-Instruct"} 0.0
|
||||
vllm:time_to_first_token_seconds_bucket{le="0.01",model_name="meta-llama/Llama-3.1-8B-Instruct"} 0.0
|
||||
vllm:time_to_first_token_seconds_bucket{le="0.02",model_name="meta-llama/Llama-3.1-8B-Instruct"} 13.0
|
||||
vllm:time_to_first_token_seconds_bucket{le="0.04",model_name="meta-llama/Llama-3.1-8B-Instruct"} 97.0
|
||||
vllm:time_to_first_token_seconds_bucket{le="0.06",model_name="meta-llama/Llama-3.1-8B-Instruct"} 123.0
|
||||
vllm:time_to_first_token_seconds_bucket{le="0.08",model_name="meta-llama/Llama-3.1-8B-Instruct"} 138.0
|
||||
vllm:time_to_first_token_seconds_bucket{le="0.1",model_name="meta-llama/Llama-3.1-8B-Instruct"} 140.0
|
||||
vllm:time_to_first_token_seconds_count{model_name="meta-llama/Llama-3.1-8B-Instruct"} 140.0
|
||||
```
|
||||
|
||||
Note - the choice of histogram buckets to be most useful to users
|
||||
across a broad set of use cases is not straightforward and will
|
||||
require refinement over time.
|
||||
|
||||
### Cache Config Info
|
||||
|
||||
`prometheus_client` has support for [Info
|
||||
metrics](https://prometheus.github.io/client_python/instrumenting/info/)
|
||||
which are equivalent to a `Gauge` whose value is permanently set to 1,
|
||||
but exposes interesting key/value pair information via labels. This is
|
||||
used for information about an instance that does not change - so it
|
||||
only needs to be observed at startup - and allows comparing across
|
||||
instances in Prometheus.
|
||||
|
||||
We use this concept for the `vllm:cache_config_info` metric:
|
||||
|
||||
```
|
||||
# HELP vllm:cache_config_info Information of the LLMEngine CacheConfig
|
||||
# TYPE vllm:cache_config_info gauge
|
||||
vllm:cache_config_info{block_size="16",cache_dtype="auto",calculate_kv_scales="False",cpu_offload_gb="0",enable_prefix_caching="False",gpu_memory_utilization="0.9",...} 1.0
|
||||
|
||||
```
|
||||
|
||||
However, `prometheus_client` has [never supported Info metrics in
|
||||
multiprocessing
|
||||
mode](https://github.com/prometheus/client_python/pull/300) - for
|
||||
[unclear
|
||||
reasons](gh-pr:7279#discussion_r1710417152). We
|
||||
simply use a `Gauge` metric set to 1 and
|
||||
`multiprocess_mode="mostrecent"` instead.
|
||||
|
||||
### LoRA Metrics
|
||||
|
||||
The `vllm:lora_requests_info` `Gauge` is somewhat similar, except the
|
||||
value is the current wall-clock time, and is updated every iteration.
|
||||
|
||||
The label names used are:
|
||||
|
||||
- `running_lora_adapters`: a per-adapter count of the number requests
|
||||
running using that adapter, formatted as a comma-separated string.
|
||||
- `waiting_lora_adapters`: similar, except counting requests that are
|
||||
waiting to be scheduled.
|
||||
- `max_lora` - the static "max number of LoRAs in a single batch."
|
||||
configuration.
|
||||
|
||||
Encoding a running/waiting counts for multiple adapters in a
|
||||
comma-separated string seems quite misguided - we could use labels to
|
||||
distinguish between per-adapter counts. This should be revisited.
|
||||
|
||||
Note that `multiprocess_mode="livemostrecent"` is used - the most
|
||||
recent metric is used, but only from currently running processes.
|
||||
|
||||
This was added in
|
||||
<gh-pr:9477> and there is
|
||||
[at least one known
|
||||
user](https://github.com/kubernetes-sigs/gateway-api-inference-extension/pull/54). If
|
||||
we revisit this design and deprecate the old metric, we should reduce
|
||||
the need for a significant deprecation period by making the change in
|
||||
v0 also and asking this project to move to the new metric.
|
||||
|
||||
### Prefix Cache metrics
|
||||
|
||||
The discussion in <gh-issue:10582> about adding prefix cache metrics yielded
|
||||
some interesting points which may be relevant to how we approach
|
||||
future metrics.
|
||||
|
||||
Every time the prefix cache is queried, we record the number of blocks
|
||||
queried and the number of queried blocks present in the cache
|
||||
(i.e. hits).
|
||||
|
||||
However, the metric of interest is the hit rate - i.e. the number of
|
||||
hits per query.
|
||||
|
||||
In the case of logging, we expect the user is best served by
|
||||
calculating the hit rate over a fixed number of the most recent
|
||||
queries (the interval is fixed to 1k most recent queries for now).
|
||||
|
||||
In the case of Prometheus though, we should take advantage of the
|
||||
time-series nature of Prometheus and allow the user to calculate the
|
||||
hit rate over an interval of their choosing. For example, a PromQL
|
||||
query to calculate the hit interval of the past 5 minutes:
|
||||
|
||||
```text
|
||||
rate(cache_query_hit[5m]) / rate(cache_query_total[5m])
|
||||
```
|
||||
|
||||
To achieve this, we should record the queries and hits as counters in
|
||||
Prometheus, rather than recording the hit rate as a gauge.
|
||||
|
||||
## Deprecated Metrics
|
||||
|
||||
### How To Deprecate
|
||||
|
||||
Deprecating metrics shouldn't be taken lightly. Users may not notice a
|
||||
metric has been deprecated, and may be quite inconvenienced when it is
|
||||
suddenly (from their perspective) when it is removed, even if there is
|
||||
an equivalent metric for them to use.
|
||||
|
||||
As an example, see how `vllm:avg_prompt_throughput_toks_per_s` was
|
||||
[deprecated](gh-pr:2764) (with a
|
||||
comment in the code),
|
||||
[removed](gh-pr:12383), and then
|
||||
[noticed by a
|
||||
user](gh-issue:13218).
|
||||
|
||||
In general:
|
||||
|
||||
1) We should be cautious about deprecating metrics, especially since
|
||||
it can be hard to predict the user impact.
|
||||
2) We should include a prominent deprecation notice in the help string
|
||||
that is included in the `/metrics' output.
|
||||
3) We should list deprecated metrics in user-facing documentation and
|
||||
release notes.
|
||||
4) We should consider hiding deprecated metrics behind a CLI argument
|
||||
in order to give administrators [an escape
|
||||
hatch](https://kubernetes.io/docs/concepts/cluster-administration/system-metrics/#show-hidden-metrics)
|
||||
for some time before deleting them.
|
||||
|
||||
### Unimplemented - `vllm:tokens_total`
|
||||
|
||||
Added by <gh-pr:4464>, but apparently never implemented. This can just be
|
||||
removed.
|
||||
|
||||
### Duplicated - Queue Time
|
||||
|
||||
The `vllm:time_in_queue_requests` Histogram metric was added by
|
||||
<gh-pr:9659> and its calculation is:
|
||||
|
||||
```
|
||||
self.metrics.first_scheduled_time = now
|
||||
self.metrics.time_in_queue = now - self.metrics.arrival_time
|
||||
```
|
||||
|
||||
Two weeks later, <gh-pr:4464> added `vllm:request_queue_time_seconds` leaving
|
||||
us with:
|
||||
|
||||
```
|
||||
if seq_group.is_finished():
|
||||
if (seq_group.metrics.first_scheduled_time is not None and
|
||||
seq_group.metrics.first_token_time is not None):
|
||||
time_queue_requests.append(
|
||||
seq_group.metrics.first_scheduled_time -
|
||||
seq_group.metrics.arrival_time)
|
||||
...
|
||||
if seq_group.metrics.time_in_queue is not None:
|
||||
time_in_queue_requests.append(
|
||||
seq_group.metrics.time_in_queue)
|
||||
```
|
||||
|
||||
This seems duplicative, and one of them should be removed. The latter
|
||||
is used by the Grafana dashboard, so we should deprecate or remove the
|
||||
former from v0.
|
||||
|
||||
### Prefix Cache Hit Rate
|
||||
|
||||
See above - we now expose 'queries' and 'hits' counters rather than a
|
||||
'hit rate' gauge.
|
||||
|
||||
### KV Cache Offloading
|
||||
|
||||
Two v0 metrics relate to a "swapped" preemption mode that is no
|
||||
longer relevant in v1:
|
||||
|
||||
- `vllm:num_requests_swapped`
|
||||
- `vllm:cpu_cache_usage_perc`
|
||||
|
||||
In this mode, when a request is preempted (e.g. to make room in KV
|
||||
cache to complete other requests), we swap kv cache blocks out to CPU
|
||||
memory. This is also known as "KV cache offloading" and is configured
|
||||
with `--swap-space` and `--preemption-mode`.
|
||||
|
||||
In v0, [VLLM has long supported beam
|
||||
search](gh-issue:6226). The
|
||||
SequenceGroup encapsulated the idea of N Sequences which
|
||||
all shared the same prompt kv blocks. This enabled KV cache block
|
||||
sharing between requests, and copy-on-write to do branching. CPU
|
||||
swapping was intended for these beam search like cases.
|
||||
|
||||
Later, the concept of prefix caching was introduced, which allowed KV
|
||||
cache blocks to be shared implicitly. This proved to be a better
|
||||
option than CPU swapping since blocks can be evicted slowly on demand
|
||||
and the part of the prompt that was evicted can be recomputed.
|
||||
|
||||
SequenceGroup was removed in V1, although a replacement will be
|
||||
required for "parallel sampling" (`n>1`). [Beam search was moved out of
|
||||
the core (in
|
||||
V0)](gh-issue:8306). There was a
|
||||
lot of complex code for a very uncommon feature.
|
||||
|
||||
In V1, with prefix caching being better (zero over head) and therefore
|
||||
on by default, the preemption and recompute strategy should work
|
||||
better.
|
||||
|
||||
## Future Work
|
||||
|
||||
### Parallel Sampling
|
||||
|
||||
Some v0 metrics are only relevant in the context of "parallel
|
||||
sampling". This is where the `n` parameter in a request is used to
|
||||
request multiple completions from the same prompt.
|
||||
|
||||
As part of adding parallel sampling support in <gh-pr:10980> we should
|
||||
also add these metrics.
|
||||
|
||||
- `vllm:request_params_n` (Histogram)
|
||||
|
||||
Observes the value of the 'n' parameter of every finished request.
|
||||
|
||||
- `vllm:request_max_num_generation_tokens` (Histogram)
|
||||
|
||||
Observes the maximum output length of all sequences in every finished
|
||||
sequence group. In the absence of parallel sampling, this is
|
||||
equivalent to `vllm:request_generation_tokens`.
|
||||
|
||||
### Speculative Decoding
|
||||
|
||||
Some v0 metrics are specific to "speculative decoding". This is where
|
||||
we generate candidate tokens using a faster, approximate method or
|
||||
model and then validate those tokens with the larger model.
|
||||
|
||||
- `vllm:spec_decode_draft_acceptance_rate` (Gauge)
|
||||
- `vllm:spec_decode_efficiency` (Gauge)
|
||||
- `vllm:spec_decode_num_accepted_tokens_total` (Counter)
|
||||
- `vllm:spec_decode_num_draft_tokens_total` (Counter)
|
||||
- `vllm:spec_decode_num_emitted_tokens_total` (Counter)
|
||||
|
||||
There is a PR under review (<gh-pr:12193>) to add "prompt lookup (ngram)"
|
||||
seculative decoding to v1. Other techniques will follow. We should
|
||||
revisit the v0 metrics in this context.
|
||||
|
||||
Note - we should probably expose acceptance rate as separate accepted
|
||||
and draft counters, like we do for prefix caching hit rate. Efficiency
|
||||
likely also needs similar treatment.
|
||||
|
||||
### Autoscaling and Load-balancing
|
||||
|
||||
A common use case for our metrics is to support automated scaling of
|
||||
vLLM instances.
|
||||
|
||||
For related discussion from the [Kubernetes Serving Working
|
||||
Group](https://github.com/kubernetes/community/tree/master/wg-serving),
|
||||
see:
|
||||
|
||||
- [Standardizing Large Model Server Metrics in
|
||||
Kubernetes](https://docs.google.com/document/d/1SpSp1E6moa4HSrJnS4x3NpLuj88sMXr2tbofKlzTZpk)
|
||||
- [Benchmarking LLM Workloads for Performance Evaluation and
|
||||
Autoscaling in
|
||||
Kubernetes](https://docs.google.com/document/d/1k4Q4X14hW4vftElIuYGDu5KDe2LtV1XammoG-Xi3bbQ)
|
||||
- [Inference
|
||||
Perf](https://github.com/kubernetes-sigs/wg-serving/tree/main/proposals/013-inference-perf)
|
||||
- <gh-issue:5041> and <gh-pr:12726>.
|
||||
|
||||
This is a non-trivial topic. Consider this comment from Rob:
|
||||
|
||||
> I think this metric should focus on trying to estimate what the max
|
||||
> concurrency that will cause the average request length > queries per
|
||||
> second ... since this is really what will "saturate" the server.
|
||||
|
||||
A clear goal is that we should expose the metrics required to detect
|
||||
this saturation point, so administrators can implement auto-scaling
|
||||
rules based on those. However, in order to do so, we need to have a
|
||||
clear view on how an administrator (and automated monitoring system)
|
||||
should judge an instance as approaching saturation:
|
||||
|
||||
> To identify, what is the saturation point for model server compute
|
||||
> (the inflection point where we cannot get more throughput with a
|
||||
> higher request rate, but start to incur additional latency) so we
|
||||
> can autoscale effectively?
|
||||
|
||||
### Metric Naming
|
||||
|
||||
Our approach to naming metrics probably deserves to be revisited:
|
||||
|
||||
1. The use of colons in metric names seems contrary to ["colons are
|
||||
reserved for user defined recording
|
||||
rules"](https://prometheus.io/docs/concepts/data_model/#metric-names-and-labels)
|
||||
2. Most of our metrics follow the convention of ending with units, but
|
||||
not all do.
|
||||
3. Some of our metric names end with `_total`:
|
||||
|
||||
```
|
||||
If there is a suffix of `_total` on the metric name, it will be removed. When
|
||||
exposing the time series for counter, a `_total` suffix will be added. This is
|
||||
for compatibility between OpenMetrics and the Prometheus text format, as OpenMetrics
|
||||
requires the `_total` suffix.
|
||||
```
|
||||
|
||||
### Adding More Metrics
|
||||
|
||||
There is no shortage of ideas for new metrics:
|
||||
|
||||
- Examples from other projects like
|
||||
[TGI](https://github.com/IBM/text-generation-inference?tab=readme-ov-file#metrics)
|
||||
- Proposals arising from specific use cases, like the Kubernetes
|
||||
auto-scaling topic above
|
||||
- Proposals that might arise out of standardisation efforts like
|
||||
[OpenTelemetry Semantic Conventions for Gen
|
||||
AI](https://github.com/open-telemetry/semantic-conventions/tree/main/docs/gen-ai).
|
||||
|
||||
We should be cautious in our approach to adding new metrics. While
|
||||
metrics are often relatively straightforward to add:
|
||||
|
||||
1. They can be difficult to remove - see the section on deprecation
|
||||
above.
|
||||
2. They can have a meaningful performance impact when enabled. And
|
||||
metrics are usually of very limited use unless they can be enabled
|
||||
by default and in production.
|
||||
3. They have an impact on development and maintenance of the
|
||||
project. Every metric added to v0 has made this v1 effort more
|
||||
time-consuming, and perhaps not all metrics justify this ongoing
|
||||
investment in their maintenance.
|
||||
|
||||
## Tracing - OpenTelemetry
|
||||
|
||||
Metrics provide an aggregated view over time of the system's
|
||||
performance and health. Tracing, on the other hand, tracks individual
|
||||
requests as they move through different services and components. Both
|
||||
fall under the more general heading of "Observability".
|
||||
|
||||
v0 has support for OpenTelemetry tracing:
|
||||
|
||||
- Added by <gh-pr:4687>
|
||||
- Configured with `--oltp-traces-endpoint` and
|
||||
`--collect-detailed-traces`
|
||||
- [OpenTelemetry blog
|
||||
post](https://opentelemetry.io/blog/2024/llm-observability/)
|
||||
- [User-facing
|
||||
docs](https://docs.vllm.ai/en/latest/getting_started/examples/opentelemetry.html)
|
||||
- [Blog
|
||||
post](https://medium.com/@ronen.schaffer/follow-the-trail-supercharging-vllm-with-opentelemetry-distributed-tracing-aa655229b46f)
|
||||
- [IBM product
|
||||
docs](https://www.ibm.com/docs/en/instana-observability/current?topic=mgaa-monitoring-large-language-models-llms-vllm-public-preview)
|
||||
|
||||
OpenTelemetry has a [Gen AI Working
|
||||
Group](https://github.com/open-telemetry/community/blob/main/projects/gen-ai.md).
|
||||
|
||||
Since metrics is a big enough topic on its own, we are going to tackle
|
||||
the topic of tracing in v1 separately.
|
||||
|
||||
### OpenTelemetry Model Forward vs Execute Time
|
||||
|
||||
In v0, we have the following two metrics:
|
||||
|
||||
- `vllm:model_forward_time_milliseconds` (Histogram) - The time spent
|
||||
in the model forward pass when this request was in the batch.
|
||||
- `vllm:model_execute_time_milliseconds` (Histogram) - The time spent
|
||||
in the model execute function. This will include model forward,
|
||||
block/sync across workers, cpu-gpu sync time and sampling time.
|
||||
|
||||
These metrics are only enabled when OpenTelemetry tracing is enabled
|
||||
and if `--collect-detailed-traces=all/model/worker` is used. The
|
||||
documentation for this option states:
|
||||
|
||||
> collect detailed traces for the specified "modules. This involves
|
||||
> use of possibly costly and or blocking operations and hence might
|
||||
> have a performance impact.
|
||||
|
||||
The metrics were added by <gh-pr:7089> and who up in an OpenTelemetry trace
|
||||
as:
|
||||
|
||||
```
|
||||
-> gen_ai.latency.time_in_scheduler: Double(0.017550230026245117)
|
||||
-> gen_ai.latency.time_in_model_forward: Double(3.151565277099609)
|
||||
-> gen_ai.latency.time_in_model_execute: Double(3.6468167304992676)
|
||||
```
|
||||
|
||||
We already have `inference_time` and `decode_time` metrics, so the
|
||||
question is whether there are sufficiently common use cases for the
|
||||
higher-resolution timings to justify the overhead.
|
||||
|
||||
Since we are going to treat the question of OpenTelemetry support
|
||||
separately, we will include these particular metrics under that topic.
|
||||
@ -170,7 +170,7 @@ Now, you can specify a base_model_name alongside the name and path using JSON fo
|
||||
|
||||
To provide the backward compatibility support, you can still use the old key-value format (name=path), but the `base_model_name` will remain unspecified in that case.
|
||||
|
||||
## Lora model lineage in model card
|
||||
## LoRA model lineage in model card
|
||||
|
||||
The new format of `--lora-modules` is mainly to support the display of parent model information in the model card. Here's an explanation of how your current response supports this:
|
||||
|
||||
|
||||
@ -3,16 +3,16 @@
|
||||
# AutoAWQ
|
||||
|
||||
To create a new 4-bit quantized model, you can leverage [AutoAWQ](https://github.com/casper-hansen/AutoAWQ).
|
||||
Quantizing reduces the model's precision from FP16 to INT4 which effectively reduces the file size by ~70%.
|
||||
Quantization reduces the model's precision from BF16/FP16 to INT4 which effectively reduces the total model memory footprint.
|
||||
The main benefits are lower latency and memory usage.
|
||||
|
||||
You can quantize your own models by installing AutoAWQ or picking one of the [400+ models on Huggingface](https://huggingface.co/models?sort=trending&search=awq).
|
||||
You can quantize your own models by installing AutoAWQ or picking one of the [6500+ models on Huggingface](https://huggingface.co/models?sort=trending&search=awq).
|
||||
|
||||
```console
|
||||
pip install autoawq
|
||||
```
|
||||
|
||||
After installing AutoAWQ, you are ready to quantize a model. Here is an example of how to quantize `mistralai/Mistral-7B-Instruct-v0.2`:
|
||||
After installing AutoAWQ, you are ready to quantize a model. Please refer to the `AutoAWQ documentation <https://casper-hansen.github.io/AutoAWQ/examples/#basic-quantization>`_ for further details. Here is an example of how to quantize `mistralai/Mistral-7B-Instruct-v0.2`:
|
||||
|
||||
```python
|
||||
from awq import AutoAWQForCausalLM
|
||||
|
||||
@ -29,6 +29,13 @@ vllm serve ./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf --tokenizer TinyLlama/TinyLlam
|
||||
We recommend using the tokenizer from base model instead of GGUF model. Because the tokenizer conversion from GGUF is time-consuming and unstable, especially for some models with large vocab size.
|
||||
:::
|
||||
|
||||
GGUF assumes that huggingface can convert the metadata to a config file. In case huggingface doesn't support your model you can manually create a config and pass it as hf-confing-path
|
||||
|
||||
```console
|
||||
# If you model is not supported by huggingface you can manually provide a huggingface compatible config path
|
||||
vllm serve ./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf --tokenizer TinyLlama/TinyLlama-1.1B-Chat-v1.0 --hf-config-path Tinyllama/TInyLlama-1.1B-Chat-v1.0
|
||||
```
|
||||
|
||||
You can also use the GGUF model directly through the LLM entrypoint:
|
||||
|
||||
```python
|
||||
|
||||
83
docs/source/features/quantization/gptqmodel.md
Normal file
83
docs/source/features/quantization/gptqmodel.md
Normal file
@ -0,0 +1,83 @@
|
||||
(gptqmodel)=
|
||||
|
||||
# GPTQModel
|
||||
|
||||
To create a new 4-bit or 8-bit GPTQ quantized model, you can leverage [GPTQModel](https://github.com/ModelCloud/GPTQModel) from ModelCloud.AI.
|
||||
|
||||
Quantization reduces the model's precision from BF16/FP16 (16-bits) to INT4 (4-bits) or INT8 (8-bits) which significantly reduces the
|
||||
total model memory footprint while at-the-same-time increasing inference performance.
|
||||
|
||||
Compatible GPTQModel quantized models can leverage the `Marlin` and `Machete` vLLM custom kernels to maximize batching
|
||||
transactions-per-second `tps` and token-latency performance for both Ampere (A100+) and Hopper (H100+) Nvidia GPUs.
|
||||
These two kernels are highly optimized by vLLM and NeuralMagic (now part of Redhat) to allow world-class inference performance of quantized GPTQ
|
||||
models.
|
||||
|
||||
GPTQModel is one of the few quantization toolkits in the world that allows `Dynamic` per-module quantization where different layers and/or modules within a llm model can be further optimized with custom quantization parameters. `Dynamic` quantization
|
||||
is fully integrated into vLLM and backed up by support from the ModelCloud.AI team. Please refer to [GPTQModel readme](https://github.com/ModelCloud/GPTQModel?tab=readme-ov-file#dynamic-quantization-per-module-quantizeconfig-override)
|
||||
for more details on this and other advanced features.
|
||||
|
||||
You can quantize your own models by installing [GPTQModel](https://github.com/ModelCloud/GPTQModel) or picking one of the [5000+ models on Huggingface](https://huggingface.co/models?sort=trending&search=gptq).
|
||||
|
||||
```console
|
||||
pip install -U gptqmodel --no-build-isolation -v
|
||||
```
|
||||
|
||||
After installing GPTQModel, you are ready to quantize a model. Please refer to the [GPTQModel readme](https://github.com/ModelCloud/GPTQModel/?tab=readme-ov-file#quantization) for further details.
|
||||
|
||||
Here is an example of how to quantize `meta-llama/Llama-3.2-1B-Instruct`:
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from gptqmodel import GPTQModel, QuantizeConfig
|
||||
|
||||
model_id = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
quant_path = "Llama-3.2-1B-Instruct-gptqmodel-4bit"
|
||||
|
||||
calibration_dataset = load_dataset(
|
||||
"allenai/c4",
|
||||
data_files="en/c4-train.00001-of-01024.json.gz",
|
||||
split="train"
|
||||
).select(range(1024))["text"]
|
||||
|
||||
quant_config = QuantizeConfig(bits=4, group_size=128)
|
||||
|
||||
model = GPTQModel.load(model_id, quant_config)
|
||||
|
||||
# increase `batch_size` to match gpu/vram specs to speed up quantization
|
||||
model.quantize(calibration_dataset, batch_size=2)
|
||||
|
||||
model.save(quant_path)
|
||||
```
|
||||
|
||||
To run an GPTQModel quantized model with vLLM, you can use [DeepSeek-R1-Distill-Qwen-7B-gptqmodel-4bit-vortex-v2](https://huggingface.co/ModelCloud/DeepSeek-R1-Distill-Qwen-7B-gptqmodel-4bit-vortex-v2) with the following command:
|
||||
|
||||
```console
|
||||
python examples/offline_inference/llm_engine_example.py --model DeepSeek-R1-Distill-Qwen-7B-gptqmodel-4bit-vortex-v2
|
||||
```
|
||||
|
||||
GPTQModel quantized models are also supported directly through the LLM entrypoint:
|
||||
|
||||
```python
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.6, top_p=0.9)
|
||||
|
||||
# Create an LLM.
|
||||
llm = LLM(model="DeepSeek-R1-Distill-Qwen-7B-gptqmodel-4bit-vortex-v2")
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
```
|
||||
@ -12,6 +12,7 @@ supported_hardware
|
||||
auto_awq
|
||||
bnb
|
||||
gguf
|
||||
gptqmodel
|
||||
int4
|
||||
int8
|
||||
fp8
|
||||
|
||||
@ -76,7 +76,13 @@ Streaming chat completions are also supported for reasoning models. The `reasoni
|
||||
}
|
||||
```
|
||||
|
||||
Please note that it is not compatible with the OpenAI Python client library. You can use the `requests` library to make streaming requests.
|
||||
Please note that it is not compatible with the OpenAI Python client library. You can use the `requests` library to make streaming requests. You could checkout the [example](https://github.com/vllm-project/vllm/blob/main/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py).
|
||||
|
||||
## Limitations
|
||||
|
||||
- The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`).
|
||||
- It is not compatible with [`tool_calling`](#tool_calling).
|
||||
- The reasoning content is not available for all models. Check the model's documentation to see if it supports reasoning.
|
||||
|
||||
## How to support a new reasoning model
|
||||
|
||||
@ -117,7 +123,7 @@ class ExampleParser(ReasoningParser):
|
||||
|
||||
def extract_reasoning_content(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Extract reasoning content from a complete model-generated string.
|
||||
|
||||
@ -132,20 +138,41 @@ class ExampleParser(ReasoningParser):
|
||||
The request object that was used to generate the model_output.
|
||||
|
||||
Returns:
|
||||
Tuple[Optional[str], Optional[str]]
|
||||
tuple[Optional[str], Optional[str]]
|
||||
A tuple containing the reasoning content and the content.
|
||||
"""
|
||||
```
|
||||
|
||||
After defining the reasoning parser, you can use it by specifying the `--reasoning-parser` flag when making a request to the chat completion endpoint.
|
||||
Additionally, to enable structured output, you'll need to create a new `Reasoner` similar to the one in `vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py`.
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class DeepSeekReasoner(Reasoner):
|
||||
"""
|
||||
Reasoner for DeepSeek R series models.
|
||||
"""
|
||||
start_token_id: int
|
||||
end_token_id: int
|
||||
|
||||
start_token: str = "<think>"
|
||||
end_token: str = "</think>"
|
||||
|
||||
@classmethod
|
||||
def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner:
|
||||
return cls(start_token_id=tokenizer.encode(
|
||||
"<think>", add_special_tokens=False)[0],
|
||||
end_token_id=tokenizer.encode("</think>",
|
||||
add_special_tokens=False)[0])
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
return self.end_token_id in input_ids
|
||||
```
|
||||
|
||||
The structured output engine like xgrammar will use `end_token_id` to check if the reasoning content is present in the model output and skip the structured output if it is the case.
|
||||
|
||||
Finally, you can enable reasoning for the model by using the `--enable-reasoning` and `--reasoning-parser` flags.
|
||||
|
||||
```bash
|
||||
vllm serve <model_tag> \
|
||||
--enable-reasoning --reasoning-parser example
|
||||
```
|
||||
|
||||
## Limitations
|
||||
|
||||
- The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`).
|
||||
- It is not compatible with the [`structured_outputs`](#structured_outputs) and [`tool_calling`](#tool_calling) features.
|
||||
- The reasoning content is not available for all models. Check the model's documentation to see if it supports reasoning.
|
||||
|
||||
@ -16,7 +16,7 @@ The following parameters are supported, which must be added as extra parameters:
|
||||
- `guided_json`: the output will follow the JSON schema.
|
||||
- `guided_grammar`: the output will follow the context free grammar.
|
||||
- `guided_whitespace_pattern`: used to override the default whitespace pattern for guided json decoding.
|
||||
- `guided_decoding_backend`: used to select the guided decoding backend to use.
|
||||
- `guided_decoding_backend`: used to select the guided decoding backend to use. Additional backend-specific options can be supplied in a comma separated list following a colon after the backend name. For example `"xgrammar:no-fallback"` will not allow vLLM to fallback to a different backend on error.
|
||||
|
||||
You can see the complete list of supported parameters on the [OpenAI-Compatible Server](#openai-compatible-server)page.
|
||||
|
||||
@ -193,7 +193,7 @@ class Step(BaseModel):
|
||||
|
||||
|
||||
class MathResponse(BaseModel):
|
||||
steps: List[Step]
|
||||
steps: list[Step]
|
||||
final_answer: str
|
||||
|
||||
|
||||
|
||||
@ -74,7 +74,7 @@ class Example:
|
||||
path (Path): The path to the main directory or file.
|
||||
category (str): The category of the document.
|
||||
main_file (Path): The main file in the directory.
|
||||
other_files (list[Path]): List of other files in the directory.
|
||||
other_files (list[Path]): list of other files in the directory.
|
||||
title (str): The title of the document.
|
||||
|
||||
Methods:
|
||||
|
||||
@ -6,7 +6,14 @@ sudo apt-get install -y gcc-12 g++-12 libnuma-dev
|
||||
sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
|
||||
```
|
||||
|
||||
Second, install Python packages for vLLM CPU backend building:
|
||||
Second, clone vLLM project:
|
||||
|
||||
```console
|
||||
git clone https://github.com/vllm-project/vllm.git vllm_source
|
||||
cd vllm_source
|
||||
```
|
||||
|
||||
Third, install Python packages for vLLM CPU backend building:
|
||||
|
||||
```console
|
||||
pip install --upgrade pip
|
||||
|
||||
@ -23,12 +23,12 @@ Therefore, it is recommended to install vLLM with a **fresh new** environment. I
|
||||
You can install vLLM using either `pip` or `uv pip`:
|
||||
|
||||
```console
|
||||
# Install vLLM with CUDA 12.1.
|
||||
# Install vLLM with CUDA 12.4.
|
||||
pip install vllm # If you are using pip.
|
||||
uv pip install vllm # If you are using uv.
|
||||
```
|
||||
|
||||
As of now, vLLM's binaries are compiled with CUDA 12.1 and public PyTorch release versions by default. We also provide vLLM binaries compiled with CUDA 11.8 and public PyTorch release versions:
|
||||
As of now, vLLM's binaries are compiled with CUDA 12.4 and public PyTorch release versions by default. We also provide vLLM binaries compiled with CUDA 12.1, 11.8, and public PyTorch release versions:
|
||||
|
||||
```console
|
||||
# Install vLLM with CUDA 11.8.
|
||||
|
||||
@ -53,9 +53,9 @@ Currently, there are no pre-built ROCm wheels.
|
||||
If you see HTTP issue related to downloading packages during building triton, please try again as the HTTP error is intermittent.
|
||||
:::
|
||||
|
||||
2. Optionally, if you choose to use CK flash attention, you can install [flash attention for ROCm](https://github.com/ROCm/flash-attention/tree/ck_tile)
|
||||
2. Optionally, if you choose to use CK flash attention, you can install [flash attention for ROCm](https://github.com/ROCm/flash-attention)
|
||||
|
||||
Install ROCm's flash attention (v2.7.2) following the instructions from [ROCm/flash-attention](https://github.com/ROCm/flash-attention/tree/ck_tile#amd-gpurocm-support)
|
||||
Install ROCm's flash attention (v2.7.2) following the instructions from [ROCm/flash-attention](https://github.com/ROCm/flash-attention#amd-rocm-support)
|
||||
Alternatively, wheels intended for vLLM use can be accessed under the releases.
|
||||
|
||||
For example, for ROCm 6.3, suppose your gfx arch is `gfx90a`. To get your gfx architecture, run `rocminfo |grep gfx`.
|
||||
|
||||
@ -24,6 +24,12 @@ source myenv/bin/activate
|
||||
uv pip install vllm
|
||||
```
|
||||
|
||||
Another delightful way is to use `uv run` with `--with [dependency]` option, which allows you to run commands such as `vllm serve` without creating an environment:
|
||||
|
||||
```console
|
||||
uv run --with vllm vllm --help
|
||||
```
|
||||
|
||||
You can also use [conda](https://docs.conda.io/projects/conda/en/latest/user-guide/getting-started.html) to create and manage Python environments.
|
||||
|
||||
```console
|
||||
@ -184,3 +190,13 @@ chat_response = client.chat.completions.create(
|
||||
)
|
||||
print("Chat response:", chat_response)
|
||||
```
|
||||
|
||||
## On Attention Backends
|
||||
|
||||
Currently, vLLM supports multiple backends for efficient Attention computation across different platforms and accelerator architectures. It automatically selects the most performant backend compatible with your system and model specifications.
|
||||
|
||||
If desired, you can also manually set the backend of your choice by configuring the environment variable `VLLM_ATTENTION_BACKEND` to one of the following options: `FLASH_ATTN`, `FLASHINFER` or `XFORMERS`.
|
||||
|
||||
```{attention}
|
||||
There are no pre-built vllm wheels containing Flash Infer, so you must install it in your environment first. Refer to the [Flash Infer official docs](https://docs.flashinfer.ai/) or see [Dockerfile](https://github.com/vllm-project/vllm/blob/main/Dockerfile) for instructions on how to install it.
|
||||
```
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user