Compare commits
35 Commits
v0.11.0
...
maybe_fix_
| Author | SHA1 | Date | |
|---|---|---|---|
| cd3ea013d6 | |||
| da63274d9f | |||
| c216119d64 | |||
| 5546acb463 | |||
| c0ec81836f | |||
| b65e56babe | |||
| 49996cd597 | |||
| ecb37e276a | |||
| a5354b3ed2 | |||
| f9df8b4ad7 | |||
| ec152c8748 | |||
| 7977e5027c | |||
| 3f5d902d2a | |||
| 27d7638b94 | |||
| 176173989a | |||
| 23b8ee672d | |||
| 3939152069 | |||
| cd87bfbf37 | |||
| b3613e3ace | |||
| d346ec695e | |||
| c242c98031 | |||
| f1d53d150c | |||
| 92da847cf5 | |||
| 3958b96bf5 | |||
| 8bf8f45822 | |||
| 6f5c0931c1 | |||
| 4e33a7ea85 | |||
| dc48ba0c75 | |||
| 4778b42660 | |||
| c70ac4b8ff | |||
| cf89202855 | |||
| f075693da7 | |||
| f708bd4904 | |||
| 0002b7f0d1 | |||
| 11aafd9886 |
@ -48,7 +48,7 @@ steps:
|
||||
agents:
|
||||
queue: cpu_queue_postmerge
|
||||
commands:
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0+PTX' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
|
||||
- "mkdir artifacts"
|
||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh"
|
||||
|
||||
@ -44,7 +44,6 @@ docker run \
|
||||
pytest -v -s v1/structured_output
|
||||
pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_eagle.py --ignore=v1/spec_decode/test_tree_attention.py
|
||||
pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py --ignore=v1/kv_connector/unit/test_shared_storage_connector.py
|
||||
pytest -v -s v1/test_metrics
|
||||
pytest -v -s v1/test_serial_utils.py
|
||||
pytest -v -s v1/test_utils.py
|
||||
pytest -v -s v1/test_metrics_reader.py
|
||||
'
|
||||
|
||||
@ -159,10 +159,7 @@ steps:
|
||||
- examples/offline_inference/rlhf.py
|
||||
- examples/offline_inference/rlhf_colocate.py
|
||||
- tests/examples/offline_inference/data_parallel.py
|
||||
- tests/v1/test_async_llm_dp.py
|
||||
- tests/v1/test_external_lb_dp.py
|
||||
- tests/v1/test_internal_lb_dp.py
|
||||
- tests/v1/test_hybrid_lb_dp.py
|
||||
- tests/v1/distributed
|
||||
- tests/v1/engine/test_engine_core_client.py
|
||||
- tests/distributed/test_symm_mem_allreduce.py
|
||||
commands:
|
||||
@ -180,10 +177,10 @@ steps:
|
||||
- TP_SIZE=2 DP_SIZE=2 ENABLE_EP=1 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py
|
||||
# test with internal dp
|
||||
- python3 ../examples/offline_inference/data_parallel.py --enforce-eager
|
||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py
|
||||
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/test_internal_lb_dp.py
|
||||
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/test_hybrid_lb_dp.py
|
||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py
|
||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py
|
||||
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_internal_lb_dp.py
|
||||
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_hybrid_lb_dp.py
|
||||
- pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp
|
||||
- pytest -v -s distributed/test_utils.py
|
||||
- pytest -v -s compile/test_basic_correctness.py
|
||||
@ -300,12 +297,9 @@ steps:
|
||||
- pytest -v -s v1/spec_decode
|
||||
- pytest -v -s v1/kv_connector/unit
|
||||
- pytest -v -s v1/metrics
|
||||
- pytest -v -s v1/test_kv_sharing.py
|
||||
- pytest -v -s v1/test_metrics_reader.py
|
||||
- pytest -v -s v1/test_oracle.py
|
||||
- pytest -v -s v1/test_request.py
|
||||
- pytest -v -s v1/test_serial_utils.py
|
||||
- pytest -v -s v1/test_utils.py
|
||||
# Integration test for streaming correctness (requires special branch).
|
||||
- pip install -U git+https://github.com/robertgshaw2-redhat/lm-evaluation-harness.git@streaming-api
|
||||
- pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine
|
||||
@ -465,29 +459,18 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s kernels/mamba
|
||||
|
||||
- label: Tensorizer Test # 14min
|
||||
timeout_in_minutes: 25
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/model_executor/model_loader
|
||||
- tests/tensorizer_loader
|
||||
- tests/entrypoints/openai/test_tensorizer_entrypoint.py
|
||||
commands:
|
||||
- apt-get update && apt-get install -y curl libsodium23
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s tensorizer_loader
|
||||
- pytest -v -s entrypoints/openai/test_tensorizer_entrypoint.py
|
||||
|
||||
- label: Model Executor Test # 7min
|
||||
timeout_in_minutes: 20
|
||||
- label: Model Executor Test # 23min
|
||||
timeout_in_minutes: 35
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/model_executor
|
||||
- tests/model_executor
|
||||
- tests/entrypoints/openai/test_tensorizer_entrypoint.py
|
||||
commands:
|
||||
- apt-get update && apt-get install -y curl libsodium23
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s model_executor
|
||||
- pytest -v -s entrypoints/openai/test_tensorizer_entrypoint.py
|
||||
|
||||
- label: Benchmarks # 11min
|
||||
timeout_in_minutes: 20
|
||||
@ -522,7 +505,7 @@ steps:
|
||||
# https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now
|
||||
# we can only upgrade after this is resolved
|
||||
- pip install --pre torchao==0.13.0.dev20250814 --index-url https://download.pytorch.org/whl/nightly/cu128
|
||||
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization
|
||||
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/
|
||||
|
||||
- label: LM Eval Small Models # 53min
|
||||
timeout_in_minutes: 75
|
||||
@ -830,6 +813,23 @@ steps:
|
||||
- uv pip install --system 'gpt-oss[eval]==0.0.5'
|
||||
- pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58 --server-args '--tensor-parallel-size 2'
|
||||
|
||||
- label: Blackwell Quantized MoE Test
|
||||
timeout_in_minutes: 60
|
||||
working_dir: "/vllm-workspace/"
|
||||
gpu: b200
|
||||
source_file_dependencies:
|
||||
- tests/quantization/test_blackwell_moe.py
|
||||
- vllm/model_executor/models/deepseek_v2.py
|
||||
- vllm/model_executor/models/gpt_oss.py
|
||||
- vllm/model_executor/models/llama4.py
|
||||
- vllm/model_executor/layers/fused_moe
|
||||
- vllm/model_executor/layers/quantization/compressed_tensors
|
||||
- vllm/model_executor/layers/quantization/modelopt.py
|
||||
- vllm/model_executor/layers/quantization/mxfp4.py
|
||||
- vllm/v1/attention/backends/flashinfer.py
|
||||
commands:
|
||||
- pytest -s -v tests/quantization/test_blackwell_moe.py
|
||||
|
||||
##### 1 GPU test #####
|
||||
##### multi gpus test #####
|
||||
|
||||
@ -889,14 +889,13 @@ steps:
|
||||
- tests/compile/test_wrapper.py
|
||||
- tests/distributed/
|
||||
- tests/entrypoints/llm/test_collective_rpc.py
|
||||
- tests/v1/test_async_llm_dp.py
|
||||
- tests/v1/test_external_lb_dp.py
|
||||
- tests/v1/distributed
|
||||
- tests/v1/entrypoints/openai/test_multi_api_servers.py
|
||||
- tests/v1/shutdown
|
||||
- tests/v1/worker/test_worker_memory_snapshot.py
|
||||
commands:
|
||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py
|
||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py
|
||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py
|
||||
- DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py
|
||||
- pytest -v -s entrypoints/llm/test_collective_rpc.py
|
||||
- pytest -v -s ./compile/test_basic_correctness.py
|
||||
|
||||
10
.github/CODEOWNERS
vendored
10
.github/CODEOWNERS
vendored
@ -12,8 +12,6 @@
|
||||
/vllm/model_executor/layers/mamba @tdoublep
|
||||
/vllm/model_executor/model_loader @22quinn
|
||||
/vllm/multimodal @DarkLight1337 @ywang96 @NickLucche
|
||||
/vllm/v1/attention @LucasWilkinson
|
||||
/vllm/v1/sample @22quinn @houseroad
|
||||
/vllm/vllm_flash_attn @LucasWilkinson
|
||||
/vllm/lora @jeejeelee
|
||||
/vllm/reasoning @aarnphm @chaunceyjiang
|
||||
@ -28,11 +26,13 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||
|
||||
# vLLM V1
|
||||
/vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat
|
||||
/vllm/v1/structured_output @mgoin @russellb @aarnphm @benchislett
|
||||
/vllm/v1/spec_decode @benchislett @luccafong
|
||||
/vllm/v1/attention @LucasWilkinson
|
||||
/vllm/v1/attention/backends/flashinfer.py @mgoin
|
||||
/vllm/v1/attention/backends/triton_attn.py @tdoublep
|
||||
/vllm/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat @heheda12345 @ApostaC
|
||||
/vllm/v1/sample @22quinn @houseroad @njhill
|
||||
/vllm/v1/spec_decode @benchislett @luccafong
|
||||
/vllm/v1/structured_output @mgoin @russellb @aarnphm @benchislett
|
||||
/vllm/v1/kv_cache_interface.py @heheda12345
|
||||
/vllm/v1/offloading @ApostaC
|
||||
|
||||
@ -54,7 +54,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||
/tests/weight_loading @mgoin @youkaichao @yewentao256
|
||||
/tests/lora @jeejeelee
|
||||
/tests/models/language/generation/test_hybrid.py @tdoublep
|
||||
/tests/v1/kv_connector/nixl_integration @NickLucche
|
||||
/tests/v1/kv_connector/nixl_integration @NickLucche
|
||||
/tests/v1/kv_connector @ApostaC
|
||||
/tests/v1/offloading @ApostaC
|
||||
|
||||
|
||||
2
.github/mergify.yml
vendored
2
.github/mergify.yml
vendored
@ -274,7 +274,7 @@ pull_request_rules:
|
||||
- files~=^vllm/model_executor/model_loader/tensorizer.py
|
||||
- files~=^vllm/model_executor/model_loader/tensorizer_loader.py
|
||||
- files~=^tests/entrypoints/openai/test_tensorizer_entrypoint.py
|
||||
- files~=^tests/tensorizer_loader/
|
||||
- files~=^tests/model_executor/model_loader/tensorizer_loader/
|
||||
actions:
|
||||
assign:
|
||||
users:
|
||||
|
||||
@ -21,6 +21,7 @@ Join us at the [PyTorch Conference, October 22-23](https://events.linuxfoundatio
|
||||
|
||||
*Latest News* 🔥
|
||||
|
||||
- [2025/09] We hosted [vLLM Toronto Meetup](https://luma.com/e80e0ymm) focused on tackling inference at scale and speculative decoding with speakers from NVIDIA and Red Hat! Please find the meetup slides [here](https://docs.google.com/presentation/d/1IYJYmJcu9fLpID5N5RbW_vO0XLo0CGOR14IXOjB61V8/edit?usp=sharing).
|
||||
- [2025/08] We hosted [vLLM Shenzhen Meetup](https://mp.weixin.qq.com/s/k8ZBO1u2_2odgiKWH_GVTQ) focusing on the ecosystem around vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1Ua2SVKVSu-wp5vou_6ElraDt2bnKhiEA).
|
||||
- [2025/08] We hosted [vLLM Singapore Meetup](https://www.sginnovate.com/event/vllm-sg-meet). We shared V1 updates, disaggregated serving and MLLM speedups with speakers from Embedded LLM, AMD, WekaIO, and A*STAR. Please find the meetup slides [here](https://drive.google.com/drive/folders/1ncf3GyqLdqFaB6IeB834E5TZJPLAOiXZ?usp=sharing).
|
||||
- [2025/08] We hosted [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/pDmAXHcN7Iqc8sUKgJgGtg) focusing on building, developing, and integrating with vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1OvLx39wnCGy_WKq8SiVKf7YcxxYI3WCH).
|
||||
|
||||
@ -584,9 +584,8 @@ def main(args: argparse.Namespace):
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
elif config.architectures[0] in (
|
||||
"DeepseekV2ForCausalLM",
|
||||
"DeepseekV3ForCausalLM",
|
||||
"DeepseekV32ForCausalLM",
|
||||
"DeepseekV2ForCausalLM",
|
||||
"Glm4MoeForCausalLM",
|
||||
):
|
||||
E = config.n_routed_experts
|
||||
|
||||
@ -18,8 +18,8 @@ if(FLASH_MLA_SRC_DIR)
|
||||
else()
|
||||
FetchContent_Declare(
|
||||
flashmla
|
||||
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA
|
||||
GIT_TAG 5f65b85703c7ed75fda01e06495077caad207c3f
|
||||
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git
|
||||
GIT_TAG a757314c04eedd166e329e846c820eb1bdd702de
|
||||
GIT_PROGRESS TRUE
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
@ -33,64 +33,23 @@ message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}")
|
||||
# The FlashMLA kernels only work on hopper and require CUDA 12.3 or later.
|
||||
# Only build FlashMLA kernels if we are building for something compatible with
|
||||
# sm90a
|
||||
|
||||
set(SUPPORT_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3)
|
||||
list(APPEND SUPPORT_ARCHS 9.0a)
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8)
|
||||
list(APPEND SUPPORT_ARCHS 10.0a)
|
||||
endif()
|
||||
|
||||
|
||||
cuda_archs_loose_intersection(FLASH_MLA_ARCHS "${SUPPORT_ARCHS}" "${CUDA_ARCHS}")
|
||||
if(FLASH_MLA_ARCHS)
|
||||
set(VLLM_FLASHMLA_GPU_FLAGS ${VLLM_GPU_FLAGS})
|
||||
list(APPEND VLLM_FLASHMLA_GPU_FLAGS "--expt-relaxed-constexpr" "--expt-extended-lambda" "--use_fast_math")
|
||||
|
||||
cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
|
||||
set(FlashMLA_SOURCES
|
||||
${flashmla_SOURCE_DIR}/csrc/torch_api.cpp
|
||||
${flashmla_SOURCE_DIR}/csrc/pybind.cpp
|
||||
${flashmla_SOURCE_DIR}/csrc/smxx/get_mla_metadata.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/smxx/mla_combine.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/sm90/decode/dense/splitkv_mla.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/sm90/decode/sparse_fp8/splitkv_mla.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/sm90/prefill/sparse/fwd.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/sm100/decode/sparse_fp8/splitkv_mla.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd.cu
|
||||
)
|
||||
|
||||
set(FlashMLA_Extension_SOURCES
|
||||
${flashmla_SOURCE_DIR}/csrc/extension/torch_api.cpp
|
||||
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/pybind.cpp
|
||||
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/flash_fwd_mla_fp8_sm90.cu
|
||||
)
|
||||
${flashmla_SOURCE_DIR}/csrc/flash_api.cpp
|
||||
${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/kernels/mla_combine.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu)
|
||||
|
||||
set(FlashMLA_INCLUDES
|
||||
${flashmla_SOURCE_DIR}/csrc
|
||||
${flashmla_SOURCE_DIR}/csrc/sm90
|
||||
${flashmla_SOURCE_DIR}/csrc/cutlass/include
|
||||
${flashmla_SOURCE_DIR}/csrc/cutlass/tools/util/include
|
||||
)
|
||||
|
||||
set(FlashMLA_Extension_INCLUDES
|
||||
${flashmla_SOURCE_DIR}/csrc
|
||||
${flashmla_SOURCE_DIR}/csrc/sm90
|
||||
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/
|
||||
${flashmla_SOURCE_DIR}/csrc/cutlass/include
|
||||
${flashmla_SOURCE_DIR}/csrc/cutlass/tools/util/include
|
||||
)
|
||||
${flashmla_SOURCE_DIR}/csrc)
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${FlashMLA_SOURCES}"
|
||||
CUDA_ARCHS "${FLASH_MLA_ARCHS}")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${FlashMLA_Extension_SOURCES}"
|
||||
CUDA_ARCHS "${FLASH_MLA_ARCHS}")
|
||||
|
||||
define_gpu_extension_target(
|
||||
_flashmla_C
|
||||
DESTINATION vllm
|
||||
@ -101,32 +60,8 @@ if(FLASH_MLA_ARCHS)
|
||||
INCLUDE_DIRECTORIES ${FlashMLA_INCLUDES}
|
||||
USE_SABI 3
|
||||
WITH_SOABI)
|
||||
|
||||
# Keep Stable ABI for the module, but *not* for CUDA/C++ files.
|
||||
# This prevents Py_LIMITED_API from affecting nvcc and C++ compiles.
|
||||
target_compile_options(_flashmla_C PRIVATE
|
||||
$<$<COMPILE_LANGUAGE:CUDA>:-UPy_LIMITED_API>
|
||||
$<$<COMPILE_LANGUAGE:CXX>:-UPy_LIMITED_API>)
|
||||
|
||||
define_gpu_extension_target(
|
||||
_flashmla_extension_C
|
||||
DESTINATION vllm
|
||||
LANGUAGE ${VLLM_GPU_LANG}
|
||||
SOURCES ${FlashMLA_Extension_SOURCES}
|
||||
COMPILE_FLAGS ${VLLM_FLASHMLA_GPU_FLAGS}
|
||||
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
||||
INCLUDE_DIRECTORIES ${FlashMLA_Extension_INCLUDES}
|
||||
USE_SABI 3
|
||||
WITH_SOABI)
|
||||
|
||||
# Keep Stable ABI for the module, but *not* for CUDA/C++ files.
|
||||
# This prevents Py_LIMITED_API from affecting nvcc and C++ compiles.
|
||||
target_compile_options(_flashmla_extension_C PRIVATE
|
||||
$<$<COMPILE_LANGUAGE:CUDA>:-UPy_LIMITED_API>
|
||||
$<$<COMPILE_LANGUAGE:CXX>:-UPy_LIMITED_API>)
|
||||
else()
|
||||
# Create empty targets for setup.py when not targeting sm90a systems
|
||||
# Create an empty target for setup.py when not targeting sm90a systems
|
||||
add_custom_target(_flashmla_C)
|
||||
add_custom_target(_flashmla_extension_C)
|
||||
endif()
|
||||
|
||||
|
||||
@ -580,22 +580,22 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
auto problem_shape = params.problem_shape;
|
||||
auto local_split_kv = params.split_kv;
|
||||
auto local_split_kv = params.split_kv;
|
||||
if (params.mainloop.ptr_seq != nullptr) {
|
||||
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
||||
if (params.ptr_split_kv != nullptr) {
|
||||
if (params.ptr_split_kv != nullptr) {
|
||||
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
||||
}
|
||||
}
|
||||
if (local_split_kv <= get<3>(blk_coord))
|
||||
continue;
|
||||
if (local_split_kv <= get<3>(blk_coord))
|
||||
continue;
|
||||
load_page_table(
|
||||
blk_coord,
|
||||
problem_shape,
|
||||
params.mainloop,
|
||||
shared_storage.tensors,
|
||||
pipeline_page_table, pipeline_pt_producer_state,
|
||||
local_split_kv
|
||||
local_split_kv
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -604,15 +604,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
auto problem_shape = params.problem_shape;
|
||||
auto local_split_kv = params.split_kv;
|
||||
auto problem_shape = params.problem_shape;
|
||||
auto local_split_kv = params.split_kv;
|
||||
if (params.mainloop.ptr_seq != nullptr) {
|
||||
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
||||
if (params.ptr_split_kv != nullptr) {
|
||||
if (params.ptr_split_kv != nullptr) {
|
||||
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
||||
}
|
||||
}
|
||||
if (local_split_kv <= get<3>(blk_coord))
|
||||
if (local_split_kv <= get<3>(blk_coord))
|
||||
continue;
|
||||
load_cpasync(
|
||||
blk_coord,
|
||||
@ -621,7 +621,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
params.mainloop_params,
|
||||
shared_storage.tensors,
|
||||
pipeline_load_qk, pipeline_load_qk_producer_state,
|
||||
local_split_kv,
|
||||
local_split_kv,
|
||||
/* must be shared pipe */
|
||||
pipeline_page_table, pipeline_pt_consumer_state
|
||||
);
|
||||
@ -633,15 +633,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
auto problem_shape = params.problem_shape;
|
||||
auto local_split_kv = params.split_kv;
|
||||
auto problem_shape = params.problem_shape;
|
||||
auto local_split_kv = params.split_kv;
|
||||
if (params.mainloop.ptr_seq != nullptr) {
|
||||
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
||||
if (params.ptr_split_kv != nullptr) {
|
||||
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
||||
}
|
||||
if (params.ptr_split_kv != nullptr) {
|
||||
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
||||
}
|
||||
}
|
||||
if (local_split_kv <= get<3>(blk_coord))
|
||||
if (local_split_kv <= get<3>(blk_coord))
|
||||
continue;
|
||||
load_tma</* paged= */ true>(
|
||||
blk_coord,
|
||||
@ -651,7 +651,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
shared_storage.tensors,
|
||||
pipeline_load_qk, pipeline_load_qk_producer_state,
|
||||
pipeline_load_qk, pipeline_load_qk_producer_state,
|
||||
local_split_kv
|
||||
local_split_kv
|
||||
);
|
||||
cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait();
|
||||
}
|
||||
@ -660,15 +660,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
auto problem_shape = params.problem_shape;
|
||||
auto local_split_kv = params.split_kv;
|
||||
auto problem_shape = params.problem_shape;
|
||||
auto local_split_kv = params.split_kv;
|
||||
if (params.mainloop.ptr_seq != nullptr) {
|
||||
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
||||
if (params.ptr_split_kv != nullptr) {
|
||||
if (params.ptr_split_kv != nullptr) {
|
||||
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
||||
}
|
||||
}
|
||||
}
|
||||
if (local_split_kv <= get<3>(blk_coord))
|
||||
if (local_split_kv <= get<3>(blk_coord))
|
||||
continue;
|
||||
load_tma<false>(
|
||||
blk_coord,
|
||||
@ -678,7 +678,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
shared_storage.tensors,
|
||||
pipeline_load_qk, pipeline_load_qk_producer_state,
|
||||
pipeline_load_qk, pipeline_load_qk_producer_state,
|
||||
local_split_kv
|
||||
local_split_kv
|
||||
);
|
||||
cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait();
|
||||
}
|
||||
@ -694,14 +694,14 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
auto problem_shape = params.problem_shape;
|
||||
auto local_split_kv = params.split_kv;
|
||||
auto local_split_kv = params.split_kv;
|
||||
if (params.mainloop.ptr_seq != nullptr) {
|
||||
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
||||
if (params.ptr_split_kv != nullptr) {
|
||||
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
||||
}
|
||||
}
|
||||
if (local_split_kv <= get<3>(blk_coord))
|
||||
if (local_split_kv <= get<3>(blk_coord))
|
||||
continue;
|
||||
mma(blk_coord,
|
||||
problem_shape,
|
||||
@ -711,7 +711,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
pipeline_mma_s, pipeline_mma_s_producer_state,
|
||||
pipeline_p_mma, pipeline_p_mma_consumer_state,
|
||||
pipeline_mma_o, pipeline_mma_o_producer_state,
|
||||
local_split_kv
|
||||
local_split_kv
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -726,15 +726,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
auto problem_shape = params.problem_shape;
|
||||
auto split_kv = params.split_kv;
|
||||
auto local_split_kv = split_kv;
|
||||
auto split_kv = params.split_kv;
|
||||
auto local_split_kv = split_kv;
|
||||
if (params.mainloop.ptr_seq != nullptr) {
|
||||
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
||||
if (params.ptr_split_kv != nullptr) {
|
||||
if (params.ptr_split_kv != nullptr) {
|
||||
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
||||
}
|
||||
}
|
||||
if (local_split_kv <= get<3>(blk_coord))
|
||||
if (local_split_kv <= get<3>(blk_coord))
|
||||
continue;
|
||||
compute(
|
||||
blk_coord,
|
||||
@ -745,7 +745,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
pipeline_mma_s, pipeline_mma_s_consumer_state,
|
||||
pipeline_p_mma, pipeline_p_mma_producer_state,
|
||||
pipeline_mma_o, pipeline_mma_o_consumer_state,
|
||||
local_split_kv
|
||||
local_split_kv
|
||||
);
|
||||
}
|
||||
|
||||
@ -1900,7 +1900,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
cutlass::arch::NamedBarrier(
|
||||
(kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp,
|
||||
kNamedBarrierEpilogue
|
||||
).arrive_and_wait();
|
||||
).arrive();
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
@ -56,11 +56,3 @@ void cp_gather_cache(
|
||||
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
|
||||
torch::Tensor const& cu_seq_lens, // [BATCH+1]
|
||||
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);
|
||||
|
||||
// Indexer K quantization and cache function
|
||||
void indexer_k_quant_and_cache(
|
||||
torch::Tensor& k, // [num_tokens, head_dim]
|
||||
torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
|
||||
torch::Tensor& slot_mapping, // [num_tokens]
|
||||
int64_t quant_block_size, // quantization block size
|
||||
const std::string& scale_fmt);
|
||||
|
||||
@ -16,7 +16,6 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cfloat> // FLT_MIN
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
@ -397,180 +396,6 @@ __global__ void concat_and_cache_mla_kernel(
|
||||
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
||||
__global__ void concat_and_cache_ds_mla_kernel(
|
||||
const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
|
||||
const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim]
|
||||
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
|
||||
// + pe_dim)]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int block_stride, //
|
||||
const int entry_stride, //
|
||||
const int kv_c_stride, //
|
||||
const int k_pe_stride, //
|
||||
const int kv_lora_rank, //
|
||||
const int pe_dim, //
|
||||
const int block_size, //
|
||||
const float* scale //
|
||||
) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
// NOTE: slot_idx can be -1 if the token is padded
|
||||
if (slot_idx < 0) {
|
||||
return;
|
||||
}
|
||||
const int64_t block_idx = slot_idx / block_size;
|
||||
const int64_t block_offset = slot_idx % block_size;
|
||||
const int64_t dst_idx_start =
|
||||
block_idx * block_stride + block_offset * entry_stride;
|
||||
|
||||
// Create 4 tile scales in shared memory
|
||||
__shared__ float smem[20];
|
||||
float* shard_abs_max = smem;
|
||||
float* tile_scales = smem + 16;
|
||||
|
||||
// For the NoPE part, each tile of 128 elements is handled by 4 warps
|
||||
// (128 threads). There are 4 total tiles, so 16 warps (512 threads).
|
||||
// The first thread of the first warp in each tile writes the scale
|
||||
// value for the tile. The RoPE part (last 64 elements) is handled
|
||||
// by another 2 warps (64 threads).
|
||||
// So in total, we use 18 warps (576 threads) per block.
|
||||
|
||||
// Cast kv_cache to 16_bit for RoPE values
|
||||
scalar_t* kv_cache_16bit =
|
||||
reinterpret_cast<scalar_t*>(&kv_cache[dst_idx_start]);
|
||||
|
||||
// The last 64 threads handle the RoPE part
|
||||
if (threadIdx.x >= kv_lora_rank) {
|
||||
const int8_t pe_idx = threadIdx.x - kv_lora_rank;
|
||||
const int64_t src_idx = token_idx * k_pe_stride + pe_idx;
|
||||
// RoPE values start after the packed 8-bit NoPE values and the
|
||||
// 32-bit scales
|
||||
const int64_t dst_idx = kv_lora_rank / 2 + 8 + pe_idx;
|
||||
kv_cache_16bit[dst_idx] = k_pe[src_idx];
|
||||
return;
|
||||
}
|
||||
|
||||
// Determine the scale for each chunk of NoPE
|
||||
const int16_t tile_idx = threadIdx.x >> 7;
|
||||
const int16_t warp_idx = (threadIdx.x & 127) >> 5;
|
||||
const int16_t lane_idx = threadIdx.x & 31;
|
||||
|
||||
// Load the NoPE element for this thread into registers
|
||||
const int64_t src_idx = token_idx * kv_c_stride + threadIdx.x;
|
||||
const scalar_t src_val = kv_c[src_idx];
|
||||
|
||||
// Warp-level reduction to find the max absolute value in the warp
|
||||
float max_abs = fabsf(src_val);
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset /= 2) {
|
||||
#ifdef USE_ROCM
|
||||
max_abs = fmaxf(max_abs, __shfl_down_sync(UINT64_MAX, max_abs, offset));
|
||||
#else
|
||||
max_abs = fmaxf(max_abs, __shfl_down_sync(0xFFFFFFFF, max_abs, offset));
|
||||
#endif
|
||||
}
|
||||
|
||||
// The first lane of each warp in each tile writes the max_abs of this part
|
||||
// of the tile to shared memory
|
||||
if (lane_idx == 0) {
|
||||
shard_abs_max[tile_idx * 4 + warp_idx] = max_abs;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// The first lane of the first warp in each tile computes the scale for the
|
||||
// tile and writes it to shared memory and to kv_cache
|
||||
if (warp_idx == 0 && lane_idx == 0) {
|
||||
float4 shard_abs_max_vec =
|
||||
reinterpret_cast<float4*>(shard_abs_max)[tile_idx];
|
||||
float tile_scale = fmaxf(fmaxf(shard_abs_max_vec.x, shard_abs_max_vec.y),
|
||||
fmaxf(shard_abs_max_vec.z, shard_abs_max_vec.w)) /
|
||||
448.f;
|
||||
|
||||
// Avoid division by zero in `scaled_convert`
|
||||
tile_scales[tile_idx] = fmaxf(tile_scale, FLT_MIN);
|
||||
float* kv_cache_32bit = reinterpret_cast<float*>(&kv_cache[dst_idx_start]);
|
||||
const uint64_t dst_idx = kv_lora_rank / 4 + tile_idx;
|
||||
kv_cache_32bit[dst_idx] = tile_scales[tile_idx];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Now all threads in the block scale and write their element
|
||||
const float scale_val = tile_scales[tile_idx];
|
||||
const int64_t dst_idx = dst_idx_start + threadIdx.x;
|
||||
kv_cache[dst_idx] =
|
||||
fp8::scaled_convert<uint8_t, scalar_t, Fp8KVCacheDataType::kFp8E4M3>(
|
||||
src_val, scale_val);
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
||||
__global__ void indexer_k_quant_and_cache_kernel(
|
||||
const scalar_t* __restrict__ k, // [num_tokens, head_dim]
|
||||
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, cache_stride]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int head_dim, // dimension of each head
|
||||
const int quant_block_size, // quantization block size
|
||||
const int cache_block_size, // cache block size
|
||||
const int cache_stride, // stride for each token in kv_cache
|
||||
const bool use_ue8m0 // use ue8m0 scale format
|
||||
) {
|
||||
constexpr int VEC_SIZE = 4;
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t head_dim_idx = (blockIdx.y * blockDim.y * blockDim.x +
|
||||
threadIdx.y * blockDim.x + threadIdx.x) *
|
||||
VEC_SIZE;
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
const int64_t block_idx = slot_idx / cache_block_size;
|
||||
const int64_t block_offset = slot_idx % cache_block_size;
|
||||
|
||||
// NOTE: slot_idx can be -1 if the token is padded
|
||||
if (slot_idx < 0 || (head_dim_idx >= head_dim)) {
|
||||
return;
|
||||
}
|
||||
|
||||
float2 k_val = (reinterpret_cast<const float2*>(
|
||||
k))[(token_idx * head_dim + head_dim_idx) / VEC_SIZE];
|
||||
scalar_t* k_val_ptr = reinterpret_cast<scalar_t*>(&k_val);
|
||||
float amax = 0.0f;
|
||||
for (int i = 0; i < VEC_SIZE; i++) {
|
||||
amax = fmaxf(amax, fabsf(float(k_val_ptr[i])));
|
||||
}
|
||||
#ifndef USE_ROCM
|
||||
__syncwarp();
|
||||
#endif
|
||||
|
||||
// Reduced amax
|
||||
for (int mask = 16; mask > 0; mask /= 2) {
|
||||
#ifdef USE_ROCM
|
||||
amax = fmaxf(amax, __shfl_xor_sync(uint64_t(-1), amax, mask));
|
||||
#else
|
||||
amax = fmaxf(amax, __shfl_xor_sync(unsigned(-1), amax, mask));
|
||||
#endif
|
||||
}
|
||||
#ifndef USE_ROCM
|
||||
__syncwarp();
|
||||
#endif
|
||||
float scale = fmaxf(amax, 1e-4) / 448.0f;
|
||||
if (use_ue8m0) {
|
||||
scale = exp2f(ceilf(log2f(scale)));
|
||||
}
|
||||
|
||||
const int64_t dst_offset = block_idx * cache_block_size * cache_stride +
|
||||
block_offset * head_dim + head_dim_idx;
|
||||
for (int i = 0; i < VEC_SIZE; i++) {
|
||||
kv_cache[dst_offset + i] =
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(k_val_ptr[i], scale);
|
||||
}
|
||||
if (threadIdx.x == 0) {
|
||||
const int64_t dst_scale_idx =
|
||||
block_idx * cache_block_size * cache_stride +
|
||||
cache_block_size * head_dim +
|
||||
(block_offset * head_dim + head_dim_idx) * 4 / quant_block_size;
|
||||
reinterpret_cast<float*>(kv_cache)[dst_scale_idx / 4] = scale;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
// KV_T is the data type of key and value tensors.
|
||||
@ -613,7 +438,7 @@ void reshape_and_cache(
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
|
||||
CALL_RESHAPE_AND_CACHE);
|
||||
CALL_RESHAPE_AND_CACHE)
|
||||
}
|
||||
|
||||
// KV_T is the data type of key and value tensors.
|
||||
@ -684,18 +509,6 @@ void reshape_and_cache_flash(
|
||||
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
|
||||
reinterpret_cast<const float*>(scale.data_ptr()));
|
||||
|
||||
// KV_T is the data type of key and value tensors.
|
||||
// CACHE_T is the stored data type of kv-cache.
|
||||
#define CALL_CONCAT_AND_CACHE_DS_MLA(KV_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::concat_and_cache_ds_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
|
||||
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), block_stride, entry_stride, \
|
||||
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
|
||||
reinterpret_cast<const float*>(scale.data_ptr()));
|
||||
|
||||
void concat_and_cache_mla(
|
||||
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
|
||||
torch::Tensor& k_pe, // [num_tokens, pe_dim]
|
||||
@ -718,44 +531,20 @@ void concat_and_cache_mla(
|
||||
int pe_dim = k_pe.size(1);
|
||||
int block_size = kv_cache.size(1);
|
||||
|
||||
if (kv_cache_dtype == "fp8_ds_mla") {
|
||||
TORCH_CHECK(kv_lora_rank == 512, "kv_lora_rank must be 512 for fp8_ds_mla");
|
||||
TORCH_CHECK(pe_dim == 64, "pe_dim must be 64 for fp8_ds_mla");
|
||||
TORCH_CHECK(kv_cache.size(2) == 656 / kv_cache.itemsize(),
|
||||
"kv_cache.size(2) must be 656 bytes for fp8_ds_mla");
|
||||
TORCH_CHECK(kv_c.itemsize() == 2,
|
||||
"kv_c.itemsize() must be 2 for fp8_ds_mla");
|
||||
TORCH_CHECK(k_pe.itemsize() == 2,
|
||||
"k_pe.itemsize() must be 2 for fp8_ds_mla");
|
||||
} else {
|
||||
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
|
||||
}
|
||||
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
|
||||
|
||||
int kv_c_stride = kv_c.stride(0);
|
||||
int k_pe_stride = k_pe.stride(0);
|
||||
int block_stride = kv_cache.stride(0);
|
||||
int entry_stride = kv_cache.stride(1);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(kv_lora_rank, 512));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
if (kv_cache_dtype == "fp8_ds_mla") {
|
||||
dim3 grid(num_tokens);
|
||||
// For the NoPE part, each tile of 128 elements is handled by 4 warps
|
||||
// (128 threads). There are 4 total tiles, so 16 warps (512 threads).
|
||||
// The first thread of the first warp in each tile writes the scale
|
||||
// value for the tile. The RoPE part (last 64 elements) is handled
|
||||
// by another 2 warps (64 threads).
|
||||
// So in total, we use 18 warps (576 threads) per block.
|
||||
dim3 block(576);
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
|
||||
CALL_CONCAT_AND_CACHE_DS_MLA);
|
||||
} else {
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(kv_lora_rank, 512));
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
|
||||
CALL_CONCAT_AND_CACHE_MLA);
|
||||
}
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
|
||||
CALL_CONCAT_AND_CACHE_MLA);
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
@ -1133,42 +922,3 @@ void cp_gather_cache(
|
||||
TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits);
|
||||
}
|
||||
}
|
||||
|
||||
// Macro to dispatch the kernel based on the data type.
|
||||
#define CALL_INDEXER_K_QUANT_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::indexer_k_quant_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<KV_T*>(k.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), head_dim, quant_block_size, \
|
||||
cache_block_size, cache_stride, use_ue8m0);
|
||||
|
||||
void indexer_k_quant_and_cache(
|
||||
torch::Tensor& k, // [num_tokens, head_dim]
|
||||
torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
|
||||
torch::Tensor& slot_mapping, // [num_tokens]
|
||||
int64_t quant_block_size, // quantization block size
|
||||
const std::string& scale_fmt) {
|
||||
int num_tokens = k.size(0);
|
||||
int head_dim = k.size(1);
|
||||
int cache_block_size = kv_cache.size(1);
|
||||
int cache_stride = kv_cache.size(2);
|
||||
bool use_ue8m0 = scale_fmt == "ue8m0";
|
||||
|
||||
TORCH_CHECK(k.device() == kv_cache.device(),
|
||||
"k and kv_cache must be on the same device");
|
||||
TORCH_CHECK(k.device() == slot_mapping.device(),
|
||||
"k and slot_mapping must be on the same device");
|
||||
TORCH_CHECK(head_dim % quant_block_size == 0,
|
||||
"head_dim must be divisible by quant_block_size");
|
||||
|
||||
constexpr int vec_size = 4;
|
||||
dim3 grid(num_tokens, (head_dim + quant_block_size * vec_size - 1) /
|
||||
(quant_block_size * vec_size));
|
||||
dim3 block(32, vec_size);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(k));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(k.dtype(), "fp8_e4m3",
|
||||
CALL_INDEXER_K_QUANT_AND_CACHE);
|
||||
}
|
||||
|
||||
16
csrc/core/batch_invariant.hpp
Normal file
16
csrc/core/batch_invariant.hpp
Normal file
@ -0,0 +1,16 @@
|
||||
#pragma once
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
#include <cctype>
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// vllm_kernel_override_batch_invariant(); returns true
|
||||
// if env VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT=1
|
||||
inline bool vllm_kernel_override_batch_invariant() {
|
||||
std::string env_key = "VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT";
|
||||
const char* val = std::getenv(env_key.c_str());
|
||||
return (val && std::atoi(val) != 0) ? 1 : 0;
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@ -1,6 +1,7 @@
|
||||
#include "type_convert.cuh"
|
||||
#include "dispatch_utils.h"
|
||||
#include "cub_helpers.h"
|
||||
#include "core/batch_invariant.hpp"
|
||||
|
||||
#include <torch/cuda.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
@ -413,7 +414,9 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
|
||||
wt_ptr % req_alignment_bytes == 0;
|
||||
bool offsets_are_multiple_of_vector_width =
|
||||
hidden_size % vector_width == 0 && input_stride % vector_width == 0;
|
||||
if (ptrs_are_aligned && offsets_are_multiple_of_vector_width) {
|
||||
bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
|
||||
if (ptrs_are_aligned && offsets_are_multiple_of_vector_width &&
|
||||
!batch_invariant_launch) {
|
||||
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
||||
} else {
|
||||
LAUNCH_FUSED_ADD_RMS_NORM(0);
|
||||
@ -459,7 +462,8 @@ void poly_norm(torch::Tensor& out, // [..., hidden_size]
|
||||
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
|
||||
auto out_ptr = reinterpret_cast<std::uintptr_t>(out.data_ptr());
|
||||
bool ptrs_are_aligned = inp_ptr % 16 == 0 && out_ptr % 16 == 0;
|
||||
if (ptrs_are_aligned && hidden_size % 8 == 0) {
|
||||
bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
|
||||
if (ptrs_are_aligned && hidden_size % 8 == 0 && !batch_invariant_launch) {
|
||||
LAUNCH_FUSED_POLY_NORM(8);
|
||||
} else {
|
||||
LAUNCH_FUSED_POLY_NORM(0);
|
||||
|
||||
@ -9,6 +9,7 @@
|
||||
#include "quantization/fp8/common.cuh"
|
||||
#include "dispatch_utils.h"
|
||||
#include "cub_helpers.h"
|
||||
#include "core/batch_invariant.hpp"
|
||||
|
||||
#include <torch/cuda.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
@ -240,7 +241,9 @@ void fused_add_rms_norm_static_fp8_quant(
|
||||
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
|
||||
bool ptrs_are_aligned =
|
||||
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
|
||||
if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0) {
|
||||
bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
|
||||
if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0 &&
|
||||
!batch_invariant_launch) {
|
||||
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
||||
} else {
|
||||
LAUNCH_FUSED_ADD_RMS_NORM(0);
|
||||
|
||||
@ -21,6 +21,7 @@
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include "../cuda_compat.h"
|
||||
#include "../cub_helpers.h"
|
||||
#include "../core/batch_invariant.hpp"
|
||||
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
@ -405,7 +406,8 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
|
||||
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM>;
|
||||
static constexpr int VPT = Constants::VPT;
|
||||
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
|
||||
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
|
||||
const bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
|
||||
const int num_warps = batch_invariant_launch ? 32 : (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
|
||||
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
|
||||
|
||||
dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB);
|
||||
|
||||
@ -576,17 +576,6 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
|
||||
TORCH_CHECK(false, \
|
||||
"Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||
} \
|
||||
} else if (KV_DTYPE == "fp8_ds_mla") { \
|
||||
if (SRC_DTYPE == at::ScalarType::Float) { \
|
||||
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
||||
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
||||
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||
} else { \
|
||||
TORCH_CHECK(false, \
|
||||
"Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||
} \
|
||||
} else { \
|
||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
|
||||
} \
|
||||
|
||||
@ -713,13 +713,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
"cp_gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, "
|
||||
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()");
|
||||
cache_ops.impl("cp_gather_cache", torch::kCUDA, &cp_gather_cache);
|
||||
|
||||
cache_ops.def(
|
||||
"indexer_k_quant_and_cache(Tensor k, Tensor! kv_cache, Tensor "
|
||||
"slot_mapping, "
|
||||
"int quant_block_size, str kv_cache_dtype) -> ()");
|
||||
cache_ops.impl("indexer_k_quant_and_cache", torch::kCUDA,
|
||||
&indexer_k_quant_and_cache);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
|
||||
|
||||
@ -14,11 +14,6 @@ ARG PYTHON_VERSION=3.12
|
||||
#
|
||||
# Example:
|
||||
# docker build --build-arg BUILD_BASE_IMAGE=registry.acme.org/mirror/nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04
|
||||
|
||||
# Important: We build with an old version of Ubuntu to maintain broad
|
||||
# compatibility with other Linux OSes. The main reason for this is that the
|
||||
# glibc version is baked into the distro, and binaries built with one glibc
|
||||
# version are not backwards compatible with OSes that use an earlier version.
|
||||
ARG BUILD_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04
|
||||
# TODO: Restore to base image after FlashInfer AOT wheel fixed
|
||||
ARG FINAL_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04
|
||||
@ -80,19 +75,34 @@ ARG TARGETPLATFORM
|
||||
ARG INSTALL_KV_CONNECTORS=false
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
ARG DEADSNAKES_MIRROR_URL
|
||||
ARG DEADSNAKES_GPGKEY_URL
|
||||
ARG GET_PIP_URL
|
||||
|
||||
# Install system dependencies and uv, then create Python virtual environment
|
||||
# Install Python and other dependencies
|
||||
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
|
||||
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y ccache software-properties-common git curl sudo python3-pip \
|
||||
&& curl -LsSf https://astral.sh/uv/install.sh | sh \
|
||||
&& $HOME/.local/bin/uv venv /opt/venv --python ${PYTHON_VERSION} \
|
||||
&& rm -f /usr/bin/python3 /usr/bin/python3-config /usr/bin/pip \
|
||||
&& ln -s /opt/venv/bin/python3 /usr/bin/python3 \
|
||||
&& ln -s /opt/venv/bin/python3-config /usr/bin/python3-config \
|
||||
&& ln -s /opt/venv/bin/pip /usr/bin/pip \
|
||||
&& apt-get install -y ccache software-properties-common git curl sudo \
|
||||
&& if [ ! -z ${DEADSNAKES_MIRROR_URL} ] ; then \
|
||||
if [ ! -z "${DEADSNAKES_GPGKEY_URL}" ] ; then \
|
||||
mkdir -p -m 0755 /etc/apt/keyrings ; \
|
||||
curl -L ${DEADSNAKES_GPGKEY_URL} | gpg --dearmor > /etc/apt/keyrings/deadsnakes.gpg ; \
|
||||
sudo chmod 644 /etc/apt/keyrings/deadsnakes.gpg ; \
|
||||
echo "deb [signed-by=/etc/apt/keyrings/deadsnakes.gpg] ${DEADSNAKES_MIRROR_URL} $(lsb_release -cs) main" > /etc/apt/sources.list.d/deadsnakes.list ; \
|
||||
fi ; \
|
||||
else \
|
||||
for i in 1 2 3; do \
|
||||
add-apt-repository -y ppa:deadsnakes/ppa && break || \
|
||||
{ echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \
|
||||
done ; \
|
||||
fi \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
|
||||
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
|
||||
&& update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \
|
||||
&& ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \
|
||||
&& curl -sS ${GET_PIP_URL} | python${PYTHON_VERSION} \
|
||||
&& python3 --version && python3 -m pip --version
|
||||
|
||||
ARG PIP_INDEX_URL UV_INDEX_URL
|
||||
@ -101,9 +111,9 @@ ARG PYTORCH_CUDA_INDEX_BASE_URL
|
||||
ARG PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL
|
||||
ARG PIP_KEYRING_PROVIDER UV_KEYRING_PROVIDER
|
||||
|
||||
# Activate virtual environment and add uv to PATH
|
||||
ENV PATH="/opt/venv/bin:/root/.local/bin:$PATH"
|
||||
ENV VIRTUAL_ENV="/opt/venv"
|
||||
# Install uv for faster pip installs
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
python3 -m pip install uv
|
||||
|
||||
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
@ -132,7 +142,7 @@ WORKDIR /workspace
|
||||
COPY requirements/common.txt requirements/common.txt
|
||||
COPY requirements/cuda.txt requirements/cuda.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --python /opt/venv/bin/python3 -r requirements/cuda.txt \
|
||||
uv pip install --system -r requirements/cuda.txt \
|
||||
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
|
||||
# cuda arch list used by torch
|
||||
@ -162,7 +172,7 @@ ENV UV_INDEX_STRATEGY="unsafe-best-match"
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --python /opt/venv/bin/python3 -r requirements/build.txt \
|
||||
uv pip install --system -r requirements/build.txt \
|
||||
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
|
||||
COPY . .
|
||||
@ -259,7 +269,7 @@ COPY requirements/lint.txt requirements/lint.txt
|
||||
COPY requirements/test.txt requirements/test.txt
|
||||
COPY requirements/dev.txt requirements/dev.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --python /opt/venv/bin/python3 -r requirements/dev.txt \
|
||||
uv pip install --system -r requirements/dev.txt \
|
||||
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
#################### DEV IMAGE ####################
|
||||
|
||||
@ -381,18 +391,28 @@ RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
|
||||
git clone --depth 1 --recursive --shallow-submodules \
|
||||
--branch ${FLASHINFER_GIT_REF} \
|
||||
${FLASHINFER_GIT_REPO} flashinfer
|
||||
# Exclude CUDA arches for older versions (11.x and 12.0-12.7)
|
||||
# TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg.
|
||||
if [[ "${CUDA_VERSION}" == 11.* ]]; then
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9"
|
||||
elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a"
|
||||
else
|
||||
# CUDA 12.8+ supports 10.0a and 12.0
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0"
|
||||
fi
|
||||
pushd flashinfer
|
||||
if [ "${FLASHINFER_AOT_COMPILE}" = "true" ]; then
|
||||
# Exclude CUDA arches for older versions (11.x and 12.0-12.7)
|
||||
# TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg.
|
||||
if [[ "${CUDA_VERSION}" == 11.* ]]; then
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9"
|
||||
elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a"
|
||||
else
|
||||
# CUDA 12.8+ supports 10.0a and 12.0
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0"
|
||||
if [[ "${CUDA_VERSION}" == 12.8.* ]] && [ "$TARGETPLATFORM" = "linux/amd64" ]; then
|
||||
# NOTE: To make new precompiled wheels, see tools/flashinfer-build.sh
|
||||
echo "🏗️ Installing FlashInfer from pre-compiled wheel"
|
||||
uv pip install --system https://wheels.vllm.ai/flashinfer-python/flashinfer_python-0.3.1-cp39-abi3-manylinux1_x86_64.whl \
|
||||
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
if [ "${FLASHINFER_AOT_COMPILE}" = "true" ]; then
|
||||
# Download pre-compiled cubins
|
||||
TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \
|
||||
python3 -m flashinfer --download-cubin || echo "WARNING: Failed to download flashinfer cubins."
|
||||
fi
|
||||
elif [ "${FLASHINFER_AOT_COMPILE}" = "true" ]; then
|
||||
echo "🏗️ Installing FlashInfer with AOT compilation for arches: ${FI_TORCH_CUDA_ARCH_LIST}"
|
||||
export FLASHINFER_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}"
|
||||
# HACK: We need these to run flashinfer.aot before installing flashinfer, get from the package in the future
|
||||
@ -454,6 +474,12 @@ ENV CUDA_HOME=/usr/local/cuda
|
||||
RUN export TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST:-9.0a+PTX}" \
|
||||
&& bash install_python_libraries.sh
|
||||
|
||||
# CUDA image changed from /usr/local/nvidia to /usr/local/cuda in 12.8 but will
|
||||
# return to /usr/local/nvidia in 13.0 to allow container providers to mount drivers
|
||||
# consistently from the host (see https://github.com/vllm-project/vllm/issues/18859).
|
||||
# Until then, add /usr/local/nvidia/lib64 before the image cuda path to allow override.
|
||||
ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib64:${LD_LIBRARY_PATH}
|
||||
|
||||
#################### vLLM installation IMAGE ####################
|
||||
|
||||
#################### TEST IMAGE ####################
|
||||
@ -526,7 +552,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
else \
|
||||
BITSANDBYTES_VERSION="0.46.1"; \
|
||||
fi; \
|
||||
uv pip install --system accelerate hf_transfer modelscope "bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm>=1.0.17' boto3 runai-model-streamer runai-model-streamer[s3]
|
||||
uv pip install --system accelerate hf_transfer modelscope "bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm>=1.0.17' 'runai-model-streamer[s3]>=0.14.0'
|
||||
|
||||
ENV VLLM_USAGE_SOURCE production-docker-image
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@ ARG CUDA_VERSION=12.8.0
|
||||
#
|
||||
#################### BASE BUILD IMAGE ####################
|
||||
# prepare basic build environment
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 AS base
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 AS base
|
||||
ARG CUDA_VERSION=12.8.0
|
||||
ARG PYTHON_VERSION=3.12
|
||||
ARG TARGETPLATFORM
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below:
|
||||
|
||||
- [vLLM Toronto Meetup](https://luma.com/e80e0ymm), September 25th 2025. [[Slides]](https://docs.google.com/presentation/d/1IYJYmJcu9fLpID5N5RbW_vO0XLo0CGOR14IXOjB61V8/edit?usp=sharing)
|
||||
- [vLLM Shenzhen Meetup](https://mp.weixin.qq.com/s/k8ZBO1u2_2odgiKWH_GVTQ), August 30th 2025. [[Slides]](https://drive.google.com/drive/folders/1Ua2SVKVSu-wp5vou_6ElraDt2bnKhiEA)
|
||||
- [vLLM Singapore Meetup](https://www.sginnovate.com/event/vllm-sg-meet), August 27th 2025. [[Slides]](https://drive.google.com/drive/folders/1ncf3GyqLdqFaB6IeB834E5TZJPLAOiXZ?usp=sharing)
|
||||
- [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/pDmAXHcN7Iqc8sUKgJgGtg), August 23rd 2025. [[Slides]](https://drive.google.com/drive/folders/1OvLx39wnCGy_WKq8SiVKf7YcxxYI3WCH)
|
||||
|
||||
@ -66,35 +66,12 @@ Further update the model as follows:
|
||||
!!! important
|
||||
The returned `multimodal_embeddings` must be either a **3D [torch.Tensor][]** of shape `(num_items, feature_size, hidden_size)`, or a **list / tuple of 2D [torch.Tensor][]'s** of shape `(feature_size, hidden_size)`, so that `multimodal_embeddings[i]` retrieves the embeddings generated from the `i`-th multimodal data item (e.g, image) of the request.
|
||||
|
||||
- Implement [get_input_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings] to merge `multimodal_embeddings` with text embeddings from the `input_ids`. If input processing for the model is implemented correctly (see sections below), then you can leverage the utility function we provide to easily merge the embeddings.
|
||||
!!! note
|
||||
By default, vLLM merges the multimodal embeddings into text embeddings depending on the information of their locations defined in
|
||||
[PlaceholderRange][vllm.multimodal.inputs.PlaceholderRange] from input processing.
|
||||
This logic can be found at [get_input_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings].
|
||||
|
||||
??? code
|
||||
|
||||
```python
|
||||
from .utils import merge_multimodal_embeddings
|
||||
|
||||
class YourModelForImage2Seq(nn.Module):
|
||||
...
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# `get_input_embeddings` should already be implemented for the language
|
||||
# model as one of the requirements of basic vLLM model implementation.
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
|
||||
if multimodal_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
placeholder_token_id=self.config.image_token_index)
|
||||
|
||||
return inputs_embeds
|
||||
```
|
||||
You may override this method if additional logic is required for your model when merging embeddings.
|
||||
|
||||
- Implement [get_language_model][vllm.model_executor.models.interfaces.SupportsMultiModal.get_language_model] getter to provide stable access to the underlying language model.
|
||||
|
||||
|
||||
@ -208,3 +208,11 @@ One example is [snakeviz](https://jiffyclub.github.io/snakeviz/).
|
||||
pip install snakeviz
|
||||
snakeviz expensive_function.prof
|
||||
```
|
||||
|
||||
### Analyzing Garbage Collection Costs
|
||||
|
||||
Leverage VLLM_GC_DEBUG environment variable to debug GC costs.
|
||||
|
||||
- VLLM_GC_DEBUG=1: enable GC debugger with gc.collect elpased times
|
||||
- VLLM_GC_DEBUG='{"top_objects":5}': enable GC debugger to log top 5
|
||||
collected objects for each gc.collect
|
||||
|
||||
@ -8,9 +8,6 @@ This page teaches you how to pass multi-modal inputs to [multi-modal models][sup
|
||||
|
||||
!!! tip
|
||||
When serving multi-modal models, consider setting `--allowed-media-domains` to restrict domain that vLLM can access to prevent it from accessing arbitrary endpoints that can potentially be vulnerable to Server-Side Request Forgery (SSRF) attacks. You can provide a list of domains for this arg. For example: `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`
|
||||
|
||||
Also, consider setting `VLLM_MEDIA_URL_ALLOW_REDIRECTS=0` to prevent HTTP redirects from being followed to bypass domain restrictions.
|
||||
|
||||
This restriction is especially important if you run vLLM in a containerized environment where the vLLM pods may have unrestricted access to internal networks.
|
||||
|
||||
## Offline Inference
|
||||
|
||||
@ -351,13 +351,92 @@ you can use the [official OpenAI Python client](https://github.com/openai/openai
|
||||
To use the Transcriptions API, please install with extra audio dependencies using `pip install vllm[audio]`.
|
||||
|
||||
Code example: <gh-file:examples/online_serving/openai_transcription_client.py>
|
||||
<!-- TODO: api enforced limits + uploading audios -->
|
||||
|
||||
#### API Enforced Limits
|
||||
|
||||
Set the maximum audio file size (in MB) that VLLM will accept, via the
|
||||
`VLLM_MAX_AUDIO_CLIP_FILESIZE_MB` environment variable. Default is 25 MB.
|
||||
|
||||
#### Uploading Audio Files
|
||||
|
||||
The Transcriptions API supports uploading audio files in various formats including FLAC, MP3, MP4, MPEG, MPGA, M4A, OGG, WAV, and WEBM.
|
||||
|
||||
**Using OpenAI Python Client:**
|
||||
|
||||
??? code
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:8000/v1",
|
||||
api_key="token-abc123",
|
||||
)
|
||||
|
||||
# Upload audio file from disk
|
||||
with open("audio.mp3", "rb") as audio_file:
|
||||
transcription = client.audio.transcriptions.create(
|
||||
model="openai/whisper-large-v3-turbo",
|
||||
file=audio_file,
|
||||
language="en",
|
||||
response_format="verbose_json"
|
||||
)
|
||||
|
||||
print(transcription.text)
|
||||
```
|
||||
|
||||
**Using curl with multipart/form-data:**
|
||||
|
||||
??? code
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:8000/v1/audio/transcriptions" \
|
||||
-H "Authorization: Bearer token-abc123" \
|
||||
-F "file=@audio.mp3" \
|
||||
-F "model=openai/whisper-large-v3-turbo" \
|
||||
-F "language=en" \
|
||||
-F "response_format=verbose_json"
|
||||
```
|
||||
|
||||
**Supported Parameters:**
|
||||
|
||||
- `file`: The audio file to transcribe (required)
|
||||
- `model`: The model to use for transcription (required)
|
||||
- `language`: The language code (e.g., "en", "zh") (optional)
|
||||
- `prompt`: Optional text to guide the transcription style (optional)
|
||||
- `response_format`: Format of the response ("json", "text") (optional)
|
||||
- `temperature`: Sampling temperature between 0 and 1 (optional)
|
||||
|
||||
For the complete list of supported parameters including sampling parameters and vLLM extensions, see the [protocol definitions](https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/protocol.py#L2182).
|
||||
|
||||
**Response Format:**
|
||||
|
||||
For `verbose_json` response format:
|
||||
|
||||
??? code
|
||||
|
||||
```json
|
||||
{
|
||||
"text": "Hello, this is a transcription of the audio file.",
|
||||
"language": "en",
|
||||
"duration": 5.42,
|
||||
"segments": [
|
||||
{
|
||||
"id": 0,
|
||||
"seek": 0,
|
||||
"start": 0.0,
|
||||
"end": 2.5,
|
||||
"text": "Hello, this is a transcription",
|
||||
"tokens": [50364, 938, 428, 307, 275, 28347],
|
||||
"temperature": 0.0,
|
||||
"avg_logprob": -0.245,
|
||||
"compression_ratio": 1.235,
|
||||
"no_speech_prob": 0.012
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
#### Extra Parameters
|
||||
|
||||
The following [sampling parameters][sampling-params] are supported.
|
||||
|
||||
@ -66,9 +66,6 @@ Restrict domains that vLLM can access for media URLs by setting
|
||||
`--allowed-media-domains` to prevent Server-Side Request Forgery (SSRF) attacks.
|
||||
(e.g. `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`)
|
||||
|
||||
Also, consider setting `VLLM_MEDIA_URL_ALLOW_REDIRECTS=0` to prevent HTTP
|
||||
redirects from being followed to bypass domain restrictions.
|
||||
|
||||
## Security and Firewalls: Protecting Exposed vLLM Systems
|
||||
|
||||
While vLLM is designed to allow unsafe network services to be isolated to
|
||||
|
||||
@ -38,11 +38,13 @@ client = OpenAI(
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
|
||||
headers = {"User-Agent": "vLLM Example Client"}
|
||||
|
||||
|
||||
def encode_base64_content_from_url(content_url: str) -> str:
|
||||
"""Encode a content retrieved from a remote url to base64 format."""
|
||||
|
||||
with requests.get(content_url) as response:
|
||||
with requests.get(content_url, headers=headers) as response:
|
||||
response.raise_for_status()
|
||||
result = base64.b64encode(response.content).decode("utf-8")
|
||||
|
||||
@ -50,19 +52,19 @@ def encode_base64_content_from_url(content_url: str) -> str:
|
||||
|
||||
|
||||
# Text-only inference
|
||||
def run_text_only(model: str) -> None:
|
||||
def run_text_only(model: str, max_completion_tokens: int) -> None:
|
||||
chat_completion = client.chat.completions.create(
|
||||
messages=[{"role": "user", "content": "What's the capital of France?"}],
|
||||
model=model,
|
||||
max_completion_tokens=64,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
)
|
||||
|
||||
result = chat_completion.choices[0].message.content
|
||||
print("Chat completion output:", result)
|
||||
print("Chat completion output:\n", result)
|
||||
|
||||
|
||||
# Single-image input inference
|
||||
def run_single_image(model: str) -> None:
|
||||
def run_single_image(model: str, max_completion_tokens: int) -> None:
|
||||
## Use image url in the payload
|
||||
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
chat_completion_from_url = client.chat.completions.create(
|
||||
@ -79,11 +81,11 @@ def run_single_image(model: str) -> None:
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
max_completion_tokens=64,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
)
|
||||
|
||||
result = chat_completion_from_url.choices[0].message.content
|
||||
print("Chat completion output from image url:", result)
|
||||
print("Chat completion output from image url:\n", result)
|
||||
|
||||
## Use base64 encoded image in the payload
|
||||
image_base64 = encode_base64_content_from_url(image_url)
|
||||
@ -101,7 +103,7 @@ def run_single_image(model: str) -> None:
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
max_completion_tokens=64,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
)
|
||||
|
||||
result = chat_completion_from_base64.choices[0].message.content
|
||||
@ -109,7 +111,7 @@ def run_single_image(model: str) -> None:
|
||||
|
||||
|
||||
# Multi-image input inference
|
||||
def run_multi_image(model: str) -> None:
|
||||
def run_multi_image(model: str, max_completion_tokens: int) -> None:
|
||||
image_url_duck = "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg"
|
||||
image_url_lion = "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg"
|
||||
chat_completion_from_url = client.chat.completions.create(
|
||||
@ -130,15 +132,15 @@ def run_multi_image(model: str) -> None:
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
max_completion_tokens=64,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
)
|
||||
|
||||
result = chat_completion_from_url.choices[0].message.content
|
||||
print("Chat completion output:", result)
|
||||
print("Chat completion output:\n", result)
|
||||
|
||||
|
||||
# Video input inference
|
||||
def run_video(model: str) -> None:
|
||||
def run_video(model: str, max_completion_tokens: int) -> None:
|
||||
video_url = "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/ForBiggerFun.mp4"
|
||||
video_base64 = encode_base64_content_from_url(video_url)
|
||||
|
||||
@ -157,11 +159,11 @@ def run_video(model: str) -> None:
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
max_completion_tokens=64,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
)
|
||||
|
||||
result = chat_completion_from_url.choices[0].message.content
|
||||
print("Chat completion output from image url:", result)
|
||||
print("Chat completion output from video url:\n", result)
|
||||
|
||||
## Use base64 encoded video in the payload
|
||||
chat_completion_from_base64 = client.chat.completions.create(
|
||||
@ -178,15 +180,15 @@ def run_video(model: str) -> None:
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
max_completion_tokens=64,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
)
|
||||
|
||||
result = chat_completion_from_base64.choices[0].message.content
|
||||
print("Chat completion output from base64 encoded image:", result)
|
||||
print("Chat completion output from base64 encoded video:\n", result)
|
||||
|
||||
|
||||
# Audio input inference
|
||||
def run_audio(model: str) -> None:
|
||||
def run_audio(model: str, max_completion_tokens: int) -> None:
|
||||
from vllm.assets.audio import AudioAsset
|
||||
|
||||
audio_url = AudioAsset("winning_call").url
|
||||
@ -211,11 +213,11 @@ def run_audio(model: str) -> None:
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
max_completion_tokens=64,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
)
|
||||
|
||||
result = chat_completion_from_base64.choices[0].message.content
|
||||
print("Chat completion output from input audio:", result)
|
||||
print("Chat completion output from input audio:\n", result)
|
||||
|
||||
# HTTP URL
|
||||
chat_completion_from_url = client.chat.completions.create(
|
||||
@ -235,11 +237,11 @@ def run_audio(model: str) -> None:
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
max_completion_tokens=64,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
)
|
||||
|
||||
result = chat_completion_from_url.choices[0].message.content
|
||||
print("Chat completion output from audio url:", result)
|
||||
print("Chat completion output from audio url:\n", result)
|
||||
|
||||
# base64 URL
|
||||
chat_completion_from_base64 = client.chat.completions.create(
|
||||
@ -259,14 +261,14 @@ def run_audio(model: str) -> None:
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
max_completion_tokens=64,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
)
|
||||
|
||||
result = chat_completion_from_base64.choices[0].message.content
|
||||
print("Chat completion output from base64 encoded audio:", result)
|
||||
print("Chat completion output from base64 encoded audio:\n", result)
|
||||
|
||||
|
||||
def run_multi_audio(model: str) -> None:
|
||||
def run_multi_audio(model: str, max_completion_tokens: int) -> None:
|
||||
from vllm.assets.audio import AudioAsset
|
||||
|
||||
# Two different audios to showcase batched inference.
|
||||
@ -300,11 +302,11 @@ def run_multi_audio(model: str) -> None:
|
||||
}
|
||||
],
|
||||
model=model,
|
||||
max_completion_tokens=64,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
)
|
||||
|
||||
result = chat_completion_from_base64.choices[0].message.content
|
||||
print("Chat completion output from input audio:", result)
|
||||
print("Chat completion output from input audio:\n", result)
|
||||
|
||||
|
||||
example_function_map = {
|
||||
@ -330,13 +332,20 @@ def parse_args():
|
||||
choices=list(example_function_map.keys()),
|
||||
help="Conversation type with multimodal data.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-completion-tokens",
|
||||
"-n",
|
||||
type=int,
|
||||
default=128,
|
||||
help="Maximum number of tokens to generate for each completion.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args) -> None:
|
||||
chat_type = args.chat_type
|
||||
model = get_first_model(client)
|
||||
example_function_map[chat_type](model)
|
||||
example_function_map[chat_type](model, args.max_completion_tokens)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -43,7 +43,6 @@ tritonclient==2.51.0
|
||||
numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding
|
||||
numba == 0.61.2; python_version > '3.9'
|
||||
numpy
|
||||
runai-model-streamer==0.11.0
|
||||
runai-model-streamer-s3==0.11.0
|
||||
runai-model-streamer[s3]==0.14.0
|
||||
fastsafetensors>=0.1.10
|
||||
pydantic>=2.10 # 2.9 leads to error on python 3.10
|
||||
|
||||
@ -5,8 +5,6 @@ numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Req
|
||||
numba == 0.61.2; python_version > '3.9'
|
||||
|
||||
# Dependencies for AMD GPUs
|
||||
boto3
|
||||
botocore
|
||||
datasets
|
||||
ray[cgraph]>=2.48.0 # Ray Compiled Graph, required for pipeline parallelism in V1.
|
||||
peft
|
||||
@ -15,7 +13,6 @@ tensorizer==2.10.1
|
||||
packaging>=24.2
|
||||
setuptools>=77.0.3,<80.0.0
|
||||
setuptools-scm>=8
|
||||
runai-model-streamer==0.11.0
|
||||
runai-model-streamer-s3==0.11.0
|
||||
runai-model-streamer[s3]==0.14.0
|
||||
conch-triton-kernels==1.2.1
|
||||
timm>=1.0.17
|
||||
@ -51,8 +51,7 @@ tritonclient==2.51.0
|
||||
numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding
|
||||
numba == 0.61.2; python_version > '3.9'
|
||||
numpy
|
||||
runai-model-streamer==0.11.0
|
||||
runai-model-streamer-s3==0.11.0
|
||||
runai-model-streamer[s3]==0.14.0
|
||||
fastsafetensors>=0.1.10
|
||||
pydantic>=2.10 # 2.9 leads to error on python 3.10
|
||||
decord==0.6.0
|
||||
|
||||
@ -72,7 +72,9 @@ blobfile==3.0.0
|
||||
bm25s==0.2.13
|
||||
# via mteb
|
||||
boto3==1.35.57
|
||||
# via tensorizer
|
||||
# via
|
||||
# runai-model-streamer-s3
|
||||
# tensorizer
|
||||
botocore==1.35.57
|
||||
# via
|
||||
# boto3
|
||||
@ -925,10 +927,10 @@ rsa==4.9.1
|
||||
# via google-auth
|
||||
rtree==1.4.0
|
||||
# via torchgeo
|
||||
runai-model-streamer==0.11.0
|
||||
# via -r requirements/test.in
|
||||
runai-model-streamer-s3==0.11.0
|
||||
runai-model-streamer==0.14.0
|
||||
# via -r requirements/test.in
|
||||
runai-model-streamer-s3==0.14.0
|
||||
# via runai-model-streamer
|
||||
s3transfer==0.10.3
|
||||
# via boto3
|
||||
sacrebleu==2.4.3
|
||||
|
||||
9
setup.py
9
setup.py
@ -322,8 +322,6 @@ class precompiled_wheel_utils:
|
||||
"vllm/_C.abi3.so",
|
||||
"vllm/_moe_C.abi3.so",
|
||||
"vllm/_flashmla_C.abi3.so",
|
||||
"vllm/_flashmla_extension_C.abi3.so",
|
||||
"vllm/_sparse_flashmla_C.abi3.so",
|
||||
"vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so",
|
||||
"vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so",
|
||||
"vllm/cumem_allocator.abi3.so",
|
||||
@ -591,8 +589,6 @@ if _is_cuda():
|
||||
# not targeting a hopper system
|
||||
ext_modules.append(
|
||||
CMakeExtension(name="vllm._flashmla_C", optional=True))
|
||||
ext_modules.append(
|
||||
CMakeExtension(name="vllm._flashmla_extension_C", optional=True))
|
||||
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
|
||||
|
||||
if _build_custom_ops():
|
||||
@ -658,10 +654,7 @@ setup(
|
||||
"bench": ["pandas", "datasets"],
|
||||
"tensorizer": ["tensorizer==2.10.1"],
|
||||
"fastsafetensors": ["fastsafetensors >= 0.1.10"],
|
||||
"runai": [
|
||||
"runai-model-streamer >= 0.14.0", "runai-model-streamer-gcs",
|
||||
"google-cloud-storage", "runai-model-streamer-s3", "boto3"
|
||||
],
|
||||
"runai": ["runai-model-streamer[s3,gcs] >= 0.14.0"],
|
||||
"audio": ["librosa", "soundfile",
|
||||
"mistral_common[audio]"], # Required for audio processing
|
||||
"video": [], # Kept for backwards compatibility
|
||||
|
||||
@ -3,12 +3,11 @@
|
||||
import contextlib
|
||||
import os
|
||||
import weakref
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.utils import wait_for_gpu_memory_to_clear
|
||||
from tests.v1.attention.utils import full_cg_backend_configs as backend_configs
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import CompilationConfig
|
||||
from vllm.platforms import current_platform
|
||||
@ -33,89 +32,6 @@ def temporary_environ(env_vars):
|
||||
os.environ[k] = v
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackendConfig:
|
||||
name: str
|
||||
env_vars: dict
|
||||
comp_config: dict
|
||||
specific_gpu_arch: Optional[tuple] = None
|
||||
|
||||
|
||||
# Define all backend configurations of full cudagraph to be tested
|
||||
backend_configs = {
|
||||
# FA3 on Hopper
|
||||
"FA3":
|
||||
BackendConfig(name="FA3",
|
||||
env_vars={
|
||||
"VLLM_FLASH_ATTN_VERSION": "3",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL",
|
||||
},
|
||||
specific_gpu_arch=(9, 0)),
|
||||
# FlashMLA on Hopper
|
||||
"FlashMLA":
|
||||
BackendConfig(name="FlashMLA",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
},
|
||||
specific_gpu_arch=(9, 0)),
|
||||
# FlashAttention MLA on Hopper
|
||||
"FlashAttentionMLA":
|
||||
BackendConfig(name="FlashAttentionMLA",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||
},
|
||||
specific_gpu_arch=(9, 0)),
|
||||
# Cutlass MLA on Blackwell
|
||||
"CutlassMLA":
|
||||
BackendConfig(
|
||||
name="CutlassMLA",
|
||||
env_vars={
|
||||
"VLLM_USE_V1": "1",
|
||||
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
|
||||
"FORCE_NUM_KV_SPLITS":
|
||||
"1", # TODO: remove this when hang issue is fixed
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
"cudagraph_capture_sizes": [16, 32, 64, 128, 256, 512],
|
||||
},
|
||||
specific_gpu_arch=(10, 0)),
|
||||
# FA2
|
||||
"FA2":
|
||||
BackendConfig(name="FA2",
|
||||
env_vars={
|
||||
"VLLM_FLASH_ATTN_VERSION": "2",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL",
|
||||
}),
|
||||
# Triton Attention
|
||||
"TritonAttn":
|
||||
BackendConfig(name="TritonAttn",
|
||||
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL",
|
||||
}),
|
||||
# FlashInfer
|
||||
"FlashInfer":
|
||||
BackendConfig(name="FlashInfer",
|
||||
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
}),
|
||||
}
|
||||
|
||||
test_params_full_cudagraph = []
|
||||
|
||||
# deepseek-ai/DeepSeek-V2-Lite with MLA
|
||||
|
||||
@ -4,7 +4,7 @@ import pytest
|
||||
|
||||
import vllm
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.config import CompilationConfig, VllmConfig
|
||||
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
|
||||
from vllm.utils import _is_torch_equal_or_newer
|
||||
|
||||
|
||||
@ -106,7 +106,6 @@ def test_dynamo_as_is(vllm_runner, monkeypatch):
|
||||
def test_no_compilation(vllm_runner, monkeypatch):
|
||||
# Disable multiprocessing so that the counter is in the same process
|
||||
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
|
||||
|
||||
with (
|
||||
compilation_counter.expect(num_graphs_seen=0,
|
||||
dynamo_as_is_count=0),
|
||||
@ -131,3 +130,67 @@ def test_enforce_eager(vllm_runner, monkeypatch):
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.4) as _):
|
||||
pass
|
||||
|
||||
|
||||
def test_splitting_ops_dynamic():
|
||||
# Default config
|
||||
config = VllmConfig()
|
||||
assert config.compilation_config.cudagraph_mode == \
|
||||
CUDAGraphMode.FULL_AND_PIECEWISE
|
||||
assert config.compilation_config.splitting_ops_contain_attention()
|
||||
|
||||
# When use_inductor_graph_partition=True
|
||||
if _is_torch_equal_or_newer('2.9.0.dev'):
|
||||
# inductor graph partition is only available in PyTorch 2.9+.
|
||||
# this is a fast config check so we are not using pytest.skip.
|
||||
config = VllmConfig(compilation_config=CompilationConfig(
|
||||
use_inductor_graph_partition=True,
|
||||
splitting_ops=["silly_attention"]))
|
||||
# should ignore splitting_ops
|
||||
assert config.compilation_config.splitting_ops == []
|
||||
|
||||
# When attn_fusion pass enabled.
|
||||
config = VllmConfig(compilation_config=CompilationConfig(
|
||||
pass_config={
|
||||
"enable_attn_fusion": True,
|
||||
"enable_noop": True
|
||||
},
|
||||
custom_ops=["+quant_fp8"],
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
))
|
||||
assert config.compilation_config.splitting_ops == []
|
||||
# cudagraph mode also fall back to FULL
|
||||
assert config.compilation_config.cudagraph_mode == \
|
||||
CUDAGraphMode.FULL
|
||||
|
||||
# splitting_ops can not contain attention ops when attn_fusion
|
||||
# pass enabled.
|
||||
with pytest.raises(AssertionError):
|
||||
config = VllmConfig(compilation_config=CompilationConfig(
|
||||
pass_config={
|
||||
"enable_attn_fusion": True,
|
||||
"enable_noop": True
|
||||
},
|
||||
custom_ops=["+quant_fp8"],
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
# work around for accessing all attntion ops
|
||||
splitting_ops=CompilationConfig()._attention_ops,
|
||||
))
|
||||
|
||||
# When both use_inductor_graph_partition and attn_fusion pass enabled.
|
||||
if _is_torch_equal_or_newer('2.9.0.dev'):
|
||||
config = VllmConfig(compilation_config=CompilationConfig(
|
||||
use_inductor_graph_partition=True,
|
||||
pass_config={
|
||||
"enable_attn_fusion": True,
|
||||
"enable_noop": True
|
||||
},
|
||||
custom_ops=["+quant_fp8"],
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
))
|
||||
assert config.compilation_config.splitting_ops == []
|
||||
# enable_attn_fusion is directly support under
|
||||
# use_inductor_graph_partition=True, and cudagraph_mode
|
||||
# is unchanged.
|
||||
assert config.compilation_config.cudagraph_mode == \
|
||||
CUDAGraphMode.PIECEWISE
|
||||
|
||||
@ -191,6 +191,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
use_mla=False,
|
||||
),
|
||||
layer_names=[self.attn.layer_name],
|
||||
vllm_config=self.vllm_config,
|
||||
|
||||
@ -593,119 +593,6 @@ def test_concat_and_cache_mla(
|
||||
torch.testing.assert_close(kv_cache, ref_kv_cache)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
|
||||
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_concat_and_cache_ds_mla(
|
||||
kv_lora_rank: int,
|
||||
qk_rope_head_dim: int,
|
||||
num_tokens: int,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
if dtype.itemsize != 2:
|
||||
pytest.skip("ds_mla only supports 16-bit input")
|
||||
kv_cache_dtype = "fp8_ds_mla"
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
total_slots = num_blocks * block_size
|
||||
slot_mapping_lst = random.sample(range(total_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping_lst,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
|
||||
k_pe = torch.randn(num_tokens,
|
||||
qk_rope_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
entry_size = kv_lora_rank + (4 * 4) + (2 * qk_rope_head_dim)
|
||||
|
||||
scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
kv_cache = _create_mla_cache(num_blocks,
|
||||
block_size,
|
||||
entry_size,
|
||||
dtype=torch.uint8,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
device=device)
|
||||
|
||||
ref_cache = torch.zeros_like(kv_cache, dtype=kv_cache.dtype)
|
||||
tile_data = torch.zeros(128, dtype=dtype, device=device)
|
||||
|
||||
for i in range(num_tokens):
|
||||
slot = slot_mapping[i].item()
|
||||
block_idx = slot // block_size
|
||||
block_offset = slot % block_size
|
||||
|
||||
ref_cache_slice = ref_cache[block_idx, block_offset]
|
||||
ref_cache_16bit = ref_cache_slice.view(dtype)
|
||||
ref_cache_32bit = ref_cache_slice.view(torch.float32)
|
||||
|
||||
kv_c_data = kv_c[i]
|
||||
for tile_idx in range(4):
|
||||
tile_start = tile_idx * 128
|
||||
tile_end = (tile_idx + 1) * 128
|
||||
tile_data[:] = kv_c_data[tile_start:tile_end]
|
||||
|
||||
# tile_scale = tile_data.amax().to(torch.float32) / 448.
|
||||
# NOTE: Using torch's amax() gives different results,
|
||||
# so this must be manually computed.
|
||||
tile_data_float = tile_data.to(torch.float32)
|
||||
manual_max = abs(tile_data_float[0])
|
||||
for j in range(1, 128):
|
||||
manual_max = max(manual_max, abs(tile_data_float[j]))
|
||||
tile_scale = manual_max / 448.
|
||||
|
||||
ref_cache_32bit[kv_lora_rank // 4 + tile_idx] = tile_scale
|
||||
|
||||
ops.convert_fp8(ref_cache_slice[tile_start:tile_end],
|
||||
tile_data,
|
||||
tile_scale.item(),
|
||||
kv_dtype="fp8")
|
||||
|
||||
for j in range(qk_rope_head_dim):
|
||||
ref_cache_16bit[kv_lora_rank // 2 + 8 + j] = k_pe[i, j]
|
||||
|
||||
opcheck(
|
||||
torch.ops._C_cache_ops.concat_and_cache_mla,
|
||||
(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
)
|
||||
|
||||
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
|
||||
kv_cache_dtype, scale)
|
||||
|
||||
for i in range(num_tokens):
|
||||
slot = slot_mapping[i].item()
|
||||
block_idx = slot // block_size
|
||||
block_offset = slot % block_size
|
||||
kv_cache_slice = kv_cache[block_idx, block_offset]
|
||||
ref_cache_slice = ref_cache[block_idx, block_offset]
|
||||
|
||||
kv_nope = kv_cache_slice[:kv_lora_rank]
|
||||
ref_nope = ref_cache_slice[:kv_lora_rank]
|
||||
kv_scales = kv_cache_slice.view(torch.float32)[kv_lora_rank //
|
||||
4:kv_lora_rank // 4 + 4]
|
||||
ref_scales = ref_cache_slice.view(
|
||||
torch.float32)[kv_lora_rank // 4:kv_lora_rank // 4 + 4]
|
||||
kv_rope = kv_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:]
|
||||
ref_rope = ref_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:]
|
||||
|
||||
torch.testing.assert_close(kv_nope, ref_nope, atol=0.001, rtol=0.1)
|
||||
torch.testing.assert_close(kv_scales, ref_scales, atol=0.001, rtol=0.1)
|
||||
torch.testing.assert_close(kv_rope, ref_rope, atol=0.001, rtol=0.1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
|
||||
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
|
||||
|
||||
@ -1,279 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv, has_deep_gemm
|
||||
from vllm.utils.deep_gemm import (_ceil_to_ue8m0, calc_diff, fp8_mqa_logits,
|
||||
fp8_paged_mqa_logits, get_num_sms,
|
||||
get_paged_mqa_logits_metadata)
|
||||
|
||||
|
||||
def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor:
|
||||
# x: (num_blocks, block_size, 1, head_dim)
|
||||
num_blocks, block_size, num_heads, head_dim = x.shape
|
||||
assert num_heads == 1
|
||||
x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4)
|
||||
sf = x_amax / 448.0
|
||||
x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn)
|
||||
x_fp8 = torch.empty(
|
||||
(num_blocks, block_size * (head_dim + 4)),
|
||||
device=x.device,
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
x_fp8[:, :block_size * head_dim] = x_scaled.view(
|
||||
num_blocks, block_size * head_dim).view(dtype=torch.uint8)
|
||||
x_fp8[:,
|
||||
block_size * head_dim:] = sf.view(num_blocks,
|
||||
block_size).view(dtype=torch.uint8)
|
||||
return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4)
|
||||
|
||||
|
||||
def per_custom_dims_cast_to_fp8(
|
||||
x: torch.Tensor, dims: tuple,
|
||||
use_ue8m0: bool) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)])
|
||||
x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4)
|
||||
sf = x_amax / 448.0
|
||||
sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
||||
x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn)
|
||||
return x_scaled, sf.squeeze()
|
||||
|
||||
|
||||
def _generate_cp_test_data(seq_len: int, seq_len_kv: int):
|
||||
assert seq_len_kv % seq_len == 0 and seq_len % 2 == 0
|
||||
chunk_size = seq_len // 2
|
||||
cp_size = seq_len_kv // seq_len
|
||||
cp_id = cp_size // 3
|
||||
ks = torch.zeros(seq_len, dtype=torch.int, device="cuda")
|
||||
ke = torch.zeros(seq_len, dtype=torch.int, device="cuda")
|
||||
for i in range(chunk_size):
|
||||
ke[i] = cp_id * chunk_size + i
|
||||
ke[i + chunk_size] = (cp_size * 2 - 1 - cp_id) * chunk_size + i
|
||||
return ks, ke
|
||||
|
||||
|
||||
def _ref_fp8_mqa_logits(
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
cu_seqlen_ks: torch.Tensor,
|
||||
cu_seqlen_ke: torch.Tensor,
|
||||
):
|
||||
seq_len_kv = kv.shape[0]
|
||||
|
||||
k = kv
|
||||
q = q.float()
|
||||
k = k.float()
|
||||
|
||||
mask_lo = (torch.arange(0, seq_len_kv, device="cuda")[None, :]
|
||||
>= cu_seqlen_ks[:, None])
|
||||
mask_hi = (torch.arange(0, seq_len_kv, device="cuda")[None, :]
|
||||
< cu_seqlen_ke[:, None])
|
||||
mask = mask_lo & mask_hi
|
||||
|
||||
score = torch.einsum("mhd,and->hmn", q, k)
|
||||
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
|
||||
logits = logits.masked_fill(~mask, float("-inf"))
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only")
|
||||
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(90),
|
||||
reason="SM90 and SM100 only")
|
||||
def test_deepgemm_fp8_mqa_logits():
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
num_heads, head_dim = 32, 128
|
||||
for seq_len in (512, ):
|
||||
for seq_len_kv in (1024, ):
|
||||
for disable_cp in (False, True):
|
||||
q = torch.randn(
|
||||
seq_len,
|
||||
num_heads,
|
||||
head_dim,
|
||||
device="cuda",
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
kv = torch.randn(seq_len_kv,
|
||||
head_dim,
|
||||
device="cuda",
|
||||
dtype=torch.bfloat16)
|
||||
weights = torch.randn(seq_len,
|
||||
num_heads,
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
|
||||
if disable_cp:
|
||||
ks = torch.zeros(seq_len, dtype=torch.int, device="cuda")
|
||||
ke = torch.arange(seq_len, dtype=torch.int,
|
||||
device="cuda") + (seq_len_kv - seq_len)
|
||||
else:
|
||||
ks, ke = _generate_cp_test_data(seq_len, seq_len_kv)
|
||||
|
||||
q_fp8 = q.to(torch.float8_e4m3fn)
|
||||
kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0, ), False)
|
||||
logits = fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke)
|
||||
|
||||
ref_logits = _ref_fp8_mqa_logits(
|
||||
q=q,
|
||||
kv=kv,
|
||||
weights=weights,
|
||||
cu_seqlen_ks=ks,
|
||||
cu_seqlen_ke=ke,
|
||||
)
|
||||
|
||||
ref_neginf_mask = ref_logits == float("-inf")
|
||||
neginf_mask = logits == float("-inf")
|
||||
assert torch.equal(neginf_mask, ref_neginf_mask)
|
||||
|
||||
ref_logits = ref_logits.masked_fill(ref_neginf_mask, 0)
|
||||
logits = logits.masked_fill(neginf_mask, 0)
|
||||
diff = calc_diff(logits, ref_logits)
|
||||
assert diff < 1e-3, f"{diff=}"
|
||||
|
||||
|
||||
def _ref_fp8_paged_mqa_logits(
|
||||
q: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
max_model_len: int,
|
||||
):
|
||||
batch_size, next_n, _, _ = q.size()
|
||||
_, block_size, _, _ = kv_cache.size()
|
||||
logits = torch.full(
|
||||
[batch_size * next_n, max_model_len],
|
||||
float("-inf"),
|
||||
device=q.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
context_lens_list = context_lens.tolist()
|
||||
for i in range(batch_size):
|
||||
context_len = context_lens_list[i]
|
||||
q_offsets = torch.arange(context_len - next_n,
|
||||
context_len,
|
||||
device="cuda")
|
||||
weight_slice = (weights[i * next_n:(i + 1) * next_n, :].transpose(
|
||||
0, 1).contiguous())
|
||||
for block_rk in range(cdiv(context_len, block_size)):
|
||||
block_idx = block_tables[i][block_rk]
|
||||
qx, kx = q[i], kv_cache[block_idx]
|
||||
k_offsets = torch.arange(
|
||||
block_rk * block_size,
|
||||
(block_rk + 1) * block_size,
|
||||
device="cuda",
|
||||
)
|
||||
mask = (k_offsets[None, :] < context_len) & (k_offsets[None, :]
|
||||
<= q_offsets[:, None])
|
||||
s = torch.where(
|
||||
mask[None, :, :],
|
||||
(qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(
|
||||
logits.dtype),
|
||||
float("-inf"),
|
||||
)
|
||||
s = torch.relu(s) * weight_slice[..., None]
|
||||
s = s.sum(dim=0)
|
||||
logits[
|
||||
i * next_n:(i + 1) * next_n,
|
||||
block_rk * block_size:(block_rk + 1) * block_size,
|
||||
] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s,
|
||||
float("-inf"))
|
||||
return logits
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only")
|
||||
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(90),
|
||||
reason="SM90 and SM100 only")
|
||||
def test_deepgemm_fp8_paged_mqa_logits():
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
|
||||
max_model_len = 4096
|
||||
for batch_size, next_n in [(4, 1), (2, 2)]:
|
||||
for heads, index_dim in [(32, 128)]:
|
||||
for avg_kv in (2048, ):
|
||||
num_blocks, blocksize = max_model_len * 2, 64
|
||||
|
||||
q = torch.randn(
|
||||
(batch_size, next_n, heads, index_dim),
|
||||
device="cuda",
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
kv_cache = torch.randn(
|
||||
(num_blocks, blocksize, 1, index_dim),
|
||||
device="cuda",
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
weights = torch.randn(
|
||||
(batch_size * next_n, heads),
|
||||
device="cuda",
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
context_lens = (torch.randint(int(0.8 * avg_kv),
|
||||
int(1.2 * avg_kv),
|
||||
(batch_size, )).cuda().to(
|
||||
torch.int32))
|
||||
max_block_len = ((context_lens.max().item() + blocksize - 1) //
|
||||
blocksize * blocksize)
|
||||
block_tables = torch.zeros(
|
||||
(batch_size, max_block_len),
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
counter = 0
|
||||
block_idx_pool = list(range(num_blocks))
|
||||
random.shuffle(block_idx_pool)
|
||||
for i in range(batch_size):
|
||||
ctx_len = int(context_lens[i].item())
|
||||
for j in range((ctx_len + blocksize - 1) // blocksize):
|
||||
block_tables[i][j] = block_idx_pool[counter]
|
||||
counter += 1
|
||||
|
||||
q_fp8 = q.to(torch.float8_e4m3fn)
|
||||
kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache)
|
||||
|
||||
schedule_metadata = get_paged_mqa_logits_metadata(
|
||||
context_lens, blocksize, get_num_sms())
|
||||
logits = fp8_paged_mqa_logits(
|
||||
q_fp8,
|
||||
kv_cache_fp8,
|
||||
weights,
|
||||
context_lens,
|
||||
block_tables,
|
||||
schedule_metadata,
|
||||
max_model_len,
|
||||
)
|
||||
|
||||
ref_logits = _ref_fp8_paged_mqa_logits(
|
||||
q,
|
||||
kv_cache,
|
||||
weights,
|
||||
context_lens,
|
||||
block_tables,
|
||||
max_model_len,
|
||||
)
|
||||
|
||||
positions = (torch.arange(max_model_len,
|
||||
device="cuda").unsqueeze(0).expand(
|
||||
batch_size * next_n, -1))
|
||||
row_indices = (
|
||||
torch.arange(batch_size * next_n, device="cuda") // next_n)
|
||||
next_n_offset = (
|
||||
torch.arange(batch_size * next_n, device="cuda") % next_n)
|
||||
mask = positions <= (context_lens[row_indices] - next_n +
|
||||
next_n_offset).unsqueeze(1)
|
||||
|
||||
logits = logits.masked_fill(~mask, 0)
|
||||
ref_logits = ref_logits.masked_fill(~mask, 0)
|
||||
diff = calc_diff(logits, ref_logits)
|
||||
assert diff < 1e-3, f"{diff=}"
|
||||
@ -97,16 +97,18 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
|
||||
descale_k = None
|
||||
|
||||
def flash_mla():
|
||||
return flash_mla_with_kvcache(q,
|
||||
blocked_k,
|
||||
block_table,
|
||||
cache_seqlens,
|
||||
dv,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
causal=causal,
|
||||
descale_q=descale_q,
|
||||
descale_k=descale_k)
|
||||
return flash_mla_with_kvcache(
|
||||
q,
|
||||
blocked_k,
|
||||
block_table,
|
||||
cache_seqlens,
|
||||
dv,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
causal=causal,
|
||||
descale_q=descale_q,
|
||||
descale_k=descale_k,
|
||||
)
|
||||
|
||||
def scaled_dot_product_attention(query, key, value, is_causal=False):
|
||||
query = query.float()
|
||||
|
||||
@ -1,119 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
def _cuda_sm90_available() -> bool:
|
||||
if not torch.cuda.is_available():
|
||||
return False
|
||||
major, _ = torch.cuda.get_device_capability()
|
||||
return major == 9
|
||||
|
||||
|
||||
def test_sparse_flashmla_metadata_smoke():
|
||||
import vllm.attention.ops.flashmla as fm
|
||||
ok, reason = fm.is_flashmla_supported()
|
||||
if not ok or not _cuda_sm90_available():
|
||||
pytest.skip(reason or "SM90 not available")
|
||||
|
||||
device = torch.device("cuda")
|
||||
batch_size = 1
|
||||
seqlen_q = 1
|
||||
num_heads_q = 128
|
||||
num_heads_k = 1
|
||||
q_seq_per_hk = seqlen_q * num_heads_q // num_heads_k
|
||||
topk = 128
|
||||
|
||||
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
||||
|
||||
tile_md, num_splits = fm.get_mla_metadata(cache_seqlens,
|
||||
q_seq_per_hk,
|
||||
num_heads_k,
|
||||
num_heads_q=num_heads_q,
|
||||
topk=topk,
|
||||
is_fp8_kvcache=True)
|
||||
assert tile_md.dtype == torch.int32
|
||||
assert num_splits.dtype == torch.int32
|
||||
|
||||
|
||||
def test_sparse_flashmla_decode_smoke():
|
||||
import vllm.attention.ops.flashmla as fm
|
||||
ok, reason = fm.is_flashmla_supported()
|
||||
if not ok or not _cuda_sm90_available():
|
||||
pytest.skip(reason or "SM90 not available")
|
||||
|
||||
device = torch.device("cuda")
|
||||
batch_size = 1
|
||||
seqlen_q = 1
|
||||
num_heads_q = 1
|
||||
head_dim_k = 576
|
||||
head_dim_v = 512
|
||||
num_heads_k = 1
|
||||
page_block_size = 64
|
||||
bytes_per_token = 656
|
||||
topk = 128
|
||||
|
||||
# Metadata
|
||||
q_seq_per_hk = seqlen_q * num_heads_q // num_heads_k
|
||||
# q_heads_per_hk = num_heads_q // num_heads_k
|
||||
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
||||
tile_md, num_splits = fm.get_mla_metadata(cache_seqlens,
|
||||
q_seq_per_hk,
|
||||
num_heads_k,
|
||||
num_heads_q=num_heads_q,
|
||||
topk=topk,
|
||||
is_fp8_kvcache=True)
|
||||
|
||||
# Inputs
|
||||
q = torch.zeros((batch_size, seqlen_q, num_heads_q, head_dim_k),
|
||||
dtype=torch.bfloat16,
|
||||
device=device)
|
||||
k_cache = torch.zeros((1, page_block_size, num_heads_k, bytes_per_token),
|
||||
dtype=torch.uint8,
|
||||
device=device)
|
||||
indices = torch.zeros((batch_size, seqlen_q, topk),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
block_table = torch.zeros((batch_size, 128),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
out, lse = fm.flash_mla_with_kvcache(q,
|
||||
k_cache,
|
||||
block_table,
|
||||
cache_seqlens,
|
||||
head_dim_v,
|
||||
tile_md,
|
||||
num_splits,
|
||||
indices=indices,
|
||||
is_fp8_kvcache=True)
|
||||
assert out.shape[0] == batch_size
|
||||
assert out.shape[-1] == head_dim_v
|
||||
assert lse.shape[0] == batch_size
|
||||
|
||||
|
||||
def test_sparse_flashmla_prefill_smoke():
|
||||
import vllm.attention.ops.flashmla as fm
|
||||
ok, reason = fm.is_flashmla_supported()
|
||||
if not ok or not _cuda_sm90_available():
|
||||
pytest.skip(reason or "SM90 not available")
|
||||
|
||||
device = torch.device("cuda")
|
||||
s_q = 1
|
||||
s_kv = 1
|
||||
h_q = 64 # kernel expects multiple of 64
|
||||
h_kv = 1
|
||||
d_qk = 576
|
||||
d_v = 512
|
||||
topk = 128
|
||||
|
||||
q = torch.zeros((s_q, h_q, d_qk), dtype=torch.bfloat16, device=device)
|
||||
kv = torch.zeros((s_kv, h_kv, d_qk), dtype=torch.bfloat16, device=device)
|
||||
indices = torch.zeros((s_q, h_kv, topk), dtype=torch.int32, device=device)
|
||||
|
||||
out, max_logits, lse = fm.flash_mla_sparse_prefill(q, kv, indices, 1.0,
|
||||
d_v)
|
||||
assert out.shape == (s_q, h_q, d_v)
|
||||
assert max_logits.shape == (s_q, h_q)
|
||||
assert lse.shape == (s_q, h_q)
|
||||
@ -1,245 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
|
||||
from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton
|
||||
|
||||
|
||||
def test_pack_seq_basic_fp8():
|
||||
"""Test basic functionality of pack_seq_triton with fp8 and 3D tensors."""
|
||||
device = "cuda"
|
||||
dtype = torch.float8_e4m3fn
|
||||
|
||||
# Test cases with 3D tensors (N, H, D)
|
||||
test_cases = [
|
||||
(6, 8, 4, 2, [3, 3]), # (6, 8, 4) -> (2, 3, 8, 4)
|
||||
(10, 4, 8, 3, [2, 4, 4]), # (10, 4, 8) -> (3, 4, 4, 8)
|
||||
(20, 16, 32, 4, [5, 5, 5, 5]), # (20, 16, 32) -> (4, 5, 16, 32)
|
||||
]
|
||||
|
||||
for N, H, D, B, lengths_list in test_cases:
|
||||
# Create input tensor with small values for fp8
|
||||
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
lengths = torch.tensor(lengths_list, device=device)
|
||||
|
||||
# Pack the data
|
||||
packed = pack_seq_triton(x, lengths)
|
||||
|
||||
# Check output shape and properties
|
||||
expected_shape = (B, max(lengths_list), H, D)
|
||||
assert packed.shape == expected_shape
|
||||
assert packed.dtype == dtype
|
||||
assert packed.device == x.device
|
||||
|
||||
# Check that valid data is preserved (within fp8 precision)
|
||||
for b in range(B):
|
||||
start_idx = sum(lengths_list[:b])
|
||||
seq_len = lengths_list[b]
|
||||
|
||||
expected_data = x[start_idx:start_idx + seq_len].to(torch.float32)
|
||||
actual_data = packed[b, :seq_len].to(torch.float32)
|
||||
|
||||
assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2)
|
||||
|
||||
|
||||
def test_pack_seq_custom_padding_fp8():
|
||||
"""Test pack_seq_triton with custom padding values for fp8."""
|
||||
device = "cuda"
|
||||
dtype = torch.float8_e4m3fn
|
||||
N, H, D, B = 20, 8, 16, 2
|
||||
lengths = torch.tensor([10, 10], device=device)
|
||||
|
||||
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
|
||||
# Test with different padding values
|
||||
for pad_value in [-100.0, -10.0, 0.0, 10.0, 100.0]:
|
||||
result = pack_seq_triton(x, lengths, pad_value=pad_value)
|
||||
|
||||
# Check valid data
|
||||
for b in range(B):
|
||||
start_idx = b * 10
|
||||
expected_data = x[start_idx:start_idx + 10].to(torch.float32)
|
||||
actual_data = result[b, :10].to(torch.float32)
|
||||
assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2)
|
||||
|
||||
# Check padding (fp8 has limited range, so check for large values)
|
||||
padded_data = result[:, 10:].to(torch.float32)
|
||||
if pad_value < 0:
|
||||
assert torch.all(padded_data < -50) # Large negative values
|
||||
elif pad_value > 0:
|
||||
assert torch.all(padded_data > 50) # Large positive values
|
||||
else:
|
||||
assert torch.allclose(padded_data,
|
||||
torch.zeros_like(padded_data),
|
||||
atol=1e-2)
|
||||
|
||||
|
||||
def test_pack_seq_default_negative_inf_padding_fp8():
|
||||
"""Test that pack_seq_triton uses -inf padding by default for fp8."""
|
||||
device = "cuda"
|
||||
dtype = torch.float8_e4m3fn
|
||||
# B = 2
|
||||
N, H, D = 20, 8, 16
|
||||
lengths = torch.tensor([10, 10], device=device)
|
||||
|
||||
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
result = pack_seq_triton(x, lengths)
|
||||
|
||||
# Check that padding is large negative values (fp8 representation of -inf)
|
||||
padded_data = result[:, 10:].to(torch.float32)
|
||||
assert torch.all(
|
||||
padded_data < -100) # fp8 -inf is represented as large negative number
|
||||
|
||||
|
||||
def test_pack_seq_edge_cases_fp8():
|
||||
"""Test pack_seq_triton with edge cases for fp8."""
|
||||
device = "cuda"
|
||||
dtype = torch.float8_e4m3fn
|
||||
|
||||
# Test with single batch element
|
||||
x = torch.randn(10, 8, 16, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
lengths = torch.tensor([10], device=device)
|
||||
result = pack_seq_triton(x, lengths)
|
||||
assert result.shape == (1, 10, 8, 16)
|
||||
|
||||
# Test with very short sequences
|
||||
x = torch.randn(20, 4, 8, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
lengths = torch.tensor([1, 1, 1], device=device)
|
||||
result = pack_seq_triton(x, lengths)
|
||||
assert result.shape == (3, 1, 4, 8)
|
||||
|
||||
# Test with different sequence lengths
|
||||
x = torch.randn(15, 8, 16, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
lengths = torch.tensor([5, 7, 3], device=device)
|
||||
result = pack_seq_triton(x, lengths)
|
||||
assert result.shape == (3, 7, 8, 16)
|
||||
|
||||
|
||||
def test_pack_seq_different_block_sizes_fp8():
|
||||
"""Test pack_seq_triton with different block sizes for fp8."""
|
||||
device = "cuda"
|
||||
dtype = torch.float8_e4m3fn
|
||||
N, H, D, B = 100, 16, 32, 4
|
||||
lengths = torch.tensor([25, 25, 25, 25], device=device)
|
||||
|
||||
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
|
||||
# Test different block sizes
|
||||
for block_t, block_d in [(32, 32), (64, 64), (128, 128)]:
|
||||
result = pack_seq_triton(x, lengths, block_t=block_t, block_d=block_d)
|
||||
|
||||
assert result.shape == (B, 25, H, D)
|
||||
|
||||
# Check that valid data is preserved (within fp8 precision)
|
||||
for b in range(B):
|
||||
start_idx = b * 25
|
||||
expected_data = x[start_idx:start_idx + 25].to(torch.float32)
|
||||
actual_data = result[b, :25].to(torch.float32)
|
||||
assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2)
|
||||
|
||||
|
||||
def test_pack_seq_shape_consistency():
|
||||
"""Test that pack_seq_triton maintains shape consistency."""
|
||||
device = "cuda"
|
||||
dtype = torch.float8_e4m3fn
|
||||
N, H, D, B = 20, 8, 16, 2
|
||||
lengths = torch.tensor([10, 10], device=device)
|
||||
|
||||
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
|
||||
result = pack_seq_triton(x, lengths)
|
||||
|
||||
# Check shape consistency
|
||||
assert result.shape[0] == B # Batch dimension
|
||||
assert result.shape[1] == lengths.max().item() # Max sequence length
|
||||
assert result.shape[2:] == x.shape[1:] # Feature dimensions preserved
|
||||
|
||||
|
||||
def test_pack_unpack_roundtrip_fp8():
|
||||
"""Test that pack -> unpack gives us back the original data for fp8."""
|
||||
device = "cuda"
|
||||
dtype = torch.float8_e4m3fn
|
||||
|
||||
# Test cases with 3D tensors
|
||||
test_cases = [
|
||||
(6, 8, 4, 2, [3, 3]),
|
||||
(10, 4, 8, 3, [2, 4, 4]),
|
||||
(20, 16, 32, 4, [5, 5, 5, 5]),
|
||||
(15, 8, 16, 3, [7, 5, 3]),
|
||||
]
|
||||
|
||||
for N, H, D, B, lengths_list in test_cases:
|
||||
# Create input tensor with small values for fp8
|
||||
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
lengths = torch.tensor(lengths_list, device=device)
|
||||
|
||||
# Pack the data
|
||||
packed = pack_seq_triton(x, lengths)
|
||||
|
||||
# Unpack the data
|
||||
unpacked = unpack_seq_triton(packed, lengths)
|
||||
|
||||
# Check that we get back the original data (within fp8 precision)
|
||||
assert unpacked.shape == x.shape
|
||||
x_f32 = x.to(torch.float32)
|
||||
unpacked_f32 = unpacked.to(torch.float32)
|
||||
assert_close(x_f32, unpacked_f32, rtol=1e-3, atol=1e-3)
|
||||
|
||||
# Unpack without explicit start locations (computed in kernel)
|
||||
unpacked_with_loc = unpack_seq_triton(packed, lengths)
|
||||
assert_close(x_f32,
|
||||
unpacked_with_loc.to(torch.float32),
|
||||
rtol=1e-3,
|
||||
atol=1e-2)
|
||||
|
||||
|
||||
def test_unpack_seq_triton_edge_cases_fp8():
|
||||
"""Test unpack function with edge cases for fp8."""
|
||||
device = "cuda"
|
||||
dtype = torch.float8_e4m3fn
|
||||
|
||||
# Test with single batch element
|
||||
x = torch.randn(10, 8, 16, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
lengths = torch.tensor([10], device=device)
|
||||
packed = pack_seq_triton(x, lengths)
|
||||
unpacked = unpack_seq_triton(packed, lengths)
|
||||
assert unpacked.shape == x.shape
|
||||
assert_close(x.to(torch.float32),
|
||||
unpacked.to(torch.float32),
|
||||
rtol=1e-1,
|
||||
atol=1e-2)
|
||||
|
||||
# Test with very short sequences
|
||||
x = torch.randn(20, 4, 8, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
lengths = torch.tensor([1, 1, 1], device=device)
|
||||
packed = pack_seq_triton(x, lengths)
|
||||
unpacked = unpack_seq_triton(packed, lengths)
|
||||
# Only compare the first 3 elements that were actually packed
|
||||
assert_close(x[:3].to(torch.float32),
|
||||
unpacked.to(torch.float32),
|
||||
rtol=1e-1,
|
||||
atol=1e-2)
|
||||
|
||||
x = torch.randn(15, 8, 16, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
lengths = torch.tensor([5, 7, 3], device=device)
|
||||
packed = pack_seq_triton(x, lengths)
|
||||
unpacked = unpack_seq_triton(packed, lengths)
|
||||
assert unpacked.shape == x.shape
|
||||
assert_close(x.to(torch.float32),
|
||||
unpacked.to(torch.float32),
|
||||
rtol=1e-1,
|
||||
atol=1e-2)
|
||||
@ -1,52 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_regex():
|
||||
return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
|
||||
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_json_schema():
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"age": {
|
||||
"type": "integer"
|
||||
},
|
||||
"skills": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"maxLength": 10
|
||||
},
|
||||
"minItems": 3
|
||||
},
|
||||
"work_history": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"company": {
|
||||
"type": "string"
|
||||
},
|
||||
"duration": {
|
||||
"type": "number"
|
||||
},
|
||||
"position": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": ["company", "position"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["name", "age", "skills", "work_history"]
|
||||
}
|
||||
@ -14,6 +14,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.model_loader.tensorizer
|
||||
from tests.utils import VLLM_PATH, RemoteOpenAIServer
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
# yapf: disable
|
||||
@ -27,7 +28,6 @@ from vllm.model_executor.model_loader.tensorizer_loader import (
|
||||
# yapf: enable
|
||||
from vllm.utils import PlaceholderModule
|
||||
|
||||
from ..utils import VLLM_PATH, RemoteOpenAIServer
|
||||
from .conftest import DummyExecutor, assert_from_collective_rpc
|
||||
|
||||
try:
|
||||
@ -207,7 +207,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
trust_remote_code=True),
|
||||
"DeepseekV3ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3", # noqa: E501
|
||||
trust_remote_code=True),
|
||||
"DeepseekV32ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3.2-Exp"),
|
||||
"Ernie4_5ForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-0.3B-PT",
|
||||
min_transformers_version="4.54"),
|
||||
"Ernie4_5_MoeForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-21B-A3B-PT",
|
||||
@ -652,6 +651,9 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
||||
"MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL",
|
||||
trust_remote_code=True,
|
||||
speculative_model="XiaomiMiMo/MiMo-7B-RL"),
|
||||
"Eagle3Qwen2_5vlForCausalLM": _HfExamplesInfo(
|
||||
"Qwen/Qwen2.5-VL-7B-Instruct",
|
||||
speculative_model="Rayzl/qwen2.5-vl-7b-eagle3-sgl"),
|
||||
"Qwen3NextMTP": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct",
|
||||
min_transformers_version="4.56.3"),
|
||||
}
|
||||
|
||||
@ -8,8 +8,7 @@ import pytest
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.utils import GiB_bytes
|
||||
from vllm.v1.core.kv_cache_utils import (generate_scheduler_kv_cache_config,
|
||||
get_kv_cache_configs)
|
||||
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
|
||||
from vllm.v1.engine.core import EngineCore as V1EngineCore
|
||||
|
||||
from ..utils import create_new_process_for_each_test
|
||||
@ -63,13 +62,11 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
|
||||
# Avoid calling model.forward()
|
||||
def _initialize_kv_caches_v1(self, vllm_config):
|
||||
kv_cache_specs = self.model_executor.get_kv_cache_specs()
|
||||
kv_cache_configs = get_kv_cache_configs(
|
||||
scheduler_kv_cache_config = get_kv_cache_configs(
|
||||
vllm_config,
|
||||
kv_cache_specs,
|
||||
[10 * GiB_bytes],
|
||||
)
|
||||
scheduler_kv_cache_config = generate_scheduler_kv_cache_config(
|
||||
kv_cache_configs)
|
||||
)[0]
|
||||
|
||||
# gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config
|
||||
return 1, 0, scheduler_kv_cache_config
|
||||
|
||||
@ -100,10 +100,9 @@ def test_distributed(
|
||||
kwargs_test=kwargs)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
current_platform.is_rocm(),
|
||||
reason="bitsandbytes quantization is currently not supported in rocm.")
|
||||
@pytest.mark.parametrize("model, quantization_kwargs", [
|
||||
("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {}),
|
||||
("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", {}),
|
||||
(
|
||||
"meta-llama/Llama-3.2-1B-Instruct",
|
||||
{
|
||||
@ -121,6 +120,11 @@ def test_quantization(
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
if (current_platform.is_rocm()
|
||||
and quantization_kwargs.get("quantization", "") == "bitsandbytes"):
|
||||
pytest.skip(
|
||||
"bitsandbytes quantization is currently not supported in rocm.")
|
||||
|
||||
with vllm_runner(
|
||||
model, model_impl="auto", enforce_eager=True,
|
||||
**quantization_kwargs) as vllm_model: # type: ignore[arg-type]
|
||||
|
||||
@ -26,5 +26,5 @@ class DummyPlatform(Platform):
|
||||
|
||||
def get_attn_backend_cls(self, backend_name, head_size, dtype,
|
||||
kv_cache_dtype, block_size, use_v1, use_mla,
|
||||
has_sink, use_sparse):
|
||||
has_sink):
|
||||
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501
|
||||
|
||||
132
tests/quantization/test_blackwell_moe.py
Normal file
132
tests/quantization/test_blackwell_moe.py
Normal file
@ -0,0 +1,132 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.is_device_capability(100):
|
||||
pytest.skip("This test only runs on Blackwell GPUs (SM100).",
|
||||
allow_module_level=True)
|
||||
|
||||
os.environ["FLASHINFER_NVCC_THREADS"] = "16"
|
||||
|
||||
# dummy_hf_overrides = {"num_layers": 4, "num_hidden_layers": 4,
|
||||
# "text_config": {"num_layers": 4, "num_hidden_layers": 4}}
|
||||
dummy_hf_overrides = {"num_layers": 4, "num_hidden_layers": 4}
|
||||
|
||||
|
||||
def can_initialize(model: str, extra_args: list[str]):
|
||||
|
||||
# Server arguments
|
||||
server_args = [
|
||||
"--max-model-len",
|
||||
"2048",
|
||||
"--max-num-batched-tokens",
|
||||
"256",
|
||||
"--load-format",
|
||||
"dummy",
|
||||
"--trust-remote-code",
|
||||
"--limit-mm-per-prompt",
|
||||
json.dumps({"image": 0}),
|
||||
*extra_args,
|
||||
]
|
||||
|
||||
# Launch server and make a simple request
|
||||
with RemoteOpenAIServer(
|
||||
model,
|
||||
server_args,
|
||||
max_wait_seconds=1000, # Due to FlashInfer compile
|
||||
override_hf_configs=dummy_hf_overrides) as server:
|
||||
client = server.get_client()
|
||||
# Make a simple request to verify the server works
|
||||
completion = client.completions.create(
|
||||
model=model,
|
||||
prompt=["Hello, World!"],
|
||||
temperature=0,
|
||||
max_tokens=2,
|
||||
)
|
||||
print(completion)
|
||||
assert completion.choices[0].text is not None
|
||||
|
||||
|
||||
## Llama4 ##
|
||||
|
||||
|
||||
@pytest.mark.skip(reason=(
|
||||
"RuntimeError: run_moe() Expected a value of type "
|
||||
"'Optional[List[Tensor]]' for argument '_9' but instead found type "
|
||||
"'list'."))
|
||||
def test_llama4_fp8_tensor_moe_flashinfer_cutlass(
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP8", "1")
|
||||
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput")
|
||||
can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", [])
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Works, but takes too long to run")
|
||||
def test_llama4_fp8_tensor_moe_flashinfer_trtllm(
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP8", "1")
|
||||
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency")
|
||||
can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", [])
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Works, but takes too long to run")
|
||||
def test_llama4_nvfp4_moe_flashinfer_cutlass(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1")
|
||||
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput")
|
||||
can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", [])
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="RuntimeError: No kernel found for the given options")
|
||||
def test_llama4_nvfp4_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1")
|
||||
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency")
|
||||
can_initialize("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", [])
|
||||
|
||||
|
||||
## DeepSeekV3 ##
|
||||
|
||||
|
||||
def test_deepseek_fp8_block_moe_deep_gemm(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "1")
|
||||
can_initialize("deepseek-ai/DeepSeek-V3.1", [])
|
||||
|
||||
|
||||
def test_deepseek_nvfp4_moe_flashinfer_cutlass(
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1")
|
||||
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "throughput")
|
||||
can_initialize("nvidia/DeepSeek-R1-0528-FP4-v2", [])
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="RuntimeError: No kernel found for the given options")
|
||||
def test_deepseek_nvfp4_moe_flashinfer_trtllm(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1")
|
||||
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", "latency")
|
||||
can_initialize("nvidia/DeepSeek-R1-0528-FP4-v2", [])
|
||||
|
||||
|
||||
## GPT-OSS ##
|
||||
|
||||
|
||||
def test_gptoss_mxfp4bf16_moe_flashinfer(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "1")
|
||||
can_initialize("openai/gpt-oss-20b", [])
|
||||
|
||||
|
||||
def test_gptoss_mxfp4mxfp8_moe_flashinfer_cutlass(
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS", "1")
|
||||
can_initialize("openai/gpt-oss-20b", [])
|
||||
|
||||
|
||||
def test_gptoss_mxfp4mxfp8_moe_flashinfer_trtllm(
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "1")
|
||||
can_initialize("openai/gpt-oss-20b", [])
|
||||
203
tests/reasoning/test_glm4_moe_reasoning_parser.py
Normal file
203
tests/reasoning/test_glm4_moe_reasoning_parser.py
Normal file
@ -0,0 +1,203 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from tests.reasoning.utils import run_reasoning_extraction
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
|
||||
parser_name = "glm45"
|
||||
start_token = "<think>"
|
||||
end_token = "</think>"
|
||||
|
||||
REASONING_MODEL_NAME = "zai-org/GLM-4.5"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def glm45_tokenizer():
|
||||
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
|
||||
|
||||
|
||||
WITH_THINK = {
|
||||
"output": "<think>This is a reasoning section</think>This is the rest",
|
||||
"reasoning_content": "This is a reasoning section",
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
|
||||
WITH_THINK_STREAM = {
|
||||
"output": "<think>This is a reasoning section</think>This is the rest",
|
||||
"reasoning_content": "This is a reasoning section",
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
|
||||
WITHOUT_THINK = {
|
||||
"output": "This is the rest",
|
||||
"reasoning_content": None,
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
|
||||
WITHOUT_THINK_STREAM = {
|
||||
"output": "This is the rest",
|
||||
"reasoning_content": None,
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
|
||||
COMPLETE_REASONING = {
|
||||
"output": "<think>This is a reasoning section</think>",
|
||||
"reasoning_content": "This is a reasoning section",
|
||||
"content": None,
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
MULTILINE_REASONING = {
|
||||
"output":
|
||||
"<think>This is a reasoning\nsection</think>This is the rest\nThat",
|
||||
"reasoning_content": "This is a reasoning\nsection",
|
||||
"content": "This is the rest\nThat",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
ONLY_OPEN_TAG = {
|
||||
"output": "<think>This is a reasoning section",
|
||||
"reasoning_content": None,
|
||||
"content": "<think>This is a reasoning section",
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
|
||||
ONLY_OPEN_TAG_STREAM = {
|
||||
"output": "<think>This is a reasoning section",
|
||||
"reasoning_content": "This is a reasoning section",
|
||||
"content": None,
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
|
||||
TEST_CASES = [
|
||||
pytest.param(
|
||||
False,
|
||||
WITH_THINK,
|
||||
id="with_think",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
WITH_THINK_STREAM,
|
||||
id="with_think_stream",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
WITHOUT_THINK,
|
||||
id="without_think",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
WITHOUT_THINK_STREAM,
|
||||
id="without_think_stream",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
COMPLETE_REASONING,
|
||||
id="complete_reasoning",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
COMPLETE_REASONING,
|
||||
id="complete_reasoning_stream",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
MULTILINE_REASONING,
|
||||
id="multiline_reasoning",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
MULTILINE_REASONING,
|
||||
id="multiline_reasoning_stream",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
ONLY_OPEN_TAG,
|
||||
id="only_open_tag",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
ONLY_OPEN_TAG_STREAM,
|
||||
id="only_open_tag_stream",
|
||||
),
|
||||
]
|
||||
|
||||
STILL_REASONING_PROMPT = """[gMASK]<sop><|system|>
|
||||
You are a helpful assistant.<|user|>
|
||||
What is the capital of France?<|assistant|>
|
||||
<think>The user is asking for the capital of"""
|
||||
|
||||
DONE_REASONING_PROMPT = """[gMASK]<sop><|system|>
|
||||
You are a helpful assistant.<|user|>
|
||||
What is the capital of France?<|assistant|>
|
||||
<think>The user is asking for the capital of France.</think>
|
||||
The capital of France is Paris."""
|
||||
|
||||
MULTI_TURN_STILL_REASONING_PROMPT = """[gMASK]<sop><|system|>
|
||||
You are a helpful assistant.<|user|>
|
||||
What is the capital of France?<|assistant|>
|
||||
<think></think>
|
||||
The capital of France is Paris.<|user|>
|
||||
What about Chile?<|assistant|>
|
||||
<think>The user is asking for the capital of"""
|
||||
|
||||
MULTI_TURN_DONE_REASONING_PROMPT = """[gMASK]<sop><|system|>
|
||||
You are a helpful assistant.<|user|>
|
||||
What is the capital of France?<|assistant|>
|
||||
<think></think>
|
||||
The capital of France is Paris.<|user|>
|
||||
What about Chile?<|assistant|>
|
||||
<think>The user is asking for the capital of Chile.</think>
|
||||
The capital of Chile is Santiago."""
|
||||
|
||||
REASONING_END_TEST_CASES = [
|
||||
pytest.param(STILL_REASONING_PROMPT, False, id="still_reasoning"),
|
||||
pytest.param(DONE_REASONING_PROMPT, True, id="done_reasoning"),
|
||||
pytest.param(MULTI_TURN_STILL_REASONING_PROMPT,
|
||||
False,
|
||||
id="multi_turn_still_reasoning"),
|
||||
pytest.param(MULTI_TURN_DONE_REASONING_PROMPT,
|
||||
True,
|
||||
id="multi_turn_done_reasoning")
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
|
||||
def test_reasoning(
|
||||
streaming: bool,
|
||||
param_dict: dict,
|
||||
glm45_tokenizer,
|
||||
):
|
||||
output = glm45_tokenizer.tokenize(param_dict["output"])
|
||||
output_tokens: list[str] = [
|
||||
glm45_tokenizer.convert_tokens_to_string([token]) for token in output
|
||||
]
|
||||
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(
|
||||
parser_name)(glm45_tokenizer)
|
||||
|
||||
reasoning, content = run_reasoning_extraction(parser,
|
||||
output_tokens,
|
||||
streaming=streaming)
|
||||
|
||||
assert reasoning == param_dict["reasoning_content"]
|
||||
assert content == param_dict["content"]
|
||||
|
||||
output_ids = glm45_tokenizer.convert_tokens_to_ids(output)
|
||||
is_reasoning_end = parser.is_reasoning_end(output_ids)
|
||||
assert is_reasoning_end == param_dict["is_reasoning_end"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("prompt, is_reasoning_end", REASONING_END_TEST_CASES)
|
||||
def test_is_reasoning_end_full_prompt(prompt: str, is_reasoning_end: bool,
|
||||
glm45_tokenizer):
|
||||
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(
|
||||
parser_name)(glm45_tokenizer)
|
||||
tokens = glm45_tokenizer.tokenize(prompt)
|
||||
token_ids = glm45_tokenizer.convert_tokens_to_ids(tokens)
|
||||
check_is_reasoning_end = parser.is_reasoning_end(token_ids)
|
||||
assert check_is_reasoning_end == is_reasoning_end
|
||||
@ -91,8 +91,10 @@ class RemoteOpenAIServer:
|
||||
env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
|
||||
if env_dict is not None:
|
||||
env.update(env_dict)
|
||||
serve_cmd = ["vllm", "serve", model, *vllm_serve_args]
|
||||
print(f"Launching RemoteOpenAIServer with: {' '.join(serve_cmd)}")
|
||||
self.proc: subprocess.Popen = subprocess.Popen(
|
||||
["vllm", "serve", model, *vllm_serve_args],
|
||||
serve_cmd,
|
||||
env=env,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stderr,
|
||||
|
||||
69
tests/utils_/test_gc_utils.py
Normal file
69
tests/utils_/test_gc_utils.py
Normal file
@ -0,0 +1,69 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from vllm.utils.gc_utils import (GCDebugConfig, _compute_detailed_type,
|
||||
_compute_top_gc_collected_objects)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Normal:
|
||||
v: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class ListWrapper:
|
||||
vs: list[int]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.vs)
|
||||
|
||||
|
||||
def test_compute_detailed_type():
|
||||
assert _compute_detailed_type(
|
||||
Normal(v=8)) == "<class 'tests.utils_.test_gc_utils.Normal'>"
|
||||
|
||||
assert _compute_detailed_type([1, 2, 3]) == "<class 'list'>(size:3)"
|
||||
assert _compute_detailed_type({4, 5}) == "<class 'set'>(size:2)"
|
||||
assert _compute_detailed_type({6: 7}) == "<class 'dict'>(size:1)"
|
||||
assert _compute_detailed_type(ListWrapper(
|
||||
vs=[])) == "<class 'tests.utils_.test_gc_utils.ListWrapper'>(size:0)"
|
||||
|
||||
|
||||
def test_compute_top_gc_collected_objects():
|
||||
objects: list[Any] = [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12],
|
||||
{13, 14}, {
|
||||
15: 16,
|
||||
17: 18
|
||||
},
|
||||
Normal(v=19),
|
||||
Normal(v=20),
|
||||
Normal(v=21)]
|
||||
assert _compute_top_gc_collected_objects(objects, top=-1) == ""
|
||||
assert _compute_top_gc_collected_objects(objects, top=0) == ""
|
||||
assert _compute_top_gc_collected_objects(
|
||||
objects, top=1) == " 4:<class 'list'>(size:3)"
|
||||
assert _compute_top_gc_collected_objects(objects, top=2) == "\n".join([
|
||||
" 4:<class 'list'>(size:3)",
|
||||
" 3:<class 'tests.utils_.test_gc_utils.Normal'>"
|
||||
])
|
||||
assert _compute_top_gc_collected_objects(objects, top=3) == "\n".join([
|
||||
" 4:<class 'list'>(size:3)",
|
||||
" 3:<class 'tests.utils_.test_gc_utils.Normal'>",
|
||||
" 1:<class 'set'>(size:2)"
|
||||
])
|
||||
|
||||
|
||||
def test_gc_debug_config():
|
||||
assert not GCDebugConfig(None).enabled
|
||||
assert not GCDebugConfig("").enabled
|
||||
assert not GCDebugConfig("0").enabled
|
||||
|
||||
config = GCDebugConfig("1")
|
||||
assert config.enabled
|
||||
assert config.top_objects == -1
|
||||
|
||||
config = GCDebugConfig("{\"top_objects\":5}")
|
||||
assert config.enabled
|
||||
assert config.top_objects == 5
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for v1 MLA backends without GPUModelRunner dependency."""
|
||||
from typing import Optional, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -11,7 +10,6 @@ from tests.v1.attention.utils import (BatchSpec, _Backend,
|
||||
create_standard_kv_cache_spec,
|
||||
create_vllm_config,
|
||||
get_attention_backend)
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
@ -80,9 +78,7 @@ def create_and_prepopulate_kv_cache(
|
||||
device: torch.device,
|
||||
num_blocks: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
randomize_blocks: bool = True,
|
||||
kv_cache_dtype: Optional[str] = None,
|
||||
scale: Union[float, torch.Tensor] = 1.0) -> torch.Tensor:
|
||||
randomize_blocks: bool = True) -> torch.Tensor:
|
||||
"""Create and prepopulate an MLA KV cache with context data.
|
||||
|
||||
Args:
|
||||
@ -97,11 +93,6 @@ def create_and_prepopulate_kv_cache(
|
||||
common_attn_metadata: Common attention metadata
|
||||
randomize_blocks: Whether to randomly permute blocks
|
||||
or use sequential order
|
||||
kv_cache_dtype: Optional kv cache dtype string. When set to
|
||||
"fp8_ds_mla" the cache is populated using the
|
||||
fp8 DeepSeek MLA layout via concat_and_cache_mla.
|
||||
scale: Scaling factor forwarded to concat_and_cache_mla when the
|
||||
fp8 cache layout is requested.
|
||||
|
||||
Returns:
|
||||
MLA KV cache tensor
|
||||
@ -114,61 +105,23 @@ def create_and_prepopulate_kv_cache(
|
||||
block_table = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
|
||||
use_fp8_ds_mla = kv_cache_dtype == "fp8_ds_mla"
|
||||
|
||||
if use_fp8_ds_mla:
|
||||
if not kv_c_contexts:
|
||||
raise ValueError("kv_c_contexts cannot be empty when using"
|
||||
" fp8_ds_mla cache dtype")
|
||||
kv_lora_rank = kv_c_contexts[0].shape[-1]
|
||||
rope_dim = k_pe_contexts[0].shape[-1]
|
||||
entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim
|
||||
kv_cache = torch.zeros(num_blocks,
|
||||
block_size,
|
||||
entry_size,
|
||||
dtype=torch.uint8,
|
||||
device=device)
|
||||
scale_tensor = (scale
|
||||
if isinstance(scale, torch.Tensor) else torch.tensor(
|
||||
scale, dtype=torch.float32, device=device))
|
||||
scale_tensor = scale_tensor.to(device=device, dtype=torch.float32)
|
||||
else:
|
||||
# Create MLA KV cache: (num_blocks, block_size, head_size)
|
||||
kv_cache = torch.empty(num_blocks,
|
||||
block_size,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
kv_cache_flat = kv_cache.view(-1, head_size)
|
||||
# Create MLA KV cache: (num_blocks, block_size, head_size)
|
||||
kv_cache = torch.empty(num_blocks,
|
||||
block_size,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
kv_cache_flat = kv_cache.view(-1, head_size)
|
||||
|
||||
# Populate the cache with the context tokens
|
||||
# Start from block_id=1 since block_id=0 is considered the null block
|
||||
start_block_idx = 1
|
||||
for i in range(batch_size):
|
||||
kv_c_context, k_pe_context = kv_c_contexts[i], k_pe_contexts[i]
|
||||
context_len = kv_c_context.shape[0]
|
||||
if context_len == 0:
|
||||
start_block_idx += cdiv(int(seq_lens[i]), block_size)
|
||||
continue
|
||||
|
||||
kv_context = torch.cat([kv_c_context, k_pe_context.squeeze(1)], dim=-1)
|
||||
start = start_block_idx * block_size
|
||||
|
||||
if use_fp8_ds_mla:
|
||||
slots = torch.arange(context_len, device=device,
|
||||
dtype=torch.long) + start
|
||||
ops.concat_and_cache_mla(
|
||||
kv_c_context,
|
||||
k_pe_context.squeeze(1),
|
||||
kv_cache,
|
||||
slots,
|
||||
kv_cache_dtype="fp8_ds_mla",
|
||||
scale=scale_tensor,
|
||||
)
|
||||
else:
|
||||
kv_context = torch.cat(
|
||||
[kv_c_context, k_pe_context.squeeze(1)], dim=-1)
|
||||
end = start + kv_context.shape[0]
|
||||
kv_cache_flat[start:end, ...] = kv_context
|
||||
end = start + kv_context.shape[0]
|
||||
kv_cache_flat[start:end, ...] = kv_context
|
||||
|
||||
# Stay block aligned and allocate enough blocks for the new tokens
|
||||
start_block_idx += cdiv(int(seq_lens[i]), block_size)
|
||||
|
||||
@ -1,448 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Unit tests for the FlashMLA sparse backend utilities."""
|
||||
|
||||
import math
|
||||
from types import MethodType, SimpleNamespace
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.v1.attention.test_mla_backends import (
|
||||
BATCH_SPECS, BatchSpec, MockAttentionLayer,
|
||||
create_and_prepopulate_kv_cache)
|
||||
from tests.v1.attention.utils import (create_common_attn_metadata,
|
||||
create_standard_kv_cache_spec,
|
||||
create_vllm_config)
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.ops import flashmla
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.attention.backends.mla.flashmla_sparse import (
|
||||
FlashMLASparseBackend, FlashMLASparseDecodeAndContextMetadata,
|
||||
FlashMLASparseImpl, FlashMLASparseMetadata)
|
||||
from vllm.v1.attention.backends.mla.indexer import split_prefill_chunks
|
||||
|
||||
SPARSE_BACKEND_BATCH_SPECS = {
|
||||
name: BATCH_SPECS[name]
|
||||
for name in [
|
||||
"mixed_small",
|
||||
"mixed_medium",
|
||||
"small_prefill",
|
||||
"medium_prefill",
|
||||
"single_prefill",
|
||||
]
|
||||
}
|
||||
|
||||
SPARSE_BACKEND_BATCH_SPECS["large_q_prefill"] = BatchSpec(seq_lens=[1024] * 2,
|
||||
query_lens=[256] * 2)
|
||||
SPARSE_BACKEND_BATCH_SPECS["large_q_pure_prefill"] = BatchSpec(
|
||||
seq_lens=[256] * 2, query_lens=[256] * 2)
|
||||
|
||||
|
||||
def _dequantize_fp8_ds_mla_entry(
|
||||
cache_slice: torch.Tensor, kv_lora_rank: int, rope_dim: int,
|
||||
dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Dequantize a single fp8_ds_mla cache entry back to latent + rope."""
|
||||
|
||||
# The first kv_lora_rank bytes store FP8 latent values with one scale per
|
||||
# 128 element tile written as float32 right after the latent payload.
|
||||
scales = cache_slice.view(torch.float32)[kv_lora_rank //
|
||||
4:kv_lora_rank // 4 + 4]
|
||||
latent = torch.empty(kv_lora_rank,
|
||||
dtype=torch.float16,
|
||||
device=cache_slice.device)
|
||||
for tile_idx in range(4):
|
||||
tile_start = tile_idx * 128
|
||||
tile_end = tile_start + 128
|
||||
ops.convert_fp8(latent[tile_start:tile_end],
|
||||
cache_slice[tile_start:tile_end],
|
||||
float(scales[tile_idx].item()),
|
||||
kv_dtype="fp8")
|
||||
latent = latent.to(dtype)
|
||||
|
||||
rope_offset = kv_lora_rank // 2 + 8
|
||||
rope_vals = cache_slice.view(dtype)[rope_offset:rope_offset + rope_dim]
|
||||
return latent, rope_vals.clone()
|
||||
|
||||
|
||||
def _quantize_dequantize_fp8_ds_mla(
|
||||
kv_c: torch.Tensor, k_pe: torch.Tensor, block_size: int,
|
||||
scale: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Round-trip kv_c/k_pe though the fp8_ds_mla cache layout."""
|
||||
|
||||
if kv_c.numel() == 0:
|
||||
return kv_c.clone(), k_pe.clone()
|
||||
|
||||
kv_lora_rank = kv_c.shape[-1]
|
||||
rope_dim = k_pe.shape[-1]
|
||||
num_tokens = kv_c.shape[0]
|
||||
num_blocks = max(1, math.ceil(num_tokens / block_size))
|
||||
entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim
|
||||
|
||||
tmp_cache = torch.zeros(num_blocks,
|
||||
block_size,
|
||||
entry_size,
|
||||
dtype=torch.uint8,
|
||||
device=kv_c.device)
|
||||
slot_mapping = torch.arange(num_tokens,
|
||||
dtype=torch.long,
|
||||
device=kv_c.device)
|
||||
|
||||
ops.concat_and_cache_mla(kv_c,
|
||||
k_pe,
|
||||
tmp_cache,
|
||||
slot_mapping,
|
||||
kv_cache_dtype="fp8_ds_mla",
|
||||
scale=scale)
|
||||
|
||||
dequant_kv_c = torch.empty_like(kv_c)
|
||||
dequant_k_pe = torch.empty_like(k_pe)
|
||||
|
||||
for token_idx in range(num_tokens):
|
||||
slot = slot_mapping[token_idx].item()
|
||||
block_idx = slot // block_size
|
||||
block_offset = slot % block_size
|
||||
cache_slice = tmp_cache[block_idx, block_offset]
|
||||
latent, rope_vals = _dequantize_fp8_ds_mla_entry(
|
||||
cache_slice, kv_lora_rank, rope_dim, kv_c.dtype)
|
||||
dequant_kv_c[token_idx] = latent
|
||||
dequant_k_pe[token_idx] = rope_vals
|
||||
|
||||
return dequant_kv_c, dequant_k_pe
|
||||
|
||||
|
||||
def test_sparse_backend_metadata_registration():
|
||||
backend = FlashMLASparseBackend
|
||||
|
||||
assert backend.get_name() == "FLASHMLA_SPARSE_VLLM_V1"
|
||||
assert backend.get_metadata_cls() is FlashMLASparseMetadata
|
||||
assert backend.get_impl_cls() is FlashMLASparseImpl
|
||||
|
||||
dtype_list = backend.get_supported_dtypes()
|
||||
assert torch.bfloat16 in dtype_list
|
||||
|
||||
shape = backend.get_kv_cache_shape(num_blocks=2,
|
||||
block_size=64,
|
||||
num_kv_heads=1,
|
||||
head_size=576)
|
||||
assert shape == (2, 64, 576)
|
||||
|
||||
|
||||
def test_sparse_decode_metadata_filters_prefill_indices():
|
||||
prefill_context_lengths = torch.tensor([4, 2], dtype=torch.int32)
|
||||
metadata = FlashMLASparseDecodeAndContextMetadata(
|
||||
scheduler_metadata=torch.tensor([[0]], dtype=torch.int32),
|
||||
num_splits=torch.tensor([1, 1], dtype=torch.int32),
|
||||
cache_lens=torch.tensor([10, 12], dtype=torch.int32),
|
||||
prefill_context_lengths=prefill_context_lengths,
|
||||
)
|
||||
|
||||
indices = torch.tensor([[0, 3, 5], [1, 2, 4]], dtype=torch.int32)
|
||||
|
||||
context_indices, new_token_indices = metadata.filter_prefill_indices(
|
||||
indices)
|
||||
|
||||
expected_context = torch.tensor([[-1, -1, 5], [-1, -1, 4]],
|
||||
dtype=torch.int32)
|
||||
expected_new_tokens = torch.tensor([[-1, -1, 1], [-1, 0, 2]],
|
||||
dtype=torch.int32)
|
||||
|
||||
assert torch.equal(context_indices, expected_context)
|
||||
assert torch.equal(new_token_indices, expected_new_tokens)
|
||||
|
||||
|
||||
def test_sparse_impl_zero_fills_when_metadata_missing():
|
||||
impl = FlashMLASparseImpl.__new__(FlashMLASparseImpl)
|
||||
dummy_layer = object()
|
||||
q = torch.zeros((2, 1, 3))
|
||||
k_c = torch.zeros((2, 3))
|
||||
k_pe = torch.zeros((2, 1, 1))
|
||||
kv_cache = torch.zeros((1, 1, 1))
|
||||
output = torch.ones((2, 4))
|
||||
|
||||
result = FlashMLASparseImpl.forward(impl,
|
||||
dummy_layer,
|
||||
q,
|
||||
k_c,
|
||||
k_pe,
|
||||
kv_cache,
|
||||
attn_metadata=None,
|
||||
output=output)
|
||||
|
||||
assert result is output
|
||||
assert torch.all(result == 0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys()))
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"])
|
||||
def test_sparse_backend_decode_correctness(dist_init, batch_name,
|
||||
kv_cache_dtype):
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is required for sparse MLA decode test")
|
||||
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16
|
||||
|
||||
batch_spec = SPARSE_BACKEND_BATCH_SPECS[batch_name]
|
||||
|
||||
# Model hyper-parameters (kept intentionally small for the unit test)
|
||||
num_heads = 128
|
||||
kv_lora_rank = 512
|
||||
qk_nope_head_dim = 128
|
||||
qk_rope_head_dim = 64
|
||||
v_head_dim = 128
|
||||
head_size = kv_lora_rank + qk_rope_head_dim
|
||||
topk_tokens = 2048
|
||||
|
||||
max_seqlen = max(batch_spec.seq_lens)
|
||||
total_cache_tokens = sum(batch_spec.seq_lens)
|
||||
block_size = 64
|
||||
|
||||
vllm_config = create_vllm_config(
|
||||
model_name="deepseek-ai/DeepSeek-V2-Lite-Chat",
|
||||
max_model_len=max_seqlen,
|
||||
num_gpu_blocks=max(2048,
|
||||
cdiv(total_cache_tokens, block_size) + 1),
|
||||
block_size=block_size)
|
||||
model_config = vllm_config.model_config
|
||||
model_config.hf_config = SimpleNamespace(
|
||||
attn_module_list_cfg=[{
|
||||
"topk_tokens": topk_tokens
|
||||
}])
|
||||
model_config.hf_text_config = SimpleNamespace(
|
||||
q_lora_rank=None,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
v_head_dim=v_head_dim,
|
||||
model_type="deepseek_v2",
|
||||
)
|
||||
model_config.dtype = dtype
|
||||
model_config.get_num_attention_heads = MethodType(
|
||||
lambda self, parallel_config: num_heads, model_config)
|
||||
model_config.get_num_kv_heads = MethodType(lambda self, parallel_config: 1,
|
||||
model_config)
|
||||
model_config.get_head_size = MethodType(lambda self: head_size,
|
||||
model_config)
|
||||
model_config.get_sliding_window = MethodType(lambda self: None,
|
||||
model_config)
|
||||
|
||||
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
scale = 1.0 / math.sqrt(head_size)
|
||||
|
||||
# Shared MLA projection weights to keep reference and backend in sync
|
||||
W_UK = torch.randn(kv_lora_rank,
|
||||
num_heads,
|
||||
qk_nope_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
W_UV = torch.randn(kv_lora_rank,
|
||||
num_heads,
|
||||
v_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
|
||||
# Build synthetic decode-only workload
|
||||
seq_lens = batch_spec.seq_lens
|
||||
query_lens = batch_spec.query_lens
|
||||
|
||||
all_q_vllm, all_kv_c_vllm, all_k_pe_vllm = [], [], []
|
||||
kv_c_contexts, k_pe_contexts = [], []
|
||||
reference_outputs = []
|
||||
|
||||
kv_cache_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
|
||||
for i in range(batch_spec.batch_size):
|
||||
s_len = seq_lens[i]
|
||||
q_len = query_lens[i]
|
||||
ctx_len = s_len - q_len
|
||||
|
||||
q_c = torch.rand(q_len,
|
||||
num_heads,
|
||||
qk_nope_head_dim + qk_rope_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
kv_c_full = torch.rand(s_len, kv_lora_rank, dtype=dtype, device=device)
|
||||
k_pe_full = torch.rand(s_len,
|
||||
1,
|
||||
qk_rope_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
|
||||
kv_c_full, k_pe_full = _quantize_dequantize_fp8_ds_mla(
|
||||
kv_c_full,
|
||||
k_pe_full.squeeze(1),
|
||||
block_size=vllm_config.cache_config.block_size,
|
||||
scale=kv_cache_scale,
|
||||
)
|
||||
|
||||
q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
|
||||
ql_nope = torch.einsum("qnh,lnh->qnl", q_nope, W_UK)
|
||||
q_mqa = torch.cat([ql_nope, q_pe], dim=-1)
|
||||
|
||||
k_mqa = torch.cat([kv_c_full, k_pe_full], dim=-1)
|
||||
k_mqa = k_mqa.unsqueeze(1).expand(-1, num_heads, -1)
|
||||
v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_heads, -1)
|
||||
|
||||
attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device)
|
||||
causal_mask = torch.tril(torch.ones(q_len, q_len, device=device))
|
||||
attn_mask[:, ctx_len:] = causal_mask
|
||||
|
||||
q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2)
|
||||
k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2)
|
||||
v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2)
|
||||
|
||||
sdpa_out = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale)
|
||||
sdpa_out = sdpa_out.transpose(1, 2).squeeze(0)
|
||||
|
||||
sdpa_out = torch.einsum("qnl,lnv->qnv", sdpa_out, W_UV)
|
||||
reference_outputs.append(sdpa_out.flatten(start_dim=-2))
|
||||
|
||||
all_q_vllm.append(q_c)
|
||||
all_kv_c_vllm.append(kv_c_full[ctx_len:])
|
||||
all_k_pe_vllm.append(k_pe_full[ctx_len:])
|
||||
kv_c_contexts.append(kv_c_full[:ctx_len + 1])
|
||||
k_pe_contexts.append(k_pe_full[:ctx_len + 1])
|
||||
|
||||
query_vllm = torch.cat(all_q_vllm, dim=0)
|
||||
kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0)
|
||||
k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0)
|
||||
sdpa_reference = torch.cat(reference_outputs, dim=0)
|
||||
|
||||
vllm_config.cache_config.cache_dtype = kv_cache_dtype
|
||||
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec,
|
||||
vllm_config.cache_config.block_size,
|
||||
device,
|
||||
arange_block_indices=True)
|
||||
|
||||
kv_cache = create_and_prepopulate_kv_cache(
|
||||
kv_c_contexts=kv_c_contexts,
|
||||
k_pe_contexts=k_pe_contexts,
|
||||
block_size=vllm_config.cache_config.block_size,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
num_blocks=vllm_config.cache_config.num_gpu_blocks,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
randomize_blocks=False,
|
||||
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
|
||||
scale=kv_cache_scale,
|
||||
)
|
||||
|
||||
builder_cls = FlashMLASparseBackend.get_builder_cls()
|
||||
builder = builder_cls(kv_cache_spec, ["placeholder"], vllm_config, device)
|
||||
metadata = builder.build(common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata)
|
||||
|
||||
starts = np.asarray(common_attn_metadata.query_start_loc_cpu,
|
||||
dtype=np.int32)
|
||||
seg_lengths = np.diff(starts)
|
||||
positions = np.arange(starts[-1], dtype=np.int32) - np.repeat(
|
||||
starts[:-1], seg_lengths)
|
||||
seq_lengths = np.asarray(common_attn_metadata.seq_lens_cpu, dtype=np.int32)
|
||||
prefix_lengths = seq_lengths - seg_lengths
|
||||
positions += np.repeat(prefix_lengths, seg_lengths)
|
||||
|
||||
pos_gpu = torch.as_tensor(positions, device=device, dtype=torch.int32)
|
||||
topk = metadata.topk_tokens
|
||||
debug_indices = torch.arange(topk, device=device,
|
||||
dtype=torch.int32).unsqueeze(0)
|
||||
token_positions = pos_gpu.unsqueeze(1)
|
||||
causal_mask = (debug_indices <= token_positions)
|
||||
debug_indices = torch.where(causal_mask, debug_indices,
|
||||
torch.full_like(debug_indices, -1))
|
||||
|
||||
# FlashMLASparseImpl now reads top-k indices from the indexer-provided
|
||||
# buffer, so emulate that contract with a simple namespace mock.
|
||||
debug_indices = debug_indices.expand(metadata.num_actual_tokens,
|
||||
-1).clone()
|
||||
mock_indexer = SimpleNamespace(topk_indices_buffer=debug_indices)
|
||||
|
||||
ok, reason = flashmla.is_flashmla_supported()
|
||||
if not ok:
|
||||
pytest.skip(reason)
|
||||
|
||||
kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1)
|
||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||
kv_lora_rank, num_heads * (qk_nope_head_dim + v_head_dim))
|
||||
|
||||
mock_kv_b_proj = ColumnParallelLinear(input_size=kv_lora_rank,
|
||||
output_size=num_heads *
|
||||
(qk_nope_head_dim + v_head_dim),
|
||||
bias=False).to(device=device,
|
||||
dtype=dtype)
|
||||
mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T.contiguous())
|
||||
|
||||
impl_cls = FlashMLASparseBackend.get_impl_cls()
|
||||
impl = impl_cls(num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=1,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
|
||||
logits_soft_cap=None,
|
||||
attn_type="decoder",
|
||||
kv_sharing_target_layer_name=None,
|
||||
q_lora_rank=None,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
|
||||
v_head_dim=v_head_dim,
|
||||
kv_b_proj=mock_kv_b_proj,
|
||||
indexer=mock_indexer)
|
||||
|
||||
impl.process_weights_after_loading(dtype)
|
||||
|
||||
layer = MockAttentionLayer(device)
|
||||
out_buffer = torch.empty(metadata.num_actual_tokens,
|
||||
num_heads * v_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
|
||||
backend_output = impl.forward(layer,
|
||||
query_vllm,
|
||||
kv_c_vllm,
|
||||
k_pe_vllm,
|
||||
kv_cache,
|
||||
metadata,
|
||||
output=out_buffer)
|
||||
|
||||
assert backend_output.shape == sdpa_reference.shape
|
||||
assert backend_output.dtype == sdpa_reference.dtype
|
||||
assert torch.isfinite(backend_output).all()
|
||||
|
||||
torch.testing.assert_close(backend_output,
|
||||
sdpa_reference,
|
||||
rtol=0.5,
|
||||
atol=0.5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"seq_lens,max_buf,start,expected",
|
||||
[
|
||||
# Basic split: totals per chunk ≤ max_buf
|
||||
(torch.tensor([2, 3, 4, 2]), 5, 0, [(0, 2), (2, 3), (3, 4)]),
|
||||
# Non-zero start index
|
||||
(torch.tensor([2, 3, 4, 2]), 5, 1, [(1, 2), (2, 3), (3, 4)]),
|
||||
# Exact fits should split between items when adding the next would
|
||||
# overflow
|
||||
(torch.tensor([5, 5, 5]), 5, 0, [(0, 1), (1, 2), (2, 3)]),
|
||||
# All requests fit in a single chunk
|
||||
(torch.tensor([1, 1, 1]), 10, 0, [(0, 3)]),
|
||||
# Large buffer with non-zero start
|
||||
(torch.tensor([4, 4, 4]), 100, 1, [(1, 3)]),
|
||||
],
|
||||
)
|
||||
def test_split_prefill_chunks(seq_lens, max_buf, start, expected):
|
||||
out = split_prefill_chunks(seq_lens, max_buf, start)
|
||||
assert out == expected
|
||||
@ -3,7 +3,7 @@
|
||||
"""Utility functions for attention-related v1 tests."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -168,6 +168,7 @@ def create_standard_kv_cache_spec(
|
||||
vllm_config.parallel_config),
|
||||
head_size=vllm_config.model_config.get_head_size(),
|
||||
dtype=vllm_config.model_config.dtype,
|
||||
use_mla=vllm_config.model_config.use_mla,
|
||||
sliding_window=vllm_config.model_config.get_sliding_window(),
|
||||
)
|
||||
|
||||
@ -259,3 +260,88 @@ def create_dummy_kv_cache(block_size: int,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
return kv_cache
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackendConfig:
|
||||
name: str
|
||||
env_vars: dict
|
||||
comp_config: dict # compilation config
|
||||
specific_gpu_arch: Optional[tuple] = None
|
||||
|
||||
|
||||
# Define all backend configurations of full cudagraph to be tested
|
||||
full_cg_backend_configs = {
|
||||
# FA3 on Hopper
|
||||
"FA3":
|
||||
BackendConfig(name="FA3",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN",
|
||||
"VLLM_FLASH_ATTN_VERSION": "3",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL",
|
||||
},
|
||||
specific_gpu_arch=(9, 0)),
|
||||
# FlashMLA on Hopper
|
||||
"FlashMLA":
|
||||
BackendConfig(name="FlashMLA",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
},
|
||||
specific_gpu_arch=(9, 0)),
|
||||
# Cutlass MLA on Blackwell
|
||||
"CutlassMLA":
|
||||
BackendConfig(
|
||||
name="CutlassMLA",
|
||||
env_vars={
|
||||
"VLLM_USE_V1": "1",
|
||||
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
|
||||
"FORCE_NUM_KV_SPLITS":
|
||||
"1", # TODO: remove this when hang issue is fixed
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
},
|
||||
specific_gpu_arch=(10, 0)),
|
||||
# FlashAttention MLA on Hopper
|
||||
"FlashAttentionMLA":
|
||||
BackendConfig(name="FlashAttentionMLA",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||
},
|
||||
specific_gpu_arch=(9, 0)),
|
||||
# FA2
|
||||
"FA2":
|
||||
BackendConfig(name="FA2",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN",
|
||||
"VLLM_FLASH_ATTN_VERSION": "2",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
}),
|
||||
# Triton Attention
|
||||
"TritonAttn":
|
||||
BackendConfig(name="TritonAttn",
|
||||
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
}),
|
||||
# FlashInfer
|
||||
"FlashInfer":
|
||||
BackendConfig(name="FlashInfer",
|
||||
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
}),
|
||||
}
|
||||
|
||||
@ -24,8 +24,7 @@ from vllm.v1.core.kv_cache_utils import (
|
||||
make_block_hash_with_group_id)
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, KVCacheSpec,
|
||||
KVCacheTensor, MLAAttentionSpec,
|
||||
SlidingWindowSpec,
|
||||
KVCacheTensor, SlidingWindowSpec,
|
||||
UniformTypeKVCacheSpecs)
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||
from vllm.v1.request import Request
|
||||
@ -78,11 +77,13 @@ def new_kv_cache_spec(block_size=16,
|
||||
num_kv_heads=2,
|
||||
head_size=64,
|
||||
dtype=torch.float32,
|
||||
use_mla=False,
|
||||
sliding_window=None):
|
||||
return FullAttentionSpec(block_size=block_size,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
use_mla=use_mla,
|
||||
sliding_window=sliding_window)
|
||||
|
||||
|
||||
@ -90,11 +91,13 @@ def new_sliding_window_spec(block_size=16,
|
||||
num_kv_heads=2,
|
||||
head_size=64,
|
||||
dtype=torch.float32,
|
||||
use_mla=False,
|
||||
sliding_window=1):
|
||||
return SlidingWindowSpec(block_size=block_size,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
use_mla=use_mla,
|
||||
sliding_window=sliding_window)
|
||||
|
||||
|
||||
@ -891,6 +894,7 @@ def test_merge_kv_cache_spec():
|
||||
num_kv_heads=full_spec.num_kv_heads,
|
||||
head_size=full_spec.head_size,
|
||||
dtype=full_spec.dtype,
|
||||
use_mla=full_spec.use_mla,
|
||||
sliding_window=1,
|
||||
),
|
||||
]
|
||||
@ -987,6 +991,7 @@ def test_estimate_max_model_len(model_id, max_model_len,
|
||||
num_kv_heads=32,
|
||||
head_size=128,
|
||||
dtype=torch.float16,
|
||||
use_mla=False,
|
||||
)
|
||||
# Estimate the maximum model length, 16384 model_len need 8GB
|
||||
estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec,
|
||||
@ -1017,6 +1022,7 @@ def test_get_max_concurrency_for_kv_cache_config():
|
||||
num_kv_heads=32,
|
||||
head_size=128,
|
||||
dtype=torch.float16,
|
||||
use_mla=False,
|
||||
)
|
||||
|
||||
sliding_window_spec = SlidingWindowSpec(
|
||||
@ -1024,6 +1030,7 @@ def test_get_max_concurrency_for_kv_cache_config():
|
||||
num_kv_heads=32,
|
||||
head_size=128,
|
||||
dtype=torch.float16,
|
||||
use_mla=False,
|
||||
sliding_window=1024,
|
||||
)
|
||||
|
||||
@ -1405,48 +1412,3 @@ def test_generate_scheduler_kv_cache_config():
|
||||
KVCacheGroupSpec(['layer_1', 'layer_2'], new_kv_cache_spec())
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def new_mla_spec(cache_dtype_str=None):
|
||||
return MLAAttentionSpec(block_size=16,
|
||||
num_kv_heads=16,
|
||||
head_size=64,
|
||||
dtype=torch.float32,
|
||||
cache_dtype_str=cache_dtype_str)
|
||||
|
||||
|
||||
def test_merge_mla_spec():
|
||||
kv_cache_specs = [
|
||||
new_mla_spec(),
|
||||
new_mla_spec(),
|
||||
]
|
||||
mla_spec = kv_cache_specs[0].merge(kv_cache_specs)
|
||||
assert mla_spec == new_mla_spec()
|
||||
|
||||
kv_cache_specs = [
|
||||
new_mla_spec(cache_dtype_str="fp8_ds_mla"),
|
||||
new_mla_spec(cache_dtype_str="fp8_ds_mla"),
|
||||
]
|
||||
mla_spec = kv_cache_specs[0].merge(kv_cache_specs)
|
||||
assert mla_spec == new_mla_spec(cache_dtype_str="fp8_ds_mla")
|
||||
|
||||
kv_cache_specs = [
|
||||
new_mla_spec(cache_dtype_str="fp8_ds_mla"),
|
||||
new_mla_spec(cache_dtype_str=None),
|
||||
]
|
||||
with pytest.raises(AssertionError):
|
||||
kv_cache_specs[0].merge(kv_cache_specs)
|
||||
|
||||
kv_cache_specs = [
|
||||
new_kv_cache_spec(),
|
||||
new_mla_spec(),
|
||||
]
|
||||
with pytest.raises(AssertionError):
|
||||
kv_cache_specs[0].merge(kv_cache_specs)
|
||||
|
||||
kv_cache_specs = [
|
||||
new_mla_spec(cache_dtype_str="fp8_ds_mla"),
|
||||
new_kv_cache_spec(),
|
||||
]
|
||||
with pytest.raises(AssertionError):
|
||||
kv_cache_specs[0].merge(kv_cache_specs)
|
||||
|
||||
@ -76,7 +76,7 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["layer"],
|
||||
FullAttentionSpec(block_size, 1, 1, torch.float32),
|
||||
FullAttentionSpec(block_size, 1, 1, torch.float32, False),
|
||||
)
|
||||
],
|
||||
)
|
||||
@ -90,7 +90,7 @@ def make_kv_cache_config_hybrid_model(block_size: int,
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["layer1"],
|
||||
FullAttentionSpec(block_size, 1, 1, torch.float32),
|
||||
FullAttentionSpec(block_size, 1, 1, torch.float32, False),
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
["layer2"],
|
||||
@ -98,6 +98,7 @@ def make_kv_cache_config_hybrid_model(block_size: int,
|
||||
1,
|
||||
1,
|
||||
torch.float32,
|
||||
False,
|
||||
sliding_window=2 * block_size),
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
@ -106,6 +107,7 @@ def make_kv_cache_config_hybrid_model(block_size: int,
|
||||
1,
|
||||
1,
|
||||
torch.float32,
|
||||
False,
|
||||
sliding_window=2 * block_size),
|
||||
),
|
||||
],
|
||||
@ -1336,6 +1338,7 @@ def test_eagle_with_sliding_window():
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=block_size,
|
||||
use_mla=False,
|
||||
)
|
||||
manager = KVCacheManager(
|
||||
KVCacheConfig(
|
||||
|
||||
@ -35,6 +35,7 @@ def test_chunked_local_attention_possible_cached_prefix():
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
attention_chunk_size=4,
|
||||
use_mla=False,
|
||||
)
|
||||
|
||||
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
|
||||
@ -99,6 +100,7 @@ def test_sliding_window_possible_cached_prefix():
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=4,
|
||||
use_mla=False,
|
||||
)
|
||||
|
||||
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
|
||||
@ -163,6 +165,7 @@ def test_chunked_local_attention_remove_skipped_blocks():
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
attention_chunk_size=4,
|
||||
use_mla=False,
|
||||
)
|
||||
|
||||
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
|
||||
@ -214,6 +217,7 @@ def test_sliding_window_remove_skipped_blocks():
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=4,
|
||||
use_mla=False,
|
||||
)
|
||||
|
||||
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
|
||||
@ -281,6 +285,7 @@ def test_get_num_blocks_to_allocate():
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=4, # Placeholder value, not related to test result
|
||||
use_mla=False,
|
||||
)
|
||||
|
||||
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
|
||||
@ -303,6 +308,7 @@ def test_chunked_local_attention_get_num_blocks_to_allocate():
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
attention_chunk_size=4, # Placeholder value, not related to test result
|
||||
use_mla=False,
|
||||
)
|
||||
|
||||
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
|
||||
|
||||
@ -45,39 +45,22 @@ def _create_vllm_config(compilation_config: CompilationConfig,
|
||||
class TestCudagraphDispatcher:
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"params",
|
||||
"case_id,cudagraph_mode_str,compilation_level",
|
||||
[
|
||||
# Test case 0: Full CG for mixed batches, no separate routine
|
||||
{
|
||||
"case_id": 0,
|
||||
"cudagraph_mode": "FULL",
|
||||
"compilation_level": CompilationLevel.NO_COMPILATION,
|
||||
},
|
||||
(0, "FULL", CompilationLevel.NO_COMPILATION),
|
||||
# Test case 1: Full CG for uniform batches, piecewise for mixed
|
||||
{
|
||||
"case_id": 1,
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
"compilation_level": CompilationLevel.PIECEWISE,
|
||||
},
|
||||
(1, "FULL_AND_PIECEWISE", CompilationLevel.NO_COMPILATION),
|
||||
# Test case 2: Full CG for uniform batches, no CG for mixed
|
||||
{
|
||||
"case_id": 2,
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||
"compilation_level": CompilationLevel.NO_COMPILATION,
|
||||
},
|
||||
(2, "FULL_DECODE_ONLY", CompilationLevel.NO_COMPILATION),
|
||||
# Test case 3: Piecewise for all
|
||||
{
|
||||
"case_id": 3,
|
||||
"cudagraph_mode": "PIECEWISE",
|
||||
"compilation_level": CompilationLevel.PIECEWISE,
|
||||
},
|
||||
(3, "PIECEWISE", CompilationLevel.PIECEWISE),
|
||||
])
|
||||
def test_dispatcher(self, params):
|
||||
def test_dispatcher(self, cudagraph_mode_str, compilation_level):
|
||||
# Setup dispatcher
|
||||
comp_config = CompilationConfig(
|
||||
cudagraph_mode=params["cudagraph_mode"],
|
||||
level=params["compilation_level"],
|
||||
cudagraph_capture_sizes=[1, 8])
|
||||
comp_config = CompilationConfig(cudagraph_mode=cudagraph_mode_str,
|
||||
level=compilation_level,
|
||||
cudagraph_capture_sizes=[1, 8])
|
||||
|
||||
config = _create_vllm_config(comp_config, max_num_seqs=8)
|
||||
dispatcher = CudagraphDispatcher(config)
|
||||
@ -86,11 +69,11 @@ class TestCudagraphDispatcher:
|
||||
uniform_decode_query_len=1)
|
||||
|
||||
# Verify the key is initialized correctly
|
||||
if params["cudagraph_mode"] in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
|
||||
if cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
|
||||
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 2
|
||||
else:
|
||||
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0
|
||||
if params["cudagraph_mode"] not in ["NONE", "PIECEWISE"]:
|
||||
if cudagraph_mode_str not in ["NONE", "PIECEWISE"]:
|
||||
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 2
|
||||
else:
|
||||
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0
|
||||
@ -99,10 +82,10 @@ class TestCudagraphDispatcher:
|
||||
# 1. non-uniform batch, size in cudagraph size list
|
||||
desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False)
|
||||
rt_mode, key = dispatcher.dispatch(desc_full_exact)
|
||||
if params["cudagraph_mode"] == "FULL":
|
||||
if cudagraph_mode_str == "FULL":
|
||||
assert rt_mode == CUDAGraphMode.FULL
|
||||
assert key == desc_full_exact
|
||||
elif params["cudagraph_mode"] in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
|
||||
elif cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
|
||||
assert rt_mode == CUDAGraphMode.PIECEWISE
|
||||
assert key == desc_full_exact
|
||||
else:
|
||||
@ -111,15 +94,13 @@ class TestCudagraphDispatcher:
|
||||
# 2. uniform decode batch, size in cudagraph size list
|
||||
desc_uniform_exact = BatchDescriptor(num_tokens=8, uniform_decode=True)
|
||||
rt_mode, key = dispatcher.dispatch(desc_uniform_exact)
|
||||
if params["cudagraph_mode"] == "FULL":
|
||||
if cudagraph_mode_str == "FULL":
|
||||
assert rt_mode == CUDAGraphMode.FULL
|
||||
assert key == desc_uniform_exact.non_uniform
|
||||
elif params["cudagraph_mode"] in [
|
||||
"FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"
|
||||
]:
|
||||
elif cudagraph_mode_str in ["FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"]:
|
||||
assert rt_mode == CUDAGraphMode.FULL
|
||||
assert key == desc_uniform_exact
|
||||
elif params["cudagraph_mode"] == "PIECEWISE":
|
||||
elif cudagraph_mode_str == "PIECEWISE":
|
||||
assert rt_mode == CUDAGraphMode.PIECEWISE
|
||||
assert key == desc_uniform_exact.non_uniform
|
||||
else:
|
||||
@ -131,6 +112,16 @@ class TestCudagraphDispatcher:
|
||||
assert rt_mode == CUDAGraphMode.NONE
|
||||
assert key is None
|
||||
|
||||
# 4. Cascade attention should have a fall back mode
|
||||
desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False)
|
||||
rt_mode, key = dispatcher.dispatch(desc_full_exact,
|
||||
use_cascade_attn=True)
|
||||
if "PIECEWISE" in cudagraph_mode_str: # string contains check
|
||||
assert rt_mode == CUDAGraphMode.PIECEWISE
|
||||
assert key == desc_full_exact.non_uniform
|
||||
else:
|
||||
assert rt_mode == CUDAGraphMode.NONE
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
|
||||
class TestCUDAGraphWrapper:
|
||||
|
||||
@ -4,12 +4,11 @@ import contextlib
|
||||
import os
|
||||
import weakref
|
||||
from contextlib import ExitStack
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.utils import wait_for_gpu_memory_to_clear
|
||||
from tests.v1.attention.utils import full_cg_backend_configs as backend_configs
|
||||
from vllm import LLM
|
||||
from vllm.config import CompilationConfig
|
||||
from vllm.platforms import current_platform
|
||||
@ -34,74 +33,6 @@ def temporary_environ(env_vars):
|
||||
os.environ[k] = v
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackendConfig:
|
||||
name: str
|
||||
env_vars: dict
|
||||
comp_config: dict
|
||||
specific_gpu_arch: Optional[tuple] = None
|
||||
|
||||
|
||||
# Define all backend configurations of full cudagraph to be tested
|
||||
backend_configs = {
|
||||
# FA3 on Hopper
|
||||
"FA3":
|
||||
BackendConfig(name="FA3",
|
||||
env_vars={
|
||||
"VLLM_FLASH_ATTN_VERSION": "3",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL",
|
||||
},
|
||||
specific_gpu_arch=(9, 0)),
|
||||
# FlashMLA on Hopper
|
||||
"FlashMLA":
|
||||
BackendConfig(name="FlashMLA",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
},
|
||||
specific_gpu_arch=(9, 0)),
|
||||
# FlashAttention MLA on Hopper
|
||||
"FlashAttentionMLA":
|
||||
BackendConfig(name="FlashAttentionMLA",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||
},
|
||||
specific_gpu_arch=(9, 0)),
|
||||
# FA2
|
||||
"FA2":
|
||||
BackendConfig(name="FA2",
|
||||
env_vars={
|
||||
"VLLM_FLASH_ATTN_VERSION": "2",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
}),
|
||||
# Triton Attention
|
||||
"TritonAttn":
|
||||
BackendConfig(name="TritonAttn",
|
||||
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
}),
|
||||
# FlashInfer
|
||||
"FlashInfer":
|
||||
BackendConfig(name="FlashInfer",
|
||||
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
}),
|
||||
}
|
||||
|
||||
# test attention backend and cudagraph_mode combo
|
||||
# (backend_name, cudagraph_mode, supported)
|
||||
combo_cases_1 = [
|
||||
@ -114,9 +45,10 @@ combo_cases_1 = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("combo_case", combo_cases_1)
|
||||
def test_backend_and_cudagraph_mode_combo(combo_case):
|
||||
backend_name, cudagraph_mode, supported = combo_case
|
||||
@pytest.mark.parametrize("backend_name, cudagraph_mode, supported",
|
||||
combo_cases_1)
|
||||
def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode,
|
||||
supported):
|
||||
if backend_name == "FlashInfer":
|
||||
try:
|
||||
import flashinfer # noqa: F401
|
||||
@ -142,7 +74,7 @@ def test_backend_and_cudagraph_mode_combo(combo_case):
|
||||
compilation_config=CompilationConfig(
|
||||
level=3, cudagraph_mode=cudagraph_mode))
|
||||
llm.generate(["Hello, my name is"] * 10)
|
||||
|
||||
# when above code raises, `llm` may be undefined, so we need to catch that
|
||||
try:
|
||||
llm = weakref.proxy(llm)
|
||||
del llm
|
||||
@ -173,7 +105,8 @@ combo_cases_2 = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("combo_case", combo_cases_2)
|
||||
@pytest.mark.parametrize("backend_name,cudagraph_mode,compilation_level,"\
|
||||
"supported", combo_cases_2)
|
||||
def test_cudagraph_compilation_combo(combo_case):
|
||||
backend_name, cudagraph_mode, compilation_level, supported\
|
||||
= combo_case
|
||||
@ -192,6 +125,7 @@ def test_cudagraph_compilation_combo(combo_case):
|
||||
compilation_config=CompilationConfig(
|
||||
level=compilation_level, cudagraph_mode=cudagraph_mode))
|
||||
llm.generate(["Hello, my name is"] * 10)
|
||||
# when above code raises, `llm` may be undefined, so we need to catch that
|
||||
try:
|
||||
llm = weakref.proxy(llm)
|
||||
del llm
|
||||
|
||||
0
tests/v1/distributed/__init__.py
Normal file
0
tests/v1/distributed/__init__.py
Normal file
@ -12,7 +12,7 @@ import pytest_asyncio
|
||||
import requests
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from tests.v1.test_utils import check_request_balancing
|
||||
from tests.v1.utils import check_request_balancing
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MODEL_NAME = "ibm-research/PowerMoE-3b"
|
||||
@ -13,7 +13,7 @@ import pytest_asyncio
|
||||
import requests
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from tests.v1.test_utils import check_request_balancing
|
||||
from tests.v1.utils import check_request_balancing
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MODEL_NAME = "ibm-research/PowerMoE-3b"
|
||||
@ -8,7 +8,7 @@ from typing import Any, Union
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.utils import get_attn_backend_list_based_on_platform
|
||||
from tests.utils import get_attn_backend_list_based_on_platform, large_gpu_mark
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.assets.base import VLLM_S3_BUCKET_URL
|
||||
from vllm.assets.image import VLM_IMAGES_DIR
|
||||
@ -88,69 +88,71 @@ def test_ngram_correctness(
|
||||
Compare the outputs of an original LLM and a speculative LLM
|
||||
should be the same when using ngram speculative decoding.
|
||||
'''
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
test_prompts = get_test_prompts(mm_enabled=False)
|
||||
test_prompts = get_test_prompts(mm_enabled=False)
|
||||
|
||||
ref_llm = LLM(model=model_name, max_model_len=1024)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
ref_llm = LLM(model=model_name, max_model_len=1024)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
spec_llm = LLM(
|
||||
model=model_name,
|
||||
speculative_config={
|
||||
"method": "ngram",
|
||||
"prompt_lookup_max": 5,
|
||||
"prompt_lookup_min": 3,
|
||||
"num_speculative_tokens": 3,
|
||||
},
|
||||
max_model_len=1024,
|
||||
)
|
||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||
matches = 0
|
||||
misses = 0
|
||||
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
|
||||
if ref_output.outputs[0].text == spec_output.outputs[0].text:
|
||||
matches += 1
|
||||
else:
|
||||
misses += 1
|
||||
print(f"ref_output: {ref_output.outputs[0].text}")
|
||||
print(f"spec_output: {spec_output.outputs[0].text}")
|
||||
spec_llm = LLM(
|
||||
model=model_name,
|
||||
speculative_config={
|
||||
"method": "ngram",
|
||||
"prompt_lookup_max": 5,
|
||||
"prompt_lookup_min": 3,
|
||||
"num_speculative_tokens": 3,
|
||||
},
|
||||
max_model_len=1024,
|
||||
)
|
||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||
matches = 0
|
||||
misses = 0
|
||||
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
|
||||
if ref_output.outputs[0].text == spec_output.outputs[0].text:
|
||||
matches += 1
|
||||
else:
|
||||
misses += 1
|
||||
print(f"ref_output: {ref_output.outputs[0].text}")
|
||||
print(f"spec_output: {spec_output.outputs[0].text}")
|
||||
|
||||
# Heuristic: expect at least 66% of the prompts to match exactly
|
||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||
assert matches >= int(0.66 * len(ref_outputs))
|
||||
del spec_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
# Heuristic: expect at least 66% of the prompts to match exactly
|
||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||
assert matches >= int(0.66 * len(ref_outputs))
|
||||
del spec_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["model_setup", "mm_enabled"], [
|
||||
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
|
||||
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
|
||||
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
|
||||
pytest.param(
|
||||
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
||||
False,
|
||||
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
|
||||
pytest.param(
|
||||
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
||||
True,
|
||||
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
|
||||
(("eagle", "eagle618/deepseek-v3-random",
|
||||
"eagle618/eagle-deepseek-v3-random", 1), False),
|
||||
],
|
||||
ids=[
|
||||
"qwen3_eagle3", "llama3_eagle", "llama3_eagle3",
|
||||
"llama4_eagle", "llama4_eagle_mm",
|
||||
"deepseek_eagle"
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
["model_setup", "mm_enabled"],
|
||||
[
|
||||
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
|
||||
pytest.param(("eagle3", "Qwen/Qwen2.5-VL-7B-Instruct",
|
||||
"Rayzl/qwen2.5-vl-7b-eagle3-sgl", 1),
|
||||
False,
|
||||
marks=pytest.mark.skip(reason="Skipping due to its " \
|
||||
"head_dim not being a a multiple of 32")),
|
||||
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
|
||||
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
|
||||
pytest.param(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
||||
False,
|
||||
marks=large_gpu_mark(min_gb=80)), # works on 4x H100
|
||||
pytest.param(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
|
||||
True,
|
||||
marks=large_gpu_mark(min_gb=80)), # works on 4x H100
|
||||
(("eagle", "eagle618/deepseek-v3-random",
|
||||
"eagle618/eagle-deepseek-v3-random", 1), False),
|
||||
],
|
||||
ids=[
|
||||
"qwen3_eagle3", "qwen2_5_vl_eagle3", "llama3_eagle", "llama3_eagle3",
|
||||
"llama4_eagle", "llama4_eagle_mm", "deepseek_eagle"
|
||||
])
|
||||
@pytest.mark.parametrize("attn_backend",
|
||||
get_attn_backend_list_based_on_platform())
|
||||
def test_eagle_correctness(
|
||||
@ -174,9 +176,14 @@ def test_eagle_correctness(
|
||||
model_setup: (method, model_name, eagle_model_name, tp_size)
|
||||
'''
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
m.setenv("VLLM_MLA_DISABLE", "1")
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||
if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN":
|
||||
# Scout requires default backend selection
|
||||
# because vision encoder has head_dim 88 being incompatible
|
||||
# with FLASH_ATTN and needs to fall back to Flex Attn
|
||||
pass
|
||||
else:
|
||||
m.setenv("VLLM_MLA_DISABLE", "1")
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||
|
||||
if (attn_backend == "TRITON_ATTN" and not current_platform.is_rocm()):
|
||||
pytest.skip("TRITON_ATTN does not support "
|
||||
|
||||
@ -836,7 +836,8 @@ def test_engine_core_proc_instantiation_cuda_empty(
|
||||
mock_spec = FullAttentionSpec(block_size=16,
|
||||
num_kv_heads=1,
|
||||
head_size=64,
|
||||
dtype=torch.float16)
|
||||
dtype=torch.float16,
|
||||
use_mla=False)
|
||||
|
||||
mock_executor.get_kv_cache_specs.return_value = [{
|
||||
"default": mock_spec
|
||||
|
||||
@ -8,7 +8,7 @@ import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from tests.v1.test_utils import check_request_balancing
|
||||
from tests.v1.utils import check_request_balancing
|
||||
|
||||
MODEL_NAME = "ibm-research/PowerMoE-3b"
|
||||
|
||||
|
||||
290
tests/v1/generation/test_batch_invariance.py
Normal file
290
tests/v1/generation/test_batch_invariance.py
Normal file
@ -0,0 +1,290 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
|
||||
# Lightweight random prompt generator to vary prompt lengths and content.
|
||||
vocab = [
|
||||
"alpha",
|
||||
"bravo",
|
||||
"charlie",
|
||||
"delta",
|
||||
"echo",
|
||||
"foxtrot",
|
||||
"golf",
|
||||
"hotel",
|
||||
"india",
|
||||
"juliet",
|
||||
"kilo",
|
||||
"lima",
|
||||
"mike",
|
||||
"november",
|
||||
"oscar",
|
||||
"papa",
|
||||
"quebec",
|
||||
"romeo",
|
||||
"sierra",
|
||||
"tango",
|
||||
"uniform",
|
||||
"victor",
|
||||
"whiskey",
|
||||
"xray",
|
||||
"yankee",
|
||||
"zulu",
|
||||
]
|
||||
n = random.randint(min_words, max_words)
|
||||
words = random.choices(vocab, k=n)
|
||||
|
||||
# Add some noise and punctuation variability
|
||||
if random.random() < 0.5:
|
||||
words[0] = words[0].capitalize()
|
||||
if random.random() < 0.2:
|
||||
words.append("".join(random.choices(string.ascii_lowercase, k=5)))
|
||||
punct = random.choice([".", "?", "!", "...", ""])
|
||||
return " ".join(words) + punct
|
||||
|
||||
|
||||
@pytest.mark.timeout(1000)
|
||||
def test_v1_generation_is_deterministic_across_batch_sizes_with_needle():
|
||||
"""
|
||||
Ensures that the same request (the 'needle' prompt) yields identical output
|
||||
whether run alone (bs=1) or mixed into a larger batch (e.g., bs=64),
|
||||
using the high-level v1 LLM() API only (no manual batching).
|
||||
|
||||
Strategy:
|
||||
- Create two LLM engines with identical config except max_num_seqs: 1 vs N.
|
||||
- Compute a baseline output for the needle prompt with the bs=1 engine.
|
||||
- For many trials, generate a batch (size N) where the needle appears at a
|
||||
random position among random filler prompts using the bs=N engine.
|
||||
- Track how many trials match vs mismatch, and report totals at the end.
|
||||
The test fails if any mismatches occur, but we still dump pass/fail
|
||||
counts.
|
||||
|
||||
Notes:
|
||||
- Use seeded stochastic sampling with a fixed seed to test determinism.
|
||||
- Outputs are intentionally longer and sampled at higher temperature/top_p
|
||||
to produce a more random-sounding phrase, yet remain deterministic by
|
||||
seed.
|
||||
- Keep max_tokens and max_model_len bounded for speed and memory use.
|
||||
"""
|
||||
random.seed(12345)
|
||||
|
||||
# Allow overrides from environment (useful for CI tuning)
|
||||
# "facebook/opt-125m" is too small, doesn't reliably test determinism
|
||||
model = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
||||
num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5"))
|
||||
batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "64"))
|
||||
assert batch_size >= 2, "Batch size should be >= 2 to mix needle."
|
||||
|
||||
# Keep GPU memory usage low to avoid startup allocation failures.
|
||||
gpu_mem_util = float(os.getenv("VLLM_GPU_MEMORY_UTILIZATION", "0.3"))
|
||||
max_model_len = int(os.getenv("VLLM_MAX_MODEL_LEN", "4096"))
|
||||
swap_space_gb = int(os.getenv("VLLM_SWAP_SPACE_GB", "4"))
|
||||
|
||||
# Sampling parameters: longer outputs with a more random-sounding
|
||||
# continuation,but still deterministic due to fixed seed.
|
||||
temperature = float(os.getenv("VLLM_NEEDLE_TEMPERATURE", "0.0"))
|
||||
top_p = float(os.getenv("VLLM_NEEDLE_TOP_P", "0.95"))
|
||||
max_tokens = int(os.getenv("VLLM_NEEDLE_MAX_TOKENS", "128"))
|
||||
|
||||
sampling = SamplingParams(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_tokens=max_tokens,
|
||||
seed=20240919,
|
||||
)
|
||||
|
||||
needle_prompt = ("There once was a ")
|
||||
|
||||
llm_bs1 = None
|
||||
llm_bsN = None
|
||||
try:
|
||||
# Engine with bs=1 behavior
|
||||
llm_bs1 = LLM_with_max_seqs(
|
||||
model=model,
|
||||
max_num_seqs=1,
|
||||
gpu_memory_utilization=gpu_mem_util,
|
||||
max_model_len=max_model_len,
|
||||
swap_space=swap_space_gb,
|
||||
)
|
||||
|
||||
# Baseline generation for the needle prompt alone.
|
||||
baseline_out = llm_bs1.generate([needle_prompt], sampling)
|
||||
assert len(baseline_out) == 1
|
||||
assert len(baseline_out[0].outputs) >= 1
|
||||
baseline_text = baseline_out[0].outputs[0].text
|
||||
|
||||
# Engine with larger batch limit (e.g., 64)
|
||||
llm_bsN = LLM_with_max_seqs(
|
||||
model=model,
|
||||
max_num_seqs=batch_size,
|
||||
gpu_memory_utilization=gpu_mem_util,
|
||||
max_model_len=max_model_len,
|
||||
swap_space=swap_space_gb,
|
||||
)
|
||||
|
||||
mismatches = 0
|
||||
|
||||
for trial in range(num_trials):
|
||||
# Create a batch of size `batch_size` and insert the needle at
|
||||
# a random index
|
||||
prompts: list[str] = []
|
||||
needle_pos = random.randint(0, batch_size - 1)
|
||||
for i in range(batch_size):
|
||||
if i == needle_pos:
|
||||
prompts.append(needle_prompt)
|
||||
else:
|
||||
prompts.append(_random_prompt())
|
||||
|
||||
# Generate with the larger-batch engine
|
||||
outputs = llm_bsN.generate(prompts, sampling)
|
||||
# Find the needle output by position
|
||||
needle_output = outputs[needle_pos]
|
||||
assert needle_output.prompt == needle_prompt
|
||||
assert len(needle_output.outputs) >= 1
|
||||
text = needle_output.outputs[0].text
|
||||
|
||||
if text != baseline_text:
|
||||
mismatches += 1
|
||||
|
||||
passes = num_trials - mismatches
|
||||
# Dump how many passed vs failed
|
||||
print(f"[determinism] total={num_trials}, passed={passes}, "
|
||||
f"failed={mismatches}, batch_size={batch_size}")
|
||||
|
||||
if mismatches > 0:
|
||||
pytest.fail(
|
||||
f"Nondeterministic outputs detected: {mismatches} failed out "
|
||||
f"of {num_trials} trials (batch_size={batch_size}).")
|
||||
|
||||
finally:
|
||||
# Ensure engines are shutdown to free GPU/VRAM across test sessions
|
||||
if llm_bs1 is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
llm_bs1.shutdown()
|
||||
if llm_bsN is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
llm_bsN.shutdown()
|
||||
|
||||
|
||||
def _extract_step_logprobs(request_output):
|
||||
if getattr(request_output, "outputs", None):
|
||||
inner = request_output.outputs[0]
|
||||
if hasattr(inner, "logprobs") and inner.logprobs is not None:
|
||||
t = torch.tensor(
|
||||
[
|
||||
inner.logprobs[i][tid].logprob
|
||||
for i, tid in enumerate(inner.token_ids)
|
||||
],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
return t
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(),
|
||||
reason="Requires CUDA to match production inference path.",
|
||||
)
|
||||
def test_logprobs_bitwise_batch_invariance_bs1_vs_bs2():
|
||||
|
||||
#model_name = os.getenv("VLLM_TEST_MODEL", "facebook/opt-125m")
|
||||
model_name = os.getenv("VLLM_TEST_MODEL", "Qwen/Qwen3-1.7B")
|
||||
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
|
||||
|
||||
# Force float32 to avoid precision-induced differences.
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
tensor_parallel_size=tp_size,
|
||||
enforce_eager=True, # helps reduce nondeterminism from some backends
|
||||
)
|
||||
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
"The capital of Germany is",
|
||||
]
|
||||
|
||||
sp = SamplingParams(
|
||||
temperature=0.0,
|
||||
top_p=1.0,
|
||||
max_tokens=8,
|
||||
# Seed shouldn't matter at temperature=0, but keeping it stable anyway.
|
||||
seed=1234,
|
||||
logprobs=5,
|
||||
)
|
||||
|
||||
# BS=1: run prompts individually and collect logprobs per step.
|
||||
bs1_logprobs_per_prompt = []
|
||||
for p in prompts:
|
||||
outs = llm.generate([p], sp, use_tqdm=False)
|
||||
assert len(outs) == 1
|
||||
step_logprobs = _extract_step_logprobs(outs[0])
|
||||
if step_logprobs is None:
|
||||
pytest.skip("Logits are not available on RequestOutput; "
|
||||
"enable logprobs return to run this test.")
|
||||
bs1_logprobs_per_prompt.append(step_logprobs)
|
||||
|
||||
# BS=2: run prompts in a batch and collect logprobs per step for each
|
||||
# prompt.
|
||||
outs_batched = llm.generate(prompts, sp, use_tqdm=False)
|
||||
assert len(outs_batched) == len(prompts)
|
||||
bs2_logprobs_per_prompt = []
|
||||
for o in outs_batched:
|
||||
step_logprobs = _extract_step_logprobs(o)
|
||||
if step_logprobs is None:
|
||||
pytest.skip("Logits are not available on RequestOutput; "
|
||||
"enable logprobs return to run this test.")
|
||||
bs2_logprobs_per_prompt.append(step_logprobs)
|
||||
|
||||
# Compare step-by-step logprobs for each prompt between BS=1 and BS=2 runs.
|
||||
for i, (logprobs_bs1, logprobs_bs2) in enumerate(
|
||||
zip(bs1_logprobs_per_prompt, bs2_logprobs_per_prompt)):
|
||||
assert len(logprobs_bs1) == len(logprobs_bs2), (
|
||||
f"Different number of generation steps for prompt index {i}: "
|
||||
f"{len(logprobs_bs1)} (BS=1) vs {len(logprobs_bs2)} (BS=2)")
|
||||
for t, (a, b) in enumerate(zip(logprobs_bs1, logprobs_bs2)):
|
||||
assert a.shape == b.shape, (
|
||||
f"Logits shape mismatch at prompt {i}, step {t}: "
|
||||
f"{a.shape} vs {b.shape}")
|
||||
# Bitwise exact equality.
|
||||
assert torch.equal(
|
||||
a, b), (f"Bitwise logprobs mismatch at prompt {i}, step {t} "
|
||||
f"(dtype={a.dtype}, shape={a.shape}).")
|
||||
|
||||
|
||||
def LLM_with_max_seqs(
|
||||
model: str,
|
||||
max_num_seqs: int,
|
||||
gpu_memory_utilization: float,
|
||||
max_model_len: int,
|
||||
swap_space: int,
|
||||
) -> LLM:
|
||||
"""
|
||||
Helper to construct an LLM with a specific max_num_seqs (batch-size limit)
|
||||
using the high-level v1 LLM API, while constraining memory usage.
|
||||
"""
|
||||
return LLM(
|
||||
model=model,
|
||||
max_num_seqs=max_num_seqs,
|
||||
# Constrain GPU memory pool so test can run even on busy GPUs.
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
# Keep KV cache footprint small while allowing longer outputs.
|
||||
max_model_len=max_model_len,
|
||||
# Allow some CPU offload if needed.
|
||||
swap_space=swap_space,
|
||||
# Keep things lean and CI-friendly.
|
||||
dtype="float16",
|
||||
# Single-GPU by default; override externally if desired.
|
||||
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
|
||||
trust_remote_code=os.getenv("VLLM_TRUST_REMOTE_CODE", "0") == "1",
|
||||
)
|
||||
@ -255,9 +255,8 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||
time.sleep(self._hand_shake_latency)
|
||||
# These should've been done in register_kv_caches(), called by
|
||||
# gpu_model_runner. Here we just hardcode some dummy values.
|
||||
slot_size_bytes = 4096
|
||||
self.slot_size_per_layer = [slot_size_bytes]
|
||||
self.block_len_per_layer = [slot_size_bytes * self.block_size]
|
||||
self.slot_size_bytes = 4096
|
||||
self.block_len = self.slot_size_bytes * self.block_size
|
||||
self.num_blocks = 1
|
||||
self.dst_num_blocks[self.engine_id] = self.num_blocks
|
||||
|
||||
@ -269,7 +268,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
||||
kv_caches_base_addr=[0],
|
||||
num_blocks=1,
|
||||
block_lens=self.block_len_per_layer,
|
||||
block_len=self.block_len,
|
||||
attn_backend_name=self.backend_name,
|
||||
# `self.kv_cache_layout` is only forced to HND when vllm engine
|
||||
# is started. We mock HND here.
|
||||
@ -486,8 +485,8 @@ class TestNixlHandshake:
|
||||
worker = connector.connector_worker
|
||||
|
||||
# Minimal local registration params used by add_remote_agent
|
||||
worker.slot_size_per_layer = [4096]
|
||||
worker.block_len_per_layer = [4096 * worker.block_size]
|
||||
worker.slot_size_bytes = 4096
|
||||
worker.block_len = worker.slot_size_bytes * worker.block_size
|
||||
worker.num_blocks = 1
|
||||
worker.dst_num_blocks[worker.engine_id] = worker.num_blocks
|
||||
|
||||
@ -499,7 +498,7 @@ class TestNixlHandshake:
|
||||
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
||||
kv_caches_base_addr=[0],
|
||||
num_blocks=1,
|
||||
block_lens=worker.block_len_per_layer,
|
||||
block_len=worker.block_len,
|
||||
attn_backend_name=worker.backend_name,
|
||||
kv_cache_layout=mismatched_layout,
|
||||
)
|
||||
|
||||
@ -337,19 +337,13 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
|
||||
"target_attn_1": mock.MagicMock(),
|
||||
"target_attn_2": mock.MagicMock()
|
||||
}
|
||||
target_indx_layers: dict[str, mock.MagicMock] = {}
|
||||
# Draft model has one extra attention layer compared to target model
|
||||
all_attn_layers = {
|
||||
**target_attn_layers, "draft_extra_attn": mock.MagicMock()
|
||||
}
|
||||
|
||||
all_indx_layers: dict[str, mock.MagicMock] = {}
|
||||
|
||||
# Make mock_get_layers return different values for each call
|
||||
mock_get_layers.side_effect = [
|
||||
target_attn_layers, target_indx_layers, all_attn_layers,
|
||||
all_indx_layers
|
||||
]
|
||||
mock_get_layers.side_effect = [target_attn_layers, all_attn_layers]
|
||||
|
||||
# Setup mock for pp group to return the appropriate value for world size
|
||||
mock_pp_group = mock.MagicMock()
|
||||
@ -664,9 +658,6 @@ def test_propose_tree(spec_token_tree):
|
||||
# Mock runner for attention metadata building.
|
||||
proposer.runner = mock.MagicMock()
|
||||
proposer.runner.attn_groups.append([mock.MagicMock()])
|
||||
proposer.runner.attn_groups[0][0].metadata_builders = [
|
||||
attn_metadata_builder
|
||||
]
|
||||
proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \
|
||||
attn_metadata_builder
|
||||
proposer._get_attention_metadata_builder = mock.MagicMock(
|
||||
|
||||
@ -63,13 +63,7 @@ def test_mtp_load_model_unified(mock_get_model, mock_get_layers,
|
||||
|
||||
target_attn_layers = {"target_attn_1": mock.MagicMock()}
|
||||
all_attn_layers = {**target_attn_layers, "draft_attn_1": mock.MagicMock()}
|
||||
target_indexer_layers: dict = {}
|
||||
all_indexer_layers: dict = {}
|
||||
|
||||
mock_get_layers.side_effect = [
|
||||
target_attn_layers, target_indexer_layers, all_attn_layers,
|
||||
all_indexer_layers
|
||||
]
|
||||
mock_get_layers.side_effect = [target_attn_layers, all_attn_layers]
|
||||
|
||||
mock_pp_group = mock.MagicMock()
|
||||
mock_pp_group.world_size = 1
|
||||
|
||||
@ -1,71 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import regex as re
|
||||
import requests
|
||||
import torch
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.v1.worker.utils import bind_kv_cache
|
||||
|
||||
|
||||
def test_bind_kv_cache():
|
||||
from vllm.attention import Attention
|
||||
|
||||
ctx = {
|
||||
'layers.0.self_attn': Attention(32, 128, 0.1),
|
||||
'layers.1.self_attn': Attention(32, 128, 0.1),
|
||||
'layers.2.self_attn': Attention(32, 128, 0.1),
|
||||
'layers.3.self_attn': Attention(32, 128, 0.1),
|
||||
}
|
||||
kv_cache = {
|
||||
'layers.0.self_attn': torch.zeros((1, )),
|
||||
'layers.1.self_attn': torch.zeros((1, )),
|
||||
'layers.2.self_attn': torch.zeros((1, )),
|
||||
'layers.3.self_attn': torch.zeros((1, )),
|
||||
}
|
||||
runner_kv_caches: list[torch.Tensor] = []
|
||||
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
|
||||
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[
|
||||
'layers.0.self_attn']
|
||||
assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[
|
||||
'layers.1.self_attn']
|
||||
assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[
|
||||
'layers.2.self_attn']
|
||||
assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[
|
||||
'layers.3.self_attn']
|
||||
|
||||
assert runner_kv_caches[0] is kv_cache['layers.0.self_attn']
|
||||
assert runner_kv_caches[1] is kv_cache['layers.1.self_attn']
|
||||
assert runner_kv_caches[2] is kv_cache['layers.2.self_attn']
|
||||
assert runner_kv_caches[3] is kv_cache['layers.3.self_attn']
|
||||
|
||||
|
||||
def test_bind_kv_cache_non_attention():
|
||||
from vllm.attention import Attention
|
||||
|
||||
# example from Jamba PP=2
|
||||
ctx = {
|
||||
'model.layers.20.attn': Attention(32, 128, 0.1),
|
||||
'model.layers.28.attn': Attention(32, 128, 0.1),
|
||||
}
|
||||
kv_cache = {
|
||||
'model.layers.20.attn': torch.zeros((1, )),
|
||||
'model.layers.28.attn': torch.zeros((1, )),
|
||||
}
|
||||
|
||||
runner_kv_caches: list[torch.Tensor] = []
|
||||
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
|
||||
|
||||
assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[
|
||||
'model.layers.20.attn']
|
||||
assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[
|
||||
'model.layers.28.attn']
|
||||
|
||||
assert runner_kv_caches[0] is kv_cache['model.layers.20.attn']
|
||||
assert runner_kv_caches[1] is kv_cache['model.layers.28.attn']
|
||||
|
||||
|
||||
# Prometheus metrics utilities for testing
|
||||
|
||||
@ -39,6 +39,7 @@ def initialize_kv_cache(runner: GPUModelRunner):
|
||||
runner.parallel_config),
|
||||
head_size=runner.model_config.get_head_size(),
|
||||
dtype=runner.kv_cache_dtype,
|
||||
use_mla=False,
|
||||
)
|
||||
tensor_size = attn_spec.page_size_bytes * NUM_BLOCKS
|
||||
kv_cache_config = KVCacheConfig(
|
||||
|
||||
63
tests/v1/worker/test_utils.py
Normal file
63
tests/v1/worker/test_utils.py
Normal file
@ -0,0 +1,63 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.v1.worker.utils import bind_kv_cache
|
||||
|
||||
|
||||
def test_bind_kv_cache():
|
||||
from vllm.attention import Attention
|
||||
|
||||
ctx = {
|
||||
'layers.0.self_attn': Attention(32, 128, 0.1),
|
||||
'layers.1.self_attn': Attention(32, 128, 0.1),
|
||||
'layers.2.self_attn': Attention(32, 128, 0.1),
|
||||
'layers.3.self_attn': Attention(32, 128, 0.1),
|
||||
}
|
||||
kv_cache = {
|
||||
'layers.0.self_attn': torch.zeros((1, )),
|
||||
'layers.1.self_attn': torch.zeros((1, )),
|
||||
'layers.2.self_attn': torch.zeros((1, )),
|
||||
'layers.3.self_attn': torch.zeros((1, )),
|
||||
}
|
||||
runner_kv_caches: list[torch.Tensor] = []
|
||||
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
|
||||
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[
|
||||
'layers.0.self_attn']
|
||||
assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[
|
||||
'layers.1.self_attn']
|
||||
assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[
|
||||
'layers.2.self_attn']
|
||||
assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[
|
||||
'layers.3.self_attn']
|
||||
|
||||
assert runner_kv_caches[0] is kv_cache['layers.0.self_attn']
|
||||
assert runner_kv_caches[1] is kv_cache['layers.1.self_attn']
|
||||
assert runner_kv_caches[2] is kv_cache['layers.2.self_attn']
|
||||
assert runner_kv_caches[3] is kv_cache['layers.3.self_attn']
|
||||
|
||||
|
||||
def test_bind_kv_cache_non_attention():
|
||||
from vllm.attention import Attention
|
||||
|
||||
# example from Jamba PP=2
|
||||
ctx = {
|
||||
'model.layers.20.attn': Attention(32, 128, 0.1),
|
||||
'model.layers.28.attn': Attention(32, 128, 0.1),
|
||||
}
|
||||
kv_cache = {
|
||||
'model.layers.20.attn': torch.zeros((1, )),
|
||||
'model.layers.28.attn': torch.zeros((1, )),
|
||||
}
|
||||
|
||||
runner_kv_caches: list[torch.Tensor] = []
|
||||
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
|
||||
|
||||
assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[
|
||||
'model.layers.20.attn']
|
||||
assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[
|
||||
'model.layers.28.attn']
|
||||
|
||||
assert runner_kv_caches[0] is kv_cache['model.layers.20.attn']
|
||||
assert runner_kv_caches[1] is kv_cache['model.layers.28.attn']
|
||||
63
tools/flashinfer-build.sh
Normal file
63
tools/flashinfer-build.sh
Normal file
@ -0,0 +1,63 @@
|
||||
#!/usr/bin/env bash
|
||||
# This script is used to build FlashInfer wheels with AOT kernels
|
||||
|
||||
set -ex
|
||||
|
||||
# FlashInfer configuration
|
||||
FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
|
||||
FLASHINFER_GIT_REF="${FLASHINFER_GIT_REF}"
|
||||
CUDA_VERSION="${CUDA_VERSION}"
|
||||
BUILD_WHEEL="${BUILD_WHEEL:-true}"
|
||||
|
||||
if [[ -z "${FLASHINFER_GIT_REF}" ]]; then
|
||||
echo "❌ FLASHINFER_GIT_REF must be specified" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ -z "${CUDA_VERSION}" ]]; then
|
||||
echo "❌ CUDA_VERSION must be specified" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "🏗️ Building FlashInfer ${FLASHINFER_GIT_REF} for CUDA ${CUDA_VERSION}"
|
||||
|
||||
# Clone FlashInfer
|
||||
git clone --depth 1 --recursive --shallow-submodules \
|
||||
--branch ${FLASHINFER_GIT_REF} \
|
||||
${FLASHINFER_GIT_REPO} flashinfer
|
||||
|
||||
# Set CUDA arch list based on CUDA version
|
||||
# Exclude CUDA arches for older versions (11.x and 12.0-12.7)
|
||||
if [[ "${CUDA_VERSION}" == 11.* ]]; then
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9"
|
||||
elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a"
|
||||
else
|
||||
# CUDA 12.8+ supports 10.0a and 12.0
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0"
|
||||
fi
|
||||
|
||||
echo "🏗️ Building FlashInfer AOT for arches: ${FI_TORCH_CUDA_ARCH_LIST}"
|
||||
|
||||
pushd flashinfer
|
||||
# Make sure the wheel is built for the correct CUDA version
|
||||
export UV_TORCH_BACKEND=cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
|
||||
# Build AOT kernels
|
||||
export TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}"
|
||||
export FLASHINFER_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}"
|
||||
python3 -m flashinfer.aot
|
||||
|
||||
if [[ "${BUILD_WHEEL}" == "true" ]]; then
|
||||
# Build wheel for distribution
|
||||
uv build --no-build-isolation --wheel --out-dir ../flashinfer-dist .
|
||||
echo "✅ FlashInfer wheel built successfully in flashinfer-dist/"
|
||||
else
|
||||
# Install directly (for Dockerfile)
|
||||
uv pip install --system --no-build-isolation --force-reinstall .
|
||||
echo "✅ FlashInfer installed successfully"
|
||||
fi
|
||||
popd
|
||||
|
||||
# Cleanup
|
||||
rm -rf flashinfer
|
||||
@ -1678,15 +1678,6 @@ def cp_gather_cache(src_cache: torch.Tensor,
|
||||
cu_seq_lens, batch_size, seq_starts)
|
||||
|
||||
|
||||
def indexer_k_quant_and_cache(k: torch.Tensor, kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
quant_block_size: int,
|
||||
kv_cache_dtype: str) -> None:
|
||||
torch.ops._C_cache_ops.indexer_k_quant_and_cache(k, kv_cache, slot_mapping,
|
||||
quant_block_size,
|
||||
kv_cache_dtype)
|
||||
|
||||
|
||||
def get_device_attribute(attribute: int, device: int) -> int:
|
||||
return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)
|
||||
|
||||
|
||||
@ -70,7 +70,6 @@ class AttentionBackend(ABC):
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> Tuple[int, ...]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -95,7 +95,6 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
per_layer_sliding_window: Optional[int] = None,
|
||||
use_mla: bool = False,
|
||||
use_sparse: bool = False,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
@ -156,7 +155,6 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
self._o_scale_float: Optional[float] = None
|
||||
|
||||
self.use_mla = use_mla
|
||||
self.use_sparse = use_sparse
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.num_kv_heads = num_kv_heads
|
||||
@ -189,8 +187,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
use_mla=use_mla,
|
||||
has_sink=self.has_sink,
|
||||
use_sparse=use_sparse)
|
||||
has_sink=self.has_sink)
|
||||
else:
|
||||
self.attn_backend = attn_backend
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
from typing import ClassVar, List, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -11,8 +11,8 @@ from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import CacheConfig, QuantizationConfig
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport, CommonAttentionMetadata,
|
||||
make_local_attention_virtual_batches, subclass_attention_backend)
|
||||
CommonAttentionMetadata, make_local_attention_virtual_batches,
|
||||
subclass_attention_backend)
|
||||
|
||||
from ..layer import Attention
|
||||
|
||||
@ -28,8 +28,6 @@ def create_chunked_local_attention_backend(
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
|
||||
class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.NEVER
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
|
||||
@ -138,208 +138,3 @@ def cp_lse_ag_out_rs(cp_attn_out: torch.Tensor,
|
||||
out, _ = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
|
||||
out = cp_group.reduce_scatter(out, dim=1)
|
||||
return out
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _pack_seq_kernel(
|
||||
x_ptr, # [N, D]
|
||||
out_ptr, # [B, Lmax, D]
|
||||
lengths_ptr, # *i32, [B]
|
||||
N: tl.constexpr,
|
||||
D: tl.constexpr,
|
||||
Lmax: tl.constexpr,
|
||||
PAD_VALUE: tl.constexpr,
|
||||
BLOCK_T: tl.constexpr, # timesteps per program
|
||||
BLOCK_D: tl.constexpr # features per program
|
||||
):
|
||||
pid_b = tl.program_id(0) # batch id
|
||||
pid_t = tl.program_id(1) # block over time dimension
|
||||
pid_d = tl.program_id(2) # block over feature dimension
|
||||
off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T]
|
||||
off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D]
|
||||
|
||||
# Compute start index and sequence length from cumulative lengths
|
||||
in_start = 0
|
||||
for i in range(pid_b):
|
||||
in_start += tl.load(lengths_ptr + i)
|
||||
seq_len = tl.load(lengths_ptr + pid_b)
|
||||
|
||||
# valid time positions for this block
|
||||
t_mask = off_t < Lmax
|
||||
|
||||
# compute input row indices for valid (b, t)
|
||||
in_row = in_start + off_t
|
||||
valid_row = (off_t < seq_len) & t_mask
|
||||
|
||||
# Pointers
|
||||
# x_ptr: row-major [N, D]
|
||||
x_row_ptr = x_ptr + in_row[:, None] * D + off_d[None, :]
|
||||
|
||||
# out_ptr: row-major [B, Lmax, D]
|
||||
out_row_ptr = out_ptr + (pid_b * Lmax + off_t)[:,
|
||||
None] * D + off_d[None, :]
|
||||
|
||||
# Initialize with PAD (cast will occur as needed based on out_ptr dtype)
|
||||
d_mask = off_d[None, :] < D
|
||||
pad_vals = tl.full([BLOCK_T, BLOCK_D], PAD_VALUE, tl.float32)
|
||||
tl.store(out_row_ptr, pad_vals, mask=t_mask[:, None] & d_mask)
|
||||
|
||||
# Load & write only where within seq_len
|
||||
x_vals = tl.load(x_row_ptr, mask=valid_row[:, None] & d_mask)
|
||||
tl.store(out_row_ptr, x_vals, mask=valid_row[:, None] & d_mask)
|
||||
|
||||
|
||||
def pack_seq_triton(x: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
pad_value: float = -float('inf'),
|
||||
block_t: int = 64,
|
||||
block_d: int = 64) -> torch.Tensor:
|
||||
"""
|
||||
Pack sequences of different lengths into a batched tensor.
|
||||
|
||||
Args:
|
||||
x: [N, ...] - input tensor where N is total number of tokens
|
||||
lengths: [B] - sequence lengths for each batch
|
||||
pad_value: value to use for padding
|
||||
block_t: block size for time dimension
|
||||
block_d: block size for feature dimension
|
||||
|
||||
Returns:
|
||||
packed: [B, Lmax, ...] - packed tensor
|
||||
"""
|
||||
|
||||
# Handle multi-dimensional input by reshaping to (N, -1)
|
||||
original_shape = x.shape
|
||||
if len(original_shape) > 2:
|
||||
N = original_shape[0]
|
||||
x_reshaped = x.reshape(N, -1)
|
||||
D = x_reshaped.shape[1]
|
||||
else:
|
||||
N, D = x.shape
|
||||
x_reshaped = x
|
||||
|
||||
B = lengths.numel()
|
||||
Lmax = int(lengths.max().item())
|
||||
|
||||
# Starts are computed inside the kernel from lengths
|
||||
|
||||
out = torch.empty((B, Lmax, D), device=x.device, dtype=x.dtype)
|
||||
|
||||
grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d))
|
||||
_pack_seq_kernel[grid](x_reshaped,
|
||||
out,
|
||||
lengths.int(),
|
||||
N,
|
||||
D,
|
||||
Lmax,
|
||||
PAD_VALUE=float(pad_value),
|
||||
BLOCK_T=block_t,
|
||||
BLOCK_D=block_d,
|
||||
num_warps=4,
|
||||
num_stages=2)
|
||||
|
||||
# Reshape output back to original dimensions (except first dimension)
|
||||
if len(original_shape) > 2:
|
||||
output_shape = (B, Lmax) + original_shape[1:]
|
||||
out = out.reshape(output_shape)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _unpack_seq_triton_kernel(
|
||||
packed_ptr, # [B, Lmax, D]
|
||||
out_ptr, # [N, D]
|
||||
lengths_ptr, # *i32, [B]
|
||||
B: tl.constexpr,
|
||||
Lmax: tl.constexpr,
|
||||
D: tl.constexpr,
|
||||
BLOCK_T: tl.constexpr, # timesteps per program
|
||||
BLOCK_D: tl.constexpr # features per program
|
||||
):
|
||||
pid_b = tl.program_id(0) # batch id
|
||||
pid_t = tl.program_id(1) # block over time dimension
|
||||
pid_d = tl.program_id(2) # block over feature dimension
|
||||
off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T]
|
||||
off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D]
|
||||
|
||||
# bounds: compute start from cumulative lengths
|
||||
in_start = 0
|
||||
for i in range(pid_b):
|
||||
in_start += tl.load(lengths_ptr + i)
|
||||
seq_len = tl.load(lengths_ptr + pid_b)
|
||||
|
||||
# valid time positions for this block
|
||||
t_mask = off_t < Lmax
|
||||
valid_row = (off_t < seq_len) & t_mask
|
||||
|
||||
# compute output row indices for valid (b, t)
|
||||
out_row = in_start + off_t
|
||||
|
||||
# Pointers
|
||||
# packed_ptr: row-major [B, Lmax, D]
|
||||
packed_row_ptr = packed_ptr + (pid_b * Lmax +
|
||||
off_t)[:, None] * D + off_d[None, :]
|
||||
|
||||
# out_ptr: row-major [N, D]
|
||||
out_row_ptr = out_ptr + out_row[:, None] * D + off_d[None, :]
|
||||
|
||||
# Load from packed tensor and store to output
|
||||
d_mask = off_d[None, :] < D
|
||||
packed_vals = tl.load(packed_row_ptr, mask=valid_row[:, None] & d_mask)
|
||||
tl.store(out_row_ptr, packed_vals, mask=valid_row[:, None] & d_mask)
|
||||
|
||||
|
||||
def unpack_seq_triton(packed_tensor: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
block_t: int = 64,
|
||||
block_d: int = 64) -> torch.Tensor:
|
||||
"""
|
||||
Unpack a packed decode query tensor back to the original format.
|
||||
Efficient Triton implementation.
|
||||
|
||||
Args:
|
||||
packed_tensor: [B, Lmax, ...] - packed tensor from pack_seq_triton
|
||||
lengths: [B] - sequence lengths for each batch
|
||||
block_t: block size for time dimension
|
||||
block_d: block size for feature dimension
|
||||
|
||||
Returns:
|
||||
unpacked_tensor: [N, ...] where N = sum(lengths)
|
||||
"""
|
||||
|
||||
# Handle multi-dimensional input by reshaping to (B, Lmax, -1)
|
||||
original_shape = packed_tensor.shape
|
||||
if len(original_shape) > 3:
|
||||
B, Lmax = original_shape[:2]
|
||||
packed_reshaped = packed_tensor.reshape(B, Lmax, -1)
|
||||
D = packed_reshaped.shape[2]
|
||||
else:
|
||||
B, Lmax, D = packed_tensor.shape
|
||||
packed_reshaped = packed_tensor
|
||||
|
||||
# Calculate total number of elements
|
||||
N = int(lengths.sum().item())
|
||||
|
||||
out = torch.empty((N, D),
|
||||
device=packed_tensor.device,
|
||||
dtype=packed_tensor.dtype)
|
||||
|
||||
grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d))
|
||||
_unpack_seq_triton_kernel[grid](packed_reshaped,
|
||||
out,
|
||||
lengths.int(),
|
||||
B,
|
||||
Lmax,
|
||||
D,
|
||||
BLOCK_T=block_t,
|
||||
BLOCK_D=block_d,
|
||||
num_warps=4,
|
||||
num_stages=2)
|
||||
|
||||
# Reshape output back to original dimensions (except first dimension)
|
||||
if len(original_shape) > 3:
|
||||
output_shape = (N, ) + original_shape[2:]
|
||||
out = out.reshape(output_shape)
|
||||
|
||||
return out
|
||||
|
||||
@ -19,15 +19,6 @@ if current_platform.is_cuda():
|
||||
else:
|
||||
_flashmla_C_AVAILABLE = False
|
||||
|
||||
if current_platform.is_cuda():
|
||||
try:
|
||||
import vllm._flashmla_extension_C # noqa: F401
|
||||
_flashmla_extension_C_AVAILABLE = True
|
||||
except ImportError:
|
||||
_flashmla_extension_C_AVAILABLE = False
|
||||
else:
|
||||
_flashmla_extension_C_AVAILABLE = False
|
||||
|
||||
|
||||
def is_flashmla_supported() -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
@ -46,34 +37,24 @@ def is_flashmla_supported() -> Tuple[bool, Optional[str]]:
|
||||
|
||||
|
||||
def get_mla_metadata(
|
||||
cache_seqlens: torch.Tensor,
|
||||
num_q_tokens_per_head_k: int,
|
||||
num_heads_k: int,
|
||||
num_heads_q: Optional[int] = None,
|
||||
is_fp8_kvcache: bool = False,
|
||||
topk: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
cache_seqlens: torch.Tensor,
|
||||
num_heads_per_head_k: int,
|
||||
num_heads_k: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
- cache_seqlens: (batch_size), dtype torch.int32.
|
||||
- num_q_tokens_per_head_k:
|
||||
Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k.
|
||||
- num_heads_k: The number of k heads.
|
||||
- num_heads_q:
|
||||
The number of q heads.
|
||||
This argument is optional when sparse attention is not enabled
|
||||
- is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format.
|
||||
- topk: If not None, sparse attention will be enabled,
|
||||
and only tokens in the `indices` array
|
||||
passed to `flash_mla_with_kvcache_sm90` will be attended to.
|
||||
cache_seqlens: (batch_size), dtype torch.int32.
|
||||
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
|
||||
num_heads_k: num_heads_k.
|
||||
|
||||
Returns:
|
||||
- tile_scheduler_metadata:
|
||||
(num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
|
||||
- num_splits: (batch_size + 1), dtype torch.int32.
|
||||
Return:
|
||||
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize),
|
||||
dtype torch.int32.
|
||||
num_splits: (batch_size + 1), dtype torch.int32.
|
||||
"""
|
||||
return torch.ops._flashmla_C.get_mla_decoding_metadata(
|
||||
cache_seqlens, num_q_tokens_per_head_k, num_heads_k, num_heads_q,
|
||||
is_fp8_kvcache, topk)
|
||||
return torch.ops._flashmla_C.get_mla_metadata(cache_seqlens,
|
||||
num_heads_per_head_k,
|
||||
num_heads_k)
|
||||
|
||||
|
||||
def flash_mla_with_kvcache(
|
||||
@ -88,95 +69,45 @@ def flash_mla_with_kvcache(
|
||||
causal: bool = False,
|
||||
descale_q: Optional[torch.Tensor] = None,
|
||||
descale_k: Optional[torch.Tensor] = None,
|
||||
is_fp8_kvcache: bool = False,
|
||||
indices: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
- q: (batch_size, seq_len_q, num_heads_q, head_dim).
|
||||
- k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
|
||||
- block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
|
||||
- cache_seqlens: (batch_size), torch.int32.
|
||||
- head_dim_v: Head dimension of v.
|
||||
- tile_scheduler_metadata:
|
||||
(num_sm_parts, TileSchedulerMetaDataSize), torch.int32,
|
||||
returned by get_mla_metadata.
|
||||
- num_splits:
|
||||
(batch_size + 1), torch.int32, returned by get_mla_metadata.
|
||||
- softmax_scale: float.
|
||||
The scale of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(head_dim).
|
||||
- causal: bool. Whether to apply causal attention mask.
|
||||
- descale_q: (batch_size),
|
||||
torch.float32. Descaling factors for Q, used for fp8 quantization.
|
||||
- descale_k: (batch_size),
|
||||
torch.float32. Descaling factors for K, used for fp8 quantization.
|
||||
- is_fp8_kvcache: bool.
|
||||
Whether the k_cache and v_cache are in fp8 format.
|
||||
For the format of FP8 KV cache, please refer to README.md
|
||||
- indices: (batch_size, seq_len_q, topk), torch.int32.
|
||||
If not None, sparse attention will be enabled,
|
||||
and only tokens in the `indices` array will be attended to.
|
||||
Invalid indices should be set to -1 or numbers >= total_seq_len_kv.
|
||||
For details about how to set up `indices`, please refer to README.md.
|
||||
q: (batch_size, seq_len_q, num_heads_q, head_dim).
|
||||
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
|
||||
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
|
||||
cache_seqlens: (batch_size), torch.int32.
|
||||
head_dim_v: Head_dim of v.
|
||||
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize),
|
||||
torch.int32, return by get_mla_metadata.
|
||||
num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(head_dim).
|
||||
causal: bool. Whether to apply causal attention mask.
|
||||
descale_q: (batch_size), torch.float32. Descaling factors for Q.
|
||||
descale_k: (batch_size), torch.float32. Descaling factors for K.
|
||||
|
||||
Returns:
|
||||
- out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
|
||||
- softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
|
||||
Return:
|
||||
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
|
||||
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
|
||||
"""
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1]**(-0.5)
|
||||
if indices is not None:
|
||||
# NOTE (zyongye): sparse attention is also causal
|
||||
# since it only attend to the tokens before
|
||||
# but here `causal` should not be specified
|
||||
assert not causal, \
|
||||
"causal must be `false` if sparse attention is enabled."
|
||||
assert (descale_q is None) == (
|
||||
descale_k is None
|
||||
), "descale_q and descale_k should be both None or both not None"
|
||||
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
|
||||
q,
|
||||
k_cache,
|
||||
head_dim_v,
|
||||
cache_seqlens,
|
||||
block_table,
|
||||
softmax_scale,
|
||||
causal,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
descale_q,
|
||||
descale_k,
|
||||
)
|
||||
|
||||
if indices is None and q.element_size() == 1:
|
||||
out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
|
||||
q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
|
||||
causal, tile_scheduler_metadata, num_splits, descale_q, descale_k)
|
||||
else:
|
||||
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
|
||||
q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
|
||||
causal, tile_scheduler_metadata, num_splits, is_fp8_kvcache,
|
||||
indices)
|
||||
return out, softmax_lse
|
||||
|
||||
|
||||
def flash_mla_sparse_prefill(
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
sm_scale: float,
|
||||
d_v: int = 512,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Sparse attention prefill kernel
|
||||
|
||||
Args:
|
||||
- q: [s_q, h_q, d_qk], bfloat16
|
||||
- kv: [s_kv, h_kv, d_qk], bfloat16
|
||||
- indices: [s_q, h_kv, topk], int32.
|
||||
Invalid indices should be set to -1 or numbers >= s_kv
|
||||
- sm_scale: float
|
||||
- d_v: The dimension of value vectors. Can only be 512
|
||||
|
||||
Returns:
|
||||
- (output, max_logits, lse)
|
||||
About the definition of output,
|
||||
max_logits and lse, please refer to README.md
|
||||
- output: [s_q, h_q, d_v], bfloat16
|
||||
- max_logits: [s_q, h_q], float
|
||||
- lse: [s_q, h_q], float, 2-based log-sum-exp
|
||||
"""
|
||||
results = torch.ops._flashmla_C.sparse_prefill_fwd(q, kv, indices,
|
||||
sm_scale, d_v)
|
||||
return results
|
||||
# Note(hc): need revisit when we support DCP with decode query_len > 1.
|
||||
return out.squeeze(1), softmax_lse.squeeze(-1)
|
||||
|
||||
|
||||
#
|
||||
|
||||
@ -50,7 +50,6 @@ class PagedAttention:
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> Tuple[int, ...]:
|
||||
return (2, num_blocks, block_size * num_kv_heads * head_size)
|
||||
|
||||
|
||||
@ -144,7 +144,6 @@ def get_attn_backend(
|
||||
block_size: int,
|
||||
use_mla: bool = False,
|
||||
has_sink: bool = False,
|
||||
use_sparse: bool = False,
|
||||
) -> type[AttentionBackend]:
|
||||
"""Selects which attention backend to use and lazily imports it."""
|
||||
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
|
||||
@ -159,7 +158,6 @@ def get_attn_backend(
|
||||
use_v1=envs.VLLM_USE_V1,
|
||||
use_mla=use_mla,
|
||||
has_sink=has_sink,
|
||||
use_sparse=use_sparse,
|
||||
)
|
||||
|
||||
|
||||
@ -172,7 +170,6 @@ def _cached_get_attn_backend(
|
||||
use_v1: bool = False,
|
||||
use_mla: bool = False,
|
||||
has_sink: bool = False,
|
||||
use_sparse: bool = False,
|
||||
) -> type[AttentionBackend]:
|
||||
|
||||
# Check whether a particular choice of backend was
|
||||
@ -206,7 +203,7 @@ def _cached_get_attn_backend(
|
||||
# get device-specific attn_backend
|
||||
attention_cls = current_platform.get_attn_backend_cls(
|
||||
selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1,
|
||||
use_mla, has_sink, use_sparse)
|
||||
use_mla, has_sink)
|
||||
if not attention_cls:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend for {current_platform.device_name}")
|
||||
|
||||
@ -1450,6 +1450,13 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
):
|
||||
dataset_class = MLPerfDataset
|
||||
args.hf_split = "train"
|
||||
elif (
|
||||
args.dataset_path in MMStarDataset.SUPPORTED_DATASET_PATHS
|
||||
or args.hf_name in MMStarDataset.SUPPORTED_DATASET_PATHS
|
||||
):
|
||||
dataset_class = MMStarDataset
|
||||
args.hf_split = "val"
|
||||
args.hf_subset = None
|
||||
else:
|
||||
supported_datasets = set([
|
||||
dataset_name for cls in HuggingFaceDataset.__subclasses__()
|
||||
@ -2721,3 +2728,76 @@ class PrefixRepetitionRandomDataset(BenchmarkDataset):
|
||||
|
||||
random.shuffle(requests)
|
||||
return requests
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# MMStar Dataset Implementation
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MMStarDataset(HuggingFaceDataset):
|
||||
"""
|
||||
Lin-Chen/MMStar: https://huggingface.co/datasets/Lin-Chen/MMStar
|
||||
refer to: https://github.com/sgl-project/SpecForge/pull/106
|
||||
"""
|
||||
DEFAULT_OUTPUT_LEN = 128
|
||||
SUPPORTED_DATASET_PATHS = {"Lin-Chen/MMStar"}
|
||||
IS_MULTIMODAL = True
|
||||
|
||||
def sample(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int,
|
||||
output_len: Optional[int] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
request_id_prefix: str = "",
|
||||
no_oversample: bool = False,
|
||||
**kwargs,
|
||||
) -> list[SampleRequest]:
|
||||
# If --hf-output-len is not set, use the default output length.
|
||||
output_len = (output_len
|
||||
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
|
||||
sampled_requests: list[SampleRequest] = []
|
||||
|
||||
for ind, item in enumerate(self.data):
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
# Split the question text from options
|
||||
# (keep only the part before "Options:").
|
||||
full_q: str = item.get("question", "")
|
||||
question_text = full_q.split("Options:", 1)[0].strip()
|
||||
|
||||
# Multimodal image content.
|
||||
mm_content = process_image(item["image"])
|
||||
|
||||
# Compute prompt token length (note: this is plain text length
|
||||
# if enable_multimodal_chat is False).
|
||||
prompt_len = len(tokenizer(question_text).input_ids)
|
||||
|
||||
if enable_multimodal_chat:
|
||||
# If multimodal content should be embedded in the chat message,
|
||||
# convert to [{"role":"user","content":[...]}]
|
||||
prompt = self.apply_multimodal_chat_transformation(
|
||||
question_text, mm_content
|
||||
)
|
||||
mm_for_request = None # Already embedded in chat content.
|
||||
else:
|
||||
# Default: prompt is plain text,
|
||||
# image is in mm_content for the bench to assemble.
|
||||
prompt = question_text
|
||||
mm_for_request = mm_content
|
||||
|
||||
sampled_requests.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
prompt_len=prompt_len,
|
||||
expected_output_len=output_len,
|
||||
multi_modal_data=mm_for_request,
|
||||
request_id=request_id_prefix + str(ind),
|
||||
)
|
||||
)
|
||||
|
||||
self.maybe_oversample_requests(
|
||||
sampled_requests, num_requests, request_id_prefix, no_oversample
|
||||
)
|
||||
return sampled_requests
|
||||
|
||||
@ -340,15 +340,15 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
num_graphs=len(self.compile_submod_names),
|
||||
runtime_shape=None)
|
||||
# Lazy import here to avoid circular import
|
||||
from .cuda_piecewise_backend import PiecewiseBackend
|
||||
from .piecewise_backend import PiecewiseBackend
|
||||
|
||||
piecewise_backend = PiecewiseBackend(
|
||||
submod, self.vllm_config, index,
|
||||
len(self.compile_submod_names), sym_shape_indices,
|
||||
compiled_graph_for_dynamic_shape, self.vllm_backend)
|
||||
|
||||
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
and
|
||||
if (self.compilation_config.cudagraph_mode.\
|
||||
has_piecewise_cudagraphs() and
|
||||
not self.compilation_config.use_inductor_graph_partition):
|
||||
# We're using Dynamo-based piecewise splitting, so we wrap
|
||||
# the whole subgraph with a static graph wrapper.
|
||||
|
||||
@ -336,7 +336,7 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
|
||||
from vllm.config import CUDAGraphMode
|
||||
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
if (compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
|
||||
and compilation_config.use_inductor_graph_partition):
|
||||
from torch._inductor.utils import CUDAGraphWrapperMetadata
|
||||
|
||||
@ -365,7 +365,7 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
|
||||
|
||||
yield
|
||||
|
||||
if (compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
if (compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
|
||||
and compilation_config.use_inductor_graph_partition):
|
||||
torch._inductor.utils.set_customized_partition_wrappers(None)
|
||||
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
|
||||
@ -18,13 +17,12 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig):
|
||||
torch_compile_start_time = time.time()
|
||||
|
||||
compilation_config: CompilationConfig = vllm_config.compilation_config
|
||||
if compilation_config.level == CompilationLevel.PIECEWISE and \
|
||||
compilation_config.debug_dump_path:
|
||||
path = vllm_config.compile_debug_dump_path()
|
||||
if compilation_config.level == CompilationLevel.PIECEWISE and path:
|
||||
import depyf
|
||||
path = os.path.join(compilation_config.debug_dump_path,
|
||||
f"rank_{vllm_config.parallel_config.rank}")
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
global context_manager
|
||||
context_manager = depyf.prepare_debug(path)
|
||||
context_manager = depyf.prepare_debug(path.as_posix())
|
||||
context_manager.__enter__()
|
||||
|
||||
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
import functools
|
||||
import operator
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
import regex as re
|
||||
@ -96,12 +95,10 @@ class VllmPatternMatcherPass(VllmInductorPass):
|
||||
|
||||
TODO(luka): use pattern object to manually produce pattern graph
|
||||
"""
|
||||
debug_dump_path = config.compilation_config.debug_dump_path
|
||||
debug_dump_path = config.compile_debug_dump_path()
|
||||
if not debug_dump_path:
|
||||
return
|
||||
|
||||
rank = config.parallel_config.rank
|
||||
debug_dump_path = Path(debug_dump_path) / f"rank_{rank}"
|
||||
debug_dump_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
from vllm.utils import unique_filepath
|
||||
|
||||
@ -92,12 +92,11 @@ class TorchCompileWrapperWithCustomDispatcher:
|
||||
return
|
||||
|
||||
self.compiled_codes.append(new_code)
|
||||
debug_dump_dir = self.vllm_config.compilation_config.debug_dump_path
|
||||
if isinstance(debug_dump_dir, str) and debug_dump_dir != "":
|
||||
rank = self.vllm_config.parallel_config.rank
|
||||
decompiled_file = os.path.join(debug_dump_dir, f"rank_{rank}",
|
||||
"transformed_code.py")
|
||||
if not os.path.exists(decompiled_file):
|
||||
|
||||
path = self.vllm_config.compile_debug_dump_path()
|
||||
if path:
|
||||
decompiled_file = path / "transformed_code.py"
|
||||
if not decompiled_file.exists():
|
||||
try:
|
||||
# usually the decompilation will succeed for most models,
|
||||
# as we guarantee a full-graph compilation in Dynamo.
|
||||
|
||||
@ -12,6 +12,7 @@ import textwrap
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import field, fields, is_dataclass, replace
|
||||
from functools import cached_property, lru_cache
|
||||
from pathlib import Path
|
||||
from typing import (TYPE_CHECKING, Any, Literal, Optional, Protocol, TypeVar,
|
||||
Union, cast)
|
||||
|
||||
@ -459,15 +460,22 @@ class VllmConfig:
|
||||
"to True to enable.")
|
||||
current_platform.check_and_update_config(self)
|
||||
|
||||
# final check of cudagraph mode after platform-specific update
|
||||
# Do this after all the updates to compilation_config.level
|
||||
if envs.VLLM_USE_V1 and \
|
||||
self.compilation_config.level == CompilationLevel.PIECEWISE:
|
||||
self.compilation_config.set_splitting_ops_for_v1()
|
||||
|
||||
# final check of cudagraph mode after all possible updates
|
||||
if envs.VLLM_USE_V1 and current_platform.is_cuda_alike():
|
||||
if self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL \
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs()\
|
||||
and self.model_config is not None and \
|
||||
not self.model_config.disable_cascade_attn:
|
||||
logger.info("CUDAGraphMode.FULL is not supported with "
|
||||
"cascade attention currently. Disabling cascade"
|
||||
"attention.")
|
||||
self.model_config.disable_cascade_attn = True
|
||||
not self.model_config.disable_cascade_attn and\
|
||||
not self.compilation_config.cudagraph_mode.\
|
||||
has_piecewise_cudagraphs():
|
||||
logger.warning_once(
|
||||
"No piecewise cudagraph for executing cascade attention."
|
||||
" Will fall back to eager execution if a batch runs "
|
||||
"into cascade attentions")
|
||||
|
||||
if self.compilation_config.cudagraph_mode\
|
||||
.requires_piecewise_compilation():
|
||||
@ -477,6 +485,12 @@ class VllmConfig:
|
||||
"when cudagraph_mode piecewise cudagraphs is used, "\
|
||||
f"cudagraph_mode={self.compilation_config.cudagraph_mode}"
|
||||
|
||||
# final migrate the deprecated flags
|
||||
self.compilation_config.use_cudagraph = self.compilation_config.\
|
||||
cudagraph_mode!= CUDAGraphMode.NONE
|
||||
self.compilation_config.full_cuda_graph = self.compilation_config.\
|
||||
cudagraph_mode.has_full_cudagraphs()
|
||||
|
||||
if self.parallel_config.enable_dbo:
|
||||
a2a_backend = envs.VLLM_ALL2ALL_BACKEND
|
||||
assert a2a_backend in \
|
||||
@ -487,14 +501,14 @@ class VllmConfig:
|
||||
"variable to deepep_low_latency or deepep_high_throughput and "\
|
||||
"install the DeepEP kernels."
|
||||
|
||||
if not self.model_config.disable_cascade_attn:
|
||||
self.model_config.disable_cascade_attn = True
|
||||
logger.warning_once(
|
||||
"Disabling cascade attention when DBO is enabled.")
|
||||
|
||||
if not self.instance_id:
|
||||
self.instance_id = random_uuid()[:5]
|
||||
|
||||
# Do this after all the updates to compilation_config.level
|
||||
if envs.VLLM_USE_V1 and \
|
||||
self.compilation_config.level == CompilationLevel.PIECEWISE:
|
||||
self.compilation_config.set_splitting_ops_for_v1()
|
||||
|
||||
if (envs.VLLM_USE_V1
|
||||
and not self.scheduler_config.disable_hybrid_kv_cache_manager):
|
||||
# logger should only print warning message for hybrid models. As we
|
||||
@ -528,6 +542,17 @@ class VllmConfig:
|
||||
# local attention.
|
||||
self.scheduler_config.disable_hybrid_kv_cache_manager = True
|
||||
|
||||
if self.compilation_config.debug_dump_path:
|
||||
self.compilation_config.debug_dump_path = \
|
||||
self.compilation_config.debug_dump_path.absolute().expanduser()
|
||||
if envs.VLLM_DEBUG_DUMP_PATH is not None:
|
||||
env_path = Path(envs.VLLM_DEBUG_DUMP_PATH).absolute().expanduser()
|
||||
if self.compilation_config.debug_dump_path:
|
||||
logger.warning(
|
||||
"Config-specified debug dump path is overridden"
|
||||
" by VLLM_DEBUG_DUMP_PATH to %s", env_path)
|
||||
self.compilation_config.debug_dump_path = env_path
|
||||
|
||||
def update_sizes_for_sequence_parallelism(self,
|
||||
possible_sizes: list) -> list:
|
||||
# remove the sizes that not multiple of tp_size when
|
||||
@ -659,6 +684,20 @@ class VllmConfig:
|
||||
f"but got '{self.load_config.load_format}'. "
|
||||
f"Model: {self.model_config.model}")
|
||||
|
||||
def compile_debug_dump_path(self) -> Optional[Path]:
|
||||
"""Returns a rank-aware path for dumping
|
||||
torch.compile debug information.
|
||||
"""
|
||||
if self.compilation_config.debug_dump_path is None:
|
||||
return None
|
||||
tp_rank = self.parallel_config.rank
|
||||
dp_rank = self.parallel_config.data_parallel_rank
|
||||
data_parallel_size = self.parallel_config.data_parallel_size
|
||||
append_path = f"rank_{tp_rank}" if data_parallel_size == 1 \
|
||||
else f"rank_{tp_rank}_dp_{dp_rank}"
|
||||
path = self.compilation_config.debug_dump_path / append_path
|
||||
return path
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"model={self.model_config.model!r}, "
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user