Compare commits
2 Commits
remove-asy
...
remove_mam
| Author | SHA1 | Date | |
|---|---|---|---|
| ddb65dad96 | |||
| c41ea52634 |
@ -168,9 +168,9 @@ See [nightly-descriptions.md](nightly-descriptions.md) for the detailed descript
|
||||
### Workflow
|
||||
|
||||
- The [nightly-pipeline.yaml](nightly-pipeline.yaml) specifies the docker containers for different LLM serving engines.
|
||||
- Inside each container, we run [scripts/run-nightly-benchmarks.sh](scripts/run-nightly-benchmarks.sh), which will probe the serving engine of the current container.
|
||||
- The `scripts/run-nightly-benchmarks.sh` will parse the workload described in [nightly-tests.json](tests/nightly-tests.json) and launch the right benchmark for the specified serving engine via `scripts/launch-server.sh`.
|
||||
- At last, we run [scripts/summary-nightly-results.py](scripts/summary-nightly-results.py) to collect and plot the final benchmarking results, and update the results to buildkite.
|
||||
- Inside each container, we run [run-nightly-suite.sh](run-nightly-suite.sh), which will probe the serving engine of the current container.
|
||||
- The `run-nightly-suite.sh` will redirect the request to `tests/run-[llm serving engine name]-nightly.sh`, which parses the workload described in [nightly-tests.json](tests/nightly-tests.json) and performs the benchmark.
|
||||
- At last, we run [scripts/plot-nightly-results.py](scripts/plot-nightly-results.py) to collect and plot the final benchmarking results, and update the results to buildkite.
|
||||
|
||||
### Nightly tests
|
||||
|
||||
@ -180,6 +180,6 @@ In [nightly-tests.json](tests/nightly-tests.json), we include the command line a
|
||||
|
||||
The docker containers for benchmarking are specified in `nightly-pipeline.yaml`.
|
||||
|
||||
WARNING: the docker versions are HARD-CODED and SHOULD BE ALIGNED WITH `nightly-descriptions.md`. The docker versions need to be hard-coded as there are several version-specific bug fixes inside `scripts/run-nightly-benchmarks.sh` and `scripts/launch-server.sh`.
|
||||
WARNING: the docker versions are HARD-CODED and SHOULD BE ALIGNED WITH `nightly-descriptions.md`. The docker versions need to be hard-coded as there are several version-specific bug fixes inside `tests/run-[llm serving engine name]-nightly.sh`.
|
||||
|
||||
WARNING: populating `trt-llm` to latest version is not easy, as it requires updating several protobuf files in [tensorrt-demo](https://github.com/neuralmagic/tensorrt-demo.git).
|
||||
|
||||
@ -12,6 +12,7 @@
|
||||
"vllm_server_parameters": {
|
||||
"disable_log_stats": "",
|
||||
"gpu_memory_utilization": 0.9,
|
||||
"num_scheduler_steps": 10,
|
||||
"max_num_seqs": 512,
|
||||
"dtype": "bfloat16"
|
||||
},
|
||||
|
||||
@ -36,6 +36,7 @@
|
||||
"vllm_server_parameters": {
|
||||
"disable_log_stats": "",
|
||||
"gpu_memory_utilization": 0.9,
|
||||
"num_scheduler_steps": 10,
|
||||
"max_num_seqs": 512,
|
||||
"dtype": "bfloat16"
|
||||
},
|
||||
@ -89,6 +90,7 @@
|
||||
"vllm_server_parameters": {
|
||||
"disable_log_stats": "",
|
||||
"gpu_memory_utilization": 0.9,
|
||||
"num_scheduler_steps": 10,
|
||||
"max_num_seqs": 512,
|
||||
"dtype": "bfloat16"
|
||||
},
|
||||
@ -142,6 +144,7 @@
|
||||
"vllm_server_parameters": {
|
||||
"disable_log_stats": "",
|
||||
"gpu_memory_utilization": 0.9,
|
||||
"num_scheduler_steps": 10,
|
||||
"max_num_seqs": 512,
|
||||
"dtype": "bfloat16"
|
||||
},
|
||||
@ -192,6 +195,7 @@
|
||||
"vllm_server_parameters": {
|
||||
"disable_log_stats": "",
|
||||
"gpu_memory_utilization": 0.9,
|
||||
"num_scheduler_steps": 10,
|
||||
"max_num_seqs": 512,
|
||||
"dtype": "bfloat16"
|
||||
},
|
||||
@ -244,6 +248,7 @@
|
||||
"vllm_server_parameters": {
|
||||
"disable_log_stats": "",
|
||||
"gpu_memory_utilization": 0.9,
|
||||
"num_scheduler_steps": 10,
|
||||
"max_num_seqs": 512,
|
||||
"dtype": "bfloat16"
|
||||
},
|
||||
@ -296,6 +301,7 @@
|
||||
"vllm_server_parameters": {
|
||||
"disable_log_stats": "",
|
||||
"gpu_memory_utilization": 0.9,
|
||||
"num_scheduler_steps": 10,
|
||||
"max_num_seqs": 512,
|
||||
"dtype": "bfloat16"
|
||||
},
|
||||
|
||||
@ -128,7 +128,7 @@ run_and_track_test() {
|
||||
|
||||
# --- Actual Test Execution ---
|
||||
run_and_track_test 1 "test_struct_output_generate.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py -k \"not test_structured_output_with_reasoning_matrices\""
|
||||
"HF_HUB_DISABLE_XET=1 python3 -m pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py -k \"not test_structured_output_with_reasoning_matrices\""
|
||||
run_and_track_test 2 "test_moe_pallas.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py"
|
||||
run_and_track_test 3 "test_lora.py" \
|
||||
@ -139,8 +139,6 @@ run_and_track_test 5 "test_spmd_model_weight_loading.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_spmd_model_weight_loading.py"
|
||||
run_and_track_test 6 "test_kv_cache_update_kernel.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_kv_cache_update_kernel.py"
|
||||
run_and_track_test 7 "test_tpu_int8.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_tpu_int8.py"
|
||||
|
||||
# After all tests have been attempted, exit with the overall status.
|
||||
if [ "$overall_script_exit_code" -ne 0 ]; then
|
||||
|
||||
@ -134,7 +134,7 @@ run_and_track_test 1 "test_compilation.py" \
|
||||
run_and_track_test 2 "test_basic.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_basic.py"
|
||||
run_and_track_test 3 "test_accuracy.py::test_lm_eval_accuracy_v1_engine" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine"
|
||||
"HF_HUB_DISABLE_XET=1 python3 -m pytest -s -v /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine"
|
||||
run_and_track_test 4 "test_quantization_accuracy.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py"
|
||||
run_and_track_test 5 "examples/offline_inference/tpu.py" \
|
||||
|
||||
@ -56,19 +56,21 @@ steps:
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/mq_llm_engine
|
||||
- tests/test_inputs.py
|
||||
- tests/test_outputs.py
|
||||
- tests/async_engine
|
||||
- tests/test_inputs
|
||||
- tests/multimodal
|
||||
- tests/utils_
|
||||
- tests/test_utils
|
||||
- tests/worker
|
||||
- tests/standalone_tests/lazy_imports.py
|
||||
commands:
|
||||
- python3 standalone_tests/lazy_imports.py
|
||||
- pytest -v -s mq_llm_engine # MQLLMEngine
|
||||
- pytest -v -s async_engine # AsyncLLMEngine
|
||||
- NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py
|
||||
- pytest -v -s test_inputs.py
|
||||
- pytest -v -s test_outputs.py
|
||||
- pytest -v -s multimodal
|
||||
- pytest -v -s utils_ # Utils
|
||||
- pytest -v -s test_utils.py # Utils
|
||||
- pytest -v -s worker # Worker
|
||||
|
||||
- label: Python-only Installation Test
|
||||
@ -424,6 +426,7 @@ steps:
|
||||
|
||||
- label: Tensorizer Test # 11min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
soft_fail: true
|
||||
source_file_dependencies:
|
||||
- vllm/model_executor/model_loader
|
||||
- tests/tensorizer_loader
|
||||
@ -532,6 +535,8 @@ steps:
|
||||
- vllm/
|
||||
- tests/models/language
|
||||
commands:
|
||||
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
|
||||
- pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8'
|
||||
- pip freeze | grep -E 'torch'
|
||||
- pytest -v -s models/language -m core_model
|
||||
|
||||
@ -542,10 +547,8 @@ steps:
|
||||
- vllm/
|
||||
- tests/models/language/generation
|
||||
commands:
|
||||
# Install fast path packages for testing against transformers
|
||||
# Note: also needed to run plamo2 model in vLLM
|
||||
- uv pip install --system --no-build-isolation 'git+https://github.com/state-spaces/mamba@v2.2.5'
|
||||
- uv pip install --system --no-build-isolation 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.2'
|
||||
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
|
||||
- pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8'
|
||||
- pytest -v -s models/language/generation -m hybrid_model
|
||||
|
||||
- label: Language Models Test (Extended Generation) # 1hr20min
|
||||
@ -770,6 +773,27 @@ steps:
|
||||
- pytest -v -s models/test_oot_registration.py # it needs a clean process
|
||||
- pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins
|
||||
|
||||
- label: Multi-step Tests (4 GPUs) # 36min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 4
|
||||
source_file_dependencies:
|
||||
- vllm/model_executor/layers/sampler.py
|
||||
- vllm/sequence.py
|
||||
- vllm/worker/worker_base.py
|
||||
- vllm/worker/worker.py
|
||||
- vllm/worker/multi_step_worker.py
|
||||
- vllm/worker/model_runner_base.py
|
||||
- vllm/worker/model_runner.py
|
||||
- vllm/worker/multi_step_model_runner.py
|
||||
- vllm/engine
|
||||
- tests/multi_step
|
||||
commands:
|
||||
# this test is quite flaky
|
||||
# TODO: investigate and fix.
|
||||
# - pytest -v -s multi_step/test_correctness_async_llm.py
|
||||
- pytest -v -s multi_step/test_correctness_llm.py
|
||||
|
||||
- label: Pipeline Parallelism Test # 45min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
|
||||
11
.github/CODEOWNERS
vendored
11
.github/CODEOWNERS
vendored
@ -9,7 +9,7 @@
|
||||
/vllm/worker/worker_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/worker/worker.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256
|
||||
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth
|
||||
/vllm/multimodal @DarkLight1337 @ywang96
|
||||
/vllm/vllm_flash_attn @LucasWilkinson
|
||||
/vllm/lora @jeejeelee
|
||||
@ -20,7 +20,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||
|
||||
# Any change to the VllmConfig changes can have a large user-facing impact,
|
||||
# so spam a lot of people
|
||||
/vllm/config @simon-mo @WoosukKwon @youkaichao @robertgshaw2-redhat @mgoin @tlrmchlsmth @houseroad @hmellor @yewentao256 @ProExpertProg
|
||||
/vllm/config.py @simon-mo @WoosukKwon @youkaichao @robertgshaw2-redhat @mgoin @tlrmchlsmth @houseroad @hmellor
|
||||
|
||||
# vLLM V1
|
||||
/vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat
|
||||
@ -34,15 +34,16 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||
/tests/distributed/test_pipeline_parallel.py @youkaichao
|
||||
/tests/distributed/test_same_node.py @youkaichao
|
||||
/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @simon-mo @aarnphm
|
||||
/tests/kernels @tlrmchlsmth @WoosukKwon @yewentao256
|
||||
/tests/kernels @tlrmchlsmth @WoosukKwon
|
||||
/tests/models @DarkLight1337 @ywang96
|
||||
/tests/multi_step @alexm-redhat @comaniac
|
||||
/tests/multimodal @DarkLight1337 @ywang96
|
||||
/tests/prefix_caching @comaniac @KuntaiDu
|
||||
/tests/quantization @mgoin @robertgshaw2-redhat @yewentao256
|
||||
/tests/quantization @mgoin @robertgshaw2-redhat
|
||||
/tests/test_inputs.py @DarkLight1337 @ywang96
|
||||
/tests/v1/entrypoints/llm/test_struct_output_generate.py @mgoin @russellb @aarnphm
|
||||
/tests/v1/structured_output @mgoin @russellb @aarnphm
|
||||
/tests/weight_loading @mgoin @youkaichao @yewentao256
|
||||
/tests/weight_loading @mgoin @youkaichao
|
||||
/tests/lora @jeejeelee
|
||||
|
||||
# Docs
|
||||
|
||||
20
.github/PULL_REQUEST_TEMPLATE.md
vendored
20
.github/PULL_REQUEST_TEMPLATE.md
vendored
@ -1,5 +1,11 @@
|
||||
<!-- markdownlint-disable -->
|
||||
PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.
|
||||
# Essential Elements of an Effective PR Description Checklist
|
||||
|
||||
- [ ] The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
|
||||
- [ ] The test plan, such as providing test command.
|
||||
- [ ] The test results, such as pasting the results comparison before and after, or e2e results
|
||||
- [ ] (Optional) The necessary documentation update, such as updating `supported_models.md` and `examples` for a new model.
|
||||
|
||||
PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS ABOVE HAVE BEEN CONSIDERED.
|
||||
|
||||
## Purpose
|
||||
|
||||
@ -9,14 +15,4 @@ PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTT
|
||||
|
||||
## (Optional) Documentation Update
|
||||
|
||||
---
|
||||
<details>
|
||||
<summary> Essential Elements of an Effective PR Description Checklist </summary>
|
||||
|
||||
- [ ] The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
|
||||
- [ ] The test plan, such as providing test command.
|
||||
- [ ] The test results, such as pasting the results comparison before and after, or e2e results
|
||||
- [ ] (Optional) The necessary documentation update, such as updating `supported_models.md` and `examples` for a new model.
|
||||
</details>
|
||||
|
||||
**BEFORE SUBMITTING, PLEASE READ <https://docs.vllm.ai/en/latest/contributing>** (anything written below this line will be removed by GitHub Actions)
|
||||
|
||||
14
.github/mergify.yml
vendored
14
.github/mergify.yml
vendored
@ -118,20 +118,6 @@ pull_request_rules:
|
||||
add:
|
||||
- qwen
|
||||
|
||||
- name: label-gpt-oss
|
||||
description: Automatically apply gpt-oss label
|
||||
conditions:
|
||||
- or:
|
||||
- files~=^examples/.*gpt[-_]?oss.*\.py
|
||||
- files~=^tests/.*gpt[-_]?oss.*\.py
|
||||
- files~=^vllm/model_executor/models/.*gpt[-_]?oss.*\.py
|
||||
- files~=^vllm/model_executor/layers/.*gpt[-_]?oss.*\.py
|
||||
- title~=(?i)gpt[-_]?oss
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
- gpt-oss
|
||||
|
||||
- name: label-rocm
|
||||
description: Automatically apply rocm label
|
||||
conditions:
|
||||
|
||||
8
.github/scripts/cleanup_pr_body.sh
vendored
8
.github/scripts/cleanup_pr_body.sh
vendored
@ -15,11 +15,11 @@ NEW=/tmp/new_pr_body.txt
|
||||
gh pr view --json body --template "{{.body}}" "${PR_NUMBER}" > "${OLD}"
|
||||
cp "${OLD}" "${NEW}"
|
||||
|
||||
# Remove markdown comments (like the <!-- markdownlint-disable --> at the start)
|
||||
sed -i '/<!--.*-->$/d' "${NEW}"
|
||||
# Remove "FIX #xxxx (*link existing issues this PR will resolve*)"
|
||||
sed -i '/FIX #xxxx.*$/d' "${NEW}"
|
||||
|
||||
# Remove "PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED."
|
||||
sed -i '/PLEASE FILL IN THE PR DESCRIPTION HERE.*$/d' "${NEW}"
|
||||
# Remove "FILL IN THE PR DESCRIPTION HERE"
|
||||
sed -i '/FILL IN THE PR DESCRIPTION HERE/d' "${NEW}"
|
||||
|
||||
# Remove all lines after and including "**BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE**"
|
||||
sed -i '/\*\*BEFORE SUBMITTING, PLEASE READ.*\*\*/,$d' "${NEW}"
|
||||
|
||||
6
.gitignore
vendored
6
.gitignore
vendored
@ -4,9 +4,6 @@
|
||||
# vllm-flash-attn built from source
|
||||
vllm/vllm_flash_attn/*
|
||||
|
||||
# triton jit
|
||||
.triton
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
@ -150,8 +147,7 @@ venv.bak/
|
||||
# mkdocs documentation
|
||||
/site
|
||||
docs/argparse
|
||||
docs/examples/*
|
||||
!docs/examples/README.md
|
||||
docs/examples
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
|
||||
@ -427,7 +427,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
set(SRCS
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8.cu"
|
||||
)
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
|
||||
@ -18,15 +18,14 @@ Easy, fast, and cheap LLM serving for everyone
|
||||
|
||||
*Latest News* 🔥
|
||||
|
||||
- [2025/08] We hosted [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/dgkWg1WFpWGO2jCdTqQHxA) focusing on large-scale LLM deployment! Please find the meetup slides [here](https://drive.google.com/drive/folders/1Pid6NSFLU43DZRi0EaTcPgXsAzDvbBqF) and the recording [here](https://www.chaspark.com/#/live/1166916873711665152).
|
||||
- [2025/05] We hosted [NYC vLLM Meetup](https://lu.ma/c1rqyf1f)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing).
|
||||
- [2025/05] vLLM is now a hosted project under PyTorch Foundation! Please find the announcement [here](https://pytorch.org/blog/pytorch-foundation-welcomes-vllm/).
|
||||
- [2025/04] We hosted [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing).
|
||||
- [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html).
|
||||
|
||||
<details>
|
||||
<summary>Previous News</summary>
|
||||
|
||||
- [2025/04] We hosted [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing).
|
||||
- [2025/03] We hosted [vLLM x Ollama Inference Night](https://lu.ma/vllm-ollama)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/16T2PDD1YwRnZ4Tu8Q5r6n53c5Lr5c73UV9Vd2_eBo4U/edit?usp=sharing).
|
||||
- [2025/03] We hosted [the first vLLM China Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg)! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1REHvfQMKGnvz6p3Fd23HhSO4c8j5WPGZV0bKYLwnHyQ/edit?usp=sharing).
|
||||
- [2025/03] We hosted [the East Coast vLLM Meetup](https://lu.ma/7mu4k4xx)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1NHiv8EUFF1NLd3fEYODm56nDmL26lEeXCaDgyDlTsRs/edit#slide=id.g31441846c39_0_0).
|
||||
@ -122,7 +121,6 @@ Cash Donations:
|
||||
|
||||
Compute Resources:
|
||||
|
||||
- Alibaba Cloud
|
||||
- AMD
|
||||
- Anyscale
|
||||
- AWS
|
||||
@ -162,7 +160,7 @@ If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs
|
||||
## Contact Us
|
||||
|
||||
<!-- --8<-- [start:contact-us] -->
|
||||
- For technical questions and feature requests, please use GitHub [Issues](https://github.com/vllm-project/vllm/issues)
|
||||
- For technical questions and feature requests, please use GitHub [Issues](https://github.com/vllm-project/vllm/issues) or [Discussions](https://github.com/vllm-project/vllm/discussions)
|
||||
- For discussing with fellow users, please use the [vLLM Forum](https://discuss.vllm.ai)
|
||||
- For coordinating contributions and development, please use [Slack](https://slack.vllm.ai)
|
||||
- For security disclosures, please use GitHub's [Security Advisories](https://github.com/vllm-project/vllm/security/advisories) feature
|
||||
|
||||
@ -31,7 +31,7 @@ class RequestFuncInput:
|
||||
model_name: Optional[str] = None
|
||||
logprobs: Optional[int] = None
|
||||
extra_body: Optional[dict] = None
|
||||
multi_modal_content: Optional[dict | list[dict]] = None
|
||||
multi_modal_content: Optional[dict] = None
|
||||
ignore_eos: bool = False
|
||||
language: Optional[str] = None
|
||||
|
||||
@ -364,15 +364,7 @@ async def async_request_openai_chat_completions(
|
||||
) as session:
|
||||
content = [{"type": "text", "text": request_func_input.prompt}]
|
||||
if request_func_input.multi_modal_content:
|
||||
mm_content = request_func_input.multi_modal_content
|
||||
if isinstance(mm_content, list):
|
||||
content.extend(mm_content)
|
||||
elif isinstance(mm_content, dict):
|
||||
content.append(mm_content)
|
||||
else:
|
||||
raise TypeError(
|
||||
"multi_modal_content must be a dict or list[dict] for openai-chat"
|
||||
)
|
||||
content.append(request_func_input.multi_modal_content)
|
||||
payload = {
|
||||
"model": request_func_input.model_name
|
||||
if request_func_input.model_name
|
||||
@ -499,10 +491,7 @@ async def async_request_openai_audio(
|
||||
buffer.seek(0)
|
||||
return buffer
|
||||
|
||||
mm_audio = request_func_input.multi_modal_content
|
||||
if not isinstance(mm_audio, dict) or "audio" not in mm_audio:
|
||||
raise TypeError("multi_modal_content must be a dict containing 'audio'")
|
||||
with to_bytes(*mm_audio["audio"]) as f:
|
||||
with to_bytes(*request_func_input.multi_modal_content["audio"]) as f:
|
||||
form = aiohttp.FormData()
|
||||
form.add_field("file", f, content_type="audio/wav")
|
||||
for key, value in payload.items():
|
||||
|
||||
@ -1,74 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import gc
|
||||
|
||||
from tabulate import tabulate
|
||||
|
||||
from benchmark_utils import TimeCollector
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
|
||||
|
||||
def main(args):
|
||||
rows = []
|
||||
for allocate_block in args.allocate_blocks:
|
||||
# Enforce a GC collect ahead to minimize the impact among runs
|
||||
gc.collect()
|
||||
block_pool = BlockPool(num_gpu_blocks=args.num_gpu_blocks, enable_caching=True)
|
||||
|
||||
get_blocks_times = TimeCollector(TimeCollector.US)
|
||||
free_blocks_times = TimeCollector(TimeCollector.US)
|
||||
for _ in range(args.num_iteration):
|
||||
with get_blocks_times:
|
||||
blocks = block_pool.get_new_blocks(allocate_block)
|
||||
with free_blocks_times:
|
||||
block_pool.free_blocks(blocks)
|
||||
|
||||
rows.append(
|
||||
[get_blocks_times.cnt, args.num_gpu_blocks, allocate_block]
|
||||
+ get_blocks_times.dump_avg_max()
|
||||
+ free_blocks_times.dump_avg_max()
|
||||
)
|
||||
|
||||
print(
|
||||
tabulate(
|
||||
rows,
|
||||
headers=[
|
||||
"Iterations",
|
||||
"Total\nBlocks",
|
||||
"Allocated\nBlocks",
|
||||
"Get Blocks\nAvg (us)",
|
||||
"Get Blocks\nMax (us)",
|
||||
"Free Blocks\nAvg (us)",
|
||||
"Free Blocks\nMax (us)",
|
||||
],
|
||||
tablefmt="grid",
|
||||
floatfmt=".3f",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def invoke_main() -> None:
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark the performance of BlockPool for KV Cache."
|
||||
)
|
||||
parser.add_argument("--num-gpu-blocks", type=int, default=100000)
|
||||
parser.add_argument(
|
||||
"--num-iteration",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of iterations to run to stablize final data readings",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allocate-blocks",
|
||||
type=int,
|
||||
nargs="*",
|
||||
default=[10, 50, 100, 500, 1000],
|
||||
help="Number of blocks to allocate",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
invoke_main() # pragma: no cover
|
||||
@ -52,7 +52,7 @@ class SampleRequest:
|
||||
prompt: Union[str, Any]
|
||||
prompt_len: int
|
||||
expected_output_len: int
|
||||
multi_modal_data: Optional[Union[MultiModalDataDict, dict, list[dict]]] = None
|
||||
multi_modal_data: Optional[Union[MultiModalDataDict, dict]] = None
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
|
||||
|
||||
|
||||
@ -1,112 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import gc
|
||||
|
||||
import numpy as np
|
||||
from tabulate import tabulate
|
||||
|
||||
from benchmark_utils import TimeCollector
|
||||
from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
|
||||
|
||||
def main(args):
|
||||
rows = []
|
||||
for max_ngram in args.max_ngram:
|
||||
collector = TimeCollector(TimeCollector.US)
|
||||
|
||||
model_config = ModelConfig(
|
||||
model="facebook/opt-125m",
|
||||
task="generate",
|
||||
max_model_len=args.num_token + args.num_spec_token,
|
||||
tokenizer="facebook/opt-125m",
|
||||
tokenizer_mode="auto",
|
||||
dtype="auto",
|
||||
seed=None,
|
||||
trust_remote_code=False,
|
||||
)
|
||||
proposer = NgramProposer(
|
||||
vllm_config=VllmConfig(
|
||||
model_config=model_config,
|
||||
speculative_config=SpeculativeConfig(
|
||||
prompt_lookup_min=args.min_ngram,
|
||||
prompt_lookup_max=max_ngram,
|
||||
num_speculative_tokens=args.num_spec_token,
|
||||
method="ngram",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Warm up
|
||||
proposer.propose(np.random.randint(0, 20, (args.num_token,)))
|
||||
|
||||
gc.collect()
|
||||
for _ in range(args.num_iteration):
|
||||
tokens = np.random.randint(0, 20, (args.num_req, args.num_token))
|
||||
with collector:
|
||||
for i in range(args.num_req):
|
||||
proposer.propose(tokens[i, :])
|
||||
rows.append(
|
||||
[args.num_req, args.num_token, args.min_ngram, max_ngram]
|
||||
+ collector.dump_avg_max()
|
||||
)
|
||||
|
||||
print(
|
||||
tabulate(
|
||||
rows,
|
||||
headers=[
|
||||
"# Request",
|
||||
"# Token",
|
||||
"Min Ngram",
|
||||
"Max Ngram",
|
||||
"Avg (us)",
|
||||
"Max (us)",
|
||||
],
|
||||
tablefmt="grid",
|
||||
floatfmt=".3f",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def invoke_main() -> None:
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark the performance of N-gram speculative decode drafting"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-iteration",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Number of iterations to run to stablize final data readings",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-req", type=int, default=128, help="Number of requests in the batch"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-token", type=int, default=1500, help="Number of tokens for each request"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min-ngram",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Minimum n-gram to match",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-ngram",
|
||||
type=int,
|
||||
nargs="*",
|
||||
default=[5, 7, 10, 15, 20],
|
||||
help="Maximum n-gram to match",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-spec-token",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Number of speculative tokens to generate",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
invoke_main() # pragma: no cover
|
||||
@ -263,14 +263,7 @@ async def benchmark(
|
||||
input_requests[0].multi_modal_data,
|
||||
)
|
||||
|
||||
assert (
|
||||
test_mm_content is None
|
||||
or isinstance(test_mm_content, dict)
|
||||
or (
|
||||
isinstance(test_mm_content, list)
|
||||
and all(isinstance(item, dict) for item in test_mm_content)
|
||||
)
|
||||
), "multi_modal_data must be a dict or list[dict]"
|
||||
assert test_mm_content is None or isinstance(test_mm_content, dict)
|
||||
test_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
model_name=model_name,
|
||||
|
||||
@ -1,12 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from types import TracebackType
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
|
||||
def convert_to_pytorch_benchmark_format(
|
||||
@ -73,53 +72,3 @@ def write_to_json(filename: str, records: list) -> None:
|
||||
cls=InfEncoder,
|
||||
default=lambda o: f"<{type(o).__name__} object is not JSON serializable>",
|
||||
)
|
||||
|
||||
|
||||
# Collect time and generate time metrics
|
||||
#
|
||||
# Example Usage:
|
||||
# collector = TimeCollector(TimeCollector.US)
|
||||
# for _ in range(total_iteration):
|
||||
# with collector:
|
||||
# ...
|
||||
# collector.dump_avg_max()
|
||||
class TimeCollector:
|
||||
NS: int = 1
|
||||
US: int = NS * 1000
|
||||
MS: int = US * 1000
|
||||
S: int = MS * 1000
|
||||
|
||||
def __init__(self, scale: int) -> None:
|
||||
self.cnt: int = 0
|
||||
self._sum: int = 0
|
||||
self._max: Optional[int] = None
|
||||
self.scale = scale
|
||||
self.start_time: int = time.monotonic_ns()
|
||||
|
||||
def collect(self, v: int) -> None:
|
||||
self.cnt += 1
|
||||
self._sum += v
|
||||
if self._max is None:
|
||||
self._max = v
|
||||
else:
|
||||
self._max = max(self._max, v)
|
||||
|
||||
def avg(self) -> Union[float, str]:
|
||||
return self._sum * 1.0 / self.cnt / self.scale if self.cnt > 0 else "N/A"
|
||||
|
||||
def max(self) -> Union[float, str]:
|
||||
return self._max / self.scale if self._max else "N/A"
|
||||
|
||||
def dump_avg_max(self) -> list[Union[float, str]]:
|
||||
return [self.avg(), self.max()]
|
||||
|
||||
def __enter__(self) -> None:
|
||||
self.start_time = time.monotonic_ns()
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[type[BaseException]],
|
||||
exc_value: Optional[BaseException],
|
||||
exc_traceback: Optional[TracebackType],
|
||||
) -> None:
|
||||
self.collect(time.monotonic_ns() - self.start_time)
|
||||
|
||||
@ -3,8 +3,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from packaging import version
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
|
||||
MINIMUM_BITBLAS_VERSION,
|
||||
)
|
||||
@ -12,7 +10,7 @@ from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
|
||||
try:
|
||||
import bitblas
|
||||
|
||||
if version.parse(bitblas.__version__) < version.parse(MINIMUM_BITBLAS_VERSION):
|
||||
if bitblas.__version__ < MINIMUM_BITBLAS_VERSION:
|
||||
raise ImportError(
|
||||
"bitblas version is wrong. Please "
|
||||
f"install bitblas>={MINIMUM_BITBLAS_VERSION}"
|
||||
|
||||
@ -22,10 +22,10 @@ from vllm.utils import FlexibleArgumentParser
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
def ensure_divisibility(numerator, denominator, text):
|
||||
def ensure_divisibility(numerator, denominator):
|
||||
"""Ensure that numerator is divisible by the denominator."""
|
||||
assert numerator % denominator == 0, "{} {} is not divisible by tp {}.".format(
|
||||
text, numerator, denominator
|
||||
assert numerator % denominator == 0, (
|
||||
"intermediate_size {} is not divisible by tp {}.".format(numerator, denominator)
|
||||
)
|
||||
|
||||
|
||||
@ -577,10 +577,12 @@ def main(args: argparse.Namespace):
|
||||
E = config.ffn_config.moe_num_experts
|
||||
topk = config.ffn_config.moe_top_k
|
||||
intermediate_size = config.ffn_config.ffn_hidden_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
elif config.architectures[0] == "JambaForCausalLM":
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
elif config.architectures[0] in (
|
||||
"DeepseekV3ForCausalLM",
|
||||
"DeepseekV2ForCausalLM",
|
||||
@ -589,14 +591,17 @@ def main(args: argparse.Namespace):
|
||||
E = config.n_routed_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
elif config.architectures[0] in ("Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"):
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
elif config.architectures[0] in ("HunYuanMoEV1ForCausalLM"):
|
||||
E = config.num_experts
|
||||
topk = config.moe_topk[0]
|
||||
intermediate_size = config.moe_intermediate_size[0]
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
else:
|
||||
# Support for llama4
|
||||
config = config.get_text_config()
|
||||
@ -604,14 +609,8 @@ def main(args: argparse.Namespace):
|
||||
E = config.num_local_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
enable_ep = bool(args.enable_expert_parallel)
|
||||
if enable_ep:
|
||||
ensure_divisibility(E, args.tp_size, "Number of experts")
|
||||
E = E // args.tp_size
|
||||
shard_intermediate_size = 2 * intermediate_size
|
||||
else:
|
||||
ensure_divisibility(intermediate_size, args.tp_size, "intermediate_size")
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
ensure_divisibility(intermediate_size, args.tp_size)
|
||||
hidden_size = config.hidden_size
|
||||
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
|
||||
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
||||
@ -743,7 +742,6 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--tp-size", "-tp", "--tensor-parallel-size", type=int, default=2
|
||||
)
|
||||
parser.add_argument("--enable-expert-parallel", "-enable-ep", action="store_true")
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
|
||||
)
|
||||
|
||||
@ -1,328 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# This script benchmarks the mrope kernel (mainly for Qwen2VL and Qwen2.5VL models).
|
||||
# It generates test data, runs benchmarks, and saves results to a CSV file.
|
||||
#
|
||||
# The CSV file (named with current date/time) contains these columns:
|
||||
# model_name, tp_size, num_tokens, num_heads, num_kv_heads, head_dim, max_position,
|
||||
# rope_theta, is_neox_style, rope_scaling, dtype, torch_mean, torch_median, torch_p99,
|
||||
# torch_min, torch_max, triton_mean, triton_median, triton_p99, triton_min, triton_max,
|
||||
# speedup
|
||||
#
|
||||
# == Usage Examples ==
|
||||
#
|
||||
# Single model benchmark:
|
||||
# python3 benchmark_mrope.py --model-name Qwen/Qwen2-VL-7B-Instruct --tp-size 1 \
|
||||
# --warmup-iter 10 --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024
|
||||
#
|
||||
# All models benchmark:
|
||||
# python3 benchmark_mrope.py --model-name "" --tp-size 1 --warmup-iter 10 \
|
||||
# --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024
|
||||
#
|
||||
# All models with different TP sizes:
|
||||
# python3 benchmark_mrope.py --model-name "" --tp-size 1 2 4 8 --warmup-iter 10 \
|
||||
# --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024
|
||||
#
|
||||
# All models with different token counts:
|
||||
# python3 benchmark_mrope.py --model-name "" --tp-size 1 --warmup-iter 10 \
|
||||
# --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024 4096 16384
|
||||
import csv
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.config import get_config
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
def generate_test_data(
|
||||
num_tokens: int,
|
||||
num_q_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
max_position_embeddings: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
):
|
||||
"""Generate test data for given configuration."""
|
||||
# Create 2D positions (3, num_tokens) for multimodal case
|
||||
positions = torch.randint(
|
||||
0, max_position_embeddings // 4, (3, num_tokens), device=device
|
||||
)
|
||||
|
||||
# Create query and key tensors
|
||||
query = torch.randn(num_tokens, num_q_heads * head_size, dtype=dtype, device=device)
|
||||
key = torch.randn(num_tokens, num_kv_heads * head_size, dtype=dtype, device=device)
|
||||
|
||||
return positions, query, key
|
||||
|
||||
|
||||
def calculate_stats(times: list[float]) -> dict[str, float]:
|
||||
"""Calculate statistics from a list of times."""
|
||||
times_array = np.array(times)
|
||||
return {
|
||||
"mean": np.mean(times_array),
|
||||
"median": np.median(times_array),
|
||||
"p99": np.percentile(times_array, 99),
|
||||
"min": np.min(times_array),
|
||||
"max": np.max(times_array),
|
||||
}
|
||||
|
||||
|
||||
def benchmark_mrope(
|
||||
model_name: str,
|
||||
num_tokens: int,
|
||||
head_dim: int,
|
||||
tp_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
max_position: int = 8192,
|
||||
rope_theta: float = 10000,
|
||||
is_neox_style: bool = True,
|
||||
rope_scaling: dict[str, Any] = None,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
seed: int = 0,
|
||||
warmup_iter: int = 10,
|
||||
benchmark_iter: int = 100,
|
||||
csv_writer=None,
|
||||
):
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
# the parameters to compute the q k v size based on tp_size
|
||||
mrope_helper_class = get_rope(
|
||||
head_size=head_dim,
|
||||
rotary_dim=head_dim,
|
||||
max_position=max_position,
|
||||
base=rope_theta,
|
||||
is_neox_style=is_neox_style,
|
||||
rope_scaling=rope_scaling,
|
||||
dtype=dtype,
|
||||
).to(device=device)
|
||||
|
||||
print(80 * "=")
|
||||
print(
|
||||
f"Evaluating model: {model_name} "
|
||||
f"with tp_size: {tp_size} "
|
||||
f"and num_tokens: {num_tokens}, "
|
||||
f"dtype: {dtype}"
|
||||
)
|
||||
|
||||
# create q k v input tensors
|
||||
# create rotary pos emb input tensors
|
||||
positions, query, key = generate_test_data(
|
||||
num_tokens, num_heads, num_kv_heads, head_dim, max_position, dtype, device
|
||||
)
|
||||
|
||||
# Warm up
|
||||
for _ in range(warmup_iter):
|
||||
mrope_helper_class.forward_native(
|
||||
positions,
|
||||
query.clone(),
|
||||
key.clone(),
|
||||
)
|
||||
|
||||
mrope_helper_class.forward_cuda(
|
||||
positions,
|
||||
query.clone(),
|
||||
key.clone(),
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Time reference implementation
|
||||
torch_times = []
|
||||
for _ in range(benchmark_iter):
|
||||
query_clone = query.clone()
|
||||
key_clone = key.clone()
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
|
||||
mrope_helper_class.forward_native(
|
||||
positions,
|
||||
query_clone,
|
||||
key_clone,
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
torch_times.append(time.time() - start_time)
|
||||
|
||||
# Time triton kernel implementation
|
||||
triton_times = []
|
||||
for _ in range(benchmark_iter):
|
||||
query_clone = query.clone()
|
||||
key_clone = key.clone()
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
mrope_helper_class.forward_cuda(
|
||||
positions,
|
||||
query_clone,
|
||||
key_clone,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
triton_times.append(time.time() - start_time)
|
||||
|
||||
# Calculate statistics
|
||||
torch_stats = calculate_stats(torch_times)
|
||||
triton_stats = calculate_stats(triton_times)
|
||||
print(f"\nPerformance for config ({num_tokens}, {num_heads}, {num_kv_heads}):")
|
||||
|
||||
print(
|
||||
f"Torch implementation: "
|
||||
f"mean={torch_stats['mean']:.8f}s, "
|
||||
f"median={torch_stats['median']:.8f}s, "
|
||||
f"p99={torch_stats['p99']:.8f}s"
|
||||
)
|
||||
|
||||
print(
|
||||
f"Triton implementation: "
|
||||
f"mean={triton_stats['mean']:.8f}s, "
|
||||
f"median={triton_stats['median']:.8f}s, "
|
||||
f"p99={triton_stats['p99']:.8f}s"
|
||||
)
|
||||
|
||||
print(
|
||||
f"Triton Speedup over Torch: {torch_stats['mean'] / triton_stats['mean']:.8f}x"
|
||||
)
|
||||
|
||||
# Write to CSV
|
||||
if csv_writer:
|
||||
row = [
|
||||
model_name,
|
||||
tp_size,
|
||||
num_tokens,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
max_position,
|
||||
rope_theta,
|
||||
is_neox_style,
|
||||
str(rope_scaling),
|
||||
str(dtype).split(".")[-1],
|
||||
torch_stats["mean"],
|
||||
torch_stats["median"],
|
||||
torch_stats["p99"],
|
||||
torch_stats["min"],
|
||||
torch_stats["max"],
|
||||
triton_stats["mean"],
|
||||
triton_stats["median"],
|
||||
triton_stats["p99"],
|
||||
triton_stats["min"],
|
||||
triton_stats["max"],
|
||||
torch_stats["mean"] / triton_stats["mean"], # speedup
|
||||
]
|
||||
csv_writer.writerow(row)
|
||||
|
||||
return torch_stats, triton_stats
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark the rotary embedding kernels."
|
||||
)
|
||||
parser.add_argument("--model-name", type=str, default="")
|
||||
parser.add_argument("--tp-size", type=int, default=1)
|
||||
parser.add_argument("--warmup-iter", type=int, default=10)
|
||||
parser.add_argument("--benchmark-iter", type=int, default=100)
|
||||
parser.add_argument("--dtype", type=str, choices=["bfloat16"], default="bfloat16")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--num-tokens", type=int, nargs="+", required=False)
|
||||
parser.add_argument("--trust-remote-code", action="store_true")
|
||||
parser.add_argument("--output-csv", type=str, default="mrope_benchmark_results.csv")
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
# Create CSV file for results
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
csv_filename = f"{os.path.splitext(args.output_csv)[0]}_{timestamp}.csv"
|
||||
|
||||
with open(csv_filename, "w", newline="") as csvfile:
|
||||
csv_writer = csv.writer(csvfile)
|
||||
# Write header
|
||||
header = [
|
||||
"model_name",
|
||||
"tp_size",
|
||||
"num_tokens",
|
||||
"num_heads",
|
||||
"num_kv_heads",
|
||||
"head_dim",
|
||||
"max_position",
|
||||
"rope_theta",
|
||||
"is_neox_style",
|
||||
"rope_scaling",
|
||||
"dtype",
|
||||
"torch_mean",
|
||||
"torch_median",
|
||||
"torch_p99",
|
||||
"torch_min",
|
||||
"torch_max",
|
||||
"triton_mean",
|
||||
"triton_median",
|
||||
"triton_p99",
|
||||
"triton_min",
|
||||
"triton_max",
|
||||
"speedup",
|
||||
]
|
||||
csv_writer.writerow(header)
|
||||
|
||||
model_tp_dict = {}
|
||||
if args.model_name == "":
|
||||
model_tp_dict = {
|
||||
"Qwen/Qwen2-VL-2B-Instruct": [1],
|
||||
"Qwen/Qwen2-VL-7B-Instruct": [1],
|
||||
"Qwen/Qwen2-VL-72B-Instruct": [2, 4, 8],
|
||||
"Qwen/Qwen2.5-VL-3B-Instruct": [1, 2, 4, 8],
|
||||
"Qwen/Qwen2.5-VL-7B-Instruct": [1, 2, 4, 8],
|
||||
"Qwen/Qwen2.5-VL-72B-Instruct": [2, 4, 8],
|
||||
}
|
||||
else:
|
||||
model_tp_dict[args.model_name] = [args.tp_size]
|
||||
|
||||
if args.num_tokens is None:
|
||||
num_tokens_list = [2**i for i in range(0, 18)]
|
||||
else:
|
||||
num_tokens_list = args.num_tokens
|
||||
|
||||
for model_name, tp_list in model_tp_dict.items():
|
||||
config = get_config(model_name, trust_remote_code=args.trust_remote_code)
|
||||
for tp_size in tp_list:
|
||||
# get the model config
|
||||
total_num_kv_heads = config.num_key_value_heads
|
||||
total_num_heads = config.num_attention_heads
|
||||
num_heads = total_num_heads // tp_size
|
||||
num_kv_heads = max(1, total_num_kv_heads // tp_size)
|
||||
head_dim = config.hidden_size // total_num_heads
|
||||
q_size = num_heads * head_dim
|
||||
kv_size = num_kv_heads * head_dim
|
||||
is_neox_style = True
|
||||
rope_theta = config.rope_theta
|
||||
max_position = config.max_position_embeddings
|
||||
|
||||
for num_tokens in num_tokens_list:
|
||||
benchmark_mrope(
|
||||
model_name=model_name,
|
||||
num_tokens=num_tokens,
|
||||
head_dim=head_dim,
|
||||
tp_size=tp_size,
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
max_position=max_position,
|
||||
rope_theta=rope_theta,
|
||||
is_neox_style=is_neox_style,
|
||||
rope_scaling=config.rope_scaling,
|
||||
dtype=getattr(torch, args.dtype),
|
||||
seed=args.seed,
|
||||
warmup_iter=args.warmup_iter,
|
||||
benchmark_iter=args.benchmark_iter,
|
||||
csv_writer=csv_writer,
|
||||
)
|
||||
|
||||
print(f"Benchmark results saved to {csv_filename}")
|
||||
108
benchmarks/kv_cache/benchmark_block_pool.py
Normal file
108
benchmarks/kv_cache/benchmark_block_pool.py
Normal file
@ -0,0 +1,108 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import gc
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from tabulate import tabulate
|
||||
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
|
||||
|
||||
class Metric:
|
||||
def __init__(self) -> None:
|
||||
self.cnt: int = 0
|
||||
self.sum_v: int = 0
|
||||
self.max_v: Optional[int] = None
|
||||
|
||||
def update(self, v: int) -> None:
|
||||
self.cnt += 1
|
||||
self.sum_v += v
|
||||
if self.max_v is None:
|
||||
self.max_v = v
|
||||
else:
|
||||
self.max_v = max(self.max_v, v)
|
||||
|
||||
def avg_v(self) -> float:
|
||||
return self.sum_v * 1.0 / self.cnt
|
||||
|
||||
|
||||
def main(args):
|
||||
rows = []
|
||||
for allocate_block in args.allocate_blocks:
|
||||
# Enforce a GC collect ahead to minimize the impact among runs
|
||||
gc.collect()
|
||||
block_pool = BlockPool(num_gpu_blocks=args.num_gpu_blocks, enable_caching=True)
|
||||
|
||||
get_blocks_metric: Metric = Metric()
|
||||
free_blocks_metric: Metric = Metric()
|
||||
for _ in range(args.num_iteration):
|
||||
t1 = time.monotonic_ns()
|
||||
blocks = block_pool.get_new_blocks(allocate_block)
|
||||
t2 = time.monotonic_ns()
|
||||
block_pool.free_blocks(blocks)
|
||||
t3 = time.monotonic_ns()
|
||||
get_blocks_metric.update(t2 - t1)
|
||||
free_blocks_metric.update(t3 - t2)
|
||||
|
||||
if get_blocks_metric.max_v is not None and free_blocks_metric.max_v is not None:
|
||||
rows.append(
|
||||
[
|
||||
get_blocks_metric.cnt,
|
||||
args.num_gpu_blocks,
|
||||
allocate_block,
|
||||
get_blocks_metric.avg_v() / 1000000,
|
||||
get_blocks_metric.max_v / 1000000.0,
|
||||
free_blocks_metric.avg_v() / 1000000,
|
||||
free_blocks_metric.max_v / 1000000.0,
|
||||
]
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"No valid metrics found."
|
||||
f" {get_blocks_metric.max_v=} {free_blocks_metric.max_v=}"
|
||||
)
|
||||
|
||||
print(
|
||||
tabulate(
|
||||
rows,
|
||||
headers=[
|
||||
"Iterations",
|
||||
"Total\nBlocks",
|
||||
"Allocated\nBlocks",
|
||||
"Get Blocks\nAvg (ms)",
|
||||
"Get Blocks\nMax (ms)",
|
||||
"Free Blocks\nAvg (ms)",
|
||||
"Free Blocks\nMax (ms)",
|
||||
],
|
||||
tablefmt="grid",
|
||||
floatfmt=".6f",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def invoke_main() -> None:
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark the performance of BlockPool for KV Cache."
|
||||
)
|
||||
parser.add_argument("--num-gpu-blocks", type=int, default=100000)
|
||||
parser.add_argument(
|
||||
"--num-iteration",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of iterations to run to stablize final data readings",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allocate-blocks",
|
||||
type=int,
|
||||
nargs="*",
|
||||
default=[10, 50, 100, 500, 1000],
|
||||
help="Number of blocks to allocate",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
invoke_main() # pragma: no cover
|
||||
@ -1,71 +0,0 @@
|
||||
# Benchmark KV Cache Offloading with Multi-Turn Conversations
|
||||
|
||||
The requirements (pip) for `benchmark_serving_multi_turn.py` can be found in `requirements.txt`
|
||||
|
||||
First start serving your model
|
||||
|
||||
```bash
|
||||
export MODEL_NAME=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/
|
||||
|
||||
vllm serve $MODEL_NAME --disable-log-requests
|
||||
```
|
||||
|
||||
## Synthetic Multi-Turn Conversations
|
||||
|
||||
Download the following text file (used for generation of synthetic conversations)
|
||||
|
||||
```bash
|
||||
wget https://www.gutenberg.org/ebooks/1184.txt.utf-8
|
||||
mv 1184.txt.utf-8 pg1184.txt
|
||||
```
|
||||
|
||||
The filename `pg1184.txt` is used in `generate_multi_turn.json` (see `"text_files"`).
|
||||
|
||||
But you may use other text files if you prefer (using this specific file is not required).
|
||||
|
||||
Then run the benchmarking script
|
||||
|
||||
```bash
|
||||
export MODEL_NAME=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/
|
||||
|
||||
python benchmark_serving_multi_turn.py --model $MODEL_NAME --input-file generate_multi_turn.json \
|
||||
--num-clients 2 --max-active-conversations 6
|
||||
```
|
||||
|
||||
You can edit the file `generate_multi_turn.json` to change the conversation parameters (number of turns, etc.).
|
||||
|
||||
If successful, you will see the following output
|
||||
|
||||
```bash
|
||||
----------------------------------------------------------------------------------------------------
|
||||
Statistics summary:
|
||||
runtime_sec = 215.810
|
||||
requests_per_sec = 0.769
|
||||
----------------------------------------------------------------------------------------------------
|
||||
count mean std min 25% 50% 75% 90% 99% max
|
||||
ttft_ms 166.0 78.22 67.63 45.91 59.94 62.26 64.43 69.66 353.18 567.54
|
||||
tpot_ms 166.0 25.37 0.57 24.40 25.07 25.31 25.50 25.84 27.50 28.05
|
||||
latency_ms 166.0 2591.07 326.90 1998.53 2341.62 2573.01 2860.10 3003.50 3268.46 3862.94
|
||||
input_num_turns 166.0 7.43 4.57 1.00 3.00 7.00 11.00 13.00 17.00 17.00
|
||||
input_num_tokens 166.0 2006.20 893.56 522.00 1247.75 2019.00 2718.00 3233.00 3736.45 3899.00
|
||||
output_num_tokens 166.0 100.01 11.80 80.00 91.00 99.00 109.75 116.00 120.00 120.00
|
||||
output_num_chunks 166.0 99.01 11.80 79.00 90.00 98.00 108.75 115.00 119.00 119.00
|
||||
----------------------------------------------------------------------------------------------------
|
||||
```
|
||||
|
||||
## ShareGPT Conversations
|
||||
|
||||
To run with the ShareGPT data, download the following ShareGPT dataset:
|
||||
`https://huggingface.co/datasets/philschmid/sharegpt-raw/blob/main/sharegpt_20230401_clean_lang_split.json`
|
||||
|
||||
Use the `convert_sharegpt_to_openai.py` script to convert the dataset to a format supported by `benchmark_serving_multi_turn.py`
|
||||
|
||||
```bash
|
||||
python convert_sharegpt_to_openai.py sharegpt_20230401_clean_lang_split.json sharegpt_conv_128.json --seed=99 --max-items=128
|
||||
```
|
||||
|
||||
The script will convert the ShareGPT dataset to a dataset with the standard user/assistant roles.
|
||||
|
||||
The flag `--max-items=128` is used to sample 128 conversations from the original dataset (change as needed).
|
||||
|
||||
Use the output JSON file `sharegpt_conv_128.json` as the `--input-file` for `benchmark_serving_multi_turn.py`.
|
||||
@ -1,493 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from statistics import mean
|
||||
from typing import Any, NamedTuple, Optional, Union
|
||||
|
||||
import numpy as np # type: ignore
|
||||
import pandas as pd # type: ignore
|
||||
from bench_utils import (
|
||||
TEXT_SEPARATOR,
|
||||
Color,
|
||||
logger,
|
||||
)
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
|
||||
# Conversation ID is a string (e.g: "UzTK34D")
|
||||
ConvId = str
|
||||
|
||||
# A list of dicts (dicts with keys "id" and "messages")
|
||||
ShareGptConversations = list[dict[str, Any]]
|
||||
|
||||
# A list of dicts (dicts with keys "role" and "content")
|
||||
MessagesList = list[dict[str, str]]
|
||||
|
||||
# Map conversation ID to conversation messages
|
||||
ConversationsMap = list[ConvId, MessagesList]
|
||||
|
||||
|
||||
class Distribution(ABC):
|
||||
@abstractmethod
|
||||
def sample(self, size: int = 1) -> np.ndarray:
|
||||
pass
|
||||
|
||||
|
||||
class UniformDistribution(Distribution):
|
||||
def __init__(
|
||||
self,
|
||||
min_val: Union[int, float],
|
||||
max_val: Union[int, float],
|
||||
is_integer: bool = True,
|
||||
) -> None:
|
||||
self.min_val = min_val
|
||||
self.max_val = max_val
|
||||
self.is_integer = is_integer
|
||||
|
||||
def sample(self, size: int = 1) -> np.ndarray:
|
||||
if self.is_integer:
|
||||
return np.random.randint(
|
||||
int(self.min_val), int(self.max_val + 1), size=size
|
||||
)
|
||||
else:
|
||||
return np.random.uniform(self.min_val, self.max_val, size=size)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"UniformDistribution[{self.min_val}, {self.max_val}]"
|
||||
|
||||
|
||||
class ConstantDistribution(Distribution):
|
||||
def __init__(self, value: Union[int, float]) -> None:
|
||||
self.value = value
|
||||
self.max_val = value
|
||||
|
||||
def sample(self, size: int = 1) -> np.ndarray:
|
||||
return np.full(shape=size, fill_value=self.value)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Constant[{self.value}]"
|
||||
|
||||
|
||||
class ZipfDistribution(Distribution):
|
||||
def __init__(self, alpha: float, max_val: Optional[int] = None) -> None:
|
||||
self.alpha = alpha
|
||||
self.max_val = max_val
|
||||
|
||||
def sample(self, size: int = 1) -> np.ndarray:
|
||||
samples = np.random.zipf(self.alpha, size=size)
|
||||
if self.max_val:
|
||||
samples = np.minimum(samples, self.max_val)
|
||||
return samples
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"ZipfDistribution[{self.alpha}]"
|
||||
|
||||
|
||||
class PoissonDistribution(Distribution):
|
||||
def __init__(self, alpha: float, max_val: Optional[int] = None) -> None:
|
||||
self.alpha = alpha
|
||||
self.max_val = max_val
|
||||
|
||||
def sample(self, size: int = 1) -> np.ndarray:
|
||||
samples = np.random.poisson(self.alpha, size=size)
|
||||
if self.max_val:
|
||||
samples = np.minimum(samples, self.max_val)
|
||||
return samples
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"PoissonDistribution[{self.alpha}]"
|
||||
|
||||
|
||||
class LognormalDistribution(Distribution):
|
||||
def __init__(
|
||||
self, mean: float, sigma: float, max_val: Optional[int] = None
|
||||
) -> None:
|
||||
self.mean = mean
|
||||
self.sigma = sigma
|
||||
self.max_val = max_val
|
||||
|
||||
def sample(self, size: int = 1) -> np.ndarray:
|
||||
samples = np.random.lognormal(mean=self.mean, sigma=self.sigma, size=size)
|
||||
if self.max_val:
|
||||
samples = np.minimum(samples, self.max_val)
|
||||
|
||||
return np.round(samples).astype(int)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"LognormalDistribution[{self.mean}, {self.sigma}]"
|
||||
|
||||
|
||||
class GenConvArgs(NamedTuple):
|
||||
num_conversations: int
|
||||
text_files: list[str]
|
||||
input_num_turns: Distribution
|
||||
input_common_prefix_num_tokens: Distribution
|
||||
input_prefix_num_tokens: Distribution
|
||||
input_num_tokens: Distribution
|
||||
output_num_tokens: Distribution
|
||||
print_stats: bool
|
||||
|
||||
|
||||
def verify_field_exists(
|
||||
conf: dict, field_name: str, section: str, subsection: str
|
||||
) -> None:
|
||||
if field_name not in conf:
|
||||
raise ValueError(
|
||||
f"Missing field '{field_name}' in {section=} and {subsection=}"
|
||||
)
|
||||
|
||||
|
||||
def get_random_distribution(
|
||||
conf: dict, section: str, subsection: str, optional: bool = False
|
||||
) -> Distribution:
|
||||
# section can be "prompt_input" or "prompt_output" (both required)
|
||||
conf = conf[section]
|
||||
|
||||
if optional and subsection not in conf:
|
||||
# Optional subsection, if not found assume the value is always 0
|
||||
return ConstantDistribution(0)
|
||||
|
||||
# subsection can be "num_turns", "num_tokens" or "prefix_num_tokens"
|
||||
if subsection not in conf:
|
||||
raise ValueError(f"Missing subsection {subsection} in section {section}")
|
||||
|
||||
conf = conf[subsection]
|
||||
|
||||
distribution = conf.get("distribution")
|
||||
if distribution is None:
|
||||
raise ValueError(
|
||||
f"Missing field 'distribution' in {section=} and {subsection=}"
|
||||
)
|
||||
|
||||
if distribution == "constant":
|
||||
verify_field_exists(conf, "value", section, subsection)
|
||||
return ConstantDistribution(conf["value"])
|
||||
|
||||
elif distribution == "zipf":
|
||||
verify_field_exists(conf, "alpha", section, subsection)
|
||||
max_val = conf.get("max", None)
|
||||
return ZipfDistribution(conf["alpha"], max_val=max_val)
|
||||
|
||||
elif distribution == "poisson":
|
||||
verify_field_exists(conf, "alpha", section, subsection)
|
||||
max_val = conf.get("max", None)
|
||||
return PoissonDistribution(conf["alpha"], max_val=max_val)
|
||||
|
||||
elif distribution == "lognormal":
|
||||
verify_field_exists(conf, "mean", section, subsection)
|
||||
verify_field_exists(conf, "sigma", section, subsection)
|
||||
max_val = conf.get("max", None)
|
||||
return LognormalDistribution(conf["mean"], conf["sigma"], max_val=max_val)
|
||||
|
||||
elif distribution == "uniform":
|
||||
verify_field_exists(conf, "min", section, subsection)
|
||||
verify_field_exists(conf, "max", section, subsection)
|
||||
|
||||
min_value = conf["min"]
|
||||
max_value = conf["max"]
|
||||
|
||||
assert min_value > 0
|
||||
assert min_value <= max_value
|
||||
|
||||
is_integer = isinstance(min_value, int) and isinstance(max_value, int)
|
||||
return UniformDistribution(min_value, max_value, is_integer)
|
||||
else:
|
||||
raise ValueError(f"Unknown distribution: {distribution}")
|
||||
|
||||
|
||||
def parse_input_json_file(conf: dict) -> GenConvArgs:
|
||||
# Validate the input file
|
||||
assert isinstance(conf, dict)
|
||||
required_fields = [
|
||||
"filetype",
|
||||
"num_conversations",
|
||||
"text_files",
|
||||
"prompt_input",
|
||||
"prompt_output",
|
||||
]
|
||||
for field in required_fields:
|
||||
assert field in conf, f"Missing field {field} in input {conf}"
|
||||
|
||||
assert conf["filetype"] == "generate_conversations"
|
||||
|
||||
assert conf["num_conversations"] > 0, "num_conversations should be larger than zero"
|
||||
|
||||
text_files = conf["text_files"]
|
||||
|
||||
assert isinstance(text_files, list), "Field 'text_files' should be a list"
|
||||
assert len(text_files) > 0, (
|
||||
"Field 'text_files' should be a list with at least one file"
|
||||
)
|
||||
|
||||
# Parse the parameters for the prompt input/output workload
|
||||
input_num_turns = get_random_distribution(conf, "prompt_input", "num_turns")
|
||||
input_num_tokens = get_random_distribution(conf, "prompt_input", "num_tokens")
|
||||
input_common_prefix_num_tokens = get_random_distribution(
|
||||
conf, "prompt_input", "common_prefix_num_tokens", optional=True
|
||||
)
|
||||
input_prefix_num_tokens = get_random_distribution(
|
||||
conf, "prompt_input", "prefix_num_tokens"
|
||||
)
|
||||
output_num_tokens = get_random_distribution(conf, "prompt_output", "num_tokens")
|
||||
|
||||
print_stats: bool = conf.get("print_stats", False)
|
||||
assert isinstance(print_stats, bool), (
|
||||
"Field 'print_stats' should be either 'true' or 'false'"
|
||||
)
|
||||
|
||||
args = GenConvArgs(
|
||||
num_conversations=conf["num_conversations"],
|
||||
text_files=text_files,
|
||||
input_num_turns=input_num_turns,
|
||||
input_common_prefix_num_tokens=input_common_prefix_num_tokens,
|
||||
input_prefix_num_tokens=input_prefix_num_tokens,
|
||||
input_num_tokens=input_num_tokens,
|
||||
output_num_tokens=output_num_tokens,
|
||||
print_stats=print_stats,
|
||||
)
|
||||
return args
|
||||
|
||||
|
||||
def print_conv_stats(conversations: ConversationsMap, tokenizer: AutoTokenizer) -> None:
|
||||
# Collect statistics
|
||||
conv_stats: list[dict[Any, Any]] = []
|
||||
req_stats: list[int] = []
|
||||
|
||||
print("\nCollecting statistics...")
|
||||
for messages in conversations.values():
|
||||
# messages is a list of dicts
|
||||
user_tokens: list[int] = []
|
||||
assistant_tokens: list[int] = []
|
||||
request_tokens: list[int] = []
|
||||
|
||||
req_tokens = 0
|
||||
for m in messages:
|
||||
content = m["content"]
|
||||
num_tokens = len(tokenizer(content).input_ids)
|
||||
|
||||
if m["role"] == "user":
|
||||
user_tokens.append(num_tokens)
|
||||
# New user prompt including all chat history
|
||||
req_tokens += num_tokens
|
||||
request_tokens.append(req_tokens)
|
||||
|
||||
elif m["role"] == "assistant":
|
||||
assistant_tokens.append(num_tokens)
|
||||
# Update assistant answer
|
||||
# (will be part of chat history for the next user prompt)
|
||||
req_tokens += num_tokens
|
||||
|
||||
item_stats = {
|
||||
"conversation_turns": len(messages),
|
||||
"user_tokens": mean(user_tokens),
|
||||
"assistant_tokens": mean(assistant_tokens),
|
||||
}
|
||||
|
||||
conv_stats.append(item_stats)
|
||||
req_stats.extend(request_tokens)
|
||||
|
||||
# Print statistics
|
||||
percentiles = [0.25, 0.5, 0.75, 0.9, 0.99]
|
||||
|
||||
print(TEXT_SEPARATOR)
|
||||
print(f"{Color.YELLOW}Conversations statistics:{Color.RESET}")
|
||||
print(TEXT_SEPARATOR)
|
||||
df = pd.DataFrame(conv_stats)
|
||||
print(df.describe(percentiles=percentiles).transpose())
|
||||
print(TEXT_SEPARATOR)
|
||||
print(f"{Color.YELLOW}Request statistics:{Color.RESET}")
|
||||
print(TEXT_SEPARATOR)
|
||||
df = pd.DataFrame(req_stats, columns=["request_tokens"])
|
||||
print(df.describe(percentiles=percentiles).transpose())
|
||||
print(TEXT_SEPARATOR)
|
||||
|
||||
|
||||
def generate_conversations(
|
||||
args: GenConvArgs, tokenizer: AutoTokenizer
|
||||
) -> ConversationsMap:
|
||||
# Text for all user prompts
|
||||
# (text from the input text files will be appended to this line)
|
||||
base_prompt_text = "Please rewrite the following text and add more content: "
|
||||
base_prompt_token_count = len(
|
||||
tokenizer.encode(base_prompt_text, add_special_tokens=False)
|
||||
)
|
||||
|
||||
logger.info(f"{Color.PURPLE}Generating conversations...{Color.RESET}")
|
||||
logger.info(args)
|
||||
|
||||
list_of_tokens = []
|
||||
|
||||
for filename in args.text_files:
|
||||
# Load text file that will be used to generate prompts
|
||||
with open(filename) as file:
|
||||
data = file.read()
|
||||
tokens_in_file = tokenizer.encode(data, add_special_tokens=False)
|
||||
list_of_tokens.extend(tokens_in_file)
|
||||
|
||||
conversations: ConversationsMap = {}
|
||||
conv_id = 0
|
||||
|
||||
# Generate number of turns for every conversation
|
||||
turn_count: np.ndarray = args.input_num_turns.sample(args.num_conversations)
|
||||
|
||||
# Turn count should be at least 2 (one user prompt and one assistant answer)
|
||||
turn_count = np.maximum(turn_count, 2)
|
||||
|
||||
# Round up to an even number (every user prompt should have an answer)
|
||||
turn_count = turn_count + (turn_count % 2)
|
||||
|
||||
# Generate number of prefix tokens for every conversation
|
||||
conv_prefix_tokens: np.ndarray = args.input_prefix_num_tokens.sample(
|
||||
args.num_conversations
|
||||
)
|
||||
|
||||
# Used to reduce shared text between conversations
|
||||
# (jump/skip over text sections between conversations)
|
||||
base_offset = 0
|
||||
|
||||
# Common prefix size for all conversations (only 1 sample required)
|
||||
common_prefix_text = ""
|
||||
common_prefix_tokens: int = args.input_common_prefix_num_tokens.sample(1)[0]
|
||||
if common_prefix_tokens > 0:
|
||||
# Using "." at the end to separate sentences
|
||||
common_prefix_text = (
|
||||
tokenizer.decode(list_of_tokens[: common_prefix_tokens - 2]) + "."
|
||||
)
|
||||
base_offset += common_prefix_tokens
|
||||
|
||||
for conv_id in range(args.num_conversations):
|
||||
# Generate a single conversation
|
||||
messages: MessagesList = []
|
||||
|
||||
nturns = turn_count[conv_id]
|
||||
|
||||
# User prompt token count per turn (with lower limit)
|
||||
input_token_count: np.ndarray = args.input_num_tokens.sample(nturns)
|
||||
input_token_count = np.maximum(input_token_count, base_prompt_token_count)
|
||||
|
||||
# Assistant answer token count per turn (with lower limit)
|
||||
output_token_count: np.ndarray = args.output_num_tokens.sample(nturns)
|
||||
output_token_count = np.maximum(output_token_count, 1)
|
||||
|
||||
user_turn = True
|
||||
for turn_id in range(nturns):
|
||||
if user_turn:
|
||||
role = "user"
|
||||
num_tokens = input_token_count[turn_id]
|
||||
|
||||
# Generate the user prompt,
|
||||
# use a unique prefix (the conv_id) for each conversation
|
||||
# (to avoid shared prefix between conversations)
|
||||
content = f"{conv_id} is a nice number... "
|
||||
|
||||
if len(common_prefix_text) > 0 and turn_id == 0:
|
||||
content = common_prefix_text + content
|
||||
|
||||
# Update the number of tokens left for the content
|
||||
num_tokens -= len(tokenizer.encode(content, add_special_tokens=False))
|
||||
|
||||
if turn_id == 0:
|
||||
prefix_num_tokens = conv_prefix_tokens[conv_id]
|
||||
if prefix_num_tokens > 0:
|
||||
# Add prefix text (context) to the first turn
|
||||
start_offset = base_offset
|
||||
end_offset = start_offset + prefix_num_tokens
|
||||
assert len(list_of_tokens) > end_offset, (
|
||||
"Not enough input text to generate "
|
||||
f"{prefix_num_tokens} tokens for the "
|
||||
f"prefix text ({start_offset=}, {end_offset=})"
|
||||
)
|
||||
|
||||
content += f"{conv_id}, " + tokenizer.decode(
|
||||
list_of_tokens[start_offset:end_offset]
|
||||
)
|
||||
base_offset += prefix_num_tokens
|
||||
|
||||
# Add the actual user prompt/question after the prefix text
|
||||
content += base_prompt_text
|
||||
num_tokens -= base_prompt_token_count
|
||||
|
||||
if num_tokens > 0:
|
||||
# Add text from the input file (to reach the desired token count)
|
||||
start_offset = base_offset + turn_id * input_token_count.max()
|
||||
end_offset = start_offset + num_tokens
|
||||
assert len(list_of_tokens) > end_offset, (
|
||||
f"Not enough input text to generate {num_tokens} tokens "
|
||||
f"for the prompt ({start_offset=}, {end_offset=})"
|
||||
)
|
||||
|
||||
# Convert tokens back to text
|
||||
content += tokenizer.decode(list_of_tokens[start_offset:end_offset])
|
||||
else:
|
||||
role = "assistant"
|
||||
# This content will not be used as input to the LLM server
|
||||
# (actual answers will be used instead).
|
||||
# Content is only required to determine the min_tokens/max_tokens
|
||||
# (inputs to the LLM server).
|
||||
num_tokens = output_token_count[turn_id]
|
||||
assert len(list_of_tokens) > num_tokens, (
|
||||
f"Not enough input text to generate {num_tokens} "
|
||||
"tokens for assistant content"
|
||||
)
|
||||
content = tokenizer.decode(list_of_tokens[:num_tokens])
|
||||
|
||||
# Append the user/assistant message to the list of messages
|
||||
messages.append({"role": role, "content": content})
|
||||
user_turn = not user_turn
|
||||
|
||||
# Add the new conversation
|
||||
conversations[f"CONV_ID_{conv_id}"] = messages
|
||||
|
||||
# Increase base offset for the next conversation
|
||||
base_offset += nturns
|
||||
|
||||
if args.print_stats:
|
||||
print_conv_stats(conversations, tokenizer)
|
||||
|
||||
return conversations
|
||||
|
||||
|
||||
def conversations_list_to_dict(input_list: ShareGptConversations) -> ConversationsMap:
|
||||
conversations: ConversationsMap = {}
|
||||
|
||||
for item in input_list:
|
||||
conv_id: str = item["id"]
|
||||
assert isinstance(conv_id, str)
|
||||
|
||||
assert conv_id not in conversations, (
|
||||
f"Conversation ID {conv_id} found more than once in the input"
|
||||
)
|
||||
|
||||
messages: MessagesList = item["messages"]
|
||||
assert isinstance(messages, list), (
|
||||
f"Conversation messages should be a list (ID: {conv_id})"
|
||||
)
|
||||
assert len(messages) > 0, f"Conversation with no messages (ID: {conv_id})"
|
||||
|
||||
conversations[conv_id] = messages
|
||||
|
||||
logger.info(f"Using {len(conversations)} unique conversations (IDs)")
|
||||
assert len(conversations) == len(input_list)
|
||||
|
||||
# Print statistics about the selected conversations
|
||||
stats: list[dict[str, Any]] = []
|
||||
for conv_data in conversations.values():
|
||||
stats.append({"num_turns": len(conv_data)})
|
||||
|
||||
print(TEXT_SEPARATOR)
|
||||
print(f"{Color.YELLOW}Conversations statistics:{Color.RESET}")
|
||||
print(TEXT_SEPARATOR)
|
||||
percentiles = [0.25, 0.5, 0.75, 0.9, 0.99, 0.999, 0.9999]
|
||||
conv_stats = pd.DataFrame(stats).describe(percentiles=percentiles)
|
||||
print(conv_stats.transpose())
|
||||
print(TEXT_SEPARATOR)
|
||||
|
||||
return conversations
|
||||
|
||||
|
||||
def conversations_dict_to_list(input_dict: ConversationsMap) -> ShareGptConversations:
|
||||
output: ShareGptConversations = []
|
||||
for conv_id, conv_data in input_dict.items():
|
||||
new_item = {"id": conv_id, "messages": conv_data}
|
||||
output.append(new_item)
|
||||
|
||||
return output
|
||||
@ -1,28 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import logging
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class Color(Enum):
|
||||
RED = "\033[91m"
|
||||
GREEN = "\033[92m"
|
||||
BLUE = "\033[94m"
|
||||
PURPLE = "\033[95m"
|
||||
CYAN = "\033[96m"
|
||||
YELLOW = "\033[93m"
|
||||
RESET = "\033[0m"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
TEXT_SEPARATOR = "-" * 100
|
||||
|
||||
# Configure the logger
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] - %(message)s",
|
||||
datefmt="%d-%m-%Y %H:%M:%S",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,354 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Download dataset from:
|
||||
https://huggingface.co/datasets/philschmid/sharegpt-raw/blob/main/sharegpt_20230401_clean_lang_split.json
|
||||
|
||||
Convert to OpenAI API:
|
||||
export INPUT_FILE=sharegpt_20230401_clean_lang_split.json
|
||||
python convert_sharegpt_to_openai.py $INPUT_FILE sharegpt_conv_128.json --max-items=128
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import random
|
||||
from statistics import mean
|
||||
from typing import Any, Optional
|
||||
|
||||
import pandas as pd # type: ignore
|
||||
import tqdm # type: ignore
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
|
||||
|
||||
def has_non_english_chars(text: str) -> bool:
|
||||
return not text.isascii()
|
||||
|
||||
|
||||
def content_is_valid(
|
||||
content: str, min_content_len: Optional[int], max_content_len: Optional[int]
|
||||
) -> bool:
|
||||
if min_content_len and len(content) < min_content_len:
|
||||
return False
|
||||
|
||||
if max_content_len and len(content) > max_content_len:
|
||||
return False
|
||||
|
||||
return has_non_english_chars(content)
|
||||
|
||||
|
||||
def print_stats(
|
||||
conversations: "list[dict[Any, Any]]", tokenizer: Optional[AutoTokenizer] = None
|
||||
) -> None:
|
||||
# Collect statistics
|
||||
stats = []
|
||||
|
||||
print("\nCollecting statistics...")
|
||||
for item in tqdm.tqdm(conversations):
|
||||
# item has "id" and "messages"
|
||||
messages = item["messages"]
|
||||
|
||||
user_turns = 0
|
||||
assistant_turns = 0
|
||||
user_words = 0
|
||||
assistant_words = 0
|
||||
conv_chars = 0
|
||||
|
||||
user_tokens: list[int] = []
|
||||
assistant_tokens: list[int] = []
|
||||
|
||||
for m in messages:
|
||||
content = m["content"]
|
||||
conv_chars += len(content)
|
||||
content_num_words = content.count(" ") + 1
|
||||
|
||||
num_tokens = 0
|
||||
if tokenizer:
|
||||
num_tokens = len(tokenizer(m["content"]).input_ids)
|
||||
|
||||
if m["role"] == "user":
|
||||
user_turns += 1
|
||||
user_words += content_num_words
|
||||
if tokenizer:
|
||||
user_tokens.append(num_tokens)
|
||||
|
||||
elif m["role"] == "assistant":
|
||||
assistant_turns += 1
|
||||
assistant_words += content_num_words
|
||||
if tokenizer:
|
||||
assistant_tokens.append(num_tokens)
|
||||
|
||||
# assert user_turns == assistant_turns, \
|
||||
# f"Invalid conversation ID {item['id']}"
|
||||
|
||||
conv_words = user_words + assistant_words
|
||||
item_stats = {
|
||||
"user_turns": user_turns,
|
||||
"assistant_turns": assistant_turns,
|
||||
"user_words": user_words,
|
||||
"assistant_words": assistant_words,
|
||||
"conv_turns": len(messages),
|
||||
"conv_words": conv_words,
|
||||
"conv_characters": conv_chars,
|
||||
}
|
||||
|
||||
if len(user_tokens) > 0:
|
||||
item_stats["user_tokens"] = int(mean(user_tokens))
|
||||
|
||||
if len(assistant_tokens) > 0:
|
||||
item_stats["assistant_tokens"] = int(mean(assistant_tokens))
|
||||
|
||||
stats.append(item_stats)
|
||||
|
||||
print("\nStatistics:")
|
||||
percentiles = [0.25, 0.5, 0.75, 0.9, 0.99, 0.999, 0.9999]
|
||||
df = pd.DataFrame(stats)
|
||||
print(df.describe(percentiles=percentiles).transpose())
|
||||
|
||||
|
||||
def convert_sharegpt_to_openai(
|
||||
seed: int,
|
||||
input_file: str,
|
||||
output_file: str,
|
||||
max_items: Optional[int],
|
||||
min_content_len: Optional[int] = None,
|
||||
max_content_len: Optional[int] = None,
|
||||
min_turns: Optional[int] = None,
|
||||
max_turns: Optional[int] = None,
|
||||
model: Optional[str] = None,
|
||||
) -> None:
|
||||
if min_turns and max_turns:
|
||||
assert min_turns <= max_turns
|
||||
|
||||
if min_content_len and max_content_len:
|
||||
# Verify that min is not larger than max if both were given
|
||||
assert min_content_len <= max_content_len
|
||||
|
||||
print(
|
||||
f"Input parameters:\n{seed=}, {max_items=}, {min_content_len=},"
|
||||
f" {max_content_len=}, {min_turns=}, {max_turns=}\n"
|
||||
)
|
||||
|
||||
random.seed(seed)
|
||||
|
||||
tokenizer = None
|
||||
if model is not None:
|
||||
print(f"Loading tokenizer from: {model}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||
|
||||
# Read the ShareGPT JSON file
|
||||
print(f"Reading file: {input_file}")
|
||||
with open(input_file, encoding="utf-8") as f:
|
||||
# Should be a list of dicts
|
||||
# Each dict should have "id" (string) and "conversations" (list of dicts)
|
||||
sharegpt_data = json.load(f)
|
||||
|
||||
assert isinstance(sharegpt_data, list), "Input file should contain a list of dicts"
|
||||
|
||||
print(f"Total items in input file: {len(sharegpt_data):,}")
|
||||
|
||||
print(f"Shuffling dataset with seed {seed}")
|
||||
random.shuffle(sharegpt_data)
|
||||
|
||||
# Map conversation ID to the all the messages
|
||||
conversation_parts: dict[str, list[Any]] = {}
|
||||
|
||||
for item in tqdm.tqdm(sharegpt_data):
|
||||
assert "id" in item, "Missing key 'id'"
|
||||
assert "conversations" in item, "Missing key 'conversations'"
|
||||
|
||||
# Conversation ID (e.g: "hiWPlMD") and part/session (0, 1, 2, etc.)
|
||||
conv_id, _ = item["id"].split("_")
|
||||
new_turns = item["conversations"]
|
||||
|
||||
if conv_id not in conversation_parts:
|
||||
# Start new conversation
|
||||
conversation_parts[conv_id] = []
|
||||
elif len(conversation_parts[conv_id]) > 0 and len(new_turns) > 0:
|
||||
prev_turns = conversation_parts[conv_id][-1]
|
||||
if prev_turns[-1]["from"] == new_turns[0]["from"]:
|
||||
new_turns = new_turns[1:]
|
||||
|
||||
if len(new_turns) > 0:
|
||||
# We assume that parts are in order in the ShareGPT dataset
|
||||
conversation_parts[conv_id].append(new_turns)
|
||||
|
||||
dataset: list[dict[str, Any]] = []
|
||||
for conv_id, conv_parts in conversation_parts.items():
|
||||
new_item = {"id": conv_id}
|
||||
|
||||
conversations: list[dict[str, str]] = []
|
||||
|
||||
# Merge all parts
|
||||
for conv_part in conv_parts:
|
||||
conversations.extend(conv_part)
|
||||
|
||||
if len(conversations) > 0:
|
||||
new_item["conversations"] = conversations
|
||||
dataset.append(new_item)
|
||||
|
||||
print(f"Total unique conversations (IDs) in input file: {len(dataset):,}")
|
||||
|
||||
# Final output data
|
||||
final_openai_dataset: list[dict] = []
|
||||
|
||||
# Filter conversations from the ShareGPT dataset and convert to OpenAI format
|
||||
for item in tqdm.tqdm(dataset):
|
||||
messages: list[dict] = []
|
||||
|
||||
assert "id" in item, "Missing key 'id'"
|
||||
assert "conversations" in item, "Missing key 'conversations'"
|
||||
|
||||
conv_id = item["id"]
|
||||
conversations = item["conversations"]
|
||||
|
||||
if min_turns is not None and len(conversations) < min_turns:
|
||||
# Skip short conversations
|
||||
continue
|
||||
|
||||
# Convert each message in the conversation, up to max_turns if specified
|
||||
for i, turn in enumerate(conversations):
|
||||
assert "from" in turn and "value" in turn, (
|
||||
f"Invalid conversation ID {conv_id} - missing 'from' or 'value'"
|
||||
)
|
||||
|
||||
role = None
|
||||
turn_from = turn["from"]
|
||||
|
||||
if turn_from in {"human", "user"}:
|
||||
role = "user"
|
||||
elif turn_from in {"gpt", "bing", "chatgpt", "bard"}:
|
||||
role = "assistant"
|
||||
elif turn_from == "system":
|
||||
role = "system"
|
||||
|
||||
assert role is not None, (
|
||||
f"Invalid conversation ID {conv_id} - 'from'='{turn_from}' is invalid"
|
||||
)
|
||||
|
||||
if i == 0 and role != "user":
|
||||
# If the first message is from assistant (gpt), skip it.
|
||||
# this happens when the conversation is a follow-up
|
||||
# to a previous conversation (from the same user).
|
||||
continue
|
||||
|
||||
if max_turns is not None and i >= max_turns:
|
||||
break
|
||||
|
||||
# Convert message to OpenAI format (with "role" and "content")
|
||||
content = turn["value"]
|
||||
messages.append({"role": role, "content": content})
|
||||
|
||||
# Add the converted conversation to the OpenAI format
|
||||
if len(messages) > 0:
|
||||
valid_messages = True
|
||||
|
||||
# First turn should always be from the user
|
||||
user_turn = True
|
||||
|
||||
for m in messages:
|
||||
# Make sure that turns alternate between user and assistant
|
||||
if (user_turn and m["role"] != "user") or (
|
||||
not user_turn and m["role"] != "assistant"
|
||||
):
|
||||
valid_messages = False
|
||||
break
|
||||
|
||||
user_turn = not user_turn
|
||||
|
||||
content = m["content"]
|
||||
valid_messages = content_is_valid(
|
||||
content, min_content_len, max_content_len
|
||||
)
|
||||
if not valid_messages:
|
||||
break
|
||||
|
||||
if valid_messages is True:
|
||||
final_openai_dataset.append({"id": conv_id, "messages": messages})
|
||||
|
||||
assert len(final_openai_dataset) > 0, "Final number of conversations is zero"
|
||||
|
||||
print_stats(final_openai_dataset)
|
||||
|
||||
print_stats_again = False
|
||||
if max_items is not None and len(final_openai_dataset) > max_items:
|
||||
print(f"\n\nSampling {max_items} items from the dataset...")
|
||||
print_stats_again = True
|
||||
final_openai_dataset = random.sample(final_openai_dataset, max_items)
|
||||
|
||||
if print_stats_again:
|
||||
# Print stats after the dataset changed
|
||||
print_stats(final_openai_dataset, tokenizer)
|
||||
|
||||
# Write the converted data to a new JSON file
|
||||
final_size = len(final_openai_dataset)
|
||||
print(f"\nTotal conversations converted (after filtering): {final_size:,}")
|
||||
print(f"\nWriting file: {output_file}")
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump(final_openai_dataset, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert ShareGPT dataset to OpenAI API format"
|
||||
)
|
||||
parser.add_argument("input_file", help="Path to the input ShareGPT JSON file")
|
||||
parser.add_argument(
|
||||
"output_file", help="Path to the output OpenAI format JSON file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", type=int, default=0, help="Seed for random number generators"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-items",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum number of items in the output file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min-turns",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Minimum number of turns per conversation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-turns",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum number of turns per conversation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min-content-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Min number of characters in the messages' content",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-content-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Max number of characters in the messages' content",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default=None,
|
||||
help="LLM model, only the tokenizer will be used",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_sharegpt_to_openai(
|
||||
args.seed,
|
||||
args.input_file,
|
||||
args.output_file,
|
||||
args.max_items,
|
||||
args.min_content_len,
|
||||
args.max_content_len,
|
||||
args.min_turns,
|
||||
args.max_turns,
|
||||
args.model,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -1,35 +0,0 @@
|
||||
{
|
||||
"filetype": "generate_conversations",
|
||||
"num_conversations": 24,
|
||||
"text_files": ["pg1184.txt"],
|
||||
"print_stats": false,
|
||||
"prompt_input": {
|
||||
"num_turns": {
|
||||
"distribution": "uniform",
|
||||
"min": 12,
|
||||
"max": 18
|
||||
},
|
||||
"common_prefix_num_tokens": {
|
||||
"distribution": "constant",
|
||||
"value": 500
|
||||
},
|
||||
"prefix_num_tokens": {
|
||||
"distribution": "lognormal",
|
||||
"mean": 6,
|
||||
"sigma": 4,
|
||||
"max": 1500
|
||||
},
|
||||
"num_tokens": {
|
||||
"distribution": "uniform",
|
||||
"min": 120,
|
||||
"max": 160
|
||||
}
|
||||
},
|
||||
"prompt_output": {
|
||||
"num_tokens": {
|
||||
"distribution": "uniform",
|
||||
"min": 80,
|
||||
"max": 120
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,5 +0,0 @@
|
||||
numpy>=1.24
|
||||
pandas>=2.0.0
|
||||
aiohttp>=3.10
|
||||
transformers>=4.46
|
||||
xlsxwriter>=3.2.1
|
||||
@ -19,7 +19,7 @@ else()
|
||||
FetchContent_Declare(
|
||||
flashmla
|
||||
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git
|
||||
GIT_TAG 0e43e774597682284358ff2c54530757b654b8d1
|
||||
GIT_TAG 575f7724b9762f265bbee5889df9c7d630801845
|
||||
GIT_PROGRESS TRUE
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
@ -37,9 +37,9 @@ cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
|
||||
set(FlashMLA_SOURCES
|
||||
${flashmla_SOURCE_DIR}/csrc/flash_api.cpp
|
||||
${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/kernels/mla_combine.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu)
|
||||
${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_bf16_sm90.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_fp16_sm90.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_metadata.cu)
|
||||
|
||||
set(FlashMLA_INCLUDES
|
||||
${flashmla_SOURCE_DIR}/csrc/cutlass/include
|
||||
|
||||
@ -38,7 +38,7 @@ else()
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn
|
||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||
GIT_TAG 93cf5a08f421a3efd0c4a7e005ef8f742b578ce0
|
||||
GIT_TAG 6dbc6e011a3ebe9349eeb74578940dd7095436ba
|
||||
GIT_PROGRESS TRUE
|
||||
# Don't share the vllm-flash-attn build between build types
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
|
||||
@ -60,13 +60,3 @@ struct enable_sm100_only : Kernel {
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Kernel>
|
||||
struct enable_sm120_only : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE void operator()(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1200
|
||||
Kernel::operator()(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@ -45,9 +45,6 @@ struct SSMParamsBase {
|
||||
index_t out_d_stride;
|
||||
index_t out_z_batch_stride;
|
||||
index_t out_z_d_stride;
|
||||
index_t ssm_states_batch_stride;
|
||||
index_t ssm_states_dim_stride;
|
||||
index_t ssm_states_dstate_stride;
|
||||
|
||||
// Common data pointers.
|
||||
void *__restrict__ A_ptr;
|
||||
|
||||
@ -132,10 +132,8 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
||||
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + sequence_start_index * params.B_batch_stride + group_id * params.B_group_stride;
|
||||
weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
|
||||
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + sequence_start_index * params.C_batch_stride + group_id * params.C_group_stride;
|
||||
input_t *ssm_states = reinterpret_cast<input_t *>(params.ssm_states_ptr) +
|
||||
cache_index * params.ssm_states_batch_stride +
|
||||
dim_id * kNRows * params.ssm_states_dim_stride;
|
||||
|
||||
input_t *ssm_states = reinterpret_cast<input_t *>(params.ssm_states_ptr) + (cache_index * params.dim + dim_id * kNRows) * params.dstate;
|
||||
|
||||
float D_val[kNRows] = {0};
|
||||
if (params.D_ptr != nullptr) {
|
||||
#pragma unroll
|
||||
@ -250,7 +248,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
||||
}
|
||||
// Initialize running total
|
||||
|
||||
scan_t running_prefix = chunk > 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.0, has_initial_state ? float(ssm_states[state_idx * params.ssm_states_dstate_stride]): 0.0);
|
||||
scan_t running_prefix = chunk > 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.0, has_initial_state ? float(ssm_states[state_idx]): 0.0);
|
||||
|
||||
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
|
||||
typename Ktraits::BlockScanT(smem_scan).InclusiveScan(
|
||||
@ -261,7 +259,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
|
||||
if (threadIdx.x == 0) {
|
||||
smem_running_prefix[state_idx] = prefix_op.running_prefix;
|
||||
if (chunk == n_chunks - 1) {
|
||||
ssm_states[state_idx * params.ssm_states_dstate_stride] = input_t(prefix_op.running_prefix.y);
|
||||
ssm_states[state_idx] = input_t(prefix_op.running_prefix.y);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
@ -483,10 +481,6 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
||||
params.out_batch_stride = out.stride(1);
|
||||
params.out_d_stride = out.stride(0);
|
||||
|
||||
params.ssm_states_batch_stride = ssm_states.stride(0);
|
||||
params.ssm_states_dim_stride = ssm_states.stride(1);
|
||||
params.ssm_states_dstate_stride = ssm_states.stride(2);
|
||||
|
||||
}
|
||||
else{
|
||||
if (!is_variable_B) {
|
||||
@ -515,10 +509,6 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
||||
}
|
||||
params.out_batch_stride = out.stride(0);
|
||||
params.out_d_stride = out.stride(1);
|
||||
|
||||
params.ssm_states_batch_stride = ssm_states.stride(0);
|
||||
params.ssm_states_dim_stride = ssm_states.stride(1);
|
||||
params.ssm_states_dstate_stride = ssm_states.stride(2);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -188,9 +188,7 @@ __launch_bounds__(TPB) __global__ void moeTopK(
|
||||
It fuses the softmax, max and argmax into a single kernel.
|
||||
|
||||
Limitations:
|
||||
1) This implementation is optimized for when the number of experts is a small power of 2.
|
||||
Additionally it also supports when number of experts is multiple of 64 which is still
|
||||
faster than the computing softmax and topK separately (only tested on CUDA yet).
|
||||
1) This implementation is intended for when the number of experts is a small power of 2.
|
||||
2) This implementation assumes k is small, but will work for any k.
|
||||
*/
|
||||
|
||||
@ -200,6 +198,8 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
|
||||
int* source_rows, const int k, const int start_expert, const int end_expert)
|
||||
{
|
||||
// We begin by enforcing compile time assertions and setting up compile time constants.
|
||||
static_assert(VPT == (VPT & -VPT), "VPT must be power of 2");
|
||||
static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2");
|
||||
static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2");
|
||||
static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
|
||||
|
||||
@ -407,10 +407,12 @@ struct TopkConstants
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <int EXPERTS, int WARPS_PER_TB, int WARP_SIZE_PARAM, int MAX_BYTES_PER_LDG, typename IndType>
|
||||
template <int EXPERTS, int WARPS_PER_TB, int WARP_SIZE_PARAM, typename IndType>
|
||||
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices,
|
||||
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
|
||||
{
|
||||
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
|
||||
|
||||
static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
|
||||
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM>;
|
||||
static constexpr int VPT = Constants::VPT;
|
||||
@ -423,27 +425,21 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
|
||||
input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert);
|
||||
}
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \
|
||||
static_assert(WARP_SIZE == 32, \
|
||||
"Unsupported warp size. Only 32 is supported for CUDA"); \
|
||||
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, WARP_SIZE, MAX_BYTES>( \
|
||||
gating_output, nullptr, topk_weights, topk_indices, \
|
||||
token_expert_indices, num_tokens, topk, 0, num_experts, stream);
|
||||
#else
|
||||
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB, MAX_BYTES) \
|
||||
if (WARP_SIZE == 64) { \
|
||||
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 64, MAX_BYTES>( \
|
||||
gating_output, nullptr, topk_weights, topk_indices, \
|
||||
token_expert_indices, num_tokens, topk, 0, num_experts, stream); \
|
||||
} else if (WARP_SIZE == 32) { \
|
||||
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 32, MAX_BYTES>( \
|
||||
gating_output, nullptr, topk_weights, topk_indices, \
|
||||
token_expert_indices, num_tokens, topk, 0, num_experts, stream); \
|
||||
} else { \
|
||||
assert(false && "Unsupported warp size. Only 32 and 64 are supported for ROCm"); \
|
||||
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \
|
||||
switch (warpSize) { \
|
||||
case 32: \
|
||||
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 32>( \
|
||||
gating_output, nullptr, topk_weights, topk_indices, \
|
||||
token_expert_indices, num_tokens, topk, 0, num_experts, stream); \
|
||||
break; \
|
||||
case 64: \
|
||||
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 64>( \
|
||||
gating_output, nullptr, topk_weights, topk_indices, \
|
||||
token_expert_indices, num_tokens, topk, 0, num_experts, stream); \
|
||||
break; \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported warp size: ", warpSize); \
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename IndType>
|
||||
void topkGatingSoftmaxKernelLauncher(
|
||||
@ -457,64 +453,38 @@ void topkGatingSoftmaxKernelLauncher(
|
||||
const int topk,
|
||||
cudaStream_t stream) {
|
||||
static constexpr int WARPS_PER_TB = 4;
|
||||
static constexpr int BYTES_PER_LDG_POWER_OF_2 = 16;
|
||||
#ifndef USE_ROCM
|
||||
static constexpr int BYTES_PER_LDG_MULTIPLE_64 = 8;
|
||||
#endif
|
||||
auto warpSize = WARP_SIZE;
|
||||
switch (num_experts) {
|
||||
case 1:
|
||||
LAUNCH_SOFTMAX(1, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
LAUNCH_SOFTMAX(1, WARPS_PER_TB);
|
||||
break;
|
||||
case 2:
|
||||
LAUNCH_SOFTMAX(2, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
LAUNCH_SOFTMAX(2, WARPS_PER_TB);
|
||||
break;
|
||||
case 4:
|
||||
LAUNCH_SOFTMAX(4, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
LAUNCH_SOFTMAX(4, WARPS_PER_TB);
|
||||
break;
|
||||
case 8:
|
||||
LAUNCH_SOFTMAX(8, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
LAUNCH_SOFTMAX(8, WARPS_PER_TB);
|
||||
break;
|
||||
case 16:
|
||||
LAUNCH_SOFTMAX(16, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
LAUNCH_SOFTMAX(16, WARPS_PER_TB);
|
||||
break;
|
||||
case 32:
|
||||
LAUNCH_SOFTMAX(32, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
LAUNCH_SOFTMAX(32, WARPS_PER_TB);
|
||||
break;
|
||||
case 64:
|
||||
LAUNCH_SOFTMAX(64, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
LAUNCH_SOFTMAX(64, WARPS_PER_TB);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_SOFTMAX(128, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
LAUNCH_SOFTMAX(128, WARPS_PER_TB);
|
||||
break;
|
||||
case 256:
|
||||
LAUNCH_SOFTMAX(256, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
LAUNCH_SOFTMAX(256, WARPS_PER_TB);
|
||||
break;
|
||||
case 512:
|
||||
LAUNCH_SOFTMAX(512, WARPS_PER_TB, BYTES_PER_LDG_POWER_OF_2);
|
||||
break;
|
||||
// (CUDA only) support multiples of 64 when num_experts is not power of 2.
|
||||
// ROCm uses WARP_SIZE 64 so 8 bytes loading won't fit for some of num_experts,
|
||||
// alternatively we can test 4 bytes loading and enable it in future.
|
||||
#ifndef USE_ROCM
|
||||
case 192:
|
||||
LAUNCH_SOFTMAX(192, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
|
||||
break;
|
||||
case 320:
|
||||
LAUNCH_SOFTMAX(320, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
|
||||
break;
|
||||
case 384:
|
||||
LAUNCH_SOFTMAX(384, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
|
||||
break;
|
||||
case 448:
|
||||
LAUNCH_SOFTMAX(448, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
|
||||
break;
|
||||
case 576:
|
||||
LAUNCH_SOFTMAX(576, WARPS_PER_TB, BYTES_PER_LDG_MULTIPLE_64);
|
||||
break;
|
||||
#endif
|
||||
default: {
|
||||
TORCH_CHECK(softmax_workspace != nullptr,
|
||||
"softmax_workspace must be provided for num_experts that are not a power of 2 or multiple of 64.");
|
||||
"softmax_workspace must be provided for num_experts that are not a power of 2.");
|
||||
static constexpr int TPB = 256;
|
||||
moeSoftmax<TPB><<<num_tokens, TPB, 0, stream>>>(
|
||||
gating_output, nullptr, softmax_workspace, num_experts);
|
||||
|
||||
@ -1,23 +0,0 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_blockwise_sm120_fp8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm120_fp8(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::bfloat16_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
cutlass_gemm_blockwise_sm120_fp8_dispatch<cutlass::half_t>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@ -1,183 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "cuda_utils.h"
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
|
||||
#include "cutlass/epilogue/dispatch_policy.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
|
||||
|
||||
#include "cutlass_gemm_caller.cuh"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
// clang-format off
|
||||
template <class OutType, int ScaleGranularityM,
|
||||
int ScaleGranularityN, int ScaleGranularityK,
|
||||
class MmaTileShape, class ClusterShape,
|
||||
class EpilogueScheduler, class MainloopScheduler>
|
||||
struct cutlass_3x_gemm_fp8_blockwise {
|
||||
using ElementAB = cutlass::float_e4m3_t;
|
||||
|
||||
using ElementA = ElementAB;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
||||
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
||||
|
||||
using ElementB = ElementAB;
|
||||
// ColumnMajor is used for B to match the CUTLASS convention.
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
|
||||
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
||||
|
||||
using ElementD = OutType;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
using LayoutD_Transpose = typename cutlass::layout::LayoutTranspose<LayoutD>::type;
|
||||
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
using ElementC = void; // TODO: support bias
|
||||
using LayoutC = LayoutD;
|
||||
using LayoutC_Transpose = LayoutD_Transpose;
|
||||
static constexpr int AlignmentC = AlignmentD;
|
||||
|
||||
using ElementAccumulator = float;
|
||||
using ElementCompute = float;
|
||||
using ElementBlockScale = float;
|
||||
|
||||
using ScaleConfig = cutlass::detail::Sm120BlockwiseScaleConfig<
|
||||
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
|
||||
cute::UMMA::Major::MN, cute::UMMA::Major::K>;
|
||||
|
||||
// layout_SFA and layout_SFB cannot be swapped since they are deduced.
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
|
||||
using ArchTag = cutlass::arch::Sm120;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
using ElementScalar = float;
|
||||
using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementScalar, RoundStyle>;
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator,
|
||||
ElementCompute,
|
||||
ElementC,
|
||||
LayoutC,
|
||||
AlignmentC,
|
||||
ElementD,
|
||||
LayoutD,
|
||||
AlignmentD,
|
||||
EpilogueScheduler,
|
||||
DefaultOperation
|
||||
>::CollectiveOp;
|
||||
|
||||
using StageCountType = cutlass::gemm::collective::StageCountAuto;
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag,
|
||||
OperatorClass,
|
||||
ElementA,
|
||||
cute::tuple<LayoutA, LayoutSFA>,
|
||||
AlignmentA,
|
||||
ElementB,
|
||||
cute::tuple<LayoutB, LayoutSFB>,
|
||||
AlignmentB,
|
||||
ElementAccumulator,
|
||||
MmaTileShape,
|
||||
ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
MainloopScheduler
|
||||
>::CollectiveOp;
|
||||
|
||||
using KernelType = enable_sm120_only<cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue>>;
|
||||
|
||||
struct GemmKernel : public KernelType {};
|
||||
};
|
||||
|
||||
template <typename Gemm>
|
||||
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
using StrideA = typename Gemm::GemmKernel::StrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::StrideB;
|
||||
using StrideD = typename Gemm::GemmKernel::StrideD;
|
||||
using StrideC = typename Gemm::GemmKernel::StrideC;
|
||||
using LayoutSFA = typename Gemm::LayoutSFA;
|
||||
using LayoutSFB = typename Gemm::LayoutSFB;
|
||||
using ScaleConfig = typename Gemm::ScaleConfig;
|
||||
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
|
||||
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
|
||||
|
||||
StrideA a_stride;
|
||||
StrideB b_stride;
|
||||
StrideC c_stride;
|
||||
a_stride =
|
||||
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
|
||||
b_stride =
|
||||
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
|
||||
c_stride =
|
||||
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1));
|
||||
|
||||
LayoutSFA layout_SFA =
|
||||
ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
|
||||
LayoutSFB layout_SFB =
|
||||
ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
|
||||
|
||||
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
||||
auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr());
|
||||
auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr());
|
||||
|
||||
auto mainloop_args = [&](){
|
||||
return typename GemmKernel::MainloopArguments{
|
||||
a_ptr, a_stride, b_ptr, b_stride,
|
||||
a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB
|
||||
};
|
||||
}();
|
||||
auto prob_shape = cute::make_shape(m, n, k, 1);
|
||||
|
||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
{}, c_ptr, c_stride, c_ptr, c_stride};
|
||||
c3x::cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
|
||||
epilogue_args);
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
// TODO: better heuristics
|
||||
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
||||
OutType, 1, 128, 128, Shape<_128, _128, _128>,
|
||||
Shape<_1, _1, _1>, cutlass::epilogue::collective::EpilogueScheduleAuto,
|
||||
cutlass::gemm::collective::KernelScheduleAuto>>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@ -47,10 +47,4 @@ void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales);
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm120_fp8(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales);
|
||||
} // namespace vllm
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
#include "c3x/scaled_mm_helper.hpp"
|
||||
#include <cudaTypedefs.h>
|
||||
#include "c3x/scaled_mm_kernels.hpp"
|
||||
|
||||
#include "cuda_utils.h"
|
||||
|
||||
/*
|
||||
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
||||
NVIDIA GPUs with sm120 (Blackwell).
|
||||
NVIDIA GPUs with sm120 (Blackwell Geforce).
|
||||
*/
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
|
||||
@ -13,10 +15,20 @@ void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
|
||||
vllm::cutlass_scaled_mm_sm120_fp8,
|
||||
nullptr, // int8 not supported on SM120
|
||||
vllm::cutlass_scaled_mm_blockwise_sm120_fp8);
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
int M = a.size(0), N = b.size(1), K = a.size(1);
|
||||
TORCH_CHECK(
|
||||
(a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
|
||||
(b_scales.numel() == 1 || b_scales.numel() == b.size(1)),
|
||||
"Currently, block scaled fp8 gemm is not implemented for Blackwell");
|
||||
|
||||
// Standard per-tensor/per-token/per-channel scaling
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn,
|
||||
"Currently, only fp8 gemm is implemented for Blackwell");
|
||||
vllm::cutlass_scaled_mm_sm120_fp8(c, a, b, a_scales, b_scales, bias);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@ -270,7 +270,7 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
const int num_kv_heads,
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
@ -304,12 +304,12 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
|
||||
const auto max_num_partitions = gridDim.y;
|
||||
|
||||
const int seq_len = seq_lens[seq_idx];
|
||||
const int context_len = context_lens[seq_idx];
|
||||
|
||||
const int partition_start_token_idx =
|
||||
partition_idx * T_PAR_SIZE; // partition_size;
|
||||
// exit if partition is out of context for seq
|
||||
if (partition_start_token_idx >= seq_len) {
|
||||
if (partition_start_token_idx >= context_len) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -361,8 +361,8 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
// output layout from QKmfma : QH16xT4x4 16 qheads across 16 lanes, 16 tokens
|
||||
// across 4 rows x 4 tokens per lane
|
||||
|
||||
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
|
||||
const int last_seq_block = num_seq_blocks - 1;
|
||||
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
|
||||
const int last_ctx_block = num_context_blocks - 1;
|
||||
|
||||
const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq;
|
||||
|
||||
@ -373,9 +373,9 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
const int klocal_token_idx =
|
||||
TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id;
|
||||
const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx;
|
||||
const int kblock_idx = (kglobal_token_idx < seq_len)
|
||||
const int kblock_idx = (kglobal_token_idx < context_len)
|
||||
? kglobal_token_idx / BLOCK_SIZE
|
||||
: last_seq_block;
|
||||
: last_ctx_block;
|
||||
kphysical_block_number[token_depth] = block_table_seq[kblock_idx];
|
||||
}
|
||||
|
||||
@ -476,9 +476,9 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
// tokens
|
||||
const int vglobal_token_idx =
|
||||
partition_start_token_idx + vlocal_token_idx;
|
||||
const int vblock_idx = (vglobal_token_idx < seq_len)
|
||||
const int vblock_idx = (vglobal_token_idx < context_len)
|
||||
? vglobal_token_idx / BLOCK_SIZE
|
||||
: last_seq_block;
|
||||
: last_ctx_block;
|
||||
vphysical_block_number[vtoken_depth][vblock_depth] =
|
||||
block_table_seq[vblock_idx];
|
||||
}
|
||||
@ -554,7 +554,7 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
if constexpr (ALIBI_ENABLED) {
|
||||
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
|
||||
const int local_token_idx = qkout_token_idx + token_depth * 16;
|
||||
const int alibi_offset = local_token_idx - seq_len + 1;
|
||||
const int alibi_offset = local_token_idx - context_len + 1;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
d_out[token_depth][i] += alibi_slope * (alibi_offset + i);
|
||||
}
|
||||
@ -568,8 +568,9 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
|
||||
const int local_token_idx = qkout_token_idx + token_depth * 16;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
const float tmp =
|
||||
(local_token_idx + i < seq_len) ? d_out[token_depth][i] : -FLT_MAX;
|
||||
const float tmp = (local_token_idx + i < context_len)
|
||||
? d_out[token_depth][i]
|
||||
: -FLT_MAX;
|
||||
qk_max = fmaxf(qk_max, tmp);
|
||||
}
|
||||
}
|
||||
@ -581,7 +582,7 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
|
||||
const int local_token_idx = qkout_token_idx + token_depth * 16;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
const float tmp = (local_token_idx + i < seq_len)
|
||||
const float tmp = (local_token_idx + i < context_len)
|
||||
? __expf(d_out[token_depth][i] - qk_max)
|
||||
: 0.0f;
|
||||
d_out[token_depth][i] = tmp;
|
||||
@ -779,7 +780,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
const int num_kv_heads,
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
@ -808,10 +809,10 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
const auto partition_size = blockDim.x;
|
||||
const auto max_num_partitions = gridDim.y;
|
||||
|
||||
const int seq_len = seq_lens[seq_idx];
|
||||
const int context_len = context_lens[seq_idx];
|
||||
const int partition_start_token_idx = partition_idx * partition_size;
|
||||
// exit if partition is out of context for seq
|
||||
if (partition_start_token_idx >= seq_len) {
|
||||
if (partition_start_token_idx >= context_len) {
|
||||
return;
|
||||
}
|
||||
// every 4 lanes fetch 4 different qheads
|
||||
@ -854,7 +855,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
const int warp_start_token_idx =
|
||||
partition_start_token_idx + warpid * WARP_SIZE;
|
||||
|
||||
if (warp_start_token_idx >= seq_len) { // warp out of context
|
||||
if (warp_start_token_idx >= context_len) { // warp out of context
|
||||
#pragma unroll
|
||||
for (int h = 0; h < GQA_RATIO4; h++) {
|
||||
shared_qk_max[warpid][h] = -FLT_MAX;
|
||||
@ -862,8 +863,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
}
|
||||
} else { // warp within context
|
||||
|
||||
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
|
||||
const int last_seq_block = num_seq_blocks - 1;
|
||||
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
|
||||
const int last_ctx_block = num_context_blocks - 1;
|
||||
|
||||
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
|
||||
// token id within partition
|
||||
@ -872,9 +873,9 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
const int global_token_idx = partition_start_token_idx + local_token_idx;
|
||||
|
||||
// fetch block number for k
|
||||
const int block_idx = (global_token_idx < seq_len)
|
||||
const int block_idx = (global_token_idx < context_len)
|
||||
? global_token_idx / BLOCK_SIZE
|
||||
: last_seq_block;
|
||||
: last_ctx_block;
|
||||
|
||||
// fetch k physical block number
|
||||
// int32 physical_block_number leads to overflow when multiplied with
|
||||
@ -887,7 +888,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
for (int b = 0; b < VBLOCKS; b++) {
|
||||
const int vblock_idx = warp_start_block_idx + b;
|
||||
const int vblock_idx_ctx =
|
||||
(vblock_idx <= last_seq_block) ? vblock_idx : last_seq_block;
|
||||
(vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block;
|
||||
vphysical_blocks[b] = block_table[vblock_idx_ctx];
|
||||
}
|
||||
|
||||
@ -1056,7 +1057,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
const int lane4_token_idx = 4 * (global_token_idx >> 2);
|
||||
|
||||
if constexpr (ALIBI_ENABLED) {
|
||||
const int alibi_offset = lane4_token_idx - seq_len + 1;
|
||||
const int alibi_offset = lane4_token_idx - context_len + 1;
|
||||
for (int h = 0; h < QHLOOP; h++) {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
d_out[h][i] += alibi_slope[h] * (alibi_offset + i);
|
||||
@ -1069,7 +1070,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
for (int h = 0; h < QHLOOP; h++) {
|
||||
qk_max[h] = -FLT_MAX;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
qk_max[h] = (lane4_token_idx + i < seq_len)
|
||||
qk_max[h] = (lane4_token_idx + i < context_len)
|
||||
? fmaxf(qk_max[h], d_out[h][i])
|
||||
: qk_max[h];
|
||||
}
|
||||
@ -1100,7 +1101,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
for (int h = 0; h < QHLOOP; h++) {
|
||||
exp_sum[h] = 0.0f;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
d_out[h][i] = (lane4_token_idx + i < seq_len)
|
||||
d_out[h][i] = (lane4_token_idx + i < context_len)
|
||||
? __expf(d_out[h][i] - qk_max[h])
|
||||
: 0.0f;
|
||||
exp_sum[h] += d_out[h][i];
|
||||
@ -1180,7 +1181,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
if (warp_start_token_idx >= seq_len) { // warp out of context
|
||||
if (warp_start_token_idx >= context_len) { // warp out of context
|
||||
for (int qh = 0; qh < QHLOOP; qh++) {
|
||||
for (int vh = 0; vh < VHELOOP; vh++) {
|
||||
vout_shared[qh][vh][laneid][warpid] = {0};
|
||||
@ -1278,7 +1279,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
// max_num_partitions]
|
||||
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
|
||||
// max_num_partitions, head_size]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
|
||||
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) {
|
||||
const auto num_heads = gridDim.x;
|
||||
@ -1292,8 +1293,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
return;
|
||||
}
|
||||
|
||||
const int seq_len = seq_lens[seq_idx];
|
||||
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
|
||||
const int context_len = context_lens[seq_idx];
|
||||
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
|
||||
const auto warpid = threadIdx.x / WARP_SIZE;
|
||||
|
||||
__shared__ float shared_global_exp_sum;
|
||||
@ -1580,7 +1581,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
// head_size, block_size]
|
||||
const int num_kv_heads, const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
@ -1614,11 +1615,11 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
|
||||
const int max_num_partitions = gridDim.y;
|
||||
|
||||
const int seq_len = seq_lens[seq_idx]; // length of a seq
|
||||
const int context_len = context_lens[seq_idx]; // length of a seq
|
||||
|
||||
const int partition_start_token_idx = partition_idx * T_PAR_SIZE;
|
||||
// exit if partition is out of context for seq
|
||||
if (partition_start_token_idx >= seq_len) {
|
||||
if (partition_start_token_idx >= context_len) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -1714,8 +1715,8 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
|
||||
const int last_seq_block = num_seq_blocks - 1;
|
||||
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
|
||||
const int last_ctx_block = num_context_blocks - 1;
|
||||
|
||||
const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq;
|
||||
|
||||
@ -1726,9 +1727,9 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
const int klocal_token_idx =
|
||||
TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id;
|
||||
const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx;
|
||||
const int kblock_idx = (kglobal_token_idx < seq_len)
|
||||
const int kblock_idx = (kglobal_token_idx < context_len)
|
||||
? kglobal_token_idx / BLOCK_SIZE
|
||||
: last_seq_block;
|
||||
: last_ctx_block;
|
||||
kphysical_block_number[token_depth] = block_table_seq[kblock_idx];
|
||||
}
|
||||
|
||||
@ -1780,9 +1781,9 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
vblock_depth * BLOCK_SIZE;
|
||||
const int vglobal_token_idx =
|
||||
partition_start_token_idx + vlocal_token_idx;
|
||||
const int vblock_idx = (vglobal_token_idx < seq_len)
|
||||
const int vblock_idx = (vglobal_token_idx < context_len)
|
||||
? vglobal_token_idx / BLOCK_SIZE
|
||||
: last_seq_block;
|
||||
: last_ctx_block;
|
||||
vphysical_block_number[vtoken_depth][vblock_depth] =
|
||||
block_table_seq[vblock_idx];
|
||||
}
|
||||
@ -1835,8 +1836,9 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
|
||||
const int local_token_idx = qkout_token_idx + token_depth * 16;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
const float tmp =
|
||||
(local_token_idx + 2 * i < seq_len) ? dout[token_depth][i] : -FLT_MAX;
|
||||
const float tmp = (local_token_idx + 2 * i < context_len)
|
||||
? dout[token_depth][i]
|
||||
: -FLT_MAX;
|
||||
qk_max = fmaxf(qk_max, tmp);
|
||||
}
|
||||
}
|
||||
@ -1846,7 +1848,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
|
||||
const int local_token_idx = qkout_token_idx + token_depth * 16;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
const float tmp = (local_token_idx + 2 * i < seq_len)
|
||||
const float tmp = (local_token_idx + 2 * i < context_len)
|
||||
? __expf(dout[token_depth][i] - qk_max)
|
||||
: 0.0f;
|
||||
dout[token_depth][i] = tmp;
|
||||
@ -2017,7 +2019,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
// head_size, block_size]
|
||||
const int num_kv_heads, const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
@ -2044,7 +2046,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
// max_num_partitions]
|
||||
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
|
||||
// max_num_partitions, head_size]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
|
||||
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) {
|
||||
const auto num_heads = gridDim.x;
|
||||
@ -2058,8 +2060,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
return;
|
||||
}
|
||||
|
||||
const int seq_len = seq_lens[seq_idx];
|
||||
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
|
||||
const int context_len = context_lens[seq_idx];
|
||||
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
|
||||
const int warpid = threadIdx.x / WARP_SIZE;
|
||||
|
||||
__shared__ float shared_global_exp_sum;
|
||||
@ -2347,7 +2349,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
// head_size, block_size]
|
||||
const int num_kv_heads, const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
@ -2380,11 +2382,11 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
|
||||
const int max_num_partitions = gridDim.y;
|
||||
|
||||
const int seq_len = seq_lens[seq_idx]; // length of a seq
|
||||
const int context_len = context_lens[seq_idx]; // length of a seq
|
||||
|
||||
const int partition_start_token_idx = partition_idx * T_PAR_SIZE;
|
||||
// exit if partition is out of context for seq
|
||||
if (partition_start_token_idx >= seq_len) {
|
||||
if (partition_start_token_idx >= context_len) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -2480,8 +2482,8 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
|
||||
const int last_seq_block = num_seq_blocks - 1;
|
||||
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
|
||||
const int last_ctx_block = num_context_blocks - 1;
|
||||
|
||||
const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq;
|
||||
|
||||
@ -2492,9 +2494,9 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
const int klocal_token_idx =
|
||||
TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id;
|
||||
const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx;
|
||||
const int kblock_idx = (kglobal_token_idx < seq_len)
|
||||
const int kblock_idx = (kglobal_token_idx < context_len)
|
||||
? kglobal_token_idx / BLOCK_SIZE
|
||||
: last_seq_block;
|
||||
: last_ctx_block;
|
||||
kphysical_block_number[token_depth] = block_table_seq[kblock_idx];
|
||||
}
|
||||
|
||||
@ -2546,9 +2548,9 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
rowid * VTOKENS_PER_LANE + vblock_depth * BLOCK_SIZE;
|
||||
const int vglobal_token_idx =
|
||||
partition_start_token_idx + vlocal_token_idx;
|
||||
const int vblock_idx = (vglobal_token_idx < seq_len)
|
||||
const int vblock_idx = (vglobal_token_idx < context_len)
|
||||
? vglobal_token_idx / BLOCK_SIZE
|
||||
: last_seq_block;
|
||||
: last_ctx_block;
|
||||
vphysical_block_number[vtoken_depth][vblock_depth] =
|
||||
block_table_seq[vblock_idx];
|
||||
}
|
||||
@ -2602,7 +2604,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
const int local_token_idx = qkout_token_idx + token_depth * 16;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
const float tmp =
|
||||
(local_token_idx + i < seq_len) ? dout[token_depth][i] : -FLT_MAX;
|
||||
(local_token_idx + i < context_len) ? dout[token_depth][i] : -FLT_MAX;
|
||||
qk_max = fmaxf(qk_max, tmp);
|
||||
}
|
||||
}
|
||||
@ -2612,7 +2614,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
|
||||
const int local_token_idx = qkout_token_idx + token_depth * 16;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
const float tmp = (local_token_idx + i < seq_len)
|
||||
const float tmp = (local_token_idx + i < context_len)
|
||||
? __expf(dout[token_depth][i] - qk_max)
|
||||
: 0.0f;
|
||||
dout[token_depth][i] = tmp;
|
||||
@ -2749,7 +2751,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
// head_size, block_size]
|
||||
const int num_kv_heads, const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
@ -2776,7 +2778,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
// max_num_partitions]
|
||||
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
|
||||
// max_num_partitions, head_size]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
|
||||
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) {
|
||||
const auto num_heads = gridDim.x;
|
||||
@ -2790,8 +2792,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
return;
|
||||
}
|
||||
|
||||
const int seq_len = seq_lens[seq_idx];
|
||||
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
|
||||
const int context_len = context_lens[seq_idx];
|
||||
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
|
||||
const int warpid = threadIdx.x / WARP_SIZE;
|
||||
|
||||
__shared__ float shared_global_exp_sum;
|
||||
@ -2978,7 +2980,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
const int num_kv_heads,
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
@ -3005,7 +3007,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
const int num_kv_heads,
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
@ -3029,7 +3031,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||
const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
|
||||
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) {
|
||||
UNREACHABLE_CODE
|
||||
@ -3044,7 +3046,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
GQA_RATIO> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
|
||||
block_tables_ptr, seq_lens_ptr, query_start_loc_ptr, \
|
||||
block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \
|
||||
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
|
||||
kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
|
||||
max_ctx_blocks, k_scale_ptr, v_scale_ptr);
|
||||
@ -3055,17 +3057,18 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
GQA_RATIO> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
|
||||
block_tables_ptr, seq_lens_ptr, query_start_loc_ptr, \
|
||||
block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \
|
||||
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
|
||||
kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
|
||||
max_ctx_blocks, k_scale_ptr, v_scale_ptr);
|
||||
|
||||
#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \
|
||||
paged_attention_ll4mi_reduce_kernel<T, OUTT, HEAD_SIZE, HEAD_SIZE, \
|
||||
PARTITION_SIZE, NPAR_LOOPS> \
|
||||
<<<reduce_grid, reduce_block, 0, stream>>>( \
|
||||
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
|
||||
query_start_loc_ptr, max_num_partitions, fp8_out_scale_ptr);
|
||||
#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \
|
||||
paged_attention_ll4mi_reduce_kernel<T, OUTT, HEAD_SIZE, HEAD_SIZE, \
|
||||
PARTITION_SIZE, NPAR_LOOPS> \
|
||||
<<<reduce_grid, reduce_block, 0, stream>>>( \
|
||||
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \
|
||||
context_lens_ptr, query_start_loc_ptr, max_num_partitions, \
|
||||
fp8_out_scale_ptr);
|
||||
|
||||
template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
|
||||
int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD,
|
||||
@ -3074,8 +3077,8 @@ void paged_attention_custom_launcher(
|
||||
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, const int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens,
|
||||
const std::optional<torch::Tensor>& query_start_loc, int max_seq_len,
|
||||
torch::Tensor& block_tables, torch::Tensor& context_lens,
|
||||
const std::optional<torch::Tensor>& query_start_loc, int max_context_len,
|
||||
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale, const std::optional<torch::Tensor>& fp8_out_scale) {
|
||||
int num_seqs = block_tables.size(0);
|
||||
@ -3106,7 +3109,7 @@ void paged_attention_custom_launcher(
|
||||
KVT* key_cache_ptr = reinterpret_cast<KVT*>(key_cache.data_ptr());
|
||||
KVT* value_cache_ptr = reinterpret_cast<KVT*>(value_cache.data_ptr());
|
||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int* seq_lens_ptr = seq_lens.data_ptr<int>();
|
||||
int* context_lens_ptr = context_lens.data_ptr<int>();
|
||||
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
|
||||
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
|
||||
// NOTE: fp8_out_scale is optional.
|
||||
@ -3116,12 +3119,13 @@ void paged_attention_custom_launcher(
|
||||
: nullptr;
|
||||
OUTT* out_ptr = reinterpret_cast<OUTT*>(out.data_ptr());
|
||||
|
||||
const int max_ctx_blocks = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE);
|
||||
const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE);
|
||||
|
||||
// partition size is fixed at 256 since both mfma4 and mfma16 kernels support
|
||||
// it mfma4 kernel also supports partition size 512
|
||||
constexpr int PARTITION_SIZE = 256;
|
||||
const int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
|
||||
const int max_num_partitions =
|
||||
DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
|
||||
const int gqa_ratio = num_heads / num_kv_heads;
|
||||
assert(num_heads % num_kv_heads == 0);
|
||||
assert(head_size == HEAD_SIZE);
|
||||
@ -3230,8 +3234,8 @@ void paged_attention_custom_launcher_navi(
|
||||
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, const int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens,
|
||||
const std::optional<torch::Tensor>& query_start_loc, int max_seq_len,
|
||||
torch::Tensor& block_tables, torch::Tensor& context_lens,
|
||||
const std::optional<torch::Tensor>& query_start_loc, int max_context_len,
|
||||
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale) {
|
||||
int num_seqs = block_tables.size(0);
|
||||
@ -3259,7 +3263,7 @@ void paged_attention_custom_launcher_navi(
|
||||
KVT* key_cache_ptr = reinterpret_cast<KVT*>(key_cache.data_ptr());
|
||||
KVT* value_cache_ptr = reinterpret_cast<KVT*>(value_cache.data_ptr());
|
||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int* seq_lens_ptr = seq_lens.data_ptr<int>();
|
||||
int* context_lens_ptr = context_lens.data_ptr<int>();
|
||||
|
||||
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
|
||||
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
|
||||
@ -3267,10 +3271,11 @@ void paged_attention_custom_launcher_navi(
|
||||
const auto fp8_out_scale_ptr = nullptr;
|
||||
OUTT* out_ptr = reinterpret_cast<OUTT*>(out.data_ptr());
|
||||
|
||||
const int max_ctx_blocks = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE);
|
||||
const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE);
|
||||
|
||||
constexpr int PARTITION_SIZE = 256;
|
||||
const int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
|
||||
const int max_num_partitions =
|
||||
DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
|
||||
const int gqa_ratio = num_heads / num_kv_heads;
|
||||
assert(num_heads % num_kv_heads == 0);
|
||||
assert(head_size == HEAD_SIZE);
|
||||
@ -3402,14 +3407,14 @@ void paged_attention_custom_launcher_navi(
|
||||
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
|
||||
OUTT, PSIZE, ALIBI_ENABLED>( \
|
||||
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
||||
num_kv_heads, scale, block_tables, seq_lens, query_start_loc, \
|
||||
max_seq_len, alibi_slopes, k_scale, v_scale, fp8_out_scale); \
|
||||
num_kv_heads, scale, block_tables, context_lens, query_start_loc, \
|
||||
max_context_len, alibi_slopes, k_scale, v_scale, fp8_out_scale); \
|
||||
} else { \
|
||||
paged_attention_custom_launcher_navi< \
|
||||
T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, ALIBI_ENABLED>( \
|
||||
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
||||
num_kv_heads, scale, block_tables, seq_lens, query_start_loc, \
|
||||
max_seq_len, alibi_slopes, k_scale, v_scale); \
|
||||
num_kv_heads, scale, block_tables, context_lens, query_start_loc, \
|
||||
max_context_len, alibi_slopes, k_scale, v_scale); \
|
||||
}
|
||||
|
||||
#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
|
||||
@ -3497,9 +3502,9 @@ void paged_attention(
|
||||
int64_t num_kv_heads,
|
||||
double scale,
|
||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
torch::Tensor& seq_lens, // [num_seqs]
|
||||
torch::Tensor& context_lens, // [num_seqs]
|
||||
const std::optional<torch::Tensor>& query_start_loc, // [num_seqs]
|
||||
int64_t block_size, int64_t max_seq_len,
|
||||
int64_t block_size, int64_t max_context_len,
|
||||
const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale,
|
||||
|
||||
@ -15,8 +15,8 @@ void paged_attention(
|
||||
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens,
|
||||
torch::Tensor& block_tables, torch::Tensor& context_lens,
|
||||
const std::optional<torch::Tensor>& query_start_loc, int64_t block_size,
|
||||
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
|
||||
int64_t max_context_len, const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale, const std::optional<torch::Tensor>& fp8_out_scale);
|
||||
|
||||
@ -41,10 +41,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
|
||||
" Tensor query, Tensor key_cache,"
|
||||
" Tensor value_cache, int num_kv_heads,"
|
||||
" float scale, Tensor block_tables,"
|
||||
" Tensor seq_lens,"
|
||||
" Tensor context_lens,"
|
||||
" Tensor? query_start_loc,"
|
||||
" int block_size,"
|
||||
" int max_seq_len,"
|
||||
" int max_context_len,"
|
||||
" Tensor? alibi_slopes,"
|
||||
" str kv_cache_dtype,"
|
||||
" Tensor k_scale, Tensor v_scale,"
|
||||
|
||||
@ -210,7 +210,16 @@ ARG SCCACHE_REGION_NAME=us-west-2
|
||||
ARG SCCACHE_S3_NO_CREDENTIALS=0
|
||||
|
||||
# Flag to control whether to use pre-built vLLM wheels
|
||||
ARG VLLM_USE_PRECOMPILED=""
|
||||
ARG VLLM_USE_PRECOMPILED
|
||||
# TODO: in setup.py VLLM_USE_PRECOMPILED is sensitive to truthiness, it will take =0 as "true", this should be fixed
|
||||
ENV VLLM_USE_PRECOMPILED=""
|
||||
RUN if [ "${VLLM_USE_PRECOMPILED}" = "1" ]; then \
|
||||
export VLLM_USE_PRECOMPILED=1 && \
|
||||
echo "Using precompiled wheels"; \
|
||||
else \
|
||||
unset VLLM_USE_PRECOMPILED && \
|
||||
echo "Leaving VLLM_USE_PRECOMPILED unset to build wheels from source"; \
|
||||
fi
|
||||
|
||||
# if USE_SCCACHE is set, use sccache to speed up compilation
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
@ -227,8 +236,6 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
&& export SCCACHE_S3_NO_CREDENTIALS=${SCCACHE_S3_NO_CREDENTIALS} \
|
||||
&& export SCCACHE_IDLE_TIMEOUT=0 \
|
||||
&& export CMAKE_BUILD_TYPE=Release \
|
||||
&& export VLLM_USE_PRECOMPILED="${VLLM_USE_PRECOMPILED}" \
|
||||
&& export VLLM_DOCKER_BUILD_CONTEXT=1 \
|
||||
&& sccache --show-stats \
|
||||
&& python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38 \
|
||||
&& sccache --show-stats; \
|
||||
@ -242,8 +249,6 @@ RUN --mount=type=cache,target=/root/.cache/ccache \
|
||||
# Clean any existing CMake artifacts
|
||||
rm -rf .deps && \
|
||||
mkdir -p .deps && \
|
||||
export VLLM_USE_PRECOMPILED="${VLLM_USE_PRECOMPILED}" && \
|
||||
export VLLM_DOCKER_BUILD_CONTEXT=1 && \
|
||||
python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38; \
|
||||
fi
|
||||
|
||||
@ -387,7 +392,7 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
|
||||
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
|
||||
# Keep this in sync with https://github.com/vllm-project/vllm/blob/main/requirements/cuda.txt
|
||||
# We use `--force-reinstall --no-deps` to avoid issues with the existing FlashInfer wheel.
|
||||
ARG FLASHINFER_GIT_REF="v0.2.11"
|
||||
ARG FLASHINFER_GIT_REF="v0.2.9"
|
||||
RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
|
||||
. /etc/environment
|
||||
git clone --depth 1 --recursive --shallow-submodules \
|
||||
@ -432,7 +437,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
|
||||
# Install DeepGEMM from source
|
||||
ARG DEEPGEMM_GIT_REPO="https://github.com/deepseek-ai/DeepGEMM.git"
|
||||
ARG DEEPGEMM_GIT_REF="7b6b5563b9d4c1ae07ffbce7f78ad3ac9204827c"
|
||||
ARG DEEPGEMM_GIT_REF="187656694f7f69e3e7975617a68bc3387680a7e1"
|
||||
RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
|
||||
. /etc/environment
|
||||
CUDA_MAJOR="${CUDA_VERSION%%.*}"
|
||||
|
||||
@ -113,7 +113,6 @@ WORKDIR /workspace/vllm
|
||||
|
||||
RUN --mount=type=bind,src=requirements/test.in,target=requirements/test.in \
|
||||
cp requirements/test.in requirements/cpu-test.in && \
|
||||
sed -i '/mamba_ssm/d' requirements/cpu-test.in && \
|
||||
sed -i 's/^torch==.*/torch==2.6.0/g' requirements/cpu-test.in && \
|
||||
sed -i 's/torchaudio.*/torchaudio/g' requirements/cpu-test.in && \
|
||||
sed -i 's/torchvision.*/torchvision/g' requirements/cpu-test.in && \
|
||||
|
||||
@ -1,12 +1,9 @@
|
||||
FROM intel/deep-learning-essentials:2025.1.3-0-devel-ubuntu24.04 AS vllm-base
|
||||
# oneapi 2025.0.2 docker base image use rolling 2448 package. https://dgpu-docs.intel.com/releases/packages.html?release=Rolling+2448.13&os=Ubuntu+22.04, and we don't need install driver manually.
|
||||
FROM intel/deep-learning-essentials:2025.0.2-0-devel-ubuntu22.04 AS vllm-base
|
||||
|
||||
RUN rm /etc/apt/sources.list.d/intel-graphics.list
|
||||
|
||||
RUN apt clean && apt-get update -y && \
|
||||
apt-get install -y software-properties-common && \
|
||||
add-apt-repository ppa:deadsnakes/ppa && \
|
||||
apt-get install -y python3.10 python3.10-distutils && \
|
||||
curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10 && \
|
||||
RUN apt-get update -y && \
|
||||
apt-get install -y --no-install-recommends --fix-missing \
|
||||
curl \
|
||||
ffmpeg \
|
||||
@ -17,13 +14,11 @@ RUN apt clean && apt-get update -y && \
|
||||
libgl1 \
|
||||
lsb-release \
|
||||
numactl \
|
||||
python3.10-dev \
|
||||
python3 \
|
||||
python3-dev \
|
||||
python3-pip \
|
||||
wget
|
||||
|
||||
|
||||
RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1
|
||||
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.10 1
|
||||
|
||||
WORKDIR /workspace/vllm
|
||||
COPY requirements/xpu.txt /workspace/vllm/requirements/xpu.txt
|
||||
COPY requirements/common.txt /workspace/vllm/requirements/common.txt
|
||||
|
||||
@ -1,17 +1,25 @@
|
||||
nav:
|
||||
- Home: README.md
|
||||
- User Guide:
|
||||
- usage/README.md
|
||||
- Home:
|
||||
- vLLM: README.md
|
||||
- Getting Started:
|
||||
- getting_started/quickstart.md
|
||||
- getting_started/installation
|
||||
- Examples:
|
||||
- examples/README.md
|
||||
- Offline Inference: examples/offline_inference
|
||||
- Online Serving: examples/online_serving
|
||||
- Others: examples/others
|
||||
- Quick Links:
|
||||
- User Guide: usage/README.md
|
||||
- Developer Guide: contributing/README.md
|
||||
- API Reference: api/README.md
|
||||
- CLI Reference: cli/README.md
|
||||
- Timeline:
|
||||
- Roadmap: https://roadmap.vllm.ai
|
||||
- Releases: https://github.com/vllm-project/vllm/releases
|
||||
- User Guide:
|
||||
- Summary: usage/README.md
|
||||
- usage/v1_guide.md
|
||||
- General:
|
||||
- usage/v1_guide.md
|
||||
- usage/*
|
||||
- Inference and Serving:
|
||||
- serving/offline_inference.md
|
||||
@ -24,7 +32,7 @@ nav:
|
||||
- deployment/integrations
|
||||
- Training: training
|
||||
- Configuration:
|
||||
- configuration/README.md
|
||||
- Summary: configuration/README.md
|
||||
- configuration/*
|
||||
- Models:
|
||||
- models/supported_models.md
|
||||
@ -37,11 +45,11 @@ nav:
|
||||
- features/*
|
||||
- features/quantization
|
||||
- Developer Guide:
|
||||
- contributing/README.md
|
||||
- Summary: contributing/README.md
|
||||
- General:
|
||||
- glob: contributing/*
|
||||
flatten_single_child_sections: true
|
||||
- Model Implementation:
|
||||
- Model Implementation:
|
||||
- contributing/model/README.md
|
||||
- contributing/model/basic.md
|
||||
- contributing/model/registration.md
|
||||
@ -50,9 +58,12 @@ nav:
|
||||
- CI: contributing/ci
|
||||
- Design Documents: design
|
||||
- API Reference:
|
||||
- api/README.md
|
||||
- api/vllm/*
|
||||
- CLI Reference: cli
|
||||
- Summary: api/README.md
|
||||
- Contents:
|
||||
- glob: api/vllm/*
|
||||
preserve_directory_names: true
|
||||
- CLI Reference:
|
||||
- Summary: cli/README.md
|
||||
- Community:
|
||||
- community/*
|
||||
- Blog: https://blog.vllm.ai
|
||||
|
||||
@ -1,9 +1,3 @@
|
||||
---
|
||||
hide:
|
||||
- navigation
|
||||
- toc
|
||||
---
|
||||
|
||||
# Welcome to vLLM
|
||||
|
||||
<figure markdown="span">
|
||||
@ -27,17 +21,6 @@ vLLM is a fast and easy-to-use library for LLM inference and serving.
|
||||
|
||||
Originally developed in the [Sky Computing Lab](https://sky.cs.berkeley.edu) at UC Berkeley, vLLM has evolved into a community-driven project with contributions from both academia and industry.
|
||||
|
||||
Where to get started with vLLM depends on the type of user. If you are looking to:
|
||||
|
||||
- Run open-source models on vLLM, we recommend starting with the [Quickstart Guide](./getting_started/quickstart.md)
|
||||
- Build applications with vLLM, we recommend starting with the [User Guide](./usage)
|
||||
- Build vLLM, we recommend starting with [Developer Guide](./contributing)
|
||||
|
||||
For information about the development of vLLM, see:
|
||||
|
||||
- [Roadmap](https://roadmap.vllm.ai)
|
||||
- [Releases](https://github.com/vllm-project/vllm/releases)
|
||||
|
||||
vLLM is fast with:
|
||||
|
||||
- State-of-the-art serving throughput
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
# Summary
|
||||
|
||||
[](){ #configuration }
|
||||
|
||||
## Configuration
|
||||
|
||||
API documentation for vLLM's configuration classes.
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 91 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 88 KiB |
@ -1 +0,0 @@
|
||||
toc_depth: 3
|
||||
@ -1,8 +0,0 @@
|
||||
nav:
|
||||
- README.md
|
||||
- serve.md
|
||||
- chat.md
|
||||
- complete.md
|
||||
- run-batch.md
|
||||
- vllm bench:
|
||||
- bench/*.md
|
||||
@ -1,3 +1,7 @@
|
||||
---
|
||||
toc_depth: 4
|
||||
---
|
||||
|
||||
# vLLM CLI Guide
|
||||
|
||||
The vllm command-line tool is used to run and manage vLLM models. You can start by viewing the help message with:
|
||||
@ -14,46 +18,37 @@ vllm {chat,complete,serve,bench,collect-env,run-batch}
|
||||
|
||||
## serve
|
||||
|
||||
Starts the vLLM OpenAI Compatible API server.
|
||||
Start the vLLM OpenAI Compatible API server.
|
||||
|
||||
Start with a model:
|
||||
??? console "Examples"
|
||||
|
||||
```bash
|
||||
vllm serve meta-llama/Llama-2-7b-hf
|
||||
```
|
||||
```bash
|
||||
# Start with a model
|
||||
vllm serve meta-llama/Llama-2-7b-hf
|
||||
|
||||
Specify the port:
|
||||
# Specify the port
|
||||
vllm serve meta-llama/Llama-2-7b-hf --port 8100
|
||||
|
||||
```bash
|
||||
vllm serve meta-llama/Llama-2-7b-hf --port 8100
|
||||
```
|
||||
# Check with --help for more options
|
||||
# To list all groups
|
||||
vllm serve --help=listgroup
|
||||
|
||||
Serve over a Unix domain socket:
|
||||
# To view a argument group
|
||||
vllm serve --help=ModelConfig
|
||||
|
||||
```bash
|
||||
vllm serve meta-llama/Llama-2-7b-hf --uds /tmp/vllm.sock
|
||||
```
|
||||
# To view a single argument
|
||||
vllm serve --help=max-num-seqs
|
||||
|
||||
Check with --help for more options:
|
||||
# To search by keyword
|
||||
vllm serve --help=max
|
||||
|
||||
```bash
|
||||
# To list all groups
|
||||
vllm serve --help=listgroup
|
||||
# To view full help with pager (less/more)
|
||||
vllm serve --help=page
|
||||
```
|
||||
|
||||
# To view a argument group
|
||||
vllm serve --help=ModelConfig
|
||||
### Options
|
||||
|
||||
# To view a single argument
|
||||
vllm serve --help=max-num-seqs
|
||||
|
||||
# To search by keyword
|
||||
vllm serve --help=max
|
||||
|
||||
# To view full help with pager (less/more)
|
||||
vllm serve --help=page
|
||||
```
|
||||
|
||||
See [vllm serve](./serve.md) for the full reference of all available arguments.
|
||||
--8<-- "docs/argparse/serve.md"
|
||||
|
||||
## chat
|
||||
|
||||
@ -70,8 +65,6 @@ vllm chat --url http://{vllm-serve-host}:{vllm-serve-port}/v1
|
||||
vllm chat --quick "hi"
|
||||
```
|
||||
|
||||
See [vllm chat](./chat.md) for the full reference of all available arguments.
|
||||
|
||||
## complete
|
||||
|
||||
Generate text completions based on the given prompt via the running API server.
|
||||
@ -87,7 +80,7 @@ vllm complete --url http://{vllm-serve-host}:{vllm-serve-port}/v1
|
||||
vllm complete --quick "The future of AI is"
|
||||
```
|
||||
|
||||
See [vllm complete](./complete.md) for the full reference of all available arguments.
|
||||
</details>
|
||||
|
||||
## bench
|
||||
|
||||
@ -114,8 +107,6 @@ vllm bench latency \
|
||||
--load-format dummy
|
||||
```
|
||||
|
||||
See [vllm bench latency](./bench/latency.md) for the full reference of all available arguments.
|
||||
|
||||
### serve
|
||||
|
||||
Benchmark the online serving throughput.
|
||||
@ -130,8 +121,6 @@ vllm bench serve \
|
||||
--num-prompts 5
|
||||
```
|
||||
|
||||
See [vllm bench serve](./bench/serve.md) for the full reference of all available arguments.
|
||||
|
||||
### throughput
|
||||
|
||||
Benchmark offline inference throughput.
|
||||
@ -145,8 +134,6 @@ vllm bench throughput \
|
||||
--load-format dummy
|
||||
```
|
||||
|
||||
See [vllm bench throughput](./bench/throughput.md) for the full reference of all available arguments.
|
||||
|
||||
## collect-env
|
||||
|
||||
Start collecting environment information.
|
||||
@ -159,25 +146,24 @@ vllm collect-env
|
||||
|
||||
Run batch prompts and write results to file.
|
||||
|
||||
Running with a local file:
|
||||
<details>
|
||||
<summary>Examples</summary>
|
||||
|
||||
```bash
|
||||
# Running with a local file
|
||||
vllm run-batch \
|
||||
-i offline_inference/openai_batch/openai_example_batch.jsonl \
|
||||
-o results.jsonl \
|
||||
--model meta-llama/Meta-Llama-3-8B-Instruct
|
||||
```
|
||||
|
||||
Using remote file:
|
||||
|
||||
```bash
|
||||
# Using remote file
|
||||
vllm run-batch \
|
||||
-i https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai_batch/openai_example_batch.jsonl \
|
||||
-o results.jsonl \
|
||||
--model meta-llama/Meta-Llama-3-8B-Instruct
|
||||
```
|
||||
|
||||
See [vllm run-batch](./run-batch.md) for the full reference of all available arguments.
|
||||
</details>
|
||||
|
||||
## More Help
|
||||
|
||||
|
||||
@ -1,9 +0,0 @@
|
||||
# vllm bench latency
|
||||
|
||||
## JSON CLI Arguments
|
||||
|
||||
--8<-- "docs/cli/json_tip.inc.md"
|
||||
|
||||
## Options
|
||||
|
||||
--8<-- "docs/argparse/bench_latency.md"
|
||||
@ -1,9 +0,0 @@
|
||||
# vllm bench serve
|
||||
|
||||
## JSON CLI Arguments
|
||||
|
||||
--8<-- "docs/cli/json_tip.inc.md"
|
||||
|
||||
## Options
|
||||
|
||||
--8<-- "docs/argparse/bench_serve.md"
|
||||
@ -1,9 +0,0 @@
|
||||
# vllm bench throughput
|
||||
|
||||
## JSON CLI Arguments
|
||||
|
||||
--8<-- "docs/cli/json_tip.inc.md"
|
||||
|
||||
## Options
|
||||
|
||||
--8<-- "docs/argparse/bench_throughput.md"
|
||||
@ -1,5 +0,0 @@
|
||||
# vllm chat
|
||||
|
||||
## Options
|
||||
|
||||
--8<-- "docs/argparse/chat.md"
|
||||
@ -1,5 +0,0 @@
|
||||
# vllm complete
|
||||
|
||||
## Options
|
||||
|
||||
--8<-- "docs/argparse/complete.md"
|
||||
@ -1,9 +0,0 @@
|
||||
When passing JSON CLI arguments, the following sets of arguments are equivalent:
|
||||
|
||||
- `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`
|
||||
- `--json-arg.key1 value1 --json-arg.key2.key3 value2`
|
||||
|
||||
Additionally, list elements can be passed individually using `+`:
|
||||
|
||||
- `--json-arg '{"key4": ["value3", "value4", "value5"]}'`
|
||||
- `--json-arg.key4+ value3 --json-arg.key4+='value4,value5'`
|
||||
@ -1,9 +0,0 @@
|
||||
# vllm run-batch
|
||||
|
||||
## JSON CLI Arguments
|
||||
|
||||
--8<-- "docs/cli/json_tip.inc.md"
|
||||
|
||||
## Options
|
||||
|
||||
--8<-- "docs/argparse/run-batch.md"
|
||||
@ -1,9 +0,0 @@
|
||||
# vllm serve
|
||||
|
||||
## JSON CLI Arguments
|
||||
|
||||
--8<-- "docs/cli/json_tip.inc.md"
|
||||
|
||||
## Options
|
||||
|
||||
--8<-- "docs/argparse/serve.md"
|
||||
@ -2,7 +2,6 @@
|
||||
|
||||
We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below:
|
||||
|
||||
- [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/dgkWg1WFpWGO2jCdTqQHxA), August 2nd 2025. [[Slides]](https://drive.google.com/drive/folders/1Pid6NSFLU43DZRi0EaTcPgXsAzDvbBqF) [[Recording]](https://www.chaspark.com/#/live/1166916873711665152).
|
||||
- [NYC vLLM Meetup](https://lu.ma/c1rqyf1f), May 7th, 2025. [[Slides]](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing)
|
||||
- [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day), April 3rd 2025. [[Slides]](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing).
|
||||
- [vLLM x Ollama Inference Night](https://lu.ma/vllm-ollama), March 27th 2025. [[Slides]](https://docs.google.com/presentation/d/16T2PDD1YwRnZ4Tu8Q5r6n53c5Lr5c73UV9Vd2_eBo4U/edit?usp=sharing).
|
||||
|
||||
@ -15,7 +15,6 @@ Cash Donations:
|
||||
|
||||
Compute Resources:
|
||||
|
||||
- Alibaba Cloud
|
||||
- AMD
|
||||
- Anyscale
|
||||
- AWS
|
||||
|
||||
@ -86,7 +86,7 @@ llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
|
||||
If you run out of CPU RAM, try the following options:
|
||||
|
||||
- (Multi-modal models only) you can set the size of multi-modal processor cache by setting `mm_processor_cache_gb` engine argument (default 4 GiB per API process + 4 GiB per engine core process)
|
||||
- (Multi-modal models only) you can set the size of multi-modal input cache using `VLLM_MM_INPUT_CACHE_GIB` environment variable (default 4 GiB).
|
||||
- (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB).
|
||||
|
||||
## Multi-modal input limits
|
||||
@ -129,18 +129,20 @@ reduce the size of the processed multi-modal inputs, which in turn saves memory.
|
||||
|
||||
Here are some examples:
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
??? code
|
||||
|
||||
# Available for Qwen2-VL series models
|
||||
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
mm_processor_kwargs={
|
||||
"max_pixels": 768 * 768, # Default is 1280 * 28 * 28
|
||||
})
|
||||
```python
|
||||
from vllm import LLM
|
||||
|
||||
# Available for InternVL series models
|
||||
llm = LLM(model="OpenGVLab/InternVL2-2B",
|
||||
mm_processor_kwargs={
|
||||
"max_dynamic_patch": 4, # Default is 12
|
||||
})
|
||||
```
|
||||
# Available for Qwen2-VL series models
|
||||
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
mm_processor_kwargs={
|
||||
"max_pixels": 768 * 768, # Default is 1280 * 28 * 28
|
||||
})
|
||||
|
||||
# Available for InternVL series models
|
||||
llm = LLM(model="OpenGVLab/InternVL2-2B",
|
||||
mm_processor_kwargs={
|
||||
"max_dynamic_patch": 4, # Default is 12
|
||||
})
|
||||
```
|
||||
|
||||
@ -11,8 +11,6 @@ Engine arguments control the behavior of the vLLM engine.
|
||||
|
||||
The engine argument classes, [EngineArgs][vllm.engine.arg_utils.EngineArgs] and [AsyncEngineArgs][vllm.engine.arg_utils.AsyncEngineArgs], are a combination of the configuration classes defined in [vllm.config][]. Therefore, if you are interested in developer documentation, we recommend looking at these configuration classes as they are the source of truth for types, defaults and docstrings.
|
||||
|
||||
--8<-- "docs/cli/json_tip.inc.md"
|
||||
|
||||
## `EngineArgs`
|
||||
|
||||
--8<-- "docs/argparse/engine_args.md"
|
||||
|
||||
@ -2,9 +2,6 @@
|
||||
|
||||
This guide covers optimization strategies and performance tuning for vLLM V1.
|
||||
|
||||
!!! tip
|
||||
Running out of memory? Consult [this guide](./conserving_memory.md) on how to conserve memory.
|
||||
|
||||
## Preemption
|
||||
|
||||
Due to the auto-regressive nature of transformer architecture, there are times when KV cache space is insufficient to handle all batched requests.
|
||||
@ -129,50 +126,62 @@ Data parallelism replicates the entire model across multiple GPU sets and proces
|
||||
Data parallelism can be combined with the other parallelism strategies and is set by `data_parallel_size=N`.
|
||||
Note that MoE layers will be sharded according to the product of the tensor parallel size and data parallel size.
|
||||
|
||||
## Input Processing
|
||||
## Reducing Memory Usage
|
||||
|
||||
### Parallel Processing
|
||||
If you encounter out-of-memory issues, consider these strategies:
|
||||
|
||||
You can run input processing in parallel via [API server scale-out](../serving/data_parallel_deployment.md#internal-load-balancing).
|
||||
This is useful when input processing (which is run inside the API server)
|
||||
becomes a bottleneck compared to model execution (which is run inside engine core)
|
||||
and you have excess CPU capacity.
|
||||
### Context Length and Batch Size
|
||||
|
||||
```console
|
||||
# Run 4 API processes and 1 engine core process
|
||||
vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4
|
||||
|
||||
# Run 4 API processes and 2 engine core processes
|
||||
vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4 -dp 2
|
||||
```
|
||||
|
||||
!!! note
|
||||
API server scale-out is only available for online inference.
|
||||
|
||||
!!! note
|
||||
[Multi-modal processor cache](#processor-cache) is disabled when API server scale-out is enabled
|
||||
because it requires a one-to-one correspondance between API and engine core processes.
|
||||
|
||||
## Multi-Modal Caching
|
||||
|
||||
### Processor Cache
|
||||
|
||||
By default, the multi-modal processor cache is enabled to avoid repeatedly processing
|
||||
the same multi-modal inputs via Hugging Face `AutoProcessor`,
|
||||
which commonly occurs in multi-turn conversations.
|
||||
|
||||
You can adjust the size of the cache by setting the value of `mm_processor_cache_gb`
|
||||
(default 4 GiB per API process + 4 GiB per engine core process).
|
||||
If you do not benefit much from the cache, you can disable it completely via `mm_processor_cache_gb=0`.
|
||||
|
||||
Examples:
|
||||
You can reduce memory usage by limiting the context length and batch size:
|
||||
|
||||
```python
|
||||
# Use a larger cache
|
||||
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
mm_processor_cache_gb=8)
|
||||
from vllm import LLM
|
||||
|
||||
# Disable the cache
|
||||
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
mm_processor_cache_gb=0)
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
max_model_len=2048, # Limit context window
|
||||
max_num_seqs=4 # Limit batch size
|
||||
)
|
||||
```
|
||||
|
||||
### Adjust CUDA Graph Compilation
|
||||
|
||||
CUDA graph compilation in V1 uses more memory than in V0. You can reduce memory usage by adjusting the compilation level:
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
from vllm.config import CompilationConfig, CompilationLevel
|
||||
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
cudagraph_capture_sizes=[1, 2, 4, 8] # Capture fewer batch sizes
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
Or, if you are not concerned about latency or overall performance, disable CUDA graph compilation entirely with `enforce_eager=True`:
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
enforce_eager=True # Disable CUDA graph compilation
|
||||
)
|
||||
```
|
||||
|
||||
### Multimodal Models
|
||||
|
||||
For multi-modal models, you can reduce memory usage by limiting the number of images/videos per request:
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
|
||||
# Accept up to 2 images per prompt
|
||||
llm = LLM(
|
||||
model="Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
limit_mm_per_prompt={"image": 2}
|
||||
)
|
||||
```
|
||||
|
||||
@ -96,7 +96,7 @@ Although it’s common to do this with GPUs, don't try to fragment 2 or 8 differ
|
||||
|
||||
### Tune your workloads
|
||||
|
||||
Although we try to have great default configs, we strongly recommend you check out the [vLLM auto-tuner](gh-file:benchmarks/auto_tune/README.md) to optimize your workloads for your use case.
|
||||
Although we try to have great default configs, we strongly recommend you check out the [vLLM auto-tuner](../../benchmarks/auto_tune/README.md) to optimize your workloads for your use case.
|
||||
|
||||
### Future Topics We'll Cover
|
||||
|
||||
|
||||
@ -117,7 +117,7 @@ For models with interleaving sliding windows (e.g. `google/gemma-2-2b-it` and `m
|
||||
|
||||
To support a model with interleaving sliding windows, we need to take care of the following details:
|
||||
|
||||
- Make sure the model's `config.json` contains `layer_types`.
|
||||
- Make sure the model's `config.json` contains `sliding_window_pattern`. vLLM then sets `self.hf_text_config.interleaved_sliding_window` to the value of `self.hf_text_config.sliding_window` and deletes `sliding_window` from `self.hf_text_config`. The model will then be treated as a full-attention model.
|
||||
- In the modeling code, parse the correct sliding window value for every layer, and pass it to the attention layer's `per_layer_sliding_window` argument. For reference, check [this line](https://github.com/vllm-project/vllm/blob/996357e4808ca5eab97d4c97c7d25b3073f46aab/vllm/model_executor/models/llama.py#L171).
|
||||
|
||||
With these two steps, interleave sliding windows should work with the model.
|
||||
|
||||
@ -540,10 +540,8 @@ return a schema of the tensors outputted by the HF processor that are related to
|
||||
The shape of `image_patches` outputted by `FuyuImageProcessor` is therefore
|
||||
`(1, num_images, num_patches, patch_width * patch_height * num_channels)`.
|
||||
|
||||
In order to support the use of
|
||||
[MultiModalFieldConfig.batched][vllm.multimodal.inputs.MultiModalFieldConfig.batched]
|
||||
like in LLaVA, we remove the extra batch dimension by overriding
|
||||
[BaseMultiModalProcessor._call_hf_processor][vllm.multimodal.processing.BaseMultiModalProcessor._call_hf_processor]:
|
||||
In order to support the use of [MultiModalFieldConfig.batched][] like in LLaVA,
|
||||
we remove the extra batch dimension by overriding [BaseMultiModalProcessor._call_hf_processor][]:
|
||||
|
||||
??? code
|
||||
|
||||
@ -818,7 +816,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies
|
||||
After you have defined [BaseProcessingInfo][vllm.multimodal.processing.BaseProcessingInfo] (Step 2),
|
||||
[BaseDummyInputsBuilder][vllm.multimodal.profiling.BaseDummyInputsBuilder] (Step 3),
|
||||
and [BaseMultiModalProcessor][vllm.multimodal.processing.BaseMultiModalProcessor] (Step 4),
|
||||
decorate the model class with [MULTIMODAL_REGISTRY.register_processor][vllm.multimodal.registry.MultiModalRegistry.register_processor]
|
||||
decorate the model class with [MULTIMODAL_REGISTRY.register_processor][vllm.multimodal.processing.MultiModalRegistry.register_processor]
|
||||
to register them to the multi-modal registry:
|
||||
|
||||
```diff
|
||||
|
||||
@ -200,8 +200,7 @@ vision-language model.
|
||||
lora_config = vllm_config.lora_config
|
||||
super().__init__(config, cache_config, quant_config, lora_config, prefix)
|
||||
|
||||
from packaging import version
|
||||
if version.parse(__version__) >= version.parse("0.6.4"):
|
||||
if __version__ >= "0.6.4":
|
||||
MyModel = MyNewModel
|
||||
else:
|
||||
MyModel = MyOldModel
|
||||
|
||||
@ -57,11 +57,11 @@ In v0, the following metrics are exposed via a Prometheus-compatible `/metrics`
|
||||
- `vllm:spec_decode_num_draft_tokens_total` (Counter)
|
||||
- `vllm:spec_decode_num_emitted_tokens_total` (Counter)
|
||||
|
||||
These are documented under [Inferencing and Serving -> Production Metrics](../usage/metrics.md).
|
||||
These are documented under [Inferencing and Serving -> Production Metrics](../../usage/metrics.md).
|
||||
|
||||
### Grafana Dashboard
|
||||
|
||||
vLLM also provides [a reference example](../examples/online_serving/prometheus_grafana.md) for how to collect and store these metrics using Prometheus and visualize them using a Grafana dashboard.
|
||||
vLLM also provides [a reference example](../../examples/online_serving/prometheus_grafana.md) for how to collect and store these metrics using Prometheus and visualize them using a Grafana dashboard.
|
||||
|
||||
The subset of metrics exposed in the Grafana dashboard gives us an indication of which metrics are especially important:
|
||||
|
||||
@ -455,7 +455,7 @@ In general:
|
||||
[an escape hatch](https://kubernetes.io/docs/concepts/cluster-administration/system-metrics/#show-hidden-metrics)
|
||||
for some time before deleting them.
|
||||
|
||||
See the [deprecation policy](../contributing/deprecation_policy.md) for
|
||||
See the [deprecation policy](../../contributing/deprecation_policy.md) for
|
||||
the project-wide deprecation policy.
|
||||
|
||||
### Unimplemented - `vllm:tokens_total`
|
||||
@ -655,7 +655,7 @@ v0 has support for OpenTelemetry tracing:
|
||||
- Added by <gh-pr:4687>
|
||||
- Configured with `--oltp-traces-endpoint` and `--collect-detailed-traces`
|
||||
- [OpenTelemetry blog post](https://opentelemetry.io/blog/2024/llm-observability/)
|
||||
- [User-facing docs](../examples/online_serving/opentelemetry.md)
|
||||
- [User-facing docs](../../examples/online_serving/opentelemetry.md)
|
||||
- [Blog post](https://medium.com/@ronen.schaffer/follow-the-trail-supercharging-vllm-with-opentelemetry-distributed-tracing-aa655229b46f)
|
||||
- [IBM product docs](https://www.ibm.com/docs/en/instana-observability/current?topic=mgaa-monitoring-large-language-models-llms-vllm-public-preview)
|
||||
|
||||
|
||||
@ -1,7 +0,0 @@
|
||||
# Examples
|
||||
|
||||
vLLM's examples are split into three categories:
|
||||
|
||||
- If you are using vLLM from within Python code, see [Offline Inference](./offline_inference/)
|
||||
- If you are using vLLM from an HTTP application or client, see [Online Serving](./online_serving/)
|
||||
- For examples of using some of vLLM's advanced features (e.g. LMCache or Tensorizer) which are not specific to either of the above use cases, see [Others](./others/)
|
||||
@ -5,7 +5,7 @@
|
||||
Automatic Prefix Caching (APC in short) caches the KV cache of existing queries, so that a new query can directly reuse the KV cache if it shares the same prefix with one of the existing queries, allowing the new query to skip the computation of the shared part.
|
||||
|
||||
!!! note
|
||||
Technical details on how vLLM implements APC can be found [here](../design/prefix_caching.md).
|
||||
Technical details on how vLLM implements APC can be found [here](../design/automatic_prefix_caching.md).
|
||||
|
||||
## Enabling APC in vLLM
|
||||
|
||||
|
||||
@ -19,18 +19,6 @@ Two main reasons:
|
||||
|
||||
Please refer to <gh-file:examples/online_serving/disaggregated_prefill.sh> for the example usage of disaggregated prefilling.
|
||||
|
||||
Now supports 5 types of connectors:
|
||||
|
||||
- **SharedStorageConnector**: refer to <gh-file:examples/offline_inference/disaggregated-prefill-v1/run.sh> for the example usage of SharedStorageConnector disaggregated prefilling.
|
||||
- **LMCacheConnectorV1**: refer to <gh-file:examples/others/lmcache/disagg_prefill_lmcache_v1/disagg_example_nixl.sh> for the example usage of LMCacheConnectorV1 disaggregated prefilling which uses NIXL as the underlying KV transmission.
|
||||
- **NixlConnector**: refer to <gh-file:tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh> for the example usage of NixlConnector disaggregated prefilling which support fully async send/recv.
|
||||
- **P2pNcclConnector**: refer to <gh-file:examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_example_p2p_nccl_xpyd.sh> for the example usage of P2pNcclConnector disaggregated prefilling.
|
||||
- **MultiConnector**: take advantage of the kv_connector_extra_config: dict[str, Any] already present in KVTransferConfig to stash all the connectors we want in an ordered list of kwargs.such as:
|
||||
|
||||
```bash
|
||||
--kv-transfer-config '{"kv_connector":"MultiConnector","kv_role":"kv_both","kv_connector_extra_config":{"connectors":[{"kv_connector":"NixlConnector","kv_role":"kv_both"},{"kv_connector":"SharedStorageConnector","kv_role":"kv_both","kv_connector_extra_config":{"shared_storage_path":"local_storage"}}]}}'
|
||||
```
|
||||
|
||||
## Benchmarks
|
||||
|
||||
Please refer to <gh-file:benchmarks/disagg_benchmarks> for disaggregated prefilling benchmarks.
|
||||
@ -60,19 +48,6 @@ The workflow of disaggregated prefilling is as follows:
|
||||
|
||||
The `buffer` corresponds to `insert` API in LookupBuffer, and the `drop_select` corresponds to `drop_select` API in LookupBuffer.
|
||||
|
||||
Now every process in vLLM will have a corresponding connector. Specifically, we have:
|
||||
|
||||
- Scheduler connector: the connector that locates in the same process as the scheduler process. It schedules the KV cache transfer ops.
|
||||
- Worker connectors: the connectors that locate in the worker processes. They execute KV cache transfer ops.
|
||||
|
||||
Here is a figure illustrating how the above 2 connectors are organized:
|
||||
|
||||

|
||||
|
||||
The figure below shows how the worker connector works with the attention module to achieve layer-by-layer KV cache store and load:
|
||||
|
||||

|
||||
|
||||
## Third-party contributions
|
||||
|
||||
Disaggregated prefilling is highly related to infrastructure, so vLLM relies on third-party connectors for production-level disaggregated prefilling (and vLLM team will actively review and merge new PRs for third-party connectors).
|
||||
|
||||
@ -351,22 +351,3 @@ vllm serve ibm-granite/granite-speech-3.3-2b \
|
||||
```
|
||||
|
||||
Note: Default multimodal LoRAs are currently only available for `.generate` and chat completions.
|
||||
|
||||
## Using Tips
|
||||
|
||||
### Configuring `max_lora_rank`
|
||||
|
||||
The `--max-lora-rank` parameter controls the maximum rank allowed for LoRA adapters. This setting affects memory allocation and performance:
|
||||
|
||||
- **Set it to the maximum rank** among all LoRA adapters you plan to use
|
||||
- **Avoid setting it too high** - using a value much larger than needed wastes memory and can cause performance issues
|
||||
|
||||
For example, if your LoRA adapters have ranks [16, 32, 64], use `--max-lora-rank 64` rather than 256
|
||||
|
||||
```bash
|
||||
# Good: matches actual maximum rank
|
||||
vllm serve model --enable-lora --max-lora-rank 64
|
||||
|
||||
# Bad: unnecessarily high, wastes memory
|
||||
vllm serve model --enable-lora --max-lora-rank 256
|
||||
```
|
||||
|
||||
@ -1,4 +1,7 @@
|
||||
# FP8 INC
|
||||
---
|
||||
title: FP8 INC
|
||||
---
|
||||
[](){ #inc }
|
||||
|
||||
vLLM supports FP8 (8-bit floating point) weight and activation quantization using Intel® Neural Compressor (INC) on Intel® Gaudi® 2 and Intel® Gaudi® 3 AI accelerators.
|
||||
Currently, quantization is validated only in Llama models.
|
||||
|
||||
@ -1,80 +0,0 @@
|
||||
# Sleep Mode
|
||||
|
||||
vLLM's Sleep Mode allows you to temporarily release most GPU memory used by a model, including model weights and KV cache, without stopping the server or unloading the Docker container. This is especially useful for RLHF, training, or cost-saving scenarios where GPU resources need to be freed between inference workloads.
|
||||
|
||||
Key benefits:
|
||||
|
||||
- **Frees GPU memory**: Offloads model weights to CPU RAM and discards KV cache, releasing up to 90%+ of GPU memory for other tasks.
|
||||
- **Fast resume**: Quickly wake up the engine and resume inference without full model reload.
|
||||
- **API endpoints**: Control sleep/wake_up state via HTTP endpoints or Python API.
|
||||
- **Supports distributed workloads**: Works with tensor parallelism, pipeline parallelism, etc.
|
||||
- **Fine-grained control**: Optionally wake up only model weights or KV cache to avoid OOM during weight updates.
|
||||
|
||||
!!! note
|
||||
This feature is only supported on CUDA platform.
|
||||
|
||||
## Sleep levels
|
||||
|
||||
Level 1 sleep will offload the model weights and discard the KV cache. The content of KV cache is forgotten. Level 1 sleep is good for sleeping and waking up the engine to run the same model again. The model weights are backed up in CPU memory. Please make sure there's enough CPU memory to store the model weights. Level 2 sleep will discard both the model weights and the KV cache (while the model's buffers are kept in CPU, like rope scaling tensors). The content of both the model weights and KV cache is forgotten. Level 2 sleep is good for sleeping and waking up the engine to run a different model or update the model, where previous model weights are not needed, e.g. RLHF weight update.
|
||||
|
||||
## Usage
|
||||
|
||||
### Offline inference
|
||||
|
||||
Enable sleep mode by passing `enable_sleep_mode=True` to the `LLM` class.
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
llm = LLM("Qwen/Qwen3-0.6B", enable_sleep_mode=True)
|
||||
```
|
||||
|
||||
#### Python API
|
||||
|
||||
```python
|
||||
# Put the engine to sleep (level=1: offload weights to CPU RAM, discard KV cache)
|
||||
llm.sleep(level=1)
|
||||
|
||||
# Wake up the engine (restore weights)
|
||||
llm.wake_up()
|
||||
```
|
||||
|
||||
#### RLHF weight updates
|
||||
|
||||
During RLHF training, vLLM allows you to selectively wake up only the model weights or the KV cache using the tags argument in wake_up(). This fine-grained control is especially useful when updating model weights: by waking up just the weights (e.g., llm.wake_up(tags=["weights"])), you avoid allocating memory for the KV cache until after the weight update is complete. This approach helps prevent GPU out-of-memory (OOM) errors, particularly with large models, by minimizing peak memory usage during weight synchronization and update operations.
|
||||
|
||||
Use `tags=["weights"]` or `tags=["kv_cache"]` to control which resources are restored, useful for RLHF and weight updates. **Note** that `is_sleeping` will report `true` until all components are awake.
|
||||
|
||||
```python
|
||||
# Put engine to deep sleep (level=2)
|
||||
llm.sleep(level=2)
|
||||
# ... Get the new weights
|
||||
# Wake up only weights to avoid OOM
|
||||
llm.wake_up(tags=["weights"])
|
||||
# ... Update the weights
|
||||
# wake up KV cache after weights are updated
|
||||
llm.wake_up(tags=["kv_cache"])
|
||||
```
|
||||
|
||||
### Online Serving
|
||||
|
||||
To enable sleep mode in a vLLM server you need to initialize it with the flag `VLLM_SERVER_DEV_MODE=1` and pass `--enable-sleep-mode` to the vLLM server.
|
||||
|
||||
#### Server in development mode
|
||||
|
||||
When using the flag `VLLM_SERVER_DEV_MODE=1` you enable development endpoints, and these endpoints should not be exposed to users.
|
||||
|
||||
```bash
|
||||
VLLM_SERVER_DEV_MODE=1 python -m vllm.entrypoints.openai.api_server \
|
||||
--model Qwen/Qwen3-0.6B \
|
||||
--enable-sleep-mode \
|
||||
--port 8000
|
||||
```
|
||||
|
||||
#### HTTP endpoints
|
||||
|
||||
- `POST /sleep?level=1` — Put the model to sleep (`level=1`).
|
||||
- `POST /wake_up` — Wake up the model. Supports optional `tags` query parameters for partial wake-up (e.g., `?tags=weights`).
|
||||
- `GET /is_sleeping` — Check if the model is sleeping.
|
||||
|
||||
!!! note
|
||||
These endpoints are only available when passing `VLLM_SERVER_DEV_MODE=1`.
|
||||
@ -203,7 +203,6 @@ an [EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency)](https
|
||||
"model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
|
||||
"draft_tensor_parallel_size": 1,
|
||||
"num_speculative_tokens": 2,
|
||||
"method": "eagle",
|
||||
},
|
||||
)
|
||||
|
||||
@ -232,9 +231,6 @@ A few important things to consider when using the EAGLE based draft models:
|
||||
reported in the reference implementation [here](https://github.com/SafeAILab/EAGLE). This issue is under
|
||||
investigation and tracked here: <gh-issue:9565>.
|
||||
|
||||
4. When using EAGLE-3 based draft model, option "method" must be set to "eagle3".
|
||||
That is, to specify `"method": "eagle3"` in `speculative_config`.
|
||||
|
||||
A variety of EAGLE draft models are available on the Hugging Face hub:
|
||||
|
||||
| Base Model | EAGLE on Hugging Face | # EAGLE Parameters |
|
||||
|
||||
@ -14,16 +14,3 @@ vLLM supports the following hardware platforms:
|
||||
- [Google TPU](google_tpu.md)
|
||||
- [Intel Gaudi](intel_gaudi.md)
|
||||
- [AWS Neuron](aws_neuron.md)
|
||||
|
||||
## Hardware Plugins
|
||||
|
||||
The backends below live **outside** the main `vllm` repository and follow the
|
||||
[Hardware-Pluggable RFC](../design/plugin_system.md).
|
||||
|
||||
| Accelerator | PyPI / package | Repository |
|
||||
|-------------|----------------|------------|
|
||||
| Ascend NPU | `vllm-ascend` | <https://github.com/vllm-project/vllm-ascend> |
|
||||
| Intel Gaudi (HPU) | N/A, install from source | <https://github.com/vllm-project/vllm-gaudi> |
|
||||
| MetaX MACA GPU | N/A, install from source | <https://github.com/MetaX-MACA/vLLM-metax> |
|
||||
| Rebellions ATOM / REBEL NPU | `vllm-rbln` | <https://github.com/rebellions-sw/vllm-rbln> |
|
||||
| IBM Spyre AIU | `vllm-spyre` | <https://github.com/vllm-project/vllm-spyre> |
|
||||
|
||||
@ -6,7 +6,7 @@ vLLM supports basic model inferencing and serving on x86 CPU platform, with data
|
||||
# --8<-- [start:requirements]
|
||||
|
||||
- OS: Linux
|
||||
- CPU flags: `avx512f` (Recommended), `avx512_bf16` (Optional), `avx512_vnni` (Optional)
|
||||
- CPU flags: `avx512f`, `avx512_bf16` (Optional), `avx512_vnni` (Optional)
|
||||
|
||||
!!! tip
|
||||
Use `lscpu` to check the CPU flags.
|
||||
@ -28,7 +28,7 @@ vLLM supports basic model inferencing and serving on x86 CPU platform, with data
|
||||
[https://gallery.ecr.aws/q9t5s3a7/vllm-cpu-release-repo](https://gallery.ecr.aws/q9t5s3a7/vllm-cpu-release-repo)
|
||||
|
||||
!!! warning
|
||||
If deploying the pre-built images on machines without `avx512f`, `avx512_bf16`, or `avx512_vnni` support, an `Illegal instruction` error may be raised. It is recommended to build images for these machines with the appropriate build arguments (e.g., `--build-arg VLLM_CPU_DISABLE_AVX512=true`, `--build-arg VLLM_CPU_AVX512BF16=false`, or `--build-arg VLLM_CPU_AVX512VNNI=false`) to disable unsupported features. Please note that without `avx512f`, AVX2 will be used and this version is not recommended because it only has basic feature support.
|
||||
If deploying the pre-built images on machines only contain `avx512f`, `Illegal instruction` error may be raised. It is recommended to build images for these machines with `--build-arg VLLM_CPU_AVX512BF16=false` and `--build-arg VLLM_CPU_AVX512VNNI=false`.
|
||||
|
||||
# --8<-- [end:pre-built-images]
|
||||
# --8<-- [start:build-image-from-source]
|
||||
@ -37,7 +37,6 @@ vLLM supports basic model inferencing and serving on x86 CPU platform, with data
|
||||
docker build -f docker/Dockerfile.cpu \
|
||||
--build-arg VLLM_CPU_AVX512BF16=false (default)|true \
|
||||
--build-arg VLLM_CPU_AVX512VNNI=false (default)|true \
|
||||
--build-arg VLLM_CPU_DISABLE_AVX512=false (default)|true \
|
||||
--tag vllm-cpu-env \
|
||||
--target vllm-openai .
|
||||
|
||||
|
||||
@ -15,14 +15,8 @@ sys.modules["aiohttp"] = MagicMock()
|
||||
sys.modules["blake3"] = MagicMock()
|
||||
sys.modules["vllm._C"] = MagicMock()
|
||||
|
||||
from vllm.benchmarks import latency # noqa: E402
|
||||
from vllm.benchmarks import serve # noqa: E402
|
||||
from vllm.benchmarks import throughput # noqa: E402
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs # noqa: E402
|
||||
from vllm.entrypoints.cli.openai import ChatCommand # noqa: E402
|
||||
from vllm.entrypoints.cli.openai import CompleteCommand # noqa: E402
|
||||
from vllm.entrypoints.openai import cli_args # noqa: E402
|
||||
from vllm.entrypoints.openai import run_batch # noqa: E402
|
||||
from vllm.entrypoints.openai.cli_args import make_arg_parser # noqa: E402
|
||||
from vllm.utils import FlexibleArgumentParser # noqa: E402
|
||||
|
||||
logger = logging.getLogger("mkdocs")
|
||||
@ -74,8 +68,7 @@ class MarkdownFormatter(HelpFormatter):
|
||||
self._markdown_output.append(
|
||||
f"Possible choices: {metavar}\n\n")
|
||||
|
||||
if action.help:
|
||||
self._markdown_output.append(f"{action.help}\n\n")
|
||||
self._markdown_output.append(f"{action.help}\n\n")
|
||||
|
||||
if (default := action.default) != SUPPRESS:
|
||||
self._markdown_output.append(f"Default: `{default}`\n\n")
|
||||
@ -85,7 +78,7 @@ class MarkdownFormatter(HelpFormatter):
|
||||
return "".join(self._markdown_output)
|
||||
|
||||
|
||||
def create_parser(add_cli_args, **kwargs) -> FlexibleArgumentParser:
|
||||
def create_parser(cls, **kwargs) -> FlexibleArgumentParser:
|
||||
"""Create a parser for the given class with markdown formatting.
|
||||
|
||||
Args:
|
||||
@ -95,12 +88,18 @@ def create_parser(add_cli_args, **kwargs) -> FlexibleArgumentParser:
|
||||
Returns:
|
||||
FlexibleArgumentParser: A parser with markdown formatting for the class.
|
||||
"""
|
||||
parser = FlexibleArgumentParser(add_json_tip=False)
|
||||
parser = FlexibleArgumentParser()
|
||||
parser.formatter_class = MarkdownFormatter
|
||||
with patch("vllm.config.DeviceConfig.__post_init__"):
|
||||
_parser = add_cli_args(parser, **kwargs)
|
||||
# add_cli_args might be in-place so return parser if _parser is None
|
||||
return _parser or parser
|
||||
return cls.add_cli_args(parser, **kwargs)
|
||||
|
||||
|
||||
def create_serve_parser() -> FlexibleArgumentParser:
|
||||
"""Create a parser for the serve command with markdown formatting."""
|
||||
parser = FlexibleArgumentParser()
|
||||
parser.formatter_class = lambda prog: MarkdownFormatter(
|
||||
prog, starting_heading_level=4)
|
||||
return make_arg_parser(parser)
|
||||
|
||||
|
||||
def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool):
|
||||
@ -114,24 +113,10 @@ def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool):
|
||||
|
||||
# Create parsers to document
|
||||
parsers = {
|
||||
"engine_args":
|
||||
create_parser(EngineArgs.add_cli_args),
|
||||
"async_engine_args":
|
||||
create_parser(AsyncEngineArgs.add_cli_args, async_args_only=True),
|
||||
"serve":
|
||||
create_parser(cli_args.make_arg_parser),
|
||||
"chat":
|
||||
create_parser(ChatCommand.add_cli_args),
|
||||
"complete":
|
||||
create_parser(CompleteCommand.add_cli_args),
|
||||
"bench_latency":
|
||||
create_parser(latency.add_cli_args),
|
||||
"bench_throughput":
|
||||
create_parser(throughput.add_cli_args),
|
||||
"bench_serve":
|
||||
create_parser(serve.add_cli_args),
|
||||
"run-batch":
|
||||
create_parser(run_batch.make_arg_parser),
|
||||
"engine_args": create_parser(EngineArgs),
|
||||
"async_engine_args": create_parser(AsyncEngineArgs,
|
||||
async_args_only=True),
|
||||
"serve": create_serve_parser(),
|
||||
}
|
||||
|
||||
# Generate documentation for each parser
|
||||
|
||||
@ -105,7 +105,7 @@ class Example:
|
||||
return fix_case(self.path.stem.replace("_", " ").title())
|
||||
|
||||
def generate(self) -> str:
|
||||
content = f"# {self.title}\n\n"
|
||||
content = f"---\ntitle: {self.title}\n---\n\n"
|
||||
content += f"Source <gh-file:{self.path.relative_to(ROOT_DIR)}>.\n\n"
|
||||
|
||||
# Use long code fence to avoid issues with
|
||||
|
||||
@ -23,13 +23,6 @@ a:not(:has(svg)):not(.md-icon):not(.autorefs-external) {
|
||||
}
|
||||
}
|
||||
|
||||
a[href*="localhost"]::after,
|
||||
a[href*="127.0.0.1"]::after,
|
||||
a[href*="org.readthedocs.build"]::after,
|
||||
a[href*="docs.vllm.ai"]::after {
|
||||
display: none !important;
|
||||
}
|
||||
|
||||
/* Light mode: darker section titles */
|
||||
body[data-md-color-scheme="default"] .md-nav__item--section > label.md-nav__link .md-ellipsis {
|
||||
color: rgba(0, 0, 0, 0.7) !important;
|
||||
|
||||
@ -4,7 +4,7 @@ vLLM provides first-class support for generative models, which covers most of LL
|
||||
|
||||
In vLLM, generative models implement the[VllmModelForTextGeneration][vllm.model_executor.models.VllmModelForTextGeneration] interface.
|
||||
Based on the final hidden states of the input, these models output log probabilities of the tokens to generate,
|
||||
which are then passed through [Sampler][vllm.model_executor.layers.sampler.Sampler] to obtain the final text.
|
||||
which are then passed through [Sampler][vllm.model_executor.layers.Sampler] to obtain the final text.
|
||||
|
||||
## Configuration
|
||||
|
||||
@ -19,7 +19,7 @@ Run a model in generation mode via the option `--runner generate`.
|
||||
## Offline Inference
|
||||
|
||||
The [LLM][vllm.LLM] class provides various methods for offline inference.
|
||||
See [configuration](../api/summary.md#configuration) for a list of options when initializing the model.
|
||||
See [configuration][configuration] for a list of options when initializing the model.
|
||||
|
||||
### `LLM.generate`
|
||||
|
||||
|
||||
@ -81,7 +81,7 @@ which takes priority over both the model's and Sentence Transformers's defaults.
|
||||
## Offline Inference
|
||||
|
||||
The [LLM][vllm.LLM] class provides various methods for offline inference.
|
||||
See [configuration](../api/summary.md#configuration) for a list of options when initializing the model.
|
||||
See [configuration][configuration] for a list of options when initializing the model.
|
||||
|
||||
### `LLM.embed`
|
||||
|
||||
|
||||
@ -320,7 +320,7 @@ th {
|
||||
}
|
||||
</style>
|
||||
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
|
||||
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
|
||||
| `AquilaForCausalLM` | Aquila, Aquila2 | `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `ArceeForCausalLM` | Arcee (AFM) | `arcee-ai/AFM-4.5B-Base`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
@ -331,7 +331,7 @@ th {
|
||||
| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | |
|
||||
| `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | |
|
||||
| `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `zai-org/chatglm2-6b`, `zai-org/chatglm3-6b`, `ShieldLM-6B-chatglm3`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `CohereForCausalLM`, `Cohere2ForCausalLM` | Command-R | `CohereLabs/c4ai-command-r-v01`, `CohereLabs/c4ai-command-r7b-12-2024`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `CohereForCausalLM`, `Cohere2ForCausalLM` | Command-R | `CohereForAI/c4ai-command-r-v01`, `CohereForAI/c4ai-command-r7b-12-2024`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `DbrxForCausalLM` | DBRX | `databricks/dbrx-base`, `databricks/dbrx-instruct`, etc. | | ✅︎ | ✅︎ |
|
||||
| `DeciLMForCausalLM` | DeciLM | `nvidia/Llama-3_3-Nemotron-Super-49B-v1`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `DeepseekForCausalLM` | DeepSeek | `deepseek-ai/deepseek-llm-67b-base`, `deepseek-ai/deepseek-llm-7b-chat`, etc. | | ✅︎ | ✅︎ |
|
||||
@ -349,10 +349,9 @@ th {
|
||||
| `GemmaForCausalLM` | Gemma | `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Gemma3nForCausalLM` | Gemma 3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ |
|
||||
| `Gemma3nForConditionalGeneration` | Gemma 3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ |
|
||||
| `GlmForCausalLM` | GLM-4 | `zai-org/glm-4-9b-chat-hf`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Glm4ForCausalLM` | GLM-4-0414 | `zai-org/GLM-4-32B-0414`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Glm4MoeForCausalLM` | GLM-4.5 | `zai-org/GLM-4.5`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `GPT2LMHeadModel` | GPT-2 | `gpt2`, `gpt2-xl`, etc. | | ✅︎ | ✅︎ |
|
||||
| `GPTBigCodeForCausalLM` | StarCoder, SantaCoder, WizardCoder | `bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, `WizardLM/WizardCoder-15B-V1.0`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `GPTJForCausalLM` | GPT-J | `EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc. | | ✅︎ | ✅︎ |
|
||||
@ -371,9 +370,9 @@ th {
|
||||
| `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ | ✅︎ |
|
||||
| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | |
|
||||
| `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | ✅︎ |
|
||||
| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | |
|
||||
| `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | ✅︎ |
|
||||
| `MiMoForCausalLM` | MiMo | `XiaomiMiMo/MiMo-7B-RL`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `MiniCPMForCausalLM` | MiniCPM | `openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, `openbmb/MiniCPM-S-1B-sft`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
@ -405,19 +404,16 @@ th {
|
||||
| `TeleChat2ForCausalLM` | TeleChat2 | `Tele-AI/TeleChat2-3B`, `Tele-AI/TeleChat2-7B`, `Tele-AI/TeleChat2-35B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `TeleFLMForCausalLM` | TeleFLM | `CofeAI/FLM-2-52B-Instruct-2407`, `CofeAI/Tele-FLM`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `XverseForCausalLM` | XVERSE | `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `MiniMaxM1ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-M1-40k`, `MiniMaxAI/MiniMax-M1-80k`, etc. | | | ✅︎ |
|
||||
| `MiniMaxText01ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-Text-01`, etc. | | | ✅︎ |
|
||||
| `MiniMaxM1ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-M1-40k`, `MiniMaxAI/MiniMax-M1-80k`, etc. | | | |
|
||||
| `MiniMaxText01ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-Text-01`, etc. | | | |
|
||||
| `Zamba2ForCausalLM` | Zamba2 | `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. | | | ✅︎ |
|
||||
|
||||
Some models are supported only via the [Transformers backend](#transformers). The purpose of the table below is to acknowledge models which we officially support in this way. The logs will say that the Transformers backend is being used, and you will see no warning that this is fallback behaviour. This means that, if you have issues with any of the models listed below, please [make an issue](https://github.com/vllm-project/vllm/issues/new/choose) and we'll do our best to fix it!
|
||||
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|
||||
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
|
||||
| `SmolLM3ForCausalLM` | SmolLM3 | `HuggingFaceTB/SmolLM3-3B` | ✅︎ | ✅︎ | ✅︎ |
|
||||
|
||||
!!! note
|
||||
Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096.
|
||||
|
||||
!!! note
|
||||
Only text inputs are currently supported for `Gemma3nForConditionalGeneration`. To use this model, please upgrade Hugging Face Transformers to version 4.53.0.
|
||||
|
||||
### Pooling Models
|
||||
|
||||
See [this page](./pooling_models.md) for more information on how to use pooling models.
|
||||
@ -430,7 +426,7 @@ See [this page](./pooling_models.md) for more information on how to use pooling
|
||||
|
||||
These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) API.
|
||||
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
|
||||
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
|
||||
| `BertModel`<sup>C</sup> | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | | |
|
||||
| `Gemma2Model`<sup>C</sup> | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | | ✅︎ |
|
||||
@ -470,7 +466,7 @@ of the whole prompt are extracted from the normalized hidden state corresponding
|
||||
|
||||
These models primarily support the [`LLM.classify`](./pooling_models.md#llmclassify) API.
|
||||
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
|
||||
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
|
||||
| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ | |
|
||||
| `GPT2ForSequenceClassification` | GPT2 | `nie3e/sentiment-polish-gpt2-small` | | | ✅︎ |
|
||||
@ -487,7 +483,7 @@ If your model is not in the above list, we will try to automatically convert the
|
||||
Cross-encoder and reranker models are a subset of classification models that accept two prompts as input.
|
||||
These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) API.
|
||||
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
|
||||
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
|
||||
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | | |
|
||||
| `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma` (see note), etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
@ -525,7 +521,7 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
|
||||
|
||||
These models primarily support the [`LLM.reward`](./pooling_models.md#llmreward) API.
|
||||
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
|
||||
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
|
||||
| `InternLM2ForRewardModel` | InternLM2-based | `internlm/internlm2-1_8b-reward`, `internlm/internlm2-7b-reward`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `LlamaForCausalLM`<sup>C</sup> | Llama-based | `peiyi9979/math-shepherd-mistral-7b-prm`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
@ -587,9 +583,6 @@ See [this page](../features/multimodal_inputs.md) on how to pass multi-modal inp
|
||||
|
||||
**This is no longer required if you are using vLLM V1.**
|
||||
|
||||
!!! tip
|
||||
For hybrid-only models such as Llama-4, Step3 and Mistral-3, a text-only mode can be enabled by setting all supported multimodal modalities to 0 (e.g, `--limit-mm-per-prompt '{"image":0}`) so that their multimodal modules will not be loaded to free up more GPU memory for KV cache.
|
||||
|
||||
!!! note
|
||||
vLLM currently only supports adding LoRA to the language backbone of multimodal models.
|
||||
|
||||
@ -601,21 +594,20 @@ See [this page](generative_models.md) for more information on how to use generat
|
||||
|
||||
These models primarily accept the [`LLM.generate`](./generative_models.md#llmgenerate) API. Chat/Instruct models additionally support the [`LLM.chat`](./generative_models.md#llmchat) API.
|
||||
|
||||
| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|
||||
| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
|
||||
|--------------|--------|--------|-------------------|----------------------|---------------------------|---------------------|
|
||||
| `AriaForConditionalGeneration` | Aria | T + I<sup>+</sup> | `rhymes-ai/Aria` | | | ✅︎ |
|
||||
| `AyaVisionForConditionalGeneration` | Aya Vision | T + I<sup>+</sup> | `CohereForAI/aya-vision-8b`, `CohereForAI/aya-vision-32b`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Blip2ForConditionalGeneration` | BLIP-2 | T + I<sup>E</sup> | `Salesforce/blip2-opt-2.7b`, `Salesforce/blip2-opt-6.7b`, etc. | | ✅︎ | ✅︎ |
|
||||
| `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Cohere2VisionForConditionalGeneration` | Command A Vision | T + I<sup>+</sup> | `CohereLabs/command-a-vision-07-2025`, etc. | | ✅︎ | ✅︎ |
|
||||
| `DeepseekVLV2ForCausalLM`<sup>^</sup> | DeepSeek-VL2 | T + I<sup>+</sup> | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large`, etc. | | | |
|
||||
| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ |
|
||||
| `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ |
|
||||
| `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Glm4vMoeForConditionalGeneration` | GLM-4.5V | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5V`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Glm4MoeForCausalLM` | GLM-4.5 | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Glm4v_moeForConditionalGeneration` | GLM-4.5V | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5V`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | ✅︎ |
|
||||
@ -630,7 +622,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
||||
| `LlavaNextVideoForConditionalGeneration` | LLaVA-NeXT-Video | T + V | `llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. | | ✅︎ | ✅︎ |
|
||||
| `LlavaOnevisionForConditionalGeneration` | LLaVA-Onevision | T + I<sup>+</sup> + V<sup>+</sup> | `llava-hf/llava-onevision-qwen2-7b-ov-hf`, `llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. | | ✅︎ | ✅︎ |
|
||||
| `MiniCPMO` | MiniCPM-O | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>E+</sup> | `openbmb/MiniCPM-o-2_6`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `MiniCPMV` | MiniCPM-V | T + I<sup>E+</sup> + V<sup>E+</sup> | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, `openbmb/MiniCPM-V-4`, etc. | ✅︎ | | ✅︎ |
|
||||
| `MiniCPMV` | MiniCPM-V | T + I<sup>E+</sup> + V<sup>E+</sup> | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, etc. | ✅︎ | | ✅︎ |
|
||||
| `MiniMaxVL01ForConditionalGeneration` | MiniMax-VL | T + I<sup>E+</sup> | `MiniMaxAI/MiniMax-VL-01`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Mistral3ForConditionalGeneration` | Mistral3 (HF Transformers) | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `MllamaForConditionalGeneration` | Llama 3.2 | T + I<sup>+</sup> | `meta-llama/Llama-3.2-90B-Vision-Instruct`, `meta-llama/Llama-3.2-11B-Vision`, etc. | | | |
|
||||
@ -655,7 +647,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
||||
|
||||
Some models are supported only via the [Transformers backend](#transformers). The purpose of the table below is to acknowledge models which we officially support in this way. The logs will say that the Transformers backend is being used, and you will see no warning that this is fallback behaviour. This means that, if you have issues with any of the models listed below, please [make an issue](https://github.com/vllm-project/vllm/issues/new/choose) and we'll do our best to fix it!
|
||||
|
||||
| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|
||||
| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
|
||||
|--------------|--------|--------|-------------------|-----------------------------|-----------------------------------------|---------------------|
|
||||
| `Emu3ForConditionalGeneration` | Emu3 | T + I | `BAAI/Emu3-Chat-hf` | ✅︎ | ✅︎ | ✅︎ |
|
||||
|
||||
@ -682,15 +674,6 @@ Some models are supported only via the [Transformers backend](#transformers). Th
|
||||
|
||||
This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends.
|
||||
|
||||
!!! note
|
||||
`Gemma3nForConditionalGeneration` is only supported on V1 due to shared KV caching and it depends on `timm>=1.0.17` to make use of its
|
||||
MobileNet-v5 vision backbone.
|
||||
|
||||
Performance is not yet fully optimized mainly due to:
|
||||
|
||||
- Both audio and vision MM encoders use `transformers.AutoModel` implementation.
|
||||
- There's no PLE caching or out-of-memory swapping support, as described in [Google's blog](https://developers.googleblog.com/en/introducing-gemma-3n/). These features might be too model-specific for vLLM, and swapping in particular may be better suited for constrained setups.
|
||||
|
||||
!!! note
|
||||
Only `InternVLChatModel` with Qwen2.5 text backbone (`OpenGVLab/InternVL3-2B`, `OpenGVLab/InternVL2.5-1B` etc) has video inputs support currently.
|
||||
|
||||
@ -743,7 +726,7 @@ Some models are supported only via the [Transformers backend](#transformers). Th
|
||||
|
||||
Speech2Text models trained specifically for Automatic Speech Recognition.
|
||||
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
|
||||
|--------------|--------|-------------------|----------------------|---------------------------|---------------------|
|
||||
| `WhisperForConditionalGeneration` | Whisper | `openai/whisper-small`, `openai/whisper-large-v3-turbo`, etc. | | | |
|
||||
| `VoxtralForConditionalGeneration` | Voxtral (Mistral format) | `mistralai/Voxtral-Mini-3B-2507`, `mistralai/Voxtral-Small-24B-2507`, etc. | | ✅︎ | ✅︎ |
|
||||
@ -761,7 +744,7 @@ These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) A
|
||||
|
||||
The following table lists those that are tested in vLLM.
|
||||
|
||||
| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|
||||
| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) |
|
||||
|--------------|--------|--------|-------------------|----------------------|---------------------------|---------------------|
|
||||
| `LlavaNextForConditionalGeneration`<sup>C</sup> | LLaVA-NeXT-based | T / I | `royokong/e5-v` | | | |
|
||||
| `Phi3VForCausalLM`<sup>C</sup> | Phi-3-Vision-based | T + I | `TIGER-Lab/VLM2Vec-Full` | 🚧 | ✅︎ | |
|
||||
@ -777,7 +760,7 @@ The following table lists those that are tested in vLLM.
|
||||
Cross-encoder and reranker models are a subset of classification models that accept two prompts as input.
|
||||
These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) API.
|
||||
|
||||
| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|
||||
| Architecture | Models | Inputs | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) |
|
||||
|-------------------------------------|--------------------|----------|--------------------------|------------------------|-----------------------------|-----------------------|
|
||||
| `JinaVLForSequenceClassification` | JinaVL-based | T + I<sup>E+</sup> | `jinaai/jina-reranker-m0`, etc. | | | ✅︎ |
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# Parallelism and Scaling
|
||||
# Distributed inference and serving
|
||||
|
||||
## Distributed inference strategies for a single-model replica
|
||||
|
||||
@ -128,17 +128,12 @@ vllm serve /path/to/the/model/in/the/container \
|
||||
--tensor-parallel-size 16
|
||||
```
|
||||
|
||||
## Optimizing network communication for tensor parallelism
|
||||
## Troubleshooting distributed deployments
|
||||
|
||||
Efficient tensor parallelism requires fast inter-node communication, preferably through high-speed network adapters such as InfiniBand.
|
||||
To set up the cluster to use InfiniBand, append additional arguments like `--privileged -e NCCL_IB_HCA=mlx5` to the
|
||||
<gh-file:examples/online_serving/run_cluster.sh> helper script.
|
||||
Contact your system administrator for more information about the required flags.
|
||||
To make tensor parallelism performant, ensure that communication between nodes is efficient, for example, by using high-speed network cards such as InfiniBand. To set up the cluster to use InfiniBand, append additional arguments like `--privileged -e NCCL_IB_HCA=mlx5` to the `run_cluster.sh` script. Contact your system administrator for more information about the required flags. One way to confirm if InfiniBand is working is to run `vllm` with the `NCCL_DEBUG=TRACE` environment variable set, for example `NCCL_DEBUG=TRACE vllm serve ...`, and check the logs for the NCCL version and the network used. If you find `[send] via NET/Socket` in the logs, NCCL uses a raw TCP socket, which is not efficient for cross-node tensor parallelism. If you find `[send] via NET/IB/GDRDMA` in the logs, NCCL uses InfiniBand with GPUDirect RDMA, which is efficient.
|
||||
|
||||
## Enabling GPUDirect RDMA
|
||||
|
||||
GPUDirect RDMA (Remote Direct Memory Access) is an NVIDIA technology that allows network adapters to directly access GPU memory, bypassing the CPU and system memory. This direct access reduces latency and CPU overhead, which is beneficial for large data transfers between GPUs across nodes.
|
||||
|
||||
To enable GPUDirect RDMA with vLLM, configure the following settings:
|
||||
|
||||
- `IPC_LOCK` security context: add the `IPC_LOCK` capability to the container's security context to lock memory pages and prevent swapping to disk.
|
||||
@ -180,17 +175,21 @@ spec:
|
||||
...
|
||||
```
|
||||
|
||||
!!! tip "Confirm GPUDirect RDMA operation"
|
||||
To confirm your InfiniBand card is using GPUDirect RDMA, run vLLM with detailed NCCL logs: `NCCL_DEBUG=TRACE vllm serve ...`.
|
||||
Efficient tensor parallelism requires fast inter-node communication, preferably through high-speed network adapters such as InfiniBand. To enable InfiniBand, append flags such as `--privileged -e NCCL_IB_HCA=mlx5` to `run_cluster.sh`. For cluster-specific settings, consult your system administrator.
|
||||
|
||||
Then look for the NCCL version and the network used.
|
||||
To confirm InfiniBand operation, enable detailed NCCL logs:
|
||||
|
||||
- If you find `[send] via NET/IB/GDRDMA` in the logs, then NCCL is using InfiniBand with GPUDirect RDMA, which *is* efficient.
|
||||
- If you find `[send] via NET/Socket` in the logs, NCCL used a raw TCP socket, which *is not* efficient for cross-node tensor parallelism.
|
||||
```bash
|
||||
NCCL_DEBUG=TRACE vllm serve ...
|
||||
```
|
||||
|
||||
Search the logs for the transport method. Entries containing `[send] via NET/Socket` indicate raw TCP sockets, which perform poorly for cross-node tensor parallelism. Entries containing `[send] via NET/IB/GDRDMA` indicate InfiniBand with GPUDirect RDMA, which provides high performance.
|
||||
|
||||
!!! tip "Verify inter-node GPU communication"
|
||||
After you start the Ray cluster, verify GPU-to-GPU communication across nodes. Proper configuration can be non-trivial. For more information, see [troubleshooting script][troubleshooting-incorrect-hardware-driver]. If you need additional environment variables for communication configuration, append them to `run_cluster.sh`, for example `-e NCCL_SOCKET_IFNAME=eth0`. Setting environment variables during cluster creation is recommended because the variables propagate to all nodes. In contrast, setting environment variables in the shell affects only the local node. For more information, see <gh-issue:6803>.
|
||||
|
||||
!!! tip "Pre-download Hugging Face models"
|
||||
If you use Hugging Face models, downloading the model before starting vLLM is recommended. Download the model on every node to the same path, or store the model on a distributed file system accessible by all nodes. Then pass the path to the model in place of the repository ID. Otherwise, supply a Hugging Face token by appending `-e HF_TOKEN=<TOKEN>` to `run_cluster.sh`.
|
||||
|
||||
## Troubleshooting distributed deployments
|
||||
|
||||
For information about distributed debugging, see [Troubleshooting distributed deployments](distributed_troubleshooting.md).
|
||||
!!! tip
|
||||
The error message `Error: No available node types can fulfill resource request` can appear even when the cluster has enough GPUs. The issue often occurs when nodes have multiple IP addresses and vLLM can't select the correct one. Ensure that vLLM and Ray use the same IP address by setting `VLLM_HOST_IP` in `run_cluster.sh` (with a different value on each node). Use `ray status` and `ray list nodes` to verify the chosen IP address. For more information, see <gh-issue:7815>.
|
||||
@ -1,16 +0,0 @@
|
||||
# Troubleshooting distributed deployments
|
||||
|
||||
For general troubleshooting, see [Troubleshooting](../usage/troubleshooting.md).
|
||||
|
||||
## Verify inter-node GPU communication
|
||||
|
||||
After you start the Ray cluster, verify GPU-to-GPU communication across nodes. Proper configuration can be non-trivial. For more information, see [troubleshooting script][troubleshooting-incorrect-hardware-driver]. If you need additional environment variables for communication configuration, append them to <gh-file:examples/online_serving/run_cluster.sh>, for example `-e NCCL_SOCKET_IFNAME=eth0`. Setting environment variables during cluster creation is recommended because the variables propagate to all nodes. In contrast, setting environment variables in the shell affects only the local node. For more information, see <gh-issue:6803>.
|
||||
|
||||
## No available node types can fulfill resource request
|
||||
|
||||
The error message `Error: No available node types can fulfill resource request` can appear even when the cluster has enough GPUs. The issue often occurs when nodes have multiple IP addresses and vLLM can't select the correct one. Ensure that vLLM and Ray use the same IP address by setting `VLLM_HOST_IP` in <gh-file:examples/online_serving/run_cluster.sh> (with a different value on each node). Use `ray status` and `ray list nodes` to verify the chosen IP address. For more information, see <gh-issue:7815>.
|
||||
|
||||
## Ray observability
|
||||
|
||||
Debugging a distributed system can be challenging due to the large scale and complexity. Ray provides a suite of tools to help monitor, debug, and optimize Ray applications and clusters. For more information about Ray observability, visit the [official Ray observability docs](https://docs.ray.io/en/latest/ray-observability/index.html). For more information about debugging Ray applications, visit the [Ray Debugging Guide](https://docs.ray.io/en/latest/ray-observability/user-guides/debug-apps/index.html). For information about troubleshooting Kubernetes clusters, see the
|
||||
[official KubeRay troubleshooting guide](https://docs.ray.io/en/latest/serve/advanced-guides/multi-node-gpu-troubleshooting.html).
|
||||
@ -1,8 +1,6 @@
|
||||
# Using vLLM
|
||||
|
||||
First, vLLM must be [installed](../getting_started/installation) for your chosen device in either a Python or Docker environment.
|
||||
|
||||
Then, vLLM supports the following usage patterns:
|
||||
vLLM supports the following usage patterns:
|
||||
|
||||
- [Inference and Serving](../serving/offline_inference.md): Run a single instance of a model.
|
||||
- [Deployment](../deployment/docker.md): Scale up model instances for production.
|
||||
|
||||
@ -289,7 +289,7 @@ Traceback (most recent call last):
|
||||
...
|
||||
```
|
||||
|
||||
This indicates vLLM failed to initialize the NCCL communicator, possibly due to a missing `IPC_LOCK` linux capability or an unmounted `/dev/shm`. Refer to [Enabling GPUDirect RDMA](../serving/parallelism_scaling.md#enabling-gpudirect-rdma) for guidance on properly configuring the environment for GPUDirect RDMA.
|
||||
This indicates vLLM failed to initialize the NCCL communicator, possibly due to a missing `IPC_LOCK` linux capability or an unmounted `/dev/shm`. Refer to [Distributed Inference and Serving](../serving/distributed_serving.md#running-vllm-on-multiple-nodes) for guidance on properly configuring the environment for distributed serving.
|
||||
|
||||
## Known Issues
|
||||
|
||||
|
||||
@ -59,13 +59,12 @@ based on assigned priority, with FCFS as a tie-breaker), configurable via the
|
||||
|
||||
### Hardware
|
||||
|
||||
| Hardware | Status |
|
||||
|------------|-----------------------------------------------|
|
||||
| **NVIDIA** | <nobr>🚀</nobr> |
|
||||
| **AMD** | <nobr>🟢</nobr> |
|
||||
| **INTEL GPU** | <nobr>🟢</nobr> |
|
||||
| **TPU** | <nobr>🟢</nobr> |
|
||||
| **CPU** | <nobr>🟢 (x86\_64/aarch64) 🟡 (MacOS) </nobr> |
|
||||
| Hardware | Status |
|
||||
|------------|------------------------------------|
|
||||
| **NVIDIA** | <nobr>🚀</nobr> |
|
||||
| **AMD** | <nobr>🟢</nobr> |
|
||||
| **TPU** | <nobr>🟢</nobr> |
|
||||
| **CPU** | <nobr>🟢 (x86) 🟡 (MacOS) </nobr> |
|
||||
|
||||
!!! note
|
||||
|
||||
@ -73,7 +72,6 @@ based on assigned priority, with FCFS as a tie-breaker), configurable via the
|
||||
|
||||
- [vllm-ascend](https://github.com/vllm-project/vllm-ascend)
|
||||
- [vllm-spyre](https://github.com/vllm-project/vllm-spyre)
|
||||
- [vllm-gaudi](https://github.com/vllm-project/vllm-gaudi)
|
||||
- [vllm-openvino](https://github.com/vllm-project/vllm-openvino)
|
||||
|
||||
Please check their corresponding repositories for more details.
|
||||
@ -85,7 +83,7 @@ based on assigned priority, with FCFS as a tie-breaker), configurable via the
|
||||
| **Decoder-only Models** | <nobr>🚀 Optimized</nobr> |
|
||||
| **Encoder-Decoder Models** | <nobr>🟠 Delayed</nobr> |
|
||||
| **Embedding Models** | <nobr>🟢 Functional</nobr> |
|
||||
| **Mamba Models** | <nobr>🟢 (Mamba-2), 🟢 (Mamba-1)</nobr> |
|
||||
| **Mamba Models** | <nobr>🟢 (Mamba-2), 🟡 (Mamba-1)</nobr> |
|
||||
| **Multimodal Models** | <nobr>🟢 Functional</nobr> |
|
||||
|
||||
vLLM V1 currently excludes model architectures with the `SupportsV0Only` protocol.
|
||||
@ -106,17 +104,15 @@ to enable simultaneous generation and embedding using the same engine instance i
|
||||
|
||||
#### Mamba Models
|
||||
|
||||
Models using selective state-space mechanisms instead of standard transformer attention are supported.
|
||||
Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`) are supported. Please note that these models currently require disabling prefix caching in V1. Additionally, Mamba-1 models require `enforce_eager=True`.
|
||||
Models using selective state-space mechanisms instead of standard transformer attention are partially supported.
|
||||
Models that use Mamba-2 layers (e.g., `Mamba2ForCausalLM`) are supported, but models that use older Mamba-1 layers
|
||||
(e.g., `MambaForCausalLM`, `JambaForCausalLM`) are not yet supported. Please note that these models currently require
|
||||
disabling prefix caching in V1.
|
||||
|
||||
Models that combine Mamba-2 and Mamba-1 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`,
|
||||
`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`, `JambaForCausalLM`). Please note that
|
||||
Models that combine Mamba-2 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`,
|
||||
`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`). Please note that
|
||||
these models currently require disabling prefix caching and using the FlashInfer attention backend in V1.
|
||||
|
||||
Hybrid models with mechanisms different to Mamba are also supported (e.g, `MiniMaxText01ForCausalLM`, `MiniMaxM1ForCausalLM`).
|
||||
Please note that these models currently require disabling prefix caching, enforcing eager mode, and using the FlashInfer
|
||||
attention backend in V1.
|
||||
|
||||
#### Encoder-Decoder Models
|
||||
|
||||
Models requiring cross-attention between separate encoder and decoder (e.g., `BartForConditionalGeneration`, `MllamaForConditionalGeneration`)
|
||||
|
||||
@ -96,25 +96,6 @@ def run_voxtral(question: str, audio_count: int) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
# Gemma3N
|
||||
def run_gemma3n(question: str, audio_count: int) -> ModelRequestData:
|
||||
model_name = "google/gemma-3n-E2B-it"
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=2048,
|
||||
max_num_batched_tokens=2048,
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={"audio": audio_count},
|
||||
enforce_eager=True,
|
||||
)
|
||||
prompt = f"<start_of_turn>user\n<audio_soft_token>{question}"
|
||||
"<end_of_turn>\n<start_of_turn>model\n"
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
|
||||
# Granite Speech
|
||||
def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
|
||||
# NOTE - the setting in this example are somehat different than what is
|
||||
@ -350,7 +331,6 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData:
|
||||
|
||||
model_example_map = {
|
||||
"voxtral": run_voxtral,
|
||||
"gemma3n": run_gemma3n,
|
||||
"granite_speech": run_granite_speech,
|
||||
"minicpmo": run_minicpmo,
|
||||
"phi4_mm": run_phi4mm,
|
||||
|
||||
@ -68,7 +68,7 @@ def run_simple_demo(args: argparse.Namespace):
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
tensor_parallel_size=2,
|
||||
mm_processor_cache_gb=0 if args.disable_mm_processor_cache else 4,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
)
|
||||
|
||||
prompt = "Describe this image in one sentence."
|
||||
@ -105,7 +105,7 @@ def run_advanced_demo(args: argparse.Namespace):
|
||||
limit_mm_per_prompt={"image": max_img_per_msg},
|
||||
max_model_len=max_img_per_msg * max_tokens_per_img,
|
||||
tensor_parallel_size=2,
|
||||
mm_processor_cache_gb=0 if args.disable_mm_processor_cache else 4,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
)
|
||||
|
||||
prompt = "Describe the following image."
|
||||
@ -164,9 +164,9 @@ def parse_args():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--disable-mm-processor-cache",
|
||||
"--disable-mm-preprocessor-cache",
|
||||
action="store_true",
|
||||
help="If True, disables caching of multi-modal processor.",
|
||||
help="If True, disables caching of multi-modal preprocessor/mapper.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -126,29 +126,6 @@ def run_chameleon(questions: list[str], modality: str) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
def run_command_a_vision(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
model_name = "CohereLabs/command-a-vision-07-2025"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=32768,
|
||||
tensor_parallel_size=4,
|
||||
limit_mm_per_prompt={modality: 1},
|
||||
)
|
||||
|
||||
prompts = [
|
||||
f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|><|IMG_PATCH|>{question}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
|
||||
for question in questions
|
||||
]
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# Deepseek-VL2
|
||||
def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
@ -234,33 +211,7 @@ def run_gemma3(questions: list[str], modality: str) -> ModelRequestData:
|
||||
)
|
||||
for question in questions
|
||||
]
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# Gemma3N
|
||||
def run_gemma3n(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
model_name = "google/gemma-3n-E2B-it"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=2048,
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={modality: 1},
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
prompts = [
|
||||
(
|
||||
"<start_of_turn>user\n"
|
||||
f"<image_soft_token>{question}<end_of_turn>\n"
|
||||
"<start_of_turn>model\n"
|
||||
)
|
||||
for question in questions
|
||||
]
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
@ -1440,12 +1391,10 @@ model_example_map = {
|
||||
"aya_vision": run_aya_vision,
|
||||
"blip-2": run_blip2,
|
||||
"chameleon": run_chameleon,
|
||||
"command_a_vision": run_command_a_vision,
|
||||
"deepseek_vl_v2": run_deepseek_vl2,
|
||||
"florence2": run_florence2,
|
||||
"fuyu": run_fuyu,
|
||||
"gemma3": run_gemma3,
|
||||
"gemma3n": run_gemma3n,
|
||||
"glm4v": run_glm4v,
|
||||
"glm4_1v": run_glm4_1v,
|
||||
"h2ovl_chat": run_h2ovl,
|
||||
@ -1614,9 +1563,9 @@ def parse_args():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--disable-mm-processor-cache",
|
||||
"--disable-mm-preprocessor-cache",
|
||||
action="store_true",
|
||||
help="If True, disables caching of multi-modal processor.",
|
||||
help="If True, disables caching of multi-modal preprocessor/mapper.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -1654,7 +1603,7 @@ def main(args):
|
||||
|
||||
engine_args = asdict(req_data.engine_args) | {
|
||||
"seed": args.seed,
|
||||
"mm_processor_cache_gb": 0 if args.disable_mm_processor_cache else 4,
|
||||
"disable_mm_preprocessor_cache": args.disable_mm_preprocessor_cache,
|
||||
}
|
||||
llm = LLM(**engine_args)
|
||||
|
||||
|
||||
@ -107,42 +107,6 @@ def load_aya_vision(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
def load_command_a_vision(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "CohereLabs/command-a-vision-07-2025"
|
||||
|
||||
# NOTE: This model is 122B parameters and requires tensor parallelism
|
||||
# Recommended to use tp=4 on H100 GPUs
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=32768,
|
||||
tensor_parallel_size=4,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
)
|
||||
|
||||
placeholders = [{"type": "image", "image": url} for url in image_urls]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
*placeholders,
|
||||
{"type": "text", "text": question},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_name)
|
||||
|
||||
prompt = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
image_data=[fetch_image(url) for url in image_urls],
|
||||
)
|
||||
|
||||
|
||||
def load_deepseek_vl2(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "deepseek-ai/deepseek-vl2-tiny"
|
||||
|
||||
@ -1067,7 +1031,6 @@ def load_tarsier2(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_example_map = {
|
||||
"aria": load_aria,
|
||||
"aya_vision": load_aya_vision,
|
||||
"command_a_vision": load_command_a_vision,
|
||||
"deepseek_vl_v2": load_deepseek_vl2,
|
||||
"gemma3": load_gemma3,
|
||||
"h2ovl_chat": load_h2ovl,
|
||||
|
||||
@ -1,186 +0,0 @@
|
||||
# Long Text Embedding with Chunked Processing
|
||||
|
||||
This directory contains examples for using vLLM's **chunked processing** feature to handle long text embedding that exceeds the model's maximum context length.
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
### Start the Server
|
||||
|
||||
Use the provided script to start a vLLM server with chunked processing enabled:
|
||||
|
||||
```bash
|
||||
# Basic usage (supports very long texts up to ~3M tokens)
|
||||
./service.sh
|
||||
|
||||
# Custom configuration with different models
|
||||
MODEL_NAME="jinaai/jina-embeddings-v3" \
|
||||
MAX_EMBED_LEN=1048576 \
|
||||
./service.sh
|
||||
|
||||
# For extremely long documents
|
||||
MODEL_NAME="intfloat/multilingual-e5-large" \
|
||||
MAX_EMBED_LEN=3072000 \
|
||||
./service.sh
|
||||
```
|
||||
|
||||
### Test Long Text Embedding
|
||||
|
||||
Run the comprehensive test client:
|
||||
|
||||
```bash
|
||||
python client.py
|
||||
```
|
||||
|
||||
## 📁 Files
|
||||
|
||||
| File | Description |
|
||||
|------|-------------|
|
||||
| `service.sh` | Server startup script with chunked processing enabled |
|
||||
| `client.py` | Comprehensive test client for long text embedding |
|
||||
|
||||
## ⚙️ Configuration
|
||||
|
||||
### Server Configuration
|
||||
|
||||
The key parameters for chunked processing are in the `--override-pooler-config`:
|
||||
|
||||
```json
|
||||
{
|
||||
"pooling_type": "auto",
|
||||
"normalize": true,
|
||||
"enable_chunked_processing": true,
|
||||
"max_embed_len": 3072000
|
||||
}
|
||||
```
|
||||
|
||||
!!! note
|
||||
`pooling_type` sets the model's own pooling strategy for processing within each chunk. The cross-chunk aggregation automatically uses MEAN strategy when input exceeds the model's native maximum length.
|
||||
|
||||
#### Chunked Processing Behavior
|
||||
|
||||
Chunked processing uses **MEAN aggregation** for cross-chunk combination when input exceeds the model's native maximum length:
|
||||
|
||||
| Component | Behavior | Description |
|
||||
|-----------|----------|-------------|
|
||||
| **Within chunks** | Model's native pooling | Uses the model's configured pooling strategy |
|
||||
| **Cross-chunk aggregation** | Always MEAN | Weighted averaging based on chunk token counts |
|
||||
| **Performance** | Optimal | All chunks processed for complete semantic coverage |
|
||||
|
||||
### Environment Variables
|
||||
|
||||
| Variable | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| `MODEL_NAME` | `intfloat/multilingual-e5-large` | Embedding model to use (supports multiple models) |
|
||||
| `PORT` | `31090` | Server port |
|
||||
| `GPU_COUNT` | `1` | Number of GPUs to use |
|
||||
| `MAX_EMBED_LEN` | `3072000` | Maximum embedding input length (supports very long documents) |
|
||||
| `POOLING_TYPE` | `auto` | Model's native pooling type: `auto`, `MEAN`, `CLS`, `LAST` (only affects within-chunk pooling, not cross-chunk aggregation) |
|
||||
| `API_KEY` | `EMPTY` | API key for authentication |
|
||||
|
||||
## 🔧 How It Works
|
||||
|
||||
1. **Enhanced Input Validation**: `max_embed_len` allows accepting inputs longer than `max_model_len` without environment variables
|
||||
2. **Smart Chunking**: Text is split based on `max_position_embeddings` to maintain semantic integrity
|
||||
3. **Unified Processing**: All chunks processed separately through the model using its configured pooling strategy
|
||||
4. **MEAN Aggregation**: When input exceeds model's native length, results combined using token count-based weighted averaging across all chunks
|
||||
5. **Consistent Output**: Final embeddings maintain the same dimensionality as standard processing
|
||||
|
||||
### Input Length Handling
|
||||
|
||||
- **Within max_embed_len**: Input is accepted and processed (up to 3M+ tokens)
|
||||
- **Exceeds max_position_embeddings**: Chunked processing is automatically triggered
|
||||
- **Exceeds max_embed_len**: Input is rejected with clear error message
|
||||
- **No environment variables required**: Works without `VLLM_ALLOW_LONG_MAX_MODEL_LEN`
|
||||
|
||||
### Extreme Long Text Support
|
||||
|
||||
With `MAX_EMBED_LEN=3072000`, you can process:
|
||||
|
||||
- **Academic papers**: Full research papers with references
|
||||
- **Legal documents**: Complete contracts and legal texts
|
||||
- **Books**: Entire chapters or small books
|
||||
- **Code repositories**: Large codebases and documentation
|
||||
|
||||
## 📊 Performance Characteristics
|
||||
|
||||
### Chunked Processing Performance
|
||||
|
||||
| Aspect | Behavior | Performance |
|
||||
|--------|----------|-------------|
|
||||
| **Chunk Processing** | All chunks processed with native pooling | Consistent with input length |
|
||||
| **Cross-chunk Aggregation** | MEAN weighted averaging | Minimal overhead |
|
||||
| **Memory Usage** | Proportional to number of chunks | Moderate, scalable |
|
||||
| **Semantic Quality** | Complete text coverage | Optimal for long documents |
|
||||
|
||||
## 🧪 Test Cases
|
||||
|
||||
The test client demonstrates:
|
||||
|
||||
- ✅ **Short text**: Normal processing (baseline)
|
||||
- ✅ **Medium text**: Single chunk processing
|
||||
- ✅ **Long text**: Multi-chunk processing with aggregation
|
||||
- ✅ **Very long text**: Many chunks processing
|
||||
- ✅ **Extreme long text**: Document-level processing (100K+ tokens)
|
||||
- ✅ **Batch processing**: Mixed-length inputs in one request
|
||||
- ✅ **Consistency**: Reproducible results across runs
|
||||
|
||||
## 🐛 Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Chunked processing not enabled**:
|
||||
|
||||
```log
|
||||
ValueError: This model's maximum position embeddings length is 4096 tokens...
|
||||
```
|
||||
|
||||
**Solution**: Ensure `enable_chunked_processing: true` in pooler config
|
||||
|
||||
2. **Input exceeds max_embed_len**:
|
||||
|
||||
```log
|
||||
ValueError: This model's maximum embedding input length is 3072000 tokens...
|
||||
```
|
||||
|
||||
**Solution**: Increase `max_embed_len` in pooler config or reduce input length
|
||||
|
||||
3. **Memory errors**:
|
||||
|
||||
```log
|
||||
RuntimeError: CUDA out of memory
|
||||
```
|
||||
|
||||
**Solution**: Reduce chunk size by adjusting model's `max_position_embeddings` or use fewer GPUs
|
||||
|
||||
4. **Slow processing**:
|
||||
**Expected**: Long text takes more time due to multiple inference calls
|
||||
|
||||
### Debug Information
|
||||
|
||||
Server logs show chunked processing activity:
|
||||
|
||||
```log
|
||||
INFO: Input length 150000 exceeds max_position_embeddings 4096, will use chunked processing
|
||||
INFO: Split input of 150000 tokens into 37 chunks (max_chunk_size: 4096)
|
||||
```
|
||||
|
||||
## 🤝 Contributing
|
||||
|
||||
To extend chunked processing support to other embedding models:
|
||||
|
||||
1. Check model compatibility with the pooling architecture
|
||||
2. Test with various text lengths
|
||||
3. Validate embedding quality compared to single-chunk processing
|
||||
4. Submit PR with test cases and documentation updates
|
||||
|
||||
## 🆕 Enhanced Features
|
||||
|
||||
### max_embed_len Parameter
|
||||
|
||||
The new `max_embed_len` parameter provides:
|
||||
|
||||
- **Simplified Configuration**: No need for `VLLM_ALLOW_LONG_MAX_MODEL_LEN` environment variable
|
||||
- **Flexible Input Validation**: Accept inputs longer than `max_model_len` up to `max_embed_len`
|
||||
- **Extreme Length Support**: Process documents with millions of tokens
|
||||
- **Clear Error Messages**: Better feedback when inputs exceed limits
|
||||
- **Backward Compatibility**: Existing configurations continue to work
|
||||
@ -1,366 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""
|
||||
Example script demonstrating long text embedding with chunked processing in vLLM.
|
||||
|
||||
This example shows how to use vLLM's chunked processing feature to handle text
|
||||
inputs that exceed the model's maximum token length. The feature automatically
|
||||
splits long text into chunks and handles different pooling types optimally.
|
||||
|
||||
Prerequisites:
|
||||
1. Start vLLM server with chunked processing enabled:
|
||||
|
||||
# MEAN pooling (processes all chunks, recommended for complete coverage)
|
||||
vllm serve intfloat/multilingual-e5-large \
|
||||
--override-pooler-config \
|
||||
'{"pooling_type": "MEAN", "normalize": true, ' \
|
||||
'"enable_chunked_processing": true, "max_embed_len": 3072000}' \
|
||||
--served-model-name multilingual-e5-large \
|
||||
--trust-remote-code \
|
||||
--port 31090 \
|
||||
--api-key your-api-key
|
||||
|
||||
# OR CLS pooling (native CLS within chunks, MEAN aggregation across chunks)
|
||||
vllm serve BAAI/bge-large-en-v1.5 \
|
||||
--override-pooler-config \
|
||||
'{"pooling_type": "CLS", "normalize": true, ' \
|
||||
'"enable_chunked_processing": true, "max_embed_len": 1048576}' \
|
||||
--served-model-name bge-large-en-v1.5 \
|
||||
--trust-remote-code \
|
||||
--port 31090 \
|
||||
--api-key your-api-key
|
||||
|
||||
2. Install required dependencies:
|
||||
pip install openai requests
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from openai import OpenAI
|
||||
|
||||
# Configuration
|
||||
API_KEY = "your-api-key" # Replace with your actual API key
|
||||
BASE_URL = "http://localhost:31090/v1"
|
||||
MODEL_NAME = "multilingual-e5-large"
|
||||
|
||||
|
||||
def generate_long_text(base_text: str, repeat_count: int) -> str:
|
||||
"""Generate long text by repeating base text."""
|
||||
return base_text * repeat_count
|
||||
|
||||
|
||||
def test_embedding_with_different_lengths():
|
||||
"""Test embedding generation with different text lengths."""
|
||||
client = OpenAI(api_key=API_KEY, base_url=BASE_URL)
|
||||
|
||||
# Test cases with different text lengths
|
||||
test_cases = [
|
||||
{
|
||||
"name": "Short Text",
|
||||
"text": "Hello, this is a short text for embedding.",
|
||||
"expected_chunks": 1,
|
||||
},
|
||||
{
|
||||
"name": "Medium Text",
|
||||
"text": generate_long_text(
|
||||
"This is a medium-length text that should fit within the "
|
||||
"model's context window. " * 20,
|
||||
2,
|
||||
),
|
||||
"expected_chunks": 1,
|
||||
},
|
||||
{
|
||||
"name": "Long Text (2 chunks)",
|
||||
"text": generate_long_text(
|
||||
"This is a very long text that will exceed the model's "
|
||||
"maximum context length and trigger chunked processing. " * 50,
|
||||
5,
|
||||
),
|
||||
"expected_chunks": 2,
|
||||
},
|
||||
{
|
||||
"name": "Very Long Text (3+ chunks)",
|
||||
"text": generate_long_text(
|
||||
"This text is extremely long and will definitely "
|
||||
"require multiple chunks for processing. " * 100,
|
||||
10,
|
||||
),
|
||||
"expected_chunks": 3,
|
||||
},
|
||||
]
|
||||
|
||||
print("🧪 Testing vLLM Long Text Embedding with Chunked Processing")
|
||||
print("=" * 70)
|
||||
|
||||
for i, test_case in enumerate(test_cases, 1):
|
||||
print(f"\n📝 Test {i}: {test_case['name']}")
|
||||
print(f"Text length: {len(test_case['text'])} characters")
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
response = client.embeddings.create(
|
||||
input=test_case["text"], model=MODEL_NAME, encoding_format="float"
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
processing_time = end_time - start_time
|
||||
|
||||
# Extract embedding data
|
||||
embedding = response.data[0].embedding
|
||||
embedding_dim = len(embedding)
|
||||
|
||||
print("✅ Success!")
|
||||
print(f" - Embedding dimension: {embedding_dim}")
|
||||
print(f" - Processing time: {processing_time:.2f}s")
|
||||
print(f" - Expected chunks: ~{test_case['expected_chunks']}")
|
||||
print(f" - First 5 values: {embedding[:5]}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Failed: {str(e)}")
|
||||
|
||||
|
||||
def test_batch_embedding():
|
||||
"""Test batch embedding with mixed-length inputs."""
|
||||
client = OpenAI(api_key=API_KEY, base_url=BASE_URL)
|
||||
|
||||
print("\n🔄 Testing Batch Embedding with Mixed Lengths")
|
||||
print("=" * 50)
|
||||
|
||||
# Mix of short and long texts
|
||||
batch_inputs = [
|
||||
"Short text 1",
|
||||
generate_long_text("Medium length text that fits in one chunk. " * 20, 1),
|
||||
"Another short text",
|
||||
generate_long_text("Long text requiring chunked processing. " * 100, 5),
|
||||
]
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
response = client.embeddings.create(
|
||||
input=batch_inputs, model=MODEL_NAME, encoding_format="float"
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
processing_time = end_time - start_time
|
||||
|
||||
print("✅ Batch processing successful!")
|
||||
print(f" - Number of inputs: {len(batch_inputs)}")
|
||||
print(f" - Number of embeddings: {len(response.data)}")
|
||||
print(f" - Total processing time: {processing_time:.2f}s")
|
||||
print(
|
||||
f" - Average time per input: {processing_time / len(batch_inputs):.2f}s"
|
||||
)
|
||||
|
||||
for i, data in enumerate(response.data):
|
||||
input_length = len(batch_inputs[i])
|
||||
embedding_dim = len(data.embedding)
|
||||
print(
|
||||
f" - Input {i + 1}: {input_length} chars → {embedding_dim}D embedding"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Batch processing failed: {str(e)}")
|
||||
|
||||
|
||||
def test_multiple_long_texts_batch():
|
||||
"""Test batch processing with multiple long texts to verify chunk ID uniqueness."""
|
||||
client = OpenAI(api_key=API_KEY, base_url=BASE_URL)
|
||||
|
||||
print("\n🔧 Testing Multiple Long Texts in Batch (Chunk ID Fix Verification)")
|
||||
print("=" * 70)
|
||||
|
||||
# Create multiple distinct long texts that will all require chunking
|
||||
# Note: All pooling types now use MEAN aggregation across chunks:
|
||||
# - Native pooling (MEAN/CLS/LAST) is used within each chunk
|
||||
# - MEAN aggregation combines results across all chunks
|
||||
# - Full semantic coverage for all pooling types
|
||||
long_texts = [
|
||||
generate_long_text(
|
||||
"First long document about artificial intelligence and machine learning. "
|
||||
* 80,
|
||||
6,
|
||||
),
|
||||
generate_long_text(
|
||||
"Second long document about natural language processing and transformers. "
|
||||
* 80,
|
||||
6,
|
||||
),
|
||||
generate_long_text(
|
||||
"Third long document about computer vision and neural networks. " * 80, 6
|
||||
),
|
||||
]
|
||||
|
||||
# Add some short texts to mix things up
|
||||
batch_inputs = [
|
||||
"Short text before long texts",
|
||||
long_texts[0],
|
||||
"Short text between long texts",
|
||||
long_texts[1],
|
||||
long_texts[2],
|
||||
"Short text after long texts",
|
||||
]
|
||||
|
||||
print("📊 Batch composition:")
|
||||
for i, text in enumerate(batch_inputs):
|
||||
length = len(text)
|
||||
text_type = "Long (will be chunked)" if length > 5000 else "Short"
|
||||
print(f" - Input {i + 1}: {length} chars ({text_type})")
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
response = client.embeddings.create(
|
||||
input=batch_inputs, model=MODEL_NAME, encoding_format="float"
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
processing_time = end_time - start_time
|
||||
|
||||
print("\n✅ Multiple long texts batch processing successful!")
|
||||
print(f" - Number of inputs: {len(batch_inputs)}")
|
||||
print(f" - Number of embeddings returned: {len(response.data)}")
|
||||
print(f" - Total processing time: {processing_time:.2f}s")
|
||||
|
||||
# Verify each embedding is different (no incorrect aggregation)
|
||||
embeddings = [data.embedding for data in response.data]
|
||||
|
||||
if len(embeddings) >= 3:
|
||||
import numpy as np
|
||||
|
||||
# Compare embeddings of the long texts (indices 1, 3, 4)
|
||||
long_embeddings = [
|
||||
np.array(embeddings[1]), # First long text
|
||||
np.array(embeddings[3]), # Second long text
|
||||
np.array(embeddings[4]), # Third long text
|
||||
]
|
||||
|
||||
print("\n🔍 Verifying embedding uniqueness:")
|
||||
for i in range(len(long_embeddings)):
|
||||
for j in range(i + 1, len(long_embeddings)):
|
||||
cosine_sim = np.dot(long_embeddings[i], long_embeddings[j]) / (
|
||||
np.linalg.norm(long_embeddings[i])
|
||||
* np.linalg.norm(long_embeddings[j])
|
||||
)
|
||||
print(
|
||||
f" - Similarity between long text {i + 1} and {j + 1}: "
|
||||
f"{cosine_sim:.4f}"
|
||||
)
|
||||
|
||||
if (
|
||||
cosine_sim < 0.9
|
||||
): # Different content should have lower similarity
|
||||
print(" ✅ Good: Embeddings are appropriately different")
|
||||
else:
|
||||
print(
|
||||
" ⚠️ High similarity - may indicate chunk "
|
||||
"aggregation issue"
|
||||
)
|
||||
|
||||
print("\n📋 Per-input results:")
|
||||
for i, data in enumerate(response.data):
|
||||
input_length = len(batch_inputs[i])
|
||||
embedding_dim = len(data.embedding)
|
||||
embedding_norm = np.linalg.norm(data.embedding)
|
||||
print(
|
||||
f" - Input {i + 1}: {input_length} chars → {embedding_dim}D "
|
||||
f"embedding (norm: {embedding_norm:.4f})"
|
||||
)
|
||||
|
||||
print(
|
||||
"\n✅ This test verifies the fix for chunk ID collisions in "
|
||||
"batch processing"
|
||||
)
|
||||
print(" - Before fix: Multiple long texts would have conflicting chunk IDs")
|
||||
print(" - After fix: Each prompt's chunks have unique IDs with prompt index")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Multiple long texts batch test failed: {str(e)}")
|
||||
print(" This might indicate the chunk ID collision bug is present!")
|
||||
|
||||
|
||||
def test_embedding_consistency():
|
||||
"""Test that chunked processing produces consistent results."""
|
||||
client = OpenAI(api_key=API_KEY, base_url=BASE_URL)
|
||||
|
||||
print("\n🔍 Testing Embedding Consistency")
|
||||
print("=" * 40)
|
||||
|
||||
# Use the same long text multiple times
|
||||
long_text = generate_long_text(
|
||||
"Consistency test text for chunked processing validation. " * 50, 3
|
||||
)
|
||||
|
||||
embeddings = []
|
||||
|
||||
try:
|
||||
for i in range(3):
|
||||
response = client.embeddings.create(
|
||||
input=long_text, model=MODEL_NAME, encoding_format="float"
|
||||
)
|
||||
embeddings.append(response.data[0].embedding)
|
||||
print(f" - Generated embedding {i + 1}")
|
||||
|
||||
# Check consistency (embeddings should be identical)
|
||||
if len(embeddings) >= 2:
|
||||
# Calculate similarity between first two embeddings
|
||||
|
||||
emb1 = np.array(embeddings[0])
|
||||
emb2 = np.array(embeddings[1])
|
||||
|
||||
# Cosine similarity
|
||||
cosine_sim = np.dot(emb1, emb2) / (
|
||||
np.linalg.norm(emb1) * np.linalg.norm(emb2)
|
||||
)
|
||||
|
||||
print("✅ Consistency test completed!")
|
||||
print(f" - Cosine similarity between runs: {cosine_sim:.6f}")
|
||||
print(" - Expected: ~1.0 (identical embeddings)")
|
||||
|
||||
if cosine_sim > 0.999:
|
||||
print(" - ✅ High consistency achieved!")
|
||||
else:
|
||||
print(" - ⚠️ Consistency may vary due to numerical precision")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Consistency test failed: {str(e)}")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to run all tests."""
|
||||
print("🚀 vLLM Long Text Embedding Client")
|
||||
print(f"📡 Connecting to: {BASE_URL}")
|
||||
print(f"🤖 Model: {MODEL_NAME}")
|
||||
masked_key = "*" * (len(API_KEY) - 4) + API_KEY[-4:] if len(API_KEY) > 4 else "****"
|
||||
print(f"🔑 API Key: {masked_key}")
|
||||
|
||||
# Run all test cases
|
||||
test_embedding_with_different_lengths()
|
||||
test_batch_embedding()
|
||||
test_multiple_long_texts_batch()
|
||||
test_embedding_consistency()
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("🎉 All tests completed!")
|
||||
print("\n💡 Key Features Demonstrated:")
|
||||
print(" - ✅ Automatic chunked processing for long text")
|
||||
print(" - ✅ Seamless handling of mixed-length batches")
|
||||
print(" - ✅ Multiple long texts in single batch (chunk ID fix)")
|
||||
print(" - ✅ Unified chunked processing:")
|
||||
print(" • Native pooling used within each chunk")
|
||||
print(" • MEAN aggregation across all chunks")
|
||||
print(" • Complete semantic coverage for all pooling types")
|
||||
print(" - ✅ Consistent embedding generation")
|
||||
print(" - ✅ Backward compatibility with short text")
|
||||
print("\n📚 For more information, see:")
|
||||
print(
|
||||
" - Documentation: https://docs.vllm.ai/en/latest/models/pooling_models.html"
|
||||
)
|
||||
print(" - Chunked Processing Guide: openai_embedding_long_text.md")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -1,137 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# vLLM Embedding Server with Enhanced Chunked Processing
|
||||
# This script starts a vLLM server with chunked processing enabled for long text embedding.
|
||||
# Now supports proper pooling type validation and model-specific configurations.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Configuration
|
||||
MODEL_NAME=${MODEL_NAME:-"intfloat/multilingual-e5-large"}
|
||||
MODEL_CODE=${MODEL_CODE:-"multilingual-e5-large"}
|
||||
|
||||
PORT=${PORT:-31090}
|
||||
GPU_COUNT=${GPU_COUNT:-1}
|
||||
MAX_EMBED_LEN=${MAX_EMBED_LEN:-3072000}
|
||||
API_KEY=${API_KEY:-"your-api-key"}
|
||||
|
||||
# Enhanced pooling configuration with model-specific defaults
|
||||
POOLING_TYPE=${POOLING_TYPE:-"auto"} # auto, MEAN, CLS, LAST
|
||||
export VLLM_ENABLE_CHUNKED_PROCESSING=true
|
||||
export CUDA_VISIBLE_DEVICES=2,3,4,5
|
||||
# export VLLM_ATTENTION_BACKEND=XFORMERS
|
||||
|
||||
echo "🚀 Starting vLLM Embedding Server with Enhanced Chunked Processing"
|
||||
echo "=================================================================="
|
||||
|
||||
# Environment variables for optimization
|
||||
export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
|
||||
# Function to determine optimal pooling type for known models
|
||||
get_optimal_pooling_type() {
|
||||
local model="$1"
|
||||
case "$model" in
|
||||
*"e5-"* | *"multilingual-e5"*)
|
||||
echo "MEAN" # E5 series native pooling
|
||||
;;
|
||||
*"bge-"*)
|
||||
echo "CLS" # BGE series native pooling
|
||||
;;
|
||||
*"gte-"*)
|
||||
echo "LAST" # GTE series native pooling
|
||||
;;
|
||||
*"sentence-t5"* | *"st5"*)
|
||||
echo "MEAN" # Sentence-T5 native pooling
|
||||
;;
|
||||
*"jina-embeddings"*)
|
||||
echo "MEAN" # Jina embeddings native pooling
|
||||
;;
|
||||
*"Qwen"*"Embedding"*)
|
||||
echo "LAST" # Qwen embeddings native pooling
|
||||
;;
|
||||
*)
|
||||
echo "MEAN" # Default native pooling for unknown models
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
# Auto-detect pooling type if not explicitly set
|
||||
if [ "$POOLING_TYPE" = "auto" ]; then
|
||||
POOLING_TYPE=$(get_optimal_pooling_type "$MODEL_NAME")
|
||||
echo "🔍 Auto-detected pooling type: $POOLING_TYPE for model $MODEL_NAME"
|
||||
fi
|
||||
|
||||
# Display configuration
|
||||
echo "📋 Configuration:"
|
||||
echo " - Model: $MODEL_NAME"
|
||||
echo " - Port: $PORT"
|
||||
echo " - GPU Count: $GPU_COUNT"
|
||||
echo " - Enhanced Chunked Processing: ${VLLM_ENABLE_CHUNKED_PROCESSING}"
|
||||
echo " - Max Embed Length: ${MAX_EMBED_LEN} tokens"
|
||||
echo " - Native Pooling Type: $POOLING_TYPE + Normalization"
|
||||
echo " - Cross-chunk Aggregation: MEAN (automatic)"
|
||||
echo ""
|
||||
|
||||
# Validate GPU availability
|
||||
if command -v nvidia-smi &> /dev/null; then
|
||||
gpu_count=$(nvidia-smi --list-gpus | wc -l)
|
||||
echo "🖥️ Available GPUs: $gpu_count"
|
||||
if [ "$GPU_COUNT" -gt "$gpu_count" ]; then
|
||||
echo "⚠️ Warning: Requested $GPU_COUNT GPUs but only $gpu_count available"
|
||||
echo " Adjusting to use $gpu_count GPUs"
|
||||
GPU_COUNT=$gpu_count
|
||||
fi
|
||||
else
|
||||
echo "⚠️ Warning: nvidia-smi not found. GPU detection skipped."
|
||||
fi
|
||||
|
||||
# Chunked processing uses unified MEAN aggregation
|
||||
echo "ℹ️ Chunked Processing: Using $POOLING_TYPE pooling within chunks, MEAN aggregation across chunks"
|
||||
echo " - All chunks processed for complete semantic coverage"
|
||||
echo " - Weighted averaging based on chunk token counts"
|
||||
|
||||
echo ""
|
||||
echo "🔧 Starting server with enhanced chunked processing configuration..."
|
||||
|
||||
# Build pooler config JSON
|
||||
POOLER_CONFIG="{\"pooling_type\": \"$POOLING_TYPE\", \"normalize\": true, \"enable_chunked_processing\": ${VLLM_ENABLE_CHUNKED_PROCESSING}, \"max_embed_len\": ${MAX_EMBED_LEN}}"
|
||||
|
||||
# Start vLLM server with enhanced chunked processing
|
||||
vllm serve "$MODEL_NAME" \
|
||||
--tensor-parallel-size "$GPU_COUNT" \
|
||||
--enforce-eager \
|
||||
--override-pooler-config "$POOLER_CONFIG" \
|
||||
--served-model-name ${MODEL_CODE} \
|
||||
--api-key "$API_KEY" \
|
||||
--trust-remote-code \
|
||||
--port "$PORT" \
|
||||
--host 0.0.0.0
|
||||
|
||||
echo ""
|
||||
echo "✅ vLLM Embedding Server started successfully!"
|
||||
echo ""
|
||||
echo "📡 Server Information:"
|
||||
echo " - Base URL: http://localhost:$PORT"
|
||||
echo " - Model Code: ${MODEL_CODE}"
|
||||
echo " - API Key: $API_KEY"
|
||||
echo " - Native Pooling: $POOLING_TYPE | Cross-chunk: MEAN"
|
||||
echo ""
|
||||
echo "🧪 Test the server with:"
|
||||
echo " python examples/online_serving/openai_embedding_long_text_client.py"
|
||||
echo ""
|
||||
echo "📚 Enhanced features enabled:"
|
||||
echo " ✅ Intelligent native pooling type detection"
|
||||
echo " ✅ Unified MEAN aggregation for chunked processing"
|
||||
echo " ✅ Model-specific native pooling optimization"
|
||||
echo " ✅ Enhanced max embedding length (${MAX_EMBED_LEN} tokens)"
|
||||
echo " ✅ Complete semantic coverage for all pooling types"
|
||||
echo " ✅ OpenAI-compatible API"
|
||||
echo " ✅ GPU acceleration"
|
||||
echo ""
|
||||
echo "🔧 Advanced usage:"
|
||||
echo " - Set POOLING_TYPE=MEAN|CLS|LAST to override auto-detection"
|
||||
echo " - Set MAX_EMBED_LEN to adjust maximum input length"
|
||||
echo " - All pooling types use MEAN aggregation across chunks"
|
||||
@ -15,14 +15,6 @@ else
|
||||
MODEL=$2
|
||||
fi
|
||||
|
||||
# The prefillers and decoders in LMCache use the same hash seed for all chunk keys.
|
||||
# This seed must be aligned so that decoders can identify and retrieve KV cache
|
||||
# entries stored by prefillers.
|
||||
#
|
||||
# WARNING: Using a fixed hash seed is insecure and makes the application vulnerable to
|
||||
# denial-of-service attacks. In a production environment, this should be set to a
|
||||
# secure random value. This is set to a fixed value for demonstration purposes only.
|
||||
export PYTHONHASHSEED=${VLLM_PYTHON_HASH_SEED:-123}
|
||||
|
||||
if [[ $1 == "prefiller" ]]; then
|
||||
# Prefiller listens on port 8100
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user