Compare commits
160 Commits
v0.6.5
...
v1-blockta
| Author | SHA1 | Date | |
|---|---|---|---|
| 7097f31955 | |||
| f840b53063 | |||
| 1ca4298b9b | |||
| ba64a0249f | |||
| 1260e43230 | |||
| a6e5d7b5b7 | |||
| 6d70198b17 | |||
| f962f426bc | |||
| 11d8a091c6 | |||
| 365801fedd | |||
| 4db72e57f6 | |||
| 0c6f998554 | |||
| e7c7c5e822 | |||
| 8c3230d8c1 | |||
| 2c5718809b | |||
| 82c49d3260 | |||
| 74fa1d123c | |||
| a2a40bcd0d | |||
| ccb1aabcca | |||
| 36e7670045 | |||
| 5886aa496e | |||
| 8d9b6721e7 | |||
| b12e87f942 | |||
| 5dbf854553 | |||
| 970d6d0776 | |||
| 628ec6c17b | |||
| 3682e33f9f | |||
| 0aa38d16f5 | |||
| faef77c0d6 | |||
| dba4d9dec6 | |||
| 32b4c63f02 | |||
| 4fb8e329fd | |||
| 328841d002 | |||
| d427e5cfda | |||
| 42bb201fd6 | |||
| 59d6bb4c86 | |||
| b7dcc003dc | |||
| d34be24bb1 | |||
| b5cbe8eeb3 | |||
| df04dffade | |||
| a60731247f | |||
| ac79799403 | |||
| dde1fa18c9 | |||
| 0240402c46 | |||
| 55509c2114 | |||
| 101418096f | |||
| 5ce4627a7e | |||
| 7af553ea30 | |||
| 2c9b8ea2b0 | |||
| d003f3ea39 | |||
| 6c6f7fe8a8 | |||
| 2339d59f92 | |||
| 1b875a0ef3 | |||
| ebfbe1244b | |||
| eb881ed006 | |||
| 6ba31aa5f6 | |||
| 34d6cc2aea | |||
| 46d4359450 | |||
| 81b979f2a8 | |||
| 371d04d39b | |||
| 0c0c2015c5 | |||
| 82d24f7aac | |||
| f49777ba62 | |||
| 55fb97f7bd | |||
| 2072924d14 | |||
| 720b10fdc6 | |||
| 27e8eb2e94 | |||
| ca4f9e69a8 | |||
| 52922193cd | |||
| bef68163a0 | |||
| ff5b1033dc | |||
| b85a977822 | |||
| eec906d811 | |||
| f57ee5650d | |||
| dcb1a944d4 | |||
| 7492a36207 | |||
| aa25985bd1 | |||
| dbeac95dbb | |||
| 51a624bf02 | |||
| b938606993 | |||
| 6ad909fdda | |||
| b689ada91e | |||
| fc601665eb | |||
| 9832e5572a | |||
| 3f3e92e1f2 | |||
| 409475a827 | |||
| 196c34b0ac | |||
| 5c7963249d | |||
| 461cde2080 | |||
| 7a5286cc04 | |||
| b1b1038fbd | |||
| 9edca6bf8f | |||
| 4f074fbf53 | |||
| a491d6f535 | |||
| 32aa2059ad | |||
| 94d545a1a1 | |||
| 60fb4f3bcf | |||
| 63afbe9215 | |||
| 8cef6e02dc | |||
| b866cdbd05 | |||
| 2e726680b3 | |||
| 5bfb30a529 | |||
| e51719ae72 | |||
| f30581c518 | |||
| 3fdbd8e2f5 | |||
| 0420fb2c7b | |||
| ee965c9c69 | |||
| 048fc57a0f | |||
| f1d1bf6288 | |||
| 72d9c316d3 | |||
| 4a9139780a | |||
| 29c748930e | |||
| 0a669eed7b | |||
| 03b1e6fdbd | |||
| 8a4180c8b6 | |||
| 1aaced5830 | |||
| c2d1b075ba | |||
| 584f0ae40d | |||
| 51ff216d85 | |||
| dd2b5633dd | |||
| 47a0b615b4 | |||
| 5d2248d81a | |||
| d573aeadcc | |||
| 995f56236b | |||
| 7c7aa37c69 | |||
| 04139ade59 | |||
| 1ecc645b8f | |||
| c954f21ac0 | |||
| 86c2d8fd1c | |||
| b880ffb87e | |||
| 7801f56ed7 | |||
| 48edab8041 | |||
| a985f7af9f | |||
| e461c262f0 | |||
| 276738ce0f | |||
| cdf22afdda | |||
| e24113a8fe | |||
| 7379b3d4b2 | |||
| 6c7f881541 | |||
| a0f7d53beb | |||
| 5aef49806d | |||
| 98356735ac | |||
| f26c4aeecb | |||
| 8936316d58 | |||
| 6142ef0ada | |||
| c6b0a7d3ba | |||
| a30482f054 | |||
| 17ca964273 | |||
| 5a9da2e6e9 | |||
| fdea8ec167 | |||
| ca5f54a9b9 | |||
| f954fe0e65 | |||
| 362cff1eb3 | |||
| 996aa70f00 | |||
| 60508ffda9 | |||
| f04e407e6b | |||
| 8b79f9e107 | |||
| 866fa4550d | |||
| bf8717ebae | |||
| c77eb8a33c |
24
.buildkite/generate_index.py
Normal file
24
.buildkite/generate_index.py
Normal file
@ -0,0 +1,24 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
template = """<!DOCTYPE html>
|
||||
<html>
|
||||
<body>
|
||||
<h1>Links for vLLM</h1/>
|
||||
<a href="../{wheel_html_escaped}">{wheel}</a><br/>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--wheel", help="The wheel path.", required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
filename = os.path.basename(args.wheel)
|
||||
|
||||
with open("index.html", "w") as f:
|
||||
print(f"Generated index.html for {args.wheel}")
|
||||
# cloudfront requires escaping the '+' character
|
||||
f.write(
|
||||
template.format(wheel=filename,
|
||||
wheel_html_escaped=filename.replace("+", "%2B")))
|
||||
@ -65,15 +65,15 @@ steps:
|
||||
- VLLM_USAGE_SOURCE
|
||||
- HF_TOKEN
|
||||
|
||||
- block: "Run H100 Benchmark"
|
||||
key: block-h100
|
||||
depends_on: ~
|
||||
#- block: "Run H100 Benchmark"
|
||||
#key: block-h100
|
||||
#depends_on: ~
|
||||
|
||||
- label: "H100"
|
||||
# skip: "use this flag to conditionally skip the benchmark step, useful for PR testing"
|
||||
agents:
|
||||
queue: H100
|
||||
depends_on: block-h100
|
||||
depends_on: ~
|
||||
plugins:
|
||||
- docker#v5.12.0:
|
||||
image: public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:$BUILDKITE_COMMIT
|
||||
|
||||
@ -55,3 +55,18 @@ steps:
|
||||
password-env: DOCKERHUB_TOKEN
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
- block: "Build CPU release image"
|
||||
key: block-cpu-release-image-build
|
||||
depends_on: ~
|
||||
|
||||
- label: "Build and publish CPU release image"
|
||||
depends_on: block-cpu-release-image-build
|
||||
agents:
|
||||
queue: cpu_queue_postmerge
|
||||
commands:
|
||||
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$RELEASE_VERSION --progress plain -f Dockerfile.cpu ."
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$RELEASE_VERSION"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
@ -4,6 +4,9 @@
|
||||
# It serves a sanity check for compilation and basic model usage.
|
||||
set -ex
|
||||
|
||||
# Skip the new torch installation during build since we are using the specified version for arm64 in the Dockerfile
|
||||
python3 use_existing_torch.py
|
||||
|
||||
# Try building the docker image
|
||||
DOCKER_BUILDKIT=1 docker build . \
|
||||
--target vllm-openai \
|
||||
|
||||
@ -106,14 +106,12 @@ steps:
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
commands:
|
||||
- pip install -e ./plugins/vllm_add_dummy_model
|
||||
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py
|
||||
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py
|
||||
- pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/test_chat_utils.py
|
||||
- pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
||||
|
||||
@ -224,8 +222,12 @@ steps:
|
||||
mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- vllm/model_executor/layers
|
||||
- vllm/model_executor/guided_decoding
|
||||
- tests/test_logits_processor
|
||||
command: pytest -v -s test_logits_processor.py
|
||||
- tests/model_executor/test_guided_processors
|
||||
commands:
|
||||
- pytest -v -s test_logits_processor.py
|
||||
- pytest -v -s model_executor/test_guided_processors.py
|
||||
|
||||
- label: Speculative decoding tests # 30min
|
||||
source_file_dependencies:
|
||||
@ -329,8 +331,6 @@ steps:
|
||||
- vllm/
|
||||
- tests/models
|
||||
commands:
|
||||
- pip install -e ./plugins/vllm_add_dummy_model
|
||||
- pytest -v -s models/test_oot_registration.py # it needs a clean process
|
||||
- pytest -v -s models/test_registry.py
|
||||
- pytest -v -s models/test_initialization.py
|
||||
|
||||
@ -356,7 +356,7 @@ steps:
|
||||
- pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model'
|
||||
- pytest -v -s models/embedding/language -m 'not core_model'
|
||||
|
||||
- label: Multi-Modal Models Test (Standard) # 28min
|
||||
- label: Multi-Modal Models Test (Standard) # 40min
|
||||
#mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -372,7 +372,7 @@ steps:
|
||||
- pytest -v -s models/encoder_decoder/language -m core_model
|
||||
- pytest -v -s models/encoder_decoder/vision_language -m core_model
|
||||
|
||||
- label: Multi-Modal Models Test (Extended) 1 # 1h16m
|
||||
- label: Multi-Modal Models Test (Extended) 1 # 48m
|
||||
optional: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -465,11 +465,28 @@ steps:
|
||||
- pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
|
||||
- pip install -e ./plugins/vllm_add_dummy_model
|
||||
- pytest -v -s distributed/test_distributed_oot.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/disagg_test.py
|
||||
|
||||
- label: Plugin Tests (2 GPUs) # 40min
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 2
|
||||
fast_check: true
|
||||
source_file_dependencies:
|
||||
- vllm/plugins/
|
||||
- tests/plugins/
|
||||
commands:
|
||||
# begin platform plugin tests, all the code in-between runs on dummy platform
|
||||
- pip install -e ./plugins/vllm_add_dummy_platform
|
||||
- pytest -v -s plugins_tests/test_platform_plugins.py
|
||||
- pip uninstall vllm_add_dummy_platform -y
|
||||
# end platform plugin tests
|
||||
# other tests continue here:
|
||||
- pip install -e ./plugins/vllm_add_dummy_model
|
||||
- pytest -v -s distributed/test_distributed_oot.py
|
||||
- pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process
|
||||
- pytest -v -s models/test_oot_registration.py # it needs a clean process
|
||||
|
||||
- label: Multi-step Tests (4 GPUs) # 36min
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 4
|
||||
|
||||
@ -23,6 +23,8 @@ wheel="$new_wheel"
|
||||
version=$(unzip -p "$wheel" '**/METADATA' | grep '^Version: ' | cut -d' ' -f2)
|
||||
echo "Version: $version"
|
||||
|
||||
normal_wheel="$wheel" # Save the original wheel filename
|
||||
|
||||
# If the version contains "dev", rename it to v1.0.0.dev for consistency
|
||||
if [[ $version == *dev* ]]; then
|
||||
suffix="${version##*.}"
|
||||
@ -32,12 +34,38 @@ if [[ $version == *dev* ]]; then
|
||||
new_version="1.0.0.dev"
|
||||
fi
|
||||
new_wheel="${wheel/$version/$new_version}"
|
||||
mv -- "$wheel" "$new_wheel"
|
||||
# use cp to keep both files in the artifacts directory
|
||||
cp -- "$wheel" "$new_wheel"
|
||||
wheel="$new_wheel"
|
||||
version="$new_version"
|
||||
fi
|
||||
|
||||
# Upload the wheel to S3
|
||||
python3 .buildkite/generate_index.py --wheel "$normal_wheel"
|
||||
|
||||
# generate index for this commit
|
||||
aws s3 cp "$wheel" "s3://vllm-wheels/$BUILDKITE_COMMIT/"
|
||||
aws s3 cp "$normal_wheel" "s3://vllm-wheels/$BUILDKITE_COMMIT/"
|
||||
|
||||
if [[ $normal_wheel == *"cu118"* ]]; then
|
||||
# if $normal_wheel matches cu118, do not upload the index.html
|
||||
echo "Skipping index files for cu118 wheels"
|
||||
else
|
||||
# only upload index.html for cu12 wheels (default wheels)
|
||||
aws s3 cp index.html "s3://vllm-wheels/$BUILDKITE_COMMIT/vllm/index.html"
|
||||
aws s3 cp "s3://vllm-wheels/nightly/index.html" "s3://vllm-wheels/$BUILDKITE_COMMIT/index.html"
|
||||
fi
|
||||
|
||||
# generate index for nightly
|
||||
aws s3 cp "$wheel" "s3://vllm-wheels/nightly/"
|
||||
aws s3 cp "$normal_wheel" "s3://vllm-wheels/nightly/"
|
||||
|
||||
if [[ $normal_wheel == *"cu118"* ]]; then
|
||||
# if $normal_wheel matches cu118, do not upload the index.html
|
||||
echo "Skipping index files for cu118 wheels"
|
||||
else
|
||||
# only upload index.html for cu12 wheels (default wheels)
|
||||
aws s3 cp index.html "s3://vllm-wheels/nightly/vllm/index.html"
|
||||
fi
|
||||
|
||||
aws s3 cp "$wheel" "s3://vllm-wheels/$version/"
|
||||
105
.github/workflows/publish.yml
vendored
105
.github/workflows/publish.yml
vendored
@ -39,67 +39,68 @@ jobs:
|
||||
const script = require('.github/workflows/scripts/create_release.js')
|
||||
await script(github, context, core)
|
||||
|
||||
wheel:
|
||||
name: Build Wheel
|
||||
runs-on: ${{ matrix.os }}
|
||||
needs: release
|
||||
# NOTE(simon): No longer build wheel using Github Actions. See buildkite's release workflow.
|
||||
# wheel:
|
||||
# name: Build Wheel
|
||||
# runs-on: ${{ matrix.os }}
|
||||
# needs: release
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: ['ubuntu-20.04']
|
||||
python-version: ['3.9', '3.10', '3.11', '3.12']
|
||||
pytorch-version: ['2.4.0'] # Must be the most recent version that meets requirements-cuda.txt.
|
||||
cuda-version: ['11.8', '12.1']
|
||||
# strategy:
|
||||
# fail-fast: false
|
||||
# matrix:
|
||||
# os: ['ubuntu-20.04']
|
||||
# python-version: ['3.9', '3.10', '3.11', '3.12']
|
||||
# pytorch-version: ['2.4.0'] # Must be the most recent version that meets requirements-cuda.txt.
|
||||
# cuda-version: ['11.8', '12.1']
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
# steps:
|
||||
# - name: Checkout
|
||||
# uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
||||
- name: Setup ccache
|
||||
uses: hendrikmuhs/ccache-action@ed74d11c0b343532753ecead8a951bb09bb34bc9 # v1.2.14
|
||||
with:
|
||||
create-symlink: true
|
||||
key: ${{ github.job }}-${{ matrix.python-version }}-${{ matrix.cuda-version }}
|
||||
# - name: Setup ccache
|
||||
# uses: hendrikmuhs/ccache-action@ed74d11c0b343532753ecead8a951bb09bb34bc9 # v1.2.14
|
||||
# with:
|
||||
# create-symlink: true
|
||||
# key: ${{ github.job }}-${{ matrix.python-version }}-${{ matrix.cuda-version }}
|
||||
|
||||
- name: Set up Linux Env
|
||||
if: ${{ runner.os == 'Linux' }}
|
||||
run: |
|
||||
bash -x .github/workflows/scripts/env.sh
|
||||
# - name: Set up Linux Env
|
||||
# if: ${{ runner.os == 'Linux' }}
|
||||
# run: |
|
||||
# bash -x .github/workflows/scripts/env.sh
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
# - name: Set up Python
|
||||
# uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
# with:
|
||||
# python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install CUDA ${{ matrix.cuda-version }}
|
||||
run: |
|
||||
bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ${{ matrix.os }}
|
||||
# - name: Install CUDA ${{ matrix.cuda-version }}
|
||||
# run: |
|
||||
# bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ${{ matrix.os }}
|
||||
|
||||
- name: Install PyTorch ${{ matrix.pytorch-version }} with CUDA ${{ matrix.cuda-version }}
|
||||
run: |
|
||||
bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.pytorch-version }} ${{ matrix.cuda-version }}
|
||||
# - name: Install PyTorch ${{ matrix.pytorch-version }} with CUDA ${{ matrix.cuda-version }}
|
||||
# run: |
|
||||
# bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.pytorch-version }} ${{ matrix.cuda-version }}
|
||||
|
||||
- name: Build wheel
|
||||
shell: bash
|
||||
env:
|
||||
CMAKE_BUILD_TYPE: Release # do not compile with debug symbol to reduce wheel size
|
||||
run: |
|
||||
bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }}
|
||||
wheel_name=$(find dist -name "*whl" -print0 | xargs -0 -n 1 basename)
|
||||
asset_name=${wheel_name//"linux"/"manylinux1"}
|
||||
echo "wheel_name=${wheel_name}" >> "$GITHUB_ENV"
|
||||
echo "asset_name=${asset_name}" >> "$GITHUB_ENV"
|
||||
# - name: Build wheel
|
||||
# shell: bash
|
||||
# env:
|
||||
# CMAKE_BUILD_TYPE: Release # do not compile with debug symbol to reduce wheel size
|
||||
# run: |
|
||||
# bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }}
|
||||
# wheel_name=$(find dist -name "*whl" -print0 | xargs -0 -n 1 basename)
|
||||
# asset_name=${wheel_name//"linux"/"manylinux1"}
|
||||
# echo "wheel_name=${wheel_name}" >> "$GITHUB_ENV"
|
||||
# echo "asset_name=${asset_name}" >> "$GITHUB_ENV"
|
||||
|
||||
- name: Upload Release Asset
|
||||
uses: actions/upload-release-asset@e8f9f06c4b078e705bd2ea027f0926603fc9b4d5 # v1.0.2
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
with:
|
||||
upload_url: ${{ needs.release.outputs.upload_url }}
|
||||
asset_path: ./dist/${{ env.wheel_name }}
|
||||
asset_name: ${{ env.asset_name }}
|
||||
asset_content_type: application/*
|
||||
# - name: Upload Release Asset
|
||||
# uses: actions/upload-release-asset@e8f9f06c4b078e705bd2ea027f0926603fc9b4d5 # v1.0.2
|
||||
# env:
|
||||
# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
# with:
|
||||
# upload_url: ${{ needs.release.outputs.upload_url }}
|
||||
# asset_path: ./dist/${{ env.wheel_name }}
|
||||
# asset_name: ${{ env.asset_name }}
|
||||
# asset_content_type: application/*
|
||||
|
||||
# (Danielkinz): This last step will publish the .whl to pypi. Warning: untested
|
||||
# - name: Publish package
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@ -81,6 +81,8 @@ instance/
|
||||
docs/_build/
|
||||
docs/source/getting_started/examples/*.rst
|
||||
!**/*.template.rst
|
||||
docs/source/getting_started/examples/*.md
|
||||
!**/*.template.md
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
|
||||
@ -193,6 +193,7 @@ set(VLLM_EXT_SRC
|
||||
"csrc/activation_kernels.cu"
|
||||
"csrc/layernorm_kernels.cu"
|
||||
"csrc/layernorm_quant_kernels.cu"
|
||||
"csrc/cuda_view.cu"
|
||||
"csrc/quantization/gptq/q_gemm.cu"
|
||||
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
|
||||
"csrc/quantization/fp8/common.cu"
|
||||
@ -200,13 +201,14 @@ set(VLLM_EXT_SRC
|
||||
"csrc/quantization/gguf/gguf_kernel.cu"
|
||||
"csrc/cuda_utils_kernels.cu"
|
||||
"csrc/prepare_inputs/advance_step.cu"
|
||||
"csrc/prepare_inputs/copy_subranges.cu"
|
||||
"csrc/torch_bindings.cpp")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
|
||||
|
||||
# Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case.
|
||||
set(CUTLASS_REVISION "v3.5.1" CACHE STRING "CUTLASS revision to use")
|
||||
set(CUTLASS_REVISION "v3.6.0" CACHE STRING "CUTLASS revision to use")
|
||||
|
||||
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
|
||||
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
|
||||
@ -223,7 +225,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
FetchContent_Declare(
|
||||
cutlass
|
||||
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
|
||||
GIT_TAG v3.5.1
|
||||
GIT_TAG v3.6.0
|
||||
GIT_PROGRESS TRUE
|
||||
|
||||
# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
|
||||
@ -241,7 +243,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"csrc/quantization/awq/gemm_kernels.cu"
|
||||
"csrc/custom_all_reduce.cu"
|
||||
"csrc/permute_cols.cu"
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu")
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
|
||||
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
|
||||
"csrc/sparse/cutlass/sparse_compressor_entry.cu"
|
||||
"csrc/cutlass_extensions/common.cpp")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${VLLM_EXT_SRC}"
|
||||
@ -270,7 +275,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
" in CUDA target architectures")
|
||||
endif()
|
||||
|
||||
#
|
||||
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
|
||||
# CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now).
|
||||
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}")
|
||||
@ -323,6 +327,31 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
#
|
||||
# 2:4 Sparse Kernels
|
||||
|
||||
# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
|
||||
# require CUDA 12.2 or later (and only work on Hopper, 9.0/9.0a for now).
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
|
||||
set(SRCS "csrc/sparse/cutlass/sparse_compressor_c3x.cu"
|
||||
"csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SPARSE_SCALED_MM_C3X=1")
|
||||
message(STATUS "Building sparse_scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
|
||||
message(STATUS "Not building sparse_scaled_mm_c3x kernels as CUDA Compiler version is "
|
||||
"not >= 12.2, we recommend upgrading to CUDA 12.2 or later "
|
||||
"if you intend on running FP8 sparse quantized models on Hopper.")
|
||||
else()
|
||||
message(STATUS "Not building sparse_scaled_mm_c3x as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
#
|
||||
# Machete kernels
|
||||
@ -404,7 +433,7 @@ define_gpu_extension_target(
|
||||
SOURCES ${VLLM_EXT_SRC}
|
||||
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
||||
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
||||
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
|
||||
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
|
||||
USE_SABI 3
|
||||
WITH_SOABI)
|
||||
|
||||
|
||||
47
Dockerfile
47
Dockerfile
@ -2,7 +2,7 @@
|
||||
# to run the OpenAI compatible server.
|
||||
|
||||
# Please update any changes made here to
|
||||
# docs/source/dev/dockerfile/dockerfile.rst and
|
||||
# docs/source/dev/dockerfile/dockerfile.md and
|
||||
# docs/source/assets/dev/dockerfile-stages-dependency.png
|
||||
|
||||
ARG CUDA_VERSION=12.4.1
|
||||
@ -45,17 +45,21 @@ RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/
|
||||
WORKDIR /workspace
|
||||
|
||||
# install build and runtime dependencies
|
||||
COPY requirements-common.txt requirements-common.txt
|
||||
COPY requirements-cuda.txt requirements-cuda.txt
|
||||
COPY requirements-cuda-arm64.txt requirements-cuda-arm64.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install -r requirements-cuda.txt
|
||||
|
||||
# arm64 (GH200) build follows the practice of "use existing pytorch" build,
|
||||
# we need to install torch and torchvision from the nightly builds first,
|
||||
# pytorch will not appear as a vLLM dependency in all of the following steps
|
||||
# after this step
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
|
||||
python3 -m pip install -r requirements-cuda-arm64.txt; \
|
||||
python3 -m pip install --index-url https://download.pytorch.org/whl/nightly/cu124 "torch==2.6.0.dev20241210+cu124" "torchvision==0.22.0.dev20241215"; \
|
||||
fi
|
||||
|
||||
COPY requirements-common.txt requirements-common.txt
|
||||
COPY requirements-cuda.txt requirements-cuda.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install -r requirements-cuda.txt
|
||||
|
||||
# cuda arch list used by torch
|
||||
# can be useful for both `dev` and `test`
|
||||
# explicitly set the list to avoid issues with torch 2.2
|
||||
@ -77,11 +81,6 @@ COPY requirements-build.txt requirements-build.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install -r requirements-build.txt
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
|
||||
python3 -m pip install -r requirements-cuda-arm64.txt; \
|
||||
fi
|
||||
|
||||
COPY . .
|
||||
ARG GIT_REPO_CHECK=0
|
||||
RUN --mount=type=bind,source=.git,target=.git \
|
||||
@ -157,8 +156,6 @@ WORKDIR /vllm-workspace
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ARG TARGETPLATFORM
|
||||
|
||||
COPY requirements-cuda-arm64.txt requirements-cuda-arm64.txt
|
||||
|
||||
RUN PYTHON_VERSION_STR=$(echo ${PYTHON_VERSION} | sed 's/\.//g') && \
|
||||
echo "export PYTHON_VERSION_STR=${PYTHON_VERSION_STR}" >> /etc/environment
|
||||
|
||||
@ -166,7 +163,7 @@ RUN PYTHON_VERSION_STR=$(echo ${PYTHON_VERSION} | sed 's/\.//g') && \
|
||||
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
|
||||
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y ccache software-properties-common git curl sudo vim python3-pip \
|
||||
&& apt-get install -y ccache software-properties-common git curl wget sudo vim python3-pip \
|
||||
&& apt-get install -y ffmpeg libsm6 libxext6 libgl1 \
|
||||
&& add-apt-repository ppa:deadsnakes/ppa \
|
||||
&& apt-get update -y \
|
||||
@ -183,17 +180,20 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
|
||||
# or future versions of triton.
|
||||
RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/
|
||||
|
||||
# arm64 (GH200) build follows the practice of "use existing pytorch" build,
|
||||
# we need to install torch and torchvision from the nightly builds first,
|
||||
# pytorch will not appear as a vLLM dependency in all of the following steps
|
||||
# after this step
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
|
||||
python3 -m pip install --index-url https://download.pytorch.org/whl/nightly/cu124 "torch==2.6.0.dev20241210+cu124" "torchvision==0.22.0.dev20241215"; \
|
||||
fi
|
||||
|
||||
# Install vllm wheel first, so that torch etc will be installed.
|
||||
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
|
||||
--mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install dist/*.whl --verbose
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
|
||||
pip uninstall -y torch && \
|
||||
python3 -m pip install -r requirements-cuda-arm64.txt; \
|
||||
fi
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
. /etc/environment && \
|
||||
if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
|
||||
@ -240,10 +240,11 @@ FROM vllm-base AS vllm-openai
|
||||
# install additional dependencies for openai api server
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
|
||||
pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.42.0' 'timm==0.9.10'; \
|
||||
pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.42.0' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \
|
||||
else \
|
||||
pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.45.0' 'timm==0.9.10'; \
|
||||
pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.45.0' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \
|
||||
fi
|
||||
|
||||
ENV VLLM_USAGE_SOURCE production-docker-image
|
||||
|
||||
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
|
||||
|
||||
@ -26,10 +26,10 @@ RUN pip install intel_extension_for_pytorch==2.5.0
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
COPY requirements-build.txt requirements-build.txt
|
||||
ARG PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu"
|
||||
ENV PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
--mount=type=bind,src=requirements-build.txt,target=requirements-build.txt \
|
||||
pip install --upgrade pip && \
|
||||
pip install -r requirements-build.txt
|
||||
|
||||
@ -37,9 +37,9 @@ FROM cpu-test-1 AS build
|
||||
|
||||
WORKDIR /workspace/vllm
|
||||
|
||||
COPY requirements-common.txt requirements-common.txt
|
||||
COPY requirements-cpu.txt requirements-cpu.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
--mount=type=bind,src=requirements-common.txt,target=requirements-common.txt \
|
||||
--mount=type=bind,src=requirements-cpu.txt,target=requirements-cpu.txt \
|
||||
pip install -v -r requirements-cpu.txt
|
||||
|
||||
COPY . .
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# default base image
|
||||
# https://gallery.ecr.aws/neuron/pytorch-inference-neuronx
|
||||
ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.1.2-neuronx-py310-sdk2.20.2-ubuntu20.04"
|
||||
ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.5.1-neuronx-py310-sdk2.21.0-ubuntu22.04"
|
||||
|
||||
FROM $BASE_IMAGE
|
||||
|
||||
@ -22,9 +22,9 @@ WORKDIR ${APP_MOUNT}/vllm
|
||||
|
||||
RUN python3 -m pip install --upgrade pip
|
||||
RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas
|
||||
RUN python3 -m pip install sentencepiece transformers==4.36.2 -U
|
||||
RUN python3 -m pip install sentencepiece transformers==4.45.2 -U
|
||||
RUN python3 -m pip install transformers-neuronx --extra-index-url=https://pip.repos.neuron.amazonaws.com -U
|
||||
RUN python3 -m pip install --pre neuronx-cc==2.15.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U
|
||||
RUN python3 -m pip install neuronx-cc==2.16.345.0 --extra-index-url=https://pip.repos.neuron.amazonaws.com -U
|
||||
|
||||
COPY . .
|
||||
ARG GIT_REPO_CHECK=0
|
||||
|
||||
@ -60,7 +60,7 @@ vLLM is flexible and easy to use with:
|
||||
|
||||
vLLM seamlessly supports most popular open-source models on HuggingFace, including:
|
||||
- Transformer-like LLMs (e.g., Llama)
|
||||
- Mixture-of-Expert LLMs (e.g., Mixtral)
|
||||
- Mixture-of-Expert LLMs (e.g., Mixtral, Deepseek-V2 and V3)
|
||||
- Embedding Models (e.g. E5-Mistral)
|
||||
- Multi-modal LLMs (e.g., LLaVA)
|
||||
|
||||
|
||||
184
benchmarks/benchmark_long_document_qa_throughput.py
Normal file
184
benchmarks/benchmark_long_document_qa_throughput.py
Normal file
@ -0,0 +1,184 @@
|
||||
"""
|
||||
Offline benchmark to test the long document QA throughput.
|
||||
|
||||
Example usage:
|
||||
# This command run the vllm with 50GB CPU memory for offloading
|
||||
# The workload samples 8 different prompts with a default input
|
||||
# length of 20000 tokens, then replicates each prompt 2 times
|
||||
# in random order.
|
||||
python benchmark_long_document_qa_throughput.py \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--enable-prefix-caching \
|
||||
--num-documents 8 \
|
||||
--repeat-count 2
|
||||
|
||||
Commandline arguments:
|
||||
--num-documents: The number of documents to sample prompts from.
|
||||
|
||||
--document-length: The length of each document in tokens.
|
||||
(Optional, default: 20000)
|
||||
|
||||
--output-len: The number of tokens to generate for each prompt.
|
||||
(Optional, default: 10)
|
||||
|
||||
--repeat-count: The number of times to repeat each prompt.
|
||||
(Optional, default: 2)
|
||||
|
||||
--repeat-mode: The mode to repeat prompts. The supported modes are:
|
||||
- 'random': shuffle the prompts randomly. (Default)
|
||||
- 'tile': the entire prompt list is repeated in sequence. (Potentially
|
||||
lowest cache hit)
|
||||
- 'interleave': each prompt is repeated consecutively before
|
||||
moving to the next element. (Highest cache hit)
|
||||
|
||||
--shuffle-seed: Random seed when the repeat mode is "random".
|
||||
(Optional, default: 0)
|
||||
|
||||
In the meantime, it also supports all the vLLM engine args to initialize the
|
||||
LLM engine. You can refer to the `vllm.engine.arg_utils.EngineArgs` for more
|
||||
details.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import random
|
||||
import time
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def test_long_document_qa(llm=None, sampling_params=None, prompts=None):
|
||||
"""
|
||||
Test long document QA with the given prompts and sampling parameters.
|
||||
Print the time spent in processing all the prompts.
|
||||
|
||||
Args:
|
||||
llm: The language model used for generating responses.
|
||||
sampling_params: Sampling parameter used to generate the response.
|
||||
prompts: A list of prompt strings to be processed by the LLM.
|
||||
"""
|
||||
start_time = time.time()
|
||||
llm.generate(prompts, sampling_params=sampling_params)
|
||||
end_time = time.time()
|
||||
print(f"Time to execute all requests: {end_time - start_time:.4f} secs")
|
||||
|
||||
|
||||
def repeat_prompts(prompts, repeat_count, mode: str):
|
||||
"""
|
||||
Repeat each prompt in the list for a specified number of times.
|
||||
The order of prompts in the output list depends on the mode.
|
||||
|
||||
Args:
|
||||
prompts: A list of prompts to be repeated.
|
||||
repeat_count: The number of times each prompt is repeated.
|
||||
mode: The mode of repetition. Supported modes are:
|
||||
- 'random': Shuffle the prompts randomly after repetition.
|
||||
- 'tile': Repeat the entire prompt list in sequence.
|
||||
Example: [1, 2, 3] -> [1, 2, 3, 1, 2, 3].
|
||||
- 'interleave': Repeat each prompt consecutively before moving to
|
||||
the next. Example: [1, 2, 3] -> [1, 1, 2, 2, 3, 3].
|
||||
|
||||
Returns:
|
||||
A list of repeated prompts in the specified order.
|
||||
|
||||
Raises:
|
||||
ValueError: If an invalid mode is provided.
|
||||
"""
|
||||
print("Repeat mode: ", mode)
|
||||
if mode == 'random':
|
||||
repeated_prompts = prompts * repeat_count
|
||||
random.shuffle(repeated_prompts)
|
||||
return repeated_prompts
|
||||
elif mode == 'tile':
|
||||
return prompts * repeat_count
|
||||
elif mode == 'interleave':
|
||||
repeated_prompts = []
|
||||
for prompt in prompts:
|
||||
repeated_prompts.extend([prompt] * repeat_count)
|
||||
return repeated_prompts
|
||||
else:
|
||||
raise ValueError(f"Invalid mode: {mode}, only support "
|
||||
"'random', 'tile', 'interleave'")
|
||||
|
||||
|
||||
def main(args):
|
||||
random.seed(args.shuffle_seed)
|
||||
|
||||
# Prepare the prompts:
|
||||
# we append the document id at the beginning to avoid any of the document
|
||||
# being the prefix of other documents
|
||||
prompts = [
|
||||
str(i) + ' '.join(['hi'] * args.document_length)
|
||||
for i in range(args.num_documents)
|
||||
]
|
||||
|
||||
prompts = repeat_prompts(prompts, args.repeat_count, mode=args.repeat_mode)
|
||||
|
||||
warmup_prompts = [
|
||||
"This is warm up request " + str(i) + \
|
||||
' '.join(['hi'] * args.document_length)
|
||||
for i in range(args.num_documents)]
|
||||
|
||||
# Create the LLM engine
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
|
||||
|
||||
print("------warm up------")
|
||||
test_long_document_qa(
|
||||
llm=llm,
|
||||
prompts=warmup_prompts,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
print("------start generating------")
|
||||
test_long_document_qa(
|
||||
llm=llm,
|
||||
prompts=prompts,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description=
|
||||
'Benchmark the performance with or without automatic prefix caching.')
|
||||
|
||||
parser.add_argument(
|
||||
'--document-length',
|
||||
type=int,
|
||||
# Roughly the number of tokens for a system paper,
|
||||
# excluding images
|
||||
default=20000,
|
||||
help='Range of input lengths for sampling prompts,'
|
||||
'specified as "min:max" (e.g., "128:256").')
|
||||
|
||||
parser.add_argument('--num-documents',
|
||||
type=int,
|
||||
default=8,
|
||||
help='Range of input lengths for sampling prompts,'
|
||||
'specified as "min:max" (e.g., "128:256").')
|
||||
|
||||
parser.add_argument('--output-len', type=int, default=10)
|
||||
|
||||
parser.add_argument('--repeat-count',
|
||||
type=int,
|
||||
default=2,
|
||||
help='Number of times to repeat each prompt')
|
||||
|
||||
parser.add_argument("--repeat-mode",
|
||||
type=str,
|
||||
default='random',
|
||||
help='The mode to repeat prompts. The supported '
|
||||
'modes are "random", "tile", and "interleave". '
|
||||
'See repeat_prompts() in the source code for details.')
|
||||
|
||||
parser.add_argument("--shuffle-seed",
|
||||
type=int,
|
||||
default=0,
|
||||
help='Random seed when the repeat mode is "random"')
|
||||
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@ -4,7 +4,8 @@ import dataclasses
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
from typing import List, Optional
|
||||
from functools import cache
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import uvloop
|
||||
@ -17,8 +18,11 @@ from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.entrypoints.openai.api_server import (
|
||||
build_async_engine_client_from_engine_args)
|
||||
from vllm.inputs import TextPrompt
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.utils import get_adapter_absolute_path
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer
|
||||
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
|
||||
|
||||
|
||||
@ -28,15 +32,17 @@ class SampleRequest:
|
||||
|
||||
Attributes:
|
||||
prompt: The input text prompt for the model.
|
||||
multi_modal_data: Optional dictionary containing multi-modal data (e.g.
|
||||
images).
|
||||
prompt_len: The length of the prompt in tokens.
|
||||
expected_output_len: The expected length of the output in tokens.
|
||||
multi_modal_data: Optional dictionary containing multi-modal data (e.g.
|
||||
images).
|
||||
lora_request: Optional LoRARequest specifying the LoRA to use.
|
||||
"""
|
||||
prompt: str
|
||||
prompt_len: int
|
||||
expected_output_len: int
|
||||
multi_modal_data: Optional[MultiModalDataDict] = None
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
|
||||
|
||||
def _get_prompt_for_image_model(question: str, *, model: str) -> str:
|
||||
@ -60,8 +66,30 @@ def _get_prompt_for_image_model(question: str, *, model: str) -> str:
|
||||
raise ValueError(f"Unsupported model {model}")
|
||||
|
||||
|
||||
@cache
|
||||
def lora_path_on_disk(lora_path: str) -> str:
|
||||
return get_adapter_absolute_path(lora_path)
|
||||
|
||||
|
||||
lora_tokenizer_cache: Dict[int, AnyTokenizer] = {}
|
||||
|
||||
|
||||
def get_random_lora_request(
|
||||
args: argparse.Namespace
|
||||
) -> Tuple[LoRARequest, Optional[AnyTokenizer]]:
|
||||
global lora_tokenizer_cache
|
||||
lora_id = random.randint(1, args.max_loras)
|
||||
lora_request = LoRARequest(lora_name=str(lora_id),
|
||||
lora_int_id=lora_id,
|
||||
lora_path=lora_path_on_disk(args.lora_path))
|
||||
if lora_id not in lora_tokenizer_cache:
|
||||
lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request)
|
||||
return lora_request, lora_tokenizer_cache[lora_id]
|
||||
|
||||
|
||||
def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||
args: argparse.Namespace) -> List[SampleRequest]:
|
||||
|
||||
dataset_path: str = args.dataset
|
||||
num_requests: int = args.num_prompts
|
||||
fixed_output_len: Optional[int] = args.output_len
|
||||
@ -79,7 +107,9 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||
|
||||
# Filter out sequences that are too long or too short
|
||||
filtered_dataset: List[SampleRequest] = []
|
||||
for data in dataset:
|
||||
for data in tqdm(dataset,
|
||||
total=len(filtered_dataset),
|
||||
desc="sampling requests"):
|
||||
if len(filtered_dataset) == num_requests:
|
||||
break
|
||||
|
||||
@ -102,9 +132,16 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||
continue
|
||||
prompt = _get_prompt_for_image_model(question=prompt, model=model)
|
||||
|
||||
request_tokenizer = tokenizer
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
if args.enable_lora:
|
||||
lora_request, lora_tokenizer = get_random_lora_request(args)
|
||||
if lora_tokenizer:
|
||||
request_tokenizer = lora_tokenizer
|
||||
|
||||
# Tokenize the prompts and completions.
|
||||
prompt_token_ids = tokenizer(prompt).input_ids
|
||||
completion_token_ids = tokenizer(completion).input_ids
|
||||
prompt_token_ids = request_tokenizer(prompt).input_ids
|
||||
completion_token_ids = request_tokenizer(completion).input_ids
|
||||
prompt_len = len(prompt_token_ids)
|
||||
output_len = len(completion_token_ids
|
||||
) if fixed_output_len is None else fixed_output_len
|
||||
@ -118,7 +155,8 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||
SampleRequest(prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
multi_modal_data=multi_modal_data))
|
||||
multi_modal_data=multi_modal_data,
|
||||
lora_request=lora_request))
|
||||
|
||||
return filtered_dataset
|
||||
|
||||
@ -146,14 +184,21 @@ def run_vllm(
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
))
|
||||
lora_requests: Optional[List[LoRARequest]] = None
|
||||
if engine_args.enable_lora:
|
||||
lora_requests = [request.lora_request for request in requests]
|
||||
|
||||
use_beam_search = False
|
||||
|
||||
if not use_beam_search:
|
||||
start = time.perf_counter()
|
||||
llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||
llm.generate(prompts,
|
||||
sampling_params,
|
||||
lora_request=lora_requests,
|
||||
use_tqdm=True)
|
||||
end = time.perf_counter()
|
||||
else:
|
||||
assert lora_requests is None, "BeamSearch API does not support LoRA"
|
||||
prompts = [request.prompt for request in requests]
|
||||
# output_len should be the same for all requests.
|
||||
output_len = requests[0][2]
|
||||
@ -185,6 +230,7 @@ async def run_vllm_async(
|
||||
# Add the requests to the engine.
|
||||
prompts: List[TextPrompt] = []
|
||||
sampling_params: List[SamplingParams] = []
|
||||
lora_requests: List[Optional[LoRARequest]] = []
|
||||
for request in requests:
|
||||
prompts.append(
|
||||
TextPrompt(prompt=request.prompt,
|
||||
@ -197,11 +243,16 @@ async def run_vllm_async(
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
))
|
||||
lora_requests.append(request.lora_request)
|
||||
|
||||
generators = []
|
||||
start = time.perf_counter()
|
||||
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
|
||||
generator = llm.generate(prompt, sp, request_id=f"test{i}")
|
||||
for i, (prompt, sp,
|
||||
lr) in enumerate(zip(prompts, sampling_params, lora_requests)):
|
||||
generator = llm.generate(prompt,
|
||||
sp,
|
||||
lora_request=lr,
|
||||
request_id=f"test{i}")
|
||||
generators.append(generator)
|
||||
all_gens = merge_async_iterators(*generators)
|
||||
async for i, res in all_gens:
|
||||
@ -297,6 +348,14 @@ def main(args: argparse.Namespace):
|
||||
vocab_size = tokenizer.vocab_size
|
||||
requests = []
|
||||
for _ in range(args.num_prompts):
|
||||
|
||||
request_tokenizer = tokenizer
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
if args.enable_lora:
|
||||
lora_request, lora_tokenizer = get_random_lora_request(args)
|
||||
if lora_tokenizer:
|
||||
request_tokenizer = lora_tokenizer
|
||||
|
||||
# Synthesize a prompt with the given input length.
|
||||
candidate_ids = [
|
||||
random.randint(0, vocab_size - 1)
|
||||
@ -305,8 +364,8 @@ def main(args: argparse.Namespace):
|
||||
# As tokenizer may add additional tokens like BOS, we need to try
|
||||
# different lengths to get the desired input length.
|
||||
for _ in range(5): # Max attempts to correct
|
||||
candidate_prompt = tokenizer.decode(candidate_ids)
|
||||
tokenized_len = len(tokenizer.encode(candidate_prompt))
|
||||
candidate_prompt = request_tokenizer.decode(candidate_ids)
|
||||
tokenized_len = len(request_tokenizer.encode(candidate_prompt))
|
||||
|
||||
if tokenized_len == args.input_len:
|
||||
break
|
||||
@ -323,7 +382,8 @@ def main(args: argparse.Namespace):
|
||||
requests.append(
|
||||
SampleRequest(prompt=candidate_prompt,
|
||||
prompt_len=args.input_len,
|
||||
expected_output_len=args.output_len))
|
||||
expected_output_len=args.output_len,
|
||||
lora_request=lora_request))
|
||||
else:
|
||||
requests = sample_requests(tokenizer, args)
|
||||
|
||||
@ -422,6 +482,14 @@ if __name__ == "__main__":
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Disable decoupled async engine frontend.")
|
||||
# LoRA
|
||||
parser.add_argument(
|
||||
"--lora-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the lora adapters to use. This can be an absolute path, "
|
||||
"a relative path, or a Hugging Face model identifier.")
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
if args.tokenizer is None:
|
||||
@ -431,6 +499,8 @@ if __name__ == "__main__":
|
||||
assert args.output_len is not None
|
||||
else:
|
||||
assert args.input_len is None
|
||||
if args.enable_lora:
|
||||
assert args.lora_path is not None
|
||||
|
||||
if args.backend == "vllm":
|
||||
if args.hf_max_batch_size is not None:
|
||||
@ -440,6 +510,9 @@ if __name__ == "__main__":
|
||||
raise ValueError("HF max batch size is required for HF backend.")
|
||||
if args.quantization is not None:
|
||||
raise ValueError("Quantization is only for vLLM backend.")
|
||||
if args.enable_lora is not None:
|
||||
raise ValueError("LoRA benchmarking is only supported for vLLM"
|
||||
" backend")
|
||||
elif args.backend == "mii":
|
||||
if args.dtype != "auto":
|
||||
raise ValueError("dtype must be auto for MII backend.")
|
||||
@ -452,4 +525,7 @@ if __name__ == "__main__":
|
||||
if args.tokenizer != args.model:
|
||||
raise ValueError("Tokenizer must be the same as the model for MII "
|
||||
"backend.")
|
||||
if args.enable_lora is not None:
|
||||
raise ValueError("LoRA benchmarking is only supported for vLLM"
|
||||
" backend")
|
||||
main(args)
|
||||
|
||||
384
benchmarks/cutlass_benchmarks/sparse_benchmarks.py
Normal file
384
benchmarks/cutlass_benchmarks/sparse_benchmarks.py
Normal file
@ -0,0 +1,384 @@
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
import pickle as pkl
|
||||
import time
|
||||
from typing import Callable, Iterable, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.utils.benchmark as TBenchmark
|
||||
from torch.utils.benchmark import Measurement as TMeasurement
|
||||
from utils import make_rand_sparse_tensors
|
||||
from weight_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
|
||||
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
|
||||
DEFAULT_TP_SIZES = [1]
|
||||
|
||||
|
||||
# bench
|
||||
def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
|
||||
**kwargs) -> TMeasurement:
|
||||
min_run_time = 1
|
||||
|
||||
globals = {
|
||||
"args": args,
|
||||
"kwargs": kwargs,
|
||||
"fn": fn,
|
||||
}
|
||||
return TBenchmark.Timer(
|
||||
stmt="fn(*args, **kwargs)",
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description=description,
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
|
||||
|
||||
def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
||||
sub_label: str) -> Iterable[TMeasurement]:
|
||||
assert dtype == torch.int8
|
||||
b_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
|
||||
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b,
|
||||
torch.bfloat16)
|
||||
out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)
|
||||
|
||||
if not torch.allclose(out, out_ref):
|
||||
print("Incorrect results")
|
||||
print(out)
|
||||
print(out_ref)
|
||||
else:
|
||||
print("Correct results")
|
||||
|
||||
timers = []
|
||||
# pytorch impl - bfloat16
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
|
||||
torch.mm, a.to(dtype=torch.bfloat16),
|
||||
b.to(dtype=torch.bfloat16)))
|
||||
|
||||
# pytorch impl - float16
|
||||
timers.append(
|
||||
bench_fn(label, sub_label,
|
||||
"pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm,
|
||||
a.to(dtype=torch.float16), b.to(dtype=torch.float16)))
|
||||
|
||||
# cutlass impl
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm",
|
||||
ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
|
||||
torch.bfloat16))
|
||||
|
||||
# cutlass with bias
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias",
|
||||
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16,
|
||||
bias))
|
||||
|
||||
# cutlass sparse impl
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm",
|
||||
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
|
||||
scale_b, torch.bfloat16))
|
||||
|
||||
# cutlass sparse with bias
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm_bias",
|
||||
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
|
||||
scale_b, torch.bfloat16, bias))
|
||||
|
||||
return timers
|
||||
|
||||
|
||||
def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
||||
sub_label: str) -> Iterable[TMeasurement]:
|
||||
assert dtype == torch.float8_e4m3fn
|
||||
b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n,
|
||||
k)
|
||||
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b,
|
||||
torch.bfloat16)
|
||||
out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)
|
||||
|
||||
if not torch.allclose(out, out_ref):
|
||||
print("Incorrect results")
|
||||
print(out)
|
||||
print(out_ref)
|
||||
else:
|
||||
print("Correct results")
|
||||
|
||||
timers = []
|
||||
|
||||
# pytorch impl w. bf16
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
|
||||
torch.mm, a.to(dtype=torch.bfloat16, device="cuda"),
|
||||
b.to(dtype=torch.bfloat16, device="cuda")))
|
||||
|
||||
# pytorch impl: bf16 output, without fp8 fast accum
|
||||
timers.append(
|
||||
bench_fn(label,
|
||||
sub_label,
|
||||
"pytorch_fp8_fp8_bf16_scaled_mm",
|
||||
torch._scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=torch.bfloat16))
|
||||
|
||||
# pytorch impl: bf16 output, with fp8 fast accum
|
||||
timers.append(
|
||||
bench_fn(label,
|
||||
sub_label,
|
||||
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum",
|
||||
torch._scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=torch.bfloat16,
|
||||
use_fast_accum=True))
|
||||
|
||||
# pytorch impl: fp16 output, without fp8 fast accum
|
||||
timers.append(
|
||||
bench_fn(label,
|
||||
sub_label,
|
||||
"pytorch_fp8_fp8_fp16_scaled_mm",
|
||||
torch._scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=torch.float16))
|
||||
|
||||
# pytorch impl: fp16 output, with fp8 fast accum
|
||||
timers.append(
|
||||
bench_fn(label,
|
||||
sub_label,
|
||||
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum",
|
||||
torch._scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=torch.float16,
|
||||
use_fast_accum=True))
|
||||
|
||||
# cutlass impl: bf16 output
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm",
|
||||
ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
|
||||
torch.bfloat16))
|
||||
|
||||
# cutlass impl: bf16 output
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_sparse_mm",
|
||||
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
|
||||
scale_b, torch.bfloat16))
|
||||
|
||||
# cutlass impl: fp16 output
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_sparse_mm",
|
||||
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
|
||||
scale_b, torch.float16))
|
||||
|
||||
# cutlass impl: bf16 output, with bias
|
||||
timers.append(
|
||||
bench_fn(label, sub_label,
|
||||
"cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias",
|
||||
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
|
||||
scale_b, torch.bfloat16, bias))
|
||||
|
||||
# cutlass impl: fp16 output, with bias
|
||||
timers.append(
|
||||
bench_fn(label, sub_label,
|
||||
"cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias",
|
||||
ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
|
||||
scale_b, torch.float16, bias.to(dtype=torch.float16)))
|
||||
|
||||
return timers
|
||||
|
||||
|
||||
def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
||||
sub_label: str) -> Iterable[TMeasurement]:
|
||||
if dtype == torch.int8:
|
||||
return bench_int8(dtype, m, k, n, label, sub_label)
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
return bench_fp8(dtype, m, k, n, label, sub_label)
|
||||
raise ValueError("unsupported type")
|
||||
|
||||
|
||||
# runner
|
||||
def print_timers(timers: Iterable[TMeasurement]):
|
||||
compare = TBenchmark.Compare(timers)
|
||||
compare.print()
|
||||
|
||||
|
||||
def run(dtype: torch.dtype,
|
||||
MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
|
||||
results = []
|
||||
for m, k, n in MKNs:
|
||||
timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm",
|
||||
f"MKN=({m}x{k}x{n})")
|
||||
print_timers(timers)
|
||||
results.extend(timers)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# output makers
|
||||
def make_output(data: Iterable[TMeasurement],
|
||||
MKNs: Iterable[Tuple[int, int, int]],
|
||||
base_description: str,
|
||||
timestamp=None):
|
||||
print(f"== All Results {base_description} ====")
|
||||
print_timers(data)
|
||||
|
||||
# pickle all the results
|
||||
timestamp = int(time.time()) if timestamp is None else timestamp
|
||||
with open(f"{base_description}-{timestamp}.pkl", "wb") as f:
|
||||
pkl.dump(data, f)
|
||||
|
||||
|
||||
# argparse runners
|
||||
|
||||
|
||||
def run_square_bench(args):
|
||||
dim_sizes = list(
|
||||
range(args.dim_start, args.dim_end + 1, args.dim_increment))
|
||||
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
|
||||
data = run(args.dtype, MKNs)
|
||||
|
||||
make_output(data, MKNs, f"square_bench-{args.dtype}")
|
||||
|
||||
|
||||
def run_range_bench(args):
|
||||
dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment))
|
||||
n = len(dim_sizes)
|
||||
Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes
|
||||
Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
|
||||
Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
|
||||
MKNs = list(zip(Ms, Ks, Ns))
|
||||
data = run(args.dtype, MKNs)
|
||||
|
||||
make_output(data, MKNs, f"range_bench-{args.dtype}")
|
||||
|
||||
|
||||
def run_model_bench(args):
|
||||
print("Benchmarking models:")
|
||||
for i, model in enumerate(args.models):
|
||||
print(f"[{i}] {model}")
|
||||
|
||||
def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]:
|
||||
KNs = []
|
||||
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
|
||||
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
|
||||
KNs.append(KN)
|
||||
return KNs
|
||||
|
||||
model_bench_data = []
|
||||
models_tps = list(itertools.product(args.models, args.tp_sizes))
|
||||
for model, tp_size in models_tps:
|
||||
Ms = args.batch_sizes
|
||||
KNs = model_shapes(model, tp_size)
|
||||
MKNs = []
|
||||
for m in Ms:
|
||||
for k, n in KNs:
|
||||
MKNs.append((m, k, n))
|
||||
|
||||
data = run(args.dtype, MKNs)
|
||||
model_bench_data.append(data)
|
||||
|
||||
# Print all results
|
||||
for data, model_tp in zip(model_bench_data, models_tps):
|
||||
model, tp_size = model_tp
|
||||
print(f"== Results {args.dtype} {model}-TP{tp_size} ====")
|
||||
print_timers(data)
|
||||
|
||||
timestamp = int(time.time())
|
||||
|
||||
all_data = []
|
||||
for d in model_bench_data:
|
||||
all_data.extend(d)
|
||||
# pickle all data
|
||||
with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f:
|
||||
pkl.dump(all_data, f)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
def to_torch_dtype(dt):
|
||||
if dt == "int8":
|
||||
return torch.int8
|
||||
if dt == "fp8":
|
||||
return torch.float8_e4m3fn
|
||||
raise ValueError("unsupported dtype")
|
||||
|
||||
parser = FlexibleArgumentParser(
|
||||
description="""
|
||||
Benchmark Cutlass GEMM.
|
||||
|
||||
To run square GEMMs:
|
||||
python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
|
||||
|
||||
To run constant N and K and sweep M:
|
||||
python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
|
||||
|
||||
To run dimensions from a model:
|
||||
python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
|
||||
|
||||
Output:
|
||||
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
|
||||
""", # noqa: E501
|
||||
formatter_class=argparse.RawTextHelpFormatter)
|
||||
|
||||
parser.add_argument("--dtype",
|
||||
type=to_torch_dtype,
|
||||
required=True,
|
||||
help="Available options are ['int8', 'fp8']")
|
||||
subparsers = parser.add_subparsers(dest="cmd")
|
||||
|
||||
square_parser = subparsers.add_parser("square_bench")
|
||||
square_parser.add_argument("--dim-start", type=int, required=True)
|
||||
square_parser.add_argument("--dim-end", type=int, required=True)
|
||||
square_parser.add_argument("--dim-increment", type=int, required=True)
|
||||
square_parser.set_defaults(func=run_square_bench)
|
||||
|
||||
range_parser = subparsers.add_parser("range_bench")
|
||||
range_parser.add_argument("--dim-start", type=int, required=True)
|
||||
range_parser.add_argument("--dim-end", type=int, required=True)
|
||||
range_parser.add_argument("--dim-increment", type=int, required=True)
|
||||
range_parser.add_argument("--m-constant", type=int, default=None)
|
||||
range_parser.add_argument("--n-constant", type=int, default=None)
|
||||
range_parser.add_argument("--k-constant", type=int, default=None)
|
||||
range_parser.set_defaults(func=run_range_bench)
|
||||
|
||||
model_parser = subparsers.add_parser("model_bench")
|
||||
model_parser.add_argument("--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=DEFAULT_MODELS,
|
||||
choices=WEIGHT_SHAPES.keys())
|
||||
model_parser.add_argument("--tp-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=DEFAULT_TP_SIZES)
|
||||
model_parser.add_argument("--batch-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=DEFAULT_BATCH_SIZES)
|
||||
model_parser.set_defaults(func=run_model_bench)
|
||||
|
||||
args = parser.parse_args()
|
||||
args.func(args)
|
||||
96
benchmarks/cutlass_benchmarks/utils.py
Normal file
96
benchmarks/cutlass_benchmarks/utils.py
Normal file
@ -0,0 +1,96 @@
|
||||
# Cutlass bench utils
|
||||
from typing import Iterable, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
|
||||
|
||||
def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
return torch.round(tensor.clamp(
|
||||
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
|
||||
|
||||
|
||||
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
|
||||
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
||||
|
||||
|
||||
def to_bf16(tensor: torch.Tensor) -> torch.Tensor:
|
||||
return tensor.to(dtype=torch.bfloat16)
|
||||
|
||||
|
||||
def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
|
||||
return tensor.to(dtype=torch.float16)
|
||||
|
||||
|
||||
def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
|
||||
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
a = torch.randn((m, k), device='cuda') * 5
|
||||
b = torch.randn((n, k), device='cuda').t() * 5
|
||||
|
||||
if dtype == torch.int8:
|
||||
return to_int8(a), to_int8(b)
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
return to_fp8(a), to_fp8(b)
|
||||
|
||||
raise ValueError("unsupported dtype")
|
||||
|
||||
|
||||
def prune_to_2_4(tensor):
|
||||
# Reshape tensor to [N, 4] where N is number of groups of 4
|
||||
original_shape = tensor.shape
|
||||
reshaped = tensor.reshape(-1, 4)
|
||||
|
||||
# Get indices of top 2 absolute values in each group of 4
|
||||
_, indices = torch.topk(torch.abs(reshaped), k=2, dim=1)
|
||||
|
||||
# Create binary mask
|
||||
mask = torch.zeros_like(reshaped)
|
||||
mask.scatter_(dim=1,
|
||||
index=indices,
|
||||
src=torch.ones_like(indices, dtype=mask.dtype))
|
||||
|
||||
# Apply mask and reshape back
|
||||
pruned = reshaped * mask
|
||||
|
||||
# Turn all -0.0 to 0.0
|
||||
pruned[pruned == -0.0] = 0.0
|
||||
|
||||
return pruned.reshape(original_shape)
|
||||
|
||||
|
||||
def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int,
|
||||
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
a = torch.randn((m, k), device='cuda') * 5
|
||||
b = torch.randn((n, k), device='cuda').t() * 5
|
||||
|
||||
b = prune_to_2_4(b.t()).t()
|
||||
|
||||
if dtype == torch.int8:
|
||||
a, b = to_int8(a), to_int8(b)
|
||||
elif dtype == torch.float8_e4m3fn:
|
||||
a, b = to_fp8(a), to_fp8(b)
|
||||
elif dtype == torch.float16:
|
||||
a, b = to_fp16(a), to_fp16(b)
|
||||
elif dtype == torch.bfloat16:
|
||||
a, b = to_bf16(a), to_bf16(b)
|
||||
else:
|
||||
raise ValueError("unsupported dtype")
|
||||
|
||||
b_compressed, e = ops.cutlass_sparse_compress(b.t())
|
||||
|
||||
# Compressed B, Metadata, Original A, B
|
||||
return b_compressed, e, a, b
|
||||
|
||||
|
||||
def make_n_rand_sparse_tensors(num_tensors: int, dtype: torch.dtype,
|
||||
m: int, n: int, k: int) -> \
|
||||
Tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]:
|
||||
ABs = []
|
||||
for _ in range(num_tensors):
|
||||
b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)
|
||||
if b_comp is not None:
|
||||
ABs.append(make_rand_sparse_tensors(dtype, m, n, k))
|
||||
BComps, Es, As, Bs = zip(*ABs)
|
||||
return list(BComps), list(Es), list(As), list(Bs)
|
||||
@ -8,6 +8,7 @@ from typing import Callable, Iterable, List, Tuple
|
||||
import torch
|
||||
import torch.utils.benchmark as TBenchmark
|
||||
from torch.utils.benchmark import Measurement as TMeasurement
|
||||
from utils import make_rand_tensors
|
||||
from weight_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
@ -17,31 +18,6 @@ DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
|
||||
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
|
||||
DEFAULT_TP_SIZES = [1]
|
||||
|
||||
# helpers
|
||||
|
||||
|
||||
def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
return torch.round(tensor.clamp(
|
||||
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
|
||||
|
||||
|
||||
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
|
||||
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
||||
|
||||
|
||||
def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
|
||||
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
a = torch.randn((m, k), device='cuda') * 5
|
||||
b = torch.randn((n, k), device='cuda').t() * 5
|
||||
|
||||
if dtype == torch.int8:
|
||||
return to_int8(a), to_int8(b)
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
return to_fp8(a), to_fp8(b)
|
||||
|
||||
raise ValueError("unsupported dtype")
|
||||
|
||||
|
||||
# bench
|
||||
def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
|
||||
@ -386,4 +362,4 @@ Benchmark Cutlass GEMM.
|
||||
model_parser.set_defaults(func=run_model_bench)
|
||||
|
||||
args = parser.parse_args()
|
||||
args.func(args)
|
||||
args.func(args)
|
||||
@ -40,4 +40,4 @@ WEIGHT_SHAPES = {
|
||||
([8192, 57344], 1),
|
||||
([28672, 8192], 0),
|
||||
],
|
||||
}
|
||||
}
|
||||
@ -10,7 +10,8 @@ set -ex
|
||||
|
||||
kill_gpu_processes() {
|
||||
# kill all processes on GPU.
|
||||
pkill -f pt_main_thread
|
||||
pgrep pt_main_thread | xargs -r kill -9
|
||||
pgrep python3 | xargs -r kill -9
|
||||
sleep 10
|
||||
|
||||
# remove vllm config file
|
||||
@ -54,7 +55,7 @@ benchmark() {
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python3 \
|
||||
-m vllm.entrypoints.openai.api_server \
|
||||
--model meta-llama/Meta-Llama-3.1-8B-Instruct \
|
||||
--model $model \
|
||||
--port 8100 \
|
||||
--max-model-len 10000 \
|
||||
--gpu-memory-utilization 0.6 \
|
||||
@ -64,7 +65,7 @@ benchmark() {
|
||||
|
||||
CUDA_VISIBLE_DEVICES=1 python3 \
|
||||
-m vllm.entrypoints.openai.api_server \
|
||||
--model meta-llama/Meta-Llama-3.1-8B-Instruct \
|
||||
--model $model \
|
||||
--port 8200 \
|
||||
--max-model-len 10000 \
|
||||
--gpu-memory-utilization 0.6 \
|
||||
@ -87,7 +88,7 @@ benchmark() {
|
||||
--port 8100 \
|
||||
--save-result \
|
||||
--result-dir $results_folder \
|
||||
--result-filename disagg_prefill_2xtp4.json \
|
||||
--result-filename disagg_prefill_tp1.json \
|
||||
--request-rate "inf"
|
||||
|
||||
|
||||
@ -105,7 +106,7 @@ benchmark() {
|
||||
--port 8200 \
|
||||
--save-result \
|
||||
--result-dir $results_folder \
|
||||
--result-filename disagg_prefill_2xtp4.json \
|
||||
--result-filename disagg_prefill_tp1_overhead.json \
|
||||
--request-rate "$qps"
|
||||
kill_gpu_processes
|
||||
|
||||
@ -118,7 +119,7 @@ main() {
|
||||
(which jq) || (apt-get -y install jq)
|
||||
(which socat) || (apt-get -y install socat)
|
||||
|
||||
pip install quart httpx
|
||||
pip install quart httpx datasets
|
||||
|
||||
cd "$(dirname "$0")"
|
||||
|
||||
|
||||
@ -1,13 +1,12 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Requirement: 8x H100 GPUs.
|
||||
# Requirement: 2x GPUs.
|
||||
|
||||
|
||||
# Model: neuralmagic/Meta-Llama-3-70B-Instruct-FP8-KV
|
||||
# Query: 2048 input tokens, 11 output tokens, QPS 4, 500 requests
|
||||
# Resource: 8x H100
|
||||
# Model: meta-llama/Meta-Llama-3.1-8B-Instruct
|
||||
# Query: 1024 input tokens, 6 output tokens, QPS 2/4/6/8, 100 requests
|
||||
# Resource: 2x GPU
|
||||
# Approaches:
|
||||
# 1. Chunked prefill: 1 vllm instance with tp=8
|
||||
# 2. Chunked prefill: 2 vllm instance with tp=4, equivalent to 1 tp=4 instance with QPS 4
|
||||
# 3. Disaggregated prefill: 1 prefilling instance and 1 decoding instance
|
||||
# Prefilling instance: max_output_token=1
|
||||
@ -114,7 +113,6 @@ benchmark() {
|
||||
--request-rate "$qps"
|
||||
|
||||
sleep 2
|
||||
|
||||
}
|
||||
|
||||
|
||||
@ -123,8 +121,9 @@ main() {
|
||||
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
|
||||
(which jq) || (apt-get -y install jq)
|
||||
(which socat) || (apt-get -y install socat)
|
||||
(which lsof) || (apt-get -y install lsof)
|
||||
|
||||
pip install quart httpx matplotlib aiohttp
|
||||
pip install quart httpx matplotlib aiohttp datasets
|
||||
|
||||
cd "$(dirname "$0")"
|
||||
|
||||
|
||||
7
csrc/core/math.hpp
Normal file
7
csrc/core/math.hpp
Normal file
@ -0,0 +1,7 @@
|
||||
#include <climits>
|
||||
#include <iostream>
|
||||
|
||||
inline uint32_t next_pow_2(uint32_t const num) {
|
||||
if (num <= 1) return num;
|
||||
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
||||
}
|
||||
@ -47,3 +47,11 @@
|
||||
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
||||
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
||||
#endif
|
||||
|
||||
// #ifndef USE_ROCM
|
||||
// #define VLLM_cudaHostGetDevicePointer(device_ptr, host_ptr, flags) \
|
||||
// cudaHostGetDevicePointer(device_ptr, host_ptr, flags)
|
||||
// #else
|
||||
// #define VLLM_cudaHostGetDevicePointer(device_ptr, host_ptr, flags) \
|
||||
// hipHostGetDevicePointer(device_ptr, host_ptr, flags)
|
||||
// #endif
|
||||
|
||||
43
csrc/cuda_view.cu
Normal file
43
csrc/cuda_view.cu
Normal file
@ -0,0 +1,43 @@
|
||||
#include <torch/all.h>
|
||||
#include <torch/cuda.h>
|
||||
|
||||
// This function assumes that `cpu_tensor` is a CPU tensor allocated with pinned
|
||||
// memory, and that UVA (Unified Virtual Addressing) is enabled.
|
||||
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor) {
|
||||
TORCH_CHECK(cpu_tensor.device().is_cpu(), "Input tensor must be on CPU");
|
||||
TORCH_CHECK(cpu_tensor.is_contiguous(), "Input tensor must be contiguous");
|
||||
|
||||
// Get raw host pointer from CPU tensor
|
||||
void* host_ptr = cpu_tensor.data_ptr();
|
||||
|
||||
// Get a device pointer corresponding to the pinned host memory
|
||||
void* device_ptr = nullptr;
|
||||
cudaError_t err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
|
||||
TORCH_CHECK(err == cudaSuccess,
|
||||
"cudaHostGetDevicePointer failed: ", cudaGetErrorString(err));
|
||||
|
||||
// Construct a CUDA tensor from the device pointer.
|
||||
// We'll use the same sizes, strides, and dtype as the CPU tensor.
|
||||
auto sizes = cpu_tensor.sizes();
|
||||
auto strides = cpu_tensor.strides();
|
||||
auto options =
|
||||
cpu_tensor.options().device(torch::kCUDA); // Change device to CUDA
|
||||
|
||||
// from_blob signature: from_blob(void *data, IntArrayRef sizes, ..., Deleter,
|
||||
// const TensorOptions &) Provide a no-op deleter. The CPU tensor holds the
|
||||
// memory, so we don't free it here.
|
||||
auto deleter = [](void*) {
|
||||
// no-op, since the memory is owned by the original CPU tensor
|
||||
};
|
||||
|
||||
torch::Tensor cuda_tensor =
|
||||
torch::from_blob(device_ptr, sizes, strides, deleter, options);
|
||||
|
||||
TORCH_CHECK(cuda_tensor.device().is_cuda(),
|
||||
"Resulting tensor is not on CUDA device");
|
||||
TORCH_CHECK(cuda_tensor.sizes().equals(sizes), "Size mismatch");
|
||||
TORCH_CHECK(cuda_tensor.strides().equals(strides), "Stride mismatch");
|
||||
TORCH_CHECK(cuda_tensor.dtype() == cpu_tensor.dtype(), "Dtype mismatch");
|
||||
|
||||
return cuda_tensor;
|
||||
}
|
||||
11
csrc/cutlass_extensions/common.cpp
Normal file
11
csrc/cutlass_extensions/common.cpp
Normal file
@ -0,0 +1,11 @@
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
int32_t get_sm_version_num() {
|
||||
int32_t major_capability, minor_capability;
|
||||
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
|
||||
0);
|
||||
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
|
||||
0);
|
||||
int32_t version_num = major_capability * 10 + minor_capability;
|
||||
return version_num;
|
||||
}
|
||||
35
csrc/cutlass_extensions/common.hpp
Normal file
35
csrc/cutlass_extensions/common.hpp
Normal file
@ -0,0 +1,35 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include <climits>
|
||||
#include "cuda_runtime.h"
|
||||
#include <iostream>
|
||||
|
||||
/**
|
||||
* Helper function for checking CUTLASS errors
|
||||
*/
|
||||
#define CUTLASS_CHECK(status) \
|
||||
{ \
|
||||
cutlass::Status error = status; \
|
||||
TORCH_CHECK(error == cutlass::Status::kSuccess, \
|
||||
cutlassGetStatusString(error)); \
|
||||
}
|
||||
|
||||
/**
|
||||
* Panic wrapper for unwinding CUDA runtime errors
|
||||
*/
|
||||
#define CUDA_CHECK(status) \
|
||||
{ \
|
||||
cudaError_t error = status; \
|
||||
TORCH_CHECK(error == cudaSuccess, cudaGetErrorString(error)); \
|
||||
}
|
||||
|
||||
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
|
||||
int max_shared_mem_per_block_opt_in = 0;
|
||||
cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin,
|
||||
device);
|
||||
return max_shared_mem_per_block_opt_in;
|
||||
}
|
||||
|
||||
int32_t get_sm_version_num();
|
||||
@ -1,3 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
|
||||
|
||||
/*
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
|
||||
|
||||
/*
|
||||
@ -36,13 +38,13 @@ struct ScaledEpilogueBase {
|
||||
// Don't want to support nullptr by default
|
||||
template <typename T, bool EnableNullPtr = false>
|
||||
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T,
|
||||
Stride<Int<1>, Int<0>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
|
||||
|
||||
// Don't want to support nullptr by default
|
||||
template <typename T, bool EnableNullPtr = false>
|
||||
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T,
|
||||
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
|
||||
|
||||
// This utility function constructs the arguments for the load descriptors
|
||||
|
||||
@ -14,9 +14,9 @@ class VLLMDataType(enum.Enum):
|
||||
|
||||
|
||||
class MixedInputKernelScheduleType(enum.Enum):
|
||||
TmaWarpSpecializedMixedInput = enum_auto()
|
||||
TmaWarpSpecializedPingpongMixedInput = enum_auto()
|
||||
TmaWarpSpecializedCooperativeMixedInput = enum_auto()
|
||||
TmaWarpSpecialized = enum_auto()
|
||||
TmaWarpSpecializedPingpong = enum_auto()
|
||||
TmaWarpSpecializedCooperative = enum_auto()
|
||||
|
||||
|
||||
VLLMDataTypeNames: Dict[Union[VLLMDataType, DataType], str] = {
|
||||
@ -68,11 +68,11 @@ VLLMKernelScheduleTag: Dict[Union[
|
||||
MixedInputKernelScheduleType, KernelScheduleType], str] = {
|
||||
**KernelScheduleTag, # type: ignore
|
||||
**{
|
||||
MixedInputKernelScheduleType.TmaWarpSpecializedMixedInput:
|
||||
"cutlass::gemm::KernelTmaWarpSpecializedMixedInput",
|
||||
MixedInputKernelScheduleType.TmaWarpSpecializedPingpongMixedInput:
|
||||
"cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput",
|
||||
MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput:
|
||||
"cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput",
|
||||
MixedInputKernelScheduleType.TmaWarpSpecialized:
|
||||
"cutlass::gemm::KernelTmaWarpSpecialized",
|
||||
MixedInputKernelScheduleType.TmaWarpSpecializedPingpong:
|
||||
"cutlass::gemm::KernelTmaWarpSpecializedPingpong",
|
||||
MixedInputKernelScheduleType.TmaWarpSpecializedCooperative:
|
||||
"cutlass::gemm::KernelTmaWarpSpecializedCooperative",
|
||||
}
|
||||
}
|
||||
|
||||
@ -113,6 +113,92 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(simon): this is temporarily adapted from
|
||||
// https://github.com/sgl-project/sglang/commit/31548116a8dc8c6df7e146e0587335a59fc5b9d7
|
||||
// we did this to unblock Deepseek V3 but there should be a better
|
||||
// implementation to manage shared memory.
|
||||
template <typename scalar_t>
|
||||
__global__ void moe_align_block_size_global_mem_kernel(
|
||||
scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids,
|
||||
int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts,
|
||||
int32_t block_size, size_t numel, int32_t* tokens_cnts, int32_t* cumsum) {
|
||||
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
|
||||
const size_t start_idx = threadIdx.x * tokens_per_thread;
|
||||
|
||||
for (int i = 0; i < num_experts; ++i) {
|
||||
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* In the first step we compute token_cnts[thread_index + 1][expert_index],
|
||||
* which counts how many tokens in the token shard of thread_index are
|
||||
* assigned to expert expert_index.
|
||||
*/
|
||||
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
||||
++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// For each expert we accumulate the token counts from the different threads.
|
||||
if (threadIdx.x < num_experts) {
|
||||
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
|
||||
for (int i = 1; i <= blockDim.x; ++i) {
|
||||
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
|
||||
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// We accumulate the token counts of all experts in thread 0.
|
||||
if (threadIdx.x == 0) {
|
||||
cumsum[0] = 0;
|
||||
for (int i = 1; i <= num_experts; ++i) {
|
||||
cumsum[i] = cumsum[i - 1] +
|
||||
CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
|
||||
block_size) *
|
||||
block_size;
|
||||
}
|
||||
*total_tokens_post_pad = cumsum[num_experts];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
/**
|
||||
* For each expert, each thread processes the tokens of the corresponding
|
||||
* blocks and stores the corresponding expert_id for each block.
|
||||
*/
|
||||
if (threadIdx.x < num_experts) {
|
||||
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
|
||||
i += block_size) {
|
||||
expert_ids[i / block_size] = threadIdx.x;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Each thread processes a token shard, calculating the index of each token
|
||||
* after sorting by expert number. Given the example topk_ids =
|
||||
* [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
|
||||
* *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
|
||||
* padding value(preset in python).
|
||||
*/
|
||||
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
||||
int32_t expert_id = topk_ids[i];
|
||||
/** The cumsum[expert_id] stores the starting index of the tokens that the
|
||||
* expert with expert_id needs to process, and
|
||||
* tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
|
||||
* processed by the expert with expert_id within the current thread's token
|
||||
* shard.
|
||||
*/
|
||||
int32_t rank_post_pad =
|
||||
tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
|
||||
cumsum[expert_id];
|
||||
sorted_token_ids[rank_post_pad] = i;
|
||||
++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, int TOPK>
|
||||
__global__ void moe_sum_kernel(
|
||||
scalar_t* __restrict__ out, // [..., d]
|
||||
@ -137,25 +223,61 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad) {
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
|
||||
// tensors
|
||||
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
|
||||
const int32_t shared_mem =
|
||||
((num_thread + 1) * num_experts + (num_experts + 1)) *
|
||||
sizeof(int32_t);
|
||||
|
||||
// set dynamic shared mem
|
||||
auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t>;
|
||||
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
|
||||
(void*)kernel, shared_mem));
|
||||
kernel<<<1, num_thread, shared_mem, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
|
||||
experts_ids.data_ptr<int32_t>(),
|
||||
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
|
||||
topk_ids.numel());
|
||||
});
|
||||
// If we have very large number of experts, we can no longer use shared
|
||||
// memory.
|
||||
// TODO(simon): the right solution should be calculating the exact right
|
||||
// amount of shared memory and use that. The num_experts >= 256 is just a
|
||||
// temporary solution to unblock Deepseek V3.
|
||||
if (num_experts >= 256) {
|
||||
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||
topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
|
||||
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
|
||||
// tensors
|
||||
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
|
||||
|
||||
const int32_t mem_tokens_cnts =
|
||||
((num_experts + 1) * num_experts) * sizeof(int32_t);
|
||||
const int32_t mem_cumsum = (num_experts + 1) * sizeof(int32_t);
|
||||
// allocate global memory
|
||||
int32_t* tokens_cnts;
|
||||
int32_t* cumsum;
|
||||
cudaMalloc(&tokens_cnts, mem_tokens_cnts);
|
||||
cudaMalloc(&cumsum, mem_cumsum);
|
||||
|
||||
auto kernel =
|
||||
vllm::moe::moe_align_block_size_global_mem_kernel<scalar_t>;
|
||||
kernel<<<1, num_thread, 0, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
experts_ids.data_ptr<int32_t>(),
|
||||
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
|
||||
topk_ids.numel(), tokens_cnts, cumsum);
|
||||
cudaFree(tokens_cnts);
|
||||
cudaFree(cumsum);
|
||||
});
|
||||
} else {
|
||||
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
|
||||
// tensors
|
||||
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
|
||||
const int32_t shared_mem =
|
||||
((num_thread + 1) * num_experts + (num_experts + 1)) *
|
||||
sizeof(int32_t);
|
||||
|
||||
// set dynamic shared mem
|
||||
auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t>;
|
||||
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
|
||||
(void*)kernel, shared_mem));
|
||||
kernel<<<1, num_thread, shared_mem, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
experts_ids.data_ptr<int32_t>(),
|
||||
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
|
||||
topk_ids.numel());
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
|
||||
|
||||
16
csrc/ops.h
16
csrc/ops.h
@ -115,6 +115,11 @@ void advance_step_flashinfer(
|
||||
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
|
||||
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);
|
||||
|
||||
void copy_subranges(torch::Tensor& matrix_src, torch::Tensor& matrix_diff,
|
||||
torch::Tensor& matrix_tgt, int64_t n);
|
||||
|
||||
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
@ -162,6 +167,17 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& azp,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
|
||||
bool cutlass_sparse_scaled_mm_supported(int64_t cuda_device_capability);
|
||||
|
||||
void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& e,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
|
||||
bool cutlass_sparse_compress_entry(torch::Tensor& a_compressed,
|
||||
torch::Tensor& e, torch::Tensor const& a);
|
||||
#endif
|
||||
|
||||
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
|
||||
75
csrc/prepare_inputs/copy_subranges.cu
Normal file
75
csrc/prepare_inputs/copy_subranges.cu
Normal file
@ -0,0 +1,75 @@
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
namespace vllm {
|
||||
__global__ void copy_subranges_kernel(const int* __restrict__ matrix_src,
|
||||
const int* __restrict__ matrix_diff,
|
||||
int* __restrict__ matrix_tgt, int64_t M) {
|
||||
int row_id = blockIdx.x;
|
||||
int row_offset = row_id * M;
|
||||
|
||||
int start = matrix_diff[row_id * 2];
|
||||
int length = matrix_diff[row_id * 2 + 1];
|
||||
int end = start + length;
|
||||
int thread_idx = threadIdx.x;
|
||||
for (int i = start + thread_idx; i < end; i += blockDim.x) {
|
||||
int idx = row_offset + i;
|
||||
matrix_tgt[idx] = matrix_src[idx];
|
||||
}
|
||||
}
|
||||
} // namespace vllm
|
||||
|
||||
void copy_subranges(torch::Tensor& matrix_src, torch::Tensor& matrix_diff,
|
||||
torch::Tensor& matrix_tgt, int64_t n) {
|
||||
// NOTE(woosuk): Here, we skip most of the error checking to minimize the
|
||||
// CPU overheads. We assume that the caller will pass the correct inputs.
|
||||
|
||||
// Check tensor properties
|
||||
// TORCH_CHECK(matrix_src.is_cuda(), "matrix_src must be a CUDA tensor");
|
||||
// TORCH_CHECK(matrix_diff.is_cuda(), "matrix_diff must be a CUDA tensor");
|
||||
// TORCH_CHECK(matrix_tgt.is_cuda(), "matrix_tgt must be a CUDA tensor");
|
||||
// TORCH_CHECK(matrix_src.is_contiguous(), "matrix_src must be contiguous");
|
||||
// TORCH_CHECK(matrix_diff.is_contiguous(), "matrix_diff must be contiguous");
|
||||
// TORCH_CHECK(matrix_tgt.is_contiguous(), "matrix_tgt must be contiguous");
|
||||
|
||||
auto src_sizes = matrix_src.sizes();
|
||||
auto diff_sizes = matrix_diff.sizes();
|
||||
auto tgt_sizes = matrix_tgt.sizes();
|
||||
|
||||
// TORCH_CHECK(src_sizes.size() == 2, "matrix_src must be 2D");
|
||||
// TORCH_CHECK(diff_sizes.size() == 2, "matrix_diff must be 2D");
|
||||
// TORCH_CHECK(tgt_sizes.size() == 2, "matrix_tgt must be 2D");
|
||||
|
||||
int64_t N = src_sizes[0];
|
||||
int64_t M = src_sizes[1];
|
||||
|
||||
// TORCH_CHECK(diff_sizes[0] == N, "matrix_diff first dim must match N");
|
||||
// TORCH_CHECK(diff_sizes[1] == 2, "matrix_diff second dim must be 2");
|
||||
// TORCH_CHECK(tgt_sizes[0] == N && tgt_sizes[1] == M,
|
||||
// "matrix_tgt must have same shape as matrix_src");
|
||||
|
||||
// TORCH_CHECK(n <= N, "n must be <= N");
|
||||
|
||||
const int* d_matrix_src = matrix_src.data_ptr<int>();
|
||||
const int* d_matrix_diff = matrix_diff.data_ptr<int>();
|
||||
int* d_matrix_tgt = matrix_tgt.data_ptr<int>();
|
||||
|
||||
// One thread block per row.
|
||||
int blocks = n;
|
||||
int threads;
|
||||
if (blocks < 128) {
|
||||
threads = 1024;
|
||||
} else if (blocks < 256) {
|
||||
threads = 512;
|
||||
} else if (blocks < 512) {
|
||||
threads = 256;
|
||||
} else {
|
||||
threads = 128;
|
||||
}
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(matrix_tgt));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
vllm::copy_subranges_kernel<<<blocks, threads, 0, stream>>>(
|
||||
d_matrix_src, d_matrix_diff, d_matrix_tgt, M);
|
||||
}
|
||||
@ -1,27 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include <climits>
|
||||
|
||||
/**
|
||||
* Helper function for checking CUTLASS errors
|
||||
*/
|
||||
#define CUTLASS_CHECK(status) \
|
||||
{ \
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, \
|
||||
cutlassGetStatusString(status)) \
|
||||
}
|
||||
|
||||
inline uint32_t next_pow_2(uint32_t const num) {
|
||||
if (num <= 1) return num;
|
||||
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
||||
}
|
||||
|
||||
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
|
||||
int max_shared_mem_per_block_opt_in = 0;
|
||||
cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin,
|
||||
device);
|
||||
return max_shared_mem_per_block_opt_in;
|
||||
}
|
||||
|
||||
@ -21,15 +21,16 @@
|
||||
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
|
||||
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
|
||||
|
||||
#include "common.hpp"
|
||||
#include "core/math.hpp"
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
// clang-format on
|
||||
|
||||
using namespace cute;
|
||||
|
||||
/*
|
||||
Epilogue functions can be defined to post-process the output before it is
|
||||
written to GPU memory.
|
||||
Epilogues must contain a public type named EVTCompute of type Sm80EVT,
|
||||
Epilogues defined in,
|
||||
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
|
||||
must contain a public type named EVTCompute of type Sm80EVT,
|
||||
as well as a static prepare_args function that constructs an
|
||||
EVTCompute::Arguments struct.
|
||||
*/
|
||||
|
||||
@ -1,384 +1,18 @@
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
|
||||
#include <torch/all.h>
|
||||
#include "scaled_mm_c3x_sm90_fp8_dispatch.cuh"
|
||||
#include "scaled_mm_c3x_sm90_int8_dispatch.cuh"
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
#include "common.hpp"
|
||||
// clang-format on
|
||||
|
||||
using namespace cute;
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
using namespace vllm;
|
||||
|
||||
/*
|
||||
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
||||
NVIDIA GPUs with sm90a (Hopper) or later.
|
||||
|
||||
Epilogue functions can be defined to post-process the output before it is
|
||||
written to GPU memory.
|
||||
Epilogues must contain a public type named EVTCompute of type Sm90EVT,
|
||||
as well as a static prepare_args function that constructs an
|
||||
EVTCompute::Arguments struct.
|
||||
*/
|
||||
|
||||
namespace {
|
||||
|
||||
// A wrapper for the GEMM kernel that is used to guard against compilation on
|
||||
// architectures that will never use the kernel. The purpose of this is to
|
||||
// reduce the size of the compiled binary.
|
||||
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
|
||||
// into code that will be executed on the device where it is defined.
|
||||
template <typename Kernel>
|
||||
struct enable_sm90_or_later : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE void operator()(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
|
||||
Kernel::operator()(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
template <typename ElementAB_, typename ElementD_,
|
||||
template <typename, typename, typename> typename Epilogue_,
|
||||
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
||||
typename EpilogueSchedule>
|
||||
struct cutlass_3x_gemm {
|
||||
using ElementAB = ElementAB_;
|
||||
using ElementD = ElementD_;
|
||||
using ElementAcc =
|
||||
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
||||
float>::type;
|
||||
|
||||
using EpilogueDescriptor =
|
||||
cutlass::epilogue::collective::detail::EpilogueDescriptor<
|
||||
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
|
||||
ElementD, EpilogueSchedule>;
|
||||
|
||||
using Epilogue = Epilogue_<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||
|
||||
using StrideD = Stride<int64_t, Int<1>, Int<0>>;
|
||||
using ElementC = void;
|
||||
using StrideC = StrideD;
|
||||
|
||||
using EVTCompute = typename Epilogue::EVTCompute;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
|
||||
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4,
|
||||
EpilogueSchedule, EVTCompute>::CollectiveOp;
|
||||
|
||||
static constexpr size_t CEStorageSize =
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage);
|
||||
using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(CEStorageSize)>;
|
||||
|
||||
// clang-format off
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
ElementAB, cutlass::layout::RowMajor, 16,
|
||||
ElementAB, cutlass::layout::ColumnMajor, 16,
|
||||
ElementAcc, TileShape, ClusterShape,
|
||||
Stages,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
// clang-format on
|
||||
|
||||
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
|
||||
cutlass::gemm::PersistentScheduler>>;
|
||||
|
||||
struct GemmKernel : public KernelType {};
|
||||
};
|
||||
|
||||
template <typename Gemm, typename... EpilogueArgs>
|
||||
void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_params) {
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
|
||||
int32_t m = a.size(0);
|
||||
int32_t n = b.size(1);
|
||||
int32_t k = a.size(1);
|
||||
|
||||
int64_t lda = a.stride(0);
|
||||
int64_t ldb = b.stride(1);
|
||||
int64_t ldc = out.stride(0);
|
||||
|
||||
using StrideA = Stride<int64_t, Int<1>, int64_t>;
|
||||
using StrideB = Stride<int64_t, Int<1>, int64_t>;
|
||||
using StrideC = typename Gemm::StrideC;
|
||||
|
||||
StrideA a_stride{lda, Int<1>{}, 0};
|
||||
StrideB b_stride{ldb, Int<1>{}, 0};
|
||||
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
|
||||
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};
|
||||
|
||||
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
||||
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
|
||||
b_stride};
|
||||
|
||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
Gemm::Epilogue::prepare_args(
|
||||
std::forward<EpilogueArgs>(epilogue_params)...),
|
||||
c_ptr, c_stride, c_ptr, c_stride};
|
||||
|
||||
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
prob_shape, mainloop_args, epilogue_args};
|
||||
|
||||
// Launch the CUTLASS GEMM kernel.
|
||||
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
GemmOp gemm_op;
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||
|
||||
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_default {
|
||||
// M in (128, inf)
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_M128 {
|
||||
// M in (64, 128]
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _128, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_M64 {
|
||||
// M in [1, 64]
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _64, _128>;
|
||||
using ClusterShape = Shape<_1, _8, _1>;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_int8_config_default {
|
||||
// For M > 128 and any N
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using KernelSchedule =
|
||||
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_int8_config_M128 {
|
||||
// For M in (64, 128] and any N
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using KernelSchedule =
|
||||
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _128, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_int8_config_M64 {
|
||||
// For M in (32, 64] and any N
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _64, _256>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_int8_config_M32_NBig {
|
||||
// For M in [1, 32] and N >= 8192
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _128, _256>;
|
||||
using ClusterShape = Shape<_1, _4, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_int8_config_M32_NSmall {
|
||||
// For M in [1, 32] and N < 8192
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _64, _256>;
|
||||
using ClusterShape = Shape<_1, _8, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm90_fp8_config_default<InType, OutType,
|
||||
Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM64 =
|
||||
typename sm90_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM128 =
|
||||
typename sm90_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
|
||||
uint32_t const m = a.size(0);
|
||||
uint32_t const mp2 =
|
||||
std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
|
||||
|
||||
if (mp2 <= 64) {
|
||||
// m in [1, 64]
|
||||
return cutlass_gemm_caller<Cutlass3xGemmM64>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 128) {
|
||||
// m in (64, 128]
|
||||
return cutlass_gemm_caller<Cutlass3xGemmM128>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
// m in (128, inf)
|
||||
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm90_int8_config_default<InType, OutType,
|
||||
Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM128 =
|
||||
typename sm90_int8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM64 =
|
||||
typename sm90_int8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM32NBig =
|
||||
typename sm90_int8_config_M32_NBig<InType, OutType,
|
||||
Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM32NSmall =
|
||||
typename sm90_int8_config_M32_NSmall<InType, OutType,
|
||||
Epilogue>::Cutlass3xGemm;
|
||||
|
||||
uint32_t const n = out.size(1);
|
||||
bool const is_small_n = n < 8192;
|
||||
|
||||
uint32_t const m = a.size(0);
|
||||
uint32_t const mp2 =
|
||||
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
|
||||
|
||||
if (mp2 <= 32) {
|
||||
// m in [1, 32]
|
||||
if (is_small_n) {
|
||||
return cutlass_gemm_caller<Cutlass3xGemmM32NSmall>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
return cutlass_gemm_caller<Cutlass3xGemmM32NBig>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
} else if (mp2 <= 64) {
|
||||
// m in (32, 64]
|
||||
return cutlass_gemm_caller<Cutlass3xGemmM64>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 128) {
|
||||
// m in (64, 128]
|
||||
return cutlass_gemm_caller<Cutlass3xGemmM128>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
// m in (128, inf)
|
||||
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
template <template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm90_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
||||
|
||||
160
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cuh
Normal file
160
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cuh
Normal file
@ -0,0 +1,160 @@
|
||||
#pragma once
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "core/math.hpp"
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
// clang-format on
|
||||
|
||||
/*
|
||||
Epilogues defined in,
|
||||
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp,
|
||||
must contain a public type named EVTCompute of type Sm90EVT, as well as a
|
||||
static prepare_args function that constructs an EVTCompute::Arguments struct.
|
||||
*/
|
||||
|
||||
using namespace cute;
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// A wrapper for the GEMM kernel that is used to guard against compilation on
|
||||
// architectures that will never use the kernel. The purpose of this is to
|
||||
// reduce the size of the compiled binary.
|
||||
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
|
||||
// into code that will be executed on the device where it is defined.
|
||||
template <typename Kernel>
|
||||
struct enable_sm90_or_later : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE void operator()(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
|
||||
Kernel::operator()(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ElementAB_, typename ElementD_,
|
||||
template <typename, typename, typename> typename Epilogue_,
|
||||
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
||||
typename EpilogueSchedule>
|
||||
struct cutlass_3x_gemm {
|
||||
using ElementAB = ElementAB_;
|
||||
using ElementD = ElementD_;
|
||||
using ElementAcc =
|
||||
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
||||
float>::type;
|
||||
|
||||
using EpilogueDescriptor =
|
||||
cutlass::epilogue::collective::detail::EpilogueDescriptor<
|
||||
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
|
||||
ElementD, EpilogueSchedule>;
|
||||
|
||||
using Epilogue = Epilogue_<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||
|
||||
using StrideD = Stride<int64_t, Int<1>, Int<0>>;
|
||||
using ElementC = void;
|
||||
using StrideC = StrideD;
|
||||
|
||||
using EVTCompute = typename Epilogue::EVTCompute;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
|
||||
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4,
|
||||
EpilogueSchedule, EVTCompute>::CollectiveOp;
|
||||
|
||||
static constexpr size_t CEStorageSize =
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage);
|
||||
using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(CEStorageSize)>;
|
||||
|
||||
// clang-format off
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
ElementAB, cutlass::layout::RowMajor, 16,
|
||||
ElementAB, cutlass::layout::ColumnMajor, 16,
|
||||
ElementAcc, TileShape, ClusterShape,
|
||||
Stages,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
// clang-format on
|
||||
|
||||
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
|
||||
cutlass::gemm::PersistentScheduler>>;
|
||||
|
||||
struct GemmKernel : public KernelType {};
|
||||
};
|
||||
|
||||
template <typename Gemm, typename... EpilogueArgs>
|
||||
void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_params) {
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
|
||||
int32_t m = a.size(0);
|
||||
int32_t n = b.size(1);
|
||||
int32_t k = a.size(1);
|
||||
|
||||
int64_t lda = a.stride(0);
|
||||
int64_t ldb = b.stride(1);
|
||||
int64_t ldc = out.stride(0);
|
||||
|
||||
using StrideA = Stride<int64_t, Int<1>, int64_t>;
|
||||
using StrideB = Stride<int64_t, Int<1>, int64_t>;
|
||||
using StrideC = typename Gemm::StrideC;
|
||||
|
||||
StrideA a_stride{lda, Int<1>{}, 0};
|
||||
StrideB b_stride{ldb, Int<1>{}, 0};
|
||||
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
|
||||
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};
|
||||
|
||||
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
||||
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
|
||||
b_stride};
|
||||
|
||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
Gemm::Epilogue::prepare_args(
|
||||
std::forward<EpilogueArgs>(epilogue_params)...),
|
||||
c_ptr, c_stride, c_ptr, c_stride};
|
||||
|
||||
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
prob_shape, mainloop_args, epilogue_args};
|
||||
|
||||
// Launch the CUTLASS GEMM kernel.
|
||||
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
GemmOp gemm_op;
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||
|
||||
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@ -0,0 +1,96 @@
|
||||
#pragma once
|
||||
|
||||
#include "scaled_mm_c3x.cuh"
|
||||
|
||||
/**
|
||||
* This file defines Gemm kernel configurations for SM90 (fp8) based on the Gemm
|
||||
* shape.
|
||||
*/
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_default {
|
||||
// M in (128, inf)
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_M128 {
|
||||
// M in (64, 128]
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _128, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_M64 {
|
||||
// M in [1, 64]
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _64, _128>;
|
||||
using ClusterShape = Shape<_1, _8, _1>;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm90_fp8_config_default<InType, OutType,
|
||||
Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM64 =
|
||||
typename sm90_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM128 =
|
||||
typename sm90_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
|
||||
uint32_t const m = a.size(0);
|
||||
uint32_t const mp2 =
|
||||
std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
|
||||
|
||||
if (mp2 <= 64) {
|
||||
// m in [1, 64]
|
||||
return cutlass_gemm_caller<Cutlass3xGemmM64>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 128) {
|
||||
// m in (64, 128]
|
||||
return cutlass_gemm_caller<Cutlass3xGemmM128>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
// m in (128, inf)
|
||||
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@ -0,0 +1,140 @@
|
||||
#pragma once
|
||||
|
||||
#include "scaled_mm_c3x.cuh"
|
||||
|
||||
/**
|
||||
* This file defines Gemm kernel configurations for SM90 (int8) based on the
|
||||
* Gemm shape.
|
||||
*/
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_int8_config_default {
|
||||
// For M > 128 and any N
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using KernelSchedule =
|
||||
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_int8_config_M128 {
|
||||
// For M in (64, 128] and any N
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using KernelSchedule =
|
||||
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _128, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_int8_config_M64 {
|
||||
// For M in (32, 64] and any N
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _64, _256>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_int8_config_M32_NBig {
|
||||
// For M in [1, 32] and N >= 8192
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _128, _256>;
|
||||
using ClusterShape = Shape<_1, _4, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_int8_config_M32_NSmall {
|
||||
// For M in [1, 32] and N < 8192
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _64, _256>;
|
||||
using ClusterShape = Shape<_1, _8, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm90_int8_config_default<InType, OutType,
|
||||
Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM128 =
|
||||
typename sm90_int8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM64 =
|
||||
typename sm90_int8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM32NBig =
|
||||
typename sm90_int8_config_M32_NBig<InType, OutType,
|
||||
Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM32NSmall =
|
||||
typename sm90_int8_config_M32_NSmall<InType, OutType,
|
||||
Epilogue>::Cutlass3xGemm;
|
||||
|
||||
uint32_t const n = out.size(1);
|
||||
bool const is_small_n = n < 8192;
|
||||
|
||||
uint32_t const m = a.size(0);
|
||||
uint32_t const mp2 =
|
||||
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
|
||||
|
||||
if (mp2 <= 32) {
|
||||
// m in [1, 32]
|
||||
if (is_small_n) {
|
||||
return cutlass_gemm_caller<Cutlass3xGemmM32NSmall>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
return cutlass_gemm_caller<Cutlass3xGemmM32NBig>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
} else if (mp2 <= 64) {
|
||||
// m in (32, 64]
|
||||
return cutlass_gemm_caller<Cutlass3xGemmM64>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 128) {
|
||||
// m in (64, 128]
|
||||
return cutlass_gemm_caller<Cutlass3xGemmM128>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
// m in (128, inf)
|
||||
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@ -3,6 +3,8 @@
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
@ -79,16 +81,6 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int32_t get_sm_version_num() {
|
||||
int32_t major_capability, minor_capability;
|
||||
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
|
||||
0);
|
||||
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
|
||||
0);
|
||||
int32_t version_num = major_capability * 10 + minor_capability;
|
||||
return version_num;
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
|
||||
@ -189,7 +189,7 @@ using Kernel_{{type_sig}} = MacheteKernelTemplate<
|
||||
{{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT
|
||||
{{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT
|
||||
{{DataTypeTag[t.a_token_scale]}}, // TokenScaleT
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperative,
|
||||
Sch>;
|
||||
|
||||
{% for sch in schs %}
|
||||
@ -223,7 +223,7 @@ torch::Tensor prepack_B_dispatch(PrepackBArgs args) {
|
||||
{{DataTypeTag[t.convert]}}, // ElementConvert
|
||||
{{DataTypeTag[t.accumulator]}}, // Accumulator
|
||||
cutlass::layout::ColumnMajor,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput>
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperative>
|
||||
>(args.B);
|
||||
}
|
||||
{%- endfor %}
|
||||
@ -239,7 +239,7 @@ torch::Tensor prepack_B_dispatch(PrepackBArgs args) {
|
||||
}; // namespace machete
|
||||
"""
|
||||
|
||||
TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput
|
||||
TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperative
|
||||
TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative
|
||||
|
||||
|
||||
@ -300,7 +300,7 @@ def generate_sch_sig(schedule_config: ScheduleConfig) -> str:
|
||||
# mostly unique shorter sch_sig
|
||||
def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str:
|
||||
kernel_terse_names_replace = {
|
||||
"KernelTmaWarpSpecializedCooperativeMixedInput_": "TmaMI_",
|
||||
"KernelTmaWarpSpecializedCooperative": "TmaMI_",
|
||||
"TmaWarpSpecializedCooperative_": "TmaCoop_",
|
||||
"StreamKScheduler": "streamK",
|
||||
}
|
||||
|
||||
@ -18,16 +18,14 @@ struct VLLMCollectiveBuilder<
|
||||
ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType,
|
||||
KernelScheduleType,
|
||||
cute::enable_if_t<(
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong> ||
|
||||
cute::is_same_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedMixedInput> ||
|
||||
cute::is_same_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedPingpongMixedInput> ||
|
||||
cute::is_same_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedCooperativeMixedInput>)>> {
|
||||
KernelTmaWarpSpecializedCooperative>)>> {
|
||||
using CollectiveOp = machete::MacheteCollectiveMma<
|
||||
ElementPairA_, GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_,
|
||||
AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK,
|
||||
StageCountType, KernelScheduleType>;
|
||||
};
|
||||
|
||||
}; // namespace cutlass::gemm::collective
|
||||
}; // namespace cutlass::gemm::collective
|
||||
|
||||
@ -66,13 +66,11 @@ struct MacheteCollectiveMma {
|
||||
using Schedule = KernelScheduleType;
|
||||
static_assert(
|
||||
cute::is_same_v<Schedule, KernelTmaWarpSpecialized> ||
|
||||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedMixedInput> ||
|
||||
cute::is_same_v<Schedule, KernelTmaWarpSpecialized> ||
|
||||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedPingpong> ||
|
||||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedPingpong> ||
|
||||
cute::is_same_v<Schedule,
|
||||
KernelTmaWarpSpecializedPingpongMixedInput> ||
|
||||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedCooperative> ||
|
||||
cute::is_same_v<Schedule,
|
||||
KernelTmaWarpSpecializedCooperativeMixedInput>,
|
||||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedCooperative>,
|
||||
"KernelSchedule must be one of the warp specialized policies");
|
||||
|
||||
public:
|
||||
@ -113,8 +111,7 @@ struct MacheteCollectiveMma {
|
||||
// For coop schedules we have two warp groups cooperatively issuing wgmma
|
||||
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
|
||||
using AtomLayoutMNK = cute::conditional_t<
|
||||
cute::is_same_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedCooperativeMixedInput>,
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative>,
|
||||
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
|
||||
|
||||
using TiledMma = decltype(cute::make_tiled_mma(
|
||||
|
||||
@ -98,8 +98,7 @@ struct PrepackedLayoutBTemplate {
|
||||
// For coop schedules we have two warp groups cooperatively issuing wgmma
|
||||
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
|
||||
using AtomLayoutMNK = cute::conditional_t<
|
||||
cute::is_same_v<KernelSchedule,
|
||||
KernelTmaWarpSpecializedCooperativeMixedInput>,
|
||||
cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperative>,
|
||||
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
|
||||
|
||||
using TiledMma = decltype(cute::make_tiled_mma(
|
||||
@ -247,4 +246,4 @@ struct PrepackedLayoutBTemplate {
|
||||
}
|
||||
};
|
||||
|
||||
}; // namespace machete
|
||||
}; // namespace machete
|
||||
|
||||
165
csrc/sparse/cutlass/sparse_compressor_c3x.cu
Normal file
165
csrc/sparse/cutlass/sparse_compressor_c3x.cu
Normal file
@ -0,0 +1,165 @@
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12020
|
||||
#include "sparse_scaled_mm_c3x.cuh"
|
||||
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass/transform/device/transform_universal_adapter.hpp"
|
||||
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
// clang-format on
|
||||
|
||||
using namespace cute;
|
||||
using namespace vllm;
|
||||
|
||||
/// Make A structured sparse by replacing elements with 0 and compress it
|
||||
template <typename ElementA_, typename ElementAcc_>
|
||||
bool cutlass_sparse_compress(torch::Tensor& a_nzs, torch::Tensor& a_meta,
|
||||
torch::Tensor const& a) {
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8 || a.dtype() == torch::kFloat8_e4m3fn ||
|
||||
a.dtype() == torch::kFloat16 || a.dtype() == torch::kBFloat16);
|
||||
TORCH_CHECK(a.dim() == 2)
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(0) % 4 == 0) // Required for semi-structured sparsity
|
||||
TORCH_CHECK(a.stride(1) == 1)
|
||||
|
||||
int m = a.size(0);
|
||||
int k = a.size(1);
|
||||
|
||||
// Sparse kernel setup; this kernel is not used for matmul,
|
||||
// but just for setting up the compressor utility
|
||||
// A matrix configuration
|
||||
using ElementA = ElementA_;
|
||||
using LayoutTagA = cutlass::layout::RowMajor;
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
||||
// B matrix configuration
|
||||
using ElementB = ElementA;
|
||||
using LayoutTagB = cutlass::layout::ColumnMajor;
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
||||
// C/D matrix configuration
|
||||
using ElementC = float;
|
||||
using LayoutTagC = cutlass::layout::ColumnMajor;
|
||||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = ElementAcc_;
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using TileShapeRef = Shape<_128, _128, _64>;
|
||||
using ClusterShape = Shape<_1, _2, _1>;
|
||||
using KernelSchedule = typename std::conditional<
|
||||
std::is_same_v<ElementA, cutlass::float_e4m3_t>,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum,
|
||||
cutlass::gemm::KernelTmaWarpSpecialized>::type;
|
||||
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
|
||||
using ProblemShape = Shape<int, int, int, int>;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
|
||||
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementAccumulator, ElementC, LayoutTagC,
|
||||
AlignmentC, ElementC, LayoutTagC, AlignmentC,
|
||||
EpilogueSchedule>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp, ElementA,
|
||||
LayoutTagA, AlignmentA, ElementB, LayoutTagB, AlignmentB,
|
||||
ElementAccumulator, TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel =
|
||||
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop,
|
||||
CollectiveEpilogue>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
using StrideA = cutlass::gemm::TagToStrideA_t<LayoutTagA>;
|
||||
using StrideE = StrideA;
|
||||
|
||||
using StrideA = Stride<int64_t, Int<1>, int64_t>;
|
||||
|
||||
// The n (=1) dimension does not matter for the compressor
|
||||
typename GemmKernel::ProblemShape prob_shape{m, 1, k, 1};
|
||||
|
||||
using LayoutA = typename GemmKernel::CollectiveMainloop::LayoutA;
|
||||
using LayoutE = typename GemmKernel::CollectiveMainloop::LayoutE;
|
||||
|
||||
using ElementE = typename GemmKernel::CollectiveMainloop::ElementE;
|
||||
using SparseConfig = typename GemmKernel::CollectiveMainloop::SparseConfig;
|
||||
|
||||
// Offline compressor kernel
|
||||
using CompressorUtility =
|
||||
cutlass::transform::kernel::StructuredSparseCompressorUtility<
|
||||
ProblemShape, ElementA, LayoutTagA, SparseConfig>;
|
||||
|
||||
using CompressorKernel =
|
||||
cutlass::transform::kernel::StructuredSparseCompressor<
|
||||
ProblemShape, ElementA, LayoutTagA, SparseConfig,
|
||||
cutlass::arch::Sm90>;
|
||||
|
||||
using Compressor =
|
||||
cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
|
||||
|
||||
auto [M, N, K, L] = prob_shape;
|
||||
|
||||
StrideA stride_A;
|
||||
stride_A =
|
||||
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
|
||||
|
||||
CompressorUtility compressor_utility(prob_shape, stride_A);
|
||||
|
||||
int ME = compressor_utility.get_metadata_m_physical();
|
||||
int KE = compressor_utility.get_metadata_k_physical();
|
||||
int KC = compressor_utility.get_tensorA_k_physical();
|
||||
|
||||
auto a_ptr = static_cast<ElementA*>(a.data_ptr());
|
||||
|
||||
auto a_nzs_ptr = static_cast<ElementA*>(a_nzs.data_ptr());
|
||||
auto a_meta_ptr = static_cast<typename Gemm::CollectiveMainloop::ElementE*>(
|
||||
a_meta.data_ptr());
|
||||
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = 0;
|
||||
hw_info.sm_count =
|
||||
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
|
||||
hw_info.device_id);
|
||||
typename Compressor::Arguments arguments{
|
||||
prob_shape, {a_ptr, stride_A, a_nzs_ptr, a_meta_ptr}, {hw_info}};
|
||||
|
||||
Compressor compressor_op;
|
||||
size_t workspace_size = Compressor::get_workspace_size(arguments);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
CUTLASS_CHECK(compressor_op.can_implement(arguments));
|
||||
CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.get()));
|
||||
CUTLASS_CHECK(compressor_op.run());
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta,
|
||||
torch::Tensor const& a) {
|
||||
if (a.dtype() == torch::kBFloat16) {
|
||||
return cutlass_sparse_compress<cutlass::bfloat16_t, float>(a_nzs, a_meta,
|
||||
a);
|
||||
} else if (a.dtype() == torch::kFloat16) {
|
||||
return cutlass_sparse_compress<cutlass::half_t, float>(a_nzs, a_meta, a);
|
||||
} else if (a.dtype() == torch::kFloat8_e4m3fn) {
|
||||
return cutlass_sparse_compress<cutlass::float_e4m3_t, float>(a_nzs, a_meta,
|
||||
a);
|
||||
} else if (a.dtype() == torch::kInt8) {
|
||||
return cutlass_sparse_compress<int8_t, int32_t>(a_nzs, a_meta, a);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
42
csrc/sparse/cutlass/sparse_compressor_entry.cu
Normal file
42
csrc/sparse/cutlass/sparse_compressor_entry.cu
Normal file
@ -0,0 +1,42 @@
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
|
||||
bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta,
|
||||
torch::Tensor const& a);
|
||||
#endif
|
||||
|
||||
bool cutlass_sparse_compress_entry(torch::Tensor& a_nzs, torch::Tensor& a_meta,
|
||||
torch::Tensor const& a) {
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dim() == 2 && a_meta.dim() == 2 && a_nzs.dim() == 2);
|
||||
TORCH_CHECK(a.size(0) == a_nzs.size(0) && a.size(0) == a_meta.size(0) &&
|
||||
a_nzs.size(1) * 2 == a.size(1) &&
|
||||
a_meta.size(1) * 2 * 4 == a.size(1));
|
||||
// Considering elemsPerMetaElem = 8b / 2b_per_nz = 4
|
||||
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(1) == 1 && a_nzs.stride(1) == 1 &&
|
||||
a_meta.stride(1) == 1); // Row-major
|
||||
TORCH_CHECK(a.stride(0) % 8 == 0); // 8 Byte Alignment for Compression
|
||||
|
||||
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
||||
int32_t version_num = get_sm_version_num();
|
||||
|
||||
// Guard against compilation issues for sm90 kernels
|
||||
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
|
||||
if (version_num >= 90) {
|
||||
return cutlass_sparse_compress_sm90(a_nzs, a_meta, a);
|
||||
}
|
||||
#endif
|
||||
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled cutlass_scaled_sparse_mm for a compute capability less than "
|
||||
"CUDA device capability: ",
|
||||
version_num);
|
||||
}
|
||||
303
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu
Normal file
303
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu
Normal file
@ -0,0 +1,303 @@
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12020
|
||||
#include "sparse_scaled_mm_c3x.cuh"
|
||||
// clang-format on
|
||||
|
||||
using namespace cute;
|
||||
using namespace vllm;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& bt_nzs,
|
||||
torch::Tensor const& bt_meta,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
|
||||
TORCH_CHECK(bt_nzs.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM64 =
|
||||
typename sm90_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM128 =
|
||||
typename sm90_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM256 =
|
||||
typename sm90_fp8_config_M256<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM512 =
|
||||
typename sm90_fp8_config_M512<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
|
||||
using Cutlass3xGemm1 =
|
||||
typename sm90_fp8_config_1<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemm2 =
|
||||
typename sm90_fp8_config_2<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemm3 =
|
||||
typename sm90_fp8_config_3<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemm4 =
|
||||
typename sm90_fp8_config_4<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemm5 =
|
||||
typename sm90_fp8_config_5<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemm6 =
|
||||
typename sm90_fp8_config_6<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemm7 =
|
||||
typename sm90_fp8_config_7<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemm8 =
|
||||
typename sm90_fp8_config_8<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
|
||||
uint32_t const n = bt_nzs.size(0);
|
||||
uint32_t const m = a.size(0); // Batch size
|
||||
uint32_t const mp2 =
|
||||
std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
|
||||
|
||||
if (mp2 <= 64) {
|
||||
if (n == 28672) {
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemm2>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (n == 4096 || n == 6144) {
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemm1>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
} else if (mp2 <= 128) {
|
||||
if (n == 4096) {
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemm3>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (n == 28672) {
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemm5>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (n == 6144) {
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemm4>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
} else if (mp2 <= 256) {
|
||||
if (n == 4096) {
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemm6>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (n == 28672) {
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemm8>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (n == 6144) {
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemm7>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
} else {
|
||||
if (n == 6144 || n == 28672) {
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemm8>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (n == 4096) {
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemm7>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise the default heuristic
|
||||
if (mp2 <= 64) {
|
||||
// n in [1, 64]
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemmM64>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 128) {
|
||||
// n in (64, 128]
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemmM128>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 256) {
|
||||
// n in (128, 256]
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemmM256>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
// n in (256, inf)
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemmM512>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_gemm_sm90_fp16_dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& bt_nzs,
|
||||
torch::Tensor const& bt_meta,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::half_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat16);
|
||||
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
|
||||
TORCH_CHECK(bt_nzs.dtype() == torch::kFloat16);
|
||||
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
|
||||
// m in (128, inf)
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemmDefault>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_gemm_sm90_bf16_dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& bt_nzs,
|
||||
torch::Tensor const& bt_meta,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::bfloat16_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kBFloat16);
|
||||
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
|
||||
TORCH_CHECK(bt_nzs.dtype() == torch::kBFloat16);
|
||||
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
|
||||
// m in (128, inf)
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemmDefault>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& bt_nzs,
|
||||
torch::Tensor const& bt_meta,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
|
||||
TORCH_CHECK(bt_nzs.dtype() == torch::kInt8);
|
||||
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM128 =
|
||||
typename sm90_int8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM64 =
|
||||
typename sm90_int8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM32NBig =
|
||||
typename sm90_int8_config_M32_NBig<InType, OutType,
|
||||
Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM32NSmall =
|
||||
typename sm90_int8_config_M32_NSmall<InType, OutType,
|
||||
Epilogue>::Cutlass3xGemm;
|
||||
|
||||
uint32_t const n = out.size(1);
|
||||
bool const is_small_n = n < 8192;
|
||||
|
||||
uint32_t const m = a.size(0);
|
||||
uint32_t const mp2 =
|
||||
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
|
||||
|
||||
if (mp2 <= 32) {
|
||||
// m in [1, 32]
|
||||
if (is_small_n) {
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemmM32NSmall>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemmM32NBig>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
} else if (mp2 <= 64) {
|
||||
// m in (32, 64]
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemmM64>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 128) {
|
||||
// m in (64, 128]
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemmM128>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
// m in (128, inf)
|
||||
return cutlass_sparse_gemm_caller<Cutlass3xGemmDefault>(
|
||||
out, a, bt_nzs, bt_meta, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
template <template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& bt_nzs,
|
||||
torch::Tensor const& bt_meta,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
|
||||
if (a.dtype() == torch::kInt8) {
|
||||
TORCH_CHECK(bt_nzs.dtype() == torch::kInt8);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
||||
Epilogue>(
|
||||
out, a, bt_nzs, bt_meta,
|
||||
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, bt_nzs, bt_meta,
|
||||
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
} else if (a.dtype() == torch::kFloat8_e4m3fn) {
|
||||
TORCH_CHECK(bt_nzs.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, bt_nzs, bt_meta,
|
||||
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::half_t, Epilogue>(
|
||||
out, a, bt_nzs, bt_meta,
|
||||
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
} else if (a.dtype() == torch::kFloat16) {
|
||||
TORCH_CHECK(bt_nzs.dtype() == torch::kFloat16);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_sm90_fp16_dispatch<cutlass::half_t,
|
||||
cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, bt_nzs, bt_meta,
|
||||
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_sm90_fp16_dispatch<cutlass::half_t, cutlass::half_t,
|
||||
Epilogue>(
|
||||
out, a, bt_nzs, bt_meta,
|
||||
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
} else { // a.dtype() == torch::kBFloat16
|
||||
TORCH_CHECK(a.dtype() == torch::kBFloat16);
|
||||
TORCH_CHECK(bt_nzs.dtype() == torch::kBFloat16);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_sm90_bf16_dispatch<cutlass::bfloat16_t,
|
||||
cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, bt_nzs, bt_meta,
|
||||
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_sm90_bf16_dispatch<cutlass::bfloat16_t,
|
||||
cutlass::half_t, Epilogue>(
|
||||
out, a, bt_nzs, bt_meta,
|
||||
std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_scaled_sparse_mm_sm90(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& bt_nzs,
|
||||
torch::Tensor const& bt_meta,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_sparse_mm_sm90_epilogue<c3x::ScaledEpilogueBias>(
|
||||
out, a, bt_nzs, bt_meta, b_scales, a_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_sparse_mm_sm90_epilogue<c3x::ScaledEpilogue>(
|
||||
out, a, bt_nzs, bt_meta, b_scales, a_scales);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
496
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh
Normal file
496
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh
Normal file
@ -0,0 +1,496 @@
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "core/math.hpp"
|
||||
#include "cutlass_extensions/cute_utils.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "cutlass_extensions/torch_utils.hpp"
|
||||
// clang-format on
|
||||
|
||||
using namespace cute;
|
||||
|
||||
/*
|
||||
This file defines sparse quantized GEMM operations using the CUTLASS 3.x API,
|
||||
for NVIDIA GPUs with sm90a (Hopper) or later.
|
||||
*/
|
||||
|
||||
namespace {
|
||||
|
||||
// A wrapper for the GEMM kernel that is used to guard against compilation on
|
||||
// architectures that will never use the kernel. The purpose of this is to
|
||||
// reduce the size of the compiled binary.
|
||||
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
|
||||
// into code that will be executed on the device where it is defined.
|
||||
template <typename Kernel>
|
||||
struct enable_sm90_or_later : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE void operator()(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
|
||||
Kernel::operator()(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
using GemmUniversalMode = cutlass::gemm::GemmUniversalMode;
|
||||
|
||||
template <typename ElementAB_, typename ElementD_,
|
||||
template <typename, typename, typename> typename Epilogue_,
|
||||
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
||||
typename EpilogueSchedule, typename AccType,
|
||||
typename TileSchedule = cutlass::gemm::PersistentScheduler,
|
||||
GemmUniversalMode Mode_ = GemmUniversalMode::kGemm>
|
||||
struct cutlass_sparse_3x_gemm {
|
||||
static const GemmUniversalMode Mode = Mode_;
|
||||
using ElementAB = ElementAB_;
|
||||
using ElementD = ElementD_;
|
||||
using ElementAcc = AccType;
|
||||
|
||||
using EpilogueDescriptor =
|
||||
cutlass::epilogue::collective::detail::EpilogueDescriptor<
|
||||
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
|
||||
ElementD, EpilogueSchedule>;
|
||||
|
||||
using Epilogue = Epilogue_<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||
|
||||
using ElementC = void;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutD = LayoutC;
|
||||
using StrideC = cutlass::detail::TagToStrideA_t<LayoutC>;
|
||||
using StrideD = cutlass::detail::TagToStrideA_t<LayoutD>;
|
||||
|
||||
using LayoutC_Transpose =
|
||||
typename cutlass::layout::LayoutTranspose<LayoutC>::type;
|
||||
using LayoutD_Transpose =
|
||||
typename cutlass::layout::LayoutTranspose<LayoutD>::type;
|
||||
|
||||
using EVTCompute = typename Epilogue::EVTCompute;
|
||||
|
||||
static constexpr int AlignmentA =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
static constexpr int AlignmentB =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
static constexpr int AlignmentCD =
|
||||
128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
|
||||
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAcc, ElementAcc, ElementC, LayoutC_Transpose, AlignmentCD,
|
||||
ElementD, LayoutD_Transpose, AlignmentCD, EpilogueSchedule,
|
||||
EVTCompute>::CollectiveOp;
|
||||
|
||||
static constexpr size_t CEStorageSize =
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage);
|
||||
using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(CEStorageSize)>;
|
||||
|
||||
// clang-format off
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp,
|
||||
ElementAB, cutlass::layout::RowMajor, AlignmentA,
|
||||
ElementAB, cutlass::layout::ColumnMajor, AlignmentB,
|
||||
ElementAcc, TileShape, ClusterShape,
|
||||
Stages,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
// clang-format on
|
||||
|
||||
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
|
||||
cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
|
||||
TileSchedule>>;
|
||||
|
||||
struct GemmKernel : public KernelType {};
|
||||
};
|
||||
|
||||
template <typename Gemm, typename... EpilogueArgs>
|
||||
void cutlass_sparse_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& bt_nzs,
|
||||
torch::Tensor const& bt_meta,
|
||||
EpilogueArgs&&... epilogue_params) {
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
|
||||
// Interface stride expected from the argument a (will get transposed)
|
||||
// We compute C^T = B^T * A^T, but we assume B is transposed before
|
||||
// compression and hence the bt_* naming
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA;
|
||||
using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
|
||||
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
|
||||
using StrideD = cutlass::detail::TagToStrideA_t<LayoutD>;
|
||||
|
||||
auto layout_A = make_cute_layout<StrideA>(a, "A");
|
||||
auto layout_D = make_cute_layout<StrideD>(out, "D");
|
||||
|
||||
// Transpose A and D
|
||||
// A doesn't need to be transposed since cutlass expects a NxK matrix
|
||||
// for B (which is At)
|
||||
auto stride_At = layout_A.stride();
|
||||
auto stride_Dt = permute_layout<1, 0, 2>(layout_D).stride();
|
||||
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
typename GemmKernel::ProblemShape prob_shape{
|
||||
static_cast<int>(bt_nzs.size(0)), static_cast<int>(size<0>(layout_A)),
|
||||
static_cast<int>(size<1>(layout_A)), 1};
|
||||
|
||||
using ElementE = typename GemmKernel::CollectiveMainloop::ElementE;
|
||||
using SparseConfig = typename GemmKernel::CollectiveMainloop::SparseConfig;
|
||||
|
||||
LayoutB b_layout = SparseConfig::fill_layoutA(prob_shape);
|
||||
LayoutE e_layout = SparseConfig::fill_layoutE(prob_shape);
|
||||
|
||||
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementAB*>(bt_nzs.data_ptr());
|
||||
auto e_ptr = static_cast<ElementE*>(bt_meta.data_ptr());
|
||||
typename GemmKernel::MainloopArguments mainloop_args{
|
||||
b_ptr, b_layout, a_ptr, stride_At, e_ptr, e_layout};
|
||||
|
||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
Gemm::Epilogue::prepare_args(
|
||||
std::forward<EpilogueArgs>(epilogue_params)...),
|
||||
c_ptr, stride_Dt, c_ptr, stride_Dt};
|
||||
|
||||
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
prob_shape, mainloop_args, epilogue_args};
|
||||
|
||||
// Launch the CUTLASS GEMM kernel.
|
||||
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
GemmOp gemm_op;
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||
|
||||
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_config_default {};
|
||||
|
||||
template <typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_config_default<half_t, OutType, Epilogue> {
|
||||
// M in (128, inf)
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<half_t, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, float>;
|
||||
};
|
||||
|
||||
template <typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_config_default<cutlass::bfloat16_t, OutType, Epilogue> {
|
||||
// M in (128, inf)
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<cutlass::bfloat16_t, OutType, Epilogue, TileShape,
|
||||
ClusterShape, KernelSchedule, EpilogueSchedule,
|
||||
float>;
|
||||
};
|
||||
|
||||
//////////////////////// Cherry-Picking Kernels ////////////////////////
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_1 {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _64, _256>;
|
||||
using ClusterShape = Shape<_8, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, float>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_2 {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using TileShape = Shape<_128, _64, _256>;
|
||||
using ClusterShape = Shape<_8, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, float>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_3 {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _64, _256>;
|
||||
using ClusterShape = Shape<_1, _2, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, float>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_4 {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using TileShape = Shape<_64, _128, _256>;
|
||||
using ClusterShape = Shape<_8, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, float>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_5 {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_128, _128, _256>;
|
||||
using ClusterShape = Shape<_8, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, float>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_6 {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _128, _256>;
|
||||
using ClusterShape = Shape<_1, _2, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, float>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_7 {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using TileShape = Shape<_128, _128, _256>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, float>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_8 {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using TileShape = Shape<_128, _256, _128>;
|
||||
using ClusterShape = Shape<_8, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, float>;
|
||||
};
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_config_default<cutlass::float_e4m3_t, OutType, Epilogue> {
|
||||
// M in (128, inf)
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_1, _2, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<cutlass::float_e4m3_t, OutType, Epilogue,
|
||||
TileShape, ClusterShape, KernelSchedule,
|
||||
EpilogueSchedule, float>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_M64 {
|
||||
// M in [1, 64]
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using TileShape = Shape<_64, _64, _256>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
|
||||
using TileSchedule = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, float,
|
||||
TileSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_M128 {
|
||||
// M in (64, 128]
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _128, _256>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
|
||||
using TileSchedule = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, float,
|
||||
TileSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_M256 {
|
||||
// M in (128, 256]
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using TileShape = Shape<_128, _128, _256>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
|
||||
using TileSchedule = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, float,
|
||||
TileSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_M512 {
|
||||
// M in (256, ]
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
typename cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using TileShape = Shape<_128, _128, _256>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
|
||||
using TileSchedule = cutlass::gemm::PersistentScheduler;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, float,
|
||||
TileSchedule>;
|
||||
};
|
||||
|
||||
template <typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_config_default<int8_t, OutType, Epilogue> {
|
||||
// For M > 128 and any N
|
||||
using KernelSchedule =
|
||||
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<int8_t, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, int32_t>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_int8_config_M128 {
|
||||
// For M in (64, 128] and any N
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using KernelSchedule =
|
||||
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _128, _128>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, int32_t>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_int8_config_M64 {
|
||||
// For M in (32, 64] and any N
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _64, _256>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, int32_t>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_int8_config_M32_NBig {
|
||||
// For M in [1, 32] and N >= 8192
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _128, _256>;
|
||||
using ClusterShape = Shape<_1, _4, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, int32_t>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_int8_config_M32_NSmall {
|
||||
// For M in [1, 32] and N < 8192
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
||||
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
||||
using TileShape = Shape<_64, _64, _256>;
|
||||
using ClusterShape = Shape<_1, _8, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_sparse_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule, int32_t>;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
70
csrc/sparse/cutlass/sparse_scaled_mm_entry.cu
Normal file
70
csrc/sparse/cutlass/sparse_scaled_mm_entry.cu
Normal file
@ -0,0 +1,70 @@
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
|
||||
bool cutlass_sparse_scaled_mm_supported(int64_t cuda_device_capability) {
|
||||
// sparse CUTLASS kernels need at least
|
||||
// CUDA 12.2 and SM90 (Hopper)
|
||||
|
||||
#if defined CUDA_VERSION
|
||||
return CUDA_VERSION >= 12020 && cuda_device_capability >= 90;
|
||||
#endif
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
|
||||
void cutlass_scaled_sparse_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& e,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
#endif
|
||||
|
||||
void cutlass_scaled_sparse_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& bt_nzs,
|
||||
torch::Tensor const& bt_meta,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dim() == 2 && bt_nzs.dim() == 2 && c.dim() == 2);
|
||||
TORCH_CHECK(c.size(1) == bt_nzs.size(0) && bt_nzs.size(1) * 2 == a.size(1) &&
|
||||
a.size(0) == c.size(0));
|
||||
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
||||
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == bt_nzs.size(0));
|
||||
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(1) == 1 && bt_nzs.stride(1) == 1 &&
|
||||
c.stride(1) == 1); // Row-major
|
||||
TORCH_CHECK(c.stride(0) % 16 == 0); // 16 Byte Alignment
|
||||
TORCH_CHECK(bt_nzs.stride(0) % 16 == 0); // 16 Byte Alignment
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->numel() == bt_nzs.size(0) && bias->is_contiguous() &&
|
||||
bias->dim() == 1);
|
||||
}
|
||||
|
||||
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
||||
int32_t version_num = get_sm_version_num();
|
||||
|
||||
// Guard against compilation issues for sm90 kernels
|
||||
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
|
||||
if (version_num >= 90) {
|
||||
cutlass_scaled_sparse_mm_sm90(c, a, bt_nzs, bt_meta, a_scales, b_scales,
|
||||
bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled cutlass_scaled_sparse_mm for a compute capability less than "
|
||||
"CUDA device capability: ",
|
||||
version_num);
|
||||
}
|
||||
@ -21,6 +21,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
|
||||
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);
|
||||
|
||||
ops.def("get_cuda_view_from_cpu_tensor(Tensor cpu_tensor) -> Tensor");
|
||||
ops.impl("get_cuda_view_from_cpu_tensor", torch::kCPU,
|
||||
&get_cuda_view_from_cpu_tensor);
|
||||
|
||||
// Attention ops
|
||||
// Compute the attention between an input query and the cached
|
||||
// keys/values using PagedAttention.
|
||||
@ -98,6 +102,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
") -> ()");
|
||||
ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer);
|
||||
|
||||
ops.def(
|
||||
"copy_subranges(Tensor matrix_src, Tensor matrix_diff, Tensor! "
|
||||
"matrix_tgt, int n) -> ()");
|
||||
ops.impl("copy_subranges", torch::kCUDA, ©_subranges);
|
||||
|
||||
// Layernorm
|
||||
// Apply Root Mean Square (RMS) Normalization to the input tensor.
|
||||
ops.def(
|
||||
@ -321,6 +330,28 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
|
||||
ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
|
||||
|
||||
// Check if cutlass sparse scaled_mm is supported for CUDA devices of the
|
||||
// given capability
|
||||
ops.def(
|
||||
"cutlass_sparse_scaled_mm_supported(int cuda_device_capability) -> bool");
|
||||
ops.impl("cutlass_sparse_scaled_mm_supported",
|
||||
&cutlass_sparse_scaled_mm_supported);
|
||||
|
||||
// CUTLASS sparse GEMM, supporting symmetric per-tensor or per-row/column
|
||||
// quantization, as well as bias
|
||||
ops.def(
|
||||
"cutlass_scaled_sparse_mm(Tensor! out, Tensor a,"
|
||||
" Tensor bt_nzs,"
|
||||
" Tensor bt_meta, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor? bias) -> ()");
|
||||
ops.impl("cutlass_scaled_sparse_mm", torch::kCUDA, &cutlass_scaled_sparse_mm);
|
||||
|
||||
// CUTLASS sparse matrix compressor
|
||||
ops.def(
|
||||
"cutlass_sparse_compress_entry(Tensor! a_nzs, Tensor! a_meta,"
|
||||
" Tensor a) -> bool");
|
||||
ops.impl("cutlass_sparse_compress_entry", &cutlass_sparse_compress_entry);
|
||||
|
||||
// Mamba selective scan kernel
|
||||
ops.def(
|
||||
"selective_scan_fwd(Tensor! u, Tensor! delta,"
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
sphinx==6.2.1
|
||||
sphinx-book-theme==1.0.1
|
||||
sphinx-copybutton==0.5.2
|
||||
myst-parser==2.0.0
|
||||
myst-parser==3.0.1
|
||||
sphinx-argparse==0.4.0
|
||||
msgspec
|
||||
cloudpickle
|
||||
@ -19,3 +19,4 @@ openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entr
|
||||
fastapi # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
|
||||
partial-json-parser # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
|
||||
requests
|
||||
zmq
|
||||
|
||||
102
docs/source/automatic_prefix_caching/apc.md
Normal file
102
docs/source/automatic_prefix_caching/apc.md
Normal file
@ -0,0 +1,102 @@
|
||||
(apc)=
|
||||
|
||||
# Introduction
|
||||
|
||||
## What is Automatic Prefix Caching
|
||||
|
||||
Automatic Prefix Caching (APC in short) caches the KV cache of existing queries, so that a new query can directly reuse the KV cache if it shares the same prefix with one of the existing queries, allowing the new query to skip the computation of the shared part.
|
||||
|
||||
```{note}
|
||||
Technical details on how vLLM implements APC are in the next page.
|
||||
```
|
||||
|
||||
## Enabling APC in vLLM
|
||||
|
||||
Set `enable_prefix_caching=True` in vLLM engine to enable APC. Here is an example:
|
||||
|
||||
```python
|
||||
import time
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
# A prompt containing a large markdown table. The table is randomly generated by GPT-4.
|
||||
LONG_PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as follows.\n# Table\n" + """
|
||||
| ID | Name | Age | Occupation | Country | Email | Phone Number | Address |
|
||||
|-----|---------------|-----|---------------|---------------|------------------------|----------------|------------------------------|
|
||||
| 1 | John Doe | 29 | Engineer | USA | john.doe@example.com | 555-1234 | 123 Elm St, Springfield, IL |
|
||||
| 2 | Jane Smith | 34 | Doctor | Canada | jane.smith@example.com | 555-5678 | 456 Oak St, Toronto, ON |
|
||||
| 3 | Alice Johnson | 27 | Teacher | UK | alice.j@example.com | 555-8765 | 789 Pine St, London, UK |
|
||||
| 4 | Bob Brown | 45 | Artist | Australia | bob.b@example.com | 555-4321 | 321 Maple St, Sydney, NSW |
|
||||
| 5 | Carol White | 31 | Scientist | New Zealand | carol.w@example.com | 555-6789 | 654 Birch St, Wellington, NZ |
|
||||
| 6 | Dave Green | 28 | Lawyer | Ireland | dave.g@example.com | 555-3456 | 987 Cedar St, Dublin, IE |
|
||||
| 7 | Emma Black | 40 | Musician | USA | emma.b@example.com | 555-1111 | 246 Ash St, New York, NY |
|
||||
| 8 | Frank Blue | 37 | Chef | Canada | frank.b@example.com | 555-2222 | 135 Spruce St, Vancouver, BC |
|
||||
| 9 | Grace Yellow | 50 | Engineer | UK | grace.y@example.com | 555-3333 | 864 Fir St, Manchester, UK |
|
||||
| 10 | Henry Violet | 32 | Artist | Australia | henry.v@example.com | 555-4444 | 753 Willow St, Melbourne, VIC|
|
||||
| 11 | Irene Orange | 26 | Scientist | New Zealand | irene.o@example.com | 555-5555 | 912 Poplar St, Auckland, NZ |
|
||||
| 12 | Jack Indigo | 38 | Teacher | Ireland | jack.i@example.com | 555-6666 | 159 Elm St, Cork, IE |
|
||||
| 13 | Karen Red | 41 | Lawyer | USA | karen.r@example.com | 555-7777 | 357 Cedar St, Boston, MA |
|
||||
| 14 | Leo Brown | 30 | Chef | Canada | leo.b@example.com | 555-8888 | 246 Oak St, Calgary, AB |
|
||||
| 15 | Mia Green | 33 | Musician | UK | mia.g@example.com | 555-9999 | 975 Pine St, Edinburgh, UK |
|
||||
| 16 | Noah Yellow | 29 | Doctor | Australia | noah.y@example.com | 555-0000 | 864 Birch St, Brisbane, QLD |
|
||||
| 17 | Olivia Blue | 35 | Engineer | New Zealand | olivia.b@example.com | 555-1212 | 753 Maple St, Hamilton, NZ |
|
||||
| 18 | Peter Black | 42 | Artist | Ireland | peter.b@example.com | 555-3434 | 912 Fir St, Limerick, IE |
|
||||
| 19 | Quinn White | 28 | Scientist | USA | quinn.w@example.com | 555-5656 | 159 Willow St, Seattle, WA |
|
||||
| 20 | Rachel Red | 31 | Teacher | Canada | rachel.r@example.com | 555-7878 | 357 Poplar St, Ottawa, ON |
|
||||
| 21 | Steve Green | 44 | Lawyer | UK | steve.g@example.com | 555-9090 | 753 Elm St, Birmingham, UK |
|
||||
| 22 | Tina Blue | 36 | Musician | Australia | tina.b@example.com | 555-1213 | 864 Cedar St, Perth, WA |
|
||||
| 23 | Umar Black | 39 | Chef | New Zealand | umar.b@example.com | 555-3435 | 975 Spruce St, Christchurch, NZ|
|
||||
| 24 | Victor Yellow | 43 | Engineer | Ireland | victor.y@example.com | 555-5657 | 246 Willow St, Galway, IE |
|
||||
| 25 | Wendy Orange | 27 | Artist | USA | wendy.o@example.com | 555-7879 | 135 Elm St, Denver, CO |
|
||||
| 26 | Xavier Green | 34 | Scientist | Canada | xavier.g@example.com | 555-9091 | 357 Oak St, Montreal, QC |
|
||||
| 27 | Yara Red | 41 | Teacher | UK | yara.r@example.com | 555-1214 | 975 Pine St, Leeds, UK |
|
||||
| 28 | Zack Blue | 30 | Lawyer | Australia | zack.b@example.com | 555-3436 | 135 Birch St, Adelaide, SA |
|
||||
| 29 | Amy White | 33 | Musician | New Zealand | amy.w@example.com | 555-5658 | 159 Maple St, Wellington, NZ |
|
||||
| 30 | Ben Black | 38 | Chef | Ireland | ben.b@example.com | 555-7870 | 246 Fir St, Waterford, IE |
|
||||
"""
|
||||
|
||||
|
||||
def get_generation_time(llm, sampling_params, prompts):
|
||||
# time the generation
|
||||
start_time = time.time()
|
||||
output = llm.generate(prompts, sampling_params=sampling_params)
|
||||
end_time = time.time()
|
||||
# print the output and generation time
|
||||
print(f"Output: {output[0].outputs[0].text}")
|
||||
print(f"Generation time: {end_time - start_time} seconds.")
|
||||
|
||||
|
||||
# set enable_prefix_caching=True to enable APC
|
||||
llm = LLM(
|
||||
model='lmsys/longchat-13b-16k',
|
||||
enable_prefix_caching=True
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0, max_tokens=100)
|
||||
|
||||
# Querying the age of John Doe
|
||||
get_generation_time(
|
||||
llm,
|
||||
sampling_params,
|
||||
LONG_PROMPT + "Question: what is the age of John Doe? Your answer: The age of John Doe is ",
|
||||
)
|
||||
|
||||
# Querying the age of Zack Blue
|
||||
# This query will be faster since vllm avoids computing the KV cache of LONG_PROMPT again.
|
||||
get_generation_time(
|
||||
llm,
|
||||
sampling_params,
|
||||
LONG_PROMPT + "Question: what is the age of Zack Blue? Your answer: The age of Zack Blue is ",
|
||||
)
|
||||
```
|
||||
|
||||
## Example workloads
|
||||
|
||||
We describe two example workloads, where APC can provide huge performance benefit:
|
||||
|
||||
- Long document query, where the user repeatedly queries the same long document (e.g. software manual or annual report) with different queries. In this case, instead of processing the long document again and again, APC allows vLLM to process this long document *only once*, and all future requests can avoid recomputing this long document by reusing its KV cache. This allows vLLM to serve future requests with much higher throughput and much lower latency.
|
||||
- Multi-round conversation, where the user may chat with the application multiple times in the same chatting session. In this case, instead of processing the whole chatting history again and again, APC allows vLLM to reuse the processing results of the chat history across all future rounds of conversation, allowing vLLM to serve future requests with much higher throughput and much lower latency.
|
||||
|
||||
## Limits
|
||||
|
||||
APC in general does not reduce the performance of vLLM. With that being said, APC only reduces the time of processing the queries (the prefilling phase) and does not reduce the time of generating new tokens (the decoding phase). So APC does not bring performance gain when vLLM spends most of the time generating answers to the queries (e.g. when the length of the answer is long), or new queries do not share the same prefix with any of existing queries (so that the computation cannot be reused).
|
||||
@ -1,110 +0,0 @@
|
||||
.. _apc:
|
||||
|
||||
Introduction
|
||||
============
|
||||
|
||||
What is Automatic Prefix Caching
|
||||
--------------------------------
|
||||
|
||||
Automatic Prefix Caching (APC in short) caches the KV cache of existing queries, so that a new query can directly reuse the KV cache if it shares the same prefix with one of the existing queries, allowing the new query to skip the computation of the shared part.
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
Technical details on how vLLM implements APC are in the next page.
|
||||
|
||||
|
||||
|
||||
Enabling APC in vLLM
|
||||
--------------------
|
||||
|
||||
Set ``enable_prefix_caching=True`` in vLLM engine to enable APC. Here is an example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import time
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
# A prompt containing a large markdown table. The table is randomly generated by GPT-4.
|
||||
LONG_PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as follows.\n# Table\n" + """
|
||||
| ID | Name | Age | Occupation | Country | Email | Phone Number | Address |
|
||||
|-----|---------------|-----|---------------|---------------|------------------------|----------------|------------------------------|
|
||||
| 1 | John Doe | 29 | Engineer | USA | john.doe@example.com | 555-1234 | 123 Elm St, Springfield, IL |
|
||||
| 2 | Jane Smith | 34 | Doctor | Canada | jane.smith@example.com | 555-5678 | 456 Oak St, Toronto, ON |
|
||||
| 3 | Alice Johnson | 27 | Teacher | UK | alice.j@example.com | 555-8765 | 789 Pine St, London, UK |
|
||||
| 4 | Bob Brown | 45 | Artist | Australia | bob.b@example.com | 555-4321 | 321 Maple St, Sydney, NSW |
|
||||
| 5 | Carol White | 31 | Scientist | New Zealand | carol.w@example.com | 555-6789 | 654 Birch St, Wellington, NZ |
|
||||
| 6 | Dave Green | 28 | Lawyer | Ireland | dave.g@example.com | 555-3456 | 987 Cedar St, Dublin, IE |
|
||||
| 7 | Emma Black | 40 | Musician | USA | emma.b@example.com | 555-1111 | 246 Ash St, New York, NY |
|
||||
| 8 | Frank Blue | 37 | Chef | Canada | frank.b@example.com | 555-2222 | 135 Spruce St, Vancouver, BC |
|
||||
| 9 | Grace Yellow | 50 | Engineer | UK | grace.y@example.com | 555-3333 | 864 Fir St, Manchester, UK |
|
||||
| 10 | Henry Violet | 32 | Artist | Australia | henry.v@example.com | 555-4444 | 753 Willow St, Melbourne, VIC|
|
||||
| 11 | Irene Orange | 26 | Scientist | New Zealand | irene.o@example.com | 555-5555 | 912 Poplar St, Auckland, NZ |
|
||||
| 12 | Jack Indigo | 38 | Teacher | Ireland | jack.i@example.com | 555-6666 | 159 Elm St, Cork, IE |
|
||||
| 13 | Karen Red | 41 | Lawyer | USA | karen.r@example.com | 555-7777 | 357 Cedar St, Boston, MA |
|
||||
| 14 | Leo Brown | 30 | Chef | Canada | leo.b@example.com | 555-8888 | 246 Oak St, Calgary, AB |
|
||||
| 15 | Mia Green | 33 | Musician | UK | mia.g@example.com | 555-9999 | 975 Pine St, Edinburgh, UK |
|
||||
| 16 | Noah Yellow | 29 | Doctor | Australia | noah.y@example.com | 555-0000 | 864 Birch St, Brisbane, QLD |
|
||||
| 17 | Olivia Blue | 35 | Engineer | New Zealand | olivia.b@example.com | 555-1212 | 753 Maple St, Hamilton, NZ |
|
||||
| 18 | Peter Black | 42 | Artist | Ireland | peter.b@example.com | 555-3434 | 912 Fir St, Limerick, IE |
|
||||
| 19 | Quinn White | 28 | Scientist | USA | quinn.w@example.com | 555-5656 | 159 Willow St, Seattle, WA |
|
||||
| 20 | Rachel Red | 31 | Teacher | Canada | rachel.r@example.com | 555-7878 | 357 Poplar St, Ottawa, ON |
|
||||
| 21 | Steve Green | 44 | Lawyer | UK | steve.g@example.com | 555-9090 | 753 Elm St, Birmingham, UK |
|
||||
| 22 | Tina Blue | 36 | Musician | Australia | tina.b@example.com | 555-1213 | 864 Cedar St, Perth, WA |
|
||||
| 23 | Umar Black | 39 | Chef | New Zealand | umar.b@example.com | 555-3435 | 975 Spruce St, Christchurch, NZ|
|
||||
| 24 | Victor Yellow | 43 | Engineer | Ireland | victor.y@example.com | 555-5657 | 246 Willow St, Galway, IE |
|
||||
| 25 | Wendy Orange | 27 | Artist | USA | wendy.o@example.com | 555-7879 | 135 Elm St, Denver, CO |
|
||||
| 26 | Xavier Green | 34 | Scientist | Canada | xavier.g@example.com | 555-9091 | 357 Oak St, Montreal, QC |
|
||||
| 27 | Yara Red | 41 | Teacher | UK | yara.r@example.com | 555-1214 | 975 Pine St, Leeds, UK |
|
||||
| 28 | Zack Blue | 30 | Lawyer | Australia | zack.b@example.com | 555-3436 | 135 Birch St, Adelaide, SA |
|
||||
| 29 | Amy White | 33 | Musician | New Zealand | amy.w@example.com | 555-5658 | 159 Maple St, Wellington, NZ |
|
||||
| 30 | Ben Black | 38 | Chef | Ireland | ben.b@example.com | 555-7870 | 246 Fir St, Waterford, IE |
|
||||
"""
|
||||
|
||||
|
||||
def get_generation_time(llm, sampling_params, prompts):
|
||||
# time the generation
|
||||
start_time = time.time()
|
||||
output = llm.generate(prompts, sampling_params=sampling_params)
|
||||
end_time = time.time()
|
||||
# print the output and generation time
|
||||
print(f"Output: {output[0].outputs[0].text}")
|
||||
print(f"Generation time: {end_time - start_time} seconds.")
|
||||
|
||||
|
||||
# set enable_prefix_caching=True to enable APC
|
||||
llm = LLM(
|
||||
model='lmsys/longchat-13b-16k',
|
||||
enable_prefix_caching=True
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0, max_tokens=100)
|
||||
|
||||
# Querying the age of John Doe
|
||||
get_generation_time(
|
||||
llm,
|
||||
sampling_params,
|
||||
LONG_PROMPT + "Question: what is the age of John Doe? Your answer: The age of John Doe is ",
|
||||
)
|
||||
|
||||
# Querying the age of Zack Blue
|
||||
# This query will be faster since vllm avoids computing the KV cache of LONG_PROMPT again.
|
||||
get_generation_time(
|
||||
llm,
|
||||
sampling_params,
|
||||
LONG_PROMPT + "Question: what is the age of Zack Blue? Your answer: The age of Zack Blue is ",
|
||||
)
|
||||
|
||||
Example workloads
|
||||
-----------------
|
||||
|
||||
We describe two example workloads, where APC can provide huge performance benefit:
|
||||
|
||||
- Long document query, where the user repeatedly queries the same long document (e.g. software manual or annual report) with different queries. In this case, instead of processing the long document again and again, APC allows vLLM to process this long document *only once*, and all future requests can avoid recomputing this long document by reusing its KV cache. This allows vLLM to serve future requests with much higher throughput and much lower latency.
|
||||
- Multi-round conversation, where the user may chat with the application multiple times in the same chatting session. In this case, instead of processing the whole chatting history again and again, APC allows vLLM to reuse the processing results of the chat history across all future rounds of conversation, allowing vLLM to serve future requests with much higher throughput and much lower latency.
|
||||
|
||||
|
||||
Limits
|
||||
------
|
||||
APC in general does not reduce the performance of vLLM. With that being said, APC only reduces the time of processing the queries (the prefilling phase) and does not reduce the time of generating new tokens (the decoding phase). So APC does not bring performance gain when vLLM spends most of the time generating answers to the queries (e.g. when the length of the answer is long), or new queries do not share the same prefix with any of existing queries (so that the computation cannot be reused).
|
||||
15
docs/source/community/meetups.md
Normal file
15
docs/source/community/meetups.md
Normal file
@ -0,0 +1,15 @@
|
||||
(meetups)=
|
||||
|
||||
# vLLM Meetups
|
||||
|
||||
We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below:
|
||||
|
||||
- [The seventh vLLM meetup](https://lu.ma/h0qvrajz), with Snowflake, November 14th 2024. [[Slides]](https://docs.google.com/presentation/d/1e3CxQBV3JsfGp30SwyvS3eM_tW-ghOhJ9PAJGK6KR54/edit?usp=sharing)
|
||||
- [The sixth vLLM meetup](https://lu.ma/87q3nvnh), with NVIDIA, September 9th 2024. [[Slides]](https://docs.google.com/presentation/d/1wrLGwytQfaOTd5wCGSPNhoaW3nq0E-9wqyP7ny93xRs/edit?usp=sharing)
|
||||
- [The fifth vLLM meetup](https://lu.ma/lp0gyjqr), with AWS, July 24th 2024. [[Slides]](https://docs.google.com/presentation/d/1RgUD8aCfcHocghoP3zmXzck9vX3RCI9yfUAB2Bbcl4Y/edit?usp=sharing)
|
||||
- [The fourth vLLM meetup](https://lu.ma/agivllm), with Cloudflare and BentoML, June 11th 2024. [[Slides]](https://docs.google.com/presentation/d/1iJ8o7V2bQEi0BFEljLTwc5G1S10_Rhv3beed5oB0NJ4/edit?usp=sharing)
|
||||
- [The third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/), with Roblox, April 2nd 2024. [[Slides]](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing)
|
||||
- [The second vLLM meetup](https://lu.ma/ygxbpzhl), with IBM Research, January 31st 2024. [[Slides]](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing) [[Video (vLLM Update)]](https://youtu.be/Y0C-DUvEnZQ) [[Video (IBM Research & torch.compile)]](https://youtu.be/m0dMtFLI-dg)
|
||||
- [The first vLLM meetup](https://lu.ma/first-vllm-meetup), with a16z, October 5th 2023. [[Slides]](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing)
|
||||
|
||||
We are always looking for speakers and sponsors at San Francisco Bay Area and potentially other locations. If you are interested in speaking or sponsoring, please contact us at [vllm-questions@lists.berkeley.edu](mailto:vllm-questions@lists.berkeley.edu).
|
||||
@ -1,16 +0,0 @@
|
||||
.. _meetups:
|
||||
|
||||
vLLM Meetups
|
||||
============
|
||||
|
||||
We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below:
|
||||
|
||||
- `The seventh vLLM meetup <https://lu.ma/h0qvrajz>`__, with Snowflake, November 14th 2024. `[Slides] <https://docs.google.com/presentation/d/1e3CxQBV3JsfGp30SwyvS3eM_tW-ghOhJ9PAJGK6KR54/edit?usp=sharing>`__
|
||||
- `The sixth vLLM meetup <https://lu.ma/87q3nvnh>`__, with NVIDIA, September 9th 2024. `[Slides] <https://docs.google.com/presentation/d/1wrLGwytQfaOTd5wCGSPNhoaW3nq0E-9wqyP7ny93xRs/edit?usp=sharing>`__
|
||||
- `The fifth vLLM meetup <https://lu.ma/lp0gyjqr>`__, with AWS, July 24th 2024. `[Slides] <https://docs.google.com/presentation/d/1RgUD8aCfcHocghoP3zmXzck9vX3RCI9yfUAB2Bbcl4Y/edit?usp=sharing>`__
|
||||
- `The fourth vLLM meetup <https://lu.ma/agivllm>`__, with Cloudflare and BentoML, June 11th 2024. `[Slides] <https://docs.google.com/presentation/d/1iJ8o7V2bQEi0BFEljLTwc5G1S10_Rhv3beed5oB0NJ4/edit?usp=sharing>`__
|
||||
- `The third vLLM meetup <https://robloxandvllmmeetup2024.splashthat.com/>`__, with Roblox, April 2nd 2024. `[Slides] <https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing>`__
|
||||
- `The second vLLM meetup <https://lu.ma/ygxbpzhl>`__, with IBM Research, January 31st 2024. `[Slides] <https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing>`__ `[Video (vLLM Update)] <https://youtu.be/Y0C-DUvEnZQ>`__ `[Video (IBM Research & torch.compile)] <https://youtu.be/m0dMtFLI-dg>`__
|
||||
- `The first vLLM meetup <https://lu.ma/first-vllm-meetup>`__, with a16z, October 5th 2023. `[Slides] <https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing>`__
|
||||
|
||||
We are always looking for speakers and sponsors at San Francisco Bay Area and potentially other locations. If you are interested in speaking or sponsoring, please contact us at `vllm-questions@lists.berkeley.edu <mailto:vllm-questions@lists.berkeley.edu>`__.
|
||||
@ -51,7 +51,7 @@ templates_path = ['_templates']
|
||||
# List of patterns, relative to source directory, that match files and
|
||||
# directories to ignore when looking for source files.
|
||||
# This pattern also affects html_static_path and html_extra_path.
|
||||
exclude_patterns: List[str] = ["**/*.template.rst"]
|
||||
exclude_patterns: List[str] = ["**/*.template.md"]
|
||||
|
||||
# Exclude the prompt "$" when copying code
|
||||
copybutton_prompt_text = r"\$ "
|
||||
@ -74,6 +74,35 @@ html_theme_options = {
|
||||
html_static_path = ["_static"]
|
||||
html_js_files = ["custom.js"]
|
||||
|
||||
myst_url_schemes = {
|
||||
'http': None,
|
||||
'https': None,
|
||||
'mailto': None,
|
||||
'ftp': None,
|
||||
"gh-issue": {
|
||||
"url":
|
||||
"https://github.com/vllm-project/vllm/issues/{{path}}#{{fragment}}",
|
||||
"title": "Issue #{{path}}",
|
||||
"classes": ["github"],
|
||||
},
|
||||
"gh-pr": {
|
||||
"url":
|
||||
"https://github.com/vllm-project/vllm/pull/{{path}}#{{fragment}}",
|
||||
"title": "Pull Request #{{path}}",
|
||||
"classes": ["github"],
|
||||
},
|
||||
"gh-dir": {
|
||||
"url": "https://github.com/vllm-project/vllm/tree/main/{{path}}",
|
||||
"title": "{{path}}",
|
||||
"classes": ["github"],
|
||||
},
|
||||
"gh-file": {
|
||||
"url": "https://github.com/vllm-project/vllm/blob/main/{{path}}",
|
||||
"title": "{{path}}",
|
||||
"classes": ["github"],
|
||||
},
|
||||
}
|
||||
|
||||
# see https://docs.readthedocs.io/en/stable/reference/environment-variables.html # noqa
|
||||
READTHEDOCS_VERSION_TYPE = os.environ.get('READTHEDOCS_VERSION_TYPE')
|
||||
if READTHEDOCS_VERSION_TYPE == "tag":
|
||||
@ -162,6 +191,7 @@ def linkcode_resolve(domain, info):
|
||||
|
||||
# Mock out external dependencies here, otherwise the autodoc pages may be blank.
|
||||
autodoc_mock_imports = [
|
||||
"blake3",
|
||||
"compressed_tensors",
|
||||
"cpuinfo",
|
||||
"cv2",
|
||||
@ -178,7 +208,7 @@ autodoc_mock_imports = [
|
||||
"tensorizer",
|
||||
"pynvml",
|
||||
"outlines",
|
||||
"xgrammar,"
|
||||
"xgrammar",
|
||||
"librosa",
|
||||
"soundfile",
|
||||
"gguf",
|
||||
|
||||
50
docs/source/contributing/dockerfile/dockerfile.md
Normal file
50
docs/source/contributing/dockerfile/dockerfile.md
Normal file
@ -0,0 +1,50 @@
|
||||
# Dockerfile
|
||||
|
||||
We provide a <gh-file:Dockerfile> to construct the image for running an OpenAI compatible server with vLLM.
|
||||
More information about deploying with Docker can be found [here](../../serving/deploying_with_docker.md).
|
||||
|
||||
Below is a visual representation of the multi-stage Dockerfile. The build graph contains the following nodes:
|
||||
|
||||
- All build stages
|
||||
- The default build target (highlighted in grey)
|
||||
- External images (with dashed borders)
|
||||
|
||||
The edges of the build graph represent:
|
||||
|
||||
- `FROM ...` dependencies (with a solid line and a full arrow head)
|
||||
|
||||
- `COPY --from=...` dependencies (with a dashed line and an empty arrow head)
|
||||
|
||||
- `RUN --mount=(.\*)from=...` dependencies (with a dotted line and an empty diamond arrow head)
|
||||
|
||||
> ```{figure} ../../assets/dev/dockerfile-stages-dependency.png
|
||||
> :align: center
|
||||
> :alt: query
|
||||
> :width: 100%
|
||||
> ```
|
||||
>
|
||||
> Made using: <https://github.com/patrickhoefler/dockerfilegraph>
|
||||
>
|
||||
> Commands to regenerate the build graph (make sure to run it **from the \`root\` directory of the vLLM repository** where the dockerfile is present):
|
||||
>
|
||||
> ```bash
|
||||
> dockerfilegraph -o png --legend --dpi 200 --max-label-length 50 --filename Dockerfile
|
||||
> ```
|
||||
>
|
||||
> or in case you want to run it directly with the docker image:
|
||||
>
|
||||
> ```bash
|
||||
> docker run \
|
||||
> --rm \
|
||||
> --user "$(id -u):$(id -g)" \
|
||||
> --workdir /workspace \
|
||||
> --volume "$(pwd)":/workspace \
|
||||
> ghcr.io/patrickhoefler/dockerfilegraph:alpine \
|
||||
> --output png \
|
||||
> --dpi 200 \
|
||||
> --max-label-length 50 \
|
||||
> --filename Dockerfile \
|
||||
> --legend
|
||||
> ```
|
||||
>
|
||||
> (To run it for a different file, you can pass in a different argument to the flag `--filename`.)
|
||||
@ -1,50 +0,0 @@
|
||||
Dockerfile
|
||||
====================
|
||||
|
||||
See `here <https://github.com/vllm-project/vllm/blob/main/Dockerfile>`__ for the main Dockerfile to construct
|
||||
the image for running an OpenAI compatible server with vLLM. More information about deploying with Docker can be found `here <https://docs.vllm.ai/en/stable/serving/deploying_with_docker.html>`__.
|
||||
|
||||
Below is a visual representation of the multi-stage Dockerfile. The build graph contains the following nodes:
|
||||
|
||||
- All build stages
|
||||
- The default build target (highlighted in grey)
|
||||
- External images (with dashed borders)
|
||||
|
||||
The edges of the build graph represent:
|
||||
|
||||
- FROM ... dependencies (with a solid line and a full arrow head)
|
||||
- COPY --from=... dependencies (with a dashed line and an empty arrow head)
|
||||
- RUN --mount=(.*)from=... dependencies (with a dotted line and an empty diamond arrow head)
|
||||
|
||||
.. figure:: ../../assets/dev/dockerfile-stages-dependency.png
|
||||
:alt: query
|
||||
:width: 100%
|
||||
:align: center
|
||||
|
||||
Made using: https://github.com/patrickhoefler/dockerfilegraph
|
||||
|
||||
Commands to regenerate the build graph (make sure to run it **from the `root` directory of the vLLM repository** where the dockerfile is present):
|
||||
|
||||
.. code:: bash
|
||||
|
||||
dockerfilegraph -o png --legend --dpi 200 --max-label-length 50 --filename Dockerfile
|
||||
|
||||
or in case you want to run it directly with the docker image:
|
||||
|
||||
.. code:: bash
|
||||
|
||||
docker run \
|
||||
--rm \
|
||||
--user "$(id -u):$(id -g)" \
|
||||
--workdir /workspace \
|
||||
--volume "$(pwd)":/workspace \
|
||||
ghcr.io/patrickhoefler/dockerfilegraph:alpine \
|
||||
--output png \
|
||||
--dpi 200 \
|
||||
--max-label-length 50 \
|
||||
--filename Dockerfile \
|
||||
--legend
|
||||
|
||||
(To run it for a different file, you can pass in a different argument to the flag `--filename`.)
|
||||
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
Contributing to vLLM
|
||||
=====================
|
||||
# Contributing to vLLM
|
||||
|
||||
Thank you for your interest in contributing to vLLM! Our community is open to everyone and welcomes all kinds of contributions, no matter how small or large. There are several ways you can contribute to the project:
|
||||
|
||||
@ -12,132 +11,121 @@ We also believe in the power of community support; thus, answering queries, offe
|
||||
|
||||
Finally, one of the most impactful ways to support us is by raising awareness about vLLM. Talk about it in your blog posts and highlight how it's driving your incredible projects. Express your support on social media if you're using vLLM, or simply offer your appreciation by starring our repository!
|
||||
|
||||
License
|
||||
-------
|
||||
## License
|
||||
|
||||
See `LICENSE <https://github.com/vllm-project/vllm/tree/main/LICENSE>`_.
|
||||
See <gh-file:LICENSE>.
|
||||
|
||||
Developing
|
||||
----------
|
||||
## Developing
|
||||
|
||||
Depending on the kind of development you'd like to do (e.g. Python, CUDA), you can choose to build vLLM with or without compilation. Check out the `building from source <https://docs.vllm.ai/en/latest/getting_started/installation.html#build-from-source>`_ documentation for details.
|
||||
Depending on the kind of development you'd like to do (e.g. Python, CUDA), you can choose to build vLLM with or without compilation.
|
||||
Check out the [building from source](#build-from-source) documentation for details.
|
||||
|
||||
Testing
|
||||
-------
|
||||
## Testing
|
||||
|
||||
.. code-block:: bash
|
||||
```bash
|
||||
pip install -r requirements-dev.txt
|
||||
|
||||
pip install -r requirements-dev.txt
|
||||
# linting and formatting
|
||||
bash format.sh
|
||||
# Static type checking
|
||||
mypy
|
||||
# Unit tests
|
||||
pytest tests/
|
||||
```
|
||||
|
||||
# linting and formatting
|
||||
bash format.sh
|
||||
# Static type checking
|
||||
mypy
|
||||
# Unit tests
|
||||
pytest tests/
|
||||
```{note}
|
||||
Currently, the repository is not fully checked by `mypy`.
|
||||
```
|
||||
|
||||
.. note:: Currently, the repository does not pass the ``mypy`` tests.
|
||||
# Contribution Guidelines
|
||||
|
||||
Contribution Guidelines
|
||||
=======================
|
||||
## Issues
|
||||
|
||||
Issues
|
||||
------
|
||||
If you encounter a bug or have a feature request, please [search existing issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue) first to see if it has already been reported. If not, please [file a new issue](https://github.com/vllm-project/vllm/issues/new/choose), providing as much relevant information as possible.
|
||||
|
||||
If you encounter a bug or have a feature request, please `search existing issues <https://github.com/vllm-project/vllm/issues?q=is%3Aissue>`_ first to see if it has already been reported. If not, please `file a new issue <https://github.com/vllm-project/vllm/issues/new/choose>`_, providing as much relevant information as possible.
|
||||
```{important}
|
||||
If you discover a security vulnerability, please follow the instructions [here](gh-file:SECURITY.md#reporting-a-vulnerability).
|
||||
```
|
||||
|
||||
.. important::
|
||||
If you discover a security vulnerability, please follow the instructions `here <https://github.com/vllm-project/vllm/tree/main/SECURITY.md#reporting-a-vulnerability>`_.
|
||||
|
||||
Pull Requests & Code Reviews
|
||||
----------------------------
|
||||
## Pull Requests & Code Reviews
|
||||
|
||||
Thank you for your contribution to vLLM! Before submitting the pull request,
|
||||
please ensure the PR meets the following criteria. This helps vLLM maintain the
|
||||
code quality and improve the efficiency of the review process.
|
||||
|
||||
DCO and Signed-off-by
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
### DCO and Signed-off-by
|
||||
|
||||
When contributing changes to this project, you must agree to the `DCO <https://github.com/vllm-project/vllm/tree/main/DCO>`_.
|
||||
Commits must include a ``Signed-off-by:`` header which certifies agreement with
|
||||
the terms of the `DCO <https://github.com/vllm-project/vllm/tree/main/DCO>`_.
|
||||
When contributing changes to this project, you must agree to the <gh-file:DCO>.
|
||||
Commits must include a `Signed-off-by:` header which certifies agreement with
|
||||
the terms of the DCO.
|
||||
|
||||
Using ``-s`` with ``git commit`` will automatically add this header.
|
||||
Using `-s` with `git commit` will automatically add this header.
|
||||
|
||||
PR Title and Classification
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
### PR Title and Classification
|
||||
|
||||
Only specific types of PRs will be reviewed. The PR title is prefixed
|
||||
appropriately to indicate the type of change. Please use one of the following:
|
||||
|
||||
- ``[Bugfix]`` for bug fixes.
|
||||
- ``[CI/Build]`` for build or continuous integration improvements.
|
||||
- ``[Doc]`` for documentation fixes and improvements.
|
||||
- ``[Model]`` for adding a new model or improving an existing model. Model name
|
||||
- `[Bugfix]` for bug fixes.
|
||||
- `[CI/Build]` for build or continuous integration improvements.
|
||||
- `[Doc]` for documentation fixes and improvements.
|
||||
- `[Model]` for adding a new model or improving an existing model. Model name
|
||||
should appear in the title.
|
||||
- ``[Frontend]`` For changes on the vLLM frontend (e.g., OpenAI API server,
|
||||
``LLM`` class, etc.)
|
||||
- ``[Kernel]`` for changes affecting CUDA kernels or other compute kernels.
|
||||
- ``[Core]`` for changes in the core vLLM logic (e.g., ``LLMEngine``,
|
||||
``AsyncLLMEngine``, ``Scheduler``, etc.)
|
||||
- ``[Hardware][Vendor]`` for hardware-specific changes. Vendor name should
|
||||
appear in the prefix (e.g., ``[Hardware][AMD]``).
|
||||
- ``[Misc]`` for PRs that do not fit the above categories. Please use this
|
||||
- `[Frontend]` For changes on the vLLM frontend (e.g., OpenAI API server,
|
||||
`LLM` class, etc.)
|
||||
- `[Kernel]` for changes affecting CUDA kernels or other compute kernels.
|
||||
- `[Core]` for changes in the core vLLM logic (e.g., `LLMEngine`,
|
||||
`AsyncLLMEngine`, `Scheduler`, etc.)
|
||||
- `[Hardware][Vendor]` for hardware-specific changes. Vendor name should
|
||||
appear in the prefix (e.g., `[Hardware][AMD]`).
|
||||
- `[Misc]` for PRs that do not fit the above categories. Please use this
|
||||
sparingly.
|
||||
|
||||
.. note::
|
||||
If the PR spans more than one category, please include all relevant prefixes.
|
||||
```{note}
|
||||
If the PR spans more than one category, please include all relevant prefixes.
|
||||
```
|
||||
|
||||
Code Quality
|
||||
^^^^^^^^^^^^
|
||||
### Code Quality
|
||||
|
||||
The PR needs to meet the following code quality standards:
|
||||
|
||||
- We adhere to `Google Python style guide
|
||||
<https://google.github.io/styleguide/pyguide.html>`_ and `Google C++ style guide
|
||||
<https://google.github.io/styleguide/cppguide.html>`_.
|
||||
- Pass all linter checks. Please use `format.sh
|
||||
<https://github.com/vllm-project/vllm/blob/main/format.sh>`_ to format your
|
||||
code.
|
||||
- We adhere to [Google Python style guide](https://google.github.io/styleguide/pyguide.html) and [Google C++ style guide](https://google.github.io/styleguide/cppguide.html).
|
||||
- Pass all linter checks. Please use <gh-file:format.sh> to format your code.
|
||||
- The code needs to be well-documented to ensure future contributors can easily
|
||||
understand the code.
|
||||
- Include sufficient tests to ensure the project stays correct and robust. This
|
||||
includes both unit tests and integration tests.
|
||||
- Please add documentation to ``docs/source/`` if the PR modifies the
|
||||
- Please add documentation to `docs/source/` if the PR modifies the
|
||||
user-facing behaviors of vLLM. It helps vLLM users understand and utilize the
|
||||
new features or changes.
|
||||
|
||||
Adding or Changing Kernels
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
### Adding or Changing Kernels
|
||||
|
||||
Each custom kernel needs a schema and one or more implementations to be registered with PyTorch.
|
||||
|
||||
- Make sure custom ops are registered following PyTorch guidelines:
|
||||
`Custom C++ and CUDA Operators <https://pytorch.org/tutorials/advanced/cpp_custom_ops.html#cpp-custom-ops-tutorial>`_
|
||||
and `The Custom Operators Manual <https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU>`_.
|
||||
- Custom operations that return ``Tensors`` require meta-functions.
|
||||
[Custom C++ and CUDA Operators](https://pytorch.org/tutorials/advanced/cpp_custom_ops.html#cpp-custom-ops-tutorial)
|
||||
and [The Custom Operators Manual](https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU).
|
||||
- Custom operations that return `Tensors` require meta-functions.
|
||||
Meta-functions should be implemented and registered in Python so that dynamic
|
||||
dims can be handled automatically. See above documents for a description of
|
||||
meta-functions.
|
||||
- Use `torch.library.opcheck() <https://pytorch.org/docs/stable/library.html#torch.library.opcheck>`_
|
||||
- Use [torch.library.opcheck()](https://pytorch.org/docs/stable/library.html#torch.library.opcheck)
|
||||
to test the function registration and meta-function for any registered ops.
|
||||
See ``tests/kernels`` for examples.
|
||||
See `tests/kernels` for examples.
|
||||
- When changing the C++ signature of an existing op, the schema must be updated
|
||||
to reflect the changes.
|
||||
- If a new custom type is needed, see the following document:
|
||||
`Custom Class Support in PT2 <https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA>`_.
|
||||
[Custom Class Support in PT2](https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA).
|
||||
|
||||
Notes for Large Changes
|
||||
^^^^^^^^^^^^^^^^^^^^^^^
|
||||
### Notes for Large Changes
|
||||
|
||||
Please keep the changes as concise as possible. For major architectural changes
|
||||
(>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue
|
||||
(RFC) discussing the technical design and justification. Otherwise, we will tag
|
||||
it with ``rfc-required`` and might not go through the PR.
|
||||
it with `rfc-required` and might not go through the PR.
|
||||
|
||||
What to Expect for the Reviews
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
### What to Expect for the Reviews
|
||||
|
||||
The goal of the vLLM team is to be a *transparent reviewing machine*. We would
|
||||
like to make the review process transparent and efficient and make sure no
|
||||
@ -150,15 +138,14 @@ review process:
|
||||
- After the PR is assigned, the reviewer will provide status updates every 2-3
|
||||
days. If the PR is not reviewed within 7 days, please feel free to ping the
|
||||
reviewer or the vLLM team.
|
||||
- After the review, the reviewer will put an ``action-required`` label on the PR
|
||||
- After the review, the reviewer will put an `action-required` label on the PR
|
||||
if there are changes required. The contributor should address the comments and
|
||||
ping the reviewer to re-review the PR.
|
||||
- Please respond to all comments within a reasonable time frame. If a comment
|
||||
isn't clear or you disagree with a suggestion, feel free to ask for
|
||||
clarification or discuss the suggestion.
|
||||
|
||||
Thank You
|
||||
---------
|
||||
## Thank You
|
||||
|
||||
Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM.
|
||||
All of your contributions help make vLLM a great tool and community for everyone!
|
||||
41
docs/source/contributing/profiling/profiling_index.md
Normal file
41
docs/source/contributing/profiling/profiling_index.md
Normal file
@ -0,0 +1,41 @@
|
||||
# Profiling vLLM
|
||||
|
||||
We support tracing vLLM workers using the `torch.profiler` module. You can enable tracing by setting the `VLLM_TORCH_PROFILER_DIR` environment variable to the directory where you want to save the traces: `VLLM_TORCH_PROFILER_DIR=/mnt/traces/`
|
||||
|
||||
The OpenAI server also needs to be started with the `VLLM_TORCH_PROFILER_DIR` environment variable set.
|
||||
|
||||
When using `benchmarks/benchmark_serving.py`, you can enable profiling by passing the `--profile` flag.
|
||||
|
||||
```{warning}
|
||||
Only enable profiling in a development environment.
|
||||
```
|
||||
|
||||
Traces can be visualized using <https://ui.perfetto.dev/>.
|
||||
|
||||
```{tip}
|
||||
Only send a few requests through vLLM when profiling, as the traces can get quite large. Also, no need to untar the traces, they can be viewed directly.
|
||||
```
|
||||
|
||||
```{tip}
|
||||
To stop the profiler - it flushes out all the profile trace files to the directory. This takes time, for example for about 100 requests worth of data for a llama 70b, it takes about 10 minutes to flush out on a H100.
|
||||
Set the env variable VLLM_RPC_TIMEOUT to a big number before you start the server. Say something like 30 minutes.
|
||||
`export VLLM_RPC_TIMEOUT=1800000`
|
||||
```
|
||||
|
||||
## Example commands and usage
|
||||
|
||||
### Offline Inference
|
||||
|
||||
Refer to <gh-file:examples/offline_inference_with_profiler.py> for an example.
|
||||
|
||||
### OpenAI Server
|
||||
|
||||
```bash
|
||||
VLLM_TORCH_PROFILER_DIR=./vllm_profile python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-70B
|
||||
```
|
||||
|
||||
benchmark_serving.py:
|
||||
|
||||
```bash
|
||||
python benchmarks/benchmark_serving.py --backend vllm --model meta-llama/Meta-Llama-3-70B --dataset-name sharegpt --dataset-path sharegpt.json --profile --num-prompts 2
|
||||
```
|
||||
@ -1,48 +0,0 @@
|
||||
==============
|
||||
Profiling vLLM
|
||||
==============
|
||||
|
||||
We support tracing vLLM workers using the ``torch.profiler`` module. You can enable tracing by setting the ``VLLM_TORCH_PROFILER_DIR`` environment variable to the directory where you want to save the traces: ``VLLM_TORCH_PROFILER_DIR=/mnt/traces/``
|
||||
|
||||
The OpenAI server also needs to be started with the ``VLLM_TORCH_PROFILER_DIR`` environment variable set.
|
||||
|
||||
When using ``benchmarks/benchmark_serving.py``, you can enable profiling by passing the ``--profile`` flag.
|
||||
|
||||
.. warning::
|
||||
|
||||
Only enable profiling in a development environment.
|
||||
|
||||
|
||||
Traces can be visualized using https://ui.perfetto.dev/.
|
||||
|
||||
.. tip::
|
||||
|
||||
Only send a few requests through vLLM when profiling, as the traces can get quite large. Also, no need to untar the traces, they can be viewed directly.
|
||||
|
||||
.. tip::
|
||||
|
||||
To stop the profiler - it flushes out all the profile trace files to the directory. This takes time, for example for about 100 requests worth of data for a llama 70b, it takes about 10 minutes to flush out on a H100.
|
||||
Set the env variable VLLM_RPC_TIMEOUT to a big number before you start the server. Say something like 30 minutes.
|
||||
``export VLLM_RPC_TIMEOUT=1800000``
|
||||
|
||||
Example commands and usage:
|
||||
===========================
|
||||
|
||||
Offline Inference:
|
||||
------------------
|
||||
|
||||
Refer to `examples/offline_inference_with_profiler.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_with_profiler.py>`_ for an example.
|
||||
|
||||
|
||||
OpenAI Server:
|
||||
--------------
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
VLLM_TORCH_PROFILER_DIR=./vllm_profile python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-70B
|
||||
|
||||
benchmark_serving.py:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python benchmarks/benchmark_serving.py --backend vllm --model meta-llama/Meta-Llama-3-70B --dataset-name sharegpt --dataset-path sharegpt.json --profile --num-prompts 2
|
||||
@ -1,25 +1,24 @@
|
||||
.. _arch_overview:
|
||||
(arch-overview)=
|
||||
|
||||
Architecture Overview
|
||||
======================
|
||||
# Architecture Overview
|
||||
|
||||
This document provides an overview of the vLLM architecture.
|
||||
|
||||
.. contents:: Table of Contents
|
||||
:local:
|
||||
:depth: 2
|
||||
```{contents} Table of Contents
|
||||
:depth: 2
|
||||
:local: true
|
||||
```
|
||||
|
||||
Entrypoints
|
||||
-----------
|
||||
## Entrypoints
|
||||
|
||||
vLLM provides a number of entrypoints for interacting with the system. The
|
||||
following diagram shows the relationship between them.
|
||||
|
||||
.. image:: /assets/design/arch_overview/entrypoints.excalidraw.png
|
||||
:alt: Entrypoints Diagram
|
||||
```{image} /assets/design/arch_overview/entrypoints.excalidraw.png
|
||||
:alt: Entrypoints Diagram
|
||||
```
|
||||
|
||||
LLM Class
|
||||
^^^^^^^^^
|
||||
### LLM Class
|
||||
|
||||
The LLM class provides the primary Python interface for doing offline inference,
|
||||
which is interacting with a model without using a separate model inference
|
||||
@ -27,75 +26,70 @@ server.
|
||||
|
||||
Here is a sample of `LLM` class usage:
|
||||
|
||||
.. code-block:: python
|
||||
```python
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
# Define a list of input prompts
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The capital of France is",
|
||||
"The largest ocean is",
|
||||
]
|
||||
|
||||
# Define a list of input prompts
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The capital of France is",
|
||||
"The largest ocean is",
|
||||
]
|
||||
# Define sampling parameters
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
# Define sampling parameters
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
# Initialize the LLM engine with the OPT-125M model
|
||||
llm = LLM(model="facebook/opt-125m")
|
||||
|
||||
# Initialize the LLM engine with the OPT-125M model
|
||||
llm = LLM(model="facebook/opt-125m")
|
||||
# Generate outputs for the input prompts
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
# Generate outputs for the input prompts
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the generated outputs
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
```
|
||||
|
||||
# Print the generated outputs
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
More API details can be found in the :doc:`Offline Inference
|
||||
More API details can be found in the {doc}`Offline Inference
|
||||
</dev/offline_inference/offline_index>` section of the API docs.
|
||||
|
||||
The code for the `LLM` class can be found in `vllm/entrypoints/llm.py
|
||||
<https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py>`_.
|
||||
The code for the `LLM` class can be found in <gh-file:vllm/entrypoints/llm.py>.
|
||||
|
||||
OpenAI-compatible API server
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
### OpenAI-compatible API server
|
||||
|
||||
The second primary interface to vLLM is via its OpenAI-compatible API server.
|
||||
This server can be started using the `vllm serve` command.
|
||||
|
||||
.. code-block:: bash
|
||||
```bash
|
||||
vllm serve <model>
|
||||
```
|
||||
|
||||
vllm serve <model>
|
||||
|
||||
The code for the `vllm` CLI can be found in `vllm/scripts.py
|
||||
<https://github.com/vllm-project/vllm/blob/main/vllm/scripts.py>`_.
|
||||
The code for the `vllm` CLI can be found in <gh-file:vllm/scripts.py>.
|
||||
|
||||
Sometimes you may see the API server entrypoint used directly instead of via the
|
||||
`vllm` CLI command. For example:
|
||||
|
||||
.. code-block:: bash
|
||||
```bash
|
||||
python -m vllm.entrypoints.openai.api_server --model <model>
|
||||
```
|
||||
|
||||
python -m vllm.entrypoints.openai.api_server --model <model>
|
||||
That code can be found in <gh-file:vllm/entrypoints/openai/api_server.py>.
|
||||
|
||||
That code can be found in `vllm/entrypoints/openai/api_server.py
|
||||
<https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/api_server.py>`_.
|
||||
|
||||
More details on the API server can be found in the :doc:`OpenAI Compatible
|
||||
More details on the API server can be found in the {doc}`OpenAI Compatible
|
||||
Server </serving/openai_compatible_server>` document.
|
||||
|
||||
LLM Engine
|
||||
----------
|
||||
## LLM Engine
|
||||
|
||||
The `LLMEngine` and `AsyncLLMEngine` classes are central to the functioning of
|
||||
the vLLM system, handling model inference and asynchronous request processing.
|
||||
|
||||
.. image:: /assets/design/arch_overview/llm_engine.excalidraw.png
|
||||
:alt: LLMEngine Diagram
|
||||
```{image} /assets/design/arch_overview/llm_engine.excalidraw.png
|
||||
:alt: LLMEngine Diagram
|
||||
```
|
||||
|
||||
LLMEngine
|
||||
^^^^^^^^^
|
||||
### LLMEngine
|
||||
|
||||
The `LLMEngine` class is the core component of the vLLM engine. It is
|
||||
responsible for receiving requests from clients and generating outputs from the
|
||||
@ -105,21 +99,15 @@ processing.
|
||||
|
||||
- **Input Processing**: Handles tokenization of input text using the specified
|
||||
tokenizer.
|
||||
|
||||
- **Scheduling**: Chooses which requests are processed in each step.
|
||||
|
||||
- **Model Execution**: Manages the execution of the language model, including
|
||||
distributed execution across multiple GPUs.
|
||||
|
||||
- **Output Processing**: Processes the outputs generated by the model, decoding the
|
||||
token IDs from a language model into human-readable text.
|
||||
|
||||
The code for `LLMEngine` can be found in `vllm/engine/llm_engine.py`_.
|
||||
The code for `LLMEngine` can be found in <gh-file:vllm/engine/llm_engine.py>.
|
||||
|
||||
.. _vllm/engine/llm_engine.py: https://github.com/vllm-project/vllm/tree/main/vllm/engine/llm_engine.py
|
||||
|
||||
AsyncLLMEngine
|
||||
^^^^^^^^^^^^^^
|
||||
### AsyncLLMEngine
|
||||
|
||||
The `AsyncLLMEngine` class is an asynchronous wrapper for the `LLMEngine` class.
|
||||
It uses `asyncio` to create a background loop that continuously processes
|
||||
@ -127,55 +115,46 @@ incoming requests. The `AsyncLLMEngine` is designed for online serving, where it
|
||||
can handle multiple concurrent requests and stream outputs to clients.
|
||||
|
||||
The OpenAI-compatible API server uses the `AsyncLLMEngine`. There is also a demo
|
||||
API server that serves as a simpler example in
|
||||
`vllm/entrypoints/api_server.py`_.
|
||||
API server that serves as a simpler example in <gh-file:vllm/entrypoints/api_server.py>.
|
||||
|
||||
.. _vllm/entrypoints/api_server.py: https://github.com/vllm-project/vllm/tree/main/vllm/entrypoints/api_server.py
|
||||
The code for `AsyncLLMEngine` can be found in <gh-file:vllm/engine/async_llm_engine.py>.
|
||||
|
||||
The code for `AsyncLLMEngine` can be found in `vllm/engine/async_llm_engine.py`_.
|
||||
|
||||
.. _vllm/engine/async_llm_engine.py: https://github.com/vllm-project/vllm/tree/main/vllm/engine/async_llm_engine.py
|
||||
|
||||
Worker
|
||||
------
|
||||
## Worker
|
||||
|
||||
A worker is a process that runs the model inference. vLLM follows the common
|
||||
practice of using one process to control one accelerator device, such as GPUs.
|
||||
For example, if we use tensor parallelism of size 2 and pipeline parallelism of
|
||||
size 2, we will have 4 workers in total. Workers are identified by their
|
||||
``rank`` and ``local_rank``. ``rank`` is used for global orchestration, while
|
||||
``local_rank`` is mainly used for assigning the accelerator device and accessing
|
||||
`rank` and `local_rank`. `rank` is used for global orchestration, while
|
||||
`local_rank` is mainly used for assigning the accelerator device and accessing
|
||||
local resources such as the file system and shared memory.
|
||||
|
||||
Model Runner
|
||||
------------
|
||||
## Model Runner
|
||||
|
||||
Every worker has one model runner object, responsible for loading and running
|
||||
the model. Much of the model execution logic resides here, such as preparing
|
||||
input tensors and capturing cudagraphs.
|
||||
|
||||
Model
|
||||
-----
|
||||
## Model
|
||||
|
||||
Every model runner object has one model object, which is the actual
|
||||
``torch.nn.Module`` instance. See :ref:`huggingface_integration` for how various
|
||||
`torch.nn.Module` instance. See [huggingface_integration](#huggingface-integration) for how various
|
||||
configurations affect the class we ultimately get.
|
||||
|
||||
Class Hierarchy
|
||||
---------------
|
||||
## Class Hierarchy
|
||||
|
||||
The following figure shows the class hierarchy of vLLM:
|
||||
|
||||
.. figure:: /assets/design/hierarchy.png
|
||||
:alt: query
|
||||
:width: 100%
|
||||
:align: center
|
||||
> ```{figure} /assets/design/hierarchy.png
|
||||
> :align: center
|
||||
> :alt: query
|
||||
> :width: 100%
|
||||
> ```
|
||||
|
||||
There are several important design choices behind this class hierarchy:
|
||||
|
||||
1. **Extensibility**: All classes in the hierarchy accept a configuration object
|
||||
containing all the necessary information. The `VllmConfig
|
||||
<https://github.com/vllm-project/vllm/blob/d1c6799b8870e513bf4f2305cbf6cda9fc3d773b/vllm/config.py#L2036>`__
|
||||
1\. **Extensibility**: All classes in the hierarchy accept a configuration object
|
||||
containing all the necessary information. The [VllmConfig](https://github.com/vllm-project/vllm/blob/d1c6799b8870e513bf4f2305cbf6cda9fc3d773b/vllm/config.py#L2036)
|
||||
class is the main configuration object that is passed around. The class
|
||||
hierarchy is quite deep, and every class needs to read the configuration it is
|
||||
interested in. By encapsulating all configurations in one object, we can easily
|
||||
@ -188,7 +167,7 @@ the `VllmConfig` class, and the model runner can access it directly. We don't
|
||||
need to change the constructor of the engine, worker, or model class to pass the
|
||||
new configuration option.
|
||||
|
||||
2. **Uniformity**: The model runner needs a unified interface to create and
|
||||
2\. **Uniformity**: The model runner needs a unified interface to create and
|
||||
initialize the model. vLLM supports more than 50 types of popular open-source
|
||||
models. Each model has its own initialization logic. If the constructor
|
||||
signature varies with models, the model runner does not know how to call the
|
||||
@ -200,46 +179,46 @@ of a vision model and a language model. By making the constructor uniform, we
|
||||
can easily create a vision model and a language model and compose them into a
|
||||
vision-language model.
|
||||
|
||||
.. note::
|
||||
````{note}
|
||||
To support this change, all vLLM models' signatures have been updated to:
|
||||
|
||||
To support this change, all vLLM models' signatures have been updated to:
|
||||
```python
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
```
|
||||
|
||||
.. code-block:: python
|
||||
To avoid accidentally passing incorrect arguments, the constructor is now keyword-only. This ensures that the constructor will raise an error if old configurations are passed. vLLM developers have already made this change for all models within vLLM. For out-of-tree registered models, developers need to update their models, for example by adding shim code to adapt the old constructor signature to the new one:
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
```python
|
||||
class MyOldModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
...
|
||||
|
||||
To avoid accidentally passing incorrect arguments, the constructor is now keyword-only. This ensures that the constructor will raise an error if old configurations are passed. vLLM developers have already made this change for all models within vLLM. For out-of-tree registered models, developers need to update their models, for example by adding shim code to adapt the old constructor signature to the new one:
|
||||
from vllm.config import VllmConfig
|
||||
class MyNewModel(MyOldModel):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
super().__init__(config, cache_config, quant_config, lora_config, prefix)
|
||||
|
||||
.. code-block:: python
|
||||
if __version__ >= "0.6.4":
|
||||
MyModel = MyNewModel
|
||||
else:
|
||||
MyModel = MyOldModel
|
||||
```
|
||||
|
||||
class MyOldModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
...
|
||||
This way, the model can work with both old and new versions of vLLM.
|
||||
````
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
class MyNewModel(MyOldModel):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
super().__init__(config, cache_config, quant_config, lora_config, prefix)
|
||||
|
||||
if __version__ >= "0.6.4":
|
||||
MyModel = MyNewModel
|
||||
else:
|
||||
MyModel = MyOldModel
|
||||
|
||||
This way, the model can work with both old and new versions of vLLM.
|
||||
|
||||
3. **Sharding and Quantization at Initialization**: Certain features require
|
||||
3\. **Sharding and Quantization at Initialization**: Certain features require
|
||||
changing the model weights. For example, tensor parallelism needs to shard the
|
||||
model weights, and quantization needs to quantize the model weights. There are
|
||||
two possible ways to implement this feature. One way is to change the model
|
||||
@ -252,23 +231,23 @@ initialized, we need to load the full 810GB weights to every GPU and then shard
|
||||
the weights, leading to a huge memory overhead. Instead, if we shard the weights
|
||||
during the model initialization, every layer will only create a shard of the
|
||||
weights it needs, leading to a much smaller memory overhead. The same idea
|
||||
applies to quantization. Note that we also add an additional argument ``prefix``
|
||||
applies to quantization. Note that we also add an additional argument `prefix`
|
||||
to the model's constructor so that the model can initialize itself differently
|
||||
based on the prefix. This is useful for non-uniform quantization, where
|
||||
different parts of the model are quantized differently. The ``prefix`` is
|
||||
usually an empty string for the top-level model and a string like ``"vision"``
|
||||
or ``"language"`` for the sub-models. In general, it matches the name of the
|
||||
different parts of the model are quantized differently. The `prefix` is
|
||||
usually an empty string for the top-level model and a string like `"vision"`
|
||||
or `"language"` for the sub-models. In general, it matches the name of the
|
||||
module's state dict in the checkpoint file.
|
||||
|
||||
One disadvantage of this design is that it is hard to write unit tests for
|
||||
individual components in vLLM because every component needs to be initialized by
|
||||
a complete config object. We solve this problem by providing a default
|
||||
initialization function that creates a default config object with all fields set
|
||||
to ``None``. If the component we want to test only cares about a few fields in
|
||||
to `None`. If the component we want to test only cares about a few fields in
|
||||
the config object, we can create a default config object and set the fields we
|
||||
care about. This way, we can test the component in isolation. Note that many
|
||||
tests in vLLM are end-to-end tests that test the whole system, so this is not a
|
||||
big problem.
|
||||
|
||||
In summary, the complete config object ``VllmConfig`` can be treated as an
|
||||
In summary, the complete config object `VllmConfig` can be treated as an
|
||||
engine-level global state that is shared among all vLLM classes.
|
||||
36
docs/source/design/huggingface_integration.md
Normal file
36
docs/source/design/huggingface_integration.md
Normal file
@ -0,0 +1,36 @@
|
||||
(huggingface-integration)=
|
||||
|
||||
# Integration with HuggingFace
|
||||
|
||||
This document describes how vLLM integrates with HuggingFace libraries. We will explain step by step what happens under the hood when we run `vllm serve`.
|
||||
|
||||
Let's say we want to serve the popular QWen model by running `vllm serve Qwen/Qwen2-7B`.
|
||||
|
||||
1. The `model` argument is `Qwen/Qwen2-7B`. vLLM determines whether this model exists by checking for the corresponding config file `config.json`. See this [code snippet](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L162-L182) for the implementation. Within this process:
|
||||
|
||||
- If the `model` argument corresponds to an existing local path, vLLM will load the config file directly from this path.
|
||||
- If the `model` argument is a HuggingFace model ID consisting of a username and model name, vLLM will first try to use the config file from the HuggingFace local cache, using the `model` argument as the model name and the `--revision` argument as the revision. See [their website](https://huggingface.co/docs/huggingface_hub/en/package_reference/environment_variables#hfhome) for more information on how the HuggingFace cache works.
|
||||
- If the `model` argument is a HuggingFace model ID but it is not found in the cache, vLLM will download the config file from the HuggingFace model hub. Refer to [this function](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L91) for the implementation. The input arguments include the `model` argument as the model name, the `--revision` argument as the revision, and the environment variable `HF_TOKEN` as the token to access the model hub. In our case, vLLM will download the [config.json](https://huggingface.co/Qwen/Qwen2-7B/blob/main/config.json) file.
|
||||
|
||||
2. After confirming the existence of the model, vLLM loads its config file and converts it into a dictionary. See this [code snippet](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L185-L186) for the implementation.
|
||||
|
||||
3. Next, vLLM [inspects](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L189) the `model_type` field in the config dictionary to [generate](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#190-L216) the config object to use. There are some `model_type` values that vLLM directly supports; see [here](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L48) for the list. If the `model_type` is not in the list, vLLM will use [AutoConfig.from_pretrained](https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoConfig.from_pretrained) to load the config class, with `model`, `--revision`, and `--trust_remote_code` as the arguments. Please note that:
|
||||
|
||||
- HuggingFace also has its own logic to determine the config class to use. It will again use the `model_type` field to search for the class name in the transformers library; see [here](https://github.com/huggingface/transformers/tree/main/src/transformers/models) for the list of supported models. If the `model_type` is not found, HuggingFace will use the `auto_map` field from the config JSON file to determine the class name. Specifically, it is the `AutoConfig` field under `auto_map`. See [DeepSeek](https://huggingface.co/deepseek-ai/DeepSeek-V2.5/blob/main/config.json) for an example.
|
||||
- The `AutoConfig` field under `auto_map` points to a module path in the model's repository. To create the config class, HuggingFace will import the module and use the `from_pretrained` method to load the config class. This can generally cause arbitrary code execution, so it is only executed when `--trust_remote_code` is enabled.
|
||||
|
||||
4. Subsequently, vLLM applies some historical patches to the config object. These are mostly related to RoPE configuration; see [here](https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/transformers_utils/config.py#L244) for the implementation.
|
||||
|
||||
5. Finally, vLLM can reach the model class we want to initialize. vLLM uses the `architectures` field in the config object to determine the model class to initialize, as it maintains the mapping from architecture name to model class in [its registry](https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/model_executor/models/registry.py#L80). If the architecture name is not found in the registry, it means this model architecture is not supported by vLLM. For `Qwen/Qwen2-7B`, the `architectures` field is `["Qwen2ForCausalLM"]`, which corresponds to the `Qwen2ForCausalLM` class in [vLLM's code](https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/model_executor/models/qwen2.py#L364). This class will initialize itself depending on various configs.
|
||||
|
||||
Beyond that, there are two more things vLLM depends on HuggingFace for.
|
||||
|
||||
1. **Tokenizer**: vLLM uses the tokenizer from HuggingFace to tokenize the input text. The tokenizer is loaded using [AutoTokenizer.from_pretrained](https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoTokenizer.from_pretrained) with the `model` argument as the model name and the `--revision` argument as the revision. It is also possible to use a tokenizer from another model by specifying the `--tokenizer` argument in the `vllm serve` command. Other relevant arguments are `--tokenizer-revision` and `--tokenizer-mode`. Please check HuggingFace's documentation for the meaning of these arguments. This part of the logic can be found in the [get_tokenizer](https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/transformers_utils/tokenizer.py#L87) function. After obtaining the tokenizer, notably, vLLM will cache some expensive attributes of the tokenizer in [get_cached_tokenizer](https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/transformers_utils/tokenizer.py#L24).
|
||||
|
||||
2. **Model weight**: vLLM downloads the model weight from the HuggingFace model hub using the `model` argument as the model name and the `--revision` argument as the revision. vLLM provides the argument `--load-format` to control what files to download from the model hub. By default, it will try to load the weights in the safetensors format and fall back to the PyTorch bin format if the safetensors format is not available. We can also pass `--load-format dummy` to skip downloading the weights.
|
||||
|
||||
- It is recommended to use the safetensors format, as it is efficient for loading in distributed inference and also safe from arbitrary code execution. See the [documentation](https://huggingface.co/docs/safetensors/en/index) for more information on the safetensors format. This part of the logic can be found [here](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/model_executor/model_loader/loader.py#L385). Please note that:
|
||||
|
||||
This completes the integration between vLLM and HuggingFace.
|
||||
|
||||
In summary, vLLM reads the config file `config.json`, tokenizer, and model weight from the HuggingFace model hub or a local directory. It uses the config class from either vLLM, HuggingFace transformers, or loads the config class from the model's repository.
|
||||
@ -1,40 +0,0 @@
|
||||
.. _huggingface_integration:
|
||||
|
||||
Integration with HuggingFace
|
||||
===================================
|
||||
|
||||
This document describes how vLLM integrates with HuggingFace libraries. We will explain step by step what happens under the hood when we run ``vllm serve``.
|
||||
|
||||
Let's say we want to serve the popular QWen model by running ``vllm serve Qwen/Qwen2-7B``.
|
||||
|
||||
1. The ``model`` argument is ``Qwen/Qwen2-7B``. vLLM determines whether this model exists by checking for the corresponding config file ``config.json``. See this `code snippet <https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L162-L182>`__ for the implementation. Within this process:
|
||||
|
||||
- If the ``model`` argument corresponds to an existing local path, vLLM will load the config file directly from this path.
|
||||
|
||||
- If the ``model`` argument is a HuggingFace model ID consisting of a username and model name, vLLM will first try to use the config file from the HuggingFace local cache, using the ``model`` argument as the model name and the ``--revision`` argument as the revision. See `their website <https://huggingface.co/docs/huggingface_hub/en/package_reference/environment_variables#hfhome>`__ for more information on how the HuggingFace cache works.
|
||||
|
||||
- If the ``model`` argument is a HuggingFace model ID but it is not found in the cache, vLLM will download the config file from the HuggingFace model hub. Refer to `this function <https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L91>`__ for the implementation. The input arguments include the ``model`` argument as the model name, the ``--revision`` argument as the revision, and the environment variable ``HF_TOKEN`` as the token to access the model hub. In our case, vLLM will download the `config.json <https://huggingface.co/Qwen/Qwen2-7B/blob/main/config.json>`__ file.
|
||||
|
||||
2. After confirming the existence of the model, vLLM loads its config file and converts it into a dictionary. See this `code snippet <https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L185-L186>`__ for the implementation.
|
||||
|
||||
3. Next, vLLM `inspects <https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L189>`__ the ``model_type`` field in the config dictionary to `generate <https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#190-L216>`__ the config object to use. There are some ``model_type`` values that vLLM directly supports; see `here <https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/transformers_utils/config.py#L48>`__ for the list. If the ``model_type`` is not in the list, vLLM will use `AutoConfig.from_pretrained <https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoConfig.from_pretrained>`__ to load the config class, with ``model``, ``--revision``, and ``--trust_remote_code`` as the arguments. Please note that:
|
||||
|
||||
- HuggingFace also has its own logic to determine the config class to use. It will again use the ``model_type`` field to search for the class name in the transformers library; see `here <https://github.com/huggingface/transformers/tree/main/src/transformers/models>`__ for the list of supported models. If the ``model_type`` is not found, HuggingFace will use the ``auto_map`` field from the config JSON file to determine the class name. Specifically, it is the ``AutoConfig`` field under ``auto_map``. See `DeepSeek <https://huggingface.co/deepseek-ai/DeepSeek-V2.5/blob/main/config.json>`__ for an example.
|
||||
|
||||
- The ``AutoConfig`` field under ``auto_map`` points to a module path in the model's repository. To create the config class, HuggingFace will import the module and use the ``from_pretrained`` method to load the config class. This can generally cause arbitrary code execution, so it is only executed when ``--trust_remote_code`` is enabled.
|
||||
|
||||
4. Subsequently, vLLM applies some historical patches to the config object. These are mostly related to RoPE configuration; see `here <https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/transformers_utils/config.py#L244>`__ for the implementation.
|
||||
|
||||
5. Finally, vLLM can reach the model class we want to initialize. vLLM uses the ``architectures`` field in the config object to determine the model class to initialize, as it maintains the mapping from architecture name to model class in `its registry <https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/model_executor/models/registry.py#L80>`__. If the architecture name is not found in the registry, it means this model architecture is not supported by vLLM. For ``Qwen/Qwen2-7B``, the ``architectures`` field is ``["Qwen2ForCausalLM"]``, which corresponds to the ``Qwen2ForCausalLM`` class in `vLLM's code <https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/model_executor/models/qwen2.py#L364>`__. This class will initialize itself depending on various configs.
|
||||
|
||||
Beyond that, there are two more things vLLM depends on HuggingFace for.
|
||||
|
||||
1. **Tokenizer**: vLLM uses the tokenizer from HuggingFace to tokenize the input text. The tokenizer is loaded using `AutoTokenizer.from_pretrained <https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoTokenizer.from_pretrained>`__ with the ``model`` argument as the model name and the ``--revision`` argument as the revision. It is also possible to use a tokenizer from another model by specifying the ``--tokenizer`` argument in the ``vllm serve`` command. Other relevant arguments are ``--tokenizer-revision`` and ``--tokenizer-mode``. Please check HuggingFace's documentation for the meaning of these arguments. This part of the logic can be found in the `get_tokenizer <https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/transformers_utils/tokenizer.py#L87>`__ function. After obtaining the tokenizer, notably, vLLM will cache some expensive attributes of the tokenizer in `get_cached_tokenizer <https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/transformers_utils/tokenizer.py#L24>`__.
|
||||
|
||||
2. **Model weight**: vLLM downloads the model weight from the HuggingFace model hub using the ``model`` argument as the model name and the ``--revision`` argument as the revision. vLLM provides the argument ``--load-format`` to control what files to download from the model hub. By default, it will try to load the weights in the safetensors format and fall back to the PyTorch bin format if the safetensors format is not available. We can also pass ``--load-format dummy`` to skip downloading the weights.
|
||||
|
||||
- It is recommended to use the safetensors format, as it is efficient for loading in distributed inference and also safe from arbitrary code execution. See the `documentation <https://huggingface.co/docs/safetensors/en/index>`__ for more information on the safetensors format. This part of the logic can be found `here <https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/model_executor/model_loader/loader.py#L385>`__. Please note that:
|
||||
|
||||
This completes the integration between vLLM and HuggingFace.
|
||||
|
||||
In summary, vLLM reads the config file ``config.json``, tokenizer, and model weight from the HuggingFace model hub or a local directory. It uses the config class from either vLLM, HuggingFace transformers, or loads the config class from the model's repository.
|
||||
@ -0,0 +1,19 @@
|
||||
(input-processing-pipeline)=
|
||||
|
||||
# Input Processing Pipeline
|
||||
|
||||
1. Input data is passed to {class}`~vllm.LLMEngine` (or {class}`~vllm.AsyncLLMEngine`).
|
||||
|
||||
2. Tokenize the data if necessary.
|
||||
|
||||
3. Process the inputs using {meth}`INPUT_REGISTRY.process_input <vllm.inputs.registry.InputRegistry.process_input>`.
|
||||
|
||||
- For example, add placeholder tokens to reserve KV cache for multi-modal embeddings.
|
||||
|
||||
4. Send the processed inputs to {class}`~vllm.executor.executor_base.ExecutorBase`.
|
||||
|
||||
5. Distribute the inputs via {class}`~vllm.worker.worker_base.WorkerBase` to {class}`~vllm.worker.model_runner_base.ModelRunnerBase`.
|
||||
|
||||
6. If the data contains multi-modal data, convert it into keyword arguments using {meth}`MULTIMODAL_REGISTRY.map_input <vllm.multimodal.MultiModalRegistry.map_input>`.
|
||||
|
||||
- For example, convert a {class}`PIL.Image.Image` input to its pixel values for a vision model.
|
||||
@ -1,20 +0,0 @@
|
||||
.. _input_processing_pipeline:
|
||||
|
||||
Input Processing Pipeline
|
||||
=========================
|
||||
|
||||
1. Input data is passed to :class:`~vllm.LLMEngine` (or :class:`~vllm.AsyncLLMEngine`).
|
||||
|
||||
2. Tokenize the data if necessary.
|
||||
|
||||
3. Process the inputs using :meth:`INPUT_REGISTRY.process_input <vllm.inputs.registry.InputRegistry.process_input>`.
|
||||
|
||||
- For example, add placeholder tokens to reserve KV cache for multi-modal embeddings.
|
||||
|
||||
4. Send the processed inputs to :class:`~vllm.executor.executor_base.ExecutorBase`.
|
||||
|
||||
5. Distribute the inputs via :class:`~vllm.worker.worker_base.WorkerBase` to :class:`~vllm.worker.model_runner_base.ModelRunnerBase`.
|
||||
|
||||
6. If the data contains multi-modal data, convert it into keyword arguments using :meth:`MULTIMODAL_REGISTRY.map_input <vllm.multimodal.MultiModalRegistry.map_input>`.
|
||||
|
||||
- For example, convert a :class:`PIL.Image.Image` input to its pixel values for a vision model.
|
||||
43
docs/source/design/input_processing/model_inputs_index.md
Normal file
43
docs/source/design/input_processing/model_inputs_index.md
Normal file
@ -0,0 +1,43 @@
|
||||
(input-processing)=
|
||||
|
||||
# Input Processing
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: vllm.inputs
|
||||
```
|
||||
|
||||
Each model can override parts of vLLM's [input processing pipeline](#input-processing-pipeline) via
|
||||
{data}`~vllm.inputs.INPUT_REGISTRY` and {data}`~vllm.multimodal.MULTIMODAL_REGISTRY`.
|
||||
|
||||
Currently, this mechanism is only utilized in [multi-modal](#multi-modality) models for preprocessing multi-modal input
|
||||
data in addition to input prompt, but it can be extended to text-only language models when needed.
|
||||
|
||||
## Guides
|
||||
|
||||
```{toctree}
|
||||
:maxdepth: 1
|
||||
|
||||
input_processing_pipeline
|
||||
```
|
||||
|
||||
## Module Contents
|
||||
|
||||
### LLM Engine Inputs
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: vllm.inputs.DecoderOnlyInputs
|
||||
:members:
|
||||
:show-inheritance:
|
||||
```
|
||||
|
||||
### Registry
|
||||
|
||||
```{eval-rst}
|
||||
.. autodata:: vllm.inputs.INPUT_REGISTRY
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: vllm.inputs.registry
|
||||
:members:
|
||||
:show-inheritance:
|
||||
```
|
||||
@ -1,39 +0,0 @@
|
||||
.. _input_processing:
|
||||
|
||||
Input Processing
|
||||
================
|
||||
|
||||
.. currentmodule:: vllm.inputs
|
||||
|
||||
Each model can override parts of vLLM's :ref:`input processing pipeline <input_processing_pipeline>` via
|
||||
:data:`~vllm.inputs.INPUT_REGISTRY` and :data:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
|
||||
|
||||
Currently, this mechanism is only utilized in :ref:`multi-modal <multi_modality>` models for preprocessing multi-modal input
|
||||
data in addition to input prompt, but it can be extended to text-only language models when needed.
|
||||
|
||||
Guides
|
||||
++++++
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
input_processing_pipeline
|
||||
|
||||
Module Contents
|
||||
+++++++++++++++
|
||||
|
||||
LLM Engine Inputs
|
||||
-----------------
|
||||
|
||||
.. autoclass:: vllm.inputs.DecoderOnlyInputs
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
Registry
|
||||
--------
|
||||
|
||||
.. autodata:: vllm.inputs.INPUT_REGISTRY
|
||||
|
||||
.. automodule:: vllm.inputs.registry
|
||||
:members:
|
||||
:show-inheritance:
|
||||
527
docs/source/design/kernel/paged_attention.md
Normal file
527
docs/source/design/kernel/paged_attention.md
Normal file
@ -0,0 +1,527 @@
|
||||
# vLLM Paged Attention
|
||||
|
||||
- Currently, vLLM utilizes its own implementation of a multi-head query
|
||||
attention kernel (`csrc/attention/attention_kernels.cu`).
|
||||
This kernel is designed to be compatible with
|
||||
vLLM's paged KV caches, where the key and value cache are stored in
|
||||
separate blocks (note that this block concept differs from the GPU
|
||||
thread block. So in a later document, I will refer to vLLM paged
|
||||
attention block as "block", while refer to GPU thread block as
|
||||
"thread block").
|
||||
- To achieve high performance, this kernel relies on a specially
|
||||
designed memory layout and access method, specifically when threads
|
||||
read data from global memory to shared memory. The purpose of this
|
||||
document is to provide a high-level explanation of the kernel
|
||||
implementation step by step, aiding those who wish to learn about the
|
||||
vLLM multi-head query attention kernel. After going through this
|
||||
document, users will likely have a better understanding and feel easier
|
||||
to follow the actual implementation.
|
||||
- Please note that this document may not cover all details, such as how
|
||||
to calculate the correct index for the corresponding data or the dot
|
||||
multiplication implementation. However, after reading this document
|
||||
and becoming familiar with the high-level logic flow, it should be
|
||||
easier for you to read the actual code and understand the details.
|
||||
|
||||
## Inputs
|
||||
|
||||
- The kernel function takes a list of arguments for the current thread
|
||||
to perform its assigned work. The three most important arguments are
|
||||
the input pointers `q`, `k_cache`, and `v_cache`, which point
|
||||
to query, key, and value data on global memory that need to be read
|
||||
and processed. The output pointer `out` points to global memory
|
||||
where the result should be written. These four pointers actually
|
||||
refer to multi-dimensional arrays, but each thread only accesses the
|
||||
portion of data assigned to it. I have omitted all other runtime
|
||||
parameters here for simplicity.
|
||||
|
||||
```cpp
|
||||
template<
|
||||
typename scalar_t,
|
||||
int HEAD_SIZE,
|
||||
int BLOCK_SIZE,
|
||||
int NUM_THREADS,
|
||||
int PARTITION_SIZE = 0>
|
||||
__device__ void paged_attention_kernel(
|
||||
... // Other side args.
|
||||
const scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||
... // Other side args.
|
||||
)
|
||||
```
|
||||
|
||||
- There are also a list of template arguments above the function
|
||||
signature that are determined during compilation time. `scalar_t`
|
||||
represents the data type of the query, key, and value data elements,
|
||||
such as FP16. `HEAD_SIZE` indicates the number of elements in each
|
||||
head. `BLOCK_SIZE` refers to the number of tokens in each block.
|
||||
`NUM_THREADS` denotes the number of threads in each thread block.
|
||||
`PARTITION_SIZE` represents the number of tensor parallel GPUs (For
|
||||
simplicity, we assume this is 0 and tensor parallel is disabled).
|
||||
|
||||
- With these arguments, we need to perform a sequence of preparations.
|
||||
This includes calculating the current head index, block index, and
|
||||
other necessary variables. However, for now, we can ignore these
|
||||
preparations and proceed directly to the actual calculations. It will
|
||||
be easier to understand them once we grasp the entire flow.
|
||||
|
||||
## Concepts
|
||||
|
||||
- Just before we dive into the calculation flow, I want to describe a
|
||||
few concepts that are needed for later sections. However, you may
|
||||
skip this section and return later if you encounter any confusing
|
||||
terminologies.
|
||||
- **Sequence**: A sequence represents a client request. For example,
|
||||
the data pointed to by `q` has a shape of
|
||||
`[num_seqs, num_heads, head_size]`. That represents there are total
|
||||
`num_seqs` of query sequence data are pointed by `q`. Since this
|
||||
kernel is a single query attention kernel, each sequence only has one
|
||||
query token. Hence, the `num_seqs` equals the total number of tokens
|
||||
that are processed in the batch.
|
||||
- **Context**: The context consists of the generated tokens from the
|
||||
sequence. For instance, `["What", "is", "your"]` are the context
|
||||
tokens, and the input query token is `"name"`. The model might
|
||||
generate the token `"?"`.
|
||||
- **Vec**: The vec is a list of elements that are fetched and
|
||||
calculated together. For query and key data, the vec size
|
||||
(`VEC_SIZE`) is determined so that each thread group can fetch and
|
||||
calculate 16 bytes of data at a time. For value data, the vec size
|
||||
(`V_VEC_SIZE`) is determined so that each thread can fetch and
|
||||
calculate 16 bytes of data at a time. For example, if the
|
||||
`scalar_t` is FP16 (2 bytes) and `THREAD_GROUP_SIZE` is 2, the
|
||||
`VEC_SIZE` will be 4, while the `V_VEC_SIZE` will be 8.
|
||||
- **Thread group**: The thread group is a small group of
|
||||
threads(`THREAD_GROUP_SIZE`) that fetches and calculates one
|
||||
query token and one key token at a time. Each thread handles only a
|
||||
portion of the token data. The total number of elements processed by
|
||||
one thread group is referred as `x`. For example, if the thread
|
||||
group contains 2 threads and the head size is 8, then thread 0
|
||||
handles the query and key elements at index 0, 2, 4, 6, while thread
|
||||
1 handles the elements at index 1, 3, 5, 7.
|
||||
- **Block**: The key and value cache data in vLLM are split into
|
||||
blocks. Each block stores data for a fixed number(`BLOCK_SIZE`)
|
||||
of tokens at one head. Each block may contain only a portion of the
|
||||
whole context tokens. For example, if the block size is 16 and the
|
||||
head size is 128, then for one head, one block can store 16 * 128 =
|
||||
2048 elements.
|
||||
- **Warp**: A warp is a group of 32 threads(`WARP_SIZE`) that
|
||||
execute simultaneously on a stream multiprocessor (SM). In this
|
||||
kernel, each warp processes the calculation between one query token
|
||||
and key tokens of one entire block at a time (it may process multiple
|
||||
blocks in multiple iterations). For example, if there are 4 warps and
|
||||
6 blocks for one context, the assignment would be like warp 0 handles
|
||||
the 0th, 4th blocks, warp 1 handles the 1st, 5th blocks, warp 2
|
||||
handles the 2nd block and warp 3 handles the 3rd block.
|
||||
- **Thread block**: A thread block is a group of
|
||||
threads(`NUM_THREADS`) that can access the same shared memory.
|
||||
Each thread block contains multiple warps(`NUM_WARPS`), and in
|
||||
this kernel, each thread block processes the calculation between one
|
||||
query token and key tokens of a whole context.
|
||||
- **Grid**: A grid is a collection of thread blocks and defines the
|
||||
shape of the collection. In this kernel, the shape is
|
||||
`(num_heads, num_seqs, max_num_partitions)`. Therefore, each thread
|
||||
block only handles the calculation for one head, one sequence, and
|
||||
one partition.
|
||||
|
||||
## Query
|
||||
|
||||
- This section will introduce how query data is stored in memory and
|
||||
fetched by each thread. As mentioned above, each thread group fetches
|
||||
one query token data, while each thread itself only handles a part of
|
||||
one query token data. Within each warp, every thread group will fetch
|
||||
the same query token data, but will multiply it with different key
|
||||
token data.
|
||||
|
||||
```cpp
|
||||
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||
```
|
||||
|
||||
```{figure} ../../assets/kernel/query.png
|
||||
:align: center
|
||||
:alt: query
|
||||
:width: 70%
|
||||
|
||||
Query data of one token at one head
|
||||
```
|
||||
|
||||
- Each thread defines its own `q_ptr` which points to the assigned
|
||||
query token data on global memory. For example, if `VEC_SIZE` is 4
|
||||
and `HEAD_SIZE` is 128, the `q_ptr` points to data that contains
|
||||
total of 128 elements divided into 128 / 4 = 32 vecs.
|
||||
|
||||
```{figure} ../../assets/kernel/q_vecs.png
|
||||
:align: center
|
||||
:alt: q_vecs
|
||||
:width: 70%
|
||||
|
||||
`q_vecs` for one thread group
|
||||
```
|
||||
|
||||
```cpp
|
||||
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
|
||||
```
|
||||
|
||||
- Next, we need to read the global memory data pointed to by `q_ptr`
|
||||
into shared memory as `q_vecs`. It is important to note that each
|
||||
vecs is assigned to a different row. For example, if the
|
||||
`THREAD_GROUP_SIZE` is 2, thread 0 will handle the 0th row vecs,
|
||||
while thread 1 handles the 1st row vecs. By reading the query data in
|
||||
this way, neighboring threads like thread 0 and thread 1 can read
|
||||
neighbor memory, achieving the memory coalescing to improve
|
||||
performance.
|
||||
|
||||
## Key
|
||||
|
||||
- Similar to the "Query" section, this section introduces memory layout
|
||||
and assignment for keys. While each thread group only handle one
|
||||
query token one kernel run, it may handle multiple key tokens across
|
||||
multiple iterations. Meanwhile, each warp will process multiple blocks
|
||||
of key tokens in multiple iterations, ensuring that all context
|
||||
tokens are processed by the entire thread group after the kernel run.
|
||||
In this context, "handle" refers to performing the dot multiplication
|
||||
between query data and key data.
|
||||
|
||||
```cpp
|
||||
const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride
|
||||
+ kv_head_idx * kv_head_stride
|
||||
+ physical_block_offset * x;
|
||||
```
|
||||
|
||||
- Unlike to `q_ptr`, `k_ptr` in each thread will point to different
|
||||
key token at different iterations. As shown above, that `k_ptr`
|
||||
points to key token data based on `k_cache` at assigned block,
|
||||
assigned head and assigned token.
|
||||
|
||||
```{figure} ../../assets/kernel/key.png
|
||||
:align: center
|
||||
:alt: key
|
||||
:width: 70%
|
||||
|
||||
Key data of all context tokens at one head
|
||||
```
|
||||
|
||||
- The diagram above illustrates the memory layout for key data. It
|
||||
assumes that the `BLOCK_SIZE` is 16, `HEAD_SIZE` is 128, `x` is
|
||||
8, `THREAD_GROUP_SIZE` is 2, and there are a total of 4 warps. Each
|
||||
rectangle represents all the elements for one key token at one head,
|
||||
which will be processed by one thread group. The left half shows the
|
||||
total 16 blocks of key token data for warp 0, while the right half
|
||||
represents the remaining key token data for other warps or
|
||||
iterations. Inside each rectangle, there are a total 32 vecs (128
|
||||
elements for one token) that will be processed by 2 threads (one
|
||||
thread group) separately.
|
||||
|
||||
```{figure} ../../assets/kernel/k_vecs.png
|
||||
:align: center
|
||||
:alt: k_vecs
|
||||
:width: 70%
|
||||
|
||||
`k_vecs` for one thread
|
||||
```
|
||||
|
||||
```cpp
|
||||
K_vec k_vecs[NUM_VECS_PER_THREAD]
|
||||
```
|
||||
|
||||
- Next, we need to read the key token data from `k_ptr` and store
|
||||
them on register memory as `k_vecs`. We use register memory for
|
||||
`k_vecs` because it will only be accessed by one thread once,
|
||||
whereas `q_vecs` will be accessed by multiple threads multiple
|
||||
times. Each `k_vecs` will contain multiple vectors for later
|
||||
calculation. Each vec will be set at each inner iteration. The
|
||||
assignment of vecs allows neighboring threads in a warp to read
|
||||
neighboring memory together, which again promotes the memory
|
||||
coalescing. For instance, thread 0 will read vec 0, while thread 1
|
||||
will read vec 1. In the next inner loop, thread 0 will read vec 2,
|
||||
while thread 1 will read vec 3, and so on.
|
||||
|
||||
- You may still be a little confused about the overall flow. Don't
|
||||
worry, please keep reading the next "QK" section. It will illustrate
|
||||
the query and key calculation flow in a clearer and higher-level
|
||||
manner.
|
||||
|
||||
## QK
|
||||
|
||||
- As shown the pseudo code below, before the entire for loop block, we
|
||||
fetch the query data for one token and store it in `q_vecs`. Then,
|
||||
in the outer for loop, we iterate through different `k_ptrs` that
|
||||
point to different tokens and prepare the `k_vecs` in the inner for
|
||||
loop. Finally, we perform the dot multiplication between the
|
||||
`q_vecs` and each `k_vecs`.
|
||||
|
||||
```cpp
|
||||
q_vecs = ...
|
||||
for ... {
|
||||
k_ptr = ...
|
||||
for ... {
|
||||
k_vecs[i] = ...
|
||||
}
|
||||
...
|
||||
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
|
||||
}
|
||||
```
|
||||
|
||||
- As mentioned before, for each thread, it only fetches part of the
|
||||
query and key token data at a time. However, there will be a cross
|
||||
thread group reduction happen in the `Qk_dot<>::dot` . So `qk`
|
||||
returned here is not just between part of the query and key token dot
|
||||
multiplication, but actually a full result between entire query and
|
||||
key token data.
|
||||
|
||||
- For example, if the value of `HEAD_SIZE` is 128 and
|
||||
`THREAD_GROUP_SIZE` is 2, each thread's `k_vecs` will contain
|
||||
total 64 elements. However, the returned `qk` is actually the
|
||||
result of dot multiplication between 128 query elements and 128 key
|
||||
elements. If you want to learn more about the details of the dot
|
||||
multiplication and reduction, you may refer to the implementation of
|
||||
`Qk_dot<>::dot`. However, for the sake of simplicity, I will not
|
||||
cover it in this document.
|
||||
|
||||
## Softmax
|
||||
|
||||
- Next, we need to calculate the normalized softmax for all `qk`s,
|
||||
as shown above, where each $x$ represents a `qk`. To do this,
|
||||
we must obtain the reduced value of `qk_max`($m(x)$) and
|
||||
the `exp_sum`($\ell(x)$) of all `qk`s. The reduction
|
||||
should be performed across the entire thread block, encompassing
|
||||
results between the query token and all context key tokens.
|
||||
|
||||
```{math}
|
||||
:nowrap: true
|
||||
|
||||
\begin{gather*}
|
||||
m(x):=\max _i \quad x_i \\ \quad f(x):=\left[\begin{array}{lll}e^{x_1-m(x)} & \ldots & e^{x_B-m(x)}\end{array}\right]\\ \quad \ell(x):=\sum_i f(x)_i \\
|
||||
\quad \operatorname{softmax}(x):=\frac{f(x)}{\ell(x)}
|
||||
\end{gather*}
|
||||
```
|
||||
|
||||
### `qk_max` and `logits`
|
||||
|
||||
- Just right after we get the `qk` result, we can set the temporary
|
||||
`logits` result with `qk` (In the end, the `logits` should
|
||||
store the normalized softmax result). Also we can compare and collect
|
||||
the `qk_max` for all `qk`s that are calculated by current
|
||||
thread group.
|
||||
|
||||
```cpp
|
||||
if (thread_group_offset == 0) {
|
||||
const bool mask = token_idx >= context_len;
|
||||
logits[token_idx - start_token_idx] = mask ? 0.f : qk;
|
||||
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
|
||||
}
|
||||
```
|
||||
|
||||
- Please note that the `logits` here is on shared memory, so each
|
||||
thread group will set the fields for its own assigned context tokens.
|
||||
Overall, the size of logits should be number of context tokens.
|
||||
|
||||
```cpp
|
||||
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
|
||||
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
|
||||
}
|
||||
|
||||
if (lane == 0) {
|
||||
red_smem[warp_idx] = qk_max;
|
||||
}
|
||||
```
|
||||
|
||||
- Then we need to get the reduced `qk_max` across each warp. The main
|
||||
idea is to make threads in warp to communicate with each other and
|
||||
get the final max `qk` .
|
||||
|
||||
```cpp
|
||||
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
||||
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
|
||||
}
|
||||
qk_max = VLLM_SHFL_SYNC(qk_max, 0);
|
||||
```
|
||||
|
||||
- Finally, we can get the reduced `qk_max` from whole thread block by
|
||||
compare the `qk_max` from all warps in this thread block. Then we
|
||||
need to broadcast the final result to each thread.
|
||||
|
||||
### `exp_sum`
|
||||
|
||||
- Similar to `qk_max`, we need to get the reduced sum value from the
|
||||
entire thread block too.
|
||||
|
||||
```cpp
|
||||
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
|
||||
float val = __expf(logits[i] - qk_max);
|
||||
logits[i] = val;
|
||||
exp_sum += val;
|
||||
}
|
||||
...
|
||||
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
|
||||
```
|
||||
|
||||
- Firstly, sum all exp values from each thread group, and meanwhile,
|
||||
convert each entry of `logits` from `qk` to `exp(qk - qk_max)`.
|
||||
Please note, the `qk_max` here is already the max `qk` across the
|
||||
whole thread block. And then we can do reduction for `exp_sum`
|
||||
across whole thread block just like the `qk_max`.
|
||||
|
||||
```cpp
|
||||
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
|
||||
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
|
||||
logits[i] *= inv_sum;
|
||||
}
|
||||
```
|
||||
|
||||
- Finally, with the reduced `qk_max` and `exp_sum`, we can obtain
|
||||
the final normalized softmax result as `logits`. This `logits`
|
||||
variable will be used for dot multiplication with the value data in
|
||||
later steps. Now, it should store the normalized softmax result of
|
||||
`qk` for all assigned context tokens.
|
||||
|
||||
## Value
|
||||
|
||||
```{figure} ../../assets/kernel/value.png
|
||||
:align: center
|
||||
:alt: value
|
||||
:width: 70%
|
||||
|
||||
Value data of all context tokens at one head
|
||||
```
|
||||
|
||||
```{figure} ../../assets/kernel/logits_vec.png
|
||||
:align: center
|
||||
:alt: logits_vec
|
||||
:width: 50%
|
||||
|
||||
`logits_vec` for one thread
|
||||
```
|
||||
|
||||
```{figure} ../../assets/kernel/v_vec.png
|
||||
:align: center
|
||||
:alt: v_vec
|
||||
:width: 70%
|
||||
|
||||
List of `v_vec` for one thread
|
||||
```
|
||||
|
||||
- Now we need to retrieve the value data and perform dot multiplication
|
||||
with `logits`. Unlike query and key, there is no thread group
|
||||
concept for value data. As shown in diagram, different from key token
|
||||
memory layout, elements from the same column correspond to the same
|
||||
value token. For one block of value data, there are `HEAD_SIZE` of
|
||||
rows and `BLOCK_SIZE` of columns that are split into multiple
|
||||
`v_vecs`.
|
||||
|
||||
- Each thread always fetches `V_VEC_SIZE` elements from the same
|
||||
`V_VEC_SIZE` of tokens at a time. As a result, a single thread
|
||||
retrieves multiple `v_vec`s from different rows and the same
|
||||
columns through multiple inner iterations. For each `v_vec`, it
|
||||
needs to be dot multiplied with the corresponding `logits_vec`,
|
||||
which is also `V_VEC_SIZE` elements from `logits`. Overall, with
|
||||
multiple inner iterations, each warp will process one block of value
|
||||
tokens. And with multiple outer iterations, the whole context value
|
||||
tokens are processd
|
||||
|
||||
```cpp
|
||||
float accs[NUM_ROWS_PER_THREAD];
|
||||
for ... { // Iteration over different blocks.
|
||||
logits_vec = ...
|
||||
for ... { // Iteration over different rows.
|
||||
v_vec = ...
|
||||
...
|
||||
accs[i] += dot(logits_vec, v_vec);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
- As shown in the above pseudo code, in the outer loop, similar to
|
||||
`k_ptr`, `logits_vec` iterates over different blocks and reads
|
||||
`V_VEC_SIZE` elements from `logits`. In the inner loop, each
|
||||
thread reads `V_VEC_SIZE` elements from the same tokens as a
|
||||
`v_vec` and performs dot multiplication. It is important to note
|
||||
that in each inner iteration, the thread fetches different head
|
||||
position elements for the same tokens. The dot result is then
|
||||
accumulated in `accs`. Therefore, each entry of `accs` is mapped
|
||||
to a head position assigned to the current thread.
|
||||
|
||||
- For example, if `BLOCK_SIZE` is 16 and `V_VEC_SIZE` is 8, each
|
||||
thread fetches 8 value elements for 8 tokens at a time. Each element
|
||||
is from different tokens at the same head position. If `HEAD_SIZE`
|
||||
is 128 and `WARP_SIZE` is 32, for each inner loop, a warp needs to
|
||||
fetch `WARP_SIZE * V_VEC_SIZE = 256` elements. This means there are
|
||||
a total of 128 * 16 / 256 = 8 inner iterations for a warp to handle
|
||||
a whole block of value tokens. And each `accs` in each thread
|
||||
contains 8 elements that accumulated at 8 different head positions.
|
||||
For the thread 0, the `accs` variable will have 8 elements, which
|
||||
are 0th, 32th … 224th elements of a value head that are accumulated
|
||||
from all assigned 8 tokens.
|
||||
|
||||
## LV
|
||||
|
||||
- Now, we need to perform reduction for `accs` within each warp. This
|
||||
process allows each thread to accumulate the `accs` for the
|
||||
assigned head positions of all tokens in one block.
|
||||
|
||||
```cpp
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
float acc = accs[i];
|
||||
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
|
||||
acc += VLLM_SHFL_XOR_SYNC(acc, mask);
|
||||
}
|
||||
accs[i] = acc;
|
||||
}
|
||||
```
|
||||
|
||||
- Next, we perform reduction for `accs` across all warps, allowing
|
||||
each thread to have the accumulation of `accs` for the assigned
|
||||
head positions of all context tokens. Please note that each `accs`
|
||||
in every thread only stores the accumulation for a portion of
|
||||
elements of the entire head for all context tokens. However, overall,
|
||||
all results for output have been calculated but are just stored in
|
||||
different thread register memory.
|
||||
|
||||
```cpp
|
||||
float* out_smem = reinterpret_cast<float*>(shared_mem);
|
||||
for (int i = NUM_WARPS; i > 1; i /= 2) {
|
||||
// Upper warps write to shared memory.
|
||||
...
|
||||
float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
...
|
||||
dst[row_idx] = accs[i];
|
||||
}
|
||||
|
||||
// Lower warps update the output.
|
||||
const float* src = &out_smem[warp_idx * HEAD_SIZE];
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
...
|
||||
accs[i] += src[row_idx];
|
||||
}
|
||||
|
||||
// Write out the accs.
|
||||
}
|
||||
```
|
||||
|
||||
## Output
|
||||
|
||||
- Now we can write all of calculated result from local register memory
|
||||
to final output global memory.
|
||||
|
||||
```cpp
|
||||
scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
|
||||
+ head_idx * max_num_partitions * HEAD_SIZE
|
||||
+ partition_idx * HEAD_SIZE;
|
||||
```
|
||||
|
||||
- First, we need to define the `out_ptr` variable, which points to
|
||||
the start address of the assigned sequence and assigned head.
|
||||
|
||||
```cpp
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
||||
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
|
||||
from_float(*(out_ptr + row_idx), accs[i]);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
- Finally, we need to iterate over different assigned head positions
|
||||
and write out the corresponding accumulated result based on the
|
||||
`out_ptr`.
|
||||
@ -1,525 +0,0 @@
|
||||
vLLM Paged Attention
|
||||
====================
|
||||
|
||||
- Currently, vLLM utilizes its own implementation of a multi-head query
|
||||
attention kernel (``csrc/attention/attention_kernels.cu``).
|
||||
This kernel is designed to be compatible with
|
||||
vLLM's paged KV caches, where the key and value cache are stored in
|
||||
separate blocks (note that this block concept differs from the GPU
|
||||
thread block. So in a later document, I will refer to vLLM paged
|
||||
attention block as "block", while refer to GPU thread block as
|
||||
"thread block").
|
||||
- To achieve high performance, this kernel relies on a specially
|
||||
designed memory layout and access method, specifically when threads
|
||||
read data from global memory to shared memory. The purpose of this
|
||||
document is to provide a high-level explanation of the kernel
|
||||
implementation step by step, aiding those who wish to learn about the
|
||||
vLLM multi-head query attention kernel. After going through this
|
||||
document, users will likely have a better understanding and feel easier
|
||||
to follow the actual implementation.
|
||||
- Please note that this document may not cover all details, such as how
|
||||
to calculate the correct index for the corresponding data or the dot
|
||||
multiplication implementation. However, after reading this document
|
||||
and becoming familiar with the high-level logic flow, it should be
|
||||
easier for you to read the actual code and understand the details.
|
||||
|
||||
Inputs
|
||||
------
|
||||
|
||||
- The kernel function takes a list of arguments for the current thread
|
||||
to perform its assigned work. The three most important arguments are
|
||||
the input pointers ``q``, ``k_cache``, and ``v_cache``, which point
|
||||
to query, key, and value data on global memory that need to be read
|
||||
and processed. The output pointer ``out`` points to global memory
|
||||
where the result should be written. These four pointers actually
|
||||
refer to multi-dimensional arrays, but each thread only accesses the
|
||||
portion of data assigned to it. I have omitted all other runtime
|
||||
parameters here for simplicity.
|
||||
|
||||
.. code:: cpp
|
||||
|
||||
template<
|
||||
typename scalar_t,
|
||||
int HEAD_SIZE,
|
||||
int BLOCK_SIZE,
|
||||
int NUM_THREADS,
|
||||
int PARTITION_SIZE = 0>
|
||||
__device__ void paged_attention_kernel(
|
||||
... // Other side args.
|
||||
const scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||
... // Other side args.
|
||||
)
|
||||
|
||||
- There are also a list of template arguments above the function
|
||||
signature that are determined during compilation time. ``scalar_t``
|
||||
represents the data type of the query, key, and value data elements,
|
||||
such as FP16. ``HEAD_SIZE`` indicates the number of elements in each
|
||||
head. ``BLOCK_SIZE`` refers to the number of tokens in each block.
|
||||
``NUM_THREADS`` denotes the number of threads in each thread block.
|
||||
``PARTITION_SIZE`` represents the number of tensor parallel GPUs (For
|
||||
simplicity, we assume this is 0 and tensor parallel is disabled).
|
||||
- With these arguments, we need to perform a sequence of preparations.
|
||||
This includes calculating the current head index, block index, and
|
||||
other necessary variables. However, for now, we can ignore these
|
||||
preparations and proceed directly to the actual calculations. It will
|
||||
be easier to understand them once we grasp the entire flow.
|
||||
|
||||
Concepts
|
||||
--------
|
||||
|
||||
- Just before we dive into the calculation flow, I want to describe a
|
||||
few concepts that are needed for later sections. However, you may
|
||||
skip this section and return later if you encounter any confusing
|
||||
terminologies.
|
||||
- **Sequence**: A sequence represents a client request. For example,
|
||||
the data pointed to by ``q`` has a shape of
|
||||
``[num_seqs, num_heads, head_size]``. That represents there are total
|
||||
``num_seqs`` of query sequence data are pointed by ``q``. Since this
|
||||
kernel is a single query attention kernel, each sequence only has one
|
||||
query token. Hence, the ``num_seqs`` equals the total number of tokens
|
||||
that are processed in the batch.
|
||||
- **Context**: The context consists of the generated tokens from the
|
||||
sequence. For instance, ``["What", "is", "your"]`` are the context
|
||||
tokens, and the input query token is ``"name"``. The model might
|
||||
generate the token ``"?"``.
|
||||
- **Vec**: The vec is a list of elements that are fetched and
|
||||
calculated together. For query and key data, the vec size
|
||||
(``VEC_SIZE``) is determined so that each thread group can fetch and
|
||||
calculate 16 bytes of data at a time. For value data, the vec size
|
||||
(``V_VEC_SIZE``) is determined so that each thread can fetch and
|
||||
calculate 16 bytes of data at a time. For example, if the
|
||||
``scalar_t`` is FP16 (2 bytes) and ``THREAD_GROUP_SIZE`` is 2, the
|
||||
``VEC_SIZE`` will be 4, while the ``V_VEC_SIZE`` will be 8.
|
||||
- **Thread group**: The thread group is a small group of
|
||||
threads(\ ``THREAD_GROUP_SIZE``) that fetches and calculates one
|
||||
query token and one key token at a time. Each thread handles only a
|
||||
portion of the token data. The total number of elements processed by
|
||||
one thread group is referred as ``x``. For example, if the thread
|
||||
group contains 2 threads and the head size is 8, then thread 0
|
||||
handles the query and key elements at index 0, 2, 4, 6, while thread
|
||||
1 handles the elements at index 1, 3, 5, 7.
|
||||
- **Block**: The key and value cache data in vLLM are split into
|
||||
blocks. Each block stores data for a fixed number(\ ``BLOCK_SIZE``)
|
||||
of tokens at one head. Each block may contain only a portion of the
|
||||
whole context tokens. For example, if the block size is 16 and the
|
||||
head size is 128, then for one head, one block can store 16 \* 128 =
|
||||
2048 elements.
|
||||
- **Warp**: A warp is a group of 32 threads(\ ``WARP_SIZE``) that
|
||||
execute simultaneously on a stream multiprocessor (SM). In this
|
||||
kernel, each warp processes the calculation between one query token
|
||||
and key tokens of one entire block at a time (it may process multiple
|
||||
blocks in multiple iterations). For example, if there are 4 warps and
|
||||
6 blocks for one context, the assignment would be like warp 0 handles
|
||||
the 0th, 4th blocks, warp 1 handles the 1st, 5th blocks, warp 2
|
||||
handles the 2nd block and warp 3 handles the 3rd block.
|
||||
- **Thread block**: A thread block is a group of
|
||||
threads(\ ``NUM_THREADS``) that can access the same shared memory.
|
||||
Each thread block contains multiple warps(\ ``NUM_WARPS``), and in
|
||||
this kernel, each thread block processes the calculation between one
|
||||
query token and key tokens of a whole context.
|
||||
- **Grid**: A grid is a collection of thread blocks and defines the
|
||||
shape of the collection. In this kernel, the shape is
|
||||
``(num_heads, num_seqs, max_num_partitions)``. Therefore, each thread
|
||||
block only handles the calculation for one head, one sequence, and
|
||||
one partition.
|
||||
|
||||
Query
|
||||
-----
|
||||
|
||||
- This section will introduce how query data is stored in memory and
|
||||
fetched by each thread. As mentioned above, each thread group fetches
|
||||
one query token data, while each thread itself only handles a part of
|
||||
one query token data. Within each warp, every thread group will fetch
|
||||
the same query token data, but will multiply it with different key
|
||||
token data.
|
||||
|
||||
.. code:: cpp
|
||||
|
||||
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||
|
||||
.. figure:: ../../assets/kernel/query.png
|
||||
:alt: query
|
||||
:width: 70%
|
||||
:align: center
|
||||
|
||||
Query data of one token at one head
|
||||
|
||||
- Each thread defines its own ``q_ptr`` which points to the assigned
|
||||
query token data on global memory. For example, if ``VEC_SIZE`` is 4
|
||||
and ``HEAD_SIZE`` is 128, the ``q_ptr`` points to data that contains
|
||||
total of 128 elements divided into 128 / 4 = 32 vecs.
|
||||
|
||||
.. figure:: ../../assets/kernel/q_vecs.png
|
||||
:alt: q_vecs
|
||||
:width: 70%
|
||||
:align: center
|
||||
|
||||
``q_vecs`` for one thread group
|
||||
|
||||
.. code:: cpp
|
||||
|
||||
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
|
||||
|
||||
- Next, we need to read the global memory data pointed to by ``q_ptr``
|
||||
into shared memory as ``q_vecs``. It is important to note that each
|
||||
vecs is assigned to a different row. For example, if the
|
||||
``THREAD_GROUP_SIZE`` is 2, thread 0 will handle the 0th row vecs,
|
||||
while thread 1 handles the 1st row vecs. By reading the query data in
|
||||
this way, neighboring threads like thread 0 and thread 1 can read
|
||||
neighbor memory, achieving the memory coalescing to improve
|
||||
performance.
|
||||
|
||||
Key
|
||||
---
|
||||
|
||||
- Similar to the "Query" section, this section introduces memory layout
|
||||
and assignment for keys. While each thread group only handle one
|
||||
query token one kernel run, it may handle multiple key tokens across
|
||||
multiple iterations. Meanwhile, each warp will process multiple blocks
|
||||
of key tokens in multiple iterations, ensuring that all context
|
||||
tokens are processed by the entire thread group after the kernel run.
|
||||
In this context, "handle" refers to performing the dot multiplication
|
||||
between query data and key data.
|
||||
|
||||
.. code:: cpp
|
||||
|
||||
const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride
|
||||
+ kv_head_idx * kv_head_stride
|
||||
+ physical_block_offset * x;
|
||||
|
||||
- Unlike to ``q_ptr``, ``k_ptr`` in each thread will point to different
|
||||
key token at different iterations. As shown above, that ``k_ptr``
|
||||
points to key token data based on ``k_cache`` at assigned block,
|
||||
assigned head and assigned token.
|
||||
|
||||
.. figure:: ../../assets/kernel/key.png
|
||||
:alt: key
|
||||
:width: 70%
|
||||
:align: center
|
||||
|
||||
Key data of all context tokens at one head
|
||||
|
||||
- The diagram above illustrates the memory layout for key data. It
|
||||
assumes that the ``BLOCK_SIZE`` is 16, ``HEAD_SIZE`` is 128, ``x`` is
|
||||
8, ``THREAD_GROUP_SIZE`` is 2, and there are a total of 4 warps. Each
|
||||
rectangle represents all the elements for one key token at one head,
|
||||
which will be processed by one thread group. The left half shows the
|
||||
total 16 blocks of key token data for warp 0, while the right half
|
||||
represents the remaining key token data for other warps or
|
||||
iterations. Inside each rectangle, there are a total 32 vecs (128
|
||||
elements for one token) that will be processed by 2 threads (one
|
||||
thread group) separately.
|
||||
|
||||
.. figure:: ../../assets/kernel/k_vecs.png
|
||||
:alt: k_vecs
|
||||
:width: 70%
|
||||
:align: center
|
||||
|
||||
``k_vecs`` for one thread
|
||||
|
||||
.. code:: cpp
|
||||
|
||||
K_vec k_vecs[NUM_VECS_PER_THREAD]
|
||||
|
||||
- Next, we need to read the key token data from ``k_ptr`` and store
|
||||
them on register memory as ``k_vecs``. We use register memory for
|
||||
``k_vecs`` because it will only be accessed by one thread once,
|
||||
whereas ``q_vecs`` will be accessed by multiple threads multiple
|
||||
times. Each ``k_vecs`` will contain multiple vectors for later
|
||||
calculation. Each vec will be set at each inner iteration. The
|
||||
assignment of vecs allows neighboring threads in a warp to read
|
||||
neighboring memory together, which again promotes the memory
|
||||
coalescing. For instance, thread 0 will read vec 0, while thread 1
|
||||
will read vec 1. In the next inner loop, thread 0 will read vec 2,
|
||||
while thread 1 will read vec 3, and so on.
|
||||
- You may still be a little confused about the overall flow. Don't
|
||||
worry, please keep reading the next "QK" section. It will illustrate
|
||||
the query and key calculation flow in a clearer and higher-level
|
||||
manner.
|
||||
|
||||
QK
|
||||
---
|
||||
|
||||
- As shown the pseudo code below, before the entire for loop block, we
|
||||
fetch the query data for one token and store it in ``q_vecs``. Then,
|
||||
in the outer for loop, we iterate through different ``k_ptrs`` that
|
||||
point to different tokens and prepare the ``k_vecs`` in the inner for
|
||||
loop. Finally, we perform the dot multiplication between the
|
||||
``q_vecs`` and each ``k_vecs``.
|
||||
|
||||
.. code:: cpp
|
||||
|
||||
q_vecs = ...
|
||||
for ... {
|
||||
k_ptr = ...
|
||||
for ... {
|
||||
k_vecs[i] = ...
|
||||
}
|
||||
...
|
||||
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
|
||||
}
|
||||
|
||||
- As mentioned before, for each thread, it only fetches part of the
|
||||
query and key token data at a time. However, there will be a cross
|
||||
thread group reduction happen in the ``Qk_dot<>::dot`` . So ``qk``
|
||||
returned here is not just between part of the query and key token dot
|
||||
multiplication, but actually a full result between entire query and
|
||||
key token data.
|
||||
- For example, if the value of ``HEAD_SIZE`` is 128 and
|
||||
``THREAD_GROUP_SIZE`` is 2, each thread's ``k_vecs`` will contain
|
||||
total 64 elements. However, the returned ``qk`` is actually the
|
||||
result of dot multiplication between 128 query elements and 128 key
|
||||
elements. If you want to learn more about the details of the dot
|
||||
multiplication and reduction, you may refer to the implementation of
|
||||
``Qk_dot<>::dot``. However, for the sake of simplicity, I will not
|
||||
cover it in this document.
|
||||
|
||||
Softmax
|
||||
-------
|
||||
|
||||
- Next, we need to calculate the normalized softmax for all ``qk``\ s,
|
||||
as shown above, where each :math:`x` represents a ``qk``. To do this,
|
||||
we must obtain the reduced value of ``qk_max``\ (:math:`m(x)`) and
|
||||
the ``exp_sum``\ (:math:`\ell(x)`) of all ``qk``\ s. The reduction
|
||||
should be performed across the entire thread block, encompassing
|
||||
results between the query token and all context key tokens.
|
||||
|
||||
.. math::
|
||||
:nowrap:
|
||||
|
||||
\begin{gather*}
|
||||
m(x):=\max _i \quad x_i \\ \quad f(x):=\left[\begin{array}{lll}e^{x_1-m(x)} & \ldots & e^{x_B-m(x)}\end{array}\right]\\ \quad \ell(x):=\sum_i f(x)_i \\
|
||||
\quad \operatorname{softmax}(x):=\frac{f(x)}{\ell(x)}
|
||||
\end{gather*}
|
||||
|
||||
``qk_max`` and ``logits``
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
- Just right after we get the ``qk`` result, we can set the temporary
|
||||
``logits`` result with ``qk`` (In the end, the ``logits`` should
|
||||
store the normalized softmax result). Also we can compare and collect
|
||||
the ``qk_max`` for all ``qk``\ s that are calculated by current
|
||||
thread group.
|
||||
|
||||
.. code:: cpp
|
||||
|
||||
if (thread_group_offset == 0) {
|
||||
const bool mask = token_idx >= context_len;
|
||||
logits[token_idx - start_token_idx] = mask ? 0.f : qk;
|
||||
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
|
||||
}
|
||||
|
||||
- Please note that the ``logits`` here is on shared memory, so each
|
||||
thread group will set the fields for its own assigned context tokens.
|
||||
Overall, the size of logits should be number of context tokens.
|
||||
|
||||
.. code:: cpp
|
||||
|
||||
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
|
||||
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
|
||||
}
|
||||
|
||||
if (lane == 0) {
|
||||
red_smem[warp_idx] = qk_max;
|
||||
}
|
||||
|
||||
- Then we need to get the reduced ``qk_max`` across each warp. The main
|
||||
idea is to make threads in warp to communicate with each other and
|
||||
get the final max ``qk`` .
|
||||
|
||||
.. code:: cpp
|
||||
|
||||
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
||||
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
|
||||
}
|
||||
qk_max = VLLM_SHFL_SYNC(qk_max, 0);
|
||||
|
||||
- Finally, we can get the reduced ``qk_max`` from whole thread block by
|
||||
compare the ``qk_max`` from all warps in this thread block. Then we
|
||||
need to broadcast the final result to each thread.
|
||||
|
||||
``exp_sum``
|
||||
~~~~~~~~~~~
|
||||
|
||||
- Similar to ``qk_max``, we need to get the reduced sum value from the
|
||||
entire thread block too.
|
||||
|
||||
.. code:: cpp
|
||||
|
||||
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
|
||||
float val = __expf(logits[i] - qk_max);
|
||||
logits[i] = val;
|
||||
exp_sum += val;
|
||||
}
|
||||
...
|
||||
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
|
||||
|
||||
- Firstly, sum all exp values from each thread group, and meanwhile,
|
||||
convert each entry of ``logits`` from ``qk`` to ``exp(qk - qk_max)``.
|
||||
Please note, the ``qk_max`` here is already the max ``qk`` across the
|
||||
whole thread block. And then we can do reduction for ``exp_sum``
|
||||
across whole thread block just like the ``qk_max``.
|
||||
|
||||
.. code:: cpp
|
||||
|
||||
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
|
||||
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
|
||||
logits[i] *= inv_sum;
|
||||
}
|
||||
|
||||
- Finally, with the reduced ``qk_max`` and ``exp_sum``, we can obtain
|
||||
the final normalized softmax result as ``logits``. This ``logits``
|
||||
variable will be used for dot multiplication with the value data in
|
||||
later steps. Now, it should store the normalized softmax result of
|
||||
``qk`` for all assigned context tokens.
|
||||
|
||||
Value
|
||||
-----
|
||||
|
||||
.. figure:: ../../assets/kernel/value.png
|
||||
:alt: value
|
||||
:width: 70%
|
||||
:align: center
|
||||
|
||||
Value data of all context tokens at one head
|
||||
|
||||
.. figure:: ../../assets/kernel/logits_vec.png
|
||||
:alt: logits_vec
|
||||
:width: 50%
|
||||
:align: center
|
||||
|
||||
``logits_vec`` for one thread
|
||||
|
||||
.. figure:: ../../assets/kernel/v_vec.png
|
||||
:alt: v_vec
|
||||
:width: 70%
|
||||
:align: center
|
||||
|
||||
List of ``v_vec`` for one thread
|
||||
|
||||
- Now we need to retrieve the value data and perform dot multiplication
|
||||
with ``logits``. Unlike query and key, there is no thread group
|
||||
concept for value data. As shown in diagram, different from key token
|
||||
memory layout, elements from the same column correspond to the same
|
||||
value token. For one block of value data, there are ``HEAD_SIZE`` of
|
||||
rows and ``BLOCK_SIZE`` of columns that are split into multiple
|
||||
``v_vecs``.
|
||||
- Each thread always fetches ``V_VEC_SIZE`` elements from the same
|
||||
``V_VEC_SIZE`` of tokens at a time. As a result, a single thread
|
||||
retrieves multiple ``v_vec``\ s from different rows and the same
|
||||
columns through multiple inner iterations. For each ``v_vec``, it
|
||||
needs to be dot multiplied with the corresponding ``logits_vec``,
|
||||
which is also ``V_VEC_SIZE`` elements from ``logits``. Overall, with
|
||||
multiple inner iterations, each warp will process one block of value
|
||||
tokens. And with multiple outer iterations, the whole context value
|
||||
tokens are processd
|
||||
|
||||
.. code:: cpp
|
||||
|
||||
float accs[NUM_ROWS_PER_THREAD];
|
||||
for ... { // Iteration over different blocks.
|
||||
logits_vec = ...
|
||||
for ... { // Iteration over different rows.
|
||||
v_vec = ...
|
||||
...
|
||||
accs[i] += dot(logits_vec, v_vec);
|
||||
}
|
||||
}
|
||||
|
||||
- As shown in the above pseudo code, in the outer loop, similar to
|
||||
``k_ptr``, ``logits_vec`` iterates over different blocks and reads
|
||||
``V_VEC_SIZE`` elements from ``logits``. In the inner loop, each
|
||||
thread reads ``V_VEC_SIZE`` elements from the same tokens as a
|
||||
``v_vec`` and performs dot multiplication. It is important to note
|
||||
that in each inner iteration, the thread fetches different head
|
||||
position elements for the same tokens. The dot result is then
|
||||
accumulated in ``accs``. Therefore, each entry of ``accs`` is mapped
|
||||
to a head position assigned to the current thread.
|
||||
- For example, if ``BLOCK_SIZE`` is 16 and ``V_VEC_SIZE`` is 8, each
|
||||
thread fetches 8 value elements for 8 tokens at a time. Each element
|
||||
is from different tokens at the same head position. If ``HEAD_SIZE``
|
||||
is 128 and ``WARP_SIZE`` is 32, for each inner loop, a warp needs to
|
||||
fetch ``WARP_SIZE * V_VEC_SIZE = 256`` elements. This means there are
|
||||
a total of 128 \* 16 / 256 = 8 inner iterations for a warp to handle
|
||||
a whole block of value tokens. And each ``accs`` in each thread
|
||||
contains 8 elements that accumulated at 8 different head positions.
|
||||
For the thread 0, the ``accs`` variable will have 8 elements, which
|
||||
are 0th, 32th … 224th elements of a value head that are accumulated
|
||||
from all assigned 8 tokens.
|
||||
|
||||
LV
|
||||
---
|
||||
- Now, we need to perform reduction for ``accs`` within each warp. This
|
||||
process allows each thread to accumulate the ``accs`` for the
|
||||
assigned head positions of all tokens in one block.
|
||||
|
||||
.. code:: cpp
|
||||
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
float acc = accs[i];
|
||||
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
|
||||
acc += VLLM_SHFL_XOR_SYNC(acc, mask);
|
||||
}
|
||||
accs[i] = acc;
|
||||
}
|
||||
|
||||
- Next, we perform reduction for ``accs`` across all warps, allowing
|
||||
each thread to have the accumulation of ``accs`` for the assigned
|
||||
head positions of all context tokens. Please note that each ``accs``
|
||||
in every thread only stores the accumulation for a portion of
|
||||
elements of the entire head for all context tokens. However, overall,
|
||||
all results for output have been calculated but are just stored in
|
||||
different thread register memory.
|
||||
|
||||
.. code:: cpp
|
||||
|
||||
float* out_smem = reinterpret_cast<float*>(shared_mem);
|
||||
for (int i = NUM_WARPS; i > 1; i /= 2) {
|
||||
// Upper warps write to shared memory.
|
||||
...
|
||||
float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
...
|
||||
dst[row_idx] = accs[i];
|
||||
}
|
||||
|
||||
// Lower warps update the output.
|
||||
const float* src = &out_smem[warp_idx * HEAD_SIZE];
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
...
|
||||
accs[i] += src[row_idx];
|
||||
}
|
||||
|
||||
// Write out the accs.
|
||||
}
|
||||
|
||||
Output
|
||||
------
|
||||
|
||||
- Now we can write all of calculated result from local register memory
|
||||
to final output global memory.
|
||||
|
||||
.. code:: cpp
|
||||
|
||||
scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
|
||||
+ head_idx * max_num_partitions * HEAD_SIZE
|
||||
+ partition_idx * HEAD_SIZE;
|
||||
|
||||
- First, we need to define the ``out_ptr`` variable, which points to
|
||||
the start address of the assigned sequence and assigned head.
|
||||
|
||||
.. code:: cpp
|
||||
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
||||
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
|
||||
from_float(*(out_ptr + row_idx), accs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
- Finally, we need to iterate over different assigned head positions
|
||||
and write out the corresponding accumulated result based on the
|
||||
``out_ptr``.
|
||||
16
docs/source/design/multimodal/adding_multimodal_plugin.md
Normal file
16
docs/source/design/multimodal/adding_multimodal_plugin.md
Normal file
@ -0,0 +1,16 @@
|
||||
(adding-multimodal-plugin)=
|
||||
|
||||
# Adding a Multimodal Plugin
|
||||
|
||||
This document teaches you how to add a new modality to vLLM.
|
||||
|
||||
Each modality in vLLM is represented by a {class}`~vllm.multimodal.MultiModalPlugin` and registered to {data}`~vllm.multimodal.MULTIMODAL_REGISTRY`.
|
||||
For vLLM to recognize a new modality type, you have to create a new plugin and then pass it to {meth}`~vllm.multimodal.MultiModalRegistry.register_plugin`.
|
||||
|
||||
The remainder of this document details how to define custom {class}`~vllm.multimodal.MultiModalPlugin` s.
|
||||
|
||||
```{note}
|
||||
This article is a work in progress.
|
||||
```
|
||||
|
||||
% TODO: Add more instructions on how to add new plugins once embeddings is in.
|
||||
@ -1,17 +0,0 @@
|
||||
.. _adding_multimodal_plugin:
|
||||
|
||||
Adding a Multimodal Plugin
|
||||
==========================
|
||||
|
||||
This document teaches you how to add a new modality to vLLM.
|
||||
|
||||
Each modality in vLLM is represented by a :class:`~vllm.multimodal.MultiModalPlugin` and registered to :data:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
|
||||
For vLLM to recognize a new modality type, you have to create a new plugin and then pass it to :meth:`~vllm.multimodal.MultiModalRegistry.register_plugin`.
|
||||
|
||||
The remainder of this document details how to define custom :class:`~vllm.multimodal.MultiModalPlugin` s.
|
||||
|
||||
.. note::
|
||||
This article is a work in progress.
|
||||
|
||||
..
|
||||
TODO: Add more instructions on how to add new plugins once embeddings is in.
|
||||
83
docs/source/design/multimodal/multimodal_index.md
Normal file
83
docs/source/design/multimodal/multimodal_index.md
Normal file
@ -0,0 +1,83 @@
|
||||
(multi-modality)=
|
||||
|
||||
# Multi-Modality
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: vllm.multimodal
|
||||
```
|
||||
|
||||
vLLM provides experimental support for multi-modal models through the {mod}`vllm.multimodal` package.
|
||||
|
||||
Multi-modal inputs can be passed alongside text and token prompts to [supported models](#supported-mm-models)
|
||||
via the `multi_modal_data` field in {class}`vllm.inputs.PromptType`.
|
||||
|
||||
Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities
|
||||
by following [this guide](#adding-multimodal-plugin).
|
||||
|
||||
Looking to add your own multi-modal model? Please follow the instructions listed [here](#enabling-multimodal-inputs).
|
||||
|
||||
## Guides
|
||||
|
||||
```{toctree}
|
||||
:maxdepth: 1
|
||||
|
||||
adding_multimodal_plugin
|
||||
```
|
||||
|
||||
## Module Contents
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: vllm.multimodal
|
||||
```
|
||||
|
||||
### Registry
|
||||
|
||||
```{eval-rst}
|
||||
.. autodata:: vllm.multimodal.MULTIMODAL_REGISTRY
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: vllm.multimodal.MultiModalRegistry
|
||||
:members:
|
||||
:show-inheritance:
|
||||
```
|
||||
|
||||
### Base Classes
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: vllm.multimodal.base
|
||||
:members:
|
||||
:show-inheritance:
|
||||
```
|
||||
|
||||
### Input Classes
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: vllm.multimodal.inputs
|
||||
:members:
|
||||
:show-inheritance:
|
||||
```
|
||||
|
||||
### Audio Classes
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: vllm.multimodal.audio
|
||||
:members:
|
||||
:show-inheritance:
|
||||
```
|
||||
|
||||
### Image Classes
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: vllm.multimodal.image
|
||||
:members:
|
||||
:show-inheritance:
|
||||
```
|
||||
|
||||
### Video Classes
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: vllm.multimodal.video
|
||||
:members:
|
||||
:show-inheritance:
|
||||
```
|
||||
@ -1,66 +0,0 @@
|
||||
.. _multi_modality:
|
||||
|
||||
Multi-Modality
|
||||
==============
|
||||
|
||||
.. currentmodule:: vllm.multimodal
|
||||
|
||||
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.
|
||||
|
||||
Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_mm_models>`
|
||||
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptType`.
|
||||
|
||||
Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities
|
||||
by following :ref:`this guide <adding_multimodal_plugin>`.
|
||||
|
||||
Looking to add your own multi-modal model? Please follow the instructions listed :ref:`here <enabling_multimodal_inputs>`.
|
||||
|
||||
Guides
|
||||
++++++
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
adding_multimodal_plugin
|
||||
|
||||
Module Contents
|
||||
+++++++++++++++
|
||||
|
||||
.. automodule:: vllm.multimodal
|
||||
|
||||
Registry
|
||||
--------
|
||||
|
||||
.. autodata:: vllm.multimodal.MULTIMODAL_REGISTRY
|
||||
|
||||
.. autoclass:: vllm.multimodal.MultiModalRegistry
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
Base Classes
|
||||
------------
|
||||
|
||||
.. autodata:: vllm.multimodal.NestedTensors
|
||||
|
||||
.. autodata:: vllm.multimodal.BatchedTensorInputs
|
||||
|
||||
.. autoclass:: vllm.multimodal.MultiModalDataBuiltins
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autodata:: vllm.multimodal.MultiModalDataDict
|
||||
|
||||
.. autoclass:: vllm.multimodal.MultiModalKwargs
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: vllm.multimodal.MultiModalPlugin
|
||||
:members:
|
||||
:show-inheritance:
|
||||
|
||||
Image Classes
|
||||
-------------
|
||||
|
||||
.. automodule:: vllm.multimodal.image
|
||||
:members:
|
||||
:show-inheritance:
|
||||
@ -2,13 +2,14 @@
|
||||
|
||||
## Debugging
|
||||
|
||||
Please see the [Debugging
|
||||
Tips](https://docs.vllm.ai/en/latest/getting_started/debugging.html#python-multiprocessing)
|
||||
Please see the [Debugging Tips](#debugging-python-multiprocessing)
|
||||
page for information on known issues and how to solve them.
|
||||
|
||||
## Introduction
|
||||
|
||||
*Note that source code references are to the state of the code at the time of writing in December, 2024.*
|
||||
```{important}
|
||||
The source code references are to the state of the code at the time of writing in December, 2024.
|
||||
```
|
||||
|
||||
The use of Python multiprocessing in vLLM is complicated by:
|
||||
|
||||
@ -20,7 +21,7 @@ This document describes how vLLM deals with these challenges.
|
||||
|
||||
## Multiprocessing Methods
|
||||
|
||||
[Python multiprocessing methods](https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods) include:
|
||||
[Python multiprocessing methods](https://docs.python.org/3/library/multiprocessing.html.md#contexts-and-start-methods) include:
|
||||
|
||||
- `spawn` - spawn a new Python process. This will be the default as of Python
|
||||
3.14.
|
||||
@ -82,7 +83,7 @@ There are other miscellaneous places hard-coding the use of `spawn`:
|
||||
|
||||
Related PRs:
|
||||
|
||||
- <https://github.com/vllm-project/vllm/pull/8823>
|
||||
- <gh-pr:8823>
|
||||
|
||||
## Prior State in v1
|
||||
|
||||
@ -96,7 +97,7 @@ engine core.
|
||||
|
||||
- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/v1/engine/llm_engine.py#L93-L95>
|
||||
- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/v1/engine/llm_engine.py#L70-L77>
|
||||
- https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/v1/engine/core_client.py#L44-L45
|
||||
- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/v1/engine/core_client.py#L44-L45>
|
||||
|
||||
It was off by default for all the reasons mentioned above - compatibility with
|
||||
dependencies and code using vLLM as a library.
|
||||
@ -119,17 +120,17 @@ instruct users to either add a `__main__` guard or to disable multiprocessing.
|
||||
If that known-failure case occurs, the user will see two messages that explain
|
||||
what is happening. First, a log message from vLLM:
|
||||
|
||||
```
|
||||
WARNING 12-11 14:50:37 multiproc_worker_utils.py:281] CUDA was previously
|
||||
initialized. We must use the `spawn` multiprocessing start method. Setting
|
||||
VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. See
|
||||
https://docs.vllm.ai/en/latest/getting_started/debugging.html#python-multiprocessing
|
||||
for more information.
|
||||
```console
|
||||
WARNING 12-11 14:50:37 multiproc_worker_utils.py:281] CUDA was previously
|
||||
initialized. We must use the `spawn` multiprocessing start method. Setting
|
||||
VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. See
|
||||
https://docs.vllm.ai/en/latest/getting_started/debugging.html#python-multiprocessing
|
||||
for more information.
|
||||
```
|
||||
|
||||
Second, Python itself will raise an exception with a nice explanation:
|
||||
|
||||
```
|
||||
```console
|
||||
RuntimeError:
|
||||
An attempt has been made to start a new process before the
|
||||
current process has finished its bootstrapping phase.
|
||||
|
||||
56
docs/source/design/plugin_system.md
Normal file
56
docs/source/design/plugin_system.md
Normal file
@ -0,0 +1,56 @@
|
||||
(plugin-system)=
|
||||
|
||||
# vLLM's Plugin System
|
||||
|
||||
The community frequently requests the ability to extend vLLM with custom features. To facilitate this, vLLM includes a plugin system that allows users to add custom features without modifying the vLLM codebase. This document explains how plugins work in vLLM and how to create a plugin for vLLM.
|
||||
|
||||
## How Plugins Work in vLLM
|
||||
|
||||
Plugins are user-registered code that vLLM executes. Given vLLM's architecture (see [](#arch-overview)), multiple processes may be involved, especially when using distributed inference with various parallelism techniques. To enable plugins successfully, every process created by vLLM needs to load the plugin. This is done by the [load_general_plugins](https://github.com/vllm-project/vllm/blob/c76ac49d266e27aa3fea84ef2df1f813d24c91c7/vllm/plugins/__init__.py#L16) function in the `vllm.plugins` module. This function is called for every process created by vLLM before it starts any work.
|
||||
|
||||
## How vLLM Discovers Plugins
|
||||
|
||||
vLLM's plugin system uses the standard Python `entry_points` mechanism. This mechanism allows developers to register functions in their Python packages for use by other packages. An example of a plugin:
|
||||
|
||||
```python
|
||||
# inside `setup.py` file
|
||||
from setuptools import setup
|
||||
|
||||
setup(name='vllm_add_dummy_model',
|
||||
version='0.1',
|
||||
packages=['vllm_add_dummy_model'],
|
||||
entry_points={
|
||||
'vllm.general_plugins':
|
||||
["register_dummy_model = vllm_add_dummy_model:register"]
|
||||
})
|
||||
|
||||
# inside `vllm_add_dummy_model.py` file
|
||||
def register():
|
||||
from vllm import ModelRegistry
|
||||
|
||||
if "MyLlava" not in ModelRegistry.get_supported_archs():
|
||||
ModelRegistry.register_model("MyLlava",
|
||||
"vllm_add_dummy_model.my_llava:MyLlava")
|
||||
```
|
||||
|
||||
For more information on adding entry points to your package, please check the [official documentation](https://setuptools.pypa.io/en/latest/userguide/entry_point.html).
|
||||
|
||||
Every plugin has three parts:
|
||||
|
||||
1. **Plugin group**: The name of the entry point group. vLLM uses the entry point group `vllm.general_plugins` to register general plugins. This is the key of `entry_points` in the `setup.py` file. Always use `vllm.general_plugins` for vLLM's general plugins.
|
||||
2. **Plugin name**: The name of the plugin. This is the value in the dictionary of the `entry_points` dictionary. In the example above, the plugin name is `register_dummy_model`. Plugins can be filtered by their names using the `VLLM_PLUGINS` environment variable. To load only a specific plugin, set `VLLM_PLUGINS` to the plugin name.
|
||||
3. **Plugin value**: The fully qualified name of the function to register in the plugin system. In the example above, the plugin value is `vllm_add_dummy_model:register`, which refers to a function named `register` in the `vllm_add_dummy_model` module.
|
||||
|
||||
## Types of supported plugins
|
||||
|
||||
- **General plugins** (with group name `vllm.general_plugins`): The primary use case for these plugins is to register custom, out-of-the-tree models into vLLM. This is done by calling `ModelRegistry.register_model` to register the model inside the plugin function.
|
||||
|
||||
- **Platform plugins** (with group name `vllm.platform_plugins`): The primary use case for these plugins is to register custom, out-of-the-tree platforms into vLLM. The plugin function should return `None` when the platform is not supported in the current environment, or the platform class's fully qualified name when the platform is supported.
|
||||
|
||||
## Guidelines for Writing Plugins
|
||||
|
||||
- **Being re-entrant**: The function specified in the entry point should be re-entrant, meaning it can be called multiple times without causing issues. This is necessary because the function might be called multiple times in some processes.
|
||||
|
||||
## Compatibility Guarantee
|
||||
|
||||
vLLM guarantees the interface of documented plugins, such as `ModelRegistry.register_model`, will always be available for plugins to register models. However, it is the responsibility of plugin developers to ensure their plugins are compatible with the version of vLLM they are targeting. For example, `"vllm_add_dummy_model.my_llava:MyLlava"` should be compatible with the version of vLLM that the plugin targets. The interface for the model may change during vLLM's development.
|
||||
@ -1,62 +0,0 @@
|
||||
.. _plugin_system:
|
||||
|
||||
vLLM's Plugin System
|
||||
====================
|
||||
|
||||
The community frequently requests the ability to extend vLLM with custom features. To facilitate this, vLLM includes a plugin system that allows users to add custom features without modifying the vLLM codebase. This document explains how plugins work in vLLM and how to create a plugin for vLLM.
|
||||
|
||||
How Plugins Work in vLLM
|
||||
------------------------
|
||||
|
||||
Plugins are user-registered code that vLLM executes. Given vLLM's architecture (see :ref:`arch_overview`), multiple processes may be involved, especially when using distributed inference with various parallelism techniques. To enable plugins successfully, every process created by vLLM needs to load the plugin. This is done by the `load_general_plugins <https://github.com/vllm-project/vllm/blob/c76ac49d266e27aa3fea84ef2df1f813d24c91c7/vllm/plugins/__init__.py#L16>`__ function in the ``vllm.plugins`` module. This function is called for every process created by vLLM before it starts any work.
|
||||
|
||||
How vLLM Discovers Plugins
|
||||
--------------------------
|
||||
|
||||
vLLM's plugin system uses the standard Python ``entry_points`` mechanism. This mechanism allows developers to register functions in their Python packages for use by other packages. An example of a plugin:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# inside `setup.py` file
|
||||
from setuptools import setup
|
||||
|
||||
setup(name='vllm_add_dummy_model',
|
||||
version='0.1',
|
||||
packages=['vllm_add_dummy_model'],
|
||||
entry_points={
|
||||
'vllm.general_plugins':
|
||||
["register_dummy_model = vllm_add_dummy_model:register"]
|
||||
})
|
||||
|
||||
# inside `vllm_add_dummy_model.py` file
|
||||
def register():
|
||||
from vllm import ModelRegistry
|
||||
|
||||
if "MyLlava" not in ModelRegistry.get_supported_archs():
|
||||
ModelRegistry.register_model("MyLlava",
|
||||
"vllm_add_dummy_model.my_llava:MyLlava")
|
||||
|
||||
For more information on adding entry points to your package, please check the `official documentation <https://setuptools.pypa.io/en/latest/userguide/entry_point.html>`__.
|
||||
|
||||
Every plugin has three parts:
|
||||
|
||||
1. **Plugin group**: The name of the entry point group. vLLM uses the entry point group ``vllm.general_plugins`` to register general plugins. This is the key of ``entry_points`` in the ``setup.py`` file. Always use ``vllm.general_plugins`` for vLLM's general plugins.
|
||||
|
||||
2. **Plugin name**: The name of the plugin. This is the value in the dictionary of the ``entry_points`` dictionary. In the example above, the plugin name is ``register_dummy_model``. Plugins can be filtered by their names using the ``VLLM_PLUGINS`` environment variable. To load only a specific plugin, set ``VLLM_PLUGINS`` to the plugin name.
|
||||
|
||||
3. **Plugin value**: The fully qualified name of the function to register in the plugin system. In the example above, the plugin value is ``vllm_add_dummy_model:register``, which refers to a function named ``register`` in the ``vllm_add_dummy_model`` module.
|
||||
|
||||
What Can Plugins Do?
|
||||
--------------------
|
||||
|
||||
Currently, the primary use case for plugins is to register custom, out-of-the-tree models into vLLM. This is done by calling ``ModelRegistry.register_model`` to register the model. In the future, the plugin system may be extended to support more features, such as swapping in custom implementations for certain classes in vLLM.
|
||||
|
||||
Guidelines for Writing Plugins
|
||||
------------------------------
|
||||
|
||||
- **Being re-entrant**: The function specified in the entry point should be re-entrant, meaning it can be called multiple times without causing issues. This is necessary because the function might be called multiple times in some processes.
|
||||
|
||||
Compatibility Guarantee
|
||||
-----------------------
|
||||
|
||||
vLLM guarantees the interface of documented plugins, such as ``ModelRegistry.register_model``, will always be available for plugins to register models. However, it is the responsibility of plugin developers to ensure their plugins are compatible with the version of vLLM they are targeting. For example, ``"vllm_add_dummy_model.my_llava:MyLlava"`` should be compatible with the version of vLLM that the plugin targets. The interface for the model may change during vLLM's development.
|
||||
@ -1,6 +1,7 @@
|
||||
AsyncLLMEngine
|
||||
=================================
|
||||
# AsyncLLMEngine
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: vllm.AsyncLLMEngine
|
||||
:members:
|
||||
:show-inheritance:
|
||||
```
|
||||
17
docs/source/dev/engine/engine_index.md
Normal file
17
docs/source/dev/engine/engine_index.md
Normal file
@ -0,0 +1,17 @@
|
||||
# vLLM Engine
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: vllm.engine
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. currentmodule:: vllm.engine
|
||||
```
|
||||
|
||||
```{toctree}
|
||||
:caption: Engines
|
||||
:maxdepth: 2
|
||||
|
||||
llm_engine
|
||||
async_llm_engine
|
||||
```
|
||||
@ -1,13 +0,0 @@
|
||||
vLLM Engine
|
||||
=================================
|
||||
|
||||
.. automodule:: vllm.engine
|
||||
.. currentmodule:: vllm.engine
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:caption: Engines
|
||||
|
||||
llm_engine
|
||||
async_llm_engine
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
LLMEngine
|
||||
=================================
|
||||
# LLMEngine
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: vllm.LLMEngine
|
||||
:members:
|
||||
:show-inheritance:
|
||||
```
|
||||
@ -1,6 +1,7 @@
|
||||
LLM Class
|
||||
=========
|
||||
# LLM Class
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: vllm.LLM
|
||||
:members:
|
||||
:show-inheritance:
|
||||
```
|
||||
@ -1,14 +1,19 @@
|
||||
LLM Inputs
|
||||
==========
|
||||
# LLM Inputs
|
||||
|
||||
```{eval-rst}
|
||||
.. autodata:: vllm.inputs.PromptType
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: vllm.inputs.TextPrompt
|
||||
:show-inheritance:
|
||||
:members:
|
||||
:member-order: bysource
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: vllm.inputs.TokensPrompt
|
||||
:show-inheritance:
|
||||
:members:
|
||||
:member-order: bysource
|
||||
```
|
||||
8
docs/source/dev/offline_inference/offline_index.md
Normal file
8
docs/source/dev/offline_inference/offline_index.md
Normal file
@ -0,0 +1,8 @@
|
||||
# Offline Inference
|
||||
|
||||
```{toctree}
|
||||
:maxdepth: 1
|
||||
|
||||
llm
|
||||
llm_inputs
|
||||
```
|
||||
@ -1,8 +0,0 @@
|
||||
Offline Inference
|
||||
=================================
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
llm
|
||||
llm_inputs
|
||||
@ -1,5 +1,6 @@
|
||||
Pooling Parameters
|
||||
==================
|
||||
# Pooling Parameters
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: vllm.PoolingParams
|
||||
:members:
|
||||
```
|
||||
@ -1,5 +1,6 @@
|
||||
Sampling Parameters
|
||||
===================
|
||||
# Sampling Parameters
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: vllm.SamplingParams
|
||||
:members:
|
||||
```
|
||||
@ -15,18 +15,12 @@ def fix_case(text: str) -> str:
|
||||
return text
|
||||
|
||||
|
||||
def underline(title: str, character: str = "=") -> str:
|
||||
return f"{title}\n{character * len(title)}"
|
||||
|
||||
|
||||
def generate_title(filename: str) -> str:
|
||||
# Turn filename into a title
|
||||
title = filename.replace("_", " ").title()
|
||||
# Handle acronyms and names
|
||||
title = fix_case(title)
|
||||
# Underline title
|
||||
title = underline(title)
|
||||
return title
|
||||
return f"# {title}"
|
||||
|
||||
|
||||
def generate_examples():
|
||||
@ -38,24 +32,23 @@ def generate_examples():
|
||||
|
||||
# Destination paths
|
||||
doc_dir = root_dir / "docs/source/getting_started/examples"
|
||||
doc_paths = [doc_dir / f"{path.stem}.rst" for path in script_paths]
|
||||
doc_paths = [doc_dir / f"{path.stem}.md" for path in script_paths]
|
||||
|
||||
# Generate the example docs for each example script
|
||||
for script_path, doc_path in zip(script_paths, doc_paths):
|
||||
script_url = f"https://github.com/vllm-project/vllm/blob/main/examples/{script_path.name}"
|
||||
# Make script_path relative to doc_path and call it include_path
|
||||
include_path = '../../../..' / script_path.relative_to(root_dir)
|
||||
content = (f"{generate_title(doc_path.stem)}\n\n"
|
||||
f"Source {script_url}.\n\n"
|
||||
f".. literalinclude:: {include_path}\n"
|
||||
" :language: python\n"
|
||||
" :linenos:\n")
|
||||
f"Source: <gh-file:examples/{script_path.name}>.\n\n"
|
||||
f"```{{literalinclude}} {include_path}\n"
|
||||
":language: python\n"
|
||||
":linenos:\n```")
|
||||
with open(doc_path, "w+") as f:
|
||||
f.write(content)
|
||||
|
||||
# Generate the toctree for the example scripts
|
||||
with open(doc_dir / "examples_index.template.rst") as f:
|
||||
with open(doc_dir / "examples_index.template.md") as f:
|
||||
examples_index = f.read()
|
||||
with open(doc_dir / "examples_index.rst", "w+") as f:
|
||||
example_docs = "\n ".join(path.stem for path in script_paths)
|
||||
with open(doc_dir / "examples_index.md", "w+") as f:
|
||||
example_docs = "\n".join(path.stem + ".md" for path in script_paths)
|
||||
f.write(examples_index.replace(r"%EXAMPLE_DOCS%", example_docs))
|
||||
|
||||
163
docs/source/getting_started/amd-installation.md
Normal file
163
docs/source/getting_started/amd-installation.md
Normal file
@ -0,0 +1,163 @@
|
||||
(installation-rocm)=
|
||||
|
||||
# Installation with ROCm
|
||||
|
||||
vLLM supports AMD GPUs with ROCm 6.2.
|
||||
|
||||
## Requirements
|
||||
|
||||
- OS: Linux
|
||||
- Python: 3.9 -- 3.12
|
||||
- GPU: MI200s (gfx90a), MI300 (gfx942), Radeon RX 7900 series (gfx1100)
|
||||
- ROCm 6.2
|
||||
|
||||
Installation options:
|
||||
|
||||
1. [Build from source with docker](#build-from-source-docker-rocm)
|
||||
2. [Build from source](#build-from-source-rocm)
|
||||
|
||||
(build-from-source-docker-rocm)=
|
||||
|
||||
## Option 1: Build from source with docker (recommended)
|
||||
|
||||
You can build and install vLLM from source.
|
||||
|
||||
First, build a docker image from <gh-file:Dockerfile.rocm> and launch a docker container from the image.
|
||||
It is important that the user kicks off the docker build using buildkit. Either the user put DOCKER_BUILDKIT=1 as environment variable when calling docker build command, or the user needs to setup buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon:
|
||||
|
||||
```console
|
||||
{
|
||||
"features": {
|
||||
"buildkit": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
<gh-file:Dockerfile.rocm> uses ROCm 6.2 by default, but also supports ROCm 5.7, 6.0 and 6.1 in older vLLM branches.
|
||||
It provides flexibility to customize the build of docker image using the following arguments:
|
||||
|
||||
- `BASE_IMAGE`: specifies the base image used when running `docker build`, specifically the PyTorch on ROCm base image.
|
||||
- `BUILD_FA`: specifies whether to build CK flash-attention. The default is 1. For [Radeon RX 7900 series (gfx1100)](https://rocm.docs.amd.com/projects/radeon/en/latest/index.html), this should be set to 0 before flash-attention supports this target.
|
||||
- `FX_GFX_ARCHS`: specifies the GFX architecture that is used to build CK flash-attention, for example, `gfx90a;gfx942` for MI200 and MI300. The default is `gfx90a;gfx942`
|
||||
- `FA_BRANCH`: specifies the branch used to build the CK flash-attention in [ROCm's flash-attention repo](https://github.com/ROCmSoftwarePlatform/flash-attention). The default is `ae7928c`
|
||||
- `BUILD_TRITON`: specifies whether to build triton flash-attention. The default value is 1.
|
||||
|
||||
Their values can be passed in when running `docker build` with `--build-arg` options.
|
||||
|
||||
To build vllm on ROCm 6.2 for MI200 and MI300 series, you can use the default:
|
||||
|
||||
```console
|
||||
$ DOCKER_BUILDKIT=1 docker build -f Dockerfile.rocm -t vllm-rocm .
|
||||
```
|
||||
|
||||
To build vllm on ROCm 6.2 for Radeon RX7900 series (gfx1100), you should specify `BUILD_FA` as below:
|
||||
|
||||
```console
|
||||
$ DOCKER_BUILDKIT=1 docker build --build-arg BUILD_FA="0" -f Dockerfile.rocm -t vllm-rocm .
|
||||
```
|
||||
|
||||
To run the above docker image `vllm-rocm`, use the below command:
|
||||
|
||||
```console
|
||||
$ docker run -it \
|
||||
--network=host \
|
||||
--group-add=video \
|
||||
--ipc=host \
|
||||
--cap-add=SYS_PTRACE \
|
||||
--security-opt seccomp=unconfined \
|
||||
--device /dev/kfd \
|
||||
--device /dev/dri \
|
||||
-v <path/to/model>:/app/model \
|
||||
vllm-rocm \
|
||||
bash
|
||||
```
|
||||
|
||||
Where the `<path/to/model>` is the location where the model is stored, for example, the weights for llama2 or llama3 models.
|
||||
|
||||
(build-from-source-rocm)=
|
||||
|
||||
## Option 2: Build from source
|
||||
|
||||
0. Install prerequisites (skip if you are already in an environment/docker with the following installed):
|
||||
|
||||
- [ROCm](https://rocm.docs.amd.com/en/latest/deploy/linux/index.html)
|
||||
- [PyTorch](https://pytorch.org/)
|
||||
|
||||
For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.3.0`, `rocm/pytorch-nightly`.
|
||||
|
||||
Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTorch installation guide in PyTorch [Getting Started](https://pytorch.org/get-started/locally/)
|
||||
|
||||
1. Install [Triton flash attention for ROCm](https://github.com/ROCm/triton)
|
||||
|
||||
Install ROCm's Triton flash attention (the default triton-mlir branch) following the instructions from [ROCm/triton](https://github.com/ROCm/triton/blob/triton-mlir/README.md)
|
||||
|
||||
```console
|
||||
$ python3 -m pip install ninja cmake wheel pybind11
|
||||
$ pip uninstall -y triton
|
||||
$ git clone https://github.com/OpenAI/triton.git
|
||||
$ cd triton
|
||||
$ git checkout e192dba
|
||||
$ cd python
|
||||
$ pip3 install .
|
||||
$ cd ../..
|
||||
```
|
||||
|
||||
```{note}
|
||||
- If you see HTTP issue related to downloading packages during building triton, please try again as the HTTP error is intermittent.
|
||||
```
|
||||
|
||||
2. Optionally, if you choose to use CK flash attention, you can install [flash attention for ROCm](https://github.com/ROCm/flash-attention/tree/ck_tile)
|
||||
|
||||
Install ROCm's flash attention (v2.5.9.post1) following the instructions from [ROCm/flash-attention](https://github.com/ROCm/flash-attention/tree/ck_tile#amd-gpurocm-support)
|
||||
Alternatively, wheels intended for vLLM use can be accessed under the releases.
|
||||
|
||||
For example, for ROCm 6.2, suppose your gfx arch is `gfx90a`. To get your gfx architecture, run `rocminfo |grep gfx`.
|
||||
|
||||
```console
|
||||
$ git clone https://github.com/ROCm/flash-attention.git
|
||||
$ cd flash-attention
|
||||
$ git checkout 3cea2fb
|
||||
$ git submodule update --init
|
||||
$ GPU_ARCHS="gfx90a" python3 setup.py install
|
||||
$ cd ..
|
||||
```
|
||||
|
||||
```{note}
|
||||
- You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)
|
||||
```
|
||||
|
||||
3. Build vLLM. For example, vLLM on ROCM 6.2 can be built with the following steps:
|
||||
|
||||
```bash
|
||||
$ pip install --upgrade pip
|
||||
|
||||
# Install PyTorch
|
||||
$ pip uninstall torch -y
|
||||
$ pip install --no-cache-dir --pre torch==2.6.0.dev20241024 --index-url https://download.pytorch.org/whl/nightly/rocm6.2
|
||||
|
||||
# Build & install AMD SMI
|
||||
$ pip install /opt/rocm/share/amd_smi
|
||||
|
||||
# Install dependencies
|
||||
$ pip install --upgrade numba scipy huggingface-hub[cli]
|
||||
$ pip install "numpy<2"
|
||||
$ pip install -r requirements-rocm.txt
|
||||
|
||||
# Build vLLM for MI210/MI250/MI300.
|
||||
$ export PYTORCH_ROCM_ARCH="gfx90a;gfx942"
|
||||
$ python3 setup.py develop
|
||||
```
|
||||
|
||||
This may take 5-10 minutes. Currently, {code}`pip install .` does not work for ROCm installation.
|
||||
|
||||
```{tip}
|
||||
- Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers.
|
||||
- Triton flash attention does not currently support sliding window attention. If using half precision, please use CK flash-attention for sliding window support.
|
||||
- To use CK flash-attention or PyTorch naive attention, please use this flag `export VLLM_USE_TRITON_FLASH_ATTN=0` to turn off triton flash attention.
|
||||
- The ROCm version of PyTorch, ideally, should match the ROCm driver version.
|
||||
```
|
||||
|
||||
```{tip}
|
||||
- For MI300x (gfx942) users, to achieve optimal performance, please refer to [MI300x tuning guide](https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/index.html) for performance optimization and tuning tips on system and workflow level.
|
||||
For vLLM, please refer to [vLLM performance optimization](https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/workload.html#vllm-performance-optimization).
|
||||
```
|
||||
@ -1,178 +0,0 @@
|
||||
.. _installation_rocm:
|
||||
|
||||
Installation with ROCm
|
||||
======================
|
||||
|
||||
vLLM supports AMD GPUs with ROCm 6.2.
|
||||
|
||||
Requirements
|
||||
------------
|
||||
|
||||
* OS: Linux
|
||||
* Python: 3.9 -- 3.12
|
||||
* GPU: MI200s (gfx90a), MI300 (gfx942), Radeon RX 7900 series (gfx1100)
|
||||
* ROCm 6.2
|
||||
|
||||
Installation options:
|
||||
|
||||
#. :ref:`Build from source with docker <build_from_source_docker_rocm>`
|
||||
#. :ref:`Build from source <build_from_source_rocm>`
|
||||
|
||||
.. _build_from_source_docker_rocm:
|
||||
|
||||
Option 1: Build from source with docker (recommended)
|
||||
-----------------------------------------------------
|
||||
|
||||
You can build and install vLLM from source.
|
||||
|
||||
First, build a docker image from `Dockerfile.rocm <https://github.com/vllm-project/vllm/blob/main/Dockerfile.rocm>`_ and launch a docker container from the image.
|
||||
It is important that the user kicks off the docker build using buildkit. Either the user put DOCKER_BUILDKIT=1 as environment variable when calling docker build command, or the user needs to setup buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
{
|
||||
"features": {
|
||||
"buildkit": true
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
`Dockerfile.rocm <https://github.com/vllm-project/vllm/blob/main/Dockerfile.rocm>`_ uses ROCm 6.2 by default, but also supports ROCm 5.7, 6.0 and 6.1 in older vLLM branches.
|
||||
It provides flexibility to customize the build of docker image using the following arguments:
|
||||
|
||||
* `BASE_IMAGE`: specifies the base image used when running ``docker build``, specifically the PyTorch on ROCm base image.
|
||||
* `BUILD_FA`: specifies whether to build CK flash-attention. The default is 1. For `Radeon RX 7900 series (gfx1100) <https://rocm.docs.amd.com/projects/radeon/en/latest/index.html>`_, this should be set to 0 before flash-attention supports this target.
|
||||
* `FX_GFX_ARCHS`: specifies the GFX architecture that is used to build CK flash-attention, for example, `gfx90a;gfx942` for MI200 and MI300. The default is `gfx90a;gfx942`
|
||||
* `FA_BRANCH`: specifies the branch used to build the CK flash-attention in `ROCm's flash-attention repo <https://github.com/ROCmSoftwarePlatform/flash-attention>`_. The default is `ae7928c`
|
||||
* `BUILD_TRITON`: specifies whether to build triton flash-attention. The default value is 1.
|
||||
|
||||
Their values can be passed in when running ``docker build`` with ``--build-arg`` options.
|
||||
|
||||
|
||||
To build vllm on ROCm 6.2 for MI200 and MI300 series, you can use the default:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ DOCKER_BUILDKIT=1 docker build -f Dockerfile.rocm -t vllm-rocm .
|
||||
|
||||
To build vllm on ROCm 6.2 for Radeon RX7900 series (gfx1100), you should specify ``BUILD_FA`` as below:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ DOCKER_BUILDKIT=1 docker build --build-arg BUILD_FA="0" -f Dockerfile.rocm -t vllm-rocm .
|
||||
|
||||
To run the above docker image ``vllm-rocm``, use the below command:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ docker run -it \
|
||||
--network=host \
|
||||
--group-add=video \
|
||||
--ipc=host \
|
||||
--cap-add=SYS_PTRACE \
|
||||
--security-opt seccomp=unconfined \
|
||||
--device /dev/kfd \
|
||||
--device /dev/dri \
|
||||
-v <path/to/model>:/app/model \
|
||||
vllm-rocm \
|
||||
bash
|
||||
|
||||
Where the `<path/to/model>` is the location where the model is stored, for example, the weights for llama2 or llama3 models.
|
||||
|
||||
|
||||
.. _build_from_source_rocm:
|
||||
|
||||
Option 2: Build from source
|
||||
---------------------------
|
||||
|
||||
0. Install prerequisites (skip if you are already in an environment/docker with the following installed):
|
||||
|
||||
- `ROCm <https://rocm.docs.amd.com/en/latest/deploy/linux/index.html>`_
|
||||
- `PyTorch <https://pytorch.org/>`_
|
||||
|
||||
For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.3.0`, `rocm/pytorch-nightly`.
|
||||
|
||||
Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTorch installation guide in PyTorch `Getting Started <https://pytorch.org/get-started/locally/>`_
|
||||
|
||||
|
||||
1. Install `Triton flash attention for ROCm <https://github.com/ROCm/triton>`_
|
||||
|
||||
Install ROCm's Triton flash attention (the default triton-mlir branch) following the instructions from `ROCm/triton <https://github.com/ROCm/triton/blob/triton-mlir/README.md>`_
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python3 -m pip install ninja cmake wheel pybind11
|
||||
$ pip uninstall -y triton
|
||||
$ git clone https://github.com/OpenAI/triton.git
|
||||
$ cd triton
|
||||
$ git checkout e192dba
|
||||
$ cd python
|
||||
$ pip3 install .
|
||||
$ cd ../..
|
||||
|
||||
.. note::
|
||||
- If you see HTTP issue related to downloading packages during building triton, please try again as the HTTP error is intermittent.
|
||||
|
||||
|
||||
2. Optionally, if you choose to use CK flash attention, you can install `flash attention for ROCm <https://github.com/ROCm/flash-attention/tree/ck_tile>`_
|
||||
|
||||
|
||||
Install ROCm's flash attention (v2.5.9.post1) following the instructions from `ROCm/flash-attention <https://github.com/ROCm/flash-attention/tree/ck_tile#amd-gpurocm-support>`_
|
||||
Alternatively, wheels intended for vLLM use can be accessed under the releases.
|
||||
|
||||
For example, for ROCm 6.2, suppose your gfx arch is `gfx90a`.
|
||||
Note to get your gfx architecture, run `rocminfo |grep gfx`.
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ git clone https://github.com/ROCm/flash-attention.git
|
||||
$ cd flash-attention
|
||||
$ git checkout 3cea2fb
|
||||
$ git submodule update --init
|
||||
$ GPU_ARCHS="gfx90a" python3 setup.py install
|
||||
$ cd ..
|
||||
|
||||
.. note::
|
||||
- You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)
|
||||
|
||||
3. Build vLLM.
|
||||
|
||||
For example, vLLM on ROCM 6.2 can be built with the following steps:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ pip install --upgrade pip
|
||||
|
||||
$ # Install PyTorch
|
||||
$ pip uninstall torch -y
|
||||
$ pip install --no-cache-dir --pre torch==2.6.0.dev20240918 --index-url https://download.pytorch.org/whl/nightly/rocm6.2
|
||||
|
||||
$ # Build & install AMD SMI
|
||||
$ pip install /opt/rocm/share/amd_smi
|
||||
|
||||
$ # Install dependencies
|
||||
$ pip install --upgrade numba scipy huggingface-hub[cli]
|
||||
$ pip install "numpy<2"
|
||||
$ pip install -r requirements-rocm.txt
|
||||
|
||||
$ # Build vLLM for MI210/MI250/MI300.
|
||||
$ export PYTORCH_ROCM_ARCH="gfx90a;gfx942"
|
||||
$ python3 setup.py develop
|
||||
|
||||
|
||||
This may take 5-10 minutes. Currently, :code:`pip install .` does not work for ROCm installation.
|
||||
|
||||
|
||||
.. tip::
|
||||
|
||||
- Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers.
|
||||
- Triton flash attention does not currently support sliding window attention. If using half precision, please use CK flash-attention for sliding window support.
|
||||
- To use CK flash-attention or PyTorch naive attention, please use this flag ``export VLLM_USE_TRITON_FLASH_ATTN=0`` to turn off triton flash attention.
|
||||
- The ROCm version of PyTorch, ideally, should match the ROCm driver version.
|
||||
|
||||
|
||||
.. tip::
|
||||
- For MI300x (gfx942) users, to achieve optimal performance, please refer to `MI300x tuning guide <https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/index.html>`_ for performance optimization and tuning tips on system and workflow level.
|
||||
For vLLM, please refer to `vLLM performance optimization <https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/workload.html#vllm-performance-optimization>`_.
|
||||
|
||||
|
||||
46
docs/source/getting_started/arm-installation.md
Normal file
46
docs/source/getting_started/arm-installation.md
Normal file
@ -0,0 +1,46 @@
|
||||
(installation-arm)=
|
||||
|
||||
# Installation for ARM CPUs
|
||||
|
||||
vLLM has been adapted to work on ARM64 CPUs with NEON support, leveraging the CPU backend initially developed for the x86 platform. This guide provides installation instructions specific to ARM. For additional details on supported features, refer to the x86 platform documentation covering:
|
||||
|
||||
- CPU backend inference capabilities
|
||||
- Relevant runtime environment variables
|
||||
- Performance optimization tips
|
||||
|
||||
ARM CPU backend currently supports Float32, FP16 and BFloat16 datatypes.
|
||||
Contents:
|
||||
|
||||
1. [Requirements](#arm-backend-requirements)
|
||||
2. [Quick Start with Dockerfile](#arm-backend-quick-start-dockerfile)
|
||||
3. [Building from Source](#build-arm-backend-from-source)
|
||||
|
||||
(arm-backend-requirements)=
|
||||
|
||||
## Requirements
|
||||
|
||||
- **Operating System**: Linux or macOS
|
||||
- **Compiler**: `gcc/g++ >= 12.3.0` (optional, but recommended)
|
||||
- **Instruction Set Architecture (ISA)**: NEON support is required
|
||||
|
||||
(arm-backend-quick-start-dockerfile)=
|
||||
|
||||
## Quick Start with Dockerfile
|
||||
|
||||
You can quickly set up vLLM on ARM using Docker:
|
||||
|
||||
```console
|
||||
$ docker build -f Dockerfile.arm -t vllm-cpu-env --shm-size=4g .
|
||||
$ docker run -it \
|
||||
--rm \
|
||||
--network=host \
|
||||
--cpuset-cpus=<cpu-id-list, optional> \
|
||||
--cpuset-mems=<memory-node, optional> \
|
||||
vllm-cpu-env
|
||||
```
|
||||
|
||||
(build-arm-backend-from-source)=
|
||||
|
||||
## Building from Source
|
||||
|
||||
To build vLLM from source on Ubuntu 22.04 or other Linux distributions, follow a similar process as with x86. Testing has been conducted on AWS Graviton3 instances for compatibility.
|
||||
@ -1,50 +0,0 @@
|
||||
.. _installation_arm:
|
||||
|
||||
Installation for ARM CPUs
|
||||
=========================
|
||||
|
||||
vLLM has been adapted to work on ARM64 CPUs with NEON support, leveraging the CPU backend initially developed for the x86 platform. This guide provides installation instructions specific to ARM. For additional details on supported features, refer to the x86 platform documentation covering:
|
||||
|
||||
* CPU backend inference capabilities
|
||||
* Relevant runtime environment variables
|
||||
* Performance optimization tips
|
||||
|
||||
ARM CPU backend currently supports Float32, FP16 and BFloat16 datatypes.
|
||||
Contents:
|
||||
|
||||
1. :ref:`Requirements <arm_backend_requirements>`
|
||||
2. :ref:`Quick Start with Dockerfile <arm_backend_quick_start_dockerfile>`
|
||||
3. :ref:`Building from Source <build_arm_backend_from_source>`
|
||||
|
||||
.. _arm_backend_requirements:
|
||||
|
||||
Requirements
|
||||
------------
|
||||
|
||||
* **Operating System**: Linux or macOS
|
||||
* **Compiler**: gcc/g++ >= 12.3.0 (optional, but recommended)
|
||||
* **Instruction Set Architecture (ISA)**: NEON support is required
|
||||
|
||||
.. _arm_backend_quick_start_dockerfile:
|
||||
|
||||
Quick Start with Dockerfile
|
||||
---------------------------
|
||||
|
||||
You can quickly set up vLLM on ARM using Docker:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ docker build -f Dockerfile.arm -t vllm-cpu-env --shm-size=4g .
|
||||
$ docker run -it \
|
||||
--rm \
|
||||
--network=host \
|
||||
--cpuset-cpus=<cpu-id-list, optional> \
|
||||
--cpuset-mems=<memory-node, optional> \
|
||||
vllm-cpu-env
|
||||
|
||||
.. _build_arm_backend_from_source:
|
||||
|
||||
Building from Source
|
||||
--------------------
|
||||
|
||||
To build vLLM from source on Ubuntu 22.04 or other Linux distributions, follow a similar process as with x86. Testing has been conducted on AWS Graviton3 instances for compatibility.
|
||||
154
docs/source/getting_started/cpu-installation.md
Normal file
154
docs/source/getting_started/cpu-installation.md
Normal file
@ -0,0 +1,154 @@
|
||||
(installation-cpu)=
|
||||
|
||||
# Installation with CPU
|
||||
|
||||
vLLM initially supports basic model inferencing and serving on x86 CPU platform, with data types FP32, FP16 and BF16. vLLM CPU backend supports the following vLLM features:
|
||||
|
||||
- Tensor Parallel
|
||||
- Model Quantization (`INT8 W8A8, AWQ`)
|
||||
- Chunked-prefill
|
||||
- Prefix-caching
|
||||
- FP8-E5M2 KV-Caching (TODO)
|
||||
|
||||
Table of contents:
|
||||
|
||||
1. [Requirements](#cpu-backend-requirements)
|
||||
2. [Quick start using Dockerfile](#cpu-backend-quick-start-dockerfile)
|
||||
3. [Build from source](#build-cpu-backend-from-source)
|
||||
4. [Related runtime environment variables](#env-intro)
|
||||
5. [Intel Extension for PyTorch](#ipex-guidance)
|
||||
6. [Performance tips](#cpu-backend-performance-tips)
|
||||
|
||||
(cpu-backend-requirements)=
|
||||
|
||||
## Requirements
|
||||
|
||||
- OS: Linux
|
||||
- Compiler: `gcc/g++>=12.3.0` (optional, recommended)
|
||||
- Instruction set architecture (ISA) requirement: AVX512 (optional, recommended)
|
||||
|
||||
(cpu-backend-quick-start-dockerfile)=
|
||||
|
||||
## Quick start using Dockerfile
|
||||
|
||||
```console
|
||||
$ docker build -f Dockerfile.cpu -t vllm-cpu-env --shm-size=4g .
|
||||
$ docker run -it \
|
||||
--rm \
|
||||
--network=host \
|
||||
--cpuset-cpus=<cpu-id-list, optional> \
|
||||
--cpuset-mems=<memory-node, optional> \
|
||||
vllm-cpu-env
|
||||
```
|
||||
|
||||
(build-cpu-backend-from-source)=
|
||||
|
||||
## Build from source
|
||||
|
||||
- First, install recommended compiler. We recommend to use `gcc/g++ >= 12.3.0` as the default compiler to avoid potential problems. For example, on Ubuntu 22.4, you can run:
|
||||
|
||||
```console
|
||||
$ sudo apt-get update -y
|
||||
$ sudo apt-get install -y gcc-12 g++-12 libnuma-dev
|
||||
$ sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
|
||||
```
|
||||
|
||||
- Second, install Python packages for vLLM CPU backend building:
|
||||
|
||||
```console
|
||||
$ pip install --upgrade pip
|
||||
$ pip install cmake>=3.26 wheel packaging ninja "setuptools-scm>=8" numpy
|
||||
$ pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
```
|
||||
|
||||
- Finally, build and install vLLM CPU backend:
|
||||
|
||||
```console
|
||||
$ VLLM_TARGET_DEVICE=cpu python setup.py install
|
||||
```
|
||||
|
||||
```{note}
|
||||
- AVX512_BF16 is an extension ISA provides native BF16 data type conversion and vector product instructions, will brings some performance improvement compared with pure AVX512. The CPU backend build script will check the host CPU flags to determine whether to enable AVX512_BF16.
|
||||
- If you want to force enable AVX512_BF16 for the cross-compilation, please set environment variable `VLLM_CPU_AVX512BF16=1` before the building.
|
||||
```
|
||||
|
||||
(env-intro)=
|
||||
|
||||
## Related runtime environment variables
|
||||
|
||||
- `VLLM_CPU_KVCACHE_SPACE`: specify the KV Cache size (e.g, `VLLM_CPU_KVCACHE_SPACE=40` means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users.
|
||||
- `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads. For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores.
|
||||
|
||||
(ipex-guidance)=
|
||||
|
||||
## Intel Extension for PyTorch
|
||||
|
||||
- [Intel Extension for PyTorch (IPEX)](https://github.com/intel/intel-extension-for-pytorch) extends PyTorch with up-to-date features optimizations for an extra performance boost on Intel hardware.
|
||||
|
||||
(cpu-backend-performance-tips)=
|
||||
|
||||
## Performance tips
|
||||
|
||||
- We highly recommend to use TCMalloc for high performance memory allocation and better cache locality. For example, on Ubuntu 22.4, you can run:
|
||||
|
||||
```console
|
||||
$ sudo apt-get install libtcmalloc-minimal4 # install TCMalloc library
|
||||
$ find / -name *libtcmalloc* # find the dynamic link library path
|
||||
$ export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:$LD_PRELOAD # prepend the library to LD_PRELOAD
|
||||
$ python examples/offline_inference.py # run vLLM
|
||||
```
|
||||
|
||||
- When using the online serving, it is recommended to reserve 1-2 CPU cores for the serving framework to avoid CPU oversubscription. For example, on a platform with 32 physical CPU cores, reserving CPU 30 and 31 for the framework and using CPU 0-29 for OpenMP:
|
||||
|
||||
```console
|
||||
$ export VLLM_CPU_KVCACHE_SPACE=40
|
||||
$ export VLLM_CPU_OMP_THREADS_BIND=0-29
|
||||
$ vllm serve facebook/opt-125m
|
||||
```
|
||||
|
||||
- If using vLLM CPU backend on a machine with hyper-threading, it is recommended to bind only one OpenMP thread on each physical CPU core using `VLLM_CPU_OMP_THREADS_BIND`. On a hyper-threading enabled platform with 16 logical CPU cores / 8 physical CPU cores:
|
||||
|
||||
```console
|
||||
$ lscpu -e # check the mapping between logical CPU cores and physical CPU cores
|
||||
|
||||
# The "CPU" column means the logical CPU core IDs, and the "CORE" column means the physical core IDs. On this platform, two logical cores are sharing one physical core.
|
||||
CPU NODE SOCKET CORE L1d:L1i:L2:L3 ONLINE MAXMHZ MINMHZ MHZ
|
||||
0 0 0 0 0:0:0:0 yes 2401.0000 800.0000 800.000
|
||||
1 0 0 1 1:1:1:0 yes 2401.0000 800.0000 800.000
|
||||
2 0 0 2 2:2:2:0 yes 2401.0000 800.0000 800.000
|
||||
3 0 0 3 3:3:3:0 yes 2401.0000 800.0000 800.000
|
||||
4 0 0 4 4:4:4:0 yes 2401.0000 800.0000 800.000
|
||||
5 0 0 5 5:5:5:0 yes 2401.0000 800.0000 800.000
|
||||
6 0 0 6 6:6:6:0 yes 2401.0000 800.0000 800.000
|
||||
7 0 0 7 7:7:7:0 yes 2401.0000 800.0000 800.000
|
||||
8 0 0 0 0:0:0:0 yes 2401.0000 800.0000 800.000
|
||||
9 0 0 1 1:1:1:0 yes 2401.0000 800.0000 800.000
|
||||
10 0 0 2 2:2:2:0 yes 2401.0000 800.0000 800.000
|
||||
11 0 0 3 3:3:3:0 yes 2401.0000 800.0000 800.000
|
||||
12 0 0 4 4:4:4:0 yes 2401.0000 800.0000 800.000
|
||||
13 0 0 5 5:5:5:0 yes 2401.0000 800.0000 800.000
|
||||
14 0 0 6 6:6:6:0 yes 2401.0000 800.0000 800.000
|
||||
15 0 0 7 7:7:7:0 yes 2401.0000 800.0000 800.000
|
||||
|
||||
# On this platform, it is recommend to only bind openMP threads on logical CPU cores 0-7 or 8-15
|
||||
$ export VLLM_CPU_OMP_THREADS_BIND=0-7
|
||||
$ python examples/offline_inference.py
|
||||
```
|
||||
|
||||
- If using vLLM CPU backend on a multi-socket machine with NUMA, be aware to set CPU cores using `VLLM_CPU_OMP_THREADS_BIND` to avoid cross NUMA node memory access.
|
||||
|
||||
## CPU Backend Considerations
|
||||
|
||||
- The CPU backend significantly differs from the GPU backend since the vLLM architecture was originally optimized for GPU use. A number of optimizations are needed to enhance its performance.
|
||||
|
||||
- Decouple the HTTP serving components from the inference components. In a GPU backend configuration, the HTTP serving and tokenization tasks operate on the CPU, while inference runs on the GPU, which typically does not pose a problem. However, in a CPU-based setup, the HTTP serving and tokenization can cause significant context switching and reduced cache efficiency. Therefore, it is strongly recommended to segregate these two components for improved performance.
|
||||
|
||||
- On CPU based setup with NUMA enabled, the memory access performance may be largely impacted by the [topology](https://github.com/intel/intel-extension-for-pytorch/blob/main/docs/tutorials/performance_tuning/tuning_guide.md#non-uniform-memory-access-numa). For NUMA architecture, two optimizations are to recommended: Tensor Parallel or Data Parallel.
|
||||
|
||||
- Using Tensor Parallel for a latency constraints deployment: following GPU backend design, a Megatron-LM's parallel algorithm will be used to shard the model, based on the number of NUMA nodes (e.g. TP = 2 for a two NUMA node system). With [TP feature on CPU](gh-pr:6125) merged, Tensor Parallel is supported for serving and offline inferencing. In general each NUMA node is treated as one GPU card. Below is the example script to enable Tensor Parallel = 2 for serving:
|
||||
|
||||
```console
|
||||
$ VLLM_CPU_KVCACHE_SPACE=40 VLLM_CPU_OMP_THREADS_BIND="0-31|32-63" vllm serve meta-llama/Llama-2-7b-chat-hf -tp=2 --distributed-executor-backend mp
|
||||
```
|
||||
|
||||
- Using Data Parallel for maximum throughput: to launch an LLM serving endpoint on each NUMA node along with one additional load balancer to dispatch the requests to those endpoints. Common solutions like [Nginx](../serving/deploying_with_nginx.md) or HAProxy are recommended. Anyscale Ray project provides the feature on LLM [serving](https://docs.ray.io/en/latest/serve/index.html). Here is the example to setup a scalable LLM serving with [Ray Serve](https://github.com/intel/llm-on-ray/blob/main/docs/setup.md).
|
||||
@ -1,164 +0,0 @@
|
||||
.. _installation_cpu:
|
||||
|
||||
Installation with CPU
|
||||
========================
|
||||
|
||||
vLLM initially supports basic model inferencing and serving on x86 CPU platform, with data types FP32, FP16 and BF16. vLLM CPU backend supports the following vLLM features:
|
||||
|
||||
- Tensor Parallel
|
||||
- Model Quantization (``INT8 W8A8, AWQ``)
|
||||
- Chunked-prefill
|
||||
- Prefix-caching
|
||||
- FP8-E5M2 KV-Caching (TODO)
|
||||
|
||||
Table of contents:
|
||||
|
||||
#. :ref:`Requirements <cpu_backend_requirements>`
|
||||
#. :ref:`Quick start using Dockerfile <cpu_backend_quick_start_dockerfile>`
|
||||
#. :ref:`Build from source <build_cpu_backend_from_source>`
|
||||
#. :ref:`Related runtime environment variables <env_intro>`
|
||||
#. :ref:`Intel Extension for PyTorch <ipex_guidance>`
|
||||
#. :ref:`Performance tips <cpu_backend_performance_tips>`
|
||||
|
||||
.. _cpu_backend_requirements:
|
||||
|
||||
Requirements
|
||||
------------
|
||||
|
||||
* OS: Linux
|
||||
* Compiler: gcc/g++>=12.3.0 (optional, recommended)
|
||||
* Instruction set architecture (ISA) requirement: AVX512 (optional, recommended)
|
||||
|
||||
.. _cpu_backend_quick_start_dockerfile:
|
||||
|
||||
Quick start using Dockerfile
|
||||
----------------------------
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ docker build -f Dockerfile.cpu -t vllm-cpu-env --shm-size=4g .
|
||||
$ docker run -it \
|
||||
--rm \
|
||||
--network=host \
|
||||
--cpuset-cpus=<cpu-id-list, optional> \
|
||||
--cpuset-mems=<memory-node, optional> \
|
||||
vllm-cpu-env
|
||||
|
||||
.. _build_cpu_backend_from_source:
|
||||
|
||||
Build from source
|
||||
-----------------
|
||||
|
||||
- First, install recommended compiler. We recommend to use ``gcc/g++ >= 12.3.0`` as the default compiler to avoid potential problems. For example, on Ubuntu 22.4, you can run:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ sudo apt-get update -y
|
||||
$ sudo apt-get install -y gcc-12 g++-12 libnuma-dev
|
||||
$ sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
|
||||
|
||||
- Second, install Python packages for vLLM CPU backend building:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ pip install --upgrade pip
|
||||
$ pip install cmake>=3.26 wheel packaging ninja "setuptools-scm>=8" numpy
|
||||
$ pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
- Finally, build and install vLLM CPU backend:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ VLLM_TARGET_DEVICE=cpu python setup.py install
|
||||
|
||||
.. note::
|
||||
- AVX512_BF16 is an extension ISA provides native BF16 data type conversion and vector product instructions, will brings some performance improvement compared with pure AVX512. The CPU backend build script will check the host CPU flags to determine whether to enable AVX512_BF16.
|
||||
|
||||
- If you want to force enable AVX512_BF16 for the cross-compilation, please set environment variable VLLM_CPU_AVX512BF16=1 before the building.
|
||||
|
||||
.. _env_intro:
|
||||
|
||||
Related runtime environment variables
|
||||
-------------------------------------
|
||||
|
||||
- ``VLLM_CPU_KVCACHE_SPACE``: specify the KV Cache size (e.g, ``VLLM_CPU_KVCACHE_SPACE=40`` means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users.
|
||||
|
||||
- ``VLLM_CPU_OMP_THREADS_BIND``: specify the CPU cores dedicated to the OpenMP threads. For example, ``VLLM_CPU_OMP_THREADS_BIND=0-31`` means there will be 32 OpenMP threads bound on 0-31 CPU cores. ``VLLM_CPU_OMP_THREADS_BIND=0-31|32-63`` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores.
|
||||
|
||||
.. _ipex_guidance:
|
||||
|
||||
Intel Extension for PyTorch
|
||||
---------------------------
|
||||
|
||||
- `Intel Extension for PyTorch (IPEX) <https://github.com/intel/intel-extension-for-pytorch>`_ extends PyTorch with up-to-date features optimizations for an extra performance boost on Intel hardware.
|
||||
|
||||
.. _cpu_backend_performance_tips:
|
||||
|
||||
Performance tips
|
||||
-----------------
|
||||
|
||||
- We highly recommend to use TCMalloc for high performance memory allocation and better cache locality. For example, on Ubuntu 22.4, you can run:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ sudo apt-get install libtcmalloc-minimal4 # install TCMalloc library
|
||||
$ find / -name *libtcmalloc* # find the dynamic link library path
|
||||
$ export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:$LD_PRELOAD # prepend the library to LD_PRELOAD
|
||||
$ python examples/offline_inference.py # run vLLM
|
||||
|
||||
- When using the online serving, it is recommended to reserve 1-2 CPU cores for the serving framework to avoid CPU oversubscription. For example, on a platform with 32 physical CPU cores, reserving CPU 30 and 31 for the framework and using CPU 0-29 for OpenMP:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ export VLLM_CPU_KVCACHE_SPACE=40
|
||||
$ export VLLM_CPU_OMP_THREADS_BIND=0-29
|
||||
$ vllm serve facebook/opt-125m
|
||||
|
||||
- If using vLLM CPU backend on a machine with hyper-threading, it is recommended to bind only one OpenMP thread on each physical CPU core using ``VLLM_CPU_OMP_THREADS_BIND``. On a hyper-threading enabled platform with 16 logical CPU cores / 8 physical CPU cores:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ lscpu -e # check the mapping between logical CPU cores and physical CPU cores
|
||||
|
||||
# The "CPU" column means the logical CPU core IDs, and the "CORE" column means the physical core IDs. On this platform, two logical cores are sharing one physical core.
|
||||
CPU NODE SOCKET CORE L1d:L1i:L2:L3 ONLINE MAXMHZ MINMHZ MHZ
|
||||
0 0 0 0 0:0:0:0 yes 2401.0000 800.0000 800.000
|
||||
1 0 0 1 1:1:1:0 yes 2401.0000 800.0000 800.000
|
||||
2 0 0 2 2:2:2:0 yes 2401.0000 800.0000 800.000
|
||||
3 0 0 3 3:3:3:0 yes 2401.0000 800.0000 800.000
|
||||
4 0 0 4 4:4:4:0 yes 2401.0000 800.0000 800.000
|
||||
5 0 0 5 5:5:5:0 yes 2401.0000 800.0000 800.000
|
||||
6 0 0 6 6:6:6:0 yes 2401.0000 800.0000 800.000
|
||||
7 0 0 7 7:7:7:0 yes 2401.0000 800.0000 800.000
|
||||
8 0 0 0 0:0:0:0 yes 2401.0000 800.0000 800.000
|
||||
9 0 0 1 1:1:1:0 yes 2401.0000 800.0000 800.000
|
||||
10 0 0 2 2:2:2:0 yes 2401.0000 800.0000 800.000
|
||||
11 0 0 3 3:3:3:0 yes 2401.0000 800.0000 800.000
|
||||
12 0 0 4 4:4:4:0 yes 2401.0000 800.0000 800.000
|
||||
13 0 0 5 5:5:5:0 yes 2401.0000 800.0000 800.000
|
||||
14 0 0 6 6:6:6:0 yes 2401.0000 800.0000 800.000
|
||||
15 0 0 7 7:7:7:0 yes 2401.0000 800.0000 800.000
|
||||
|
||||
# On this platform, it is recommend to only bind openMP threads on logical CPU cores 0-7 or 8-15
|
||||
$ export VLLM_CPU_OMP_THREADS_BIND=0-7
|
||||
$ python examples/offline_inference.py
|
||||
|
||||
- If using vLLM CPU backend on a multi-socket machine with NUMA, be aware to set CPU cores using ``VLLM_CPU_OMP_THREADS_BIND`` to avoid cross NUMA node memory access.
|
||||
|
||||
CPU Backend Considerations
|
||||
--------------------------
|
||||
|
||||
- The CPU backend significantly differs from the GPU backend since the vLLM architecture was originally optimized for GPU use. A number of optimizations are needed to enhance its performance.
|
||||
|
||||
- Decouple the HTTP serving components from the inference components. In a GPU backend configuration, the HTTP serving and tokenization tasks operate on the CPU, while inference runs on the GPU, which typically does not pose a problem. However, in a CPU-based setup, the HTTP serving and tokenization can cause significant context switching and reduced cache efficiency. Therefore, it is strongly recommended to segregate these two components for improved performance.
|
||||
|
||||
- On CPU based setup with NUMA enabled, the memory access performance may be largely impacted by the `topology <https://github.com/intel/intel-extension-for-pytorch/blob/main/docs/tutorials/performance_tuning/tuning_guide.md#non-uniform-memory-access-numa>`_. For NUMA architecture, two optimizations are to recommended: Tensor Parallel or Data Parallel.
|
||||
|
||||
* Using Tensor Parallel for a latency constraints deployment: following GPU backend design, a Megatron-LM's parallel algorithm will be used to shard the model, based on the number of NUMA nodes (e.g. TP = 2 for a two NUMA node system). With `TP feature on CPU <https://github.com/vllm-project/vllm/pull/6125>`_ merged, Tensor Parallel is supported for serving and offline inferencing. In general each NUMA node is treated as one GPU card. Below is the example script to enable Tensor Parallel = 2 for serving:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ VLLM_CPU_KVCACHE_SPACE=40 VLLM_CPU_OMP_THREADS_BIND="0-31|32-63" vllm serve meta-llama/Llama-2-7b-chat-hf -tp=2 --distributed-executor-backend mp
|
||||
|
||||
|
||||
* Using Data Parallel for maximum throughput: to launch an LLM serving endpoint on each NUMA node along with one additional load balancer to dispatch the requests to those endpoints. Common solutions like `Nginx <../serving/deploying_with_nginx.html>`_ or HAProxy are recommended. Anyscale Ray project provides the feature on LLM `serving <https://docs.ray.io/en/latest/serve/index.html>`_. Here is the example to setup a scalable LLM serving with `Ray Serve <https://github.com/intel/llm-on-ray/blob/main/docs/setup.md>`_.
|
||||
200
docs/source/getting_started/debugging.md
Normal file
200
docs/source/getting_started/debugging.md
Normal file
@ -0,0 +1,200 @@
|
||||
(debugging)=
|
||||
|
||||
# Debugging Tips
|
||||
|
||||
This document outlines some debugging strategies you can consider. If you think you've discovered a bug, please [search existing issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue) first to see if it has already been reported. If not, please [file a new issue](https://github.com/vllm-project/vllm/issues/new/choose), providing as much relevant information as possible.
|
||||
|
||||
```{note}
|
||||
Once you've debugged a problem, remember to turn off any debugging environment variables defined, or simply start a new shell to avoid being affected by lingering debugging settings. Otherwise, the system might be slow with debugging functionalities left activated.
|
||||
```
|
||||
|
||||
## Hangs downloading a model
|
||||
|
||||
If the model isn't already downloaded to disk, vLLM will download it from the internet which can take time and depend on your internet connection.
|
||||
It's recommended to download the model first using the [huggingface-cli](https://huggingface.co/docs/huggingface_hub/en/guides/cli) and passing the local path to the model to vLLM. This way, you can isolate the issue.
|
||||
|
||||
## Hangs loading a model from disk
|
||||
|
||||
If the model is large, it can take a long time to load it from disk. Pay attention to where you store the model. Some clusters have shared filesystems across nodes, e.g. a distributed filesystem or a network filesystem, which can be slow.
|
||||
It'd be better to store the model in a local disk. Additionally, have a look at the CPU memory usage, when the model is too large it might take a lot of CPU memory, slowing down the operating system because it needs to frequently swap between disk and memory.
|
||||
|
||||
```{note}
|
||||
To isolate the model downloading and loading issue, you can use the `--load-format dummy` argument to skip loading the model weights. This way, you can check if the model downloading and loading is the bottleneck.
|
||||
```
|
||||
|
||||
## Model is too large
|
||||
|
||||
If the model is too large to fit in a single GPU, you might want to [consider tensor parallelism](#distributed-serving) to split the model across multiple GPUs. In that case, every process will read the whole model and split it into chunks, which makes the disk reading time even longer (proportional to the size of tensor parallelism). You can convert the model checkpoint to a sharded checkpoint using <gh-file:examples/save_sharded_state.py>. The conversion process might take some time, but later you can load the sharded checkpoint much faster. The model loading time should remain constant regardless of the size of tensor parallelism.
|
||||
|
||||
## Enable more logging
|
||||
|
||||
If other strategies don't solve the problem, it's likely that the vLLM instance is stuck somewhere. You can use the following environment variables to help debug the issue:
|
||||
|
||||
- `export VLLM_LOGGING_LEVEL=DEBUG` to turn on more logging.
|
||||
- `export CUDA_LAUNCH_BLOCKING=1` to identify which CUDA kernel is causing the problem.
|
||||
- `export NCCL_DEBUG=TRACE` to turn on more logging for NCCL.
|
||||
- `export VLLM_TRACE_FUNCTION=1` to record all function calls for inspection in the log files to tell which function crashes or hangs.
|
||||
|
||||
## Incorrect network setup
|
||||
|
||||
The vLLM instance cannot get the correct IP address if you have a complicated network config. You can find a log such as `DEBUG 06-10 21:32:17 parallel_state.py:88] world_size=8 rank=0 local_rank=0 distributed_init_method=tcp://xxx.xxx.xxx.xxx:54641 backend=nccl` and the IP address should be the correct one.
|
||||
If it's not, override the IP address using the environment variable `export VLLM_HOST_IP=<your_ip_address>`.
|
||||
|
||||
You might also need to set `export NCCL_SOCKET_IFNAME=<your_network_interface>` and `export GLOO_SOCKET_IFNAME=<your_network_interface>` to specify the network interface for the IP address.
|
||||
|
||||
## Error near `self.graph.replay()`
|
||||
|
||||
If vLLM crashes and the error trace captures it somewhere around `self.graph.replay()` in `vllm/worker/model_runner.py`, it is a CUDA error inside CUDAGraph.
|
||||
To identify the particular CUDA operation that causes the error, you can add `--enforce-eager` to the command line, or `enforce_eager=True` to the {class}`~vllm.LLM` class to disable the CUDAGraph optimization and isolate the exact CUDA operation that causes the error.
|
||||
|
||||
## Incorrect hardware/driver
|
||||
|
||||
If GPU/CPU communication cannot be established, you can use the following Python script and follow the instructions below to confirm whether the GPU/CPU communication is working correctly.
|
||||
|
||||
```python
|
||||
# Test PyTorch NCCL
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
dist.init_process_group(backend="nccl")
|
||||
local_rank = dist.get_rank() % torch.cuda.device_count()
|
||||
torch.cuda.set_device(local_rank)
|
||||
data = torch.FloatTensor([1,] * 128).to("cuda")
|
||||
dist.all_reduce(data, op=dist.ReduceOp.SUM)
|
||||
torch.cuda.synchronize()
|
||||
value = data.mean().item()
|
||||
world_size = dist.get_world_size()
|
||||
assert value == world_size, f"Expected {world_size}, got {value}"
|
||||
|
||||
print("PyTorch NCCL is successful!")
|
||||
|
||||
# Test PyTorch GLOO
|
||||
gloo_group = dist.new_group(ranks=list(range(world_size)), backend="gloo")
|
||||
cpu_data = torch.FloatTensor([1,] * 128)
|
||||
dist.all_reduce(cpu_data, op=dist.ReduceOp.SUM, group=gloo_group)
|
||||
value = cpu_data.mean().item()
|
||||
assert value == world_size, f"Expected {world_size}, got {value}"
|
||||
|
||||
print("PyTorch GLOO is successful!")
|
||||
|
||||
if world_size <= 1:
|
||||
exit()
|
||||
|
||||
# Test vLLM NCCL, with cuda graph
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
|
||||
pynccl = PyNcclCommunicator(group=gloo_group, device=local_rank)
|
||||
# pynccl is enabled by default for 0.6.5+,
|
||||
# but for 0.6.4 and below, we need to enable it manually.
|
||||
# keep the code for backward compatibility when because people
|
||||
# prefer to read the latest documentation.
|
||||
pynccl.disabled = False
|
||||
|
||||
s = torch.cuda.Stream()
|
||||
with torch.cuda.stream(s):
|
||||
data.fill_(1)
|
||||
pynccl.all_reduce(data, stream=s)
|
||||
value = data.mean().item()
|
||||
assert value == world_size, f"Expected {world_size}, got {value}"
|
||||
|
||||
print("vLLM NCCL is successful!")
|
||||
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(cuda_graph=g, stream=s):
|
||||
pynccl.all_reduce(data, stream=torch.cuda.current_stream())
|
||||
|
||||
data.fill_(1)
|
||||
g.replay()
|
||||
torch.cuda.current_stream().synchronize()
|
||||
value = data.mean().item()
|
||||
assert value == world_size, f"Expected {world_size}, got {value}"
|
||||
|
||||
print("vLLM NCCL with cuda graph is successful!")
|
||||
|
||||
dist.destroy_process_group(gloo_group)
|
||||
dist.destroy_process_group()
|
||||
```
|
||||
|
||||
If you are testing with a single node, adjust `--nproc-per-node` to the number of GPUs you want to use:
|
||||
|
||||
```console
|
||||
$ NCCL_DEBUG=TRACE torchrun --nproc-per-node=<number-of-GPUs> test.py
|
||||
```
|
||||
|
||||
If you are testing with multi-nodes, adjust `--nproc-per-node` and `--nnodes` according to your setup and set `MASTER_ADDR` to the correct IP address of the master node, reachable from all nodes. Then, run:
|
||||
|
||||
```console
|
||||
$ NCCL_DEBUG=TRACE torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=$MASTER_ADDR test.py
|
||||
```
|
||||
|
||||
If the script runs successfully, you should see the message `sanity check is successful!`.
|
||||
|
||||
If the test script hangs or crashes, usually it means the hardware/drivers are broken in some sense. You should try to contact your system administrator or hardware vendor for further assistance. As a common workaround, you can try to tune some NCCL environment variables, such as `export NCCL_P2P_DISABLE=1` to see if it helps. Please check [their documentation](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html) for more information. Please only use these environment variables as a temporary workaround, as they might affect the performance of the system. The best solution is still to fix the hardware/drivers so that the test script can run successfully.
|
||||
|
||||
```{note}
|
||||
A multi-node environment is more complicated than a single-node one. If you see errors such as `torch.distributed.DistNetworkError`, it is likely that the network/DNS setup is incorrect. In that case, you can manually assign node rank and specify the IP via command line arguments:
|
||||
|
||||
- In the first node, run `NCCL_DEBUG=TRACE torchrun --nnodes 2 --nproc-per-node=2 --node-rank 0 --master_addr $MASTER_ADDR test.py`.
|
||||
- In the second node, run `NCCL_DEBUG=TRACE torchrun --nnodes 2 --nproc-per-node=2 --node-rank 1 --master_addr $MASTER_ADDR test.py`.
|
||||
|
||||
Adjust `--nproc-per-node`, `--nnodes`, and `--node-rank` according to your setup, being sure to execute different commands (with different `--node-rank`) on different nodes.
|
||||
```
|
||||
|
||||
(debugging-python-multiprocessing)=
|
||||
## Python multiprocessing
|
||||
|
||||
### `RuntimeError` Exception
|
||||
|
||||
If you have seen a warning in your logs like this:
|
||||
|
||||
```console
|
||||
WARNING 12-11 14:50:37 multiproc_worker_utils.py:281] CUDA was previously
|
||||
initialized. We must use the `spawn` multiprocessing start method. Setting
|
||||
VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. See
|
||||
https://docs.vllm.ai/en/latest/getting_started/debugging.html#python-multiprocessing
|
||||
for more information.
|
||||
```
|
||||
|
||||
or an error from Python that looks like this:
|
||||
|
||||
```console
|
||||
RuntimeError:
|
||||
An attempt has been made to start a new process before the
|
||||
current process has finished its bootstrapping phase.
|
||||
|
||||
This probably means that you are not using fork to start your
|
||||
child processes and you have forgotten to use the proper idiom
|
||||
in the main module:
|
||||
|
||||
if __name__ == '__main__':
|
||||
freeze_support()
|
||||
...
|
||||
|
||||
The "freeze_support()" line can be omitted if the program
|
||||
is not going to be frozen to produce an executable.
|
||||
|
||||
To fix this issue, refer to the "Safe importing of main module"
|
||||
section in https://docs.python.org/3/library/multiprocessing.html
|
||||
```
|
||||
|
||||
then you must update your Python code to guard usage of `vllm` behind a `if
|
||||
__name__ == '__main__':` block. For example, instead of this:
|
||||
|
||||
```python
|
||||
import vllm
|
||||
|
||||
llm = vllm.LLM(...)
|
||||
```
|
||||
|
||||
try this instead:
|
||||
|
||||
```python
|
||||
if __name__ == '__main__':
|
||||
import vllm
|
||||
|
||||
llm = vllm.LLM(...)
|
||||
```
|
||||
|
||||
## Known Issues
|
||||
|
||||
- In `v0.5.2`, `v0.5.3`, and `v0.5.3.post1`, there is a bug caused by [zmq](https://github.com/zeromq/pyzmq/issues/2000) , which can occasionally cause vLLM to hang depending on the machine configuration. The solution is to upgrade to the latest version of `vllm` to include the [fix](gh-pr:6759).
|
||||
- To circumvent a NCCL [bug](https://github.com/NVIDIA/nccl/issues/1234) , all vLLM processes will set an environment variable `NCCL_CUMEM_ENABLE=0` to disable NCCL's `cuMem` allocator. It does not affect performance but only gives memory benefits. When external processes want to set up a NCCL connection with vLLM's processes, they should also set this environment variable, otherwise, inconsistent environment setup will cause NCCL to hang or crash, as observed in the [RLHF integration](https://github.com/OpenRLHF/OpenRLHF/pull/604) and the [discussion](gh-issue:5723#issuecomment-2554389656) .
|
||||
@ -1,197 +0,0 @@
|
||||
.. _debugging:
|
||||
|
||||
===============
|
||||
Debugging Tips
|
||||
===============
|
||||
|
||||
This document outlines some debugging strategies you can consider. If you think you've discovered a bug, please `search existing issues <https://github.com/vllm-project/vllm/issues?q=is%3Aissue>`_ first to see if it has already been reported. If not, please `file a new issue <https://github.com/vllm-project/vllm/issues/new/choose>`_, providing as much relevant information as possible.
|
||||
|
||||
.. note::
|
||||
|
||||
Once you've debugged a problem, remember to turn off any debugging environment variables defined, or simply start a new shell to avoid being affected by lingering debugging settings. Otherwise, the system might be slow with debugging functionalities left activated.
|
||||
|
||||
Hangs downloading a model
|
||||
----------------------------------------
|
||||
If the model isn't already downloaded to disk, vLLM will download it from the internet which can take time and depend on your internet connection.
|
||||
It's recommended to download the model first using the `huggingface-cli <https://huggingface.co/docs/huggingface_hub/en/guides/cli>`_ and passing the local path to the model to vLLM. This way, you can isolate the issue.
|
||||
|
||||
Hangs loading a model from disk
|
||||
----------------------------------------
|
||||
If the model is large, it can take a long time to load it from disk. Pay attention to where you store the model. Some clusters have shared filesystems across nodes, e.g. a distributed filesystem or a network filesystem, which can be slow.
|
||||
It'd be better to store the model in a local disk. Additionally, have a look at the CPU memory usage, when the model is too large it might take a lot of CPU memory, slowing down the operating system because it needs to frequently swap between disk and memory.
|
||||
|
||||
.. note::
|
||||
|
||||
To isolate the model downloading and loading issue, you can use the ``--load-format dummy`` argument to skip loading the model weights. This way, you can check if the model downloading and loading is the bottleneck.
|
||||
|
||||
Model is too large
|
||||
----------------------------------------
|
||||
If the model is too large to fit in a single GPU, you might want to `consider tensor parallelism <https://docs.vllm.ai/en/latest/serving/distributed_serving.html#distributed-inference-and-serving>`_ to split the model across multiple GPUs. In that case, every process will read the whole model and split it into chunks, which makes the disk reading time even longer (proportional to the size of tensor parallelism). You can convert the model checkpoint to a sharded checkpoint using `this example <https://docs.vllm.ai/en/latest/getting_started/examples/save_sharded_state.html>`_ . The conversion process might take some time, but later you can load the sharded checkpoint much faster. The model loading time should remain constant regardless of the size of tensor parallelism.
|
||||
|
||||
Enable more logging
|
||||
----------------------------------------
|
||||
If other strategies don't solve the problem, it's likely that the vLLM instance is stuck somewhere. You can use the following environment variables to help debug the issue:
|
||||
|
||||
- ``export VLLM_LOGGING_LEVEL=DEBUG`` to turn on more logging.
|
||||
- ``export CUDA_LAUNCH_BLOCKING=1`` to identify which CUDA kernel is causing the problem.
|
||||
- ``export NCCL_DEBUG=TRACE`` to turn on more logging for NCCL.
|
||||
- ``export VLLM_TRACE_FUNCTION=1`` to record all function calls for inspection in the log files to tell which function crashes or hangs.
|
||||
|
||||
Incorrect network setup
|
||||
----------------------------------------
|
||||
The vLLM instance cannot get the correct IP address if you have a complicated network config. You can find a log such as ``DEBUG 06-10 21:32:17 parallel_state.py:88] world_size=8 rank=0 local_rank=0 distributed_init_method=tcp://xxx.xxx.xxx.xxx:54641 backend=nccl`` and the IP address should be the correct one.
|
||||
If it's not, override the IP address using the environment variable ``export VLLM_HOST_IP=<your_ip_address>``.
|
||||
|
||||
You might also need to set ``export NCCL_SOCKET_IFNAME=<your_network_interface>`` and ``export GLOO_SOCKET_IFNAME=<your_network_interface>`` to specify the network interface for the IP address.
|
||||
|
||||
Error near ``self.graph.replay()``
|
||||
----------------------------------------
|
||||
If vLLM crashes and the error trace captures it somewhere around ``self.graph.replay()`` in ``vllm/worker/model_runner.py``, it is a CUDA error inside CUDAGraph.
|
||||
To identify the particular CUDA operation that causes the error, you can add ``--enforce-eager`` to the command line, or ``enforce_eager=True`` to the :class:`~vllm.LLM` class to disable the CUDAGraph optimization and isolate the exact CUDA operation that causes the error.
|
||||
|
||||
Incorrect hardware/driver
|
||||
----------------------------------------
|
||||
If GPU/CPU communication cannot be established, you can use the following Python script and follow the instructions below to confirm whether the GPU/CPU communication is working correctly.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Test PyTorch NCCL
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
dist.init_process_group(backend="nccl")
|
||||
local_rank = dist.get_rank() % torch.cuda.device_count()
|
||||
torch.cuda.set_device(local_rank)
|
||||
data = torch.FloatTensor([1,] * 128).to("cuda")
|
||||
dist.all_reduce(data, op=dist.ReduceOp.SUM)
|
||||
torch.cuda.synchronize()
|
||||
value = data.mean().item()
|
||||
world_size = dist.get_world_size()
|
||||
assert value == world_size, f"Expected {world_size}, got {value}"
|
||||
|
||||
print("PyTorch NCCL is successful!")
|
||||
|
||||
# Test PyTorch GLOO
|
||||
gloo_group = dist.new_group(ranks=list(range(world_size)), backend="gloo")
|
||||
cpu_data = torch.FloatTensor([1,] * 128)
|
||||
dist.all_reduce(cpu_data, op=dist.ReduceOp.SUM, group=gloo_group)
|
||||
value = cpu_data.mean().item()
|
||||
assert value == world_size, f"Expected {world_size}, got {value}"
|
||||
|
||||
print("PyTorch GLOO is successful!")
|
||||
|
||||
if world_size <= 1:
|
||||
exit()
|
||||
|
||||
# Test vLLM NCCL, with cuda graph
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
|
||||
pynccl = PyNcclCommunicator(group=gloo_group, device=local_rank)
|
||||
|
||||
s = torch.cuda.Stream()
|
||||
with torch.cuda.stream(s):
|
||||
data.fill_(1)
|
||||
pynccl.all_reduce(data, stream=s)
|
||||
value = data.mean().item()
|
||||
assert value == world_size, f"Expected {world_size}, got {value}"
|
||||
|
||||
print("vLLM NCCL is successful!")
|
||||
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(cuda_graph=g, stream=s):
|
||||
pynccl.all_reduce(data, stream=torch.cuda.current_stream())
|
||||
|
||||
data.fill_(1)
|
||||
g.replay()
|
||||
torch.cuda.current_stream().synchronize()
|
||||
value = data.mean().item()
|
||||
assert value == world_size, f"Expected {world_size}, got {value}"
|
||||
|
||||
print("vLLM NCCL with cuda graph is successful!")
|
||||
|
||||
dist.destroy_process_group(gloo_group)
|
||||
dist.destroy_process_group()
|
||||
|
||||
If you are testing with a single node, adjust ``--nproc-per-node`` to the number of GPUs you want to use:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ NCCL_DEBUG=TRACE torchrun --nproc-per-node=<number-of-GPUs> test.py
|
||||
|
||||
If you are testing with multi-nodes, adjust ``--nproc-per-node`` and ``--nnodes`` according to your setup and set ``MASTER_ADDR`` to the correct IP address of the master node, reachable from all nodes. Then, run:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ NCCL_DEBUG=TRACE torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=$MASTER_ADDR test.py
|
||||
|
||||
If the script runs successfully, you should see the message ``sanity check is successful!``.
|
||||
|
||||
If the test script hangs or crashes, usually it means the hardware/drivers are broken in some sense. You should try to contact your system administrator or hardware vendor for further assistance. As a common workaround, you can try to tune some NCCL environment variables, such as ``export NCCL_P2P_DISABLE=1`` to see if it helps. Please check `their documentation <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html>`__ for more information. Please only use these environment variables as a temporary workaround, as they might affect the performance of the system. The best solution is still to fix the hardware/drivers so that the test script can run successfully.
|
||||
|
||||
.. note::
|
||||
|
||||
A multi-node environment is more complicated than a single-node one. If you see errors such as ``torch.distributed.DistNetworkError``, it is likely that the network/DNS setup is incorrect. In that case, you can manually assign node rank and specify the IP via command line arguments:
|
||||
|
||||
- In the first node, run ``NCCL_DEBUG=TRACE torchrun --nnodes 2 --nproc-per-node=2 --node-rank 0 --master_addr $MASTER_ADDR test.py``.
|
||||
- In the second node, run ``NCCL_DEBUG=TRACE torchrun --nnodes 2 --nproc-per-node=2 --node-rank 1 --master_addr $MASTER_ADDR test.py``.
|
||||
|
||||
Adjust ``--nproc-per-node``, ``--nnodes``, and ``--node-rank`` according to your setup, being sure to execute different commands (with different ``--node-rank``) on different nodes.
|
||||
|
||||
Python multiprocessing
|
||||
----------------------
|
||||
|
||||
`RuntimeError` Exception
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
If you have seen a warning in your logs like this:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
WARNING 12-11 14:50:37 multiproc_worker_utils.py:281] CUDA was previously
|
||||
initialized. We must use the `spawn` multiprocessing start method. Setting
|
||||
VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. See
|
||||
https://docs.vllm.ai/en/latest/getting_started/debugging.html#python-multiprocessing
|
||||
for more information.
|
||||
|
||||
or an error from Python that looks like this:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
RuntimeError:
|
||||
An attempt has been made to start a new process before the
|
||||
current process has finished its bootstrapping phase.
|
||||
|
||||
This probably means that you are not using fork to start your
|
||||
child processes and you have forgotten to use the proper idiom
|
||||
in the main module:
|
||||
|
||||
if __name__ == '__main__':
|
||||
freeze_support()
|
||||
...
|
||||
|
||||
The "freeze_support()" line can be omitted if the program
|
||||
is not going to be frozen to produce an executable.
|
||||
|
||||
To fix this issue, refer to the "Safe importing of main module"
|
||||
section in https://docs.python.org/3/library/multiprocessing.html
|
||||
|
||||
then you must update your Python code to guard usage of ``vllm`` behind a ``if
|
||||
__name__ == '__main__':`` block. For example, instead of this:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import vllm
|
||||
|
||||
llm = vllm.LLM(...)
|
||||
|
||||
try this instead:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
if __name__ == '__main__':
|
||||
import vllm
|
||||
|
||||
llm = vllm.LLM(...)
|
||||
|
||||
Known Issues
|
||||
----------------------------------------
|
||||
- In ``v0.5.2``, ``v0.5.3``, and ``v0.5.3.post1``, there is a bug caused by `zmq <https://github.com/zeromq/pyzmq/issues/2000>`_ , which can occasionally cause vLLM to hang depending on the machine configuration. The solution is to upgrade to the latest version of ``vllm`` to include the `fix <https://github.com/vllm-project/vllm/pull/6759>`_.
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user