Compare commits
56 Commits
split_kv_c
...
codex/remo
| Author | SHA1 | Date | |
|---|---|---|---|
| 85013bf094 | |||
| 07665f8679 | |||
| 9fac6aa30b | |||
| a53ad626d6 | |||
| 1c3dad22ff | |||
| d2a30a2d93 | |||
| 75fb112d80 | |||
| 38db529f66 | |||
| 064cac7bb7 | |||
| e19bce40a1 | |||
| 505805b645 | |||
| bbdc0f2366 | |||
| dc34059360 | |||
| c4cb0af98a | |||
| 1c3b1634aa | |||
| 2ea50e977a | |||
| b419937c78 | |||
| 5f696c33b1 | |||
| 67244c86f0 | |||
| 072d7e53e5 | |||
| 01a583fea4 | |||
| bc19d75985 | |||
| fbd6523ac0 | |||
| 470484a4f5 | |||
| 21da73343a | |||
| 66072b36db | |||
| 3ed1ec4af2 | |||
| 5a33ae9a3f | |||
| c9ff9e6f0c | |||
| eaffe4486c | |||
| 8ed039d527 | |||
| 37970105fe | |||
| cc935fdd7e | |||
| abdfcd4f3d | |||
| 4f02b77de4 | |||
| 29283e8976 | |||
| 05b044e698 | |||
| aa3f105c59 | |||
| ef7eefe17a | |||
| 350c94deb3 | |||
| f4cd80f944 | |||
| 349e0e3462 | |||
| 81b16a2bc9 | |||
| e111d5b0ae | |||
| a904ea78ea | |||
| b7433ca1a4 | |||
| 5c65a72bb1 | |||
| 9d8a2d86d2 | |||
| 3bc18127ff | |||
| bec060fd99 | |||
| 52bc9d5b3e | |||
| dc2979c585 | |||
| 027d37df38 | |||
| b98219670f | |||
| 32baf1d036 | |||
| 3127274d02 |
@ -167,12 +167,6 @@ if [[ $commands == *" entrypoints/llm "* ]]; then
|
||||
--ignore=entrypoints/llm/test_prompt_validation.py "}
|
||||
fi
|
||||
|
||||
#Obsolete currently
|
||||
##ignore certain Entrypoints/llm tests
|
||||
#if [[ $commands == *" && pytest -v -s entrypoints/llm/test_guided_generate.py"* ]]; then
|
||||
# commands=${commands//" && pytest -v -s entrypoints/llm/test_guided_generate.py"/" "}
|
||||
#fi
|
||||
|
||||
# --ignore=entrypoints/openai/test_encoder_decoder.py \
|
||||
# --ignore=entrypoints/openai/test_embedding.py \
|
||||
# --ignore=entrypoints/openai/test_oot_registration.py
|
||||
|
||||
@ -46,22 +46,18 @@ steps:
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/async_engine
|
||||
- tests/test_inputs.py
|
||||
- tests/test_outputs.py
|
||||
- tests/multimodal
|
||||
- tests/utils_
|
||||
- tests/worker
|
||||
- tests/standalone_tests/lazy_imports.py
|
||||
- tests/transformers_utils
|
||||
commands:
|
||||
- python3 standalone_tests/lazy_imports.py
|
||||
- pytest -v -s async_engine # AsyncLLMEngine
|
||||
- pytest -v -s test_inputs.py
|
||||
- pytest -v -s test_outputs.py
|
||||
- pytest -v -s multimodal
|
||||
- pytest -v -s utils_ # Utils
|
||||
- pytest -v -s worker # Worker
|
||||
- pytest -v -s transformers_utils # transformers_utils
|
||||
|
||||
- label: Python-only Installation Test # 10min
|
||||
@ -82,14 +78,12 @@ steps:
|
||||
- vllm/
|
||||
- tests/basic_correctness/test_basic_correctness
|
||||
- tests/basic_correctness/test_cpu_offload
|
||||
- tests/basic_correctness/test_preemption
|
||||
- tests/basic_correctness/test_cumem.py
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s basic_correctness/test_cumem.py
|
||||
- pytest -v -s basic_correctness/test_basic_correctness.py
|
||||
- pytest -v -s basic_correctness/test_cpu_offload.py
|
||||
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
|
||||
|
||||
- label: Entrypoints Unit Tests # 5min
|
||||
timeout_in_minutes: 10
|
||||
@ -114,8 +108,7 @@ steps:
|
||||
- tests/entrypoints/offline_mode
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_collective_rpc.py
|
||||
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_collective_rpc.py
|
||||
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
|
||||
- VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
||||
|
||||
@ -287,6 +280,7 @@ steps:
|
||||
# split the test to avoid interference
|
||||
- pytest -v -s v1/core
|
||||
- pytest -v -s v1/executor
|
||||
- pytest -v -s v1/kv_offload
|
||||
- pytest -v -s v1/sample
|
||||
- pytest -v -s v1/logits_processors
|
||||
- pytest -v -s v1/worker
|
||||
|
||||
8
.github/CODEOWNERS
vendored
8
.github/CODEOWNERS
vendored
@ -37,11 +37,10 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||
/vllm/v1/attention/backends/triton_attn.py @tdoublep
|
||||
/vllm/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat @heheda12345 @ApostaC
|
||||
/vllm/v1/kv_cache_interface.py @heheda12345
|
||||
/vllm/v1/worker/kv_cache_initializer_mixin.py @heheda12345
|
||||
/vllm/v1/offloading @ApostaC
|
||||
|
||||
# Test ownership
|
||||
/.buildkite/lm-eval-harness @mgoin @simon-mo
|
||||
/tests/async_engine @njhill @robertgshaw2-redhat @simon-mo
|
||||
/tests/distributed/test_multi_node_assignment.py @youkaichao
|
||||
/tests/distributed/test_pipeline_parallel.py @youkaichao
|
||||
/tests/distributed/test_same_node.py @youkaichao
|
||||
@ -50,7 +49,6 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||
/tests/kernels @mgoin @tlrmchlsmth @WoosukKwon @yewentao256
|
||||
/tests/models @DarkLight1337 @ywang96
|
||||
/tests/multimodal @DarkLight1337 @ywang96 @NickLucche
|
||||
/tests/prefix_caching @comaniac @KuntaiDu
|
||||
/tests/quantization @mgoin @robertgshaw2-redhat @yewentao256
|
||||
/tests/test_inputs.py @DarkLight1337 @ywang96
|
||||
/tests/v1/entrypoints/llm/test_struct_output_generate.py @mgoin @russellb @aarnphm
|
||||
@ -63,6 +61,10 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||
/tests/v1/kv_connector @ApostaC
|
||||
/tests/v1/offloading @ApostaC
|
||||
|
||||
# Transformers backend
|
||||
/vllm/model_executor/models/transformers.py @hmellor
|
||||
/tests/models/test_transformers.py @hmellor
|
||||
|
||||
# Docs
|
||||
/docs @hmellor
|
||||
mkdocs.yaml @hmellor
|
||||
|
||||
19
.github/mergify.yml
vendored
19
.github/mergify.yml
vendored
@ -171,7 +171,7 @@ pull_request_rules:
|
||||
- files=examples/online_serving/openai_chat_completion_structured_outputs.py
|
||||
- files=examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py
|
||||
- files~=^tests/v1/structured_output/
|
||||
- files=tests/v1/entrypoints/llm/test_guided_generate.py
|
||||
- files=tests/v1/entrypoints/llm/test_struct_output_generate.py
|
||||
- files~=^vllm/v1/structured_output/
|
||||
actions:
|
||||
label:
|
||||
@ -302,3 +302,20 @@ pull_request_rules:
|
||||
label:
|
||||
remove:
|
||||
- needs-rebase
|
||||
|
||||
- name: label-kv-connector
|
||||
description: Automatically apply kv-connector label
|
||||
conditions:
|
||||
- or:
|
||||
- files~=^examples/online_serving/disaggregated[^/]*/.*
|
||||
- files~=^examples/offline_inference/disaggregated[^/]*/.*
|
||||
- files~=^examples/others/lmcache/
|
||||
- files~=^tests/v1/kv_connector/
|
||||
- files~=^vllm/distributed/kv_transfer/
|
||||
- title~=(?i)\bP/?D\b
|
||||
- title~=(?i)NIXL
|
||||
- title~=(?i)LMCache
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
- kv-connector
|
||||
@ -164,9 +164,7 @@ repos:
|
||||
name: Validate configuration has default values and that each field has a docstring
|
||||
entry: python tools/validate_config.py
|
||||
language: python
|
||||
types: [python]
|
||||
pass_filenames: true
|
||||
files: vllm/config.py|tests/test_config.py|vllm/entrypoints/openai/cli_args.py
|
||||
additional_dependencies: [regex]
|
||||
# Keep `suggestion` last
|
||||
- id: suggestion
|
||||
name: Suggestion
|
||||
|
||||
@ -696,11 +696,11 @@ def evaluate(ret, args):
|
||||
return re.match(args.regex, actual) is not None
|
||||
|
||||
def _eval_correctness(expected, actual):
|
||||
if args.structure_type == "guided_json":
|
||||
if args.structure_type == "json":
|
||||
return _eval_correctness_json(expected, actual)
|
||||
elif args.structure_type == "guided_regex":
|
||||
elif args.structure_type == "regex":
|
||||
return _eval_correctness_regex(expected, actual)
|
||||
elif args.structure_type == "guided_choice":
|
||||
elif args.structure_type == "choice":
|
||||
return _eval_correctness_choice(expected, actual)
|
||||
else:
|
||||
return None
|
||||
@ -780,18 +780,18 @@ def main(args: argparse.Namespace):
|
||||
)
|
||||
|
||||
if args.dataset == "grammar":
|
||||
args.structure_type = "guided_grammar"
|
||||
args.structure_type = "grammar"
|
||||
elif args.dataset == "regex":
|
||||
args.structure_type = "guided_regex"
|
||||
args.structure_type = "regex"
|
||||
elif args.dataset == "choice":
|
||||
args.structure_type = "guided_choice"
|
||||
args.structure_type = "choice"
|
||||
else:
|
||||
args.structure_type = "guided_json"
|
||||
args.structure_type = "json"
|
||||
|
||||
if args.no_structured_output:
|
||||
args.structured_output_ratio = 0
|
||||
if args.save_results:
|
||||
result_file_name = f"{args.structured_output_ratio}guided"
|
||||
result_file_name = f"{args.structured_output_ratio}so"
|
||||
result_file_name += f"_{backend}"
|
||||
result_file_name += f"_{args.request_rate}qps"
|
||||
result_file_name += f"_{args.model.split('/')[-1]}"
|
||||
|
||||
@ -17,4 +17,8 @@
|
||||
#warning "unsupported vLLM cpu implementation"
|
||||
#endif
|
||||
|
||||
#ifdef _OPENMP
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
#endif
|
||||
@ -21,6 +21,7 @@
|
||||
#include <torch/all.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda/std/limits>
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
namespace cg = cooperative_groups;
|
||||
@ -28,7 +29,6 @@ namespace cg = cooperative_groups;
|
||||
namespace vllm {
|
||||
namespace moe {
|
||||
|
||||
constexpr float kNegInfinity = INFINITY * -1;
|
||||
constexpr unsigned FULL_WARP_MASK = 0xffffffff;
|
||||
constexpr int32_t WARP_SIZE = 32;
|
||||
constexpr int32_t BLOCK_SIZE = 512;
|
||||
@ -411,14 +411,21 @@ __device__ inline float cuda_cast<float, __nv_bfloat16>(__nv_bfloat16 val) {
|
||||
return __bfloat162float(val);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ inline T neg_inf() {
|
||||
// cuda::std::numeric_limits<T>::infinity() returns `0` for [T=bf16 or fp16]
|
||||
// so we need to cast from fp32
|
||||
return cuda_cast<T, float>(-cuda::std::numeric_limits<float>::infinity());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void topk_with_k2(T* output, T const* input,
|
||||
cg::thread_block_tile<32> const& tile,
|
||||
int32_t const lane_id,
|
||||
int const num_experts_per_group) {
|
||||
// Get the top2 per thread
|
||||
T largest = -INFINITY;
|
||||
T second_largest = -INFINITY;
|
||||
T largest = neg_inf<T>();
|
||||
T second_largest = neg_inf<T>();
|
||||
|
||||
if (num_experts_per_group > WARP_SIZE) {
|
||||
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
|
||||
@ -513,8 +520,8 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
warp_id * topk;
|
||||
s_topk_idx += warp_id * topk;
|
||||
|
||||
T value = kNegInfinity;
|
||||
T topk_group_value = kNegInfinity;
|
||||
T value = neg_inf<T>();
|
||||
T topk_group_value = neg_inf<T>();
|
||||
int32_t num_equalto_topkth_group;
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
@ -525,11 +532,8 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
if (case_id < num_tokens) {
|
||||
// calculate group_idx
|
||||
int32_t target_num_min = WARP_SIZE - n_group + topk_group;
|
||||
if (lane_id < n_group &&
|
||||
(isfinite(cuda_cast<float, T>(
|
||||
group_scores[lane_id])))) // The check is necessary to avoid
|
||||
// abnormal input
|
||||
{
|
||||
// The check is necessary to avoid abnormal input
|
||||
if (lane_id < n_group && cuda::std::isfinite(group_scores[lane_id])) {
|
||||
value = group_scores[lane_id];
|
||||
}
|
||||
|
||||
@ -540,11 +544,11 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
__syncwarp(); // Ensure all threads have valid data before reduction
|
||||
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
|
||||
if (value == topk_group_value) {
|
||||
value = kNegInfinity;
|
||||
value = neg_inf<T>();
|
||||
}
|
||||
pre_count_equal_to_top_value = count_equal_to_top_value;
|
||||
count_equal_to_top_value = __popc(__ballot_sync(
|
||||
FULL_WARP_MASK, (value == cuda_cast<T, float>(kNegInfinity))));
|
||||
count_equal_to_top_value =
|
||||
__popc(__ballot_sync(FULL_WARP_MASK, (value == neg_inf<T>())));
|
||||
}
|
||||
num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
|
||||
}
|
||||
@ -552,11 +556,10 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
|
||||
warp_topk::WarpSelect</*capability*/ WARP_SIZE, /*greater*/ true, T, int32_t,
|
||||
/* is_stable */ true>
|
||||
queue((int32_t)topk, -INFINITY);
|
||||
queue((int32_t)topk, neg_inf<T>());
|
||||
|
||||
int count_equalto_topkth_group = 0;
|
||||
bool if_proceed_next_topk =
|
||||
(topk_group_value != cuda_cast<T, float>(kNegInfinity));
|
||||
bool if_proceed_next_topk = topk_group_value != neg_inf<T>();
|
||||
if (case_id < num_tokens && if_proceed_next_topk) {
|
||||
for (int i_group = 0; i_group < n_group; i_group++) {
|
||||
if ((group_scores[i_group] > topk_group_value) ||
|
||||
@ -566,10 +569,10 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
for (int32_t i = lane_id; i < align_num_experts_per_group;
|
||||
i += WARP_SIZE) {
|
||||
T candidates =
|
||||
(i < num_experts_per_group) && isfinite(cuda_cast<float, T>(
|
||||
scores_with_bias[offset + i]))
|
||||
(i < num_experts_per_group) &&
|
||||
cuda::std::isfinite(scores_with_bias[offset + i])
|
||||
? scores_with_bias[offset + i]
|
||||
: cuda_cast<T, float>(kNegInfinity);
|
||||
: neg_inf<T>();
|
||||
queue.add(candidates, offset + i);
|
||||
}
|
||||
if (group_scores[i_group] == topk_group_value) {
|
||||
@ -598,7 +601,8 @@ __global__ void group_idx_and_topk_idx_kernel(
|
||||
if (i < topk) {
|
||||
s_topk_value[i] = value;
|
||||
}
|
||||
topk_sum += reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
|
||||
topk_sum +=
|
||||
cg::reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -365,7 +365,6 @@ __global__ void silu_mul_fp8_quant_deep_gemm_kernel(
|
||||
int32_t compute_pipeline_offset_64 = 0;
|
||||
|
||||
for (int32_t t = n_tokens_lower; t < n_tokens_upper; ++t) {
|
||||
__nv_bfloat16 y_max_bf16 = EPS;
|
||||
__nv_bfloat162 results_bf162[2];
|
||||
|
||||
cp_async_wait<NUM_STAGES - 2>();
|
||||
@ -405,7 +404,7 @@ __global__ void silu_mul_fp8_quant_deep_gemm_kernel(
|
||||
auto _y_max2 =
|
||||
__hmax2(__habs2(results_bf162[0]), __habs2(results_bf162[1]));
|
||||
|
||||
y_max_bf16 = __hmax(_y_max2.x, _y_max2.y);
|
||||
__nv_bfloat16 y_max_bf16 = __hmax(EPS, __hmax(_y_max2.x, _y_max2.y));
|
||||
|
||||
// An entire group is assigned to a single warp, so a simple warp reduce
|
||||
// is used.
|
||||
|
||||
@ -29,7 +29,10 @@ ARG VLLM_BRANCH="main"
|
||||
ONBUILD RUN git clone ${VLLM_REPO} \
|
||||
&& cd vllm \
|
||||
&& git fetch -v --prune -- origin ${VLLM_BRANCH} \
|
||||
&& git checkout FETCH_HEAD
|
||||
&& git checkout FETCH_HEAD \
|
||||
&& if [ ${VLLM_REPO} != "https://github.com/vllm-project/vllm.git" ] ; then \
|
||||
git remote add upstream "https://github.com/vllm-project/vllm.git" \
|
||||
&& git fetch upstream ; fi
|
||||
FROM fetch_vllm_${REMOTE_VLLM} AS fetch_vllm
|
||||
|
||||
# -----------------------
|
||||
|
||||
@ -1,25 +1,23 @@
|
||||
ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:6.4.1-complete
|
||||
ARG HIPBLASLT_BRANCH="aa0bda7b"
|
||||
ARG HIPBLAS_COMMON_BRANCH="9b80ba8e"
|
||||
ARG LEGACY_HIPBLASLT_OPTION=
|
||||
ARG TRITON_BRANCH="e5be006"
|
||||
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
|
||||
ARG PYTORCH_BRANCH="f717b2af"
|
||||
ARG PYTORCH_VISION_BRANCH="v0.21.0"
|
||||
ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:7.0-complete
|
||||
ARG TRITON_BRANCH="f9e5bf54"
|
||||
ARG TRITON_REPO="https://github.com/ROCm/triton.git"
|
||||
ARG PYTORCH_BRANCH="b2fb6885"
|
||||
ARG PYTORCH_VISION_BRANCH="v0.23.0"
|
||||
ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git"
|
||||
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
|
||||
ARG FA_BRANCH="1a7f4dfa"
|
||||
ARG FA_BRANCH="0e60e394"
|
||||
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
|
||||
ARG AITER_BRANCH="4822e675"
|
||||
ARG AITER_BRANCH="2ab9f4cd"
|
||||
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
|
||||
|
||||
FROM ${BASE_IMAGE} AS base
|
||||
|
||||
ENV PATH=/opt/rocm/llvm/bin:$PATH
|
||||
ENV PATH=/opt/rocm/llvm/bin:/opt/rocm/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
|
||||
ENV ROCM_PATH=/opt/rocm
|
||||
ENV LD_LIBRARY_PATH=/opt/rocm/lib:/usr/local/lib:
|
||||
ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942;gfx1100;gfx1101;gfx1200;gfx1201
|
||||
ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942;gfx950;gfx1100;gfx1101;gfx1200;gfx1201
|
||||
ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
|
||||
ENV AITER_ROCM_ARCH=gfx942;gfx950
|
||||
|
||||
ARG PYTHON_VERSION=3.12
|
||||
|
||||
@ -45,29 +43,6 @@ RUN apt-get update -y \
|
||||
|
||||
RUN pip install -U packaging 'cmake<4' ninja wheel 'setuptools<80' pybind11 Cython
|
||||
|
||||
FROM base AS build_hipblaslt
|
||||
ARG HIPBLASLT_BRANCH
|
||||
ARG HIPBLAS_COMMON_BRANCH
|
||||
# Set to "--legacy_hipblas_direct" for ROCm<=6.2
|
||||
ARG LEGACY_HIPBLASLT_OPTION
|
||||
RUN git clone https://github.com/ROCm/hipBLAS-common.git
|
||||
RUN apt-get remove -y hipblaslt && apt-get autoremove -y && apt-get autoclean -y
|
||||
RUN cd hipBLAS-common \
|
||||
&& git checkout ${HIPBLAS_COMMON_BRANCH} \
|
||||
&& mkdir build \
|
||||
&& cd build \
|
||||
&& cmake .. \
|
||||
&& make package \
|
||||
&& dpkg -i ./*.deb
|
||||
RUN git clone https://github.com/ROCm/hipBLASLt
|
||||
RUN cd hipBLASLt \
|
||||
&& git checkout ${HIPBLASLT_BRANCH} \
|
||||
&& apt-get install -y llvm-dev \
|
||||
&& ./install.sh -dc --architecture ${PYTORCH_ROCM_ARCH} ${LEGACY_HIPBLASLT_OPTION} \
|
||||
&& cd build/release \
|
||||
&& make package
|
||||
RUN mkdir -p /app/install && cp /app/hipBLASLt/build/release/*.deb /app/hipBLAS-common/build/*.deb /app/install
|
||||
|
||||
FROM base AS build_triton
|
||||
ARG TRITON_BRANCH
|
||||
ARG TRITON_REPO
|
||||
@ -121,13 +96,11 @@ RUN cd aiter \
|
||||
&& git checkout ${AITER_BRANCH} \
|
||||
&& git submodule update --init --recursive \
|
||||
&& pip install -r requirements.txt
|
||||
RUN pip install pyyaml && cd aiter && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py bdist_wheel --dist-dir=dist && ls /app/aiter/dist/*.whl
|
||||
RUN pip install pyyaml && cd aiter && PREBUILD_KERNELS=1 GPU_ARCHS=${AITER_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist && ls /app/aiter/dist/*.whl
|
||||
RUN mkdir -p /app/install && cp /app/aiter/dist/*.whl /app/install
|
||||
|
||||
FROM base AS debs
|
||||
RUN mkdir /app/debs
|
||||
RUN --mount=type=bind,from=build_hipblaslt,src=/app/install/,target=/install \
|
||||
cp /install/*.deb /app/debs
|
||||
RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \
|
||||
cp /install/*.whl /app/debs
|
||||
RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
|
||||
@ -138,11 +111,6 @@ RUN --mount=type=bind,from=build_aiter,src=/app/install/,target=/install \
|
||||
cp /install/*.whl /app/debs
|
||||
|
||||
FROM base AS final
|
||||
RUN --mount=type=bind,from=build_hipblaslt,src=/app/install/,target=/install \
|
||||
dpkg -i /install/*deb \
|
||||
&& perl -p -i -e 's/, hipblas-common-dev \([^)]*?\), /, /g' /var/lib/dpkg/status \
|
||||
&& perl -p -i -e 's/, hipblaslt-dev \([^)]*?\), /, /g' /var/lib/dpkg/status \
|
||||
&& perl -p -i -e 's/, hipblaslt \([^)]*?\), /, /g' /var/lib/dpkg/status
|
||||
RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \
|
||||
pip install /install/*.whl
|
||||
RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
|
||||
@ -153,9 +121,6 @@ RUN --mount=type=bind,from=build_aiter,src=/app/install/,target=/install \
|
||||
pip install /install/*.whl
|
||||
|
||||
ARG BASE_IMAGE
|
||||
ARG HIPBLAS_COMMON_BRANCH
|
||||
ARG HIPBLASLT_BRANCH
|
||||
ARG LEGACY_HIPBLASLT_OPTION
|
||||
ARG TRITON_BRANCH
|
||||
ARG TRITON_REPO
|
||||
ARG PYTORCH_BRANCH
|
||||
@ -167,9 +132,6 @@ ARG FA_REPO
|
||||
ARG AITER_BRANCH
|
||||
ARG AITER_REPO
|
||||
RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \
|
||||
&& echo "HIPBLAS_COMMON_BRANCH: ${HIPBLAS_COMMON_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "HIPBLASLT_BRANCH: ${HIPBLASLT_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "LEGACY_HIPBLASLT_OPTION: ${LEGACY_HIPBLASLT_OPTION}" >> /app/versions.txt \
|
||||
&& echo "TRITON_BRANCH: ${TRITON_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "TRITON_REPO: ${TRITON_REPO}" >> /app/versions.txt \
|
||||
&& echo "PYTORCH_BRANCH: ${PYTORCH_BRANCH}" >> /app/versions.txt \
|
||||
@ -177,5 +139,6 @@ RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \
|
||||
&& echo "PYTORCH_REPO: ${PYTORCH_REPO}" >> /app/versions.txt \
|
||||
&& echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \
|
||||
&& echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt \
|
||||
&& echo "AITER_BRANCH: ${AITER_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt
|
||||
@ -14,7 +14,7 @@ API documentation for vLLM's configuration classes.
|
||||
- [vllm.config.LoRAConfig][]
|
||||
- [vllm.config.MultiModalConfig][]
|
||||
- [vllm.config.PoolerConfig][]
|
||||
- [vllm.config.DecodingConfig][]
|
||||
- [vllm.config.StructuredOutputsConfig][]
|
||||
- [vllm.config.ObservabilityConfig][]
|
||||
- [vllm.config.KVTransferConfig][]
|
||||
- [vllm.config.CompilationConfig][]
|
||||
@ -46,7 +46,6 @@ Engine classes for offline and online inference.
|
||||
Inference parameters for vLLM APIs.
|
||||
|
||||
[](){ #sampling-params }
|
||||
[](){ #pooling-params }
|
||||
|
||||
- [vllm.SamplingParams][]
|
||||
- [vllm.PoolingParams][]
|
||||
|
||||
@ -175,6 +175,7 @@ Regardless, you need to set `mm_encoder_tp_mode="data"` in engine arguments to u
|
||||
Known supported models:
|
||||
|
||||
- GLM-4.5V GLM-4.1V (<gh-pr:23168>)
|
||||
- InternVL (<gh-pr:23909>)
|
||||
- Kimi-VL (<gh-pr:23817>)
|
||||
- Llama4 (<gh-pr:18368>)
|
||||
- MiniCPM-V-2.5 or above (<gh-pr:23327>, <gh-pr:23948>)
|
||||
|
||||
@ -26,113 +26,123 @@ See <gh-file:LICENSE>.
|
||||
|
||||
## Developing
|
||||
|
||||
--8<-- "docs/getting_started/installation/python_env_setup.inc.md"
|
||||
|
||||
Depending on the kind of development you'd like to do (e.g. Python, CUDA), you can choose to build vLLM with or without compilation.
|
||||
Check out the [building from source][build-from-source] documentation for details.
|
||||
|
||||
For an optimized workflow when iterating on C++/CUDA kernels, see the [Incremental Compilation Workflow](./incremental_build.md) for recommendations.
|
||||
|
||||
### Building the docs with MkDocs
|
||||
|
||||
#### Introduction to MkDocs
|
||||
|
||||
[MkDocs](https://github.com/mkdocs/mkdocs) is a fast, simple and downright gorgeous static site generator that's geared towards building project documentation. Documentation source files are written in Markdown, and configured with a single YAML configuration file.
|
||||
|
||||
#### Install MkDocs and Plugins
|
||||
|
||||
Install MkDocs along with the [plugins](https://github.com/vllm-project/vllm/blob/main/mkdocs.yaml) used in the vLLM documentation, as well as required dependencies:
|
||||
|
||||
```bash
|
||||
uv pip install -r requirements/docs.txt
|
||||
```
|
||||
|
||||
!!! note
|
||||
Ensure that your Python version is compatible with the plugins (e.g., `mkdocs-awesome-nav` requires Python 3.10+)
|
||||
|
||||
#### Verify Installation
|
||||
|
||||
Confirm that MkDocs is correctly installed:
|
||||
|
||||
```bash
|
||||
mkdocs --version
|
||||
```
|
||||
|
||||
Example output:
|
||||
|
||||
```console
|
||||
mkdocs, version 1.6.1 from /opt/miniconda3/envs/mkdoc/lib/python3.10/site-packages/mkdocs (Python 3.10)
|
||||
```
|
||||
|
||||
#### Clone the `vLLM` repository
|
||||
The first step of contributing to vLLM is to clone the GitHub repository:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/vllm-project/vllm.git
|
||||
cd vllm
|
||||
```
|
||||
|
||||
#### Start the Development Server
|
||||
Then, configure your Python virtual environment.
|
||||
|
||||
MkDocs comes with a built-in dev-server that lets you preview your documentation as you work on it. Make sure you're in the same directory as the `mkdocs.yml` configuration file, and then start the server by running the `mkdocs serve` command:
|
||||
--8<-- "docs/getting_started/installation/python_env_setup.inc.md"
|
||||
|
||||
If you are only developing vLLM's Python code, install vLLM using:
|
||||
|
||||
```bash
|
||||
mkdocs serve
|
||||
VLLM_USE_PRECOMPILED=1 uv pip install -e .
|
||||
```
|
||||
|
||||
Example output:
|
||||
If you are developing vLLM's Python and CUDA/C++ code, install vLLM using:
|
||||
|
||||
```console
|
||||
INFO - Documentation built in 106.83 seconds
|
||||
INFO - [22:02:02] Watching paths for changes: 'docs', 'mkdocs.yaml'
|
||||
INFO - [22:02:02] Serving on http://127.0.0.1:8000/
|
||||
```bash
|
||||
uv pip install -e .
|
||||
```
|
||||
|
||||
#### View in Your Browser
|
||||
For more details about installing from source and installing for other hardware, check out the [installation instructions](../getting_started/installation/README.md) for your hardware and head to the "Build wheel from source" section.
|
||||
|
||||
Open up [http://127.0.0.1:8000/](http://127.0.0.1:8000/) in your browser to see a live preview:.
|
||||
|
||||
#### Learn More
|
||||
|
||||
For additional features and advanced configurations, refer to the official [MkDocs Documentation](https://www.mkdocs.org/).
|
||||
|
||||
## Testing
|
||||
|
||||
??? console "Commands"
|
||||
|
||||
```bash
|
||||
# These commands are only for Nvidia CUDA platforms.
|
||||
uv pip install -r requirements/common.txt -r requirements/dev.txt --torch-backend=auto
|
||||
|
||||
# Linting, formatting and static type checking
|
||||
pre-commit install
|
||||
|
||||
# You can manually run pre-commit with
|
||||
pre-commit run --all-files --show-diff-on-failure
|
||||
|
||||
# To manually run something from CI that does not run
|
||||
# locally by default, you can run:
|
||||
pre-commit run mypy-3.9 --hook-stage manual --all-files
|
||||
|
||||
# Unit tests
|
||||
pytest tests/
|
||||
|
||||
# Run tests for a single test file with detailed output
|
||||
pytest -s -v tests/test_logger.py
|
||||
```
|
||||
For an optimized workflow when iterating on C++/CUDA kernels, see the [Incremental Compilation Workflow](./incremental_build.md) for recommendations.
|
||||
|
||||
!!! tip
|
||||
Since the <gh-file:docker/Dockerfile> ships with Python 3.12, all tests in CI (except `mypy`) are run with Python 3.12.
|
||||
vLLM is compatible with Python versions 3.9 to 3.12. However, vLLM's default [Dockerfile](gh-file:docker/Dockerfile) ships with Python 3.12 and tests in CI (except `mypy`) are run with Python 3.12.
|
||||
|
||||
Therefore, we recommend developing with Python 3.12 to minimise the chance of your local environment clashing with our CI environment.
|
||||
|
||||
!!! note "Install python3-dev if Python.h is missing"
|
||||
### Linting
|
||||
|
||||
vLLM uses `pre-commit` to lint and format the codebase. See <https://pre-commit.com/#usage> if `pre-commit` is new to you. Setting up `pre-commit` is as easy as:
|
||||
|
||||
```bash
|
||||
uv pip install pre-commit
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
vLLM's `pre-commit` hooks will now run automatically every time you commit.
|
||||
|
||||
!!! tip "Tips"
|
||||
You can manually run the `pre-commit` hooks using:
|
||||
|
||||
```bash
|
||||
pre-commit run # runs on staged files
|
||||
pre-commit run -a # runs on all files (short for --all-files)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
Some `pre-commit` hooks only run in CI. If you need to, you can run them locally with:
|
||||
|
||||
```bash
|
||||
pre-commit run --hook-stage manual markdownlint
|
||||
pre-commit run --hook-stage manual mypy-3.9
|
||||
```
|
||||
|
||||
### Documentation
|
||||
|
||||
MkDocs is a fast, simple and downright gorgeous static site generator that's geared towards building project documentation. Documentation source files are written in Markdown, and configured with a single YAML configuration file, <gh-file:mkdocs.yaml>.
|
||||
|
||||
Get started with:
|
||||
|
||||
```bash
|
||||
uv pip install -r requirements/docs.txt
|
||||
```
|
||||
|
||||
!!! tip
|
||||
Ensure that your Python version is compatible with the plugins
|
||||
(e.g., `mkdocs-awesome-nav` requires Python 3.10+)
|
||||
|
||||
MkDocs comes with a built-in dev-server that lets you preview your documentation as you work on it.
|
||||
From the root of the repository, run:
|
||||
|
||||
```bash
|
||||
mkdocs serve # with API ref (~10 minutes)
|
||||
API_AUTONAV_EXCLUDE=vllm mkdocs serve # API ref off (~15 seconds)
|
||||
```
|
||||
|
||||
Once you see `Serving on http://127.0.0.1:8000/` in the logs, the live preview is ready!
|
||||
Open <http://127.0.0.1:8000/> in your browser to see it.
|
||||
|
||||
For additional features and advanced configurations, refer to the:
|
||||
|
||||
- [MkDocs documentation](https://www.mkdocs.org/)
|
||||
- [Material for MkDocs documentation](https://squidfunk.github.io/mkdocs-material/) (the MkDocs theme we use)
|
||||
|
||||
### Testing
|
||||
|
||||
vLLM uses `pytest` to test the codebase.
|
||||
|
||||
```bash
|
||||
# Install the test dependencies used in CI (CUDA only)
|
||||
uv pip install -r requirements/common.txt -r requirements/dev.txt --torch-backend=auto
|
||||
|
||||
# Install some common test dependencies (hardware agnostic)
|
||||
uv pip install pytest pytest-asyncio
|
||||
|
||||
# Run all tests
|
||||
pytest tests/
|
||||
|
||||
# Run tests for a single test file with detailed output
|
||||
pytest -s -v tests/test_logger.py
|
||||
```
|
||||
|
||||
!!! tip "Install python3-dev if Python.h is missing"
|
||||
If any of the above commands fails with `Python.h: No such file or directory`, install
|
||||
`python3-dev` with `sudo apt install python3-dev`.
|
||||
|
||||
!!! note
|
||||
!!! warning "Warnings"
|
||||
Currently, the repository is not fully checked by `mypy`.
|
||||
|
||||
!!! note
|
||||
---
|
||||
|
||||
Currently, not all unit tests pass when run on CPU platforms. If you don't have access to a GPU
|
||||
platform to run unit tests locally, rely on the continuous integration system to run the tests for
|
||||
now.
|
||||
@ -194,8 +204,7 @@ appropriately to indicate the type of change. Please use one of the following:
|
||||
The PR needs to meet the following code quality standards:
|
||||
|
||||
- We adhere to [Google Python style guide](https://google.github.io/styleguide/pyguide.html) and [Google C++ style guide](https://google.github.io/styleguide/cppguide.html).
|
||||
- Pass all linter checks. Please use `pre-commit` to format your code. See
|
||||
<https://pre-commit.com/#usage> if `pre-commit` is new to you.
|
||||
- Pass all linter checks.
|
||||
- The code needs to be well-documented to ensure future contributors can easily
|
||||
understand the code.
|
||||
- Include sufficient tests to ensure the project stays correct and robust. This
|
||||
|
||||
@ -156,7 +156,6 @@ vllm serve Qwen/Qwen2-VL-7B-Instruct
|
||||
```bash
|
||||
vllm bench serve \
|
||||
--backend openai-chat \
|
||||
--endpoint-type openai-chat \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--endpoint /v1/chat/completions \
|
||||
--dataset-name hf \
|
||||
@ -230,7 +229,6 @@ vllm serve Qwen/Qwen2-VL-7B-Instruct
|
||||
```bash
|
||||
vllm bench serve \
|
||||
--backend openai-chat \
|
||||
--endpoint-type openai-chat \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--endpoint /v1/chat/completions \
|
||||
--dataset-name hf \
|
||||
@ -245,7 +243,6 @@ vllm bench serve \
|
||||
```bash
|
||||
vllm bench serve \
|
||||
--backend openai-chat \
|
||||
--endpoint-type openai-chat \
|
||||
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||
--endpoint /v1/chat/completions \
|
||||
--dataset-name hf \
|
||||
|
||||
@ -10,12 +10,12 @@ vLLM currently supports the following reasoning models:
|
||||
|
||||
| Model Series | Parser Name | Structured Output Support | Tool Calling |
|
||||
|--------------|-------------|------------------|-------------|
|
||||
| [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `deepseek_r1` | `guided_json`, `guided_regex` | ❌ |
|
||||
| [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `guided_json`, `guided_regex` | ✅ |
|
||||
| [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `deepseek_r1` | `json`, `regex` | ❌ |
|
||||
| [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `json`, `regex` | ✅ |
|
||||
| [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a) | `granite` | ❌ | ❌ |
|
||||
| [Qwen3 series](https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f) | `qwen3` | `guided_json`, `guided_regex` | ✅ |
|
||||
| [Hunyuan A13B series](https://huggingface.co/collections/tencent/hunyuan-a13b-685ec38e5b46321e3ea7c4be) | `hunyuan_a13b` | `guided_json`, `guided_regex` | ✅ |
|
||||
| [GLM-4.5 series](https://huggingface.co/collections/zai-org/glm-45-687c621d34bda8c9e4bf503b) | `glm45` | `guided_json`, `guided_regex` | ✅ |
|
||||
| [Qwen3 series](https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f) | `qwen3` | `json`, `regex` | ✅ |
|
||||
| [Hunyuan A13B series](https://huggingface.co/collections/tencent/hunyuan-a13b-685ec38e5b46321e3ea7c4be) | `hunyuan_a13b` | `json`, `regex` | ✅ |
|
||||
| [GLM-4.5 series](https://huggingface.co/collections/zai-org/glm-45-687c621d34bda8c9e4bf503b) | `glm45` | `json`, `regex` | ✅ |
|
||||
|
||||
!!! note
|
||||
IBM Granite 3.2 reasoning is disabled by default; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`.
|
||||
|
||||
@ -12,23 +12,23 @@ You can generate structured outputs using the OpenAI's [Completions](https://pla
|
||||
|
||||
The following parameters are supported, which must be added as extra parameters:
|
||||
|
||||
- `guided_choice`: the output will be exactly one of the choices.
|
||||
- `guided_regex`: the output will follow the regex pattern.
|
||||
- `guided_json`: the output will follow the JSON schema.
|
||||
- `guided_grammar`: the output will follow the context free grammar.
|
||||
- `choice`: the output will be exactly one of the choices.
|
||||
- `regex`: the output will follow the regex pattern.
|
||||
- `json`: the output will follow the JSON schema.
|
||||
- `grammar`: the output will follow the context free grammar.
|
||||
- `structural_tag`: Follow a JSON schema within a set of specified tags within the generated text.
|
||||
|
||||
You can see the complete list of supported parameters on the [OpenAI-Compatible Server](../serving/openai_compatible_server.md) page.
|
||||
|
||||
Structured outputs are supported by default in the OpenAI-Compatible Server. You
|
||||
may choose to specify the backend to use by setting the
|
||||
`--guided-decoding-backend` flag to `vllm serve`. The default backend is `auto`,
|
||||
`--structured-outputs-config.backend` flag to `vllm serve`. The default backend is `auto`,
|
||||
which will try to choose an appropriate backend based on the details of the
|
||||
request. You may also choose a specific backend, along with
|
||||
some options. A full set of options is available in the `vllm serve --help`
|
||||
text.
|
||||
|
||||
Now let´s see an example for each of the cases, starting with the `guided_choice`, as it´s the easiest one:
|
||||
Now let´s see an example for each of the cases, starting with the `choice`, as it´s the easiest one:
|
||||
|
||||
??? code
|
||||
|
||||
@ -45,12 +45,12 @@ Now let´s see an example for each of the cases, starting with the `guided_choic
|
||||
messages=[
|
||||
{"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"}
|
||||
],
|
||||
extra_body={"guided_choice": ["positive", "negative"]},
|
||||
extra_body={"structured_outputs": {"choice": ["positive", "negative"]}},
|
||||
)
|
||||
print(completion.choices[0].message.content)
|
||||
```
|
||||
|
||||
The next example shows how to use the `guided_regex`. The idea is to generate an email address, given a simple regex template:
|
||||
The next example shows how to use the `regex`. The idea is to generate an email address, given a simple regex template:
|
||||
|
||||
??? code
|
||||
|
||||
@ -63,18 +63,18 @@ The next example shows how to use the `guided_regex`. The idea is to generate an
|
||||
"content": "Generate an example email address for Alan Turing, who works in Enigma. End in .com and new line. Example result: alan.turing@enigma.com\n",
|
||||
}
|
||||
],
|
||||
extra_body={"guided_regex": r"\w+@\w+\.com\n", "stop": ["\n"]},
|
||||
extra_body={"structured_outputs": {"regex": r"\w+@\w+\.com\n"}, "stop": ["\n"]},
|
||||
)
|
||||
print(completion.choices[0].message.content)
|
||||
```
|
||||
|
||||
One of the most relevant features in structured text generation is the option to generate a valid JSON with pre-defined fields and formats.
|
||||
For this we can use the `guided_json` parameter in two different ways:
|
||||
For this we can use the `json` parameter in two different ways:
|
||||
|
||||
- Using directly a [JSON Schema](https://json-schema.org/)
|
||||
- Defining a [Pydantic model](https://docs.pydantic.dev/latest/) and then extracting the JSON Schema from it (which is normally an easier option).
|
||||
|
||||
The next example shows how to use the `guided_json` parameter with a Pydantic model:
|
||||
The next example shows how to use the `response_format` parameter with a Pydantic model:
|
||||
|
||||
??? code
|
||||
|
||||
@ -119,7 +119,7 @@ The next example shows how to use the `guided_json` parameter with a Pydantic mo
|
||||
JSON schema and how the fields should be populated. This can improve the
|
||||
results notably in most cases.
|
||||
|
||||
Finally we have the `guided_grammar` option, which is probably the most
|
||||
Finally we have the `grammar` option, which is probably the most
|
||||
difficult to use, but it´s really powerful. It allows us to define complete
|
||||
languages like SQL queries. It works by using a context free EBNF grammar.
|
||||
As an example, we can use to define a specific format of simplified SQL queries:
|
||||
@ -149,7 +149,7 @@ As an example, we can use to define a specific format of simplified SQL queries:
|
||||
"content": "Generate an SQL query to show the 'username' and 'email' from the 'users' table.",
|
||||
}
|
||||
],
|
||||
extra_body={"guided_grammar": simplified_sql_grammar},
|
||||
extra_body={"structured_outputs": {"grammar": simplified_sql_grammar}},
|
||||
)
|
||||
print(completion.choices[0].message.content)
|
||||
```
|
||||
@ -292,8 +292,8 @@ An example of using `structural_tag` can be found here: <gh-file:examples/online
|
||||
## Offline Inference
|
||||
|
||||
Offline inference allows for the same types of structured outputs.
|
||||
To use it, we´ll need to configure the guided decoding using the class `GuidedDecodingParams` inside `SamplingParams`.
|
||||
The main available options inside `GuidedDecodingParams` are:
|
||||
To use it, we´ll need to configure the structured outputs using the class `StructuredOutputsParams` inside `SamplingParams`.
|
||||
The main available options inside `StructuredOutputsParams` are:
|
||||
|
||||
- `json`
|
||||
- `regex`
|
||||
@ -309,12 +309,12 @@ shown below:
|
||||
|
||||
```python
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
from vllm.sampling_params import StructuredOutputsParams
|
||||
|
||||
llm = LLM(model="HuggingFaceTB/SmolLM2-1.7B-Instruct")
|
||||
|
||||
guided_decoding_params = GuidedDecodingParams(choice=["Positive", "Negative"])
|
||||
sampling_params = SamplingParams(guided_decoding=guided_decoding_params)
|
||||
structured_outputs_params = StructuredOutputsParams(choice=["Positive", "Negative"])
|
||||
sampling_params = SamplingParams(structured_outputs=structured_outputs_params)
|
||||
outputs = llm.generate(
|
||||
prompts="Classify this sentiment: vLLM is wonderful!",
|
||||
sampling_params=sampling_params,
|
||||
|
||||
@ -71,7 +71,7 @@ This example demonstrates:
|
||||
* Making a request with `tool_choice="auto"`
|
||||
* Handling the structured response and executing the corresponding function
|
||||
|
||||
You can also specify a particular function using named function calling by setting `tool_choice={"type": "function", "function": {"name": "get_weather"}}`. Note that this will use the guided decoding backend - so the first time this is used, there will be several seconds of latency (or more) as the FSM is compiled for the first time before it is cached for subsequent requests.
|
||||
You can also specify a particular function using named function calling by setting `tool_choice={"type": "function", "function": {"name": "get_weather"}}`. Note that this will use the structured outputs backend - so the first time this is used, there will be several seconds of latency (or more) as the FSM is compiled for the first time before it is cached for subsequent requests.
|
||||
|
||||
Remember that it's the caller's responsibility to:
|
||||
|
||||
@ -83,19 +83,18 @@ For more advanced usage, including parallel tool calls and different model-speci
|
||||
|
||||
## Named Function Calling
|
||||
|
||||
vLLM supports named function calling in the chat completion API by default. It does so using Outlines through guided decoding, so this is
|
||||
enabled by default and will work with any supported model. You are guaranteed a validly-parsable function call - not a
|
||||
vLLM supports named function calling in the chat completion API by default. This should work with most structured outputs backends supported by vLLM. You are guaranteed a validly-parsable function call - not a
|
||||
high-quality one.
|
||||
|
||||
vLLM will use guided decoding to ensure the response matches the tool parameter object defined by the JSON schema in the `tools` parameter.
|
||||
For best results, we recommend ensuring that the expected output format / schema is specified in the prompt to ensure that the model's intended generation is aligned with the schema that it's being forced to generate by the guided decoding backend.
|
||||
vLLM will use structured outputs to ensure the response matches the tool parameter object defined by the JSON schema in the `tools` parameter.
|
||||
For best results, we recommend ensuring that the expected output format / schema is specified in the prompt to ensure that the model's intended generation is aligned with the schema that it's being forced to generate by the structured outputs backend.
|
||||
|
||||
To use a named function, you need to define the functions in the `tools` parameter of the chat completion request, and
|
||||
specify the `name` of one of the tools in the `tool_choice` parameter of the chat completion request.
|
||||
|
||||
## Required Function Calling
|
||||
|
||||
vLLM supports the `tool_choice='required'` option in the chat completion API. Similar to the named function calling, it also uses guided decoding, so this is enabled by default and will work with any supported model. The guided decoding features for `tool_choice='required'` (such as JSON schema with `anyOf`) are currently only supported in the V0 engine with the guided decoding backend `outlines`. However, support for alternative decoding backends are on the [roadmap](../usage/v1_guide.md#features) for the V1 engine.
|
||||
vLLM supports the `tool_choice='required'` option in the chat completion API. Similar to the named function calling, it also uses structured outputs, so this is enabled by default and will work with any supported model. However, support for alternative decoding backends are on the [roadmap](../usage/v1_guide.md#features) for the V1 engine.
|
||||
|
||||
When tool_choice='required' is set, the model is guaranteed to generate one or more tool calls based on the specified tool list in the `tools` parameter. The number of tool calls depends on the user's query. The output format strictly follows the schema defined in the `tools` parameter.
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
It's recommended to use [uv](https://docs.astral.sh/uv/), a very fast Python environment manager, to create and manage Python environments. Please follow the [documentation](https://docs.astral.sh/uv/#getting-started) to install `uv`. After installing `uv`, you can create a new Python environment and install vLLM using the following commands:
|
||||
It's recommended to use [uv](https://docs.astral.sh/uv/), a very fast Python environment manager, to create and manage Python environments. Please follow the [documentation](https://docs.astral.sh/uv/#getting-started) to install `uv`. After installing `uv`, you can create a new Python environment using the following commands:
|
||||
|
||||
```bash
|
||||
uv venv --python 3.12 --seed
|
||||
|
||||
@ -554,6 +554,17 @@ If your model is not in the above list, we will try to automatically convert the
|
||||
For process-supervised reward models such as `peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly,
|
||||
e.g.: `--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`.
|
||||
|
||||
#### Token Classification
|
||||
|
||||
These models primarily support the [`LLM.encode`](./pooling_models.md#llmencode) API.
|
||||
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | [V1](gh-issue:8779) |
|
||||
|--------------|--------|-------------------|-----------------------------|-----------------------------------------|---------------------|
|
||||
| `BertForTokenClassification` | bert-based | `boltuix/NeuroBERT-NER` (see note), etc. | | | ✅︎ |
|
||||
|
||||
!!! note
|
||||
Named Entity Recognition (NER) usage, please refer to <gh-file:examples/offline_inference/pooling/ner.py>, <gh-file:examples/online_serving/pooling/ner.py>.
|
||||
|
||||
[](){ #supported-mm-models }
|
||||
|
||||
## List of Multimodal Language Models
|
||||
|
||||
@ -133,7 +133,7 @@ completion = client.chat.completions.create(
|
||||
{"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"}
|
||||
],
|
||||
extra_body={
|
||||
"guided_choice": ["positive", "negative"]
|
||||
"structured_outputs": {"choice": ["positive", "negative"]}
|
||||
}
|
||||
)
|
||||
```
|
||||
@ -317,10 +317,11 @@ Full example: <gh-file:examples/online_serving/pooling/openai_chat_embedding_cli
|
||||
|
||||
#### Extra parameters
|
||||
|
||||
The following [pooling parameters][pooling-params] are supported.
|
||||
The following [pooling parameters][vllm.PoolingParams] are supported.
|
||||
|
||||
```python
|
||||
--8<-- "vllm/entrypoints/openai/protocol.py:embedding-pooling-params"
|
||||
--8<-- "vllm/pooling_params.py:common-pooling-params"
|
||||
--8<-- "vllm/pooling_params.py:embedding-pooling-params"
|
||||
```
|
||||
|
||||
The following extra parameters are supported by default:
|
||||
@ -374,7 +375,7 @@ The following extra parameters are supported:
|
||||
```python
|
||||
--8<-- "vllm/entrypoints/openai/protocol.py:transcription-extra-params"
|
||||
```
|
||||
|
||||
|
||||
[](){ #translations-api }
|
||||
|
||||
### Translations API
|
||||
@ -527,10 +528,11 @@ curl -v "http://127.0.0.1:8000/classify" \
|
||||
|
||||
#### Extra parameters
|
||||
|
||||
The following [pooling parameters][pooling-params] are supported.
|
||||
The following [pooling parameters][vllm.PoolingParams] are supported.
|
||||
|
||||
```python
|
||||
--8<-- "vllm/entrypoints/openai/protocol.py:classification-pooling-params"
|
||||
--8<-- "vllm/pooling_params.py:common-pooling-params"
|
||||
--8<-- "vllm/pooling_params.py:classification-pooling-params"
|
||||
```
|
||||
|
||||
The following extra parameters are supported:
|
||||
@ -733,10 +735,11 @@ Full example: <gh-file:examples/online_serving/openai_cross_encoder_score_for_mu
|
||||
|
||||
#### Extra parameters
|
||||
|
||||
The following [pooling parameters][pooling-params] are supported.
|
||||
The following [pooling parameters][vllm.PoolingParams] are supported.
|
||||
|
||||
```python
|
||||
--8<-- "vllm/entrypoints/openai/protocol.py:score-pooling-params"
|
||||
--8<-- "vllm/pooling_params.py:common-pooling-params"
|
||||
--8<-- "vllm/pooling_params.py:classification-pooling-params"
|
||||
```
|
||||
|
||||
The following extra parameters are supported:
|
||||
@ -815,10 +818,11 @@ Result documents will be sorted by relevance, and the `index` property can be us
|
||||
|
||||
#### Extra parameters
|
||||
|
||||
The following [pooling parameters][pooling-params] are supported.
|
||||
The following [pooling parameters][vllm.PoolingParams] are supported.
|
||||
|
||||
```python
|
||||
--8<-- "vllm/entrypoints/openai/protocol.py:rerank-pooling-params"
|
||||
--8<-- "vllm/pooling_params.py:common-pooling-params"
|
||||
--8<-- "vllm/pooling_params.py:classification-pooling-params"
|
||||
```
|
||||
|
||||
The following extra parameters are supported:
|
||||
|
||||
@ -26,8 +26,14 @@ python examples/offline_inference/pooling/embed_jina_embeddings_v3.py
|
||||
python examples/offline_inference/pooling/embed_matryoshka_fy.py
|
||||
```
|
||||
|
||||
## Named Entity Recognition (NER) usage
|
||||
|
||||
```bash
|
||||
python examples/offline_inference/pooling/ner.py
|
||||
```
|
||||
|
||||
## Qwen3 reranker usage
|
||||
|
||||
```bash
|
||||
python qwen3_reranker.py
|
||||
python examples/offline_inference/pooling/qwen3_reranker.py
|
||||
```
|
||||
|
||||
54
examples/offline_inference/pooling/ner.py
Normal file
54
examples/offline_inference/pooling/ner.py
Normal file
@ -0,0 +1,54 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Adapted from https://huggingface.co/boltuix/NeuroBERT-NER
|
||||
|
||||
from argparse import Namespace
|
||||
|
||||
from vllm import LLM, EngineArgs
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
# Set example specific arguments
|
||||
parser.set_defaults(
|
||||
model="boltuix/NeuroBERT-NER",
|
||||
runner="pooling",
|
||||
enforce_eager=True,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args: Namespace):
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Barack Obama visited Microsoft headquarters in Seattle on January 2025."
|
||||
]
|
||||
|
||||
# Create an LLM.
|
||||
llm = LLM(**vars(args))
|
||||
tokenizer = llm.get_tokenizer()
|
||||
label_map = llm.llm_engine.vllm_config.model_config.hf_config.id2label
|
||||
|
||||
# Run inference
|
||||
outputs = llm.encode(prompts)
|
||||
|
||||
for prompt, output in zip(prompts, outputs):
|
||||
logits = output.outputs.data
|
||||
predictions = logits.argmax(dim=-1)
|
||||
|
||||
# Map predictions to labels
|
||||
tokens = tokenizer.convert_ids_to_tokens(output.prompt_token_ids)
|
||||
labels = [label_map[p.item()] for p in predictions]
|
||||
|
||||
# Print results
|
||||
for token, label in zip(tokens, labels):
|
||||
if token not in tokenizer.all_special_tokens:
|
||||
print(f"{token:15} → {label}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
@ -53,7 +53,6 @@ def parse_args():
|
||||
"--method",
|
||||
type=str,
|
||||
default="eagle",
|
||||
choices=["ngram", "eagle", "eagle3", "mtp"],
|
||||
)
|
||||
parser.add_argument("--num-spec-tokens", type=int, default=2)
|
||||
parser.add_argument("--prompt-lookup-max", type=int, default=5)
|
||||
@ -118,6 +117,11 @@ def main():
|
||||
"prompt_lookup_max": args.prompt_lookup_max,
|
||||
"prompt_lookup_min": args.prompt_lookup_min,
|
||||
}
|
||||
elif args.method.endswith("mtp"):
|
||||
speculative_config = {
|
||||
"method": args.method,
|
||||
"num_speculative_tokens": args.num_spec_tokens,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"unknown method: {args.method}")
|
||||
|
||||
|
||||
@ -1,11 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This file demonstrates the example usage of guided decoding
|
||||
to generate structured outputs using vLLM. It shows how to apply
|
||||
different guided decoding techniques such as Choice, Regex, JSON schema,
|
||||
and Grammar to produce structured and formatted results
|
||||
based on specific prompts.
|
||||
This file demonstrates the example usage of structured outputs
|
||||
in vLLM. It shows how to apply different constraints such as choice,
|
||||
regex, json schema, and grammar to produce structured and formatted
|
||||
results based on specific prompts.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
@ -13,19 +12,23 @@ from enum import Enum
|
||||
from pydantic import BaseModel
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
from vllm.sampling_params import StructuredOutputsParams
|
||||
|
||||
MAX_TOKENS = 50
|
||||
|
||||
# Guided decoding by Choice (list of possible options)
|
||||
guided_decoding_params_choice = GuidedDecodingParams(choice=["Positive", "Negative"])
|
||||
sampling_params_choice = SamplingParams(guided_decoding=guided_decoding_params_choice)
|
||||
# Structured outputs by Choice (list of possible options)
|
||||
structured_outputs_params_choice = StructuredOutputsParams(
|
||||
choice=["Positive", "Negative"]
|
||||
)
|
||||
sampling_params_choice = SamplingParams(
|
||||
structured_outputs=structured_outputs_params_choice
|
||||
)
|
||||
prompt_choice = "Classify this sentiment: vLLM is wonderful!"
|
||||
|
||||
# Guided decoding by Regex
|
||||
guided_decoding_params_regex = GuidedDecodingParams(regex=r"\w+@\w+\.com\n")
|
||||
# Structured outputs by Regex
|
||||
structured_outputs_params_regex = StructuredOutputsParams(regex=r"\w+@\w+\.com\n")
|
||||
sampling_params_regex = SamplingParams(
|
||||
guided_decoding=guided_decoding_params_regex,
|
||||
structured_outputs=structured_outputs_params_regex,
|
||||
stop=["\n"],
|
||||
max_tokens=MAX_TOKENS,
|
||||
)
|
||||
@ -36,7 +39,7 @@ prompt_regex = (
|
||||
)
|
||||
|
||||
|
||||
# Guided decoding by JSON using Pydantic schema
|
||||
# Structured outputs by JSON using Pydantic schema
|
||||
class CarType(str, Enum):
|
||||
sedan = "sedan"
|
||||
suv = "SUV"
|
||||
@ -51,17 +54,16 @@ class CarDescription(BaseModel):
|
||||
|
||||
|
||||
json_schema = CarDescription.model_json_schema()
|
||||
guided_decoding_params_json = GuidedDecodingParams(json=json_schema)
|
||||
structured_outputs_params_json = StructuredOutputsParams(json=json_schema)
|
||||
sampling_params_json = SamplingParams(
|
||||
guided_decoding=guided_decoding_params_json,
|
||||
max_tokens=MAX_TOKENS,
|
||||
structured_outputs=structured_outputs_params_json, max_tokens=MAX_TOKENS
|
||||
)
|
||||
prompt_json = (
|
||||
"Generate a JSON with the brand, model and car_type of"
|
||||
"Generate a JSON with the brand, model and car_type of "
|
||||
"the most iconic car from the 90's"
|
||||
)
|
||||
|
||||
# Guided decoding by Grammar
|
||||
# Structured outputs by Grammar
|
||||
simplified_sql_grammar = """
|
||||
root ::= select_statement
|
||||
select_statement ::= "SELECT " column " from " table " where " condition
|
||||
@ -70,13 +72,15 @@ table ::= "table_1 " | "table_2 "
|
||||
condition ::= column "= " number
|
||||
number ::= "1 " | "2 "
|
||||
"""
|
||||
guided_decoding_params_grammar = GuidedDecodingParams(grammar=simplified_sql_grammar)
|
||||
structured_outputs_params_grammar = StructuredOutputsParams(
|
||||
grammar=simplified_sql_grammar
|
||||
)
|
||||
sampling_params_grammar = SamplingParams(
|
||||
guided_decoding=guided_decoding_params_grammar,
|
||||
structured_outputs=structured_outputs_params_grammar,
|
||||
max_tokens=MAX_TOKENS,
|
||||
)
|
||||
prompt_grammar = (
|
||||
"Generate an SQL query to show the 'username' and 'email'from the 'users' table."
|
||||
"Generate an SQL query to show the 'username' and 'email' from the 'users' table."
|
||||
)
|
||||
|
||||
|
||||
@ -93,16 +97,16 @@ def main():
|
||||
llm = LLM(model="Qwen/Qwen2.5-3B-Instruct", max_model_len=100)
|
||||
|
||||
choice_output = generate_output(prompt_choice, sampling_params_choice, llm)
|
||||
format_output("Guided decoding by Choice", choice_output)
|
||||
format_output("Structured outputs by Choice", choice_output)
|
||||
|
||||
regex_output = generate_output(prompt_regex, sampling_params_regex, llm)
|
||||
format_output("Guided decoding by Regex", regex_output)
|
||||
format_output("Structured outputs by Regex", regex_output)
|
||||
|
||||
json_output = generate_output(prompt_json, sampling_params_json, llm)
|
||||
format_output("Guided decoding by JSON", json_output)
|
||||
format_output("Structured outputs by JSON", json_output)
|
||||
|
||||
grammar_output = generate_output(prompt_grammar, sampling_params_grammar, llm)
|
||||
format_output("Guided decoding by Grammar", grammar_output)
|
||||
format_output("Structured outputs by Grammar", grammar_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -6,7 +6,7 @@ without any specific flags:
|
||||
|
||||
```bash
|
||||
VLLM_USE_V1=0 vllm serve unsloth/Llama-3.2-1B-Instruct \
|
||||
--guided-decoding-backend outlines
|
||||
--structured-outputs-config.backend outlines
|
||||
```
|
||||
|
||||
This example demonstrates how to generate chat completions
|
||||
|
||||
@ -12,6 +12,12 @@ python examples/online_serving/pooling/cohere_rerank_client.py
|
||||
python examples/online_serving/pooling/jinaai_rerank_client.py
|
||||
```
|
||||
|
||||
## Named Entity Recognition (NER) usage
|
||||
|
||||
```bash
|
||||
python examples/online_serving/pooling/ner.py
|
||||
```
|
||||
|
||||
## Openai chat embedding for multimodal usage
|
||||
|
||||
```bash
|
||||
|
||||
71
examples/online_serving/pooling/ner.py
Normal file
71
examples/online_serving/pooling/ner.py
Normal file
@ -0,0 +1,71 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Adapted from https://huggingface.co/boltuix/NeuroBERT-NER
|
||||
|
||||
"""
|
||||
Example online usage of Pooling API for Named Entity Recognition (NER).
|
||||
|
||||
Run `vllm serve <model> --runner pooling`
|
||||
to start up the server in vLLM. e.g.
|
||||
|
||||
vllm serve boltuix/NeuroBERT-NER
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import requests
|
||||
import torch
|
||||
|
||||
|
||||
def post_http_request(prompt: dict, api_url: str) -> requests.Response:
|
||||
headers = {"User-Agent": "Test Client"}
|
||||
response = requests.post(api_url, headers=headers, json=prompt)
|
||||
return response
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--model", type=str, default="boltuix/NeuroBERT-NER")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(args):
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
|
||||
api_url = f"http://{args.host}:{args.port}/pooling"
|
||||
model_name = args.model
|
||||
|
||||
# Load tokenizer and config
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
label_map = config.id2label
|
||||
|
||||
# Input text
|
||||
text = "Barack Obama visited Microsoft headquarters in Seattle on January 2025."
|
||||
prompt = {"model": model_name, "input": text}
|
||||
|
||||
pooling_response = post_http_request(prompt=prompt, api_url=api_url)
|
||||
|
||||
# Run inference
|
||||
output = pooling_response.json()["data"][0]
|
||||
logits = torch.tensor(output["data"])
|
||||
predictions = logits.argmax(dim=-1)
|
||||
inputs = tokenizer(text, return_tensors="pt")
|
||||
|
||||
# Map predictions to labels
|
||||
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
|
||||
labels = [label_map[p.item()] for p in predictions]
|
||||
assert len(tokens) == len(predictions)
|
||||
|
||||
# Print results
|
||||
for token, label in zip(tokens, labels):
|
||||
if token not in tokenizer.all_special_tokens:
|
||||
print(f"{token:15} → {label}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
@ -86,7 +86,7 @@ PARAMS: dict[ConstraintsFormat, dict[str, Any]] = {
|
||||
"content": "Classify this sentiment: vLLM is wonderful!",
|
||||
}
|
||||
],
|
||||
"extra_body": {"guided_choice": ["positive", "negative"]},
|
||||
"extra_body": {"structured_outputs": {"choice": ["positive", "negative"]}},
|
||||
},
|
||||
"regex": {
|
||||
"messages": [
|
||||
@ -96,7 +96,7 @@ PARAMS: dict[ConstraintsFormat, dict[str, Any]] = {
|
||||
}
|
||||
],
|
||||
"extra_body": {
|
||||
"guided_regex": r"[a-z0-9.]{1,20}@\w{6,10}\.com\n",
|
||||
"structured_outputs": {"regex": r"[a-z0-9.]{1,20}@\w{6,10}\.com\n"},
|
||||
},
|
||||
},
|
||||
"json": {
|
||||
@ -122,7 +122,8 @@ PARAMS: dict[ConstraintsFormat, dict[str, Any]] = {
|
||||
}
|
||||
],
|
||||
"extra_body": {
|
||||
"guided_grammar": """
|
||||
"structured_outputs": {
|
||||
"grammar": """
|
||||
root ::= select_statement
|
||||
|
||||
select_statement ::= "SELECT " column " from " table " where " condition
|
||||
@ -135,6 +136,7 @@ condition ::= column "= " number
|
||||
|
||||
number ::= "1 " | "2 "
|
||||
""",
|
||||
}
|
||||
},
|
||||
},
|
||||
"structural_tag": {
|
||||
|
||||
@ -79,6 +79,7 @@ plugins:
|
||||
- "re:vllm\\._.*" # Internal modules
|
||||
- "vllm.third_party"
|
||||
- "vllm.vllm_flash_attn"
|
||||
- !ENV [API_AUTONAV_EXCLUDE, "re:^$"] # Match nothing by default
|
||||
- mkdocstrings:
|
||||
handlers:
|
||||
python:
|
||||
|
||||
@ -1,54 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""vllm.entrypoints.api_server with some extra logging for testing."""
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
import uvicorn
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
|
||||
import vllm.entrypoints.api_server
|
||||
import vllm.envs as envs
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
app = vllm.entrypoints.api_server.app
|
||||
|
||||
|
||||
class AsyncLLMEngineWithStats(AsyncLLMEngine):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._num_aborts = 0
|
||||
|
||||
async def _engine_abort(self, request_ids: Iterable[str]):
|
||||
ids = list(request_ids)
|
||||
self._num_aborts += len(ids)
|
||||
await super()._engine_abort(ids)
|
||||
|
||||
def testing_stats(self) -> dict[str, Any]:
|
||||
return {"num_aborted_requests": self._num_aborts}
|
||||
|
||||
|
||||
@app.get("/stats")
|
||||
def stats() -> Response:
|
||||
"""Get the statistics of the engine."""
|
||||
return JSONResponse(engine.testing_stats())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngineWithStats.from_engine_args(engine_args)
|
||||
vllm.entrypoints.api_server.engine = engine
|
||||
uvicorn.run(app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level="debug",
|
||||
timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE)
|
||||
@ -1,12 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch):
|
||||
"""
|
||||
Since this module is V0 only, set VLLM_USE_V1=0 for
|
||||
all tests in the module.
|
||||
"""
|
||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||
@ -1,139 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copyreg
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from multiprocessing import Pool
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
import urllib3.exceptions
|
||||
|
||||
|
||||
def _pickle_new_connection_error(obj):
|
||||
"""Custom pickler for NewConnectionError to fix tblib compatibility."""
|
||||
# Extract the original message by removing the "conn: " prefix
|
||||
full_message = obj.args[0] if obj.args else ""
|
||||
if ': ' in full_message:
|
||||
# Split off the connection part and keep the actual message
|
||||
_, actual_message = full_message.split(': ', 1)
|
||||
else:
|
||||
actual_message = full_message
|
||||
return _unpickle_new_connection_error, (actual_message, )
|
||||
|
||||
|
||||
def _unpickle_new_connection_error(message):
|
||||
"""Custom unpickler for NewConnectionError."""
|
||||
# Create with None as conn and the actual message
|
||||
return urllib3.exceptions.NewConnectionError(None, message)
|
||||
|
||||
|
||||
# Register the custom pickle/unpickle functions for tblib compatibility
|
||||
copyreg.pickle(urllib3.exceptions.NewConnectionError,
|
||||
_pickle_new_connection_error)
|
||||
|
||||
|
||||
def _query_server(prompt: str, max_tokens: int = 5) -> dict:
|
||||
response = requests.post("http://localhost:8000/generate",
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": 0,
|
||||
"ignore_eos": True
|
||||
})
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
def _query_server_long(prompt: str) -> dict:
|
||||
return _query_server(prompt, max_tokens=500)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_server(distributed_executor_backend: str):
|
||||
script_path = Path(__file__).parent.joinpath(
|
||||
"api_server_async_engine.py").absolute()
|
||||
commands = [
|
||||
sys.executable,
|
||||
"-u",
|
||||
str(script_path),
|
||||
"--model",
|
||||
"facebook/opt-125m",
|
||||
"--host",
|
||||
"127.0.0.1",
|
||||
"--distributed-executor-backend",
|
||||
distributed_executor_backend,
|
||||
]
|
||||
|
||||
# API Server Test Requires V0.
|
||||
my_env = os.environ.copy()
|
||||
my_env["VLLM_USE_V1"] = "0"
|
||||
uvicorn_process = subprocess.Popen(commands, env=my_env)
|
||||
yield
|
||||
uvicorn_process.terminate()
|
||||
|
||||
|
||||
@pytest.mark.timeout(300)
|
||||
@pytest.mark.parametrize("distributed_executor_backend", ["mp", "ray"])
|
||||
def test_api_server(api_server, distributed_executor_backend: str):
|
||||
"""
|
||||
Run the API server and test it.
|
||||
|
||||
We run both the server and requests in separate processes.
|
||||
|
||||
We test that the server can handle incoming requests, including
|
||||
multiple requests at the same time, and that it can handle requests
|
||||
being cancelled without crashing.
|
||||
"""
|
||||
with Pool(32) as pool:
|
||||
# Wait until the server is ready
|
||||
prompts = ["warm up"] * 1
|
||||
result = None
|
||||
while not result:
|
||||
try:
|
||||
for r in pool.map(_query_server, prompts):
|
||||
result = r
|
||||
break
|
||||
except requests.exceptions.ConnectionError:
|
||||
time.sleep(1)
|
||||
|
||||
# Actual tests start here
|
||||
# Try with 1 prompt
|
||||
for result in pool.map(_query_server, prompts):
|
||||
assert result
|
||||
|
||||
num_aborted_requests = requests.get(
|
||||
"http://localhost:8000/stats").json()["num_aborted_requests"]
|
||||
assert num_aborted_requests == 0
|
||||
|
||||
# Try with 100 prompts
|
||||
prompts = ["test prompt"] * 100
|
||||
for result in pool.map(_query_server, prompts):
|
||||
assert result
|
||||
|
||||
with Pool(32) as pool:
|
||||
# Cancel requests
|
||||
prompts = ["canceled requests"] * 100
|
||||
pool.map_async(_query_server_long, prompts)
|
||||
time.sleep(0.01)
|
||||
pool.terminate()
|
||||
pool.join()
|
||||
|
||||
# check cancellation stats
|
||||
# give it some time to update the stats
|
||||
time.sleep(1)
|
||||
|
||||
num_aborted_requests = requests.get(
|
||||
"http://localhost:8000/stats").json()["num_aborted_requests"]
|
||||
assert num_aborted_requests > 0
|
||||
|
||||
# check that server still runs after cancellations
|
||||
with Pool(32) as pool:
|
||||
# Try with 100 prompts
|
||||
prompts = ["test prompt after canceled"] * 100
|
||||
for result in pool.map(_query_server, prompts):
|
||||
assert result
|
||||
@ -1,71 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.engine.async_llm_engine import RequestTracker
|
||||
from vllm.outputs import RequestOutput
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_tracker():
|
||||
tracker = RequestTracker()
|
||||
stream_1 = tracker.add_request("1")
|
||||
assert tracker.new_requests_event.is_set()
|
||||
await tracker.wait_for_new_requests()
|
||||
new, aborted = tracker.get_new_and_aborted_requests()
|
||||
assert not tracker.new_requests_event.is_set()
|
||||
assert len(new) == 1
|
||||
assert new[0]["request_id"] == "1"
|
||||
assert not aborted
|
||||
assert not stream_1.finished
|
||||
|
||||
stream_2 = tracker.add_request("2")
|
||||
stream_3 = tracker.add_request("3")
|
||||
assert tracker.new_requests_event.is_set()
|
||||
await tracker.wait_for_new_requests()
|
||||
new, aborted = tracker.get_new_and_aborted_requests()
|
||||
assert not tracker.new_requests_event.is_set()
|
||||
assert len(new) == 2
|
||||
assert new[0]["request_id"] == "2"
|
||||
assert new[1]["request_id"] == "3"
|
||||
assert not aborted
|
||||
assert not stream_2.finished
|
||||
assert not stream_3.finished
|
||||
|
||||
# request_ids must be unique
|
||||
with pytest.raises(KeyError):
|
||||
tracker.add_request("1")
|
||||
assert not tracker.new_requests_event.is_set()
|
||||
|
||||
tracker.abort_request("1")
|
||||
new, aborted = tracker.get_new_and_aborted_requests()
|
||||
assert len(aborted) == 1
|
||||
assert "1" in aborted
|
||||
assert not new
|
||||
assert stream_1.finished
|
||||
|
||||
stream_4 = tracker.add_request("4")
|
||||
tracker.abort_request("4")
|
||||
assert tracker.new_requests_event.is_set()
|
||||
await tracker.wait_for_new_requests()
|
||||
new, aborted = tracker.get_new_and_aborted_requests()
|
||||
# aborted new requests will cancel each other out -
|
||||
# there's no need for them to propagate into the
|
||||
# engine
|
||||
assert not aborted
|
||||
assert not new
|
||||
assert stream_4.finished
|
||||
|
||||
stream_5 = tracker.add_request("5")
|
||||
assert tracker.new_requests_event.is_set()
|
||||
tracker.process_request_output(
|
||||
RequestOutput("2", "output", [], [], [], finished=True))
|
||||
await tracker.wait_for_new_requests()
|
||||
new, aborted = tracker.get_new_and_aborted_requests()
|
||||
assert not tracker.new_requests_event.is_set()
|
||||
assert not aborted
|
||||
assert len(new) == 1
|
||||
assert new[0]["request_id"] == "5"
|
||||
assert stream_2.finished
|
||||
assert not stream_5.finished
|
||||
@ -1,189 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Compare the short outputs of HF and vLLM when using greedy sampling.
|
||||
|
||||
VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 has to be set before running this test.
|
||||
|
||||
Run `VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1
|
||||
pytest tests/basic_correctness/test_preemption.py`.
|
||||
"""
|
||||
import pytest
|
||||
from prometheus_client import REGISTRY
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import SamplingParams
|
||||
from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT,
|
||||
ENABLE_ARTIFICIAL_PREEMPT)
|
||||
|
||||
from ..models.utils import check_outputs_equal
|
||||
|
||||
MODELS = [
|
||||
"distilbert/distilgpt2",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch):
|
||||
"""
|
||||
We should enable this for V1, but VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT,
|
||||
so use VLLM_USE_V1=0 for all tests in the file.
|
||||
"""
|
||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def check_settings():
|
||||
assert ENABLE_ARTIFICIAL_PREEMPT is True, (
|
||||
"Use an env var VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1."
|
||||
"`VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 "
|
||||
"pytest tests/basic_correctness/test_preemption.py`")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def distributed_executor_backend() -> str:
|
||||
# When SPMD worker is used, use distributed_executor_backend="ray"
|
||||
# to test delta input optimization works with preemption.
|
||||
return "ray" if envs.VLLM_USE_RAY_SPMD_WORKER else "mp"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [96])
|
||||
@pytest.mark.parametrize("chunked_prefill_token_size", [16])
|
||||
def test_chunked_prefill_recompute(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
chunked_prefill_token_size: int,
|
||||
distributed_executor_backend: str,
|
||||
) -> None:
|
||||
"""Ensure that chunked prefill works with preemption."""
|
||||
max_num_seqs = min(chunked_prefill_token_size, 256)
|
||||
enable_chunked_prefill = False
|
||||
max_num_batched_tokens = None
|
||||
if chunked_prefill_token_size != -1:
|
||||
enable_chunked_prefill = True
|
||||
max_num_batched_tokens = chunked_prefill_token_size
|
||||
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
max_num_seqs=max_num_seqs,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
disable_log_stats=False,
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
assert (vllm_model.llm.llm_engine.scheduler[0].artificial_preempt_cnt
|
||||
< ARTIFICIAL_PREEMPTION_MAX_CNT)
|
||||
|
||||
for i in range(len(example_prompts)):
|
||||
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||
vllm_output_ids, vllm_output_str = vllm_outputs[i]
|
||||
assert hf_output_str == vllm_output_str, (
|
||||
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
||||
assert hf_output_ids == vllm_output_ids, (
|
||||
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("max_tokens", [96])
|
||||
def test_preemption(
|
||||
caplog_vllm,
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
distributed_executor_backend: str,
|
||||
) -> None:
|
||||
"""By default, recompute preemption is enabled"""
|
||||
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
disable_log_stats=False,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
assert (vllm_model.llm.llm_engine.scheduler[0].artificial_preempt_cnt
|
||||
< ARTIFICIAL_PREEMPTION_MAX_CNT)
|
||||
total_preemption = (
|
||||
vllm_model.llm.llm_engine.scheduler[0].num_cumulative_preemption)
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
assert ("is preempted by PreemptionMode.RECOMPUTE mode because there "
|
||||
"is not enough KV cache space." in caplog_vllm.text)
|
||||
# Ensure the count bucket of request-level histogram metrics matches
|
||||
# the number of requests as a simple sanity check to ensure metrics are
|
||||
# generated
|
||||
preemption_metrics = None
|
||||
for m in REGISTRY.collect():
|
||||
if m.name == "vllm:num_preemptions":
|
||||
preemption_metrics = m
|
||||
assert preemption_metrics is not None
|
||||
total_recorded_preemption = 0
|
||||
for sample in preemption_metrics.samples:
|
||||
total_recorded_preemption += sample.value
|
||||
assert total_preemption == total_recorded_preemption
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@pytest.mark.parametrize("max_tokens", [96])
|
||||
def test_preemption_infeasible(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
distributed_executor_backend: str,
|
||||
) -> None:
|
||||
"""Verify infeasible preemption request will be ignored."""
|
||||
BLOCK_SIZE = 16
|
||||
prefill_blocks = 2
|
||||
decode_blocks = max_tokens // BLOCK_SIZE
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
block_size=BLOCK_SIZE,
|
||||
# Not enough gpu blocks to complete a single sequence.
|
||||
# preemption should happen, and the sequence should be
|
||||
# ignored instead of hanging forever.
|
||||
num_gpu_blocks_override=prefill_blocks + decode_blocks // 2,
|
||||
max_model_len=((prefill_blocks + decode_blocks // 2) * BLOCK_SIZE),
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
) as vllm_model:
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens,
|
||||
ignore_eos=True)
|
||||
req_outputs = vllm_model.llm.generate(
|
||||
example_prompts,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
assert (vllm_model.llm.llm_engine.scheduler[0].artificial_preempt_cnt
|
||||
< ARTIFICIAL_PREEMPTION_MAX_CNT)
|
||||
|
||||
# Verify the request is ignored and not hang.
|
||||
for req_output in req_outputs:
|
||||
outputs = req_output.outputs
|
||||
assert len(outputs) == 1
|
||||
assert outputs[0].finish_reason == "length"
|
||||
@ -68,7 +68,7 @@ def test_bench_serve_chat(server):
|
||||
"5",
|
||||
"--endpoint",
|
||||
"/v1/chat/completions",
|
||||
"--endpoint-type",
|
||||
"--backend",
|
||||
"openai-chat",
|
||||
]
|
||||
result = subprocess.run(command, capture_output=True, text=True)
|
||||
|
||||
@ -1,11 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(run_with_both_engines):
|
||||
# Simple autouse wrapper to run both engines for each test
|
||||
# This can be promoted up to conftest.py to run for every
|
||||
# test in a package
|
||||
pass
|
||||
@ -1,83 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.inputs import token_inputs
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import Logprob, Sequence, SequenceStatus
|
||||
|
||||
|
||||
def sequence_with_eos(text: str, eos_token: str,
|
||||
eos_token_id: int) -> Sequence:
|
||||
"""
|
||||
Create a Sequence that ends with an EOS token.
|
||||
"""
|
||||
seq = Sequence(
|
||||
seq_id=0,
|
||||
inputs=token_inputs([]),
|
||||
block_size=16,
|
||||
eos_token_id=eos_token_id,
|
||||
)
|
||||
seq.output_text = text + eos_token
|
||||
|
||||
offset = eos_token_id + 1
|
||||
for i in range(offset, len(text) + offset):
|
||||
seq.append_token_id(token_id=i, logprobs={i: Logprob(0.0)})
|
||||
seq.append_token_id(token_id=eos_token_id,
|
||||
logprobs={eos_token_id: Logprob(0.0)})
|
||||
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
|
||||
return seq
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["text_wo_eos", "eos_token", "eos_token_id"], [
|
||||
("This text ends with EOS token", "</s>", 2),
|
||||
])
|
||||
@pytest.mark.parametrize("ignore_eos", [True, False])
|
||||
@pytest.mark.parametrize("include_stop_str_in_output", [True, False])
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_stop_on_eos_token(text_wo_eos: str, eos_token: str, eos_token_id: int,
|
||||
ignore_eos: bool, include_stop_str_in_output: bool):
|
||||
"""
|
||||
Test the behavior of the StopChecker's maybe_stop_sequence method
|
||||
when an EOS token is encountered.
|
||||
|
||||
This test covers:
|
||||
- When the EOS token should stop the sequence and be removed from the output
|
||||
- When the EOS token should stop the sequence and be included in the output
|
||||
- When the EOS token should be ignored, and the sequence continues
|
||||
"""
|
||||
|
||||
stop_checker = StopChecker(max_model_len=1024)
|
||||
|
||||
seq = sequence_with_eos(
|
||||
text=text_wo_eos,
|
||||
eos_token=eos_token,
|
||||
eos_token_id=eos_token_id,
|
||||
)
|
||||
new_char_count = len(eos_token)
|
||||
|
||||
# Note that `stop` and `stop_token_ids` are not specified
|
||||
sampling_params = SamplingParams(
|
||||
min_tokens=1,
|
||||
ignore_eos=ignore_eos,
|
||||
include_stop_str_in_output=include_stop_str_in_output)
|
||||
|
||||
stop_checker.maybe_stop_sequence(
|
||||
seq=seq,
|
||||
new_char_count=new_char_count,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
if ignore_eos:
|
||||
assert seq.status == SequenceStatus.RUNNING
|
||||
assert seq.output_text == text_wo_eos + eos_token
|
||||
elif include_stop_str_in_output:
|
||||
assert seq.status == SequenceStatus.FINISHED_STOPPED
|
||||
assert seq.output_text == text_wo_eos + eos_token
|
||||
else:
|
||||
assert seq.status == SequenceStatus.FINISHED_STOPPED
|
||||
assert seq.output_text == text_wo_eos
|
||||
@ -184,7 +184,7 @@ def sample_enum_json_schema():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_guided_choice():
|
||||
def sample_structured_outputs_choices():
|
||||
return [
|
||||
"Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript",
|
||||
"Ruby", "Swift", "Kotlin"
|
||||
|
||||
@ -1,82 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import sys
|
||||
from contextlib import nullcontext
|
||||
|
||||
from vllm_test_utils import BlameResult, blame
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
|
||||
|
||||
def run_normal():
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
# Create an LLM without guided decoding as a baseline.
|
||||
llm = LLM(model="distilbert/distilgpt2",
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.3)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
# Destroy the LLM object and free up the GPU memory.
|
||||
del llm
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
def run_xgrammar(sample_regex):
|
||||
# Create an LLM with guided decoding enabled.
|
||||
llm = LLM(model="distilbert/distilgpt2",
|
||||
enforce_eager=True,
|
||||
guided_decoding_backend="xgrammar",
|
||||
gpu_memory_utilization=0.3)
|
||||
prompt = f"Give an example IPv4 address with this regex: {sample_regex}"
|
||||
guided_decoding = GuidedDecodingParams(regex=sample_regex)
|
||||
sampling_params = SamplingParams(temperature=0.8,
|
||||
top_p=0.95,
|
||||
guided_decoding=guided_decoding)
|
||||
outputs = llm.generate(
|
||||
prompts=[prompt] * 2,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
)
|
||||
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
|
||||
def test_lazy_outlines(sample_regex):
|
||||
"""If users don't use guided decoding, outlines should not be imported.
|
||||
"""
|
||||
# make sure outlines is not imported
|
||||
module_name = "outlines"
|
||||
# In CI, we only check finally if the module is imported.
|
||||
# If it is indeed imported, we can rerun the test with `use_blame=True`,
|
||||
# which will trace every function call to find the first import location,
|
||||
# and help find the root cause.
|
||||
# We don't run it in CI by default because it is slow.
|
||||
use_blame = False
|
||||
context = blame(
|
||||
lambda: module_name in sys.modules) if use_blame else nullcontext()
|
||||
with context as result:
|
||||
run_normal()
|
||||
run_xgrammar(sample_regex)
|
||||
if use_blame:
|
||||
assert isinstance(result, BlameResult)
|
||||
print(f"the first import location is:\n{result.trace_stack}")
|
||||
assert module_name not in sys.modules, (
|
||||
f"Module {module_name} is imported. To see the first"
|
||||
f" import location, run the test with `use_blame=True`.")
|
||||
@ -81,13 +81,3 @@ def test_lm_eval_accuracy_v1_engine(monkeypatch: pytest.MonkeyPatch):
|
||||
more_args = ["--max-num-seqs", "64"]
|
||||
|
||||
run_test(more_args)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("more_args", MORE_ARGS_LIST)
|
||||
def test_lm_eval_accuracy_v0_engine(monkeypatch: pytest.MonkeyPatch,
|
||||
more_args):
|
||||
"""Run with the V0 Engine."""
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "0")
|
||||
run_test(more_args)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# imports for guided decoding tests
|
||||
# imports for structured outputs tests
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
@ -28,11 +28,9 @@ def monkeypatch_module():
|
||||
mpatch.undo()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=[False, True])
|
||||
def server(request, monkeypatch_module, zephyr_lora_files): #noqa: F811
|
||||
|
||||
use_v1 = request.param
|
||||
monkeypatch_module.setenv('VLLM_USE_V1', '1' if use_v1 else '0')
|
||||
@pytest.fixture(scope="module")
|
||||
def server(monkeypatch_module, zephyr_lora_files): #noqa: F811
|
||||
monkeypatch_module.setenv('VLLM_USE_V1', '1')
|
||||
|
||||
args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
@ -57,13 +55,6 @@ def server(request, monkeypatch_module, zephyr_lora_files): #noqa: F811
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def is_v1_server(server):
|
||||
import os
|
||||
assert os.environ['VLLM_USE_V1'] in ['0', '1']
|
||||
return os.environ['VLLM_USE_V1'] == '1'
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
@ -480,10 +471,10 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI,
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_guided_choice_chat(client: openai.AsyncOpenAI,
|
||||
sample_guided_choice, is_v1_server: bool):
|
||||
if not is_v1_server:
|
||||
pytest.skip("Guided decoding is only supported in v1 engine")
|
||||
async def test_structured_outputs_choice_chat(
|
||||
client: openai.AsyncOpenAI,
|
||||
sample_structured_outputs_choices,
|
||||
):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
@ -498,9 +489,10 @@ async def test_guided_choice_chat(client: openai.AsyncOpenAI,
|
||||
messages=messages,
|
||||
max_completion_tokens=10,
|
||||
temperature=0.7,
|
||||
extra_body=dict(guided_choice=sample_guided_choice))
|
||||
extra_body=dict(
|
||||
structured_outputs={"choice": sample_structured_outputs_choices}))
|
||||
choice1 = chat_completion.choices[0].message.content
|
||||
assert choice1 in sample_guided_choice
|
||||
assert choice1 in sample_structured_outputs_choices
|
||||
|
||||
messages.append({"role": "assistant", "content": choice1})
|
||||
messages.append({
|
||||
@ -512,18 +504,18 @@ async def test_guided_choice_chat(client: openai.AsyncOpenAI,
|
||||
messages=messages,
|
||||
max_completion_tokens=10,
|
||||
temperature=0.7,
|
||||
extra_body=dict(guided_choice=sample_guided_choice))
|
||||
extra_body=dict(
|
||||
structured_outputs={"choice": sample_structured_outputs_choices}))
|
||||
choice2 = chat_completion.choices[0].message.content
|
||||
assert choice2 in sample_guided_choice
|
||||
assert choice2 in sample_structured_outputs_choices
|
||||
assert choice1 != choice2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_guided_json_chat(client: openai.AsyncOpenAI, sample_json_schema,
|
||||
is_v1_server: bool):
|
||||
if not is_v1_server:
|
||||
pytest.skip("Guided decoding is only supported in v1 engine")
|
||||
|
||||
async def test_structured_outputs_json_chat(
|
||||
client: openai.AsyncOpenAI,
|
||||
sample_json_schema,
|
||||
):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
@ -538,7 +530,7 @@ async def test_guided_json_chat(client: openai.AsyncOpenAI, sample_json_schema,
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_completion_tokens=1000,
|
||||
extra_body=dict(guided_json=sample_json_schema))
|
||||
extra_body=dict(structured_outputs={"json": sample_json_schema}))
|
||||
message = chat_completion.choices[0].message
|
||||
assert message.content is not None
|
||||
json1 = json.loads(message.content)
|
||||
@ -555,7 +547,7 @@ async def test_guided_json_chat(client: openai.AsyncOpenAI, sample_json_schema,
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_completion_tokens=1000,
|
||||
extra_body=dict(guided_json=sample_json_schema))
|
||||
extra_body=dict(structured_outputs={"json": sample_json_schema}))
|
||||
message = chat_completion.choices[0].message
|
||||
assert message.content is not None
|
||||
json2 = json.loads(message.content)
|
||||
@ -565,10 +557,10 @@ async def test_guided_json_chat(client: openai.AsyncOpenAI, sample_json_schema,
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_guided_regex_chat(client: openai.AsyncOpenAI, sample_regex,
|
||||
is_v1_server: bool):
|
||||
if not is_v1_server:
|
||||
pytest.skip("Guided decoding is only supported in v1 engine")
|
||||
async def test_structured_outputs_regex_chat(
|
||||
client: openai.AsyncOpenAI,
|
||||
sample_regex,
|
||||
):
|
||||
|
||||
messages = [{
|
||||
"role": "system",
|
||||
@ -583,7 +575,7 @@ async def test_guided_regex_chat(client: openai.AsyncOpenAI, sample_regex,
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_completion_tokens=20,
|
||||
extra_body=dict(guided_regex=sample_regex))
|
||||
extra_body=dict(structured_outputs={"regex": sample_regex}))
|
||||
ip1 = chat_completion.choices[0].message.content
|
||||
assert ip1 is not None
|
||||
assert re.fullmatch(sample_regex, ip1) is not None
|
||||
@ -594,7 +586,7 @@ async def test_guided_regex_chat(client: openai.AsyncOpenAI, sample_regex,
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_completion_tokens=20,
|
||||
extra_body=dict(guided_regex=sample_regex))
|
||||
extra_body=dict(structured_outputs={"regex": sample_regex}))
|
||||
ip2 = chat_completion.choices[0].message.content
|
||||
assert ip2 is not None
|
||||
assert re.fullmatch(sample_regex, ip2) is not None
|
||||
@ -602,7 +594,7 @@ async def test_guided_regex_chat(client: openai.AsyncOpenAI, sample_regex,
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_guided_decoding_type_error(client: openai.AsyncOpenAI):
|
||||
async def test_structured_outputs_type_error(client: openai.AsyncOpenAI):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
@ -614,17 +606,19 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI):
|
||||
}]
|
||||
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
_ = await client.chat.completions.create(model=MODEL_NAME,
|
||||
messages=messages,
|
||||
extra_body=dict(guided_regex={
|
||||
1: "Python",
|
||||
2: "C++"
|
||||
}))
|
||||
_ = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
extra_body=dict(
|
||||
structured_outputs={"regex": {
|
||||
1: "Python",
|
||||
2: "C++"
|
||||
}}))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI,
|
||||
sample_guided_choice):
|
||||
async def test_structured_outputs_choice_chat_logprobs(
|
||||
client: openai.AsyncOpenAI, sample_structured_outputs_choices):
|
||||
|
||||
messages = [{
|
||||
"role": "system",
|
||||
@ -641,7 +635,8 @@ async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI,
|
||||
max_completion_tokens=10,
|
||||
logprobs=True,
|
||||
top_logprobs=5,
|
||||
extra_body=dict(guided_choice=sample_guided_choice))
|
||||
extra_body=dict(
|
||||
structured_outputs={"choice": sample_structured_outputs_choices}))
|
||||
|
||||
assert chat_completion.choices[0].logprobs is not None
|
||||
assert chat_completion.choices[0].logprobs.content is not None
|
||||
@ -653,20 +648,33 @@ async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI,
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_named_tool_use(client: openai.AsyncOpenAI, sample_json_schema,
|
||||
is_v1_server: bool):
|
||||
if not is_v1_server:
|
||||
pytest.skip("Tool use is only supported in v1 engine")
|
||||
async def test_named_tool_use(
|
||||
client: openai.AsyncOpenAI,
|
||||
sample_json_schema,
|
||||
):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
}, {
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
f"Give an example JSON for an employee profile that "
|
||||
f"fits this schema: {sample_json_schema}"
|
||||
"content": ("Give an example JSON for an employee "
|
||||
"profile using the specified tool.")
|
||||
}]
|
||||
tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "dummy_function_name",
|
||||
"description": "This is a dummy function",
|
||||
"parameters": sample_json_schema
|
||||
}
|
||||
}]
|
||||
tool_choice = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "dummy_function_name"
|
||||
}
|
||||
}
|
||||
|
||||
# non-streaming
|
||||
|
||||
@ -674,20 +682,8 @@ async def test_named_tool_use(client: openai.AsyncOpenAI, sample_json_schema,
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_completion_tokens=1000,
|
||||
tools=[{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "dummy_function_name",
|
||||
"description": "This is a dummy function",
|
||||
"parameters": sample_json_schema
|
||||
}
|
||||
}],
|
||||
tool_choice={
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "dummy_function_name"
|
||||
}
|
||||
},
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
message = chat_completion.choices[0].message
|
||||
assert len(message.content) == 0
|
||||
@ -705,25 +701,12 @@ async def test_named_tool_use(client: openai.AsyncOpenAI, sample_json_schema,
|
||||
|
||||
# streaming
|
||||
|
||||
stream = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_completion_tokens=1000,
|
||||
tools=[{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "dummy_function_name",
|
||||
"description": "This is a dummy function",
|
||||
"parameters": sample_json_schema
|
||||
}
|
||||
}],
|
||||
tool_choice={
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "dummy_function_name"
|
||||
}
|
||||
},
|
||||
stream=True)
|
||||
stream = await client.chat.completions.create(model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_completion_tokens=1000,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=True)
|
||||
|
||||
output = []
|
||||
finish_reason_count = 0
|
||||
@ -826,11 +809,7 @@ async def test_response_format_json_object(client: openai.AsyncOpenAI):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_format_json_schema(client: openai.AsyncOpenAI,
|
||||
is_v1_server: bool):
|
||||
if not is_v1_server:
|
||||
pytest.skip(
|
||||
"JSON schema response format is only supported in v1 engine")
|
||||
async def test_response_format_json_schema(client: openai.AsyncOpenAI):
|
||||
prompt = 'what is 1+1? The format is "result": 2'
|
||||
# Check that this prompt cannot lead to a valid JSON without json_schema
|
||||
for _ in range(2):
|
||||
|
||||
@ -99,3 +99,26 @@ async def test_prompt_logprobs(client: openai.AsyncOpenAI):
|
||||
|
||||
assert completion.prompt_logprobs is not None
|
||||
assert len(completion.prompt_logprobs) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_top_logprobs(client: openai.AsyncOpenAI):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Beijing is the capital of which country?"
|
||||
}]
|
||||
|
||||
completion = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
extra_body={
|
||||
"top_logprobs": -1,
|
||||
"logprobs": "true",
|
||||
},
|
||||
)
|
||||
assert completion.choices[0].logprobs is not None
|
||||
assert completion.choices[0].logprobs.content is not None
|
||||
assert len(completion.choices[0].logprobs.content) > 0
|
||||
|
||||
@ -1,831 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# imports for guided decoding tests
|
||||
import json
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import jsonschema
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import regex as re
|
||||
import requests
|
||||
# downloading lora to test lora requests
|
||||
from openai import BadRequestError
|
||||
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
# any model with a chat template should work here
|
||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||
# technically these adapters use a different base model,
|
||||
# but we're not testing generation quality here
|
||||
|
||||
GUIDED_DECODING_BACKENDS = ["outlines", "xgrammar", "guidance"]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def default_server_args(zephyr_lora_files):
|
||||
return [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--max-num-seqs",
|
||||
"128",
|
||||
"--enforce-eager",
|
||||
# lora config
|
||||
"--enable-lora",
|
||||
"--lora-modules",
|
||||
f"zephyr-lora={zephyr_lora_files}",
|
||||
"--max-lora-rank",
|
||||
"64",
|
||||
"--max-cpu-loras",
|
||||
"2",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module",
|
||||
params=["", "--disable-frontend-multiprocessing"])
|
||||
def server(default_server_args, request):
|
||||
if request.param:
|
||||
default_server_args.append(request.param)
|
||||
|
||||
original_value = os.environ.get('VLLM_USE_V1')
|
||||
os.environ['VLLM_USE_V1'] = '0'
|
||||
try:
|
||||
with RemoteOpenAIServer(MODEL_NAME,
|
||||
default_server_args) as remote_server:
|
||||
yield remote_server
|
||||
finally:
|
||||
# Restore original env value
|
||||
if original_value is None:
|
||||
os.environ.pop('VLLM_USE_V1', None)
|
||||
else:
|
||||
os.environ['VLLM_USE_V1'] = original_value
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def is_v1_server(server):
|
||||
import os
|
||||
|
||||
# For completion tests, we assume v0 since there's no explicit v1 setup
|
||||
return os.environ.get('VLLM_USE_V1', '0') == '1'
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
# first test base model, then test loras
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
|
||||
completion = await client.completions.create(model=model_name,
|
||||
prompt="Hello, my name is",
|
||||
max_tokens=5,
|
||||
temperature=0.0)
|
||||
|
||||
assert completion.id is not None
|
||||
assert completion.choices is not None and len(completion.choices) == 1
|
||||
|
||||
choice = completion.choices[0]
|
||||
assert len(choice.text) >= 5
|
||||
assert choice.finish_reason == "length"
|
||||
assert completion.usage == openai.types.CompletionUsage(
|
||||
completion_tokens=5, prompt_tokens=6, total_tokens=11)
|
||||
|
||||
# test using token IDs
|
||||
completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
assert len(completion.choices[0].text) >= 1
|
||||
assert completion.choices[0].prompt_logprobs is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI):
|
||||
# test using token IDs
|
||||
with pytest.raises(openai.BadRequestError, match="out of vocabulary"):
|
||||
# Added tokens should be rejected by the base model
|
||||
await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=[0, 0, 32000, 32001, 32002],
|
||||
echo=True,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
# first test base model, then test loras
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str):
|
||||
# test using token IDs
|
||||
completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
logprobs=None,
|
||||
)
|
||||
choice = completion.choices[0]
|
||||
assert choice.logprobs is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
# just test 1 lora
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str):
|
||||
# test using token IDs
|
||||
completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
logprobs=0,
|
||||
)
|
||||
choice = completion.choices[0]
|
||||
assert choice.logprobs is not None
|
||||
assert choice.logprobs.token_logprobs is not None
|
||||
assert choice.logprobs.top_logprobs is not None
|
||||
assert len(choice.logprobs.top_logprobs[0]) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str):
|
||||
# test using token IDs
|
||||
completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
logprobs=5,
|
||||
)
|
||||
choice = completion.choices[0]
|
||||
assert choice.logprobs is not None
|
||||
assert choice.logprobs.token_logprobs is not None
|
||||
assert choice.logprobs.top_logprobs is not None
|
||||
assert 5 <= len(choice.logprobs.top_logprobs[0]) <= 6
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
|
||||
with pytest.raises(
|
||||
(openai.BadRequestError, openai.APIError)): # test using token IDs
|
||||
await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
# vLLM has higher default max_logprobs (20 instead of 5) to support
|
||||
# both Completion API and Chat Completion API
|
||||
logprobs=21,
|
||||
)
|
||||
...
|
||||
with pytest.raises(
|
||||
(openai.BadRequestError, openai.APIError)): # test using token IDs
|
||||
stream = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
# vLLM has higher default max_logprobs (20 instead of 5) to support
|
||||
# both Completion API and Chat Completion API
|
||||
logprobs=30,
|
||||
stream=True,
|
||||
)
|
||||
async for chunk in stream:
|
||||
...
|
||||
|
||||
# the server should still work afterwards
|
||||
completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
assert len(completion.choices[0].text) >= 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name, prompt_logprobs", [(MODEL_NAME, -1),
|
||||
(MODEL_NAME, 0),
|
||||
(MODEL_NAME, 1),
|
||||
(MODEL_NAME, None)])
|
||||
async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI,
|
||||
model_name: str,
|
||||
prompt_logprobs: Optional[int]):
|
||||
params: dict = {
|
||||
"prompt": ["A robot may not injure another robot", "My name is"],
|
||||
"model": model_name,
|
||||
}
|
||||
if prompt_logprobs is not None:
|
||||
params["extra_body"] = {"prompt_logprobs": prompt_logprobs}
|
||||
|
||||
if prompt_logprobs is not None and prompt_logprobs < 0:
|
||||
with pytest.raises(BadRequestError):
|
||||
await client.completions.create(**params)
|
||||
else:
|
||||
completion = await client.completions.create(**params)
|
||||
if prompt_logprobs is not None:
|
||||
assert completion.choices[0].prompt_logprobs is not None
|
||||
assert len(completion.choices[0].prompt_logprobs) > 0
|
||||
|
||||
assert completion.choices[1].prompt_logprobs is not None
|
||||
assert len(completion.choices[1].prompt_logprobs) > 0
|
||||
|
||||
else:
|
||||
assert completion.choices[0].prompt_logprobs is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
async def test_completion_streaming(client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
prompt = "What is an LLM?"
|
||||
|
||||
single_completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
single_output = single_completion.choices[0].text
|
||||
stream = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True)
|
||||
chunks: list[str] = []
|
||||
finish_reason_count = 0
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk.choices[0].text)
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
finish_reason_count += 1
|
||||
# finish reason should only return in last block
|
||||
assert finish_reason_count == 1
|
||||
assert chunk.choices[0].finish_reason == "length"
|
||||
assert chunk.choices[0].text
|
||||
assert "".join(chunks) == single_output
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
|
||||
"""Streaming for parallel sampling.
|
||||
The tokens from multiple samples, are flattened into a single stream,
|
||||
with an index to indicate which sample the token belongs to.
|
||||
"""
|
||||
|
||||
prompt = "What is an LLM?"
|
||||
n = 3
|
||||
max_tokens = 5
|
||||
|
||||
stream = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
stream=True)
|
||||
chunks: list[list[str]] = [[] for i in range(n)]
|
||||
finish_reason_count = 0
|
||||
async for chunk in stream:
|
||||
index = chunk.choices[0].index
|
||||
text = chunk.choices[0].text
|
||||
chunks[index].append(text)
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
finish_reason_count += 1
|
||||
assert finish_reason_count == n
|
||||
for chunk in chunks:
|
||||
assert len(chunk) == max_tokens
|
||||
print("".join(chunk))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
async def test_completion_stream_options(client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
prompt = "What is the capital of France?"
|
||||
|
||||
# Test stream=True, stream_options=
|
||||
# {"include_usage": False, "continuous_usage_stats": False}
|
||||
stream = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
stream_options={
|
||||
"include_usage": False,
|
||||
"continuous_usage_stats":
|
||||
False,
|
||||
})
|
||||
|
||||
async for chunk in stream:
|
||||
assert chunk.usage is None
|
||||
|
||||
# Test stream=True, stream_options=
|
||||
# {"include_usage": False, "continuous_usage_stats": True}
|
||||
stream = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
stream_options={
|
||||
"include_usage": False,
|
||||
"continuous_usage_stats":
|
||||
True,
|
||||
})
|
||||
async for chunk in stream:
|
||||
assert chunk.usage is None
|
||||
|
||||
# Test stream=True, stream_options=
|
||||
# {"include_usage": True, "continuous_usage_stats": False}
|
||||
stream = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
stream_options={
|
||||
"include_usage": True,
|
||||
"continuous_usage_stats":
|
||||
False,
|
||||
})
|
||||
async for chunk in stream:
|
||||
if chunk.choices[0].finish_reason is None:
|
||||
assert chunk.usage is None
|
||||
else:
|
||||
assert chunk.usage is None
|
||||
final_chunk = await stream.__anext__()
|
||||
assert final_chunk.usage is not None
|
||||
assert final_chunk.usage.prompt_tokens > 0
|
||||
assert final_chunk.usage.completion_tokens > 0
|
||||
assert final_chunk.usage.total_tokens == (
|
||||
final_chunk.usage.prompt_tokens +
|
||||
final_chunk.usage.completion_tokens)
|
||||
assert final_chunk.choices == []
|
||||
|
||||
# Test stream=True, stream_options=
|
||||
# {"include_usage": True, "continuous_usage_stats": True}
|
||||
stream = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
stream_options={
|
||||
"include_usage": True,
|
||||
"continuous_usage_stats":
|
||||
True,
|
||||
})
|
||||
async for chunk in stream:
|
||||
assert chunk.usage is not None
|
||||
assert chunk.usage.prompt_tokens > 0
|
||||
assert chunk.usage.completion_tokens > 0
|
||||
assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens +
|
||||
chunk.usage.completion_tokens)
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
final_chunk = await stream.__anext__()
|
||||
assert final_chunk.usage is not None
|
||||
assert final_chunk.usage.prompt_tokens > 0
|
||||
assert final_chunk.usage.completion_tokens > 0
|
||||
assert final_chunk.usage.total_tokens == (
|
||||
final_chunk.usage.prompt_tokens +
|
||||
final_chunk.usage.completion_tokens)
|
||||
assert final_chunk.choices == []
|
||||
|
||||
# Test stream=False, stream_options=
|
||||
# {"include_usage": None}
|
||||
with pytest.raises(BadRequestError):
|
||||
await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=False,
|
||||
stream_options={"include_usage": None})
|
||||
|
||||
# Test stream=False, stream_options=
|
||||
# {"include_usage": True}
|
||||
with pytest.raises(BadRequestError):
|
||||
await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=False,
|
||||
stream_options={"include_usage": True})
|
||||
|
||||
# Test stream=False, stream_options=
|
||||
# {"continuous_usage_stats": None}
|
||||
with pytest.raises(BadRequestError):
|
||||
await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=False,
|
||||
stream_options={"continuous_usage_stats": None})
|
||||
|
||||
# Test stream=False, stream_options=
|
||||
# {"continuous_usage_stats": True}
|
||||
with pytest.raises(BadRequestError):
|
||||
await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=False,
|
||||
stream_options={"continuous_usage_stats": True})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str):
|
||||
# test both text and token IDs
|
||||
for prompts in (["Hello, my name is"] * 2, [[0, 0, 0, 0, 0]] * 2):
|
||||
# test simple list
|
||||
batch = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompts,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
assert len(batch.choices) == 2
|
||||
assert batch.choices[0].text == batch.choices[1].text
|
||||
|
||||
# test n = 2
|
||||
batch = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompts,
|
||||
n=2,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body=dict(
|
||||
# NOTE: this has to be true for n > 1 in vLLM, but
|
||||
# not necessary for official client.
|
||||
use_beam_search=True),
|
||||
)
|
||||
assert len(batch.choices) == 4
|
||||
assert batch.choices[0].text != batch.choices[
|
||||
1].text, "beam search should be different"
|
||||
assert batch.choices[0].text == batch.choices[
|
||||
2].text, "two copies of the same prompt should be the same"
|
||||
assert batch.choices[1].text == batch.choices[
|
||||
3].text, "two copies of the same prompt should be the same"
|
||||
|
||||
# test streaming
|
||||
batch = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompts,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
)
|
||||
texts = [""] * 2
|
||||
async for chunk in batch:
|
||||
assert len(chunk.choices) == 1
|
||||
choice = chunk.choices[0]
|
||||
texts[choice.index] += choice.text
|
||||
assert texts[0] == texts[1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logits_bias(client: openai.AsyncOpenAI):
|
||||
prompt = "Hello, my name is"
|
||||
max_tokens = 5
|
||||
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
||||
|
||||
# Test exclusive selection
|
||||
token_id = 1000
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=0.0,
|
||||
logit_bias={str(token_id): 100},
|
||||
seed=42,
|
||||
)
|
||||
assert len(completion.choices[0].text) >= 5
|
||||
response_tokens = tokenizer(completion.choices[0].text,
|
||||
add_special_tokens=False)["input_ids"]
|
||||
expected_tokens = tokenizer(tokenizer.decode([token_id] * 5),
|
||||
add_special_tokens=False)["input_ids"]
|
||||
assert all([
|
||||
response == expected
|
||||
for response, expected in zip(response_tokens, expected_tokens)
|
||||
])
|
||||
|
||||
# Test ban
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=0.0,
|
||||
)
|
||||
response_tokens = tokenizer(completion.choices[0].text,
|
||||
add_special_tokens=False)["input_ids"]
|
||||
first_response = completion.choices[0].text
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=0.0,
|
||||
logit_bias={str(token): -100
|
||||
for token in response_tokens},
|
||||
)
|
||||
assert first_response != completion.choices[0].text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allowed_token_ids(client: openai.AsyncOpenAI):
|
||||
prompt = "Hello, my name is"
|
||||
max_tokens = 1
|
||||
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
||||
|
||||
# Test exclusive selection
|
||||
allowed_ids = [21555, 21557, 21558]
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=0.0,
|
||||
seed=42,
|
||||
extra_body=dict(allowed_token_ids=allowed_ids),
|
||||
logprobs=1,
|
||||
)
|
||||
response_tokens = completion.choices[0].logprobs.tokens
|
||||
assert len(response_tokens) == 1
|
||||
assert tokenizer.convert_tokens_to_ids(response_tokens)[0] in allowed_ids
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||
async def test_guided_json_completion(client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str,
|
||||
sample_json_schema, is_v1_server: bool):
|
||||
if not is_v1_server:
|
||||
pytest.skip("Guided decoding is only supported in v1 engine")
|
||||
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=f"Give an example JSON for an employee profile "
|
||||
f"that fits this schema: {sample_json_schema}",
|
||||
n=3,
|
||||
temperature=1.0,
|
||||
max_tokens=500,
|
||||
extra_body=dict(guided_json=sample_json_schema,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
|
||||
assert completion.id is not None
|
||||
assert len(completion.choices) == 3
|
||||
for i in range(3):
|
||||
output_json = json.loads(completion.choices[i].text)
|
||||
jsonschema.validate(instance=output_json, schema=sample_json_schema)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||
async def test_guided_regex_completion(client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str,
|
||||
sample_regex, is_v1_server: bool):
|
||||
if not is_v1_server:
|
||||
pytest.skip("Guided decoding is only supported in v1 engine")
|
||||
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=f"Give an example IPv4 address with this regex: {sample_regex}",
|
||||
n=3,
|
||||
temperature=1.0,
|
||||
max_tokens=20,
|
||||
extra_body=dict(guided_regex=sample_regex,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
|
||||
assert completion.id is not None
|
||||
assert len(completion.choices) == 3
|
||||
for i in range(3):
|
||||
assert re.fullmatch(sample_regex,
|
||||
completion.choices[i].text) is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||
async def test_guided_choice_completion(client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str,
|
||||
sample_guided_choice,
|
||||
is_v1_server: bool):
|
||||
if not is_v1_server:
|
||||
pytest.skip("Guided decoding is only supported in v1 engine")
|
||||
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt="The best language for type-safe systems programming is ",
|
||||
n=2,
|
||||
temperature=1.0,
|
||||
max_tokens=10,
|
||||
extra_body=dict(guided_choice=sample_guided_choice,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
|
||||
assert completion.id is not None
|
||||
assert len(completion.choices) == 2
|
||||
for i in range(2):
|
||||
assert completion.choices[i].text in sample_guided_choice
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_guided_grammar(client: openai.AsyncOpenAI,
|
||||
sample_sql_statements, is_v1_server: bool):
|
||||
if not is_v1_server:
|
||||
pytest.skip("Guided grammar is only supported in v1 engine")
|
||||
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=("Generate a sql state that select col_1 from "
|
||||
"table_1 where it is equals to 1"),
|
||||
temperature=1.0,
|
||||
max_tokens=500,
|
||||
extra_body=dict(guided_grammar=sample_sql_statements))
|
||||
|
||||
content = completion.choices[0].text
|
||||
|
||||
# use Lark to parse the output, and make sure it's a valid parse tree
|
||||
from lark import Lark
|
||||
parser = Lark(sample_sql_statements)
|
||||
parser.parse(content)
|
||||
|
||||
# remove spaces for comparison b/c we removed them in the grammar
|
||||
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(" ", "")
|
||||
|
||||
assert content.strip() == ground_truth
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
# first test base model, then test loras
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
@pytest.mark.parametrize("logprobs_arg", [1, 0])
|
||||
async def test_echo_logprob_completion(client: openai.AsyncOpenAI,
|
||||
model_name: str, logprobs_arg: int):
|
||||
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
||||
# test using text and token IDs
|
||||
for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]):
|
||||
completion = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
echo=True,
|
||||
logprobs=logprobs_arg)
|
||||
|
||||
prompt_text = tokenizer.decode(prompt) if isinstance(prompt,
|
||||
list) else prompt
|
||||
assert re.search(r"^" + prompt_text, completion.choices[0].text)
|
||||
logprobs = completion.choices[0].logprobs
|
||||
assert logprobs is not None
|
||||
assert len(logprobs.text_offset) > 5
|
||||
assert (len(logprobs.token_logprobs) > 5
|
||||
and logprobs.token_logprobs[0] is None)
|
||||
assert (len(logprobs.top_logprobs) > 5
|
||||
and logprobs.top_logprobs[0] is None)
|
||||
for top_logprobs in logprobs.top_logprobs[1:]:
|
||||
assert max(logprobs_arg,
|
||||
1) <= len(top_logprobs) <= logprobs_arg + 1
|
||||
assert len(logprobs.tokens) > 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||
async def test_guided_decoding_type_error(client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str,
|
||||
sample_json_schema, sample_regex,
|
||||
is_v1_server: bool):
|
||||
if not is_v1_server:
|
||||
pytest.skip("Guided decoding is only supported in v1 engine")
|
||||
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
_ = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt="Give an example JSON that fits this schema: 42",
|
||||
extra_body=dict(guided_json=42,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
_ = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt="Give an example string that fits this regex",
|
||||
extra_body=dict(guided_regex=sample_regex,
|
||||
guided_json=sample_json_schema))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name,stream,echo",
|
||||
[
|
||||
(MODEL_NAME, False, False),
|
||||
(MODEL_NAME, False, True),
|
||||
(MODEL_NAME, True, False),
|
||||
(MODEL_NAME, True, True) # should not raise BadRequestError error
|
||||
],
|
||||
)
|
||||
async def test_echo_stream_completion(client: openai.AsyncOpenAI,
|
||||
model_name: str, stream: bool,
|
||||
echo: bool):
|
||||
saying: str = "Hello, my name is"
|
||||
result = await client.completions.create(model=model_name,
|
||||
prompt=saying,
|
||||
max_tokens=10,
|
||||
temperature=0.0,
|
||||
echo=echo,
|
||||
stream=stream)
|
||||
|
||||
stop_reason = "length"
|
||||
|
||||
if not stream:
|
||||
completion = result
|
||||
assert completion.id is not None
|
||||
assert completion.choices is not None and len(completion.choices) == 1
|
||||
|
||||
choice = completion.choices[0]
|
||||
assert len(choice.text) >= 5
|
||||
assert choice.finish_reason == stop_reason
|
||||
|
||||
if echo:
|
||||
assert choice.text is not None and saying in choice.text
|
||||
else:
|
||||
assert choice.text is not None and saying not in choice.text
|
||||
|
||||
else:
|
||||
chunks: list[str] = []
|
||||
final_finish_reason = None
|
||||
async for chunk in result:
|
||||
if chunk.choices and chunk.choices[0].text:
|
||||
chunks.append(chunk.choices[0].text)
|
||||
if chunk.choices and chunk.choices[0].finish_reason:
|
||||
final_finish_reason = chunk.choices[0].finish_reason
|
||||
|
||||
assert final_finish_reason == stop_reason
|
||||
content = "".join(chunks)
|
||||
if echo:
|
||||
assert content is not None and saying in content
|
||||
else:
|
||||
assert content is not None and saying not in content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invocations(server: RemoteOpenAIServer,
|
||||
client: openai.AsyncOpenAI):
|
||||
request_args = {
|
||||
"model": MODEL_NAME,
|
||||
"prompt": "Hello, my name is",
|
||||
"max_tokens": 5,
|
||||
"temperature": 0.0,
|
||||
"logprobs": None,
|
||||
}
|
||||
|
||||
completion = await client.completions.create(**request_args)
|
||||
|
||||
invocation_response = requests.post(server.url_for("invocations"),
|
||||
json=request_args)
|
||||
invocation_response.raise_for_status()
|
||||
|
||||
completion_output = completion.model_dump()
|
||||
invocation_output = invocation_response.json()
|
||||
|
||||
assert completion_output.keys() == invocation_output.keys()
|
||||
assert completion_output["choices"] == invocation_output["choices"]
|
||||
@ -142,7 +142,7 @@ def server(): # noqa: F811
|
||||
"--dtype",
|
||||
"half",
|
||||
"--enable-auto-tool-choice",
|
||||
"--guided-decoding-backend",
|
||||
"--structured-outputs-config.backend",
|
||||
"xgrammar",
|
||||
"--tool-call-parser",
|
||||
"hermes",
|
||||
@ -225,7 +225,7 @@ def k2_server(): # noqa: F811
|
||||
"--dtype",
|
||||
"half",
|
||||
"--enable-auto-tool-choice",
|
||||
"--guided-decoding-backend",
|
||||
"--structured-outputs-config.backend",
|
||||
"xgrammar",
|
||||
"--tool-call-parser",
|
||||
"hermes",
|
||||
|
||||
@ -14,6 +14,9 @@ from transformers import AutoConfig
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
pytest.skip("Skipping prompt_embeds test until V1 supports it.",
|
||||
allow_module_level=True)
|
||||
|
||||
# any model with a chat template should work here
|
||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||
|
||||
@ -228,3 +231,20 @@ async def test_completions_with_logprobs_and_prompt_embeds(
|
||||
assert max(logprobs_arg,
|
||||
1) <= len(top_logprobs) <= logprobs_arg + 1
|
||||
assert len(logprobs.tokens) == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_logprobs_raises_error(
|
||||
client_with_prompt_embeds: openai.AsyncOpenAI):
|
||||
with pytest.raises(BadRequestError, match="not compatible"):
|
||||
encoded_embeds = create_dummy_embeds()
|
||||
await client_with_prompt_embeds.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt="",
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body={
|
||||
"prompt_embeds": encoded_embeds,
|
||||
"prompt_logprobs": True
|
||||
},
|
||||
)
|
||||
|
||||
@ -53,12 +53,13 @@ def monkeypatch_module():
|
||||
mpatch.undo()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=[False, True])
|
||||
@pytest.fixture(scope="module", params=[True])
|
||||
def server_with_lora_modules_json(request, monkeypatch_module,
|
||||
zephyr_lora_files):
|
||||
|
||||
use_v1 = request.param
|
||||
monkeypatch_module.setenv('VLLM_USE_V1', '1' if use_v1 else '0')
|
||||
assert use_v1
|
||||
monkeypatch_module.setenv('VLLM_USE_V1', '1')
|
||||
|
||||
# Define the json format LoRA module configurations
|
||||
lora_module_1 = {
|
||||
|
||||
@ -22,7 +22,7 @@ MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
||||
PREV_MINOR_VERSION = version._prev_minor_version()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=[True, False])
|
||||
@pytest.fixture(scope="module", params=[True])
|
||||
def use_v1(request):
|
||||
# Module-scoped variant of run_with_both_engines
|
||||
#
|
||||
|
||||
@ -102,12 +102,14 @@ def before_generate_case(context: schemathesis.hooks.HookContext, strategy):
|
||||
if "custom" in tool_call:
|
||||
return False
|
||||
|
||||
# Sometimes guided_grammar is generated to be empty
|
||||
# Sometimes structured_outputs.grammar is generated to be empty
|
||||
# Causing a server error in EBNF grammar parsing
|
||||
# https://github.com/vllm-project/vllm/pull/22587#issuecomment-3195253421
|
||||
guided_grammar = case.body.get("guided_grammar")
|
||||
structured_outputs = case.body.get("structured_outputs", {})
|
||||
grammar = structured_outputs.get("grammar") if isinstance(
|
||||
structured_outputs, dict) else None
|
||||
|
||||
if guided_grammar == '':
|
||||
if grammar == '':
|
||||
# Allow None (will be handled as no grammar)
|
||||
# But skip empty strings
|
||||
return False
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
|
||||
import io
|
||||
|
||||
# imports for guided decoding tests
|
||||
# imports for structured outputs tests
|
||||
import openai
|
||||
import pybase64
|
||||
import pytest
|
||||
|
||||
@ -10,8 +10,30 @@ import pytest
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
from .test_completion import default_server_args # noqa: F401
|
||||
from .test_completion import MODEL_NAME
|
||||
|
||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def default_server_args(zephyr_lora_files):
|
||||
return [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--max-num-seqs",
|
||||
"128",
|
||||
"--enforce-eager",
|
||||
# lora config
|
||||
"--enable-lora",
|
||||
"--lora-modules",
|
||||
f"zephyr-lora={zephyr_lora_files}",
|
||||
"--max-lora-rank",
|
||||
"64",
|
||||
"--max-cpu-loras",
|
||||
"2",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
||||
@ -333,7 +333,6 @@ async def test_serving_chat_should_set_correct_max_tokens():
|
||||
"role": "user",
|
||||
"content": "what is 1+1?"
|
||||
}],
|
||||
guided_decoding_backend="outlines",
|
||||
)
|
||||
|
||||
with suppress(Exception):
|
||||
@ -378,7 +377,6 @@ async def test_serving_chat_should_set_correct_max_tokens():
|
||||
"role": "user",
|
||||
"content": "what is 1+1?"
|
||||
}],
|
||||
guided_decoding_backend="outlines",
|
||||
)
|
||||
|
||||
with suppress(Exception):
|
||||
@ -433,7 +431,6 @@ async def test_serving_chat_should_set_correct_max_tokens():
|
||||
"role": "user",
|
||||
"content": "what is 1+1?"
|
||||
}],
|
||||
guided_decoding_backend="outlines",
|
||||
)
|
||||
|
||||
with suppress(Exception):
|
||||
@ -489,7 +486,6 @@ async def test_serving_chat_could_load_correct_generation_config():
|
||||
"role": "user",
|
||||
"content": "what is 1+1?"
|
||||
}],
|
||||
guided_decoding_backend="outlines",
|
||||
)
|
||||
|
||||
with suppress(Exception):
|
||||
|
||||
@ -15,14 +15,6 @@ MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
|
||||
DTYPE = "float16"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def v1(run_with_both_engines):
|
||||
# Simple autouse wrapper to run both engines for each test
|
||||
# This can be promoted up to conftest.py to run for every
|
||||
# test in a package
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# imports for guided decoding tests
|
||||
# imports for structured outputs tests
|
||||
import io
|
||||
import json
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import io
|
||||
# imports for guided decoding tests
|
||||
# imports for structured outputs tests
|
||||
import json
|
||||
|
||||
import httpx
|
||||
|
||||
@ -102,9 +102,6 @@ def test_triton_unified_attn(
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
if q_dtype is not None and q_dtype.itemsize < 2 and block_size < 32:
|
||||
pytest.skip("block size must be at least 32 for fp8")
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
num_seqs = len(seq_lens)
|
||||
query_lens = [x[0] for x in seq_lens]
|
||||
|
||||
@ -418,7 +418,9 @@ def test_full_cuda_graph(
|
||||
@pytest.mark.parametrize("model", FP32_STATE_MODELS)
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_fp32_state(
|
||||
@pytest.mark.parametrize("cache_dtype_param",
|
||||
["mamba_ssm_cache_dtype", "mamba_cache_dtype"])
|
||||
def test_fp32_cache_state(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
@ -426,6 +428,7 @@ def test_fp32_state(
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
cache_dtype_param: str,
|
||||
) -> None:
|
||||
|
||||
try:
|
||||
@ -443,13 +446,13 @@ def test_fp32_state(
|
||||
m.setenv("VLLM_USE_V1", "0")
|
||||
with vllm_runner(model,
|
||||
max_num_seqs=MAX_NUM_SEQS,
|
||||
mamba_ssm_cache_dtype="float32") as vllm_model:
|
||||
**{cache_dtype_param: "float32"}) as vllm_model:
|
||||
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
with vllm_runner(model,
|
||||
max_num_seqs=MAX_NUM_SEQS,
|
||||
mamba_ssm_cache_dtype="float32") as vllm_model:
|
||||
**{cache_dtype_param: "float32"}) as vllm_model:
|
||||
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
|
||||
39
tests/models/language/pooling/test_token_classification.py
Normal file
39
tests/models/language/pooling/test_token_classification.py
Normal file
@ -0,0 +1,39 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoModelForTokenClassification
|
||||
|
||||
from tests.models.utils import softmax
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["boltuix/NeuroBERT-NER"])
|
||||
# The float32 is required for this tiny model to pass the test.
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@torch.inference_mode
|
||||
def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
with vllm_runner(model, max_model_len=None, dtype=dtype) as vllm_model:
|
||||
vllm_outputs = vllm_model.encode(example_prompts)
|
||||
|
||||
with hf_runner(model,
|
||||
dtype=dtype,
|
||||
auto_cls=AutoModelForTokenClassification) as hf_model:
|
||||
tokenizer = hf_model.tokenizer
|
||||
hf_outputs = []
|
||||
for prompt in example_prompts:
|
||||
inputs = tokenizer([prompt], return_tensors="pt")
|
||||
inputs = hf_model.wrap_device(inputs)
|
||||
output = hf_model.model(**inputs)
|
||||
hf_outputs.append(softmax(output.logits[0]))
|
||||
|
||||
# check logits difference
|
||||
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
|
||||
hf_output = torch.tensor(hf_output).cpu().float()
|
||||
vllm_output = torch.tensor(vllm_output).cpu().float()
|
||||
assert torch.allclose(hf_output, vllm_output, 1e-2)
|
||||
@ -414,6 +414,7 @@ _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
|
||||
|
||||
# [Cross-encoder]
|
||||
"BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2"), # noqa: E501
|
||||
"BertForTokenClassification": _HfExamplesInfo("boltuix/NeuroBERT-NER"),
|
||||
"GteNewForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-multilingual-reranker-base", # noqa: E501
|
||||
trust_remote_code=True,
|
||||
hf_overrides={
|
||||
|
||||
@ -22,7 +22,7 @@ class DataModuleConfig(TypedDict):
|
||||
|
||||
class ImagePrompt(BaseModel):
|
||||
|
||||
data_format: Literal["b64_json", "bytes", "url"]
|
||||
data_format: Literal["b64_json", "bytes", "url", "path"]
|
||||
"""
|
||||
This is the data type for the input image
|
||||
"""
|
||||
|
||||
@ -1,182 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
from ..conftest import VllmRunner
|
||||
|
||||
MODELS = ["distilbert/distilgpt2"]
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch):
|
||||
"""
|
||||
This module is V0 only since it uses dtype=float, so
|
||||
set VLLM_USE_V1=0 for all tests in the module.
|
||||
"""
|
||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype",
|
||||
["float"]) # needed for comparing logprobs with HF
|
||||
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
|
||||
@pytest.mark.parametrize("num_top_logprobs", [0, 6]) # 32000 == vocab_size
|
||||
@pytest.mark.parametrize("detokenize", [True, False])
|
||||
def test_get_prompt_logprobs(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
model,
|
||||
dtype,
|
||||
chunked_prefill_token_size: int,
|
||||
num_top_logprobs: int,
|
||||
detokenize: bool,
|
||||
example_prompts,
|
||||
):
|
||||
max_num_seqs = 256
|
||||
enable_chunked_prefill = False
|
||||
max_num_batched_tokens = None
|
||||
if chunked_prefill_token_size != -1:
|
||||
enable_chunked_prefill = True
|
||||
max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
|
||||
max_num_batched_tokens = chunked_prefill_token_size
|
||||
|
||||
max_tokens = 5
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_logprobs = hf_model.generate_greedy_logprobs(
|
||||
example_prompts,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_logprobs=num_top_logprobs,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_num_seqs=max_num_seqs,
|
||||
) as vllm_model:
|
||||
vllm_sampling_params = SamplingParams(max_tokens=max_tokens,
|
||||
logprobs=num_top_logprobs,
|
||||
prompt_logprobs=num_top_logprobs,
|
||||
temperature=0.0,
|
||||
detokenize=detokenize)
|
||||
vllm_results = vllm_model.llm.generate(
|
||||
example_prompts, sampling_params=vllm_sampling_params)
|
||||
|
||||
# Test whether logprobs are included in the results.
|
||||
for result in vllm_results:
|
||||
assert result.prompt_logprobs is not None
|
||||
assert result.outputs[0].logprobs is not None
|
||||
assert len(result.outputs[0].logprobs) == max_tokens
|
||||
for logprobs in result.outputs[0].logprobs:
|
||||
# If the output token is not included in the top X
|
||||
# logprob, it can return 1 more data
|
||||
assert (len(logprobs) == num_top_logprobs
|
||||
or len(logprobs) == num_top_logprobs + 1)
|
||||
output_text = result.outputs[0].text
|
||||
output_string_from_most_likely_tokens_lst: list[str] = []
|
||||
for top_logprobs in result.outputs[0].logprobs:
|
||||
top_logprob = next(iter(top_logprobs.values()))
|
||||
output_string_from_most_likely_tokens_lst.append(
|
||||
top_logprob.decoded_token)
|
||||
|
||||
if detokenize:
|
||||
output_string_from_most_likely_tokens = "".join(
|
||||
output_string_from_most_likely_tokens_lst)
|
||||
assert output_text == output_string_from_most_likely_tokens, (
|
||||
"The output text from the top logprob for each token position "
|
||||
"should be the same as the output text in the result.")
|
||||
else:
|
||||
assert output_text == ''
|
||||
assert output_string_from_most_likely_tokens_lst == ([None] *
|
||||
max_tokens)
|
||||
|
||||
# The first prompt logprob is always None
|
||||
assert result.prompt_logprobs[0] is None
|
||||
for prompt_logprobs in result.prompt_logprobs[1:]:
|
||||
# If the prompt token is not included in the top X
|
||||
# logprob, it can return 1 more data
|
||||
assert (len(prompt_logprobs) == num_top_logprobs
|
||||
or len(prompt_logprobs) == num_top_logprobs + 1)
|
||||
|
||||
# Test whether prompt logprobs are consistent with HF
|
||||
for vllm_result, hf_logprob in zip(vllm_results, hf_logprobs):
|
||||
# Check prompt logprobs
|
||||
# The first prompt logprob is always None, so we compare it from 1:.
|
||||
vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:]
|
||||
for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs):
|
||||
for token_id, logprob in vllm_prompt_logprob_dict.items():
|
||||
torch.testing.assert_close(logprob.logprob,
|
||||
hf_logprob[0][i][token_id].item(),
|
||||
atol=1e-2,
|
||||
rtol=1e-2)
|
||||
vllm_sample_logprobs = vllm_result.outputs[0].logprobs
|
||||
for i, top_logprobs in enumerate(vllm_sample_logprobs):
|
||||
for token_id, sample_logprob in top_logprobs.items():
|
||||
logprob = sample_logprob.logprob
|
||||
torch.testing.assert_close(logprob,
|
||||
hf_logprob[i][-1][token_id].item(),
|
||||
atol=1e-2,
|
||||
rtol=1e-2)
|
||||
if detokenize:
|
||||
assert isinstance(sample_logprob.decoded_token, str), (
|
||||
"The token should be decoded by the time it is returned"
|
||||
" to the user.")
|
||||
|
||||
# Test if prompt logprobs are correctly set.
|
||||
for vllm_result in vllm_results:
|
||||
token_ids = vllm_result.prompt_token_ids
|
||||
prompt_logprobs = vllm_result.prompt_logprobs
|
||||
|
||||
# The first token doesn't have logprob.
|
||||
assert prompt_logprobs[0] is None
|
||||
|
||||
for token_id, logprob_dict in zip(token_ids[1:], prompt_logprobs[1:]):
|
||||
assert token_id in logprob_dict
|
||||
|
||||
|
||||
def test_max_logprobs():
|
||||
runner = VllmRunner("facebook/opt-125m", max_logprobs=1)
|
||||
vllm_sampling_params = SamplingParams(logprobs=1)
|
||||
# should pass
|
||||
runner.generate(["Hello world"], sampling_params=vllm_sampling_params)
|
||||
|
||||
bad_sampling_params = SamplingParams(logprobs=2)
|
||||
with pytest.raises(ValueError):
|
||||
runner.generate(["Hello world"], sampling_params=bad_sampling_params)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
|
||||
@pytest.mark.parametrize("detokenize", [True, False])
|
||||
def test_none_logprobs(vllm_runner, model, chunked_prefill_token_size: int,
|
||||
detokenize: bool, example_prompts):
|
||||
max_num_seqs = 256
|
||||
enable_chunked_prefill = False
|
||||
max_num_batched_tokens = None
|
||||
if chunked_prefill_token_size != -1:
|
||||
enable_chunked_prefill = True
|
||||
max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
|
||||
max_num_batched_tokens = chunked_prefill_token_size
|
||||
max_tokens = 5
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_num_seqs=max_num_seqs,
|
||||
) as vllm_model:
|
||||
sampling_params_logprobs_none = SamplingParams(max_tokens=max_tokens,
|
||||
logprobs=None,
|
||||
temperature=0.0,
|
||||
detokenize=detokenize)
|
||||
results_logprobs_none = vllm_model.llm.generate(
|
||||
example_prompts, sampling_params=sampling_params_logprobs_none)
|
||||
|
||||
for i in range(len(results_logprobs_none)):
|
||||
assert results_logprobs_none[i].outputs[0].logprobs is None
|
||||
assert results_logprobs_none[i].outputs[0].cumulative_logprob is None
|
||||
@ -1,84 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for the SamplingParams class.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen1.5-7B"
|
||||
|
||||
|
||||
def test_max_tokens_none():
|
||||
"""max_tokens=None should be allowed"""
|
||||
SamplingParams(temperature=0.01, top_p=0.1, max_tokens=None)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def model_config():
|
||||
return ModelConfig(
|
||||
MODEL_NAME,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def default_max_tokens():
|
||||
return 4096
|
||||
|
||||
|
||||
def test_sampling_params_from_request_with_no_guided_decoding_backend(
|
||||
model_config, default_max_tokens):
|
||||
# guided_decoding_backend is not present at request level
|
||||
request = ChatCompletionRequest.model_validate({
|
||||
'messages': [{
|
||||
'role': 'user',
|
||||
'content': 'Hello'
|
||||
}],
|
||||
'model':
|
||||
MODEL_NAME,
|
||||
'response_format': {
|
||||
'type': 'json_object',
|
||||
},
|
||||
})
|
||||
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens,
|
||||
model_config.logits_processor_pattern,
|
||||
)
|
||||
# we do not expect any backend to be present and the default
|
||||
# guided_decoding_backend at engine level will be used.
|
||||
assert sampling_params.guided_decoding.backend is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("request_level_guided_decoding_backend,expected",
|
||||
[("xgrammar", "xgrammar"), ("guidance", "guidance"),
|
||||
("outlines", "outlines")])
|
||||
def test_sampling_params_from_request_with_guided_decoding_backend(
|
||||
request_level_guided_decoding_backend: str, expected: str,
|
||||
model_config, default_max_tokens):
|
||||
|
||||
request = ChatCompletionRequest.model_validate({
|
||||
'messages': [{
|
||||
'role': 'user',
|
||||
'content': 'Hello'
|
||||
}],
|
||||
'model':
|
||||
MODEL_NAME,
|
||||
'response_format': {
|
||||
'type': 'json_object',
|
||||
},
|
||||
'guided_decoding_backend':
|
||||
request_level_guided_decoding_backend,
|
||||
})
|
||||
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens,
|
||||
model_config.logits_processor_pattern,
|
||||
)
|
||||
# backend correctly identified in resulting sampling_params
|
||||
assert sampling_params.guided_decoding.backend == expected
|
||||
@ -68,7 +68,7 @@ EXAMPLE_TOOLS = [
|
||||
def _compile_and_check(tools: list[ChatCompletionToolsParam], sample_output,
|
||||
should_match: bool):
|
||||
self = MagicMock(tool_choice="required", tools=tools)
|
||||
schema = ChatCompletionRequest._get_guided_json_from_tool(self)
|
||||
schema = ChatCompletionRequest._get_json_schema_from_tool(self)
|
||||
assert isinstance(schema, dict)
|
||||
|
||||
# use build_regex_from_schema used in JSONLogitsProcessor to create Guide
|
||||
@ -218,7 +218,7 @@ VALID_TOOLS = [t[0] for t in VALID_TOOL_OUTPUTS]
|
||||
}
|
||||
}, {}], False),
|
||||
])
|
||||
def test_guided_json(sample_output, should_match):
|
||||
def test_structured_outputs_json(sample_output, should_match):
|
||||
_compile_and_check(tools=TypeAdapter(
|
||||
list[ChatCompletionToolsParam]).validate_python(EXAMPLE_TOOLS),
|
||||
sample_output=sample_output,
|
||||
@ -273,8 +273,9 @@ def update_parameters_empty_dict(
|
||||
@pytest.mark.parametrize(
|
||||
"update_parameters",
|
||||
[update_parameters_none, update_parameters_empty_dict])
|
||||
def test_guided_json_without_parameters(sample_output, should_match,
|
||||
update_parameters):
|
||||
def test_structured_outputs_json_without_parameters(sample_output,
|
||||
should_match,
|
||||
update_parameters):
|
||||
updated_tools = [deepcopy(EXAMPLE_TOOLS[0])]
|
||||
tools = TypeAdapter(
|
||||
list[ChatCompletionToolsParam]).validate_python(updated_tools)
|
||||
@ -334,4 +335,4 @@ def test_streaming_output_valid(output, empty_params, delta_len):
|
||||
combined_messages += message.tool_calls[0].function.arguments
|
||||
combined_messages += "}]"
|
||||
assert json.loads(combined_messages) == output
|
||||
assert json.dumps(json.loads(combined_messages)) == output_json
|
||||
assert json.dumps(json.loads(combined_messages)) == output_json
|
||||
|
||||
@ -10,7 +10,7 @@ from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
|
||||
SchedulerConfig, SpeculativeConfig, VllmConfig)
|
||||
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
|
||||
MultiModalKwargsItem, PlaceholderRange)
|
||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
|
||||
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
@ -1796,11 +1796,11 @@ def test_schedule_skip_tokenizer_init():
|
||||
|
||||
def test_schedule_skip_tokenizer_init_structured_output_request():
|
||||
scheduler = create_scheduler(skip_tokenizer_init=True)
|
||||
guided_params = GuidedDecodingParams(regex="[0-9]+")
|
||||
structured_outputs_params = StructuredOutputsParams(regex="[0-9]+")
|
||||
sampling_params = SamplingParams(
|
||||
ignore_eos=False,
|
||||
max_tokens=16,
|
||||
guided_decoding=guided_params,
|
||||
structured_outputs=structured_outputs_params,
|
||||
)
|
||||
request = Request(
|
||||
request_id="0",
|
||||
|
||||
@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Optional
|
||||
import pytest
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
|
||||
from vllm.v1.metrics.reader import Counter, Gauge, Histogram, Metric, Vector
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -97,7 +97,7 @@ def _get_test_sampling_params(
|
||||
top_p=0.95,
|
||||
n=n,
|
||||
seed=seed,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
structured_outputs=StructuredOutputsParams(
|
||||
regex="[0-9]+") if structured_outputs else None,
|
||||
) for n in n_list
|
||||
], n_list
|
||||
|
||||
@ -151,7 +151,7 @@ def sample_definition_json_schema():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_guided_choice():
|
||||
def sample_structured_outputs_choices():
|
||||
return [
|
||||
"Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript",
|
||||
"Ruby", "Swift", "Kotlin"
|
||||
|
||||
@ -15,12 +15,13 @@ import torch
|
||||
from pydantic import BaseModel
|
||||
|
||||
from tests.reasoning.utils import run_reasoning_extraction
|
||||
from vllm.config import StructuredOutputsConfig
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager
|
||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import TokenizerMode
|
||||
@ -90,7 +91,7 @@ def _load_json(s: str, backend: str) -> str:
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, guided_decoding_backend, tokenizer_mode, speculative_config",
|
||||
"model_name, backend, tokenizer_mode, speculative_config",
|
||||
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE)
|
||||
def test_structured_output(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
@ -99,8 +100,8 @@ def test_structured_output(
|
||||
sample_sql_ebnf: str,
|
||||
sample_sql_lark: str,
|
||||
sample_regex: str,
|
||||
sample_guided_choice: str,
|
||||
guided_decoding_backend: str,
|
||||
sample_structured_outputs_choices: str,
|
||||
backend: str,
|
||||
tokenizer_mode: str,
|
||||
model_name: str,
|
||||
speculative_config: dict[str, Any],
|
||||
@ -115,16 +116,15 @@ def test_structured_output(
|
||||
enforce_eager = bool(not current_platform.is_tpu())
|
||||
# Use a single LLM instance for several scenarios to
|
||||
# speed up the test suite.
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
enforce_eager=enforce_eager,
|
||||
max_model_len=1024,
|
||||
guided_decoding_backend=guided_decoding_backend,
|
||||
guided_decoding_disable_any_whitespace=(guided_decoding_backend
|
||||
in {"xgrammar", "guidance"}),
|
||||
seed=120,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
speculative_config=speculative_config)
|
||||
llm = LLM(model=model_name,
|
||||
enforce_eager=enforce_eager,
|
||||
max_model_len=1024,
|
||||
structured_outputs_config=dict(backend=backend,
|
||||
disable_any_whitespace=backend
|
||||
in {"xgrammar", "guidance"}),
|
||||
seed=120,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
speculative_config=speculative_config)
|
||||
|
||||
#
|
||||
# Test 1: Generate JSON output based on a provided schema
|
||||
@ -132,7 +132,7 @@ def test_structured_output(
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=4096,
|
||||
guided_decoding=GuidedDecodingParams(json=sample_json_schema))
|
||||
structured_outputs=StructuredOutputsParams(json=sample_json_schema))
|
||||
|
||||
prompt = ("Give an example JSON for an employee profile that fits this "
|
||||
"schema. Make the response as short as possible. Schema: "
|
||||
@ -152,7 +152,7 @@ def test_structured_output(
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
if guided_decoding_backend != 'lm-format-enforcer':
|
||||
if backend != 'lm-format-enforcer':
|
||||
assert "\n" not in generated_text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
output_json = json.loads(generated_text)
|
||||
@ -161,12 +161,12 @@ def test_structured_output(
|
||||
#
|
||||
# Test 2: Generate JSON object without a schema
|
||||
#
|
||||
if guided_decoding_backend != "outlines":
|
||||
if backend != "outlines":
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=4096,
|
||||
n=2,
|
||||
guided_decoding=GuidedDecodingParams(json_object=True))
|
||||
structured_outputs=StructuredOutputsParams(json_object=True))
|
||||
|
||||
outputs = llm.generate(prompts=(
|
||||
"Generate a JSON object with curly braces for a person with "
|
||||
@ -195,8 +195,9 @@ def test_structured_output(
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=4096,
|
||||
guided_decoding=GuidedDecodingParams(json=unsupported_json_schema))
|
||||
if guided_decoding_backend.startswith("xgrammar"):
|
||||
structured_outputs=StructuredOutputsParams(
|
||||
json=unsupported_json_schema))
|
||||
if backend.startswith("xgrammar"):
|
||||
with pytest.raises(ValueError,
|
||||
match="The provided JSON schema contains features "
|
||||
"not supported by xgrammar."):
|
||||
@ -230,7 +231,7 @@ def test_structured_output(
|
||||
parsed_json = json.loads(generated_text)
|
||||
assert isinstance(parsed_json, dict)
|
||||
|
||||
if guided_decoding_backend not in ["outlines", "lm-format-enforcer"]:
|
||||
if backend not in ["outlines", "lm-format-enforcer"]:
|
||||
#
|
||||
# Test 4: Generate SQL statement using EBNF grammar
|
||||
#
|
||||
@ -238,7 +239,8 @@ def test_structured_output(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf))
|
||||
structured_outputs=StructuredOutputsParams(
|
||||
grammar=sample_sql_ebnf))
|
||||
outputs = llm.generate(
|
||||
("Generate a sql statement that selects col_1 from "
|
||||
"table_1 where it is equal to 1. Make the response as short as "
|
||||
@ -271,7 +273,8 @@ def test_structured_output(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark))
|
||||
structured_outputs=StructuredOutputsParams(
|
||||
grammar=sample_sql_lark))
|
||||
outputs = llm.generate(
|
||||
("Generate a sql statement that selects col_1 from "
|
||||
"table_1 where it is equal to 1. Make the response as short as "
|
||||
@ -309,7 +312,8 @@ def test_structured_output(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(grammar="not a grammar"))
|
||||
structured_outputs=StructuredOutputsParams(
|
||||
grammar="not a grammar"))
|
||||
with pytest.raises(ValueError, match="Failed to convert the grammar "):
|
||||
llm.generate(
|
||||
("Generate a sql statement that selects col_1 from "
|
||||
@ -325,7 +329,7 @@ def test_structured_output(
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
guided_decoding=GuidedDecodingParams(regex=sample_regex))
|
||||
structured_outputs=StructuredOutputsParams(regex=sample_regex))
|
||||
|
||||
prompt = (f"Give an example IPv4 address with this regex: {sample_regex}. "
|
||||
f"Make the response as short as possible.")
|
||||
@ -352,7 +356,8 @@ def test_structured_output(
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
guided_decoding=GuidedDecodingParams(choice=sample_guided_choice))
|
||||
structured_outputs=StructuredOutputsParams(
|
||||
choice=sample_structured_outputs_choices))
|
||||
|
||||
outputs = llm.generate(
|
||||
("The best language for type-safe systems programming is "
|
||||
@ -368,7 +373,7 @@ def test_structured_output(
|
||||
generated_text = output.outputs[0].text
|
||||
print(generated_text)
|
||||
assert generated_text is not None
|
||||
assert generated_text in sample_guided_choice
|
||||
assert generated_text in sample_structured_outputs_choices
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
#
|
||||
@ -378,7 +383,7 @@ def test_structured_output(
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(json=json_schema))
|
||||
structured_outputs=StructuredOutputsParams(json=json_schema))
|
||||
|
||||
outputs = llm.generate(
|
||||
("Generate a JSON with the brand, model and car_type of the most "
|
||||
@ -422,7 +427,7 @@ def test_structured_output(
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=4096,
|
||||
guided_decoding=GuidedDecodingParams(json=json_schema))
|
||||
structured_outputs=StructuredOutputsParams(json=json_schema))
|
||||
|
||||
outputs = llm.generate(
|
||||
("Generate a description of a frog using 50 characters. "
|
||||
@ -444,7 +449,7 @@ def test_structured_output(
|
||||
output_json = json.loads(generated_text)
|
||||
jsonschema.validate(instance=output_json, schema=json_schema)
|
||||
|
||||
if guided_decoding_backend not in ["outlines", "lm-format-enforcer"]:
|
||||
if backend not in ["outlines", "lm-format-enforcer"]:
|
||||
#
|
||||
# Test 11: Generate structured output using structural_tag format
|
||||
#
|
||||
@ -470,7 +475,7 @@ def test_structured_output(
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=4096,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
structured_outputs=StructuredOutputsParams(
|
||||
structural_tag=json.dumps(structural_tag_config)))
|
||||
|
||||
prompt = """
|
||||
@ -547,7 +552,7 @@ Make the response as short as possible.
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, guided_decoding_backend, tokenizer_mode, reasoning_parser, speculative_config", # noqa: E501
|
||||
"model_name, backend, tokenizer_mode, reasoning_parser, speculative_config", # noqa: E501
|
||||
[
|
||||
("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", "xgrammar", "auto",
|
||||
"deepseek_r1", NGRAM_SPEC_CONFIG),
|
||||
@ -556,7 +561,7 @@ Make the response as short as possible.
|
||||
)
|
||||
def test_structured_output_with_reasoning_matrices(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
guided_decoding_backend: str,
|
||||
backend: str,
|
||||
tokenizer_mode: TokenizerMode,
|
||||
reasoning_parser: str,
|
||||
model_name: str,
|
||||
@ -576,10 +581,11 @@ def test_structured_output_with_reasoning_matrices(
|
||||
enforce_eager=bool(not current_platform.is_tpu()),
|
||||
max_model_len=1024,
|
||||
max_num_seqs=16,
|
||||
guided_decoding_backend=guided_decoding_backend,
|
||||
guided_decoding_disable_any_whitespace=True,
|
||||
structured_outputs_config=dict(backend=backend,
|
||||
disable_any_whitespace=backend
|
||||
in {"xgrammar", "guidance"},
|
||||
reasoning_parser=reasoning_parser),
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
reasoning_parser=reasoning_parser,
|
||||
speculative_config=speculative_config,
|
||||
)
|
||||
tokenizer = llm.get_tokenizer()
|
||||
@ -603,7 +609,7 @@ def test_structured_output_with_reasoning_matrices(
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.1,
|
||||
max_tokens=8192,
|
||||
guided_decoding=GuidedDecodingParams(json=reasoning_schema),
|
||||
structured_outputs=StructuredOutputsParams(json=reasoning_schema),
|
||||
)
|
||||
outputs = llm.generate(
|
||||
[reasoning_prompt],
|
||||
@ -640,13 +646,14 @@ def test_structured_output_auto_mode(
|
||||
|
||||
llm = LLM(model=model_name,
|
||||
max_model_len=1024,
|
||||
guided_decoding_backend="auto",
|
||||
structured_outputs_config=dict(backend="auto"),
|
||||
tokenizer_mode=tokenizer_mode)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(json=unsupported_json_schema))
|
||||
structured_outputs=StructuredOutputsParams(
|
||||
json=unsupported_json_schema))
|
||||
|
||||
prompts = (
|
||||
"Give an example JSON object for a grade "
|
||||
@ -681,9 +688,10 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
llm = LLM(model="Qwen/Qwen2.5-1.5B-Instruct",
|
||||
max_model_len=1024,
|
||||
guided_decoding_backend="guidance",
|
||||
guided_decoding_disable_any_whitespace=True,
|
||||
guided_decoding_disable_additional_properties=True)
|
||||
structured_outputs_config=dict(
|
||||
backend="guidance",
|
||||
disable_any_whitespace=True,
|
||||
disable_additional_properties=True))
|
||||
|
||||
schema = {
|
||||
'type': 'object',
|
||||
@ -709,14 +717,15 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch):
|
||||
"<|im_end|>\n<|im_start|>assistant\n")
|
||||
|
||||
def generate_with_backend(backend):
|
||||
guided_params = GuidedDecodingParams(
|
||||
structured_outputs_params = StructuredOutputsParams(
|
||||
json=schema,
|
||||
backend=backend,
|
||||
disable_any_whitespace=True,
|
||||
disable_additional_properties=True)
|
||||
sampling_params = SamplingParams(temperature=0,
|
||||
max_tokens=256,
|
||||
guided_decoding=guided_params)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
max_tokens=256,
|
||||
structured_outputs=structured_outputs_params)
|
||||
|
||||
outputs = llm.generate(prompt, sampling_params=sampling_params)
|
||||
assert outputs is not None
|
||||
@ -736,12 +745,11 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch):
|
||||
assert "a6" not in generated
|
||||
|
||||
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["guidance", "xgrammar", "outlines"])
|
||||
def test_structured_output_batched_with_non_guided_requests(
|
||||
@pytest.mark.parametrize("backend", ["guidance", "xgrammar", "outlines"])
|
||||
def test_structured_output_batched_with_non_structured_outputs_requests(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sample_json_schema: dict[str, Any],
|
||||
guided_decoding_backend: str,
|
||||
backend: str,
|
||||
):
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
@ -753,24 +761,25 @@ def test_structured_output_batched_with_non_guided_requests(
|
||||
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
enforce_eager=enforce_eager,
|
||||
max_model_len=1024,
|
||||
guided_decoding_backend=guided_decoding_backend,
|
||||
guided_decoding_disable_any_whitespace=(guided_decoding_backend
|
||||
in {"xgrammar", "guidance"}),
|
||||
structured_outputs_config=StructuredOutputsConfig(
|
||||
backend=backend,
|
||||
disable_any_whitespace=backend in {"xgrammar", "guidance"},
|
||||
),
|
||||
)
|
||||
|
||||
guided_prompt = (
|
||||
structured_outputs_prompt = (
|
||||
"Give an example JSON for an employee profile that fits this "
|
||||
"schema. Make the response as short as possible. Schema: "
|
||||
f"{sample_json_schema}")
|
||||
|
||||
non_guided_prompt = "The diameter of the Earth in kilometers is "
|
||||
non_structured_outputs_prompt = "The diameter of the Earth in kilometers is "
|
||||
|
||||
prompts = [guided_prompt, non_guided_prompt]
|
||||
prompts = [structured_outputs_prompt, non_structured_outputs_prompt]
|
||||
sampling_params = [
|
||||
SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=400,
|
||||
guided_decoding=GuidedDecodingParams(json=sample_json_schema)),
|
||||
SamplingParams(temperature=1.0,
|
||||
max_tokens=400,
|
||||
structured_outputs=StructuredOutputsParams(
|
||||
json=sample_json_schema)),
|
||||
# No max tokens, temp=0 to assert on contents
|
||||
SamplingParams(
|
||||
seed=42,
|
||||
@ -801,16 +810,16 @@ def test_structured_output_batched_with_non_guided_requests(
|
||||
print(f"Prompt:\n{prompt!r}\nGenerated text:\n{generated_text!r}")
|
||||
|
||||
if index == 0:
|
||||
# First prompt is guided, expect valid JSON
|
||||
# First prompt is structured outputs, expect valid JSON
|
||||
assert "\n" not in generated_text
|
||||
output_json = json.loads(generated_text)
|
||||
jsonschema.validate(instance=output_json,
|
||||
schema=sample_json_schema)
|
||||
else:
|
||||
# Second prompt is not guided, expect valid output
|
||||
# Second prompt is not structured outputs, expect valid output
|
||||
# Cannot assert on exact output, but we can expect it to be factual
|
||||
assert "12,742" in generated_text
|
||||
|
||||
# non-guided requests should not return a valid JSON here
|
||||
# non-structured outputs requests should not return a valid JSON here
|
||||
with pytest.raises(ValueError):
|
||||
output_json = json.loads(generated_text)
|
||||
|
||||
@ -77,7 +77,9 @@ async def test_invalid_json_schema(client: openai.AsyncOpenAI,
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}],
|
||||
extra_body={"guided_json": invalid_json_schema},
|
||||
extra_body={"structured_outputs": {
|
||||
"json": invalid_json_schema
|
||||
}},
|
||||
)
|
||||
|
||||
|
||||
@ -99,7 +101,9 @@ async def test_invalid_regex(client: openai.AsyncOpenAI, model_name: str):
|
||||
"content": prompt,
|
||||
}],
|
||||
extra_body={
|
||||
"guided_regex": r"[.*",
|
||||
"structured_outputs": {
|
||||
"regex": r"[.*"
|
||||
},
|
||||
"stop": ["\n"]
|
||||
},
|
||||
)
|
||||
@ -134,5 +138,9 @@ async def test_invalid_grammar(client: openai.AsyncOpenAI, model_name: str):
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}],
|
||||
extra_body={"guided_grammar": invalid_simplified_sql_grammar},
|
||||
extra_body={
|
||||
"structured_outputs": {
|
||||
"grammar": invalid_simplified_sql_grammar
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@ -627,7 +627,9 @@ async def test_invalid_json_schema(client: openai.AsyncOpenAI,
|
||||
await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompt,
|
||||
extra_body={"guided_json": invalid_json_schema},
|
||||
extra_body={"structured_outputs": {
|
||||
"json": invalid_json_schema
|
||||
}},
|
||||
)
|
||||
|
||||
|
||||
@ -646,7 +648,9 @@ async def test_invalid_regex(client: openai.AsyncOpenAI, model_name: str):
|
||||
model=model_name,
|
||||
prompt=prompt,
|
||||
extra_body={
|
||||
"guided_regex": r"[.*",
|
||||
"structured_outputs": {
|
||||
"regex": r"[.*"
|
||||
},
|
||||
"stop": ["\n"]
|
||||
},
|
||||
)
|
||||
@ -678,7 +682,11 @@ async def test_invalid_grammar(client: openai.AsyncOpenAI, model_name: str):
|
||||
await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompt,
|
||||
extra_body={"guided_grammar": invalid_simplified_sql_grammar},
|
||||
extra_body={
|
||||
"structured_outputs": {
|
||||
"grammar": invalid_simplified_sql_grammar
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
|
||||
152
tests/v1/kv_offload/test_worker.py
Normal file
152
tests/v1/kv_offload/test_worker.py
Normal file
@ -0,0 +1,152 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.v1.kv_offload.abstract import LoadStoreSpec
|
||||
from vllm.v1.kv_offload.worker.worker import (OffloadingHandler,
|
||||
OffloadingWorker, TransferResult,
|
||||
TransferSpec)
|
||||
|
||||
|
||||
class LoadStoreSpec1(LoadStoreSpec):
|
||||
|
||||
def __init__(self,
|
||||
submit_success: bool = True,
|
||||
async_success: bool = True,
|
||||
exception: bool = False):
|
||||
self.finished = False
|
||||
self.submit_success = submit_success
|
||||
self.async_success = async_success
|
||||
self.exception = exception
|
||||
|
||||
@staticmethod
|
||||
def medium() -> str:
|
||||
return "1"
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.medium()}: {id(self)}"
|
||||
|
||||
|
||||
class LoadStoreSpec2(LoadStoreSpec):
|
||||
|
||||
@staticmethod
|
||||
def medium() -> str:
|
||||
return "2"
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.medium()}: {id(self)}"
|
||||
|
||||
|
||||
class OffloadingHandler1To2(OffloadingHandler):
|
||||
|
||||
def __init__(self):
|
||||
self.transfers: dict[int, LoadStoreSpec1] = {}
|
||||
|
||||
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
|
||||
src, dst = spec
|
||||
assert isinstance(src, LoadStoreSpec1)
|
||||
assert isinstance(dst, LoadStoreSpec2)
|
||||
|
||||
if src.exception:
|
||||
raise Exception("An expected exception. Don't worry!")
|
||||
if not src.submit_success:
|
||||
return False
|
||||
|
||||
self.transfers[job_id] = src
|
||||
return True
|
||||
|
||||
def get_finished(self) -> list[TransferResult]:
|
||||
finished = []
|
||||
for job_id, spec in list(self.transfers.items()):
|
||||
if spec.finished:
|
||||
finished.append((job_id, spec.async_success))
|
||||
del self.transfers[job_id]
|
||||
return finished
|
||||
|
||||
|
||||
class OffloadingHandler2To1(OffloadingHandler):
|
||||
|
||||
def __init__(self):
|
||||
self.transfers: dict[int, LoadStoreSpec1] = {}
|
||||
|
||||
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
|
||||
src, dst = spec
|
||||
assert isinstance(src, LoadStoreSpec2)
|
||||
assert isinstance(dst, LoadStoreSpec1)
|
||||
|
||||
self.transfers[job_id] = dst
|
||||
return True
|
||||
|
||||
def get_finished(self) -> list[TransferResult]:
|
||||
finished = []
|
||||
for job_id, spec in list(self.transfers.items()):
|
||||
if spec.finished:
|
||||
finished.append((job_id, spec.async_success))
|
||||
del self.transfers[job_id]
|
||||
return finished
|
||||
|
||||
|
||||
def test_offloading_worker():
|
||||
"""
|
||||
Tests OffloadingWorker with 2 handlers.
|
||||
One handler performs 1->2 transfers, and the other handles 2->1.
|
||||
"""
|
||||
worker = OffloadingWorker()
|
||||
handler1to2 = OffloadingHandler1To2()
|
||||
handler2to1 = OffloadingHandler2To1()
|
||||
worker.register_handler(LoadStoreSpec1, LoadStoreSpec2, handler1to2)
|
||||
worker.register_handler(LoadStoreSpec2, LoadStoreSpec1, handler2to1)
|
||||
|
||||
# 1st transfer 1->2 (exception)
|
||||
src1 = LoadStoreSpec1(exception=True)
|
||||
dst1 = LoadStoreSpec2()
|
||||
assert not worker.transfer_async(1, (src1, dst1))
|
||||
|
||||
# 2ed transfer 1->2 (failure to submit)
|
||||
src2 = LoadStoreSpec1(submit_success=False)
|
||||
dst2 = LoadStoreSpec2()
|
||||
assert not worker.transfer_async(2, (src2, dst2))
|
||||
|
||||
# 3rd transfer 1->2 (failure)
|
||||
src3 = LoadStoreSpec1(async_success=False)
|
||||
dst3 = LoadStoreSpec2()
|
||||
assert worker.transfer_async(3, (src3, dst3))
|
||||
|
||||
# 4th transfer 1->2 (success)
|
||||
src4 = LoadStoreSpec1()
|
||||
dst4 = LoadStoreSpec2()
|
||||
worker.transfer_async(4, (src4, dst4))
|
||||
assert set(handler1to2.transfers.keys()) == {3, 4}
|
||||
|
||||
# 5th transfer 2->1
|
||||
src5 = LoadStoreSpec2()
|
||||
dst5 = LoadStoreSpec1()
|
||||
worker.transfer_async(5, (src5, dst5))
|
||||
assert set(handler2to1.transfers.keys()) == {5}
|
||||
|
||||
# no transfer completed yet
|
||||
assert worker.get_finished() == []
|
||||
|
||||
# complete 3rd, 4th
|
||||
src3.finished = True
|
||||
src4.finished = True
|
||||
|
||||
# 6th transfer 1->2
|
||||
src6 = LoadStoreSpec1()
|
||||
dst6 = LoadStoreSpec2()
|
||||
worker.transfer_async(6, (src6, dst6))
|
||||
|
||||
# 7th transfer 2->1
|
||||
src7 = LoadStoreSpec2()
|
||||
dst7 = LoadStoreSpec1()
|
||||
worker.transfer_async(7, (src7, dst7))
|
||||
|
||||
# 6th and 7th transfers started
|
||||
assert 6 in handler1to2.transfers
|
||||
assert 7 in handler2to1.transfers
|
||||
|
||||
# verify result of 3rd and 4th transfers
|
||||
assert (sorted(worker.get_finished()) == [(3, False), (4, True)])
|
||||
|
||||
# complete 6th and 7th transfers
|
||||
src6.finished = True
|
||||
dst7.finished = True
|
||||
assert (sorted(worker.get_finished()) == [(6, True), (7, True)])
|
||||
@ -19,6 +19,8 @@ from vllm.config.load import LoadConfig
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.spec_decode.eagle import EagleProposer
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
|
||||
@ -64,6 +66,86 @@ def _create_proposer(
|
||||
device=current_platform.device_type)
|
||||
|
||||
|
||||
def test_prepare_next_token_ids():
|
||||
"""
|
||||
Test for prepare_next_token_ids_cpu and prepare_next_token_ids_padded.
|
||||
Each will produce a device tensor of next_token_ids, taking as input
|
||||
either the GPU tensor of sampled_token_ids with -1 for rejected tokens,
|
||||
or the CPU python list[list[int]] with the rejected tokens removed.
|
||||
"""
|
||||
device = torch.device(current_platform.device_type)
|
||||
|
||||
num_requests = 4
|
||||
num_speculative_tokens = 4
|
||||
batch_spec = BatchSpec(
|
||||
seq_lens=[num_speculative_tokens + 1] * num_requests,
|
||||
query_lens=[num_speculative_tokens + 1] * num_requests,
|
||||
)
|
||||
|
||||
req_ids = [f"req_{i+1}" for i in range(num_requests)]
|
||||
mock_input_batch = mock.MagicMock(spec=InputBatch)
|
||||
mock_input_batch.req_ids = req_ids
|
||||
mock_input_batch.num_reqs = num_requests
|
||||
mock_input_batch.vocab_size = 100
|
||||
|
||||
mock_num_scheduled_tokens = {req_id: 0 for req_id in req_ids}
|
||||
mock_requests = {}
|
||||
for req_id in req_ids:
|
||||
mock_request = mock.MagicMock(spec=CachedRequestState)
|
||||
# Each request will have a backup next token id of 10, 20, 30, 40
|
||||
mock_request.get_token_id.return_value = int(req_id.split("_")[1]) * 10
|
||||
mock_request.num_computed_tokens = 0
|
||||
mock_requests[req_id] = mock_request
|
||||
|
||||
sampled_token_ids = [
|
||||
[0, 1, -1, -1, -1], # 1 accepted, 3 rejected, "1" sampled
|
||||
[0, 1, 2, 3, 4], # all accepted, "4" sampled
|
||||
[-1, -1, -1, -1, -1], # sampling skipped, use backup token "30"
|
||||
[-1, -1, -1, -1, -1] # this request will be discarded
|
||||
]
|
||||
sampled_token_ids_tensor = torch.tensor(sampled_token_ids,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
sampled_token_ids_cpu = [[i for i in seq if i != -1]
|
||||
for seq in sampled_token_ids]
|
||||
|
||||
expected_next_token_ids_cpu = [1, 4, 30, 40]
|
||||
expected_next_token_ids_tensor = torch.tensor(expected_next_token_ids_cpu,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
proposer = _create_proposer("eagle", num_speculative_tokens)
|
||||
|
||||
next_token_ids_from_cpu = proposer.prepare_next_token_ids_cpu(
|
||||
sampled_token_ids_cpu, mock_requests, mock_input_batch,
|
||||
mock_num_scheduled_tokens)
|
||||
|
||||
assert torch.equal(next_token_ids_from_cpu, expected_next_token_ids_tensor)
|
||||
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec,
|
||||
block_size=16,
|
||||
device=device,
|
||||
)
|
||||
|
||||
discarded_req_indices = torch.tensor([3], dtype=torch.int64, device=device)
|
||||
num_discarded_reqs = 1
|
||||
|
||||
expected_valid_sampled_tokens_count = torch.tensor([2, 5, 0, 0],
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
next_token_ids_from_padded, valid_sampled_tokens_count = \
|
||||
proposer.prepare_next_token_ids_padded(
|
||||
common_attn_metadata, sampled_token_ids_tensor, mock_requests,
|
||||
mock_input_batch, discarded_req_indices, num_discarded_reqs)
|
||||
|
||||
assert torch.equal(next_token_ids_from_padded,
|
||||
expected_next_token_ids_tensor)
|
||||
assert torch.equal(valid_sampled_tokens_count,
|
||||
expected_valid_sampled_tokens_count)
|
||||
|
||||
|
||||
def test_prepare_inputs():
|
||||
"""
|
||||
cu_target_query_lens: [0, a, a + b, a + b + c]
|
||||
@ -90,10 +172,24 @@ def test_prepare_inputs():
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Rejected tokens per request: [1, 3, 2]
|
||||
num_rejected_tokens = torch.tensor([1, 3, 2],
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
# If there are `k` sampled tokens, then `k-1` tokens are draft tokens
|
||||
# from the previous iteration, and the last token is the bonus token sampled
|
||||
# from the base model.
|
||||
num_draft_tokens = [3, 6, 4] # one less than query_lens
|
||||
# num rejected tokens is [1, 3, 2]
|
||||
ACCEPT_TOKEN = 0
|
||||
BONUS_TOKEN = 1
|
||||
REJECT_TOKEN = -1
|
||||
sampled_token_ids = [
|
||||
[ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, BONUS_TOKEN],
|
||||
[
|
||||
ACCEPT_TOKEN, ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN,
|
||||
REJECT_TOKEN, REJECT_TOKEN, BONUS_TOKEN
|
||||
],
|
||||
[ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, REJECT_TOKEN, BONUS_TOKEN]
|
||||
]
|
||||
sampled_token_ids = [[i for i in seq if i != REJECT_TOKEN]
|
||||
for seq in sampled_token_ids]
|
||||
|
||||
# Expected calculations:
|
||||
# query_len_per_req = [4, 7, 5]
|
||||
@ -125,7 +221,7 @@ def test_prepare_inputs():
|
||||
proposer = _create_proposer("eagle", 1)
|
||||
|
||||
updated_metadata, token_indices = proposer.prepare_inputs(
|
||||
common_attn_metadata, num_rejected_tokens.cpu())
|
||||
common_attn_metadata, sampled_token_ids, num_draft_tokens)
|
||||
|
||||
assert torch.equal(updated_metadata.query_start_loc,
|
||||
expected_cu_num_tokens)
|
||||
@ -133,6 +229,77 @@ def test_prepare_inputs():
|
||||
assert torch.equal(token_indices, expected_token_indices)
|
||||
|
||||
|
||||
def test_prepare_inputs_padded():
|
||||
"""
|
||||
Input scenario is 3 requests with num_speculative_tokens == 2 and:
|
||||
- Request 1: query_len = 3, rejected = 1
|
||||
- Request 2: query_len = 3, rejected = 0
|
||||
- Request 3: query_len = 3, rejected = 2
|
||||
|
||||
Expected outputs:
|
||||
token_indices: [0, 1, 2,
|
||||
3, 4, 5,
|
||||
6, 7, 8]
|
||||
Reason: Deferred computation should not disturb the original indices.
|
||||
|
||||
token_indices_to_sample: [1, 5, 6]
|
||||
Reason: After accounting for rejections, these are the valid token positions
|
||||
from the original indices to sample from.
|
||||
"""
|
||||
|
||||
device = torch.device(current_platform.device_type)
|
||||
|
||||
expected_token_indices = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8],
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
expected_token_indices_to_sample = torch.tensor([1, 5, 6],
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
num_speculative_tokens = 2
|
||||
batch_spec = BatchSpec(
|
||||
seq_lens=[3, 3, 3],
|
||||
query_lens=[3, 3, 3],
|
||||
)
|
||||
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec,
|
||||
block_size=16,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Needed for cu_num_draft_tokens, which is expected to be [3, 6, 9]
|
||||
expected_query_start_loc = torch.tensor([0, 3, 6, 9],
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
draft_token_ids=[[0] * num_speculative_tokens] * 3,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# num_rejected_tokens = [1, 0, 2]
|
||||
# num_draft_tokens = [2, 2, 2]
|
||||
# valid_sampled_tokens_count = num_draft_tokens + 1 - num_rejected_tokens
|
||||
valid_sampled_tokens_count = torch.tensor([2, 3, 1],
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
proposer = _create_proposer("eagle", num_speculative_tokens)
|
||||
|
||||
output_metadata, token_indices, token_indices_to_sample = \
|
||||
proposer.prepare_inputs_padded(
|
||||
common_attn_metadata,
|
||||
spec_decode_metadata,
|
||||
valid_sampled_tokens_count)
|
||||
|
||||
assert output_metadata.max_query_len == 3
|
||||
assert torch.equal(output_metadata.query_start_loc,
|
||||
expected_query_start_loc)
|
||||
assert torch.equal(token_indices, expected_token_indices)
|
||||
assert torch.equal(token_indices_to_sample,
|
||||
expected_token_indices_to_sample)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
|
||||
@pytest.mark.parametrize("attn_backend",
|
||||
get_attn_backend_list_based_on_platform())
|
||||
@ -373,6 +540,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
next_token_ids=next_token_ids,
|
||||
last_token_indices=None,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
sampling_metadata=sampling_metadata)
|
||||
|
||||
@ -526,6 +694,7 @@ def test_propose_tree(spec_token_tree):
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
next_token_ids=next_token_ids,
|
||||
last_token_indices=None,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
sampling_metadata=sampling_metadata)
|
||||
assert result.shape == (batch_size, num_speculative_tokens)
|
||||
|
||||
@ -7,7 +7,6 @@ import pytest
|
||||
import vllm.envs as envs
|
||||
from vllm import LLM
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
|
||||
MODEL = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
|
||||
@ -96,20 +95,3 @@ def test_v1_attn_backend(monkeypatch):
|
||||
_ = AsyncEngineArgs(model=MODEL).create_engine_config()
|
||||
assert envs.VLLM_USE_V1
|
||||
m.delenv("VLLM_USE_V1")
|
||||
|
||||
|
||||
def test_reject_using_constructor_directly(monkeypatch):
|
||||
with monkeypatch.context() as m:
|
||||
if os.getenv("VLLM_USE_V1", None):
|
||||
m.delenv("VLLM_USE_V1")
|
||||
|
||||
# Sets VLLM_USE_V1=1.
|
||||
vllm_config = AsyncEngineArgs(model=MODEL).create_engine_config()
|
||||
|
||||
# This uses the V0 constructor directly.
|
||||
with pytest.raises(ValueError):
|
||||
AsyncLLMEngine(vllm_config,
|
||||
AsyncLLMEngine._get_executor_cls(vllm_config),
|
||||
log_stats=True)
|
||||
|
||||
m.delenv("VLLM_USE_V1")
|
||||
|
||||
@ -1,11 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch):
|
||||
"""
|
||||
This module tests V0 internals, so set VLLM_USE_V1=0.
|
||||
"""
|
||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||
@ -1,113 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention import AttentionMetadata, AttentionMetadataBuilder
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
|
||||
|
||||
class MockAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls():
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
||||
return AttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["AttentionMetadataBuilder"]:
|
||||
return AttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> tuple[int, ...]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: list[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def test_model_runner_input():
|
||||
sampling_metadata = SamplingMetadata(
|
||||
["seq_group"],
|
||||
"selected_token_indices",
|
||||
"categorized_sample_indices",
|
||||
"num_prompts",
|
||||
)
|
||||
attn_metadata = AttentionMetadata(
|
||||
num_prefills=1,
|
||||
num_prefill_tokens=2,
|
||||
num_decode_tokens=3,
|
||||
slot_mapping=torch.zeros(1),
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
)
|
||||
model_input = ModelInputForGPUWithSamplingMetadata(
|
||||
input_tokens=torch.ones(10),
|
||||
input_positions=torch.ones(10),
|
||||
sampling_metadata=sampling_metadata,
|
||||
attn_metadata=attn_metadata)
|
||||
|
||||
assert isinstance(model_input, ModelInputForGPUWithSamplingMetadata)
|
||||
|
||||
# Test round trip serialization.
|
||||
tensor_dict = model_input.as_broadcastable_tensor_dict()
|
||||
attn_backend = MockAttentionBackend()
|
||||
received_model_input = (
|
||||
ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
|
||||
tensor_dict, attn_backend=attn_backend))
|
||||
# Check that received copy has correct values.
|
||||
assert isinstance(received_model_input,
|
||||
ModelInputForGPUWithSamplingMetadata)
|
||||
assert received_model_input.input_tokens is not None
|
||||
assert (
|
||||
received_model_input.input_tokens == model_input.input_tokens).all()
|
||||
assert received_model_input.input_positions is not None
|
||||
assert (received_model_input.input_positions == model_input.input_positions
|
||||
).all()
|
||||
assert received_model_input.multi_modal_kwargs is None
|
||||
assert (received_model_input.multi_modal_kwargs ==
|
||||
model_input.multi_modal_kwargs)
|
||||
assert received_model_input.lora_requests is None
|
||||
assert received_model_input.lora_requests == model_input.lora_requests
|
||||
assert received_model_input.lora_mapping is None
|
||||
assert received_model_input.lora_mapping == model_input.lora_mapping
|
||||
for field in dataclasses.fields(AttentionMetadata):
|
||||
assert getattr(received_model_input.attn_metadata, field.name,
|
||||
None) == getattr(attn_metadata, field.name, None)
|
||||
# For sampling metadata, only selected_token_indices is copied.
|
||||
assert (received_model_input.sampling_metadata.selected_token_indices ==
|
||||
sampling_metadata.selected_token_indices)
|
||||
assert received_model_input.sampling_metadata.seq_groups is None
|
||||
@ -1,462 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.utils import get_open_port
|
||||
from vllm.worker.model_runner import ModelRunner
|
||||
|
||||
|
||||
def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
|
||||
engine_args = EngineArgs(model, *args, **kwargs)
|
||||
engine_config = engine_args.create_engine_config()
|
||||
model_runner = ModelRunner(
|
||||
vllm_config=engine_config,
|
||||
is_driver_worker=True,
|
||||
)
|
||||
return model_runner
|
||||
|
||||
|
||||
def test_deepseek_mla_attn_backend_module():
|
||||
model_runner = _create_model_runner(
|
||||
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct",
|
||||
trust_remote_code=True,
|
||||
enable_chunked_prefill=False,
|
||||
)
|
||||
assert model_runner.attn_backend.__name__ == "TritonMLABackend"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", list(range(1, 257, 3)))
|
||||
@pytest.mark.parametrize("use_prompt_embeds", [True, False])
|
||||
def test_prepare_prompt(batch_size, use_prompt_embeds, monkeypatch):
|
||||
if use_prompt_embeds:
|
||||
# Prompt Embeddings is only currently supported on V0
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
|
||||
model_runner = _create_model_runner(
|
||||
"facebook/opt-125m",
|
||||
max_num_batched_tokens=100000,
|
||||
max_num_seqs=100000,
|
||||
enable_chunked_prefill=False,
|
||||
enable_prompt_embeds=True,
|
||||
)
|
||||
|
||||
seq_lens: list[int] = []
|
||||
seq_group_metadata_list: list[SequenceGroupMetadata] = []
|
||||
block_tables = {0: [1]}
|
||||
expected_input_embeds_len = 0
|
||||
for i in range(batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
seq_len = i % (model_runner.block_size - 1) + 1
|
||||
seq_lens.append(seq_len)
|
||||
if use_prompt_embeds:
|
||||
seq_data = SequenceData.from_seqs(
|
||||
prompt_token_ids=[0] * seq_len,
|
||||
prompt_embeds=torch.rand(seq_len, 10),
|
||||
)
|
||||
expected_input_embeds_len += seq_len
|
||||
else:
|
||||
seq_data = SequenceData.from_seqs(prompt_token_ids=range(seq_len))
|
||||
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: seq_data},
|
||||
sampling_params=SamplingParams(temperature=0),
|
||||
block_tables=block_tables,
|
||||
)
|
||||
assert seq_group_metadata.token_chunk_size == seq_data.get_len()
|
||||
seq_group_metadata_list.append(seq_group_metadata)
|
||||
|
||||
expected_selected_token_indices = []
|
||||
selected_token_start_idx = 0
|
||||
for seq_len in seq_lens:
|
||||
expected_selected_token_indices.append(selected_token_start_idx +
|
||||
seq_len - 1)
|
||||
selected_token_start_idx += seq_len
|
||||
model_input = model_runner._prepare_model_input_tensors(
|
||||
seq_group_metadata_list)
|
||||
input_tokens = model_input.input_tokens
|
||||
input_positions = model_input.input_positions
|
||||
input_embeds = model_input.inputs_embeds
|
||||
attn_metadata = model_input.attn_metadata
|
||||
return_seq_lens = model_input.seq_lens
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
assert return_seq_lens == seq_lens
|
||||
assert len(slot_mapping) == len(input_tokens)
|
||||
|
||||
# Verify input metadata is correct for prompts.
|
||||
device = model_runner.device
|
||||
assert attn_metadata.num_prefills > 0
|
||||
assert attn_metadata.num_decode_tokens == 0
|
||||
torch.testing.assert_close(
|
||||
attn_metadata.seq_lens_tensor,
|
||||
torch.tensor(seq_lens, device=device, dtype=torch.int))
|
||||
assert attn_metadata.seq_lens == seq_lens
|
||||
assert attn_metadata.max_prefill_seq_len == max(seq_lens)
|
||||
assert attn_metadata.max_decode_seq_len == 0
|
||||
|
||||
# Test subquery start locs.
|
||||
start_idx = 0
|
||||
start_loc = [start_idx]
|
||||
for seq_len in seq_lens:
|
||||
start_idx += seq_len
|
||||
start_loc.append(start_idx)
|
||||
torch.testing.assert_close(
|
||||
attn_metadata.query_start_loc,
|
||||
torch.tensor(start_loc, dtype=torch.int32, device=device))
|
||||
|
||||
# Test seq start locs. Note that for normal prefill it is
|
||||
# equivalent to query_start_loc.
|
||||
start_idx = 0
|
||||
seq_start_loc = [start_idx]
|
||||
for seq_len in seq_lens:
|
||||
start_idx += seq_len
|
||||
seq_start_loc.append(start_idx)
|
||||
|
||||
torch.testing.assert_close(
|
||||
attn_metadata.seq_start_loc,
|
||||
torch.tensor(start_loc, dtype=torch.int32, device=device))
|
||||
torch.testing.assert_close(
|
||||
attn_metadata.context_lens_tensor,
|
||||
torch.zeros(attn_metadata.context_lens_tensor.shape[0],
|
||||
dtype=torch.int,
|
||||
device=device))
|
||||
|
||||
expected = torch.tensor([[] for _ in range(len(seq_group_metadata_list))],
|
||||
dtype=torch.int32,
|
||||
device=model_runner.device)
|
||||
torch.testing.assert_close(attn_metadata.block_tables, expected)
|
||||
# Cuda graph should not be used for prerill.
|
||||
assert attn_metadata.use_cuda_graph is False
|
||||
|
||||
assert len(input_tokens) == sum(seq_lens)
|
||||
assert len(input_positions) == sum(seq_lens)
|
||||
if expected_input_embeds_len == 0:
|
||||
torch.testing.assert_close(input_tokens, input_positions)
|
||||
assert input_embeds is None
|
||||
else:
|
||||
assert len(input_embeds) == expected_input_embeds_len
|
||||
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
seq_lens,
|
||||
query_lens=seq_lens,
|
||||
device=model_runner.device,
|
||||
pin_memory=model_runner.pin_memory)
|
||||
assert len(input_tokens) == sum(seq_lens)
|
||||
assert len(input_positions) == sum(seq_lens)
|
||||
actual = sampling_metadata.selected_token_indices
|
||||
expected = torch.tensor(expected_selected_token_indices,
|
||||
device=actual.device,
|
||||
dtype=actual.dtype)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
torch.allclose(input_tokens, input_positions)
|
||||
|
||||
actual = sampling_metadata.selected_token_indices
|
||||
expected = torch.tensor(expected_selected_token_indices,
|
||||
device=actual.device,
|
||||
dtype=actual.dtype)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", list(range(1, 257, 3)))
|
||||
@pytest.mark.parametrize("use_prompt_embeds", [True, False])
|
||||
def test_prepare_decode_cuda_graph(batch_size, use_prompt_embeds, monkeypatch):
|
||||
if use_prompt_embeds:
|
||||
# Prompt Embeddings is only currently supported on V0
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
|
||||
model_runner = _create_model_runner(
|
||||
"facebook/opt-125m",
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
enforce_eager=False,
|
||||
max_num_batched_tokens=100000,
|
||||
max_num_seqs=100000,
|
||||
enable_chunked_prefill=False,
|
||||
enable_prompt_embeds=True,
|
||||
)
|
||||
|
||||
context_lens: list[int] = []
|
||||
seq_group_metadata_list: list[SequenceGroupMetadata] = []
|
||||
# Assume each seq group finishes prefill.
|
||||
for i in range(batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
context_len = i % (model_runner.block_size - 1) + 1
|
||||
context_lens.append(context_len)
|
||||
if use_prompt_embeds:
|
||||
seq_data = SequenceData.from_seqs(
|
||||
prompt_token_ids=[0] * context_len,
|
||||
prompt_embeds=torch.rand(context_len, 10),
|
||||
)
|
||||
output_embed = torch.rand(10)
|
||||
else:
|
||||
seq_data = SequenceData.from_seqs(
|
||||
prompt_token_ids=range(context_len))
|
||||
output_embed = None
|
||||
seq_data.update_num_computed_tokens(context_len)
|
||||
# Append one token ID since prefill is finished.
|
||||
seq_data.append_token_id(1, 0, output_embed)
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=False,
|
||||
seq_data={0: seq_data},
|
||||
sampling_params=SamplingParams(temperature=0),
|
||||
block_tables={0: [1]},
|
||||
)
|
||||
assert seq_group_metadata.token_chunk_size == 1
|
||||
seq_group_metadata_list.append(seq_group_metadata)
|
||||
|
||||
model_input = model_runner._prepare_model_input_tensors(
|
||||
seq_group_metadata_list)
|
||||
input_tokens = model_input.input_tokens
|
||||
input_positions = model_input.input_positions
|
||||
input_embeds = model_input.inputs_embeds
|
||||
attn_metadata = model_input.attn_metadata
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
|
||||
assert len(slot_mapping) == len(input_tokens)
|
||||
|
||||
expected_bs = model_runner.vllm_config.pad_for_cudagraph(
|
||||
len(seq_group_metadata_list))
|
||||
# Verify input metadata is correct for prompts.
|
||||
device = model_runner.device
|
||||
assert attn_metadata.num_prefills == 0
|
||||
assert attn_metadata.num_prefill_tokens == 0
|
||||
seq_lens = [context_len + 1 for context_len in context_lens]
|
||||
# seq_lens are padded to expected_bs
|
||||
for _ in range(expected_bs - len(seq_lens)):
|
||||
seq_lens.append(1)
|
||||
assert attn_metadata.seq_lens == seq_lens
|
||||
assert attn_metadata.num_decode_tokens == len(seq_lens)
|
||||
start_idx = 0
|
||||
start_loc = [start_idx]
|
||||
for _ in context_lens:
|
||||
# decode has only 1 token for query.
|
||||
start_idx += 1
|
||||
start_loc.append(start_idx)
|
||||
torch.testing.assert_close(
|
||||
attn_metadata.query_start_loc,
|
||||
torch.tensor(start_loc, dtype=torch.int32, device=device))
|
||||
|
||||
start_idx = 0
|
||||
seq_start_loc = [start_idx]
|
||||
for seq_len in seq_lens:
|
||||
start_idx += seq_len
|
||||
seq_start_loc.append(start_idx)
|
||||
torch.testing.assert_close(
|
||||
attn_metadata.seq_start_loc,
|
||||
torch.tensor(seq_start_loc, dtype=torch.int32, device=device))
|
||||
|
||||
torch.testing.assert_close(
|
||||
attn_metadata.context_lens_tensor,
|
||||
torch.tensor(context_lens, dtype=torch.int, device=device))
|
||||
assert attn_metadata.max_decode_seq_len == max(seq_lens)
|
||||
torch.testing.assert_close(
|
||||
attn_metadata.seq_lens_tensor[:len(seq_lens)],
|
||||
torch.tensor(seq_lens, dtype=torch.int, device=device))
|
||||
|
||||
# block table's first index corresponds to each batch, meaning in
|
||||
# decoding it is each token.
|
||||
assert attn_metadata.block_tables.shape[0] == len(input_tokens)
|
||||
# Block table's second dim corresponds to each token's block number.
|
||||
# It is padded up to
|
||||
assert attn_metadata.block_tables.shape[1] == (
|
||||
model_runner.get_max_block_per_batch())
|
||||
assert attn_metadata.use_cuda_graph is True
|
||||
|
||||
assert len(input_tokens) == expected_bs
|
||||
assert len(input_positions) == expected_bs
|
||||
if use_prompt_embeds:
|
||||
expected_input_embeds_length = start_loc[-1]
|
||||
assert len(input_embeds) == expected_input_embeds_length
|
||||
assert expected_input_embeds_length <= expected_bs
|
||||
else:
|
||||
assert input_embeds is None
|
||||
|
||||
# Verify Sampling
|
||||
expected_selected_token_indices = []
|
||||
for selected_token_start_idx, _ in enumerate(context_lens):
|
||||
expected_selected_token_indices.append(selected_token_start_idx)
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
seq_lens,
|
||||
# query lens is all 1 for decode.
|
||||
query_lens=[1 for _ in range(len(context_lens))],
|
||||
device=model_runner.device,
|
||||
pin_memory=model_runner.pin_memory)
|
||||
actual = sampling_metadata.selected_token_indices
|
||||
expected = torch.tensor(expected_selected_token_indices,
|
||||
device=actual.device,
|
||||
dtype=actual.dtype)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
def test_empty_seq_group():
|
||||
"""Verify prepare prompt and decode returns empty output."""
|
||||
model_runner = _create_model_runner(
|
||||
"facebook/opt-125m",
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
enforce_eager=False,
|
||||
)
|
||||
seq_group_metadata_list: list[SequenceGroupMetadata] = []
|
||||
model_input = model_runner._prepare_model_input_tensors(
|
||||
seq_group_metadata_list)
|
||||
|
||||
input_tokens = model_input.input_tokens
|
||||
input_positions = model_input.input_positions
|
||||
attn_metadata = model_input.attn_metadata
|
||||
|
||||
assert input_tokens is None
|
||||
assert input_positions is None
|
||||
assert attn_metadata is None
|
||||
|
||||
model_input = model_runner._prepare_model_input_tensors(
|
||||
seq_group_metadata_list)
|
||||
|
||||
input_tokens = model_input.input_tokens
|
||||
input_positions = model_input.input_positions
|
||||
input_embeds = model_input.inputs_embeds
|
||||
attn_metadata = model_input.attn_metadata
|
||||
return_seq_lens = model_input.seq_lens
|
||||
|
||||
assert input_tokens is None
|
||||
assert input_positions is None
|
||||
assert input_embeds is None
|
||||
assert attn_metadata is None
|
||||
assert return_seq_lens is None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def distributed_init():
|
||||
init_distributed_environment(
|
||||
world_size=1,
|
||||
rank=0,
|
||||
distributed_init_method=f"tcp://127.0.0.1:{get_open_port()}",
|
||||
local_rank=0)
|
||||
ensure_model_parallel_initialized(1, 1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", list(range(2, 128, 3)))
|
||||
@pytest.mark.parametrize("enforce_eager", [True, False])
|
||||
@pytest.mark.parametrize('use_prompt_embeds', [True, False])
|
||||
def test_hybrid_batches(batch_size, enforce_eager, use_prompt_embeds,
|
||||
distributed_init, monkeypatch):
|
||||
if use_prompt_embeds:
|
||||
# Prompt Embeddings is only currently supported on V0
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
|
||||
model_runner = _create_model_runner(
|
||||
"facebook/opt-125m",
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
enforce_eager=enforce_eager,
|
||||
max_num_batched_tokens=100000,
|
||||
max_num_seqs=100000,
|
||||
enable_chunked_prefill=True,
|
||||
enable_prompt_embeds=True,
|
||||
)
|
||||
|
||||
# Add prefill requests.
|
||||
seq_lens: list[int] = []
|
||||
seq_group_metadata_list: list[SequenceGroupMetadata] = []
|
||||
prefill_metadata_list: list[SequenceGroupMetadata] = []
|
||||
decode_metadata_list: list[SequenceGroupMetadata] = []
|
||||
block_tables = {0: [1]}
|
||||
prefill_batch_size = batch_size // 2
|
||||
decode_batch_size = batch_size - prefill_batch_size
|
||||
expected_input_embeds_len = 0
|
||||
for i in range(prefill_batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
seq_len = i % (model_runner.block_size - 1) + 1
|
||||
seq_lens.append(seq_len)
|
||||
if use_prompt_embeds:
|
||||
seq_data = SequenceData.from_seqs(
|
||||
prompt_token_ids=[0] * seq_len,
|
||||
prompt_embeds=torch.rand(seq_len, 10),
|
||||
)
|
||||
expected_input_embeds_len += seq_len
|
||||
else:
|
||||
seq_data = SequenceData.from_seqs(
|
||||
prompt_token_ids=range(seq_len), )
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: seq_data},
|
||||
sampling_params=SamplingParams(temperature=0),
|
||||
block_tables=block_tables,
|
||||
)
|
||||
assert seq_group_metadata.token_chunk_size == seq_data.get_len()
|
||||
seq_group_metadata_list.append(seq_group_metadata)
|
||||
prefill_metadata_list.append(seq_group_metadata)
|
||||
|
||||
# Add decode requests
|
||||
for i in range(prefill_batch_size, batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
context_len = i % (model_runner.block_size - 1) + 1
|
||||
if use_prompt_embeds:
|
||||
seq_data = SequenceData.from_seqs(
|
||||
prompt_token_ids=[0] * context_len,
|
||||
prompt_embeds=torch.rand(context_len, 10),
|
||||
)
|
||||
output_embed = torch.rand(10)
|
||||
# This also iterates the expected input_embeds, because the model
|
||||
# needs both the input and output embeddings passed into together
|
||||
expected_input_embeds_len += 1
|
||||
else:
|
||||
seq_data = SequenceData.from_seqs(
|
||||
prompt_token_ids=range(context_len), )
|
||||
output_embed = None
|
||||
assert len(seq_data.prompt_token_ids) == context_len
|
||||
seq_data.append_token_id(1, 0, output_embed)
|
||||
seq_data.update_num_computed_tokens(context_len)
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=False,
|
||||
seq_data={0: seq_data},
|
||||
sampling_params=SamplingParams(temperature=0),
|
||||
block_tables={0: [1]},
|
||||
)
|
||||
assert seq_group_metadata.token_chunk_size == 1
|
||||
seq_group_metadata_list.append(seq_group_metadata)
|
||||
decode_metadata_list.append(seq_group_metadata)
|
||||
|
||||
model_input = model_runner.prepare_model_input(seq_group_metadata_list)
|
||||
|
||||
input_tokens = model_input.input_tokens
|
||||
input_positions = model_input.input_positions
|
||||
input_embeds = model_input.inputs_embeds
|
||||
attn_metadata = model_input.attn_metadata
|
||||
|
||||
prefill_meta_actual = attn_metadata.prefill_metadata
|
||||
decode_meta_actual = attn_metadata.decode_metadata
|
||||
|
||||
assert len(attn_metadata.slot_mapping) == len(input_tokens)
|
||||
assert len(input_positions) == len(input_tokens)
|
||||
assert attn_metadata.num_prefills == prefill_batch_size
|
||||
assert attn_metadata.num_decode_tokens == decode_batch_size
|
||||
assert attn_metadata.num_prefill_tokens == sum(seq_lens)
|
||||
if expected_input_embeds_len == 0:
|
||||
assert input_embeds is None
|
||||
else:
|
||||
assert len(input_embeds) == expected_input_embeds_len
|
||||
|
||||
# Verify attn metadata is consistent. We don't need to test individual
|
||||
# values here because they are tested above.
|
||||
attn_metadata = model_runner._prepare_model_input_tensors(
|
||||
seq_group_metadata_list).attn_metadata
|
||||
|
||||
for attr_expected, attr_actual in zip(vars(attn_metadata.prefill_metadata),
|
||||
vars(prefill_meta_actual)):
|
||||
assert attr_expected[1] == attr_actual[1]
|
||||
for attr_expected, attr_actual in zip(vars(attn_metadata.decode_metadata),
|
||||
vars(decode_meta_actual)):
|
||||
assert attr_expected[1] == attr_actual[1]
|
||||
@ -1,68 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
|
||||
def test_gpu_memory_profiling():
|
||||
# Tests the gpu profiling that happens in order to determine the number of
|
||||
# KV cache blocks that we can allocate on the GPU.
|
||||
# This test mocks the maximum available gpu memory so that it can run on
|
||||
# any gpu setup.
|
||||
|
||||
# Set up engine args to build a worker.
|
||||
engine_args = EngineArgs(model="facebook/opt-125m",
|
||||
dtype="half",
|
||||
load_format="dummy")
|
||||
engine_config = engine_args.create_engine_config()
|
||||
engine_config.cache_config.num_gpu_blocks = 1000
|
||||
engine_config.cache_config.num_cpu_blocks = 1000
|
||||
|
||||
# Create the worker.
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
worker = Worker(
|
||||
vllm_config=engine_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=True,
|
||||
)
|
||||
|
||||
# Set 10GiB as the total gpu ram to be device-agnostic
|
||||
def mock_mem_info():
|
||||
current_usage = torch.cuda.memory_stats(
|
||||
)["allocated_bytes.all.current"]
|
||||
mock_total_bytes = 10 * 1024**3
|
||||
free = mock_total_bytes - current_usage
|
||||
|
||||
return (free, mock_total_bytes)
|
||||
|
||||
from unittest.mock import patch
|
||||
with patch("torch.cuda.mem_get_info", side_effect=mock_mem_info):
|
||||
# Load the model so we can profile it
|
||||
worker.init_device()
|
||||
worker.load_model()
|
||||
gpu_blocks, _ = worker.determine_num_available_blocks()
|
||||
|
||||
# Peak vram usage by torch should be 0.47 GiB
|
||||
# Model weights take 0.25 GiB
|
||||
# No memory should be allocated outside of torch
|
||||
# 9.0 GiB should be the utilization target
|
||||
# 8.28 GiB should be available for the KV cache
|
||||
block_size = CacheEngine.get_cache_block_size(
|
||||
engine_config.cache_config, engine_config.model_config,
|
||||
engine_config.parallel_config)
|
||||
|
||||
expected_blocks = (8.28 * 1024**3) // block_size
|
||||
|
||||
# Check within a small tolerance for portability
|
||||
# Hardware, kernel, or dependency changes could all affect memory
|
||||
# utilization.
|
||||
# A 100 block tolerance here should be about 60MB of wiggle room.
|
||||
assert abs(gpu_blocks - expected_blocks) < 100
|
||||
@ -1,87 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
|
||||
def test_swap() -> None:
|
||||
# Configure the engine.
|
||||
engine_args = EngineArgs(model="distilbert/distilgpt2",
|
||||
dtype="half",
|
||||
load_format="dummy")
|
||||
engine_config = engine_args.create_engine_config()
|
||||
engine_config.cache_config.num_gpu_blocks = 1000
|
||||
engine_config.cache_config.num_cpu_blocks = 1000
|
||||
|
||||
# Create the worker.
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
worker = Worker(
|
||||
vllm_config=engine_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=True,
|
||||
)
|
||||
|
||||
# Initialize the worker.
|
||||
worker.init_device()
|
||||
worker.load_model()
|
||||
worker.initialize_cache(
|
||||
num_gpu_blocks=engine_config.cache_config.num_gpu_blocks,
|
||||
num_cpu_blocks=engine_config.cache_config.num_cpu_blocks)
|
||||
|
||||
# Randomly initialize the cache.
|
||||
gpu_cache = worker.cache_engine[0].gpu_cache
|
||||
cpu_cache = worker.cache_engine[0].cpu_cache
|
||||
num_layers = len(gpu_cache)
|
||||
for i in range(num_layers):
|
||||
gpu_key_cache, gpu_value_cache = gpu_cache[i]
|
||||
gpu_key_cache.random_()
|
||||
gpu_value_cache.random_()
|
||||
cpu_key_cache, cpu_value_cache = cpu_cache[i]
|
||||
cpu_key_cache.random_()
|
||||
cpu_value_cache.random_()
|
||||
|
||||
allclose = lambda a, b: torch.allclose(
|
||||
a.cuda(), b.cuda(), rtol=0.0, atol=0.0)
|
||||
|
||||
# Test swap out.
|
||||
blocks_to_swap_out = [(3, 72), (56, 35), (84, 34)]
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=[],
|
||||
blocks_to_swap_in=[],
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=[],
|
||||
)
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
|
||||
for i in range(num_layers):
|
||||
gpu_key_cache, gpu_value_cache = gpu_cache[i]
|
||||
cpu_key_cache, cpu_value_cache = cpu_cache[i]
|
||||
for src, dst in blocks_to_swap_out:
|
||||
assert allclose(gpu_key_cache[src], cpu_key_cache[dst])
|
||||
assert allclose(gpu_value_cache[src], cpu_value_cache[dst])
|
||||
|
||||
# Test swap in.
|
||||
execute_model_req.blocks_to_swap_out = []
|
||||
execute_model_req.blocks_to_swap_in = [
|
||||
(19, 45),
|
||||
(67, 23),
|
||||
(12, 78),
|
||||
(40, 99),
|
||||
(1, 71),
|
||||
]
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
|
||||
for i in range(num_layers):
|
||||
gpu_key_cache, gpu_value_cache = gpu_cache[i]
|
||||
cpu_key_cache, cpu_value_cache = cpu_cache[i]
|
||||
for src, dst in execute_model_req.blocks_to_swap_in:
|
||||
assert allclose(gpu_key_cache[dst], cpu_key_cache[src])
|
||||
assert allclose(gpu_value_cache[dst], cpu_value_cache[src])
|
||||
@ -50,8 +50,8 @@ ALLOWED_FILES = set([
|
||||
# cloudpickle
|
||||
'vllm/worker/worker_base.py',
|
||||
'vllm/executor/mp_distributed_executor.py',
|
||||
'vllm/executor/ray_distributed_executor.py',
|
||||
'vllm/entrypoints/llm.py',
|
||||
'vllm/v1/executor/ray_distributed_executor.py',
|
||||
'tests/utils.py',
|
||||
# pickle and cloudpickle
|
||||
'vllm/utils/__init__.py',
|
||||
|
||||
@ -9,6 +9,8 @@ import ast
|
||||
import inspect
|
||||
import sys
|
||||
|
||||
import regex as re
|
||||
|
||||
|
||||
def get_attr_docs(cls_node: ast.ClassDef) -> dict[str, str]:
|
||||
"""
|
||||
@ -88,11 +90,12 @@ def validate_class(class_node: ast.ClassDef):
|
||||
for stmt in class_node.body:
|
||||
# A field is defined as a class variable that has a type annotation.
|
||||
if isinstance(stmt, ast.AnnAssign):
|
||||
# Skip ClassVar
|
||||
# Skip ClassVar and InitVar
|
||||
# see https://docs.python.org/3/library/dataclasses.html#class-variables
|
||||
if isinstance(stmt.annotation, ast.Subscript) and isinstance(
|
||||
stmt.annotation.value,
|
||||
ast.Name) and stmt.annotation.value.id == "ClassVar":
|
||||
# and https://docs.python.org/3/library/dataclasses.html#init-only-variables
|
||||
if (isinstance(stmt.annotation, ast.Subscript)
|
||||
and isinstance(stmt.annotation.value, ast.Name)
|
||||
and stmt.annotation.value.id in {"ClassVar", "InitVar"}):
|
||||
continue
|
||||
|
||||
if isinstance(stmt.target, ast.Name):
|
||||
@ -132,7 +135,7 @@ def validate_ast(tree: ast.stmt):
|
||||
|
||||
def validate_file(file_path: str):
|
||||
try:
|
||||
print(f"validating {file_path} config dataclasses ", end="")
|
||||
print(f"Validating {file_path} config dataclasses ", end="")
|
||||
with open(file_path, encoding="utf-8") as f:
|
||||
source = f.read()
|
||||
|
||||
@ -140,7 +143,7 @@ def validate_file(file_path: str):
|
||||
validate_ast(tree)
|
||||
except ValueError as e:
|
||||
print(e)
|
||||
SystemExit(2)
|
||||
raise SystemExit(1) from e
|
||||
else:
|
||||
print("✅")
|
||||
|
||||
@ -151,7 +154,13 @@ def fail(message: str, node: ast.stmt):
|
||||
|
||||
def main():
|
||||
for filename in sys.argv[1:]:
|
||||
validate_file(filename)
|
||||
# Only run for Python files in vllm/ or tests/
|
||||
if not re.match(r"^(vllm|tests)/.*\.py$", filename):
|
||||
continue
|
||||
# Only run if the file contains @config
|
||||
with open(filename, encoding="utf-8") as f:
|
||||
if "@config" in f.read():
|
||||
validate_file(filename)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -391,8 +391,8 @@ class MultiHeadAttention(nn.Module):
|
||||
backend = _Backend.FLASH_ATTN
|
||||
use_upstream_fa = True
|
||||
|
||||
if current_platform.is_rocm():
|
||||
# currently, only torch_sdpa is supported on rocm
|
||||
if current_platform.is_rocm() or current_platform.is_xpu():
|
||||
# currently, only torch_sdpa is supported on rocm/xpu
|
||||
self.attn_backend = _Backend.TORCH_SDPA
|
||||
else:
|
||||
|
||||
|
||||
@ -73,6 +73,7 @@ def kernel_unified_attention_2d(
|
||||
output_stride_1: tl.int64, # int, should be equal to head_size
|
||||
qq_bias_stride_0: tl.int64, # int
|
||||
BLOCK_SIZE: tl.constexpr, # int
|
||||
TILE_SIZE: tl.constexpr, # int must be power of 2
|
||||
HEAD_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
||||
@ -118,6 +119,7 @@ def kernel_unified_attention_2d(
|
||||
|
||||
offs_m = tl.arange(0, BLOCK_M)
|
||||
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
|
||||
offs_t = tl.arange(0, TILE_SIZE)
|
||||
query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv
|
||||
|
||||
query_offset_0 = cur_batch_in_all_start_index + query_pos
|
||||
@ -177,31 +179,32 @@ def kernel_unified_attention_2d(
|
||||
# actual sequence length
|
||||
max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len)
|
||||
|
||||
# calculate the number of tiles (blocks) that need to be processed to
|
||||
# cover the longest sequence prefix (due to causal masking, blocks beyond
|
||||
# calculate the number of tiles that need to be processed to
|
||||
# cover the longest sequence prefix (due to causal masking, tiles beyond
|
||||
# this prefix can be skipped)
|
||||
num_blocks = cdiv_fn(max_seq_prefix_len, BLOCK_SIZE)
|
||||
num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE)
|
||||
|
||||
# iterate through tiles
|
||||
for j in range(0, num_blocks):
|
||||
for j in range(0, num_tiles):
|
||||
seq_offset = j * TILE_SIZE + offs_t
|
||||
tile_mask = seq_offset < max_seq_prefix_len
|
||||
|
||||
physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j)
|
||||
physical_block_idx = tl.load(block_tables_ptr + block_table_offset +
|
||||
seq_offset // BLOCK_SIZE).to(tl.int64)
|
||||
|
||||
offs_n = tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
v_offset = (physical_block_idx * stride_v_cache_0 +
|
||||
v_offset = (physical_block_idx[:, None] * stride_v_cache_0 +
|
||||
kv_head_idx * stride_v_cache_2 +
|
||||
offs_d[None, :] * stride_v_cache_3 +
|
||||
offs_n[:, None] * stride_v_cache_1)
|
||||
(seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1)
|
||||
|
||||
k_offset = (physical_block_idx * stride_k_cache_0 +
|
||||
k_offset = (physical_block_idx[None, :] * stride_k_cache_0 +
|
||||
kv_head_idx * stride_k_cache_2 +
|
||||
offs_d[:, None] * stride_k_cache_3 +
|
||||
offs_n[None, :] * stride_k_cache_1)
|
||||
(seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1)
|
||||
|
||||
# K : (HEAD_SIZE, BLOCK_SIZE)
|
||||
# K : (HEAD_SIZE, TILE_SIZE)
|
||||
K_load = tl.load(key_cache_ptr + k_offset,
|
||||
mask=dim_mask[:, None],
|
||||
mask=dim_mask[:, None] & tile_mask[None, :],
|
||||
other=0.0)
|
||||
|
||||
if K_load.dtype.is_fp8():
|
||||
@ -212,9 +215,9 @@ def kernel_unified_attention_2d(
|
||||
else:
|
||||
K = K_load
|
||||
|
||||
# V : (BLOCK_SIZE, HEAD_SIZE)
|
||||
# V : (TILE_SIZE, HEAD_SIZE)
|
||||
V_load = tl.load(value_cache_ptr + v_offset,
|
||||
mask=dim_mask[None, :],
|
||||
mask=dim_mask[None, :] & tile_mask[:, None],
|
||||
other=0.0)
|
||||
|
||||
if V_load.dtype.is_fp8():
|
||||
@ -225,12 +228,10 @@ def kernel_unified_attention_2d(
|
||||
else:
|
||||
V = V_load
|
||||
|
||||
seq_offset = j * BLOCK_SIZE + offs_n
|
||||
|
||||
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
|
||||
|
||||
# S : (BLOCK_M, BLOCK_SIZE)
|
||||
S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32)
|
||||
# S : (BLOCK_M, TILE_SIZE)
|
||||
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
|
||||
|
||||
S += scale * tl.dot(Q, K)
|
||||
|
||||
@ -262,11 +263,12 @@ def kernel_unified_attention_2d(
|
||||
# compute running maximum
|
||||
# m_j : (BLOCK_M,)
|
||||
m_j = tl.maximum(M, tl.max(S, axis=1))
|
||||
|
||||
# For sliding window there's a chance the max is -inf due to masking of
|
||||
# the entire row. In this case we need to set m_j 0 to avoid NaN
|
||||
m_j = tl.where(m_j > float("-inf"), m_j, 0.0)
|
||||
|
||||
# P : (BLOCK_M, BLOCK_SIZE)
|
||||
# P : (BLOCK_M, TILE_SIZE)
|
||||
P = tl.exp(S - m_j[:, None])
|
||||
|
||||
# l_j : (BLOCK_M,)
|
||||
@ -327,6 +329,7 @@ def kernel_unified_attention_3d(
|
||||
query_stride_1: tl.int64, # int, should be equal to head_size
|
||||
qq_bias_stride_0: tl.int64, # int
|
||||
BLOCK_SIZE: tl.constexpr, # int
|
||||
TILE_SIZE: tl.constexpr, # int, must be power of 2
|
||||
HEAD_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
||||
@ -374,20 +377,19 @@ def kernel_unified_attention_3d(
|
||||
|
||||
# number of segments for this particular sequence
|
||||
num_segments = NUM_SEGMENTS_PER_SEQ
|
||||
blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE)
|
||||
tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE)
|
||||
|
||||
if segm_idx * blocks_per_segment * BLOCK_SIZE >= seq_len:
|
||||
if segm_idx * tiles_per_segment * TILE_SIZE >= seq_len:
|
||||
return
|
||||
|
||||
offs_m = tl.arange(0, BLOCK_M)
|
||||
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
|
||||
|
||||
offs_t = tl.arange(0, TILE_SIZE)
|
||||
query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv
|
||||
|
||||
query_offset_0 = cur_batch_in_all_start_index + query_pos
|
||||
query_offset_1 = kv_head_idx * num_queries_per_kv + \
|
||||
offs_m % num_queries_per_kv
|
||||
|
||||
query_offset = (query_offset_0[:, None] * query_stride_0 +
|
||||
query_offset_1[:, None] * query_stride_1 + offs_d[None, :])
|
||||
|
||||
@ -433,30 +435,44 @@ def kernel_unified_attention_3d(
|
||||
qq_bias_row_ptrs = (qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0
|
||||
) # shape: [BLOCK_M]
|
||||
|
||||
num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)
|
||||
# compute the length of the longest sequence prefix spanned by any
|
||||
# query token in the current q_block (q_block_local_idx)
|
||||
max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + (
|
||||
BLOCK_M - 1) // num_queries_per_kv + 1
|
||||
|
||||
# adjust for potential padding in the last q_block by considering the
|
||||
# actual sequence length
|
||||
max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len)
|
||||
|
||||
# calculate the number of tiles that need to be processed to
|
||||
# cover the longest sequence prefix (due to causal masking, tiles beyond
|
||||
# this prefix can be skipped)
|
||||
num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE)
|
||||
|
||||
# iterate through tiles within current segment
|
||||
for j in range(
|
||||
segm_idx * blocks_per_segment,
|
||||
min((segm_idx + 1) * blocks_per_segment, num_blocks),
|
||||
segm_idx * tiles_per_segment,
|
||||
min((segm_idx + 1) * tiles_per_segment, num_tiles),
|
||||
):
|
||||
physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j)
|
||||
seq_offset = j * TILE_SIZE + offs_t
|
||||
tile_mask = seq_offset < max_seq_prefix_len
|
||||
|
||||
offs_n = tl.arange(0, BLOCK_SIZE)
|
||||
physical_block_idx = tl.load(block_tables_ptr + block_table_offset +
|
||||
seq_offset // BLOCK_SIZE).to(tl.int64)
|
||||
|
||||
v_offset = (physical_block_idx * stride_v_cache_0 +
|
||||
v_offset = (physical_block_idx[:, None] * stride_v_cache_0 +
|
||||
kv_head_idx * stride_v_cache_2 +
|
||||
offs_d[None, :] * stride_v_cache_3 +
|
||||
offs_n[:, None] * stride_v_cache_1)
|
||||
(seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1)
|
||||
|
||||
k_offset = (physical_block_idx * stride_k_cache_0 +
|
||||
k_offset = (physical_block_idx[None, :] * stride_k_cache_0 +
|
||||
kv_head_idx * stride_k_cache_2 +
|
||||
offs_d[:, None] * stride_k_cache_3 +
|
||||
offs_n[None, :] * stride_k_cache_1)
|
||||
(seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1)
|
||||
|
||||
# K : (HEAD_SIZE, BLOCK_SIZE)
|
||||
# K : (HEAD_SIZE, TILE_SIZE)
|
||||
K_load = tl.load(key_cache_ptr + k_offset,
|
||||
mask=dim_mask[:, None],
|
||||
mask=dim_mask[:, None] & tile_mask[None, :],
|
||||
other=0.0)
|
||||
|
||||
if K_load.dtype.is_fp8():
|
||||
@ -467,9 +483,9 @@ def kernel_unified_attention_3d(
|
||||
else:
|
||||
K = K_load
|
||||
|
||||
# V : (BLOCK_SIZE, HEAD_SIZE)
|
||||
# V : (TILE_SIZE, HEAD_SIZE)
|
||||
V_load = tl.load(value_cache_ptr + v_offset,
|
||||
mask=dim_mask[None, :],
|
||||
mask=dim_mask[None, :] & tile_mask[:, None],
|
||||
other=0.0)
|
||||
|
||||
if V_load.dtype.is_fp8():
|
||||
@ -480,13 +496,10 @@ def kernel_unified_attention_3d(
|
||||
else:
|
||||
V = V_load
|
||||
|
||||
seq_offset = j * BLOCK_SIZE + offs_n
|
||||
|
||||
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
|
||||
|
||||
# S : (BLOCK_M, BLOCK_SIZE)
|
||||
S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32)
|
||||
|
||||
# S : (BLOCK_M, TILE_SIZE)
|
||||
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
|
||||
S += scale * tl.dot(Q, K)
|
||||
|
||||
if USE_SOFTCAP:
|
||||
@ -517,11 +530,12 @@ def kernel_unified_attention_3d(
|
||||
# compute running maximum
|
||||
# m_j : (BLOCK_M,)
|
||||
m_j = tl.maximum(M, tl.max(S, axis=1))
|
||||
|
||||
# For sliding window there's a chance the max is -inf due to masking of
|
||||
# the entire row. In this case we need to set m_j 0 to avoid NaN
|
||||
m_j = tl.where(m_j > float("-inf"), m_j, 0.0)
|
||||
|
||||
# P : (BLOCK_M, BLOCK_SIZE,)
|
||||
# P : (BLOCK_M, TILE_SIZE,)
|
||||
P = tl.exp(S - m_j[:, None])
|
||||
|
||||
# l_j : (BLOCK_M,)
|
||||
@ -573,7 +587,7 @@ def reduce_segments(
|
||||
output_stride_0: tl.int64, # int
|
||||
output_stride_1: tl.int64, # int, should be equal to head_size
|
||||
block_table_stride: tl.int64, # int
|
||||
BLOCK_SIZE: tl.constexpr, # int
|
||||
TILE_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE: tl.constexpr, # int, must be power of 2
|
||||
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||
query_start_len_ptr, # [num_seqs+1]
|
||||
@ -594,10 +608,10 @@ def reduce_segments(
|
||||
|
||||
# number of segments for this particular sequence
|
||||
num_segments = NUM_SEGMENTS_PER_SEQ
|
||||
blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE)
|
||||
tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE)
|
||||
|
||||
# create masks for subsequent loads
|
||||
act_num_segments = cdiv_fn(seq_len, blocks_per_segment * BLOCK_SIZE)
|
||||
act_num_segments = cdiv_fn(seq_len, tiles_per_segment * TILE_SIZE)
|
||||
segm_mask = tl.arange(0, NUM_SEGMENTS_PER_SEQ) < tl.full(
|
||||
[NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32)
|
||||
dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1,
|
||||
@ -671,13 +685,10 @@ def unified_attention(
|
||||
# Optional tensor for sinks
|
||||
sinks=None,
|
||||
):
|
||||
|
||||
assert causal, "Only causal attention is supported"
|
||||
assert q_descale is None, "Q scales not supported"
|
||||
|
||||
block_size = v.shape[1]
|
||||
assert q.element_size() >= 2 or block_size >= 32, \
|
||||
"Block size must be at least 32 for fp8"
|
||||
|
||||
if sinks is not None:
|
||||
assert sinks.shape[0] == q.shape[1], \
|
||||
"Sinks must be num_query_heads size"
|
||||
@ -707,6 +718,12 @@ def unified_attention(
|
||||
# = floor(q.shape[0] / BLOCK_Q) + num_seqs
|
||||
total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs
|
||||
|
||||
# Assigning default tile sizes for prefill and decode.
|
||||
# Note: each tile size must be at least 32 for "fp8" (q.element_size() == 1)
|
||||
# and at least 16 for all other data types.
|
||||
TILE_SIZE_PREFILL = 32
|
||||
TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32
|
||||
|
||||
# if batch contains a prefill
|
||||
if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128:
|
||||
kernel_unified_attention_2d[(
|
||||
@ -736,6 +753,7 @@ def unified_attention(
|
||||
output_stride_1=out.stride(1),
|
||||
qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0,
|
||||
BLOCK_SIZE=block_size,
|
||||
TILE_SIZE=TILE_SIZE_PREFILL,
|
||||
HEAD_SIZE=head_size,
|
||||
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||
USE_ALIBI_SLOPES=use_alibi_slopes,
|
||||
@ -809,6 +827,7 @@ def unified_attention(
|
||||
query_stride_1=q.stride(1),
|
||||
qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0,
|
||||
BLOCK_SIZE=block_size,
|
||||
TILE_SIZE=TILE_SIZE_DECODE,
|
||||
HEAD_SIZE=head_size,
|
||||
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||
USE_ALIBI_SLOPES=use_alibi_slopes,
|
||||
@ -830,7 +849,6 @@ def unified_attention(
|
||||
BLOCK_M=BLOCK_M,
|
||||
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
|
||||
)
|
||||
|
||||
reduce_segments[(q.shape[0], num_query_heads)](
|
||||
output_ptr=out,
|
||||
segm_output_ptr=segm_output,
|
||||
@ -844,7 +862,7 @@ def unified_attention(
|
||||
output_stride_0=out.stride(0),
|
||||
output_stride_1=out.stride(1),
|
||||
block_table_stride=block_table.stride(0),
|
||||
BLOCK_SIZE=block_size,
|
||||
TILE_SIZE=TILE_SIZE_DECODE,
|
||||
HEAD_SIZE=head_size,
|
||||
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||
query_start_len_ptr=cu_seqlens_q,
|
||||
|
||||
@ -171,7 +171,8 @@ class BenchmarkDataset(ABC):
|
||||
If `None`, LoRA is not used.
|
||||
|
||||
Returns:
|
||||
A new [LoRARequest][] (or `None` if not applicable).
|
||||
A new [`LoRARequest`][vllm.lora.request.LoRARequest]
|
||||
(or `None` if not applicable).
|
||||
"""
|
||||
if max_loras is None or lora_path is None:
|
||||
return None
|
||||
@ -1357,7 +1358,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
elif args.dataset_name == "sonnet":
|
||||
dataset = SonnetDataset(dataset_path=args.dataset_path)
|
||||
# For the "sonnet" dataset, formatting depends on the backend.
|
||||
if args.endpoint_type == "openai-chat":
|
||||
if args.backend == "openai-chat":
|
||||
input_requests = dataset.sample(
|
||||
num_requests=args.num_prompts,
|
||||
input_len=args.sonnet_input_len,
|
||||
@ -1461,7 +1462,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
"Please consider contributing if you would "
|
||||
"like to add support for additional dataset formats.")
|
||||
|
||||
if dataset_class.IS_MULTIMODAL and args.endpoint_type not in [
|
||||
if dataset_class.IS_MULTIMODAL and args.backend not in [
|
||||
"openai-chat",
|
||||
"openai-audio",
|
||||
]:
|
||||
@ -1469,7 +1470,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
# endpoint-type.
|
||||
raise ValueError(
|
||||
"Multi-modal content is only supported on 'openai-chat' and "
|
||||
"'openai-audio' endpoint-type.")
|
||||
"'openai-audio' backends.")
|
||||
input_requests = dataset_class(
|
||||
dataset_path=args.dataset_path,
|
||||
dataset_subset=args.hf_subset,
|
||||
@ -1562,7 +1563,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
||||
|
||||
try:
|
||||
# Enforce endpoint compatibility for multimodal datasets.
|
||||
if args.dataset_name == "random-mm" and args.endpoint_type not in [
|
||||
if args.dataset_name == "random-mm" and args.backend not in [
|
||||
"openai-chat"]:
|
||||
raise ValueError(
|
||||
"Multi-modal content (images) is only supported on "
|
||||
|
||||
@ -89,6 +89,7 @@ class RequestFuncOutput:
|
||||
tpot: float = 0.0 # avg next-token latencies
|
||||
prompt_len: int = 0
|
||||
error: str = ""
|
||||
start_time: float = 0.0
|
||||
|
||||
|
||||
async def async_request_openai_completions(
|
||||
@ -140,6 +141,7 @@ async def async_request_openai_completions(
|
||||
|
||||
generated_text = ""
|
||||
st = time.perf_counter()
|
||||
output.start_time = st
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload,
|
||||
@ -272,6 +274,7 @@ async def async_request_openai_chat_completions(
|
||||
generated_text = ""
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
output.start_time = st
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload,
|
||||
@ -396,6 +399,7 @@ async def async_request_openai_audio(
|
||||
generated_text = ""
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
output.start_time = st
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url,
|
||||
@ -475,6 +479,7 @@ async def async_request_openai_embeddings(
|
||||
|
||||
output = RequestFuncOutput()
|
||||
st = time.perf_counter()
|
||||
output.start_time = st
|
||||
try:
|
||||
async with session.post(
|
||||
url=api_url,
|
||||
|
||||
@ -8,8 +8,8 @@ to launch the vLLM OpenAI API server:
|
||||
|
||||
On the client side, run:
|
||||
vllm bench serve \
|
||||
--endpoint-type <endpoint_type. Default 'openai'> \
|
||||
--label <benchmark result label. Default using endpoint_type> \
|
||||
--backend <backend or endpoint type. Default 'openai'> \
|
||||
--label <benchmark result label. Default using backend> \
|
||||
--model <your_model> \
|
||||
--dataset-name <dataset_name. Default 'random'> \
|
||||
--request-rate <request_rate. Default inf> \
|
||||
@ -18,9 +18,11 @@ On the client side, run:
|
||||
import argparse
|
||||
import asyncio
|
||||
import gc
|
||||
import importlib.util
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import time
|
||||
import warnings
|
||||
from collections.abc import AsyncGenerator, Iterable
|
||||
@ -46,6 +48,24 @@ from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
|
||||
|
||||
TERM_PLOTLIB_AVAILABLE = ((importlib.util.find_spec("termplotlib") is not None)
|
||||
and (shutil.which("gnuplot") is not None))
|
||||
|
||||
|
||||
# TODO: Remove this in v0.11.0
|
||||
class DeprecatedEndpointTypeAction(argparse.Action):
|
||||
"""Argparse action for the deprecated --endpoint-type flag.
|
||||
"""
|
||||
|
||||
def __call__(self, _, namespace, values, option_string=None):
|
||||
warnings.warn(
|
||||
"'--endpoint-type' is deprecated and will be removed in v0.11.0. "
|
||||
"Please use '--backend' instead or remove this argument if you "
|
||||
"have already set it.",
|
||||
stacklevel=1,
|
||||
)
|
||||
setattr(namespace, self.dest, values)
|
||||
|
||||
|
||||
class TaskType(Enum):
|
||||
GENERATION = "generation"
|
||||
@ -80,18 +100,23 @@ class BenchmarkMetrics:
|
||||
median_e2el_ms: float
|
||||
std_e2el_ms: float
|
||||
percentiles_e2el_ms: list[tuple[float, float]]
|
||||
# Max output tokens per second and concurrent requests at that peak
|
||||
max_output_tokens_per_s: float
|
||||
max_concurrent_requests: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbedBenchmarkMetrics:
|
||||
completed: int
|
||||
total_input: int
|
||||
request_throughput: float
|
||||
total_token_throughput :float
|
||||
total_token_throughput: float
|
||||
mean_e2el_ms: float
|
||||
std_e2el_ms: float
|
||||
median_e2el_ms: float
|
||||
percentiles_e2el_ms: float
|
||||
|
||||
|
||||
def _get_current_request_rate(
|
||||
ramp_up_strategy: Optional[Literal["linear", "exponential"]],
|
||||
ramp_up_start_rps: Optional[int],
|
||||
@ -150,8 +175,8 @@ async def get_request(
|
||||
assert burstiness > 0, (
|
||||
f"A positive burstiness factor is expected, but given {burstiness}.")
|
||||
# Convert to list to get length for ramp-up calculations
|
||||
if isinstance(input_requests, Iterable) and not isinstance(
|
||||
input_requests, list):
|
||||
if isinstance(input_requests,
|
||||
Iterable) and not isinstance(input_requests, list):
|
||||
input_requests = list(input_requests)
|
||||
|
||||
total_requests = len(input_requests)
|
||||
@ -161,12 +186,9 @@ async def get_request(
|
||||
request_rates = []
|
||||
delay_ts = []
|
||||
for request_index, request in enumerate(input_requests):
|
||||
current_request_rate = _get_current_request_rate(ramp_up_strategy,
|
||||
ramp_up_start_rps,
|
||||
ramp_up_end_rps,
|
||||
request_index,
|
||||
total_requests,
|
||||
request_rate)
|
||||
current_request_rate = _get_current_request_rate(
|
||||
ramp_up_strategy, ramp_up_start_rps, ramp_up_end_rps,
|
||||
request_index, total_requests, request_rate)
|
||||
request_rates.append(current_request_rate)
|
||||
if current_request_rate == float("inf"):
|
||||
delay_ts.append(0)
|
||||
@ -206,10 +228,8 @@ async def get_request(
|
||||
|
||||
|
||||
def calculate_metrics_for_embeddings(
|
||||
outputs: list[RequestFuncOutput],
|
||||
dur_s: float,
|
||||
selected_percentiles: list[float]
|
||||
) -> EmbedBenchmarkMetrics:
|
||||
outputs: list[RequestFuncOutput], dur_s: float,
|
||||
selected_percentiles: list[float]) -> EmbedBenchmarkMetrics:
|
||||
"""Calculate the metrics for the embedding requests.
|
||||
|
||||
Args:
|
||||
@ -242,10 +262,8 @@ def calculate_metrics_for_embeddings(
|
||||
mean_e2el_ms=np.mean(e2els or 0) * 1000,
|
||||
std_e2el_ms=np.std(e2els or 0) * 1000,
|
||||
median_e2el_ms=np.median(e2els or 0) * 1000,
|
||||
percentiles_e2el_ms=[
|
||||
(p, np.percentile(e2els or 0, p) * 1000)
|
||||
for p in selected_percentiles
|
||||
],
|
||||
percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000)
|
||||
for p in selected_percentiles],
|
||||
)
|
||||
return metrics
|
||||
|
||||
@ -336,6 +354,67 @@ def calculate_metrics(
|
||||
"All requests failed. This is likely due to a misconfiguration "
|
||||
"on the benchmark arguments.",
|
||||
stacklevel=2)
|
||||
|
||||
# Calculate max output tokens per second metric
|
||||
max_output_tokens_per_s = 0.0
|
||||
max_concurrent_requests = 0
|
||||
|
||||
# Find the time range across all successful requests
|
||||
successful_outputs = [output for output in outputs if output.success]
|
||||
if successful_outputs:
|
||||
min_start_time = min(output.start_time
|
||||
for output in successful_outputs)
|
||||
max_end_time = max(output.start_time + output.latency
|
||||
for output in successful_outputs)
|
||||
|
||||
# Create second buckets (ceiling to ensure we capture all time)
|
||||
duration_seconds = int(np.ceil(max_end_time - min_start_time)) + 1
|
||||
tokens_per_second = np.zeros(duration_seconds)
|
||||
concurrent_requests_per_second = np.zeros(duration_seconds)
|
||||
|
||||
for i, output in enumerate(successful_outputs):
|
||||
# Calculate token generation timestamp using
|
||||
# start_time, ttft, and itl
|
||||
token_times = [output.start_time + output.ttft]
|
||||
current_time = token_times[0]
|
||||
for itl_value in output.itl:
|
||||
current_time += itl_value
|
||||
token_times.append(current_time)
|
||||
|
||||
# Add tokens to second buckets
|
||||
for token_time in token_times:
|
||||
second_bucket = int(token_time - min_start_time)
|
||||
if 0 <= second_bucket < duration_seconds:
|
||||
tokens_per_second[second_bucket] += 1
|
||||
|
||||
# Track concurrent requests for each second this request was active
|
||||
request_start_second = int(output.start_time - min_start_time)
|
||||
request_end_second = int((output.start_time + output.latency) -
|
||||
min_start_time)
|
||||
|
||||
for second in range(request_start_second, request_end_second + 1):
|
||||
concurrent_requests_per_second[second] += 1
|
||||
|
||||
# Find the maximum tokens per second and corresponding
|
||||
# concurrent requests
|
||||
if len(tokens_per_second) > 0:
|
||||
max_output_tokens_per_s = float(np.max(tokens_per_second))
|
||||
max_concurrent_requests = int(
|
||||
np.max(concurrent_requests_per_second))
|
||||
|
||||
if TERM_PLOTLIB_AVAILABLE:
|
||||
import termplotlib as tpl
|
||||
fig = tpl.figure()
|
||||
fig.plot(np.arange(len(tokens_per_second)),
|
||||
tokens_per_second,
|
||||
title="Output tokens per second")
|
||||
fig.plot(np.arange(len(concurrent_requests_per_second)),
|
||||
concurrent_requests_per_second,
|
||||
title="Concurrent requests per second")
|
||||
fig.show()
|
||||
else:
|
||||
print("tip: install termplotlib and gnuplot to plot the metrics")
|
||||
|
||||
metrics = BenchmarkMetrics(
|
||||
completed=completed,
|
||||
total_input=total_input,
|
||||
@ -365,6 +444,8 @@ def calculate_metrics(
|
||||
median_e2el_ms=np.median(e2els or 0) * 1000,
|
||||
percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000)
|
||||
for p in selected_percentiles],
|
||||
max_output_tokens_per_s=max_output_tokens_per_s,
|
||||
max_concurrent_requests=max_concurrent_requests,
|
||||
)
|
||||
|
||||
return metrics, actual_output_lens
|
||||
@ -396,18 +477,15 @@ async def benchmark(
|
||||
ramp_up_end_rps: Optional[int] = None,
|
||||
ready_check_timeout_sec: int = 600,
|
||||
):
|
||||
task_type = (
|
||||
TaskType.EMBEDDING
|
||||
if api_url.endswith("/v1/embeddings")
|
||||
else TaskType.GENERATION
|
||||
)
|
||||
task_type = (TaskType.EMBEDDING if api_url.endswith("/v1/embeddings") else
|
||||
TaskType.GENERATION)
|
||||
if endpoint_type in ASYNC_REQUEST_FUNCS:
|
||||
if task_type == TaskType.EMBEDDING:
|
||||
request_func = ASYNC_REQUEST_FUNCS["openai-embeddings"]
|
||||
else:
|
||||
request_func = ASYNC_REQUEST_FUNCS[endpoint_type]
|
||||
else:
|
||||
raise ValueError(f"Unknown endpoint_type: {endpoint_type}")
|
||||
raise ValueError(f"Unknown backend: {endpoint_type}")
|
||||
|
||||
# Reuses connections across requests to reduce TLS handshake overhead.
|
||||
connector = aiohttp.TCPConnector(
|
||||
@ -435,14 +513,10 @@ async def benchmark(
|
||||
input_requests[0].multi_modal_data,
|
||||
)
|
||||
|
||||
assert (
|
||||
test_mm_content is None
|
||||
or isinstance(test_mm_content, dict)
|
||||
or (
|
||||
isinstance(test_mm_content, list)
|
||||
and all(isinstance(item, dict) for item in test_mm_content)
|
||||
)
|
||||
), "multi_modal_data must be a dict or list[dict]"
|
||||
assert (test_mm_content is None or isinstance(test_mm_content, dict)
|
||||
or (isinstance(test_mm_content, list)
|
||||
and all(isinstance(item, dict) for item in test_mm_content))
|
||||
), "multi_modal_data must be a dict or list[dict]"
|
||||
test_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
model_name=model_name,
|
||||
@ -488,13 +562,13 @@ async def benchmark(
|
||||
ignore_eos=ignore_eos,
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body)
|
||||
profile_output = await request_func(
|
||||
request_func_input=profile_input, session=session)
|
||||
profile_output = await request_func(request_func_input=profile_input,
|
||||
session=session)
|
||||
if profile_output.success:
|
||||
print("Profiler started")
|
||||
|
||||
distribution = ("Poisson process" if burstiness == 1.0
|
||||
else "Gamma distribution")
|
||||
distribution = ("Poisson process"
|
||||
if burstiness == 1.0 else "Gamma distribution")
|
||||
|
||||
if ramp_up_strategy is not None:
|
||||
print(f"Traffic ramp-up strategy: {ramp_up_strategy}.")
|
||||
@ -562,18 +636,20 @@ async def benchmark(
|
||||
req_lora_module = next(lora_modules)
|
||||
req_model_id, req_model_name = req_lora_module, req_lora_module
|
||||
|
||||
request_func_input = RequestFuncInput(model=req_model_id,
|
||||
model_name=req_model_name,
|
||||
prompt=prompt,
|
||||
api_url=api_url,
|
||||
prompt_len=prompt_len,
|
||||
output_len=output_len,
|
||||
logprobs=logprobs,
|
||||
multi_modal_content=mm_content,
|
||||
ignore_eos=ignore_eos,
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
request_id=request_id,)
|
||||
request_func_input = RequestFuncInput(
|
||||
model=req_model_id,
|
||||
model_name=req_model_name,
|
||||
prompt=prompt,
|
||||
api_url=api_url,
|
||||
prompt_len=prompt_len,
|
||||
output_len=output_len,
|
||||
logprobs=logprobs,
|
||||
multi_modal_content=mm_content,
|
||||
ignore_eos=ignore_eos,
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
request_id=request_id,
|
||||
)
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
limited_request_func(request_func_input=request_func_input,
|
||||
@ -615,19 +691,21 @@ async def benchmark(
|
||||
benchmark_duration))
|
||||
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
|
||||
if isinstance(metrics, BenchmarkMetrics):
|
||||
print("{:<40} {:<10}".format(
|
||||
"Total generated tokens:", metrics.total_output))
|
||||
print("{:<40} {:<10}".format("Total generated tokens:",
|
||||
metrics.total_output))
|
||||
print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
|
||||
metrics.request_throughput))
|
||||
if goodput_config_dict:
|
||||
print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
|
||||
metrics.request_goodput))
|
||||
if isinstance(metrics, BenchmarkMetrics):
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
"Output token throughput (tok/s):", metrics.output_throughput
|
||||
)
|
||||
)
|
||||
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
|
||||
metrics.output_throughput))
|
||||
print("{:<40} {:<10.2f}".format(
|
||||
"Peak output token throughput (tok/s):",
|
||||
metrics.max_output_tokens_per_s))
|
||||
print("{:<40} {:<10.2f}".format("Peak concurrent requests:",
|
||||
metrics.max_concurrent_requests))
|
||||
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
|
||||
metrics.total_token_throughput))
|
||||
|
||||
@ -648,6 +726,8 @@ async def benchmark(
|
||||
"itls": [output.itl for output in outputs],
|
||||
"generated_texts": [output.generated_text for output in outputs],
|
||||
"errors": [output.error for output in outputs],
|
||||
"max_output_tokens_per_s": metrics.max_output_tokens_per_s,
|
||||
"max_concurrent_requests": metrics.max_concurrent_requests,
|
||||
}
|
||||
else:
|
||||
result = {
|
||||
@ -697,8 +777,8 @@ async def benchmark(
|
||||
|
||||
if task_type == TaskType.GENERATION:
|
||||
process_one_metric("ttft", "TTFT", "Time to First Token")
|
||||
process_one_metric(
|
||||
"tpot", "TPOT", "Time per Output Token (excl. 1st token)")
|
||||
process_one_metric("tpot", "TPOT",
|
||||
"Time per Output Token (excl. 1st token)")
|
||||
process_one_metric("itl", "ITL", "Inter-token Latency")
|
||||
process_one_metric("e2el", "E2EL", "End-to-end Latency")
|
||||
|
||||
@ -714,8 +794,8 @@ async def benchmark(
|
||||
output_len=test_output_len,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
profile_output = await request_func(
|
||||
request_func_input=profile_input, session=session)
|
||||
profile_output = await request_func(request_func_input=profile_input,
|
||||
session=session)
|
||||
if profile_output.success:
|
||||
print("Profiler stopped")
|
||||
|
||||
@ -785,24 +865,28 @@ def save_to_pytorch_benchmark_format(args: argparse.Namespace,
|
||||
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
add_dataset_parser(parser)
|
||||
parser.add_argument(
|
||||
"--endpoint-type",
|
||||
type=str,
|
||||
default="openai",
|
||||
choices=list(ASYNC_REQUEST_FUNCS.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--label",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The label (prefix) of the benchmark results. If not specified, "
|
||||
"the endpoint type will be used as the label.",
|
||||
"the value of '--backend' will be used as the label.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
type=str,
|
||||
default="vllm",
|
||||
default="openai",
|
||||
choices=list(ASYNC_REQUEST_FUNCS.keys()),
|
||||
help="The type of backend or endpoint to use for the benchmark."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--endpoint-type",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=list(ASYNC_REQUEST_FUNCS.keys()),
|
||||
action=DeprecatedEndpointTypeAction,
|
||||
help="'--endpoint-type' is deprecated and will be removed in v0.11.0. "
|
||||
"Please use '--backend' instead.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base-url",
|
||||
@ -851,7 +935,8 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--tokenizer",
|
||||
type=str,
|
||||
help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
|
||||
help=
|
||||
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
|
||||
)
|
||||
parser.add_argument("--use-beam-search", action="store_true")
|
||||
parser.add_argument(
|
||||
@ -982,7 +1067,6 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
||||
help="Specify the prefix of request id.",
|
||||
)
|
||||
|
||||
|
||||
sampling_group = parser.add_argument_group("sampling parameters")
|
||||
sampling_group.add_argument(
|
||||
"--top-p",
|
||||
@ -1047,8 +1131,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
||||
help="The ramp-up strategy. This would be used to "
|
||||
"ramp up the request rate from initial RPS to final "
|
||||
"RPS rate (specified by --ramp-up-start-rps and "
|
||||
"--ramp-up-end-rps.) over the duration of the benchmark."
|
||||
)
|
||||
"--ramp-up-end-rps.) over the duration of the benchmark.")
|
||||
parser.add_argument(
|
||||
"--ramp-up-start-rps",
|
||||
type=int,
|
||||
@ -1087,13 +1170,11 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
|
||||
raise ValueError(
|
||||
"When using ramp-up, do not specify --request-rate. "
|
||||
"The request rate will be controlled by ramp-up parameters. "
|
||||
"Please remove the --request-rate argument."
|
||||
)
|
||||
"Please remove the --request-rate argument.")
|
||||
if args.ramp_up_start_rps is None or args.ramp_up_end_rps is None:
|
||||
raise ValueError(
|
||||
"When using --ramp-up-strategy, both --ramp-up-start-rps and "
|
||||
"--ramp-up-end-rps must be specified"
|
||||
)
|
||||
"--ramp-up-end-rps must be specified")
|
||||
if args.ramp_up_start_rps < 0 or args.ramp_up_end_rps < 0:
|
||||
raise ValueError("Ramp-up start and end RPS must be non-negative")
|
||||
if args.ramp_up_start_rps > args.ramp_up_end_rps:
|
||||
@ -1103,7 +1184,6 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
|
||||
raise ValueError(
|
||||
"For exponential ramp-up, the start RPS cannot be 0.")
|
||||
|
||||
endpoint_type = args.endpoint_type
|
||||
label = args.label
|
||||
model_id = args.model
|
||||
model_name = args.served_model_name
|
||||
@ -1127,8 +1207,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
|
||||
headers[kvstring[0].strip()] = kvstring[1].strip()
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid header format. Please use KEY=VALUE format."
|
||||
)
|
||||
"Invalid header format. Please use KEY=VALUE format.")
|
||||
|
||||
tokenizer = get_tokenizer(tokenizer_id,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
@ -1167,7 +1246,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
|
||||
gc.freeze()
|
||||
|
||||
benchmark_result = await benchmark(
|
||||
endpoint_type=args.endpoint_type,
|
||||
endpoint_type=args.backend,
|
||||
api_url=api_url,
|
||||
base_url=base_url,
|
||||
model_id=model_id,
|
||||
@ -1201,7 +1280,8 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
|
||||
# Setup
|
||||
current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
result_json["date"] = current_dt
|
||||
result_json["endpoint_type"] = args.endpoint_type
|
||||
result_json["endpoint_type"] = args.backend # for backward compatibility
|
||||
result_json["backend"] = args.backend
|
||||
result_json["label"] = label
|
||||
result_json["model_id"] = model_id
|
||||
result_json["tokenizer_id"] = tokenizer_id
|
||||
@ -1215,8 +1295,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
|
||||
result_json[kvstring[0].strip()] = kvstring[1].strip()
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid metadata format. Please use KEY=VALUE format."
|
||||
)
|
||||
"Invalid metadata format. Please use KEY=VALUE format.")
|
||||
|
||||
# Traffic
|
||||
result_json["request_rate"] = (args.request_rate if args.request_rate
|
||||
@ -1252,7 +1331,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
|
||||
base_model_id = model_id.split("/")[-1]
|
||||
max_concurrency_str = (f"-concurrency{args.max_concurrency}"
|
||||
if args.max_concurrency is not None else "")
|
||||
label = label or endpoint_type
|
||||
label = label or args.backend
|
||||
if args.ramp_up_strategy is not None:
|
||||
file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
|
||||
else:
|
||||
|
||||
@ -42,6 +42,7 @@ from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig,
|
||||
ParallelConfig)
|
||||
from vllm.config.scheduler import SchedulerConfig, SchedulerPolicy
|
||||
from vllm.config.speculative import SpeculativeConfig
|
||||
from vllm.config.structured_outputs import StructuredOutputsConfig
|
||||
from vllm.config.utils import ConfigType, config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
@ -449,6 +450,8 @@ class ModelConfig:
|
||||
|
||||
# Multimodal config and init vars
|
||||
multimodal_config: Optional[MultiModalConfig] = None
|
||||
"""Configuration for multimodal model. If `None`, this will be inferred
|
||||
from the architecture of `self.model`."""
|
||||
limit_mm_per_prompt: InitVar[Optional[dict[str, int]]] = None
|
||||
media_io_kwargs: InitVar[Optional[dict[str, dict[str, Any]]]] = None
|
||||
mm_processor_kwargs: InitVar[Optional[dict[str, Any]]] = None
|
||||
@ -2277,66 +2280,6 @@ def get_served_model_name(model: str,
|
||||
return served_model_name
|
||||
|
||||
|
||||
GuidedDecodingBackend = Literal["auto", "xgrammar", "guidance", "outlines",
|
||||
"lm-format-enforcer"]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class DecodingConfig:
|
||||
"""Dataclass which contains the decoding strategy of the engine."""
|
||||
|
||||
backend: GuidedDecodingBackend = "auto"
|
||||
"""Which engine will be used for guided decoding (JSON schema / regex etc)
|
||||
by default. With "auto", we will make opinionated choices based on request
|
||||
contents and what the backend libraries currently support, so the behavior
|
||||
is subject to change in each release."""
|
||||
|
||||
disable_fallback: bool = False
|
||||
"""If `True`, vLLM will not fallback to a different backend on error."""
|
||||
|
||||
disable_any_whitespace: bool = False
|
||||
"""If `True`, the model will not generate any whitespace during guided
|
||||
decoding. This is only supported for xgrammar and guidance backends."""
|
||||
|
||||
disable_additional_properties: bool = False
|
||||
"""If `True`, the `guidance` backend will not use `additionalProperties`
|
||||
in the JSON schema. This is only supported for the `guidance` backend and
|
||||
is used to better align its behaviour with `outlines` and `xgrammar`."""
|
||||
|
||||
reasoning_backend: str = ""
|
||||
"""Select the reasoning parser depending on the model that you're using.
|
||||
This is used to parse the reasoning content into OpenAI API format."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self):
|
||||
if (self.disable_any_whitespace
|
||||
and self.backend not in ("xgrammar", "guidance")):
|
||||
raise ValueError("disable_any_whitespace is only supported for "
|
||||
"xgrammar and guidance backends.")
|
||||
if (self.disable_additional_properties and self.backend != "guidance"):
|
||||
raise ValueError("disable_additional_properties is only supported "
|
||||
"for the guidance backend.")
|
||||
|
||||
|
||||
DetailedTraceModules = Literal["model", "worker", "all"]
|
||||
|
||||
|
||||
@ -2451,8 +2394,9 @@ class VllmConfig:
|
||||
"""LoRA configuration."""
|
||||
speculative_config: Optional[SpeculativeConfig] = None
|
||||
"""Speculative decoding configuration."""
|
||||
decoding_config: DecodingConfig = field(default_factory=DecodingConfig)
|
||||
"""Decoding configuration."""
|
||||
structured_outputs_config: StructuredOutputsConfig = field(
|
||||
default_factory=StructuredOutputsConfig)
|
||||
"""Structured outputs configuration."""
|
||||
observability_config: Optional[ObservabilityConfig] = None
|
||||
"""Observability configuration."""
|
||||
quant_config: Optional[QuantizationConfig] = None
|
||||
@ -2543,8 +2487,8 @@ class VllmConfig:
|
||||
vllm_factors.append(self.speculative_config.compute_hash())
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
if self.decoding_config:
|
||||
vllm_factors.append(self.decoding_config.compute_hash())
|
||||
if self.structured_outputs_config:
|
||||
vllm_factors.append(self.structured_outputs_config.compute_hash())
|
||||
else:
|
||||
vllm_factors.append("None")
|
||||
if self.observability_config:
|
||||
@ -3029,6 +2973,18 @@ class VllmConfig:
|
||||
SequenceClassificationConfig)
|
||||
SequenceClassificationConfig.verify_and_update_config(self)
|
||||
|
||||
if hasattr(self.model_config, "model_weights") and is_runai_obj_uri(
|
||||
self.model_config.model_weights):
|
||||
if self.load_config.load_format == "auto":
|
||||
logger.info("Detected Run:ai model config. "
|
||||
"Overriding `load_format` to 'runai_streamer'")
|
||||
self.load_config.load_format = "runai_streamer"
|
||||
elif self.load_config.load_format != "runai_streamer":
|
||||
raise ValueError(f"To load a model from S3, 'load_format' "
|
||||
f"must be 'runai_streamer', "
|
||||
f"but got '{self.load_config.load_format}'. "
|
||||
f"Model: {self.model_config.model}")
|
||||
|
||||
def __str__(self):
|
||||
return (
|
||||
f"model={self.model_config.model!r}, "
|
||||
@ -3051,7 +3007,7 @@ class VllmConfig:
|
||||
f"enforce_eager={self.model_config.enforce_eager}, "
|
||||
f"kv_cache_dtype={self.cache_config.cache_dtype}, "
|
||||
f"device_config={self.device_config.device}, "
|
||||
f"decoding_config={self.decoding_config!r}, "
|
||||
f"structured_outputs_config={self.structured_outputs_config!r}, "
|
||||
f"observability_config={self.observability_config!r}, "
|
||||
f"seed={self.model_config.seed}, "
|
||||
f"served_model_name={self.model_config.served_model_name}, "
|
||||
|
||||
@ -563,18 +563,6 @@ class CompilationConfig:
|
||||
self.cudagraph_mode = CUDAGraphMode.FULL
|
||||
self.splitting_ops = []
|
||||
|
||||
if envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput":
|
||||
# exclude MoE dispatch/combine from capture by ensuring
|
||||
# piecewise splitting includes them, so communication remains
|
||||
# outside CUDA graphs while compute can still be graphed.
|
||||
moe_ops = [
|
||||
"vllm.moe_forward",
|
||||
"vllm.moe_forward_shared",
|
||||
]
|
||||
for op in moe_ops:
|
||||
if op not in self.splitting_ops:
|
||||
self.splitting_ops.append(op)
|
||||
|
||||
def splitting_ops_contain_attention(self) -> bool:
|
||||
return self.splitting_ops is not None and all(
|
||||
op in self.splitting_ops for op in self._attention_ops)
|
||||
|
||||
@ -31,7 +31,7 @@ logger = init_logger(__name__)
|
||||
|
||||
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
|
||||
"mlp_speculator", "draft_model", "deepseek_mtp",
|
||||
"ernie_mtp", "qwen3_next_mtp"]
|
||||
"ernie_mtp", "qwen3_next_mtp", "mimo_mtp"]
|
||||
|
||||
|
||||
@config
|
||||
@ -83,6 +83,11 @@ class SpeculativeConfig:
|
||||
disable_by_batch_size: Optional[int] = None
|
||||
"""Disable speculative decoding for new incoming requests when the number
|
||||
of enqueued requests is larger than this value, if provided."""
|
||||
disable_padded_drafter_batch: bool = False
|
||||
"""Disable input padding for speculative decoding. If set to True,
|
||||
speculative input batches can contain sequences of different lengths,
|
||||
which may only be supported by certain attention backends. This currently
|
||||
only affects the EAGLE method of speculation."""
|
||||
|
||||
# Ngram proposer configuration
|
||||
prompt_lookup_max: Optional[int] = None
|
||||
|
||||
64
vllm/config/structured_outputs.py
Normal file
64
vllm/config/structured_outputs.py
Normal file
@ -0,0 +1,64 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.config.utils import config
|
||||
|
||||
StructuredOutputsBackend = Literal["auto", "xgrammar", "guidance", "outlines",
|
||||
"lm-format-enforcer"]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class StructuredOutputsConfig:
|
||||
"""Dataclass which contains structured outputs config for the engine."""
|
||||
|
||||
backend: StructuredOutputsBackend = "auto"
|
||||
"""Which engine will be used for structured outputs (e.g. JSON schema,
|
||||
regex, etc) by default. With "auto", we will make opinionated choices
|
||||
based on request contents and what the backend libraries currently support,
|
||||
so the behavior is subject to change in each release."""
|
||||
disable_fallback: bool = False
|
||||
"""If `True`, vLLM will not fallback to a different backend on error."""
|
||||
disable_any_whitespace: bool = False
|
||||
"""If `True`, the model will not generate any whitespace during structured
|
||||
outputs. This is only supported for xgrammar and guidance backends."""
|
||||
disable_additional_properties: bool = False
|
||||
"""If `True`, the `guidance` backend will not use `additionalProperties`
|
||||
in the JSON schema. This is only supported for the `guidance` backend and
|
||||
is used to better align its behaviour with `outlines` and `xgrammar`."""
|
||||
reasoning_parser: str = ""
|
||||
"""Select the reasoning parser depending on the model that you're using.
|
||||
This is used to parse the reasoning content into OpenAI API format."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self):
|
||||
if (self.disable_any_whitespace
|
||||
and self.backend not in ("xgrammar", "guidance")):
|
||||
raise ValueError("disable_any_whitespace is only supported for "
|
||||
"xgrammar and guidance backends.")
|
||||
if (self.disable_additional_properties and self.backend != "guidance"):
|
||||
raise ValueError("disable_additional_properties is only supported "
|
||||
"for the guidance backend.")
|
||||
@ -5,6 +5,7 @@ from typing import Any
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.distributed import get_dp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import has_deep_ep, has_pplx
|
||||
@ -69,6 +70,44 @@ class NaiveAll2AllManager(All2AllManagerBase):
|
||||
pass
|
||||
|
||||
|
||||
class AgRsAll2AllManager(All2AllManagerBase):
|
||||
"""
|
||||
An implementation of all2all communication based on
|
||||
all-gather (dispatch) and reduce-scatter (combine).
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
"""
|
||||
Gather hidden_states and router_logits from all dp ranks.
|
||||
"""
|
||||
sizes = get_forward_context(
|
||||
).dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
hidden_states, router_logits = get_dp_group().all_gatherv(
|
||||
[hidden_states, router_logits],
|
||||
dim=0,
|
||||
sizes=sizes,
|
||||
)
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Reduce-scatter hidden_states across all dp ranks.
|
||||
"""
|
||||
sizes = get_forward_context(
|
||||
).dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
hidden_states = get_dp_group().reduce_scatterv(hidden_states,
|
||||
dim=0,
|
||||
sizes=sizes)
|
||||
return hidden_states
|
||||
|
||||
def destroy(self):
|
||||
pass
|
||||
|
||||
|
||||
class PPLXAll2AllManager(All2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on PPLX kernels.
|
||||
|
||||
@ -87,6 +87,10 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
from .all2all import NaiveAll2AllManager
|
||||
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
|
||||
logger.info("Using naive all2all manager.")
|
||||
elif all2all_backend == "allgather_reducescatter":
|
||||
from .all2all import AgRsAll2AllManager
|
||||
self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
|
||||
logger.info("Using AllGather-ReduceScatter all2all manager.")
|
||||
elif all2all_backend == "pplx":
|
||||
from .all2all import PPLXAll2AllManager
|
||||
self.all2all_manager = PPLXAll2AllManager(self.cpu_group)
|
||||
|
||||
@ -30,7 +30,7 @@ class SingleWriterShmRingBuffer:
|
||||
- Maintains metadata for each allocated buffer chunk in the writer process
|
||||
- Supports custom "is_free_fn" functions to determine when buffers can be
|
||||
reused
|
||||
- Each buffer chunk contains: [4-byte id][4-byte size][actual_data]
|
||||
- Each buffer chunk contains: `[4-byte id][4-byte size][actual_data]`
|
||||
|
||||
Key Concepts:
|
||||
- monotonic_id_start/end: Track the range of active buffer IDs
|
||||
@ -99,7 +99,7 @@ class SingleWriterShmRingBuffer:
|
||||
- Writer handles garbage collection (free_buf) based on reader feedback
|
||||
|
||||
Memory Layout per Buffer Chunk:
|
||||
[4-byte monotonic_id][4-byte chunk_size][actual_data...]
|
||||
`[4-byte monotonic_id][4-byte chunk_size][actual_data...]`
|
||||
^metadata_start ^data_start
|
||||
|
||||
The monotonic_id ensures data integrity - readers can verify they're
|
||||
@ -185,7 +185,7 @@ class SingleWriterShmRingBuffer:
|
||||
'''
|
||||
Allocate a buffer `MD_SIZE` + `size` bytes in the shared memory.
|
||||
Memory layout:
|
||||
[4-byte monotonic_id][4-byte size][buffer data...]
|
||||
`[4-byte monotonic_id][4-byte size][buffer data...]`
|
||||
'''
|
||||
assert self.is_writer, "Only the writer can allocate buffers."
|
||||
assert size > 0, "Size must be greater than 0"
|
||||
@ -253,7 +253,7 @@ class SingleWriterShmRingBuffer:
|
||||
|
||||
Args:
|
||||
nbytes (int, optional): The size of the buffer to free. If None,
|
||||
frees the maximum size of the ring buffer.
|
||||
frees the maximum size of the ring buffer.
|
||||
'''
|
||||
|
||||
assert self.is_writer, "Only the writer can free buffers."
|
||||
@ -413,7 +413,7 @@ class SingleWriterShmObjectStorage:
|
||||
allocation
|
||||
|
||||
Memory Layout per Object:
|
||||
[4-byte reference_count][metadata_size][serialized_object_data]
|
||||
`[4-byte reference_count][metadata_size][serialized_object_data]`
|
||||
|
||||
Thread Safety:
|
||||
- Writer operations (put, clear) are single-threaded by design
|
||||
|
||||
@ -22,17 +22,16 @@ from typing_extensions import TypeIs, deprecated
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
|
||||
ConfigType, ConvertOption, DecodingConfig,
|
||||
DetailedTraceModules, Device, DeviceConfig,
|
||||
DistributedExecutorBackend, EPLBConfig,
|
||||
GuidedDecodingBackend, HfOverrides, KVEventsConfig,
|
||||
ConfigType, ConvertOption, DetailedTraceModules,
|
||||
Device, DeviceConfig, DistributedExecutorBackend,
|
||||
EPLBConfig, HfOverrides, KVEventsConfig,
|
||||
KVTransferConfig, LoadConfig, LogprobsMode,
|
||||
LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig,
|
||||
ModelDType, ModelImpl, ObservabilityConfig,
|
||||
ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
|
||||
RunnerOption, SchedulerConfig, SchedulerPolicy,
|
||||
SpeculativeConfig, TaskOption, TokenizerMode,
|
||||
VllmConfig, get_attr_docs)
|
||||
SpeculativeConfig, StructuredOutputsConfig,
|
||||
TaskOption, TokenizerMode, VllmConfig, get_attr_docs)
|
||||
from vllm.config.multimodal import MMCacheType, MultiModalConfig
|
||||
from vllm.config.parallel import ExpertPlacementStrategy
|
||||
from vllm.config.utils import get_field
|
||||
@ -418,12 +417,15 @@ class EngineArgs:
|
||||
disable_hybrid_kv_cache_manager: bool = (
|
||||
SchedulerConfig.disable_hybrid_kv_cache_manager)
|
||||
|
||||
guided_decoding_backend: GuidedDecodingBackend = DecodingConfig.backend
|
||||
guided_decoding_disable_fallback: bool = DecodingConfig.disable_fallback
|
||||
guided_decoding_disable_any_whitespace: bool = \
|
||||
DecodingConfig.disable_any_whitespace
|
||||
guided_decoding_disable_additional_properties: bool = \
|
||||
DecodingConfig.disable_additional_properties
|
||||
structured_outputs_config: StructuredOutputsConfig = get_field(
|
||||
VllmConfig, "structured_outputs_config")
|
||||
reasoning_parser: str = StructuredOutputsConfig.reasoning_parser
|
||||
# Deprecated guided decoding fields
|
||||
guided_decoding_backend: Optional[str] = None
|
||||
guided_decoding_disable_fallback: Optional[bool] = None
|
||||
guided_decoding_disable_any_whitespace: Optional[bool] = None
|
||||
guided_decoding_disable_additional_properties: Optional[bool] = None
|
||||
|
||||
logits_processor_pattern: Optional[
|
||||
str] = ModelConfig.logits_processor_pattern
|
||||
|
||||
@ -462,7 +464,6 @@ class EngineArgs:
|
||||
|
||||
additional_config: dict[str, Any] = \
|
||||
get_field(VllmConfig, "additional_config")
|
||||
reasoning_parser: str = DecodingConfig.reasoning_backend
|
||||
|
||||
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
|
||||
pt_load_map_location: str = LoadConfig.pt_load_map_location
|
||||
@ -618,28 +619,29 @@ class EngineArgs:
|
||||
load_group.add_argument('--pt-load-map-location',
|
||||
**load_kwargs["pt_load_map_location"])
|
||||
|
||||
# Guided decoding arguments
|
||||
guided_decoding_kwargs = get_kwargs(DecodingConfig)
|
||||
guided_decoding_group = parser.add_argument_group(
|
||||
title="DecodingConfig",
|
||||
description=DecodingConfig.__doc__,
|
||||
# Structured outputs arguments
|
||||
structured_outputs_kwargs = get_kwargs(StructuredOutputsConfig)
|
||||
structured_outputs_group = parser.add_argument_group(
|
||||
title="StructuredOutputsConfig",
|
||||
description=StructuredOutputsConfig.__doc__,
|
||||
)
|
||||
guided_decoding_group.add_argument("--guided-decoding-backend",
|
||||
**guided_decoding_kwargs["backend"])
|
||||
guided_decoding_group.add_argument(
|
||||
"--guided-decoding-disable-fallback",
|
||||
**guided_decoding_kwargs["disable_fallback"])
|
||||
guided_decoding_group.add_argument(
|
||||
"--guided-decoding-disable-any-whitespace",
|
||||
**guided_decoding_kwargs["disable_any_whitespace"])
|
||||
guided_decoding_group.add_argument(
|
||||
"--guided-decoding-disable-additional-properties",
|
||||
**guided_decoding_kwargs["disable_additional_properties"])
|
||||
guided_decoding_group.add_argument(
|
||||
structured_outputs_group.add_argument(
|
||||
"--reasoning-parser",
|
||||
# This choice is a special case because it's not static
|
||||
choices=list(ReasoningParserManager.reasoning_parsers),
|
||||
**guided_decoding_kwargs["reasoning_backend"])
|
||||
**structured_outputs_kwargs["reasoning_parser"])
|
||||
# Deprecated guided decoding arguments
|
||||
for arg, type in [
|
||||
("--guided-decoding-backend", str),
|
||||
("--guided-decoding-disable-fallback", bool),
|
||||
("--guided-decoding-disable-any-whitespace", bool),
|
||||
("--guided-decoding-disable-additional-properties", bool),
|
||||
]:
|
||||
structured_outputs_group.add_argument(
|
||||
arg,
|
||||
type=type,
|
||||
help=(f"[DEPRECATED] {arg} will be removed in v0.12.0."),
|
||||
deprecated=True)
|
||||
|
||||
# Parallel arguments
|
||||
parallel_kwargs = get_kwargs(ParallelConfig)
|
||||
@ -934,6 +936,8 @@ class EngineArgs:
|
||||
**vllm_kwargs["compilation_config"])
|
||||
vllm_group.add_argument("--additional-config",
|
||||
**vllm_kwargs["additional_config"])
|
||||
vllm_group.add_argument('--structured-outputs-config',
|
||||
**vllm_kwargs["structured_outputs_config"])
|
||||
|
||||
# Other arguments
|
||||
parser.add_argument('--disable-log-stats',
|
||||
@ -959,7 +963,6 @@ class EngineArgs:
|
||||
if (not isinstance(self, AsyncEngineArgs) and envs.VLLM_CI_USE_S3
|
||||
and self.model in MODELS_ON_S3 and self.load_format == "auto"):
|
||||
self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}"
|
||||
self.load_format = "runai_streamer"
|
||||
|
||||
if self.disable_mm_preprocessor_cache:
|
||||
logger.warning(
|
||||
@ -1422,14 +1425,25 @@ class EngineArgs:
|
||||
|
||||
load_config = self.create_load_config()
|
||||
|
||||
decoding_config = DecodingConfig(
|
||||
backend=self.guided_decoding_backend,
|
||||
disable_fallback=self.guided_decoding_disable_fallback,
|
||||
disable_any_whitespace=self.guided_decoding_disable_any_whitespace,
|
||||
disable_additional_properties=\
|
||||
self.guided_decoding_disable_additional_properties,
|
||||
reasoning_backend=self.reasoning_parser
|
||||
)
|
||||
# Pass reasoning_parser into StructuredOutputsConfig
|
||||
if self.reasoning_parser:
|
||||
self.structured_outputs_config.reasoning_parser = \
|
||||
self.reasoning_parser
|
||||
|
||||
# Forward the deprecated CLI args to the StructuredOutputsConfig
|
||||
so_config = self.structured_outputs_config
|
||||
if self.guided_decoding_backend is not None:
|
||||
so_config.guided_decoding_backend = \
|
||||
self.guided_decoding_backend
|
||||
if self.guided_decoding_disable_fallback is not None:
|
||||
so_config.guided_decoding_disable_fallback = \
|
||||
self.guided_decoding_disable_fallback
|
||||
if self.guided_decoding_disable_any_whitespace is not None:
|
||||
so_config.guided_decoding_disable_any_whitespace = \
|
||||
self.guided_decoding_disable_any_whitespace
|
||||
if self.guided_decoding_disable_additional_properties is not None:
|
||||
so_config.guided_decoding_disable_additional_properties = \
|
||||
self.guided_decoding_disable_additional_properties
|
||||
|
||||
observability_config = ObservabilityConfig(
|
||||
show_hidden_metrics_for_version=(
|
||||
@ -1447,7 +1461,7 @@ class EngineArgs:
|
||||
lora_config=lora_config,
|
||||
speculative_config=speculative_config,
|
||||
load_config=load_config,
|
||||
decoding_config=decoding_config,
|
||||
structured_outputs_config=self.structured_outputs_config,
|
||||
observability_config=observability_config,
|
||||
compilation_config=self.compilation_config,
|
||||
kv_transfer_config=self.kv_transfer_config,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,173 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Workaround for https://github.com/python/cpython/issues/86296
|
||||
#
|
||||
# From https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py
|
||||
# Licensed under the Apache License (Apache-2.0)
|
||||
|
||||
import asyncio
|
||||
import enum
|
||||
import sys
|
||||
from types import TracebackType
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
if sys.version_info[:2] >= (3, 11):
|
||||
from asyncio import timeout as asyncio_timeout
|
||||
else:
|
||||
|
||||
def asyncio_timeout(delay: Optional[float]) -> "Timeout":
|
||||
"""timeout context manager.
|
||||
Useful in cases when you want to apply timeout logic around block
|
||||
of code or in cases when asyncio.wait_for is not suitable. For example:
|
||||
>>> async with timeout(0.001):
|
||||
... async with aiohttp.get('https://github.com') as r:
|
||||
... await r.text()
|
||||
delay - value in seconds or None to disable timeout logic
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
deadline = loop.time() + delay if delay is not None else None
|
||||
return Timeout(deadline, loop)
|
||||
|
||||
class _State(enum.Enum):
|
||||
INIT = "INIT"
|
||||
ENTER = "ENTER"
|
||||
TIMEOUT = "TIMEOUT"
|
||||
EXIT = "EXIT"
|
||||
|
||||
class Timeout:
|
||||
# Internal class, please don't instantiate it directly
|
||||
# Use timeout() and timeout_at() public factories instead.
|
||||
#
|
||||
# Implementation note: `async with timeout()` is preferred
|
||||
# over `with timeout()`.
|
||||
# While technically the Timeout class implementation
|
||||
# doesn't need to be async at all,
|
||||
# the `async with` statement explicitly points that
|
||||
# the context manager should be used from async function context.
|
||||
#
|
||||
# This design allows to avoid many silly misusages.
|
||||
#
|
||||
# TimeoutError is raised immediately when scheduled
|
||||
# if the deadline is passed.
|
||||
# The purpose is to time out as soon as possible
|
||||
# without waiting for the next await expression.
|
||||
|
||||
__slots__ = ("_deadline", "_loop", "_state", "_timeout_handler")
|
||||
|
||||
def __init__(self, deadline: Optional[float],
|
||||
loop: asyncio.AbstractEventLoop) -> None:
|
||||
self._loop = loop
|
||||
self._state = _State.INIT
|
||||
|
||||
self._timeout_handler = None # type: Optional[asyncio.Handle]
|
||||
if deadline is None:
|
||||
self._deadline = None # type: Optional[float]
|
||||
else:
|
||||
self.update(deadline)
|
||||
|
||||
async def __aenter__(self) -> "Timeout":
|
||||
self._do_enter()
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> Optional[bool]:
|
||||
self._do_exit(exc_type)
|
||||
return None
|
||||
|
||||
@property
|
||||
def expired(self) -> bool:
|
||||
"""Is timeout expired during execution?"""
|
||||
return self._state == _State.TIMEOUT
|
||||
|
||||
@property
|
||||
def deadline(self) -> Optional[float]:
|
||||
return self._deadline
|
||||
|
||||
def reject(self) -> None:
|
||||
"""Reject scheduled timeout if any."""
|
||||
# cancel is maybe better name but
|
||||
# task.cancel() raises CancelledError in asyncio world.
|
||||
if self._state not in (_State.INIT, _State.ENTER):
|
||||
raise RuntimeError(f"invalid state {self._state.value}")
|
||||
self._reject()
|
||||
|
||||
def _reject(self) -> None:
|
||||
if self._timeout_handler is not None:
|
||||
self._timeout_handler.cancel()
|
||||
self._timeout_handler = None
|
||||
|
||||
def shift(self, delay: float) -> None:
|
||||
"""Advance timeout on delay seconds.
|
||||
The delay can be negative.
|
||||
Raise RuntimeError if shift is called when deadline is not scheduled
|
||||
"""
|
||||
deadline = self._deadline
|
||||
if deadline is None:
|
||||
raise RuntimeError(
|
||||
"cannot shift timeout if deadline is not scheduled")
|
||||
self.update(deadline + delay)
|
||||
|
||||
def update(self, deadline: float) -> None:
|
||||
"""Set deadline to absolute value.
|
||||
deadline argument points on the time in the same clock system
|
||||
as loop.time().
|
||||
If new deadline is in the past the timeout is raised immediately.
|
||||
Please note: it is not POSIX time but a time with
|
||||
undefined starting base, e.g. the time of the system power on.
|
||||
"""
|
||||
if self._state == _State.EXIT:
|
||||
raise RuntimeError(
|
||||
"cannot reschedule after exit from context manager")
|
||||
if self._state == _State.TIMEOUT:
|
||||
raise RuntimeError("cannot reschedule expired timeout")
|
||||
if self._timeout_handler is not None:
|
||||
self._timeout_handler.cancel()
|
||||
self._deadline = deadline
|
||||
if self._state != _State.INIT:
|
||||
self._reschedule()
|
||||
|
||||
def _reschedule(self) -> None:
|
||||
assert self._state == _State.ENTER
|
||||
deadline = self._deadline
|
||||
if deadline is None:
|
||||
return
|
||||
|
||||
now = self._loop.time()
|
||||
if self._timeout_handler is not None:
|
||||
self._timeout_handler.cancel()
|
||||
|
||||
task = asyncio.current_task()
|
||||
if deadline <= now:
|
||||
self._timeout_handler = self._loop.call_soon(
|
||||
self._on_timeout, task)
|
||||
else:
|
||||
self._timeout_handler = self._loop.call_at(
|
||||
deadline, self._on_timeout, task)
|
||||
|
||||
def _do_enter(self) -> None:
|
||||
if self._state != _State.INIT:
|
||||
raise RuntimeError(f"invalid state {self._state.value}")
|
||||
self._state = _State.ENTER
|
||||
self._reschedule()
|
||||
|
||||
def _do_exit(self, exc_type: Optional[Type[BaseException]]) -> None:
|
||||
if exc_type is asyncio.CancelledError and \
|
||||
self._state == _State.TIMEOUT:
|
||||
self._timeout_handler = None
|
||||
raise asyncio.TimeoutError
|
||||
# timeout has not expired
|
||||
self._state = _State.EXIT
|
||||
self._reject()
|
||||
return None
|
||||
|
||||
def _on_timeout(self, task: "Optional[asyncio.Task[Any]]") -> None:
|
||||
if task:
|
||||
task.cancel()
|
||||
self._state = _State.TIMEOUT
|
||||
# drop the reference early
|
||||
self._timeout_handler = None
|
||||
@ -16,9 +16,8 @@ import torch
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import (DecodingConfig, ModelConfig, ObservabilityConfig,
|
||||
from vllm.config import (LoRAConfig, ModelConfig, ObservabilityConfig,
|
||||
ParallelConfig, SchedulerConfig, VllmConfig)
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.core.scheduler import ScheduledSequenceGroup, SchedulerOutputs
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.metrics_types import StatLoggerBase, Stats
|
||||
@ -213,8 +212,7 @@ class LLMEngine:
|
||||
self.device_config = vllm_config.device_config
|
||||
self.speculative_config = vllm_config.speculative_config # noqa
|
||||
self.load_config = vllm_config.load_config
|
||||
self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa
|
||||
)
|
||||
self.structured_outputs_config = vllm_config.structured_outputs_config
|
||||
self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa
|
||||
)
|
||||
|
||||
@ -364,10 +362,9 @@ class LLMEngine:
|
||||
self.observability_config.otlp_traces_endpoint)
|
||||
|
||||
# Initialize reasoning parser if reasoning backend is set.
|
||||
if self.decoding_config.reasoning_backend and \
|
||||
self.tokenizer:
|
||||
if self.structured_outputs_config.reasoning_parser and self.tokenizer:
|
||||
reasoner_class = ReasoningParserManager.get_reasoning_parser(
|
||||
self.decoding_config.reasoning_backend)
|
||||
self.structured_outputs_config.reasoning_parser)
|
||||
self.reasoner: ReasoningParser = reasoner_class(
|
||||
self.tokenizer.get_lora_tokenizer())
|
||||
|
||||
@ -381,7 +378,8 @@ class LLMEngine:
|
||||
self.seq_counter,
|
||||
stop_checker=StopChecker(
|
||||
self.scheduler_config.max_model_len,
|
||||
self.reasoner if self.decoding_config.reasoning_backend
|
||||
self.reasoner
|
||||
if self.structured_outputs_config.reasoning_parser
|
||||
and self.tokenizer else None,
|
||||
),
|
||||
))
|
||||
@ -435,9 +433,9 @@ class LLMEngine:
|
||||
f"ExecutorBase. Got {distributed_executor_backend}.")
|
||||
executor_class = distributed_executor_backend
|
||||
elif distributed_executor_backend == "ray":
|
||||
from vllm.executor.ray_distributed_executor import (
|
||||
RayDistributedExecutor)
|
||||
executor_class = RayDistributedExecutor
|
||||
raise RuntimeError(
|
||||
"The Ray distributed executor is only available in the v1 "
|
||||
"engine. Enable it by setting 'VLLM_USE_V1=1'.")
|
||||
elif distributed_executor_backend == "mp":
|
||||
from vllm.executor.mp_distributed_executor import (
|
||||
MultiprocessingDistributedExecutor)
|
||||
@ -671,10 +669,13 @@ class LLMEngine:
|
||||
arrival_time = time.time()
|
||||
|
||||
if (isinstance(prompt, dict)
|
||||
and prompt.get("prompt_embeds", None) is not None
|
||||
and not prompt.get("prompt_token_ids", None)):
|
||||
seq_len = prompt["prompt_embeds"].shape[0]
|
||||
prompt["prompt_token_ids"] = [0] * seq_len
|
||||
and prompt.get("prompt_embeds", None) is not None):
|
||||
if not prompt.get("prompt_token_ids", None):
|
||||
seq_len = prompt["prompt_embeds"].shape[0]
|
||||
prompt["prompt_token_ids"] = [0] * seq_len
|
||||
if params.prompt_logprobs is not None:
|
||||
raise ValueError(
|
||||
"prompt_logprobs is not compatible with prompt embeds.")
|
||||
|
||||
processed_inputs = self.input_preprocessor.preprocess(
|
||||
prompt,
|
||||
@ -769,10 +770,6 @@ class LLMEngine:
|
||||
"""Gets the parallel configuration."""
|
||||
return self.parallel_config
|
||||
|
||||
def get_decoding_config(self) -> DecodingConfig:
|
||||
"""Gets the decoding configuration."""
|
||||
return self.decoding_config
|
||||
|
||||
def get_scheduler_config(self) -> SchedulerConfig:
|
||||
"""Gets the scheduler configuration."""
|
||||
return self.scheduler_config
|
||||
|
||||
@ -6,7 +6,7 @@ from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncGenerator, Iterable, Mapping, Optional, Union
|
||||
|
||||
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
|
||||
from vllm.config import DecodingConfig, ModelConfig, VllmConfig
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.inputs.data import PromptType, TokensPrompt
|
||||
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
|
||||
@ -248,11 +248,6 @@ class EngineClient(ABC):
|
||||
"""Get the model configuration of the vLLM engine."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_decoding_config(self) -> DecodingConfig:
|
||||
"""Get the decoding configuration of the vLLM engine."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_input_preprocessor(self) -> InputPreprocessor:
|
||||
"""Get the input processor of the vLLM engine."""
|
||||
|
||||
@ -45,6 +45,28 @@ def _interactive_cli(args: argparse.Namespace) -> tuple[str, OpenAI]:
|
||||
return model_name, openai_client
|
||||
|
||||
|
||||
def _print_chat_stream(stream) -> str:
|
||||
output = ""
|
||||
for chunk in stream:
|
||||
delta = chunk.choices[0].delta
|
||||
if delta.content:
|
||||
output += delta.content
|
||||
print(delta.content, end="", flush=True)
|
||||
print()
|
||||
return output
|
||||
|
||||
|
||||
def _print_completion_stream(stream) -> str:
|
||||
output = ""
|
||||
for chunk in stream:
|
||||
text = chunk.choices[0].text
|
||||
if text is not None:
|
||||
output += text
|
||||
print(text, end="", flush=True)
|
||||
print()
|
||||
return output
|
||||
|
||||
|
||||
def chat(system_prompt: str | None, model_name: str, client: OpenAI) -> None:
|
||||
conversation: list[ChatCompletionMessageParam] = []
|
||||
if system_prompt is not None:
|
||||
@ -58,14 +80,11 @@ def chat(system_prompt: str | None, model_name: str, client: OpenAI) -> None:
|
||||
break
|
||||
conversation.append({"role": "user", "content": input_message})
|
||||
|
||||
chat_completion = client.chat.completions.create(model=model_name,
|
||||
messages=conversation)
|
||||
|
||||
response_message = chat_completion.choices[0].message
|
||||
output = response_message.content
|
||||
|
||||
conversation.append(response_message) # type: ignore
|
||||
print(output)
|
||||
stream = client.chat.completions.create(model=model_name,
|
||||
messages=conversation,
|
||||
stream=True)
|
||||
output = _print_chat_stream(stream)
|
||||
conversation.append({"role": "assistant", "content": output})
|
||||
|
||||
|
||||
def _add_query_options(
|
||||
@ -108,9 +127,11 @@ class ChatCommand(CLISubcommand):
|
||||
if args.quick:
|
||||
conversation.append({"role": "user", "content": args.quick})
|
||||
|
||||
chat_completion = client.chat.completions.create(
|
||||
model=model_name, messages=conversation)
|
||||
print(chat_completion.choices[0].message.content)
|
||||
stream = client.chat.completions.create(model=model_name,
|
||||
messages=conversation,
|
||||
stream=True)
|
||||
output = _print_chat_stream(stream)
|
||||
conversation.append({"role": "assistant", "content": output})
|
||||
return
|
||||
|
||||
print("Please enter a message for the chat model:")
|
||||
@ -121,14 +142,11 @@ class ChatCommand(CLISubcommand):
|
||||
break
|
||||
conversation.append({"role": "user", "content": input_message})
|
||||
|
||||
chat_completion = client.chat.completions.create(
|
||||
model=model_name, messages=conversation)
|
||||
|
||||
response_message = chat_completion.choices[0].message
|
||||
output = response_message.content
|
||||
|
||||
conversation.append(response_message) # type: ignore
|
||||
print(output)
|
||||
stream = client.chat.completions.create(model=model_name,
|
||||
messages=conversation,
|
||||
stream=True)
|
||||
output = _print_chat_stream(stream)
|
||||
conversation.append({"role": "assistant", "content": output})
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
@ -168,9 +186,10 @@ class CompleteCommand(CLISubcommand):
|
||||
model_name, client = _interactive_cli(args)
|
||||
|
||||
if args.quick:
|
||||
completion = client.completions.create(model=model_name,
|
||||
prompt=args.quick)
|
||||
print(completion.choices[0].text)
|
||||
stream = client.completions.create(model=model_name,
|
||||
prompt=args.quick,
|
||||
stream=True)
|
||||
_print_completion_stream(stream)
|
||||
return
|
||||
|
||||
print("Please enter prompt to complete:")
|
||||
@ -179,10 +198,10 @@ class CompleteCommand(CLISubcommand):
|
||||
input_prompt = input("> ")
|
||||
except EOFError:
|
||||
break
|
||||
completion = client.completions.create(model=model_name,
|
||||
prompt=input_prompt)
|
||||
output = completion.choices[0].text
|
||||
print(output)
|
||||
stream = client.completions.create(model=model_name,
|
||||
prompt=input_prompt,
|
||||
stream=True)
|
||||
_print_completion_stream(stream)
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
|
||||
@ -11,7 +11,6 @@ import uvicorn
|
||||
from fastapi import FastAPI, Request, Response
|
||||
|
||||
from vllm import envs
|
||||
from vllm.engine.async_llm_engine import AsyncEngineDeadError
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.constants import (H11_MAX_HEADER_COUNT_DEFAULT,
|
||||
H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT)
|
||||
@ -154,7 +153,6 @@ def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
|
||||
"""
|
||||
|
||||
@app.exception_handler(RuntimeError)
|
||||
@app.exception_handler(AsyncEngineDeadError)
|
||||
@app.exception_handler(EngineDeadError)
|
||||
@app.exception_handler(EngineGenerateError)
|
||||
async def runtime_exception_handler(request: Request, __):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user