Compare commits
119 Commits
v1_fix_pro
...
pd_schedul
| Author | SHA1 | Date | |
|---|---|---|---|
| 161010c384 | |||
| 3d3ab3689f | |||
| 686623c5e7 | |||
| aadb656562 | |||
| 87e067de41 | |||
| 26507f8973 | |||
| 9c1d5b456d | |||
| e31045f95c | |||
| aaec845f8e | |||
| 7bdfd29a35 | |||
| e78587a64c | |||
| 7eb4255628 | |||
| 6a0f547561 | |||
| 30ed81b7ca | |||
| 7a4a5de729 | |||
| c16fb5dae8 | |||
| e37073efd7 | |||
| 183dad7a85 | |||
| 3408e47159 | |||
| 0377b8310b | |||
| e4755f7fac | |||
| 92edf35826 | |||
| eb5819b2d9 | |||
| 5989f4684d | |||
| 5125d72f02 | |||
| a018e555fd | |||
| 6211b92273 | |||
| 05fcd1b430 | |||
| 7c02d6a137 | |||
| 11c3b98491 | |||
| dbe7f07001 | |||
| c69bf4ee06 | |||
| d27ea94034 | |||
| 99ed526101 | |||
| 207da28186 | |||
| 5b1aca2ae3 | |||
| d8e557b5e5 | |||
| 61a44a0b22 | |||
| a6481525b8 | |||
| 8cac35ba43 | |||
| 9dbf7a2dc1 | |||
| 607029e515 | |||
| cb072ce93b | |||
| 95aca283b4 | |||
| 2b05b8ce69 | |||
| 3c776dcefb | |||
| 2cbd4d2999 | |||
| 3092375e27 | |||
| 3cd91dc955 | |||
| 8a7368e069 | |||
| 93e561ec4d | |||
| e1b004839a | |||
| ee378f3d49 | |||
| e82ee40de3 | |||
| facbe2a114 | |||
| 7168920491 | |||
| 21378a2323 | |||
| 976711d9db | |||
| 44fa4d556c | |||
| 3ac98edcb1 | |||
| 966c742ed2 | |||
| 0d7d05f4b6 | |||
| 96bb8aa68b | |||
| 3badb0213b | |||
| fdcb850f14 | |||
| 54a66e5fee | |||
| 280d62b8a2 | |||
| 1666e66443 | |||
| 1575c1701a | |||
| 6ae996a873 | |||
| b590adfdc1 | |||
| b4fe16c75b | |||
| bc5dd4f669 | |||
| dbb036cf61 | |||
| 70e7ed841d | |||
| d06ba4ed3f | |||
| 6b40996ae8 | |||
| d2020acac7 | |||
| 1eb3c2ed48 | |||
| c64ee87267 | |||
| b1308b84a3 | |||
| 7b5ecf79bd | |||
| 9883a18859 | |||
| b3f2fddd17 | |||
| aa29841ede | |||
| 6bf27affb6 | |||
| 1dd23386ec | |||
| 7cbfc10943 | |||
| ce4ddd2d1a | |||
| e51929ebca | |||
| dc1b4a6f13 | |||
| 63d2705edb | |||
| d085a44082 | |||
| f49e5aff11 | |||
| 6c11ecf8d3 | |||
| 93e5f3c5fb | |||
| 70363bccfa | |||
| 3cdc57669f | |||
| 68bb122eb4 | |||
| d9fc8cd9da | |||
| f069f3ea74 | |||
| c5bc0e7fcc | |||
| 4a3a518722 | |||
| fbf722c6e6 | |||
| e92d7085bf | |||
| bd6028d6b0 | |||
| 802329dee9 | |||
| 41cc883c29 | |||
| 57504a4bcf | |||
| ed4792c990 | |||
| 87b836ba77 | |||
| 56c76c2e0e | |||
| c09632a66c | |||
| a3bf8d4a2b | |||
| 16eda8c43a | |||
| cd77382ac1 | |||
| 71b9cde010 | |||
| 5285589f37 | |||
| f41647ee6b |
@ -5,10 +5,34 @@
|
||||
set -ex
|
||||
|
||||
# Setup cleanup
|
||||
remove_docker_container() { docker rm -f cpu-test || true; docker system prune -f; }
|
||||
remove_docker_container() { podman rm -f cpu-test-ubi9-ppc || true; podman system prune -f; }
|
||||
trap remove_docker_container EXIT
|
||||
remove_docker_container
|
||||
|
||||
# Try building the docker image
|
||||
docker build -t cpu-test -f docker/Dockerfile.ppc64le .
|
||||
podman build -t cpu-test-ubi9-ppc -f docker/Dockerfile.ppc64le .
|
||||
|
||||
# Run the image
|
||||
podman run -itd --entrypoint /bin/bash -v /tmp/:/root/.cache/huggingface --privileged=true --network host -e HF_TOKEN --name cpu-test-ubi9-ppc cpu-test-ubi9-ppc
|
||||
|
||||
function cpu_tests() {
|
||||
|
||||
# offline inference
|
||||
podman exec cpu-test-ubi9-ppc bash -c "
|
||||
set -e
|
||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m"
|
||||
|
||||
# Run basic model test
|
||||
podman exec cpu-test-ubi9-ppc bash -c "
|
||||
set -e
|
||||
pip install pytest pytest-asyncio einops peft Pillow soundfile transformers_stream_generator matplotlib
|
||||
pip install sentence-transformers datamodel_code_generator
|
||||
pytest -v -s tests/models/embedding/language/test_cls_models.py::test_classification_models[float-jason9693/Qwen2.5-1.5B-apeach]
|
||||
pytest -v -s tests/models/embedding/language/test_embedding.py::test_models[half-BAAI/bge-base-en-v1.5]
|
||||
pytest -v -s tests/models/encoder_decoder/language -m cpu_model"
|
||||
}
|
||||
|
||||
# All of CPU tests are expected to be finished less than 40 mins.
|
||||
export -f cpu_tests
|
||||
timeout 40m bash -c cpu_tests
|
||||
|
||||
|
||||
13
.buildkite/scripts/hardware_ci/run-cpu-test-s390x.sh
Executable file
13
.buildkite/scripts/hardware_ci/run-cpu-test-s390x.sh
Executable file
@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
|
||||
# This script build the CPU docker image and run the offline inference inside the container.
|
||||
# It serves a sanity check for compilation and basic model usage.
|
||||
set -ex
|
||||
|
||||
# Setup cleanup
|
||||
remove_docker_container() { docker rm -f cpu-test || true; docker system prune -f; }
|
||||
trap remove_docker_container EXIT
|
||||
remove_docker_container
|
||||
|
||||
# Try building the docker image
|
||||
docker build -t cpu-test -f docker/Dockerfile.s390x .
|
||||
@ -17,10 +17,12 @@ source /etc/environment
|
||||
docker run --privileged --net host --shm-size=16G -it \
|
||||
-e "HF_TOKEN=$HF_TOKEN" --name tpu-test \
|
||||
vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git \
|
||||
&& python3 -m pip install pytest \
|
||||
&& python3 -m pip install pytest tpu-info \
|
||||
&& python3 -m pip install lm_eval[api]==0.4.4 \
|
||||
&& export VLLM_USE_V1=1 \
|
||||
&& export VLLM_XLA_CHECK_RECOMPILATION=1 \
|
||||
&& echo HARDWARE \
|
||||
&& tpu-info \
|
||||
&& echo TEST_0 \
|
||||
&& pytest -v -s /workspace/vllm/tests/v1/tpu/test_perf.py \
|
||||
&& echo TEST_1 \
|
||||
|
||||
@ -118,7 +118,7 @@ steps:
|
||||
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
|
||||
- VLLM_USE_V1=0 pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/correctness/
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_openai_schema.py
|
||||
- pytest -v -s entrypoints/test_chat_utils.py
|
||||
- VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
||||
|
||||
@ -341,6 +341,13 @@ steps:
|
||||
commands:
|
||||
- bash scripts/run-benchmarks.sh
|
||||
|
||||
- label: Benchmarks CLI Test # 10min
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/benchmarks/
|
||||
commands:
|
||||
- pytest -v -s benchmarks/
|
||||
|
||||
- label: Quantization Test # 33min
|
||||
source_file_dependencies:
|
||||
- csrc/
|
||||
@ -378,8 +385,10 @@ steps:
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/tool_use
|
||||
- tests/mistral_tool_use
|
||||
commands:
|
||||
- pytest -v -s tool_use
|
||||
- pytest -v -s mistral_tool_use
|
||||
|
||||
##### models test #####
|
||||
|
||||
@ -391,8 +400,9 @@ steps:
|
||||
- pytest -v -s models/test_transformers.py
|
||||
- pytest -v -s models/test_registry.py
|
||||
# V1 Test: https://github.com/vllm-project/vllm/issues/14531
|
||||
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4'
|
||||
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'
|
||||
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'llama4'
|
||||
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'plamo2'
|
||||
|
||||
- label: Language Models Test (Standard) # 32min
|
||||
#mirror_hardwares: [amd]
|
||||
@ -402,6 +412,8 @@ steps:
|
||||
- tests/models/embedding/language
|
||||
- tests/models/encoder_decoder/language
|
||||
commands:
|
||||
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
|
||||
- pip install causal-conv1d
|
||||
- pytest -v -s models/decoder_only/language -m 'core_model or quant_model'
|
||||
- pytest -v -s models/embedding/language -m core_model
|
||||
|
||||
@ -413,6 +425,8 @@ steps:
|
||||
- tests/models/embedding/language
|
||||
- tests/models/encoder_decoder/language
|
||||
commands:
|
||||
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
|
||||
- pip install causal-conv1d
|
||||
- pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model'
|
||||
- pytest -v -s models/embedding/language -m 'not core_model'
|
||||
|
||||
@ -538,6 +552,7 @@ steps:
|
||||
# - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
|
||||
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
||||
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/test_disagg.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown
|
||||
|
||||
- label: Plugin Tests (2 GPUs) # 40min
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/200-installation.yml
vendored
2
.github/ISSUE_TEMPLATE/200-installation.yml
vendored
@ -14,7 +14,7 @@ body:
|
||||
description: |
|
||||
Please run the following and paste the output below.
|
||||
```sh
|
||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/collect_env.py
|
||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/vllm/collect_env.py
|
||||
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
||||
python collect_env.py
|
||||
```
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/300-usage.yml
vendored
2
.github/ISSUE_TEMPLATE/300-usage.yml
vendored
@ -14,7 +14,7 @@ body:
|
||||
description: |
|
||||
Please run the following and paste the output below.
|
||||
```sh
|
||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/collect_env.py
|
||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/vllm/collect_env.py
|
||||
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
||||
python collect_env.py
|
||||
```
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/400-bug-report.yml
vendored
2
.github/ISSUE_TEMPLATE/400-bug-report.yml
vendored
@ -14,7 +14,7 @@ body:
|
||||
description: |
|
||||
Please run the following and paste the output below.
|
||||
```sh
|
||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/collect_env.py
|
||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/vllm/collect_env.py
|
||||
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
||||
python collect_env.py
|
||||
```
|
||||
|
||||
@ -35,7 +35,7 @@ body:
|
||||
description: |
|
||||
Please run the following and paste the output below.
|
||||
```sh
|
||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/collect_env.py
|
||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/vllm/collect_env.py
|
||||
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
||||
python collect_env.py
|
||||
```
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@ -203,3 +203,6 @@ benchmarks/**/*.json
|
||||
# Linting
|
||||
actionlint
|
||||
shellcheck*/
|
||||
|
||||
# Ingore moe/marlin_moe gen code
|
||||
csrc/moe/marlin_moe_wna16/kernel_*
|
||||
|
||||
@ -11,7 +11,6 @@ repos:
|
||||
hooks:
|
||||
- id: yapf
|
||||
args: [--in-place, --verbose]
|
||||
additional_dependencies: [toml] # TODO: Remove when yapf is upgraded
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.9.3
|
||||
hooks:
|
||||
|
||||
@ -609,21 +609,51 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}")
|
||||
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
|
||||
if (MARLIN_MOE_ARCHS)
|
||||
set(MARLIN_MOE_SRC
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel.h"
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h"
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu"
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h"
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu"
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h"
|
||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu"
|
||||
"csrc/moe/marlin_moe_ops.cu")
|
||||
|
||||
#
|
||||
# For the Marlin MOE kernels we automatically generate sources for various
|
||||
# preselected input type pairs and schedules.
|
||||
# Generate sources:
|
||||
set(MOE_MARLIN_GEN_SCRIPT
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/csrc/moe/marlin_moe_wna16/generate_kernels.py)
|
||||
file(MD5 ${MOE_MARLIN_GEN_SCRIPT} MOE_MARLIN_GEN_SCRIPT_HASH)
|
||||
|
||||
message(STATUS "Marlin MOE generation script hash: ${MOE_MARLIN_GEN_SCRIPT_HASH}")
|
||||
message(STATUS "Last run Marlin MOE generate script hash: $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH}")
|
||||
|
||||
if (NOT DEFINED CACHE{MOE_MARLIN_GEN_SCRIPT_HASH}
|
||||
OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH})
|
||||
execute_process(
|
||||
COMMAND ${CMAKE_COMMAND} -E env
|
||||
PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH
|
||||
${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT}
|
||||
RESULT_VARIABLE moe_marlin_generation_result
|
||||
OUTPUT_VARIABLE moe_marlin_generation_output
|
||||
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log
|
||||
ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log
|
||||
)
|
||||
|
||||
if (NOT moe_marlin_generation_result EQUAL 0)
|
||||
message(FATAL_ERROR "Marlin MOE generation failed."
|
||||
" Result: \"${moe_marlin_generation_result}\""
|
||||
"\nCheck the log for details: "
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log")
|
||||
else()
|
||||
set(MOE_MARLIN_GEN_SCRIPT_HASH ${MOE_MARLIN_GEN_SCRIPT_HASH}
|
||||
CACHE STRING "Last run Marlin MOE generate script hash" FORCE)
|
||||
message(STATUS "Marlin MOE generation completed successfully.")
|
||||
endif()
|
||||
else()
|
||||
message(STATUS "Marlin MOE generation script has not changed, skipping generation.")
|
||||
endif()
|
||||
|
||||
file(GLOB MOE_WNAA16_MARLIN_SRC "csrc/moe/marlin_moe_wna16/*.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${MARLIN_MOE_SRC}"
|
||||
SRCS "${MOE_WNAA16_MARLIN_SRC}"
|
||||
CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
|
||||
|
||||
list(APPEND VLLM_MOE_EXT_SRC "${MARLIN_MOE_SRC}")
|
||||
list(APPEND VLLM_MOE_EXT_SRC ${MOE_WNAA16_MARLIN_SRC})
|
||||
|
||||
message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building Marlin MOE kernels as no compatible archs found"
|
||||
|
||||
@ -10,7 +10,7 @@ Easy, fast, and cheap LLM serving for everyone
|
||||
</h3>
|
||||
|
||||
<p align="center">
|
||||
| <a href="https://docs.vllm.ai"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://x.com/vllm_project"><b>Twitter/X</b></a> | <a href="https://discuss.vllm.ai"><b>User Forum</b></a> | <a href="https://slack.vllm.ai"><b>Developer Slack</b></a> |
|
||||
| <a href="https://docs.vllm.ai"><b>Documentation</b></a> | <a href="https://blog.vllm.ai/"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://x.com/vllm_project"><b>Twitter/X</b></a> | <a href="https://discuss.vllm.ai"><b>User Forum</b></a> | <a href="https://slack.vllm.ai"><b>Developer Slack</b></a> |
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
@ -597,18 +597,26 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--prefix-len",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of fixed prefix tokens before the random "
|
||||
"context in a request (default: 0).",
|
||||
)
|
||||
default=None,
|
||||
help=f"Number of prefix tokens to be used in RandomDataset "
|
||||
"and SonnetDataset. For RandomDataset, the total input "
|
||||
"length is the sum of prefix-len (default: "
|
||||
f"{RandomDataset.DEFAULT_PREFIX_LEN}) and a random context length "
|
||||
"sampled from [input_len * (1 - range_ratio), "
|
||||
"input_len * (1 + range_ratio)]. For SonnetDataset, "
|
||||
f"prefix_len (default: {SonnetDataset.DEFAULT_PREFIX_LEN}) "
|
||||
"controls how much of the input is fixed lines versus "
|
||||
"random lines, but the total input length remains approximately "
|
||||
"input_len tokens.")
|
||||
# random dataset
|
||||
parser.add_argument(
|
||||
"--random-range-ratio",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Range ratio for sampling input/output length, "
|
||||
"used only for RandomDataset. Must be in the range [0, 1) to define "
|
||||
"a symmetric sampling range "
|
||||
default=None,
|
||||
help=f"Range ratio (default : {RandomDataset.DEFAULT_RANGE_RATIO}) "
|
||||
"for sampling input/output length, "
|
||||
"used only for RandomDataset. Must be in the range [0, 1) to "
|
||||
"define a symmetric sampling range "
|
||||
"[length * (1 - range_ratio), length * (1 + range_ratio)].",
|
||||
)
|
||||
|
||||
|
||||
@ -38,7 +38,7 @@ else()
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn
|
||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||
GIT_TAG dc9d410b3e2d6534a4c70724c2515f4def670a22
|
||||
GIT_TAG 0a721daebe4fa7149f06ecf3d3eabeb6dcd0f1fa
|
||||
GIT_PROGRESS TRUE
|
||||
# Don't share the vllm-flash-attn build between build types
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
|
||||
@ -107,13 +107,14 @@ __global__ void merge_attn_states_kernel(
|
||||
|
||||
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \
|
||||
{ \
|
||||
vllm::merge_attn_states_kernel<scalar_t, NUM_THREADS><<<grid, block>>>( \
|
||||
reinterpret_cast<scalar_t*>(output.data_ptr()), output_lse_ptr, \
|
||||
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \
|
||||
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
|
||||
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
|
||||
reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \
|
||||
num_heads, head_size); \
|
||||
vllm::merge_attn_states_kernel<scalar_t, NUM_THREADS> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<scalar_t*>(output.data_ptr()), output_lse_ptr, \
|
||||
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \
|
||||
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
|
||||
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
|
||||
reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \
|
||||
num_heads, head_size); \
|
||||
}
|
||||
|
||||
/*@brief Merges the attention states from prefix and suffix
|
||||
@ -122,10 +123,10 @@ __global__ void merge_attn_states_kernel(
|
||||
* @param output [n,h,d] The output tensor to store the merged attention states.
|
||||
* @param output_lse [h,d] Optional tensor to store the log-sum-exp values.
|
||||
* @param prefix_output [n,h,d] The prefix attention states.
|
||||
* @param prefix_lse [h,d] The log-sum-exp values for the prefix attention
|
||||
* @param prefix_lse [h,n] The log-sum-exp values for the prefix attention
|
||||
* states.
|
||||
* @param suffix_output [n,h,d] The suffix attention states.
|
||||
* @param suffix_lse [h,d] The log-sum-exp values for the suffix attention
|
||||
* @param suffix_lse [h,n] The log-sum-exp values for the suffix attention
|
||||
* states.
|
||||
*/
|
||||
template <typename scalar_t>
|
||||
@ -146,13 +147,17 @@ void merge_attn_states_launcher(torch::Tensor& output,
|
||||
if (output_lse.has_value()) {
|
||||
output_lse_ptr = output_lse.value().data_ptr<float>();
|
||||
}
|
||||
// process one pack elements per thread. float -> 4, half/bf16 -> 8
|
||||
// Process one pack elements per thread. for float, the
|
||||
// pack_size is 4 for half/bf16, the pack_size is 8.
|
||||
const uint threads_per_head = head_size / pack_size;
|
||||
const uint total_threads = num_tokens * num_heads * threads_per_head;
|
||||
|
||||
dim3 block(NUM_THREADS);
|
||||
dim3 grid((total_threads + NUM_THREADS - 1) / NUM_THREADS);
|
||||
|
||||
const c10::cuda::OptionalCUDAGuard device_guard(prefix_output.device());
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS);
|
||||
}
|
||||
|
||||
|
||||
@ -375,7 +375,7 @@ class CustomAllreduce {
|
||||
bool fully_connected_;
|
||||
|
||||
RankSignals sg_;
|
||||
// Stores an map from a pointer to its peer pointers from all ranks.
|
||||
// Stores a map from a pointer to its peer pointers from all ranks.
|
||||
std::unordered_map<void*, RankData*> buffers_;
|
||||
Signal* self_sg_;
|
||||
|
||||
|
||||
103
csrc/moe/marlin_moe_wna16/generate_kernels.py
Normal file
103
csrc/moe/marlin_moe_wna16/generate_kernels.py
Normal file
@ -0,0 +1,103 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import glob
|
||||
import itertools
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
import jinja2
|
||||
|
||||
FILE_HEAD = """
|
||||
// auto generated by generate.py
|
||||
// clang-format off
|
||||
|
||||
#include "kernel.h"
|
||||
#include "marlin_template.h"
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
""".strip()
|
||||
|
||||
TEMPLATE = ("template __global__ void Marlin<"
|
||||
"{{scalar_t}}, "
|
||||
"{{w_type_id}}, "
|
||||
"{{threads}}, "
|
||||
"{{thread_m_blocks}}, "
|
||||
"{{thread_n_blocks}}, "
|
||||
"{{thread_k_blocks}}, "
|
||||
"{{'true' if m_block_size_8 else 'false'}}, "
|
||||
"{{stages}}, "
|
||||
"{{'true' if has_act_order else 'false'}}, "
|
||||
"{{'true' if has_zp else 'false'}}, "
|
||||
"{{group_blocks}}, "
|
||||
"{{'true' if is_zp_float else 'false'}}>"
|
||||
"( MARLIN_KERNEL_PARAMS );")
|
||||
|
||||
# int8 with zero point case (vllm::kU8) is also supported,
|
||||
# we don't add it to reduce wheel size.
|
||||
SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128"]
|
||||
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
|
||||
|
||||
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
|
||||
# group_blocks:
|
||||
# = 0 : act order case
|
||||
# = -1 : channelwise quantization
|
||||
# > 0 : group_size=16*group_blocks
|
||||
GROUP_BLOCKS = [0, -1, 2, 4, 8]
|
||||
DTYPES = ["fp16", "bf16"]
|
||||
|
||||
|
||||
def remove_old_kernels():
|
||||
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"):
|
||||
subprocess.call(["rm", "-f", filename])
|
||||
|
||||
|
||||
def generate_new_kernels():
|
||||
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
|
||||
has_zp = "B" not in scalar_type
|
||||
all_template_str_list = []
|
||||
|
||||
for group_blocks, m_blocks, thread_configs in itertools.product(
|
||||
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):
|
||||
|
||||
has_act_order = group_blocks == 0
|
||||
if has_zp and has_act_order:
|
||||
continue
|
||||
if thread_configs[2] == 256:
|
||||
if m_blocks <= 1 and thread_configs[0] != 128:
|
||||
continue
|
||||
if m_blocks > 1 and thread_configs[0] != 64:
|
||||
continue
|
||||
|
||||
k_blocks = thread_configs[0] // 16
|
||||
n_blocks = thread_configs[1] // 16
|
||||
threads = thread_configs[2]
|
||||
|
||||
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
|
||||
|
||||
template_str = jinja2.Template(TEMPLATE).render(
|
||||
scalar_t=c_dtype,
|
||||
w_type_id=scalar_type + ".id()",
|
||||
threads=threads,
|
||||
thread_m_blocks=max(m_blocks, 1),
|
||||
thread_n_blocks=n_blocks,
|
||||
thread_k_blocks=k_blocks,
|
||||
m_block_size_8=m_blocks == 0.5,
|
||||
stages="pipe_stages",
|
||||
has_act_order=has_act_order,
|
||||
has_zp=has_zp,
|
||||
group_blocks=group_blocks,
|
||||
is_zp_float=False,
|
||||
)
|
||||
|
||||
all_template_str_list.append(template_str)
|
||||
|
||||
file_content = FILE_HEAD + "\n\n"
|
||||
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
||||
filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu"
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
||||
f.write(file_content)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
remove_old_kernels()
|
||||
generate_new_kernels()
|
||||
44
csrc/moe/marlin_moe_wna16/kernel.h
Normal file
44
csrc/moe/marlin_moe_wna16/kernel.h
Normal file
@ -0,0 +1,44 @@
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
#include "quantization/gptq_marlin/marlin.cuh"
|
||||
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
#define MARLIN_KERNEL_PARAMS \
|
||||
const int4 *__restrict__ A, const int4 *__restrict__ B, \
|
||||
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
||||
const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \
|
||||
const int *__restrict__ g_idx, \
|
||||
const int32_t *__restrict__ sorted_token_ids_ptr, \
|
||||
const int32_t *__restrict__ expert_ids_ptr, \
|
||||
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
|
||||
const float *__restrict__ topk_weights_ptr, int top_k, \
|
||||
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
|
||||
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
|
||||
bool use_fp32_reduce
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
||||
const int threads, // number of threads in a threadblock
|
||||
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||
// dimension (batchsize) of the
|
||||
// threadblock
|
||||
const int thread_n_blocks, // same for n dimension (output)
|
||||
const int thread_k_blocks, // same for k dimension (reduction)
|
||||
const bool m_block_size_8, // whether m_block_size == 8
|
||||
// only works when thread_m_blocks == 1
|
||||
const int stages, // number of stages for the async global->shared
|
||||
// fetch pipeline
|
||||
const bool has_act_order, // whether act_order is enabled
|
||||
const bool has_zp, // whether zero-points are enabled
|
||||
const int group_blocks, // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
const bool is_zp_float // is zero point of float16 type?
|
||||
>
|
||||
__global__ void Marlin(MARLIN_KERNEL_PARAMS);
|
||||
|
||||
}
|
||||
1917
csrc/moe/marlin_moe_wna16/marlin_template.h
Normal file
1917
csrc/moe/marlin_moe_wna16/marlin_template.h
Normal file
File diff suppressed because it is too large
Load Diff
927
csrc/moe/marlin_moe_wna16/ops.cu
Normal file
927
csrc/moe/marlin_moe_wna16/ops.cu
Normal file
@ -0,0 +1,927 @@
|
||||
/*
|
||||
* Modified by Neural Magic
|
||||
* Copyright (C) Marlin.2024 Elias Frantar
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*
|
||||
* Adapted from https://github.com/IST-DASLab/marlin
|
||||
*/
|
||||
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||
#endif
|
||||
|
||||
#include "kernel.h"
|
||||
#include "core/registration.h"
|
||||
|
||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||
static_assert(std::is_same<scalar_t, half>::value || \
|
||||
std::is_same<scalar_t, nv_bfloat16>::value, \
|
||||
"only float16 and bfloat16 is supported");
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
__global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){};
|
||||
|
||||
using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS);
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
|
||||
template <int moe_block_size>
|
||||
__global__ void permute_cols_kernel(
|
||||
int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr,
|
||||
int4* __restrict__ out_int4_ptr,
|
||||
const int32_t* __restrict__ sorted_token_ids_ptr,
|
||||
const int32_t* __restrict__ expert_ids_ptr,
|
||||
const int32_t* __restrict__ num_tokens_past_padded_ptr, int size_m,
|
||||
int size_k, int top_k) {};
|
||||
|
||||
} // namespace marlin
|
||||
|
||||
torch::Tensor moe_wna16_marlin_gemm(
|
||||
torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none,
|
||||
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
|
||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
|
||||
torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids,
|
||||
torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights,
|
||||
int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep,
|
||||
vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n,
|
||||
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
|
||||
bool is_zp_float) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
"marlin_gemm(..) requires CUDA_ARCH >= 8.0");
|
||||
return torch::empty({1, 1});
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
// For a given "a" of size [M,K] performs a permutation of the K columns based
|
||||
// on the given "perm" indices.
|
||||
template <int moe_block_size>
|
||||
__global__ void permute_cols_kernel(
|
||||
int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr,
|
||||
int4* __restrict__ out_int4_ptr,
|
||||
const int32_t* __restrict__ sorted_token_ids_ptr,
|
||||
const int32_t* __restrict__ expert_ids_ptr,
|
||||
const int32_t* __restrict__ num_tokens_past_padded_ptr, int size_m,
|
||||
int size_k, int top_k) {
|
||||
int num_tokens_past_padded = num_tokens_past_padded_ptr[0];
|
||||
int num_moe_blocks = div_ceil(num_tokens_past_padded, moe_block_size);
|
||||
int32_t block_sorted_ids[moe_block_size];
|
||||
int block_num_valid_tokens = 0;
|
||||
int64_t old_expert_id = 0;
|
||||
int64_t expert_id = 0;
|
||||
int row_stride = size_k * sizeof(half) / 16;
|
||||
|
||||
auto read_moe_block_data = [&](int block_id) {
|
||||
block_num_valid_tokens = moe_block_size;
|
||||
int4* tmp_block_sorted_ids = reinterpret_cast<int4*>(block_sorted_ids);
|
||||
for (int i = 0; i < moe_block_size / 4; i++) {
|
||||
tmp_block_sorted_ids[i] =
|
||||
((int4*)sorted_token_ids_ptr)[block_id * moe_block_size / 4 + i];
|
||||
}
|
||||
for (int i = 0; i < moe_block_size; i++) {
|
||||
if (block_sorted_ids[i] >= size_m * top_k) {
|
||||
block_num_valid_tokens = i;
|
||||
break;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
auto permute_row = [&](int row) {
|
||||
int iters = size_k / default_threads;
|
||||
int rest = size_k % default_threads;
|
||||
|
||||
int in_offset = (row / top_k) * row_stride;
|
||||
int out_offset = row * row_stride;
|
||||
|
||||
half const* a_row_half =
|
||||
reinterpret_cast<half const*>(a_int4_ptr + in_offset);
|
||||
half* out_half = reinterpret_cast<half*>(out_int4_ptr + out_offset);
|
||||
|
||||
int base_k = 0;
|
||||
|
||||
for (int i = 0; i < iters; i++) {
|
||||
int cur_k = base_k + threadIdx.x;
|
||||
int src_pos = perm_int_ptr[cur_k];
|
||||
|
||||
out_half[cur_k] = a_row_half[src_pos];
|
||||
|
||||
base_k += default_threads;
|
||||
}
|
||||
|
||||
if (rest) {
|
||||
if (threadIdx.x < rest) {
|
||||
int cur_k = base_k + threadIdx.x;
|
||||
int src_pos = perm_int_ptr[cur_k];
|
||||
|
||||
out_half[cur_k] = a_row_half[src_pos];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
for (int index = blockIdx.x; index < num_moe_blocks; index += gridDim.x) {
|
||||
old_expert_id = expert_id;
|
||||
int tmp_expert_id = expert_ids_ptr[index];
|
||||
if (tmp_expert_id == -1) continue;
|
||||
expert_id = tmp_expert_id;
|
||||
perm_int_ptr += (expert_id - old_expert_id) * size_k;
|
||||
read_moe_block_data(index);
|
||||
|
||||
for (int i = 0; i < block_num_valid_tokens; i++)
|
||||
permute_row(block_sorted_ids[i]);
|
||||
}
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
int thread_k;
|
||||
int thread_n;
|
||||
int num_threads;
|
||||
} thread_config_t;
|
||||
|
||||
thread_config_t small_batch_thread_configs[] = {
|
||||
// Ordered by priority
|
||||
|
||||
// thread_k, thread_n, num_threads
|
||||
{128, 128, 256},
|
||||
{64, 128, 128}};
|
||||
|
||||
thread_config_t large_batch_thread_configs[] = {
|
||||
// Ordered by priority
|
||||
|
||||
// thread_k, thread_n, num_threads
|
||||
{64, 256, 256},
|
||||
{64, 128, 128}};
|
||||
|
||||
typedef struct {
|
||||
int blocks_per_sm;
|
||||
thread_config_t tb_cfg;
|
||||
} exec_config_t;
|
||||
|
||||
int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
|
||||
int prob_n, int prob_k, int num_bits, int group_size,
|
||||
bool has_act_order, bool is_k_full) {
|
||||
bool cache_scales_chunk = has_act_order && !is_k_full;
|
||||
|
||||
int tb_n = th_config.thread_n;
|
||||
int tb_k = th_config.thread_k;
|
||||
|
||||
// Get max scale groups per thread-block
|
||||
int tb_groups;
|
||||
if (group_size == -1) {
|
||||
tb_groups = 1;
|
||||
} else if (group_size == 0) {
|
||||
tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size
|
||||
} else {
|
||||
tb_groups = div_ceil(tb_k, group_size);
|
||||
}
|
||||
|
||||
if (cache_scales_chunk) {
|
||||
int load_groups =
|
||||
tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
|
||||
load_groups = max(load_groups, 32); // We load at least 32 scale groups
|
||||
return load_groups * tb_n * 2;
|
||||
|
||||
} else {
|
||||
int tb_scales = tb_groups * tb_n * 2;
|
||||
|
||||
return tb_scales * pipe_stages;
|
||||
}
|
||||
}
|
||||
|
||||
int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
|
||||
int prob_m, int prob_n, int prob_k, int num_bits,
|
||||
int group_size, bool has_act_order, bool is_k_full,
|
||||
int has_zp, int is_zp_float) {
|
||||
int pack_factor = 32 / num_bits;
|
||||
|
||||
// Get B size
|
||||
int tb_k = th_config.thread_k;
|
||||
int tb_n = th_config.thread_n;
|
||||
int tb_m = thread_m_blocks * 16;
|
||||
|
||||
// shm size for block_sorted_ids/block_topk_weights
|
||||
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
|
||||
int sh_block_meta_size = tb_m * 4 * 2;
|
||||
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;
|
||||
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
|
||||
int sh_s_size =
|
||||
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
|
||||
group_size, has_act_order, is_k_full);
|
||||
int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0;
|
||||
int sh_zp_size = 0;
|
||||
if (has_zp) {
|
||||
if (is_zp_float)
|
||||
sh_zp_size = sh_s_size;
|
||||
else if (num_bits == 4)
|
||||
sh_zp_size = sh_s_size / 4;
|
||||
else if (num_bits == 8)
|
||||
sh_zp_size = sh_s_size / 2;
|
||||
}
|
||||
|
||||
int total_size = sh_a_size + sh_b_size + sh_s_size + sh_zp_size +
|
||||
sh_g_idx_size + sh_block_meta_size;
|
||||
|
||||
return total_size;
|
||||
}
|
||||
|
||||
bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
|
||||
int prob_m, int prob_n, int prob_k, int num_bits,
|
||||
int group_size, bool has_act_order, bool is_k_full,
|
||||
int has_zp, int is_zp_float, int max_shared_mem) {
|
||||
// Sanity
|
||||
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
|
||||
th_config.num_threads == -1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Verify K/N are divisible by thread K/N
|
||||
if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Verify min for thread K/N
|
||||
if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// num_threads must be at least 128 (= 4 warps)
|
||||
if (th_config.num_threads < 128) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check that pipeline fits into cache
|
||||
int cache_size = get_kernel_cache_size(
|
||||
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size,
|
||||
has_act_order, is_k_full, has_zp, is_zp_float);
|
||||
return cache_size <= max_shared_mem;
|
||||
}
|
||||
|
||||
#define __GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||
M_BLOCK_SIZE_8, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
|
||||
NUM_THREADS, IS_ZP_FLOAT) \
|
||||
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
|
||||
thread_n_blocks == THREAD_N_BLOCKS && \
|
||||
thread_k_blocks == THREAD_K_BLOCKS && \
|
||||
m_block_size_8 == M_BLOCK_SIZE_8 && \
|
||||
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
|
||||
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
|
||||
is_zp_float == IS_ZP_FLOAT) { \
|
||||
kernel = Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
|
||||
THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, \
|
||||
pipe_stages, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
|
||||
IS_ZP_FLOAT>; \
|
||||
}
|
||||
|
||||
#define GPTQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, true, false, 0, NUM_THREADS, \
|
||||
false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
|
||||
NUM_THREADS, false) \
|
||||
\
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 8, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
|
||||
NUM_THREADS, false)
|
||||
|
||||
#define GPTQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
|
||||
NUM_THREADS, false) \
|
||||
\
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
|
||||
NUM_THREADS, false) \
|
||||
\
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
|
||||
NUM_THREADS, false) \
|
||||
\
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
|
||||
NUM_THREADS, false)
|
||||
|
||||
#define AWQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 2, NUM_THREADS, \
|
||||
false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \
|
||||
false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 8, NUM_THREADS, \
|
||||
false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
|
||||
NUM_THREADS, false)
|
||||
|
||||
#define AWQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
|
||||
NUM_THREADS, false) \
|
||||
\
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
|
||||
NUM_THREADS, false) \
|
||||
\
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||
NUM_THREADS, false) \
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
|
||||
NUM_THREADS, false)
|
||||
|
||||
// We currently have 4-bit models only with group_blocks == 4
|
||||
#define HQQ_GET_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \
|
||||
true) \
|
||||
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||
NUM_THREADS, true) \
|
||||
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||
NUM_THREADS, true) \
|
||||
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||
NUM_THREADS, true) \
|
||||
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||
NUM_THREADS, true)
|
||||
|
||||
template <typename scalar_t>
|
||||
MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
|
||||
int thread_m_blocks, int thread_n_blocks,
|
||||
int thread_k_blocks, bool m_block_size_8,
|
||||
bool has_act_order, bool has_zp,
|
||||
int group_blocks, int num_threads,
|
||||
bool is_zp_float) {
|
||||
int num_bits = q_type.size_bits();
|
||||
auto kernel = MarlinDefault;
|
||||
if (false) {
|
||||
}
|
||||
GPTQ_GET_IF_M1(vllm::kU4B8, 8, 8, 256)
|
||||
GPTQ_GET_IF_M1(vllm::kU4B8, 8, 4, 128)
|
||||
|
||||
GPTQ_GET_IF_M234(vllm::kU4B8, 16, 4, 256)
|
||||
GPTQ_GET_IF_M234(vllm::kU4B8, 8, 4, 128)
|
||||
|
||||
GPTQ_GET_IF_M1(vllm::kU8B128, 8, 8, 256)
|
||||
GPTQ_GET_IF_M1(vllm::kU8B128, 8, 4, 128)
|
||||
|
||||
GPTQ_GET_IF_M234(vllm::kU8B128, 16, 4, 256)
|
||||
GPTQ_GET_IF_M234(vllm::kU8B128, 8, 4, 128)
|
||||
|
||||
AWQ_GET_IF_M1(vllm::kU4, 8, 8, 256)
|
||||
AWQ_GET_IF_M1(vllm::kU4, 8, 4, 128)
|
||||
|
||||
AWQ_GET_IF_M234(vllm::kU4, 16, 4, 256)
|
||||
AWQ_GET_IF_M234(vllm::kU4, 8, 4, 128)
|
||||
|
||||
return kernel;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
|
||||
int prob_n, int prob_k, int thread_m_blocks,
|
||||
bool m_block_size_8, int num_bits,
|
||||
int group_size, bool has_act_order,
|
||||
bool is_k_full, bool has_zp,
|
||||
bool is_zp_float, int max_shared_mem) {
|
||||
exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};
|
||||
thread_config_t* thread_configs = thread_m_blocks > 1
|
||||
? large_batch_thread_configs
|
||||
: small_batch_thread_configs;
|
||||
int thread_configs_size =
|
||||
thread_m_blocks > 1
|
||||
? sizeof(large_batch_thread_configs) / sizeof(thread_config_t)
|
||||
: sizeof(small_batch_thread_configs) / sizeof(thread_config_t);
|
||||
|
||||
int count = 0;
|
||||
constexpr int device_max_reg_size = 255 * 1024;
|
||||
for (int i = 0; i < thread_configs_size; i++) {
|
||||
thread_config_t th_config = thread_configs[i];
|
||||
|
||||
if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k,
|
||||
num_bits, group_size, has_act_order, is_k_full, has_zp,
|
||||
is_zp_float, max_shared_mem)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int cache_size = get_kernel_cache_size(
|
||||
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits,
|
||||
group_size, has_act_order, is_k_full, has_zp, is_zp_float);
|
||||
|
||||
int group_blocks = 0;
|
||||
if (!has_act_order) {
|
||||
group_blocks = group_size == -1 ? -1 : group_size / 16;
|
||||
}
|
||||
|
||||
auto kernel = get_marlin_kernel<scalar_t>(
|
||||
q_type, thread_m_blocks, th_config.thread_n / 16,
|
||||
th_config.thread_k / 16, m_block_size_8, has_act_order, has_zp,
|
||||
group_blocks, th_config.num_threads, is_zp_float);
|
||||
|
||||
if (kernel == MarlinDefault) continue;
|
||||
|
||||
if (thread_m_blocks > 1) {
|
||||
exec_cfg = {1, th_config};
|
||||
break;
|
||||
} else {
|
||||
cudaFuncAttributes attr;
|
||||
cudaFuncGetAttributes(&attr, kernel);
|
||||
int reg_size = max(attr.numRegs, 1) * th_config.num_threads * 4;
|
||||
int allow_count = min(device_max_reg_size / reg_size,
|
||||
max_shared_mem / (cache_size + 1024));
|
||||
allow_count = max(min(allow_count, 4), 1);
|
||||
if (allow_count > count) {
|
||||
count = allow_count;
|
||||
exec_cfg = {count, th_config};
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return exec_cfg;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
void* zp, void* g_idx, void* perm, void* a_tmp,
|
||||
void* sorted_token_ids, void* expert_ids,
|
||||
void* num_tokens_past_padded, void* topk_weights,
|
||||
int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep,
|
||||
int prob_m, int prob_n, int prob_k, void* workspace,
|
||||
vllm::ScalarType const& q_type, bool has_act_order,
|
||||
bool is_k_full, bool has_zp, int num_groups, int group_size,
|
||||
int dev, cudaStream_t stream, int thread_k, int thread_n,
|
||||
int sms, bool use_atomic_add, bool use_fp32_reduce,
|
||||
bool is_zp_float) {
|
||||
int thread_m_blocks = div_ceil(moe_block_size, 16);
|
||||
bool m_block_size_8 = moe_block_size == 8;
|
||||
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(
|
||||
q_type == vllm::kU4 || q_type == vllm::kU8,
|
||||
"q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
q_type == vllm::kU4B8 || q_type == vllm::kU8B128,
|
||||
"q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
|
||||
q_type.str());
|
||||
}
|
||||
|
||||
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
||||
", ", prob_n, ", ", prob_k, "]");
|
||||
|
||||
int group_blocks = 0;
|
||||
if (has_act_order) {
|
||||
if (is_k_full) {
|
||||
TORCH_CHECK(group_size != -1);
|
||||
group_blocks = group_size / 16;
|
||||
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
|
||||
" is not divisible by group_blocks = ", group_blocks);
|
||||
} else {
|
||||
TORCH_CHECK(group_size == 0);
|
||||
group_blocks = 0;
|
||||
}
|
||||
} else {
|
||||
if (group_size == -1) {
|
||||
group_blocks = -1;
|
||||
} else {
|
||||
group_blocks = group_size / 16;
|
||||
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
|
||||
" is not divisible by group_blocks = ", group_blocks);
|
||||
}
|
||||
}
|
||||
|
||||
int num_bits = q_type.size_bits();
|
||||
const int4* A_ptr = (const int4*)A;
|
||||
const int4* B_ptr = (const int4*)B;
|
||||
int4* C_ptr = (int4*)C;
|
||||
int4* C_tmp_ptr = (int4*)C_tmp;
|
||||
const int4* s_ptr = (const int4*)s;
|
||||
const int4* zp_ptr = (const int4*)zp;
|
||||
const int* g_idx_ptr = (const int*)g_idx;
|
||||
const int* perm_ptr = (const int*)perm;
|
||||
int4* a_tmp_ptr = (int4*)a_tmp;
|
||||
const int32_t* sorted_token_ids_ptr = (const int32_t*)sorted_token_ids;
|
||||
const int32_t* expert_ids_ptr = (const int32_t*)expert_ids;
|
||||
const int32_t* num_tokens_past_padded_ptr =
|
||||
(const int32_t*)num_tokens_past_padded;
|
||||
const float* topk_weights_ptr = (const float*)topk_weights;
|
||||
int* locks = (int*)workspace;
|
||||
|
||||
if (has_act_order) {
|
||||
// Permute A columns
|
||||
auto kernel = permute_cols_kernel<8>;
|
||||
if (moe_block_size == 8) {
|
||||
} else if (moe_block_size == 16)
|
||||
kernel = permute_cols_kernel<16>;
|
||||
else if (moe_block_size == 32)
|
||||
kernel = permute_cols_kernel<32>;
|
||||
else if (moe_block_size == 48)
|
||||
kernel = permute_cols_kernel<48>;
|
||||
else if (moe_block_size == 64)
|
||||
kernel = permute_cols_kernel<64>;
|
||||
else
|
||||
TORCH_CHECK(false, "unsupported moe_block_size ", moe_block_size);
|
||||
|
||||
// avoid ">>>" being formatted to "> > >"
|
||||
// clang-format off
|
||||
kernel<<<sms, default_threads, 0, stream>>>(
|
||||
A_ptr, perm_ptr, a_tmp_ptr, sorted_token_ids_ptr, expert_ids_ptr,
|
||||
num_tokens_past_padded_ptr, prob_m, prob_k, top_k);
|
||||
// clang-format on
|
||||
A_ptr = a_tmp_ptr;
|
||||
prob_m = prob_m * top_k;
|
||||
top_k = 1;
|
||||
|
||||
// If we have a full K, then we can run the non-act-order version of Marlin
|
||||
// (since the weight rows are reordered by increasing group ids, and by
|
||||
// having a full K, we have full original groups)
|
||||
if (is_k_full) has_act_order = false;
|
||||
}
|
||||
|
||||
int max_shared_mem = 0;
|
||||
cudaDeviceGetAttribute(&max_shared_mem,
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
||||
TORCH_CHECK(max_shared_mem > 0);
|
||||
|
||||
// Set thread config
|
||||
exec_config_t exec_cfg;
|
||||
thread_config_t thread_tfg;
|
||||
if (thread_k != -1 && thread_n != -1) {
|
||||
thread_tfg = thread_config_t{thread_k, thread_n, default_threads};
|
||||
exec_cfg = exec_config_t{1, thread_tfg};
|
||||
TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
|
||||
" is not divisible by thread_n = ", thread_n);
|
||||
TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
|
||||
" is not divisible by thread_k = ", thread_k);
|
||||
} else {
|
||||
// Auto config
|
||||
exec_cfg = determine_exec_config<scalar_t>(
|
||||
q_type, prob_m, prob_n, prob_k, thread_m_blocks, m_block_size_8,
|
||||
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float,
|
||||
max_shared_mem);
|
||||
thread_tfg = exec_cfg.tb_cfg;
|
||||
}
|
||||
|
||||
int num_threads = thread_tfg.num_threads;
|
||||
thread_k = thread_tfg.thread_k;
|
||||
thread_n = thread_tfg.thread_n;
|
||||
int blocks = sms * exec_cfg.blocks_per_sm;
|
||||
if (exec_cfg.blocks_per_sm > 1)
|
||||
max_shared_mem = max_shared_mem / exec_cfg.blocks_per_sm - 1024;
|
||||
|
||||
int thread_k_blocks = thread_k / 16;
|
||||
int thread_n_blocks = thread_n / 16;
|
||||
|
||||
TORCH_CHECK(is_valid_config(thread_tfg, thread_m_blocks, prob_m, prob_n,
|
||||
prob_k, num_bits, group_size, has_act_order,
|
||||
is_k_full, has_zp, is_zp_float, max_shared_mem),
|
||||
"Invalid thread config: thread_m_blocks = ", thread_m_blocks,
|
||||
", thread_k = ", thread_tfg.thread_k,
|
||||
", thread_n = ", thread_tfg.thread_n,
|
||||
", num_threads = ", thread_tfg.num_threads, " for MKN = [",
|
||||
prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
|
||||
", group_size = ", group_size,
|
||||
", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
|
||||
", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float,
|
||||
", max_shared_mem = ", max_shared_mem);
|
||||
|
||||
auto kernel = get_marlin_kernel<scalar_t>(
|
||||
q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, m_block_size_8,
|
||||
has_act_order, has_zp, group_blocks, num_threads, is_zp_float);
|
||||
|
||||
if (kernel == MarlinDefault) {
|
||||
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
|
||||
", ", prob_k, "]", ", has_act_order = ", has_act_order,
|
||||
", num_groups = ", num_groups, ", group_size = ", group_size,
|
||||
", thread_m_blocks = ", thread_m_blocks,
|
||||
", thread_n_blocks = ", thread_n_blocks,
|
||||
", thread_k_blocks = ", thread_k_blocks,
|
||||
", num_bits = ", num_bits);
|
||||
}
|
||||
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
max_shared_mem);
|
||||
// avoid ">>>" being formatted to "> > >"
|
||||
// clang-format off
|
||||
kernel<<<blocks, num_threads, max_shared_mem, stream>>>(
|
||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr,
|
||||
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
|
||||
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
|
||||
prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
|
||||
torch::Tensor moe_wna16_marlin_gemm(
|
||||
torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none,
|
||||
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
|
||||
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
|
||||
torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids,
|
||||
torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights,
|
||||
int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep,
|
||||
vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n,
|
||||
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
|
||||
bool is_zp_float) {
|
||||
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
|
||||
int pack_factor = 32 / b_q_type.size_bits();
|
||||
|
||||
if (moe_block_size != 8) {
|
||||
TORCH_CHECK(moe_block_size % 16 == 0,
|
||||
"unsupported moe_block_size=", moe_block_size);
|
||||
TORCH_CHECK(moe_block_size >= 16 && moe_block_size <= 64,
|
||||
"unsupported moe_block_size=", moe_block_size);
|
||||
}
|
||||
|
||||
// Verify A
|
||||
TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
|
||||
", size_m = ", size_m);
|
||||
TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1),
|
||||
", size_k = ", size_k);
|
||||
|
||||
// Verify B
|
||||
TORCH_CHECK(
|
||||
size_k % MARLIN_NAMESPACE_NAME::tile_size == 0, "size_k = ", size_k,
|
||||
" is not divisible by tile_size = ", MARLIN_NAMESPACE_NAME::tile_size);
|
||||
TORCH_CHECK((size_k / MARLIN_NAMESPACE_NAME::tile_size) == b_q_weight.size(1),
|
||||
"Shape mismatch: b_q_weight.size(1) = ", b_q_weight.size(1),
|
||||
", size_k = ", size_k,
|
||||
", tile_size = ", MARLIN_NAMESPACE_NAME::tile_size);
|
||||
TORCH_CHECK(
|
||||
b_q_weight.size(2) % MARLIN_NAMESPACE_NAME::tile_size == 0,
|
||||
"b_q_weight.size(2) = ", b_q_weight.size(2),
|
||||
" is not divisible by tile_size = ", MARLIN_NAMESPACE_NAME::tile_size);
|
||||
int actual_size_n =
|
||||
(b_q_weight.size(2) / MARLIN_NAMESPACE_NAME::tile_size) * pack_factor;
|
||||
TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
|
||||
", actual_size_n = ", actual_size_n);
|
||||
|
||||
// Verify device and strides
|
||||
TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
|
||||
TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
|
||||
|
||||
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
|
||||
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
|
||||
|
||||
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
|
||||
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
|
||||
|
||||
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
|
||||
// auto -1)
|
||||
int thread_k = -1;
|
||||
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as
|
||||
// auto -1)
|
||||
int thread_n = -1;
|
||||
// sms: number of SMs to use for the kernel
|
||||
int sms = -1;
|
||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());
|
||||
|
||||
// Alloc buffers
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
||||
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
|
||||
torch::Tensor c;
|
||||
if (c_or_none.has_value()) {
|
||||
c = c_or_none.value();
|
||||
TORCH_CHECK(c.device().is_cuda(), "c is not on GPU");
|
||||
TORCH_CHECK(c.is_contiguous(), "c is not contiguous");
|
||||
TORCH_CHECK(c.size(0) == size_m * top_k,
|
||||
"Shape mismatch: c.size(0) = ", c.size(0),
|
||||
", size_m * topk = ", size_m * top_k);
|
||||
TORCH_CHECK(c.size(1) == size_n, "Shape mismatch: c.size(1) = ", c.size(1),
|
||||
", size_n = ", size_n);
|
||||
} else {
|
||||
c = torch::empty({size_m * top_k, size_n}, options);
|
||||
}
|
||||
|
||||
// Alloc C tmp buffer that is going to be used for the global reduce
|
||||
torch::Tensor c_tmp;
|
||||
auto options_fp32 =
|
||||
torch::TensorOptions().dtype(at::kFloat).device(a.device());
|
||||
if (use_fp32_reduce && !use_atomic_add) {
|
||||
// max num of threadblocks is sms * 4
|
||||
long max_c_tmp_size = min(
|
||||
(long)size_n * sorted_token_ids.size(0),
|
||||
(long)sms * 4 * moe_block_size * MARLIN_NAMESPACE_NAME::max_thread_n);
|
||||
if (moe_block_size == 8) max_c_tmp_size *= 2;
|
||||
c_tmp = torch::empty({max_c_tmp_size}, options_fp32);
|
||||
} else {
|
||||
c_tmp = torch::empty({0}, options_fp32);
|
||||
}
|
||||
|
||||
// Detect groupsize and act_order
|
||||
int num_groups = -1;
|
||||
int group_size = -1;
|
||||
|
||||
int rank = b_scales.sizes().size();
|
||||
TORCH_CHECK(rank == 3, "b_scales rank = ", rank, " is not 3");
|
||||
TORCH_CHECK(b_scales.size(2) == size_n, "b_scales dim 2 = ", b_scales.size(2),
|
||||
" is not size_n = ", size_n);
|
||||
num_groups = b_scales.size(1);
|
||||
|
||||
torch::Tensor g_idx, perm, a_tmp;
|
||||
;
|
||||
if (g_idx_or_none.has_value() && perm_or_none.has_value()) {
|
||||
g_idx = g_idx_or_none.value();
|
||||
perm = perm_or_none.value();
|
||||
|
||||
TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU");
|
||||
TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous");
|
||||
TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU");
|
||||
TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous");
|
||||
|
||||
// Verify g_idx and perm
|
||||
TORCH_CHECK((g_idx.size(-1) == 0 && perm.size(-1) == 0) ||
|
||||
(g_idx.size(-1) == size_k && perm.size(-1) == size_k),
|
||||
"Unexpected g_idx.size(-1) = ", g_idx.size(-1),
|
||||
" and perm.size(-1) = ", perm.size(-1),
|
||||
", where size_k = ", size_k);
|
||||
} else {
|
||||
g_idx = torch::empty({0}, options);
|
||||
perm = torch::empty({0}, options);
|
||||
a_tmp = torch::empty({0}, options);
|
||||
}
|
||||
bool has_act_order = g_idx.size(-1) > 0 && perm.size(-1) > 0;
|
||||
|
||||
if (has_act_order) {
|
||||
a_tmp = torch::empty({size_m * top_k, size_k}, options);
|
||||
if (is_k_full) {
|
||||
TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1");
|
||||
TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k,
|
||||
", is not divisible by num_groups = ", num_groups);
|
||||
group_size = size_k / num_groups;
|
||||
} else {
|
||||
group_size = 0;
|
||||
}
|
||||
|
||||
} else {
|
||||
a_tmp = torch::empty({0}, options);
|
||||
if (num_groups > 1) {
|
||||
TORCH_CHECK(
|
||||
size_k % num_groups == 0, "size_k = ", size_k,
|
||||
", is not divisible by b_scales.size(1) = ", b_scales.size(1));
|
||||
group_size = size_k / num_groups;
|
||||
} else {
|
||||
group_size = -1;
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor b_zeros;
|
||||
if (b_zeros_or_none.has_value()) {
|
||||
b_zeros = b_zeros_or_none.value();
|
||||
TORCH_CHECK(b_zeros.device().is_cuda(), "b_zeros is not on GPU");
|
||||
TORCH_CHECK(b_zeros.is_contiguous(), "b_zeros is not contiguous");
|
||||
} else {
|
||||
b_zeros = torch::empty({0}, options);
|
||||
}
|
||||
bool has_zp = b_zeros.size(-1) > 0;
|
||||
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(
|
||||
b_q_type == vllm::kU4,
|
||||
"b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str());
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128,
|
||||
"b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
|
||||
b_q_type.str());
|
||||
}
|
||||
|
||||
if (has_zp && is_zp_float) {
|
||||
TORCH_CHECK(a.scalar_type() == at::ScalarType::Half,
|
||||
"Computation type must be float16 (half) when using float zero "
|
||||
"points.");
|
||||
}
|
||||
|
||||
// Verify b_zeros
|
||||
if (has_zp) {
|
||||
int rank = b_zeros.sizes().size();
|
||||
TORCH_CHECK(rank == 3, "b_zeros rank = ", rank, " is not 3");
|
||||
if (is_zp_float) {
|
||||
TORCH_CHECK(b_zeros.size(2) == size_n,
|
||||
"b_zeros dim 2 = ", b_zeros.size(2),
|
||||
" is not size_n = ", size_n);
|
||||
TORCH_CHECK(num_groups == b_zeros.size(1),
|
||||
"b_zeros dim 1 = ", b_zeros.size(1),
|
||||
" is not num_groups = ", num_groups);
|
||||
TORCH_CHECK(num_groups != -1, "num_groups must be != -1");
|
||||
} else {
|
||||
TORCH_CHECK(b_zeros.size(1) == num_groups,
|
||||
"b_zeros dim 1 = ", b_zeros.size(1),
|
||||
" is not num_groups = ", num_groups);
|
||||
TORCH_CHECK(b_zeros.size(2) == size_n / pack_factor,
|
||||
"b_zeros dim 2 = ", b_zeros.size(2),
|
||||
" is not size_n / pack_factor = ", size_n / pack_factor);
|
||||
}
|
||||
}
|
||||
|
||||
// Verify workspace size
|
||||
TORCH_CHECK(size_n % MARLIN_NAMESPACE_NAME::min_thread_n == 0,
|
||||
"size_n = ", size_n, ", is not divisible by min_thread_n = ",
|
||||
MARLIN_NAMESPACE_NAME::min_thread_n);
|
||||
|
||||
int max_n_tiles = size_n / MARLIN_NAMESPACE_NAME::min_thread_n;
|
||||
int min_workspace_size = min(
|
||||
max_n_tiles * (int)(sorted_token_ids.size(0) / moe_block_size), sms * 4);
|
||||
TORCH_CHECK(workspace.numel() >= min_workspace_size,
|
||||
"workspace.numel = ", workspace.numel(),
|
||||
" is below min_workspace_size = ", min_workspace_size);
|
||||
|
||||
int dev = a.get_device();
|
||||
if (a.scalar_type() == at::ScalarType::Half) {
|
||||
MARLIN_NAMESPACE_NAME::marlin_mm<half>(
|
||||
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
|
||||
c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
|
||||
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
|
||||
a_tmp.data_ptr<at::Half>(), sorted_token_ids.data_ptr(),
|
||||
expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(),
|
||||
topk_weights.data_ptr(), moe_block_size, top_k, mul_topk_weights, is_ep,
|
||||
size_m, size_n, size_k, workspace.data_ptr(), b_q_type, has_act_order,
|
||||
is_k_full, has_zp, num_groups, group_size, dev,
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||
use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||
MARLIN_NAMESPACE_NAME::marlin_mm<nv_bfloat16>(
|
||||
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
||||
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
|
||||
b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
||||
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
|
||||
sorted_token_ids.data_ptr(), expert_ids.data_ptr(),
|
||||
num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(),
|
||||
moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k,
|
||||
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
|
||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||
} else {
|
||||
TORCH_CHECK(false,
|
||||
"moe_wna16_marlin_gemm only supports bfloat16 and float16");
|
||||
}
|
||||
|
||||
return c;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("moe_wna16_marlin_gemm", &moe_wna16_marlin_gemm);
|
||||
}
|
||||
@ -13,7 +13,6 @@
|
||||
template <typename scalar_t, int bit, int GROUPS>
|
||||
__global__ void moe_wna16_gemm_kernel(
|
||||
const scalar_t* __restrict__ input, scalar_t* __restrict__ output,
|
||||
|
||||
const uint32_t* __restrict__ qweight, const scalar_t* __restrict__ scales,
|
||||
const uint32_t* __restrict__ qzeros,
|
||||
|
||||
@ -54,8 +53,6 @@ __global__ void moe_wna16_gemm_kernel(
|
||||
if (token_index / top_k >= size_m) break;
|
||||
|
||||
num_valid_tokens = m + 1;
|
||||
if (blockIdx.z == 0 && offset_n < size_n)
|
||||
output[token_index * size_n + offset_n] = Dtype::int2num(0);
|
||||
|
||||
if (expert_id != -1) {
|
||||
int k_per_thread = DIVIDE(BLOCK_SIZE_K, BLOCK_SIZE_N);
|
||||
@ -284,8 +281,7 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
|
||||
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
|
||||
int64_t BLOCK_SIZE_K, int64_t bit) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
auto options =
|
||||
torch::TensorOptions().dtype(input.dtype()).device(input.device());
|
||||
output.zero_();
|
||||
|
||||
const int num_experts = b_qweight.size(0);
|
||||
const int size_m = input.size(0);
|
||||
@ -302,9 +298,9 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
|
||||
const uint32_t* b_qzeros_ptr;
|
||||
if (b_qzeros.has_value())
|
||||
b_qzeros_ptr = (const uint32_t*)b_qzeros.value().data_ptr<uint8_t>();
|
||||
const float* topk_weights_ptr;
|
||||
const float* topk_weights_ptr = nullptr;
|
||||
if (topk_weights.has_value())
|
||||
topk_weights_ptr = (const float*)topk_weights.value().data_ptr();
|
||||
topk_weights_ptr = (const float*)topk_weights.value().data_ptr<float>();
|
||||
|
||||
int groups_per_block_row = BLOCK_SIZE_K / group_size;
|
||||
TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8");
|
||||
|
||||
@ -43,14 +43,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
m.impl("moe_wna16_gemm", torch::kCUDA, &moe_wna16_gemm);
|
||||
|
||||
m.def(
|
||||
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
|
||||
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
|
||||
"b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, "
|
||||
"int b_q_type, SymInt size_m, "
|
||||
"SymInt size_n, SymInt size_k, bool is_k_full, int num_experts, int "
|
||||
"topk, "
|
||||
"int moe_block_size, bool replicate_input, bool apply_weights)"
|
||||
" -> Tensor");
|
||||
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
|
||||
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
|
||||
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
|
||||
"Tensor sorted_token_ids,"
|
||||
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
|
||||
"Tensor! topk_weights, int moe_block_size, int top_k, "
|
||||
"bool mul_topk_weights, bool is_ep, int b_q_type_id,"
|
||||
"int size_m, int size_n, int size_k,"
|
||||
"bool is_full_k, bool use_atomic_add,"
|
||||
"bool use_fp32_reduce, bool is_zp_float) -> Tensor");
|
||||
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
#endif
|
||||
|
||||
@ -9,7 +9,11 @@
|
||||
#include <cuda_runtime.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace marlin {
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin
|
||||
#endif
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
// Marlin params
|
||||
|
||||
@ -23,6 +27,7 @@ static constexpr int pipe_stages =
|
||||
|
||||
static constexpr int min_thread_n = 64;
|
||||
static constexpr int min_thread_k = 64;
|
||||
static constexpr int max_thread_n = 256;
|
||||
|
||||
static constexpr int tile_size = 16;
|
||||
static constexpr int max_par = 16;
|
||||
@ -84,4 +89,4 @@ __device__ inline void cp_async_wait() {
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace marlin
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
|
||||
@ -5,7 +5,11 @@
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
namespace marlin {
|
||||
#ifndef MARLIN_NAMESPACE_NAME
|
||||
#define MARLIN_NAMESPACE_NAME marlin
|
||||
#endif
|
||||
|
||||
namespace MARLIN_NAMESPACE_NAME {
|
||||
|
||||
template <typename scalar_t>
|
||||
class ScalarType {};
|
||||
@ -54,7 +58,7 @@ class ScalarType<nv_bfloat16> {
|
||||
using FragS = Vec<nv_bfloat162, 1>;
|
||||
using FragZP = Vec<nv_bfloat162, 4>;
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
|
||||
static __device__ float inline num2float(const nv_bfloat16 x) {
|
||||
return __bfloat162float(x);
|
||||
}
|
||||
@ -74,6 +78,6 @@ class ScalarType<nv_bfloat16> {
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace marlin
|
||||
} // namespace MARLIN_NAMESPACE_NAME
|
||||
|
||||
#endif
|
||||
|
||||
@ -240,6 +240,8 @@ if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
|
||||
uv pip install --system https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.1.post2/flashinfer_python-0.2.1.post2+cu124torch2.6-cp38-abi3-linux_x86_64.whl ; \
|
||||
fi
|
||||
COPY examples examples
|
||||
COPY benchmarks benchmarks
|
||||
COPY ./vllm/collect_env.py .
|
||||
|
||||
# Although we build Flashinfer with AOT mode, there's still
|
||||
# some issues w.r.t. JIT compilation. Therefore we need to
|
||||
|
||||
@ -121,6 +121,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
ADD ./tests/ ./tests/
|
||||
ADD ./examples/ ./examples/
|
||||
ADD ./benchmarks/ ./benchmarks/
|
||||
ADD ./vllm/collect_env.py .
|
||||
|
||||
# install development dependencies (for testing)
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
|
||||
@ -126,13 +126,16 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
FROM base-builder AS cv-builder
|
||||
|
||||
ARG MAX_JOBS
|
||||
ARG OPENCV_VERSION=84
|
||||
ARG OPENCV_VERSION=86
|
||||
# patch for version 4.11.0.86
|
||||
ARG OPENCV_PATCH=97f3f39
|
||||
ARG ENABLE_HEADLESS=1
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
source /opt/rh/gcc-toolset-13/enable && \
|
||||
git clone --recursive https://github.com/opencv/opencv-python.git -b ${OPENCV_VERSION} && \
|
||||
cd opencv-python && \
|
||||
sed -i 's/"setuptools==59.2.0",/"setuptools<70.0",/g' pyproject.toml && \
|
||||
sed -i -E -e 's/"setuptools.+",/"setuptools",/g' pyproject.toml && \
|
||||
cd opencv && git cherry-pick --no-commit $OPENCV_PATCH && cd .. && \
|
||||
python -m build --wheel --installer=uv --outdir /opencvwheels/
|
||||
|
||||
###############################################################
|
||||
@ -148,9 +151,15 @@ COPY --from=arrow-builder /tmp/control /dev/null
|
||||
COPY --from=cv-builder /tmp/control /dev/null
|
||||
|
||||
ARG VLLM_TARGET_DEVICE=cpu
|
||||
ARG GRPC_PYTHON_BUILD_SYSTEM_OPENSSL=1
|
||||
|
||||
# this step installs vllm and populates uv cache
|
||||
# with all the transitive dependencies
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
source /opt/rh/gcc-toolset-13/enable && \
|
||||
git clone https://github.com/huggingface/xet-core.git && cd xet-core/hf_xet/ && \
|
||||
uv pip install maturin && \
|
||||
uv build --wheel --out-dir /hf_wheels/
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,from=torch-builder,source=/torchwheels/,target=/torchwheels/,ro \
|
||||
--mount=type=bind,from=arrow-builder,source=/arrowwheels/,target=/arrowwheels/,ro \
|
||||
@ -159,7 +168,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
source /opt/rh/gcc-toolset-13/enable && \
|
||||
uv pip install /opencvwheels/*.whl /arrowwheels/*.whl /torchwheels/*.whl && \
|
||||
sed -i -e 's/.*torch.*//g' /src/pyproject.toml /src/requirements/*.txt && \
|
||||
uv pip install pandas pythran pybind11 && \
|
||||
uv pip install pandas pythran pybind11 /hf_wheels/*.whl && \
|
||||
# sentencepiece.pc is in some pkgconfig inside uv cache
|
||||
export PKG_CONFIG_PATH=$(find / -type d -name "pkgconfig" 2>/dev/null | tr '\n' ':') && \
|
||||
uv pip install -r /src/requirements/common.txt -r /src/requirements/cpu.txt -r /src/requirements/build.txt --no-build-isolation && \
|
||||
@ -247,8 +256,9 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,from=torch-builder,source=/torchwheels/,target=/torchwheels/,ro \
|
||||
--mount=type=bind,from=arrow-builder,source=/arrowwheels/,target=/arrowwheels/,ro \
|
||||
--mount=type=bind,from=cv-builder,source=/opencvwheels/,target=/opencvwheels/,ro \
|
||||
--mount=type=bind,from=vllmcache-builder,source=/hf_wheels/,target=/hf_wheels/,ro \
|
||||
--mount=type=bind,from=vllmcache-builder,source=/vllmwheel/,target=/vllmwheel/,ro \
|
||||
HOME=/root uv pip install /opencvwheels/*.whl /arrowwheels/*.whl /torchwheels/*.whl /vllmwheel/*.whl
|
||||
HOME=/root uv pip install /opencvwheels/*.whl /arrowwheels/*.whl /torchwheels/*.whl /hf_wheels/*.whl /vllmwheel/*.whl
|
||||
|
||||
COPY ./ /workspace/vllm
|
||||
WORKDIR /workspace/vllm
|
||||
|
||||
@ -58,7 +58,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
cd ../../python && \
|
||||
export PYARROW_PARALLEL=4 && \
|
||||
export ARROW_BUILD_TYPE=release && \
|
||||
uv pip install -r requirements/build.txt && \
|
||||
uv pip install -r requirements-build.txt && \
|
||||
python setup.py build_ext --build-type=$ARROW_BUILD_TYPE --bundle-arrow-cpp bdist_wheel
|
||||
|
||||
FROM python-install AS numa-build
|
||||
@ -96,6 +96,22 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install -v torch==${TORCH_VERSION} --extra-index-url https://download.pytorch.org/whl/nightly/cpu && \
|
||||
python setup.py bdist_wheel
|
||||
|
||||
FROM python-install AS hf-xet-builder
|
||||
# Install hf-xet
|
||||
WORKDIR /tmp
|
||||
ENV CARGO_HOME=/root/.cargo
|
||||
ENV RUSTUP_HOME=/root/.rustup
|
||||
ENV PATH="$CARGO_HOME/bin:$RUSTUP_HOME/bin:$PATH"
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,from=rust,source=/root/.cargo,target=/root/.cargo,rw \
|
||||
--mount=type=bind,from=rust,source=/root/.rustup,target=/root/.rustup,rw \
|
||||
git clone https://github.com/huggingface/xet-core.git && \
|
||||
cd xet-core/hf_xet/ && \
|
||||
uv pip install maturin patchelf && \
|
||||
python -m maturin build --release --out dist && \
|
||||
mkdir -p /tmp/hf-xet/dist && \
|
||||
cp dist/*.whl /tmp/hf-xet/dist/
|
||||
|
||||
# Final build stage
|
||||
FROM python-install AS vllm-cpu
|
||||
ARG PYTHON_VERSION
|
||||
@ -120,12 +136,15 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,from=rust,source=/root/.rustup,target=/root/.rustup,rw \
|
||||
--mount=type=bind,from=pyarrow,source=/tmp/arrow/python/dist,target=/tmp/arrow-wheels \
|
||||
--mount=type=bind,from=torch-vision,source=/tmp/vision/dist,target=/tmp/vision-wheels/ \
|
||||
--mount=type=bind,from=hf-xet-builder,source=/tmp/hf-xet/dist,target=/tmp/hf-xet-wheels/ \
|
||||
sed -i '/^torch/d' requirements/build.txt && \
|
||||
ARROW_WHL_FILE=$(ls /tmp/arrow-wheels/pyarrow-*.whl | head -n 1) && \
|
||||
VISION_WHL_FILE=$(ls /tmp/vision-wheels/*.whl | head -n 1) && \
|
||||
HF_XET_WHL_FILE=$(ls /tmp/hf-xet-wheels/*.whl | head -n 1) && \
|
||||
uv pip install -v \
|
||||
$ARROW_WHL_FILE \
|
||||
$VISION_WHL_FILE \
|
||||
$HF_XET_WHL_FILE \
|
||||
--extra-index-url https://download.pytorch.org/whl/nightly/cpu \
|
||||
--index-strategy unsafe-best-match \
|
||||
-r requirements/build.txt \
|
||||
@ -149,4 +168,5 @@ USER 2000
|
||||
WORKDIR /home/vllm
|
||||
|
||||
# Set the default entrypoint
|
||||
ENTRYPOINT ["python", "-m", "vllm.entrypoints.openai.api_server"]
|
||||
ENTRYPOINT ["python", "-m", "vllm.entrypoints.openai.api_server"]
|
||||
|
||||
|
||||
BIN
docs/source/assets/deployment/open_webui.png
Normal file
BIN
docs/source/assets/deployment/open_webui.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 68 KiB |
@ -19,6 +19,18 @@ $ docker run --runtime nvidia --gpus all \
|
||||
--model mistralai/Mistral-7B-v0.1
|
||||
```
|
||||
|
||||
This image can also be used with other container engines such as [Podman](https://podman.io/).
|
||||
|
||||
```console
|
||||
$ podman run --gpus all \
|
||||
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||
--env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \
|
||||
-p 8000:8000 \
|
||||
--ipc=host \
|
||||
vllm/vllm-openai:latest \
|
||||
--model mistralai/Mistral-7B-v0.1
|
||||
```
|
||||
|
||||
You can add any other <project:#engine-args> you need after the image tag (`vllm/vllm-openai:latest`).
|
||||
|
||||
:::{note}
|
||||
|
||||
@ -9,6 +9,7 @@ dstack
|
||||
helm
|
||||
lws
|
||||
modal
|
||||
open-webui
|
||||
skypilot
|
||||
triton
|
||||
:::
|
||||
|
||||
29
docs/source/deployment/frameworks/open-webui.md
Normal file
29
docs/source/deployment/frameworks/open-webui.md
Normal file
@ -0,0 +1,29 @@
|
||||
(deployment-open-webui)=
|
||||
|
||||
# Open WebUI
|
||||
|
||||
1. Install the [Docker](https://docs.docker.com/engine/install/)
|
||||
|
||||
2. Start the vLLM server with the supported chat completion model, e.g.
|
||||
|
||||
```console
|
||||
vllm serve qwen/Qwen1.5-0.5B-Chat
|
||||
```
|
||||
|
||||
1. Start the [Open WebUI](https://github.com/open-webui/open-webui) docker container (replace the vllm serve host and vllm serve port):
|
||||
|
||||
```console
|
||||
docker run -d -p 3000:8080 \
|
||||
--name open-webui \
|
||||
-v open-webui:/app/backend/data \
|
||||
-e OPENAI_API_BASE_URL=http://<vllm serve host>:<vllm serve port>/v1 \
|
||||
--restart always \
|
||||
ghcr.io/open-webui/open-webui:main
|
||||
```
|
||||
|
||||
1. Open it in the browser: <http://open-webui-host:3000/>
|
||||
|
||||
On the top of the web page, you can see the model `qwen/Qwen1.5-0.5B-Chat`.
|
||||
|
||||
:::{image} /assets/deployment/open_webui.png
|
||||
:::
|
||||
@ -16,7 +16,7 @@ Ensure that you have a running Kubernetes environment with GPU (you can follow [
|
||||
|
||||
## Deployment using vLLM production stack
|
||||
|
||||
The standard vLLM production stack install uses a Helm chart. You can run this [bash script](https://github.com/vllm-project/production-stack/blob/main/tutorials/install-helm.sh) to install Helm on your GPU server.
|
||||
The standard vLLM production stack is installed using a Helm chart. You can run this [bash script](https://github.com/vllm-project/production-stack/blob/main/utils/install-helm.sh) to install Helm on your GPU server.
|
||||
|
||||
To install the vLLM production stack, run the following commands on your desktop:
|
||||
|
||||
|
||||
@ -66,8 +66,8 @@ vLLM also provides [a reference example](https://docs.vllm.ai/en/latest/getting_
|
||||
The subset of metrics exposed in the Grafana dashboard gives us an indication of which metrics are especially important:
|
||||
|
||||
- `vllm:e2e_request_latency_seconds_bucket` - End to end request latency measured in seconds
|
||||
- `vllm:prompt_tokens_total` - Prompt Tokens/Sec
|
||||
- `vllm:generation_tokens_total` - Generation Tokens/Sec
|
||||
- `vllm:prompt_tokens_total` - Prompt Tokens
|
||||
- `vllm:generation_tokens_total` - Generation Tokens
|
||||
- `vllm:time_per_output_token_seconds` - Inter token latency (Time Per Output Token, TPOT) in second.
|
||||
- `vllm:time_to_first_token_seconds` - Time to First Token (TTFT) latency in seconds.
|
||||
- `vllm:num_requests_running` (also, `_swapped` and `_waiting`) - Number of requests in RUNNING, WAITING, and SWAPPED state
|
||||
@ -86,6 +86,17 @@ See [the PR which added this Dashboard](gh-pr:2316) for interesting and useful b
|
||||
|
||||
Prometheus support was initially added [using the aioprometheus library](gh-pr:1890), but a switch was made quickly to [prometheus_client](gh-pr:2730). The rationale is discussed in both linked PRs.
|
||||
|
||||
With the switch to `aioprometheus`, we lost a `MetricsMiddleware` to track HTTP metrics, but this was reinstated [using prometheus_fastapi_instrumentator](gh-pr:15657):
|
||||
|
||||
```bash
|
||||
$ curl http://0.0.0.0:8000/metrics 2>/dev/null | grep -P '^http_(?!.*(_bucket|_created|_sum)).*'
|
||||
http_requests_total{handler="/v1/completions",method="POST",status="2xx"} 201.0
|
||||
http_request_size_bytes_count{handler="/v1/completions"} 201.0
|
||||
http_response_size_bytes_count{handler="/v1/completions"} 201.0
|
||||
http_request_duration_highr_seconds_count 201.0
|
||||
http_request_duration_seconds_count{handler="/v1/completions",method="POST"} 201.0
|
||||
```
|
||||
|
||||
### Multi-process Mode
|
||||
|
||||
In v0, metrics are collected in the engine core process and we use multi-process mode to make them available in the API server process. See <gh-pr:7279>.
|
||||
|
||||
@ -134,6 +134,6 @@ The cudagraphs are captured and managed by the compiler backend, and replayed wh
|
||||
|
||||
By default, vLLM will try to determine a set of sizes to capture cudagraph. You can also override it using the config `cudagraph_capture_sizes`:
|
||||
|
||||
`VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.2-1B --compilation_config "{'cudagraph_capture_sizes': [1, 2, 4, 8]}"`
|
||||
`VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.2-1B --compilation-config "{'cudagraph_capture_sizes': [1, 2, 4, 8]}"`
|
||||
|
||||
Then it will only capture cudagraph for the specified sizes. It can be useful to have fine-grained control over the cudagraph capture.
|
||||
|
||||
@ -106,19 +106,18 @@ curl http://localhost:8000/v1/completions \
|
||||
|
||||
## Dynamically serving LoRA Adapters
|
||||
|
||||
In addition to serving LoRA adapters at server startup, the vLLM server now supports dynamically loading and unloading
|
||||
LoRA adapters at runtime through dedicated API endpoints. This feature can be particularly useful when the flexibility
|
||||
to change models on-the-fly is needed.
|
||||
In addition to serving LoRA adapters at server startup, the vLLM server supports dynamically configuring LoRA adapters at runtime through dedicated API endpoints and plugins. This feature can be particularly useful when the flexibility to change models on-the-fly is needed.
|
||||
|
||||
Note: Enabling this feature in production environments is risky as users may participate in model adapter management.
|
||||
|
||||
To enable dynamic LoRA loading and unloading, ensure that the environment variable `VLLM_ALLOW_RUNTIME_LORA_UPDATING`
|
||||
is set to `True`. When this option is enabled, the API server will log a warning to indicate that dynamic loading is active.
|
||||
To enable dynamic LoRA configuration, ensure that the environment variable `VLLM_ALLOW_RUNTIME_LORA_UPDATING`
|
||||
is set to `True`.
|
||||
|
||||
```bash
|
||||
export VLLM_ALLOW_RUNTIME_LORA_UPDATING=True
|
||||
```
|
||||
|
||||
### Using API Endpoints
|
||||
Loading a LoRA Adapter:
|
||||
|
||||
To dynamically load a LoRA adapter, send a POST request to the `/v1/load_lora_adapter` endpoint with the necessary
|
||||
@ -153,6 +152,58 @@ curl -X POST http://localhost:8000/v1/unload_lora_adapter \
|
||||
}'
|
||||
```
|
||||
|
||||
### Using Plugins
|
||||
Alternatively, you can use the LoRAResolver plugin to dynamically load LoRA adapters. LoRAResolver plugins enable you to load LoRA adapters from both local and remote sources such as local file system and S3. On every request, when there's a new model name that hasn't been loaded yet, the LoRAResolver will try to resolve and load the corresponding LoRA adapter.
|
||||
|
||||
You can set up multiple LoRAResolver plugins if you want to load LoRA adapters from different sources. For example, you might have one resolver for local files and another for S3 storage. vLLM will load the first LoRA adapter that it finds.
|
||||
|
||||
You can either install existing plugins or implement your own.
|
||||
|
||||
Steps to implement your own LoRAResolver plugin:
|
||||
1. Implement the LoRAResolver interface.
|
||||
|
||||
Example of a simple S3 LoRAResolver implementation:
|
||||
|
||||
```python
|
||||
import os
|
||||
import s3fs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.resolver import LoRAResolver
|
||||
|
||||
class S3LoRAResolver(LoRAResolver):
|
||||
def __init__(self):
|
||||
self.s3 = s3fs.S3FileSystem()
|
||||
self.s3_path_format = os.getenv("S3_PATH_TEMPLATE")
|
||||
self.local_path_format = os.getenv("LOCAL_PATH_TEMPLATE")
|
||||
|
||||
async def resolve_lora(self, base_model_name, lora_name):
|
||||
s3_path = self.s3_path_format.format(base_model_name=base_model_name, lora_name=lora_name)
|
||||
local_path = self.local_path_format.format(base_model_name=base_model_name, lora_name=lora_name)
|
||||
|
||||
# Download the LoRA from S3 to the local path
|
||||
await self.s3._get(
|
||||
s3_path, local_path, recursive=True, maxdepth=1
|
||||
)
|
||||
|
||||
lora_request = LoRARequest(
|
||||
lora_name=lora_name,
|
||||
lora_path=local_path,
|
||||
lora_int_id=abs(hash(lora_name))
|
||||
)
|
||||
return lora_request
|
||||
```
|
||||
|
||||
2. Register LoRAResolver plugin.
|
||||
|
||||
```python
|
||||
from vllm.lora.resolver import LoRAResolverRegistry
|
||||
|
||||
s3_resolver = S3LoRAResolver()
|
||||
LoRAResolverRegistry.register_resolver("s3_resolver", s3_resolver)
|
||||
```
|
||||
|
||||
For more details, refer to the [vLLM's Plugins System](../design/plugin_system.md).
|
||||
|
||||
## New format for `--lora-modules`
|
||||
|
||||
In the previous version, users would provide LoRA modules via the following format, either as a key-value pair or in JSON format. For example:
|
||||
|
||||
@ -245,6 +245,8 @@ Example supported models:
|
||||
* `meta-llama/Llama-3.2-3B-Instruct`\* (use with `examples/tool_chat_template_llama3.2_pythonic.jinja`)
|
||||
* `Team-ACE/ToolACE-8B` (use with `examples/tool_chat_template_toolace.jinja`)
|
||||
* `fixie-ai/ultravox-v0_4-ToolACE-8B` (use with `examples/tool_chat_template_toolace.jinja`)
|
||||
* `meta-llama/Llama-4-Scout-17B-16E-Instruct`\* (use with `examples/tool_chat_template_llama4_pythonic.jinja`)
|
||||
* `meta-llama/Llama-4-Maverick-17B-128E-Instruct`\* (use with `examples/tool_chat_template_llama4_pythonic.jinja`)
|
||||
|
||||
Flags: `--tool-call-parser pythonic --chat-template {see_above}`
|
||||
|
||||
|
||||
@ -272,7 +272,7 @@ $ python examples/offline_inference/basic/basic.py
|
||||
|
||||
- Decouple the HTTP serving components from the inference components. In a GPU backend configuration, the HTTP serving and tokenization tasks operate on the CPU, while inference runs on the GPU, which typically does not pose a problem. However, in a CPU-based setup, the HTTP serving and tokenization can cause significant context switching and reduced cache efficiency. Therefore, it is strongly recommended to segregate these two components for improved performance.
|
||||
|
||||
- On CPU based setup with NUMA enabled, the memory access performance may be largely impacted by the [topology](https://github.com/intel/intel-extension-for-pytorch/blob/main/docs/tutorials/performance_tuning/tuning_guide.inc.md#non-uniform-memory-access-numa). For NUMA architecture, Tensor Parallel is a option for better performance.
|
||||
- On CPU based setup with NUMA enabled, the memory access performance may be largely impacted by the [topology](https://github.com/intel/intel-extension-for-pytorch/blob/main/docs/tutorials/performance_tuning/tuning_guide.md#non-uniform-memory-access-numa). For NUMA architecture, Tensor Parallel is a option for better performance.
|
||||
|
||||
- Tensor Parallel is supported for serving and offline inferencing. In general each NUMA node is treated as one GPU card. Below is the example script to enable Tensor Parallel = 2 for serving:
|
||||
|
||||
|
||||
@ -46,7 +46,7 @@ LLM inference is a fast-evolving field, and the latest code may contain bug fixe
|
||||
##### Install the latest code using `pip`
|
||||
|
||||
```console
|
||||
pip install vllm --pre --extra-index-url https://wheels.vllm.ai/nightly
|
||||
pip install -U vllm --pre --extra-index-url https://wheels.vllm.ai/nightly
|
||||
```
|
||||
|
||||
`--pre` is required for `pip` to consider pre-released versions.
|
||||
@ -65,9 +65,11 @@ Note that the wheels are built with Python 3.8 ABI (see [PEP 425](https://peps.p
|
||||
Another way to install the latest code is to use `uv`:
|
||||
|
||||
```console
|
||||
uv pip install vllm --extra-index-url https://wheels.vllm.ai/nightly
|
||||
uv pip install -U vllm --extra-index-url https://wheels.vllm.ai/nightly
|
||||
```
|
||||
|
||||
##### Install specific revisions using `uv`
|
||||
|
||||
If you want to access the wheels for previous commits (e.g. to bisect the behavior change, performance regression), you can specify the commit hash in the URL:
|
||||
|
||||
```console
|
||||
|
||||
@ -24,7 +24,7 @@ To isolate the model downloading and loading issue, you can use the `--load-form
|
||||
|
||||
## Out of memory
|
||||
|
||||
If the model is too large to fit in a single GPU, you will get an out-of-memory (OOM) error. Consider [using tensor parallelism](#distributed-serving) to split the model across multiple GPUs. In that case, every process will read the whole model and split it into chunks, which makes the disk reading time even longer (proportional to the size of tensor parallelism). You can convert the model checkpoint to a sharded checkpoint using <gh-file:examples/offline_inference/save_sharded_state.py>. The conversion process might take some time, but later you can load the sharded checkpoint much faster. The model loading time should remain constant regardless of the size of tensor parallelism.
|
||||
If the model is too large to fit in a single GPU, you will get an out-of-memory (OOM) error. Consider adopting [these options](#reducing-memory-usage) to reduce the memory consumption.
|
||||
|
||||
## Generation quality changed
|
||||
|
||||
|
||||
@ -141,3 +141,77 @@ Our [OpenAI-Compatible Server](#openai-compatible-server) provides endpoints tha
|
||||
- [Pooling API](#pooling-api) is similar to `LLM.encode`, being applicable to all types of pooling models.
|
||||
- [Embeddings API](#embeddings-api) is similar to `LLM.embed`, accepting both text and [multi-modal inputs](#multimodal-inputs) for embedding models.
|
||||
- [Score API](#score-api) is similar to `LLM.score` for cross-encoder models.
|
||||
|
||||
## Matryoshka Embeddings
|
||||
|
||||
[Matryoshka Embeddings](https://sbert.net/examples/sentence_transformer/training/matryoshka/README.html#matryoshka-embeddings) or [Matryoshka Representation Learning (MRL)](https://arxiv.org/abs/2205.13147) is a technique used in training embedding models. It allows user to trade off between performance and cost.
|
||||
|
||||
:::{warning}
|
||||
Not all embedding models are trained using Matryoshka Representation Learning. To avoid misuse of the `dimensions` parameter, vLLM returns an error for requests that attempt to change the output dimension of models that do not support Matryoshka Embeddings.
|
||||
|
||||
For example, setting `dimensions` parameter while using the `BAAI/bge-m3` model will result in the following error.
|
||||
|
||||
```json
|
||||
{"object":"error","message":"Model \"BAAI/bge-m3\" does not support matryoshka representation, changing output dimensions will lead to poor results.","type":"BadRequestError","param":null,"code":400}
|
||||
```
|
||||
|
||||
:::
|
||||
|
||||
### Manually enable Matryoshka Embeddings
|
||||
|
||||
There is currently no official interface for specifying support for Matryoshka Embeddings. In vLLM, we simply check the existence of the fields `is_matryoshka` or `matryoshka_dimensions` inside `config.json`.
|
||||
|
||||
For models that support Matryoshka Embeddings but not recognized by vLLM, please manually override the config using `hf_overrides={"is_matryoshka": True}` (offline) or `--hf_overrides '{"is_matryoshka": true}'` (online).
|
||||
|
||||
Here is an example to serve a model with Matryoshka Embeddings enabled.
|
||||
|
||||
```text
|
||||
vllm serve Snowflake/snowflake-arctic-embed-m-v1.5 --hf_overrides '{"is_matryoshka":true}'
|
||||
```
|
||||
|
||||
### Offline Inference
|
||||
|
||||
You can change the output dimensions of embedding models that support Matryoshka Embeddings by using the dimensions parameter in {class}`~vllm.PoolingParams`.
|
||||
|
||||
```python
|
||||
from vllm import LLM, PoolingParams
|
||||
|
||||
model = LLM(model="jinaai/jina-embeddings-v3",
|
||||
task="embed",
|
||||
trust_remote_code=True)
|
||||
outputs = model.embed(["Follow the white rabbit."],
|
||||
pooling_params=PoolingParams(dimensions=32))
|
||||
print(outputs[0].outputs)
|
||||
```
|
||||
|
||||
A code example can be found here: <gh-file:examples/offline_inference/embed_matryoshka_fy.py>
|
||||
|
||||
### Online Inference
|
||||
|
||||
Use the following command to start vllm server.
|
||||
|
||||
```text
|
||||
vllm serve jinaai/jina-embeddings-v3 --trust-remote-code
|
||||
```
|
||||
|
||||
You can change the output dimensions of embedding models that support Matryoshka Embeddings by using the dimensions parameter.
|
||||
|
||||
```text
|
||||
curl http://127.0.0.1:8000/v1/embeddings \
|
||||
-H 'accept: application/json' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"input": "Follow the white rabbit.",
|
||||
"model": "jinaai/jina-embeddings-v3",
|
||||
"encoding_format": "float",
|
||||
"dimensions": 1
|
||||
}'
|
||||
```
|
||||
|
||||
Expected output:
|
||||
|
||||
```json
|
||||
{"id":"embd-0aab28c384d348c3b8f0eb783109dc5f","object":"list","created":1744195454,"model":"jinaai/jina-embeddings-v3","data":[{"index":0,"object":"embedding","embedding":[-1.0]}],"usage":{"prompt_tokens":10,"total_tokens":10,"completion_tokens":0,"prompt_tokens_details":null}}
|
||||
```
|
||||
|
||||
A openai client example can be found here: <gh-file:examples/online_serving/openai_embedding_matryoshka_fy.py>
|
||||
|
||||
@ -55,6 +55,10 @@ If your model is neither supported natively by vLLM or Transformers, you can sti
|
||||
Simply set `trust_remote_code=True` and vLLM will run any model on the Model Hub that is compatible with Transformers.
|
||||
Provided that the model writer implements their model in a compatible way, this means that you can run new models before they are officially supported in Transformers or vLLM!
|
||||
|
||||
:::{tip}
|
||||
If you have not yet created your custom model, you can follow this guide on [customising models in Transformers](https://huggingface.co/docs/transformers/en/custom_models).
|
||||
:::
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
llm = LLM(model=..., task="generate", trust_remote_code=True) # Name or path of your model
|
||||
@ -334,7 +338,7 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
* ✅︎
|
||||
- * `Glm4ForCausalLM`
|
||||
* GLM-4-0414
|
||||
* `THUDM/GLM-4-32B-Chat-0414`, etc.
|
||||
* `THUDM/GLM-4-32B-0414`, etc.
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
- * `GPT2LMHeadModel`
|
||||
@ -497,6 +501,11 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
* `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc.
|
||||
*
|
||||
* ✅︎
|
||||
- * `Plamo2ForCausalLM`
|
||||
* PLaMo2
|
||||
* `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc.
|
||||
*
|
||||
*
|
||||
- * `QWenLMHeadModel`
|
||||
* Qwen
|
||||
* `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.
|
||||
@ -735,6 +744,11 @@ If your model is not in the above list, we will try to automatically convert the
|
||||
* `BAAI/bge-reranker-v2-m3`, etc.
|
||||
*
|
||||
*
|
||||
- * `ModernBertForSequenceClassification`
|
||||
* ModernBert-based
|
||||
* `Alibaba-NLP/gte-reranker-modernbert-base`, etc.
|
||||
*
|
||||
*
|
||||
:::
|
||||
|
||||
(supported-mm-models)=
|
||||
@ -759,7 +773,7 @@ On the other hand, modalities separated by `/` are mutually exclusive.
|
||||
See [this page](#multimodal-inputs) on how to pass multi-modal inputs to the model.
|
||||
|
||||
:::{important}
|
||||
To enable multiple multi-modal items per text prompt, you have to set `limit_mm_per_prompt` (offline inference)
|
||||
**To enable multiple multi-modal items per text prompt in vLLM V0**, you have to set `limit_mm_per_prompt` (offline inference)
|
||||
or `--limit-mm-per-prompt` (online serving). For example, to enable passing up to 4 images per text prompt:
|
||||
|
||||
Offline inference:
|
||||
@ -774,9 +788,11 @@ llm = LLM(
|
||||
Online serving:
|
||||
|
||||
```bash
|
||||
vllm serve Qwen/Qwen2-VL-7B-Instruct --limit-mm-per-prompt image=4
|
||||
vllm serve Qwen/Qwen2-VL-7B-Instruct --limit-mm-per-prompt '{"image":4}'
|
||||
```
|
||||
|
||||
**This is no longer required if you are using vLLM V1.**
|
||||
|
||||
:::
|
||||
|
||||
:::{note}
|
||||
@ -878,14 +894,21 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
*
|
||||
* ✅︎
|
||||
- * `InternVLChatModel`
|
||||
* InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0
|
||||
* InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0
|
||||
* T + I<sup>E+</sup>
|
||||
* `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc.
|
||||
* `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc.
|
||||
*
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
- * `KimiVLForConditionalGeneration`
|
||||
* Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking
|
||||
* T + I<sup>+</sup>
|
||||
* `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking`
|
||||
*
|
||||
*
|
||||
* ✅︎
|
||||
- * `Llama4ForConditionalGeneration`
|
||||
* Llama-4-17B-Omni-Instruct
|
||||
* Llama 4
|
||||
* T + I<sup>+</sup>
|
||||
* `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc.
|
||||
*
|
||||
|
||||
@ -228,7 +228,7 @@ First, launch the OpenAI-compatible server:
|
||||
|
||||
```bash
|
||||
vllm serve microsoft/Phi-3.5-vision-instruct --task generate \
|
||||
--trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2
|
||||
--trust-remote-code --max-model-len 4096 --limit-mm-per-prompt '{"image":2}'
|
||||
```
|
||||
|
||||
Then, you can use the OpenAI client as follows:
|
||||
|
||||
@ -28,6 +28,8 @@ Please refer to the above pages for more details about each API.
|
||||
[API Reference](/api/offline_inference/index)
|
||||
:::
|
||||
|
||||
(configuration-options)=
|
||||
|
||||
## Configuration Options
|
||||
|
||||
This section lists the most common options for running the vLLM engine.
|
||||
@ -59,6 +61,8 @@ model = LLM(
|
||||
|
||||
Our [list of supported models](#supported-models) shows the model architectures that are recognized by vLLM.
|
||||
|
||||
(reducing-memory-usage)=
|
||||
|
||||
### Reducing memory usage
|
||||
|
||||
Large models might cause your machine to run out of memory (OOM). Here are some options that help alleviate this problem.
|
||||
@ -81,6 +85,12 @@ before initializing vLLM. Otherwise, you may run into an error like `RuntimeErro
|
||||
To control which devices are used, please instead set the `CUDA_VISIBLE_DEVICES` environment variable.
|
||||
:::
|
||||
|
||||
:::{note}
|
||||
With tensor parallelism enabled, each process will read the whole model and split it into chunks, which makes the disk reading time even longer (proportional to the size of tensor parallelism).
|
||||
|
||||
You can convert the model checkpoint to a sharded checkpoint using <gh-file:examples/offline_inference/save_sharded_state.py>. The conversion process might take some time, but later you can load the sharded checkpoint much faster. The model loading time should remain constant regardless of the size of tensor parallelism.
|
||||
:::
|
||||
|
||||
#### Quantization
|
||||
|
||||
Quantized models take less memory at the cost of lower precision.
|
||||
@ -103,6 +113,39 @@ llm = LLM(model="adept/fuyu-8b",
|
||||
max_num_seqs=2)
|
||||
```
|
||||
|
||||
#### Reduce CUDA Graphs
|
||||
|
||||
By default, we optimize model inference using CUDA graphs which take up extra memory in the GPU.
|
||||
|
||||
:::{important}
|
||||
CUDA graph capture takes up more memory in V1 than in V0.
|
||||
:::
|
||||
|
||||
You can adjust `compilation_config` to achieve a better balance between inference speed and memory usage:
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
from vllm.config import CompilationConfig, CompilationLevel
|
||||
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
# By default, it goes up to max_num_seqs
|
||||
cudagraph_capture_sizes=[1, 2, 4, 8, 16],
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
You can disable graph capturing completely via the `enforce_eager` flag:
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
|
||||
llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
enforce_eager=True)
|
||||
```
|
||||
|
||||
#### Adjust cache size
|
||||
|
||||
If you run out of CPU RAM, try the following options:
|
||||
@ -110,6 +153,62 @@ If you run out of CPU RAM, try the following options:
|
||||
- (Multi-modal models only) you can set the size of multi-modal input cache using `VLLM_MM_INPUT_CACHE_GIB` environment variable (default 4 GiB).
|
||||
- (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB).
|
||||
|
||||
#### Multi-modal input limits
|
||||
|
||||
You can allow a smaller number of multi-modal items per prompt to reduce the memory footprint of the model:
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
|
||||
# Accept up to 3 images and 1 video per prompt
|
||||
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
limit_mm_per_prompt={"image": 3, "video": 1})
|
||||
```
|
||||
|
||||
You can go a step further and disable unused modalities completely by setting its limit to zero.
|
||||
For example, if your application only accepts image input, there is no need to allocate any memory for videos.
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
|
||||
# Accept any number of images but no videos
|
||||
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
limit_mm_per_prompt={"video": 0})
|
||||
```
|
||||
|
||||
You can even run a multi-modal model for text-only inference:
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
|
||||
# Don't accept images. Just text.
|
||||
llm = LLM(model="google/gemma-3-27b-it",
|
||||
limit_mm_per_prompt={"image": 0})
|
||||
```
|
||||
|
||||
#### Multi-modal processor arguments
|
||||
|
||||
For certain models, you can adjust the multi-modal processor arguments to
|
||||
reduce the size of the processed multi-modal inputs, which in turn saves memory.
|
||||
|
||||
Here are some examples:
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
|
||||
# Available for Qwen2-VL series models
|
||||
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
mm_processor_kwargs={
|
||||
"max_pixels": 768 * 768, # Default is 1280 * 28 * 28
|
||||
})
|
||||
|
||||
# Available for InternVL series models
|
||||
llm = LLM(model="OpenGVLab/InternVL2-2B",
|
||||
mm_processor_kwargs={
|
||||
"max_dynamic_patch": 4, # Default is 12
|
||||
})
|
||||
```
|
||||
|
||||
### Performance optimization and tuning
|
||||
|
||||
You can potentially improve the performance of vLLM by finetuning various options.
|
||||
|
||||
@ -2,15 +2,15 @@
|
||||
|
||||
# OpenAI-Compatible Server
|
||||
|
||||
vLLM provides an HTTP server that implements OpenAI's [Completions API](https://platform.openai.com/docs/api-reference/completions), [Chat API](https://platform.openai.com/docs/api-reference/chat), and more!
|
||||
vLLM provides an HTTP server that implements OpenAI's [Completions API](https://platform.openai.com/docs/api-reference/completions), [Chat API](https://platform.openai.com/docs/api-reference/chat), and more! This functionality lets you serve models and interact with them using an HTTP client.
|
||||
|
||||
You can start the server via the [`vllm serve`](#vllm-serve) command, or through [Docker](#deployment-docker):
|
||||
In your terminal, you can [install](../getting_started/installation.md) vLLM, then start the server with the [`vllm serve`](#vllm-serve) command. (You can also use our [Docker](#deployment-docker) image.)
|
||||
|
||||
```bash
|
||||
vllm serve NousResearch/Meta-Llama-3-8B-Instruct --dtype auto --api-key token-abc123
|
||||
```
|
||||
|
||||
To call the server, you can use the [official OpenAI Python client](https://github.com/openai/openai-python), or any other HTTP client.
|
||||
To call the server, in your preferred text editor, create a script that uses an HTTP client. Include any messages that you want to send to the model. Then run that script. Below is an example script using the [official OpenAI Python client](https://github.com/openai/openai-python).
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
@ -33,11 +33,13 @@ print(completion.choices[0].message)
|
||||
vLLM supports some parameters that are not supported by OpenAI, `top_k` for example.
|
||||
You can pass these parameters to vLLM using the OpenAI client in the `extra_body` parameter of your requests, i.e. `extra_body={"top_k": 50}` for `top_k`.
|
||||
:::
|
||||
|
||||
:::{important}
|
||||
By default, the server applies `generation_config.json` from the Hugging Face model repository if it exists. This means the default values of certain sampling parameters can be overridden by those recommended by the model creator.
|
||||
|
||||
To disable this behavior, please pass `--generation-config vllm` when launching the server.
|
||||
:::
|
||||
|
||||
## Supported APIs
|
||||
|
||||
We currently support the following OpenAI APIs:
|
||||
@ -172,6 +174,12 @@ print(completion._request_id)
|
||||
|
||||
The `vllm serve` command is used to launch the OpenAI-compatible server.
|
||||
|
||||
:::{tip}
|
||||
The vast majority of command-line arguments are based on those for offline inference.
|
||||
|
||||
See [here](configuration-options) for some common options.
|
||||
:::
|
||||
|
||||
:::{argparse}
|
||||
:module: vllm.entrypoints.openai.cli_args
|
||||
:func: create_parser_for_docs
|
||||
|
||||
@ -187,6 +187,33 @@ model_example_map = {
|
||||
}
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Demo on using vLLM for offline inference with '
|
||||
'audio language models')
|
||||
parser.add_argument('--model-type',
|
||||
'-m',
|
||||
type=str,
|
||||
default="ultravox",
|
||||
choices=model_example_map.keys(),
|
||||
help='Huggingface "model_type".')
|
||||
parser.add_argument('--num-prompts',
|
||||
type=int,
|
||||
default=1,
|
||||
help='Number of prompts to run.')
|
||||
parser.add_argument("--num-audios",
|
||||
type=int,
|
||||
default=1,
|
||||
choices=[0, 1, 2],
|
||||
help="Number of audio items per prompt.")
|
||||
parser.add_argument("--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the seed when initializing `vllm.LLM`.")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
model = args.model_type
|
||||
if model not in model_example_map:
|
||||
@ -196,6 +223,11 @@ def main(args):
|
||||
req_data = model_example_map[model](question_per_audio_count[audio_count],
|
||||
audio_count)
|
||||
|
||||
# Disable other modalities to save memory
|
||||
default_limits = {"image": 0, "video": 0, "audio": 0}
|
||||
req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
|
||||
req_data.engine_args.limit_mm_per_prompt or {})
|
||||
|
||||
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
|
||||
llm = LLM(**engine_args)
|
||||
|
||||
@ -235,28 +267,5 @@ def main(args):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Demo on using vLLM for offline inference with '
|
||||
'audio language models')
|
||||
parser.add_argument('--model-type',
|
||||
'-m',
|
||||
type=str,
|
||||
default="ultravox",
|
||||
choices=model_example_map.keys(),
|
||||
help='Huggingface "model_type".')
|
||||
parser.add_argument('--num-prompts',
|
||||
type=int,
|
||||
default=1,
|
||||
help='Number of prompts to run.')
|
||||
parser.add_argument("--num-audios",
|
||||
type=int,
|
||||
default=1,
|
||||
choices=[0, 1, 2],
|
||||
help="Number of audio items per prompt.")
|
||||
parser.add_argument("--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the seed when initializing `vllm.LLM`.")
|
||||
|
||||
args = parser.parse_args()
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
||||
@ -12,16 +12,23 @@ prompts = [
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
# Create an LLM.
|
||||
llm = LLM(model="facebook/opt-125m")
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
print("\nGenerated Outputs:\n" + "-" * 60)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}")
|
||||
print(f"Output: {generated_text!r}")
|
||||
print("-" * 60)
|
||||
|
||||
def main():
|
||||
# Create an LLM.
|
||||
llm = LLM(model="facebook/opt-125m")
|
||||
# Generate texts from the prompts.
|
||||
# The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
print("\nGenerated Outputs:\n" + "-" * 60)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}")
|
||||
print(f"Output: {generated_text!r}")
|
||||
print("-" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -4,6 +4,24 @@ from vllm import LLM, EngineArgs
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def create_parser():
|
||||
parser = FlexibleArgumentParser()
|
||||
# Add engine args
|
||||
engine_group = parser.add_argument_group("Engine arguments")
|
||||
EngineArgs.add_cli_args(engine_group)
|
||||
engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
|
||||
# Add sampling params
|
||||
sampling_group = parser.add_argument_group("Sampling parameters")
|
||||
sampling_group.add_argument("--max-tokens", type=int)
|
||||
sampling_group.add_argument("--temperature", type=float)
|
||||
sampling_group.add_argument("--top-p", type=float)
|
||||
sampling_group.add_argument("--top-k", type=int)
|
||||
# Add example params
|
||||
parser.add_argument("--chat-template-path", type=str)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main(args: dict):
|
||||
# Pop arguments not used by LLM
|
||||
max_tokens = args.pop("max_tokens")
|
||||
@ -82,18 +100,6 @@ def main(args: dict):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
# Add engine args
|
||||
engine_group = parser.add_argument_group("Engine arguments")
|
||||
EngineArgs.add_cli_args(engine_group)
|
||||
engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
|
||||
# Add sampling params
|
||||
sampling_group = parser.add_argument_group("Sampling parameters")
|
||||
sampling_group.add_argument("--max-tokens", type=int)
|
||||
sampling_group.add_argument("--temperature", type=float)
|
||||
sampling_group.add_argument("--top-p", type=float)
|
||||
sampling_group.add_argument("--top-k", type=int)
|
||||
# Add example params
|
||||
parser.add_argument("--chat-template-path", type=str)
|
||||
parser = create_parser()
|
||||
args: dict = vars(parser.parse_args())
|
||||
main(args)
|
||||
|
||||
@ -6,6 +6,16 @@ from vllm import LLM, EngineArgs
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(model="jason9693/Qwen2.5-1.5B-apeach",
|
||||
task="classify",
|
||||
enforce_eager=True)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
@ -34,11 +44,5 @@ def main(args: Namespace):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(model="jason9693/Qwen2.5-1.5B-apeach",
|
||||
task="classify",
|
||||
enforce_eager=True)
|
||||
args = parser.parse_args()
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
||||
@ -6,6 +6,16 @@ from vllm import LLM, EngineArgs
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(model="intfloat/e5-mistral-7b-instruct",
|
||||
task="embed",
|
||||
enforce_eager=True)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
@ -34,11 +44,5 @@ def main(args: Namespace):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(model="intfloat/e5-mistral-7b-instruct",
|
||||
task="embed",
|
||||
enforce_eager=True)
|
||||
args = parser.parse_args()
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
||||
@ -4,6 +4,22 @@ from vllm import LLM, EngineArgs
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def create_parser():
|
||||
parser = FlexibleArgumentParser()
|
||||
# Add engine args
|
||||
engine_group = parser.add_argument_group("Engine arguments")
|
||||
EngineArgs.add_cli_args(engine_group)
|
||||
engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
|
||||
# Add sampling params
|
||||
sampling_group = parser.add_argument_group("Sampling parameters")
|
||||
sampling_group.add_argument("--max-tokens", type=int)
|
||||
sampling_group.add_argument("--temperature", type=float)
|
||||
sampling_group.add_argument("--top-p", type=float)
|
||||
sampling_group.add_argument("--top-k", type=int)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main(args: dict):
|
||||
# Pop arguments not used by LLM
|
||||
max_tokens = args.pop("max_tokens")
|
||||
@ -35,23 +51,15 @@ def main(args: dict):
|
||||
]
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
print("-" * 50)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
# Add engine args
|
||||
engine_group = parser.add_argument_group("Engine arguments")
|
||||
EngineArgs.add_cli_args(engine_group)
|
||||
engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
|
||||
# Add sampling params
|
||||
sampling_group = parser.add_argument_group("Sampling parameters")
|
||||
sampling_group.add_argument("--max-tokens", type=int)
|
||||
sampling_group.add_argument("--temperature", type=float)
|
||||
sampling_group.add_argument("--top-p", type=float)
|
||||
sampling_group.add_argument("--top-k", type=int)
|
||||
parser = create_parser()
|
||||
args: dict = vars(parser.parse_args())
|
||||
main(args)
|
||||
|
||||
@ -6,6 +6,16 @@ from vllm import LLM, EngineArgs
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(model="BAAI/bge-reranker-v2-m3",
|
||||
task="score",
|
||||
enforce_eager=True)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
# Sample prompts.
|
||||
text_1 = "What is the capital of France?"
|
||||
@ -30,11 +40,5 @@ def main(args: Namespace):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(model="BAAI/bge-reranker-v2-m3",
|
||||
task="score",
|
||||
enforce_eager=True)
|
||||
args = parser.parse_args()
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
||||
90
examples/offline_inference/batch_llm_inference.py
Normal file
90
examples/offline_inference/batch_llm_inference.py
Normal file
@ -0,0 +1,90 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
This example shows how to use Ray Data for data parallel batch inference.
|
||||
|
||||
Ray Data is a data processing framework that can handle large datasets
|
||||
and integrates tightly with vLLM for data-parallel inference.
|
||||
|
||||
As of Ray 2.44, Ray Data has a native integration with
|
||||
vLLM (under ray.data.llm).
|
||||
|
||||
Ray Data provides functionality for:
|
||||
* Reading and writing to cloud storage (S3, GCS, etc.)
|
||||
* Automatic sharding and load-balancing across a cluster
|
||||
* Optimized configuration of vLLM using continuous batching
|
||||
* Compatible with tensor/pipeline parallel inference as well.
|
||||
|
||||
Learn more about Ray Data's LLM integration:
|
||||
https://docs.ray.io/en/latest/data/working-with-llms.html
|
||||
"""
|
||||
import ray
|
||||
from packaging.version import Version
|
||||
from ray.data.llm import build_llm_processor, vLLMEngineProcessorConfig
|
||||
|
||||
assert Version(ray.__version__) >= Version(
|
||||
"2.44.1"), "Ray version must be at least 2.44.1"
|
||||
|
||||
# Uncomment to reduce clutter in stdout
|
||||
# ray.init(log_to_driver=False)
|
||||
# ray.data.DataContext.get_current().enable_progress_bars = False
|
||||
|
||||
# Read one text file from S3. Ray Data supports reading multiple files
|
||||
# from cloud storage (such as JSONL, Parquet, CSV, binary format).
|
||||
ds = ray.data.read_text("s3://anonymous@air-example-data/prompts.txt")
|
||||
print(ds.schema())
|
||||
|
||||
size = ds.count()
|
||||
print(f"Size of dataset: {size} prompts")
|
||||
|
||||
# Configure vLLM engine.
|
||||
config = vLLMEngineProcessorConfig(
|
||||
model_source="unsloth/Llama-3.1-8B-Instruct",
|
||||
engine_kwargs={
|
||||
"enable_chunked_prefill": True,
|
||||
"max_num_batched_tokens": 4096,
|
||||
"max_model_len": 16384,
|
||||
},
|
||||
concurrency=1, # set the number of parallel vLLM replicas
|
||||
batch_size=64,
|
||||
)
|
||||
|
||||
# Create a Processor object, which will be used to
|
||||
# do batch inference on the dataset
|
||||
vllm_processor = build_llm_processor(
|
||||
config,
|
||||
preprocess=lambda row: dict(
|
||||
messages=[{
|
||||
"role": "system",
|
||||
"content": "You are a bot that responds with haikus."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": row["text"]
|
||||
}],
|
||||
sampling_params=dict(
|
||||
temperature=0.3,
|
||||
max_tokens=250,
|
||||
)),
|
||||
postprocess=lambda row: dict(
|
||||
answer=row["generated_text"],
|
||||
**row # This will return all the original columns in the dataset.
|
||||
),
|
||||
)
|
||||
|
||||
ds = vllm_processor(ds)
|
||||
|
||||
# Peek first 10 results.
|
||||
# NOTE: This is for local testing and debugging. For production use case,
|
||||
# one should write full result out as shown below.
|
||||
outputs = ds.take(limit=10)
|
||||
|
||||
for output in outputs:
|
||||
prompt = output["prompt"]
|
||||
generated_text = output["generated_text"]
|
||||
print(f"Prompt: {prompt!r}")
|
||||
print(f"Generated text: {generated_text!r}")
|
||||
|
||||
# Write inference output data out as Parquet files to S3.
|
||||
# Multiple files would be written to the output destination,
|
||||
# and each task would write one or more files separately.
|
||||
#
|
||||
# ds.write_parquet("s3://<your-output-bucket>")
|
||||
@ -3,9 +3,12 @@
|
||||
This file demonstrates the example usage of cpu offloading
|
||||
with LMCache.
|
||||
|
||||
Note that `pip install lmcache` is needed to run this example.
|
||||
Learn more about LMCache in https://github.com/LMCache/LMCache.
|
||||
Note that `lmcache` is needed to run this example.
|
||||
Requirements: Linux, Python: 3.10 or higher, CUDA: 12.1
|
||||
Learn more about LMCache environment setup, please refer to:
|
||||
https://docs.lmcache.ai/getting_started/installation.html
|
||||
"""
|
||||
import contextlib
|
||||
import os
|
||||
import time
|
||||
|
||||
@ -15,51 +18,83 @@ from lmcache.integration.vllm.utils import ENGINE_NAME
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import KVTransferConfig
|
||||
|
||||
# LMCache-related environment variables
|
||||
# Use experimental features in LMCache
|
||||
os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True"
|
||||
# LMCache is set to use 256 tokens per chunk
|
||||
os.environ["LMCACHE_CHUNK_SIZE"] = "256"
|
||||
# Enable local CPU backend in LMCache
|
||||
os.environ["LMCACHE_LOCAL_CPU"] = "True"
|
||||
# Set local CPU memory limit to 5.0 GB
|
||||
os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "5.0"
|
||||
|
||||
# This example script runs two requests with a shared prefix.
|
||||
shared_prompt = "Hello, how are you?" * 1000
|
||||
first_prompt = [
|
||||
shared_prompt + "Hello, my name is",
|
||||
]
|
||||
second_prompt = [
|
||||
shared_prompt + "Tell me a very long story",
|
||||
]
|
||||
def setup_environment_variables():
|
||||
# LMCache-related environment variables
|
||||
# Use experimental features in LMCache
|
||||
os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True"
|
||||
# LMCache is set to use 256 tokens per chunk
|
||||
os.environ["LMCACHE_CHUNK_SIZE"] = "256"
|
||||
# Enable local CPU backend in LMCache
|
||||
os.environ["LMCACHE_LOCAL_CPU"] = "True"
|
||||
# Set local CPU memory limit to 5.0 GB
|
||||
os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "5.0"
|
||||
|
||||
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
|
||||
|
||||
ktc = KVTransferConfig.from_cli(
|
||||
'{"kv_connector":"LMCacheConnector", "kv_role":"kv_both"}')
|
||||
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
|
||||
# memory. Reduce the value if your GPU has less memory.
|
||||
# Note that LMCache is not compatible with chunked prefill for now.
|
||||
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2",
|
||||
kv_transfer_config=ktc,
|
||||
max_model_len=8000,
|
||||
enable_chunked_prefill=False,
|
||||
gpu_memory_utilization=0.8)
|
||||
@contextlib.contextmanager
|
||||
def build_llm_with_lmcache():
|
||||
ktc = KVTransferConfig.from_cli(
|
||||
'{"kv_connector":"LMCacheConnector", "kv_role":"kv_both"}')
|
||||
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
|
||||
# memory. Reduce the value if your GPU has less memory.
|
||||
# Note: LMCache supports chunked prefill (see vLLM#14505, LMCache#392).
|
||||
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2",
|
||||
kv_transfer_config=ktc,
|
||||
max_model_len=8000,
|
||||
enable_chunked_prefill=True,
|
||||
gpu_memory_utilization=0.8)
|
||||
|
||||
outputs = llm.generate(first_prompt, sampling_params)
|
||||
for output in outputs:
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Generated text: {generated_text!r}")
|
||||
print("First request done.")
|
||||
try:
|
||||
yield llm
|
||||
finally:
|
||||
# Clean up lmcache backend
|
||||
LMCacheEngineBuilder.destroy(ENGINE_NAME)
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
outputs = llm.generate(second_prompt, sampling_params)
|
||||
for output in outputs:
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Generated text: {generated_text!r}")
|
||||
print("Second request done.")
|
||||
def print_output(
|
||||
llm: LLM,
|
||||
prompt: list[str],
|
||||
sampling_params: SamplingParams,
|
||||
req_str: str,
|
||||
):
|
||||
start = time.time()
|
||||
outputs = llm.generate(prompt, sampling_params)
|
||||
print("-" * 50)
|
||||
for output in outputs:
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Generated text: {generated_text!r}")
|
||||
print(f"Generation took {time.time() - start:.2f} seconds, "
|
||||
f"{req_str} request done.")
|
||||
print("-" * 50)
|
||||
|
||||
# Clean up lmcache backend
|
||||
LMCacheEngineBuilder.destroy(ENGINE_NAME)
|
||||
|
||||
def main():
|
||||
setup_environment_variables()
|
||||
|
||||
with build_llm_with_lmcache() as llm:
|
||||
|
||||
# This example script runs two requests with a shared prefix.
|
||||
# Define the shared prompt and specific prompts
|
||||
shared_prompt = "Hello, how are you?" * 1000
|
||||
first_prompt = [
|
||||
shared_prompt + "Hello, my name is",
|
||||
]
|
||||
second_prompt = [
|
||||
shared_prompt + "Tell me a very long story",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(temperature=0,
|
||||
top_p=0.95,
|
||||
max_tokens=10)
|
||||
|
||||
# Print the first output
|
||||
print_output(llm, first_prompt, sampling_params, "first")
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
# print the second output
|
||||
print_output(llm, second_prompt, sampling_params, "second")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -34,6 +34,40 @@ from vllm import LLM, SamplingParams
|
||||
from vllm.utils import get_open_port
|
||||
|
||||
|
||||
def parse_args():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Data Parallel Inference")
|
||||
parser.add_argument("--model",
|
||||
type=str,
|
||||
default="ibm-research/PowerMoE-3b",
|
||||
help="Model name or path")
|
||||
parser.add_argument("--dp-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Data parallel size")
|
||||
parser.add_argument("--tp-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Tensor parallel size")
|
||||
parser.add_argument("--node-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Total number of nodes")
|
||||
parser.add_argument("--node-rank",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Rank of the current node")
|
||||
parser.add_argument("--master-addr",
|
||||
type=str,
|
||||
default="",
|
||||
help="Master node IP address")
|
||||
parser.add_argument("--master-port",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Master node port")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
|
||||
dp_master_port, GPUs_per_dp_rank):
|
||||
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
|
||||
@ -95,37 +129,8 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Data Parallel Inference")
|
||||
parser.add_argument("--model",
|
||||
type=str,
|
||||
default="ibm-research/PowerMoE-3b",
|
||||
help="Model name or path")
|
||||
parser.add_argument("--dp-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Data parallel size")
|
||||
parser.add_argument("--tp-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Tensor parallel size")
|
||||
parser.add_argument("--node-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Total number of nodes")
|
||||
parser.add_argument("--node-rank",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Rank of the current node")
|
||||
parser.add_argument("--master-addr",
|
||||
type=str,
|
||||
default="",
|
||||
help="Master node IP address")
|
||||
parser.add_argument("--master-port",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Master node port")
|
||||
args = parser.parse_args()
|
||||
|
||||
args = parse_args()
|
||||
|
||||
dp_size = args.dp_size
|
||||
tp_size = args.tp_size
|
||||
|
||||
@ -0,0 +1,36 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import KVTransferConfig
|
||||
|
||||
# Read prompts from output.txt
|
||||
prompts = []
|
||||
try:
|
||||
with open("output.txt") as f:
|
||||
for line in f:
|
||||
prompts.append(line.strip())
|
||||
print(f"Loaded {len(prompts)} prompts from output.txt")
|
||||
except FileNotFoundError:
|
||||
print("Error: output.txt file not found")
|
||||
exit(-1)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
|
||||
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.2-1B-Instruct",
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.8,
|
||||
max_num_batched_tokens=64,
|
||||
max_num_seqs=16,
|
||||
kv_transfer_config=KVTransferConfig.from_cli(
|
||||
'{"kv_connector":"SharedStorageConnector","kv_role":"kv_both",'
|
||||
'"kv_connector_extra_config": {"shared_storage_path": "local_storage"}}'
|
||||
)) #, max_model_len=2048, max_num_batched_tokens=2048)
|
||||
|
||||
# 1ST generation (prefill instance)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
@ -0,0 +1,43 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import KVTransferConfig
|
||||
|
||||
context = "Hi " * 1000
|
||||
context2 = "Hey " * 500
|
||||
prompts = [
|
||||
context + "Hello, my name is",
|
||||
context + "The capital of France is",
|
||||
context2 + "Your name is",
|
||||
context2 + "The capital of China is",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
|
||||
|
||||
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.8,
|
||||
kv_transfer_config=KVTransferConfig.from_cli(
|
||||
'{"kv_connector":"SharedStorageConnector","kv_role":"kv_both", '
|
||||
'"kv_connector_extra_config": '
|
||||
'{"shared_storage_path": "local_storage"}}')
|
||||
) #, max_model_len=2048, max_num_batched_tokens=2048)
|
||||
|
||||
# 1ST generation (prefill instance)
|
||||
outputs = llm.generate(
|
||||
prompts,
|
||||
sampling_params,
|
||||
)
|
||||
|
||||
new_prompts = []
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
new_prompts.append(prompt + generated_text)
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
# Write new_prompts to output.txt
|
||||
with open("output.txt", "w") as f:
|
||||
for prompt in new_prompts:
|
||||
f.write(prompt + "\n")
|
||||
print(f"Saved {len(new_prompts)} prompts to output.txt")
|
||||
@ -0,0 +1,5 @@
|
||||
rm -rf local_storage/
|
||||
rm output.txt
|
||||
|
||||
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 prefill_example.py
|
||||
VLLM_ENABLE_V1_MULTIPROCESSING=0 CUDA_VISIBLE_DEVICES=0 python3 decode_example.py
|
||||
@ -95,7 +95,7 @@ def run_decode(prefill_done):
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def main():
|
||||
prefill_done = Event()
|
||||
prefill_process = Process(target=run_prefill, args=(prefill_done, ))
|
||||
decode_process = Process(target=run_decode, args=(prefill_done, ))
|
||||
@ -109,3 +109,7 @@ if __name__ == "__main__":
|
||||
# Terminate the prefill node when decode is finished
|
||||
decode_process.join()
|
||||
prefill_process.terminate()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -38,6 +38,10 @@ os.environ["LMCACHE_REMOTE_URL"] = f"lm://localhost:{port}"
|
||||
# `naive` indicates using raw bytes of the tensor without any compression
|
||||
os.environ["LMCACHE_REMOTE_SERDE"] = "naive"
|
||||
|
||||
prompts = [
|
||||
"Hello, how are you?" * 1000,
|
||||
]
|
||||
|
||||
|
||||
def run_prefill(prefill_done, prompts):
|
||||
# We use GPU 0 for prefill node.
|
||||
@ -106,12 +110,7 @@ def run_lmcache_server(port):
|
||||
return server_proc
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
prompts = [
|
||||
"Hello, how are you?" * 1000,
|
||||
]
|
||||
|
||||
def main():
|
||||
prefill_done = Event()
|
||||
prefill_process = Process(target=run_prefill, args=(prefill_done, prompts))
|
||||
decode_process = Process(target=run_decode, args=(prefill_done, prompts))
|
||||
@ -128,3 +127,7 @@ if __name__ == "__main__":
|
||||
prefill_process.terminate()
|
||||
lmcache_server_process.terminate()
|
||||
lmcache_server_process.wait()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -1,109 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
This example shows how to use Ray Data for running offline batch inference
|
||||
distributively on a multi-nodes cluster.
|
||||
|
||||
Learn more about Ray Data in https://docs.ray.io/en/latest/data/data.html
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import ray
|
||||
from packaging.version import Version
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
assert Version(ray.__version__) >= Version(
|
||||
"2.22.0"), "Ray version must be at least 2.22.0"
|
||||
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
# Set tensor parallelism per instance.
|
||||
tensor_parallel_size = 1
|
||||
|
||||
# Set number of instances. Each instance will use tensor_parallel_size GPUs.
|
||||
num_instances = 1
|
||||
|
||||
|
||||
# Create a class to do batch inference.
|
||||
class LLMPredictor:
|
||||
|
||||
def __init__(self):
|
||||
# Create an LLM.
|
||||
self.llm = LLM(model="meta-llama/Llama-2-7b-chat-hf",
|
||||
tensor_parallel_size=tensor_parallel_size)
|
||||
|
||||
def __call__(self, batch: dict[str, np.ndarray]) -> dict[str, list]:
|
||||
# Generate texts from the prompts.
|
||||
# The output is a list of RequestOutput objects that contain the prompt,
|
||||
# generated text, and other information.
|
||||
outputs = self.llm.generate(batch["text"], sampling_params)
|
||||
prompt: list[str] = []
|
||||
generated_text: list[str] = []
|
||||
for output in outputs:
|
||||
prompt.append(output.prompt)
|
||||
generated_text.append(' '.join([o.text for o in output.outputs]))
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"generated_text": generated_text,
|
||||
}
|
||||
|
||||
|
||||
# Read one text file from S3. Ray Data supports reading multiple files
|
||||
# from cloud storage (such as JSONL, Parquet, CSV, binary format).
|
||||
ds = ray.data.read_text("s3://anonymous@air-example-data/prompts.txt")
|
||||
|
||||
|
||||
# For tensor_parallel_size > 1, we need to create placement groups for vLLM
|
||||
# to use. Every actor has to have its own placement group.
|
||||
def scheduling_strategy_fn():
|
||||
# One bundle per tensor parallel worker
|
||||
pg = ray.util.placement_group(
|
||||
[{
|
||||
"GPU": 1,
|
||||
"CPU": 1
|
||||
}] * tensor_parallel_size,
|
||||
strategy="STRICT_PACK",
|
||||
)
|
||||
return dict(scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||
pg, placement_group_capture_child_tasks=True))
|
||||
|
||||
|
||||
resources_kwarg: dict[str, Any] = {}
|
||||
if tensor_parallel_size == 1:
|
||||
# For tensor_parallel_size == 1, we simply set num_gpus=1.
|
||||
resources_kwarg["num_gpus"] = 1
|
||||
else:
|
||||
# Otherwise, we have to set num_gpus=0 and provide
|
||||
# a function that will create a placement group for
|
||||
# each instance.
|
||||
resources_kwarg["num_gpus"] = 0
|
||||
resources_kwarg["ray_remote_args_fn"] = scheduling_strategy_fn
|
||||
|
||||
# Apply batch inference for all input data.
|
||||
ds = ds.map_batches(
|
||||
LLMPredictor,
|
||||
# Set the concurrency to the number of LLM instances.
|
||||
concurrency=num_instances,
|
||||
# Specify the batch size for inference.
|
||||
batch_size=32,
|
||||
**resources_kwarg,
|
||||
)
|
||||
|
||||
# Peek first 10 results.
|
||||
# NOTE: This is for local testing and debugging. For production use case,
|
||||
# one should write full result out as shown below.
|
||||
outputs = ds.take(limit=10)
|
||||
for output in outputs:
|
||||
prompt = output["prompt"]
|
||||
generated_text = output["generated_text"]
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
# Write inference output data out as Parquet files to S3.
|
||||
# Multiple files would be written to the output destination,
|
||||
# and each task would write one or more files separately.
|
||||
#
|
||||
# ds.write_parquet("s3://<your-output-bucket>")
|
||||
@ -27,7 +27,7 @@ def load_prompts(dataset_path, num_prompts):
|
||||
return prompts[:num_prompts]
|
||||
|
||||
|
||||
def main():
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
@ -45,7 +45,12 @@ def main():
|
||||
parser.add_argument("--enable_chunked_prefill", action='store_true')
|
||||
parser.add_argument("--max_num_batched_tokens", type=int, default=2048)
|
||||
parser.add_argument("--temp", type=float, default=0)
|
||||
args = parser.parse_args()
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
args = parse_args()
|
||||
|
||||
model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||
eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm"
|
||||
|
||||
@ -6,6 +6,16 @@ from vllm import LLM, EngineArgs
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(model="jinaai/jina-embeddings-v3",
|
||||
task="embed",
|
||||
trust_remote_code=True)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
@ -40,11 +50,5 @@ def main(args: Namespace):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(model="jinaai/jina-embeddings-v3",
|
||||
task="embed",
|
||||
trust_remote_code=True)
|
||||
args = parser.parse_args()
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
||||
52
examples/offline_inference/embed_matryoshka_fy.py
Normal file
52
examples/offline_inference/embed_matryoshka_fy.py
Normal file
@ -0,0 +1,52 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from argparse import Namespace
|
||||
|
||||
from vllm import LLM, EngineArgs, PoolingParams
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(model="jinaai/jina-embeddings-v3",
|
||||
task="embed",
|
||||
trust_remote_code=True)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Follow the white rabbit.", # English
|
||||
"Sigue al conejo blanco.", # Spanish
|
||||
"Suis le lapin blanc.", # French
|
||||
"跟着白兔走。", # Chinese
|
||||
"اتبع الأرنب الأبيض.", # Arabic
|
||||
"Folge dem weißen Kaninchen.", # German
|
||||
]
|
||||
|
||||
# Create an LLM.
|
||||
# You should pass task="embed" for embedding models
|
||||
model = LLM(**vars(args))
|
||||
|
||||
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
|
||||
outputs = model.embed(prompts, pooling_params=PoolingParams(dimensions=32))
|
||||
|
||||
# Print the outputs.
|
||||
print("\nGenerated Outputs:")
|
||||
print("-" * 60)
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
embeds = output.outputs.embedding
|
||||
embeds_trimmed = ((str(embeds[:16])[:-1] +
|
||||
", ...]") if len(embeds) > 16 else embeds)
|
||||
print(f"Prompt: {prompt!r} \n"
|
||||
f"Embeddings: {embeds_trimmed} "
|
||||
f"(size={len(embeds)})")
|
||||
print("-" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
@ -8,94 +8,112 @@ from vllm import LLM, SamplingParams
|
||||
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
|
||||
TokensPrompt, zip_enc_dec_prompts)
|
||||
|
||||
dtype = "float"
|
||||
|
||||
# Create a BART encoder/decoder model instance
|
||||
llm = LLM(
|
||||
model="facebook/bart-large-cnn",
|
||||
dtype=dtype,
|
||||
)
|
||||
def create_prompts(tokenizer):
|
||||
# Test prompts
|
||||
#
|
||||
# This section shows all of the valid ways to prompt an
|
||||
# encoder/decoder model.
|
||||
#
|
||||
# - Helpers for building prompts
|
||||
text_prompt_raw = "Hello, my name is"
|
||||
text_prompt = TextPrompt(prompt="The president of the United States is")
|
||||
tokens_prompt = TokensPrompt(prompt_token_ids=tokenizer.encode(
|
||||
prompt="The capital of France is"))
|
||||
# - Pass a single prompt to encoder/decoder model
|
||||
# (implicitly encoder input prompt);
|
||||
# decoder input prompt is assumed to be None
|
||||
|
||||
# Get BART tokenizer
|
||||
tokenizer = llm.llm_engine.get_tokenizer_group()
|
||||
single_text_prompt_raw = text_prompt_raw # Pass a string directly
|
||||
single_text_prompt = text_prompt # Pass a TextPrompt
|
||||
single_tokens_prompt = tokens_prompt # Pass a TokensPrompt
|
||||
|
||||
# Test prompts
|
||||
#
|
||||
# This section shows all of the valid ways to prompt an
|
||||
# encoder/decoder model.
|
||||
#
|
||||
# - Helpers for building prompts
|
||||
text_prompt_raw = "Hello, my name is"
|
||||
text_prompt = TextPrompt(prompt="The president of the United States is")
|
||||
tokens_prompt = TokensPrompt(prompt_token_ids=tokenizer.encode(
|
||||
prompt="The capital of France is"))
|
||||
# - Pass a single prompt to encoder/decoder model
|
||||
# (implicitly encoder input prompt);
|
||||
# decoder input prompt is assumed to be None
|
||||
# ruff: noqa: E501
|
||||
# - Pass explicit encoder and decoder input prompts within one data structure.
|
||||
# Encoder and decoder prompts can both independently be text or tokens, with
|
||||
# no requirement that they be the same prompt type. Some example prompt-type
|
||||
# combinations are shown below, note that these are not exhaustive.
|
||||
|
||||
single_text_prompt_raw = text_prompt_raw # Pass a string directly
|
||||
single_text_prompt = text_prompt # Pass a TextPrompt
|
||||
single_tokens_prompt = tokens_prompt # Pass a TokensPrompt
|
||||
enc_dec_prompt1 = ExplicitEncoderDecoderPrompt(
|
||||
# Pass encoder prompt string directly, &
|
||||
# pass decoder prompt tokens
|
||||
encoder_prompt=single_text_prompt_raw,
|
||||
decoder_prompt=single_tokens_prompt,
|
||||
)
|
||||
enc_dec_prompt2 = ExplicitEncoderDecoderPrompt(
|
||||
# Pass TextPrompt to encoder, and
|
||||
# pass decoder prompt string directly
|
||||
encoder_prompt=single_text_prompt,
|
||||
decoder_prompt=single_text_prompt_raw,
|
||||
)
|
||||
enc_dec_prompt3 = ExplicitEncoderDecoderPrompt(
|
||||
# Pass encoder prompt tokens directly, and
|
||||
# pass TextPrompt to decoder
|
||||
encoder_prompt=single_tokens_prompt,
|
||||
decoder_prompt=single_text_prompt,
|
||||
)
|
||||
|
||||
# - Pass explicit encoder and decoder input prompts within one data structure.
|
||||
# Encoder and decoder prompts can both independently be text or tokens, with
|
||||
# no requirement that they be the same prompt type. Some example prompt-type
|
||||
# combinations are shown below, note that these are not exhaustive.
|
||||
# - Finally, here's a useful helper function for zipping encoder and
|
||||
# decoder prompts together into a list of ExplicitEncoderDecoderPrompt
|
||||
# instances
|
||||
zipped_prompt_list = zip_enc_dec_prompts(
|
||||
['An encoder prompt', 'Another encoder prompt'],
|
||||
['A decoder prompt', 'Another decoder prompt'])
|
||||
|
||||
enc_dec_prompt1 = ExplicitEncoderDecoderPrompt(
|
||||
# Pass encoder prompt string directly, &
|
||||
# pass decoder prompt tokens
|
||||
encoder_prompt=single_text_prompt_raw,
|
||||
decoder_prompt=single_tokens_prompt,
|
||||
)
|
||||
enc_dec_prompt2 = ExplicitEncoderDecoderPrompt(
|
||||
# Pass TextPrompt to encoder, and
|
||||
# pass decoder prompt string directly
|
||||
encoder_prompt=single_text_prompt,
|
||||
decoder_prompt=single_text_prompt_raw,
|
||||
)
|
||||
enc_dec_prompt3 = ExplicitEncoderDecoderPrompt(
|
||||
# Pass encoder prompt tokens directly, and
|
||||
# pass TextPrompt to decoder
|
||||
encoder_prompt=single_tokens_prompt,
|
||||
decoder_prompt=single_text_prompt,
|
||||
)
|
||||
# - Let's put all of the above example prompts together into one list
|
||||
# which we will pass to the encoder/decoder LLM.
|
||||
return [
|
||||
single_text_prompt_raw, single_text_prompt, single_tokens_prompt,
|
||||
enc_dec_prompt1, enc_dec_prompt2, enc_dec_prompt3
|
||||
] + zipped_prompt_list
|
||||
|
||||
# - Finally, here's a useful helper function for zipping encoder and
|
||||
# decoder prompts together into a list of ExplicitEncoderDecoderPrompt
|
||||
# instances
|
||||
zipped_prompt_list = zip_enc_dec_prompts(
|
||||
['An encoder prompt', 'Another encoder prompt'],
|
||||
['A decoder prompt', 'Another decoder prompt'])
|
||||
|
||||
# - Let's put all of the above example prompts together into one list
|
||||
# which we will pass to the encoder/decoder LLM.
|
||||
prompts = [
|
||||
single_text_prompt_raw, single_text_prompt, single_tokens_prompt,
|
||||
enc_dec_prompt1, enc_dec_prompt2, enc_dec_prompt3
|
||||
] + zipped_prompt_list
|
||||
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
top_p=1.0,
|
||||
min_tokens=0,
|
||||
max_tokens=20,
|
||||
)
|
||||
def create_sampling_params():
|
||||
return SamplingParams(
|
||||
temperature=0,
|
||||
top_p=1.0,
|
||||
min_tokens=0,
|
||||
max_tokens=20,
|
||||
)
|
||||
|
||||
# Generate output tokens from the prompts. The output is a list of
|
||||
# RequestOutput objects that contain the prompt, generated
|
||||
# text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
# Print the outputs.
|
||||
print("-" * 50)
|
||||
for i, output in enumerate(outputs):
|
||||
prompt = output.prompt
|
||||
encoder_prompt = output.encoder_prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Output {i+1}:")
|
||||
print(f"Encoder prompt: {encoder_prompt!r}\n"
|
||||
f"Decoder prompt: {prompt!r}\n"
|
||||
f"Generated text: {generated_text!r}")
|
||||
def print_outputs(outputs):
|
||||
print("-" * 50)
|
||||
for i, output in enumerate(outputs):
|
||||
prompt = output.prompt
|
||||
encoder_prompt = output.encoder_prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Output {i+1}:")
|
||||
print(f"Encoder prompt: {encoder_prompt!r}\n"
|
||||
f"Decoder prompt: {prompt!r}\n"
|
||||
f"Generated text: {generated_text!r}")
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
def main():
|
||||
dtype = "float"
|
||||
|
||||
# Create a BART encoder/decoder model instance
|
||||
llm = LLM(
|
||||
model="facebook/bart-large-cnn",
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Get BART tokenizer
|
||||
tokenizer = llm.llm_engine.get_tokenizer_group()
|
||||
|
||||
prompts = create_prompts(tokenizer)
|
||||
sampling_params = create_sampling_params()
|
||||
|
||||
# Generate output tokens from the prompts. The output is a list of
|
||||
# RequestOutput objects that contain the prompt, generated
|
||||
# text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
print_outputs(outputs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -22,7 +22,7 @@ class ModelRequestData(NamedTuple):
|
||||
def run_florence2():
|
||||
engine_args = EngineArgs(
|
||||
model="microsoft/Florence-2-large",
|
||||
tokenizer="facebook/bart-large",
|
||||
tokenizer="Isotr0py/Florence-2-tokenizer",
|
||||
max_num_seqs=8,
|
||||
trust_remote_code=True,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
@ -126,6 +126,23 @@ model_example_map = {
|
||||
}
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Demo on using vLLM for offline inference with '
|
||||
'vision language models for text generation')
|
||||
parser.add_argument('--model-type',
|
||||
'-m',
|
||||
type=str,
|
||||
default="mllama",
|
||||
choices=model_example_map.keys(),
|
||||
help='Huggingface "model_type".')
|
||||
parser.add_argument("--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the seed when initializing `vllm.LLM`.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
model = args.model_type
|
||||
if model not in model_example_map:
|
||||
@ -133,6 +150,11 @@ def main(args):
|
||||
|
||||
req_data = model_example_map[model]()
|
||||
|
||||
# Disable other modalities to save memory
|
||||
default_limits = {"image": 0, "video": 0, "audio": 0}
|
||||
req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
|
||||
req_data.engine_args.limit_mm_per_prompt or {})
|
||||
|
||||
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
|
||||
llm = LLM(**engine_args)
|
||||
|
||||
@ -143,6 +165,7 @@ def main(args):
|
||||
temperature=0,
|
||||
top_p=1.0,
|
||||
max_tokens=64,
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
|
||||
start = time.time()
|
||||
@ -166,19 +189,5 @@ def main(args):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Demo on using vLLM for offline inference with '
|
||||
'vision language models for text generation')
|
||||
parser.add_argument('--model-type',
|
||||
'-m',
|
||||
type=str,
|
||||
default="mllama",
|
||||
choices=model_example_map.keys(),
|
||||
help='Huggingface "model_type".')
|
||||
parser.add_argument("--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the seed when initializing `vllm.LLM`.")
|
||||
|
||||
args = parser.parse_args()
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
||||
@ -50,6 +50,13 @@ def initialize_engine(args: argparse.Namespace) -> LLMEngine:
|
||||
return LLMEngine.from_engine_args(engine_args)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Demo on using the LLMEngine class directly')
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
"""Main function that sets up and runs the prompt processing."""
|
||||
engine = initialize_engine(args)
|
||||
@ -58,8 +65,5 @@ def main(args: argparse.Namespace):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Demo on using the LLMEngine class directly')
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
||||
@ -16,11 +16,11 @@ from vllm.sampling_params import SamplingParams
|
||||
# # Mistral format
|
||||
# vllm serve mistralai/Mistral-Small-3.1-24B-Instruct-2503 \
|
||||
# --tokenizer-mode mistral --config-format mistral --load-format mistral \
|
||||
# --limit-mm-per-prompt 'image=4' --max-model-len 16384
|
||||
# --limit-mm-per-prompt '{"image":4}' --max-model-len 16384
|
||||
#
|
||||
# # HF format
|
||||
# vllm serve mistralai/Mistral-Small-3.1-24B-Instruct-2503 \
|
||||
# --limit-mm-per-prompt 'image=4' --max-model-len 16384
|
||||
# --limit-mm-per-prompt '{"image":4}' --max-model-len 16384
|
||||
# ```
|
||||
#
|
||||
# - Client:
|
||||
@ -168,7 +168,7 @@ def run_advanced_demo(args: argparse.Namespace):
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
def main():
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run a demo in simple or advanced mode.")
|
||||
|
||||
@ -187,8 +187,11 @@ def main():
|
||||
'--disable-mm-preprocessor-cache',
|
||||
action='store_true',
|
||||
help='If True, disables caching of multi-modal preprocessor/mapper.')
|
||||
return parser.parse_args()
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
if args.mode == "simple":
|
||||
print("Running simple demo...")
|
||||
|
||||
@ -34,8 +34,7 @@ def time_generation(llm: LLM, prompts: list[str],
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def main():
|
||||
template = (
|
||||
"Below is an instruction that describes a task. Write a response "
|
||||
"that appropriately completes the request.\n\n### Instruction:\n{}"
|
||||
@ -66,3 +65,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
time_generation(llm, prompts, sampling_params, "With speculation")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -417,6 +417,38 @@ def run_model(input_data,
|
||||
return pred_imgs
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser("MAE run inference", add_help=False)
|
||||
|
||||
parser.add_argument(
|
||||
"--data_file",
|
||||
type=str,
|
||||
default="./India_900498_S2Hand.tif",
|
||||
help="Path to the file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="output",
|
||||
help="Path to the directory where to save outputs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input_indices",
|
||||
default=[1, 2, 3, 8, 11, 12],
|
||||
type=int,
|
||||
nargs="+",
|
||||
help=
|
||||
"0-based indices of the six Prithvi channels to be selected from the "
|
||||
"input. By default selects [1,2,3,8,11,12] for S2L1C data.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rgb_outputs",
|
||||
action="store_true",
|
||||
help="If present, output files will only contain RGB channels. "
|
||||
"Otherwise, all bands will be saved.",
|
||||
)
|
||||
|
||||
|
||||
def main(
|
||||
data_file: str,
|
||||
output_dir: str,
|
||||
@ -496,35 +528,7 @@ def main(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("MAE run inference", add_help=False)
|
||||
|
||||
parser.add_argument(
|
||||
"--data_file",
|
||||
type=str,
|
||||
default="./India_900498_S2Hand.tif",
|
||||
help="Path to the file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="output",
|
||||
help="Path to the directory where to save outputs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input_indices",
|
||||
default=[1, 2, 3, 8, 11, 12],
|
||||
type=int,
|
||||
nargs="+",
|
||||
help=
|
||||
"0-based indices of the six Prithvi channels to be selected from the "
|
||||
"input. By default selects [1,2,3,8,11,12] for S2L1C data.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rgb_outputs",
|
||||
action="store_true",
|
||||
help="If present, output files will only contain RGB channels. "
|
||||
"Otherwise, all bands will be saved.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args = parse_args()
|
||||
|
||||
main(**vars(args))
|
||||
|
||||
@ -359,7 +359,7 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
|
||||
f" in folder {context.save_chrome_traces_folder}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser(description="""
|
||||
Profile a model
|
||||
|
||||
@ -449,7 +449,10 @@ Profile a model
|
||||
|
||||
EngineArgs.add_cli_args(parser)
|
||||
|
||||
args = parser.parse_args()
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
context = ProfileContext(
|
||||
engine_args=EngineArgs.from_cli_args(args),
|
||||
**{
|
||||
@ -458,3 +461,8 @@ Profile a model
|
||||
if k in inspect.signature(ProfileContext).parameters
|
||||
})
|
||||
run_profile(context, csv_output=args.csv, json_output=args.json)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
||||
@ -29,20 +29,23 @@ from pathlib import Path
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
parser = FlexibleArgumentParser()
|
||||
EngineArgs.add_cli_args(parser)
|
||||
parser.add_argument("--output",
|
||||
"-o",
|
||||
required=True,
|
||||
type=str,
|
||||
help="path to output checkpoint")
|
||||
parser.add_argument("--file-pattern",
|
||||
type=str,
|
||||
help="string pattern of saved filenames")
|
||||
parser.add_argument("--max-file-size",
|
||||
type=str,
|
||||
default=5 * 1024**3,
|
||||
help="max size (in bytes) of each safetensors file")
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
EngineArgs.add_cli_args(parser)
|
||||
parser.add_argument("--output",
|
||||
"-o",
|
||||
required=True,
|
||||
type=str,
|
||||
help="path to output checkpoint")
|
||||
parser.add_argument("--file-pattern",
|
||||
type=str,
|
||||
help="string pattern of saved filenames")
|
||||
parser.add_argument("--max-file-size",
|
||||
type=str,
|
||||
default=5 * 1024**3,
|
||||
help="max size (in bytes) of each safetensors file")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
@ -87,5 +90,5 @@ def main(args):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
||||
@ -18,8 +18,8 @@ prompts = [
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def main():
|
||||
# Create an LLM.
|
||||
llm = LLM(model="facebook/opt-125m", tensor_parallel_size=1)
|
||||
|
||||
@ -42,3 +42,7 @@ if __name__ == "__main__":
|
||||
# Add a buffer to wait for profiler in the background process
|
||||
# (in case MP is on) to finish writing profiling output.
|
||||
time.sleep(10)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -45,7 +45,7 @@ def run_aria(questions: list[str], modality: str) -> ModelRequestData:
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
dtype="bfloat16",
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
prompts = [(f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>{question}"
|
||||
@ -71,7 +71,7 @@ def run_aya_vision(questions: list[str], modality: str) -> ModelRequestData:
|
||||
max_model_len=2048,
|
||||
max_num_seqs=2,
|
||||
mm_processor_kwargs={"crop_to_patches": True},
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
prompts = [
|
||||
f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|><image>{question}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
|
||||
@ -92,7 +92,7 @@ def run_blip2(questions: list[str], modality: str) -> ModelRequestData:
|
||||
prompts = [f"Question: {question} Answer:" for question in questions]
|
||||
engine_args = EngineArgs(
|
||||
model="Salesforce/blip2-opt-6.7b",
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
@ -110,7 +110,7 @@ def run_chameleon(questions: list[str], modality: str) -> ModelRequestData:
|
||||
model="facebook/chameleon-7b",
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
@ -129,8 +129,8 @@ def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData:
|
||||
model=model_name,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]},
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
prompts = [
|
||||
@ -150,12 +150,12 @@ def run_florence2(questions: list[str], modality: str) -> ModelRequestData:
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model="microsoft/Florence-2-large",
|
||||
tokenizer="facebook/bart-large",
|
||||
tokenizer="Isotr0py/Florence-2-tokenizer",
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
trust_remote_code=True,
|
||||
dtype="bfloat16",
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
prompts = ["<MORE_DETAILED_CAPTION>" for _ in questions]
|
||||
@ -175,7 +175,7 @@ def run_fuyu(questions: list[str], modality: str) -> ModelRequestData:
|
||||
model="adept/fuyu-8b",
|
||||
max_model_len=2048,
|
||||
max_num_seqs=2,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
@ -194,7 +194,7 @@ def run_gemma3(questions: list[str], modality: str) -> ModelRequestData:
|
||||
max_model_len=2048,
|
||||
max_num_seqs=2,
|
||||
mm_processor_kwargs={"do_pan_and_scan": True},
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
prompts = [("<bos><start_of_turn>user\n"
|
||||
@ -219,7 +219,7 @@ def run_glm4v(questions: list[str], modality: str) -> ModelRequestData:
|
||||
trust_remote_code=True,
|
||||
enforce_eager=True,
|
||||
hf_overrides={"architectures": ["GLM4VForCausalLM"]},
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
prompts = [
|
||||
@ -246,7 +246,7 @@ def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_model_len=8192,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
@ -287,7 +287,7 @@ def run_idefics3(questions: list[str], modality: str) -> ModelRequestData:
|
||||
"longest_edge": 3 * 364
|
||||
},
|
||||
},
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
prompts = [(
|
||||
f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:"
|
||||
@ -314,7 +314,7 @@ def run_smolvlm(questions: list[str], modality: str) -> ModelRequestData:
|
||||
"longest_edge": 384
|
||||
},
|
||||
},
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
prompts = [
|
||||
(f"<|im_start|>User:<image>{question}<end_of_utterance>\nAssistant:")
|
||||
@ -337,7 +337,7 @@ def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_model_len=4096,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
@ -364,6 +364,29 @@ def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
# Kimi-VL
|
||||
def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
prompts = [
|
||||
"<|im_user|>user<|im_middle|><|media_start|>image<|media_content|>"
|
||||
f"<|media_pad|><|media_end|>{question}<|im_end|>"
|
||||
"<|im_assistant|>assistant<|im_middle|>" for question in questions
|
||||
]
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model="moonshotai/Kimi-VL-A3B-Instruct",
|
||||
trust_remote_code=True,
|
||||
max_model_len=4096,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# LLaVA-1.5
|
||||
def run_llava(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
@ -375,7 +398,7 @@ def run_llava(questions: list[str], modality: str) -> ModelRequestData:
|
||||
engine_args = EngineArgs(
|
||||
model="llava-hf/llava-1.5-7b-hf",
|
||||
max_model_len=4096,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
@ -392,7 +415,7 @@ def run_llava_next(questions: list[str], modality: str) -> ModelRequestData:
|
||||
engine_args = EngineArgs(
|
||||
model="llava-hf/llava-v1.6-mistral-7b-hf",
|
||||
max_model_len=8192,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
@ -414,7 +437,7 @@ def run_llava_next_video(questions: list[str],
|
||||
model="llava-hf/LLaVA-NeXT-Video-7B-hf",
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
@ -442,7 +465,7 @@ def run_llava_onevision(questions: list[str],
|
||||
engine_args = EngineArgs(
|
||||
model="llava-hf/llava-onevision-qwen2-7b-ov-hf",
|
||||
max_model_len=16384,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
@ -465,7 +488,7 @@ def run_mantis(questions: list[str], modality: str) -> ModelRequestData:
|
||||
model="TIGER-Lab/Mantis-8B-siglip-llama3",
|
||||
max_model_len=4096,
|
||||
hf_overrides={"architectures": ["MantisForConditionalGeneration"]},
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
stop_token_ids = [128009]
|
||||
|
||||
@ -506,7 +529,7 @@ def run_minicpmv_base(questions: list[str], modality: str, model_name):
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
trust_remote_code=True,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
# NOTE The stop_token_ids are different for various versions of MiniCPM-V
|
||||
# 2.0
|
||||
@ -561,7 +584,7 @@ def run_mistral3(questions: list[str], modality: str) -> ModelRequestData:
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
tensor_parallel_size=2,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
prompts = [f"<s>[INST]{question}\n[IMG][/INST]" for question in questions]
|
||||
@ -587,7 +610,7 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData:
|
||||
model=model_name,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
@ -611,7 +634,7 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
def run_llama4(questions: list[str], modality: str):
|
||||
def run_llama4(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
|
||||
model_name = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
||||
@ -621,8 +644,8 @@ def run_llama4(questions: list[str], modality: str):
|
||||
max_model_len=8192,
|
||||
max_num_seqs=4,
|
||||
tensor_parallel_size=8,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
gpu_memory_utilization=0.4,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
@ -657,7 +680,7 @@ def run_molmo(questions: list[str], modality: str) -> ModelRequestData:
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
dtype="bfloat16",
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
prompts = [
|
||||
@ -683,7 +706,7 @@ def run_nvlm_d(questions: list[str], modality: str) -> ModelRequestData:
|
||||
trust_remote_code=True,
|
||||
max_model_len=4096,
|
||||
tensor_parallel_size=4,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
@ -710,7 +733,8 @@ def run_paligemma(questions: list[str], modality: str) -> ModelRequestData:
|
||||
prompts = ["caption en" for _ in questions]
|
||||
engine_args = EngineArgs(
|
||||
model="google/paligemma-3b-mix-224",
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
@ -726,7 +750,8 @@ def run_paligemma2(questions: list[str], modality: str) -> ModelRequestData:
|
||||
prompts = ["caption en" for _ in questions]
|
||||
engine_args = EngineArgs(
|
||||
model="google/paligemma2-3b-ft-docci-448",
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
@ -762,7 +787,7 @@ def run_phi3v(questions: list[str], modality: str) -> ModelRequestData:
|
||||
max_num_seqs=2,
|
||||
# Note - mm_processor_kwargs can also be passed to generate/chat calls
|
||||
mm_processor_kwargs={"num_crops": 16},
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
@ -793,6 +818,7 @@ def run_phi4mm(questions: list[str], modality: str) -> ModelRequestData:
|
||||
max_num_seqs=2,
|
||||
enable_lora=True,
|
||||
max_lora_rank=320,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
@ -813,7 +839,7 @@ def run_pixtral_hf(questions: list[str], modality: str) -> ModelRequestData:
|
||||
model=model_name,
|
||||
max_model_len=6144,
|
||||
max_num_seqs=2,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
prompts = [f"<s>[INST]{question}\n[IMG][/INST]" for question in questions]
|
||||
@ -834,7 +860,7 @@ def run_qwen_vl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
max_model_len=1024,
|
||||
max_num_seqs=2,
|
||||
hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]},
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
prompts = [f"{question}Picture 1: <img></img>\n" for question in questions]
|
||||
@ -859,7 +885,7 @@ def run_qwen2_vl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
"min_pixels": 28 * 28,
|
||||
"max_pixels": 1280 * 28 * 28,
|
||||
},
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
if modality == "image":
|
||||
@ -894,7 +920,7 @@ def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
"max_pixels": 1280 * 28 * 28,
|
||||
"fps": 1,
|
||||
},
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
if modality == "image":
|
||||
@ -925,7 +951,7 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_model_len=4096,
|
||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name,
|
||||
@ -963,6 +989,7 @@ model_example_map = {
|
||||
"h2ovl_chat": run_h2ovl,
|
||||
"idefics3": run_idefics3,
|
||||
"internvl_chat": run_internvl,
|
||||
"kimi_vl": run_kimi_vl,
|
||||
"llava": run_llava,
|
||||
"llava-next": run_llava_next,
|
||||
"llava-next-video": run_llava_next_video,
|
||||
@ -1070,6 +1097,59 @@ def time_counter(enable: bool):
|
||||
yield
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Demo on using vLLM for offline inference with '
|
||||
'vision language models for text generation')
|
||||
parser.add_argument('--model-type',
|
||||
'-m',
|
||||
type=str,
|
||||
default="llava",
|
||||
choices=model_example_map.keys(),
|
||||
help='Huggingface "model_type".')
|
||||
parser.add_argument('--num-prompts',
|
||||
type=int,
|
||||
default=4,
|
||||
help='Number of prompts to run.')
|
||||
parser.add_argument('--modality',
|
||||
type=str,
|
||||
default="image",
|
||||
choices=['image', 'video'],
|
||||
help='Modality of the input.')
|
||||
parser.add_argument('--num-frames',
|
||||
type=int,
|
||||
default=16,
|
||||
help='Number of frames to extract from the video.')
|
||||
parser.add_argument("--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the seed when initializing `vllm.LLM`.")
|
||||
|
||||
parser.add_argument(
|
||||
'--image-repeat-prob',
|
||||
type=float,
|
||||
default=None,
|
||||
help='Simulates the hit-ratio for multi-modal preprocessor cache'
|
||||
' (if enabled)')
|
||||
|
||||
parser.add_argument(
|
||||
'--disable-mm-preprocessor-cache',
|
||||
action='store_true',
|
||||
help='If True, disables caching of multi-modal preprocessor/mapper.')
|
||||
|
||||
parser.add_argument(
|
||||
'--time-generate',
|
||||
action='store_true',
|
||||
help='If True, then print the total generate() call time')
|
||||
|
||||
parser.add_argument(
|
||||
'--use-different-prompt-per-request',
|
||||
action='store_true',
|
||||
help='If True, then use different prompt (with the same multi-modal '
|
||||
'data) for each request.')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
model = args.model_type
|
||||
if model not in model_example_map:
|
||||
@ -1082,15 +1162,16 @@ def main(args):
|
||||
|
||||
req_data = model_example_map[model](questions, modality)
|
||||
|
||||
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
|
||||
llm = LLM(**engine_args)
|
||||
# Disable other modalities to save memory
|
||||
default_limits = {"image": 0, "video": 0, "audio": 0}
|
||||
req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
|
||||
req_data.engine_args.limit_mm_per_prompt or {})
|
||||
|
||||
# To maintain code compatibility in this script, we add LoRA here.
|
||||
# You can also add LoRA using:
|
||||
# llm.generate(prompts, lora_request=lora_request,...)
|
||||
if req_data.lora_requests:
|
||||
for lora_request in req_data.lora_requests:
|
||||
llm.llm_engine.add_lora(lora_request=lora_request)
|
||||
engine_args = asdict(req_data.engine_args) | {
|
||||
"seed": args.seed,
|
||||
"disable_mm_preprocessor_cache": args.disable_mm_preprocessor_cache,
|
||||
}
|
||||
llm = LLM(**engine_args)
|
||||
|
||||
# Don't want to check the flag multiple times, so just hijack `prompts`.
|
||||
prompts = req_data.prompts if args.use_different_prompt_per_request else [
|
||||
@ -1147,55 +1228,5 @@ def main(args):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Demo on using vLLM for offline inference with '
|
||||
'vision language models for text generation')
|
||||
parser.add_argument('--model-type',
|
||||
'-m',
|
||||
type=str,
|
||||
default="llava",
|
||||
choices=model_example_map.keys(),
|
||||
help='Huggingface "model_type".')
|
||||
parser.add_argument('--num-prompts',
|
||||
type=int,
|
||||
default=4,
|
||||
help='Number of prompts to run.')
|
||||
parser.add_argument('--modality',
|
||||
type=str,
|
||||
default="image",
|
||||
choices=['image', 'video'],
|
||||
help='Modality of the input.')
|
||||
parser.add_argument('--num-frames',
|
||||
type=int,
|
||||
default=16,
|
||||
help='Number of frames to extract from the video.')
|
||||
parser.add_argument("--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the seed when initializing `vllm.LLM`.")
|
||||
|
||||
parser.add_argument(
|
||||
'--image-repeat-prob',
|
||||
type=float,
|
||||
default=None,
|
||||
help='Simulates the hit-ratio for multi-modal preprocessor cache'
|
||||
' (if enabled)')
|
||||
|
||||
parser.add_argument(
|
||||
'--disable-mm-preprocessor-cache',
|
||||
action='store_true',
|
||||
help='If True, disables caching of multi-modal preprocessor/mapper.')
|
||||
|
||||
parser.add_argument(
|
||||
'--time-generate',
|
||||
action='store_true',
|
||||
help='If True, then print the total generate() call time')
|
||||
|
||||
parser.add_argument(
|
||||
'--use-different-prompt-per-request',
|
||||
action='store_true',
|
||||
help='If True, then use different prompt (with the same multi-modal '
|
||||
'data) for each request.')
|
||||
|
||||
args = parser.parse_args()
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
||||
@ -63,6 +63,7 @@ def run_e5_v(query: Query) -> ModelRequestData:
|
||||
model="royokong/e5-v",
|
||||
task="embed",
|
||||
max_model_len=4096,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
@ -93,6 +94,7 @@ def run_vlm2vec(query: Query) -> ModelRequestData:
|
||||
task="embed",
|
||||
trust_remote_code=True,
|
||||
mm_processor_kwargs={"num_crops": 4},
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
@ -131,6 +133,11 @@ def run_encode(model: str, modality: QueryModality, seed: Optional[int]):
|
||||
query = get_query(modality)
|
||||
req_data = model_example_map[model](query)
|
||||
|
||||
# Disable other modalities to save memory
|
||||
default_limits = {"image": 0, "video": 0, "audio": 0}
|
||||
req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
|
||||
req_data.engine_args.limit_mm_per_prompt or {})
|
||||
|
||||
engine_args = asdict(req_data.engine_args) | {"seed": seed}
|
||||
llm = LLM(**engine_args)
|
||||
|
||||
@ -149,16 +156,13 @@ def run_encode(model: str, modality: QueryModality, seed: Optional[int]):
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
run_encode(args.model_name, args.modality, args.seed)
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"e5_v": run_e5_v,
|
||||
"vlm2vec": run_vlm2vec,
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Demo on using vLLM for offline inference with '
|
||||
'vision language models for multimodal embedding')
|
||||
@ -177,6 +181,13 @@ if __name__ == "__main__":
|
||||
type=int,
|
||||
default=None,
|
||||
help="Set the seed when initializing `vllm.LLM`.")
|
||||
return parser.parse_args()
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
def main(args: Namespace):
|
||||
run_encode(args.model_name, args.modality, args.seed)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
||||
@ -326,6 +326,44 @@ def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
def load_kimi_vl(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "moonshotai/Kimi-VL-A3B-Instruct"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=4,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
)
|
||||
|
||||
placeholders = [{"type": "image", "image": url} for url in image_urls]
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
*placeholders,
|
||||
{
|
||||
"type": "text",
|
||||
"text": question
|
||||
},
|
||||
],
|
||||
}]
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_name,
|
||||
trust_remote_code=True)
|
||||
|
||||
prompt = processor.apply_chat_template(messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
image_data=[fetch_image(url) for url in image_urls],
|
||||
)
|
||||
|
||||
|
||||
def load_mistral3(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
||||
|
||||
@ -640,6 +678,7 @@ model_example_map = {
|
||||
"h2ovl_chat": load_h2ovl,
|
||||
"idefics3": load_idefics3,
|
||||
"internvl_chat": load_internvl,
|
||||
"kimi_vl": load_kimi_vl,
|
||||
"llama4": load_llama4,
|
||||
"mistral3": load_mistral3,
|
||||
"mllama": load_mllama,
|
||||
@ -687,16 +726,14 @@ def run_chat(model: str, question: str, image_urls: list[str],
|
||||
seed: Optional[int]):
|
||||
req_data = model_example_map[model](question, image_urls)
|
||||
|
||||
# Disable other modalities to save memory
|
||||
default_limits = {"image": 0, "video": 0, "audio": 0}
|
||||
req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
|
||||
req_data.engine_args.limit_mm_per_prompt or {})
|
||||
|
||||
engine_args = asdict(req_data.engine_args) | {"seed": seed}
|
||||
llm = LLM(**engine_args)
|
||||
|
||||
# To maintain code compatibility in this script, we add LoRA here.
|
||||
# You can also add LoRA using:
|
||||
# llm.generate(prompts, lora_request=lora_request,...)
|
||||
if req_data.lora_requests:
|
||||
for lora_request in req_data.lora_requests:
|
||||
llm.llm_engine.add_lora(lora_request=lora_request)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.0,
|
||||
max_tokens=256,
|
||||
stop_token_ids=req_data.stop_token_ids)
|
||||
@ -729,22 +766,7 @@ def run_chat(model: str, question: str, image_urls: list[str],
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
model = args.model_type
|
||||
method = args.method
|
||||
seed = args.seed
|
||||
|
||||
image_urls = IMAGE_URLS[:args.num_images]
|
||||
|
||||
if method == "generate":
|
||||
run_generate(model, QUESTION, image_urls, seed)
|
||||
elif method == "chat":
|
||||
run_chat(model, QUESTION, image_urls, seed)
|
||||
else:
|
||||
raise ValueError(f"Invalid method: {method}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Demo on using vLLM for offline inference with '
|
||||
'vision language models that support multi-image input for text '
|
||||
@ -770,6 +792,24 @@ if __name__ == "__main__":
|
||||
choices=list(range(1, 13)), # 12 is the max number of images
|
||||
default=2,
|
||||
help="Number of images to use for the demo.")
|
||||
return parser.parse_args()
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
def main(args: Namespace):
|
||||
model = args.model_type
|
||||
method = args.method
|
||||
seed = args.seed
|
||||
|
||||
image_urls = IMAGE_URLS[:args.num_images]
|
||||
|
||||
if method == "generate":
|
||||
run_generate(model, QUESTION, image_urls, seed)
|
||||
elif method == "chat":
|
||||
run_chat(model, QUESTION, image_urls, seed)
|
||||
else:
|
||||
raise ValueError(f"Invalid method: {method}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
||||
@ -58,6 +58,16 @@ def get_response(response: requests.Response) -> list[str]:
|
||||
return output
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--n", type=int, default=1)
|
||||
parser.add_argument("--prompt", type=str, default="San Francisco is a")
|
||||
parser.add_argument("--stream", action="store_true")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
prompt = args.prompt
|
||||
api_url = f"http://{args.host}:{args.port}/generate"
|
||||
@ -82,11 +92,5 @@ def main(args: Namespace):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--n", type=int, default=1)
|
||||
parser.add_argument("--prompt", type=str, default="San Francisco is a")
|
||||
parser.add_argument("--stream", action="store_true")
|
||||
args = parser.parse_args()
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
||||
@ -2,32 +2,46 @@
|
||||
"""
|
||||
Example of using the OpenAI entrypoint's rerank API which is compatible with
|
||||
the Cohere SDK: https://github.com/cohere-ai/cohere-python
|
||||
Note that `pip install cohere` is needed to run this example.
|
||||
|
||||
run: vllm serve BAAI/bge-reranker-base
|
||||
"""
|
||||
from typing import Union
|
||||
|
||||
import cohere
|
||||
from cohere import Client, ClientV2
|
||||
|
||||
# cohere v1 client
|
||||
co = cohere.Client(base_url="http://localhost:8000", api_key="sk-fake-key")
|
||||
rerank_v1_result = co.rerank(
|
||||
model="BAAI/bge-reranker-base",
|
||||
query="What is the capital of France?",
|
||||
documents=[
|
||||
"The capital of France is Paris", "Reranking is fun!",
|
||||
"vLLM is an open-source framework for fast AI serving"
|
||||
])
|
||||
model = "BAAI/bge-reranker-base"
|
||||
|
||||
print(rerank_v1_result)
|
||||
query = "What is the capital of France?"
|
||||
|
||||
# or the v2
|
||||
co2 = cohere.ClientV2("sk-fake-key", base_url="http://localhost:8000")
|
||||
documents = [
|
||||
"The capital of France is Paris", "Reranking is fun!",
|
||||
"vLLM is an open-source framework for fast AI serving"
|
||||
]
|
||||
|
||||
v2_rerank_result = co2.rerank(
|
||||
model="BAAI/bge-reranker-base",
|
||||
query="What is the capital of France?",
|
||||
documents=[
|
||||
"The capital of France is Paris", "Reranking is fun!",
|
||||
"vLLM is an open-source framework for fast AI serving"
|
||||
])
|
||||
|
||||
print(v2_rerank_result)
|
||||
def cohere_rerank(client: Union[Client, ClientV2], model: str, query: str,
|
||||
documents: list[str]) -> dict:
|
||||
return client.rerank(model=model, query=query, documents=documents)
|
||||
|
||||
|
||||
def main():
|
||||
# cohere v1 client
|
||||
cohere_v1 = cohere.Client(base_url="http://localhost:8000",
|
||||
api_key="sk-fake-key")
|
||||
rerank_v1_result = cohere_rerank(cohere_v1, model, query, documents)
|
||||
print("-" * 50)
|
||||
print("rerank_v1_result:\n", rerank_v1_result)
|
||||
print("-" * 50)
|
||||
|
||||
# or the v2
|
||||
cohere_v2 = cohere.ClientV2("sk-fake-key",
|
||||
base_url="http://localhost:8000")
|
||||
rerank_v2_result = cohere_rerank(cohere_v2, model, query, documents)
|
||||
print("rerank_v2_result:\n", rerank_v2_result)
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -1,52 +1,32 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Example for starting a Gradio OpenAI Chatbot Webserver
|
||||
Start vLLM API server:
|
||||
vllm serve meta-llama/Llama-2-7b-chat-hf
|
||||
|
||||
Start Gradio OpenAI Chatbot Webserver:
|
||||
python examples/online_serving/gradio_openai_chatbot_webserver.py \
|
||||
-m meta-llama/Llama-2-7b-chat-hf
|
||||
|
||||
Note that `pip install --upgrade gradio` is needed to run this example.
|
||||
More details: https://github.com/gradio-app/gradio
|
||||
|
||||
If your antivirus software blocks the download of frpc for gradio,
|
||||
you can install it manually by following these steps:
|
||||
|
||||
1. Download this file: https://cdn-media.huggingface.co/frpc-gradio-0.3/frpc_linux_amd64
|
||||
2. Rename the downloaded file to: frpc_linux_amd64_v0.3
|
||||
3. Move the file to this location: /home/user/.cache/huggingface/gradio/frpc
|
||||
"""
|
||||
import argparse
|
||||
|
||||
import gradio as gr
|
||||
from openai import OpenAI
|
||||
|
||||
# Argument parser setup
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Chatbot Interface with Customizable Parameters')
|
||||
parser.add_argument('--model-url',
|
||||
type=str,
|
||||
default='http://localhost:8000/v1',
|
||||
help='Model URL')
|
||||
parser.add_argument('-m',
|
||||
'--model',
|
||||
type=str,
|
||||
required=True,
|
||||
help='Model name for the chatbot')
|
||||
parser.add_argument('--temp',
|
||||
type=float,
|
||||
default=0.8,
|
||||
help='Temperature for text generation')
|
||||
parser.add_argument('--stop-token-ids',
|
||||
type=str,
|
||||
default='',
|
||||
help='Comma-separated stop token IDs')
|
||||
parser.add_argument("--host", type=str, default=None)
|
||||
parser.add_argument("--port", type=int, default=8001)
|
||||
|
||||
# Parse the arguments
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set OpenAI's API key and API base to use vLLM's API server.
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = args.model_url
|
||||
|
||||
# Create an OpenAI client to interact with the API server
|
||||
client = OpenAI(
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
|
||||
|
||||
def predict(message, history):
|
||||
# Convert chat history to OpenAI format
|
||||
def format_history_to_openai(history):
|
||||
history_openai_format = [{
|
||||
"role": "system",
|
||||
"content": "You are a great ai assistant."
|
||||
"content": "You are a great AI assistant."
|
||||
}]
|
||||
for human, assistant in history:
|
||||
history_openai_format.append({"role": "user", "content": human})
|
||||
@ -54,31 +34,92 @@ def predict(message, history):
|
||||
"role": "assistant",
|
||||
"content": assistant
|
||||
})
|
||||
return history_openai_format
|
||||
|
||||
|
||||
def predict(message, history, client, model_name, temp, stop_token_ids):
|
||||
# Format history to OpenAI chat format
|
||||
history_openai_format = format_history_to_openai(history)
|
||||
history_openai_format.append({"role": "user", "content": message})
|
||||
|
||||
# Create a chat completion request and send it to the API server
|
||||
# Send request to OpenAI API (vLLM server)
|
||||
stream = client.chat.completions.create(
|
||||
model=args.model, # Model name to use
|
||||
messages=history_openai_format, # Chat history
|
||||
temperature=args.temp, # Temperature for text generation
|
||||
stream=True, # Stream response
|
||||
model=model_name,
|
||||
messages=history_openai_format,
|
||||
temperature=temp,
|
||||
stream=True,
|
||||
extra_body={
|
||||
'repetition_penalty':
|
||||
1,
|
||||
'stop_token_ids': [
|
||||
int(id.strip()) for id in args.stop_token_ids.split(',')
|
||||
if id.strip()
|
||||
] if args.stop_token_ids else []
|
||||
'stop_token_ids':
|
||||
[int(id.strip())
|
||||
for id in stop_token_ids.split(',')] if stop_token_ids else []
|
||||
})
|
||||
|
||||
# Read and return generated text from response stream
|
||||
partial_message = ""
|
||||
# Collect all chunks and concatenate them into a full message
|
||||
full_message = ""
|
||||
for chunk in stream:
|
||||
partial_message += (chunk.choices[0].delta.content or "")
|
||||
yield partial_message
|
||||
full_message += (chunk.choices[0].delta.content or "")
|
||||
|
||||
# Return the full message as a single response
|
||||
return full_message
|
||||
|
||||
|
||||
# Create and launch a chat interface with Gradio
|
||||
gr.ChatInterface(predict).queue().launch(server_name=args.host,
|
||||
server_port=args.port,
|
||||
share=True)
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Chatbot Interface with Customizable Parameters')
|
||||
parser.add_argument('--model-url',
|
||||
type=str,
|
||||
default='http://localhost:8000/v1',
|
||||
help='Model URL')
|
||||
parser.add_argument('-m',
|
||||
'--model',
|
||||
type=str,
|
||||
required=True,
|
||||
help='Model name for the chatbot')
|
||||
parser.add_argument('--temp',
|
||||
type=float,
|
||||
default=0.8,
|
||||
help='Temperature for text generation')
|
||||
parser.add_argument('--stop-token-ids',
|
||||
type=str,
|
||||
default='',
|
||||
help='Comma-separated stop token IDs')
|
||||
parser.add_argument("--host", type=str, default=None)
|
||||
parser.add_argument("--port", type=int, default=8001)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def build_gradio_interface(client, model_name, temp, stop_token_ids):
|
||||
|
||||
def chat_predict(message, history):
|
||||
return predict(message, history, client, model_name, temp,
|
||||
stop_token_ids)
|
||||
|
||||
return gr.ChatInterface(fn=chat_predict,
|
||||
title="Chatbot Interface",
|
||||
description="A simple chatbot powered by vLLM")
|
||||
|
||||
|
||||
def main():
|
||||
# Parse the arguments
|
||||
args = parse_args()
|
||||
|
||||
# Set OpenAI's API key and API base to use vLLM's API server
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = args.model_url
|
||||
|
||||
# Create an OpenAI client
|
||||
client = OpenAI(api_key=openai_api_key, base_url=openai_api_base)
|
||||
|
||||
# Define the Gradio chatbot interface using the predict function
|
||||
gradio_interface = build_gradio_interface(client, args.model, args.temp,
|
||||
args.stop_token_ids)
|
||||
|
||||
gradio_interface.queue().launch(server_name=args.host,
|
||||
server_port=args.port,
|
||||
share=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -1,5 +1,22 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Example for starting a Gradio Webserver
|
||||
Start vLLM API server:
|
||||
python -m vllm.entrypoints.api_server \
|
||||
--model meta-llama/Llama-2-7b-chat-hf
|
||||
|
||||
Start Webserver:
|
||||
python examples/online_serving/gradio_webserver.py
|
||||
|
||||
Note that `pip install --upgrade gradio` is needed to run this example.
|
||||
More details: https://github.com/gradio-app/gradio
|
||||
|
||||
If your antivirus software blocks the download of frpc for gradio,
|
||||
you can install it manually by following these steps:
|
||||
|
||||
1. Download this file: https://cdn-media.huggingface.co/frpc-gradio-0.3/frpc_linux_amd64
|
||||
2. Rename the downloaded file to: frpc_linux_amd64_v0.3
|
||||
3. Move the file to this location: /home/user/.cache/huggingface/gradio/frpc
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
|
||||
@ -39,16 +56,23 @@ def build_demo():
|
||||
return demo
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default=None)
|
||||
parser.add_argument("--port", type=int, default=8001)
|
||||
parser.add_argument("--model-url",
|
||||
type=str,
|
||||
default="http://localhost:8000/generate")
|
||||
args = parser.parse_args()
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
demo = build_demo()
|
||||
demo.queue().launch(server_name=args.host,
|
||||
server_port=args.port,
|
||||
share=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
||||
@ -23,12 +23,19 @@ data = {
|
||||
"The capital of France is Paris.", "Horses and cows are both animals"
|
||||
]
|
||||
}
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
|
||||
# Check the response
|
||||
if response.status_code == 200:
|
||||
print("Request successful!")
|
||||
print(json.dumps(response.json(), indent=2))
|
||||
else:
|
||||
print(f"Request failed with status code: {response.status_code}")
|
||||
print(response.text)
|
||||
|
||||
def main():
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
|
||||
# Check the response
|
||||
if response.status_code == 200:
|
||||
print("Request successful!")
|
||||
print(json.dumps(response.json(), indent=2))
|
||||
else:
|
||||
print(f"Request failed with status code: {response.status_code}")
|
||||
print(response.text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -1,38 +1,49 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
"""Example Python client for OpenAI Chat Completion using vLLM API server
|
||||
NOTE: start a supported chat completion model server with `vllm serve`, e.g.
|
||||
vllm serve meta-llama/Llama-2-7b-chat-hf
|
||||
"""
|
||||
from openai import OpenAI
|
||||
|
||||
# Modify OpenAI's API key and API base to use vLLM's API server.
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = "http://localhost:8000/v1"
|
||||
|
||||
client = OpenAI(
|
||||
# defaults to os.environ.get("OPENAI_API_KEY")
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Who won the world series in 2020?"
|
||||
}, {
|
||||
"role": "assistant",
|
||||
"content": "The Los Angeles Dodgers won the World Series in 2020."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Where was it played?"
|
||||
}]
|
||||
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
|
||||
chat_completion = client.chat.completions.create(
|
||||
messages=[{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Who won the world series in 2020?"
|
||||
}, {
|
||||
"role":
|
||||
"assistant",
|
||||
"content":
|
||||
"The Los Angeles Dodgers won the World Series in 2020."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Where was it played?"
|
||||
}],
|
||||
model=model,
|
||||
)
|
||||
def main():
|
||||
client = OpenAI(
|
||||
# defaults to os.environ.get("OPENAI_API_KEY")
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
|
||||
print("Chat completion results:")
|
||||
print(chat_completion)
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
|
||||
chat_completion = client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=model,
|
||||
)
|
||||
|
||||
print("-" * 50)
|
||||
print("Chat completion results:")
|
||||
print(chat_completion)
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -9,7 +9,7 @@ vllm serve llava-hf/llava-1.5-7b-hf --chat-template template_llava.jinja
|
||||
|
||||
(multi-image inference with Phi-3.5-vision-instruct)
|
||||
vllm serve microsoft/Phi-3.5-vision-instruct --task generate \
|
||||
--trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2
|
||||
--trust-remote-code --max-model-len 4096 --limit-mm-per-prompt '{"image":2}'
|
||||
|
||||
(audio inference with Ultravox)
|
||||
vllm serve fixie-ai/ultravox-v0_5-llama-3_2-1b --max-model-len 4096
|
||||
@ -303,12 +303,7 @@ example_function_map = {
|
||||
}
|
||||
|
||||
|
||||
def main(args) -> None:
|
||||
chat_type = args.chat_type
|
||||
example_function_map[chat_type]()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Demo on using OpenAI client for online serving with '
|
||||
'multimodal language models served with vLLM.')
|
||||
@ -318,5 +313,14 @@ if __name__ == "__main__":
|
||||
default="single-image",
|
||||
choices=list(example_function_map.keys()),
|
||||
help='Conversation type with multimodal data.')
|
||||
args = parser.parse_args()
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args) -> None:
|
||||
chat_type = args.chat_type
|
||||
example_function_map[chat_type]()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
To run this example, you can start the vLLM server
|
||||
To run this example, you can start the vLLM server
|
||||
without any specific flags:
|
||||
|
||||
```bash
|
||||
@ -8,7 +8,7 @@ VLLM_USE_V1=0 vllm serve unsloth/Llama-3.2-1B-Instruct \
|
||||
--guided-decoding-backend outlines
|
||||
```
|
||||
|
||||
This example demonstrates how to generate chat completions
|
||||
This example demonstrates how to generate chat completions
|
||||
using the OpenAI Python client library.
|
||||
"""
|
||||
|
||||
@ -18,15 +18,6 @@ from openai import OpenAI
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = "http://localhost:8000/v1"
|
||||
|
||||
client = OpenAI(
|
||||
# defaults to os.environ.get("OPENAI_API_KEY")
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
@ -116,21 +107,36 @@ messages = [
|
||||
},
|
||||
]
|
||||
|
||||
chat_completion = client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=model,
|
||||
tools=tools,
|
||||
tool_choice="required",
|
||||
stream=True # Enable streaming response
|
||||
)
|
||||
|
||||
for chunk in chat_completion:
|
||||
if chunk.choices and chunk.choices[0].delta.tool_calls:
|
||||
print(chunk.choices[0].delta.tool_calls)
|
||||
def main():
|
||||
client = OpenAI(
|
||||
# defaults to os.environ.get("OPENAI_API_KEY")
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
|
||||
chat_completion = client.chat.completions.create(messages=messages,
|
||||
model=model,
|
||||
tools=tools,
|
||||
tool_choice="required")
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
|
||||
print(chat_completion.choices[0].message.tool_calls)
|
||||
chat_completion = client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=model,
|
||||
tools=tools,
|
||||
tool_choice="required",
|
||||
stream=True # Enable streaming response
|
||||
)
|
||||
|
||||
for chunk in chat_completion:
|
||||
if chunk.choices and chunk.choices[0].delta.tool_calls:
|
||||
print(chunk.choices[0].delta.tool_calls)
|
||||
|
||||
chat_completion = client.chat.completions.create(messages=messages,
|
||||
model=model,
|
||||
tools=tools,
|
||||
tool_choice="required")
|
||||
|
||||
print(chat_completion.choices[0].message.tool_calls)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -3,8 +3,8 @@
|
||||
An example shows how to generate chat completions from reasoning models
|
||||
like DeepSeekR1.
|
||||
|
||||
To run this example, you need to start the vLLM server with the reasoning
|
||||
parser:
|
||||
To run this example, you need to start the vLLM server
|
||||
with the reasoning parser:
|
||||
|
||||
```bash
|
||||
vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \
|
||||
@ -21,35 +21,44 @@ from openai import OpenAI
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = "http://localhost:8000/v1"
|
||||
|
||||
client = OpenAI(
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
def main():
|
||||
client = OpenAI(
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
|
||||
# Round 1
|
||||
messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}]
|
||||
# For granite, add: `extra_body={"chat_template_kwargs": {"thinking": True}}`
|
||||
response = client.chat.completions.create(model=model, messages=messages)
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
|
||||
reasoning_content = response.choices[0].message.reasoning_content
|
||||
content = response.choices[0].message.content
|
||||
# Round 1
|
||||
messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}]
|
||||
# ruff: noqa: E501
|
||||
# For granite, add: `extra_body={"chat_template_kwargs": {"thinking": True}}`
|
||||
response = client.chat.completions.create(model=model, messages=messages)
|
||||
|
||||
print("reasoning_content for Round 1:", reasoning_content)
|
||||
print("content for Round 1:", content)
|
||||
reasoning_content = response.choices[0].message.reasoning_content
|
||||
content = response.choices[0].message.content
|
||||
|
||||
# Round 2
|
||||
messages.append({"role": "assistant", "content": content})
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": "How many Rs are there in the word 'strawberry'?",
|
||||
})
|
||||
response = client.chat.completions.create(model=model, messages=messages)
|
||||
print("reasoning_content for Round 1:", reasoning_content)
|
||||
print("content for Round 1:", content)
|
||||
|
||||
reasoning_content = response.choices[0].message.reasoning_content
|
||||
content = response.choices[0].message.content
|
||||
# Round 2
|
||||
messages.append({"role": "assistant", "content": content})
|
||||
messages.append({
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"How many Rs are there in the word 'strawberry'?",
|
||||
})
|
||||
response = client.chat.completions.create(model=model, messages=messages)
|
||||
|
||||
print("reasoning_content for Round 2:", reasoning_content)
|
||||
print("content for Round 2:", content)
|
||||
reasoning_content = response.choices[0].message.reasoning_content
|
||||
content = response.choices[0].message.content
|
||||
|
||||
print("reasoning_content for Round 2:", reasoning_content)
|
||||
print("content for Round 2:", content)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
An example shows how to generate chat completions from reasoning models
|
||||
like DeepSeekR1.
|
||||
|
||||
To run this example, you need to start the vLLM server with the reasoning
|
||||
To run this example, you need to start the vLLM server with the reasoning
|
||||
parser:
|
||||
|
||||
```bash
|
||||
@ -29,41 +29,49 @@ from openai import OpenAI
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = "http://localhost:8000/v1"
|
||||
|
||||
client = OpenAI(
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
|
||||
messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}]
|
||||
# For granite, add: `extra_body={"chat_template_kwargs": {"thinking": True}}`
|
||||
stream = client.chat.completions.create(model=model,
|
||||
messages=messages,
|
||||
stream=True)
|
||||
|
||||
print("client: Start streaming chat completions...")
|
||||
printed_reasoning_content = False
|
||||
printed_content = False
|
||||
|
||||
for chunk in stream:
|
||||
reasoning_content = None
|
||||
content = None
|
||||
# Check the content is reasoning_content or content
|
||||
if hasattr(chunk.choices[0].delta, "reasoning_content"):
|
||||
reasoning_content = chunk.choices[0].delta.reasoning_content
|
||||
elif hasattr(chunk.choices[0].delta, "content"):
|
||||
content = chunk.choices[0].delta.content
|
||||
def main():
|
||||
client = OpenAI(
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
|
||||
if reasoning_content is not None:
|
||||
if not printed_reasoning_content:
|
||||
printed_reasoning_content = True
|
||||
print("reasoning_content:", end="", flush=True)
|
||||
print(reasoning_content, end="", flush=True)
|
||||
elif content is not None:
|
||||
if not printed_content:
|
||||
printed_content = True
|
||||
print("\ncontent:", end="", flush=True)
|
||||
# Extract and print the content
|
||||
print(content, end="", flush=True)
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
|
||||
# ruff: noqa: E501
|
||||
# For granite: add: `extra_body={"chat_template_kwargs": {"thinking": True}}`
|
||||
stream = client.chat.completions.create(model=model,
|
||||
messages=messages,
|
||||
stream=True)
|
||||
|
||||
print("client: Start streaming chat completions...")
|
||||
printed_reasoning_content = False
|
||||
printed_content = False
|
||||
|
||||
for chunk in stream:
|
||||
reasoning_content = None
|
||||
content = None
|
||||
# Check the content is reasoning_content or content
|
||||
if hasattr(chunk.choices[0].delta, "reasoning_content"):
|
||||
reasoning_content = chunk.choices[0].delta.reasoning_content
|
||||
elif hasattr(chunk.choices[0].delta, "content"):
|
||||
content = chunk.choices[0].delta.content
|
||||
|
||||
if reasoning_content is not None:
|
||||
if not printed_reasoning_content:
|
||||
printed_reasoning_content = True
|
||||
print("reasoning_content:", end="", flush=True)
|
||||
print(reasoning_content, end="", flush=True)
|
||||
elif content is not None:
|
||||
if not printed_content:
|
||||
printed_content = True
|
||||
print("\ncontent:", end="", flush=True)
|
||||
# Extract and print the content
|
||||
print(content, end="", flush=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -98,7 +98,7 @@ def dse_qwen2_vl(inp: dict):
|
||||
print("Embedding output:", response_json["data"][0]["embedding"])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
"Script to call a specified VLM through the API. Make sure to serve "
|
||||
"the model with --task embed before running this.")
|
||||
@ -107,8 +107,10 @@ if __name__ == '__main__':
|
||||
choices=["vlm2vec", "dse_qwen2_vl"],
|
||||
required=True,
|
||||
help="Which model to call.")
|
||||
args = parser.parse_args()
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.model == "vlm2vec":
|
||||
vlm2vec()
|
||||
elif args.model == "dse_qwen2_vl":
|
||||
@ -120,3 +122,8 @@ if __name__ == '__main__':
|
||||
"type": "text",
|
||||
"content": "What is the weather like today?",
|
||||
})
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
||||
@ -6,28 +6,36 @@ from openai import OpenAI
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = "http://localhost:8000/v1"
|
||||
|
||||
client = OpenAI(
|
||||
# defaults to os.environ.get("OPENAI_API_KEY")
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
def main():
|
||||
client = OpenAI(
|
||||
# defaults to os.environ.get("OPENAI_API_KEY")
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
|
||||
# Completion API
|
||||
stream = False
|
||||
completion = client.completions.create(
|
||||
model=model,
|
||||
prompt="A robot may not injure a human being",
|
||||
echo=False,
|
||||
n=2,
|
||||
stream=stream,
|
||||
logprobs=3)
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
|
||||
print("Completion results:")
|
||||
if stream:
|
||||
for c in completion:
|
||||
print(c)
|
||||
else:
|
||||
print(completion)
|
||||
# Completion API
|
||||
stream = False
|
||||
completion = client.completions.create(
|
||||
model=model,
|
||||
prompt="A robot may not injure a human being",
|
||||
echo=False,
|
||||
n=2,
|
||||
stream=stream,
|
||||
logprobs=3)
|
||||
|
||||
print("-" * 50)
|
||||
print("Completion results:")
|
||||
if stream:
|
||||
for c in completion:
|
||||
print(c)
|
||||
else:
|
||||
print(completion)
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -16,13 +16,15 @@ def post_http_request(prompt: dict, api_url: str) -> requests.Response:
|
||||
return response
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--model", type=str, default="BAAI/bge-reranker-v2-m3")
|
||||
return parser.parse_args()
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
def main(args):
|
||||
api_url = f"http://{args.host}:{args.port}/score"
|
||||
model_name = args.model
|
||||
|
||||
@ -30,9 +32,9 @@ if __name__ == "__main__":
|
||||
text_2 = "The capital of Brazil is Brasilia."
|
||||
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
|
||||
score_response = post_http_request(prompt=prompt, api_url=api_url)
|
||||
print("Prompt when text_1 and text_2 are both strings:")
|
||||
print("\nPrompt when text_1 and text_2 are both strings:")
|
||||
pprint.pprint(prompt)
|
||||
print("Score Response:")
|
||||
print("\nScore Response:")
|
||||
pprint.pprint(score_response.json())
|
||||
|
||||
text_1 = "What is the capital of France?"
|
||||
@ -41,9 +43,9 @@ if __name__ == "__main__":
|
||||
]
|
||||
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
|
||||
score_response = post_http_request(prompt=prompt, api_url=api_url)
|
||||
print("Prompt when text_1 is string and text_2 is a list:")
|
||||
print("\nPrompt when text_1 is string and text_2 is a list:")
|
||||
pprint.pprint(prompt)
|
||||
print("Score Response:")
|
||||
print("\nScore Response:")
|
||||
pprint.pprint(score_response.json())
|
||||
|
||||
text_1 = [
|
||||
@ -54,7 +56,12 @@ if __name__ == "__main__":
|
||||
]
|
||||
prompt = {"model": model_name, "text_1": text_1, "text_2": text_2}
|
||||
score_response = post_http_request(prompt=prompt, api_url=api_url)
|
||||
print("Prompt when text_1 and text_2 are both lists:")
|
||||
print("\nPrompt when text_1 and text_2 are both lists:")
|
||||
pprint.pprint(prompt)
|
||||
print("Score Response:")
|
||||
print("\nScore Response:")
|
||||
pprint.pprint(score_response.json())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
||||
@ -6,22 +6,29 @@ from openai import OpenAI
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = "http://localhost:8000/v1"
|
||||
|
||||
client = OpenAI(
|
||||
# defaults to os.environ.get("OPENAI_API_KEY")
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
def main():
|
||||
client = OpenAI(
|
||||
# defaults to os.environ.get("OPENAI_API_KEY")
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
|
||||
responses = client.embeddings.create(
|
||||
input=[
|
||||
"Hello my name is",
|
||||
"The best thing about vLLM is that it supports many different models"
|
||||
],
|
||||
model=model,
|
||||
)
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
|
||||
for data in responses.data:
|
||||
print(data.embedding) # List of float of len 4096
|
||||
responses = client.embeddings.create(
|
||||
# ruff: noqa: E501
|
||||
input=[
|
||||
"Hello my name is",
|
||||
"The best thing about vLLM is that it supports many different models"
|
||||
],
|
||||
model=model,
|
||||
)
|
||||
|
||||
for data in responses.data:
|
||||
print(data.embedding) # List of float of len 4096
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
36
examples/online_serving/openai_embedding_matryoshka_fy.py
Normal file
36
examples/online_serving/openai_embedding_matryoshka_fy.py
Normal file
@ -0,0 +1,36 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Example Python client for embedding API dimensions using vLLM API server
|
||||
NOTE:
|
||||
start a supported Matryoshka Embeddings model server with `vllm serve`, e.g.
|
||||
vllm serve jinaai/jina-embeddings-v3 --trust-remote-code
|
||||
"""
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
# Modify OpenAI's API key and API base to use vLLM's API server.
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = "http://localhost:8000/v1"
|
||||
|
||||
|
||||
def main():
|
||||
client = OpenAI(
|
||||
# defaults to os.environ.get("OPENAI_API_KEY")
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
|
||||
responses = client.embeddings.create(
|
||||
input=["Follow the white rabbit."],
|
||||
model=model,
|
||||
dimensions=1,
|
||||
)
|
||||
|
||||
for data in responses.data:
|
||||
print(data.embedding) # List of float of len 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -17,7 +17,7 @@ def post_http_request(prompt: dict, api_url: str) -> requests.Response:
|
||||
return response
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
@ -25,15 +25,20 @@ if __name__ == "__main__":
|
||||
type=str,
|
||||
default="jason9693/Qwen2.5-1.5B-apeach")
|
||||
|
||||
args = parser.parse_args()
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
api_url = f"http://{args.host}:{args.port}/pooling"
|
||||
model_name = args.model
|
||||
|
||||
# Input like Completions API
|
||||
prompt = {"model": model_name, "input": "vLLM is great!"}
|
||||
pooling_response = post_http_request(prompt=prompt, api_url=api_url)
|
||||
print("-" * 50)
|
||||
print("Pooling Response:")
|
||||
pprint.pprint(pooling_response.json())
|
||||
print("-" * 50)
|
||||
|
||||
# Input like Chat API
|
||||
prompt = {
|
||||
@ -50,3 +55,9 @@ if __name__ == "__main__":
|
||||
pooling_response = post_http_request(prompt=prompt, api_url=api_url)
|
||||
print("Pooling Response:")
|
||||
pprint.pprint(pooling_response.json())
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
||||
@ -23,7 +23,7 @@ def sync_openai():
|
||||
with open(str(mary_had_lamb), "rb") as f:
|
||||
transcription = client.audio.transcriptions.create(
|
||||
file=f,
|
||||
model="openai/whisper-small",
|
||||
model="openai/whisper-large-v3",
|
||||
language="en",
|
||||
response_format="json",
|
||||
temperature=0.0)
|
||||
|
||||
139
examples/tool_chat_template_llama4_pythonic.jinja
Normal file
139
examples/tool_chat_template_llama4_pythonic.jinja
Normal file
@ -0,0 +1,139 @@
|
||||
{{- bos_token }}
|
||||
{%- if custom_tools is defined %}
|
||||
{%- set tools = custom_tools %}
|
||||
{%- endif %}
|
||||
{%- if not tools_in_user_message is defined %}
|
||||
{%- set tools_in_user_message = false %}
|
||||
{%- endif %}
|
||||
{%- if not tools is defined %}
|
||||
{%- set tools = none %}
|
||||
{%- endif %}
|
||||
|
||||
{#- This block extracts the system message, so we can slot it into the right place. #}
|
||||
{%- if messages[0]['role'] == 'system' %}
|
||||
{%- if messages[0]['content'] is string %}
|
||||
{%- set system_message = messages[0]['content']|trim %}
|
||||
{%- else %}
|
||||
{%- set system_message = messages[0]['content'][0]['text']|trim %}
|
||||
{%- endif %}
|
||||
{%- set messages = messages[1:] %}
|
||||
{%- else %}
|
||||
{%- if tools is not none %}
|
||||
{#- Add default tool system message when tools are provided #}
|
||||
{%- set system_message = "You are a helpful assistant with tool calling "
|
||||
"capabilities. Only reply with a tool call if the function exists in the "
|
||||
"library provided by the user. If it doesn't exist, just reply directly in "
|
||||
"natural language. When you receive a tool call response, use the output to "
|
||||
"format an answer to the original user question." %}
|
||||
{%- else %}
|
||||
{%- set system_message = "" %}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
|
||||
{#- System message if the user supplied one, or if tools are used (default tool system message) #}
|
||||
{%- if system_message %}
|
||||
{#- always use user provided system message to override default tool system message #}
|
||||
{{- "<|header_start|>system<|header_end|>\n\n" }}
|
||||
{{- system_message }}
|
||||
{%- if tools is not none and not tools_in_user_message %}
|
||||
{{- "Tools: You have access to the following tools. You might need to use one "
|
||||
"or more function/tool calls to fulfill the task. \n"
|
||||
"If none are needed, then proceed to the response.\n\n"
|
||||
"Tool Call Syntax: You can call tools using the following syntax:\n"
|
||||
"[func_name1(params_name1=params_value1, params_name2=params_value2, ...), ...]\n"
|
||||
"Do not include anything else when calling the tools with the syntax above.\n\n"
|
||||
"Here is a list of functions in JSON format that you can invoke.\n " }}
|
||||
{%- for t in tools %}
|
||||
{{- t | tojson(indent=4) }}
|
||||
{{- "\n\n" }}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{{- "<|eot|>" }}
|
||||
{%- endif %}
|
||||
|
||||
{#- Custom tools are passed in a user message with some extra guidance #}
|
||||
{%- if tools_in_user_message and tools is not none %}
|
||||
{#- Extract the first user message so we can plug it in here #}
|
||||
{%- if messages | length != 0 %}
|
||||
{%- if messages[0]['content'] is string %}
|
||||
{%- set first_user_message = messages[0]['content']|trim %}
|
||||
{%- else %}
|
||||
{%- set first_user_message = messages[0]['content'] | selectattr('type', 'equalto', 'text') | map(attribute='text') | map('trim') | join('\n') %}
|
||||
{%- endif %}
|
||||
{%- set messages = messages[1:] %}
|
||||
{%- else %}
|
||||
{{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }}
|
||||
{%- endif %}
|
||||
{{- '<|header_start|>user<|header_end|>\n\n' -}}
|
||||
{{- first_user_message}}
|
||||
{{- "\nHere is a list of functions in JSON format that you can invoke:"}}
|
||||
{%- for t in tools %}
|
||||
{{- t | tojson(indent=4) }}
|
||||
{{- "\n\n" }}
|
||||
{%- endfor %}
|
||||
{{- "Should you decide to return the function call(s), put them in the format "
|
||||
"of [func_name1(params_name1=params_value1, params_name2=params_value2, "
|
||||
"...), ...]\nDo not include anything else when calling the tools with the "
|
||||
"syntax above." }}
|
||||
{%- endif %}
|
||||
|
||||
{%- for message in messages %}
|
||||
{%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}
|
||||
{{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }}
|
||||
{%- if message['content'] is string %}
|
||||
{{- message['content'] }}
|
||||
{%- else %}
|
||||
{%- for content in message['content'] %}
|
||||
{%- if content['type'] == 'image' %}
|
||||
{{- '<|image|>' }}
|
||||
{%- elif content['type'] == 'text' %}
|
||||
{{- content['text'] | trim }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{{- "<|eot|>" }}
|
||||
{%- elif 'tool_calls' in message and message.tool_calls|length > 0 %}
|
||||
{%- set tool_call = message.tool_calls[0].function %}
|
||||
{{- '<|header_start|>assistant<|header_end|>\n\n' -}}
|
||||
{%- if message['content'] is string %}
|
||||
{{- message['content'] }}
|
||||
{%- else %}
|
||||
{%- for content in message['content'] %}
|
||||
{%- if content['type'] == 'image' %}
|
||||
{{- '<|image|>' }}
|
||||
{%- elif content['type'] == 'text' %}
|
||||
{{- content['text'] }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{%- for tool_call in message.tool_calls %}
|
||||
{%- if tool_call.function is defined %}
|
||||
{%- set tool_call = tool_call.function %}
|
||||
{%- endif %}
|
||||
{{- tool_call.name + '(' -}}
|
||||
{%- for param in tool_call.arguments %}
|
||||
{{- param + '=' -}}
|
||||
{{- "%s" | format(tool_call.arguments[param]) -}}
|
||||
{% if not loop.last %}, {% endif %}
|
||||
{%- endfor %}
|
||||
{{- ')' -}}
|
||||
{% if not loop.last %}, {% endif %}
|
||||
{%- endfor %}
|
||||
{{- "<|eom|>" }}
|
||||
{%- elif message.role == "tool" or message.role == "ipython" %}
|
||||
{{- "<|header_start|>ipython<|header_end|>\n\n" }}
|
||||
{%- if message.content is string %}
|
||||
{{- message.content | tojson }}
|
||||
{%- else %}
|
||||
{%- for content in message['content'] %}
|
||||
{%- if content['type'] == 'text' %}
|
||||
{{- content['text'] | tojson }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{{- "<|eom|>" }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- if add_generation_prompt %}
|
||||
{{- '<|header_start|>assistant<|header_end|>\n\n' }}
|
||||
{%- endif %}
|
||||
@ -8,7 +8,7 @@ blake3
|
||||
py-cpuinfo
|
||||
transformers >= 4.51.1
|
||||
huggingface-hub[hf_xet] >= 0.30.0 # Required for Xet downloads.
|
||||
tokenizers >= 0.19.1 # Required for Llama 3.
|
||||
tokenizers >= 0.21.1 # Required for fast incremental detokenization.
|
||||
protobuf # Required by LlamaTokenizer.
|
||||
fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint.
|
||||
aiohttp
|
||||
@ -26,7 +26,7 @@ xgrammar == 0.1.18; platform_machine == "x86_64" or platform_machine == "aarch64
|
||||
typing_extensions >= 4.10
|
||||
filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317
|
||||
partial-json-parser # used for parsing partial JSON outputs
|
||||
pyzmq
|
||||
pyzmq >= 25.0.0
|
||||
msgspec
|
||||
gguf >= 0.13.0
|
||||
importlib_metadata
|
||||
|
||||
@ -12,9 +12,9 @@ torchaudio; platform_machine != "ppc64le" and platform_machine != "s390x"
|
||||
torchaudio==2.6.0; platform_machine == "ppc64le"
|
||||
|
||||
# required for the image processor of phi3v, this must be updated alongside torch
|
||||
torchvision; platform_machine != "ppc64le" and platform_machine != "s390x"
|
||||
torchvision; platform_machine != "ppc64le" and platform_machine != "s390x"
|
||||
torchvision==0.21.0; platform_machine == "ppc64le"
|
||||
datasets # for benchmark scripts
|
||||
|
||||
# cpu cannot use triton 3.3.0
|
||||
triton==3.2.0; platform_machine != "ppc64le"
|
||||
triton==3.2.0; platform_machine == "x86_64"
|
||||
|
||||
@ -6,6 +6,7 @@ torch==2.6.0
|
||||
torchvision==0.21.0
|
||||
torchaudio==2.6.0
|
||||
|
||||
triton==3.2
|
||||
cmake>=3.26,<4
|
||||
packaging
|
||||
setuptools>=61
|
||||
|
||||
@ -10,6 +10,7 @@ pytest-timeout
|
||||
# testing utils
|
||||
awscli
|
||||
backoff # required for phi4mm test
|
||||
blobfile # required for kimi-vl test
|
||||
einops # required for MPT, qwen-vl and Mamba
|
||||
httpx
|
||||
librosa # required for audio tests
|
||||
@ -26,6 +27,7 @@ torch==2.6.0
|
||||
torchaudio==2.6.0
|
||||
torchvision==0.21.0
|
||||
transformers_stream_generator # required for qwen-vl test
|
||||
mamba_ssm # required for plamo2 test
|
||||
matplotlib # required for qwen-vl test
|
||||
mistral_common[opencv] >= 1.5.4 # required for pixtral test
|
||||
num2words # required for smolvlm test
|
||||
@ -33,7 +35,9 @@ opencv-python-headless >= 4.11.0 # required for video test
|
||||
datamodel_code_generator # required for minicpm3 test
|
||||
lm-eval[api]==0.4.8 # required for model evaluation test
|
||||
transformers==4.51.1
|
||||
tokenizers==0.21.1
|
||||
huggingface-hub[hf_xet]>=0.30.0 # Required for Xet downloads.
|
||||
schemathesis>=3.39.15 # Required for openai schema test.
|
||||
# quantization
|
||||
bitsandbytes>=0.45.3
|
||||
buildkite-test-collector==0.1.9
|
||||
|
||||
@ -20,25 +20,35 @@ aiosignal==1.3.1
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
anyio==4.6.2.post1
|
||||
# via httpx
|
||||
# via
|
||||
# httpx
|
||||
# starlette
|
||||
argcomplete==3.5.1
|
||||
# via datamodel-code-generator
|
||||
arrow==1.3.0
|
||||
# via isoduration
|
||||
attrs==24.2.0
|
||||
# via
|
||||
# aiohttp
|
||||
# hypothesis
|
||||
# jsonlines
|
||||
# jsonschema
|
||||
# pytest-subtests
|
||||
# referencing
|
||||
audioread==3.0.1
|
||||
# via librosa
|
||||
awscli==1.35.23
|
||||
# via -r requirements/test.in
|
||||
backoff==2.2.1
|
||||
# via -r requirements/test.in
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# schemathesis
|
||||
bitsandbytes==0.45.3
|
||||
# via -r requirements/test.in
|
||||
black==24.10.0
|
||||
# via datamodel-code-generator
|
||||
blobfile==3.0.0
|
||||
# via -r requirements/test.in
|
||||
boto3==1.35.57
|
||||
# via tensorizer
|
||||
botocore==1.35.57
|
||||
@ -67,11 +77,13 @@ click==8.1.7
|
||||
# jiwer
|
||||
# nltk
|
||||
# ray
|
||||
# schemathesis
|
||||
# typer
|
||||
colorama==0.4.6
|
||||
# via
|
||||
# awscli
|
||||
# sacrebleu
|
||||
# schemathesis
|
||||
# tqdm-multiprocess
|
||||
contourpy==1.3.0
|
||||
# via matplotlib
|
||||
@ -109,6 +121,7 @@ einops==0.8.0
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# encodec
|
||||
# mamba-ssm
|
||||
# vector-quantize-pytorch
|
||||
# vocos
|
||||
einx==0.3.0
|
||||
@ -127,6 +140,7 @@ fastsafetensors==0.1.10
|
||||
# via -r requirements/test.in
|
||||
filelock==3.16.1
|
||||
# via
|
||||
# blobfile
|
||||
# datasets
|
||||
# huggingface-hub
|
||||
# ray
|
||||
@ -134,6 +148,8 @@ filelock==3.16.1
|
||||
# transformers
|
||||
fonttools==4.54.1
|
||||
# via matplotlib
|
||||
fqdn==1.5.1
|
||||
# via jsonschema
|
||||
frozendict==2.4.6
|
||||
# via einx
|
||||
frozenlist==1.5.0
|
||||
@ -152,8 +168,12 @@ genai-perf==0.0.8
|
||||
# via -r requirements/test.in
|
||||
genson==1.3.0
|
||||
# via datamodel-code-generator
|
||||
graphql-core==3.2.6
|
||||
# via hypothesis-graphql
|
||||
h11==0.14.0
|
||||
# via httpcore
|
||||
harfile==0.3.0
|
||||
# via schemathesis
|
||||
hf-xet==0.1.4
|
||||
# via huggingface-hub
|
||||
hiredis==3.0.0
|
||||
@ -161,7 +181,9 @@ hiredis==3.0.0
|
||||
httpcore==1.0.6
|
||||
# via httpx
|
||||
httpx==0.27.2
|
||||
# via -r requirements/test.in
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# schemathesis
|
||||
huggingface-hub==0.30.1
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
@ -176,17 +198,29 @@ huggingface-hub==0.30.1
|
||||
# vocos
|
||||
humanize==4.11.0
|
||||
# via runai-model-streamer
|
||||
hypothesis==6.131.0
|
||||
# via
|
||||
# hypothesis-graphql
|
||||
# hypothesis-jsonschema
|
||||
# schemathesis
|
||||
hypothesis-graphql==0.11.1
|
||||
# via schemathesis
|
||||
hypothesis-jsonschema==0.23.1
|
||||
# via schemathesis
|
||||
idna==3.10
|
||||
# via
|
||||
# anyio
|
||||
# email-validator
|
||||
# httpx
|
||||
# jsonschema
|
||||
# requests
|
||||
# yarl
|
||||
inflect==5.6.2
|
||||
# via datamodel-code-generator
|
||||
iniconfig==2.0.0
|
||||
# via pytest
|
||||
isoduration==20.11.0
|
||||
# via jsonschema
|
||||
isort==5.13.2
|
||||
# via datamodel-code-generator
|
||||
jinja2==3.1.6
|
||||
@ -206,12 +240,18 @@ joblib==1.4.2
|
||||
# scikit-learn
|
||||
jsonlines==4.0.0
|
||||
# via lm-eval
|
||||
jsonpointer==3.0.0
|
||||
# via jsonschema
|
||||
jsonschema==4.23.0
|
||||
# via
|
||||
# hypothesis-jsonschema
|
||||
# mistral-common
|
||||
# ray
|
||||
# schemathesis
|
||||
jsonschema-specifications==2024.10.1
|
||||
# via jsonschema
|
||||
junit-xml==1.9
|
||||
# via schemathesis
|
||||
kaleido==0.2.1
|
||||
# via genai-perf
|
||||
kiwisolver==1.4.7
|
||||
@ -227,11 +267,17 @@ llvmlite==0.44.0
|
||||
lm-eval==0.4.8
|
||||
# via -r requirements/test.in
|
||||
lxml==5.3.0
|
||||
# via sacrebleu
|
||||
# via
|
||||
# blobfile
|
||||
# sacrebleu
|
||||
mamba-ssm==2.2.4
|
||||
# via -r requirements/test.in
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
markupsafe==3.0.2
|
||||
# via jinja2
|
||||
# via
|
||||
# jinja2
|
||||
# werkzeug
|
||||
matplotlib==3.9.2
|
||||
# via -r requirements/test.in
|
||||
mbstrdecoder==1.1.3
|
||||
@ -263,6 +309,8 @@ mypy-extensions==1.0.0
|
||||
# via black
|
||||
networkx==3.2.1
|
||||
# via torch
|
||||
ninja==1.11.1.3
|
||||
# via mamba-ssm
|
||||
nltk==3.9.1
|
||||
# via rouge-score
|
||||
num2words==0.5.14
|
||||
@ -355,6 +403,7 @@ packaging==24.1
|
||||
# fastparquet
|
||||
# huggingface-hub
|
||||
# lazy-loader
|
||||
# mamba-ssm
|
||||
# matplotlib
|
||||
# peft
|
||||
# plotly
|
||||
@ -426,6 +475,8 @@ pybind11==2.13.6
|
||||
# via lm-eval
|
||||
pycparser==2.22
|
||||
# via cffi
|
||||
pycryptodomex==3.22.0
|
||||
# via blobfile
|
||||
pydantic==2.9.2
|
||||
# via
|
||||
# datamodel-code-generator
|
||||
@ -436,6 +487,8 @@ pygments==2.18.0
|
||||
# via rich
|
||||
pyparsing==3.2.0
|
||||
# via matplotlib
|
||||
pyrate-limiter==3.7.0
|
||||
# via schemathesis
|
||||
pytablewriter==1.2.0
|
||||
# via lm-eval
|
||||
pytest==8.3.3
|
||||
@ -448,7 +501,9 @@ pytest==8.3.3
|
||||
# pytest-mock
|
||||
# pytest-rerunfailures
|
||||
# pytest-shard
|
||||
# pytest-subtests
|
||||
# pytest-timeout
|
||||
# schemathesis
|
||||
pytest-asyncio==0.24.0
|
||||
# via -r requirements/test.in
|
||||
pytest-forked==1.6.0
|
||||
@ -459,10 +514,13 @@ pytest-rerunfailures==14.0
|
||||
# via -r requirements/test.in
|
||||
pytest-shard==0.1.2
|
||||
# via -r requirements/test.in
|
||||
pytest-subtests==0.14.1
|
||||
# via schemathesis
|
||||
pytest-timeout==2.3.1
|
||||
# via -r requirements/test.in
|
||||
python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# arrow
|
||||
# botocore
|
||||
# matplotlib
|
||||
# pandas
|
||||
@ -484,6 +542,7 @@ pyyaml==6.0.2
|
||||
# peft
|
||||
# ray
|
||||
# responses
|
||||
# schemathesis
|
||||
# timm
|
||||
# transformers
|
||||
# vocos
|
||||
@ -514,10 +573,16 @@ requests==2.32.3
|
||||
# pooch
|
||||
# ray
|
||||
# responses
|
||||
# schemathesis
|
||||
# starlette-testclient
|
||||
# tiktoken
|
||||
# transformers
|
||||
responses==0.25.3
|
||||
# via genai-perf
|
||||
rfc3339-validator==0.1.4
|
||||
# via jsonschema
|
||||
rfc3987==1.3.8
|
||||
# via jsonschema
|
||||
rich==13.9.4
|
||||
# via
|
||||
# genai-perf
|
||||
@ -546,6 +611,8 @@ safetensors==0.4.5
|
||||
# peft
|
||||
# timm
|
||||
# transformers
|
||||
schemathesis==3.39.15
|
||||
# via -r requirements/test.in
|
||||
scikit-learn==1.5.2
|
||||
# via
|
||||
# librosa
|
||||
@ -564,18 +631,23 @@ sentencepiece==0.2.0
|
||||
# via mistral-common
|
||||
setuptools==75.8.0
|
||||
# via
|
||||
# mamba-ssm
|
||||
# pytablewriter
|
||||
# torch
|
||||
shellingham==1.5.4
|
||||
# via typer
|
||||
six==1.16.0
|
||||
# via
|
||||
# junit-xml
|
||||
# python-dateutil
|
||||
# rfc3339-validator
|
||||
# rouge-score
|
||||
sniffio==1.3.1
|
||||
# via
|
||||
# anyio
|
||||
# httpx
|
||||
sortedcontainers==2.4.0
|
||||
# via hypothesis
|
||||
soundfile==0.12.1
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
@ -584,6 +656,12 @@ soxr==0.5.0.post1
|
||||
# via librosa
|
||||
sqlitedict==2.1.0
|
||||
# via lm-eval
|
||||
starlette==0.46.2
|
||||
# via
|
||||
# schemathesis
|
||||
# starlette-testclient
|
||||
starlette-testclient==0.4.1
|
||||
# via schemathesis
|
||||
statsmodels==0.14.4
|
||||
# via genai-perf
|
||||
sympy==1.13.1
|
||||
@ -610,8 +688,14 @@ tiktoken==0.7.0
|
||||
# mistral-common
|
||||
timm==1.0.11
|
||||
# via -r requirements/test.in
|
||||
tokenizers==0.21.0
|
||||
# via transformers
|
||||
tokenizers==0.21.1
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# transformers
|
||||
tomli==2.2.1
|
||||
# via schemathesis
|
||||
tomli-w==1.2.0
|
||||
# via schemathesis
|
||||
torch==2.6.0
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
@ -620,6 +704,7 @@ torch==2.6.0
|
||||
# encodec
|
||||
# fastsafetensors
|
||||
# lm-eval
|
||||
# mamba-ssm
|
||||
# peft
|
||||
# runai-model-streamer
|
||||
# sentence-transformers
|
||||
@ -657,6 +742,7 @@ transformers==4.51.1
|
||||
# -r requirements/test.in
|
||||
# genai-perf
|
||||
# lm-eval
|
||||
# mamba-ssm
|
||||
# peft
|
||||
# sentence-transformers
|
||||
# transformers-stream-generator
|
||||
@ -675,6 +761,8 @@ typepy==1.3.2
|
||||
# tabledata
|
||||
typer==0.15.2
|
||||
# via fastsafetensors
|
||||
types-python-dateutil==2.9.0.20241206
|
||||
# via arrow
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# huggingface-hub
|
||||
@ -687,8 +775,11 @@ typing-extensions==4.12.2
|
||||
# typer
|
||||
tzdata==2024.2
|
||||
# via pandas
|
||||
uri-template==1.3.0
|
||||
# via jsonschema
|
||||
urllib3==2.2.3
|
||||
# via
|
||||
# blobfile
|
||||
# botocore
|
||||
# requests
|
||||
# responses
|
||||
@ -697,6 +788,10 @@ vector-quantize-pytorch==1.21.2
|
||||
# via -r requirements/test.in
|
||||
vocos==0.1.0
|
||||
# via -r requirements/test.in
|
||||
webcolors==24.11.1
|
||||
# via jsonschema
|
||||
werkzeug==3.1.3
|
||||
# via schemathesis
|
||||
word2number==1.1
|
||||
# via lm-eval
|
||||
xxhash==3.5.0
|
||||
@ -704,6 +799,8 @@ xxhash==3.5.0
|
||||
# datasets
|
||||
# evaluate
|
||||
yarl==1.17.1
|
||||
# via aiohttp
|
||||
# via
|
||||
# aiohttp
|
||||
# schemathesis
|
||||
zstandard==0.23.0
|
||||
# via lm-eval
|
||||
|
||||
@ -17,9 +17,8 @@ ray[data]
|
||||
--find-links https://storage.googleapis.com/libtpu-releases/index.html
|
||||
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
|
||||
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250408-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
||||
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250408-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
||||
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250408-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||
torch==2.8.0.dev20250408
|
||||
torchvision==0.22.0.dev20250408
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user