Compare commits
90 Commits
moondream2
...
tpu_v1_opt
| Author | SHA1 | Date | |
|---|---|---|---|
| 70b4e46e70 | |||
| 5fb9dbe6f6 | |||
| 996b92ccb4 | |||
| 2b0526fa15 | |||
| 7be649256f | |||
| 627efde813 | |||
| c2867d5bc1 | |||
| 39c4a4cdb5 | |||
| 1ccf100c6a | |||
| 248c5b632d | |||
| 950f349492 | |||
| 61bb55f3d5 | |||
| 0bddb6b9a5 | |||
| c715fb19e5 | |||
| 24b0205f58 | |||
| c5cffcd0cd | |||
| 682b55bc07 | |||
| 9726ad676d | |||
| eb5cb5e528 | |||
| 2cbeedad09 | |||
| 2c85529bfc | |||
| e97f802b2d | |||
| 6e650f56a1 | |||
| 3f50c148fd | |||
| 8c01b8022c | |||
| 99d01a5e3d | |||
| d07efb31c5 | |||
| 978b45f399 | |||
| c5b4b11d7f | |||
| 8ae5ff2009 | |||
| 511627445e | |||
| f0ef37233e | |||
| 7551a34032 | |||
| 01a55941f5 | |||
| 8d7aa9de71 | |||
| 68c4421b6d | |||
| aea94362c9 | |||
| 7206ce4ce1 | |||
| 96f6a7596f | |||
| 84bee4bd5c | |||
| fc66dee76d | |||
| 6609cdf019 | |||
| 16366ee8bb | |||
| 528dbcac7d | |||
| cd7b6f0857 | |||
| 68ad4e3a8d | |||
| 4004f144f3 | |||
| 66818e5b63 | |||
| 222a9dc350 | |||
| cbdc4ad5a5 | |||
| 016e3676e7 | |||
| 64ea24d0b3 | |||
| df76e5af26 | |||
| 09ccc9c8f7 | |||
| 69196a9bc7 | |||
| 2acba47d9b | |||
| 9c485d9e25 | |||
| fa9ee08121 | |||
| 347eeebe3b | |||
| 18fd4a8331 | |||
| 132a132100 | |||
| 1e60f87bb3 | |||
| 9705b90bcf | |||
| 3aec49e56f | |||
| c64612802b | |||
| 9a7c3a0042 | |||
| b197a5ccfd | |||
| c81081fece | |||
| a94eee4456 | |||
| f2e9f2a3be | |||
| 1f1542afa9 | |||
| 96912550c8 | |||
| 2fc6944c5e | |||
| 5fe6bf29d6 | |||
| d4b62d4641 | |||
| ecf67814f1 | |||
| 750f4cabfa | |||
| 06a760d6e8 | |||
| da7512215f | |||
| af69a6aded | |||
| 7bd3630067 | |||
| 96663699b2 | |||
| 18572e3384 | |||
| 86bfb6dba7 | |||
| 5f0ec3935a | |||
| c222f47992 | |||
| 170eb35079 | |||
| b37d82791e | |||
| 3127e975fb | |||
| 4001ea1266 |
@ -25,8 +25,11 @@ if [ -f /tmp/neuron-docker-build-timestamp ]; then
|
||||
last_build=$(cat /tmp/neuron-docker-build-timestamp)
|
||||
current_time=$(date +%s)
|
||||
if [ $((current_time - last_build)) -gt 86400 ]; then
|
||||
# Remove dangling images (those that are not tagged and not used by any container)
|
||||
docker image prune -f
|
||||
docker system prune -f
|
||||
# Remove unused volumes / force the system prune for old images as well.
|
||||
docker volume prune -f && docker system prune -f
|
||||
# Remove huggingface model artifacts and compiler cache
|
||||
rm -rf "${HF_MOUNT:?}/*"
|
||||
rm -rf "${NEURON_COMPILE_CACHE_MOUNT:?}/*"
|
||||
echo "$current_time" > /tmp/neuron-docker-build-timestamp
|
||||
|
||||
@ -76,7 +76,9 @@ steps:
|
||||
- tests/basic_correctness/test_basic_correctness
|
||||
- tests/basic_correctness/test_cpu_offload
|
||||
- tests/basic_correctness/test_preemption
|
||||
- tests/basic_correctness/test_cumem.py
|
||||
commands:
|
||||
- pytest -v -s basic_correctness/test_cumem.py
|
||||
- pytest -v -s basic_correctness/test_basic_correctness.py
|
||||
- pytest -v -s basic_correctness/test_cpu_offload.py
|
||||
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
|
||||
@ -477,7 +479,9 @@ steps:
|
||||
- pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
|
||||
# this test fails consistently.
|
||||
# TODO: investigate and fix
|
||||
# - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/disagg_test.py
|
||||
|
||||
@ -515,7 +519,9 @@ steps:
|
||||
- vllm/engine
|
||||
- tests/multi_step
|
||||
commands:
|
||||
- pytest -v -s multi_step/test_correctness_async_llm.py
|
||||
# this test is quite flaky
|
||||
# TODO: investigate and fix.
|
||||
# - pytest -v -s multi_step/test_correctness_async_llm.py
|
||||
- pytest -v -s multi_step/test_correctness_llm.py
|
||||
|
||||
- label: Pipeline Parallelism Test # 45min
|
||||
|
||||
27
.github/CODEOWNERS
vendored
27
.github/CODEOWNERS
vendored
@ -2,32 +2,35 @@
|
||||
# for more info about CODEOWNERS file
|
||||
|
||||
# This lists cover the "core" components of vLLM that require careful review
|
||||
/vllm/attention/backends/abstract.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
|
||||
/vllm/core @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
|
||||
/vllm/engine/llm_engine.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
|
||||
/vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
|
||||
/vllm/worker/worker_base.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
|
||||
/vllm/worker/worker.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
|
||||
/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
|
||||
/vllm/attention/backends/abstract.py @WoosukKwon @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/core @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/engine/llm_engine.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/worker/worker_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/worker/worker.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth
|
||||
/vllm/model_executor/guided_decoding @mgoin
|
||||
/vllm/multimodal @DarkLight1337 @ywang96
|
||||
CMakeLists.txt @tlrmchlsmth
|
||||
|
||||
# vLLM V1
|
||||
/vllm/v1 @WoosukKwon @robertgshaw2-neuralmagic @njhill @ywang96 @comaniac @alexm-neuralmagic
|
||||
/vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat
|
||||
|
||||
# Test ownership
|
||||
/tests/async_engine @njhill @robertgshaw2-neuralmagic @simon-mo
|
||||
/tests/async_engine @njhill @robertgshaw2-redhat @simon-mo
|
||||
/tests/test_inputs.py @DarkLight1337 @ywang96
|
||||
/tests/entrypoints @DarkLight1337 @robertgshaw2-neuralmagic @simon-mo
|
||||
/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @simon-mo
|
||||
/tests/models @DarkLight1337 @ywang96
|
||||
/tests/multimodal @DarkLight1337 @ywang96
|
||||
/tests/prefix_caching @comaniac @KuntaiDu
|
||||
/tests/spec_decode @njhill @LiuXiaoxuanPKU
|
||||
/tests/kernels @tlrmchlsmth @WoosukKwon
|
||||
/tests/quantization @mgoin @robertgshaw2-neuralmagic
|
||||
/tests/quantization @mgoin @robertgshaw2-redhat
|
||||
/.buildkite/lm-eval-harness @mgoin @simon-mo
|
||||
/tests/distributed/test_multi_node_assignment.py @youkaichao
|
||||
/tests/distributed/test_pipeline_parallel.py @youkaichao
|
||||
/tests/distributed/test_same_node.py @youkaichao
|
||||
/tests/multi_step @alexm-neuralmagic @comaniac
|
||||
/tests/multi_step @alexm-redhat @comaniac
|
||||
/tests/weight_loading @mgoin @youkaichao
|
||||
/tests/basic_correctness/test_chunked_prefill @rkooo567 @comaniac
|
||||
|
||||
20
.github/workflows/dummy.yml
vendored
20
.github/workflows/dummy.yml
vendored
@ -1,20 +0,0 @@
|
||||
name: dummy-checks
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
mypy:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.12"]
|
||||
steps:
|
||||
- run: echo "This is a dummy step that always passes"
|
||||
ruff:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.12"]
|
||||
steps:
|
||||
- run: echo "This is a dummy step that always passes"
|
||||
2
.github/workflows/pre-commit.yml
vendored
2
.github/workflows/pre-commit.yml
vendored
@ -15,3 +15,5 @@ jobs:
|
||||
python-version: "3.12"
|
||||
- run: echo "::add-matcher::.github/workflows/matchers/actionlint.json"
|
||||
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
||||
with:
|
||||
extra_args: --all-files --hook-stage manual
|
||||
|
||||
@ -1,3 +1,6 @@
|
||||
default_stages:
|
||||
- pre-commit # Run locally
|
||||
- manual # Run in CI
|
||||
repos:
|
||||
- repo: https://github.com/google/yapf
|
||||
rev: v0.32.0
|
||||
@ -31,32 +34,47 @@ repos:
|
||||
hooks:
|
||||
- id: pymarkdown
|
||||
files: docs/.*
|
||||
- repo: https://github.com/rhysd/actionlint
|
||||
rev: v1.7.6
|
||||
hooks:
|
||||
- id: actionlint
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: mypy-local
|
||||
name: Run mypy for local Python installation
|
||||
entry: tools/mypy.sh 0 "local"
|
||||
language: python
|
||||
types: [python]
|
||||
additional_dependencies: &mypy_deps [mypy==1.11.1, types-setuptools, types-PyYAML, types-requests]
|
||||
stages: [pre-commit] # Don't run in CI
|
||||
- id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
||||
name: Run mypy for Python 3.9
|
||||
entry: tools/mypy.sh 1 "3.9"
|
||||
language: python
|
||||
types: [python]
|
||||
additional_dependencies: &mypy_deps [mypy==1.11.1, types-setuptools, types-PyYAML, types-requests]
|
||||
additional_dependencies: *mypy_deps
|
||||
stages: [manual] # Only run in CI
|
||||
- id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
||||
name: Run mypy for Python 3.10
|
||||
entry: tools/mypy.sh 1 "3.10"
|
||||
language: python
|
||||
types: [python]
|
||||
additional_dependencies: *mypy_deps
|
||||
stages: [manual] # Only run in CI
|
||||
- id: mypy-3.11 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
||||
name: Run mypy for Python 3.11
|
||||
entry: tools/mypy.sh 1 "3.11"
|
||||
language: python
|
||||
types: [python]
|
||||
additional_dependencies: *mypy_deps
|
||||
stages: [manual] # Only run in CI
|
||||
- id: mypy-3.12 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
||||
name: Run mypy for Python 3.12
|
||||
entry: tools/mypy.sh 1 "3.12"
|
||||
language: python
|
||||
types: [python]
|
||||
additional_dependencies: *mypy_deps
|
||||
stages: [manual] # Only run in CI
|
||||
- id: shellcheck
|
||||
name: Lint shell scripts
|
||||
entry: tools/shellcheck.sh
|
||||
@ -67,7 +85,8 @@ repos:
|
||||
entry: tools/png-lint.sh
|
||||
language: script
|
||||
types: [png]
|
||||
- repo: https://github.com/rhysd/actionlint
|
||||
rev: v1.7.6
|
||||
hooks:
|
||||
- id: actionlint
|
||||
- id: suggestion
|
||||
name: Suggestion
|
||||
entry: bash -c 'echo "To bypass pre-commit hooks, add --no-verify to git commit."'
|
||||
language: system
|
||||
verbose: true
|
||||
@ -24,9 +24,6 @@ include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
|
||||
# Suppress potential warnings about unused manually-specified variables
|
||||
set(ignoreMe "${VLLM_PYTHON_PATH}")
|
||||
|
||||
# Prevent installation of dependencies (cutlass) by default.
|
||||
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
|
||||
|
||||
#
|
||||
# Supported python versions. These versions will be searched in order, the
|
||||
# first match will be selected. These should be kept in sync with setup.py.
|
||||
@ -181,6 +178,31 @@ message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}")
|
||||
# Define other extension targets
|
||||
#
|
||||
|
||||
#
|
||||
# cumem_allocator extension
|
||||
#
|
||||
|
||||
set(VLLM_CUMEM_EXT_SRC
|
||||
"csrc/cumem_allocator.cpp")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${VLLM_CUMEM_EXT_SRC}"
|
||||
CUDA_ARCHS "${CUDA_ARCHS}")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
message(STATUS "Enabling cumem allocator extension.")
|
||||
# link against cuda driver library
|
||||
list(APPEND CUMEM_LIBS cuda)
|
||||
define_gpu_extension_target(
|
||||
cumem_allocator
|
||||
DESTINATION vllm
|
||||
LANGUAGE CXX
|
||||
SOURCES ${VLLM_CUMEM_EXT_SRC}
|
||||
LIBRARIES ${CUMEM_LIBS}
|
||||
USE_SABI 3.8
|
||||
WITH_SOABI)
|
||||
endif()
|
||||
|
||||
#
|
||||
# _C extension
|
||||
#
|
||||
@ -510,7 +532,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
|
||||
endif()
|
||||
|
||||
# vllm-flash-attn currently only supported on CUDA
|
||||
if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda")
|
||||
if (NOT VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
return()
|
||||
endif ()
|
||||
|
||||
@ -533,7 +555,7 @@ endif()
|
||||
# They should be identical but if they aren't, this is a massive footgun.
|
||||
#
|
||||
# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place.
|
||||
# To only install vllm-flash-attn, use --component vllm_flash_attn_c.
|
||||
# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2) or --component _vllm_fa3_C (for FA3).
|
||||
# If no component is specified, vllm-flash-attn is still installed.
|
||||
|
||||
# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading.
|
||||
@ -545,43 +567,41 @@ if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR})
|
||||
endif()
|
||||
|
||||
if(VLLM_FLASH_ATTN_SRC_DIR)
|
||||
FetchContent_Declare(vllm-flash-attn SOURCE_DIR ${VLLM_FLASH_ATTN_SRC_DIR})
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn SOURCE_DIR
|
||||
${VLLM_FLASH_ATTN_SRC_DIR}
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
)
|
||||
else()
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn
|
||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||
GIT_TAG 96266b1111111f3d11aabefaf3bacbab6a89d03c
|
||||
GIT_TAG 90eacc1af2a7c3de62ea249e929ed5faccf38954
|
||||
GIT_PROGRESS TRUE
|
||||
# Don't share the vllm-flash-attn build between build types
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
)
|
||||
endif()
|
||||
|
||||
# Set the parent build flag so that the vllm-flash-attn library does not redo compile flag and arch initialization.
|
||||
set(VLLM_PARENT_BUILD ON)
|
||||
|
||||
# Ensure the vllm/vllm_flash_attn directory exists before installation
|
||||
install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")" COMPONENT vllm_flash_attn_c)
|
||||
|
||||
# Make sure vllm-flash-attn install rules are nested under vllm/
|
||||
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" COMPONENT vllm_flash_attn_c)
|
||||
install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c)
|
||||
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" COMPONENT vllm_flash_attn_c)
|
||||
|
||||
# Fetch the vllm-flash-attn library
|
||||
FetchContent_MakeAvailable(vllm-flash-attn)
|
||||
message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}")
|
||||
|
||||
# Restore the install prefix
|
||||
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c)
|
||||
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" COMPONENT vllm_flash_attn_c)
|
||||
|
||||
# Copy over the vllm-flash-attn python files
|
||||
# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in
|
||||
# case only one is built, in the case both are built redundant work is done)
|
||||
install(
|
||||
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
|
||||
DESTINATION vllm/vllm_flash_attn
|
||||
COMPONENT vllm_flash_attn_c
|
||||
FILES_MATCHING PATTERN "*.py"
|
||||
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
|
||||
DESTINATION vllm_flash_attn
|
||||
COMPONENT _vllm_fa2_C
|
||||
FILES_MATCHING PATTERN "*.py"
|
||||
)
|
||||
|
||||
install(
|
||||
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
|
||||
DESTINATION vllm_flash_attn
|
||||
COMPONENT _vllm_fa3_C
|
||||
FILES_MATCHING PATTERN "*.py"
|
||||
)
|
||||
|
||||
# Nothing after vllm-flash-attn, see comment about macros above
|
||||
|
||||
@ -52,7 +52,7 @@ WORKDIR /workspace
|
||||
# after this step
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
|
||||
python3 -m pip install --index-url https://download.pytorch.org/whl/nightly/cu124 "torch==2.6.0.dev20241210+cu124" "torchvision==0.22.0.dev20241215"; \
|
||||
python3 -m pip install --index-url https://download.pytorch.org/whl/nightly/cu126 "torch==2.7.0.dev20250121+cu126" "torchvision==0.22.0.dev20250121"; \
|
||||
fi
|
||||
|
||||
COPY requirements-common.txt requirements-common.txt
|
||||
|
||||
261
Dockerfile.rocm
261
Dockerfile.rocm
@ -1,174 +1,119 @@
|
||||
# Default ROCm 6.2 base image
|
||||
ARG BASE_IMAGE="rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.3.0"
|
||||
# default base image
|
||||
ARG REMOTE_VLLM="0"
|
||||
ARG USE_CYTHON="0"
|
||||
ARG BUILD_RPD="1"
|
||||
ARG COMMON_WORKDIR=/app
|
||||
ARG BASE_IMAGE=rocm/vllm-dev:base
|
||||
|
||||
# Default ROCm ARCHes to build vLLM for.
|
||||
ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100"
|
||||
FROM ${BASE_IMAGE} AS base
|
||||
|
||||
# Whether to install CK-based flash-attention
|
||||
# If 0, will not install flash-attention
|
||||
ARG BUILD_FA="1"
|
||||
ARG FA_GFX_ARCHS="gfx90a;gfx942"
|
||||
ARG FA_BRANCH="3cea2fb"
|
||||
|
||||
# Whether to build triton on rocm
|
||||
ARG BUILD_TRITON="1"
|
||||
ARG TRITON_BRANCH="e192dba"
|
||||
|
||||
### Base image build stage
|
||||
FROM $BASE_IMAGE AS base
|
||||
|
||||
# Import arg(s) defined before this build stage
|
||||
ARG PYTORCH_ROCM_ARCH
|
||||
ARG ARG_PYTORCH_ROCM_ARCH
|
||||
ENV PYTORCH_ROCM_ARCH=${ARG_PYTORCH_ROCM_ARCH:-${PYTORCH_ROCM_ARCH}}
|
||||
|
||||
# Install some basic utilities
|
||||
RUN apt-get update && apt-get install python3 python3-pip -y
|
||||
RUN apt-get update && apt-get install -y \
|
||||
curl \
|
||||
ca-certificates \
|
||||
sudo \
|
||||
git \
|
||||
bzip2 \
|
||||
libx11-6 \
|
||||
build-essential \
|
||||
wget \
|
||||
unzip \
|
||||
tmux \
|
||||
ccache \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# When launching the container, mount the code directory to /vllm-workspace
|
||||
ARG APP_MOUNT=/vllm-workspace
|
||||
WORKDIR ${APP_MOUNT}
|
||||
|
||||
RUN python3 -m pip install --upgrade pip
|
||||
# Remove sccache so it doesn't interfere with ccache
|
||||
# TODO: implement sccache support across components
|
||||
RUN apt-get update -q -y && apt-get install -q -y \
|
||||
sqlite3 libsqlite3-dev libfmt-dev libmsgpack-dev libsuitesparse-dev
|
||||
# Remove sccache
|
||||
RUN python3 -m pip install --upgrade pip && pip install setuptools_scm
|
||||
RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(which sccache)"
|
||||
|
||||
# Install torch == 2.6.0 on ROCm
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
|
||||
*"rocm-6.2"*) \
|
||||
python3 -m pip uninstall -y torch torchvision \
|
||||
&& python3 -m pip install --pre \
|
||||
torch \
|
||||
'setuptools-scm>=8' \
|
||||
torchvision \
|
||||
--extra-index-url https://download.pytorch.org/whl/rocm6.2;; \
|
||||
*) ;; esac
|
||||
|
||||
ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
|
||||
ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin:
|
||||
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
|
||||
ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:
|
||||
|
||||
ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
|
||||
ENV CCACHE_DIR=/root/.cache/ccache
|
||||
ARG COMMON_WORKDIR
|
||||
WORKDIR ${COMMON_WORKDIR}
|
||||
|
||||
|
||||
### AMD-SMI build stage
|
||||
FROM base AS build_amdsmi
|
||||
# Build amdsmi wheel always
|
||||
RUN cd /opt/rocm/share/amd_smi \
|
||||
&& python3 -m pip wheel . --wheel-dir=/install
|
||||
# -----------------------
|
||||
# vLLM fetch stages
|
||||
FROM base AS fetch_vllm_0
|
||||
ONBUILD COPY ./ vllm/
|
||||
FROM base AS fetch_vllm_1
|
||||
ARG VLLM_REPO="https://github.com/vllm-project/vllm.git"
|
||||
ARG VLLM_BRANCH="main"
|
||||
ONBUILD RUN git clone ${VLLM_REPO} \
|
||||
&& cd vllm \
|
||||
&& git checkout ${VLLM_BRANCH}
|
||||
FROM fetch_vllm_${REMOTE_VLLM} AS fetch_vllm
|
||||
|
||||
# -----------------------
|
||||
# vLLM build stages
|
||||
FROM fetch_vllm AS build_vllm
|
||||
ARG USE_CYTHON
|
||||
# Build vLLM
|
||||
RUN cd vllm \
|
||||
&& python3 -m pip install -r requirements-rocm.txt \
|
||||
&& python3 setup.py clean --all \
|
||||
&& if [ ${USE_CYTHON} -eq "1" ]; then python3 setup_cython.py build_ext --inplace; fi \
|
||||
&& python3 setup.py bdist_wheel --dist-dir=dist
|
||||
FROM scratch AS export_vllm
|
||||
ARG COMMON_WORKDIR
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/dist/*.whl /
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/requirements*.txt /
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/benchmarks /benchmarks
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/tests /tests
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/examples /examples
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/.buildkite /.buildkite
|
||||
|
||||
### Flash-Attention wheel build stage
|
||||
FROM base AS build_fa
|
||||
ARG BUILD_FA
|
||||
ARG FA_GFX_ARCHS
|
||||
ARG FA_BRANCH
|
||||
# Build ROCm flash-attention wheel if `BUILD_FA = 1`
|
||||
RUN --mount=type=cache,target=${CCACHE_DIR} \
|
||||
if [ "$BUILD_FA" = "1" ]; then \
|
||||
mkdir -p libs \
|
||||
&& cd libs \
|
||||
&& git clone https://github.com/ROCm/flash-attention.git \
|
||||
&& cd flash-attention \
|
||||
&& git checkout "${FA_BRANCH}" \
|
||||
&& git submodule update --init \
|
||||
&& GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \
|
||||
# Create an empty directory otherwise as later build stages expect one
|
||||
else mkdir -p /install; \
|
||||
fi
|
||||
# -----------------------
|
||||
# Test vLLM image
|
||||
FROM base AS test
|
||||
|
||||
RUN python3 -m pip install --upgrade pip && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
### Triton wheel build stage
|
||||
FROM base AS build_triton
|
||||
ARG BUILD_TRITON
|
||||
ARG TRITON_BRANCH
|
||||
# Build triton wheel if `BUILD_TRITON = 1`
|
||||
RUN --mount=type=cache,target=${CCACHE_DIR} \
|
||||
if [ "$BUILD_TRITON" = "1" ]; then \
|
||||
mkdir -p libs \
|
||||
&& cd libs \
|
||||
&& python3 -m pip install ninja cmake wheel pybind11 \
|
||||
&& git clone https://github.com/OpenAI/triton.git \
|
||||
&& cd triton \
|
||||
&& git checkout "${TRITON_BRANCH}" \
|
||||
&& cd python \
|
||||
&& python3 setup.py bdist_wheel --dist-dir=/install; \
|
||||
# Create an empty directory otherwise as later build stages expect one
|
||||
else mkdir -p /install; \
|
||||
fi
|
||||
# Install vLLM
|
||||
RUN --mount=type=bind,from=export_vllm,src=/,target=/install \
|
||||
cd /install \
|
||||
&& pip install -U -r requirements-rocm.txt \
|
||||
&& pip uninstall -y vllm \
|
||||
&& pip install *.whl
|
||||
|
||||
|
||||
### Final vLLM build stage
|
||||
FROM base AS final
|
||||
# Import the vLLM development directory from the build context
|
||||
COPY . .
|
||||
ARG GIT_REPO_CHECK=0
|
||||
RUN --mount=type=bind,source=.git,target=.git \
|
||||
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi
|
||||
|
||||
RUN python3 -m pip install --upgrade pip
|
||||
|
||||
# Package upgrades for useful functionality or to avoid dependency issues
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install --upgrade numba scipy huggingface-hub[cli] pytest-shard
|
||||
|
||||
|
||||
# Workaround for ray >= 2.10.0
|
||||
ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1
|
||||
# Silences the HF Tokenizers warning
|
||||
ENV TOKENIZERS_PARALLELISM=false
|
||||
|
||||
RUN --mount=type=cache,target=${CCACHE_DIR} \
|
||||
--mount=type=bind,source=.git,target=.git \
|
||||
--mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install -Ur requirements-rocm.txt \
|
||||
&& python3 setup.py clean --all \
|
||||
&& python3 setup.py develop
|
||||
|
||||
# Copy amdsmi wheel into final image
|
||||
RUN --mount=type=bind,from=build_amdsmi,src=/install,target=/install \
|
||||
mkdir -p libs \
|
||||
&& cp /install/*.whl libs \
|
||||
# Preemptively uninstall to avoid same-version no-installs
|
||||
&& python3 -m pip uninstall -y amdsmi;
|
||||
|
||||
# Copy triton wheel(s) into final image if they were built
|
||||
RUN --mount=type=bind,from=build_triton,src=/install,target=/install \
|
||||
mkdir -p libs \
|
||||
&& if ls /install/*.whl; then \
|
||||
cp /install/*.whl libs \
|
||||
# Preemptively uninstall to avoid same-version no-installs
|
||||
&& python3 -m pip uninstall -y triton; fi
|
||||
|
||||
# Copy flash-attn wheel(s) into final image if they were built
|
||||
RUN --mount=type=bind,from=build_fa,src=/install,target=/install \
|
||||
mkdir -p libs \
|
||||
&& if ls /install/*.whl; then \
|
||||
cp /install/*.whl libs \
|
||||
# Preemptively uninstall to avoid same-version no-installs
|
||||
&& python3 -m pip uninstall -y flash-attn; fi
|
||||
|
||||
# Install wheels that were built to the final image
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
if ls libs/*.whl; then \
|
||||
python3 -m pip install libs/*.whl; fi
|
||||
WORKDIR /vllm-workspace
|
||||
ARG COMMON_WORKDIR
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm /vllm-workspace
|
||||
|
||||
# install development dependencies (for testing)
|
||||
RUN python3 -m pip install -e tests/vllm_test_utils
|
||||
RUN cd /vllm-workspace \
|
||||
&& rm -rf vllm \
|
||||
&& python3 -m pip install -e tests/vllm_test_utils \
|
||||
&& python3 -m pip install lm-eval[api]==0.4.4 \
|
||||
&& python3 -m pip install pytest-shard
|
||||
|
||||
# -----------------------
|
||||
# Final vLLM image
|
||||
FROM base AS final
|
||||
|
||||
RUN python3 -m pip install --upgrade pip && rm -rf /var/lib/apt/lists/*
|
||||
# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
|
||||
# Manually remove it so that later steps of numpy upgrade can continue
|
||||
RUN case "$(which python3)" in \
|
||||
*"/opt/conda/envs/py_3.9"*) \
|
||||
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/;; \
|
||||
*) ;; esac
|
||||
|
||||
RUN python3 -m pip install --upgrade huggingface-hub[cli]
|
||||
ARG BUILD_RPD
|
||||
RUN if [ ${BUILD_RPD} -eq "1" ]; then \
|
||||
git clone -b nvtx_enabled https://github.com/ROCm/rocmProfileData.git \
|
||||
&& cd rocmProfileData/rpd_tracer \
|
||||
&& pip install -r requirements.txt && cd ../ \
|
||||
&& make && make install \
|
||||
&& cd hipMarker && python3 setup.py install ; fi
|
||||
|
||||
# Install vLLM
|
||||
RUN --mount=type=bind,from=export_vllm,src=/,target=/install \
|
||||
cd /install \
|
||||
&& pip install -U -r requirements-rocm.txt \
|
||||
&& pip uninstall -y vllm \
|
||||
&& pip install *.whl
|
||||
|
||||
ARG COMMON_WORKDIR
|
||||
|
||||
# Copy over the benchmark scripts as well
|
||||
COPY --from=export_vllm /benchmarks ${COMMON_WORKDIR}/vllm/benchmarks
|
||||
COPY --from=export_vllm /examples ${COMMON_WORKDIR}/vllm/examples
|
||||
|
||||
ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1
|
||||
ENV TOKENIZERS_PARALLELISM=false
|
||||
|
||||
# Performance environment variable.
|
||||
ENV HIP_FORCE_DEV_KERNARG=1
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
|
||||
158
Dockerfile.rocm_base
Normal file
158
Dockerfile.rocm_base
Normal file
@ -0,0 +1,158 @@
|
||||
ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:6.3.1-complete
|
||||
ARG HIPBLASLT_BRANCH="4d40e36"
|
||||
ARG HIPBLAS_COMMON_BRANCH="7c1566b"
|
||||
ARG LEGACY_HIPBLASLT_OPTION=
|
||||
ARG RCCL_BRANCH="648a58d"
|
||||
ARG RCCL_REPO="https://github.com/ROCm/rccl"
|
||||
ARG TRITON_BRANCH="e5be006"
|
||||
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
|
||||
ARG PYTORCH_BRANCH="8d4926e"
|
||||
ARG PYTORCH_VISION_BRANCH="v0.19.1"
|
||||
ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
|
||||
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
|
||||
ARG FA_BRANCH="b7d29fb"
|
||||
ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
|
||||
|
||||
FROM ${BASE_IMAGE} AS base
|
||||
|
||||
ENV PATH=/opt/rocm/llvm/bin:$PATH
|
||||
ENV ROCM_PATH=/opt/rocm
|
||||
ENV LD_LIBRARY_PATH=/opt/rocm/lib:/usr/local/lib:
|
||||
ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942
|
||||
ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
|
||||
|
||||
ARG PYTHON_VERSION=3.12
|
||||
|
||||
RUN mkdir -p /app
|
||||
WORKDIR /app
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Install Python and other dependencies
|
||||
RUN apt-get update -y \
|
||||
&& apt-get install -y software-properties-common git curl sudo vim less \
|
||||
&& add-apt-repository ppa:deadsnakes/ppa \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
|
||||
python${PYTHON_VERSION}-lib2to3 python-is-python3 \
|
||||
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
|
||||
&& update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \
|
||||
&& ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \
|
||||
&& curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \
|
||||
&& python3 --version && python3 -m pip --version
|
||||
|
||||
RUN pip install -U packaging cmake ninja wheel setuptools pybind11 Cython
|
||||
|
||||
FROM base AS build_hipblaslt
|
||||
ARG HIPBLASLT_BRANCH
|
||||
ARG HIPBLAS_COMMON_BRANCH
|
||||
# Set to "--legacy_hipblas_direct" for ROCm<=6.2
|
||||
ARG LEGACY_HIPBLASLT_OPTION
|
||||
RUN git clone https://github.com/ROCm/hipBLAS-common.git
|
||||
RUN cd hipBLAS-common \
|
||||
&& git checkout ${HIPBLAS_COMMON_BRANCH} \
|
||||
&& mkdir build \
|
||||
&& cd build \
|
||||
&& cmake .. \
|
||||
&& make package \
|
||||
&& dpkg -i ./*.deb
|
||||
RUN git clone https://github.com/ROCm/hipBLASLt
|
||||
RUN cd hipBLASLt \
|
||||
&& git checkout ${HIPBLASLT_BRANCH} \
|
||||
&& ./install.sh -d --architecture ${PYTORCH_ROCM_ARCH} ${LEGACY_HIPBLASLT_OPTION} \
|
||||
&& cd build/release \
|
||||
&& make package
|
||||
RUN mkdir -p /app/install && cp /app/hipBLASLt/build/release/*.deb /app/hipBLAS-common/build/*.deb /app/install
|
||||
|
||||
FROM base AS build_rccl
|
||||
ARG RCCL_BRANCH
|
||||
ARG RCCL_REPO
|
||||
RUN git clone ${RCCL_REPO}
|
||||
RUN cd rccl \
|
||||
&& git checkout ${RCCL_BRANCH} \
|
||||
&& ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH}
|
||||
RUN mkdir -p /app/install && cp /app/rccl/build/release/*.deb /app/install
|
||||
|
||||
FROM base AS build_triton
|
||||
ARG TRITON_BRANCH
|
||||
ARG TRITON_REPO
|
||||
RUN git clone ${TRITON_REPO}
|
||||
RUN cd triton \
|
||||
&& git checkout ${TRITON_BRANCH} \
|
||||
&& cd python \
|
||||
&& python3 setup.py bdist_wheel --dist-dir=dist
|
||||
RUN mkdir -p /app/install && cp /app/triton/python/dist/*.whl /app/install
|
||||
|
||||
FROM base AS build_amdsmi
|
||||
RUN cd /opt/rocm/share/amd_smi \
|
||||
&& pip wheel . --wheel-dir=dist
|
||||
RUN mkdir -p /app/install && cp /opt/rocm/share/amd_smi/dist/*.whl /app/install
|
||||
|
||||
FROM base AS build_pytorch
|
||||
ARG PYTORCH_BRANCH
|
||||
ARG PYTORCH_VISION_BRANCH
|
||||
ARG PYTORCH_REPO
|
||||
ARG PYTORCH_VISION_REPO
|
||||
ARG FA_BRANCH
|
||||
ARG FA_REPO
|
||||
RUN git clone ${PYTORCH_REPO} pytorch
|
||||
RUN cd pytorch && git checkout ${PYTORCH_BRANCH} && \
|
||||
pip install -r requirements.txt && git submodule update --init --recursive \
|
||||
&& python3 tools/amd_build/build_amd.py \
|
||||
&& CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist \
|
||||
&& pip install dist/*.whl
|
||||
RUN git clone ${PYTORCH_VISION_REPO} vision
|
||||
RUN cd vision && git checkout ${PYTORCH_VISION_BRANCH} \
|
||||
&& python3 setup.py bdist_wheel --dist-dir=dist \
|
||||
&& pip install dist/*.whl
|
||||
RUN git clone ${FA_REPO}
|
||||
RUN cd flash-attention \
|
||||
&& git checkout ${FA_BRANCH} \
|
||||
&& git submodule update --init \
|
||||
&& MAX_JOBS=64 GPU_ARCHS=${PYTORCH_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist
|
||||
RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \
|
||||
&& cp /app/vision/dist/*.whl /app/install \
|
||||
&& cp /app/flash-attention/dist/*.whl /app/install
|
||||
|
||||
FROM base AS final
|
||||
RUN --mount=type=bind,from=build_hipblaslt,src=/app/install/,target=/install \
|
||||
dpkg -i /install/*deb \
|
||||
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
|
||||
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status
|
||||
RUN --mount=type=bind,from=build_rccl,src=/app/install/,target=/install \
|
||||
dpkg -i /install/*deb \
|
||||
&& sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \
|
||||
&& sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status
|
||||
RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \
|
||||
pip install /install/*.whl
|
||||
RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
|
||||
pip install /install/*.whl
|
||||
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
|
||||
pip install /install/*.whl
|
||||
|
||||
ARG BASE_IMAGE
|
||||
ARG HIPBLASLT_BRANCH
|
||||
ARG LEGACY_HIPBLASLT_OPTION
|
||||
ARG RCCL_BRANCH
|
||||
ARG RCCL_REPO
|
||||
ARG TRITON_BRANCH
|
||||
ARG TRITON_REPO
|
||||
ARG PYTORCH_BRANCH
|
||||
ARG PYTORCH_VISION_BRANCH
|
||||
ARG PYTORCH_REPO
|
||||
ARG PYTORCH_VISION_REPO
|
||||
ARG FA_BRANCH
|
||||
ARG FA_REPO
|
||||
RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \
|
||||
&& echo "HIPBLAS_COMMON_BRANCH: ${HIPBLAS_COMMON_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "HIPBLASLT_BRANCH: ${HIPBLASLT_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "LEGACY_HIPBLASLT_OPTION: ${LEGACY_HIPBLASLT_OPTION}" >> /app/versions.txt \
|
||||
&& echo "RCCL_BRANCH: ${RCCL_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "RCCL_REPO: ${RCCL_REPO}" >> /app/versions.txt \
|
||||
&& echo "TRITON_BRANCH: ${TRITON_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "TRITON_REPO: ${TRITON_REPO}" >> /app/versions.txt \
|
||||
&& echo "PYTORCH_BRANCH: ${PYTORCH_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "PYTORCH_VISION_BRANCH: ${PYTORCH_VISION_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "PYTORCH_REPO: ${PYTORCH_REPO}" >> /app/versions.txt \
|
||||
&& echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \
|
||||
&& echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt
|
||||
@ -1,4 +1,4 @@
|
||||
ARG NIGHTLY_DATE="20241017"
|
||||
ARG NIGHTLY_DATE="20250122"
|
||||
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"
|
||||
|
||||
FROM $BASE_IMAGE
|
||||
|
||||
@ -15,11 +15,8 @@ Easy, fast, and cheap LLM serving for everyone
|
||||
|
||||
---
|
||||
|
||||
The first vLLM meetup in 2025 is happening on January 22nd, Wednesday, with Google Cloud in San Francisco! We will talk about vLLM's performant V1 architecture, Q1 roadmap, Google Cloud's innovation around vLLM: networking, Cloud Run, Vertex, and TPU! [Register Now](https://lu.ma/zep56hui)
|
||||
|
||||
---
|
||||
|
||||
*Latest News* 🔥
|
||||
- [2025/01] We hosted [the eighth vLLM meetup](https://lu.ma/zep56hui) with Google Cloud! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1epVkt4Zu8Jz_S5OhEHPc798emsYh2BwYfRuDDVEF7u4/edit?usp=sharing).
|
||||
- [2024/12] vLLM joins [pytorch ecosystem](https://pytorch.org/blog/vllm-joins-pytorch)! Easy, Fast, and Cheap LLM Serving for Everyone!
|
||||
- [2024/11] We hosted [the seventh vLLM meetup](https://lu.ma/h0qvrajz) with Snowflake! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1e3CxQBV3JsfGp30SwyvS3eM_tW-ghOhJ9PAJGK6KR54/edit?usp=sharing), and Snowflake team [here](https://docs.google.com/presentation/d/1qF3RkDAbOULwz9WK5TOltt2fE9t6uIc_hVNLFAaQX6A/edit?usp=sharing).
|
||||
- [2024/10] We have just created a developer slack ([slack.vllm.ai](https://slack.vllm.ai)) focusing on coordinating contributions and discussing features. Please feel free to join us there!
|
||||
|
||||
@ -35,6 +35,7 @@ class RequestFuncOutput:
|
||||
generated_text: str = ""
|
||||
success: bool = False
|
||||
latency: float = 0.0
|
||||
output_tokens: int = 0
|
||||
ttft: float = 0.0 # Time to first token
|
||||
itl: List[float] = field(
|
||||
default_factory=list) # List of inter-token latencies
|
||||
@ -156,7 +157,7 @@ async def async_request_trt_llm(
|
||||
timestamp = time.perf_counter()
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = time.perf_counter() - st
|
||||
ttft = timestamp - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
@ -244,8 +245,12 @@ async def async_request_openai_completions(
|
||||
"max_tokens": request_func_input.output_len,
|
||||
"logprobs": request_func_input.logprobs,
|
||||
"stream": True,
|
||||
"ignore_eos": request_func_input.ignore_eos,
|
||||
"stream_options": {
|
||||
"include_usage": True,
|
||||
},
|
||||
}
|
||||
if request_func_input.ignore_eos:
|
||||
payload["ignore_eos"] = request_func_input.ignore_eos
|
||||
if request_func_input.extra_body:
|
||||
payload.update(request_func_input.extra_body)
|
||||
headers = {
|
||||
@ -256,7 +261,6 @@ async def async_request_openai_completions(
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
generated_text = ""
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
@ -271,15 +275,16 @@ async def async_request_openai_completions(
|
||||
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
||||
"data: ")
|
||||
if chunk == "[DONE]":
|
||||
latency = time.perf_counter() - st
|
||||
else:
|
||||
if chunk != "[DONE]":
|
||||
data = json.loads(chunk)
|
||||
|
||||
# NOTE: Some completion API might have a last
|
||||
# usage summary response without a token so we
|
||||
# want to check a token was generated
|
||||
if data["choices"][0]["text"]:
|
||||
if choices := data.get("choices"):
|
||||
# Note that text could be empty here
|
||||
# e.g. for special tokens
|
||||
text = choices[0].get("text")
|
||||
timestamp = time.perf_counter()
|
||||
# First token
|
||||
if not first_chunk_received:
|
||||
@ -293,7 +298,10 @@ async def async_request_openai_completions(
|
||||
most_recent_timestamp)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
generated_text += data["choices"][0]["text"]
|
||||
generated_text += text or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens")
|
||||
if first_chunk_received:
|
||||
output.success = True
|
||||
else:
|
||||
@ -302,7 +310,7 @@ async def async_request_openai_completions(
|
||||
"Never received a valid chunk to calculate TTFT."
|
||||
"This response will be marked as failed!")
|
||||
output.generated_text = generated_text
|
||||
output.latency = latency
|
||||
output.latency = most_recent_timestamp - st
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
@ -341,8 +349,12 @@ async def async_request_openai_chat_completions(
|
||||
"temperature": 0.0,
|
||||
"max_completion_tokens": request_func_input.output_len,
|
||||
"stream": True,
|
||||
"ignore_eos": request_func_input.ignore_eos,
|
||||
"stream_options": {
|
||||
"include_usage": True,
|
||||
},
|
||||
}
|
||||
if request_func_input.ignore_eos:
|
||||
payload["ignore_eos"] = request_func_input.ignore_eos
|
||||
if request_func_input.extra_body:
|
||||
payload.update(request_func_input.extra_body)
|
||||
headers = {
|
||||
@ -368,17 +380,15 @@ async def async_request_openai_chat_completions(
|
||||
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
||||
"data: ")
|
||||
if chunk == "[DONE]":
|
||||
latency = time.perf_counter() - st
|
||||
else:
|
||||
if chunk != "[DONE]":
|
||||
timestamp = time.perf_counter()
|
||||
data = json.loads(chunk)
|
||||
|
||||
delta = data["choices"][0]["delta"]
|
||||
if delta.get("content", None):
|
||||
if choices := data.get("choices"):
|
||||
content = choices[0]["delta"].get("content")
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = time.perf_counter() - st
|
||||
ttft = timestamp - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
@ -386,13 +396,16 @@ async def async_request_openai_chat_completions(
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
|
||||
generated_text += delta["content"]
|
||||
generated_text += content or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens")
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
output.generated_text = generated_text
|
||||
output.success = True
|
||||
output.latency = latency
|
||||
output.latency = most_recent_timestamp - st
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
|
||||
@ -25,6 +25,7 @@ On the client side, run:
|
||||
import argparse
|
||||
import asyncio
|
||||
import base64
|
||||
import gc
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
@ -423,7 +424,7 @@ def calculate_metrics(
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
selected_percentile_metrics: List[str],
|
||||
selected_percentiles: List[float],
|
||||
gootput_config_dict: Dict[str, float],
|
||||
goodput_config_dict: Dict[str, float],
|
||||
) -> Tuple[BenchmarkMetrics, List[int]]:
|
||||
actual_output_lens: List[int] = []
|
||||
total_input = 0
|
||||
@ -436,19 +437,23 @@ def calculate_metrics(
|
||||
e2els: List[float] = []
|
||||
for i in range(len(outputs)):
|
||||
if outputs[i].success:
|
||||
# We use the tokenizer to count the number of output tokens for all
|
||||
# serving backends instead of looking at len(outputs[i].itl) since
|
||||
# multiple output tokens may be bundled together
|
||||
# Note : this may inflate the output token count slightly
|
||||
output_len = len(
|
||||
tokenizer(outputs[i].generated_text,
|
||||
add_special_tokens=False).input_ids)
|
||||
output_len = outputs[i].output_tokens
|
||||
|
||||
if output_len is None:
|
||||
# We use the tokenizer to count the number of output tokens
|
||||
# for some serving backends instead of looking at
|
||||
# len(outputs[i].itl) since multiple output tokens may be
|
||||
# bundled together
|
||||
# Note : this may inflate the output token count slightly
|
||||
output_len = len(
|
||||
tokenizer(outputs[i].generated_text,
|
||||
add_special_tokens=False).input_ids)
|
||||
actual_output_lens.append(output_len)
|
||||
total_input += input_requests[i][1]
|
||||
tpot = 0
|
||||
if output_len > 1:
|
||||
tpot = (outputs[i].latency - outputs[i].ttft) / (output_len -
|
||||
1)
|
||||
latency_minus_ttft = outputs[i].latency - outputs[i].ttft
|
||||
tpot = latency_minus_ttft / (output_len - 1)
|
||||
tpots.append(tpot)
|
||||
# Note: if output_len <= 1, we regard tpot as 0 for goodput
|
||||
all_tpots.append(tpot)
|
||||
@ -459,21 +464,21 @@ def calculate_metrics(
|
||||
else:
|
||||
actual_output_lens.append(0)
|
||||
|
||||
if gootput_config_dict:
|
||||
if goodput_config_dict:
|
||||
valid_metrics = []
|
||||
slo_values = []
|
||||
|
||||
if "ttft" in gootput_config_dict:
|
||||
if "ttft" in goodput_config_dict:
|
||||
valid_metrics.append(ttfts)
|
||||
slo_values.append(gootput_config_dict["ttft"] /
|
||||
slo_values.append(goodput_config_dict["ttft"] /
|
||||
MILLISECONDS_TO_SECONDS_CONVERSION)
|
||||
if "tpot" in gootput_config_dict:
|
||||
if "tpot" in goodput_config_dict:
|
||||
valid_metrics.append(all_tpots)
|
||||
slo_values.append(gootput_config_dict["tpot"] /
|
||||
slo_values.append(goodput_config_dict["tpot"] /
|
||||
MILLISECONDS_TO_SECONDS_CONVERSION)
|
||||
if "e2el" in gootput_config_dict:
|
||||
if "e2el" in goodput_config_dict:
|
||||
valid_metrics.append(e2els)
|
||||
slo_values.append(gootput_config_dict["e2el"] /
|
||||
slo_values.append(goodput_config_dict["e2el"] /
|
||||
MILLISECONDS_TO_SECONDS_CONVERSION)
|
||||
|
||||
for req_metric in zip(*valid_metrics):
|
||||
@ -537,7 +542,7 @@ async def benchmark(
|
||||
selected_percentile_metrics: List[str],
|
||||
selected_percentiles: List[str],
|
||||
ignore_eos: bool,
|
||||
gootput_config_dict: Dict[str, float],
|
||||
goodput_config_dict: Dict[str, float],
|
||||
max_concurrency: Optional[int],
|
||||
):
|
||||
if backend in ASYNC_REQUEST_FUNCS:
|
||||
@ -661,7 +666,7 @@ async def benchmark(
|
||||
tokenizer=tokenizer,
|
||||
selected_percentile_metrics=selected_percentile_metrics,
|
||||
selected_percentiles=selected_percentiles,
|
||||
gootput_config_dict=gootput_config_dict,
|
||||
goodput_config_dict=goodput_config_dict,
|
||||
)
|
||||
|
||||
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
|
||||
@ -673,7 +678,7 @@ async def benchmark(
|
||||
metrics.total_output))
|
||||
print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
|
||||
metrics.request_throughput))
|
||||
if gootput_config_dict:
|
||||
if goodput_config_dict:
|
||||
print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
|
||||
metrics.request_goodput))
|
||||
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
|
||||
@ -688,7 +693,7 @@ async def benchmark(
|
||||
"total_output_tokens": metrics.total_output,
|
||||
"request_throughput": metrics.request_throughput,
|
||||
"request_goodput:":
|
||||
metrics.request_goodput if gootput_config_dict else None,
|
||||
metrics.request_goodput if goodput_config_dict else None,
|
||||
"output_throughput": metrics.output_throughput,
|
||||
"total_token_throughput": metrics.total_token_throughput,
|
||||
"input_lens": [output.prompt_len for output in outputs],
|
||||
@ -744,11 +749,11 @@ async def benchmark(
|
||||
|
||||
def check_goodput_args(args):
|
||||
# Check and parse goodput arguments
|
||||
gootput_config_dict = {}
|
||||
goodput_config_dict = {}
|
||||
VALID_NAMES = ["ttft", "tpot", "e2el"]
|
||||
if args.goodput:
|
||||
gootput_config_dict = parse_goodput(args.goodput)
|
||||
for slo_name, slo_val in gootput_config_dict.items():
|
||||
goodput_config_dict = parse_goodput(args.goodput)
|
||||
for slo_name, slo_val in goodput_config_dict.items():
|
||||
if slo_name not in VALID_NAMES:
|
||||
raise ValueError(
|
||||
f"Invalid metric name found, {slo_name}: {slo_val}. "
|
||||
@ -759,22 +764,22 @@ def check_goodput_args(args):
|
||||
f"Invalid value found, {slo_name}: {slo_val}. "
|
||||
"The service level objective value should be "
|
||||
"non-negative.")
|
||||
return gootput_config_dict
|
||||
return goodput_config_dict
|
||||
|
||||
|
||||
def parse_goodput(slo_pairs):
|
||||
gootput_config_dict = {}
|
||||
goodput_config_dict = {}
|
||||
try:
|
||||
for slo_pair in slo_pairs:
|
||||
slo_name, slo_val = slo_pair.split(":")
|
||||
gootput_config_dict[slo_name] = float(slo_val)
|
||||
goodput_config_dict[slo_name] = float(slo_val)
|
||||
except ValueError as err:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"Invalid format found for service level objectives. "
|
||||
"Specify service level objectives for goodput as \"KEY:VALUE\" "
|
||||
"pairs, where the key is a metric name, and the value is a "
|
||||
"number in milliseconds.") from err
|
||||
return gootput_config_dict
|
||||
return goodput_config_dict
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
@ -874,7 +879,11 @@ def main(args: argparse.Namespace):
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset: {args.dataset_name}")
|
||||
|
||||
gootput_config_dict = check_goodput_args(args)
|
||||
goodput_config_dict = check_goodput_args(args)
|
||||
|
||||
# Avoid GC processing "static" data - reduce pause times.
|
||||
gc.collect()
|
||||
gc.freeze()
|
||||
|
||||
benchmark_result = asyncio.run(
|
||||
benchmark(
|
||||
@ -896,7 +905,7 @@ def main(args: argparse.Namespace):
|
||||
float(p) for p in args.metric_percentiles.split(",")
|
||||
],
|
||||
ignore_eos=args.ignore_eos,
|
||||
gootput_config_dict=gootput_config_dict,
|
||||
goodput_config_dict=goodput_config_dict,
|
||||
max_concurrency=args.max_concurrency,
|
||||
))
|
||||
|
||||
|
||||
@ -12,10 +12,10 @@ from transformers import AutoConfig
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import *
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import FlexibleArgumentParser, is_navi
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
FP8_DTYPE = torch.float8_e4m3fnuz if current_platform.is_rocm(
|
||||
) and not is_navi() else torch.float8_e4m3fn
|
||||
) else torch.float8_e4m3fn
|
||||
|
||||
|
||||
class BenchmarkConfig(TypedDict):
|
||||
|
||||
@ -98,7 +98,9 @@ def main(
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Using default kv_scale
|
||||
k_scale = v_scale = 1.0
|
||||
k_scale = v_scale = torch.tensor(1.0,
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
|
||||
for _ in range(num_iters):
|
||||
if version == "v1":
|
||||
|
||||
@ -105,7 +105,7 @@ __device__ void paged_attention_kernel(
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||
const float k_scale, const float v_scale, const int tp_rank,
|
||||
const float* k_scale, const float* v_scale, const int tp_rank,
|
||||
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
||||
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
||||
const int seq_idx = blockIdx.y;
|
||||
@ -285,7 +285,7 @@ __device__ void paged_attention_kernel(
|
||||
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
|
||||
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
||||
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
|
||||
k_vec_quant, k_scale);
|
||||
k_vec_quant, *k_scale);
|
||||
}
|
||||
}
|
||||
|
||||
@ -415,7 +415,7 @@ __device__ void paged_attention_kernel(
|
||||
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
||||
// Vector conversion from V_quant_vec to V_vec.
|
||||
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
|
||||
v_scale);
|
||||
*v_scale);
|
||||
}
|
||||
if (block_idx == num_seq_blocks - 1) {
|
||||
// NOTE(woosuk): When v_vec contains the tokens that are out of the
|
||||
@ -513,7 +513,7 @@ __global__ void paged_attention_v1_kernel(
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||
const float k_scale, const float v_scale, const int tp_rank,
|
||||
const float* k_scale, const float* v_scale, const int tp_rank,
|
||||
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
||||
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
||||
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
|
||||
@ -549,7 +549,7 @@ __global__ void paged_attention_v2_kernel(
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||
const float k_scale, const float v_scale, const int tp_rank,
|
||||
const float* k_scale, const float* v_scale, const int tp_rank,
|
||||
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
||||
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
||||
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
|
||||
|
||||
@ -41,7 +41,7 @@
|
||||
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
|
||||
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
|
||||
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
|
||||
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
|
||||
k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \
|
||||
blocksparse_vert_stride, blocksparse_block_size, \
|
||||
blocksparse_head_sliding_step);
|
||||
|
||||
@ -53,10 +53,10 @@ void paged_attention_v1_launcher(
|
||||
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
||||
const std::optional<torch::Tensor>& alibi_slopes, float k_scale,
|
||||
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
|
||||
const int blocksparse_vert_stride, const int blocksparse_block_size,
|
||||
const int blocksparse_head_sliding_step) {
|
||||
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale, const int tp_rank,
|
||||
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
||||
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
||||
int num_seqs = query.size(0);
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
@ -80,6 +80,8 @@ void paged_attention_v1_launcher(
|
||||
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
|
||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int* seq_lens_ptr = seq_lens.data_ptr<int>();
|
||||
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
|
||||
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
|
||||
|
||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
int padded_max_seq_len =
|
||||
@ -177,8 +179,9 @@ void paged_attention_v1(
|
||||
torch::Tensor& seq_lens, // [num_seqs]
|
||||
int64_t block_size, int64_t max_seq_len,
|
||||
const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale, const int64_t tp_rank,
|
||||
const int64_t blocksparse_local_blocks,
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
const int64_t blocksparse_head_sliding_step) {
|
||||
const bool is_block_sparse = (blocksparse_vert_stride > 1);
|
||||
|
||||
@ -37,7 +37,7 @@
|
||||
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
|
||||
value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
|
||||
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
|
||||
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
|
||||
kv_block_stride, kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \
|
||||
blocksparse_local_blocks, blocksparse_vert_stride, \
|
||||
blocksparse_block_size, blocksparse_head_sliding_step); \
|
||||
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
|
||||
@ -54,10 +54,10 @@ void paged_attention_v2_launcher(
|
||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
||||
const std::optional<torch::Tensor>& alibi_slopes, float k_scale,
|
||||
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
|
||||
const int blocksparse_vert_stride, const int blocksparse_block_size,
|
||||
const int blocksparse_head_sliding_step) {
|
||||
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale, const int tp_rank,
|
||||
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
||||
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
||||
int num_seqs = query.size(0);
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
@ -84,6 +84,8 @@ void paged_attention_v2_launcher(
|
||||
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
|
||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int* seq_lens_ptr = seq_lens.data_ptr<int>();
|
||||
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
|
||||
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
|
||||
|
||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
|
||||
@ -188,8 +190,9 @@ void paged_attention_v2(
|
||||
torch::Tensor& seq_lens, // [num_seqs]
|
||||
int64_t block_size, int64_t max_seq_len,
|
||||
const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale, const int64_t tp_rank,
|
||||
const int64_t blocksparse_local_blocks,
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
const int64_t blocksparse_head_sliding_step) {
|
||||
const bool is_block_sparse = (blocksparse_vert_stride > 1);
|
||||
|
||||
@ -18,15 +18,15 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
|
||||
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping,
|
||||
const std::string& kv_cache_dtype, const double k_scale,
|
||||
const double v_scale);
|
||||
const std::string& kv_cache_dtype,
|
||||
torch::Tensor& k_scale, torch::Tensor& v_scale);
|
||||
|
||||
void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping,
|
||||
const std::string& kv_cache_dtype,
|
||||
const double k_scale, const double v_scale);
|
||||
torch::Tensor& k_scale, torch::Tensor& v_scale);
|
||||
|
||||
// Just for unittest
|
||||
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||
|
||||
@ -159,8 +159,8 @@ __global__ void reshape_and_cache_kernel(
|
||||
// block_size]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int key_stride, const int value_stride, const int num_heads,
|
||||
const int head_size, const int block_size, const int x, const float k_scale,
|
||||
const float v_scale) {
|
||||
const int head_size, const int block_size, const int x,
|
||||
const float* k_scale, const float* v_scale) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
if (slot_idx < 0) {
|
||||
@ -196,9 +196,9 @@ __global__ void reshape_and_cache_kernel(
|
||||
value_cache[tgt_value_idx] = tgt_value;
|
||||
} else {
|
||||
key_cache[tgt_key_idx] =
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, k_scale);
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
|
||||
value_cache[tgt_value_idx] =
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, v_scale);
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -214,7 +214,7 @@ __global__ void reshape_and_cache_flash_kernel(
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int block_stride, const int key_stride, const int value_stride,
|
||||
const int num_heads, const int head_size, const int block_size,
|
||||
const float k_scale, const float v_scale) {
|
||||
const float* k_scale, const float* v_scale) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
// NOTE: slot_idx can be -1 if the token is padded
|
||||
@ -239,9 +239,9 @@ __global__ void reshape_and_cache_flash_kernel(
|
||||
value_cache[tgt_key_value_idx] = tgt_value;
|
||||
} else {
|
||||
key_cache[tgt_key_value_idx] =
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, k_scale);
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
|
||||
value_cache[tgt_key_value_idx] =
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, v_scale);
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -258,7 +258,9 @@ __global__ void reshape_and_cache_flash_kernel(
|
||||
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
|
||||
num_heads, head_size, block_size, x, k_scale, v_scale);
|
||||
num_heads, head_size, block_size, x, \
|
||||
reinterpret_cast<const float*>(k_scale.data_ptr()), \
|
||||
reinterpret_cast<const float*>(v_scale.data_ptr()));
|
||||
|
||||
void reshape_and_cache(
|
||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||
@ -268,8 +270,8 @@ void reshape_and_cache(
|
||||
torch::Tensor&
|
||||
value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
torch::Tensor& slot_mapping, // [num_tokens]
|
||||
const std::string& kv_cache_dtype, const double k_scale,
|
||||
const double v_scale) {
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale) {
|
||||
int num_tokens = key.size(0);
|
||||
int num_heads = key.size(1);
|
||||
int head_size = key.size(2);
|
||||
@ -299,7 +301,9 @@ void reshape_and_cache(
|
||||
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \
|
||||
value_stride, num_heads, head_size, block_size, k_scale, v_scale);
|
||||
value_stride, num_heads, head_size, block_size, \
|
||||
reinterpret_cast<const float*>(k_scale.data_ptr()), \
|
||||
reinterpret_cast<const float*>(v_scale.data_ptr()));
|
||||
|
||||
void reshape_and_cache_flash(
|
||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||
@ -308,8 +312,8 @@ void reshape_and_cache_flash(
|
||||
torch::Tensor&
|
||||
value_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
|
||||
const std::string& kv_cache_dtype, const double k_scale,
|
||||
const double v_scale) {
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale) {
|
||||
// NOTE(woosuk): In vLLM V1, key.size(0) can be different from
|
||||
// slot_mapping.size(0) because of padding for CUDA graphs.
|
||||
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
|
||||
|
||||
@ -460,11 +460,11 @@ void paged_attention_v1(
|
||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale, const int64_t tp_rank,
|
||||
const int64_t blocksparse_local_blocks,
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
const int64_t blocksparse_head_sliding_step) {
|
||||
TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);
|
||||
TORCH_CHECK(blocksparse_vert_stride <= 1,
|
||||
"CPU backend does not support blocksparse attention yet.");
|
||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
|
||||
@ -782,11 +782,11 @@ void paged_attention_v2(
|
||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale, const int64_t tp_rank,
|
||||
const int64_t blocksparse_local_blocks,
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
const int64_t blocksparse_head_sliding_step) {
|
||||
TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);
|
||||
TORCH_CHECK(blocksparse_vert_stride <= 1,
|
||||
"CPU backend does not support blocksparse attention yet.");
|
||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
|
||||
|
||||
@ -107,10 +107,8 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
|
||||
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping,
|
||||
const std::string& kv_cache_dtype, double k_scale,
|
||||
double v_scale) {
|
||||
TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);
|
||||
|
||||
const std::string& kv_cache_dtype,
|
||||
torch::Tensor& k_scale, torch::Tensor& v_scale) {
|
||||
int num_tokens = key.size(0);
|
||||
int num_heads = key.size(1);
|
||||
int head_size = key.size(2);
|
||||
|
||||
@ -30,7 +30,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor value_cache, int num_kv_heads, float scale,"
|
||||
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
||||
" int max_seq_len, Tensor? alibi_slopes,"
|
||||
" str kv_cache_dtype, float k_scale, float v_scale,"
|
||||
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
|
||||
" int tp_rank, int blocksparse_local_blocks,"
|
||||
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
||||
" int blocksparse_head_sliding_step) -> ()");
|
||||
@ -44,7 +44,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor value_cache, int num_kv_heads, float scale,"
|
||||
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
||||
" int max_seq_len, Tensor? alibi_slopes,"
|
||||
" str kv_cache_dtype, float k_scale, float v_scale,"
|
||||
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
|
||||
" int tp_rank, int blocksparse_local_blocks,"
|
||||
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
||||
" int blocksparse_head_sliding_step) -> ()");
|
||||
@ -148,7 +148,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
" Tensor! key_cache, Tensor! value_cache,"
|
||||
" Tensor slot_mapping,"
|
||||
" str kv_cache_dtype,"
|
||||
" float k_scale, float v_scale) -> ()");
|
||||
" Tensor k_scale, Tensor v_scale) -> ()");
|
||||
cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache);
|
||||
}
|
||||
|
||||
|
||||
310
csrc/cumem_allocator.cpp
Normal file
310
csrc/cumem_allocator.cpp
Normal file
@ -0,0 +1,310 @@
|
||||
// A CUDAPluggableAllocator based on cumem* APIs.
|
||||
// Important: allocation size, CUdeviceptr and CUmemGenericAllocationHandle*
|
||||
// need to be unsigned long long
|
||||
#include <iostream>
|
||||
|
||||
extern "C" {
|
||||
|
||||
#define PY_SSIZE_T_CLEAN
|
||||
#include <Python.h>
|
||||
|
||||
#include <sys/types.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cuda.h>
|
||||
|
||||
#define CUDA_CHECK(condition) \
|
||||
do { \
|
||||
CUresult error = condition; \
|
||||
if (error != 0) { \
|
||||
char* error_string; \
|
||||
cuGetErrorString(error, (const char**)&error_string); \
|
||||
std::cerr << "CUDA Error: " << error_string << " at " << __FILE__ << ":" \
|
||||
<< __LINE__ << std::endl; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// Global references to Python callables
|
||||
// NOTE: this is borrowed reference, so we don't need to DECREF them.
|
||||
// This brings the limitation that the allocator needs to be singleton.
|
||||
static PyObject* g_python_malloc_callback = nullptr;
|
||||
static PyObject* g_python_free_callback = nullptr;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helper functions:
|
||||
|
||||
void ensure_context(unsigned long long device) {
|
||||
CUcontext pctx;
|
||||
CUDA_CHECK(cuCtxGetCurrent(&pctx));
|
||||
if (!pctx) {
|
||||
// Ensure device context.
|
||||
CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device));
|
||||
CUDA_CHECK(cuCtxSetCurrent(pctx));
|
||||
}
|
||||
}
|
||||
|
||||
void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem,
|
||||
CUmemGenericAllocationHandle* p_memHandle) {
|
||||
ensure_context(device);
|
||||
// Define memory allocation properties
|
||||
CUmemAllocationProp prop = {};
|
||||
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
|
||||
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
|
||||
prop.location.id = device;
|
||||
prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE;
|
||||
|
||||
// Allocate memory using cuMemCreate
|
||||
CUDA_CHECK(cuMemCreate(p_memHandle, size, &prop, 0));
|
||||
CUDA_CHECK(cuMemMap(d_mem, size, 0, *p_memHandle, 0));
|
||||
|
||||
CUmemAccessDesc accessDesc = {};
|
||||
accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
|
||||
accessDesc.location.id = device;
|
||||
accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
|
||||
|
||||
CUDA_CHECK(cuMemSetAccess(d_mem, size, &accessDesc, 1));
|
||||
// std::cout << "create_and_map: device=" << device << ", size=" << size << ",
|
||||
// d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl;
|
||||
}
|
||||
|
||||
void unmap_and_release(unsigned long long device, ssize_t size,
|
||||
CUdeviceptr d_mem,
|
||||
CUmemGenericAllocationHandle* p_memHandle) {
|
||||
// std::cout << "unmap_and_release: device=" << device << ", size=" << size <<
|
||||
// ", d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl;
|
||||
ensure_context(device);
|
||||
CUDA_CHECK(cuMemUnmap(d_mem, size));
|
||||
CUDA_CHECK(cuMemRelease(*p_memHandle));
|
||||
}
|
||||
|
||||
PyObject* create_tuple_from_c_integers(unsigned long long a,
|
||||
unsigned long long b,
|
||||
unsigned long long c,
|
||||
unsigned long long d) {
|
||||
// Create a new tuple of size 4
|
||||
PyObject* tuple = PyTuple_New(4);
|
||||
if (!tuple) {
|
||||
return NULL; // Return NULL on failure
|
||||
}
|
||||
|
||||
// Convert integers to Python objects and set them in the tuple
|
||||
PyTuple_SetItem(
|
||||
tuple, 0,
|
||||
PyLong_FromUnsignedLongLong(a)); // Steals reference to the PyLong
|
||||
PyTuple_SetItem(tuple, 1, PyLong_FromUnsignedLongLong(b));
|
||||
PyTuple_SetItem(tuple, 2, PyLong_FromUnsignedLongLong(c));
|
||||
PyTuple_SetItem(tuple, 3, PyLong_FromUnsignedLongLong(d));
|
||||
|
||||
// Note: PyTuple_SetItem "steals" a reference to each object,
|
||||
// so we do not need to Py_DECREF the PyLong objects explicitly.
|
||||
|
||||
return tuple; // Return the created tuple
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Our exported C functions that call Python:
|
||||
|
||||
// use CUstream instead of cudaStream_t, to avoid including cuda_runtime_api.h
|
||||
void* my_malloc(ssize_t size, int device, CUstream stream) {
|
||||
ensure_context(device);
|
||||
|
||||
// first allocation, align the size, and reserve an address, and also allocate
|
||||
// a CUmemGenericAllocationHandle
|
||||
|
||||
// Define memory allocation properties
|
||||
CUmemAllocationProp prop = {};
|
||||
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
|
||||
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
|
||||
prop.location.id = device;
|
||||
prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE;
|
||||
|
||||
// Check if the allocation is supported
|
||||
size_t granularity;
|
||||
CUDA_CHECK(cuMemGetAllocationGranularity(&granularity, &prop,
|
||||
CU_MEM_ALLOC_GRANULARITY_MINIMUM));
|
||||
|
||||
size_t alignedSize = ((size + granularity - 1) / granularity) * granularity;
|
||||
|
||||
CUdeviceptr d_mem;
|
||||
CUDA_CHECK(cuMemAddressReserve(&d_mem, alignedSize, 0, 0, 0));
|
||||
|
||||
// allocate the CUmemGenericAllocationHandle
|
||||
CUmemGenericAllocationHandle* p_memHandle =
|
||||
(CUmemGenericAllocationHandle*)malloc(
|
||||
sizeof(CUmemGenericAllocationHandle));
|
||||
|
||||
if (!g_python_malloc_callback) {
|
||||
std::cerr << "ERROR: g_python_malloc_callback not set.\n";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Acquire GIL (not in stable ABI officially, but often works)
|
||||
PyGILState_STATE gstate = PyGILState_Ensure();
|
||||
|
||||
PyObject* arg_tuple = create_tuple_from_c_integers(
|
||||
(unsigned long long)device, (unsigned long long)alignedSize,
|
||||
(unsigned long long)d_mem, (unsigned long long)p_memHandle);
|
||||
|
||||
// Call g_python_malloc_callback
|
||||
PyObject* py_result =
|
||||
PyObject_CallFunctionObjArgs(g_python_malloc_callback, arg_tuple, NULL);
|
||||
Py_DECREF(arg_tuple);
|
||||
|
||||
if (!py_result) {
|
||||
PyErr_Print();
|
||||
PyGILState_Release(gstate);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
PyGILState_Release(gstate);
|
||||
|
||||
// do the final mapping
|
||||
create_and_map(device, alignedSize, d_mem, p_memHandle);
|
||||
|
||||
return (void*)d_mem;
|
||||
}
|
||||
|
||||
// use CUstream instead of cudaStream_t, to avoid including cuda_runtime_api.h
|
||||
void my_free(void* ptr, ssize_t size, int device, CUstream stream) {
|
||||
// get memory handle from the pointer
|
||||
if (!g_python_free_callback) {
|
||||
std::cerr << "ERROR: g_python_free_callback not set.\n";
|
||||
return;
|
||||
}
|
||||
|
||||
// Acquire GIL (not in stable ABI officially, but often works)
|
||||
PyGILState_STATE gstate = PyGILState_Ensure();
|
||||
|
||||
PyObject* py_ptr =
|
||||
PyLong_FromUnsignedLongLong(reinterpret_cast<unsigned long long>(ptr));
|
||||
|
||||
PyObject* py_result =
|
||||
PyObject_CallFunctionObjArgs(g_python_free_callback, py_ptr, NULL);
|
||||
|
||||
if (!py_result || !PyTuple_Check(py_result) || PyTuple_Size(py_result) != 4) {
|
||||
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4");
|
||||
return;
|
||||
}
|
||||
|
||||
unsigned long long recv_device, recv_size;
|
||||
unsigned long long recv_d_mem, recv_p_memHandle;
|
||||
// Unpack the tuple into four C integers
|
||||
if (!PyArg_ParseTuple(py_result, "KKKK", &recv_device, &recv_size,
|
||||
&recv_d_mem, &recv_p_memHandle)) {
|
||||
// PyArg_ParseTuple sets an error if it fails
|
||||
return;
|
||||
}
|
||||
|
||||
PyGILState_Release(gstate);
|
||||
|
||||
// recv_size == size
|
||||
// recv_device == device
|
||||
|
||||
// Free memory
|
||||
|
||||
CUdeviceptr d_mem = (CUdeviceptr)recv_d_mem;
|
||||
CUmemGenericAllocationHandle* p_memHandle =
|
||||
(CUmemGenericAllocationHandle*)recv_p_memHandle;
|
||||
unmap_and_release(device, size, d_mem, p_memHandle);
|
||||
|
||||
// free address and the handle
|
||||
CUDA_CHECK(cuMemAddressFree(d_mem, size));
|
||||
free(p_memHandle);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Python extension boilerplate:
|
||||
|
||||
// Python-exposed function: init_module(python_malloc, python_free)
|
||||
static PyObject* py_init_module(PyObject* self, PyObject* args) {
|
||||
PyObject* malloc_callback = nullptr;
|
||||
PyObject* free_callback = nullptr;
|
||||
|
||||
if (!PyArg_ParseTuple(args, "OO", &malloc_callback, &free_callback)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!PyCallable_Check(malloc_callback) || !PyCallable_Check(free_callback)) {
|
||||
PyErr_SetString(PyExc_TypeError, "Both arguments must be callables");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Save the Python callables
|
||||
// This module does not handle GC of these objects, so they must be kept alive
|
||||
// outside of this module.
|
||||
g_python_malloc_callback = malloc_callback;
|
||||
g_python_free_callback = free_callback;
|
||||
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject* python_unmap_and_release(PyObject* self, PyObject* args) {
|
||||
if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) {
|
||||
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
unsigned long long recv_device, recv_size;
|
||||
unsigned long long recv_d_mem, recv_p_memHandle;
|
||||
// Unpack the tuple into four C integers
|
||||
if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem,
|
||||
&recv_p_memHandle)) {
|
||||
// PyArg_ParseTuple sets an error if it fails
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
CUdeviceptr d_mem_ptr = (CUdeviceptr)recv_d_mem;
|
||||
CUmemGenericAllocationHandle* p_memHandle =
|
||||
(CUmemGenericAllocationHandle*)recv_p_memHandle;
|
||||
|
||||
unmap_and_release(recv_device, recv_size, d_mem_ptr, p_memHandle);
|
||||
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject* python_create_and_map(PyObject* self, PyObject* args) {
|
||||
if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) {
|
||||
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
unsigned long long recv_device, recv_size;
|
||||
unsigned long long recv_d_mem, recv_p_memHandle;
|
||||
// Unpack the tuple into four C integers
|
||||
if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem,
|
||||
&recv_p_memHandle)) {
|
||||
// PyArg_ParseTuple sets an error if it fails
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
CUdeviceptr d_mem_ptr = (CUdeviceptr)recv_d_mem;
|
||||
CUmemGenericAllocationHandle* p_memHandle =
|
||||
(CUmemGenericAllocationHandle*)recv_p_memHandle;
|
||||
|
||||
create_and_map(recv_device, recv_size, d_mem_ptr, p_memHandle);
|
||||
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyMethodDef module_methods[] = {
|
||||
{"init_module", (PyCFunction)py_init_module, METH_VARARGS,
|
||||
"Initialize module with python_malloc and python_free callables."},
|
||||
{"python_create_and_map", (PyCFunction)python_create_and_map, METH_VARARGS,
|
||||
"Create and map memory on the device."},
|
||||
{"python_unmap_and_release", (PyCFunction)python_unmap_and_release,
|
||||
METH_VARARGS, "Unmap and release memory on the device."},
|
||||
{NULL, NULL, 0, NULL} // sentinel
|
||||
};
|
||||
|
||||
static struct PyModuleDef cumem_allocator_module = {
|
||||
PyModuleDef_HEAD_INIT, "cumem_allocator",
|
||||
"cumem-based allocator for CUDAPluggableAllocator", -1, module_methods};
|
||||
|
||||
PyMODINIT_FUNC PyInit_cumem_allocator(void) {
|
||||
// Initialize the module
|
||||
PyObject* module = PyModule_Create(&cumem_allocator_module);
|
||||
if (!module) {
|
||||
return NULL;
|
||||
}
|
||||
return module;
|
||||
}
|
||||
} // extern "C"
|
||||
@ -21,7 +21,7 @@ __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <typename scalar_t>
|
||||
template <typename scalar_t, typename token_cnts_t>
|
||||
__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
|
||||
int32_t* sorted_token_ids,
|
||||
int32_t* expert_ids,
|
||||
@ -32,12 +32,8 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
|
||||
const size_t start_idx = threadIdx.x * tokens_per_thread;
|
||||
|
||||
extern __shared__ int32_t shared_mem[];
|
||||
|
||||
int32_t* tokens_cnts =
|
||||
shared_mem; // 2d tensor with shape (blockDim.x + 1, num_experts)
|
||||
int32_t* cumsum =
|
||||
shared_mem +
|
||||
(blockDim.x + 1) * num_experts; // 1d tensor with shape (num_experts + 1)
|
||||
int32_t* cumsum = shared_mem; // 1d tensor with shape (num_experts + 1)
|
||||
token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + blockDim.x + 1);
|
||||
|
||||
for (int i = 0; i < num_experts; ++i) {
|
||||
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
|
||||
@ -74,7 +70,7 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
|
||||
block_size) *
|
||||
block_size;
|
||||
}
|
||||
*total_tokens_post_pad = cumsum[num_experts];
|
||||
*total_tokens_post_pad = static_cast<int32_t>(cumsum[num_experts]);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
@ -224,26 +220,46 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
torch::Tensor num_tokens_post_pad) {
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
// If we have very large number of experts, we can no longer use shared
|
||||
// memory.
|
||||
// TODO(simon): the right solution should be calculating the exact right
|
||||
// amount of shared memory and use that. The num_experts >= 256 is just a
|
||||
// temporary solution to unblock Deepseek V3.
|
||||
if (num_experts >= 256) {
|
||||
int device_max_shared_mem;
|
||||
auto dev = topk_ids.get_device();
|
||||
cudaDeviceGetAttribute(&device_max_shared_mem,
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
||||
|
||||
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
|
||||
const int32_t shared_mem_i32 =
|
||||
((num_thread + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t);
|
||||
const int32_t shared_mem_i16 =
|
||||
((num_thread + 1) * num_experts) * sizeof(uint16_t) +
|
||||
(num_experts + 1) * sizeof(int32_t);
|
||||
|
||||
bool use_global_memory = false;
|
||||
bool use_i16 = false; // Use uint16_t for shared memory token counts
|
||||
if (shared_mem_i32 < device_max_shared_mem) {
|
||||
// Do nothing in this case. We're all set to use int32_t token counts
|
||||
} else if (shared_mem_i16 < device_max_shared_mem &&
|
||||
topk_ids.numel() <= 65535) {
|
||||
// when nelements of topk_ids is smaller than 65535 (max value of uint16),
|
||||
// element value of token_cnts would also smaller than 65535,
|
||||
// so we can use uint16 as dtype of token_cnts
|
||||
use_i16 = true;
|
||||
} else {
|
||||
use_global_memory = true;
|
||||
}
|
||||
|
||||
if (use_global_memory) {
|
||||
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||
topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
|
||||
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
|
||||
// tensors
|
||||
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
|
||||
|
||||
const int32_t mem_tokens_cnts =
|
||||
((num_experts + 1) * num_experts) * sizeof(int32_t);
|
||||
const int32_t mem_cumsum = (num_experts + 1) * sizeof(int32_t);
|
||||
// allocate global memory
|
||||
int32_t* tokens_cnts;
|
||||
int32_t* cumsum;
|
||||
cudaMalloc(&tokens_cnts, mem_tokens_cnts);
|
||||
cudaMalloc(&cumsum, mem_cumsum);
|
||||
auto options_int = torch::TensorOptions()
|
||||
.dtype(torch::kInt)
|
||||
.device(topk_ids.device());
|
||||
torch::Tensor token_cnts_buffer =
|
||||
torch::empty({(num_experts + 1) * num_experts}, options_int);
|
||||
torch::Tensor cumsum_buffer =
|
||||
torch::empty({num_experts + 1}, options_int);
|
||||
|
||||
auto kernel =
|
||||
vllm::moe::moe_align_block_size_global_mem_kernel<scalar_t>;
|
||||
@ -252,25 +268,32 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
experts_ids.data_ptr<int32_t>(),
|
||||
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
|
||||
topk_ids.numel(), tokens_cnts, cumsum);
|
||||
cudaFree(tokens_cnts);
|
||||
cudaFree(cumsum);
|
||||
topk_ids.numel(), token_cnts_buffer.data_ptr<int32_t>(),
|
||||
cumsum_buffer.data_ptr<int32_t>());
|
||||
});
|
||||
} else if (use_i16) {
|
||||
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||
// set dynamic shared mem
|
||||
auto kernel =
|
||||
vllm::moe::moe_align_block_size_kernel<scalar_t, uint16_t>;
|
||||
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
|
||||
(void*)kernel, shared_mem_i16));
|
||||
kernel<<<1, num_thread, shared_mem_i16, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
experts_ids.data_ptr<int32_t>(),
|
||||
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
|
||||
topk_ids.numel());
|
||||
});
|
||||
} else {
|
||||
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
|
||||
// tensors
|
||||
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
|
||||
const int32_t shared_mem =
|
||||
((num_thread + 1) * num_experts + (num_experts + 1)) *
|
||||
sizeof(int32_t);
|
||||
|
||||
// set dynamic shared mem
|
||||
auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t>;
|
||||
auto kernel =
|
||||
vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>;
|
||||
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
|
||||
(void*)kernel, shared_mem));
|
||||
kernel<<<1, num_thread, shared_mem, stream>>>(
|
||||
(void*)kernel, shared_mem_i32));
|
||||
kernel<<<1, num_thread, shared_mem_i32, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
experts_ids.data_ptr<int32_t>(),
|
||||
|
||||
10
csrc/ops.h
10
csrc/ops.h
@ -34,8 +34,9 @@ void paged_attention_v1(
|
||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale, const int64_t tp_rank,
|
||||
const int64_t blocksparse_local_blocks,
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
const int64_t blocksparse_head_sliding_step);
|
||||
|
||||
@ -45,8 +46,9 @@ void paged_attention_v2(
|
||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale, const int64_t tp_rank,
|
||||
const int64_t blocksparse_local_blocks,
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
const int64_t blocksparse_head_sliding_step);
|
||||
|
||||
|
||||
@ -218,7 +218,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
|
||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
|
||||
// head_size]
|
||||
scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size]
|
||||
int max_ctx_blocks, float k_scale, float v_scale) {
|
||||
int max_ctx_blocks, const float* k_scale_ptr, const float* v_scale_ptr) {
|
||||
constexpr int NWARPS = NUM_THREADS / WARP_SIZE;
|
||||
const int warpid = threadIdx.x / WARP_SIZE;
|
||||
const int laneid = threadIdx.x % WARP_SIZE;
|
||||
@ -406,7 +406,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
|
||||
// Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d];
|
||||
const _B8x8 Vlocalb8 = v_ptrh8be[d];
|
||||
Vlocal[h][b * BLOCK_SIZE / 8 + d] =
|
||||
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Vlocalb8, v_scale);
|
||||
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Vlocalb8, *v_scale_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -416,7 +416,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
|
||||
#pragma unroll
|
||||
for (int d = 0; d < KHELOOP; d++) {
|
||||
Klocal[d] =
|
||||
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Klocalb8[d], k_scale);
|
||||
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Klocalb8[d], *k_scale_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
@ -890,7 +890,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
|
||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
|
||||
// head_size]
|
||||
scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size]
|
||||
int max_ctx_blocks, float k_scale, float v_scale) {
|
||||
int max_ctx_blocks, const float* k_scale, const float* v_scale) {
|
||||
UNREACHABLE_CODE
|
||||
}
|
||||
|
||||
@ -919,7 +919,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \
|
||||
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
|
||||
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \
|
||||
k_scale, v_scale);
|
||||
k_scale_ptr, v_scale_ptr);
|
||||
|
||||
template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
|
||||
int BLOCK_SIZE, int HEAD_SIZE, int PARTITION_SIZE = 512>
|
||||
@ -929,7 +929,7 @@ void paged_attention_custom_launcher(
|
||||
torch::Tensor& value_cache, const int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& context_lens,
|
||||
int max_context_len, const std::optional<torch::Tensor>& alibi_slopes,
|
||||
float k_scale, float v_scale) {
|
||||
torch::Tensor& k_scale, torch::Tensor& v_scale) {
|
||||
int num_seqs = query.size(0);
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
@ -953,6 +953,8 @@ void paged_attention_custom_launcher(
|
||||
KVT* value_cache_ptr = reinterpret_cast<KVT*>(value_cache.data_ptr());
|
||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int* context_lens_ptr = context_lens.data_ptr<int>();
|
||||
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
|
||||
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
|
||||
|
||||
const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE);
|
||||
const int max_num_partitions =
|
||||
@ -1087,7 +1089,8 @@ void paged_attention(
|
||||
torch::Tensor& context_lens, // [num_seqs]
|
||||
int64_t block_size, int64_t max_context_len,
|
||||
const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, double k_scale, double v_scale) {
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale) {
|
||||
const int head_size = query.size(2);
|
||||
if (kv_cache_dtype == "auto") {
|
||||
if (query.dtype() == at::ScalarType::Half) {
|
||||
|
||||
@ -10,5 +10,5 @@ void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
|
||||
torch::Tensor& context_lens, int64_t block_size,
|
||||
int64_t max_context_len,
|
||||
const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, double k_scale,
|
||||
double v_scale);
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale);
|
||||
|
||||
@ -27,7 +27,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
|
||||
" int max_context_len,"
|
||||
" Tensor? alibi_slopes,"
|
||||
" str kv_cache_dtype,"
|
||||
" float k_scale, float v_scale) -> ()");
|
||||
" Tensor k_scale, Tensor v_scale) -> ()");
|
||||
rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention);
|
||||
}
|
||||
|
||||
|
||||
@ -30,7 +30,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor value_cache, int num_kv_heads, float scale,"
|
||||
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
||||
" int max_seq_len, Tensor? alibi_slopes,"
|
||||
" str kv_cache_dtype, float k_scale, float v_scale,"
|
||||
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
|
||||
" int tp_rank, int blocksparse_local_blocks,"
|
||||
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
||||
" int blocksparse_head_sliding_step) -> ()");
|
||||
@ -44,7 +44,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor value_cache, int num_kv_heads, float scale,"
|
||||
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
||||
" int max_seq_len, Tensor? alibi_slopes,"
|
||||
" str kv_cache_dtype, float k_scale, float v_scale,"
|
||||
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
|
||||
" int tp_rank, int blocksparse_local_blocks,"
|
||||
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
||||
" int blocksparse_head_sliding_step) -> ()");
|
||||
@ -449,7 +449,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
" Tensor! key_cache, Tensor! value_cache,"
|
||||
" Tensor slot_mapping,"
|
||||
" str kv_cache_dtype,"
|
||||
" float k_scale, float v_scale) -> ()");
|
||||
" Tensor k_scale, Tensor v_scale) -> ()");
|
||||
cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache);
|
||||
|
||||
// Reshape the key and value tensors and cache them.
|
||||
@ -459,7 +459,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
" Tensor! value_cache,"
|
||||
" Tensor slot_mapping,"
|
||||
" str kv_cache_dtype,"
|
||||
" float k_scale, float v_scale) -> ()");
|
||||
" Tensor k_scale, Tensor v_scale) -> ()");
|
||||
cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
|
||||
&reshape_and_cache_flash);
|
||||
|
||||
|
||||
@ -43,7 +43,7 @@
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: vllm.multimodal.inputs.MultiModalInputsV2
|
||||
.. autoclass:: vllm.multimodal.inputs.MultiModalInputs
|
||||
:members:
|
||||
:show-inheritance:
|
||||
```
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
|
||||
We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below:
|
||||
|
||||
- [The eighth vLLM meetup](https://lu.ma/zep56hui), with Google Cloud, January 22nd 2025. [[Slides]](https://docs.google.com/presentation/d/1epVkt4Zu8Jz_S5OhEHPc798emsYh2BwYfRuDDVEF7u4/edit?usp=sharing)
|
||||
- [The seventh vLLM meetup](https://lu.ma/h0qvrajz), with Snowflake, November 14th 2024. [[Slides]](https://docs.google.com/presentation/d/1e3CxQBV3JsfGp30SwyvS3eM_tW-ghOhJ9PAJGK6KR54/edit?usp=sharing)
|
||||
- [The sixth vLLM meetup](https://lu.ma/87q3nvnh), with NVIDIA, September 9th 2024. [[Slides]](https://docs.google.com/presentation/d/1wrLGwytQfaOTd5wCGSPNhoaW3nq0E-9wqyP7ny93xRs/edit?usp=sharing)
|
||||
- [The fifth vLLM meetup](https://lu.ma/lp0gyjqr), with AWS, July 24th 2024. [[Slides]](https://docs.google.com/presentation/d/1RgUD8aCfcHocghoP3zmXzck9vX3RCI9yfUAB2Bbcl4Y/edit?usp=sharing)
|
||||
|
||||
@ -41,3 +41,20 @@ You may use the `#security` channel in the [VLLM Slack](https://slack.vllm.ai)
|
||||
to discuss security-related topics. However, please do not disclose any
|
||||
vulnerabilities in this channel. If you need to report a vulnerability, please
|
||||
use the GitHub security advisory system or contact a VMT member privately.
|
||||
|
||||
## Vulnerability Disclosure
|
||||
|
||||
The process for disclosing vulnerabilities is the following:
|
||||
|
||||
- The VMT will work with the project maintainers to develop a fix for the
|
||||
vulnerability.
|
||||
- The VMT will coordinate with the reporter and project maintainers to prepare a
|
||||
security advisory that adequately describes the vulnerability and its impact.
|
||||
- The VMT will coordinate with the project maintainers to publish a fix and
|
||||
release an update that includes that fix.
|
||||
- The VMT will publish the security advisory on GitHub. Release notes will be
|
||||
updated to include a reference to the security advisory.
|
||||
|
||||
The VMT and project maintainers will work to minimize the amount of time in
|
||||
between disclosing any public information about the vulnerability and making a
|
||||
release and advisory available.
|
||||
|
||||
@ -307,7 +307,7 @@ Check the '✗' with links to see tracking issue for unsupported feature/hardwar
|
||||
- ✅
|
||||
- ?
|
||||
- ?
|
||||
- ✅
|
||||
- [✗](gh-issue:11484)
|
||||
- ✅
|
||||
- ✗
|
||||
- ?
|
||||
|
||||
@ -1,44 +0,0 @@
|
||||
(fp8-e4m3-kvcache)=
|
||||
|
||||
# FP8 E4M3 KV Cache
|
||||
|
||||
Quantizing the KV cache to FP8 reduces its memory footprint. This increases the number of tokens that can be stored in the cache,
|
||||
improving throughput. OCP (Open Compute Project www.opencompute.org) specifies two common 8-bit floating point data formats: E5M2
|
||||
(5 exponent bits and 2 mantissa bits) and E4M3FN (4 exponent bits and 3 mantissa bits), often shortened as E4M3. One benefit of
|
||||
the E4M3 format over E5M2 is that floating point numbers are represented in higher precision. However, the small dynamic range of
|
||||
FP8 E4M3 (±240.0 can be represented) typically necessitates the use of a higher-precision (typically FP32) scaling factor alongside
|
||||
each quantized tensor. For now, only per-tensor (scalar) scaling factors are supported. Development is ongoing to support scaling
|
||||
factors of a finer granularity (e.g. per-channel).
|
||||
|
||||
These scaling factors can be specified by passing an optional quantization param JSON to the LLM engine at load time. If
|
||||
this JSON is not specified, scaling factors default to 1.0. These scaling factors are typically obtained when running an
|
||||
unquantized model through a quantizer tool (e.g. AMD quantizer or NVIDIA AMMO).
|
||||
|
||||
To install AMMO (AlgorithMic Model Optimization):
|
||||
|
||||
```console
|
||||
pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com nvidia-ammo
|
||||
```
|
||||
|
||||
Studies have shown that FP8 E4M3 quantization typically only minimally degrades inference accuracy. The most recent silicon
|
||||
offerings e.g. AMD MI300, NVIDIA Hopper or later support native hardware conversion to and from fp32, fp16, bf16, etc.
|
||||
Thus, LLM inference is greatly accelerated with minimal accuracy loss.
|
||||
|
||||
Here is an example of how to enable this feature:
|
||||
|
||||
```python
|
||||
# two float8_e4m3fn kv cache scaling factor files are provided under tests/fp8_kv, please refer to
|
||||
# https://github.com/vllm-project/vllm/blob/main/examples/other/fp8/README.md to generate kv_cache_scales.json of your own.
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
sampling_params = SamplingParams(temperature=1.3, top_p=0.8)
|
||||
llm = LLM(model="meta-llama/Llama-2-7b-chat-hf",
|
||||
kv_cache_dtype="fp8",
|
||||
quantization_param_path="./tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json")
|
||||
prompt = "London is the capital of"
|
||||
out = llm.generate(prompt, sampling_params)[0].outputs[0].text
|
||||
print(out)
|
||||
|
||||
# output w/ scaling factors: England, the United Kingdom, and one of the world's leading financial,
|
||||
# output w/o scaling factors: England, located in the southeastern part of the country. It is known
|
||||
```
|
||||
@ -1,31 +0,0 @@
|
||||
(fp8-kv-cache)=
|
||||
|
||||
# FP8 E5M2 KV Cache
|
||||
|
||||
The int8/int4 quantization scheme requires additional scale GPU memory storage, which reduces the expected GPU memory benefits.
|
||||
The FP8 data format retains 2~3 mantissa bits and can convert float/fp16/bfloat16 and fp8 to each other.
|
||||
|
||||
Here is an example of how to enable this feature:
|
||||
|
||||
```python
|
||||
from vllm import LLM, SamplingParams
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
# Create an LLM.
|
||||
llm = LLM(model="facebook/opt-125m", kv_cache_dtype="fp8")
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
```
|
||||
@ -14,6 +14,5 @@ bnb
|
||||
gguf
|
||||
int8
|
||||
fp8
|
||||
fp8_e5m2_kvcache
|
||||
fp8_e4m3_kvcache
|
||||
quantized_kvcache
|
||||
```
|
||||
|
||||
147
docs/source/features/quantization/quantized_kvcache.md
Normal file
147
docs/source/features/quantization/quantized_kvcache.md
Normal file
@ -0,0 +1,147 @@
|
||||
(quantized-kvcache)=
|
||||
|
||||
# Quantized KV Cache
|
||||
|
||||
## FP8 KV Cache
|
||||
|
||||
Quantizing the KV cache to FP8 reduces its memory footprint. This increases the number of tokens that can be stored in the cache, improving throughput.
|
||||
|
||||
### FP8 Formats
|
||||
|
||||
[OCP (Open Compute Project)](https://www.opencompute.org) specifies two common 8-bit floating point data formats:
|
||||
|
||||
- E5M2 (5 exponent bits and 2 mantissa bits)
|
||||
- E4M3FN (4 exponent bits and 3 mantissa bits, often shortened as E4M3)
|
||||
|
||||
The E4M3 format offers higher precision compared to E5M2. However, due to its small dynamic range (±240.0), E4M3 typically requires a higher-precision (FP32) scaling factor alongside each quantized tensor.
|
||||
|
||||
### Current Limitations
|
||||
|
||||
For now, only per-tensor (scalar) scaling factors are supported. Development is ongoing to support scaling factors of a finer granularity (e.g. per-channel).
|
||||
|
||||
### Performance Impact
|
||||
|
||||
The current FP8 KV cache implementation primarily benefits throughput by allowing approximately double the amount of space for KV cache allocation. This enables either:
|
||||
|
||||
- Processing longer context lengths for individual requests, or
|
||||
- Handling more concurrent request batches
|
||||
|
||||
However, there are currently no latency improvements as the implementation does not yet include fused dequantization and attention operations. Future releases will support quantized attention with hardware acceleration, which should provide additional performance benefits. While the most recent silicon offerings (e.g. AMD MI300, NVIDIA Hopper or later) support native hardware conversion between FP8 and other formats (fp32, fp16, bf16), this benefit is not yet fully realized.
|
||||
|
||||
Studies have shown that FP8 E4M3 quantization typically only minimally degrades inference accuracy, making it a practical choice for throughput optimization.
|
||||
|
||||
## Usage Example
|
||||
|
||||
Here is an example of how to enable FP8 quantization:
|
||||
|
||||
```python
|
||||
# To calculate kv cache scales on the fly enable the calculate_kv_scales
|
||||
# parameter
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.7, top_p=0.8)
|
||||
llm = LLM(model="meta-llama/Llama-2-7b-chat-hf",
|
||||
kv_cache_dtype="fp8",
|
||||
calculate_kv_scales=True)
|
||||
prompt = "London is the capital of"
|
||||
out = llm.generate(prompt, sampling_params)[0].outputs[0].text
|
||||
print(out)
|
||||
```
|
||||
|
||||
The `kv_cache_dtype` argument specifies the data type for KV cache storage:
|
||||
- `"auto"`: Uses the model's default "unquantized" data type
|
||||
- `"fp8"` or `"fp8_e4m3"`: Supported on CUDA 11.8+ and ROCm (AMD GPU)
|
||||
- `"fp8_e5m2"`: Supported on CUDA 11.8+
|
||||
|
||||
## Calibrated Scales for Better Accuracy
|
||||
|
||||
For optimal model quality when using FP8 KV Cache, we recommend using calibrated scales tuned to representative inference data. [LLM Compressor](https://github.com/vllm-project/llm-compressor/) is the recommended tool for this process.
|
||||
|
||||
### Installation
|
||||
|
||||
First, install the required dependencies:
|
||||
|
||||
```console
|
||||
pip install llmcompressor
|
||||
```
|
||||
|
||||
### Example Usage
|
||||
|
||||
Here's a complete example using `meta-llama/Llama-3.1-8B-Instruct` (most models can use this same pattern):
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from llmcompressor.transformers import oneshot
|
||||
|
||||
# Select model and load it
|
||||
MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="auto", torch_dtype="auto")
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
||||
|
||||
# Select calibration dataset
|
||||
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
|
||||
DATASET_SPLIT = "train_sft"
|
||||
|
||||
# Configure calibration parameters
|
||||
NUM_CALIBRATION_SAMPLES = 512 # 512 samples is a good starting point
|
||||
MAX_SEQUENCE_LENGTH = 2048
|
||||
|
||||
# Load and preprocess dataset
|
||||
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
|
||||
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))
|
||||
|
||||
def process_and_tokenize(example):
|
||||
text = tokenizer.apply_chat_template(example["messages"], tokenize=False)
|
||||
return tokenizer(
|
||||
text,
|
||||
padding=False,
|
||||
max_length=MAX_SEQUENCE_LENGTH,
|
||||
truncation=True,
|
||||
add_special_tokens=False,
|
||||
)
|
||||
|
||||
ds = ds.map(process_and_tokenize, remove_columns=ds.column_names)
|
||||
|
||||
# Configure quantization settings
|
||||
recipe = """
|
||||
quant_stage:
|
||||
quant_modifiers:
|
||||
QuantizationModifier:
|
||||
kv_cache_scheme:
|
||||
num_bits: 8
|
||||
type: float
|
||||
strategy: tensor
|
||||
dynamic: false
|
||||
symmetric: true
|
||||
"""
|
||||
|
||||
# Apply quantization
|
||||
oneshot(
|
||||
model=model,
|
||||
dataset=ds,
|
||||
recipe=recipe,
|
||||
max_seq_length=MAX_SEQUENCE_LENGTH,
|
||||
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
|
||||
)
|
||||
|
||||
# Save quantized model
|
||||
SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-KV"
|
||||
model.save_pretrained(SAVE_DIR, save_compressed=True)
|
||||
tokenizer.save_pretrained(SAVE_DIR)
|
||||
```
|
||||
|
||||
The above script will create a folder in your current directory containing your quantized model (e.g., `Llama-3.1-8B-Instruct-FP8-KV`) with calibrated scales.
|
||||
|
||||
When running the model you must specify `kv_cache_dtype="fp8"` in order to enable the kv cache quantization and use the scales.
|
||||
|
||||
```python
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.7, top_p=0.8)
|
||||
llm = LLM(model="Llama-3.1-8B-Instruct-FP8-KV", kv_cache_dtype="fp8")
|
||||
prompt = "London is the capital of"
|
||||
out = llm.generate(prompt, sampling_params)[0].outputs[0].text
|
||||
print(out)
|
||||
```
|
||||
@ -13,6 +13,14 @@ vLLM supports AMD GPUs with ROCm 6.2.
|
||||
|
||||
Currently, there are no pre-built ROCm wheels.
|
||||
|
||||
However, the [AMD Infinity hub for vLLM](https://hub.docker.com/r/rocm/vllm/tags) offers a prebuilt, optimized
|
||||
docker image designed for validating inference performance on the AMD Instinct™ MI300X accelerator.
|
||||
|
||||
```{tip}
|
||||
Please check [LLM inference performance validation on AMD Instinct MI300X](https://rocm.docs.amd.com/en/latest/how-to/performance-validation/mi300x/vllm-benchmark.html)
|
||||
for instructions on how to use this prebuilt docker image.
|
||||
```
|
||||
|
||||
### Build wheel from source
|
||||
|
||||
0. Install prerequisites (skip if you are already in an environment/docker with the following installed):
|
||||
@ -123,11 +131,10 @@ It is important that the user kicks off the docker build using buildkit. Either
|
||||
<gh-file:Dockerfile.rocm> uses ROCm 6.2 by default, but also supports ROCm 5.7, 6.0 and 6.1 in older vLLM branches.
|
||||
It provides flexibility to customize the build of docker image using the following arguments:
|
||||
|
||||
- `BASE_IMAGE`: specifies the base image used when running `docker build`, specifically the PyTorch on ROCm base image.
|
||||
- `BUILD_FA`: specifies whether to build CK flash-attention. The default is 1. For [Radeon RX 7900 series (gfx1100)](https://rocm.docs.amd.com/projects/radeon/en/latest/index.html), this should be set to 0 before flash-attention supports this target.
|
||||
- `FX_GFX_ARCHS`: specifies the GFX architecture that is used to build CK flash-attention, for example, `gfx90a;gfx942` for MI200 and MI300. The default is `gfx90a;gfx942`
|
||||
- `FA_BRANCH`: specifies the branch used to build the CK flash-attention in [ROCm's flash-attention repo](https://github.com/ROCmSoftwarePlatform/flash-attention). The default is `ae7928c`
|
||||
- `BUILD_TRITON`: specifies whether to build triton flash-attention. The default value is 1.
|
||||
- `BASE_IMAGE`: specifies the base image used when running `docker build`. The default value `rocm/vllm-dev:base` is an image published and maintained by AMD. It is being built using <gh-file:Dockerfile.rocm_base>
|
||||
- `USE_CYTHON`: An option to run cython compilation on a subset of python files upon docker build
|
||||
- `BUILD_RPD`: Include RocmProfileData profiling tool in the image
|
||||
- `ARG_PYTORCH_ROCM_ARCH`: Allows to override the gfx architecture values from the base docker image
|
||||
|
||||
Their values can be passed in when running `docker build` with `--build-arg` options.
|
||||
|
||||
@ -137,10 +144,10 @@ To build vllm on ROCm 6.2 for MI200 and MI300 series, you can use the default:
|
||||
DOCKER_BUILDKIT=1 docker build -f Dockerfile.rocm -t vllm-rocm .
|
||||
```
|
||||
|
||||
To build vllm on ROCm 6.2 for Radeon RX7900 series (gfx1100), you should specify `BUILD_FA` as below:
|
||||
To build vllm on ROCm 6.2 for Radeon RX7900 series (gfx1100), you should pick the alternative base image:
|
||||
|
||||
```console
|
||||
DOCKER_BUILDKIT=1 docker build --build-arg BUILD_FA="0" -f Dockerfile.rocm -t vllm-rocm .
|
||||
DOCKER_BUILDKIT=1 docker build --build-arg BASE_IMAGE="rocm/vllm-dev:navi_base" -f Dockerfile.rocm -t vllm-rocm .
|
||||
```
|
||||
|
||||
To run the above docker image `vllm-rocm`, use the below command:
|
||||
|
||||
@ -22,9 +22,9 @@ It'd be better to store the model in a local disk. Additionally, have a look at
|
||||
To isolate the model downloading and loading issue, you can use the `--load-format dummy` argument to skip loading the model weights. This way, you can check if the model downloading and loading is the bottleneck.
|
||||
```
|
||||
|
||||
## Model is too large
|
||||
## Out of memory
|
||||
|
||||
If the model is too large to fit in a single GPU, you might want to [consider tensor parallelism](#distributed-serving) to split the model across multiple GPUs. In that case, every process will read the whole model and split it into chunks, which makes the disk reading time even longer (proportional to the size of tensor parallelism). You can convert the model checkpoint to a sharded checkpoint using <gh-file:examples/offline_inference/save_sharded_state.py>. The conversion process might take some time, but later you can load the sharded checkpoint much faster. The model loading time should remain constant regardless of the size of tensor parallelism.
|
||||
If the model is too large to fit in a single GPU, you will get an out-of-memory (OOM) error. Consider [using tensor parallelism](#distributed-serving) to split the model across multiple GPUs. In that case, every process will read the whole model and split it into chunks, which makes the disk reading time even longer (proportional to the size of tensor parallelism). You can convert the model checkpoint to a sharded checkpoint using <gh-file:examples/offline_inference/save_sharded_state.py>. The conversion process might take some time, but later you can load the sharded checkpoint much faster. The model loading time should remain constant regardless of the size of tensor parallelism.
|
||||
|
||||
## Enable more logging
|
||||
|
||||
@ -197,6 +197,63 @@ if __name__ == '__main__':
|
||||
llm = vllm.LLM(...)
|
||||
```
|
||||
|
||||
## `torch.compile` Error
|
||||
|
||||
vLLM heavily depends on `torch.compile` to optimize the model for better performance, which introduces the dependency on the `torch.compile` functionality and the `triton` library. By default, we use `torch.compile` to [optimize some functions](https://github.com/vllm-project/vllm/pull/10406) in the model. Before running vLLM, you can check if `torch.compile` is working as expected by running the following script:
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
@torch.compile
|
||||
def f(x):
|
||||
# a simple function to test torch.compile
|
||||
x = x + 1
|
||||
x = x * 2
|
||||
x = x.sin()
|
||||
return x
|
||||
|
||||
x = torch.randn(4, 4).cuda()
|
||||
print(f(x))
|
||||
```
|
||||
|
||||
If it raises errors from `torch/_inductor` directory, usually it means you have a custom `triton` library that is not compatible with the version of PyTorch you are using. See [this issue](https://github.com/vllm-project/vllm/issues/12219) for example.
|
||||
|
||||
## Model failed to be inspected
|
||||
|
||||
If you see an error like:
|
||||
|
||||
```text
|
||||
File "vllm/model_executor/models/registry.py", line xxx, in _raise_for_unsupported
|
||||
raise ValueError(
|
||||
ValueError: Model architectures ['<arch>'] failed to be inspected. Please check the logs for more details.
|
||||
```
|
||||
|
||||
It means that vLLM failed to import the model file.
|
||||
Usually, it is related to missing dependencies or outdated binaries in the vLLM build.
|
||||
Please read the logs carefully to determine the root cause of the error.
|
||||
|
||||
## Model not supported
|
||||
|
||||
If you see an error like:
|
||||
|
||||
```text
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
File "vllm/model_executor/models/registry.py", line xxx, in inspect_model_cls
|
||||
for arch in architectures:
|
||||
TypeError: 'NoneType' object is not iterable
|
||||
```
|
||||
|
||||
or:
|
||||
|
||||
```text
|
||||
File "vllm/model_executor/models/registry.py", line xxx, in _raise_for_unsupported
|
||||
raise ValueError(
|
||||
ValueError: Model architectures ['<arch>'] are not supported for now. Supported architectures: [...]
|
||||
```
|
||||
|
||||
But you are sure that the model is in the [list of supported models](#supported-models), there may be some issue with vLLM's model resolution. In that case, please follow [these steps](#model-resolution) to explicitly specify the vLLM implementation for the model.
|
||||
|
||||
## Known Issues
|
||||
|
||||
- In `v0.5.2`, `v0.5.3`, and `v0.5.3.post1`, there is a bug caused by [zmq](https://github.com/zeromq/pyzmq/issues/2000) , which can occasionally cause vLLM to hang depending on the machine configuration. The solution is to upgrade to the latest version of `vllm` to include the [fix](gh-pr:6759).
|
||||
|
||||
@ -302,8 +302,8 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
* - `Phi3ForCausalLM`
|
||||
- Phi-3
|
||||
- `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc.
|
||||
- Phi-4, Phi-3
|
||||
- `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc.
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
* - `Phi3SmallForCausalLM`
|
||||
|
||||
@ -31,6 +31,8 @@ Please refer to the above pages for more details about each API.
|
||||
This section lists the most common options for running the vLLM engine.
|
||||
For a full list, refer to the [Engine Arguments](#engine-args) page.
|
||||
|
||||
(model-resolution)=
|
||||
|
||||
### Model resolution
|
||||
|
||||
vLLM loads HuggingFace-compatible models by inspecting the `architectures` field in `config.json` of the model repository
|
||||
@ -41,37 +43,6 @@ Nevertheless, our model resolution may fail for the following reasons:
|
||||
- Unofficial repositories refer to a model using alternative names which are not recorded in vLLM.
|
||||
- The same architecture name is used for multiple models, creating ambiguity as to which model should be loaded.
|
||||
|
||||
In those cases, vLLM may throw an error like:
|
||||
|
||||
```text
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
File "vllm/model_executor/models/registry.py", line xxx, in inspect_model_cls
|
||||
for arch in architectures:
|
||||
TypeError: 'NoneType' object is not iterable
|
||||
```
|
||||
|
||||
or:
|
||||
|
||||
```text
|
||||
File "vllm/model_executor/models/registry.py", line xxx, in _raise_for_unsupported
|
||||
raise ValueError(
|
||||
ValueError: Model architectures ['<arch>'] are not supported for now. Supported architectures: [...]
|
||||
```
|
||||
|
||||
:::{note}
|
||||
The above error is distinct from the following similar but different error:
|
||||
|
||||
```text
|
||||
File "vllm/model_executor/models/registry.py", line xxx, in _raise_for_unsupported
|
||||
raise ValueError(
|
||||
ValueError: Model architectures ['<arch>'] failed to be inspected. Please check the logs for more details.
|
||||
```
|
||||
|
||||
This error means that vLLM failed to import the model file. Usually, it is related to missing dependencies or outdated
|
||||
binaries in the vLLM build. Please read the logs carefully to determine the real cause of the error.
|
||||
:::
|
||||
|
||||
To fix this, explicitly specify the model architecture by passing `config.json` overrides to the `hf_overrides` option.
|
||||
For example:
|
||||
|
||||
|
||||
@ -8,10 +8,10 @@ prompts = [
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
sampling_params = SamplingParams() #temperature=0.8, top_p=0.95)
|
||||
|
||||
# Create an LLM.
|
||||
llm = LLM(model="facebook/opt-125m")
|
||||
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", max_model_len=512, max_num_seqs=16)
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
@ -19,4 +19,4 @@ outputs = llm.generate(prompts, sampling_params)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
@ -19,7 +19,7 @@ from ray.util.placement_group import placement_group
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from vllm import LLM, SamplingParams, configure_as_vllm_process
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.utils import get_ip, get_open_port
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
@ -98,12 +98,7 @@ class MyLLM(LLM):
|
||||
"""
|
||||
Start the training process, here we use huggingface transformers
|
||||
as an example to hold a model on GPU 0.
|
||||
|
||||
It is important for all the processes outside of vLLM to call
|
||||
`configure_as_vllm_process` to set some common environment variables
|
||||
the same as vLLM workers.
|
||||
"""
|
||||
configure_as_vllm_process()
|
||||
|
||||
train_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
|
||||
train_model.to("cuda:0")
|
||||
|
||||
@ -26,14 +26,12 @@ def run_aria(question: str, modality: str):
|
||||
|
||||
# NOTE: Need L40 (or equivalent) to avoid OOM
|
||||
llm = LLM(model=model_name,
|
||||
tokenizer_mode="slow",
|
||||
dtype="bfloat16",
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
trust_remote_code=True,
|
||||
dtype="bfloat16",
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||
|
||||
prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>\n{question}"
|
||||
prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>{question}"
|
||||
"<|im_end|>\n<|im_start|>assistant\n")
|
||||
|
||||
stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
|
||||
|
||||
@ -1,96 +0,0 @@
|
||||
# FP8 KV Cache
|
||||
|
||||
This utility extracts the KV cache scaling factors from a quantized HF (Hugging Face) model. The extracted scaling factors are saved to a JSON file, which can later be used by vLLM (variable-length language model) during runtime. This tool is particularly useful when the KV cache data type is FP8 and is intended for use on ROCm (AMD GPU) platforms.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Python 3.x
|
||||
- PyTorch
|
||||
- NumPy
|
||||
- Hugging Face Transformers
|
||||
- Hugging Face Hub
|
||||
- AMMO
|
||||
|
||||
Before incorporating the FP8 datatype for inference workloads, you must adhere to the following steps:
|
||||
1. Install all necessary prerequisites and dependencies.
|
||||
2. Convert HF model into a quantized HF model.
|
||||
3. Extract KV Cache Scaling Factors from quantized HF model.
|
||||
4. Load KV Cache Scaling Factors into VLLM.
|
||||
|
||||
### 2. Convert HF model into a quantized HF model.
|
||||
Note: The following steps are adapted from the [TensorRT-LLM repository](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/quantization/README.md).
|
||||
|
||||
`quantize.py` (examples/other/fp8/quantizer/quantize.py) uses the quantization toolkit (AMMO) to calibrate the PyTorch models and export TensorRT-LLM checkpoints. Each TensorRT-LLM checkpoint contains a config file (in .json format) and one or several rank weight files (in .safetensors format).
|
||||
|
||||
The detailed quantization toolkit (AMMO) conversion guide for FP8 can be found at `examples/other/fp8/quantizer/README.md`.
|
||||
|
||||
### 3. Extract KV Cache Scaling Factors from quantized HF model.
|
||||
`extract_scales.py` (examples/other/fp8/extract_scales.py) can be utilized to extract the KV cache scaling factors from your quantized HF model, however at the moment, this tool exclusively supports Llama 2 models. It is also important to note the following:
|
||||
1. **File Structure**: The utility operates under the assumption that all parameters, including KV cache scaling factors, corresponding to a particular Tensor Parallelism (TP) rank are stored in a single file. These files must adhere to a specific naming convention where the TP rank is immediately identified after a specific keyword (e.g., "rank") in the filename.
|
||||
|
||||
2. **TP Decomposition**: The utility assumes consistency between the TP decomposition employed by the quantizer tool and that used by vLLM.
|
||||
|
||||
3. **AMMO Compatibility**: Currently, the generated KV cache scaling factors for AMMO remain uniform across all TP ranks.
|
||||
|
||||
```python
|
||||
# prerequisites:
|
||||
# - Quantized HF LLaMa 2 model
|
||||
python3 examples/other/fp8/extract_scales.py --help
|
||||
Usage: extract_scales.py [-h] --quantized_model QUANTIZED_MODEL [--load_format {auto,safetensors,npz,pt}] [--output_dir OUTPUT_DIR] [--output_name OUTPUT_NAME] [--tp_size TP_SIZE]
|
||||
|
||||
KV Scale Extraction Example
|
||||
|
||||
optional arguments:
|
||||
--quantized_model: Specify either the local path to, or name of, a quantized HF model. It is expected that the quantization format is FP8_E4M3, for use on ROCm (AMD GPU).
|
||||
Optional arguments:
|
||||
--cache_dir: Specify a cache directory to use in the event of a HF model download. (Default: None)
|
||||
--load_format: Specify the format of the model's tensor files containing the KV cache scaling factors. (Choices: auto, safetensors, npz, pt; Default: auto)
|
||||
--revision: Specify the model's revision number. (Default: None)
|
||||
--output_dir: Specify the output directory. By default the KV cache scaling factors will be saved in the model directory. (Default: None)
|
||||
--output_name: Specify the output filename. (Default: kv_cache_scales.json)
|
||||
--tp_size: Specify the tensor-parallel (TP) size that the quantized model should correspond to. If specified, during KV cache scaling factor extraction the observed TP size will be checked against this and an error will be raised if there is a mismatch. (Default: None)
|
||||
```
|
||||
```python
|
||||
Example:
|
||||
python3 examples/other/fp8/extract_scales.py --quantized_model <QUANTIZED_MODEL_DIR> --tp_size <TENSOR_PARALLEL_SIZE> --output_dir <PATH_TO_OUTPUT_DIR>
|
||||
```
|
||||
### 4. Load KV Cache Scaling Factors into VLLM.
|
||||
This script evaluates the inference throughput of language models using various backends such as vLLM. It measures the time taken to process a given number of prompts and generate sequences for each prompt. The recently generated KV cache scaling factors are now integrated into the benchmarking process and allow for KV cache scaling factors to be utilized for FP8.
|
||||
```
|
||||
# prerequisites:
|
||||
# - LLaMa 2 kv_cache_scales.json file
|
||||
|
||||
python3 benchmarks/benchmark_throughput.py --help
|
||||
usage: benchmark_throughput.py [-h] [--backend {vllm,hf,mii}] [--dataset DATASET] [--input-len INPUT_LEN] [--output-len OUTPUT_LEN] [--model MODEL]
|
||||
[--tokenizer TOKENIZER] [--quantization {awq,gptq,None}] [--tensor-parallel-size TENSOR_PARALLEL_SIZE] [--n N]
|
||||
[--use-beam-search] [--num-prompts NUM_PROMPTS] [--seed SEED] [--hf-max-batch-size HF_MAX_BATCH_SIZE] [--trust-remote-code]
|
||||
[--max-model-len MAX_MODEL_LEN] [--dtype {auto,half,float16,bfloat16,float,float32}] [--enforce-eager] [--kv-cache-dtype {auto,fp8}]
|
||||
[--quantization-param-path KV_CACHE_quantization_param_path]
|
||||
|
||||
Benchmark Throughput Example
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
--backend {vllm,hf,mii}
|
||||
--dataset DATASET Path to the dataset.
|
||||
--input-len INPUT_LEN Input prompt length for each request
|
||||
--output-len OUTPUT_LEN Output length for each request. Overrides the output length from the dataset.
|
||||
--model MODEL
|
||||
--tokenizer TOKENIZER
|
||||
--quantization {awq,gptq,None}, -q {awq,gptq,None}
|
||||
--tensor-parallel-size TENSOR_PARALLEL_SIZE, -tp TENSOR_PARALLEL_SIZE
|
||||
--n N Number of generated sequences per prompt.
|
||||
--use-beam-search
|
||||
--num-prompts NUM_PROMPTS Number of prompts to process.
|
||||
--seed SEED
|
||||
--hf-max-batch-size HF_MAX_BATCH_SIZE Maximum batch size for HF backend.
|
||||
--trust-remote-code trust remote code from huggingface
|
||||
--max-model-len MAX_MODEL_LEN Maximum length of a sequence (including prompt and output). If None, will be derived from the model.
|
||||
--dtype {auto,half,float16,bfloat16,float,float32} data type for model weights and activations. The "auto" option will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models.
|
||||
--enforce-eager enforce eager execution
|
||||
--kv-cache-dtype {auto,fp8} Data type for kv cache storage. If "auto", will use model data type. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported ```for common inference criteria.
|
||||
--quantization-param-path QUANT_PARAM_JSON Path to the JSON file containing the KV cache scaling factors. This should generally be supplied, when KV cache dtype is FP8. Otherwise, KV cache scaling factors default to 1.0, which may cause accuracy issues. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for common inference criteria.
|
||||
```
|
||||
Example:
|
||||
```console
|
||||
python3 benchmarks/benchmark_throughput.py --input-len <INPUT_LEN> --output-len <OUTPUT_LEN> -tp <TENSOR_PARALLEL_SIZE> --kv-cache-dtype fp8 --quantization-param-path <path/to/kv_cache_scales.json> --model <path-to-llama2>
|
||||
```
|
||||
@ -1,367 +0,0 @@
|
||||
import argparse
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from safetensors.torch import safe_open
|
||||
|
||||
from vllm.model_executor.layers.quantization.schema import QuantParamSchema
|
||||
|
||||
|
||||
# Adapted from vllm/model_executor/model_loader/weight_utils.py
|
||||
# The main differences are that we add the NPZ format and simplify
|
||||
# its functionality drastically for our purposes (e.g. we assume that
|
||||
# the quantized model exists locally and there is no need to download it)
|
||||
def _prepare_hf_weights(
|
||||
quantized_model_dir: str,
|
||||
load_format: str = "auto",
|
||||
fall_back_to_pt: bool = True,
|
||||
) -> Tuple[List[str], bool]:
|
||||
if not os.path.isdir(quantized_model_dir):
|
||||
raise FileNotFoundError(
|
||||
f"The quantized model directory `{quantized_model_dir}` "
|
||||
"does not exist.")
|
||||
use_safetensors = False
|
||||
# Some quantized models use .pt files for storing the weights.
|
||||
if load_format == "auto":
|
||||
allow_patterns = ["*.safetensors", "*.bin"]
|
||||
elif load_format == "safetensors":
|
||||
use_safetensors = True
|
||||
allow_patterns = ["*.safetensors"]
|
||||
elif load_format == "pt":
|
||||
allow_patterns = ["*.pt"]
|
||||
elif load_format == "npz":
|
||||
allow_patterns = ["*.npz"]
|
||||
else:
|
||||
raise ValueError(f"Unknown load_format: {load_format}")
|
||||
if fall_back_to_pt:
|
||||
allow_patterns += ["*.pt"]
|
||||
|
||||
hf_weights_files: List[str] = []
|
||||
for pattern in allow_patterns:
|
||||
hf_weights_files += glob.glob(
|
||||
os.path.join(quantized_model_dir, pattern))
|
||||
if len(hf_weights_files) > 0:
|
||||
if pattern == "*.safetensors":
|
||||
use_safetensors = True
|
||||
break
|
||||
|
||||
if not use_safetensors:
|
||||
# Exclude files that are not needed for inference.
|
||||
# https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
|
||||
blacklist = [
|
||||
"training_args.bin",
|
||||
"optimizer.bin",
|
||||
"optimizer.pt",
|
||||
"scheduler.pt",
|
||||
"scaler.pt",
|
||||
]
|
||||
hf_weights_files = [
|
||||
f for f in hf_weights_files
|
||||
if not any(f.endswith(x) for x in blacklist)
|
||||
]
|
||||
|
||||
if len(hf_weights_files) == 0:
|
||||
raise RuntimeError(
|
||||
f"Cannot find any model weights with `{quantized_model_dir}`")
|
||||
|
||||
return hf_weights_files, use_safetensors
|
||||
|
||||
|
||||
# Adapted from vllm/model_executor/model_loader/weight_utils.py
|
||||
def _hf_tensorfile_iterator(filename: str, load_format: str,
|
||||
use_safetensors: bool):
|
||||
if load_format == "npz":
|
||||
assert not use_safetensors
|
||||
with np.load(filename) as data:
|
||||
for name in data.files:
|
||||
param = torch.from_numpy(data[name])
|
||||
yield name, param
|
||||
elif use_safetensors:
|
||||
with safe_open(filename, framework="pt") as f:
|
||||
for name in f.keys(): # NOQA: SIM118
|
||||
param = f.get_tensor(name)
|
||||
yield name, param
|
||||
else:
|
||||
state = torch.load(filename, map_location="cpu")
|
||||
for name, param in state.items():
|
||||
yield name, param
|
||||
del state
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def _kv_scales_extractor(
|
||||
hf_tensor_files: List[str],
|
||||
use_safetensors: bool,
|
||||
rank_keyword: str = "rank",
|
||||
expected_tp_size: Optional[int] = None) -> Dict[int, Dict[int, float]]:
|
||||
"""
|
||||
Given a list of files containing tensor data, attempt to extract KV cache
|
||||
scales from these files. Intended as a helper function taking in the output
|
||||
from _prepare_hf_weights.
|
||||
Args:
|
||||
rank_keyword Matches the number immediately after this keyword in the
|
||||
tensor filename to determine the TP rank corresponding
|
||||
to said tensor file
|
||||
expected_tp_size If specified, the TP size of the tensor files is checked
|
||||
against this and an error is raised if they don't match.
|
||||
Returns a dictionary mapping TP ranks to their relevant KV cache scales.
|
||||
The per-rank scales are themselves represented as a dictionary of layer
|
||||
indices to the respective per-layer scale.
|
||||
"""
|
||||
for char in rank_keyword:
|
||||
assert not char.isdecimal(
|
||||
), f"Rank keyword {rank_keyword} contains a numeric character!"
|
||||
rank_scales_map: Dict[int, Dict[int, float]] = {}
|
||||
for tensor_file in hf_tensor_files:
|
||||
try:
|
||||
rank_idx = tensor_file.find(rank_keyword)
|
||||
if rank_idx != -1:
|
||||
start_idx = rank_idx + len(rank_keyword)
|
||||
stop_idx = start_idx
|
||||
while stop_idx < len(
|
||||
tensor_file) and tensor_file[stop_idx].isdecimal():
|
||||
stop_idx += 1
|
||||
if stop_idx == start_idx:
|
||||
raise RuntimeError("Did not find rank # in filename.")
|
||||
rank = int(tensor_file[start_idx:stop_idx])
|
||||
elif len(hf_tensor_files) == 1:
|
||||
# Since there is only one tensor file, we can assume
|
||||
# that it's intended for TP rank 0
|
||||
rank = 0
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Filename does not contain '{rank_keyword}'.")
|
||||
except RuntimeError:
|
||||
print("Unable to determine TP rank "
|
||||
f"corresponding to file '{tensor_file}'")
|
||||
raise
|
||||
|
||||
if rank not in rank_scales_map:
|
||||
layer_scales_map: Dict[int, float] = {}
|
||||
rank_scales_map[rank] = layer_scales_map
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Tensor file '{tensor_file}' shares TP rank {rank} "
|
||||
"with another tensor file.")
|
||||
|
||||
module_delimiter = ":" if args.load_format == "npz" else "."
|
||||
for name, param in _hf_tensorfile_iterator(tensor_file,
|
||||
args.load_format,
|
||||
use_safetensors):
|
||||
if "kv_cache_scaling_factor" in name:
|
||||
nums = [
|
||||
int(s) for s in name.split(module_delimiter)
|
||||
if s.isdecimal()
|
||||
]
|
||||
assert len(
|
||||
nums) == 1, f"Could not determine layer idx for {name}"
|
||||
layer_idx = nums[0]
|
||||
assert layer_idx not in layer_scales_map, f"Duplicate scaling"\
|
||||
f" factor corresponding to layer {layer_idx}"
|
||||
try:
|
||||
layer_scales_map[layer_idx] = param.item()
|
||||
except RuntimeError:
|
||||
print(
|
||||
"This utility supports only per-tensor scalar scales "
|
||||
f"for now. The tensor\n {name} = {param} \nis an "
|
||||
"invalid scale factor.")
|
||||
raise
|
||||
|
||||
if all(
|
||||
len(layer_scales_map) == 0
|
||||
for layer_scales_map in rank_scales_map.values()):
|
||||
# Note: this is true even if the rank_scales_map is empty
|
||||
print("WARNING: No KV cache scale factors found. No output saved.")
|
||||
return None
|
||||
empirical_tp_world_size = max(rank_scales_map.keys()) + 1
|
||||
if expected_tp_size is not None:
|
||||
assert expected_tp_size == empirical_tp_world_size, \
|
||||
f"User expected TP world size = {expected_tp_size} " \
|
||||
"from model but tool is expecting TP world size = " \
|
||||
f"{empirical_tp_world_size} from model instead."
|
||||
for i in range(empirical_tp_world_size):
|
||||
assert i in rank_scales_map, "Expected TP world size = "\
|
||||
f"{empirical_tp_world_size} but did not find KV " \
|
||||
f"cache scaling factors for TP rank {i}"
|
||||
print(f"Found TP world size = {empirical_tp_world_size} "
|
||||
"when extracting KV cache scales!")
|
||||
return rank_scales_map
|
||||
|
||||
|
||||
def _metadata_extractor(quantized_model_dir: str,
|
||||
metadata_extract_fns: \
|
||||
Dict[str, Callable[[Dict[str, Any]], Any]]) \
|
||||
-> Dict[str, Any]:
|
||||
"""
|
||||
Given a directory containing quantized model files, this function
|
||||
aims to extract metadata from the JSON files within this directory.
|
||||
Each JSON file is expected to represent a dictionary in JSON
|
||||
format (referred to as a "JSON-dictionary"). Metadata extraction is
|
||||
defined by a dictionary called metadata_extract_fns, where each
|
||||
metadata field name is mapped to an extraction function.
|
||||
|
||||
These extraction functions are designed to take a JSON-dictionary
|
||||
as their only argument and return the corresponding metadata.
|
||||
While extraction functions are permitted to raise exceptions, they
|
||||
should only raise a KeyError or ValueError if the metadata field
|
||||
cannot be extracted from the current JSON-dictionary, yet there's
|
||||
a possibility of finding it in another JSON-dictionary.
|
||||
|
||||
The function returns a dictionary that maps metadata fields to
|
||||
their extracted data. The keys of this dictionary correspond exactly
|
||||
to those in metadata_extract_fns. If any fields fail to be extracted,
|
||||
their corresponding values are set to None, and a warning is printed.
|
||||
"""
|
||||
if not os.path.isdir(quantized_model_dir):
|
||||
raise FileNotFoundError(
|
||||
f"The quantized model directory `{quantized_model_dir}` "
|
||||
"does not exist.")
|
||||
metadata_files = glob.glob(os.path.join(quantized_model_dir, "*.json"))
|
||||
|
||||
result: Dict[str, Any] = {}
|
||||
for file in metadata_files:
|
||||
with open(file) as f:
|
||||
try:
|
||||
metadata = json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
print(f"Could not parse `{file}` as a valid metadata file,"
|
||||
" skipping it.")
|
||||
continue
|
||||
if not isinstance(metadata, dict):
|
||||
print(f"The file `{file}` does not correspond to a "
|
||||
"JSON-serialized dictionary, skipping it.")
|
||||
continue
|
||||
for metadata_name, extract_fn in metadata_extract_fns.items():
|
||||
try:
|
||||
metadata_info = extract_fn(metadata)
|
||||
if metadata_name not in result:
|
||||
result[metadata_name] = metadata_info
|
||||
elif metadata_info != result[metadata_name]:
|
||||
raise RuntimeError(
|
||||
"Metadata mismatch! Originally found "
|
||||
f"{metadata_name} = {result[metadata_name]} but "
|
||||
f"now found {metadata_name} = {metadata_info} in "
|
||||
f"`{file}`")
|
||||
except KeyError:
|
||||
# It is possible that a given file does not contain some
|
||||
# of our selected metadata as it could be located in some
|
||||
# other metadata file.
|
||||
# 'EFINAE': extract_fn failure is not an error.
|
||||
pass
|
||||
except ValueError:
|
||||
# See above.
|
||||
pass
|
||||
|
||||
# Warn if we cannot find any of the requested metadata
|
||||
for metadata_name in metadata_extract_fns:
|
||||
if metadata_name not in result:
|
||||
print("WARNING: Unable to find requested metadata field "
|
||||
f"`{metadata_name}`, setting it to None.")
|
||||
result[metadata_name] = None
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main(args):
|
||||
metadata_extract_fns = {
|
||||
"model_type": lambda json_dict: json_dict["layers"][0]["decoder_type"],
|
||||
"tp_size": lambda json_dict: int(json_dict["tensor_parallel"]),
|
||||
"model_dtype": lambda json_dict: json_dict["dtype"]
|
||||
}
|
||||
recovered_metadata = _metadata_extractor(args.quantized_model,
|
||||
metadata_extract_fns)
|
||||
if args.tp_size is not None:
|
||||
metadata_tp_size = recovered_metadata["tp_size"]
|
||||
if metadata_tp_size is not None:
|
||||
assert args.tp_size == metadata_tp_size, \
|
||||
f"User expected TP world size = {args.tp_size} " \
|
||||
f"but found TP world size = {metadata_tp_size} from metadata!"
|
||||
expected_tp_size = args.tp_size or recovered_metadata["tp_size"]
|
||||
rank_keyword = "rank"
|
||||
hf_tensor_files, use_safetensors = _prepare_hf_weights(
|
||||
args.quantized_model, args.load_format)
|
||||
rank_scales_map = _kv_scales_extractor(hf_tensor_files, use_safetensors,
|
||||
rank_keyword, expected_tp_size)
|
||||
# Postprocess: formatting to the current schema. Consider pulling it
|
||||
# out into a dedicated function should it ever become more complicated.
|
||||
rank_scales_map = {
|
||||
rank: {k: scale[k]
|
||||
for k in sorted(scale.keys())}
|
||||
for rank, scale in rank_scales_map.items()
|
||||
}
|
||||
# TODO: Expand this with activation and weights scaling factors when
|
||||
# they are used in the future
|
||||
schema = QuantParamSchema(
|
||||
model_type=recovered_metadata["model_type"],
|
||||
kv_cache={
|
||||
"dtype": ("float8_e4m3fn" if len(rank_scales_map) > 0 else
|
||||
recovered_metadata["model_dtype"]),
|
||||
"scaling_factor":
|
||||
rank_scales_map
|
||||
},
|
||||
)
|
||||
|
||||
if args.output_dir is None:
|
||||
output_file = os.path.join(args.quantized_model, args.output_name)
|
||||
else:
|
||||
if not os.path.isdir(args.output_dir):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
output_file = os.path.join(args.output_dir, args.output_name)
|
||||
|
||||
with open(output_file, 'w') as f:
|
||||
f.write(schema.model_dump_json(indent=4))
|
||||
print(f"Completed! KV cache scaling factors saved to {output_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="This simple utility extracts the "
|
||||
"KV cache scaling factors from a quantized HF model "
|
||||
"and saves them to a JSON file compatible with later "
|
||||
"use by vLLM (pass this file to the appropriate "
|
||||
"runtime typically using the argument "
|
||||
"--quantization-param-path <filename>). This is only used "
|
||||
"if the KV cache dtype is FP8 and on ROCm (AMD GPU).")
|
||||
parser.add_argument(
|
||||
"--quantized-model",
|
||||
help="Specify the directory containing a single quantized HF model. "
|
||||
"It is expected that the quantization format is FP8_E4M3, for use "
|
||||
"on ROCm (AMD GPU).",
|
||||
required=True)
|
||||
parser.add_argument(
|
||||
"--load_format",
|
||||
help="Optionally specify the format of the model's tensor files "
|
||||
"containing the KV cache scaling factors.",
|
||||
choices=["auto", "safetensors", "npz", "pt"],
|
||||
default="auto")
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
help="Optionally specify the output directory. By default the "
|
||||
"KV cache scaling factors will be saved in the model directory, "
|
||||
"however you can override this behavior here.",
|
||||
default=None)
|
||||
parser.add_argument(
|
||||
"--output-name",
|
||||
help="Optionally specify the output filename.",
|
||||
# TODO: Change this once additional scaling factors are enabled
|
||||
default="kv_cache_scales.json")
|
||||
parser.add_argument(
|
||||
"--tp-size",
|
||||
help="Optionally specify the tensor-parallel (TP) size that the "
|
||||
"quantized model should correspond to. If specified, during KV "
|
||||
"cache scaling factor extraction the observed TP size will be "
|
||||
"checked against this and an error will be raised if there is "
|
||||
"a mismatch. If not specified, the quantized model's expected "
|
||||
"TP size is instead inferred from the largest TP rank observed. "
|
||||
"The expected TP size is cross-checked against the TP ranks "
|
||||
"observed in the quantized model and an error is raised if any "
|
||||
"discrepancies are found.",
|
||||
default=None,
|
||||
type=int)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
@ -1,32 +0,0 @@
|
||||
### Quantizer Utilities
|
||||
`quantize.py`: NVIDIA Quantization utilities using TensorRT-Model-Optimizer, ported
|
||||
from TensorRT-LLM: [`examples/quantization/quantize.py`](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/quantization/quantize.py)
|
||||
|
||||
### Prerequisite
|
||||
|
||||
#### AMMO (AlgorithMic Model Optimization) Installation: nvidia-ammo 0.7.1 or later
|
||||
`pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com nvidia-ammo`
|
||||
|
||||
#### AMMO Download (code and docs)
|
||||
`https://developer.nvidia.com/downloads/assets/cuda/files/nvidia-ammo/nvidia_ammo-0.5.0.tar.gz`
|
||||
`https://developer.nvidia.com/downloads/assets/cuda/files/nvidia-ammo/nvidia_ammo-0.7.1.tar.gz`
|
||||
|
||||
### Usage
|
||||
|
||||
#### Run on H100 system for speed if FP8; number of GPUs depends on the model size
|
||||
|
||||
#### Example: quantize Llama2-7b model from HF to FP8 with FP8 KV Cache:
|
||||
`python quantize.py --model-dir ./ll2-7b --dtype float16 --qformat fp8 --kv-cache-dtype fp8 --output-dir ./ll2_7b_fp8 --calib-size 512 --tp-size 1`
|
||||
|
||||
Outputs: model structure, quantized model & parameters (with scaling factors) are in JSON and Safetensors (npz is generated only for the reference)
|
||||
```
|
||||
# ll ./ll2_7b_fp8/
|
||||
total 19998244
|
||||
drwxr-xr-x 2 root root 4096 Feb 7 01:08 ./
|
||||
drwxrwxr-x 8 1060 1061 4096 Feb 7 01:08 ../
|
||||
-rw-r--r-- 1 root root 176411 Feb 7 01:08 llama_tp1.json
|
||||
-rw-r--r-- 1 root root 13477087480 Feb 7 01:09 llama_tp1_rank0.npz
|
||||
-rw-r--r-- 1 root root 7000893272 Feb 7 01:08 rank0.safetensors
|
||||
#
|
||||
```
|
||||
|
||||
@ -1,367 +0,0 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # noqa: E501
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Adapted from examples/quantization/hf_ptq.py
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
|
||||
import ammo.torch.quantization as atq
|
||||
import numpy as np
|
||||
import torch
|
||||
from ammo.torch.export import export_model_config
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
RAND_SEED = 1234
|
||||
MAX_SEQ_LEN = 2048
|
||||
|
||||
EMPTY_CFG = {
|
||||
"quant_cfg": {
|
||||
"*weight_quantizer": {
|
||||
"enable": False,
|
||||
},
|
||||
"*input_quantizer": {
|
||||
"enable": False
|
||||
},
|
||||
"*lm_head*": {
|
||||
"enable": False
|
||||
},
|
||||
"*output_layer*": {
|
||||
"enable": False
|
||||
},
|
||||
"default": {
|
||||
"enable": False
|
||||
},
|
||||
},
|
||||
"algorithm": "max",
|
||||
}
|
||||
|
||||
KV_CACHE_CFG = {
|
||||
"*.query_key_value.output_quantizer": {
|
||||
"num_bits": 8,
|
||||
"axis": None,
|
||||
"enable": True
|
||||
},
|
||||
"*.Wqkv.output_quantizer": {
|
||||
"num_bits": 8,
|
||||
"axis": None,
|
||||
"enable": True
|
||||
},
|
||||
"*.W_pack.output_quantizer": {
|
||||
"num_bits": 8,
|
||||
"axis": None,
|
||||
"enable": True
|
||||
},
|
||||
"*.c_attn.output_quantizer": {
|
||||
"num_bits": 8,
|
||||
"axis": None,
|
||||
"enable": True
|
||||
},
|
||||
"*.k_proj.output_quantizer": {
|
||||
"num_bits": 8,
|
||||
"axis": None,
|
||||
"enable": True
|
||||
},
|
||||
"*.v_proj.output_quantizer": {
|
||||
"num_bits": 8,
|
||||
"axis": None,
|
||||
"enable": True
|
||||
},
|
||||
}
|
||||
|
||||
QUANT_CFG_CHOICES = {
|
||||
"int8_sq": atq.INT8_SMOOTHQUANT_CFG,
|
||||
"fp8": atq.FP8_DEFAULT_CFG,
|
||||
"int4_awq": atq.INT4_AWQ_CFG,
|
||||
"w4a8_awq": atq.W4A8_AWQ_BETA_CFG,
|
||||
"int8_wo": EMPTY_CFG,
|
||||
"int4_wo": EMPTY_CFG,
|
||||
"full_prec": EMPTY_CFG,
|
||||
}
|
||||
|
||||
MODEL_NAME_PATTERN_MAP = {
|
||||
"GPT2": "gpt2",
|
||||
"Xverse": "llama",
|
||||
"Llama": "llama",
|
||||
"Mistral": "llama",
|
||||
"GPTJ": "gptj",
|
||||
"FalconForCausalLM": "falcon",
|
||||
"RWForCausalLM": "falcon",
|
||||
"baichuan": "baichuan",
|
||||
"MPT": "mpt",
|
||||
"Bloom": "bloom",
|
||||
"ChatGLM": "chatglm",
|
||||
"QWen": "qwen",
|
||||
}
|
||||
|
||||
|
||||
def get_tokenizer(ckpt_path, max_seq_len=MAX_SEQ_LEN, model_type=None):
|
||||
print(f"Initializing tokenizer from {ckpt_path}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
ckpt_path,
|
||||
model_max_length=max_seq_len,
|
||||
padding_side="left",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
if model_type and model_type == "qwen":
|
||||
# qwen use token id 151643 as pad and eos tokens
|
||||
tokenizer.pad_token = tokenizer.convert_ids_to_tokens(151643)
|
||||
tokenizer.eos_token = tokenizer.convert_ids_to_tokens(151643)
|
||||
|
||||
# can't set attribute 'pad_token' for "<unk>"
|
||||
if tokenizer.pad_token != "<unk>":
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
assert (tokenizer.pad_token
|
||||
is not None), f"Pad token for {model_type} cannot be set!"
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def get_model(ckpt_path, dtype="fp16", device="cuda"):
|
||||
print(f"Initializing model from {ckpt_path}")
|
||||
if dtype == "bf16" or dtype == "bfloat16":
|
||||
dtype = torch.bfloat16
|
||||
elif dtype == "fp16" or dtype == "float16":
|
||||
dtype = torch.float16
|
||||
elif dtype == "fp32" or dtype == "float32":
|
||||
dtype = torch.float32
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown dtype {dtype}")
|
||||
|
||||
# model_kwargs = {"torch_dtype": dtype}
|
||||
model_kwargs = {"torch_dtype": "auto"}
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(ckpt_path,
|
||||
device_map="auto",
|
||||
**model_kwargs,
|
||||
trust_remote_code=True)
|
||||
model.eval()
|
||||
|
||||
model_dtype = next(model.parameters()).dtype
|
||||
if dtype != model_dtype:
|
||||
print("[TensorRT-LLM][WARNING] The manually set model data type is "
|
||||
f"{dtype}, but the data type of the HuggingFace model is "
|
||||
f"{model_dtype}.")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_model_type(model):
|
||||
for k, v in MODEL_NAME_PATTERN_MAP.items():
|
||||
if k.lower() in type(model).__name__.lower():
|
||||
return v
|
||||
return None
|
||||
|
||||
|
||||
def get_calib_dataloader(data="cnn_dailymail",
|
||||
tokenizer=None,
|
||||
batch_size=1,
|
||||
calib_size=512,
|
||||
block_size=512,
|
||||
device=None):
|
||||
print("Loading calibration dataset")
|
||||
if data == "pileval":
|
||||
dataset = load_dataset(
|
||||
"json",
|
||||
data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst",
|
||||
split="train")
|
||||
dataset = dataset["text"][:calib_size]
|
||||
elif data == "cnn_dailymail":
|
||||
dataset = load_dataset("cnn_dailymail", name="3.0.0", split="train")
|
||||
dataset = dataset["article"][:calib_size]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
batch_encoded = tokenizer.batch_encode_plus(dataset,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=block_size)
|
||||
if device:
|
||||
batch_encoded = batch_encoded.to(device)
|
||||
batch_encoded = batch_encoded["input_ids"]
|
||||
|
||||
calib_dataloader = DataLoader(batch_encoded,
|
||||
batch_size=batch_size,
|
||||
shuffle=False)
|
||||
|
||||
return calib_dataloader
|
||||
|
||||
|
||||
def quantize_model(model, quant_cfg, calib_dataloader=None):
|
||||
|
||||
def calibrate_loop():
|
||||
if calib_dataloader is None:
|
||||
return
|
||||
"""Adjusts weights and scaling factors based on selected algorithms."""
|
||||
for idx, data in enumerate(calib_dataloader):
|
||||
print(f"Calibrating batch {idx}")
|
||||
model(data)
|
||||
|
||||
print("Starting quantization...")
|
||||
start_time = time.time()
|
||||
atq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
|
||||
end_time = time.time()
|
||||
print("Quantization done. Total time used: {:.2f} s.".format(end_time -
|
||||
start_time))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def main(args):
|
||||
if not torch.cuda.is_available():
|
||||
raise OSError("GPU is required for inference.")
|
||||
|
||||
random.seed(RAND_SEED)
|
||||
np.random.seed(RAND_SEED)
|
||||
|
||||
model = get_model(args.model_dir, args.dtype, args.device)
|
||||
model_type = get_model_type(model)
|
||||
tokenizer = get_tokenizer(args.model_dir, model_type=model_type)
|
||||
|
||||
if args.qformat in ["full_prec", "int8_wo", "int4_wo"
|
||||
] and args.kv_cache_dtype is None:
|
||||
print(f"No quantization applied, export {args.dtype} model")
|
||||
else:
|
||||
if "awq" in args.qformat:
|
||||
if args.calib_size > 32:
|
||||
print("AWQ calibration could take longer with calib_size = "
|
||||
f"{args.calib_size}, Using calib_size=32 instead")
|
||||
args.calib_size = 32
|
||||
print("\nAWQ calibration could take longer than other calibration "
|
||||
"methods. Please increase the batch size to speed up the "
|
||||
"calibration process. Batch size can be set by adding the "
|
||||
"argument --batch_size <batch_size> to the command line.\n")
|
||||
|
||||
calib_dataloader = get_calib_dataloader(
|
||||
tokenizer=tokenizer,
|
||||
batch_size=args.batch_size,
|
||||
calib_size=args.calib_size,
|
||||
device=args.device,
|
||||
)
|
||||
|
||||
if args.qformat in QUANT_CFG_CHOICES:
|
||||
quant_cfg = QUANT_CFG_CHOICES[args.qformat]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported quantization format: {args.qformat}")
|
||||
|
||||
if "awq" in args.qformat:
|
||||
quant_cfg = copy.deepcopy(QUANT_CFG_CHOICES[args.qformat])
|
||||
weight_quantizer = quant_cfg["quant_cfg"][
|
||||
"*weight_quantizer"] # type: ignore
|
||||
if isinstance(weight_quantizer, list):
|
||||
weight_quantizer = weight_quantizer[0]
|
||||
weight_quantizer["block_sizes"][-1] = args.awq_block_size
|
||||
|
||||
if args.kv_cache_dtype is not None:
|
||||
if args.kv_cache_dtype == "fp8":
|
||||
for value in KV_CACHE_CFG.values():
|
||||
value.update({"num_bits": (4, 3)}) # type: ignore
|
||||
quant_cfg["quant_cfg"].update(KV_CACHE_CFG) # type: ignore
|
||||
|
||||
print(quant_cfg)
|
||||
|
||||
model = quantize_model(model, quant_cfg, calib_dataloader)
|
||||
|
||||
with torch.inference_mode():
|
||||
if model_type is None:
|
||||
print(f"Unknown model type {type(model).__name__}. Continue "
|
||||
"exporting...")
|
||||
model_type = f"unknown:{type(model).__name__}"
|
||||
|
||||
export_path = args.output_dir
|
||||
start_time = time.time()
|
||||
|
||||
if args.qformat == "int4_awq" and model_type == "qwen":
|
||||
torch.save(model.state_dict(), export_path)
|
||||
else:
|
||||
export_npz = (model_type not in [
|
||||
'gptj', 'falcon', 'chatglm', 'mpt', 'llama', 'baichuan'
|
||||
])
|
||||
|
||||
# export safetensors
|
||||
export_model_config(
|
||||
model,
|
||||
model_type,
|
||||
getattr(torch, args.dtype),
|
||||
export_dir=export_path,
|
||||
inference_tensor_parallel=args.tp_size,
|
||||
inference_pipeline_parallel=args.pp_size,
|
||||
# export_tensorrt_llm_config=(not export_npz),
|
||||
export_tensorrt_llm_config=False,
|
||||
export_npz=export_npz)
|
||||
|
||||
# Workaround for wo quantization
|
||||
if args.qformat in ["int8_wo", "int4_wo", "full_prec"]:
|
||||
with open(f"{export_path}/config.json") as f:
|
||||
tensorrt_llm_config = json.load(f)
|
||||
if args.qformat == "int8_wo":
|
||||
tensorrt_llm_config["quantization"]["quant_algo"] = 'W8A16'
|
||||
elif args.qformat == "int4_wo":
|
||||
tensorrt_llm_config["quantization"]["quant_algo"] = 'W4A16'
|
||||
else:
|
||||
tensorrt_llm_config["quantization"]["quant_algo"] = None
|
||||
with open(f"{export_path}/config.json", "w") as f:
|
||||
json.dump(tensorrt_llm_config, f, indent=4)
|
||||
|
||||
end_time = time.time()
|
||||
print("Quantized model exported to {} \nTotal time used {:.2f} s.".
|
||||
format(export_path, end_time - start_time))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--model-dir",
|
||||
help="Specify where the HuggingFace model is",
|
||||
required=True)
|
||||
parser.add_argument("--device", default="cuda")
|
||||
parser.add_argument("--dtype", help="Model data type.", default="float16")
|
||||
parser.add_argument(
|
||||
"--qformat",
|
||||
help="Quantization format.",
|
||||
default="full_prec",
|
||||
choices=[
|
||||
"fp8", "int8_sq", "int4_awq", "w4a8_awq", "int8_wo", "int4_wo",
|
||||
"full_prec"
|
||||
],
|
||||
)
|
||||
parser.add_argument("--batch-size",
|
||||
help="Batch size for calibration.",
|
||||
type=int,
|
||||
default=1)
|
||||
parser.add_argument("--calib-size",
|
||||
help="Number of samples for calibration.",
|
||||
type=int,
|
||||
default=512)
|
||||
parser.add_argument("--output-dir", default="exported_model")
|
||||
parser.add_argument("--tp-size", type=int, default=1)
|
||||
parser.add_argument("--pp-size", type=int, default=1)
|
||||
parser.add_argument("--awq-block-size", type=int, default=128)
|
||||
parser.add_argument("--kv-cache-dtype",
|
||||
help="KV Cache dtype.",
|
||||
default=None,
|
||||
choices=["int8", "fp8", None])
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
@ -5,7 +5,7 @@ requests >= 2.26.0
|
||||
tqdm
|
||||
blake3
|
||||
py-cpuinfo
|
||||
transformers >= 4.45.2 # Required for Llama 3.2 and Qwen2-VL.
|
||||
transformers >= 4.48.2 # Required for Bamba model and Transformers backend.
|
||||
tokenizers >= 0.19.1 # Required for Llama 3.
|
||||
protobuf # Required by LlamaTokenizer.
|
||||
fastapi >= 0.107.0, < 0.113.0; python_version < '3.9'
|
||||
@ -19,7 +19,7 @@ pillow # Required for image processing
|
||||
prometheus-fastapi-instrumentator >= 7.0.0
|
||||
tiktoken >= 0.6.0 # Required for DBRX tokenizer
|
||||
lm-format-enforcer >= 0.10.9, < 0.11
|
||||
outlines == 0.1.11 # Requires pytorch
|
||||
outlines == 0.1.11
|
||||
lark == 1.2.2
|
||||
xgrammar >= 0.1.6; platform_machine == "x86_64"
|
||||
typing_extensions >= 4.10
|
||||
@ -34,6 +34,6 @@ pyyaml
|
||||
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
|
||||
setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12
|
||||
einops # Required for Qwen2-VL.
|
||||
compressed-tensors == 0.8.1 # required for compressed-tensors, requires pytorch
|
||||
compressed-tensors == 0.9.1 # required for compressed-tensors
|
||||
depyf==0.18.0 # required for profiling and debugging with compilation config
|
||||
cloudpickle # allows pickling lambda functions in model_executor/models/registry.py
|
||||
cloudpickle # allows pickling lambda functions in model_executor/models/registry.py
|
||||
@ -3,7 +3,7 @@
|
||||
|
||||
# Dependencies for HPU code
|
||||
ray
|
||||
triton
|
||||
triton==3.1.0
|
||||
pandas
|
||||
tabulate
|
||||
setuptools>=61
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# This file is autogenerated by pip-compile with Python 3.12
|
||||
# by the following command:
|
||||
#
|
||||
# python3.12 -m piptools compile requirements-test.in -o requirements-test.txt
|
||||
# python3.12 -m piptools compile requirements-test.in -o requirements-test.txt
|
||||
#
|
||||
absl-py==2.1.0
|
||||
# via rouge-score
|
||||
@ -106,9 +106,17 @@ dnspython==2.7.0
|
||||
docutils==0.16
|
||||
# via awscli
|
||||
einops==0.8.0
|
||||
# via -r requirements-test.in
|
||||
# via
|
||||
# -r requirements-test.in
|
||||
# encodec
|
||||
# vector-quantize-pytorch
|
||||
# vocos
|
||||
einx==0.3.0
|
||||
# via vector-quantize-pytorch
|
||||
email-validator==2.2.0
|
||||
# via pydantic
|
||||
encodec==0.1.1
|
||||
# via vocos
|
||||
evaluate==0.4.3
|
||||
# via lm-eval
|
||||
fastparquet==2024.11.0
|
||||
@ -125,6 +133,8 @@ filelock==3.16.1
|
||||
# triton
|
||||
fonttools==4.54.1
|
||||
# via matplotlib
|
||||
frozendict==2.4.6
|
||||
# via einx
|
||||
frozenlist==1.5.0
|
||||
# via
|
||||
# aiohttp
|
||||
@ -159,6 +169,7 @@ huggingface-hub==0.26.2
|
||||
# timm
|
||||
# tokenizers
|
||||
# transformers
|
||||
# vocos
|
||||
idna==3.10
|
||||
# via
|
||||
# anyio
|
||||
@ -261,6 +272,8 @@ numpy==1.26.4
|
||||
# cupy-cuda12x
|
||||
# datasets
|
||||
# decord
|
||||
# einx
|
||||
# encodec
|
||||
# evaluate
|
||||
# fastparquet
|
||||
# genai-perf
|
||||
@ -283,6 +296,7 @@ numpy==1.26.4
|
||||
# torchvision
|
||||
# transformers
|
||||
# tritonclient
|
||||
# vocos
|
||||
nvidia-cublas-cu12==12.4.5.8
|
||||
# via
|
||||
# nvidia-cudnn-cu12
|
||||
@ -455,6 +469,7 @@ pyyaml==6.0.2
|
||||
# responses
|
||||
# timm
|
||||
# transformers
|
||||
# vocos
|
||||
ray[adag]==2.40.0
|
||||
# via -r requirements-test.in
|
||||
redis==5.2.0
|
||||
@ -517,6 +532,7 @@ scipy==1.13.1
|
||||
# scikit-learn
|
||||
# sentence-transformers
|
||||
# statsmodels
|
||||
# vocos
|
||||
sentence-transformers==3.2.1
|
||||
# via -r requirements-test.in
|
||||
sentencepiece==0.2.0
|
||||
@ -540,7 +556,9 @@ sqlitedict==2.1.0
|
||||
statsmodels==0.14.4
|
||||
# via genai-perf
|
||||
sympy==1.13.1
|
||||
# via torch
|
||||
# via
|
||||
# einx
|
||||
# torch
|
||||
tabledata==1.3.3
|
||||
# via pytablewriter
|
||||
tabulate==0.9.0
|
||||
@ -568,12 +586,21 @@ torch==2.5.1
|
||||
# -r requirements-test.in
|
||||
# accelerate
|
||||
# bitsandbytes
|
||||
# encodec
|
||||
# lm-eval
|
||||
# peft
|
||||
# sentence-transformers
|
||||
# tensorizer
|
||||
# timm
|
||||
# torchaudio
|
||||
# torchvision
|
||||
# vector-quantize-pytorch
|
||||
# vocos
|
||||
torchaudio==2.5.1
|
||||
# via
|
||||
# -r requirements-test.in
|
||||
# encodec
|
||||
# vocos
|
||||
torchvision==0.20.1
|
||||
# via timm
|
||||
tqdm==4.66.6
|
||||
@ -584,13 +611,15 @@ tqdm==4.66.6
|
||||
# lm-eval
|
||||
# nltk
|
||||
# peft
|
||||
# pqdm
|
||||
# sentence-transformers
|
||||
# tqdm-multiprocess
|
||||
# transformers
|
||||
tqdm-multiprocess==0.0.11
|
||||
# via lm-eval
|
||||
transformers==4.47.0
|
||||
transformers==4.48.2
|
||||
# via
|
||||
# -r requirements-test.in
|
||||
# genai-perf
|
||||
# lm-eval
|
||||
# peft
|
||||
@ -615,6 +644,7 @@ typing-extensions==4.12.2
|
||||
# huggingface-hub
|
||||
# librosa
|
||||
# mistral-common
|
||||
# pqdm
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# torch
|
||||
@ -626,6 +656,10 @@ urllib3==2.2.3
|
||||
# requests
|
||||
# responses
|
||||
# tritonclient
|
||||
vector-quantize-pytorch==1.21.2
|
||||
# via -r requirements-test.in
|
||||
vocos==0.1.0
|
||||
# via -r requirements-test.in
|
||||
word2number==1.1
|
||||
# via lm-eval
|
||||
xxhash==3.5.0
|
||||
@ -638,4 +672,4 @@ zstandard==0.23.0
|
||||
# via lm-eval
|
||||
|
||||
# The following packages are considered to be unsafe in a requirements file:
|
||||
# setuptools
|
||||
# setuptools
|
||||
@ -13,13 +13,11 @@ ray[default]
|
||||
# Install torch_xla
|
||||
--pre
|
||||
--extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
--find-links https://storage.googleapis.com/libtpu-wheels/index.html
|
||||
--find-links https://storage.googleapis.com/libtpu-releases/index.html
|
||||
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
|
||||
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||
torch==2.6.0.dev20241126+cpu
|
||||
torchvision==0.20.0.dev20241126+cpu
|
||||
torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241126-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
||||
torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241126-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
||||
torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241126-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||
jaxlib==0.4.36.dev20241122
|
||||
jax==0.4.36.dev20241122
|
||||
torch==2.6.0.dev20241216+cpu
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||
16
setup.py
16
setup.py
@ -228,8 +228,11 @@ class cmake_build_ext(build_ext):
|
||||
|
||||
# CMake appends the extension prefix to the install path,
|
||||
# and outdir already contains that prefix, so we need to remove it.
|
||||
# We assume only the final component of extension prefix is added by
|
||||
# CMake, this is currently true for current extensions but may not
|
||||
# always be the case.
|
||||
prefix = outdir
|
||||
for i in range(ext.name.count('.')):
|
||||
if '.' in ext.name:
|
||||
prefix = prefix.parent
|
||||
|
||||
# prefix here should actually be the same for all components
|
||||
@ -298,9 +301,11 @@ class repackage_wheel(build_ext):
|
||||
files_to_copy = [
|
||||
"vllm/_C.abi3.so",
|
||||
"vllm/_moe_C.abi3.so",
|
||||
"vllm/vllm_flash_attn/vllm_flash_attn_c.abi3.so",
|
||||
"vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so",
|
||||
"vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so",
|
||||
"vllm/vllm_flash_attn/flash_attn_interface.py",
|
||||
"vllm/vllm_flash_attn/__init__.py",
|
||||
"vllm/cumem_allocator.abi3.so",
|
||||
# "vllm/_version.py", # not available in nightly wheels yet
|
||||
]
|
||||
file_members = filter(lambda x: x.filename in files_to_copy,
|
||||
@ -549,7 +554,7 @@ def get_requirements() -> List[str]:
|
||||
return resolved_requirements
|
||||
|
||||
if _no_device():
|
||||
requirements = _read_requirements("requirements-cuda.txt")
|
||||
requirements = _read_requirements("requirements-cpu.txt")
|
||||
elif _is_cuda():
|
||||
requirements = _read_requirements("requirements-cuda.txt")
|
||||
cuda_major, cuda_minor = torch.version.cuda.split(".")
|
||||
@ -592,8 +597,9 @@ if _is_hip():
|
||||
ext_modules.append(CMakeExtension(name="vllm._rocm_C"))
|
||||
|
||||
if _is_cuda():
|
||||
ext_modules.append(
|
||||
CMakeExtension(name="vllm.vllm_flash_attn.vllm_flash_attn_c"))
|
||||
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C"))
|
||||
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
|
||||
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
|
||||
|
||||
if _build_custom_ops():
|
||||
ext_modules.append(CMakeExtension(name="vllm._C"))
|
||||
|
||||
112
tests/basic_correctness/test_cumem.py
Normal file
112
tests/basic_correctness/test_cumem.py
Normal file
@ -0,0 +1,112 @@
|
||||
import torch
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.device_allocator.cumem import CuMemAllocator
|
||||
from vllm.utils import GiB_bytes
|
||||
|
||||
from ..utils import fork_new_process_for_each_test
|
||||
|
||||
|
||||
@fork_new_process_for_each_test
|
||||
def test_basic_cumem():
|
||||
# some tensors from default memory pool
|
||||
shape = (1024, 1024)
|
||||
x = torch.empty(shape, device='cuda')
|
||||
x.zero_()
|
||||
|
||||
# some tensors from custom memory pool
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
with allocator.use_memory_pool():
|
||||
# custom memory pool
|
||||
y = torch.empty(shape, device='cuda')
|
||||
y.zero_()
|
||||
y += 1
|
||||
z = torch.empty(shape, device='cuda')
|
||||
z.zero_()
|
||||
z += 2
|
||||
|
||||
# they can be used together
|
||||
output = x + y + z
|
||||
assert torch.allclose(output, torch.ones_like(output) * 3)
|
||||
|
||||
free_bytes = torch.cuda.mem_get_info()[0]
|
||||
allocator.sleep()
|
||||
free_bytes_after_sleep = torch.cuda.mem_get_info()[0]
|
||||
assert free_bytes_after_sleep > free_bytes
|
||||
allocator.wake_up()
|
||||
|
||||
# they can be used together
|
||||
output = x + y + z
|
||||
assert torch.allclose(output, torch.ones_like(output) * 3)
|
||||
|
||||
|
||||
@fork_new_process_for_each_test
|
||||
def test_cumem_with_cudagraph():
|
||||
allocator = CuMemAllocator.get_instance()
|
||||
with allocator.use_memory_pool():
|
||||
weight = torch.eye(1024, device='cuda')
|
||||
with allocator.use_memory_pool(tag="discard"):
|
||||
cache = torch.empty(1024, 1024, device='cuda')
|
||||
|
||||
def model(x):
|
||||
out = x @ weight
|
||||
cache[:out.size(0)].copy_(out)
|
||||
return out + 1
|
||||
|
||||
x = torch.empty(128, 1024, device='cuda')
|
||||
|
||||
# warmup
|
||||
model(x)
|
||||
|
||||
# capture cudagraph
|
||||
model_graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(model_graph):
|
||||
y = model(x)
|
||||
|
||||
free_bytes = torch.cuda.mem_get_info()[0]
|
||||
allocator.sleep()
|
||||
free_bytes_after_sleep = torch.cuda.mem_get_info()[0]
|
||||
assert free_bytes_after_sleep > free_bytes
|
||||
allocator.wake_up()
|
||||
|
||||
# after waking up, the content in the weight tensor
|
||||
# should be restored, but the content in the cache tensor
|
||||
# should be discarded
|
||||
|
||||
# this operation is also compatible with cudagraph
|
||||
|
||||
x.random_()
|
||||
model_graph.replay()
|
||||
|
||||
# cache content is as expected
|
||||
assert torch.allclose(x, cache[:x.size(0)])
|
||||
|
||||
# output content is as expected
|
||||
assert torch.allclose(y, x + 1)
|
||||
|
||||
|
||||
@fork_new_process_for_each_test
|
||||
def test_end_to_end():
|
||||
free, total = torch.cuda.mem_get_info()
|
||||
used_bytes_baseline = total - free # in case other process is running
|
||||
llm = LLM("meta-llama/Llama-3.2-1B", enable_sleep_mode=True)
|
||||
prompt = "How are you?"
|
||||
sampling_params = SamplingParams(temperature=0, max_tokens=10)
|
||||
output = llm.generate(prompt, sampling_params)
|
||||
|
||||
# the benefit of `llm.sleep(level=2)` is mainly CPU memory usage,
|
||||
# which is difficult to measure in the test. therefore, we only
|
||||
# test sleep level 1 here.
|
||||
llm.sleep(level=1)
|
||||
|
||||
free_gpu_bytes_after_sleep, total = torch.cuda.mem_get_info()
|
||||
used_bytes = total - free_gpu_bytes_after_sleep - used_bytes_baseline
|
||||
# now the memory usage is mostly cudagraph memory pool,
|
||||
# and it should be less than the model weights (1B model, 2GiB weights)
|
||||
assert used_bytes < 2 * GiB_bytes
|
||||
|
||||
llm.wake_up()
|
||||
output2 = llm.generate(prompt, sampling_params)
|
||||
|
||||
# cmp output
|
||||
assert output[0].outputs[0].text == output2[0].outputs[0].text
|
||||
@ -796,6 +796,44 @@ class TestPrefixCachingBlockAllocator:
|
||||
block_hashes=block_hashes_seq1)
|
||||
assert len(cached_blocks) == len(blocks_seq1) - num_evicted_blocks
|
||||
|
||||
# Test reset prefix cache
|
||||
@staticmethod
|
||||
@pytest.mark.parametrize("num_blocks", [10])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
def test_reset_prefix_cache(num_blocks: int, block_size: int):
|
||||
"""This test case simulates the case of resetting the prefix cache."""
|
||||
|
||||
allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
|
||||
block_size=block_size)
|
||||
token_ids = list(range(3 * block_size))
|
||||
|
||||
first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||
block_size=block_size,
|
||||
token_ids=token_ids,
|
||||
allocator=allocator,
|
||||
)
|
||||
second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain(
|
||||
block_size=block_size,
|
||||
token_ids=token_ids,
|
||||
allocator=allocator,
|
||||
)
|
||||
|
||||
# Free each block in the first chain.
|
||||
for block in first_chain:
|
||||
allocator.free(block)
|
||||
|
||||
# Failed to reset prefix cache because some blocks are not freed yet.
|
||||
assert not allocator.reset_prefix_cache()
|
||||
assert allocator.get_prefix_cache_hit_rate() > 0.0
|
||||
|
||||
# Free each block in the second chain.
|
||||
for block in second_chain:
|
||||
allocator.free(block)
|
||||
|
||||
# Reset prefix cache.
|
||||
assert allocator.reset_prefix_cache()
|
||||
assert allocator.get_prefix_cache_hit_rate() == 0.0
|
||||
|
||||
@staticmethod
|
||||
def create_immutable_chain(
|
||||
block_size: int,
|
||||
|
||||
@ -20,7 +20,7 @@ TASK = "gsm8k"
|
||||
FILTER = "exact_match,strict-match"
|
||||
RTOL = 0.03
|
||||
EXPECTED_VALUE = 0.58
|
||||
DEFAULT_ARGS = ["--max-model-len", "2048", "--disable-log-requests"]
|
||||
DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"]
|
||||
MORE_ARGS_LIST = [
|
||||
[], # Default
|
||||
["--enable-chunked-prefill"], # Chunked
|
||||
@ -66,14 +66,21 @@ def run_test(more_args):
|
||||
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(),
|
||||
reason="V1 currently only supported on CUDA")
|
||||
@pytest.mark.skipif(not current_platform.is_cuda()
|
||||
and not current_platform.is_tpu(),
|
||||
reason="V1 currently only supported on CUDA and TPU")
|
||||
def test_lm_eval_accuracy_v1_engine(monkeypatch):
|
||||
"""Run with the V1 Engine."""
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
run_test([])
|
||||
more_args = []
|
||||
|
||||
# Limit compilation time for V1
|
||||
if current_platform.is_tpu():
|
||||
more_args = ["--max-num-seqs", "64"]
|
||||
|
||||
run_test(more_args)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("more_args", MORE_ARGS_LIST)
|
||||
|
||||
@ -1,90 +0,0 @@
|
||||
{
|
||||
"model_type": "llama",
|
||||
"kv_cache": {
|
||||
"dtype": "float8_e4m3fn",
|
||||
"scaling_factor": {
|
||||
"0": {
|
||||
"0": 0.0230364128947258,
|
||||
"1": 0.01979283057153225,
|
||||
"2": 0.0241350457072258,
|
||||
"3": 0.0308314748108387,
|
||||
"4": 0.0430733822286129,
|
||||
"5": 0.0370396226644516,
|
||||
"6": 0.0306222103536129,
|
||||
"7": 0.0357491634786129,
|
||||
"8": 0.0358189195394516,
|
||||
"9": 0.0443289652466774,
|
||||
"10": 0.0433175228536129,
|
||||
"11": 0.0416782945394516,
|
||||
"12": 0.0366908498108387,
|
||||
"13": 0.0432477705180645,
|
||||
"14": 0.0410505048930645,
|
||||
"15": 0.0457589291036129,
|
||||
"16": 0.0418526791036129,
|
||||
"17": 0.0432477705180645,
|
||||
"18": 0.0469447560608387,
|
||||
"19": 0.0514787957072258,
|
||||
"20": 0.0541294664144516,
|
||||
"21": 0.0587681382894516,
|
||||
"22": 0.0625,
|
||||
"23": 0.0585588738322258,
|
||||
"24": 0.0600237175822258,
|
||||
"25": 0.0588030144572258,
|
||||
"26": 0.0531180277466774,
|
||||
"27": 0.06396484375,
|
||||
"28": 0.0603027381002903,
|
||||
"29": 0.0582101047039032,
|
||||
"30": 0.0625348836183548,
|
||||
"31": 0.0585588738322258,
|
||||
"32": 0.0582798570394516,
|
||||
"33": 0.0575125589966774,
|
||||
"34": 0.0590820349752903,
|
||||
"35": 0.0614188089966774,
|
||||
"36": 0.0631975457072258,
|
||||
"37": 0.0615931935608387,
|
||||
"38": 0.0601283498108387,
|
||||
"39": 0.0571986623108387,
|
||||
"40": 0.0670340433716774,
|
||||
"41": 0.0523507259786129,
|
||||
"42": 0.0547223798930645,
|
||||
"43": 0.0631975457072258,
|
||||
"44": 0.0663713738322258,
|
||||
"45": 0.0603376142680645,
|
||||
"46": 0.0652204304933548,
|
||||
"47": 0.0734514519572258,
|
||||
"48": 0.0693708211183548,
|
||||
"49": 0.0725446492433548,
|
||||
"50": 0.0627790242433548,
|
||||
"51": 0.0691266804933548,
|
||||
"52": 0.0688825398683548,
|
||||
"53": 0.068429134786129,
|
||||
"54": 0.0605119988322258,
|
||||
"55": 0.0799386203289032,
|
||||
"56": 0.0853097140789032,
|
||||
"57": 0.0661969929933548,
|
||||
"58": 0.0689871683716774,
|
||||
"59": 0.0724051371216774,
|
||||
"60": 0.0541643425822258,
|
||||
"61": 0.0626743882894516,
|
||||
"62": 0.0628487765789032,
|
||||
"63": 0.0607212632894516,
|
||||
"64": 0.0589076466858387,
|
||||
"65": 0.0451660193502903,
|
||||
"66": 0.0453055277466774,
|
||||
"67": 0.0414341539144516,
|
||||
"68": 0.0385044664144516,
|
||||
"69": 0.0414341539144516,
|
||||
"70": 0.0466308631002903,
|
||||
"71": 0.0399693101644516,
|
||||
"72": 0.0437011756002903,
|
||||
"73": 0.0434221550822258,
|
||||
"74": 0.0428989976644516,
|
||||
"75": 0.0401785746216774,
|
||||
"76": 0.0431082621216774,
|
||||
"77": 0.0484444759786129,
|
||||
"78": 0.0417829267680645,
|
||||
"79": 0.0418178029358387
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,42 +0,0 @@
|
||||
{
|
||||
"model_type": "llama",
|
||||
"kv_cache": {
|
||||
"dtype": "float8_e4m3fn",
|
||||
"scaling_factor": {
|
||||
"0": {
|
||||
"0": 0.0152239128947258,
|
||||
"1": 0.0188860222697258,
|
||||
"2": 0.0354178324341774,
|
||||
"3": 0.0376674123108387,
|
||||
"4": 0.0418526791036129,
|
||||
"5": 0.0433175228536129,
|
||||
"6": 0.0397600457072258,
|
||||
"7": 0.0424455925822258,
|
||||
"8": 0.0415387861430645,
|
||||
"9": 0.0408412404358387,
|
||||
"10": 0.0395856611430645,
|
||||
"11": 0.0377371683716774,
|
||||
"12": 0.0400739423930645,
|
||||
"13": 0.040771484375,
|
||||
"14": 0.0393415205180645,
|
||||
"15": 0.0369001142680645,
|
||||
"16": 0.03857421875,
|
||||
"17": 0.0387486070394516,
|
||||
"18": 0.0403180830180645,
|
||||
"19": 0.0396205373108387,
|
||||
"20": 0.0375627800822258,
|
||||
"21": 0.0407366082072258,
|
||||
"22": 0.0432477705180645,
|
||||
"23": 0.0377022884786129,
|
||||
"24": 0.0399693101644516,
|
||||
"25": 0.0374581478536129,
|
||||
"26": 0.0413295216858387,
|
||||
"27": 0.0442243330180645,
|
||||
"28": 0.0424804724752903,
|
||||
"29": 0.0456891767680645,
|
||||
"30": 0.0409109964966774,
|
||||
"31": 0.0482352152466774
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -182,7 +182,7 @@ def test_paged_attention(
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
# Using default kv_scale
|
||||
k_scale = v_scale = 1.0
|
||||
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
|
||||
# Call the paged attention kernel.
|
||||
output = torch.empty_like(query)
|
||||
|
||||
@ -210,7 +210,7 @@ def test_paged_attention(
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
# Using default kv_scale
|
||||
k_scale = v_scale = 1.0
|
||||
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
tp_rank = 0
|
||||
|
||||
# Call the paged attention kernel.
|
||||
|
||||
@ -160,7 +160,7 @@ def test_reshape_and_cache(
|
||||
cloned_value_cache = value_cache.clone()
|
||||
|
||||
# Using default kv_scale
|
||||
k_scale = v_scale = 1.0
|
||||
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
|
||||
# Call the reshape_and_cache kernel.
|
||||
opcheck(torch.ops._C_cache_ops.reshape_and_cache,
|
||||
@ -258,8 +258,8 @@ def test_reshape_and_cache_flash(
|
||||
del key_caches
|
||||
del value_caches
|
||||
|
||||
k_scale = key.amax().item() / 256
|
||||
v_scale = value.amax().item() / 256
|
||||
k_scale = (key.amax() / 256.0).to(torch.float32)
|
||||
v_scale = (value.amax() / 256.0).to(torch.float32)
|
||||
|
||||
# Clone the KV caches.
|
||||
if kv_cache_dtype == "fp8":
|
||||
@ -284,12 +284,12 @@ def test_reshape_and_cache_flash(
|
||||
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
||||
ops.convert_fp8(result_key_cache,
|
||||
key_cache,
|
||||
k_scale,
|
||||
k_scale.item(),
|
||||
kv_dtype=kv_cache_dtype)
|
||||
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
||||
ops.convert_fp8(result_value_cache,
|
||||
value_cache,
|
||||
v_scale,
|
||||
v_scale.item(),
|
||||
kv_dtype=kv_cache_dtype)
|
||||
|
||||
# Run the reference implementation.
|
||||
|
||||
@ -78,6 +78,7 @@ CASES = [
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("soft_cap", [None, 50])
|
||||
@pytest.mark.parametrize("num_blocks", [2048])
|
||||
@pytest.mark.parametrize("fa_version", [2, 3])
|
||||
@torch.inference_mode()
|
||||
def test_cascade(
|
||||
seq_lens_and_common_prefix: Tuple[List[Tuple[int, int]], int],
|
||||
@ -87,8 +88,14 @@ def test_cascade(
|
||||
block_size: int,
|
||||
soft_cap: Optional[float],
|
||||
num_blocks: int,
|
||||
fa_version: int,
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6)
|
||||
or torch.cuda.get_device_capability() == (8, 9)):
|
||||
pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to "
|
||||
"insufficient shared memory for some shapes")
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
window_size = (-1, -1)
|
||||
@ -118,9 +125,7 @@ def test_cascade(
|
||||
cu_query_lens = torch.tensor([0] + query_lens,
|
||||
dtype=torch.int32).cumsum(dim=0,
|
||||
dtype=torch.int32)
|
||||
cu_kv_lens = torch.tensor([0] + kv_lens,
|
||||
dtype=torch.int32).cumsum(dim=0,
|
||||
dtype=torch.int32)
|
||||
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
|
||||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(0,
|
||||
num_blocks,
|
||||
@ -140,7 +145,7 @@ def test_cascade(
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
cu_seqlens_k=cu_kv_lens,
|
||||
seqused_k=kv_lens_tensor,
|
||||
max_seqlen_q=max_query_len,
|
||||
max_seqlen_k=max_kv_len,
|
||||
softmax_scale=scale,
|
||||
@ -154,10 +159,8 @@ def test_cascade(
|
||||
assert all(common_prefix_len < kv_len for kv_len in kv_lens)
|
||||
cu_prefix_query_lens = torch.tensor([0, total_num_query_tokens],
|
||||
dtype=torch.int32)
|
||||
cu_prefix_kv_lens = torch.tensor([0, common_prefix_len], dtype=torch.int32)
|
||||
cu_suffix_kv_lens = (
|
||||
cu_kv_lens -
|
||||
torch.arange(num_seqs + 1, dtype=torch.int32) * common_prefix_len)
|
||||
prefix_kv_lens = torch.tensor([common_prefix_len], dtype=torch.int32)
|
||||
suffix_kv_lens = kv_lens_tensor - common_prefix_len
|
||||
output = torch.empty_like(query)
|
||||
cascade_attention(
|
||||
output=output,
|
||||
@ -167,8 +170,8 @@ def test_cascade(
|
||||
cu_query_lens=cu_query_lens,
|
||||
max_query_len=max_query_len,
|
||||
cu_prefix_query_lens=cu_prefix_query_lens,
|
||||
cu_prefix_kv_lens=cu_prefix_kv_lens,
|
||||
cu_suffix_kv_lens=cu_suffix_kv_lens,
|
||||
prefix_kv_lens=prefix_kv_lens,
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
max_kv_len=max_kv_len,
|
||||
softmax_scale=scale,
|
||||
alibi_slopes=None,
|
||||
@ -176,6 +179,7 @@ def test_cascade(
|
||||
logits_soft_cap=soft_cap if soft_cap is not None else 0,
|
||||
block_table=block_tables,
|
||||
common_prefix_len=common_prefix_len,
|
||||
fa_version=fa_version,
|
||||
)
|
||||
|
||||
# Compare the results.
|
||||
|
||||
@ -80,6 +80,7 @@ def ref_paged_attn(
|
||||
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("sliding_window", [None, 256])
|
||||
@pytest.mark.parametrize("fa_version", [2, 3])
|
||||
@torch.inference_mode()
|
||||
def test_flash_attn_with_paged_kv(
|
||||
use_out: bool,
|
||||
@ -91,8 +92,14 @@ def test_flash_attn_with_paged_kv(
|
||||
soft_cap: Optional[float],
|
||||
num_blocks: int,
|
||||
sliding_window: Optional[int],
|
||||
fa_version: int,
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6)
|
||||
or torch.cuda.get_device_capability() == (8, 9)):
|
||||
pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to "
|
||||
"insufficient shared memory for some shapes")
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
num_seqs = len(kv_lens)
|
||||
num_query_heads = num_heads[0]
|
||||
@ -131,6 +138,7 @@ def test_flash_attn_with_paged_kv(
|
||||
cache_seqlens=kv_lens_tensor,
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
window_size=window_size,
|
||||
fa_version=fa_version,
|
||||
)
|
||||
output = output if not use_out else out
|
||||
output = output.squeeze(1)
|
||||
@ -159,6 +167,7 @@ def test_flash_attn_with_paged_kv(
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("fa_version", [2, 3])
|
||||
@torch.inference_mode()
|
||||
def test_varlen_with_paged_kv(
|
||||
use_out: bool,
|
||||
@ -170,8 +179,14 @@ def test_varlen_with_paged_kv(
|
||||
block_size: int,
|
||||
soft_cap: Optional[float],
|
||||
num_blocks: int,
|
||||
fa_version: int,
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6)
|
||||
or torch.cuda.get_device_capability() == (8, 9)):
|
||||
pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to "
|
||||
"insufficient shared memory for some shapes")
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
num_seqs = len(seq_lens)
|
||||
query_lens = [x[0] for x in seq_lens]
|
||||
@ -198,9 +213,7 @@ def test_varlen_with_paged_kv(
|
||||
cu_query_lens = torch.tensor([0] + query_lens,
|
||||
dtype=torch.int32).cumsum(dim=0,
|
||||
dtype=torch.int32)
|
||||
cu_kv_lens = torch.tensor([0] + kv_lens,
|
||||
dtype=torch.int32).cumsum(dim=0,
|
||||
dtype=torch.int32)
|
||||
kv_lens = torch.tensor(kv_lens, dtype=torch.int32)
|
||||
|
||||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(0,
|
||||
@ -215,7 +228,7 @@ def test_varlen_with_paged_kv(
|
||||
v=value_cache,
|
||||
out=out,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
cu_seqlens_k=cu_kv_lens,
|
||||
seqused_k=kv_lens,
|
||||
max_seqlen_q=max_query_len,
|
||||
max_seqlen_k=max_kv_len,
|
||||
softmax_scale=scale,
|
||||
@ -223,6 +236,7 @@ def test_varlen_with_paged_kv(
|
||||
window_size=window_size,
|
||||
block_table=block_tables,
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
fa_version=fa_version,
|
||||
)
|
||||
output = output if not use_out else out
|
||||
|
||||
|
||||
@ -138,6 +138,7 @@ def test_contexted_kv_attention(
|
||||
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
|
||||
v_cache = v_cache.view(-1, block_size, num_kv_heads,
|
||||
head_size).permute(0, 2, 3, 1).contiguous()
|
||||
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
|
||||
# Warm up the Triton kernel by calling it once before actually measuring
|
||||
# generation time
|
||||
@ -153,6 +154,8 @@ def test_contexted_kv_attention(
|
||||
b_seq_len,
|
||||
b_ctx_len,
|
||||
max_input_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
sliding_window=sliding_window)
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
@ -168,6 +171,8 @@ def test_contexted_kv_attention(
|
||||
b_seq_len,
|
||||
b_ctx_len,
|
||||
max_input_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
sliding_window=sliding_window)
|
||||
torch.cuda.synchronize()
|
||||
end_time = time.time()
|
||||
@ -366,6 +371,7 @@ def test_contexted_kv_attention_alibi(
|
||||
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
|
||||
v_cache = v_cache.view(-1, block_size, num_kv_heads,
|
||||
head_size).permute(0, 2, 3, 1).contiguous()
|
||||
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
|
||||
# Warm up the Triton kernel by calling it once before actually measuring
|
||||
# generation time
|
||||
@ -381,6 +387,8 @@ def test_contexted_kv_attention_alibi(
|
||||
b_seq_len,
|
||||
b_ctx_len,
|
||||
max_input_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
alibi_slopes=alibi_slopes)
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
@ -396,6 +404,8 @@ def test_contexted_kv_attention_alibi(
|
||||
b_seq_len,
|
||||
b_ctx_len,
|
||||
max_input_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
alibi_slopes=alibi_slopes)
|
||||
torch.cuda.synchronize()
|
||||
end_time = time.time()
|
||||
|
||||
@ -39,6 +39,23 @@ def get_8bit_types():
|
||||
return types
|
||||
|
||||
|
||||
# This test is to check regressions for int8 support on ROCm.
|
||||
@pytest.mark.parametrize("model_path", [
|
||||
"neuralmagic/Llama-3.2-1B-quantized.w8a8",
|
||||
])
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
@pytest.mark.parametrize("num_logprobs", [10])
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(),
|
||||
reason="Should only run on ROCm")
|
||||
def test_rocm_compressed_tensors_w8a8(vllm_runner, example_prompts, model_path,
|
||||
max_tokens, num_logprobs):
|
||||
dtype = "bfloat16"
|
||||
|
||||
with vllm_runner(model_path, dtype=dtype) as vllm_model:
|
||||
vllm_model.generate_greedy_logprobs(example_prompts, max_tokens,
|
||||
num_logprobs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M", [1, 33, 64, 512])
|
||||
@pytest.mark.parametrize("N", [256, 971, 20486])
|
||||
@pytest.mark.parametrize("K", [128, 496, 1024])
|
||||
|
||||
@ -909,6 +909,7 @@ def make_test_metadata(
|
||||
num_prefills=num_prefills,
|
||||
slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
seq_lens=seq_lens,
|
||||
@ -958,6 +959,7 @@ def make_test_metadata(
|
||||
num_prefills=num_prefills,
|
||||
slot_mapping=kv_mmap.slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
seq_lens=seq_lens,
|
||||
|
||||
@ -19,18 +19,17 @@ os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
|
||||
reason="fp8 is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize(
|
||||
"kv_cache_dtype,base_model,test_model,scale_path",
|
||||
"kv_cache_dtype,base_model,test_model",
|
||||
[
|
||||
# Test FP8 checkpoint w. fp8_e4m3 kv-cache scaling factors.
|
||||
("fp8_e4m3", "meta-llama/Llama-3.2-1B-Instruct",
|
||||
"nm-testing/Llama-3.2-1B-Instruct-FP8-KV", None),
|
||||
"nm-testing/Llama-3.2-1B-Instruct-FP8-KV"),
|
||||
# Test FP16 checkpoint w. fp8_e5m2 kv-cache.
|
||||
("fp8_e5m2", "meta-llama/Llama-3.2-1B-Instruct",
|
||||
"meta-llama/Llama-3.2-1B-Instruct", None),
|
||||
"meta-llama/Llama-3.2-1B-Instruct"),
|
||||
# Test FP16 checkpoint w. fp8_e4m3 kv-cache scaling factors in json.
|
||||
("fp8_e4m3", "meta-llama/Llama-2-7b-chat-hf",
|
||||
"meta-llama/Llama-2-7b-chat-hf",
|
||||
"./tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json")
|
||||
"meta-llama/Llama-2-7b-chat-hf")
|
||||
])
|
||||
# Due to low-precision numerical divergence, we only test logprob of 4 tokens
|
||||
@pytest.mark.parametrize("max_tokens", [4])
|
||||
@ -48,7 +47,6 @@ def test_models(
|
||||
kv_cache_dtype: str,
|
||||
base_model: str,
|
||||
test_model: str,
|
||||
scale_path: Optional[str],
|
||||
max_tokens: int,
|
||||
enforce_eager: bool,
|
||||
backend: str,
|
||||
@ -76,10 +74,6 @@ def test_models(
|
||||
baseline_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, NUM_LOG_PROBS)
|
||||
|
||||
extra_kwargs = {}
|
||||
if scale_path is not None:
|
||||
extra_kwargs["quantization_param_path"] = scale_path
|
||||
|
||||
with vllm_runner(
|
||||
test_model,
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
@ -87,7 +81,6 @@ def test_models(
|
||||
enforce_eager=enforce_eager,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
disable_async_output_proc=disable_async_output_proc,
|
||||
**extra_kwargs,
|
||||
) as vllm_model:
|
||||
test_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, NUM_LOG_PROBS)
|
||||
|
||||
@ -66,12 +66,16 @@ STARCODER_CONFIG = GGUFTestConfig(
|
||||
gguf_filename="starcoder2-3b.Q6_K.gguf",
|
||||
)
|
||||
|
||||
DOLPHIN_CONFIG = GGUFTestConfig(
|
||||
# Test VocabParallelEmbedding sharding issue.
|
||||
original_model="cognitivecomputations/TinyDolphin-2.8-1.1b",
|
||||
gguf_repo="tsunemoto/TinyDolphin-2.8-1.1b-GGUF",
|
||||
gguf_filename="tinydolphin-2.8-1.1b.Q6_K.gguf",
|
||||
)
|
||||
|
||||
MODELS = [
|
||||
LLAMA_CONFIG,
|
||||
QWEN2_CONFIG,
|
||||
PHI3_CONFIG,
|
||||
GPT2_CONFIG,
|
||||
STABLELM_CONFIG,
|
||||
LLAMA_CONFIG, QWEN2_CONFIG, PHI3_CONFIG, GPT2_CONFIG, STABLELM_CONFIG,
|
||||
DOLPHIN_CONFIG
|
||||
# STARCODER_CONFIG, # broken
|
||||
]
|
||||
|
||||
@ -106,15 +110,18 @@ def test_models(
|
||||
messages, tokenize=False, add_generation_prompt=True)
|
||||
|
||||
# Run unquantized model.
|
||||
with vllm_runner(model_name=model.original_model,
|
||||
dtype=dtype,
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
tensor_parallel_size=tp_size) as original_model:
|
||||
with vllm_runner(
|
||||
model_name=model.original_model,
|
||||
enforce_eager=True, # faster tests
|
||||
dtype=dtype,
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
tensor_parallel_size=tp_size) as original_model:
|
||||
original_outputs = original_model.generate_greedy_logprobs(
|
||||
example_prompts[:-1], max_tokens, num_logprobs)
|
||||
|
||||
# Run gguf model.
|
||||
with vllm_runner(model_name=model.gguf_model,
|
||||
enforce_eager=True,
|
||||
tokenizer_name=model.original_model,
|
||||
dtype=dtype,
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
|
||||
@ -10,7 +10,6 @@ from typing import Type
|
||||
import pytest
|
||||
from transformers import AutoModelForVision2Seq
|
||||
from transformers import __version__ as TRANSFORMERS_VERSION
|
||||
from transformers.utils import is_flash_attn_2_available
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import identity
|
||||
@ -140,9 +139,7 @@ VLM_TEST_SETTINGS = {
|
||||
#### Extended model tests
|
||||
"aria": VLMTestInfo(
|
||||
models=["rhymes-ai/Aria"],
|
||||
tokenizer_mode="slow",
|
||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||
dtype="bfloat16",
|
||||
prompt_formatter=lambda img_prompt: f"<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n ", # noqa: E501
|
||||
img_idx_to_prompt=lambda idx: "<fim_prefix><|img|><fim_suffix>\n",
|
||||
max_model_len=4096,
|
||||
@ -158,8 +155,8 @@ VLM_TEST_SETTINGS = {
|
||||
max_tokens=64,
|
||||
marks=[
|
||||
pytest.mark.skipif(
|
||||
not is_flash_attn_2_available(),
|
||||
reason="Model needs flash-attn for numeric convergence.",
|
||||
TRANSFORMERS_VERSION < "4.48.0",
|
||||
reason="HF model requires transformers>=4.48.0",
|
||||
),
|
||||
large_gpu_mark(min_gb=64),
|
||||
],
|
||||
|
||||
@ -11,6 +11,7 @@ from vllm.multimodal.processing import ProcessingCache
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
|
||||
from ....multimodal.utils import random_audio, random_image, random_video
|
||||
from ...registry import HF_EXAMPLE_MODELS
|
||||
|
||||
|
||||
def _test_processing_correctness(
|
||||
@ -20,12 +21,9 @@ def _test_processing_correctness(
|
||||
num_batches: int,
|
||||
simplify_rate: float,
|
||||
):
|
||||
if model_id == "TIGER-Lab/Mantis-8B-siglip-llama3":
|
||||
hf_overrides = {"architectures": ["MantisForConditionalGeneration"]}
|
||||
elif model_id == "deepseek-ai/deepseek-vl2-tiny":
|
||||
hf_overrides = {"architectures": ["DeepseekVLV2ForCausalLM"]}
|
||||
else:
|
||||
hf_overrides = {}
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
|
||||
limit_mm_per_prompt = {
|
||||
modality: 3 if supports_multi else 1
|
||||
@ -37,11 +35,11 @@ def _test_processing_correctness(
|
||||
task="auto",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=True,
|
||||
trust_remote_code=model_info.trust_remote_code,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
hf_overrides=hf_overrides,
|
||||
hf_overrides=model_info.hf_overrides,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
)
|
||||
|
||||
|
||||
@ -1,5 +1,9 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import AbstractSet, Mapping, Optional
|
||||
from typing import AbstractSet, Any, Literal, Mapping, Optional
|
||||
|
||||
import pytest
|
||||
from packaging.version import Version
|
||||
from transformers import __version__ as TRANSFORMERS_VERSION
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -38,6 +42,50 @@ class _HfExamplesInfo:
|
||||
trust_remote_code: bool = False
|
||||
"""The ``trust_remote_code`` level required to load the model."""
|
||||
|
||||
hf_overrides: dict[str, Any] = field(default_factory=dict)
|
||||
"""The ``hf_overrides`` required to load the model."""
|
||||
|
||||
def check_transformers_version(
|
||||
self,
|
||||
*,
|
||||
on_fail: Literal["error", "skip"],
|
||||
) -> None:
|
||||
"""
|
||||
If the installed transformers version does not meet the requirements,
|
||||
perform the given action.
|
||||
"""
|
||||
if self.min_transformers_version is None:
|
||||
return
|
||||
|
||||
current_version = TRANSFORMERS_VERSION
|
||||
required_version = self.min_transformers_version
|
||||
if Version(current_version) < Version(required_version):
|
||||
msg = (
|
||||
f"You have `transformers=={current_version}` installed, but "
|
||||
f"`transformers>={required_version}` is required to run this "
|
||||
"model")
|
||||
|
||||
if on_fail == "error":
|
||||
raise RuntimeError(msg)
|
||||
else:
|
||||
pytest.skip(msg)
|
||||
|
||||
def check_available_online(
|
||||
self,
|
||||
*,
|
||||
on_fail: Literal["error", "skip"],
|
||||
) -> None:
|
||||
"""
|
||||
If the model is not available online, perform the given action.
|
||||
"""
|
||||
if not self.is_available_online:
|
||||
msg = "Model is not available online"
|
||||
|
||||
if on_fail == "error":
|
||||
raise RuntimeError(msg)
|
||||
else:
|
||||
pytest.skip(msg)
|
||||
|
||||
|
||||
# yapf: disable
|
||||
_TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
@ -48,8 +96,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
trust_remote_code=True),
|
||||
"ArcticForCausalLM": _HfExamplesInfo("Snowflake/snowflake-arctic-instruct",
|
||||
trust_remote_code=True),
|
||||
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria",
|
||||
trust_remote_code=True),
|
||||
"BaiChuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan-7B",
|
||||
trust_remote_code=True),
|
||||
"BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat",
|
||||
@ -176,6 +222,8 @@ _CROSS_ENCODER_EXAMPLE_MODELS = {
|
||||
|
||||
_MULTIMODAL_EXAMPLE_MODELS = {
|
||||
# [Decoder-only]
|
||||
"AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria",
|
||||
min_transformers_version="4.48"),
|
||||
"Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b"), # noqa: E501
|
||||
"ChameleonForConditionalGeneration": _HfExamplesInfo("facebook/chameleon-7b"), # noqa: E501
|
||||
"ChatGLMModel": _HfExamplesInfo("THUDM/glm-4v-9b",
|
||||
@ -183,7 +231,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
trust_remote_code=True),
|
||||
"ChatGLMForConditionalGeneration": _HfExamplesInfo("chatglm2-6b",
|
||||
is_available_online=False),
|
||||
"DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny"), # noqa: E501
|
||||
"DeepseekVLV2ForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-vl2-tiny", # noqa: E501
|
||||
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501
|
||||
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
|
||||
"H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m"),
|
||||
"InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B",
|
||||
@ -194,7 +243,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
"LlavaNextForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-v1.6-mistral-7b-hf"), # noqa: E501
|
||||
"LlavaNextVideoForConditionalGeneration": _HfExamplesInfo("llava-hf/LLaVA-NeXT-Video-7B-hf"), # noqa: E501
|
||||
"LlavaOnevisionForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501
|
||||
"MantisForConditionalGeneration": _HfExamplesInfo("TIGER-Lab/Mantis-8B-siglip-llama3"), # noqa: E501
|
||||
"MantisForConditionalGeneration": _HfExamplesInfo("TIGER-Lab/Mantis-8B-siglip-llama3", # noqa: E501
|
||||
hf_overrides={"architectures": ["MantisForConditionalGeneration"]}), # noqa: E501
|
||||
"MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5",
|
||||
trust_remote_code=True),
|
||||
"MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924",
|
||||
@ -211,7 +261,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
trust_remote_code=True),
|
||||
"Qwen2AudioForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-Audio-7B-Instruct"), # noqa: E501
|
||||
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501
|
||||
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_3"),
|
||||
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_3",
|
||||
trust_remote_code=True),
|
||||
# [Encoder-decoder]
|
||||
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501
|
||||
"WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501
|
||||
@ -247,5 +298,17 @@ class HfExampleModels:
|
||||
def get_hf_info(self, model_arch: str) -> _HfExamplesInfo:
|
||||
return self.hf_models[model_arch]
|
||||
|
||||
def find_hf_info(self, model_id: str) -> _HfExamplesInfo:
|
||||
for info in self.hf_models.values():
|
||||
if info.default == model_id:
|
||||
return info
|
||||
|
||||
# Fallback to extras
|
||||
for info in self.hf_models.values():
|
||||
if any(extra == model_id for extra in info.extras.values()):
|
||||
return info
|
||||
|
||||
raise ValueError(f"No example model defined for {model_id}")
|
||||
|
||||
|
||||
HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS)
|
||||
|
||||
@ -1,9 +1,7 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from packaging.version import Version
|
||||
from transformers import PretrainedConfig
|
||||
from transformers import __version__ as TRANSFORMERS_VERSION
|
||||
|
||||
from vllm import LLM
|
||||
|
||||
@ -13,16 +11,8 @@ from .registry import HF_EXAMPLE_MODELS
|
||||
@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
|
||||
def test_can_initialize(model_arch):
|
||||
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
|
||||
if not model_info.is_available_online:
|
||||
pytest.skip("Model is not available online")
|
||||
if model_info.min_transformers_version is not None:
|
||||
current_version = TRANSFORMERS_VERSION
|
||||
required_version = model_info.min_transformers_version
|
||||
if Version(current_version) < Version(required_version):
|
||||
pytest.skip(
|
||||
f"You have `transformers=={current_version}` installed, but "
|
||||
f"`transformers>={required_version}` is required to run this "
|
||||
"model")
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
|
||||
# Avoid OOM
|
||||
def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig:
|
||||
|
||||
@ -21,6 +21,9 @@ from .registry import HF_EXAMPLE_MODELS
|
||||
|
||||
@pytest.mark.parametrize("model_arch", ModelRegistry.get_supported_archs())
|
||||
def test_registry_imports(model_arch):
|
||||
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
|
||||
# Ensure all model classes can be imported successfully
|
||||
model_cls, _ = ModelRegistry.resolve_model_cls(model_arch)
|
||||
|
||||
|
||||
@ -7,12 +7,16 @@ import pytest
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.processing import (PlaceholderInfo, PromptReplacement,
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
|
||||
PromptReplacement,
|
||||
find_mm_placeholders,
|
||||
find_text_matches, find_token_matches,
|
||||
iter_token_matches,
|
||||
replace_text_matches,
|
||||
replace_token_matches)
|
||||
# yapf: enable
|
||||
from vllm.multimodal.profiling import MultiModalProfiler
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
@ -433,19 +437,19 @@ def test_find_replace_tokens(
|
||||
[1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
|
||||
{
|
||||
"pattern_1": [
|
||||
PlaceholderInfo(
|
||||
PlaceholderFeaturesInfo(
|
||||
modality="pattern_1",
|
||||
item_idx=0,
|
||||
start_idx=6,
|
||||
replacement=[32000, 32000],
|
||||
tokens=[32000, 32000],
|
||||
),
|
||||
],
|
||||
"pattern_4": [
|
||||
PlaceholderInfo(
|
||||
PlaceholderFeaturesInfo(
|
||||
modality="pattern_4",
|
||||
item_idx=0,
|
||||
start_idx=3,
|
||||
replacement=[32000],
|
||||
tokens=[32000],
|
||||
),
|
||||
],
|
||||
}
|
||||
@ -455,25 +459,25 @@ def test_find_replace_tokens(
|
||||
[1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550],
|
||||
{
|
||||
"pattern_1": [
|
||||
PlaceholderInfo(
|
||||
PlaceholderFeaturesInfo(
|
||||
modality="pattern_1",
|
||||
item_idx=0,
|
||||
start_idx=1,
|
||||
replacement=[32000, 32000],
|
||||
tokens=[32000, 32000],
|
||||
),
|
||||
PlaceholderInfo(
|
||||
PlaceholderFeaturesInfo(
|
||||
modality="pattern_1",
|
||||
item_idx=1,
|
||||
start_idx=5,
|
||||
replacement=[32000, 32000],
|
||||
tokens=[32000, 32000],
|
||||
),
|
||||
],
|
||||
"pattern_3": [
|
||||
PlaceholderInfo(
|
||||
PlaceholderFeaturesInfo(
|
||||
modality="pattern_3",
|
||||
item_idx=0,
|
||||
start_idx=7,
|
||||
replacement=[1550, 918, 1550],
|
||||
tokens=[1550, 918, 1550],
|
||||
),
|
||||
],
|
||||
# No match for pattern_4 as it has lower priority than pattern_1
|
||||
@ -483,33 +487,33 @@ def test_find_replace_tokens(
|
||||
[1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
|
||||
{
|
||||
"pattern_1": [
|
||||
PlaceholderInfo(
|
||||
PlaceholderFeaturesInfo(
|
||||
modality="pattern_1",
|
||||
item_idx=0,
|
||||
start_idx=1,
|
||||
replacement=[32000, 32000],
|
||||
tokens=[32000, 32000],
|
||||
),
|
||||
PlaceholderInfo(
|
||||
PlaceholderFeaturesInfo(
|
||||
modality="pattern_1",
|
||||
item_idx=1,
|
||||
start_idx=3,
|
||||
replacement=[32000, 32000],
|
||||
tokens=[32000, 32000],
|
||||
),
|
||||
],
|
||||
"pattern_4": [
|
||||
PlaceholderInfo(
|
||||
PlaceholderFeaturesInfo(
|
||||
modality="pattern_4",
|
||||
item_idx=0,
|
||||
start_idx=5,
|
||||
replacement=[32000],
|
||||
tokens=[32000],
|
||||
),
|
||||
],
|
||||
"pattern_3": [
|
||||
PlaceholderInfo(
|
||||
PlaceholderFeaturesInfo(
|
||||
modality="pattern_3",
|
||||
item_idx=0,
|
||||
start_idx=6,
|
||||
replacement=[1550, 918, 1550],
|
||||
tokens=[1550, 918, 1550],
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
@ -31,7 +31,7 @@ def test_random_sample_with_seed(
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
# Parameters to ensure sufficient randomness
|
||||
temperature=2.0,
|
||||
temperature=3.0,
|
||||
top_p=min(random.random() + 0.3, 1),
|
||||
top_k=random.randint(5, 20),
|
||||
n=random.randint(1, 10),
|
||||
@ -75,3 +75,8 @@ def test_random_sample_with_seed(
|
||||
# verify requests with the same seed match
|
||||
assert outputs[1] == outputs[4]
|
||||
assert outputs[2] == outputs[5]
|
||||
|
||||
# verify generations within the same parallel sampling group differ
|
||||
for output in outputs:
|
||||
for sub_output_a, sub_output_b in combinations(output, 2):
|
||||
assert sub_output_a != sub_output_b
|
||||
|
||||
@ -100,32 +100,32 @@ def test_traces(trace_service):
|
||||
|
||||
attributes = decode_attributes(
|
||||
request.resource_spans[0].scope_spans[0].spans[0].attributes)
|
||||
assert attributes.get(SpanAttributes.LLM_RESPONSE_MODEL) == model
|
||||
assert attributes.get(SpanAttributes.GEN_AI_RESPONSE_MODEL) == model
|
||||
assert attributes.get(
|
||||
SpanAttributes.LLM_REQUEST_ID) == outputs[0].request_id
|
||||
SpanAttributes.GEN_AI_REQUEST_ID) == outputs[0].request_id
|
||||
assert attributes.get(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE
|
||||
) == sampling_params.temperature
|
||||
assert attributes.get(
|
||||
SpanAttributes.LLM_REQUEST_TEMPERATURE) == sampling_params.temperature
|
||||
SpanAttributes.GEN_AI_REQUEST_TOP_P) == sampling_params.top_p
|
||||
assert attributes.get(
|
||||
SpanAttributes.LLM_REQUEST_TOP_P) == sampling_params.top_p
|
||||
assert attributes.get(
|
||||
SpanAttributes.LLM_REQUEST_MAX_TOKENS) == sampling_params.max_tokens
|
||||
assert attributes.get(SpanAttributes.LLM_REQUEST_N) == sampling_params.n
|
||||
assert attributes.get(SpanAttributes.LLM_USAGE_PROMPT_TOKENS) == len(
|
||||
SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS) == sampling_params.max_tokens
|
||||
assert attributes.get(SpanAttributes.GEN_AI_REQUEST_N) == sampling_params.n
|
||||
assert attributes.get(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS) == len(
|
||||
outputs[0].prompt_token_ids)
|
||||
completion_tokens = sum(len(o.token_ids) for o in outputs[0].outputs)
|
||||
assert attributes.get(
|
||||
SpanAttributes.LLM_USAGE_COMPLETION_TOKENS) == completion_tokens
|
||||
SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS) == completion_tokens
|
||||
metrics = outputs[0].metrics
|
||||
assert attributes.get(
|
||||
SpanAttributes.LLM_LATENCY_TIME_IN_QUEUE) == metrics.time_in_queue
|
||||
SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE) == metrics.time_in_queue
|
||||
ttft = metrics.first_token_time - metrics.arrival_time
|
||||
assert attributes.get(
|
||||
SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN) == ttft
|
||||
SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN) == ttft
|
||||
e2e_time = metrics.finished_time - metrics.arrival_time
|
||||
assert attributes.get(SpanAttributes.LLM_LATENCY_E2E) == e2e_time
|
||||
assert attributes.get(SpanAttributes.GEN_AI_LATENCY_E2E) == e2e_time
|
||||
assert metrics.scheduler_time > 0
|
||||
assert attributes.get(
|
||||
SpanAttributes.LLM_LATENCY_TIME_IN_SCHEDULER) == metrics.scheduler_time
|
||||
assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_SCHEDULER
|
||||
) == metrics.scheduler_time
|
||||
# Model forward and model execute should be none, since detailed traces is
|
||||
# not enabled.
|
||||
assert metrics.model_forward_time is None
|
||||
@ -166,37 +166,37 @@ def test_traces_with_detailed_steps(trace_service):
|
||||
|
||||
attributes = decode_attributes(
|
||||
request.resource_spans[0].scope_spans[0].spans[0].attributes)
|
||||
assert attributes.get(SpanAttributes.LLM_RESPONSE_MODEL) == model
|
||||
assert attributes.get(SpanAttributes.GEN_AI_RESPONSE_MODEL) == model
|
||||
assert attributes.get(
|
||||
SpanAttributes.LLM_REQUEST_ID) == outputs[0].request_id
|
||||
SpanAttributes.GEN_AI_REQUEST_ID) == outputs[0].request_id
|
||||
assert attributes.get(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE
|
||||
) == sampling_params.temperature
|
||||
assert attributes.get(
|
||||
SpanAttributes.LLM_REQUEST_TEMPERATURE) == sampling_params.temperature
|
||||
SpanAttributes.GEN_AI_REQUEST_TOP_P) == sampling_params.top_p
|
||||
assert attributes.get(
|
||||
SpanAttributes.LLM_REQUEST_TOP_P) == sampling_params.top_p
|
||||
assert attributes.get(
|
||||
SpanAttributes.LLM_REQUEST_MAX_TOKENS) == sampling_params.max_tokens
|
||||
assert attributes.get(SpanAttributes.LLM_REQUEST_N) == sampling_params.n
|
||||
assert attributes.get(SpanAttributes.LLM_USAGE_PROMPT_TOKENS) == len(
|
||||
SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS) == sampling_params.max_tokens
|
||||
assert attributes.get(SpanAttributes.GEN_AI_REQUEST_N) == sampling_params.n
|
||||
assert attributes.get(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS) == len(
|
||||
outputs[0].prompt_token_ids)
|
||||
completion_tokens = sum(len(o.token_ids) for o in outputs[0].outputs)
|
||||
assert attributes.get(
|
||||
SpanAttributes.LLM_USAGE_COMPLETION_TOKENS) == completion_tokens
|
||||
SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS) == completion_tokens
|
||||
metrics = outputs[0].metrics
|
||||
assert attributes.get(
|
||||
SpanAttributes.LLM_LATENCY_TIME_IN_QUEUE) == metrics.time_in_queue
|
||||
SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE) == metrics.time_in_queue
|
||||
ttft = metrics.first_token_time - metrics.arrival_time
|
||||
assert attributes.get(
|
||||
SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN) == ttft
|
||||
SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN) == ttft
|
||||
e2e_time = metrics.finished_time - metrics.arrival_time
|
||||
assert attributes.get(SpanAttributes.LLM_LATENCY_E2E) == e2e_time
|
||||
assert attributes.get(SpanAttributes.GEN_AI_LATENCY_E2E) == e2e_time
|
||||
assert metrics.scheduler_time > 0
|
||||
assert attributes.get(
|
||||
SpanAttributes.LLM_LATENCY_TIME_IN_SCHEDULER) == metrics.scheduler_time
|
||||
assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_SCHEDULER
|
||||
) == metrics.scheduler_time
|
||||
assert metrics.model_forward_time > 0
|
||||
assert attributes.get(
|
||||
SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_FORWARD) == pytest.approx(
|
||||
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD) == pytest.approx(
|
||||
metrics.model_forward_time / 1000)
|
||||
assert metrics.model_execute_time > 0
|
||||
assert attributes.get(SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_EXECUTE
|
||||
assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE
|
||||
) == metrics.model_execute_time
|
||||
assert metrics.model_forward_time < 1000 * metrics.model_execute_time
|
||||
|
||||
@ -587,3 +587,72 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
|
||||
assert {block.ref_cnt for block in block_part1[:3]} == {1}
|
||||
# Block 3-5 are free.
|
||||
assert {block.ref_cnt for block in block_part1[3:]} == {0}
|
||||
|
||||
|
||||
def test_reset_prefix_cache():
|
||||
manager = KVCacheManager(
|
||||
block_size=16,
|
||||
num_gpu_blocks=10,
|
||||
max_model_len=8192,
|
||||
sliding_window=None,
|
||||
enable_caching=True,
|
||||
num_preallocate_tokens=0,
|
||||
)
|
||||
|
||||
full_block_token_ids = [i for i in range(3) for _ in range(16)]
|
||||
unique_token_ids = [3] * 7
|
||||
all_token_ids = full_block_token_ids + unique_token_ids
|
||||
req0 = make_request("0", all_token_ids)
|
||||
blocks = manager.allocate_slots(req0, 55, [])
|
||||
assert [b.block_id for b in blocks] == [0, 1, 2, 3]
|
||||
|
||||
unique_token_ids = [4] * 7
|
||||
all_token_ids = full_block_token_ids + unique_token_ids
|
||||
req1 = make_request("1", all_token_ids)
|
||||
computed_blocks, _ = manager.get_computed_blocks(req1)
|
||||
assert len(req1.kv_block_hashes) == 3
|
||||
assert len(computed_blocks) == 3
|
||||
blocks = manager.allocate_slots(req1, 7, computed_blocks)
|
||||
assert [b.block_id for b in blocks] == [4]
|
||||
|
||||
# Failed to reset prefix cache because some blocks are not freed yet.
|
||||
assert not manager.reset_prefix_cache()
|
||||
assert manager.cached_block_hash_to_block
|
||||
|
||||
# Free the blocks.
|
||||
manager.free(req0)
|
||||
manager.free(req1)
|
||||
|
||||
assert manager.reset_prefix_cache()
|
||||
assert not manager.cached_block_hash_to_block
|
||||
assert all([blk.block_hash is None for blk in manager.block_pool])
|
||||
|
||||
|
||||
def test_uncache_blocks():
|
||||
manager = KVCacheManager(
|
||||
block_size=16,
|
||||
num_gpu_blocks=10,
|
||||
max_model_len=8192,
|
||||
sliding_window=None,
|
||||
enable_caching=True,
|
||||
num_preallocate_tokens=0,
|
||||
)
|
||||
|
||||
req0 = make_request("0", list(range(30)))
|
||||
blocks = manager.allocate_slots(req0, 30, [])
|
||||
assert [b.block_id for b in blocks] == [0, 1]
|
||||
assert len(manager.cached_block_hash_to_block) == 1
|
||||
|
||||
req0.num_computed_tokens = 30
|
||||
|
||||
# Simulate speculative tokens.
|
||||
for _ in range(5):
|
||||
req0.append_output_token_ids(8)
|
||||
manager.append_slots(req0, 5)
|
||||
assert len(manager.cached_block_hash_to_block) == 2
|
||||
|
||||
# After sampling, assuming only 1 token is accepted.
|
||||
req0.num_computed_tokens = 31
|
||||
num_uncached_blocks = manager.uncache_blocks(req0)
|
||||
assert num_uncached_blocks == 1
|
||||
assert len(manager.cached_block_hash_to_block) == 1
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
from contextlib import ExitStack
|
||||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
@ -6,6 +7,7 @@ import pytest
|
||||
from vllm import SamplingParams
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
@ -18,28 +20,39 @@ ENGINE_ARGS = AsyncEngineArgs(model="meta-llama/Llama-3.2-1B",
|
||||
|
||||
|
||||
async def generate(engine: AsyncLLM, request_id: str,
|
||||
output_kind: RequestOutputKind,
|
||||
max_tokens: int) -> Tuple[int, str]:
|
||||
count = 0
|
||||
async for _ in engine.generate(request_id=request_id,
|
||||
prompt="Hello my name is Robert and",
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=max_tokens, temperature=0)):
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens,
|
||||
output_kind=output_kind,
|
||||
temperature=0)
|
||||
async for out in engine.generate(request_id=request_id,
|
||||
prompt="Hello my name is Robert and",
|
||||
sampling_params=sampling_params):
|
||||
|
||||
num_tokens = len(out.outputs[0].token_ids)
|
||||
if output_kind == RequestOutputKind.DELTA:
|
||||
count += num_tokens
|
||||
else:
|
||||
count = num_tokens
|
||||
|
||||
count += 1
|
||||
await asyncio.sleep(0.)
|
||||
|
||||
return count, request_id
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
|
||||
@pytest.mark.asyncio
|
||||
async def test_load(monkeypatch):
|
||||
async def test_load(monkeypatch, output_kind: RequestOutputKind):
|
||||
# TODO(rickyx): Remove monkeypatch once we have a better way to test V1
|
||||
# so that in the future when we switch, we don't have to change all the
|
||||
# tests.
|
||||
with monkeypatch.context() as m:
|
||||
with monkeypatch.context() as m, ExitStack() as after:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
engine = AsyncLLM.from_engine_args(ENGINE_ARGS)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
NUM_REQUESTS = 10000
|
||||
NUM_EXPECTED_TOKENS = 10
|
||||
@ -51,26 +64,33 @@ async def test_load(monkeypatch):
|
||||
for request_id in request_ids:
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
generate(engine, request_id, NUM_EXPECTED_TOKENS)))
|
||||
generate(engine, request_id, output_kind,
|
||||
NUM_EXPECTED_TOKENS)))
|
||||
|
||||
# Confirm that we got all the EXPECTED tokens from the requests.
|
||||
for task in tasks:
|
||||
done, pending = await asyncio.wait(tasks,
|
||||
return_when=asyncio.FIRST_EXCEPTION)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
for task in done:
|
||||
num_generated_tokens, request_id = await task
|
||||
assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
|
||||
f"{request_id} generated {num_generated_tokens} but "
|
||||
f"expected {NUM_EXPECTED_TOKENS}")
|
||||
|
||||
assert not engine.output_processor.has_unfinished_requests()
|
||||
engine.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
|
||||
@pytest.mark.asyncio
|
||||
async def test_abort(monkeypatch):
|
||||
async def test_abort(monkeypatch, output_kind: RequestOutputKind):
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
with monkeypatch.context() as m, ExitStack() as after:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
engine = AsyncLLM.from_engine_args(ENGINE_ARGS)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
NUM_REQUESTS = 100
|
||||
NUM_EXPECTED_TOKENS = 100
|
||||
@ -83,7 +103,8 @@ async def test_abort(monkeypatch):
|
||||
for request_id in request_ids:
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
generate(engine, request_id, NUM_EXPECTED_TOKENS)))
|
||||
generate(engine, request_id, output_kind,
|
||||
NUM_EXPECTED_TOKENS)))
|
||||
|
||||
# API server cancels requests when they disconnect.
|
||||
for idx in REQUEST_IDS_TO_ABORT:
|
||||
@ -108,9 +129,7 @@ async def test_abort(monkeypatch):
|
||||
# Confirm we can do another generation.
|
||||
request_id = f"request-{REQUEST_IDS_TO_ABORT[0]}"
|
||||
task = asyncio.create_task(
|
||||
generate(engine, request_id, NUM_EXPECTED_TOKENS))
|
||||
generate(engine, request_id, output_kind, NUM_EXPECTED_TOKENS))
|
||||
num_generated_tokens, request_id = await task
|
||||
assert num_generated_tokens == NUM_EXPECTED_TOKENS
|
||||
assert not engine.output_processor.has_unfinished_requests()
|
||||
|
||||
engine.shutdown()
|
||||
|
||||
300
tests/v1/test_stats.py
Normal file
300
tests/v1/test_stats.py
Normal file
@ -0,0 +1,300 @@
|
||||
import pytest
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.stats.common import RequestStats, RequestStatsUpdate
|
||||
|
||||
|
||||
def make_update(
|
||||
request_id: str,
|
||||
update_type: RequestStatsUpdate.Type,
|
||||
monotonic_ts_s: float,
|
||||
**kwargs,
|
||||
):
|
||||
if update_type == RequestStatsUpdate.Type.INPUT_PROCESSED:
|
||||
kwargs.setdefault("sampling_params", SamplingParams(n=1))
|
||||
kwargs.setdefault("num_prompt_tokens", 10)
|
||||
elif update_type == RequestStatsUpdate.Type.PREFILLING:
|
||||
kwargs.setdefault("num_computed_tokens", 10)
|
||||
kwargs.setdefault("num_cached_tokens", 10)
|
||||
elif update_type == RequestStatsUpdate.Type.DETOKENIZED:
|
||||
kwargs.setdefault("num_new_tokens", 10)
|
||||
elif update_type == RequestStatsUpdate.Type.FINISHED:
|
||||
kwargs.setdefault("finish_reason", "test_reason")
|
||||
|
||||
return RequestStatsUpdate(
|
||||
request_id=request_id,
|
||||
type=update_type,
|
||||
monotonic_ts_s=monotonic_ts_s,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def test_invalid_request_update():
|
||||
request_id = "test_request"
|
||||
update_specific_required_fields = {
|
||||
RequestStatsUpdate.Type.INPUT_PROCESSED: [
|
||||
"sampling_params",
|
||||
"num_prompt_tokens",
|
||||
],
|
||||
RequestStatsUpdate.Type.PREFILLING: [
|
||||
"num_computed_tokens",
|
||||
"num_cached_tokens",
|
||||
],
|
||||
RequestStatsUpdate.Type.DETOKENIZED: ["num_new_tokens"],
|
||||
RequestStatsUpdate.Type.FINISHED: ["finish_reason"],
|
||||
}
|
||||
|
||||
# Missing a required field should raise an assertion error.
|
||||
for update_type in RequestStatsUpdate.Type:
|
||||
required_fields = update_specific_required_fields.get(update_type, [])
|
||||
|
||||
# Try to miss one of the required fields.
|
||||
kwargs = {field: object() for field in required_fields}
|
||||
for field in required_fields:
|
||||
copy_kwargs = kwargs.copy()
|
||||
copy_kwargs.pop(field)
|
||||
with pytest.raises(ValueError):
|
||||
RequestStatsUpdate(
|
||||
request_id=request_id,
|
||||
type=update_type,
|
||||
**copy_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def test_invalid_request_update_transition():
|
||||
# Test invalid transition type.
|
||||
for src in RequestStatsUpdate.Type:
|
||||
for dst in RequestStatsUpdate.Type:
|
||||
if dst not in RequestStatsUpdate._VALID_TRANSITIONS[src]:
|
||||
with pytest.raises(AssertionError):
|
||||
RequestStatsUpdate.check_valid_update(
|
||||
make_update(
|
||||
update_type=dst,
|
||||
request_id="test_request",
|
||||
monotonic_ts_s=1,
|
||||
),
|
||||
last_update_type=src,
|
||||
last_updated_ts_s=0,
|
||||
)
|
||||
else:
|
||||
RequestStatsUpdate.check_valid_update(
|
||||
make_update(
|
||||
request_id="test_request",
|
||||
update_type=dst,
|
||||
monotonic_ts_s=1,
|
||||
),
|
||||
last_update_type=src,
|
||||
last_updated_ts_s=0,
|
||||
)
|
||||
|
||||
# Test invalid timestamp.
|
||||
with pytest.raises(AssertionError):
|
||||
RequestStatsUpdate.check_valid_update(
|
||||
make_update(
|
||||
request_id="test_request",
|
||||
update_type=RequestStatsUpdate.Type.ARRIVED,
|
||||
monotonic_ts_s=1,
|
||||
),
|
||||
last_update_type=None,
|
||||
last_updated_ts_s=2,
|
||||
)
|
||||
|
||||
|
||||
def test_lifecycle_updates():
|
||||
request_id = "test_request"
|
||||
stats = RequestStats(request_id=request_id)
|
||||
|
||||
# Test the below scenario:
|
||||
arrived_ts = 0
|
||||
input_processed_ts = 1
|
||||
queued_ts = 2
|
||||
prefilling_ts = 3
|
||||
decoded_ts = 5
|
||||
detokenized_ts = 6
|
||||
decoded_2_ts = 7
|
||||
detokenized_2_ts = 8
|
||||
preempted_ts = 9
|
||||
resumed_ts = 10
|
||||
decoded_3_ts = 11
|
||||
detokenized_3_ts = 12
|
||||
finished_ts = 13
|
||||
|
||||
# Test ARRIVED
|
||||
arrived_update = RequestStatsUpdate(
|
||||
request_id=request_id,
|
||||
type=RequestStatsUpdate.Type.ARRIVED,
|
||||
monotonic_ts_s=arrived_ts,
|
||||
)
|
||||
stats.update_from(arrived_update)
|
||||
assert stats.arrival_ts_s == arrived_ts
|
||||
assert stats.last_updated_ts_s == arrived_ts
|
||||
|
||||
# Test INPUT_PROCESSED
|
||||
sampling_params = SamplingParams(n=1)
|
||||
input_processed_update = RequestStatsUpdate(
|
||||
request_id=request_id,
|
||||
type=RequestStatsUpdate.Type.INPUT_PROCESSED,
|
||||
monotonic_ts_s=input_processed_ts,
|
||||
sampling_params=sampling_params,
|
||||
num_prompt_tokens=6,
|
||||
)
|
||||
stats.update_from(input_processed_update)
|
||||
assert stats.input_processor_end_ts_s == input_processed_ts
|
||||
assert stats.last_updated_ts_s == input_processed_ts
|
||||
assert stats.num_prompt_tokens == 6
|
||||
assert stats.sampling_params == sampling_params
|
||||
|
||||
assert stats.first_token_ts_s is None
|
||||
assert stats.prefill_ts_s is None
|
||||
|
||||
# Test QUEUED
|
||||
queued_update = RequestStatsUpdate(
|
||||
request_id=request_id,
|
||||
type=RequestStatsUpdate.Type.QUEUED,
|
||||
monotonic_ts_s=queued_ts,
|
||||
)
|
||||
stats.update_from(queued_update)
|
||||
assert stats.queued_ts_s == queued_ts
|
||||
assert stats.last_updated_ts_s == queued_ts
|
||||
|
||||
# Test PREFILLING
|
||||
prefilling_update = RequestStatsUpdate(
|
||||
request_id=request_id,
|
||||
type=RequestStatsUpdate.Type.PREFILLING,
|
||||
monotonic_ts_s=prefilling_ts,
|
||||
num_computed_tokens=3,
|
||||
num_cached_tokens=1,
|
||||
)
|
||||
stats.update_from(prefilling_update)
|
||||
assert stats.prefill_ts_s == prefilling_ts
|
||||
assert stats.num_computed_tokens == 3
|
||||
assert stats.num_cached_tokens == 1
|
||||
assert stats.queue_duration_s == prefilling_ts - queued_ts
|
||||
|
||||
# Test DECODING
|
||||
decoded_update = RequestStatsUpdate(
|
||||
request_id=request_id,
|
||||
type=RequestStatsUpdate.Type.DECODING,
|
||||
monotonic_ts_s=decoded_ts,
|
||||
)
|
||||
stats.update_from(decoded_update)
|
||||
assert stats.last_updated_ts_s == decoded_ts
|
||||
|
||||
# Test DETOKENIZED
|
||||
detokenized_update = RequestStatsUpdate(
|
||||
request_id=request_id,
|
||||
type=RequestStatsUpdate.Type.DETOKENIZED,
|
||||
monotonic_ts_s=detokenized_ts,
|
||||
num_new_tokens=1,
|
||||
)
|
||||
stats.update_from(detokenized_update)
|
||||
assert stats.last_updated_ts_s == detokenized_ts
|
||||
assert stats.num_output_tokens == 1
|
||||
# Since arrival
|
||||
assert stats.first_token_latency_s == detokenized_ts - arrived_ts
|
||||
# Since first scheduled
|
||||
assert stats.prefill_latency_s == detokenized_ts - prefilling_ts
|
||||
|
||||
# Test another DECODING and DETOKENIZED should
|
||||
# yield correct inter token latency
|
||||
decoded_update = RequestStatsUpdate(
|
||||
request_id=request_id,
|
||||
type=RequestStatsUpdate.Type.DECODING,
|
||||
monotonic_ts_s=decoded_2_ts,
|
||||
)
|
||||
stats.update_from(decoded_update)
|
||||
|
||||
detokenized_update = RequestStatsUpdate(
|
||||
request_id=request_id,
|
||||
type=RequestStatsUpdate.Type.DETOKENIZED,
|
||||
monotonic_ts_s=detokenized_2_ts,
|
||||
num_new_tokens=1,
|
||||
)
|
||||
stats.update_from(detokenized_update)
|
||||
assert stats.output_token_latency_s_lst == [
|
||||
detokenized_2_ts - detokenized_ts,
|
||||
]
|
||||
assert stats.num_output_tokens == 2
|
||||
|
||||
# Test PREEMPTED
|
||||
preempted_update = RequestStatsUpdate(
|
||||
request_id=request_id,
|
||||
type=RequestStatsUpdate.Type.PREEMPTED,
|
||||
monotonic_ts_s=preempted_ts,
|
||||
)
|
||||
stats.update_from(preempted_update)
|
||||
assert stats.last_updated_ts_s == preempted_ts
|
||||
assert stats.preempted_ts_s_lst == [preempted_ts]
|
||||
# States should be reset
|
||||
assert stats.num_computed_tokens == 0
|
||||
assert stats.num_cached_tokens == 0
|
||||
# These states should not be reset
|
||||
assert stats.num_output_tokens == 2
|
||||
assert stats.output_token_latency_s_lst == [
|
||||
detokenized_2_ts - detokenized_ts,
|
||||
]
|
||||
assert stats.prefill_latency_s == prefilling_ts - arrived_ts
|
||||
assert stats.num_prompt_tokens == 6
|
||||
assert stats.prefill_start_ts_s_lst == [prefilling_ts]
|
||||
|
||||
# Test resumed
|
||||
resumed_update = RequestStatsUpdate(
|
||||
request_id=request_id,
|
||||
type=RequestStatsUpdate.Type.PREFILLING,
|
||||
monotonic_ts_s=resumed_ts,
|
||||
num_computed_tokens=6,
|
||||
num_cached_tokens=2,
|
||||
)
|
||||
stats.update_from(resumed_update)
|
||||
# prefill timestamp should not be updated since it's a resumed prefill
|
||||
assert stats.prefill_ts_s == prefilling_ts
|
||||
assert stats.num_computed_tokens == 6
|
||||
assert stats.num_cached_tokens == 2
|
||||
assert stats.prefill_start_ts_s_lst == [
|
||||
prefilling_ts,
|
||||
resumed_ts,
|
||||
]
|
||||
assert stats.last_updated_ts_s == resumed_ts
|
||||
|
||||
# Test another DECODED/DETOKENIZED should yield correct first token latency.
|
||||
decoded_update = RequestStatsUpdate(
|
||||
request_id=request_id,
|
||||
type=RequestStatsUpdate.Type.DECODING,
|
||||
monotonic_ts_s=decoded_3_ts,
|
||||
)
|
||||
detokenized_update = RequestStatsUpdate(
|
||||
request_id=request_id,
|
||||
type=RequestStatsUpdate.Type.DETOKENIZED,
|
||||
monotonic_ts_s=detokenized_3_ts,
|
||||
num_new_tokens=1,
|
||||
)
|
||||
stats.update_from(decoded_update)
|
||||
stats.update_from(detokenized_update)
|
||||
assert stats.first_token_ts_s == detokenized_ts - arrived_ts
|
||||
assert stats.num_output_tokens == 3
|
||||
assert stats.output_token_latency_s_lst == [
|
||||
detokenized_2_ts - detokenized_ts,
|
||||
detokenized_3_ts - detokenized_2_ts,
|
||||
]
|
||||
|
||||
# Test FINISHED
|
||||
finished_update = RequestStatsUpdate(
|
||||
request_id=request_id,
|
||||
type=RequestStatsUpdate.Type.FINISHED,
|
||||
monotonic_ts_s=finished_ts,
|
||||
finish_reason="test_reason",
|
||||
)
|
||||
stats.update_from(finished_update)
|
||||
assert stats.last_updated_ts_s == finished_ts
|
||||
assert stats.e2e_latency_s == finished_ts - arrived_ts
|
||||
assert stats.inference_latency_s == finished_ts - prefilling_ts
|
||||
assert stats.prefill_latency_s == detokenized_ts - prefilling_ts
|
||||
assert stats.decode_latency_s == finished_ts - detokenized_ts
|
||||
assert stats.first_token_latency_s == detokenized_ts - arrived_ts
|
||||
assert stats.queue_duration_s == prefilling_ts - queued_ts
|
||||
assert stats.is_finished
|
||||
assert stats.finish_reason == "test_reason"
|
||||
|
||||
# TODO(rickyx): Add model forward/execute time.
|
||||
assert stats.model_forward_duration_s == 0.0
|
||||
assert stats.model_execute_duration_s == 0.0
|
||||
@ -74,6 +74,7 @@ def test_model_runner_input():
|
||||
num_decode_tokens=3,
|
||||
slot_mapping=torch.zeros(1),
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
)
|
||||
model_input = ModelInputForGPUWithSamplingMetadata(
|
||||
input_tokens=torch.ones(10),
|
||||
@ -126,6 +127,7 @@ def test_embedding_model_runner_input():
|
||||
num_decode_tokens=3,
|
||||
slot_mapping=torch.zeros(1),
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
)
|
||||
model_input = ModelInputForGPUWithPoolingMetadata(
|
||||
input_tokens=torch.ones(10),
|
||||
@ -177,6 +179,7 @@ def test_multi_step_model_runner_input():
|
||||
num_decode_tokens=3,
|
||||
slot_mapping=torch.zeros(1),
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
)
|
||||
frozen_model_input = ModelInputForGPUWithSamplingMetadata(
|
||||
input_tokens=torch.ones(10),
|
||||
|
||||
@ -1,12 +1,16 @@
|
||||
#!/bin/bash
|
||||
|
||||
CI=${1:-0}
|
||||
PYTHON_VERSION=${2:-3.9}
|
||||
PYTHON_VERSION=${2:-local}
|
||||
|
||||
if [ "$CI" -eq 1 ]; then
|
||||
set -e
|
||||
fi
|
||||
|
||||
if [ $PYTHON_VERSION == "local" ]; then
|
||||
PYTHON_VERSION=$(python -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")')
|
||||
fi
|
||||
|
||||
run_mypy() {
|
||||
echo "Running mypy on $1"
|
||||
if [ "$CI" -eq 1 ] && [ -z "$1" ]; then
|
||||
@ -30,4 +34,4 @@ run_mypy vllm/plugins
|
||||
run_mypy vllm/prompt_adapter
|
||||
run_mypy vllm/spec_decode
|
||||
run_mypy vllm/worker
|
||||
run_mypy vllm/v1
|
||||
run_mypy vllm/v1
|
||||
@ -1,4 +1,7 @@
|
||||
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
@ -17,43 +20,18 @@ from vllm.sampling_params import SamplingParams
|
||||
|
||||
from .version import __version__, __version_tuple__
|
||||
|
||||
# set some common config/environment variables that should be set
|
||||
# for all processes created by vllm and all processes
|
||||
# that interact with vllm workers.
|
||||
# they are executed whenever `import vllm` is called.
|
||||
|
||||
def configure_as_vllm_process():
|
||||
"""
|
||||
set some common config/environment variables that should be set
|
||||
for all processes created by vllm and all processes
|
||||
that interact with vllm workers.
|
||||
"""
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
# see https://github.com/NVIDIA/nccl/issues/1234
|
||||
os.environ['NCCL_CUMEM_ENABLE'] = '0'
|
||||
|
||||
# see https://github.com/vllm-project/vllm/issues/10480
|
||||
os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1'
|
||||
# see https://github.com/vllm-project/vllm/issues/10619
|
||||
torch._inductor.config.compile_threads = 1
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_xpu():
|
||||
# see https://github.com/pytorch/pytorch/blob/43c5f59/torch/_dynamo/config.py#L158
|
||||
torch._dynamo.config.disable = True
|
||||
elif current_platform.is_hpu():
|
||||
# NOTE(kzawora): PT HPU lazy backend (PT_HPU_LAZY_MODE = 1)
|
||||
# does not support torch.compile
|
||||
# Eager backend (PT_HPU_LAZY_MODE = 0) must be selected for
|
||||
# torch.compile support
|
||||
is_lazy = os.environ.get('PT_HPU_LAZY_MODE', '1') == '1'
|
||||
if is_lazy:
|
||||
torch._dynamo.config.disable = True
|
||||
# NOTE(kzawora) multi-HPU inference with HPUGraphs (lazy-only)
|
||||
# requires enabling lazy collectives
|
||||
# see https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html # noqa: E501
|
||||
os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES'] = 'true'
|
||||
# see https://github.com/NVIDIA/nccl/issues/1234
|
||||
os.environ['NCCL_CUMEM_ENABLE'] = '0'
|
||||
|
||||
# see https://github.com/vllm-project/vllm/issues/10480
|
||||
os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1'
|
||||
# see https://github.com/vllm-project/vllm/issues/10619
|
||||
torch._inductor.config.compile_threads = 1
|
||||
|
||||
__all__ = [
|
||||
"__version__",
|
||||
@ -80,5 +58,4 @@ __all__ = [
|
||||
"AsyncEngineArgs",
|
||||
"initialize_ray_cluster",
|
||||
"PoolingParams",
|
||||
"configure_as_vllm_process",
|
||||
]
|
||||
|
||||
@ -48,8 +48,8 @@ def paged_attention_v1(
|
||||
max_seq_len: int,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
tp_rank: int = 0,
|
||||
blocksparse_local_blocks: int = 0,
|
||||
blocksparse_vert_stride: int = 0,
|
||||
@ -80,8 +80,8 @@ def paged_attention_v2(
|
||||
max_seq_len: int,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
tp_rank: int = 0,
|
||||
blocksparse_local_blocks: int = 0,
|
||||
blocksparse_vert_stride: int = 0,
|
||||
@ -112,8 +112,8 @@ def paged_attention_rocm(
|
||||
max_seq_len: int,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
|
||||
key_cache, value_cache, num_kv_heads,
|
||||
@ -956,8 +956,8 @@ def reshape_and_cache(
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache,
|
||||
value_cache, slot_mapping,
|
||||
@ -971,8 +971,8 @@ def reshape_and_cache_flash(
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache,
|
||||
value_cache, slot_mapping,
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
|
||||
Tuple, Type, TypeVar)
|
||||
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional,
|
||||
Protocol, Set, Tuple, Type, TypeVar)
|
||||
|
||||
import torch
|
||||
|
||||
@ -65,11 +65,6 @@ class AttentionBackend(ABC):
|
||||
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def make_metadata_builder(cls, *args,
|
||||
**kwargs) -> "AttentionMetadataBuilder":
|
||||
return cls.get_builder_cls()(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_kv_cache_shape(
|
||||
@ -128,6 +123,10 @@ class AttentionMetadata:
|
||||
multi_modal_placeholder_index_maps: Optional[Dict[
|
||||
str, MultiModalPlaceholderMap.IndexMap]]
|
||||
|
||||
# Enable/disable KV scales calculation. This is so that we can disable the
|
||||
# calculation until after prefill and cuda graph capture.
|
||||
enable_kv_scales_calculation: bool
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def prefill_metadata(self) -> Optional["AttentionMetadata"]:
|
||||
@ -214,6 +213,12 @@ class AttentionMetadataBuilder(ABC, Generic[T]):
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None:
|
||||
"""Create the builder, remember some configuration and parameters."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def prepare(self) -> None:
|
||||
"""Prepare for one batch."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
@ -223,6 +228,24 @@ class AttentionMetadataBuilder(ABC, Generic[T]):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AttentionLayer(Protocol):
|
||||
|
||||
_k_scale: torch.Tensor
|
||||
_v_scale: torch.Tensor
|
||||
_k_scale_float: float
|
||||
_v_scale_float: float
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
|
||||
|
||||
class AttentionImpl(ABC, Generic[T]):
|
||||
|
||||
@abstractmethod
|
||||
@ -244,13 +267,12 @@ class AttentionImpl(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.utils import (CommonAttentionState,
|
||||
CommonMetadataBuilder)
|
||||
@ -221,6 +222,7 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
|
||||
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
|
||||
multi_modal_placeholder_index_maps=self.
|
||||
multi_modal_placeholder_index_maps,
|
||||
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
|
||||
seq_lens=self.seq_lens[:self.num_prefills],
|
||||
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
|
||||
max_query_len=self.max_query_len,
|
||||
@ -250,6 +252,7 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=False,
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
|
||||
max_query_len=None,
|
||||
@ -358,13 +361,12 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: BlocksparseFlashAttentionMetadata,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention and PagedAttention.
|
||||
@ -401,8 +403,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
@ -439,8 +441,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
self.alibi_slopes,
|
||||
k_scale,
|
||||
v_scale,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
tp_rank=self.tp_rank,
|
||||
blocksparse_local_blocks=self.local_blocks,
|
||||
blocksparse_vert_stride=self.vert_stride,
|
||||
|
||||
@ -8,6 +8,7 @@ import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType)
|
||||
@ -16,7 +17,9 @@ from vllm.attention.backends.utils import (
|
||||
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
|
||||
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
|
||||
is_all_encoder_attn_metadata_set, is_block_tables_empty)
|
||||
from vllm.envs import VLLM_FLASH_ATTN_VERSION
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -24,7 +27,8 @@ if TYPE_CHECKING:
|
||||
ModelInputForGPUWithSamplingMetadata)
|
||||
|
||||
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
|
||||
flash_attn_with_kvcache)
|
||||
flash_attn_with_kvcache,
|
||||
is_fa_version_supported)
|
||||
|
||||
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
@ -226,6 +230,7 @@ class FlashAttentionMetadata(AttentionMetadata):
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=self.
|
||||
multi_modal_placeholder_index_maps,
|
||||
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=self.max_query_len,
|
||||
@ -270,6 +275,7 @@ class FlashAttentionMetadata(AttentionMetadata):
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_decode_query_len=self.max_decode_query_len,
|
||||
@ -374,6 +380,12 @@ class FlashAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[FlashAttentionMetadata]):
|
||||
|
||||
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
||||
self.input_builder = input_builder
|
||||
self.runner = input_builder.runner
|
||||
self.sliding_window = input_builder.sliding_window
|
||||
self.block_size = input_builder.block_size
|
||||
|
||||
def prepare(self):
|
||||
self.slot_mapping: List[int] = []
|
||||
self.prefill_seq_lens: List[int] = []
|
||||
self.context_lens: List[int] = []
|
||||
@ -387,11 +399,6 @@ class FlashAttentionMetadataBuilder(
|
||||
self.num_decode_tokens = 0
|
||||
self.has_prefix_cache_hit = False
|
||||
|
||||
self.input_builder = input_builder
|
||||
self.runner = input_builder.runner
|
||||
self.sliding_window = input_builder.sliding_window
|
||||
self.block_size = input_builder.block_size
|
||||
|
||||
def _add_seq_group(
|
||||
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
||||
chunked_prefill_enabled: bool, prefix_cache_hit: bool):
|
||||
@ -552,6 +559,7 @@ class FlashAttentionMetadataBuilder(
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
seq_lens=seq_lens,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
enable_kv_scales_calculation=True,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=max_query_len,
|
||||
max_decode_query_len=max_decode_query_len,
|
||||
@ -632,15 +640,28 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
f"Supported head sizes are: {support_head_sizes}.")
|
||||
self.attn_type = attn_type
|
||||
|
||||
# if hopper default to FA3, otherwise stick to FA2 for now
|
||||
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
|
||||
# use FA3 as default for both
|
||||
if current_platform.get_device_capability()[0] >= 9:
|
||||
self.fa_version = 3 if is_fa_version_supported(3) else 2
|
||||
else:
|
||||
self.fa_version = 2
|
||||
|
||||
if VLLM_FLASH_ATTN_VERSION is not None:
|
||||
assert VLLM_FLASH_ATTN_VERSION in [2, 3]
|
||||
self.fa_version = VLLM_FLASH_ATTN_VERSION
|
||||
|
||||
assert is_fa_version_supported(self.fa_version)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
@ -657,7 +678,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
NOTE: It in-place updates the output tensor.
|
||||
"""
|
||||
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
|
||||
assert k_scale == 1.0 and v_scale == 1.0, (
|
||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0, (
|
||||
"key/v_scale is not supported in FlashAttention.")
|
||||
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
@ -709,8 +730,8 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
kv_cache[1],
|
||||
updated_slot_mapping.flatten(), # type: ignore[union-attr]
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
(num_prefill_query_tokens, num_prefill_kv_tokens,
|
||||
@ -751,6 +772,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
alibi_slopes=alibi_slopes,
|
||||
softcap=logits_soft_cap,
|
||||
out=prefill_output,
|
||||
fa_version=self.fa_version,
|
||||
)
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
@ -764,7 +786,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
v=value_cache,
|
||||
cu_seqlens_q=prefill_meta.query_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_query_len,
|
||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||
seqused_k=prefill_meta.seq_lens_tensor,
|
||||
max_seqlen_k=max_seq_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
@ -773,6 +795,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
block_table=prefill_meta.block_tables,
|
||||
softcap=logits_soft_cap,
|
||||
out=prefill_output,
|
||||
fa_version=self.fa_version,
|
||||
)
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
@ -792,7 +815,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
v=value_cache,
|
||||
cu_seqlens_q=decode_meta.query_start_loc,
|
||||
max_seqlen_q=decode_meta.max_decode_query_len,
|
||||
cu_seqlens_k=decode_meta.seq_start_loc,
|
||||
seqused_k=decode_meta.seq_lens_tensor,
|
||||
max_seqlen_k=decode_meta.max_decode_seq_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
@ -801,6 +824,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
softcap=logits_soft_cap,
|
||||
block_table=decode_meta.block_tables,
|
||||
out=decode_output,
|
||||
fa_version=self.fa_version,
|
||||
)
|
||||
else:
|
||||
# Use flash_attn_with_kvcache for normal decoding.
|
||||
@ -821,6 +845,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
alibi_slopes=alibi_slopes,
|
||||
softcap=logits_soft_cap,
|
||||
out=decode_output.unsqueeze(1),
|
||||
fa_version=self.fa_version,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
@ -23,6 +23,7 @@ import torch
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionState, AttentionType)
|
||||
@ -218,6 +219,7 @@ class FlashInferState(AttentionState):
|
||||
num_prefills=0,
|
||||
slot_mapping=self._graph_slot_mapping[:batch_size],
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=False,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=batch_size,
|
||||
max_prefill_seq_len=0,
|
||||
@ -487,6 +489,14 @@ class FlashInferMetadata(AttentionMetadata):
|
||||
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
|
||||
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
||||
|
||||
self.input_builder = input_builder
|
||||
self.runner = input_builder.runner
|
||||
|
||||
self.sliding_window = input_builder.sliding_window
|
||||
self.block_size = input_builder.block_size
|
||||
|
||||
def prepare(self):
|
||||
self.slot_mapping: List[int] = []
|
||||
self.prefill_seq_lens: List[int] = []
|
||||
self.context_lens: List[int] = []
|
||||
@ -499,12 +509,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self.num_prefill_tokens = 0
|
||||
self.num_decode_tokens = 0
|
||||
|
||||
self.input_builder = input_builder
|
||||
self.runner = input_builder.runner
|
||||
|
||||
self.sliding_window = input_builder.sliding_window
|
||||
self.block_size = input_builder.block_size
|
||||
|
||||
# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
|
||||
# for the precise definition of the following fields.
|
||||
# An example:
|
||||
@ -730,6 +734,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
num_prefills=self.num_prefills,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
enable_kv_scales_calculation=False,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
max_prefill_seq_len=max_prefill_seq_len,
|
||||
@ -792,13 +797,12 @@ class FlashInferImpl(AttentionImpl):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashInferMetadata,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
@ -826,8 +830,8 @@ class FlashInferImpl(AttentionImpl):
|
||||
kv_cache[:, 1],
|
||||
attn_metadata.slot_mapping.flatten(),
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
|
||||
# to process the cache when the kv_cache_dtype is fp8
|
||||
@ -886,8 +890,8 @@ class FlashInferImpl(AttentionImpl):
|
||||
kv_cache,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
causal=True,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
k_scale=layer._k_scale_float,
|
||||
v_scale=layer._v_scale_float,
|
||||
window_left=window_left)
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
assert decode_meta is not None
|
||||
@ -897,8 +901,8 @@ class FlashInferImpl(AttentionImpl):
|
||||
kv_cache,
|
||||
sm_scale=softmax_scale,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
k_scale=layer._k_scale_float,
|
||||
v_scale=layer._v_scale_float,
|
||||
window_left=window_left)
|
||||
|
||||
if prefill_output is None and decode_output is not None:
|
||||
|
||||
@ -11,6 +11,7 @@ import vllm_hpu_extension.ops as ops
|
||||
from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention,
|
||||
@ -152,13 +153,12 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: HPUAttentionMetadata,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with xFormers and PagedAttention.
|
||||
|
||||
@ -7,6 +7,7 @@ import torch
|
||||
|
||||
from vllm._ipex_ops import ipex_ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||
@ -171,13 +172,12 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: IpexAttnMetadata, # type: ignore
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with IPEX varlen_attention and PagedAttention.
|
||||
@ -193,7 +193,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert k_scale == 1.0 and v_scale == 1.0
|
||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||
num_tokens, hidden_size = query.shape
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
@ -210,8 +210,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping.flatten(),
|
||||
self.kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if attn_metadata.is_prompt:
|
||||
@ -296,8 +296,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
||||
max_seq_len,
|
||||
self.alibi_slopes,
|
||||
self.kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
else:
|
||||
# Run PagedAttention V2.
|
||||
@ -329,8 +329,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
||||
max_seq_len,
|
||||
self.alibi_slopes,
|
||||
self.kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
# Reshape the output tensor.
|
||||
|
||||
@ -5,6 +5,7 @@ import torch
|
||||
import torch_xla.experimental.custom_kernel # Required to register custom ops.
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
|
||||
@ -150,13 +151,12 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
||||
attn_metadata: PallasMetadata,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with Pallas attention.
|
||||
@ -173,7 +173,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
Returns:
|
||||
shape = [batch_size, seq_len, num_heads * head_size]
|
||||
"""
|
||||
assert k_scale == 1.0 and v_scale == 1.0
|
||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||
batch_size, seq_len, hidden_size = query.shape
|
||||
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
|
||||
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
|
||||
|
||||
@ -140,6 +140,7 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=self.
|
||||
multi_modal_placeholder_index_maps,
|
||||
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
|
||||
seq_lens=self.seq_lens[:self.num_prefills],
|
||||
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
|
||||
max_decode_query_len=0,
|
||||
@ -173,6 +174,7 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
|
||||
max_decode_query_len=self.max_decode_query_len,
|
||||
@ -253,6 +255,11 @@ class PlaceholderAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[PlaceholderAttentionMetadata]):
|
||||
|
||||
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
||||
|
||||
self.input_builder = input_builder
|
||||
self.runner = input_builder.runner
|
||||
|
||||
def prepare(self):
|
||||
self.prefill_seq_lens: List[int] = []
|
||||
self.context_lens: List[int] = []
|
||||
self.curr_seq_lens: List[int] = []
|
||||
@ -263,9 +270,6 @@ class PlaceholderAttentionMetadataBuilder(
|
||||
self.num_prefill_tokens = 0
|
||||
self.num_decode_tokens = 0
|
||||
|
||||
self.input_builder = input_builder
|
||||
self.runner = input_builder.runner
|
||||
|
||||
def _add_seq_group(
|
||||
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
||||
chunked_prefill_enabled: bool):
|
||||
@ -378,6 +382,7 @@ class PlaceholderAttentionMetadataBuilder(
|
||||
num_prefills=self.num_prefills,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
enable_kv_scales_calculation=True,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
seq_lens=seq_lens,
|
||||
|
||||
@ -7,6 +7,7 @@ import torch
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.utils import (CommonAttentionState,
|
||||
CommonMetadataBuilder)
|
||||
@ -152,6 +153,7 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
|
||||
multi_modal_placeholder_index_maps=self.
|
||||
multi_modal_placeholder_index_maps,
|
||||
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
|
||||
seq_lens=self.seq_lens[:self.num_prefills],
|
||||
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
|
||||
max_query_len=self.max_query_len,
|
||||
@ -181,6 +183,7 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
|
||||
max_query_len=None,
|
||||
@ -414,13 +417,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: ROCmFlashAttentionMetadata,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention and PagedAttention.
|
||||
@ -458,8 +460,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
@ -567,8 +569,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
prefill_meta.max_query_len,
|
||||
self.alibi_slopes,
|
||||
self.sliding_window[0],
|
||||
k_scale,
|
||||
v_scale,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
@ -613,8 +615,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
max_seq_len,
|
||||
self.alibi_slopes,
|
||||
self.kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
else:
|
||||
output[num_prefill_tokens:] = PagedAttention.forward_decode(
|
||||
@ -628,8 +630,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
self.alibi_slopes,
|
||||
k_scale,
|
||||
v_scale,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
# Reshape the output tensor.
|
||||
|
||||
@ -7,6 +7,7 @@ import torch
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType)
|
||||
@ -281,7 +282,10 @@ class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]):
|
||||
|
||||
def __init__(self, input_builder: ModelInputForCPUBuilder) -> None:
|
||||
self.chunked_prefill = input_builder.chunked_prefill
|
||||
self.input_data = input_builder.input_data
|
||||
self.input_builder = input_builder
|
||||
|
||||
def prepare(self):
|
||||
self.input_data = self.input_builder.input_data
|
||||
|
||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||
cuda_graph_pad_size: int, batch_size: int) -> TorchSDPAMetadata:
|
||||
@ -375,6 +379,7 @@ class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]):
|
||||
prefill_block_tables=prefill_block_tables,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
enable_kv_scales_calculation=False,
|
||||
)
|
||||
|
||||
return attn_metadata
|
||||
@ -429,13 +434,12 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: TorchSDPAMetadata, # type: ignore
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with torch SDPA and PagedAttention.
|
||||
@ -451,7 +455,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert k_scale == 1.0 and v_scale == 1.0
|
||||
attn_type = self.attn_type
|
||||
if (attn_type == AttentionType.ENCODER
|
||||
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
||||
@ -493,11 +496,9 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
# Update self-attention KV cache (prefill/decode)
|
||||
updated_slot_mapping = attn_metadata.slot_mapping
|
||||
|
||||
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
||||
value_cache,
|
||||
updated_slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
k_scale, v_scale)
|
||||
PagedAttention.write_to_paged_cache(
|
||||
key, value, key_cache, value_cache, updated_slot_mapping,
|
||||
self.kv_cache_dtype, layer._k_scale, layer._v_scale)
|
||||
|
||||
if attn_type != AttentionType.ENCODER:
|
||||
# Decoder self-attention supports chunked prefill.
|
||||
@ -571,8 +572,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
self.alibi_slopes,
|
||||
k_scale,
|
||||
v_scale,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
# Reshape the output tensor.
|
||||
|
||||
@ -122,6 +122,13 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
|
||||
_metadata_cls: Type[TAttentionMetadata]
|
||||
|
||||
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
||||
self.input_builder = input_builder
|
||||
self.runner = input_builder.runner
|
||||
|
||||
self.sliding_window = input_builder.sliding_window
|
||||
self.block_size = input_builder.block_size
|
||||
|
||||
def prepare(self):
|
||||
self.slot_mapping: List[int] = []
|
||||
self.prefill_seq_lens: List[int] = []
|
||||
self.context_lens: List[int] = []
|
||||
@ -134,12 +141,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
|
||||
self.num_prefill_tokens = 0
|
||||
self.num_decode_tokens = 0
|
||||
|
||||
self.input_builder = input_builder
|
||||
self.runner = input_builder.runner
|
||||
|
||||
self.sliding_window = input_builder.sliding_window
|
||||
self.block_size = input_builder.block_size
|
||||
|
||||
def _add_seq_group(
|
||||
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
||||
chunked_prefill_enabled: bool):
|
||||
@ -264,6 +265,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
|
||||
num_prefills=self.num_prefills,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
enable_kv_scales_calculation=True,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
seq_lens=seq_lens,
|
||||
@ -316,6 +318,7 @@ class CommonAttentionState(AttentionState):
|
||||
num_decode_tokens=batch_size,
|
||||
slot_mapping=self._graph_slot_mapping[:batch_size],
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=self._graph_seq_lens[:batch_size],
|
||||
max_query_len=1,
|
||||
|
||||
@ -10,6 +10,7 @@ from xformers.ops.fmha.attn_bias import (AttentionBias,
|
||||
LowerTriangularMaskWithTensorBias)
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.utils import (
|
||||
CommonAttentionState, CommonMetadataBuilder,
|
||||
@ -217,6 +218,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=self.
|
||||
multi_modal_placeholder_index_maps,
|
||||
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=self.max_query_len,
|
||||
@ -261,6 +263,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_decode_seq_len,
|
||||
@ -412,13 +415,12 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor],
|
||||
value: Optional[torch.Tensor],
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: "XFormersMetadata",
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with xFormers and PagedAttention.
|
||||
@ -524,11 +526,9 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
# If kv_cache is not provided, the new key and value tensors are
|
||||
# not cached. This happens during the initial memory
|
||||
# profiling run.
|
||||
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
||||
value_cache,
|
||||
updated_slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
k_scale, v_scale)
|
||||
PagedAttention.write_to_paged_cache(
|
||||
key, value, key_cache, value_cache, updated_slot_mapping,
|
||||
self.kv_cache_dtype, layer._k_scale, layer._v_scale)
|
||||
(num_prefill_query_tokens, num_prefill_kv_tokens,
|
||||
num_decode_query_tokens) = \
|
||||
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
|
||||
@ -580,8 +580,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
prefill_meta.max_query_len,
|
||||
self.alibi_slopes,
|
||||
self.sliding_window,
|
||||
k_scale,
|
||||
v_scale,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
assert output[:num_prefill_query_tokens].shape == out.shape
|
||||
output[:num_prefill_query_tokens] = out
|
||||
@ -607,8 +607,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
self.alibi_slopes,
|
||||
k_scale,
|
||||
v_scale,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
# Reshape the output tensor.
|
||||
|
||||
@ -5,6 +5,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import AttentionMetadata, AttentionType
|
||||
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
|
||||
from vllm.config import CacheConfig, get_current_vllm_config
|
||||
@ -57,10 +58,12 @@ class Attention(nn.Module):
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
is_attention_free = cache_config.is_attention_free
|
||||
calculate_kv_scales = cache_config.calculate_kv_scales
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
is_attention_free = False
|
||||
calculate_kv_scales = False
|
||||
if num_kv_heads is None:
|
||||
num_kv_heads = num_heads
|
||||
|
||||
@ -70,8 +73,15 @@ class Attention(nn.Module):
|
||||
# expect the pre-quantized k/v_scale to be loaded along
|
||||
# with the model weights.
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self._k_scale = 1.0
|
||||
self._v_scale = 1.0
|
||||
self.calculate_kv_scales = calculate_kv_scales
|
||||
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
|
||||
# We also keep the float32 versions of k/v_scale for attention
|
||||
# backends that don't support tensors (Flashinfer)
|
||||
self._k_scale_float = 1.0
|
||||
self._v_scale_float = 1.0
|
||||
|
||||
quant_method = quant_config.get_quant_method(
|
||||
self, prefix=prefix) if quant_config else None
|
||||
if quant_method is not None:
|
||||
@ -127,6 +137,9 @@ class Attention(nn.Module):
|
||||
).parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
|
||||
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
@ -135,6 +148,9 @@ class Attention(nn.Module):
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
if self.calculate_kv_scales and \
|
||||
attn_metadata.enable_kv_scales_calculation:
|
||||
self.calc_kv_scales(key, value)
|
||||
if self.use_output:
|
||||
output = torch.empty_like(query)
|
||||
hidden_size = query.size(-1)
|
||||
@ -161,6 +177,14 @@ class Attention(nn.Module):
|
||||
return torch.ops.vllm.unified_attention(
|
||||
query, key, value, self.layer_name)
|
||||
|
||||
def calc_kv_scales(self, key, value):
|
||||
self._k_scale.copy_(torch.abs(key).max() / self.k_range)
|
||||
self._v_scale.copy_(torch.abs(value).max() / self.v_range)
|
||||
self._k_scale_float = self._k_scale.item()
|
||||
self._v_scale_float = self._v_scale.item()
|
||||
# We only calculate the scales once
|
||||
self.calculate_kv_scales = False
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"head_size={self.impl.head_size}" # type: ignore
|
||||
s += f", num_heads={self.impl.num_heads}" # type: ignore
|
||||
@ -243,8 +267,7 @@ def unified_attention(
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
self = forward_context.attn_layers[layer_name]
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
|
||||
self._k_scale, self._v_scale)
|
||||
return self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
|
||||
|
||||
|
||||
def unified_attention_fake(
|
||||
@ -276,13 +299,12 @@ def unified_attention_with_output(
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
self = forward_context.attn_layers[layer_name]
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
self.impl.forward(query,
|
||||
self.impl.forward(self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
self._k_scale,
|
||||
self._v_scale,
|
||||
output=output)
|
||||
|
||||
|
||||
|
||||
@ -52,8 +52,8 @@ class _PagedAttention:
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
*args,
|
||||
) -> None:
|
||||
ops.reshape_and_cache(
|
||||
@ -80,8 +80,8 @@ class _PagedAttention:
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
*args,
|
||||
) -> None:
|
||||
tp_rank: int = 0
|
||||
@ -149,8 +149,8 @@ class _IPEXPagedAttention(_PagedAttention):
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
*args,
|
||||
) -> None:
|
||||
ipex_modules.PagedAttention.reshape_and_cache(
|
||||
@ -170,8 +170,8 @@ class _IPEXPagedAttention(_PagedAttention):
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
*args,
|
||||
) -> None:
|
||||
block_size = value_cache.shape[2]
|
||||
|
||||
@ -69,8 +69,8 @@ class PagedAttention:
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
) -> None:
|
||||
ops.reshape_and_cache(
|
||||
key,
|
||||
@ -95,8 +95,8 @@ class PagedAttention:
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
tp_rank: int = 0,
|
||||
blocksparse_local_blocks: int = 0,
|
||||
blocksparse_vert_stride: int = 0,
|
||||
@ -204,8 +204,8 @@ class PagedAttention:
|
||||
max_query_len: int,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
sliding_window: Optional[int],
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
output = torch.empty_like(query)
|
||||
context_attention_fwd(
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user