Compare commits
90 Commits
copilot/fi
...
copilot/fi
| Author | SHA1 | Date | |
|---|---|---|---|
| c6efc2afba | |||
| d95d55443b | |||
| 2a97ffc33d | |||
| efc88cf64a | |||
| 7b6a837275 | |||
| c34c82b7fe | |||
| 8a044754bd | |||
| 9188ae7cb5 | |||
| 8a3cd90af5 | |||
| 2a167b2eeb | |||
| 0ff902f3b4 | |||
| a9082a4d14 | |||
| e0329ed4b4 | |||
| 6879cd80ae | |||
| e269be2ba2 | |||
| 5c4b6e66fe | |||
| d0a4a3f645 | |||
| ebafb0936d | |||
| 0cb7b065c3 | |||
| 2da02dd0d8 | |||
| d765cf01fe | |||
| 712d0f88d8 | |||
| 49ab23b3cc | |||
| c9abb10489 | |||
| 787cdb3829 | |||
| a5203d04df | |||
| 99f8094400 | |||
| 170e8ea9ea | |||
| a71e4765cc | |||
| 39971db3aa | |||
| 504d914314 | |||
| 47455c424f | |||
| c7fc6b1354 | |||
| ad78868450 | |||
| e2db1164a1 | |||
| 416f05929a | |||
| 5e021b4981 | |||
| 1b9b16649c | |||
| e76e233540 | |||
| a75277285b | |||
| 9dc30b7068 | |||
| 053278a5dc | |||
| c55c028998 | |||
| 65197a5fb3 | |||
| b8f17f5d98 | |||
| d9a55204ba | |||
| b4e9fd811f | |||
| 308fa287a8 | |||
| fa78de9dc3 | |||
| f6818a92cb | |||
| 23c939fd30 | |||
| add1adfec7 | |||
| c80c53a30f | |||
| 24d0c9e6ed | |||
| cc7ae5e7ca | |||
| 0313cf854d | |||
| 0483fabc74 | |||
| da65bec309 | |||
| 4645024d3a | |||
| cd7a3df26f | |||
| 32d2b4064f | |||
| 22cf679aad | |||
| b6d7d34fc6 | |||
| 341923b982 | |||
| 424fb7a5d2 | |||
| 88491c1b6b | |||
| 613a23b57f | |||
| 51a215300b | |||
| ebe14621e3 | |||
| 325aa3dee9 | |||
| a073be6d87 | |||
| 695e7adcd2 | |||
| 281710ef9a | |||
| 808d2e9aa0 | |||
| 285178b3b8 | |||
| 88016c372a | |||
| 998720859c | |||
| 0ba1b54ac6 | |||
| 53415653ff | |||
| 17373dcd93 | |||
| 5964069367 | |||
| de9c085e17 | |||
| 111692bb8c | |||
| 394591e343 | |||
| 3ac849665d | |||
| 0b9cc56fac | |||
| 8896eb72eb | |||
| 19fe1a0510 | |||
| 480bdf5a7b | |||
| 5368f76855 |
@ -2,7 +2,7 @@
|
||||
# We can use this script to compute baseline accuracy on GSM for transformers.
|
||||
#
|
||||
# Make sure you have lm-eval-harness installed:
|
||||
# pip install lm-eval==0.4.4
|
||||
# pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api]
|
||||
|
||||
usage() {
|
||||
echo``
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
# We use this for fp8, which HF does not support.
|
||||
#
|
||||
# Make sure you have lm-eval-harness installed:
|
||||
# pip install lm-eval==0.4.4
|
||||
# pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api]
|
||||
|
||||
usage() {
|
||||
echo``
|
||||
|
||||
@ -17,7 +17,7 @@ Latest reproduction guilde: [github issue link](https://github.com/vllm-project/
|
||||
- SGLang: `lmsysorg/sglang:v0.3.2-cu121`
|
||||
- LMDeploy: `openmmlab/lmdeploy:v0.6.1-cu12`
|
||||
- TensorRT-LLM: `nvcr.io/nvidia/tritonserver:24.07-trtllm-python-py3`
|
||||
- *NOTE: we uses r24.07 as the current implementation only works for this version. We are going to bump this up.*
|
||||
- *NOTE: we use r24.07 as the current implementation only works for this version. We are going to bump this up.*
|
||||
- Check [nightly-pipeline.yaml](nightly-pipeline.yaml) for the concrete docker images, specs and commands we use for the benchmark.
|
||||
- Hardware
|
||||
- 8x Nvidia A100 GPUs
|
||||
|
||||
@ -382,7 +382,7 @@ run_genai_perf_tests() {
|
||||
client_command="genai-perf profile \
|
||||
-m $model \
|
||||
--service-kind openai \
|
||||
--backend vllm \
|
||||
--backend "$backend" \
|
||||
--endpoint-type chat \
|
||||
--streaming \
|
||||
--url localhost:$port \
|
||||
|
||||
@ -61,7 +61,7 @@ echo "Results will be stored in: $RESULTS_DIR"
|
||||
echo "--- Installing Python dependencies ---"
|
||||
python3 -m pip install --progress-bar off git+https://github.com/thuml/depyf.git \
|
||||
&& python3 -m pip install --progress-bar off pytest pytest-asyncio tpu-info \
|
||||
&& python3 -m pip install --progress-bar off lm_eval[api]==0.4.4 \
|
||||
&& python3 -m pip install --progress-bar off "lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d" \
|
||||
&& python3 -m pip install --progress-bar off hf-transfer
|
||||
echo "--- Python dependencies installed ---"
|
||||
export VLLM_USE_V1=1
|
||||
|
||||
@ -61,7 +61,7 @@ echo "Results will be stored in: $RESULTS_DIR"
|
||||
echo "--- Installing Python dependencies ---"
|
||||
python3 -m pip install --progress-bar off git+https://github.com/thuml/depyf.git \
|
||||
&& python3 -m pip install --progress-bar off pytest pytest-asyncio tpu-info \
|
||||
&& python3 -m pip install --progress-bar off lm_eval[api]==0.4.4 \
|
||||
&& python3 -m pip install --progress-bar off "lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d" \
|
||||
&& python3 -m pip install --progress-bar off hf-transfer
|
||||
echo "--- Python dependencies installed ---"
|
||||
export VLLM_USE_V1=1
|
||||
|
||||
@ -244,6 +244,7 @@ steps:
|
||||
- pytest -v -s v1/core
|
||||
- pytest -v -s v1/engine
|
||||
- pytest -v -s v1/entrypoints
|
||||
- pytest -v -s v1/executor
|
||||
- pytest -v -s v1/sample
|
||||
- pytest -v -s v1/logits_processors
|
||||
- pytest -v -s v1/worker
|
||||
@ -842,3 +843,10 @@ steps:
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4
|
||||
|
||||
- label: Qwen MoE EP Test # optional
|
||||
gpu: h200
|
||||
optional: true
|
||||
num_gpus: 2
|
||||
commands:
|
||||
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 /vllm-workspace/examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
|
||||
|
||||
3
.github/PULL_REQUEST_TEMPLATE.md
vendored
3
.github/PULL_REQUEST_TEMPLATE.md
vendored
@ -7,8 +7,6 @@ PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTT
|
||||
|
||||
## Test Result
|
||||
|
||||
## (Optional) Documentation Update
|
||||
|
||||
---
|
||||
<details>
|
||||
<summary> Essential Elements of an Effective PR Description Checklist </summary>
|
||||
@ -17,6 +15,7 @@ PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTT
|
||||
- [ ] The test plan, such as providing test command.
|
||||
- [ ] The test results, such as pasting the results comparison before and after, or e2e results
|
||||
- [ ] (Optional) The necessary documentation update, such as updating `supported_models.md` and `examples` for a new model.
|
||||
- [ ] (Optional) Release notes update. If your change is user facing, please update the release notes draft in the [Google Doc](https://docs.google.com/document/d/1YyVqrgX4gHTtrstbq8oWUImOyPCKSGnJ7xtTpmXzlRs/edit?tab=t.0).
|
||||
</details>
|
||||
|
||||
**BEFORE SUBMITTING, PLEASE READ <https://docs.vllm.ai/en/latest/contributing>** (anything written below this line will be removed by GitHub Actions)
|
||||
|
||||
@ -750,6 +750,33 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"found in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Only build W4A8 kernels if we are building for something compatible with sm90a
|
||||
cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${W4A8_ARCHS}")
|
||||
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
|
||||
message(STATUS "Building W4A8 kernels for archs: ${W4A8_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0
|
||||
AND W4A8_ARCHS)
|
||||
message(STATUS "Not building W4A8 kernels as CUDA Compiler version is "
|
||||
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
|
||||
"later if you intend on running w4a16 quantized models on "
|
||||
"Hopper.")
|
||||
else()
|
||||
message(STATUS "Not building W4A8 kernels as no compatible archs "
|
||||
"found in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# if CUDA endif
|
||||
endif()
|
||||
|
||||
@ -790,7 +817,9 @@ set(VLLM_MOE_EXT_SRC
|
||||
"csrc/moe/topk_softmax_kernels.cu")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_MOE_EXT_SRC "csrc/moe/moe_wna16.cu")
|
||||
list(APPEND VLLM_MOE_EXT_SRC
|
||||
"csrc/moe/moe_wna16.cu"
|
||||
"csrc/moe/grouped_topk_kernels.cu")
|
||||
endif()
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
@ -18,14 +18,15 @@ Easy, fast, and cheap LLM serving for everyone
|
||||
|
||||
*Latest News* 🔥
|
||||
|
||||
- [2025/08] We hosted [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/pDmAXHcN7Iqc8sUKgJgGtg) focusing on building, developing, and integrating with vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1OvLx39wnCGy_WKq8SiVKf7YcxxYI3WCH).
|
||||
- [2025/08] We hosted [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/dgkWg1WFpWGO2jCdTqQHxA) focusing on large-scale LLM deployment! Please find the meetup slides [here](https://drive.google.com/drive/folders/1Pid6NSFLU43DZRi0EaTcPgXsAzDvbBqF) and the recording [here](https://www.chaspark.com/#/live/1166916873711665152).
|
||||
- [2025/05] We hosted [NYC vLLM Meetup](https://lu.ma/c1rqyf1f)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing).
|
||||
- [2025/05] vLLM is now a hosted project under PyTorch Foundation! Please find the announcement [here](https://pytorch.org/blog/pytorch-foundation-welcomes-vllm/).
|
||||
- [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html).
|
||||
|
||||
<details>
|
||||
<summary>Previous News</summary>
|
||||
|
||||
- [2025/05] We hosted [NYC vLLM Meetup](https://lu.ma/c1rqyf1f)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing).
|
||||
- [2025/04] We hosted [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing).
|
||||
- [2025/03] We hosted [vLLM x Ollama Inference Night](https://lu.ma/vllm-ollama)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/16T2PDD1YwRnZ4Tu8Q5r6n53c5Lr5c73UV9Vd2_eBo4U/edit?usp=sharing).
|
||||
- [2025/03] We hosted [the first vLLM China Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg)! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1REHvfQMKGnvz6p3Fd23HhSO4c8j5WPGZV0bKYLwnHyQ/edit?usp=sharing).
|
||||
|
||||
@ -59,6 +59,12 @@ become available.
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td><code>synthetic</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>RandomMultiModal (Image/Video)</strong></td>
|
||||
<td style="text-align: center;">🟡</td>
|
||||
<td style="text-align: center;">🚧</td>
|
||||
<td><code>synthetic</code> </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>Prefix Repetition</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
@ -722,4 +728,75 @@ python benchmarks/benchmark_serving.py \
|
||||
--endpoint /v1/chat/completion
|
||||
```
|
||||
|
||||
### Synthetic Random Images (random-mm)
|
||||
|
||||
Generate synthetic image inputs alongside random text prompts to stress-test vision models without external datasets.
|
||||
|
||||
Notes:
|
||||
|
||||
- Works only with online benchmark via the OpenAI backend (`--backend openai-chat`) and endpoint `/v1/chat/completions`.
|
||||
- Video sampling is not yet implemented.
|
||||
|
||||
Start the server (example):
|
||||
|
||||
```bash
|
||||
vllm serve Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
--dtype bfloat16 \
|
||||
--max-model-len 16384 \
|
||||
--limit-mm-per-prompt '{"image": 3, "video": 0}' \
|
||||
--mm-processor-kwargs max_pixels=1003520
|
||||
```
|
||||
|
||||
Benchmark. It is recommended to use the flag `--ignore-eos` to simulate real responses. You can set the size of the output via the arg `random-output-len`.
|
||||
|
||||
Ex.1: Fixed number of items and a single image resolutionm, enforcing generation of approx 40 tokens:
|
||||
|
||||
```bash
|
||||
vllm bench serve \
|
||||
--backend openai-chat \
|
||||
--model Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
--endpoint /v1/chat/completions \
|
||||
--dataset-name random-mm \
|
||||
--num-prompts 100 \
|
||||
--max-concurrency 10 \
|
||||
--random-prefix-len 25 \
|
||||
--random-input-len 300 \
|
||||
--random-output-len 40 \
|
||||
--random-range-ratio 0.2 \
|
||||
--random-mm-base-items-per-request 2 \
|
||||
--random-mm-limit-mm-per-prompt '{"image": 3, "video": 0}' \
|
||||
--random-mm-bucket-config '{(224, 224, 1): 1.0}' \
|
||||
--request-rate inf \
|
||||
--ignore-eos \
|
||||
--seed 42
|
||||
```
|
||||
|
||||
The number of items per request can be controlled by passing multiple image buckets:
|
||||
|
||||
```bash
|
||||
--random-mm-base-items-per-request 2 \
|
||||
--random-mm-num-mm-items-range-ratio 0.5 \
|
||||
--random-mm-limit-mm-per-prompt '{"image": 4, "video": 0}' \
|
||||
--random-mm-bucket-config '{(256, 256, 1): 0.7, (720, 1280, 1): 0.3}' \
|
||||
```
|
||||
|
||||
Flags specific to `random-mm`:
|
||||
|
||||
- `--random-mm-base-items-per-request`: base number of multimodal items per request.
|
||||
- `--random-mm-num-mm-items-range-ratio`: vary item count uniformly in the closed integer range [floor(n·(1−r)), ceil(n·(1+r))]. Set r=0 to keep it fixed; r=1 allows 0 items.
|
||||
- `--random-mm-limit-mm-per-prompt`: per-modality hard caps, e.g. '{"image": 3, "video": 0}'.
|
||||
- `--random-mm-bucket-config`: dict mapping (H, W, T) → probability. Entries with probability 0 are removed; remaining probabilities are renormalized to sum to 1. Use T=1 for images. Set any T>1 for videos (video sampling not yet supported).
|
||||
|
||||
Behavioral notes:
|
||||
|
||||
- If the requested base item count cannot be satisfied under the provided per-prompt limits, the tool raises an error rather than silently clamping.
|
||||
|
||||
How sampling works:
|
||||
|
||||
- Determine per-request item count k by sampling uniformly from the integer range defined by `--random-mm-base-items-per-request` and `--random-mm-num-mm-items-range-ratio`, then clamp k to at most the sum of per-modality limits.
|
||||
- For each of the k items, sample a bucket (H, W, T) according to the normalized probabilities in `--random-mm-bucket-config`, while tracking how many items of each modality have been added.
|
||||
- If a modality (e.g., image) reaches its limit from `--random-mm-limit-mm-per-prompt`, all buckets of that modality are excluded and the remaining bucket probabilities are renormalized before continuing.
|
||||
This should be seen as an edge case, and if this behavior can be avoided by setting `--random-mm-limit-mm-per-prompt` to a large number. Note that this might result in errors due to engine config `--limit-mm-per-prompt`.
|
||||
- The resulting request contains synthetic image data in `multi_modal_data` (OpenAI Chat format). When `random-mm` is used with the OpenAI Chat backend, prompts remain text and MM content is attached via `multi_modal_data`.
|
||||
|
||||
</details>
|
||||
|
||||
@ -284,6 +284,25 @@ def machete_create_bench_fn(
|
||||
)
|
||||
|
||||
|
||||
def cutlass_w4a8_create_bench_fn(
|
||||
bt: BenchmarkTensors, out_type=torch.dtype, schedule=None
|
||||
) -> Callable:
|
||||
w_q = bt.w_q.t().contiguous().t() # make col major
|
||||
w_q = ops.cutlass_encode_and_reorder_int4b(w_q)
|
||||
# expects fp8 scales
|
||||
w_s = ops.cutlass_pack_scale_fp8(bt.w_g_s.to(torch.float8_e4m3fn))
|
||||
|
||||
return lambda: ops.cutlass_w4a8_mm(
|
||||
a=bt.a,
|
||||
b_q=w_q,
|
||||
b_group_scales=w_s,
|
||||
b_group_size=bt.group_size,
|
||||
b_channel_scales=bt.w_ch_s,
|
||||
a_token_scales=bt.w_tok_s,
|
||||
maybe_schedule=schedule,
|
||||
)
|
||||
|
||||
|
||||
# impl
|
||||
|
||||
# bench
|
||||
@ -385,6 +404,20 @@ def bench(
|
||||
)
|
||||
)
|
||||
|
||||
# cutlass w4a8
|
||||
if types.act_type == torch.float8_e4m3fn and group_size == 128:
|
||||
timers.append(
|
||||
bench_fns(
|
||||
label,
|
||||
sub_label,
|
||||
f"cutlass w4a8 ({name_type_string})",
|
||||
[
|
||||
cutlass_w4a8_create_bench_fn(bt, out_type=types.output_type)
|
||||
for bt in benchmark_tensors
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
if sweep_schedules:
|
||||
global _SWEEP_SCHEDULES_RESULTS
|
||||
|
||||
|
||||
@ -9,8 +9,11 @@ from typing import Optional
|
||||
import flashinfer
|
||||
import torch
|
||||
|
||||
from vllm.utils import round_up
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
FP8_DTYPE = torch.float8_e4m3fn
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
|
||||
def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||
@ -61,13 +64,13 @@ def benchmark_decode(
|
||||
else:
|
||||
raise ValueError(f"Invalid kv_layout: {kv_layout}")
|
||||
|
||||
query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype)
|
||||
# Always using 1.0 scale to reflect the real perf in benchmarking
|
||||
q_scale = 1.0
|
||||
ref_query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype)
|
||||
if q_quant_dtype == FP8_DTYPE:
|
||||
query, q_scale = to_float8(query)
|
||||
ref_query = query.to(dtype) * q_scale
|
||||
query, _ = to_float8(ref_query)
|
||||
else:
|
||||
q_scale = 1.0
|
||||
ref_query = query
|
||||
query = ref_query
|
||||
|
||||
kv_lens = torch.randint(1, max_seq_len, (batch_size,), dtype=torch.int32)
|
||||
kv_lens[-1] = max_seq_len
|
||||
@ -75,14 +78,13 @@ def benchmark_decode(
|
||||
seq_lens = kv_lens
|
||||
max_seq_len = torch.max(seq_lens).item()
|
||||
|
||||
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
||||
# Always using 1.0 scale to reflect the real perf in benchmarking
|
||||
k_scale = v_scale = 1.0
|
||||
ref_kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
||||
if kv_quant_dtype == FP8_DTYPE:
|
||||
kv_cache, kv_scale = to_float8(kv_cache)
|
||||
ref_kv_cache = kv_cache.to(dtype) * kv_scale
|
||||
kv_cache, _ = to_float8(ref_kv_cache)
|
||||
else:
|
||||
kv_scale = 1.0
|
||||
ref_kv_cache = kv_cache
|
||||
k_scale = v_scale = kv_scale
|
||||
kv_cache = ref_kv_cache
|
||||
|
||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(
|
||||
@ -142,11 +144,31 @@ def benchmark_decode(
|
||||
return sum(times) / len(times), torch.std(torch.tensor(times))
|
||||
|
||||
o_scale = 1.0
|
||||
o_sf_scale = None
|
||||
output_baseline = torch.empty(ref_query.shape, dtype=dtype)
|
||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||
if o_quant_dtype == FP4_DTYPE:
|
||||
o_sf_scale = 500.0
|
||||
output_trtllm = flashinfer.utils.FP4Tensor(
|
||||
torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8),
|
||||
torch.empty(
|
||||
(
|
||||
round_up(query.shape[0], 128),
|
||||
round_up(query.shape[1] * query.shape[2] // 16, 4),
|
||||
),
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
)
|
||||
else:
|
||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||
|
||||
def baseline_decode():
|
||||
return wrapper.run(ref_query, ref_kv_cache, out=output_baseline)
|
||||
return wrapper.run(
|
||||
ref_query,
|
||||
ref_kv_cache,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
out=output_baseline,
|
||||
)
|
||||
|
||||
def trtllm_decode():
|
||||
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
|
||||
@ -158,6 +180,7 @@ def benchmark_decode(
|
||||
max_seq_len=max_seq_len,
|
||||
bmm1_scale=q_scale * k_scale * sm_scale,
|
||||
bmm2_scale=v_scale / o_scale,
|
||||
o_sf_scale=o_sf_scale,
|
||||
out=output_trtllm,
|
||||
)
|
||||
|
||||
@ -237,6 +260,7 @@ if __name__ == "__main__":
|
||||
(None, None, None),
|
||||
(None, FP8_DTYPE, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
|
||||
]
|
||||
|
||||
for quant_dtype in quant_dtypes:
|
||||
|
||||
@ -9,8 +9,11 @@ from typing import Optional
|
||||
import flashinfer
|
||||
import torch
|
||||
|
||||
from vllm.utils import round_up
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
FP8_DTYPE = torch.float8_e4m3fn
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
|
||||
def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||
@ -72,13 +75,15 @@ def benchmark_prefill(
|
||||
]
|
||||
)
|
||||
|
||||
query = torch.randn(torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype)
|
||||
# Always using 1.0 scale to reflect the real perf in benchmarking
|
||||
q_scale = 1.0
|
||||
ref_query = torch.randn(
|
||||
torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype
|
||||
)
|
||||
if q_quant_dtype == FP8_DTYPE:
|
||||
query, q_scale = to_float8(query)
|
||||
ref_query = query.to(dtype) * q_scale
|
||||
query, _ = to_float8(ref_query)
|
||||
else:
|
||||
q_scale = 1.0
|
||||
ref_query = query
|
||||
query = ref_query
|
||||
|
||||
kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32)
|
||||
kv_lens[-1] = max_kv_len
|
||||
@ -86,14 +91,13 @@ def benchmark_prefill(
|
||||
seq_lens = kv_lens + q_lens
|
||||
max_seq_len = torch.max(seq_lens).item()
|
||||
|
||||
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
||||
# Always using 1.0 scale to reflect the real perf in benchmarking
|
||||
k_scale = v_scale = 1.0
|
||||
ref_kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
||||
if kv_quant_dtype == FP8_DTYPE:
|
||||
kv_cache, kv_scale = to_float8(kv_cache)
|
||||
ref_kv_cache = kv_cache.to(dtype) * kv_scale
|
||||
kv_cache, _ = to_float8(ref_kv_cache)
|
||||
else:
|
||||
kv_scale = 1.0
|
||||
ref_kv_cache = kv_cache
|
||||
k_scale = v_scale = kv_scale
|
||||
kv_cache = ref_kv_cache
|
||||
|
||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(
|
||||
@ -152,11 +156,31 @@ def benchmark_prefill(
|
||||
return sum(times) / len(times), torch.std(torch.tensor(times))
|
||||
|
||||
o_scale = 1.0
|
||||
o_sf_scale = None
|
||||
output_baseline = torch.empty(ref_query.shape, dtype=dtype)
|
||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||
if o_quant_dtype == FP4_DTYPE:
|
||||
o_sf_scale = 500.0
|
||||
output_trtllm = flashinfer.utils.FP4Tensor(
|
||||
torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8),
|
||||
torch.empty(
|
||||
(
|
||||
round_up(query.shape[0], 128),
|
||||
round_up(query.shape[1] * query.shape[2] // 16, 4),
|
||||
),
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
)
|
||||
else:
|
||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||
|
||||
def baseline_prefill():
|
||||
return wrapper.run(ref_query, ref_kv_cache, out=output_baseline)
|
||||
return wrapper.run(
|
||||
ref_query,
|
||||
ref_kv_cache,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
out=output_baseline,
|
||||
)
|
||||
|
||||
def trtllm_prefill():
|
||||
return flashinfer.prefill.trtllm_batch_context_with_kv_cache(
|
||||
@ -172,6 +196,7 @@ def benchmark_prefill(
|
||||
batch_size=batch_size,
|
||||
cum_seq_lens_q=q_indptr,
|
||||
cum_seq_lens_kv=kv_indptr,
|
||||
o_sf_scale=o_sf_scale,
|
||||
out=output_trtllm,
|
||||
)
|
||||
|
||||
@ -250,6 +275,7 @@ if __name__ == "__main__":
|
||||
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
|
||||
(None, None, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
|
||||
]
|
||||
|
||||
for quant_dtype in quant_dtypes:
|
||||
|
||||
@ -11,8 +11,8 @@ from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
import triton
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
_w8a8_block_fp8_matmul,
|
||||
|
||||
@ -95,4 +95,10 @@ WEIGHT_SHAPES = {
|
||||
([2048, 2816], 1),
|
||||
([1408, 2048], 0),
|
||||
],
|
||||
"CohereLabs/c4ai-command-a-03-2025": [
|
||||
([12288, 14336], 1),
|
||||
([12288, 12288], 0),
|
||||
([12288, 73728], 1),
|
||||
([36864, 12288], 0),
|
||||
],
|
||||
}
|
||||
|
||||
@ -19,7 +19,7 @@ else()
|
||||
FetchContent_Declare(
|
||||
flashmla
|
||||
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git
|
||||
GIT_TAG 0e43e774597682284358ff2c54530757b654b8d1
|
||||
GIT_TAG a757314c04eedd166e329e846c820eb1bdd702de
|
||||
GIT_PROGRESS TRUE
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
@ -37,13 +37,14 @@ cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
|
||||
set(FlashMLA_SOURCES
|
||||
${flashmla_SOURCE_DIR}/csrc/flash_api.cpp
|
||||
${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/kernels/mla_combine.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu)
|
||||
${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu)
|
||||
|
||||
set(FlashMLA_INCLUDES
|
||||
${flashmla_SOURCE_DIR}/csrc/cutlass/include
|
||||
${flashmla_SOURCE_DIR}/csrc/include)
|
||||
${flashmla_SOURCE_DIR}/csrc)
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${FlashMLA_SOURCES}"
|
||||
|
||||
@ -40,9 +40,11 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
|
||||
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||
const double scale, const std::string& kv_cache_dtype);
|
||||
|
||||
void gather_cache(
|
||||
void gather_and_maybe_dequant_cache(
|
||||
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
|
||||
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
|
||||
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
|
||||
torch::Tensor const& cu_seq_lens, // [BATCH+1]
|
||||
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);
|
||||
int64_t batch_size, const std::string& kv_cache_dtype,
|
||||
torch::Tensor const& scale,
|
||||
std::optional<torch::Tensor> seq_starts = std::nullopt);
|
||||
@ -624,9 +624,9 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||
namespace vllm {
|
||||
|
||||
// grid is launched with dimensions (batch, num_splits)
|
||||
template <typename scalar_t>
|
||||
__global__ void gather_cache(
|
||||
const scalar_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
|
||||
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
||||
__global__ void gather_and_maybe_dequant_cache(
|
||||
const cache_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
|
||||
// ENTRIES...]
|
||||
scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...]
|
||||
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
|
||||
@ -634,6 +634,7 @@ __global__ void gather_cache(
|
||||
const int32_t block_size, const int32_t entry_size,
|
||||
const int64_t block_table_stride, const int64_t cache_block_stride,
|
||||
const int64_t cache_entry_stride, const int64_t dst_entry_stride,
|
||||
const float* __restrict__ scale,
|
||||
const int32_t* __restrict__ seq_starts) { // Optional: starting offsets per
|
||||
// batch
|
||||
|
||||
@ -675,10 +676,16 @@ __global__ void gather_cache(
|
||||
if (partial_block_size) full_blocks_end -= 1;
|
||||
}
|
||||
|
||||
auto copy_entry = [&](const scalar_t* __restrict__ _src,
|
||||
auto copy_entry = [&](const cache_t* __restrict__ _src,
|
||||
scalar_t* __restrict__ _dst) {
|
||||
for (int i = threadIdx.x; i < entry_size; i += blockDim.x)
|
||||
_dst[i] = _src[i];
|
||||
for (int i = threadIdx.x; i < entry_size; i += blockDim.x) {
|
||||
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
||||
_dst[i] = static_cast<scalar_t>(_src[i]);
|
||||
} else {
|
||||
_dst[i] =
|
||||
fp8::scaled_convert<scalar_t, cache_t, kv_dt>(_src[i], *scale);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
for (int pid = split_start; pid < full_blocks_end; ++pid) {
|
||||
@ -705,25 +712,31 @@ __global__ void gather_cache(
|
||||
} // namespace vllm
|
||||
|
||||
// Macro to dispatch the kernel based on the data type.
|
||||
#define CALL_GATHER_CACHE(CPY_DTYPE) \
|
||||
vllm::gather_cache<CPY_DTYPE><<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<CPY_DTYPE*>(src_cache.data_ptr()), \
|
||||
reinterpret_cast<CPY_DTYPE*>(dst.data_ptr()), \
|
||||
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
|
||||
block_size, entry_size, block_table_stride, cache_block_stride, \
|
||||
cache_entry_stride, dst_entry_stride, seq_starts_ptr);
|
||||
// SCALAR_T is the data type of the destination tensor.
|
||||
// CACHE_T is the stored data type of kv-cache.
|
||||
// KV_DTYPE is the real data type of kv-cache.
|
||||
#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::gather_and_maybe_dequant_cache<SCALAR_T, CACHE_T, KV_DTYPE> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<CACHE_T*>(src_cache.data_ptr()), \
|
||||
reinterpret_cast<SCALAR_T*>(dst.data_ptr()), \
|
||||
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
|
||||
block_size, entry_size, block_table_stride, cache_block_stride, \
|
||||
cache_entry_stride, dst_entry_stride, \
|
||||
reinterpret_cast<const float*>(scale.data_ptr()), seq_starts_ptr);
|
||||
|
||||
// Gather sequences from the cache into the destination tensor.
|
||||
// - cu_seq_lens contains the cumulative sequence lengths for each batch
|
||||
// - block_table contains the cache block indices for each sequence
|
||||
// - Optionally, seq_starts (if provided) offsets the starting block index by
|
||||
// (seq_starts[bid] / page_size)
|
||||
void gather_cache(
|
||||
void gather_and_maybe_dequant_cache(
|
||||
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
|
||||
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
|
||||
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
|
||||
torch::Tensor const& cu_seq_lens, // [BATCH+1]
|
||||
int64_t batch_size,
|
||||
int64_t batch_size, const std::string& kv_cache_dtype,
|
||||
torch::Tensor const& scale,
|
||||
std::optional<torch::Tensor> seq_starts = std::nullopt) {
|
||||
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
@ -761,20 +774,8 @@ void gather_cache(
|
||||
dim3 grid(batch_size, num_splits);
|
||||
dim3 block(1024);
|
||||
|
||||
TORCH_CHECK(src_cache.dtype() == dst.dtype(),
|
||||
"src_cache and dst must have the same dtype");
|
||||
|
||||
const int dtype_bits = src_cache.element_size() * 8;
|
||||
const int32_t* seq_starts_ptr =
|
||||
seq_starts.has_value() ? seq_starts.value().data_ptr<int32_t>() : nullptr;
|
||||
|
||||
if (dtype_bits == 32) {
|
||||
CALL_GATHER_CACHE(uint32_t);
|
||||
} else if (dtype_bits == 16) {
|
||||
CALL_GATHER_CACHE(uint16_t);
|
||||
} else if (dtype_bits == 8) {
|
||||
CALL_GATHER_CACHE(uint8_t);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits);
|
||||
}
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype, CALL_GATHER_CACHE);
|
||||
}
|
||||
|
||||
757
csrc/moe/grouped_topk_kernels.cu
Normal file
757
csrc/moe/grouped_topk_kernels.cu
Normal file
@ -0,0 +1,757 @@
|
||||
/*
|
||||
* Adapted from
|
||||
* https://github.com/NVIDIA/TensorRT-LLM/blob/v0.21.0/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu
|
||||
* Copyright (c) 2025, The vLLM team.
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION &
|
||||
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/all.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
namespace vllm {
|
||||
namespace moe {
|
||||
|
||||
constexpr unsigned FULL_WARP_MASK = 0xffffffff;
|
||||
constexpr int32_t WARP_SIZE = 32;
|
||||
constexpr int32_t BLOCK_SIZE = 512;
|
||||
constexpr int32_t NUM_WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE;
|
||||
|
||||
namespace warp_topk {
|
||||
|
||||
template <int size, typename T>
|
||||
__host__ __device__ constexpr T round_up_to_multiple_of(T len) {
|
||||
if (len == 0) {
|
||||
return 0;
|
||||
}
|
||||
return ((len - 1) / size + 1) * size;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr __host__ __device__ bool isPowerOf2(T v) {
|
||||
return (v && !(v & (v - 1)));
|
||||
}
|
||||
|
||||
template <bool greater, typename T>
|
||||
__forceinline__ __device__ bool is_better_than(T val, T baseline) {
|
||||
return (val > baseline && greater) || (val < baseline && !greater);
|
||||
}
|
||||
|
||||
template <bool greater, typename T, typename idxT>
|
||||
__forceinline__ __device__ bool is_better_than(T val, T baseline, idxT index,
|
||||
idxT baseline_index) {
|
||||
bool res = (val > baseline && greater) || (val < baseline && !greater);
|
||||
if (val == baseline) {
|
||||
res = (index < baseline_index && greater) ||
|
||||
(index < baseline_index && !greater);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename T, typename idxT>
|
||||
int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) {
|
||||
int64_t cache_topk = (sizeof(T) + sizeof(idxT)) * num_of_warp * k;
|
||||
int64_t n = std::max<int>(num_of_warp / 2 * k, num_of_warp * WARP_SIZE);
|
||||
return max(cache_topk,
|
||||
round_up_to_multiple_of<256>(n * sizeof(T)) + n * sizeof(idxT));
|
||||
}
|
||||
|
||||
template <int size, bool ascending, bool reverse, typename T, typename idxT,
|
||||
bool is_stable>
|
||||
struct BitonicMerge {
|
||||
// input should be a bitonic sequence, and sort it to be a monotonic sequence
|
||||
__device__ static void merge(T* __restrict__ val_arr,
|
||||
idxT* __restrict__ idx_arr) {
|
||||
static_assert(isPowerOf2(size));
|
||||
static_assert(size >= 2 * WARP_SIZE);
|
||||
constexpr int arr_len = size / WARP_SIZE;
|
||||
|
||||
constexpr int stride = arr_len / 2;
|
||||
for (int i = 0; i < stride; ++i) {
|
||||
int const other_i = i + stride;
|
||||
T& val = val_arr[i];
|
||||
T& other_val = val_arr[other_i];
|
||||
bool is_better;
|
||||
if constexpr (is_stable) {
|
||||
is_better = is_better_than<ascending>(val, other_val, idx_arr[i],
|
||||
idx_arr[other_i]);
|
||||
} else {
|
||||
is_better = is_better_than<ascending>(val, other_val);
|
||||
}
|
||||
|
||||
if (is_better) {
|
||||
T tmp = val;
|
||||
val = other_val;
|
||||
other_val = tmp;
|
||||
|
||||
idxT tmp2 = idx_arr[i];
|
||||
idx_arr[i] = idx_arr[other_i];
|
||||
idx_arr[other_i] = tmp2;
|
||||
}
|
||||
}
|
||||
|
||||
BitonicMerge<size / 2, ascending, reverse, T, idxT, is_stable>::merge(
|
||||
val_arr, idx_arr);
|
||||
BitonicMerge<size / 2, ascending, reverse, T, idxT, is_stable>::merge(
|
||||
val_arr + arr_len / 2, idx_arr + arr_len / 2);
|
||||
}
|
||||
};
|
||||
|
||||
template <int size, bool ascending, typename T, typename idxT, bool is_stable>
|
||||
struct BitonicSort {
|
||||
__device__ static void sort(T* __restrict__ val_arr,
|
||||
idxT* __restrict__ idx_arr) {
|
||||
static_assert(isPowerOf2(size));
|
||||
static_assert(size >= 2 * WARP_SIZE);
|
||||
constexpr int arr_len = size / WARP_SIZE;
|
||||
|
||||
BitonicSort<size / 2, true, T, idxT, is_stable>::sort(val_arr, idx_arr);
|
||||
BitonicSort<size / 2, false, T, idxT, is_stable>::sort(
|
||||
val_arr + arr_len / 2, idx_arr + arr_len / 2);
|
||||
BitonicMerge<size, ascending, ascending, T, idxT, is_stable>::merge(
|
||||
val_arr, idx_arr);
|
||||
}
|
||||
};
|
||||
|
||||
template <bool ascending, typename T, typename idxT, bool is_stable>
|
||||
struct BitonicSort<32, ascending, T, idxT, is_stable> {
|
||||
__device__ static void sort(T* __restrict__ val_arr,
|
||||
idxT* __restrict__ idx_arr) {
|
||||
int const lane = threadIdx.x % WARP_SIZE;
|
||||
|
||||
// ascending doesn't matter before merging since all we need is a bitonic
|
||||
// sequence
|
||||
for (int stage = 0; stage < 4; ++stage) {
|
||||
for (int stride = (1 << stage); stride > 0; stride /= 2) {
|
||||
bool reverse = (lane >> stage) & 2;
|
||||
bool is_second = lane & stride;
|
||||
|
||||
T other = __shfl_xor_sync(FULL_WARP_MASK, *val_arr, stride);
|
||||
idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, *idx_arr, stride);
|
||||
|
||||
bool is_better;
|
||||
if constexpr (is_stable) {
|
||||
if constexpr (ascending) {
|
||||
is_better = ((*val_arr > other) ||
|
||||
((*val_arr == other) && (*idx_arr < other_idx))) !=
|
||||
(reverse != is_second);
|
||||
} else {
|
||||
is_better = ((*val_arr > other) ||
|
||||
((*val_arr == other) && (*idx_arr > other_idx))) !=
|
||||
(reverse != is_second);
|
||||
}
|
||||
} else {
|
||||
is_better = (*val_arr != other &&
|
||||
(*val_arr > other) != (reverse != is_second));
|
||||
}
|
||||
if (is_better) {
|
||||
*val_arr = other;
|
||||
*idx_arr = other_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BitonicMerge<32, ascending, ascending, T, idxT, is_stable>::merge(val_arr,
|
||||
idx_arr);
|
||||
}
|
||||
};
|
||||
|
||||
template <bool ascending, bool reverse, typename T, typename idxT,
|
||||
bool is_stable>
|
||||
struct BitonicMerge<32, ascending, reverse, T, idxT, is_stable> {
|
||||
__device__ static void merge(T* __restrict__ val_arr,
|
||||
idxT* __restrict__ idx_arr) {
|
||||
int const lane = threadIdx.x % WARP_SIZE;
|
||||
for (int stride = WARP_SIZE / 2; stride > 0; stride /= 2) {
|
||||
bool is_second = lane & stride;
|
||||
T& val = *val_arr;
|
||||
T other = __shfl_xor_sync(FULL_WARP_MASK, val, stride);
|
||||
idxT& idx = *idx_arr;
|
||||
idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, idx, stride);
|
||||
|
||||
bool is_better;
|
||||
if constexpr (is_stable) {
|
||||
if constexpr (ascending) {
|
||||
is_better = ((*val_arr > other) ||
|
||||
((*val_arr == other) && (*idx_arr < other_idx))) ==
|
||||
(reverse != is_second); // for min
|
||||
} else {
|
||||
is_better = ((*val_arr > other) ||
|
||||
((*val_arr == other) && (*idx_arr > other_idx))) ==
|
||||
(reverse != is_second); // for max
|
||||
}
|
||||
} else {
|
||||
is_better =
|
||||
(val != other && ((val > other) == (ascending != is_second)));
|
||||
}
|
||||
|
||||
if (is_better) {
|
||||
val = other;
|
||||
idx = other_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <int capacity, bool greater, typename T, typename idxT, bool is_stable>
|
||||
class WarpSort {
|
||||
public:
|
||||
__device__ WarpSort(idxT k, T dummy)
|
||||
: lane_(threadIdx.x % WARP_SIZE), k_(k), dummy_(dummy) {
|
||||
static_assert(capacity >= WARP_SIZE && isPowerOf2(capacity));
|
||||
|
||||
for (int i = 0; i < max_arr_len_; ++i) {
|
||||
val_arr_[i] = dummy_;
|
||||
idx_arr_[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// load and merge k sorted values
|
||||
__device__ void load_sorted(T const* __restrict__ in,
|
||||
idxT const* __restrict__ in_idx, idxT start) {
|
||||
idxT idx = start + WARP_SIZE - 1 - lane_;
|
||||
for (int i = max_arr_len_ - 1; i >= 0; --i, idx += WARP_SIZE) {
|
||||
if (idx < start + k_) {
|
||||
T t = in[idx];
|
||||
bool is_better;
|
||||
if constexpr (is_stable) {
|
||||
is_better =
|
||||
is_better_than<greater>(t, val_arr_[i], in_idx[idx], idx_arr_[i]);
|
||||
} else {
|
||||
is_better = is_better_than<greater>(t, val_arr_[i]);
|
||||
}
|
||||
if (is_better) {
|
||||
val_arr_[i] = t;
|
||||
idx_arr_[i] = in_idx[idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BitonicMerge<capacity, greater, !greater, T, idxT, is_stable>::merge(
|
||||
val_arr_, idx_arr_);
|
||||
}
|
||||
|
||||
__device__ void dump(T* __restrict__ out, idxT* __restrict__ out_idx) const {
|
||||
for (int i = 0; i < max_arr_len_; ++i) {
|
||||
idxT out_i = i * WARP_SIZE + lane_;
|
||||
if (out_i < k_) {
|
||||
out[out_i] = val_arr_[i];
|
||||
out_idx[out_i] = idx_arr_[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void dumpIdx(idxT* __restrict__ out_idx) const {
|
||||
for (int i = 0; i < max_arr_len_; ++i) {
|
||||
idxT out_i = i * WARP_SIZE + lane_;
|
||||
if (out_i < k_) {
|
||||
out_idx[out_i] = idx_arr_[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
static constexpr int max_arr_len_ = capacity / WARP_SIZE;
|
||||
|
||||
T val_arr_[max_arr_len_];
|
||||
idxT idx_arr_[max_arr_len_];
|
||||
|
||||
int const lane_;
|
||||
idxT const k_;
|
||||
T const dummy_;
|
||||
|
||||
}; // end class WarpSort
|
||||
|
||||
template <int capacity, bool greater, typename T, typename idxT, bool is_stable>
|
||||
class WarpSelect : public WarpSort<capacity, greater, T, idxT, is_stable> {
|
||||
public:
|
||||
__device__ WarpSelect(idxT k, T dummy)
|
||||
: WarpSort<capacity, greater, T, idxT, is_stable>(k, dummy),
|
||||
k_th_(dummy),
|
||||
k_th_lane_((k - 1) % WARP_SIZE) {
|
||||
extern __shared__ char smem_buf[]; // extern __shared__ T smem_buf[];
|
||||
|
||||
int const num_of_warp = blockDim.x / WARP_SIZE;
|
||||
int const warp_id = threadIdx.x / WARP_SIZE;
|
||||
val_smem_ = reinterpret_cast<T*>(smem_buf);
|
||||
val_smem_ += warp_id * WARP_SIZE;
|
||||
idx_smem_ = reinterpret_cast<idxT*>(
|
||||
smem_buf +
|
||||
round_up_to_multiple_of<256>(num_of_warp * sizeof(T) * WARP_SIZE));
|
||||
idx_smem_ += warp_id * WARP_SIZE;
|
||||
}
|
||||
|
||||
__device__ void add(T const* in, idxT start, idxT end) {
|
||||
idxT const end_for_fullwarp =
|
||||
round_up_to_multiple_of<WARP_SIZE>(end - start) + start;
|
||||
for (idxT i = start + lane_; i < end_for_fullwarp; i += WARP_SIZE) {
|
||||
T val = (i < end) ? in[i] : dummy_;
|
||||
add(val, i);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void add(T val, idxT idx) {
|
||||
bool do_add;
|
||||
if constexpr (is_stable) {
|
||||
do_add = is_better_than<greater>(val, k_th_, idx, k_th_idx_);
|
||||
} else {
|
||||
do_add = is_better_than<greater>(val, k_th_);
|
||||
}
|
||||
|
||||
uint32_t mask = __ballot_sync(FULL_WARP_MASK, do_add);
|
||||
if (mask == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
int pos = smem_buf_len_ + __popc(mask & ((0x1u << lane_) - 1));
|
||||
if (do_add && pos < WARP_SIZE) {
|
||||
val_smem_[pos] = val;
|
||||
idx_smem_[pos] = idx;
|
||||
do_add = false;
|
||||
}
|
||||
smem_buf_len_ += __popc(mask);
|
||||
if (smem_buf_len_ >= WARP_SIZE) {
|
||||
__syncwarp();
|
||||
merge_buf_(val_smem_[lane_], idx_smem_[lane_]);
|
||||
smem_buf_len_ -= WARP_SIZE;
|
||||
}
|
||||
if (do_add) {
|
||||
pos -= WARP_SIZE;
|
||||
val_smem_[pos] = val;
|
||||
idx_smem_[pos] = idx;
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
__device__ void done() {
|
||||
if (smem_buf_len_) {
|
||||
T val = (lane_ < smem_buf_len_) ? val_smem_[lane_] : dummy_;
|
||||
idxT idx = (lane_ < smem_buf_len_) ? idx_smem_[lane_] : 0;
|
||||
merge_buf_(val, idx);
|
||||
}
|
||||
|
||||
// after done(), smem is used for merging results among warps
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
private:
|
||||
__device__ void set_k_th_() {
|
||||
k_th_ = __shfl_sync(FULL_WARP_MASK, val_arr_[max_arr_len_ - 1], k_th_lane_);
|
||||
if constexpr (is_stable) {
|
||||
k_th_idx_ =
|
||||
__shfl_sync(FULL_WARP_MASK, idx_arr_[max_arr_len_ - 1], k_th_lane_);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void merge_buf_(T val, idxT idx) {
|
||||
BitonicSort<WARP_SIZE, greater, T, idxT, is_stable>::sort(&val, &idx);
|
||||
|
||||
T& old = val_arr_[max_arr_len_ - 1];
|
||||
|
||||
bool is_better;
|
||||
if constexpr (is_stable) {
|
||||
is_better =
|
||||
is_better_than<greater>(val, old, idx, idx_arr_[max_arr_len_ - 1]);
|
||||
} else {
|
||||
is_better = is_better_than<greater>(val, old);
|
||||
}
|
||||
|
||||
if (is_better) {
|
||||
old = val;
|
||||
idx_arr_[max_arr_len_ - 1] = idx;
|
||||
}
|
||||
|
||||
BitonicMerge<capacity, greater, !greater, T, idxT, is_stable>::merge(
|
||||
val_arr_, idx_arr_);
|
||||
|
||||
set_k_th_();
|
||||
}
|
||||
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::max_arr_len_;
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::val_arr_;
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::idx_arr_;
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::lane_;
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::k_;
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::dummy_;
|
||||
|
||||
T* val_smem_;
|
||||
idxT* idx_smem_;
|
||||
int smem_buf_len_ = 0;
|
||||
|
||||
T k_th_;
|
||||
idxT k_th_idx_;
|
||||
int const k_th_lane_;
|
||||
}; // end class WarpSelect
|
||||
} // namespace warp_topk
|
||||
|
||||
template <typename T_OUT, typename T_IN>
|
||||
__device__ inline T_OUT cuda_cast(T_IN val) {
|
||||
return val;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline float cuda_cast<float, __nv_bfloat16>(__nv_bfloat16 val) {
|
||||
return __bfloat162float(val);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void topk_with_k2(T* output, T const* input,
|
||||
cg::thread_block_tile<32> const& tile,
|
||||
int32_t const lane_id,
|
||||
int const num_experts_per_group) {
|
||||
// Get the top2 per thread
|
||||
T largest = -INFINITY;
|
||||
T second_largest = -INFINITY;
|
||||
|
||||
if (num_experts_per_group > WARP_SIZE) {
|
||||
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
|
||||
T value = input[i];
|
||||
if (value > largest) {
|
||||
second_largest = largest;
|
||||
largest = value;
|
||||
} else if (value > second_largest) {
|
||||
second_largest = value;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
|
||||
largest = input[i];
|
||||
}
|
||||
}
|
||||
|
||||
__syncwarp(); // Ensure all threads have valid data before reduction
|
||||
// Get the top2 warpwise
|
||||
T max1 = cg::reduce(tile, largest, cg::greater<T>());
|
||||
|
||||
T max2 = max1;
|
||||
bool equal_to_max1 = (max1 == largest);
|
||||
|
||||
int count_max1 = __popc(__ballot_sync(FULL_WARP_MASK, equal_to_max1));
|
||||
|
||||
if (count_max1 == 1) {
|
||||
largest = (largest == max1) ? second_largest : largest;
|
||||
max2 = cg::reduce(tile, largest, cg::greater<T>());
|
||||
}
|
||||
|
||||
if (lane_id == 0) {
|
||||
*output = max1 + max2;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void topk_with_k2_kernel(T* output, T* input,
|
||||
int64_t const num_tokens,
|
||||
int64_t const num_cases,
|
||||
int64_t const n_group,
|
||||
int64_t const num_experts_per_group) {
|
||||
int32_t warp_id = threadIdx.x / WARP_SIZE;
|
||||
int32_t lane_id = threadIdx.x % WARP_SIZE;
|
||||
|
||||
int32_t case_id = blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id;
|
||||
if (case_id < num_cases) {
|
||||
input += case_id * num_experts_per_group;
|
||||
output += case_id;
|
||||
|
||||
cg::thread_block block = cg::this_thread_block();
|
||||
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block);
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.wait;");
|
||||
#endif
|
||||
topk_with_k2(output, input, tile, lane_id, num_experts_per_group);
|
||||
}
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT>
|
||||
__global__ void group_idx_and_topk_idx_kernel(
|
||||
T* scores, T const* group_scores, T* topk_values, IdxT* topk_indices,
|
||||
T* scores_with_bias, int64_t const num_tokens, int64_t const n_group,
|
||||
int64_t const topk_group, int64_t const topk, int64_t const num_experts,
|
||||
int64_t const num_experts_per_group, bool renormalize,
|
||||
double routed_scaling_factor) {
|
||||
int32_t warp_id = threadIdx.x / WARP_SIZE;
|
||||
int32_t lane_id = threadIdx.x % WARP_SIZE;
|
||||
int32_t case_id =
|
||||
blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; // one per token
|
||||
scores_with_bias += case_id * num_experts;
|
||||
scores += case_id * num_experts;
|
||||
group_scores += case_id * n_group;
|
||||
topk_values += case_id * topk;
|
||||
topk_indices += case_id * topk;
|
||||
|
||||
int32_t align_num_experts_per_group =
|
||||
warp_topk::round_up_to_multiple_of<WARP_SIZE>(num_experts_per_group);
|
||||
|
||||
cg::thread_block block = cg::this_thread_block();
|
||||
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block);
|
||||
|
||||
extern __shared__ char smem_buf[]; // NOTE: reuse the shared memory here to
|
||||
// store the target topk idx
|
||||
int32_t* s_topk_idx = reinterpret_cast<int32_t*>(smem_buf);
|
||||
T* s_topk_value =
|
||||
reinterpret_cast<T*>(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) +
|
||||
warp_id * topk;
|
||||
s_topk_idx += warp_id * topk;
|
||||
|
||||
T value = cuda::std::numeric_limits<T>::min();
|
||||
T topk_group_value = cuda::std::numeric_limits<T>::min();
|
||||
int32_t num_equalto_topkth_group;
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.wait;"); // I think all prolog can be put before
|
||||
// acqbulk because it's ptr arithmetic
|
||||
#endif
|
||||
|
||||
if (case_id < num_tokens) {
|
||||
// calculate group_idx
|
||||
int32_t target_num_min = WARP_SIZE - n_group + topk_group;
|
||||
if (lane_id < n_group &&
|
||||
(isfinite(cuda_cast<float, T>(
|
||||
group_scores[lane_id])))) // The check is necessary to avoid
|
||||
// abnormal input
|
||||
{
|
||||
value = group_scores[lane_id];
|
||||
}
|
||||
|
||||
int count_equal_to_top_value = WARP_SIZE - n_group;
|
||||
int pre_count_equal_to_top_value = 0;
|
||||
// Use loop to find the largset top_group
|
||||
while (count_equal_to_top_value < target_num_min) {
|
||||
__syncwarp(); // Ensure all threads have valid data before reduction
|
||||
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
|
||||
if (value == topk_group_value) {
|
||||
value = cuda::std::numeric_limits<T>::min();
|
||||
}
|
||||
pre_count_equal_to_top_value = count_equal_to_top_value;
|
||||
count_equal_to_top_value = __popc(__ballot_sync(
|
||||
FULL_WARP_MASK, (value == cuda::std::numeric_limits<T>::min())));
|
||||
}
|
||||
num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
warp_topk::WarpSelect</*capability*/ WARP_SIZE, /*greater*/ true, T, int32_t,
|
||||
/* is_stable */ true>
|
||||
queue((int32_t)topk, -INFINITY);
|
||||
|
||||
int count_equalto_topkth_group = 0;
|
||||
bool if_proceed_next_topk =
|
||||
(topk_group_value != cuda::std::numeric_limits<T>::min());
|
||||
if (case_id < num_tokens && if_proceed_next_topk) {
|
||||
for (int i_group = 0; i_group < n_group; i_group++) {
|
||||
if ((group_scores[i_group] > topk_group_value) ||
|
||||
((group_scores[i_group] == topk_group_value) &&
|
||||
(count_equalto_topkth_group < num_equalto_topkth_group))) {
|
||||
int32_t offset = i_group * num_experts_per_group;
|
||||
for (int32_t i = lane_id; i < align_num_experts_per_group;
|
||||
i += WARP_SIZE) {
|
||||
T candidates =
|
||||
(i < num_experts_per_group) && isfinite(cuda_cast<float, T>(
|
||||
scores_with_bias[offset + i]))
|
||||
? scores_with_bias[offset + i]
|
||||
: cuda::std::numeric_limits<T>::min();
|
||||
queue.add(candidates, offset + i);
|
||||
}
|
||||
if (group_scores[i_group] == topk_group_value) {
|
||||
count_equalto_topkth_group++;
|
||||
}
|
||||
}
|
||||
}
|
||||
queue.done();
|
||||
__syncwarp();
|
||||
// Get the topk_idx
|
||||
queue.dumpIdx(s_topk_idx);
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
// Load the valid score value
|
||||
// Calculate the summation
|
||||
float topk_sum = 1e-20;
|
||||
if (case_id < num_tokens && if_proceed_next_topk) {
|
||||
for (int i = lane_id;
|
||||
i < warp_topk::round_up_to_multiple_of<WARP_SIZE>(topk);
|
||||
i += WARP_SIZE) {
|
||||
T value =
|
||||
i < topk
|
||||
? scores[s_topk_idx[i]]
|
||||
: cuda_cast<T, float>(0.0f); // Load the valid value of expert
|
||||
if (i < topk) {
|
||||
s_topk_value[i] = value;
|
||||
}
|
||||
topk_sum += reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (case_id < num_tokens) {
|
||||
if (if_proceed_next_topk) {
|
||||
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
||||
float value;
|
||||
if (renormalize) {
|
||||
value = cuda_cast<float, T>(s_topk_value[i]) / topk_sum *
|
||||
routed_scaling_factor;
|
||||
} else {
|
||||
value = cuda_cast<float, T>(s_topk_value[i]) * routed_scaling_factor;
|
||||
}
|
||||
topk_indices[i] = s_topk_idx[i];
|
||||
topk_values[i] = cuda_cast<T, float>(value);
|
||||
}
|
||||
} else {
|
||||
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
||||
topk_indices[i] = i;
|
||||
topk_values[i] = cuda_cast<T, float>(1.0f / topk);
|
||||
}
|
||||
}
|
||||
// Note: when if_proceed_next_topk==false, choose the first 8 experts as the
|
||||
// default result.
|
||||
}
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT>
|
||||
void invokeNoAuxTc(T* scores, T* group_scores, T* topk_values,
|
||||
IdxT* topk_indices, T* scores_with_bias,
|
||||
int64_t const num_tokens, int64_t const num_experts,
|
||||
int64_t const n_group, int64_t const topk_group,
|
||||
int64_t const topk, bool const renormalize,
|
||||
double const routed_scaling_factor, bool enable_pdl = false,
|
||||
cudaStream_t const stream = 0) {
|
||||
int64_t num_cases = num_tokens * n_group;
|
||||
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
|
||||
auto* kernel_instance1 = &topk_with_k2_kernel<T>;
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = topk_with_k2_num_blocks;
|
||||
config.blockDim = BLOCK_SIZE;
|
||||
config.dynamicSmemBytes = 0;
|
||||
config.stream = stream;
|
||||
cudaLaunchAttribute attrs[1];
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores_with_bias,
|
||||
num_tokens, num_cases, n_group, num_experts / n_group);
|
||||
|
||||
int64_t topk_with_k_group_num_blocks =
|
||||
(num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
|
||||
size_t dynamic_smem_in_bytes =
|
||||
warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK,
|
||||
topk);
|
||||
auto* kernel_instance2 = &group_idx_and_topk_idx_kernel<T, IdxT>;
|
||||
config.gridDim = topk_with_k_group_num_blocks;
|
||||
config.blockDim = BLOCK_SIZE;
|
||||
config.dynamicSmemBytes = dynamic_smem_in_bytes;
|
||||
config.stream = stream;
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores,
|
||||
topk_values, topk_indices, scores_with_bias, num_tokens,
|
||||
n_group, topk_group, topk, num_experts,
|
||||
num_experts / n_group, renormalize, routed_scaling_factor);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_NOAUX_TC(T, IdxT) \
|
||||
template void invokeNoAuxTc<T, IdxT>( \
|
||||
T * scores, T * group_scores, T * topk_values, IdxT * topk_indices, \
|
||||
T * scores_with_bias, int64_t const num_tokens, \
|
||||
int64_t const num_experts, int64_t const n_group, \
|
||||
int64_t const topk_group, int64_t const topk, bool const renormalize, \
|
||||
double const routed_scaling_factor, bool enable_pdl, \
|
||||
cudaStream_t const stream);
|
||||
|
||||
INSTANTIATE_NOAUX_TC(float, int32_t);
|
||||
INSTANTIATE_NOAUX_TC(half, int32_t);
|
||||
INSTANTIATE_NOAUX_TC(__nv_bfloat16, int32_t);
|
||||
} // end namespace moe
|
||||
} // namespace vllm
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
|
||||
torch::Tensor const& scores, torch::Tensor const& scores_with_bias,
|
||||
int64_t n_group, int64_t topk_group, int64_t topk, bool renormalize,
|
||||
double routed_scaling_factor) {
|
||||
auto data_type = scores_with_bias.scalar_type();
|
||||
auto input_size = scores_with_bias.sizes();
|
||||
int64_t num_tokens = input_size[0];
|
||||
int64_t num_experts = input_size[1];
|
||||
TORCH_CHECK(input_size.size() == 2, "scores_with_bias must be a 2D Tensor");
|
||||
TORCH_CHECK(num_experts % n_group == 0,
|
||||
"num_experts should be divisible by n_group");
|
||||
TORCH_CHECK(n_group <= 32,
|
||||
"n_group should be smaller than or equal to 32 for now");
|
||||
TORCH_CHECK(topk <= 32, "topk should be smaller than or equal to 32 for now");
|
||||
|
||||
torch::Tensor group_scores = torch::empty(
|
||||
{num_tokens, n_group}, torch::dtype(data_type).device(torch::kCUDA));
|
||||
torch::Tensor topk_values = torch::empty(
|
||||
{num_tokens, topk}, torch::dtype(data_type).device(torch::kCUDA));
|
||||
torch::Tensor topk_indices = torch::empty(
|
||||
{num_tokens, topk}, torch::dtype(torch::kInt32).device(torch::kCUDA));
|
||||
|
||||
auto stream = c10::cuda::getCurrentCUDAStream(scores_with_bias.get_device());
|
||||
|
||||
switch (data_type) {
|
||||
case torch::kFloat16:
|
||||
// Handle Float16
|
||||
vllm::moe::invokeNoAuxTc<half, int32_t>(
|
||||
reinterpret_cast<half*>(scores.mutable_data_ptr()),
|
||||
reinterpret_cast<half*>(group_scores.mutable_data_ptr()),
|
||||
reinterpret_cast<half*>(topk_values.mutable_data_ptr()),
|
||||
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()),
|
||||
reinterpret_cast<half*>(scores_with_bias.data_ptr()), num_tokens,
|
||||
num_experts, n_group, topk_group, topk, renormalize,
|
||||
routed_scaling_factor, false, stream);
|
||||
break;
|
||||
case torch::kFloat32:
|
||||
// Handle Float32
|
||||
vllm::moe::invokeNoAuxTc<float, int32_t>(
|
||||
reinterpret_cast<float*>(scores.mutable_data_ptr()),
|
||||
reinterpret_cast<float*>(group_scores.mutable_data_ptr()),
|
||||
reinterpret_cast<float*>(topk_values.mutable_data_ptr()),
|
||||
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()),
|
||||
reinterpret_cast<float*>(scores_with_bias.data_ptr()), num_tokens,
|
||||
num_experts, n_group, topk_group, topk, renormalize,
|
||||
routed_scaling_factor, false, stream);
|
||||
break;
|
||||
case torch::kBFloat16:
|
||||
// Handle BFloat16
|
||||
vllm::moe::invokeNoAuxTc<__nv_bfloat16, int32_t>(
|
||||
reinterpret_cast<__nv_bfloat16*>(scores.mutable_data_ptr()),
|
||||
reinterpret_cast<__nv_bfloat16*>(group_scores.mutable_data_ptr()),
|
||||
reinterpret_cast<__nv_bfloat16*>(topk_values.mutable_data_ptr()),
|
||||
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()),
|
||||
reinterpret_cast<__nv_bfloat16*>(scores_with_bias.data_ptr()),
|
||||
num_tokens, num_experts, n_group, topk_group, topk, renormalize,
|
||||
routed_scaling_factor, false, stream);
|
||||
break;
|
||||
default:
|
||||
// Handle other data types
|
||||
throw std::invalid_argument(
|
||||
"Invalid dtype, only supports float16, float32, and bfloat16");
|
||||
break;
|
||||
}
|
||||
return {topk_values, topk_indices};
|
||||
}
|
||||
@ -22,6 +22,11 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
|
||||
torch::Tensor num_tokens_post_pad, int64_t top_k,
|
||||
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
|
||||
int64_t BLOCK_SIZE_K, int64_t bit);
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
|
||||
torch::Tensor const& scores, torch::Tensor const& scores_with_bias,
|
||||
int64_t n_group, int64_t topk_group, int64_t topk, bool renormalize,
|
||||
double routed_scaling_factor);
|
||||
#endif
|
||||
|
||||
bool moe_permute_unpermute_supported();
|
||||
|
||||
@ -78,6 +78,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
"output_tensor) -> ()");
|
||||
m.impl("shuffle_rows", torch::kCUDA, &shuffle_rows);
|
||||
|
||||
// Apply grouped topk routing to select experts.
|
||||
m.def(
|
||||
"grouped_topk(Tensor scores, Tensor scores_with_bias, int n_group, int "
|
||||
"topk_group, int topk, bool renormalize, float "
|
||||
"routed_scaling_factor) -> (Tensor, Tensor)");
|
||||
m.impl("grouped_topk", torch::kCUDA, &grouped_topk);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
418
csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu
Normal file
418
csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu
Normal file
@ -0,0 +1,418 @@
|
||||
//
|
||||
// Based off of:
|
||||
// https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu
|
||||
//
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
#include "cutlass_extensions/torch_utils.hpp"
|
||||
|
||||
#include "core/registration.h"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/mixed_dtype_utils.hpp"
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm::cutlass_w4a8 {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
// -------------------------------------------------------------------------------------
|
||||
// Static configuration shared across all instantiations
|
||||
// -------------------------------------------------------------------------------------
|
||||
using MmaType = cutlass::float_e4m3_t; // A/scale element type
|
||||
using QuantType = cutlass::int4b_t; // B element type (packed int4)
|
||||
|
||||
static int constexpr TileShapeK = 128 * 8 / sizeof_bits<MmaType>::value;
|
||||
static int constexpr ScalePackSize = 8; // pack 8 scale elements together
|
||||
static int constexpr PackFactor = 8; // 8 4-bit packed into int32
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = MmaType; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
using LayoutA_Transpose =
|
||||
typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
||||
constexpr int AlignmentA =
|
||||
128 / cutlass::sizeof_bits<
|
||||
ElementA>::value; // Memory access granularity/alignment of A
|
||||
// matrix in units of elements (up to 16 bytes)
|
||||
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = QuantType; // Element type for B matrix operand
|
||||
using LayoutB =
|
||||
cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
using LayoutB_Transpose =
|
||||
typename cutlass::layout::LayoutTranspose<LayoutB>::type;
|
||||
constexpr int AlignmentB =
|
||||
128 / cutlass::sizeof_bits<
|
||||
ElementB>::value; // Memory access granularity/alignment of B
|
||||
// matrix in units of elements (up to 16 bytes)
|
||||
using StrideB = cutlass::detail::TagToStrideB_t<LayoutB>;
|
||||
|
||||
// Define the CuTe layout for reordered quantized tensor B
|
||||
// LayoutAtomQuant places values that will be read by the same thread in
|
||||
// contiguous locations in global memory. It specifies the reordering within a
|
||||
// single warp's fragment
|
||||
using LayoutAtomQuant =
|
||||
decltype(cutlass::compute_memory_reordering_atom<MmaType>());
|
||||
using LayoutB_Reordered = decltype(cute::tile_to_shape(
|
||||
LayoutAtomQuant{}, Layout<Shape<int, int, int>, StrideB>{}));
|
||||
|
||||
// Group-wise scales
|
||||
using ElementScale = MmaType;
|
||||
using LayoutScale = cutlass::layout::RowMajor;
|
||||
|
||||
// Per-tok, per-chan scales
|
||||
using ElementSChannel = float;
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC =
|
||||
cutlass::bfloat16_t; // Element type for C and D matrix operands
|
||||
using LayoutC =
|
||||
cutlass::layout::RowMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC =
|
||||
128 / cutlass::sizeof_bits<
|
||||
ElementC>::value; // Memory access granularity/alignment of C
|
||||
// matrix in units of elements (up to 16 bytes)
|
||||
|
||||
using ElementD = ElementC;
|
||||
using LayoutD = LayoutC;
|
||||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ElementCompute = float; // Element type for epilogue computation
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that
|
||||
// supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperative; // Kernel to launch
|
||||
// based on the default
|
||||
// setting in the
|
||||
// Collective Builder
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Kernel template — Tile/Cluster shapes
|
||||
// ----------------------------------------------------------------------------
|
||||
template <class TileShape_MN, class ClusterShape_MNK>
|
||||
struct W4A8GemmKernel {
|
||||
using TileShape =
|
||||
decltype(cute::append(TileShape_MN{}, cute::Int<TileShapeK>{}));
|
||||
using ClusterShape = ClusterShape_MNK;
|
||||
|
||||
// Epilogue per-tok, per-chan scales
|
||||
using ChTokScalesEpilogue =
|
||||
typename vllm::c3x::ScaledEpilogue<ElementAccumulator, ElementD,
|
||||
TileShape>;
|
||||
using EVTCompute = typename ChTokScalesEpilogue::EVTCompute;
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType,
|
||||
ElementAccumulator, ElementSChannel,
|
||||
// Transpose layout of D here since we use explicit swap + transpose
|
||||
// the void type for C tells the builder to allocate 0 smem for the C
|
||||
// matrix. We can enable this if beta == 0 by changing ElementC to
|
||||
// void below.
|
||||
ElementC, typename cutlass::layout::LayoutTranspose<LayoutC>::type,
|
||||
AlignmentC, ElementD,
|
||||
typename cutlass::layout::LayoutTranspose<LayoutD>::type, AlignmentD,
|
||||
EpilogueSchedule, // This is the only epi supporting the required
|
||||
// swap + transpose.
|
||||
EVTCompute>::CollectiveOp;
|
||||
|
||||
// The Scale information must get paired with the operand that will be scaled.
|
||||
// In this example, B is scaled so we make a tuple of B's information and the
|
||||
// scale information.
|
||||
using CollectiveMainloopShuffled =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
cute::tuple<ElementB, cutlass::Array<ElementScale, ScalePackSize>>,
|
||||
LayoutB_Reordered, AlignmentB, ElementA, LayoutA_Transpose,
|
||||
AlignmentA, ElementAccumulator, TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernelShuffled = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, // Indicates ProblemShape
|
||||
CollectiveMainloopShuffled, CollectiveEpilogue>;
|
||||
using GemmShuffled =
|
||||
cutlass::gemm::device::GemmUniversalAdapter<GemmKernelShuffled>;
|
||||
|
||||
using StrideC = typename GemmKernelShuffled::StrideC;
|
||||
using StrideD = typename GemmKernelShuffled::StrideD;
|
||||
using StrideS = typename CollectiveMainloopShuffled::StrideScale;
|
||||
|
||||
static torch::Tensor mm(torch::Tensor const& A,
|
||||
torch::Tensor const& B, // already packed
|
||||
torch::Tensor const& group_scales, // already packed
|
||||
int64_t group_size,
|
||||
torch::Tensor const& channel_scales,
|
||||
torch::Tensor const& token_scales,
|
||||
std::optional<at::ScalarType> const& maybe_out_type) {
|
||||
// TODO: param validation
|
||||
int m = A.size(0);
|
||||
int k = A.size(1);
|
||||
int n = B.size(1);
|
||||
|
||||
// Allocate output
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
||||
auto device = A.device();
|
||||
auto stream = at::cuda::getCurrentCUDAStream(device.index());
|
||||
torch::Tensor D =
|
||||
torch::empty({m, n}, torch::TensorOptions()
|
||||
.dtype(equivalent_scalar_type_v<ElementD>)
|
||||
.device(device));
|
||||
// prepare arg pointers
|
||||
auto A_ptr = static_cast<MmaType const*>(A.const_data_ptr());
|
||||
auto B_ptr = static_cast<QuantType const*>(B.const_data_ptr());
|
||||
auto D_ptr = static_cast<ElementD*>(D.data_ptr());
|
||||
// can we avoid harcode the 8 here
|
||||
auto S_ptr =
|
||||
static_cast<cutlass::Array<ElementScale, ScalePackSize> const*>(
|
||||
group_scales.const_data_ptr());
|
||||
|
||||
// runtime layout for B
|
||||
auto shape_B = cute::make_shape(n, k, 1);
|
||||
LayoutB_Reordered layout_B_reordered =
|
||||
cute::tile_to_shape(LayoutAtomQuant{}, shape_B);
|
||||
|
||||
// strides
|
||||
int const scale_k = cutlass::ceil_div(k, group_size);
|
||||
StrideA stride_A =
|
||||
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
|
||||
// Reverse stride here due to swap and transpose
|
||||
StrideD stride_D =
|
||||
cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(n, m, 1));
|
||||
StrideS stride_S = cutlass::make_cute_packed_stride(
|
||||
StrideS{}, cute::make_shape(n, scale_k, 1));
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an
|
||||
// instance of Gemm auto arguments =
|
||||
// args_from_options<GemmShuffled>(options);
|
||||
/// Populates a Gemm::Arguments structure from the given arguments
|
||||
/// Swap the A and B tensors, as well as problem shapes here.
|
||||
using Args = typename GemmShuffled::Arguments;
|
||||
using MainloopArguments = typename GemmKernelShuffled::MainloopArguments;
|
||||
using EpilogueArguments = typename GemmKernelShuffled::EpilogueArguments;
|
||||
|
||||
MainloopArguments mainloop_arguments{
|
||||
B_ptr, layout_B_reordered, A_ptr, stride_A,
|
||||
S_ptr, stride_S, group_size};
|
||||
|
||||
EpilogueArguments epilogue_arguments{
|
||||
ChTokScalesEpilogue::prepare_args(channel_scales, token_scales),
|
||||
nullptr,
|
||||
{}, // no C
|
||||
D_ptr,
|
||||
stride_D};
|
||||
|
||||
Args arguments{cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{n, m, k, 1}, // shape
|
||||
mainloop_arguments,
|
||||
epilogue_arguments};
|
||||
|
||||
// Workspace
|
||||
size_t workspace_size = GemmShuffled::get_workspace_size(arguments);
|
||||
torch::Tensor workspace =
|
||||
torch::empty(workspace_size,
|
||||
torch::TensorOptions().dtype(torch::kU8).device(device));
|
||||
|
||||
// Run GEMM
|
||||
GemmShuffled gemm;
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.data_ptr(), stream));
|
||||
CUTLASS_CHECK(gemm.run(stream));
|
||||
|
||||
return D;
|
||||
}
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Kernel instantiations and dispatch logic
|
||||
// ----------------------------------------------------------------------------
|
||||
using Kernel_256x128_1x1x1 =
|
||||
W4A8GemmKernel<Shape<_256, _128>, Shape<_1, _1, _1>>;
|
||||
using Kernel_256x64_1x1x1 = W4A8GemmKernel<Shape<_256, _64>, Shape<_1, _1, _1>>;
|
||||
using Kernel_256x32_1x1x1 = W4A8GemmKernel<Shape<_256, _32>, Shape<_1, _1, _1>>;
|
||||
using Kernel_256x16_1x1x1 = W4A8GemmKernel<Shape<_256, _16>, Shape<_1, _1, _1>>;
|
||||
using Kernel_128x256_2x1x1 =
|
||||
W4A8GemmKernel<Shape<_128, _256>, Shape<_2, _1, _1>>;
|
||||
using Kernel_128x256_1x1x1 =
|
||||
W4A8GemmKernel<Shape<_128, _256>, Shape<_1, _1, _1>>;
|
||||
using Kernel_128x128_1x1x1 =
|
||||
W4A8GemmKernel<Shape<_128, _128>, Shape<_1, _1, _1>>;
|
||||
using Kernel_128x64_1x1x1 = W4A8GemmKernel<Shape<_128, _64>, Shape<_1, _1, _1>>;
|
||||
using Kernel_128x32_1x1x1 = W4A8GemmKernel<Shape<_128, _32>, Shape<_1, _1, _1>>;
|
||||
using Kernel_128x16_1x1x1 = W4A8GemmKernel<Shape<_128, _16>, Shape<_1, _1, _1>>;
|
||||
|
||||
torch::Tensor mm_dispatch(torch::Tensor const& A,
|
||||
torch::Tensor const& B, // already packed
|
||||
torch::Tensor const& group_scales, // already packed
|
||||
int64_t group_size,
|
||||
torch::Tensor const& channel_scales,
|
||||
torch::Tensor const& token_scales,
|
||||
std::optional<at::ScalarType> const& maybe_out_type,
|
||||
const std::string& schedule) {
|
||||
if (schedule == "256x128_1x1x1") {
|
||||
return Kernel_256x128_1x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
} else if (schedule == "256x64_1x1x1") {
|
||||
return Kernel_256x64_1x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
} else if (schedule == "256x32_1x1x1") {
|
||||
return Kernel_256x32_1x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
} else if (schedule == "256x16_1x1x1") {
|
||||
return Kernel_256x16_1x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
} else if (schedule == "128x256_2x1x1") {
|
||||
return Kernel_128x256_2x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
} else if (schedule == "128x256_1x1x1") {
|
||||
return Kernel_128x256_1x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
} else if (schedule == "128x128_1x1x1") {
|
||||
return Kernel_128x128_1x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
} else if (schedule == "128x64_1x1x1") {
|
||||
return Kernel_128x64_1x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
} else if (schedule == "128x32_1x1x1") {
|
||||
return Kernel_128x32_1x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
} else if (schedule == "128x16_1x1x1") {
|
||||
return Kernel_128x16_1x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
}
|
||||
TORCH_CHECK(false, "Unknown W4A8 schedule: ", schedule);
|
||||
return {};
|
||||
}
|
||||
|
||||
torch::Tensor mm(torch::Tensor const& A,
|
||||
torch::Tensor const& B, // already packed
|
||||
torch::Tensor const& group_scales, // already packed
|
||||
int64_t group_size, torch::Tensor const& channel_scales,
|
||||
torch::Tensor const& token_scales,
|
||||
std::optional<at::ScalarType> const& maybe_out_type,
|
||||
std::optional<std::string> maybe_schedule) {
|
||||
// requested a specific schedule
|
||||
if (maybe_schedule) {
|
||||
return mm_dispatch(A, B, group_scales, group_size, channel_scales,
|
||||
token_scales, maybe_out_type, *maybe_schedule);
|
||||
}
|
||||
std::string schedule;
|
||||
int M = A.size(0);
|
||||
int K = A.size(1);
|
||||
int N = B.size(1);
|
||||
// heuristic
|
||||
if (M <= 16) {
|
||||
schedule = (K == 16384 && N == 18432) ? "256x16_1x1x1" : "128x16_1x1x1";
|
||||
} else if (M <= 32) {
|
||||
schedule = (K == 16384 && N == 18432) ? "256x32_1x1x1" : "128x32_1x1x1";
|
||||
} else if (M <= 64) {
|
||||
if (K == 16384 && N == 18432)
|
||||
schedule = "256x64_1x1x1";
|
||||
else if (N <= 8192 && K <= 8192)
|
||||
schedule = "128x32_1x1x1";
|
||||
else
|
||||
schedule = "128x64_1x1x1";
|
||||
} else if (M <= 128) {
|
||||
if (K == 16384 && N == 18432)
|
||||
schedule = "256x128_1x1x1";
|
||||
else if (N <= 8192)
|
||||
schedule = "128x64_1x1x1";
|
||||
else
|
||||
schedule = "128x128_1x1x1";
|
||||
} else if (M <= 256) {
|
||||
if (N <= 4096)
|
||||
schedule = "128x64_1x1x1";
|
||||
else if (N <= 8192)
|
||||
schedule = "128x128_1x1x1";
|
||||
else
|
||||
schedule = "128x256_1x1x1";
|
||||
} else if (M <= 512 && N <= 4096) {
|
||||
schedule = "128x128_1x1x1";
|
||||
} else if (M <= 1024) {
|
||||
schedule = "128x256_1x1x1";
|
||||
} else {
|
||||
schedule = "128x256_2x1x1";
|
||||
}
|
||||
return mm_dispatch(A, B, group_scales, group_size, channel_scales,
|
||||
token_scales, maybe_out_type, schedule);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Pre-processing utils
|
||||
// ----------------------------------------------------------------------------
|
||||
torch::Tensor pack_scale_fp8(torch::Tensor const& scales) {
|
||||
TORCH_CHECK(scales.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(scales.is_contiguous());
|
||||
TORCH_CHECK(scales.is_cuda());
|
||||
|
||||
auto packed_scales = torch::empty(
|
||||
{scales.numel() * ScalePackSize},
|
||||
torch::TensorOptions().dtype(scales.dtype()).device(scales.device()));
|
||||
auto scales_ptr = static_cast<MmaType const*>(scales.const_data_ptr());
|
||||
auto packed_scales_ptr =
|
||||
static_cast<cutlass::Array<ElementScale, ScalePackSize>*>(
|
||||
packed_scales.data_ptr());
|
||||
|
||||
cutlass::pack_scale_fp8(scales_ptr, packed_scales_ptr, scales.numel());
|
||||
|
||||
return packed_scales;
|
||||
}
|
||||
|
||||
torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) {
|
||||
TORCH_CHECK(B.dtype() == torch::kInt32);
|
||||
TORCH_CHECK(B.dim() == 2);
|
||||
|
||||
torch::Tensor B_packed = torch::empty_like(B);
|
||||
|
||||
int k = B.size(0) * PackFactor; // logical k
|
||||
int n = B.size(1);
|
||||
|
||||
auto B_ptr = static_cast<QuantType const*>(B.const_data_ptr());
|
||||
auto B_packed_ptr = static_cast<QuantType*>(B_packed.data_ptr());
|
||||
auto shape_B = cute::make_shape(n, k, 1);
|
||||
auto layout_B = make_layout(shape_B, LayoutRight{}); // row major
|
||||
LayoutB_Reordered layout_B_reordered =
|
||||
cute::tile_to_shape(LayoutAtomQuant{}, shape_B);
|
||||
|
||||
cutlass::unified_encode_int4b(B_ptr, B_packed_ptr, n * k);
|
||||
cutlass::reorder_tensor(B_packed_ptr, layout_B, layout_B_reordered);
|
||||
|
||||
return B_packed;
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("cutlass_w4a8_mm", &mm);
|
||||
m.impl("cutlass_pack_scale_fp8", &pack_scale_fp8);
|
||||
m.impl("cutlass_encode_and_reorder_int4b", &encode_and_reorder_int4b);
|
||||
}
|
||||
|
||||
} // namespace vllm::cutlass_w4a8
|
||||
@ -309,6 +309,26 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
|
||||
"SymInt size_n, int num_bits) -> Tensor");
|
||||
// conditionally compiled so impl registrations are in source file
|
||||
|
||||
// CUTLASS w4a8 GEMM
|
||||
ops.def(
|
||||
"cutlass_w4a8_mm("
|
||||
" Tensor A,"
|
||||
" Tensor B,"
|
||||
" Tensor group_scales,"
|
||||
" int group_size,"
|
||||
" Tensor channel_scales,"
|
||||
" Tensor token_scales,"
|
||||
" ScalarType? out_type,"
|
||||
" str? maybe_schedule"
|
||||
") -> Tensor",
|
||||
{stride_tag});
|
||||
// pack scales
|
||||
ops.def("cutlass_pack_scale_fp8(Tensor scales) -> Tensor");
|
||||
// encode and reorder weight matrix
|
||||
ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor");
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
#endif
|
||||
|
||||
// Dequantization for GGML.
|
||||
@ -672,11 +692,16 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
"str kv_cache_dtype) -> ()");
|
||||
cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
|
||||
|
||||
// Gather cache blocks from src_cache to dst.
|
||||
// Gather cache blocks from src_cache to dst, dequantizing from
|
||||
// src_cache's dtype to dst's dtype if necessary.
|
||||
cache_ops.def(
|
||||
"gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, "
|
||||
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()");
|
||||
cache_ops.impl("gather_cache", torch::kCUDA, &gather_cache);
|
||||
"gather_and_maybe_dequant_cache(Tensor src_cache, Tensor! dst, "
|
||||
" Tensor block_table, Tensor cu_seq_lens, "
|
||||
" int batch_size, "
|
||||
" str kv_cache_dtype, "
|
||||
" Tensor scale, Tensor? seq_starts) -> ()");
|
||||
cache_ops.impl("gather_and_maybe_dequant_cache", torch::kCUDA,
|
||||
&gather_and_maybe_dequant_cache);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
|
||||
|
||||
@ -432,31 +432,19 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
|
||||
# Install DeepGEMM from source
|
||||
ARG DEEPGEMM_GIT_REPO="https://github.com/deepseek-ai/DeepGEMM.git"
|
||||
ARG DEEPGEMM_GIT_REF="7b6b5563b9d4c1ae07ffbce7f78ad3ac9204827c"
|
||||
RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
|
||||
. /etc/environment
|
||||
CUDA_MAJOR="${CUDA_VERSION%%.*}"
|
||||
CUDA_MINOR="${CUDA_VERSION#${CUDA_MAJOR}.}"
|
||||
CUDA_MINOR="${CUDA_MINOR%%.*}"
|
||||
if [ "$CUDA_MAJOR" -ge 12 ] && [ "$CUDA_MINOR" -ge 8 ]; then
|
||||
git clone --recursive --shallow-submodules \
|
||||
${DEEPGEMM_GIT_REPO} deepgemm
|
||||
echo "🏗️ Building DeepGEMM"
|
||||
pushd deepgemm
|
||||
git checkout ${DEEPGEMM_GIT_REF}
|
||||
# Build DeepGEMM
|
||||
# (Based on https://github.com/deepseek-ai/DeepGEMM/blob/main/install.sh)
|
||||
rm -rf build dist
|
||||
rm -rf *.egg-info
|
||||
python3 setup.py bdist_wheel
|
||||
uv pip install --system dist/*.whl
|
||||
popd
|
||||
rm -rf deepgemm
|
||||
else
|
||||
echo "Skipping DeepGEMM installation (requires CUDA 12.8+ but got ${CUDA_VERSION})"
|
||||
fi
|
||||
BASH
|
||||
COPY tools/install_deepgemm.sh /tmp/install_deepgemm.sh
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
VLLM_DOCKER_BUILD_CONTEXT=1 /tmp/install_deepgemm.sh --cuda-version "${CUDA_VERSION}" --ref "${DEEPGEMM_GIT_REF}" \
|
||||
&& rm /tmp/install_deepgemm.sh
|
||||
|
||||
# Install EP kernels(pplx-kernels and DeepEP), NixL
|
||||
COPY tools/ep_kernels/install_python_libraries.sh install_python_libraries.sh
|
||||
COPY tools/install_nixl.sh install_nixl.sh
|
||||
ENV CUDA_HOME=/usr/local/cuda
|
||||
RUN export TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST:-9.0a+PTX}" \
|
||||
&& bash install_python_libraries.sh \
|
||||
&& bash install_nixl.sh --force
|
||||
|
||||
#################### vLLM installation IMAGE ####################
|
||||
|
||||
|
||||
@ -71,7 +71,7 @@ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm /vllm-workspace
|
||||
RUN cd /vllm-workspace \
|
||||
&& rm -rf vllm \
|
||||
&& python3 -m pip install -e tests/vllm_test_utils \
|
||||
&& python3 -m pip install lm-eval[api]==0.4.4 \
|
||||
&& python3 -m pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] \
|
||||
&& python3 -m pip install pytest-shard
|
||||
|
||||
# -----------------------
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below:
|
||||
|
||||
- [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/pDmAXHcN7Iqc8sUKgJgGtg), August 23rd 2025. [[Slides]](https://drive.google.com/drive/folders/1OvLx39wnCGy_WKq8SiVKf7YcxxYI3WCH)
|
||||
- [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/dgkWg1WFpWGO2jCdTqQHxA), August 2nd 2025. [[Slides]](https://drive.google.com/drive/folders/1Pid6NSFLU43DZRi0EaTcPgXsAzDvbBqF) [[Recording]](https://www.chaspark.com/#/live/1166916873711665152).
|
||||
- [NYC vLLM Meetup](https://lu.ma/c1rqyf1f), May 7th, 2025. [[Slides]](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing)
|
||||
- [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day), April 3rd 2025. [[Slides]](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing).
|
||||
|
||||
@ -172,6 +172,7 @@ The availablilty of batch-level DP is based on model implementation.
|
||||
Currently, the following models support `mm_encoder_tp_mode="data"`:
|
||||
|
||||
- Llama4 (<gh-pr:18368>)
|
||||
- MiniCPM-V-4 (<gh-pr:23327>)
|
||||
- Qwen2.5-VL (<gh-pr:22742>)
|
||||
- Step3 (<gh-pr:22697>)
|
||||
|
||||
@ -195,6 +196,13 @@ vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4 -dp 2
|
||||
!!! note
|
||||
API server scale-out is only available for online inference.
|
||||
|
||||
!!! warning
|
||||
By default, 8 CPU threads are used in each API server to load media items (e.g. images)
|
||||
from request data.
|
||||
|
||||
If you apply API server scale-out, consider adjusting `VLLM_MEDIA_LOADING_THREAD_COUNT`
|
||||
to avoid CPU resource exhaustion.
|
||||
|
||||
!!! note
|
||||
[Multi-modal processor cache](#processor-cache) is disabled when API server scale-out is enabled
|
||||
because it requires a one-to-one correspondance between API and engine core processes.
|
||||
|
||||
@ -18,7 +18,7 @@ vllm serve Qwen/Qwen1.5-32B-Chat-AWQ --max-model-len 4096
|
||||
|
||||
- Download and install [Anything LLM desktop](https://anythingllm.com/desktop).
|
||||
|
||||
- On the bottom left of open settings, AI Prooviders --> LLM:
|
||||
- On the bottom left of open settings, AI Providers --> LLM:
|
||||
- LLM Provider: Generic OpenAI
|
||||
- Base URL: http://{vllm server host}:{vllm server port}/v1
|
||||
- Chat Model Name: `Qwen/Qwen1.5-32B-Chat-AWQ`
|
||||
|
||||
@ -226,7 +226,7 @@ Doing this will add the new implementation to the test suite.
|
||||
|
||||
The unit test file [test_modular_kernel_combinations.py](gh-file:tests/kernels/moe/test_modular_kernel_combinations.py) can also be executed as a standalone script.
|
||||
Example: `python3 -m tests.kernels.moe.test_modular_kernel_combinations --pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts`
|
||||
As a side-effect, this script can be used to test `FusedMoEPrepareAndFinalize` & `FusedMoEPermuteExpertsUnpermute` compatibility. When invoked
|
||||
As a side effect, this script can be used to test `FusedMoEPrepareAndFinalize` & `FusedMoEPermuteExpertsUnpermute` compatibility. When invoked
|
||||
with incompatible types, the script will error.
|
||||
|
||||
### How To Profile
|
||||
|
||||
@ -565,7 +565,7 @@ model and then validate those tokens with the larger model.
|
||||
- `vllm:spec_decode_num_emitted_tokens_total` (Counter)
|
||||
|
||||
There is a PR under review (<gh-pr:12193>) to add "prompt lookup (ngram)"
|
||||
seculative decoding to v1. Other techniques will follow. We should
|
||||
speculative decoding to v1. Other techniques will follow. We should
|
||||
revisit the v0 metrics in this context.
|
||||
|
||||
!!! note
|
||||
|
||||
@ -77,7 +77,7 @@ The `multiproc_xpu_executor` forces the use of `spawn`.
|
||||
|
||||
There are other miscellaneous places hard-coding the use of `spawn`:
|
||||
|
||||
- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/distributed/device_communicators/custom_all_reduce_utils.py#L135>
|
||||
- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/distributed/device_communicators/all_reduce_utils.py#L135>
|
||||
- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/entrypoints/openai/api_server.py#L184>
|
||||
|
||||
Related PRs:
|
||||
|
||||
@ -422,7 +422,7 @@ a total of 128 * 16 / 256 = 8 inner iterations for a warp to handle
|
||||
a whole block of value tokens. And each `accs` in each thread
|
||||
contains 8 elements that accumulated at 8 different head positions.
|
||||
For the thread 0, the `accs` variable will have 8 elements, which
|
||||
are 0th, 32th … 224th elements of a value head that are accumulated
|
||||
are 0th, 32nd … 224th elements of a value head that are accumulated
|
||||
from all assigned 8 tokens.
|
||||
|
||||
## LV
|
||||
|
||||
@ -79,7 +79,7 @@ Since simple RTN does not require data for weight quantization and the activatio
|
||||
Install `vllm` and `lm-evaluation-harness` for evaluation:
|
||||
|
||||
```bash
|
||||
pip install vllm lm-eval==0.4.4
|
||||
pip install vllm git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api]
|
||||
```
|
||||
|
||||
Load and run the model in `vllm`:
|
||||
|
||||
@ -7,7 +7,7 @@ Intel Gaudi supports quantization of various modules and functions, including, b
|
||||
[Supported Modules\\Supported Functions\\Custom Patched Modules](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Quantization/Inference_Using_FP8.html#supported-modules).
|
||||
|
||||
!!! note
|
||||
Measurement files are required to run quantized models with vLLM on Gaudi accelerators. The FP8 model calibration procedure is described in the [vllm-hpu-extention](https://github.com/HabanaAI/vllm-hpu-extension/tree/main/calibration/README.md) package.
|
||||
Measurement files are required to run quantized models with vLLM on Gaudi accelerators. The FP8 model calibration procedure is described in the [vLLM HPU extension](https://github.com/HabanaAI/vllm-hpu-extension/tree/main/calibration/README.md) package.
|
||||
|
||||
!!! note
|
||||
`QUANT_CONFIG` is an environment variable that points to the measurement or quantization [JSON config file](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Quantization/Inference_Using_FP8.html#supported-json-config-file-options).
|
||||
|
||||
@ -18,7 +18,7 @@ pip install llmcompressor
|
||||
Additionally, install `vllm` and `lm-evaluation-harness` for evaluation:
|
||||
|
||||
```bash
|
||||
pip install vllm lm-eval==0.4.4
|
||||
pip install vllm git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api]
|
||||
```
|
||||
|
||||
## Quantization Process
|
||||
|
||||
@ -19,7 +19,7 @@ pip install llmcompressor
|
||||
Additionally, install `vllm` and `lm-evaluation-harness` for evaluation:
|
||||
|
||||
```bash
|
||||
pip install vllm lm-eval==0.4.4
|
||||
pip install vllm git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api]
|
||||
```
|
||||
|
||||
## Quantization Process
|
||||
|
||||
@ -20,7 +20,7 @@ for more installation details.
|
||||
Additionally, install `vllm` and `lm-evaluation-harness` for evaluation:
|
||||
|
||||
```bash
|
||||
pip install vllm lm-eval==0.4.4
|
||||
pip install vllm git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api]
|
||||
```
|
||||
|
||||
## Quantization Process
|
||||
|
||||
@ -284,6 +284,14 @@ Supported models:
|
||||
|
||||
Flags: `--tool-call-parser deepseek_v3 --chat-template {see_above}`
|
||||
|
||||
### DeepSeek-V3.1 Models (`deepseek_v31`)
|
||||
|
||||
Supported models:
|
||||
|
||||
* `deepseek-ai/DeepSeek-V3.1` (use with <gh-file:examples/tool_chat_template_deepseekv31.jinja>)
|
||||
|
||||
Flags: `--tool-call-parser deepseek_v31 --chat-template {see_above}`
|
||||
|
||||
### Kimi-K2 Models (`kimi_k2`)
|
||||
|
||||
Supported models:
|
||||
|
||||
@ -170,7 +170,7 @@ This value is 4GB by default. Larger space can support more concurrent requests,
|
||||
|
||||
First of all, please make sure the thread-binding and KV cache space are properly set and take effect. You can check the thread-binding by running a vLLM benchmark and observing CPU cores usage via `htop`.
|
||||
|
||||
Inference batch size is a important parameter for the performance. Larger batch usually provides higher throughput, smaller batch provides lower latency. Tuning max batch size starts from default value to balance throughput and latency is an effective way to improve vLLM CPU performance on specific platforms. There are two important related parameters in vLLM:
|
||||
Inference batch size is an important parameter for the performance. Larger batch usually provides higher throughput, smaller batch provides lower latency. Tuning max batch size starts from default value to balance throughput and latency is an effective way to improve vLLM CPU performance on specific platforms. There are two important related parameters in vLLM:
|
||||
|
||||
- `--max-num-batched-tokens`, defines the limit of token numbers in a single batch, has more impacts on the first token performance. The default value is set as:
|
||||
- Offline Inference: `4096 * world_size`
|
||||
@ -179,7 +179,7 @@ Inference batch size is a important parameter for the performance. Larger batch
|
||||
- Offline Inference: `256 * world_size`
|
||||
- Online Serving: `128 * world_size`
|
||||
|
||||
vLLM CPU supports tensor parallel (TP) and pipeline parallel (PP) to leverage multiple CPU sockets and memory nodes. For more detials of tuning TP and PP, please refer to [Optimization and Tuning](../../configuration/optimization.md). For vLLM CPU, it is recommend to use TP and PP togther if there are enough CPU sockets and memory nodes.
|
||||
vLLM CPU supports tensor parallel (TP) and pipeline parallel (PP) to leverage multiple CPU sockets and memory nodes. For more details of tuning TP and PP, please refer to [Optimization and Tuning](../../configuration/optimization.md). For vLLM CPU, it is recommend to use TP and PP together if there are enough CPU sockets and memory nodes.
|
||||
|
||||
### Which quantization configs does vLLM CPU support?
|
||||
|
||||
@ -190,6 +190,6 @@ vLLM CPU supports tensor parallel (TP) and pipeline parallel (PP) to leverage mu
|
||||
|
||||
### (x86 only) What is the purpose of `VLLM_CPU_MOE_PREPACK` and `VLLM_CPU_SGL_KERNEL`?
|
||||
|
||||
- Both of them requires `amx` CPU flag.
|
||||
- Both of them require `amx` CPU flag.
|
||||
- `VLLM_CPU_MOE_PREPACK` can provides better performance for MoE models
|
||||
- `VLLM_CPU_SGL_KERNEL` can provides better performance for MoE models and small-batch scenarios.
|
||||
|
||||
@ -261,13 +261,13 @@ Lower value corresponds to less usable graph memory reserved for prefill stage,
|
||||
|
||||
User can also configure the strategy for capturing HPU Graphs for prompt and decode stages separately. Strategy affects the order of capturing graphs. There are two strategies implemented:
|
||||
|
||||
- `max_bs` - graph capture queue will sorted in descending order by their batch sizes. Buckets with equal batch sizes are sorted by sequence length in ascending order (e.g. `(64, 128)`, `(64, 256)`, `(32, 128)`, `(32, 256)`, `(1, 128)`, `(1,256)`), default strategy for decode
|
||||
- `max_bs` - graph capture queue will be sorted in descending order by their batch sizes. Buckets with equal batch sizes are sorted by sequence length in ascending order (e.g. `(64, 128)`, `(64, 256)`, `(32, 128)`, `(32, 256)`, `(1, 128)`, `(1,256)`), default strategy for decode
|
||||
- `min_tokens` - graph capture queue will be sorted in ascending order by the number of tokens each graph processes (`batch_size*sequence_length`), default strategy for prompt
|
||||
|
||||
When there's large amount of requests pending, vLLM scheduler will attempt to fill the maximum batch size for decode as soon as possible. When a request is finished, decode batch size decreases. When that happens, vLLM will attempt to schedule a prefill iteration for requests in the waiting queue, to fill the decode batch size to its previous state. This means that in a full load scenario, decode batch size is often at its maximum, which makes large batch size HPU Graphs crucial to capture, as reflected by `max_bs` strategy. On the other hand, prefills will be executed most frequently with very low batch sizes (1-4), which is reflected in `min_tokens` strategy.
|
||||
|
||||
!!! note
|
||||
`VLLM_GRAPH_PROMPT_RATIO` does not set a hard limit on memory taken by graphs for each stage (prefill and decode). vLLM will first attempt to use up entirety of usable prefill graph memory (usable graph memory * `VLLM_GRAPH_PROMPT_RATIO`) for capturing prefill HPU Graphs, next it will attempt do the same for decode graphs and usable decode graph memory pool. If one stage is fully captured, and there is unused memory left within usable graph memory pool, vLLM will attempt further graph capture for the other stage, until no more HPU Graphs can be captured without exceeding reserved memory pool. The behavior on that mechanism can be observed in the example below.
|
||||
`VLLM_GRAPH_PROMPT_RATIO` does not set a hard limit on memory taken by graphs for each stage (prefill and decode). vLLM will first attempt to use up entirety of usable prefill graph memory (usable graph memory * `VLLM_GRAPH_PROMPT_RATIO`) for capturing prefill HPU Graphs, next it will attempt to do the same for decode graphs and usable decode graph memory pool. If one stage is fully captured, and there is unused memory left within usable graph memory pool, vLLM will attempt further graph capture for the other stage, until no more HPU Graphs can be captured without exceeding reserved memory pool. The behavior on that mechanism can be observed in the example below.
|
||||
|
||||
Each described step is logged by vLLM server, as follows (negative values correspond to memory being released):
|
||||
|
||||
|
||||
@ -328,11 +328,11 @@ th {
|
||||
| `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `BailingMoeForCausalLM` | Ling | `inclusionAI/Ling-lite-1.5`, `inclusionAI/Ling-plus`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | |
|
||||
| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | ✅︎ |
|
||||
| `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | |
|
||||
| `MBartForConditionalGeneration` | mBART | `facebook/mbart-large-en-ro`, `facebook/mbart-large-50`, etc. | | | |
|
||||
| `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `zai-org/chatglm2-6b`, `zai-org/chatglm3-6b`, `ShieldLM-6B-chatglm3`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `CohereForCausalLM`, `Cohere2ForCausalLM` | Command-R | `CohereLabs/c4ai-command-r-v01`, `CohereLabs/c4ai-command-r7b-12-2024`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `CohereForCausalLM`, `Cohere2ForCausalLM` | Command-R, Command-A | `CohereLabs/c4ai-command-r-v01`, `CohereLabs/c4ai-command-r7b-12-2024`, `CohereLabs/c4ai-command-a-03-2025`, `CohereLabs/command-a-reasoning-08-2025`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `DbrxForCausalLM` | DBRX | `databricks/dbrx-base`, `databricks/dbrx-instruct`, etc. | | ✅︎ | ✅︎ |
|
||||
| `DeciLMForCausalLM` | DeciLM | `nvidia/Llama-3_3-Nemotron-Super-49B-v1`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `DeepseekForCausalLM` | DeepSeek | `deepseek-ai/deepseek-llm-67b-base`, `deepseek-ai/deepseek-llm-7b-chat`, etc. | | ✅︎ | ✅︎ |
|
||||
@ -401,6 +401,7 @@ th {
|
||||
| `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen3ForCausalLM` | Qwen3 | `Qwen/Qwen3-8B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen3MoeForCausalLM` | Qwen3MoE | `Qwen/Qwen3-30B-A3B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `SeedOssForCausalLM` | SeedOss | `ByteDance-Seed/Seed-OSS-36B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | | | ✅︎ |
|
||||
| `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | | ✅︎ | ✅︎ |
|
||||
| `SolarForCausalLM` | Solar Pro | `upstage/solar-pro-preview-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
@ -614,6 +615,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
||||
| `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Cohere2VisionForConditionalGeneration` | Command A Vision | T + I<sup>+</sup> | `CohereLabs/command-a-vision-07-2025`, etc. | | ✅︎ | ✅︎ |
|
||||
| `DeepseekVLV2ForCausalLM`<sup>^</sup> | DeepSeek-VL2 | T + I<sup>+</sup> | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ | ✅︎ |
|
||||
| `DonutForConditionalGeneration`<sup>^</sup> | Donut | T + I | `ByteDance/Dolphin`, `naver-clova-ix/donut-base-finetuned-docvqa`, etc. | | | |
|
||||
| `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large`, etc. | | | |
|
||||
| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ |
|
||||
|
||||
@ -166,7 +166,7 @@ Processed means the values after applying all processors, including temperature
|
||||
|
||||
##### Prompt Logprobs with Prefix Caching
|
||||
|
||||
Currently prompt logprobs are only supported when prefix caching is turned off via `--no-enable-prefix-caching`. In a future release, prompt logprobs will be compatible with prefix caching, but a recomputation will be triggered to recover the full prompt logprobs even upon a prefix cache hit. See details in [RFC #13414](gh-issue:13414).
|
||||
Logprobs are not cached. For a request requiring prompt logprobs, the engine will ignore the prefix cache and recompute the prefill of full prompt to generate the logprobs.
|
||||
|
||||
#### Deprecated Features
|
||||
|
||||
|
||||
311
examples/offline_inference/dolphin.py
Normal file
311
examples/offline_inference/dolphin.py
Normal file
@ -0,0 +1,311 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import regex as re
|
||||
from PIL import Image
|
||||
from transformers import DonutProcessor
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.multimodal.utils import fetch_image
|
||||
|
||||
|
||||
# Copied from https://github.com/bytedance/Dolphin/utils/utils.py
|
||||
@dataclass
|
||||
class ImageDimensions:
|
||||
original_w: int
|
||||
original_h: int
|
||||
padded_w: int
|
||||
padded_h: int
|
||||
|
||||
|
||||
# Copied from https://github.com/bytedance/Dolphin/utils/utils.py
|
||||
def map_to_original_coordinates(
|
||||
x1, y1, x2, y2, dims: ImageDimensions
|
||||
) -> tuple[int, int, int, int]:
|
||||
try:
|
||||
top = (dims.padded_h - dims.original_h) // 2
|
||||
left = (dims.padded_w - dims.original_w) // 2
|
||||
orig_x1 = max(0, x1 - left)
|
||||
orig_y1 = max(0, y1 - top)
|
||||
orig_x2 = min(dims.original_w, x2 - left)
|
||||
orig_y2 = min(dims.original_h, y2 - top)
|
||||
if orig_x2 <= orig_x1:
|
||||
orig_x2 = min(orig_x1 + 1, dims.original_w)
|
||||
if orig_y2 <= orig_y1:
|
||||
orig_y2 = min(orig_y1 + 1, dims.original_h)
|
||||
return int(orig_x1), int(orig_y1), int(orig_x2), int(orig_y2)
|
||||
except Exception as e:
|
||||
print(f"map_to_original_coordinates error: {str(e)}")
|
||||
return 0, 0, min(100, dims.original_w), min(100, dims.original_h)
|
||||
|
||||
|
||||
# Copied from https://github.com/bytedance/Dolphin/utils/utils.py
|
||||
def adjust_box_edges(image, boxes: list[list[float]], max_pixels=15, threshold=0.2):
|
||||
if isinstance(image, str):
|
||||
image = cv2.imread(image)
|
||||
img_h, img_w = image.shape[:2]
|
||||
new_boxes = []
|
||||
for box in boxes:
|
||||
best_box = copy.deepcopy(box)
|
||||
|
||||
def check_edge(img, current_box, i, is_vertical):
|
||||
edge = current_box[i]
|
||||
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
_, binary = cv2.threshold(
|
||||
gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU
|
||||
)
|
||||
if is_vertical:
|
||||
line = binary[current_box[1] : current_box[3] + 1, edge]
|
||||
else:
|
||||
line = binary[edge, current_box[0] : current_box[2] + 1]
|
||||
transitions = np.abs(np.diff(line))
|
||||
return np.sum(transitions) / len(transitions)
|
||||
|
||||
edges = [(0, -1, True), (2, 1, True), (1, -1, False), (3, 1, False)]
|
||||
current_box = copy.deepcopy(box)
|
||||
current_box[0] = min(max(current_box[0], 0), img_w - 1)
|
||||
current_box[1] = min(max(current_box[1], 0), img_h - 1)
|
||||
current_box[2] = min(max(current_box[2], 0), img_w - 1)
|
||||
current_box[3] = min(max(current_box[3], 0), img_h - 1)
|
||||
|
||||
for i, direction, is_vertical in edges:
|
||||
best_score = check_edge(image, current_box, i, is_vertical)
|
||||
if best_score <= threshold:
|
||||
continue
|
||||
for step in range(max_pixels):
|
||||
current_box[i] += direction
|
||||
if i == 0 or i == 2:
|
||||
current_box[i] = min(max(current_box[i], 0), img_w - 1)
|
||||
else:
|
||||
current_box[i] = min(max(current_box[i], 0), img_h - 1)
|
||||
score = check_edge(image, current_box, i, is_vertical)
|
||||
if score < best_score:
|
||||
best_score = score
|
||||
best_box = copy.deepcopy(current_box)
|
||||
if score <= threshold:
|
||||
break
|
||||
new_boxes.append(best_box)
|
||||
return new_boxes
|
||||
|
||||
|
||||
# Copied from https://github.com/bytedance/Dolphin/utils/utils.py
|
||||
def process_coordinates(coords, padded_image, dims: ImageDimensions, previous_box=None):
|
||||
try:
|
||||
x1, y1 = int(coords[0] * dims.padded_w), int(coords[1] * dims.padded_h)
|
||||
x2, y2 = int(coords[2] * dims.padded_w), int(coords[3] * dims.padded_h)
|
||||
x1, y1, x2, y2 = (
|
||||
max(0, min(x1, dims.padded_w - 1)),
|
||||
max(0, min(y1, dims.padded_h - 1)),
|
||||
max(0, min(x2, dims.padded_w)),
|
||||
max(0, min(y2, dims.padded_h)),
|
||||
)
|
||||
if x2 <= x1:
|
||||
x2 = min(x1 + 1, dims.padded_w)
|
||||
if y2 <= y1:
|
||||
y2 = min(y1 + 1, dims.padded_h)
|
||||
new_boxes = adjust_box_edges(padded_image, [[x1, y1, x2, y2]])
|
||||
x1, y1, x2, y2 = new_boxes[0]
|
||||
x1, y1, x2, y2 = (
|
||||
max(0, min(x1, dims.padded_w - 1)),
|
||||
max(0, min(y1, dims.padded_h - 1)),
|
||||
max(0, min(x2, dims.padded_w)),
|
||||
max(0, min(y2, dims.padded_h)),
|
||||
)
|
||||
if x2 <= x1:
|
||||
x2 = min(x1 + 1, dims.padded_w)
|
||||
if y2 <= y1:
|
||||
y2 = min(y1 + 1, dims.padded_h)
|
||||
if previous_box is not None:
|
||||
prev_x1, prev_y1, prev_x2, prev_y2 = previous_box
|
||||
if (x1 < prev_x2 and x2 > prev_x1) and (y1 < prev_y2 and y2 > prev_y1):
|
||||
y1 = prev_y2
|
||||
y1 = min(y1, dims.padded_h - 1)
|
||||
if y2 <= y1:
|
||||
y2 = min(y1 + 1, dims.padded_h)
|
||||
new_previous_box = [x1, y1, x2, y2]
|
||||
orig_x1, orig_y1, orig_x2, orig_y2 = map_to_original_coordinates(
|
||||
x1, y1, x2, y2, dims
|
||||
)
|
||||
return x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, new_previous_box
|
||||
except Exception as e:
|
||||
print(f"process_coordinates error: {str(e)}")
|
||||
orig_x1, orig_y1, orig_x2, orig_y2 = (
|
||||
0,
|
||||
0,
|
||||
min(100, dims.original_w),
|
||||
min(100, dims.original_h),
|
||||
)
|
||||
return 0, 0, 100, 100, orig_x1, orig_y1, orig_x2, orig_y2, [0, 0, 100, 100]
|
||||
|
||||
|
||||
# Copied from https://github.com/bytedance/Dolphin/utils/utils.py
|
||||
def prepare_image(image) -> tuple[np.ndarray, ImageDimensions]:
|
||||
try:
|
||||
image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
||||
original_h, original_w = image_cv.shape[:2]
|
||||
max_size = max(original_h, original_w)
|
||||
top = (max_size - original_h) // 2
|
||||
bottom = max_size - original_h - top
|
||||
left = (max_size - original_w) // 2
|
||||
right = max_size - original_w - left
|
||||
padded_image = cv2.copyMakeBorder(
|
||||
image_cv, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(0, 0, 0)
|
||||
)
|
||||
padded_h, padded_w = padded_image.shape[:2]
|
||||
dimensions = ImageDimensions(
|
||||
original_w=original_w,
|
||||
original_h=original_h,
|
||||
padded_w=padded_w,
|
||||
padded_h=padded_h,
|
||||
)
|
||||
return padded_image, dimensions
|
||||
except Exception as e:
|
||||
print(f"prepare_image error: {str(e)}")
|
||||
h, w = image.height, image.width
|
||||
dimensions = ImageDimensions(original_w=w, original_h=h, padded_w=w, padded_h=h)
|
||||
return np.zeros((h, w, 3), dtype=np.uint8), dimensions
|
||||
|
||||
|
||||
# Copied from https://github.com/bytedance/Dolphin/utils/utils.py
|
||||
def parse_layout_string(bbox_str):
|
||||
"""Parse layout string using regular expressions"""
|
||||
pattern = r"\[(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+)\]\s*(\w+)"
|
||||
matches = re.finditer(pattern, bbox_str)
|
||||
|
||||
parsed_results = []
|
||||
for match in matches:
|
||||
coords = [float(match.group(i)) for i in range(1, 5)]
|
||||
label = match.group(5).strip()
|
||||
parsed_results.append((coords, label))
|
||||
|
||||
return parsed_results
|
||||
|
||||
|
||||
model_id = "ByteDance/Dolphin"
|
||||
|
||||
# The input image size for Dolphin is 896 x 896,
|
||||
# and the patch_size is 4 x 4.
|
||||
# Therefore, the initial number of patches is:
|
||||
# Height: 896 / 4 = 224 patches
|
||||
# Width: 896 / 4 = 224 patches
|
||||
|
||||
# The Dolphin model uses a staged downsampling approach,
|
||||
# defined by the "depths": [2, 2, 14, 2] configuration.
|
||||
# Before entering stages 2, 3, and 4, a "Patch Merging" operation is performed,
|
||||
# which halves the feature map's dimensions (dividing both height and width by 2).
|
||||
# Before Stage 2: The size changes from 224 x 224 to (224/2) x (224/2) = 112 x 112.
|
||||
# Before Stage 3: The size changes from 112 x 112 to (112/2) x (112/2) = 56 x 56.
|
||||
# Before Stage 4: The size changes from 56 x 56 to (56/2) x (56/2) = 28 x 28.
|
||||
|
||||
# Because vLLM needs to fill the image features with an encoder_prompt,
|
||||
# and the encoder_prompt will have `<pad>` tokens added when tokenized,
|
||||
# we need to construct an encoder_prompt with a length of 28 x 28 - 1 = 783.
|
||||
encoder_prompt = "".join(["0"] * 783)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=2048,
|
||||
)
|
||||
|
||||
processor = DonutProcessor.from_pretrained(model_id)
|
||||
llm = LLM(
|
||||
model=model_id,
|
||||
dtype="float16",
|
||||
max_num_seqs=8,
|
||||
hf_overrides={"architectures": ["DonutForConditionalGeneration"]},
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--image_path", type=str, default=None, help="Path to a local image file."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.image_path:
|
||||
if not os.path.exists(args.image_path):
|
||||
raise FileNotFoundError(f"Error: File not found at {args.image_path}")
|
||||
image = Image.open(args.image_path).convert("RGB")
|
||||
else:
|
||||
image = fetch_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/0.jpg"
|
||||
)
|
||||
|
||||
|
||||
prompt = "Parse the reading order of this document. "
|
||||
decoder_prompt = f"<s>{prompt}<Answer/>"
|
||||
decoder_prompt_tokens = TokensPrompt(
|
||||
prompt_token_ids=processor.tokenizer(decoder_prompt, add_special_tokens=False)[
|
||||
"input_ids"
|
||||
]
|
||||
)
|
||||
enc_dec_prompt = ExplicitEncoderDecoderPrompt(
|
||||
encoder_prompt=TextPrompt(prompt=encoder_prompt, multi_modal_data={"image": image}),
|
||||
decoder_prompt=decoder_prompt_tokens,
|
||||
)
|
||||
layout_outputs = llm.generate(prompts=enc_dec_prompt, sampling_params=sampling_params)
|
||||
layout_result_str = layout_outputs[0].outputs[0].text
|
||||
print(f"Layout analysis output:\n{layout_result_str}")
|
||||
|
||||
padded_image, dims = prepare_image(image)
|
||||
layout_results = parse_layout_string(layout_result_str)
|
||||
text_table_elements = []
|
||||
previous_box = None
|
||||
reading_order = 0
|
||||
for bbox_coords, label in layout_results:
|
||||
if label == "fig":
|
||||
continue
|
||||
try:
|
||||
x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = (
|
||||
process_coordinates(bbox_coords, padded_image, dims, previous_box)
|
||||
)
|
||||
cropped = padded_image[y1:y2, x1:x2]
|
||||
if cropped.size > 0 and cropped.shape[0] > 3 and cropped.shape[1] > 3:
|
||||
pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
|
||||
prompt_ocr = (
|
||||
"Parse the table in the image. "
|
||||
if label == "tab"
|
||||
else "Read text in the image. "
|
||||
)
|
||||
text_table_elements.append(
|
||||
{
|
||||
"crop": pil_crop,
|
||||
"prompt": prompt_ocr,
|
||||
"reading_order": reading_order,
|
||||
}
|
||||
)
|
||||
reading_order += 1
|
||||
except Exception as e:
|
||||
print(f"Error processing bbox (label: {label}): {str(e)}")
|
||||
continue
|
||||
|
||||
if text_table_elements:
|
||||
batch_prompts = []
|
||||
for elem in text_table_elements:
|
||||
decoder_prompt_str = f"<s>{elem['prompt']}<Answer/>"
|
||||
decoder_prompt_tokens = TokensPrompt(
|
||||
prompt_token_ids=processor.tokenizer(
|
||||
decoder_prompt_str, add_special_tokens=False
|
||||
)["input_ids"]
|
||||
)
|
||||
enc_dec_prompt = ExplicitEncoderDecoderPrompt(
|
||||
encoder_prompt=TextPrompt(
|
||||
prompt=encoder_prompt, multi_modal_data={"image": elem["crop"]}
|
||||
),
|
||||
decoder_prompt=decoder_prompt_tokens,
|
||||
)
|
||||
batch_prompts.append(enc_dec_prompt)
|
||||
batch_outputs = llm.generate(prompts=batch_prompts, sampling_params=sampling_params)
|
||||
for i, output in enumerate(batch_outputs):
|
||||
text_table_elements[i]["text"] = output.outputs[0].text.strip()
|
||||
|
||||
print("------" * 8)
|
||||
text_table_elements.sort(key=lambda x: x["reading_order"])
|
||||
for elem in text_table_elements:
|
||||
print(elem.get("text", ""))
|
||||
@ -13,6 +13,7 @@ from typing import NamedTuple
|
||||
from vllm import LLM, EngineArgs, PromptType, SamplingParams
|
||||
from vllm.assets.audio import AudioAsset
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.multimodal.utils import fetch_image
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
@ -21,6 +22,50 @@ class ModelRequestData(NamedTuple):
|
||||
prompts: Sequence[PromptType]
|
||||
|
||||
|
||||
def run_donut():
|
||||
engine_args = EngineArgs(
|
||||
model="naver-clova-ix/donut-base-finetuned-docvqa",
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
dtype="float16",
|
||||
hf_overrides={"architectures": ["DonutForConditionalGeneration"]},
|
||||
)
|
||||
|
||||
# The input image size for donut-base-finetuned-docvqa is 2560 x 1920,
|
||||
# and the patch_size is 4 x 4.
|
||||
# Therefore, the initial number of patches is:
|
||||
# Height: 1920 / 4 = 480 patches
|
||||
# Width: 2560 / 4 = 640 patches
|
||||
# The Swin model uses a staged downsampling approach,
|
||||
# defined by the "depths": [2, 2, 14, 2] configuration.
|
||||
# Before entering stages 2, 3, and 4, a "Patch Merging" operation is performed,
|
||||
# which halves the feature map's dimensions (dividing both height and width by 2).
|
||||
# Before Stage 2: The size changes from 480 x 640 to (480/2) x (640/2) = 240 x 320.
|
||||
# Before Stage 3: The size changes from 240 x 320 to (240/2) x (320/2) = 120 x 160.
|
||||
# Before Stage 4: The size changes from 120 x 160 to (120/2) x (160/2) = 60 x 80.
|
||||
# Because vLLM needs to fill the image features with an encoder_prompt,
|
||||
# and the encoder_prompt will have `<pad>` tokens added when tokenized,
|
||||
# we need to construct an encoder_prompt with a length of 60 x 80 - 1 = 4799.
|
||||
prompts = [
|
||||
{
|
||||
"encoder_prompt": {
|
||||
"prompt": "".join(["$"] * 4799),
|
||||
"multi_modal_data": {
|
||||
"image": fetch_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/0.jpg"
|
||||
) # noqa: E501
|
||||
},
|
||||
},
|
||||
"decoder_prompt": "<s_docvqa><s_question>What time is the coffee break?</s_question><s_answer>", # noqa: E501
|
||||
},
|
||||
]
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
def run_florence2():
|
||||
engine_args = EngineArgs(
|
||||
model="microsoft/Florence-2-large",
|
||||
@ -118,6 +163,7 @@ def run_whisper():
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"donut": run_donut,
|
||||
"florence2": run_florence2,
|
||||
"mllama": run_mllama,
|
||||
"whisper": run_whisper,
|
||||
|
||||
@ -5,6 +5,7 @@ from transformers import AutoTokenizer
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.benchmarks.datasets import add_dataset_parser, get_samples
|
||||
from vllm.inputs import TokensPrompt
|
||||
from vllm.v1.metrics.reader import Counter, Vector
|
||||
|
||||
try:
|
||||
@ -137,7 +138,8 @@ def main():
|
||||
sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
|
||||
if not args.custom_mm_prompts:
|
||||
outputs = llm.generate(
|
||||
prompt_token_ids=prompt_ids, sampling_params=sampling_params
|
||||
TokensPrompt(prompt_token_ids=prompt_ids),
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
else:
|
||||
outputs = llm.chat(prompts, sampling_params=sampling_params)
|
||||
|
||||
@ -85,7 +85,7 @@ def format_output(title: str, output: str):
|
||||
|
||||
|
||||
def generate_output(prompt: str, sampling_params: SamplingParams, llm: LLM):
|
||||
outputs = llm.generate(prompts=prompt, sampling_params=sampling_params)
|
||||
outputs = llm.generate(prompt, sampling_params=sampling_params)
|
||||
return outputs[0].outputs[0].text
|
||||
|
||||
|
||||
|
||||
91
examples/tool_chat_template_deepseekv31.jinja
Normal file
91
examples/tool_chat_template_deepseekv31.jinja
Normal file
@ -0,0 +1,91 @@
|
||||
{% if not add_generation_prompt is defined %}
|
||||
{% set add_generation_prompt = false %}
|
||||
{% endif %}
|
||||
{% if not thinking is defined %}
|
||||
{% set thinking = false %}
|
||||
{% endif %}
|
||||
{% set ns = namespace(is_first=false, is_tool=false, system_prompt='', is_first_sp=true, is_last_user=false) %}
|
||||
{%- for message in messages %}
|
||||
{%- if message['role'] == 'system' %}
|
||||
{%- if ns.is_first_sp %}
|
||||
{% set ns.system_prompt = ns.system_prompt + message['content'] %}
|
||||
{% set ns.is_first_sp = false %}
|
||||
{%- else %}
|
||||
{% set ns.system_prompt = ns.system_prompt + '\n\n' + message['content'] %}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
|
||||
{% if tools is defined and tools is not none %}
|
||||
{% set tool_ns = namespace(text='## Tools\nYou have access to the following tools:\n') %}
|
||||
{% for tool in tools %}
|
||||
{% set tool_ns.text = tool_ns.text + '\n### ' + tool.function.name + '\nDescription: ' + tool.function.description + '\n\nParameters: ' + (tool.function.parameters | tojson) + '\n' %}
|
||||
{% endfor %}
|
||||
{% set tool_ns.text = tool_ns.text + "\nIMPORTANT: ALWAYS adhere to this exact format for tool use:\n<|tool▁calls▁begin|><|tool▁call▁begin|>tool_call_name<|tool▁sep|>tool_call_arguments<|tool▁call▁end|>{{additional_tool_calls}}<|tool▁calls▁end|>\n\nWhere:\n\n- `tool_call_name` must be an exact match to one of the available tools\n- `tool_call_arguments` must be valid JSON that strictly follows the tool's Parameters Schema\n- For multiple tool calls, chain them directly without separators or spaces\n" %}
|
||||
{% set ns.system_prompt = ns.system_prompt + '\n\n' + tool_ns.text %}
|
||||
{% endif %}
|
||||
|
||||
{{ bos_token }}{{ ns.system_prompt }}
|
||||
{%- for message in messages %}
|
||||
{%- if message['role'] == 'user' %}
|
||||
{%- set ns.is_tool = false -%}
|
||||
{%- set ns.is_first = false -%}
|
||||
{%- set ns.is_last_user = true -%}
|
||||
{{'<|User|>' + message['content']}}
|
||||
{%- endif %}
|
||||
{%- if message['role'] == 'assistant' and message['tool_calls'] is defined and message['tool_calls'] is not none %}
|
||||
{%- if ns.is_last_user %}
|
||||
{{'<|Assistant|></think>'}}
|
||||
{%- endif %}
|
||||
{%- set ns.is_last_user = false -%}
|
||||
{%- set ns.is_first = false %}
|
||||
{%- set ns.is_tool = false -%}
|
||||
{%- for tool in message['tool_calls'] %}
|
||||
{%- if not ns.is_first %}
|
||||
{%- if message['content'] is none %}
|
||||
{{'<|tool▁calls▁begin|><|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments']|tojson + '<|tool▁call▁end|>'}}
|
||||
{%- else %}
|
||||
{{message['content'] + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments']|tojson + '<|tool▁call▁end|>'}}
|
||||
{%- endif %}
|
||||
{%- set ns.is_first = true -%}
|
||||
{%- else %}
|
||||
{{'<|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments']|tojson + '<|tool▁call▁end|>'}}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}
|
||||
{%- endif %}
|
||||
{%- if message['role'] == 'assistant' and (message['tool_calls'] is not defined or message['tool_calls'] is none) %}
|
||||
{%- if ns.is_last_user %}
|
||||
{{'<|Assistant|>'}}
|
||||
{%- if message['prefix'] is defined and message['prefix'] and thinking %}
|
||||
{{'<think>'}}
|
||||
{%- else %}
|
||||
{{'</think>'}}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- set ns.is_last_user = false -%}
|
||||
{%- if ns.is_tool %}
|
||||
{{message['content'] + '<|end▁of▁sentence|>'}}
|
||||
{%- set ns.is_tool = false -%}
|
||||
{%- else %}
|
||||
{%- set content = message['content'] -%}
|
||||
{%- if '</think>' in content %}
|
||||
{%- set content = content.split('</think>', 1)[1] -%}
|
||||
{%- endif %}
|
||||
{{content + '<|end▁of▁sentence|>'}}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- if message['role'] == 'tool' %}
|
||||
{%- set ns.is_last_user = false -%}
|
||||
{%- set ns.is_tool = true -%}
|
||||
{{'<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}
|
||||
{%- endif %}
|
||||
{%- endfor -%}
|
||||
{%- if add_generation_prompt and ns.is_last_user and not ns.is_tool %}
|
||||
{{'<|Assistant|>'}}
|
||||
{%- if not thinking %}
|
||||
{{'</think>'}}
|
||||
{%- else %}
|
||||
{{'<think>'}}
|
||||
{%- endif %}
|
||||
{% endif %}
|
||||
123
examples/tool_chat_template_gemma3_pythonic.jinja
Normal file
123
examples/tool_chat_template_gemma3_pythonic.jinja
Normal file
@ -0,0 +1,123 @@
|
||||
{#- Begin-of-sequence token to start the model prompt -#}
|
||||
{{ bos_token }}
|
||||
{#- Extracts the system message. Gemma does not support system messages so it will be prepended to first user message. -#}
|
||||
{%- if messages[0]['role'] == 'system' -%}
|
||||
{%- if messages[0]['content'] is string -%}
|
||||
{%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}
|
||||
{%- else -%}
|
||||
{%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}
|
||||
{%- endif -%}
|
||||
{%- set loop_messages = messages[1:] -%}
|
||||
{%- else -%}
|
||||
{%- set first_user_prefix = "" -%}
|
||||
{%- set loop_messages = messages -%}
|
||||
{%- endif -%}
|
||||
{#- Set tools to none if not defined for this ChatCompletion request (helps avoid errors later) -#}
|
||||
{%- if not tools is defined %}
|
||||
{%- set tools = none %}
|
||||
{%- endif %}
|
||||
{#- Validate alternating user/assistant messages (excluding 'tool' messages and ones with tool_calls) -#}
|
||||
{%- for message in loop_messages | rejectattr("role", "equalto", "tool") | selectattr("tool_calls", "undefined") -%}
|
||||
{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
|
||||
{{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
|
||||
{#- Main loop over all messages in the conversation history -#}
|
||||
{%- for message in loop_messages -%}
|
||||
{#- Normalize roles for model prompt formatting -#}
|
||||
{%- if (message['role'] == 'assistant') -%}
|
||||
{%- set role = "model" -%}
|
||||
{%- elif (message['role'] == 'tool') -%}
|
||||
{%- set role = "user" -%}
|
||||
{%- else -%}
|
||||
{%- set role = message['role'] -%}
|
||||
{%- endif -%}
|
||||
{#- Mark the start of a message block with the appropriate role -#}
|
||||
{{ '<start_of_turn>' + role + '\n' -}}
|
||||
|
||||
{#- Insert system message content (if present) at the beginning of the first message. -#}
|
||||
{%- if loop.first -%}
|
||||
{{ first_user_prefix }}
|
||||
{#- Append system message with tool information if using tools in message request. -#}
|
||||
{%- if tools is not none -%}
|
||||
{{- "Tools (functions) are available. If you decide to invoke one or more of the tools, you must respond with a python list of the function calls.\n" -}}
|
||||
{{- "Example Format: [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)] \n" -}}
|
||||
{{- "Do not use variables. DO NOT USE MARKDOWN SYNTAX. You SHOULD NOT include any other text in the response if you call a function. If none of the functions can be used, point it out. If you lack the parameters required by the function, also point it out.\n" -}}
|
||||
{{- "Here is a list of functions in JSON format that you can invoke.\n" -}}
|
||||
{{- tools | tojson(indent=4) -}}
|
||||
{{- "\n\n" -}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
|
||||
{#- Format model tool calls (turns where model indicates they want to call a tool) -#}
|
||||
{%- if 'tool_calls' in message -%}
|
||||
{#- Opening bracket for tool call list. -#}
|
||||
{{- '[' -}}
|
||||
{#- For each tool call -#}
|
||||
{%- for tool_call in message.tool_calls -%}
|
||||
{#- Get tool call function. -#}
|
||||
{%- if tool_call.function is defined -%}
|
||||
{%- set tool_call = tool_call.function -%}
|
||||
{%- endif -%}
|
||||
{#- Function name & opening parenthesis. -#}
|
||||
{{- tool_call.name + '(' -}}
|
||||
|
||||
{#-- Handle arguments as list (positional) or dict (named) --#}
|
||||
{#-- Named arguments (dict) --#}
|
||||
{%- if tool_call.arguments is iterable and tool_call.arguments is mapping -%}
|
||||
{%- set first = true -%}
|
||||
{%- for key, val in tool_call.arguments.items() -%}
|
||||
{%- if not first %}, {% endif -%}
|
||||
{{ key }}={{ val | tojson }}
|
||||
{%- set first = false -%}
|
||||
{%- endfor -%}
|
||||
{#-- Positional arguments (list) --#}
|
||||
{%- elif tool_call.arguments is iterable -%}
|
||||
{{- tool_call.arguments | map('tojson') | join(', ') -}}
|
||||
{#-- Fallback: single positional value --#}
|
||||
{%- else -%}
|
||||
{{- tool_call.arguments | tojson -}}
|
||||
{#-- Closing parenthesis. --#}
|
||||
{%- endif -%}
|
||||
{{- ')' -}}
|
||||
{#-- If more than one tool call, place comma and move to formatting next tool call --#}
|
||||
{%- if not loop.last -%}, {% endif -%}
|
||||
{%- endfor -%}
|
||||
{#- Closing bracket for tool call list. -#}
|
||||
{{- ']' -}}
|
||||
{%- endif -%}
|
||||
|
||||
{#- Tool response start tag (for messages from a tool) -#}
|
||||
{%- if (message['role'] == 'tool') -%}
|
||||
{{ '<tool_response>\n' -}}
|
||||
{%- endif -%}
|
||||
|
||||
{#- Render the message content: handle plain string or multimodal content like image/text -#}
|
||||
{%- if message['content'] is string -%}
|
||||
{{ message['content'] | trim }}
|
||||
{%- elif message['content'] is iterable -%}
|
||||
{%- for item in message['content'] -%}
|
||||
{%- if item['type'] == 'image' -%}
|
||||
{{ '<start_of_image>' }}
|
||||
{%- elif item['type'] == 'text' -%}
|
||||
{{ item['text'] | trim }}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- else -%}
|
||||
{{ raise_exception("Invalid content type") }}
|
||||
{%- endif -%}
|
||||
|
||||
{#- Tool response end tag -#}
|
||||
{%- if (message['role'] == 'tool') -%}
|
||||
{{ '</tool_response>' -}}
|
||||
{%- endif -%}
|
||||
|
||||
{#- Mark end of a single turn -#}
|
||||
{{ '<end_of_turn>\n' }}
|
||||
{%- endfor -%}
|
||||
|
||||
{#- If generation is to be triggered, add model prompt prefix -#}
|
||||
{%- if add_generation_prompt -%}
|
||||
{{'<start_of_turn>model\n'}}
|
||||
{%- endif -%}
|
||||
@ -13,12 +13,12 @@ protobuf # Required by LlamaTokenizer.
|
||||
fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint.
|
||||
aiohttp
|
||||
openai >= 1.99.1 # For Responses API with reasoning content
|
||||
pydantic >= 2.10
|
||||
pydantic >= 2.11.7
|
||||
prometheus_client >= 0.18.0
|
||||
pillow # Required for image processing
|
||||
prometheus-fastapi-instrumentator >= 7.0.0
|
||||
tiktoken >= 0.6.0 # Required for DBRX tokenizer
|
||||
lm-format-enforcer >= 0.10.11, < 0.11
|
||||
lm-format-enforcer == 0.11.3
|
||||
llguidance >= 0.7.11, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64"
|
||||
outlines_core == 0.2.10 ; platform_machine != "s390x"
|
||||
outlines == 0.1.11 ; platform_machine == "s390x"
|
||||
|
||||
@ -27,7 +27,7 @@ mistral_common[image,audio] >= 1.8.2 # required for voxtral test
|
||||
num2words # required for smolvlm test
|
||||
opencv-python-headless >= 4.11.0 # required for video test
|
||||
datamodel_code_generator # required for minicpm3 test
|
||||
lm-eval[api]==0.4.8 # required for model evaluation test
|
||||
lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d # required for model evaluation test
|
||||
mteb>=1.38.11, <2 # required for mteb test
|
||||
transformers==4.52.4
|
||||
tokenizers==0.21.1
|
||||
|
||||
@ -6,7 +6,7 @@ torch==2.7.0
|
||||
torchvision==0.22.0
|
||||
torchaudio==2.7.0
|
||||
|
||||
triton==3.2
|
||||
triton==3.3.0
|
||||
cmake>=3.26.1,<4
|
||||
packaging>=24.2
|
||||
setuptools>=77.0.3,<80.0.0
|
||||
|
||||
@ -17,4 +17,4 @@ setuptools>=77.0.3,<80.0.0
|
||||
setuptools-scm>=8
|
||||
runai-model-streamer==0.11.0
|
||||
runai-model-streamer-s3==0.11.0
|
||||
conch-triton-kernels==1.2.1
|
||||
conch-triton-kernels==1.2.1
|
||||
@ -32,7 +32,8 @@ num2words # required for smolvlm test
|
||||
open_clip_torch==2.32.0 # Required for nemotron_vl test
|
||||
opencv-python-headless >= 4.11.0 # required for video test
|
||||
datamodel_code_generator # required for minicpm3 test
|
||||
lm-eval[api]==0.4.8 # required for model evaluation test
|
||||
# TODO: Use lm-eval[api]==0.4.10 once released
|
||||
lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d # required for model evaluation test
|
||||
mteb[bm25s]>=1.38.11, <2 # required for mteb test
|
||||
transformers==4.55.2
|
||||
tokenizers==0.21.1
|
||||
|
||||
@ -408,7 +408,7 @@ lightning-utilities==0.14.3
|
||||
# torchmetrics
|
||||
llvmlite==0.44.0
|
||||
# via numba
|
||||
lm-eval==0.4.8
|
||||
lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d
|
||||
# via -r requirements/test.in
|
||||
lxml==5.3.0
|
||||
# via
|
||||
@ -742,7 +742,7 @@ pycparser==2.22
|
||||
# via cffi
|
||||
pycryptodomex==3.22.0
|
||||
# via blobfile
|
||||
pydantic==2.11.5
|
||||
pydantic==2.11.7
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# albumentations
|
||||
|
||||
2
setup.py
2
setup.py
@ -695,6 +695,8 @@ setup(
|
||||
"video": [], # Kept for backwards compatibility
|
||||
# FlashInfer should be updated together with the Dockerfile
|
||||
"flashinfer": ["flashinfer-python==0.2.12"],
|
||||
# Optional deps for AMD FP4 quantization support
|
||||
"petit-kernel": ["petit-kernel"],
|
||||
},
|
||||
cmdclass=cmdclass,
|
||||
package_data=package_data,
|
||||
|
||||
@ -177,3 +177,34 @@ def test_end_to_end(monkeypatch: pytest.MonkeyPatch, model: str, use_v1: bool):
|
||||
|
||||
# cmp output
|
||||
assert output[0].outputs[0].text == output3[0].outputs[0].text
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
def test_deep_sleep():
|
||||
model = "Qwen/Qwen3-0.6B"
|
||||
free, total = torch.cuda.mem_get_info()
|
||||
used_bytes_baseline = total - free # in case other process is running
|
||||
llm = LLM(model, enable_sleep_mode=True)
|
||||
prompt = "How are you?"
|
||||
sampling_params = SamplingParams(temperature=0, max_tokens=10)
|
||||
output = llm.generate(prompt, sampling_params)
|
||||
|
||||
# Put the engine to deep sleep
|
||||
llm.sleep(level=2)
|
||||
|
||||
free_gpu_bytes_after_sleep, total = torch.cuda.mem_get_info()
|
||||
used_bytes = total - free_gpu_bytes_after_sleep - used_bytes_baseline
|
||||
assert used_bytes < 3 * GiB_bytes
|
||||
|
||||
llm.wake_up(tags=["weights"])
|
||||
llm.collective_rpc("reload_weights")
|
||||
free_gpu_bytes_wake_up_w, total = torch.cuda.mem_get_info()
|
||||
used_bytes = total - free_gpu_bytes_wake_up_w - used_bytes_baseline
|
||||
assert used_bytes < 4 * GiB_bytes
|
||||
|
||||
# now allocate kv cache and cuda graph memory
|
||||
llm.wake_up(tags=["kv_cache"])
|
||||
output2 = llm.generate(prompt, sampling_params)
|
||||
|
||||
# cmp output
|
||||
assert output[0].outputs[0].text == output2[0].outputs[0].text
|
||||
|
||||
344
tests/benchmarks/test_random_dataset.py
Normal file
344
tests/benchmarks/test_random_dataset.py
Normal file
@ -0,0 +1,344 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import random
|
||||
from typing import Any, NamedTuple, Optional, cast
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||
|
||||
from vllm.benchmarks.datasets import (RandomDataset, RandomMultiModalDataset,
|
||||
SampleRequest)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def hf_tokenizer() -> PreTrainedTokenizerBase:
|
||||
# Use a small, commonly available tokenizer
|
||||
return AutoTokenizer.from_pretrained("gpt2")
|
||||
|
||||
|
||||
class Params(NamedTuple):
|
||||
num_requests: int
|
||||
prefix_len: int
|
||||
range_ratio: float
|
||||
input_len: int
|
||||
output_len: int
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def random_dataset_params() -> Params:
|
||||
return Params(num_requests=16,
|
||||
prefix_len=7,
|
||||
range_ratio=0.3,
|
||||
input_len=50,
|
||||
output_len=20)
|
||||
|
||||
|
||||
def _fingerprint_sample(req: SampleRequest) -> tuple[str, int, int]:
|
||||
"""Project a SampleRequest into a comparable tuple."""
|
||||
return (req.prompt, req.prompt_len, req.expected_output_len)
|
||||
|
||||
|
||||
def _collect_samples(dataset: RandomDataset,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int = 16,
|
||||
prefix_len: int = 7,
|
||||
range_ratio: float = 0.3,
|
||||
input_len: int = 50,
|
||||
output_len: int = 20) -> list[tuple[str, int, int]]:
|
||||
samples = dataset.sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=num_requests,
|
||||
prefix_len=prefix_len,
|
||||
range_ratio=range_ratio,
|
||||
input_len=input_len,
|
||||
output_len=output_len,
|
||||
)
|
||||
return [_fingerprint_sample(s) for s in samples]
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_dataset_same_seed(
|
||||
hf_tokenizer: PreTrainedTokenizerBase,
|
||||
random_dataset_params: Params) -> None:
|
||||
"""Same seed should yield identical outputs, even if global RNGs change.
|
||||
|
||||
This guards against accidental reliance on Python's random or np.random
|
||||
in RandomDataset after moving to numpy.default_rng.
|
||||
"""
|
||||
p = random_dataset_params
|
||||
common_seed = 123
|
||||
dataset_a = RandomDataset(random_seed=common_seed)
|
||||
dataset_b = RandomDataset(random_seed=common_seed)
|
||||
a = _collect_samples(dataset_a,
|
||||
hf_tokenizer,
|
||||
num_requests=p.num_requests,
|
||||
prefix_len=p.prefix_len,
|
||||
range_ratio=p.range_ratio,
|
||||
input_len=p.input_len,
|
||||
output_len=p.output_len)
|
||||
|
||||
# Perturb global RNG state to ensure isolation
|
||||
random.seed(999)
|
||||
_ = [random.random() for _ in range(100)]
|
||||
np.random.seed(888)
|
||||
_ = [np.random.random() for _ in range(100)]
|
||||
|
||||
b = _collect_samples(dataset_b,
|
||||
hf_tokenizer,
|
||||
num_requests=p.num_requests,
|
||||
prefix_len=p.prefix_len,
|
||||
range_ratio=p.range_ratio,
|
||||
input_len=p.input_len,
|
||||
output_len=p.output_len)
|
||||
assert a == b
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_dataset_different_seeds(
|
||||
hf_tokenizer: PreTrainedTokenizerBase,
|
||||
random_dataset_params: Params) -> None:
|
||||
"""Different seeds should change outputs with overwhelming likelihood."""
|
||||
p = random_dataset_params
|
||||
seed_a = 0
|
||||
dataset_a = RandomDataset(random_seed=seed_a)
|
||||
a = _collect_samples(dataset_a,
|
||||
hf_tokenizer,
|
||||
num_requests=p.num_requests,
|
||||
prefix_len=p.prefix_len,
|
||||
range_ratio=p.range_ratio,
|
||||
input_len=p.input_len,
|
||||
output_len=p.output_len)
|
||||
|
||||
seed_b = 999
|
||||
dataset_b = RandomDataset(random_seed=seed_b)
|
||||
# Perturb global RNG with same seed as dataset_a to ensure isolation
|
||||
random.seed(seed_a)
|
||||
np.random.seed(seed_a)
|
||||
b = _collect_samples(dataset_b,
|
||||
hf_tokenizer,
|
||||
num_requests=p.num_requests,
|
||||
prefix_len=p.prefix_len,
|
||||
range_ratio=p.range_ratio,
|
||||
input_len=p.input_len,
|
||||
output_len=p.output_len)
|
||||
assert a != b
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# RandomMultiModalDataset tests
|
||||
# -----------------------------
|
||||
|
||||
def _mm_fingerprint_sample(
|
||||
req: SampleRequest,
|
||||
) -> tuple[str, int, int, int, list[str]]:
|
||||
"""Create a compact fingerprint for multimodal samples.
|
||||
|
||||
Includes:
|
||||
- prompt string
|
||||
- prompt_len
|
||||
- expected_output_len
|
||||
- count of multimodal items
|
||||
- per-item type and URL prefix (e.g., 'data:image/jpeg;base64,')
|
||||
"""
|
||||
items = req.multi_modal_data or []
|
||||
item_prefixes: list[str] = []
|
||||
for it in items:
|
||||
if isinstance(it, dict) and it.get("type") == "image_url":
|
||||
url = it.get("image_url", {}).get("url", "")
|
||||
# Only keep a short identifying prefix to avoid huge strings
|
||||
item_prefixes.append(f"image:{url[:22]}")
|
||||
elif isinstance(it, dict) and it.get("type") == "video_url":
|
||||
url = it.get("video_url", {}).get("url", "")
|
||||
item_prefixes.append(f"video:{url[:22]}")
|
||||
else:
|
||||
item_prefixes.append("unknown:")
|
||||
return (req.prompt, req.prompt_len, req.expected_output_len, len(items),
|
||||
item_prefixes)
|
||||
|
||||
|
||||
def _collect_mm_samples(
|
||||
dataset: RandomMultiModalDataset,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
*,
|
||||
num_requests: int = 8,
|
||||
prefix_len: int = 3,
|
||||
range_ratio: float = 0.0,
|
||||
input_len: int = 20,
|
||||
output_len: int = 5,
|
||||
base_items_per_request: int = 2,
|
||||
num_mm_items_range_ratio: float = 0.0,
|
||||
limit_mm_per_prompt: Optional[dict[str, int]] = None,
|
||||
bucket_config: Optional[dict[tuple[int, int, int], float]] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
) -> list[SampleRequest]:
|
||||
if limit_mm_per_prompt is None:
|
||||
limit_mm_per_prompt = {"image": 5, "video": 0}
|
||||
if bucket_config is None:
|
||||
bucket_config = {(32, 32, 1): 0.5, (52, 64, 1): 0.5}
|
||||
return dataset.sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=num_requests,
|
||||
prefix_len=prefix_len,
|
||||
range_ratio=range_ratio,
|
||||
input_len=input_len,
|
||||
output_len=output_len,
|
||||
base_items_per_request=base_items_per_request,
|
||||
num_mm_items_range_ratio=num_mm_items_range_ratio,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
bucket_config=bucket_config,
|
||||
enable_multimodal_chat=enable_multimodal_chat,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_mm_same_seed(hf_tokenizer: PreTrainedTokenizerBase) -> None:
|
||||
seed = 42
|
||||
ds_a = RandomMultiModalDataset(random_seed=seed)
|
||||
ds_b = RandomMultiModalDataset(random_seed=seed)
|
||||
a = _collect_mm_samples(ds_a, hf_tokenizer)
|
||||
b = _collect_mm_samples(ds_b, hf_tokenizer)
|
||||
fa = [_mm_fingerprint_sample(s) for s in a]
|
||||
fb = [_mm_fingerprint_sample(s) for s in b]
|
||||
assert fa == fb
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_mm_different_seeds(
|
||||
hf_tokenizer: PreTrainedTokenizerBase,
|
||||
) -> None:
|
||||
ds_a = RandomMultiModalDataset(random_seed=0)
|
||||
ds_b = RandomMultiModalDataset(random_seed=999)
|
||||
a = _collect_mm_samples(ds_a, hf_tokenizer)
|
||||
b = _collect_mm_samples(ds_b, hf_tokenizer)
|
||||
fa = [_mm_fingerprint_sample(s) for s in a]
|
||||
fb = [_mm_fingerprint_sample(s) for s in b]
|
||||
assert fa != fb
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_mm_respects_limits(
|
||||
hf_tokenizer: PreTrainedTokenizerBase,
|
||||
) -> None:
|
||||
ds = RandomMultiModalDataset(random_seed=0)
|
||||
# Requesting 3 items with a per-prompt limit of 1 should error per current
|
||||
# design (dataset refuses to silently clamp below the requested baseline).
|
||||
with pytest.raises(ValueError):
|
||||
_collect_mm_samples(
|
||||
ds,
|
||||
hf_tokenizer,
|
||||
num_requests=12,
|
||||
base_items_per_request=3,
|
||||
num_mm_items_range_ratio=0.0,
|
||||
limit_mm_per_prompt={"image": 1, "video": 0},
|
||||
bucket_config={(32, 32, 1): 1.0},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_mm_zero_prob_entries_are_removed(
|
||||
hf_tokenizer: PreTrainedTokenizerBase,
|
||||
) -> None:
|
||||
ds = RandomMultiModalDataset(random_seed=0)
|
||||
# Second bucket has zero probability and should be ignored after
|
||||
# normalization
|
||||
samples = _collect_mm_samples(
|
||||
ds,
|
||||
hf_tokenizer,
|
||||
num_requests=6,
|
||||
base_items_per_request=2,
|
||||
num_mm_items_range_ratio=0.0,
|
||||
limit_mm_per_prompt={"image": 10, "video": 0},
|
||||
bucket_config={(32, 32, 1): 1.0, (52, 64, 1): 0.0},
|
||||
)
|
||||
for s in samples:
|
||||
assert isinstance(s.multi_modal_data, list)
|
||||
typed_mm = cast(list[dict[str, Any]], s.multi_modal_data)
|
||||
for it in typed_mm:
|
||||
assert it.get("type") == "image_url"
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_mm_zero_items(hf_tokenizer: PreTrainedTokenizerBase) -> None:
|
||||
ds = RandomMultiModalDataset(random_seed=0)
|
||||
samples = _collect_mm_samples(
|
||||
ds,
|
||||
hf_tokenizer,
|
||||
num_requests=5,
|
||||
base_items_per_request=0,
|
||||
num_mm_items_range_ratio=0.0,
|
||||
limit_mm_per_prompt={"image": 5, "video": 0},
|
||||
bucket_config={(32, 32, 1): 1.0},
|
||||
)
|
||||
for s in samples:
|
||||
assert s.multi_modal_data == []
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_mm_num_items_per_prompt(
|
||||
hf_tokenizer: PreTrainedTokenizerBase) -> None:
|
||||
ds = RandomMultiModalDataset(random_seed=0)
|
||||
# Fixed number of images per prompt
|
||||
# set num_mm_items_range_ratio to 0.0
|
||||
# TODO: modify video values when video sampling is implemented
|
||||
samples_fixed_items = _collect_mm_samples(
|
||||
ds,
|
||||
hf_tokenizer,
|
||||
num_requests=5,
|
||||
base_items_per_request=3,
|
||||
num_mm_items_range_ratio=0.0,
|
||||
limit_mm_per_prompt={"image": 3, "video": 0},
|
||||
bucket_config={(32, 32, 1): 1.0},
|
||||
)
|
||||
# Must have 5 requests each with 3 mm items per prompt
|
||||
assert len(samples_fixed_items) == 5
|
||||
for s in samples_fixed_items:
|
||||
mm_data = cast(list[dict[str, Any]], s.multi_modal_data)
|
||||
assert len(mm_data) == 3
|
||||
for it in mm_data:
|
||||
assert it.get("type") == "image_url"
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_mm_bucket_config_not_mutated(
|
||||
hf_tokenizer: PreTrainedTokenizerBase,
|
||||
) -> None:
|
||||
|
||||
ds = RandomMultiModalDataset(random_seed=0)
|
||||
# This bucket config is not normalized to sum to 1
|
||||
# and has more buckets than requested images
|
||||
original = {(32, 32, 1): 0.2, (52, 64, 1): 6, (25, 64, 1): 3}
|
||||
# Keep a snapshot to compare after sampling
|
||||
snapshot = dict(original)
|
||||
|
||||
_ = _collect_mm_samples(
|
||||
ds,
|
||||
hf_tokenizer,
|
||||
num_requests=4,
|
||||
base_items_per_request=1,
|
||||
num_mm_items_range_ratio=0.0,
|
||||
limit_mm_per_prompt={"image": 1, "video": 0},
|
||||
bucket_config=original,
|
||||
)
|
||||
|
||||
# Ensure the original dict content is unchanged
|
||||
assert original == snapshot
|
||||
|
||||
|
||||
# Vary number of mm items per prompt
|
||||
# set num_mm_items_range_ratio to 0.5
|
||||
samples_varying_items = _collect_mm_samples(
|
||||
ds,
|
||||
hf_tokenizer,
|
||||
num_requests=5,
|
||||
base_items_per_request=2,
|
||||
num_mm_items_range_ratio=0.5,
|
||||
limit_mm_per_prompt={"image": 4, "video": 0},
|
||||
bucket_config={(32, 32, 1): 1.0},
|
||||
)
|
||||
# Must have 5 requests each with less than 4 mm items per prompt
|
||||
# but at least 1 mm item per prompt
|
||||
assert len(samples_varying_items) == 5
|
||||
for s in samples_varying_items:
|
||||
mm_data = cast(list[dict[str, Any]], s.multi_modal_data)
|
||||
assert len(mm_data) <= 4
|
||||
assert len(mm_data) >= 1
|
||||
for it in mm_data:
|
||||
assert it.get("type") == "image_url"
|
||||
@ -8,11 +8,12 @@ import vllm.envs as envs
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
|
||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||
from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey,
|
||||
kFp8DynamicTokenSym, kFp8StaticTensorSym)
|
||||
from vllm.compilation.fusion import FUSED_OPS, FusionPass
|
||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.config import CompilationConfig, PassConfig, VllmConfig
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey, kFp8DynamicTokenSym, kFp8StaticTensorSym)
|
||||
|
||||
from .backend import TestBackend
|
||||
|
||||
|
||||
@ -7,11 +7,13 @@ import torch
|
||||
import vllm.envs as envs
|
||||
import vllm.plugins
|
||||
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
|
||||
FusionPass, GroupShape, QuantKey)
|
||||
FusionPass)
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
|
||||
VllmConfig)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape, QuantKey, ScaleDesc)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity)
|
||||
from vllm.platforms import current_platform
|
||||
@ -30,10 +32,8 @@ class TestModel(torch.nn.Module):
|
||||
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
|
||||
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
|
||||
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
|
||||
self.key = QuantKey(dtype=FP8_DTYPE,
|
||||
static=static,
|
||||
group_shape=group_shape,
|
||||
symmetric=True)
|
||||
quant_scale = ScaleDesc(torch.float32, static, group_shape)
|
||||
self.key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
|
||||
if static:
|
||||
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
|
||||
else:
|
||||
|
||||
@ -11,9 +11,10 @@ from tests.models.utils import check_outputs_equal
|
||||
from tests.v1.attention.utils import (BatchSpec, _Backend,
|
||||
create_common_attn_metadata)
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
from vllm.attention import Attention
|
||||
from vllm.attention.selector import global_force_attn_backend_context_manager
|
||||
from vllm.compilation.fusion import QUANT_OPS, QuantKey, kFp8StaticTensorSym
|
||||
from vllm.compilation.fusion import QUANT_OPS
|
||||
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
|
||||
from vllm.compilation.fx_utils import find_op_nodes
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
@ -22,13 +23,14 @@ from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
|
||||
set_current_vllm_config)
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
QuantKey, kFp8StaticTensorSym, kNvfp4Quant)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
# globals needed for string-import custom Dynamo backend field
|
||||
backend: Optional[TestBackend] = None
|
||||
@ -105,9 +107,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
|
||||
|
||||
# check support
|
||||
attn_fusion_supported = [
|
||||
layer.impl.fused_output_quant_supported(quant_key.dtype,
|
||||
quant_key.static,
|
||||
quant_key.group_shape)
|
||||
layer.impl.fused_output_quant_supported(quant_key)
|
||||
for key, layer in compile_config.static_forward_context.items()
|
||||
]
|
||||
|
||||
@ -149,12 +149,12 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
|
||||
backend = None
|
||||
|
||||
|
||||
class TestAttentionStaticQuantPatternModel(torch.nn.Module):
|
||||
"""Test model for AttentionStaticQuantPattern fusion."""
|
||||
class AttentionQuantPatternModel(torch.nn.Module):
|
||||
"""Base model for AttentionQuantPattern fusion."""
|
||||
|
||||
def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int,
|
||||
kv_cache_dtype: torch.dtype, device: torch.device,
|
||||
vllm_config: VllmConfig):
|
||||
vllm_config: VllmConfig, **kwargs):
|
||||
super().__init__()
|
||||
self.num_qo_heads = num_qo_heads
|
||||
self.num_kv_heads = num_kv_heads
|
||||
@ -172,11 +172,6 @@ class TestAttentionStaticQuantPatternModel(torch.nn.Module):
|
||||
prefix="model.layers.0.self_attn.attn",
|
||||
)
|
||||
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR)
|
||||
self.wscale = torch.tensor([1.0], dtype=torch.float32)
|
||||
self.scale = torch.tensor([1.0], dtype=torch.float32)
|
||||
|
||||
self.block_size = 16
|
||||
|
||||
# Initialize attn MetadataBuilder
|
||||
@ -230,23 +225,86 @@ class TestAttentionStaticQuantPatternModel(torch.nn.Module):
|
||||
|
||||
return self.attn_metadata
|
||||
|
||||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
w: torch.Tensor):
|
||||
|
||||
class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
|
||||
"""Test model for AttentionFp8StaticQuantPattern fusion."""
|
||||
|
||||
quant_key = kFp8StaticTensorSym
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=self.quant_key.scale.static,
|
||||
act_quant_group_shape=self.quant_key.scale.group_shape)
|
||||
|
||||
hidden_size = self.num_qo_heads * self.head_size
|
||||
self.w = kwargs.get(
|
||||
"w", {
|
||||
"weight":
|
||||
torch.randn(hidden_size, hidden_size).to(
|
||||
dtype=FP8_DTYPE, device=self.device).t(),
|
||||
"wscale":
|
||||
torch.tensor([1.0], dtype=torch.float32, device=self.device),
|
||||
"scale":
|
||||
torch.tensor([1.0], dtype=torch.float32, device=self.device),
|
||||
})
|
||||
|
||||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||||
"""Forward pass that creates the pattern to be fused."""
|
||||
attn_output = self.attn(q, k, v)
|
||||
return self.fp8_linear.apply(input=attn_output,
|
||||
weight=w,
|
||||
weight_scale=self.wscale,
|
||||
input_scale=self.scale)
|
||||
weight=self.w["weight"],
|
||||
weight_scale=self.w["wscale"],
|
||||
input_scale=self.w["scale"])
|
||||
|
||||
|
||||
class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
|
||||
"""Test model for AttentionNvfp4QuantPattern fusion."""
|
||||
|
||||
quant_key = kNvfp4Quant
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
hidden_size = self.num_qo_heads * self.head_size
|
||||
self.w = kwargs.get(
|
||||
"w", {
|
||||
"weight":
|
||||
torch.randint(256, (hidden_size, hidden_size // 2),
|
||||
dtype=FP4_DTYPE,
|
||||
device=self.device),
|
||||
"wscale_swizzled":
|
||||
torch.randn(hidden_size, hidden_size // 16).to(
|
||||
dtype=FP8_DTYPE, device=self.device),
|
||||
"wscale":
|
||||
torch.tensor([500], dtype=torch.float32, device=self.device),
|
||||
"scale":
|
||||
torch.tensor([0.002], dtype=torch.float32, device=self.device),
|
||||
})
|
||||
|
||||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||||
"""Forward pass that creates the pattern to be fused."""
|
||||
attn_output = self.attn(q, k, v)
|
||||
quant_output, output_block_scale = scaled_fp4_quant(
|
||||
attn_output, 1 / self.w["scale"])
|
||||
return cutlass_scaled_fp4_mm(a=quant_output,
|
||||
b=self.w["weight"],
|
||||
block_scale_a=output_block_scale,
|
||||
block_scale_b=self.w["wscale_swizzled"],
|
||||
alpha=self.w["scale"] * self.w["wscale"],
|
||||
out_dtype=attn_output.dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_qo_heads, num_kv_heads", [(64, 8), (40, 8)])
|
||||
@pytest.mark.parametrize("head_size", [128])
|
||||
@pytest.mark.parametrize("batch_size", [7, 256, 533])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, quant_key",
|
||||
[("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", kFp8StaticTensorSym)])
|
||||
@pytest.mark.parametrize("model_name, model_class",
|
||||
[("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
|
||||
TestAttentionFp8StaticQuantPatternModel),
|
||||
("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
|
||||
TestAttentionNvfp4QuantPatternModel)])
|
||||
@pytest.mark.parametrize("backend", [_Backend.FLASHINFER])
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA")
|
||||
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
|
||||
@ -255,8 +313,8 @@ class TestAttentionStaticQuantPatternModel(torch.nn.Module):
|
||||
def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
||||
head_size: int, batch_size: int,
|
||||
dtype: torch.dtype, model_name: str,
|
||||
quant_key: QuantKey, backend: _Backend,
|
||||
monkeypatch, dist_init):
|
||||
model_class: type[AttentionQuantPatternModel],
|
||||
backend: _Backend, monkeypatch, dist_init):
|
||||
"""Test AttentionStaticQuantPattern fusion pass"""
|
||||
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
@ -277,8 +335,10 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
||||
cache_config=CacheConfig(cache_dtype="fp8"))
|
||||
|
||||
# Create test inputs
|
||||
hidden_size = num_qo_heads * head_size
|
||||
q = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
|
||||
q = torch.randn(batch_size,
|
||||
num_qo_heads * head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
k = torch.randn(batch_size,
|
||||
num_kv_heads * head_size,
|
||||
dtype=dtype,
|
||||
@ -287,7 +347,6 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
||||
num_kv_heads * head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
linear_w = torch.randn(hidden_size, hidden_size).to(FP8_DTYPE).t()
|
||||
|
||||
# Mark first dimension as dynamic for realistic testing
|
||||
torch._dynamo.mark_dynamic(q, 0)
|
||||
@ -299,9 +358,12 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
||||
with set_current_vllm_config(vllm_config_unfused), set_forward_context(
|
||||
attn_metadata=None, vllm_config=vllm_config_unfused
|
||||
), global_force_attn_backend_context_manager(backend):
|
||||
model_unfused = TestAttentionStaticQuantPatternModel(
|
||||
num_qo_heads, num_kv_heads, head_size, FP8_DTYPE, device,
|
||||
vllm_config_unfused)
|
||||
model_unfused = model_class(num_qo_heads=num_qo_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_size=head_size,
|
||||
kv_cache_dtype=FP8_DTYPE,
|
||||
device=device,
|
||||
vllm_config=vllm_config_unfused)
|
||||
model_unfused = model_unfused.to(device)
|
||||
|
||||
forward_ctx = get_forward_context()
|
||||
@ -309,7 +371,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
||||
batch_size)
|
||||
|
||||
# Run model directly without compilation and fusion
|
||||
result_unfused = model_unfused(q, k, v, linear_w)
|
||||
result_unfused = model_unfused(q, k, v)
|
||||
|
||||
# Run model with attn fusion enabled
|
||||
vllm_config.compilation_config.pass_config = PassConfig(
|
||||
@ -317,9 +379,13 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
||||
with set_current_vllm_config(vllm_config), set_forward_context(
|
||||
attn_metadata=None, vllm_config=vllm_config
|
||||
), global_force_attn_backend_context_manager(backend):
|
||||
model_fused = TestAttentionStaticQuantPatternModel(
|
||||
num_qo_heads, num_kv_heads, head_size, FP8_DTYPE, device,
|
||||
vllm_config)
|
||||
model_fused = model_class(num_qo_heads=num_qo_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_size=head_size,
|
||||
kv_cache_dtype=FP8_DTYPE,
|
||||
device=device,
|
||||
vllm_config=vllm_config,
|
||||
w=model_unfused.w)
|
||||
model_fused = model_fused.to(device)
|
||||
|
||||
forward_ctx = get_forward_context()
|
||||
@ -336,21 +402,20 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
||||
backend=test_backend,
|
||||
fullgraph=True)
|
||||
assert model_compiled.attn._o_scale_float is None
|
||||
result_fused_1 = model_compiled(q, k, v, linear_w)
|
||||
result_fused_1 = model_compiled(q, k, v)
|
||||
|
||||
# After the 1st round of the forward pass, output quant scale should be
|
||||
# loaded into the attn layer's _o_scale_float, the 2nd round should
|
||||
# reuse the loaded _o_scale_float
|
||||
assert model_compiled.attn._o_scale_float is not None
|
||||
result_fused_2 = model_compiled(q, k, v, linear_w)
|
||||
result_fused_2 = model_compiled(q, k, v)
|
||||
assert model_compiled.attn._o_scale_float is not None
|
||||
|
||||
# Check attn fusion support
|
||||
quant_key = model_class.quant_key
|
||||
attn_fusion_supported = [
|
||||
layer.impl.fused_output_quant_supported(quant_key.dtype,
|
||||
quant_key.static,
|
||||
quant_key.group_shape) for key,
|
||||
layer in vllm_config.compilation_config.static_forward_context.items()
|
||||
layer.impl.fused_output_quant_supported(quant_key) for key, layer in
|
||||
vllm_config.compilation_config.static_forward_context.items()
|
||||
]
|
||||
if any(attn_fusion_supported):
|
||||
# Check quantization ops in the graph before and after fusion
|
||||
@ -370,6 +435,15 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
|
||||
assert attn_nodes_post[0].kwargs.get("output_scale") is not None, \
|
||||
"Attention should have output_scale after fusion"
|
||||
|
||||
assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, \
|
||||
"Attention should not have output_block_scale before fusion"
|
||||
if quant_key.dtype == FP8_DTYPE:
|
||||
assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, \
|
||||
"Attention should not have output_block_scale after FP8 fusion"
|
||||
elif quant_key.dtype == FP4_DTYPE:
|
||||
assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, \
|
||||
"Attention should have output_block_scale after FP4 fusion" # noqa: E501
|
||||
|
||||
# Check that results are closed
|
||||
torch.testing.assert_close(result_unfused,
|
||||
result_fused_1,
|
||||
|
||||
@ -233,6 +233,7 @@ MULTIMODAL_MODELS = {
|
||||
"openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(),
|
||||
"allenai/Molmo-7B-D-0924": PPTestSettings.fast(),
|
||||
"AIDC-AI/Ovis2-1B": PPTestSettings.fast(),
|
||||
"AIDC-AI/Ovis2.5-2B": PPTestSettings.fast(),
|
||||
"microsoft/Phi-3.5-vision-instruct": PPTestSettings.fast(),
|
||||
"mistralai/Pixtral-12B-2409": PPTestSettings.fast(load_format="dummy"),
|
||||
"Qwen/Qwen-VL-Chat": PPTestSettings.fast(),
|
||||
|
||||
108
tests/distributed/test_symm_mem_allreduce.py
Normal file
108
tests/distributed/test_symm_mem_allreduce.py
Normal file
@ -0,0 +1,108 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import random
|
||||
import typing
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.device_communicators.cuda_communicator import (
|
||||
CudaCommunicator)
|
||||
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
|
||||
get_tp_group,
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import update_environment_variables
|
||||
|
||||
torch.manual_seed(42)
|
||||
random.seed(44)
|
||||
|
||||
test_size_elements = 4 * 1024 * 1024
|
||||
|
||||
|
||||
def symm_mem_allreduce_worker(local_rank: int, world_size: int):
|
||||
monkeypatch = pytest.MonkeyPatch()
|
||||
with monkeypatch.context() as m:
|
||||
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
|
||||
dtype = torch.bfloat16
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
torch.cuda.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
torch.set_default_dtype(dtype)
|
||||
update_environment_variables({
|
||||
'RANK': str(local_rank),
|
||||
'LOCAL_RANK': str(local_rank),
|
||||
'WORLD_SIZE': str(world_size),
|
||||
'MASTER_ADDR': 'localhost',
|
||||
'MASTER_PORT': '12345',
|
||||
})
|
||||
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
cuda_communicator = typing.cast(CudaCommunicator,
|
||||
get_tp_group().device_communicator)
|
||||
symm_mem_comm = cuda_communicator.symm_mem_comm
|
||||
if symm_mem_comm is None or symm_mem_comm.disabled:
|
||||
pytest.skip("SymmMemCommunicator is not available or disabled.")
|
||||
|
||||
inp_direct_symm_mem = torch.randint(1,
|
||||
23, (test_size_elements, ),
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem):
|
||||
pytest.skip(
|
||||
"SymmMemCommunicator isn't used for this world and input size."
|
||||
)
|
||||
|
||||
original_inp_direct_symm_mem = inp_direct_symm_mem.clone()
|
||||
out_direct_symm_mem = symm_mem_comm.all_reduce(inp_direct_symm_mem)
|
||||
assert out_direct_symm_mem is not None
|
||||
|
||||
group = get_tensor_model_parallel_group().device_group
|
||||
dist.all_reduce(original_inp_direct_symm_mem, group=group)
|
||||
torch.testing.assert_close(out_direct_symm_mem,
|
||||
original_inp_direct_symm_mem,
|
||||
atol=2.5,
|
||||
rtol=0.1)
|
||||
|
||||
# Test tensor_model_parallel_all_reduce which should use symm_mem
|
||||
inp_tensor_parallel = torch.randint(-23,
|
||||
1, (test_size_elements, ),
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
original_inp_tensor_parallel = inp_tensor_parallel.clone()
|
||||
out_tensor_parallel = tensor_model_parallel_all_reduce(
|
||||
inp_tensor_parallel)
|
||||
dist.all_reduce(original_inp_tensor_parallel, group=group)
|
||||
torch.testing.assert_close(out_tensor_parallel,
|
||||
original_inp_tensor_parallel,
|
||||
atol=2.5,
|
||||
rtol=0.1)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(),
|
||||
reason="SymmMemAllreduce is only available for CUDA platforms.")
|
||||
@pytest.mark.parametrize("tp_size", [2])
|
||||
@pytest.mark.parametrize("pipeline_parallel_size", [1])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
|
||||
reason="Only test on CUDA")
|
||||
def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size,
|
||||
pipeline_parallel_size):
|
||||
world_size = tp_size * pipeline_parallel_size
|
||||
if world_size > torch.cuda.device_count():
|
||||
pytest.skip("Not enough GPUs to run the test.")
|
||||
|
||||
# Enable SymmMemCommunicator
|
||||
monkeypatch.setenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1")
|
||||
|
||||
mp.spawn(symm_mem_allreduce_worker, args=(world_size, ), nprocs=world_size)
|
||||
cleanup_dist_env_and_memory()
|
||||
@ -18,10 +18,9 @@ def text_llm():
|
||||
enforce_eager=True,
|
||||
seed=0)
|
||||
|
||||
with llm.deprecate_legacy_api():
|
||||
yield weakref.proxy(llm)
|
||||
yield weakref.proxy(llm)
|
||||
|
||||
del llm
|
||||
del llm
|
||||
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
@ -88,10 +87,9 @@ def vision_llm():
|
||||
seed=0,
|
||||
)
|
||||
|
||||
with llm.deprecate_legacy_api():
|
||||
yield weakref.proxy(llm)
|
||||
yield weakref.proxy(llm)
|
||||
|
||||
del llm
|
||||
del llm
|
||||
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
@ -158,10 +156,9 @@ def thinking_llm():
|
||||
seed=0,
|
||||
)
|
||||
|
||||
with llm.deprecate_legacy_api():
|
||||
yield weakref.proxy(llm)
|
||||
yield weakref.proxy(llm)
|
||||
|
||||
del llm
|
||||
del llm
|
||||
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@ -35,10 +35,9 @@ def llm():
|
||||
enforce_eager=True,
|
||||
seed=0)
|
||||
|
||||
with llm.deprecate_legacy_api():
|
||||
yield weakref.proxy(llm)
|
||||
yield weakref.proxy(llm)
|
||||
|
||||
del llm
|
||||
del llm
|
||||
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@ -26,10 +26,9 @@ def llm():
|
||||
enforce_eager=True,
|
||||
seed=0)
|
||||
|
||||
with llm.deprecate_legacy_api():
|
||||
yield weakref.proxy(llm)
|
||||
yield weakref.proxy(llm)
|
||||
|
||||
del llm
|
||||
del llm
|
||||
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@ -5,11 +5,9 @@ import weakref
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, PoolingParams, PoolingRequestOutput
|
||||
from vllm import LLM, PoolingParams
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
|
||||
from ...models.utils import check_embeddings_close
|
||||
|
||||
MODEL_NAME = "intfloat/multilingual-e5-small"
|
||||
|
||||
PROMPTS = [
|
||||
@ -48,57 +46,13 @@ def llm():
|
||||
enforce_eager=True,
|
||||
seed=0)
|
||||
|
||||
with llm.deprecate_legacy_api():
|
||||
yield weakref.proxy(llm)
|
||||
yield weakref.proxy(llm)
|
||||
|
||||
del llm
|
||||
del llm
|
||||
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
def assert_outputs_match(o1: list[PoolingRequestOutput],
|
||||
o2: list[PoolingRequestOutput]):
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=[o.outputs.data for o in o1],
|
||||
embeddings_1_lst=[o.outputs.data for o in o2],
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
tol=1e-2,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS)
|
||||
def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
|
||||
prompt_token_ids):
|
||||
pooling_params = PoolingParams()
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"):
|
||||
v1_output = llm.encode(prompt_token_ids=prompt_token_ids,
|
||||
pooling_params=pooling_params)
|
||||
|
||||
v2_output = llm.encode({"prompt_token_ids": prompt_token_ids},
|
||||
pooling_params=pooling_params)
|
||||
assert_outputs_match(v1_output, v2_output)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
|
||||
pooling_params = PoolingParams()
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"):
|
||||
v1_output = llm.encode(prompt_token_ids=TOKEN_IDS,
|
||||
pooling_params=pooling_params)
|
||||
|
||||
v2_output = llm.encode(
|
||||
[{
|
||||
"prompt_token_ids": p
|
||||
} for p in TOKEN_IDS],
|
||||
pooling_params=pooling_params,
|
||||
)
|
||||
assert_outputs_match(v1_output, v2_output)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_multiple_pooling_params(llm: LLM):
|
||||
pooling_params = [
|
||||
|
||||
@ -5,7 +5,7 @@ import weakref
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, RequestOutput, SamplingParams
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
|
||||
MODEL_NAME = "distilbert/distilgpt2"
|
||||
@ -41,50 +41,13 @@ def llm():
|
||||
gpu_memory_utilization=0.10,
|
||||
enforce_eager=True)
|
||||
|
||||
with llm.deprecate_legacy_api():
|
||||
yield weakref.proxy(llm)
|
||||
yield weakref.proxy(llm)
|
||||
|
||||
del llm
|
||||
del llm
|
||||
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
def assert_outputs_equal(o1: list[RequestOutput], o2: list[RequestOutput]):
|
||||
assert [o.outputs for o in o1] == [o.outputs for o in o2]
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS)
|
||||
def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
|
||||
prompt_token_ids):
|
||||
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"):
|
||||
v1_output = llm.generate(prompt_token_ids=prompt_token_ids,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
v2_output = llm.generate({"prompt_token_ids": prompt_token_ids},
|
||||
sampling_params=sampling_params)
|
||||
assert_outputs_equal(v1_output, v2_output)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
|
||||
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"):
|
||||
v1_output = llm.generate(prompt_token_ids=TOKEN_IDS,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
v2_output = llm.generate(
|
||||
[{
|
||||
"prompt_token_ids": p
|
||||
} for p in TOKEN_IDS],
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
assert_outputs_equal(v1_output, v2_output)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_multiple_sampling_params(llm: LLM):
|
||||
sampling_params = [
|
||||
|
||||
@ -48,10 +48,9 @@ def llm(request, monkeypatch_module):
|
||||
max_num_seqs=128,
|
||||
enforce_eager=True)
|
||||
|
||||
with llm.deprecate_legacy_api():
|
||||
yield weakref.proxy(llm)
|
||||
yield weakref.proxy(llm)
|
||||
|
||||
del llm
|
||||
del llm
|
||||
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@ -36,10 +36,9 @@ def llm():
|
||||
trust_remote_code=True,
|
||||
seed=0)
|
||||
|
||||
with llm.deprecate_legacy_api():
|
||||
yield weakref.proxy(llm)
|
||||
yield weakref.proxy(llm)
|
||||
|
||||
del llm
|
||||
del llm
|
||||
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@ -33,10 +33,9 @@ def llm():
|
||||
enforce_eager=True,
|
||||
seed=0)
|
||||
|
||||
with llm.deprecate_legacy_api():
|
||||
yield weakref.proxy(llm)
|
||||
yield weakref.proxy(llm)
|
||||
|
||||
del llm
|
||||
del llm
|
||||
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@ -64,6 +64,28 @@ async def test_smaller_truncation_size(client: openai.AsyncOpenAI):
|
||||
assert response["usage"]["prompt_tokens"] == truncation_size
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_truncation_size(client: openai.AsyncOpenAI):
|
||||
truncation_size = 0
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": MODEL_NAME,
|
||||
"input": input,
|
||||
"truncate_prompt_tokens": truncation_size
|
||||
}
|
||||
|
||||
with pytest.raises(openai.BadRequestError) as err:
|
||||
await client.post(path="embeddings", cast_to=object, body={**kwargs})
|
||||
|
||||
assert err.value.status_code == 400
|
||||
error_details = err.value.response.json()["error"]
|
||||
|
||||
assert error_details["type"] == "BadRequestError"
|
||||
assert "This model's maximum context length is" in error_details["message"]
|
||||
assert "tokens in the input for embedding generation" in error_details[
|
||||
"message"]
|
||||
assert "Please reduce the length of the input" in error_details["message"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bigger_truncation_size(client: openai.AsyncOpenAI):
|
||||
truncation_size = max_model_len + 1
|
||||
@ -74,18 +96,15 @@ async def test_bigger_truncation_size(client: openai.AsyncOpenAI):
|
||||
}
|
||||
|
||||
with pytest.raises(openai.BadRequestError) as err:
|
||||
err = await client.post(path="embeddings",
|
||||
cast_to=object,
|
||||
body={**kwargs})
|
||||
await client.post(path="embeddings", cast_to=object, body={**kwargs})
|
||||
|
||||
assert str(err) == f"""openai.BadRequestError:
|
||||
Error code: 400 - {{'object': 'error',
|
||||
'message': 'truncate_prompt_tokens value
|
||||
({truncation_size})
|
||||
is greater than max_model_len ({max_model_len}).
|
||||
Please, select a smaller truncation size.',
|
||||
'type': 'BadRequestError',
|
||||
'param': None, 'code': 400}}"""
|
||||
assert err.value.status_code == 400
|
||||
error_details = err.value.response.json()["error"]
|
||||
assert error_details["type"] == "BadRequestError"
|
||||
expected_message = ("truncate_prompt_tokens value is "
|
||||
"greater than max_model_len."
|
||||
" Please, select a smaller truncation size.")
|
||||
assert error_details["message"] == expected_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@ -709,14 +709,15 @@ def test_swap_blocks_mla(
|
||||
@pytest.mark.parametrize("max_seq_len", [512])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("dtype", [torch.float32])
|
||||
@pytest.mark.parametrize("kv_cache_dtype",
|
||||
["auto"]) # You can also test "fp8" if needed.
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
|
||||
num_blocks, max_seq_len, batch_size, dtype,
|
||||
kv_cache_dtype, device):
|
||||
def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim,
|
||||
block_size, num_blocks,
|
||||
max_seq_len, batch_size, dtype,
|
||||
kv_cache_dtype, device):
|
||||
entry_size = kv_lora_rank + qk_rope_head_dim
|
||||
scale = torch.tensor(0.1, dtype=torch.float32, device=device)
|
||||
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
||||
kv_cache_dtype, device)
|
||||
_fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)
|
||||
@ -742,9 +743,7 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
|
||||
perm = torch.randperm(num_blocks, device=device)
|
||||
block_table[b, :] = perm
|
||||
|
||||
dst = torch.zeros((total_tokens, entry_size),
|
||||
dtype=src_cache.dtype,
|
||||
device=device)
|
||||
dst = torch.zeros((total_tokens, entry_size), dtype=dtype, device=device)
|
||||
|
||||
expected_batches = []
|
||||
for b in range(batch_size):
|
||||
@ -756,21 +755,38 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
|
||||
|
||||
gathered_rows = []
|
||||
for i in range(tot - 1):
|
||||
gathered_rows.append(src_cache[blocks[i]])
|
||||
block_data = src_cache[blocks[i]]
|
||||
if kv_cache_dtype == "fp8":
|
||||
dequantized_block = torch.empty_like(block_data, dtype=dtype)
|
||||
ops.convert_fp8(dequantized_block, block_data, scale.item())
|
||||
gathered_rows.append(dequantized_block)
|
||||
else:
|
||||
gathered_rows.append(block_data)
|
||||
remaining = s - (tot - 1) * block_size
|
||||
gathered_rows.append(src_cache[blocks[-1], :remaining, :])
|
||||
last_block_data = src_cache[blocks[-1], :remaining, :]
|
||||
if kv_cache_dtype == "fp8":
|
||||
dequantized_last_block = torch.empty_like(last_block_data,
|
||||
dtype=dtype)
|
||||
ops.convert_fp8(dequantized_last_block, last_block_data,
|
||||
scale.item())
|
||||
gathered_rows.append(dequantized_last_block)
|
||||
else:
|
||||
gathered_rows.append(last_block_data)
|
||||
|
||||
batch_expected = torch.cat(gathered_rows, dim=0)
|
||||
expected_batches.append(batch_expected)
|
||||
expected = torch.cat(expected_batches, dim=0)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C_cache_ops.gather_cache,
|
||||
(src_cache, dst, block_table, cu_seq_lens, batch_size, None),
|
||||
torch.ops._C_cache_ops.gather_and_maybe_dequant_cache,
|
||||
(src_cache, dst, block_table, cu_seq_lens, batch_size, kv_cache_dtype,
|
||||
scale, None),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
)
|
||||
|
||||
ops.gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size)
|
||||
ops.gather_and_maybe_dequant_cache(src_cache, dst, block_table,
|
||||
cu_seq_lens, batch_size, kv_cache_dtype,
|
||||
scale, None)
|
||||
torch.testing.assert_close(dst, expected)
|
||||
|
||||
|
||||
|
||||
@ -6,7 +6,11 @@ import flashinfer
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
|
||||
FLOAT8_E4M3_MAX,
|
||||
dequantize_nvfp4_to_dtype)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import round_up
|
||||
|
||||
if not current_platform.is_device_capability(100):
|
||||
pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.",
|
||||
@ -14,6 +18,7 @@ if not current_platform.is_device_capability(100):
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
|
||||
def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||
@ -29,7 +34,9 @@ DTYPE = [torch.bfloat16]
|
||||
QUANT_DTYPES = [
|
||||
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
|
||||
(None, None, None),
|
||||
(None, FP8_DTYPE, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
|
||||
]
|
||||
BATCH_SIZE = [4, 12]
|
||||
MAX_SEQ_LENS = [(1024, 4096)]
|
||||
@ -153,11 +160,25 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
||||
output = torch.empty(ref_query.shape, dtype=dtype)
|
||||
wrapper.run(ref_query, ref_kv_cache, out=output)
|
||||
o_scale = 1.0
|
||||
o_sf_scale = None
|
||||
if o_quant_dtype == FP8_DTYPE:
|
||||
_, o_scale = to_float8(output)
|
||||
elif o_quant_dtype == FP4_DTYPE:
|
||||
o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
|
||||
torch.amax(output.flatten(), dim=-1)).to(torch.float32)
|
||||
|
||||
# TRTLLM Decode
|
||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||
if o_quant_dtype == FP4_DTYPE:
|
||||
output_trtllm = flashinfer.utils.FP4Tensor(
|
||||
torch.empty(query.shape[:-1] + (query.shape[-1] // 2, ),
|
||||
dtype=torch.uint8),
|
||||
torch.empty((round_up(query.shape[0], 128),
|
||||
round_up(query.shape[1] * query.shape[2] // 16, 4)),
|
||||
dtype=torch.float8_e4m3fn),
|
||||
)
|
||||
else:
|
||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||
|
||||
flashinfer.decode.trtllm_batch_decode_with_kv_cache(
|
||||
query=query,
|
||||
kv_cache=kv_cache,
|
||||
@ -167,15 +188,27 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
||||
max_seq_len=max_seq_len,
|
||||
bmm1_scale=q_scale * k_scale * sm_scale,
|
||||
bmm2_scale=v_scale / o_scale,
|
||||
o_sf_scale=o_sf_scale,
|
||||
out=output_trtllm,
|
||||
)
|
||||
if o_quant_dtype == FP8_DTYPE:
|
||||
output_trtllm = output_trtllm.to(dtype) * o_scale
|
||||
elif o_quant_dtype == FP4_DTYPE:
|
||||
output_trtllm.data = output_trtllm.data.reshape(
|
||||
-1, query.shape[1] * query.shape[2] // 2)
|
||||
output_trtllm = dequantize_nvfp4_to_dtype(output_trtllm.data,
|
||||
output_trtllm.scale,
|
||||
o_sf_scale, dtype,
|
||||
query.device)
|
||||
output_trtllm = output_trtllm.reshape(-1, query.shape[1],
|
||||
query.shape[2])
|
||||
|
||||
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
|
||||
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
|
||||
rtol, atol = 3e-1, 1e0
|
||||
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
|
||||
rtol, atol = 5e-2, 7e-2
|
||||
else:
|
||||
rtol, atol = 1e-2, 1e-2
|
||||
rtol, atol = 1e-2, 2e-2
|
||||
|
||||
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \
|
||||
f"{torch.max(torch.abs(output - output_trtllm))}"
|
||||
@ -211,6 +244,9 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
kv_quant_dtype = kv_quant_dtype or dtype
|
||||
o_quant_dtype = o_quant_dtype or dtype
|
||||
|
||||
if q_quant_dtype != kv_quant_dtype:
|
||||
pytest.skip("Skipped mixed QKV dtypes for prefill")
|
||||
|
||||
max_q_len, max_kv_len = max_seq_lens
|
||||
|
||||
num_qo_heads, num_kv_heads = num_heads
|
||||
@ -303,11 +339,25 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
output = torch.empty(ref_query.shape, dtype=dtype)
|
||||
wrapper.run(ref_query, ref_kv_cache, out=output)
|
||||
o_scale = 1.0
|
||||
o_sf_scale = None
|
||||
if o_quant_dtype == FP8_DTYPE:
|
||||
_, o_scale = to_float8(output)
|
||||
elif o_quant_dtype == FP4_DTYPE:
|
||||
o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
|
||||
torch.amax(output.flatten(), dim=-1)).to(torch.float32)
|
||||
|
||||
# TRTLLM Prefill
|
||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||
if o_quant_dtype == FP4_DTYPE:
|
||||
output_trtllm = flashinfer.utils.FP4Tensor(
|
||||
torch.empty(query.shape[:-1] + (query.shape[-1] // 2, ),
|
||||
dtype=torch.uint8),
|
||||
torch.empty((round_up(query.shape[0], 128),
|
||||
round_up(query.shape[1] * query.shape[2] // 16, 4)),
|
||||
dtype=torch.float8_e4m3fn),
|
||||
)
|
||||
else:
|
||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||
|
||||
flashinfer.prefill.trtllm_batch_context_with_kv_cache(
|
||||
query=query,
|
||||
kv_cache=kv_cache,
|
||||
@ -321,12 +371,24 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
batch_size=batch_size,
|
||||
cum_seq_lens_q=q_indptr,
|
||||
cum_seq_lens_kv=kv_indptr,
|
||||
o_sf_scale=o_sf_scale,
|
||||
out=output_trtllm,
|
||||
)
|
||||
if o_quant_dtype == FP8_DTYPE:
|
||||
output_trtllm = output_trtllm.to(dtype) * o_scale
|
||||
elif o_quant_dtype == FP4_DTYPE:
|
||||
output_trtllm.data = output_trtllm.data.reshape(
|
||||
-1, query.shape[1] * query.shape[2] // 2)
|
||||
output_trtllm = dequantize_nvfp4_to_dtype(output_trtllm.data,
|
||||
output_trtllm.scale,
|
||||
o_sf_scale, dtype,
|
||||
query.device)
|
||||
output_trtllm = output_trtllm.reshape(-1, query.shape[1],
|
||||
query.shape[2])
|
||||
|
||||
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
|
||||
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
|
||||
rtol, atol = 4e-1, 1e0
|
||||
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
|
||||
rtol, atol = 5e-2, 7e-2
|
||||
else:
|
||||
rtol, atol = 1e-2, 1e-2
|
||||
|
||||
@ -13,11 +13,17 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
|
||||
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
|
||||
def cal_diff(x: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
name: str,
|
||||
use_fp8: bool = False) -> None:
|
||||
x, y = x.double(), y.double()
|
||||
cos_diff = 1 - 2 * (x * y).sum().item() / max(
|
||||
(x * x + y * y).sum().item(), 1e-12)
|
||||
assert cos_diff < 1e-5
|
||||
if (use_fp8):
|
||||
assert cos_diff < 1e-4
|
||||
else:
|
||||
assert cos_diff < 1e-5
|
||||
|
||||
FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
|
||||
if not is_flashmla_supported()[0] else "FlashMLA is supported"
|
||||
@ -27,7 +33,7 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
|
||||
reason=FLASH_MLA_UNSUPPORTED_REASON)
|
||||
@pytest.mark.parametrize("b", [128])
|
||||
@pytest.mark.parametrize("s_q", [1, 2])
|
||||
@pytest.mark.parametrize("mean_sk", [4096, 8192])
|
||||
@pytest.mark.parametrize("mean_sk", [4096, 8192, 16384])
|
||||
@pytest.mark.parametrize("h_q", [16, 32, 64, 128])
|
||||
@pytest.mark.parametrize("h_kv", [1])
|
||||
@pytest.mark.parametrize("d", [576])
|
||||
@ -35,20 +41,26 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
|
||||
@pytest.mark.parametrize("block_size", [64])
|
||||
@pytest.mark.parametrize("causal", [True])
|
||||
@pytest.mark.parametrize("varlen", [False, True])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("torch_dtype",
|
||||
[torch.bfloat16, torch.float16, torch.float8_e4m3fn])
|
||||
@torch.inference_mode()
|
||||
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
|
||||
varlen, dtype):
|
||||
varlen, torch_dtype):
|
||||
device = torch.device("cuda:0")
|
||||
torch.set_default_dtype(dtype)
|
||||
if torch_dtype == torch.float8_e4m3fn:
|
||||
init_dtype = torch.bfloat16
|
||||
else:
|
||||
init_dtype = torch_dtype
|
||||
torch.set_default_dtype(init_dtype)
|
||||
torch.set_default_device(device)
|
||||
torch.cuda.set_device(device)
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
|
||||
print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, "
|
||||
f"{d=}, {dv=}, {causal=}, {varlen=}, {dtype=}")
|
||||
f"{d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}")
|
||||
|
||||
use_fp8 = torch_dtype == torch.float8_e4m3fn
|
||||
cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32)
|
||||
if varlen:
|
||||
for i in range(b):
|
||||
@ -71,6 +83,19 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
|
||||
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
||||
cache_seqlens, s_q * h_q // h_kv, h_kv)
|
||||
|
||||
init_dtype = q.dtype
|
||||
if use_fp8:
|
||||
fp8_dtype = torch.float8_e4m3fn
|
||||
descale_q = torch.ones((1), dtype=torch.float32)
|
||||
descale_k = torch.ones((1), dtype=torch.float32)
|
||||
|
||||
q = q.to(fp8_dtype)
|
||||
blocked_k = blocked_k.to(fp8_dtype)
|
||||
blocked_v = blocked_v.to(fp8_dtype)
|
||||
else:
|
||||
descale_q = None
|
||||
descale_k = None
|
||||
|
||||
def flash_mla():
|
||||
return flash_mla_with_kvcache(
|
||||
q,
|
||||
@ -81,6 +106,8 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
causal=causal,
|
||||
descale_q=descale_q,
|
||||
descale_k=descale_k,
|
||||
)
|
||||
|
||||
def scaled_dot_product_attention(query, key, value, is_causal=False):
|
||||
@ -104,29 +131,35 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
|
||||
return attn_weight @ value, lse
|
||||
|
||||
def ref_mla():
|
||||
q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q
|
||||
blocked_k_ = (blocked_k.to(torch.float) *
|
||||
descale_k).to(init_dtype) if use_fp8 else blocked_k
|
||||
blocked_v_ = (blocked_v.to(torch.float) *
|
||||
descale_k).to(init_dtype) if use_fp8 else blocked_v
|
||||
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
|
||||
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
|
||||
for i in range(b):
|
||||
begin = i * max_seqlen_pad
|
||||
end = begin + cache_seqlens[i]
|
||||
ref_O, LSE = scaled_dot_product_attention(
|
||||
q[i].transpose(0, 1),
|
||||
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
|
||||
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
|
||||
out_i, lse_i = scaled_dot_product_attention(
|
||||
q_[i].transpose(0, 1),
|
||||
blocked_k_.view(-1, h_kv, d)[begin:end].transpose(0, 1),
|
||||
blocked_v_.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
|
||||
is_causal=causal,
|
||||
)
|
||||
out[i] = ref_O.transpose(0, 1)
|
||||
lse[i] = LSE
|
||||
out[i] = out_i.transpose(0, 1)
|
||||
lse[i] = lse_i
|
||||
return out, lse
|
||||
|
||||
out_flash, lse_flash = flash_mla()
|
||||
out_torch, lse_torch = ref_mla()
|
||||
cal_diff(out_flash, out_torch, "out")
|
||||
cal_diff(out_flash, out_torch, "out", use_fp8)
|
||||
cal_diff(lse_flash, lse_torch, "lse")
|
||||
|
||||
t = triton.testing.do_bench(flash_mla)
|
||||
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
|
||||
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d +
|
||||
b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
|
||||
print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} "
|
||||
f"TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s")
|
||||
bytes = (total_seqlens * h_kv * d +
|
||||
b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + (
|
||||
b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8)
|
||||
print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS,",
|
||||
f"{bytes / 10 ** 6 / t:.0f} GB/s")
|
||||
|
||||
76
tests/kernels/moe/test_grouped_topk.py
Normal file
76
tests/kernels/moe/test_grouped_topk.py
Normal file
@ -0,0 +1,76 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for the MoE grouped topk kernel
|
||||
|
||||
Run `pytest tests/kernels/moe/test_grouped_topk.py`.
|
||||
"""
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_grouped_topk,
|
||||
grouped_topk)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(),
|
||||
reason="This test is skipped on non-CUDA platform.")
|
||||
@pytest.mark.parametrize("n_token", [1, 33, 64])
|
||||
@pytest.mark.parametrize("n_hidden", [1024, 2048])
|
||||
@pytest.mark.parametrize("n_expert", [16])
|
||||
@pytest.mark.parametrize("topk", [2])
|
||||
@pytest.mark.parametrize("renormalize", [True, False])
|
||||
@pytest.mark.parametrize("num_expert_group", [8])
|
||||
@pytest.mark.parametrize("topk_group", [2])
|
||||
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
|
||||
@pytest.mark.parametrize("routed_scaling_factor", [1.0, 2.5])
|
||||
@pytest.mark.parametrize("dtype",
|
||||
[torch.float16, torch.bfloat16, torch.float32])
|
||||
def test_grouped_topk(monkeypatch: pytest.MonkeyPatch, n_token: int,
|
||||
n_hidden: int, n_expert: int, topk: int,
|
||||
renormalize: bool, num_expert_group: int,
|
||||
topk_group: int, scoring_func: str,
|
||||
routed_scaling_factor: float, dtype: torch.dtype):
|
||||
current_platform.seed_everything(0)
|
||||
hidden_states = torch.randn((n_token, n_hidden),
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
gating_output = torch.randn((n_token, n_expert),
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
e_score_correction_bias = torch.randn((n_expert, ),
|
||||
dtype=torch.float32,
|
||||
device="cuda")
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0")
|
||||
baseline_topk_weights, baseline_topk_ids = grouped_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=gating_output,
|
||||
topk=topk,
|
||||
renormalize=renormalize,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
test_topk_weights, test_topk_ids = fused_grouped_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=gating_output,
|
||||
topk=topk,
|
||||
renormalize=renormalize,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
if renormalize:
|
||||
torch.testing.assert_close(baseline_topk_weights,
|
||||
test_topk_weights,
|
||||
atol=2e-2,
|
||||
rtol=0)
|
||||
torch.testing.assert_close(baseline_topk_ids,
|
||||
test_topk_ids,
|
||||
atol=0,
|
||||
rtol=0)
|
||||
@ -429,11 +429,11 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
|
||||
vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[...,
|
||||
0:-128],
|
||||
requires_grad=False)
|
||||
torch.cuda.empty_cache()
|
||||
vllm_moe.experts.w2_weight = Parameter(F.pad(
|
||||
vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[...,
|
||||
0:-128],
|
||||
requires_grad=False)
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Run forward passes for both MoE blocks
|
||||
|
||||
259
tests/kernels/quantization/test_cutlass_w4a8.py
Normal file
259
tests/kernels/quantization/test_cutlass_w4a8.py
Normal file
@ -0,0 +1,259 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for the CUTLASS W4A8 kernel.
|
||||
|
||||
Run `pytest tests/kernels/test_cutlass_w4a8.py`.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
pack_rows, quantize_weights)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
|
||||
# unit tests to a common utility function. Currently the use of
|
||||
# `is_quant_method_supported` conflates kernels with quantization methods
|
||||
# an assumption which is breaking down as quantizations methods can have
|
||||
# have kernels and some kernels support multiple quantization methods.
|
||||
IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9
|
||||
|
||||
MNK_SHAPES = [(1, 128, 128), (1, 512, 1024), (1, 4096, 4096), (1, 8192, 28672),
|
||||
(13, 8192, 4096), (26, 4096, 8192), (64, 4096, 4096),
|
||||
(64, 8192, 28672), (257, 128, 4096), (257, 4096, 4096),
|
||||
(1024, 4096, 8192), (1024, 8192, 4096)]
|
||||
|
||||
# TODO(czhu): get supported schedules from fn
|
||||
SCHEDULES = [
|
||||
'128x16_1x1x1', '256x16_1x1x1', '128x32_1x1x1', '256x32_1x1x1',
|
||||
'128x64_1x1x1', '256x64_1x1x1', '128x128_1x1x1', '256x128_1x1x1',
|
||||
'128x256_1x1x1', '128x256_2x1x1'
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TypeConfig:
|
||||
act_type: torch.dtype
|
||||
weight_type: ScalarType
|
||||
output_type: Optional[torch.dtype]
|
||||
group_scale_type: Optional[torch.dtype]
|
||||
channel_scale_type: Optional[torch.dtype]
|
||||
token_scale_type: Optional[torch.dtype]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Tensors:
|
||||
w_ref: torch.Tensor
|
||||
a_ref: torch.Tensor
|
||||
a: torch.Tensor
|
||||
w_q: torch.Tensor
|
||||
w_g_s: torch.Tensor
|
||||
w_ch_s: torch.Tensor
|
||||
w_tok_s: torch.Tensor
|
||||
|
||||
|
||||
# (Act Type, Weight Type, Output Type, Scale Type, ZeroPoints,
|
||||
# Ch Scales Type, Tok Scales Type)
|
||||
TestTypeTuple = tuple[list[torch.dtype], ScalarType, Optional[torch.dtype],
|
||||
Optional[torch.dtype], bool]
|
||||
TEST_TYPES = [
|
||||
*(
|
||||
TypeConfig(act_type=torch.float8_e4m3fn,
|
||||
weight_type=w_type,
|
||||
output_type=o_type,
|
||||
group_scale_type=torch.float8_e4m3fn,
|
||||
channel_scale_type=torch.float32,
|
||||
token_scale_type=torch.float32)
|
||||
for w_type in [scalar_types.int4]
|
||||
# TODO(czhu): fp16 out type
|
||||
for o_type in [torch.bfloat16]),
|
||||
]
|
||||
|
||||
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
|
||||
# unit tests to a common utility function. Currently the use of
|
||||
# `is_quant_method_supported` conflates kernels with quantization methods
|
||||
# an assumption which is breaking down as quantizations methods can have
|
||||
# have kernels and some kernels support multiple quantization methods.
|
||||
IS_SUPPORTED_BY_GPU = current_platform.has_device_capability(90)
|
||||
|
||||
|
||||
# For testing quantized linear kernels
|
||||
def to_fp8(tensor: torch.Tensor):
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
return tensor.clamp(min=finfo.min,
|
||||
max=finfo.max).to(dtype=torch.float8_e4m3fn)
|
||||
|
||||
|
||||
def cutlass_quantize_and_pack(atype: torch.dtype,
|
||||
w: torch.Tensor,
|
||||
wtype: ScalarType,
|
||||
stype: Optional[torch.dtype],
|
||||
group_size: Optional[int],
|
||||
zero_points: bool = False):
|
||||
assert wtype.is_integer(), "TODO: support floating point weights"
|
||||
|
||||
w_ref, w_q, w_s, w_zp = quantize_weights(w,
|
||||
wtype,
|
||||
group_size=group_size,
|
||||
zero_points=zero_points)
|
||||
|
||||
# since scales are cast to fp8, we need to compute w_ref this way
|
||||
w_ref = ((w_q).to(torch.float32) * w_s.to(atype).to(
|
||||
torch.float32).repeat_interleave(group_size, dim=0)).to(atype)
|
||||
|
||||
# bit mask prevents sign extending int4 when packing
|
||||
w_q = pack_rows(w_q & 0x0F, wtype.size_bits, *w_q.shape)
|
||||
w_q = w_q.t().contiguous().t() # convert to col major
|
||||
|
||||
w_q_packed = ops.cutlass_encode_and_reorder_int4b(w_q)
|
||||
w_s_packed = ops.cutlass_pack_scale_fp8(w_s.to(atype))
|
||||
|
||||
return w_ref, w_q_packed, w_s_packed, w_zp
|
||||
|
||||
|
||||
def create_test_tensors(shape: tuple[int, int, int], types: TypeConfig,
|
||||
group_size: Optional[int]) -> Tensors:
|
||||
m, n, k = shape
|
||||
|
||||
print("create_test_tensors, shape:", shape, "types:", types, "group_size:",
|
||||
group_size)
|
||||
|
||||
a = to_fp8(torch.randn((m, k), device="cuda"))
|
||||
w = to_fp8(torch.randn((k, n), device="cuda"))
|
||||
|
||||
if types.group_scale_type is not None:
|
||||
w = w.to(types.group_scale_type)
|
||||
if w.dtype.itemsize == 1:
|
||||
w = w.to(torch.float16)
|
||||
|
||||
w_ref, w_q_packed, w_s, _ = cutlass_quantize_and_pack(
|
||||
a.dtype, w, types.weight_type, types.group_scale_type, group_size,
|
||||
False)
|
||||
|
||||
a_ref = a.to(torch.float32)
|
||||
w_ref = w_ref.to(torch.float32)
|
||||
|
||||
# for the practical use case we need per-tok scales for fp8 activations
|
||||
w_tok_s = torch.randn((m, ), device='cuda', dtype=types.token_scale_type)
|
||||
# weights are already per-group quantized, use placeholder here
|
||||
w_ch_s = torch.ones((n, ), device='cuda', dtype=types.channel_scale_type)
|
||||
|
||||
return Tensors(w_ref=w_ref,
|
||||
a_ref=a_ref,
|
||||
a=a,
|
||||
w_q=w_q_packed,
|
||||
w_g_s=w_s,
|
||||
w_ch_s=w_ch_s,
|
||||
w_tok_s=w_tok_s)
|
||||
|
||||
|
||||
def mm_test_helper(types: TypeConfig,
|
||||
tensors: Tensors,
|
||||
group_size: Optional[int] = None,
|
||||
schedule: Optional[str] = None):
|
||||
# CUTLASS upstream uses fp8 with fastaccum as reference
|
||||
# https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu#L406
|
||||
output_ref = torch._scaled_mm(
|
||||
tensors.a_ref.to(types.act_type),
|
||||
tensors.w_ref.to(types.act_type).t().contiguous().t(), # col major
|
||||
tensors.w_tok_s.unsqueeze(1),
|
||||
tensors.w_ch_s.unsqueeze(0),
|
||||
out_dtype=types.output_type,
|
||||
use_fast_accum=True)
|
||||
|
||||
output = ops.cutlass_w4a8_mm(
|
||||
a=tensors.a,
|
||||
b_q=tensors.w_q,
|
||||
b_group_scales=tensors.w_g_s,
|
||||
b_group_size=group_size,
|
||||
b_channel_scales=tensors.w_ch_s,
|
||||
a_token_scales=tensors.w_tok_s,
|
||||
)
|
||||
|
||||
print(output)
|
||||
print(output_ref)
|
||||
|
||||
torch.testing.assert_close(output,
|
||||
output_ref.to(output.dtype),
|
||||
rtol=1e-3,
|
||||
atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
|
||||
reason="CUTLASS W4A8 is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("shape",
|
||||
MNK_SHAPES,
|
||||
ids=lambda x: "x".join(str(v) for v in x))
|
||||
@pytest.mark.parametrize("types", TEST_TYPES)
|
||||
@pytest.mark.parametrize("schedule", SCHEDULES)
|
||||
def test_cutlass_w4a8(shape, types: TypeConfig, schedule):
|
||||
group_sizes = [128]
|
||||
for group_size in group_sizes:
|
||||
tensors = create_test_tensors(shape, types, group_size)
|
||||
mm_test_helper(types, tensors, group_size, schedule)
|
||||
|
||||
|
||||
# Test to make sure cuda graphs work
|
||||
class W4A8Layer(torch.nn.Module):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
self.kwargs = kwargs
|
||||
|
||||
def forward(self, a):
|
||||
return ops.cutlass_w4a8_mm(a=a, **self.kwargs)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
|
||||
reason="CUTLASS W4A8 is not supported on this GPU type.")
|
||||
def test_w4a8_cuda_graph():
|
||||
m, n, k = 512, 4096, 4096
|
||||
|
||||
a = to_fp8(torch.randn((m, k), device="cuda"))
|
||||
b = to_fp8(torch.randn((k, n), device="cuda"))
|
||||
|
||||
wtype = scalar_types.int4
|
||||
stype = torch.float8_e4m3fn
|
||||
group_size = 128
|
||||
zero_points = False
|
||||
|
||||
w_ref, w_q_packed, w_s, _ = cutlass_quantize_and_pack(
|
||||
a.dtype, b.to(torch.float16), wtype, stype, group_size, zero_points)
|
||||
|
||||
w_tok_s = torch.randn((m, ), device='cuda', dtype=torch.float32)
|
||||
w_ch_s = torch.ones((n, ), device='cuda', dtype=torch.float32)
|
||||
|
||||
# Construct a trivial model with a single layer that calls the kernel
|
||||
model = W4A8Layer(
|
||||
b_q=w_q_packed,
|
||||
b_group_scales=w_s,
|
||||
b_group_size=group_size,
|
||||
b_channel_scales=w_ch_s,
|
||||
a_token_scales=w_tok_s,
|
||||
)
|
||||
|
||||
output_ref = torch._scaled_mm(
|
||||
a,
|
||||
w_ref.to(a.dtype).t().contiguous().t(), # col major
|
||||
w_tok_s.unsqueeze(1),
|
||||
w_ch_s.unsqueeze(0),
|
||||
out_dtype=torch.bfloat16,
|
||||
use_fast_accum=True)
|
||||
|
||||
# Run the model with a cuda graph
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
output = model(a)
|
||||
|
||||
output.zero_()
|
||||
g.replay()
|
||||
|
||||
torch.testing.assert_close(output, output_ref, rtol=1e-3, atol=1e-3)
|
||||
@ -9,12 +9,17 @@ import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from vllm import SamplingParams
|
||||
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
|
||||
create_standard_kv_cache_spec,
|
||||
create_vllm_config)
|
||||
from vllm.v1.attention.backends.flex_attention import (
|
||||
FlexAttentionMetadataBuilder)
|
||||
|
||||
from ..models.utils import check_embeddings_close
|
||||
from ..models.utils import check_embeddings_close, check_logprobs_close
|
||||
|
||||
TORCH_VERSION = version.parse(torch.__version__)
|
||||
MINIMUM_TORCH_VERSION = version.parse("2.7.0")
|
||||
DIRECT_BUILD_VERSION = version.parse("2.9.dev0")
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
@ -34,22 +39,18 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
|
||||
"""Test that FlexAttention produces the same outputs as the default backend.
|
||||
|
||||
This test compares the outputs from the FlexAttention backend with
|
||||
the default backend, ensuring they are identical when using the same seed.
|
||||
the default backend, ensuring they are similar when using the same seed.
|
||||
"""
|
||||
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
seed = 42
|
||||
max_tokens = 24
|
||||
num_logprobs = 5
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.0,
|
||||
top_p=1.0,
|
||||
seed=seed,
|
||||
max_tokens=max_tokens)
|
||||
|
||||
# Run with flex attention
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
@ -61,7 +62,8 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
|
||||
tensor_parallel_size=1,
|
||||
num_gpu_blocks_override=128,
|
||||
enforce_eager=True) as llm_flex:
|
||||
output_flex = llm_flex.generate(prompts, sampling_params)
|
||||
output_flex = llm_flex.generate_greedy_logprobs(
|
||||
prompts, max_tokens, num_logprobs)
|
||||
|
||||
# Run with default backend
|
||||
with monkeypatch.context() as m:
|
||||
@ -71,20 +73,17 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
|
||||
runner="generate",
|
||||
tensor_parallel_size=1,
|
||||
num_gpu_blocks_override=128,
|
||||
enforce_eager=True) as llm_default:
|
||||
output_default = llm_default.generate(prompts, sampling_params)
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.85) as llm_default:
|
||||
output_default = llm_default.generate_greedy_logprobs(
|
||||
prompts, max_tokens, num_logprobs)
|
||||
|
||||
# Compare outputs from both backends
|
||||
for i, (flex_result,
|
||||
default_result) in enumerate(zip(output_flex, output_default)):
|
||||
prompt = prompts[i]
|
||||
flex_text = flex_result[1][0]
|
||||
default_text = default_result[1][0]
|
||||
|
||||
assert flex_text == default_text, (
|
||||
f"FlexAttention output doesn't match default for: {prompt!r}\n"
|
||||
f"FlexAttention: {flex_text!r}\n"
|
||||
f"Default: {default_text!r}")
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=output_flex,
|
||||
outputs_1_lst=output_default,
|
||||
name_0="flex",
|
||||
name_1="default",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
@ -136,5 +135,70 @@ def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available() or TORCH_VERSION < DIRECT_BUILD_VERSION,
|
||||
reason="CUDA not available or PyTorch version < 2.7",
|
||||
)
|
||||
def test_block_mask_direct_vs_slow_path():
|
||||
"""Test that direct path block mask is a superset of slow path.
|
||||
|
||||
The direct path may include extra blocks for performance (over-estimation),
|
||||
but must include all blocks that the slow path determines are necessary.
|
||||
"""
|
||||
device = torch.device("cuda")
|
||||
|
||||
vllm_config = create_vllm_config(model_name="meta-llama/Meta-Llama-3-8B",
|
||||
block_size=16,
|
||||
max_model_len=1024)
|
||||
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
|
||||
|
||||
# Use a mixed batch that will create groups spanning multiple sequences
|
||||
batch_spec = BatchSpec(seq_lens=[35, 64, 128, 256],
|
||||
query_lens=[33, 5, 32, 64],
|
||||
name="test_mixed_batch")
|
||||
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec, vllm_config.cache_config.block_size, device)
|
||||
|
||||
builder = FlexAttentionMetadataBuilder(kv_cache_spec, [], vllm_config,
|
||||
device)
|
||||
|
||||
metadata_direct = builder.build(common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata)
|
||||
builder.direct_build = False
|
||||
metadata_slow = builder.build(common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata)
|
||||
|
||||
assert metadata_direct.block_mask is not None
|
||||
assert metadata_slow.block_mask is not None
|
||||
|
||||
# Extract block indices for comparison, B, H are the same
|
||||
direct_indices = metadata_direct.block_mask.kv_indices[0, 0]
|
||||
slow_indices = metadata_slow.block_mask.kv_indices[0, 0]
|
||||
direct_num = metadata_direct.block_mask.kv_num_blocks[0, 0]
|
||||
slow_num = metadata_slow.block_mask.kv_num_blocks[0, 0]
|
||||
|
||||
# main test: every block needed by slow path must be in direct path
|
||||
num_groups = direct_num.shape[0]
|
||||
all_contained = True
|
||||
missing_details = []
|
||||
|
||||
for group_idx in range(num_groups):
|
||||
direct_blocks = set(
|
||||
direct_indices[group_idx, :direct_num[group_idx]].tolist())
|
||||
slow_blocks = set(
|
||||
slow_indices[group_idx, :slow_num[group_idx]].tolist())
|
||||
|
||||
missing_blocks = slow_blocks - direct_blocks
|
||||
if missing_blocks:
|
||||
all_contained = False
|
||||
missing_details.append(
|
||||
f"Group {group_idx}: missing {sorted(missing_blocks)}")
|
||||
|
||||
assert all_contained, (
|
||||
"Direct path is missing blocks required by slow path:\n" +
|
||||
"\n".join(missing_details))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
|
||||
@ -3,15 +3,13 @@
|
||||
|
||||
import tempfile
|
||||
from collections import OrderedDict
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
import vllm
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.distributed import (cleanup_dist_env_and_memory,
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
@ -21,7 +19,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models.interfaces import SupportsLoRA
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@ -104,6 +101,7 @@ def dummy_model() -> nn.Module:
|
||||
]))
|
||||
model.config = MagicMock()
|
||||
model.embedding_modules = {"lm_head": "lm_head"}
|
||||
model.unpadded_vocab_size = 32000
|
||||
return model
|
||||
|
||||
|
||||
@ -137,6 +135,8 @@ def dummy_model_gate_up() -> nn.Module:
|
||||
],
|
||||
}
|
||||
model.embedding_modules = {"lm_head": "lm_head"}
|
||||
model.unpadded_vocab_size = 32000
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@ -221,29 +221,6 @@ def phi2_lora_files():
|
||||
return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llama_2_7b_engine_extra_embeddings():
|
||||
cleanup_dist_env_and_memory(shutdown_ray=True)
|
||||
get_model_old = get_model
|
||||
|
||||
def get_model_patched(**kwargs):
|
||||
kwargs["vllm_config"].lora_config = LoRAConfig(max_loras=4,
|
||||
max_lora_rank=8)
|
||||
return get_model_old(**kwargs)
|
||||
|
||||
with patch("vllm.worker.model_runner.get_model", get_model_patched):
|
||||
engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False)
|
||||
yield engine.llm_engine
|
||||
del engine
|
||||
cleanup_dist_env_and_memory(shutdown_ray=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llama_2_7b_model_extra_embeddings(llama_2_7b_engine_extra_embeddings):
|
||||
yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker.
|
||||
model_runner.model)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reset_default_device():
|
||||
"""
|
||||
|
||||
@ -5,7 +5,6 @@ import time
|
||||
|
||||
import pytest
|
||||
|
||||
import vllm.envs as env
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.entrypoints.openai.api_server import (
|
||||
build_async_engine_client_from_engine_args)
|
||||
@ -98,12 +97,10 @@ async def test_add_lora(chatglm3_lora_files):
|
||||
# Run with warmup
|
||||
add_lora_tasks = [llm.add_lora(lr) for lr in warmup_run_requests]
|
||||
add_lora_results = await asyncio.gather(*add_lora_tasks)
|
||||
if env.VLLM_USE_V1:
|
||||
# Test that all all_lora calls are successful.
|
||||
assert all(add_lora_results)
|
||||
else:
|
||||
# No way to check V0 engine results as the calls just return None.
|
||||
pass
|
||||
|
||||
# Test that all all_lora calls are successful.
|
||||
assert all(add_lora_results)
|
||||
|
||||
time_with_add_lora = await requests_processing_time(
|
||||
llm, warmup_run_requests)
|
||||
|
||||
|
||||
@ -113,8 +113,7 @@ def test_llama_lora(sql_lora_files):
|
||||
enable_lora=True,
|
||||
# also test odd max_num_seqs
|
||||
max_num_seqs=13,
|
||||
max_loras=4,
|
||||
enable_chunked_prefill=True)
|
||||
max_loras=4)
|
||||
generate_and_test(llm, sql_lora_files)
|
||||
|
||||
|
||||
@ -128,7 +127,6 @@ def test_llama_lora_tp4(sql_lora_files):
|
||||
max_num_seqs=16,
|
||||
max_loras=4,
|
||||
tensor_parallel_size=4,
|
||||
enable_chunked_prefill=True,
|
||||
)
|
||||
generate_and_test(llm, sql_lora_files)
|
||||
|
||||
@ -144,7 +142,6 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files):
|
||||
max_loras=4,
|
||||
tensor_parallel_size=4,
|
||||
fully_sharded_loras=True,
|
||||
enable_chunked_prefill=True,
|
||||
)
|
||||
generate_and_test(llm, sql_lora_files)
|
||||
|
||||
|
||||
@ -21,6 +21,8 @@ from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
|
||||
WorkerLoRAManager)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .utils import create_peft_lora
|
||||
|
||||
EMBEDDING_MODULES = {
|
||||
"embed_tokens": "input_embeddings",
|
||||
"lm_head": "output_embeddings",
|
||||
@ -35,17 +37,6 @@ DEVICES = ([
|
||||
DEFAULT_DTYPE = torch.get_default_dtype()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
Some tests depend on V0 internals. Since both V0 and V1 use the same
|
||||
LoRAModelManager it is okay to just test V0.
|
||||
"""
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv('VLLM_USE_V1', '0')
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_from_lora_tensors(sql_lora_files, device):
|
||||
tensors = load_file(
|
||||
@ -326,7 +317,6 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
|
||||
max_loras=2,
|
||||
lora_dtype=DEFAULT_DTYPE),
|
||||
device=device)
|
||||
|
||||
assert all(x is None for x in manager.lora_index_to_id)
|
||||
|
||||
# Add up to capacity
|
||||
@ -430,32 +420,40 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
sql_lora_files, device):
|
||||
def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device,
|
||||
tmp_path):
|
||||
lora_config = LoRAConfig(max_lora_rank=8,
|
||||
max_cpu_loras=4,
|
||||
max_loras=4,
|
||||
lora_dtype=DEFAULT_DTYPE)
|
||||
|
||||
dummy_lora_files = f"{tmp_path}/lora_adapter"
|
||||
os.makedirs(dummy_lora_files, exist_ok=True)
|
||||
create_peft_lora(
|
||||
dummy_model,
|
||||
save_dir=dummy_lora_files,
|
||||
target_modules=["layer1.dense1", "dense2"],
|
||||
lora_dtype=DEFAULT_DTYPE,
|
||||
)
|
||||
worker_adapter_manager = LRUCacheWorkerLoRAManager(
|
||||
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
|
||||
lora_config.lora_extra_vocab_size, lora_config, device,
|
||||
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
|
||||
worker_adapter_manager.create_lora_manager(
|
||||
llama_2_7b_model_extra_embeddings)
|
||||
4, 2,
|
||||
dummy_model.unpadded_vocab_size - lora_config.lora_extra_vocab_size,
|
||||
lora_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
|
||||
worker_adapter_manager.create_lora_manager(dummy_model)
|
||||
|
||||
mapping = LoRAMapping([], [])
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("1", 1, sql_lora_files),
|
||||
LoRARequest("2", 2, sql_lora_files)
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("2", 2, dummy_lora_files)
|
||||
], mapping)
|
||||
assert worker_adapter_manager.list_adapters() == {1, 2}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
|
||||
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("1", 1, sql_lora_files),
|
||||
LoRARequest("3", 3, sql_lora_files),
|
||||
LoRARequest("4", 4, sql_lora_files)
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("3", 3, dummy_lora_files),
|
||||
LoRARequest("4", 4, dummy_lora_files)
|
||||
], mapping)
|
||||
assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
@ -464,9 +462,9 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
|
||||
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("1", 1, sql_lora_files),
|
||||
LoRARequest("2", 2, sql_lora_files),
|
||||
LoRARequest("5", 5, sql_lora_files)
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("2", 2, dummy_lora_files),
|
||||
LoRARequest("5", 5, dummy_lora_files)
|
||||
], mapping)
|
||||
assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
@ -475,9 +473,9 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
|
||||
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("1", 1, sql_lora_files),
|
||||
LoRARequest("1", 1, sql_lora_files),
|
||||
LoRARequest("1", 1, sql_lora_files)
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("1", 1, dummy_lora_files)
|
||||
], mapping)
|
||||
assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
@ -486,9 +484,9 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
|
||||
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("6", 6, sql_lora_files),
|
||||
LoRARequest("7", 7, sql_lora_files),
|
||||
LoRARequest("8", 8, sql_lora_files)
|
||||
LoRARequest("6", 6, dummy_lora_files),
|
||||
LoRARequest("7", 7, dummy_lora_files),
|
||||
LoRARequest("8", 8, dummy_lora_files)
|
||||
], mapping)
|
||||
assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
@ -499,11 +497,11 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
# Over capacity
|
||||
with pytest.raises(RuntimeError):
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("10", 10, sql_lora_files),
|
||||
LoRARequest("11", 11, sql_lora_files),
|
||||
LoRARequest("12", 12, sql_lora_files),
|
||||
LoRARequest("13", 13, sql_lora_files),
|
||||
LoRARequest("14", 14, sql_lora_files)
|
||||
LoRARequest("10", 10, dummy_lora_files),
|
||||
LoRARequest("11", 11, dummy_lora_files),
|
||||
LoRARequest("12", 12, dummy_lora_files),
|
||||
LoRARequest("13", 13, dummy_lora_files),
|
||||
LoRARequest("14", 14, dummy_lora_files)
|
||||
], mapping)
|
||||
|
||||
assert worker_adapter_manager.device == device
|
||||
@ -512,33 +510,41 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
sql_lora_files, device):
|
||||
def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device,
|
||||
tmp_path):
|
||||
# Should remove every LoRA not specified in the request.
|
||||
lora_config = LoRAConfig(max_lora_rank=8,
|
||||
max_cpu_loras=4,
|
||||
max_loras=4,
|
||||
lora_dtype=DEFAULT_DTYPE)
|
||||
worker_adapter_manager = WorkerLoRAManager(
|
||||
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
|
||||
4, 2, dummy_model_gate_up.unpadded_vocab_size -
|
||||
lora_config.lora_extra_vocab_size, lora_config, device,
|
||||
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
|
||||
worker_adapter_manager.create_lora_manager(
|
||||
llama_2_7b_model_extra_embeddings)
|
||||
worker_adapter_manager.create_lora_manager(dummy_model_gate_up)
|
||||
|
||||
dummy_lora_files = f"{tmp_path}/lora_adapter"
|
||||
os.makedirs(dummy_lora_files, exist_ok=True)
|
||||
create_peft_lora(
|
||||
dummy_model_gate_up,
|
||||
save_dir=dummy_lora_files,
|
||||
target_modules=["layer1.dense1", "dense2"],
|
||||
lora_dtype=DEFAULT_DTYPE,
|
||||
)
|
||||
|
||||
mapping = LoRAMapping([], [])
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("1", 1, sql_lora_files),
|
||||
LoRARequest("2", 2, sql_lora_files)
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("2", 2, dummy_lora_files)
|
||||
], mapping)
|
||||
assert worker_adapter_manager.list_adapters() == {1, 2}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
|
||||
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("1", 1, sql_lora_files),
|
||||
LoRARequest("3", 3, sql_lora_files),
|
||||
LoRARequest("4", 4, sql_lora_files)
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("3", 3, dummy_lora_files),
|
||||
LoRARequest("4", 4, dummy_lora_files)
|
||||
], mapping)
|
||||
assert worker_adapter_manager.list_adapters() == {1, 3, 4}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
@ -546,9 +552,9 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4
|
||||
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("1", 1, sql_lora_files),
|
||||
LoRARequest("2", 2, sql_lora_files),
|
||||
LoRARequest("5", 5, sql_lora_files)
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("2", 2, dummy_lora_files),
|
||||
LoRARequest("5", 5, dummy_lora_files)
|
||||
], mapping)
|
||||
assert worker_adapter_manager.list_adapters() == {1, 2, 5}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
@ -556,9 +562,9 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
|
||||
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("1", 1, sql_lora_files),
|
||||
LoRARequest("1", 1, sql_lora_files),
|
||||
LoRARequest("1", 1, sql_lora_files)
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("1", 1, dummy_lora_files)
|
||||
], mapping)
|
||||
assert worker_adapter_manager.list_adapters() == {1}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
@ -566,9 +572,9 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None
|
||||
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("6", 6, sql_lora_files),
|
||||
LoRARequest("7", 7, sql_lora_files),
|
||||
LoRARequest("8", 8, sql_lora_files)
|
||||
LoRARequest("6", 6, dummy_lora_files),
|
||||
LoRARequest("7", 7, dummy_lora_files),
|
||||
LoRARequest("8", 8, dummy_lora_files)
|
||||
], mapping)
|
||||
assert worker_adapter_manager.list_adapters() == {6, 7, 8}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8
|
||||
@ -578,11 +584,11 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
# Over capacity
|
||||
with pytest.raises(RuntimeError):
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("10", 10, sql_lora_files),
|
||||
LoRARequest("11", 11, sql_lora_files),
|
||||
LoRARequest("12", 12, sql_lora_files),
|
||||
LoRARequest("13", 13, sql_lora_files),
|
||||
LoRARequest("14", 14, sql_lora_files)
|
||||
LoRARequest("10", 10, dummy_lora_files),
|
||||
LoRARequest("11", 11, dummy_lora_files),
|
||||
LoRARequest("12", 12, dummy_lora_files),
|
||||
LoRARequest("13", 13, dummy_lora_files),
|
||||
LoRARequest("14", 14, dummy_lora_files)
|
||||
], mapping)
|
||||
|
||||
assert worker_adapter_manager.device == device
|
||||
|
||||
@ -50,7 +50,6 @@ def test_mixtral_lora(mixtral_lora_files, tp_size):
|
||||
max_loras=4,
|
||||
distributed_executor_backend="ray",
|
||||
tensor_parallel_size=tp_size,
|
||||
enable_chunked_prefill=True,
|
||||
)
|
||||
|
||||
expected_lora_output = [
|
||||
|
||||
@ -4,17 +4,14 @@
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
from typing import Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
VllmConfig)
|
||||
from vllm.lora.models import LoRAMapping
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.v1.worker.gpu_worker import Worker as V1Worker
|
||||
from vllm.worker.worker import Worker
|
||||
from vllm.v1.worker.gpu_worker import Worker
|
||||
|
||||
NUM_LORAS = 16
|
||||
|
||||
@ -22,18 +19,11 @@ NUM_LORAS = 16
|
||||
@patch.dict(os.environ, {"RANK": "0"})
|
||||
def test_worker_apply_lora(sql_lora_files):
|
||||
|
||||
def set_active_loras(worker: Union[Worker, V1Worker],
|
||||
lora_requests: list[LoRARequest]):
|
||||
def set_active_loras(worker: Worker, lora_requests: list[LoRARequest]):
|
||||
lora_mapping = LoRAMapping([], [])
|
||||
if isinstance(worker, Worker):
|
||||
# v0 case
|
||||
worker.model_runner.set_active_loras(lora_requests, lora_mapping)
|
||||
else:
|
||||
# v1 case
|
||||
worker.model_runner.lora_manager.set_active_adapters(
|
||||
lora_requests, lora_mapping)
|
||||
|
||||
worker_cls = V1Worker if envs.VLLM_USE_V1 else Worker
|
||||
worker.model_runner.lora_manager.set_active_adapters(
|
||||
lora_requests, lora_mapping)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
model_config=ModelConfig(
|
||||
@ -62,7 +52,7 @@ def test_worker_apply_lora(sql_lora_files):
|
||||
max_cpu_loras=NUM_LORAS,
|
||||
max_loras=NUM_LORAS),
|
||||
)
|
||||
worker = worker_cls(
|
||||
worker = Worker(
|
||||
vllm_config=vllm_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
|
||||
@ -1,10 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||
|
||||
@ -340,3 +343,76 @@ def generate_data_for_nslices(
|
||||
seq_len_tensor,
|
||||
indices,
|
||||
)
|
||||
|
||||
|
||||
def create_peft_lora(
|
||||
model: torch.nn.Module,
|
||||
save_dir: str,
|
||||
target_modules: list[str],
|
||||
rank: int = 8,
|
||||
alpha: int = 16,
|
||||
dropout: float = 0.1,
|
||||
lora_dtype: torch.dtype = torch.float16,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
lora_weights = {}
|
||||
adapter_config = {
|
||||
"peft_type": "LORA",
|
||||
"auto_mapping": None,
|
||||
"base_model_name_or_path": "dummy_model",
|
||||
"revision": None,
|
||||
"task_type": "CAUSAL_LM",
|
||||
"inference_mode": False,
|
||||
"r": rank,
|
||||
"lora_alpha": alpha,
|
||||
"lora_dropout": dropout,
|
||||
"fan_in_fan_out": False,
|
||||
"bias": "none",
|
||||
"modules_to_save": None,
|
||||
"init_lora_weights": True,
|
||||
"layers_to_transform": None,
|
||||
"layers_pattern": None,
|
||||
"target_modules": target_modules,
|
||||
"exclude_modules": None,
|
||||
"use_rslora": False,
|
||||
"use_dora": False,
|
||||
"loftq_config": None,
|
||||
}
|
||||
|
||||
for module_name in target_modules:
|
||||
|
||||
module = model
|
||||
for attr in module_name.split("."):
|
||||
module = getattr(module, attr)
|
||||
|
||||
if hasattr(module, "input_size") and hasattr(module, "output_size"):
|
||||
|
||||
in_features = module.input_size
|
||||
out_features = module.output_size
|
||||
|
||||
elif hasattr(module, "embedding_dim") and hasattr(
|
||||
module, "num_embeddings"):
|
||||
# ParallelLMHead
|
||||
in_features = module.embedding_dim
|
||||
out_features = module.num_embeddings
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unable to determine dimensions for module {module_name}")
|
||||
|
||||
lora_A = torch.randn(rank, in_features, dtype=lora_dtype)
|
||||
|
||||
torch.nn.init.kaiming_uniform_(lora_A, a=5**0.5)
|
||||
|
||||
lora_B = torch.zeros(out_features, rank, dtype=lora_dtype)
|
||||
|
||||
# PEFT style
|
||||
lora_weights[f"base_model.model.{module_name}.lora_A.weight"] = lora_A
|
||||
lora_weights[f"base_model.model.{module_name}.lora_B.weight"] = lora_B
|
||||
|
||||
config_path = os.path.join(save_dir, "adapter_config.json")
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(adapter_config, f, indent=2, ensure_ascii=False)
|
||||
|
||||
weights_path = os.path.join(save_dir, "adapter_model.safetensors")
|
||||
save_file(lora_weights, weights_path)
|
||||
|
||||
return lora_weights
|
||||
|
||||
22
tests/models/language/pooling/test_st_projector.py
Normal file
22
tests/models/language/pooling/test_st_projector.py
Normal file
@ -0,0 +1,22 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
|
||||
from ...utils import CLSPoolingEmbedModelInfo, EmbedModelInfo
|
||||
from .mteb_utils import mteb_test_embed_models
|
||||
|
||||
# ST models with projector (Dense) layers
|
||||
ST_PROJECTOR_MODELS = [
|
||||
CLSPoolingEmbedModelInfo(
|
||||
"TencentBAC/Conan-embedding-v1",
|
||||
architecture="BertModel",
|
||||
enable_test=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", ST_PROJECTOR_MODELS)
|
||||
def test_embed_models_mteb(hf_runner, vllm_runner,
|
||||
model_info: EmbedModelInfo) -> None:
|
||||
|
||||
mteb_test_embed_models(hf_runner, vllm_runner, model_info)
|
||||
@ -11,7 +11,6 @@ from pathlib import PosixPath
|
||||
import pytest
|
||||
from transformers import (AutoModel, AutoModelForImageTextToText,
|
||||
AutoModelForTextToWaveform, AutoModelForVision2Seq)
|
||||
from transformers.utils import is_flash_attn_2_available
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import identity
|
||||
@ -637,10 +636,7 @@ VLM_TEST_SETTINGS = {
|
||||
dtype="half",
|
||||
num_logprobs=10,
|
||||
patch_hf_runner=model_utils.ovis2_5_patch_hf_runner,
|
||||
marks=[pytest.mark.skipif(
|
||||
not is_flash_attn_2_available(),
|
||||
reason="HF model needs `flash_attn` installed"
|
||||
)],
|
||||
hf_model_kwargs={"revision": "refs/pr/5"},
|
||||
),
|
||||
"phi3v": VLMTestInfo(
|
||||
models=["microsoft/Phi-3.5-vision-instruct"],
|
||||
|
||||
@ -160,6 +160,7 @@ def _test_processing_correctness(
|
||||
# incorrect token ids. So we need use `add_special_tokens=False` here
|
||||
# to leave bos_token to be added by the processor.
|
||||
_ADD_SPECIAL_TOKENS_OVERRIDES = {
|
||||
"donut": False,
|
||||
"mllama": False,
|
||||
"ovis": False,
|
||||
"ovis2_5": False,
|
||||
@ -270,6 +271,7 @@ def _test_processing_correctness_one(
|
||||
"facebook/chameleon-7b",
|
||||
"CohereLabs/command-a-vision-07-2025",
|
||||
"deepseek-ai/deepseek-vl2-tiny",
|
||||
"naver-clova-ix/donut-base-finetuned-docvqa",
|
||||
"microsoft/Florence-2-base",
|
||||
"adept/fuyu-8b",
|
||||
"google/gemma-3-4b-it",
|
||||
|
||||
@ -292,6 +292,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
|
||||
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
|
||||
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
|
||||
"SeedOssForCausalLM": _HfExamplesInfo("ByteDance-Seed/Seed-OSS-36B-Instruct", # noqa: E501
|
||||
trust_remote_code=True,
|
||||
is_available_online=False),
|
||||
"SmolLM3ForCausalLM": _HfExamplesInfo("HuggingFaceTB/SmolLM3-3B"),
|
||||
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501
|
||||
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
|
||||
@ -413,8 +416,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
trust_remote_code=True),
|
||||
"Idefics3ForConditionalGeneration": _HfExamplesInfo("HuggingFaceM4/Idefics3-8B-Llama3", # noqa: E501
|
||||
{"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}, # noqa: E501
|
||||
min_transformers_version="4.55.1",
|
||||
transformers_version_reason="HF model broken in 4.55.0"), # noqa: E501
|
||||
min_transformers_version="4.56",
|
||||
transformers_version_reason="HF model broken in 4.55"), # noqa: E501
|
||||
"InternS1ForConditionalGeneration": _HfExamplesInfo("internlm/Intern-S1",
|
||||
trust_remote_code=True), # noqa: E501
|
||||
"InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B",
|
||||
@ -465,9 +468,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
extras={"1.6-llama": "AIDC-AI/Ovis1.6-Llama3.2-3B",
|
||||
"1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B"}), # noqa: E501
|
||||
"Ovis2_5": _HfExamplesInfo("AIDC-AI/Ovis2.5-2B",
|
||||
trust_remote_code=True,
|
||||
max_transformers_version="4.53",
|
||||
transformers_version_reason="HF model is not compatible"), # noqa: E501
|
||||
trust_remote_code=True),
|
||||
"PaliGemmaForConditionalGeneration": _HfExamplesInfo("google/paligemma-3b-mix-224", # noqa: E501
|
||||
extras={"v2": "google/paligemma2-3b-ft-docci-448"}), # noqa: E501
|
||||
"Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct",
|
||||
@ -496,8 +497,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B",
|
||||
trust_remote_code=True),
|
||||
"SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct", # noqa: E501
|
||||
min_transformers_version="4.55.1",
|
||||
transformers_version_reason="HF model broken in 4.55.0"), # noqa: E501
|
||||
min_transformers_version="4.56",
|
||||
transformers_version_reason="HF model broken in 4.55"), # noqa: E501
|
||||
"Step3VLForConditionalGeneration": _HfExamplesInfo("stepfun-ai/step3",
|
||||
trust_remote_code=True),
|
||||
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501
|
||||
@ -512,6 +513,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
is_available_online=False,
|
||||
),
|
||||
# [Encoder-decoder]
|
||||
"DonutForConditionalGeneration": _HfExamplesInfo("naver-clova-ix/donut-base-finetuned-docvqa", # noqa: E501
|
||||
hf_overrides={"architectures": ["DonutForConditionalGeneration"], "model_type": "donut"}, # noqa: E501
|
||||
extras={"dolphin": "ByteDance/Dolphin"}), # noqa: E501
|
||||
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
|
||||
# Therefore, we borrow the BartTokenizer from the original Bart model
|
||||
"Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501
|
||||
|
||||
@ -17,13 +17,11 @@ from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
|
||||
PromptReplacement, apply_text_matches,
|
||||
apply_token_matches,
|
||||
find_mm_placeholders,
|
||||
find_text_matches, find_token_matches,
|
||||
iter_token_matches,
|
||||
replace_token_matches)
|
||||
# yapf: enable
|
||||
from vllm.multimodal.profiling import MultiModalProfiler
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import full_groupby
|
||||
|
||||
from .utils import random_image
|
||||
|
||||
@ -75,12 +73,15 @@ from .utils import random_image
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("start_idx", [0, 4, 8])
|
||||
# yapf: enable
|
||||
def test_iter_token_matches(token_ids, match_ids, expected):
|
||||
result = list(iter_token_matches(token_ids, match_ids))
|
||||
def test_iter_token_matches(token_ids, match_ids, expected, start_idx):
|
||||
result = list(iter_token_matches(token_ids, match_ids,
|
||||
start_idx=start_idx))
|
||||
|
||||
# Manually constructed results
|
||||
assert [item._asdict() for item in result] == expected
|
||||
assert [item._asdict() for item in result
|
||||
] == [item for item in expected if item["start_idx"] >= start_idx]
|
||||
|
||||
# Invariants
|
||||
match_lens = [end - start for start, end in result]
|
||||
@ -241,21 +242,23 @@ def test_find_token_matches(
|
||||
# Should not be used since there is nothing to convert to token IDs
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
|
||||
prompt_updates = [
|
||||
update_type(key, target, []).bind(mock_tokenizer)
|
||||
prompt_updates = {
|
||||
key: update_type(key, target, []).resolve(0)
|
||||
for key, target in target_by_key.items()
|
||||
]
|
||||
result = find_token_matches(prompt, prompt_updates)
|
||||
}
|
||||
result = {
|
||||
key: list(update.iter_token_matches(prompt, mock_tokenizer))
|
||||
for key, update in prompt_updates.items()
|
||||
}
|
||||
|
||||
# Only displayed on error
|
||||
print("result:", result)
|
||||
|
||||
# Manually constructed results
|
||||
result_groups = dict(full_groupby(result, key=lambda x: x.modality))
|
||||
assert {
|
||||
key: [
|
||||
dict(start_idx=item.start_idx, end_idx=item.end_idx)
|
||||
for item in result_groups.get(key, [])
|
||||
for item in result.get(key, [])
|
||||
]
|
||||
for key in expected_by_key
|
||||
} == expected_by_key
|
||||
@ -388,21 +391,23 @@ def test_find_text_matches(
|
||||
# Should not be used since there is nothing to convert to text
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
|
||||
prompt_updates = [
|
||||
update_type(key, target, []).bind(mock_tokenizer)
|
||||
prompt_updates = {
|
||||
key: update_type(key, target, []).resolve(0)
|
||||
for key, target in target_by_key.items()
|
||||
]
|
||||
result = find_text_matches(prompt, prompt_updates)
|
||||
}
|
||||
result = {
|
||||
key: list(update.iter_text_matches(prompt, mock_tokenizer))
|
||||
for key, update in prompt_updates.items()
|
||||
}
|
||||
|
||||
# Only displayed on error
|
||||
print("result:", result)
|
||||
|
||||
# Manually constructed results
|
||||
result_groups = dict(full_groupby(result, key=lambda x: x.modality))
|
||||
assert {
|
||||
key: [
|
||||
dict(start_idx=item.start_idx, end_idx=item.end_idx)
|
||||
for item in result_groups.get(key, [])
|
||||
for item in result.get(key, [])
|
||||
]
|
||||
for key in expected_by_key
|
||||
} == expected_by_key
|
||||
@ -552,39 +557,35 @@ def test_find_update_text(
|
||||
update_type,
|
||||
expected_by_mm_count,
|
||||
) in expected_by_update_type_mm_count.items():
|
||||
mm_prompt_updates = {
|
||||
key:
|
||||
[update_type(key, target, repl_by_key[key]).bind(mock_tokenizer)]
|
||||
for key, target in target_by_key.items()
|
||||
}
|
||||
mm_matches = {
|
||||
key: find_text_matches(prompt, updates)
|
||||
for key, updates in mm_prompt_updates.items()
|
||||
}
|
||||
|
||||
for mm_count, expected in expected_by_mm_count.items():
|
||||
result = apply_text_matches(
|
||||
mm_prompt_updates = {
|
||||
key: [[update_type(key, target, repl_by_key[key]).resolve(i)]
|
||||
for i in range(mm_count)]
|
||||
for key, target in target_by_key.items()
|
||||
}
|
||||
|
||||
new_prompt, result = apply_text_matches(
|
||||
prompt,
|
||||
mm_matches,
|
||||
{key: mm_count
|
||||
for key in repl_by_key},
|
||||
mm_prompt_updates,
|
||||
mock_tokenizer,
|
||||
)
|
||||
|
||||
# Only displayed on error
|
||||
print("update_type:", update_type)
|
||||
print("mm_count:", mm_count)
|
||||
print("mm_matches:", mm_matches)
|
||||
print("mm_prompt_updates:", mm_prompt_updates)
|
||||
print("new_prompt:", new_prompt)
|
||||
print("result:", result)
|
||||
|
||||
# Manually constructed results
|
||||
assert result == expected
|
||||
assert new_prompt == expected
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"), # noqa: E501
|
||||
[
|
||||
# Tokenized test cases of `test_find_replace_text`
|
||||
# Tokenized test cases of `test_find_update_text`
|
||||
# using the vocab of llava-hf/llava-v1.6-mistral-7b-hf
|
||||
(
|
||||
[1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
|
||||
@ -726,32 +727,28 @@ def test_find_update_tokens(
|
||||
update_type,
|
||||
expected_by_mm_count,
|
||||
) in expected_by_update_type_mm_count.items():
|
||||
mm_prompt_updates = {
|
||||
key:
|
||||
[update_type(key, target, repl_by_key[key]).bind(mock_tokenizer)]
|
||||
for key, target in target_by_key.items()
|
||||
}
|
||||
mm_matches = {
|
||||
key: find_token_matches(prompt, updates)
|
||||
for key, updates in mm_prompt_updates.items()
|
||||
}
|
||||
|
||||
for mm_count, expected in expected_by_mm_count.items():
|
||||
result = apply_token_matches(
|
||||
mm_prompt_updates = {
|
||||
key: [[update_type(key, target, repl_by_key[key]).resolve(i)]
|
||||
for i in range(mm_count)]
|
||||
for key, target in target_by_key.items()
|
||||
}
|
||||
|
||||
new_prompt, result = apply_token_matches(
|
||||
prompt,
|
||||
mm_matches,
|
||||
{key: mm_count
|
||||
for key in repl_by_key},
|
||||
mm_prompt_updates,
|
||||
mock_tokenizer,
|
||||
)
|
||||
|
||||
# Only displayed on error
|
||||
print("update_type:", update_type)
|
||||
print("mm_count:", mm_count)
|
||||
print("mm_matches:", mm_matches)
|
||||
print("mm_prompt_updates:", mm_prompt_updates)
|
||||
print("new_prompt:", new_prompt)
|
||||
print("result:", result)
|
||||
|
||||
# Manually constructed results
|
||||
assert result == expected
|
||||
assert new_prompt == expected
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@ -878,17 +875,11 @@ def test_find_mm_placeholders(
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
|
||||
mm_prompt_updates = {
|
||||
key: [update_type(key, [], repl).bind(mock_tokenizer)]
|
||||
key: [[update_type(key, [], repl).resolve(i)] for i in range(3)]
|
||||
for key, repl in repl_by_key.items()
|
||||
}
|
||||
|
||||
result = find_mm_placeholders(
|
||||
mm_prompt_updates,
|
||||
prompt,
|
||||
# Effectively match all occurrences in the prompt
|
||||
{key: 3
|
||||
for key in repl_by_key},
|
||||
)
|
||||
result = find_mm_placeholders(prompt, mm_prompt_updates, mock_tokenizer)
|
||||
|
||||
# Only displayed on error
|
||||
print("result:", result)
|
||||
|
||||
@ -38,8 +38,7 @@ def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool,
|
||||
with vllm_runner(model_id) as llm:
|
||||
# note: this does not test accuracy, just that we can run through
|
||||
# see lm-eval tests for accuracy
|
||||
outputs = llm.generate_greedy(prompts=["Hello my name is"],
|
||||
max_tokens=10)
|
||||
outputs = llm.generate_greedy(["Hello my name is"], max_tokens=10)
|
||||
print(outputs[0][1])
|
||||
|
||||
|
||||
@ -90,8 +89,7 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str,
|
||||
|
||||
# note: this does not test accuracy, just that we can run through
|
||||
# see lm-eval tests for accuracy
|
||||
outputs = llm.generate_greedy(prompts=["Hello my name is"],
|
||||
max_tokens=10)
|
||||
outputs = llm.generate_greedy(["Hello my name is"], max_tokens=10)
|
||||
print(outputs[0][1])
|
||||
|
||||
|
||||
|
||||
@ -46,5 +46,5 @@ def test_lm_head(
|
||||
vllm_model.apply_model(check_model)
|
||||
|
||||
print(
|
||||
vllm_model.generate_greedy(prompts=["Hello my name is"],
|
||||
vllm_model.generate_greedy(["Hello my name is"],
|
||||
max_tokens=10)[0][1])
|
||||
|
||||
459
tests/tool_use/test_seed_oss_tool_parser.py
Normal file
459
tests/tool_use/test_seed_oss_tool_parser.py
Normal file
@ -0,0 +1,459 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: E501
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
ChatCompletionToolsParam,
|
||||
DeltaMessage, FunctionCall,
|
||||
ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers import SeedOssToolParser
|
||||
from vllm.transformers_utils.detokenizer import detokenize_incrementally
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
||||
|
||||
# Use a common model that is likely to be available
|
||||
MODEL = "ByteDance-Seed/Seed-OSS-36B-Instruct"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def seed_oss_tokenizer():
|
||||
return get_tokenizer(tokenizer_name=MODEL, trust_remote_code=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def seed_oss_tool_parser(seed_oss_tokenizer):
|
||||
return SeedOssToolParser(seed_oss_tokenizer)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tools():
|
||||
return [
|
||||
ChatCompletionToolsParam(
|
||||
type="function",
|
||||
function={
|
||||
"name": "get_weather",
|
||||
"description": "Get current temperature for a given location.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description":
|
||||
"City and country e.g. Bogotá, Colombia"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "this is the unit of temperature"
|
||||
}
|
||||
},
|
||||
"required": ["location"],
|
||||
"additionalProperties": False
|
||||
},
|
||||
"returns": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"temperature": {
|
||||
"type": "number",
|
||||
"description": "temperature in celsius"
|
||||
}
|
||||
},
|
||||
"required": ["temperature"],
|
||||
"additionalProperties": False
|
||||
},
|
||||
"strict": True
|
||||
}),
|
||||
]
|
||||
|
||||
|
||||
def assert_tool_calls(actual_tool_calls: list[ToolCall],
|
||||
expected_tool_calls: list[ToolCall]):
|
||||
assert len(actual_tool_calls) == len(expected_tool_calls)
|
||||
|
||||
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
|
||||
expected_tool_calls):
|
||||
# Seed-OSS tool call will not generate id
|
||||
assert actual_tool_call.type == "function"
|
||||
assert actual_tool_call.function == expected_tool_call.function
|
||||
|
||||
assert actual_tool_call.function.name == expected_tool_call.function.name
|
||||
assert actual_tool_call.function.arguments == expected_tool_call.function.arguments
|
||||
|
||||
|
||||
def test_extract_tool_calls_no_tools(seed_oss_tool_parser):
|
||||
model_output = "This is a test response without any tool calls"
|
||||
extracted_tool_calls = seed_oss_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
|
||||
assert not extracted_tool_calls.tools_called
|
||||
assert extracted_tool_calls.tool_calls == []
|
||||
assert extracted_tool_calls.content == model_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"tool_call_0_thinking_budget",
|
||||
"tool_call_512_thinkg_budget",
|
||||
"tool_call_unlimited_thinking_budget",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
("""<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n"""
|
||||
"""The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n"""
|
||||
"""<seed:tool_call>\n<function=get_weather>\n"""
|
||||
"""<parameter=location>Barcelona, Spain</parameter>\n</function>\n</seed:tool_call>""",
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps({
|
||||
"location": "Barcelona, Spain",
|
||||
}, ),
|
||||
),
|
||||
type='function')
|
||||
],
|
||||
"""<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n"""
|
||||
"""The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n"""
|
||||
),
|
||||
(
|
||||
"""<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """
|
||||
"""question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
|
||||
"""there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """
|
||||
"""check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """
|
||||
"""optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """
|
||||
"""country). \n<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use."""
|
||||
"""</seed:cot_budget_reflect>\n Since the unit isn\'t specified, the function will default to Celsius, which """
|
||||
"""is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """
|
||||
"""the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """
|
||||
"""user\'s input has a space, but the function might accept either; to be safe, using the standard format """
|
||||
"""with a comma).\n<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """
|
||||
"""use.</seed:cot_budget_reflect>\n The unit parameter can be omitted since it\'s optional.</seed:think>\n"""
|
||||
"""<seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, Spain</parameter>\n</function>"""
|
||||
"""\n</seed:tool_call>""",
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps({
|
||||
"location": "Barcelona, Spain",
|
||||
}, ),
|
||||
),
|
||||
type='function')
|
||||
],
|
||||
"""<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """
|
||||
"""question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
|
||||
"""there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """
|
||||
"""check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """
|
||||
"""optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """
|
||||
"""country). \n<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use."""
|
||||
"""</seed:cot_budget_reflect>\n Since the unit isn\'t specified, the function will default to Celsius, which """
|
||||
"""is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """
|
||||
"""the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """
|
||||
"""user\'s input has a space, but the function might accept either; to be safe, using the standard format """
|
||||
"""with a comma).\n<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """
|
||||
"""use.</seed:cot_budget_reflect>\n The unit parameter can be omitted since it\'s optional.</seed:think>\n""",
|
||||
),
|
||||
(
|
||||
"""<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """
|
||||
"""First, I need to remember the function I can use: get_weather. The function requires a """
|
||||
"""location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """
|
||||
"""the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """
|
||||
"""let me check the function docstring again. Oh, the function says unit is optional, and """
|
||||
"""returns temperature in Celsius. So I should call get_weather with location "Barcelona, """
|
||||
"""Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """
|
||||
"""The format is <seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, """
|
||||
"""Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>. """
|
||||
"""Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """
|
||||
"""of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """
|
||||
"""it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """
|
||||
"""call should be as above. Then wait for the result to come back and tell the user the """
|
||||
"""temperature in Celsius.</seed:think><seed:tool_call>\n<function=get_weather>\n<parameter=location>"""
|
||||
"""Barcelona, Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>""",
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"location": "Barcelona, Spain",
|
||||
"unit": "celsius",
|
||||
}, ),
|
||||
),
|
||||
type='function')
|
||||
],
|
||||
"""<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """
|
||||
"""First, I need to remember the function I can use: get_weather. The function requires a """
|
||||
"""location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """
|
||||
"""the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """
|
||||
"""let me check the function docstring again. Oh, the function says unit is optional, and """
|
||||
"""returns temperature in Celsius. So I should call get_weather with location "Barcelona, """
|
||||
"""Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """
|
||||
"""The format is <seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, """
|
||||
"""Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>. """
|
||||
"""Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """
|
||||
"""of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """
|
||||
"""it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """
|
||||
"""call should be as above. Then wait for the result to come back and tell the user the """
|
||||
"""temperature in Celsius.</seed:think>""",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls(seed_oss_tool_parser, sample_tools, model_output,
|
||||
expected_tool_calls, expected_content):
|
||||
request = ChatCompletionRequest(model=MODEL,
|
||||
messages=[],
|
||||
tools=sample_tools)
|
||||
extracted_tool_calls = seed_oss_tool_parser.extract_tool_calls(
|
||||
model_output, request=request) # type: ignore[arg-type]
|
||||
assert extracted_tool_calls.tools_called
|
||||
|
||||
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
|
||||
|
||||
assert extracted_tool_calls.content == expected_content
|
||||
|
||||
|
||||
def test_streaming_tool_calls_no_tools(seed_oss_tool_parser):
|
||||
model_output = "This is a test response without any tool calls"
|
||||
|
||||
result = seed_oss_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="his is a test response",
|
||||
current_text=model_output,
|
||||
delta_text=" without any tool calls.",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=None,
|
||||
)
|
||||
|
||||
# Should return the delta text as content
|
||||
assert result is not None
|
||||
assert hasattr(result, 'content')
|
||||
assert result.content == " without any tool calls."
|
||||
|
||||
|
||||
def stream_delta_message_generator(
|
||||
seed_oss_tool_parser: SeedOssToolParser,
|
||||
seed_oss_tokenizer: AnyTokenizer,
|
||||
model_output: str,
|
||||
request: Optional[ChatCompletionRequest] = None
|
||||
) -> Generator[DeltaMessage, None, None]:
|
||||
all_token_ids = seed_oss_tokenizer.encode(model_output,
|
||||
add_special_tokens=False)
|
||||
|
||||
previous_text = ""
|
||||
previous_tokens = None
|
||||
prefix_offset = 0
|
||||
read_offset = 0
|
||||
for i, delta_token in enumerate(all_token_ids):
|
||||
delta_token_ids = [delta_token]
|
||||
previous_token_ids = all_token_ids[:i]
|
||||
current_token_ids = all_token_ids[:i + 1]
|
||||
|
||||
(new_tokens, delta_text, new_prefix_offset,
|
||||
new_read_offset) = detokenize_incrementally(
|
||||
tokenizer=seed_oss_tokenizer,
|
||||
all_input_ids=current_token_ids,
|
||||
prev_tokens=previous_tokens,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
skip_special_tokens=False,
|
||||
spaces_between_special_tokens=True,
|
||||
)
|
||||
|
||||
current_text = previous_text + delta_text
|
||||
|
||||
delta_message = seed_oss_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text,
|
||||
current_text,
|
||||
delta_text,
|
||||
previous_token_ids,
|
||||
current_token_ids,
|
||||
delta_token_ids,
|
||||
request=request,
|
||||
)
|
||||
if delta_message:
|
||||
yield delta_message
|
||||
|
||||
previous_text = current_text
|
||||
previous_tokens = (previous_tokens +
|
||||
new_tokens if previous_tokens else new_tokens)
|
||||
prefix_offset = new_prefix_offset
|
||||
read_offset = new_read_offset
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"tool_call_0_thinking_budget",
|
||||
"tool_call_512_thinkg_budget",
|
||||
"tool_call_unlimited_thinking_budget",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
("""<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n"""
|
||||
"""The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n"""
|
||||
"""<seed:tool_call>\n<function=get_weather>\n"""
|
||||
"""<parameter=location>Barcelona, Spain</parameter>\n</function>\n</seed:tool_call>""",
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps({
|
||||
"location": "Barcelona, Spain",
|
||||
}, ),
|
||||
),
|
||||
type='function')
|
||||
],
|
||||
"""<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n"""
|
||||
"""The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n"""
|
||||
),
|
||||
(
|
||||
"""<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """
|
||||
"""question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
|
||||
"""there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """
|
||||
"""check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """
|
||||
"""optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """
|
||||
"""country). \n<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use."""
|
||||
"""</seed:cot_budget_reflect>\n Since the unit isn\'t specified, the function will default to Celsius, which """
|
||||
"""is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """
|
||||
"""the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """
|
||||
"""user\'s input has a space, but the function might accept either; to be safe, using the standard format """
|
||||
"""with a comma).\n<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """
|
||||
"""use.</seed:cot_budget_reflect>\n The unit parameter can be omitted since it\'s optional.</seed:think>\n"""
|
||||
"""<seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, Spain</parameter>\n</function>"""
|
||||
"""\n</seed:tool_call>""",
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps({
|
||||
"location": "Barcelona, Spain",
|
||||
}, ),
|
||||
),
|
||||
type='function')
|
||||
],
|
||||
"""<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """
|
||||
"""question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
|
||||
"""there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """
|
||||
"""check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """
|
||||
"""optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """
|
||||
"""country). \n<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use."""
|
||||
"""</seed:cot_budget_reflect>\n Since the unit isn\'t specified, the function will default to Celsius, which """
|
||||
"""is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """
|
||||
"""the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """
|
||||
"""user\'s input has a space, but the function might accept either; to be safe, using the standard format """
|
||||
"""with a comma).\n<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """
|
||||
"""use.</seed:cot_budget_reflect>\n The unit parameter can be omitted since it\'s optional.</seed:think>\n""",
|
||||
),
|
||||
(
|
||||
"""<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """
|
||||
"""First, I need to remember the function I can use: get_weather. The function requires a """
|
||||
"""location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """
|
||||
"""the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """
|
||||
"""let me check the function docstring again. Oh, the function says unit is optional, and """
|
||||
"""returns temperature in Celsius. So I should call get_weather with location "Barcelona, """
|
||||
"""Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """
|
||||
"""The format is <seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, """
|
||||
"""Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>. """
|
||||
"""Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """
|
||||
"""of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """
|
||||
"""it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """
|
||||
"""call should be as above. Then wait for the result to come back and tell the user the """
|
||||
"""temperature in Celsius.</seed:think><seed:tool_call>\n<function=get_weather>\n<parameter=location>"""
|
||||
"""Barcelona, Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>""",
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"location": "Barcelona, Spain",
|
||||
"unit": "celsius",
|
||||
}, ),
|
||||
),
|
||||
type='function')
|
||||
],
|
||||
"""<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """
|
||||
"""First, I need to remember the function I can use: get_weather. The function requires a """
|
||||
"""location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """
|
||||
"""the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """
|
||||
"""let me check the function docstring again. Oh, the function says unit is optional, and """
|
||||
"""returns temperature in Celsius. So I should call get_weather with location "Barcelona, """
|
||||
"""Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """
|
||||
"""The format is <seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, """
|
||||
"""Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>. """
|
||||
"""Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """
|
||||
"""of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """
|
||||
"""it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """
|
||||
"""call should be as above. Then wait for the result to come back and tell the user the """
|
||||
"""temperature in Celsius.</seed:think>""",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_streaming_tool_calls(seed_oss_tool_parser, seed_oss_tokenizer,
|
||||
sample_tools, model_output, expected_tool_calls,
|
||||
expected_content):
|
||||
"""Test incremental streaming behavior"""
|
||||
request = ChatCompletionRequest(model=MODEL,
|
||||
messages=[],
|
||||
tools=sample_tools)
|
||||
|
||||
other_content = ''
|
||||
tool_states = {} # Track state per tool index
|
||||
|
||||
for delta_message in stream_delta_message_generator(
|
||||
seed_oss_tool_parser, seed_oss_tokenizer, model_output, request):
|
||||
# role should never be streamed from tool parser
|
||||
assert not delta_message.role
|
||||
|
||||
if delta_message.content:
|
||||
other_content += delta_message.content
|
||||
|
||||
if delta_message.tool_calls:
|
||||
for tool_call in delta_message.tool_calls:
|
||||
idx = tool_call.index
|
||||
|
||||
# Initialize state for new tool
|
||||
if idx not in tool_states:
|
||||
tool_states[idx] = {
|
||||
"id": None,
|
||||
"name": None,
|
||||
"arguments": "",
|
||||
"type": None
|
||||
}
|
||||
|
||||
# First chunk should have id, name, and type
|
||||
if tool_call.id:
|
||||
tool_states[idx]["id"] = tool_call.id
|
||||
|
||||
if tool_call.type:
|
||||
assert tool_call.type == "function"
|
||||
tool_states[idx]["type"] = tool_call.type
|
||||
|
||||
if tool_call.function:
|
||||
if tool_call.function.name:
|
||||
# Should only be set once
|
||||
assert tool_states[idx]["name"] is None
|
||||
tool_states[idx]["name"] = tool_call.function.name
|
||||
|
||||
if tool_call.function.arguments is not None:
|
||||
# Accumulate arguments incrementally
|
||||
tool_states[idx][
|
||||
"arguments"] += tool_call.function.arguments
|
||||
|
||||
# Verify final content
|
||||
assert other_content == expected_content
|
||||
|
||||
# Verify we got all expected tool calls
|
||||
assert len(tool_states) == len(expected_tool_calls)
|
||||
|
||||
# Verify each tool call
|
||||
for idx, expected_tool in enumerate(expected_tool_calls):
|
||||
state = tool_states[idx]
|
||||
assert state["id"] is not None
|
||||
assert state["type"] == "function"
|
||||
assert state["name"] == expected_tool.function.name
|
||||
|
||||
# Parse accumulated arguments
|
||||
arguments_str = state["arguments"]
|
||||
assert arguments_str is not None
|
||||
actual_args = json.loads(arguments_str)
|
||||
expected_args = json.loads(expected_tool.function.arguments)
|
||||
assert actual_args == expected_args
|
||||
@ -10,14 +10,15 @@ from tests.v1.attention.utils import (BatchSpec, _Backend,
|
||||
create_standard_kv_cache_spec,
|
||||
create_vllm_config,
|
||||
get_attention_backend)
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer
|
||||
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
|
||||
set_kv_cache_layout)
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
|
||||
BACKENDS_TO_TEST = [
|
||||
_Backend.FLASH_ATTN_VLLM_V1, _Backend.FLASHINFER_VLLM_V1,
|
||||
_Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN_VLLM_V1, _Backend.TREE_ATTN
|
||||
_Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN_VLLM_V1, _Backend.TREE_ATTN,
|
||||
"FLEX_ATTENTION_SLOW"
|
||||
]
|
||||
|
||||
# Remove flashinfer from the list if it's not available
|
||||
@ -97,7 +98,7 @@ def create_and_prepopulate_kv_cache(
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
randomize_blocks: bool = True) -> torch.Tensor:
|
||||
"""Create and prepopulate a KV cache with context data.
|
||||
|
||||
|
||||
Args:
|
||||
k_contexts: List of key context tensors for each sequence
|
||||
v_contexts: List of value context tensors for each sequence
|
||||
@ -109,9 +110,9 @@ def create_and_prepopulate_kv_cache(
|
||||
device: Device to create the cache on
|
||||
num_blocks: Total number of blocks in the cache
|
||||
block_table: Block table tensor to populate
|
||||
randomize_blocks: Whether to randomly permute blocks
|
||||
randomize_blocks: Whether to randomly permute blocks
|
||||
or use sequential order
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple of (kv_cache, updated_block_table)
|
||||
"""
|
||||
@ -206,10 +207,18 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
|
||||
kv_cache: torch.Tensor) -> torch.Tensor:
|
||||
"""Run attention computation using the specified backend's AttentionImpl."""
|
||||
|
||||
builder_cls, impl_cls = get_attention_backend(backend)
|
||||
# Handle special case for FLEX_ATTENTION_SLOW
|
||||
actual_backend = backend
|
||||
|
||||
use_direct_block_mask = is_torch_equal_or_newer("2.9.0.dev0")
|
||||
if backend == "FLEX_ATTENTION_SLOW":
|
||||
actual_backend = _Backend.FLEX_ATTENTION
|
||||
use_direct_block_mask = False
|
||||
|
||||
builder_cls, impl_cls = get_attention_backend(actual_backend)
|
||||
|
||||
# Mock flashinfer's get_per_layer_parameters if needed
|
||||
if backend == _Backend.FLASHINFER_VLLM_V1:
|
||||
if actual_backend == _Backend.FLASHINFER_VLLM_V1:
|
||||
import unittest.mock
|
||||
|
||||
from vllm.v1.attention.backends.utils import PerLayerParameters
|
||||
@ -239,6 +248,8 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
|
||||
else:
|
||||
# Build metadata
|
||||
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
|
||||
if actual_backend == _Backend.FLEX_ATTENTION:
|
||||
builder.direct_build = use_direct_block_mask
|
||||
attn_metadata = builder.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
@ -453,11 +464,6 @@ def test_backend_correctness(batch_spec_name: str, model: str):
|
||||
rtol = 1e-2
|
||||
atol = 5e-3
|
||||
|
||||
if backend_name == _Backend.FLEX_ATTENTION:
|
||||
atol = 5e-1 # TODO: figure out why flex_attention has such large
|
||||
# numerical differences for medium_decode, medium_prefill,
|
||||
# mixed_medium
|
||||
|
||||
max_diff = torch.max(torch.abs(backend_output - sdpa_output)).item()
|
||||
max_rel_diff = torch.max(
|
||||
torch.abs(backend_output - sdpa_output) /
|
||||
|
||||
104
tests/v1/attention/test_attention_backends_selection.py
Normal file
104
tests/v1/attention/test_attention_backends_selection.py
Normal file
@ -0,0 +1,104 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for mamba attention backend selectors."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.mamba.short_conv import ShortConv
|
||||
from vllm.model_executor.models.minimax_text_01 import (
|
||||
MiniMaxText01LinearAttention)
|
||||
from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend
|
||||
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend
|
||||
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend
|
||||
from vllm.v1.attention.backends.short_conv_attn import (
|
||||
ShortConvAttentionBackend)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"layer_class, init_kwargs, expected_backend, expected_mamba_type", [
|
||||
(
|
||||
MambaMixer,
|
||||
dict(
|
||||
hidden_size=128,
|
||||
ssm_state_size=16,
|
||||
conv_kernel_size=4,
|
||||
intermediate_size=256,
|
||||
time_step_rank=8,
|
||||
use_conv_bias=True,
|
||||
use_bias=False,
|
||||
use_rms_norm=True,
|
||||
),
|
||||
Mamba1AttentionBackend,
|
||||
"mamba1",
|
||||
),
|
||||
(
|
||||
MambaMixer2,
|
||||
dict(
|
||||
hidden_size=128,
|
||||
ssm_state_size=16,
|
||||
conv_kernel_size=4,
|
||||
intermediate_size=256,
|
||||
use_conv_bias=True,
|
||||
use_bias=False,
|
||||
n_groups=1,
|
||||
num_heads=8,
|
||||
head_dim=32,
|
||||
),
|
||||
Mamba2AttentionBackend,
|
||||
"mamba2",
|
||||
),
|
||||
(
|
||||
MiniMaxText01LinearAttention,
|
||||
dict(
|
||||
hidden_size=128,
|
||||
hidden_inner_size=256,
|
||||
num_heads=8,
|
||||
head_dim=32,
|
||||
max_position=2048,
|
||||
block_size=64,
|
||||
num_hidden_layer=12,
|
||||
layer_idx=0,
|
||||
linear_layer_idx=0,
|
||||
),
|
||||
LinearAttentionBackend,
|
||||
"linear_attention",
|
||||
),
|
||||
(
|
||||
ShortConv,
|
||||
dict(
|
||||
config=SimpleNamespace(conv_L_cache=32, conv_bias=True),
|
||||
dim=128,
|
||||
layer_idx=0,
|
||||
),
|
||||
ShortConvAttentionBackend,
|
||||
"short_conv",
|
||||
),
|
||||
])
|
||||
def test_mamba_layers_get_attn_backend(dist_init, layer_class, init_kwargs,
|
||||
expected_backend, expected_mamba_type):
|
||||
"""Test that Mamba-like layers return the correct attention backend."""
|
||||
layer = layer_class(**init_kwargs)
|
||||
|
||||
backend_class = layer.get_attn_backend()
|
||||
assert backend_class is expected_backend
|
||||
assert layer.mamba_type == expected_mamba_type
|
||||
|
||||
|
||||
@pytest.mark.parametrize("layer_class,expected_backend,expected_mamba_type", [
|
||||
(MambaMixer, Mamba1AttentionBackend, "mamba1"),
|
||||
(MambaMixer2, Mamba2AttentionBackend, "mamba2"),
|
||||
(MiniMaxText01LinearAttention, LinearAttentionBackend, "linear_attention"),
|
||||
(ShortConv, ShortConvAttentionBackend, "short_conv"),
|
||||
])
|
||||
def test_mamba_layers_have_unified_interface(layer_class, expected_backend,
|
||||
expected_mamba_type):
|
||||
"""Test that all Mamba layers have the unified get_attn_backend
|
||||
interface."""
|
||||
assert hasattr(layer_class, 'get_attn_backend'), (
|
||||
f"{layer_class.__name__} should have get_attn_backend method")
|
||||
assert hasattr(layer_class, 'mamba_type'), (
|
||||
f"{layer_class.__name__} should have mamba_type property")
|
||||
@ -1,25 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for mamba attention backend selectors."""
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend
|
||||
from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
|
||||
|
||||
|
||||
@pytest.mark.parametrize(argnames=["mamba_type", "expected_backend"],
|
||||
argvalues=[("mamba2", Mamba2AttentionBackend)])
|
||||
def test_get_mamba_attn_backend_mamba2(mamba_type, expected_backend):
|
||||
backend_class = get_mamba_attn_backend(mamba_type)
|
||||
|
||||
assert backend_class is expected_backend
|
||||
|
||||
|
||||
def test_get_mamba_attn_backend_unsupported():
|
||||
unsupported_types = ["mamba", ""]
|
||||
|
||||
for mamba_type in unsupported_types:
|
||||
err_message = f"Mamba Attention type {mamba_type} is not supported yet."
|
||||
with pytest.raises(NotImplementedError, match=err_message):
|
||||
get_mamba_attn_backend(mamba_type)
|
||||
144
tests/v1/core/test_encoder_cache_manager.py
Normal file
144
tests/v1/core/test_encoder_cache_manager.py
Normal file
@ -0,0 +1,144 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
|
||||
|
||||
|
||||
# ------------------ Mock Classes ------------------ #
|
||||
class MockRequest:
|
||||
|
||||
def __init__(self, request_id, mm_hashes, token_counts):
|
||||
self.request_id = request_id
|
||||
self.mm_hashes = mm_hashes
|
||||
self._token_counts = token_counts
|
||||
|
||||
def get_num_encoder_tokens(self, input_id: int) -> int:
|
||||
return self._token_counts[input_id]
|
||||
|
||||
|
||||
# ------------------ Unit Tests ------------------ #
|
||||
def test_basic_allocate_and_reuse():
|
||||
cache = EncoderCacheManager(cache_size=10)
|
||||
req = MockRequest("r1", ["imgA"], [4])
|
||||
|
||||
assert not cache.check_and_update_cache(req, 0)
|
||||
assert cache.try_allocate(req, 0, int(1e9))
|
||||
|
||||
cache.allocate(req, 0)
|
||||
|
||||
assert cache.check_and_update_cache(req, 0)
|
||||
assert "r1" in cache.cached["imgA"]
|
||||
assert cache.num_free_slots == 6
|
||||
|
||||
# Free twice to bring refcount to 0.
|
||||
cache.free_encoder_input(req, 0)
|
||||
cache.free_encoder_input(req, 0)
|
||||
|
||||
assert not cache.cached["imgA"]
|
||||
assert "imgA" in cache.freeable
|
||||
assert cache.num_freeable_slots == 10
|
||||
assert cache.num_free_slots == 6
|
||||
|
||||
|
||||
def test_freeing_decreases_refcount_and_moves_to_freeable():
|
||||
manager = EncoderCacheManager(cache_size=10)
|
||||
req = MockRequest("req2", ["img3"], [5])
|
||||
|
||||
assert manager.try_allocate(req, 0, int(1e9))
|
||||
manager.allocate(req, 0)
|
||||
|
||||
assert len(manager.cached["img3"]) == 1
|
||||
|
||||
manager.free_encoder_input(req, 0)
|
||||
|
||||
assert not manager.cached["img3"]
|
||||
assert "img3" in manager.freeable
|
||||
assert manager.num_freeable_slots == 10
|
||||
|
||||
|
||||
def test_free_request_frees_all_inputs():
|
||||
manager = EncoderCacheManager(cache_size=10)
|
||||
req = MockRequest("req3", ["a", "b"], [2, 3])
|
||||
|
||||
assert manager.try_allocate(req, 0, int(1e9))
|
||||
manager.allocate(req, 0)
|
||||
|
||||
assert manager.try_allocate(req, 1, int(1e9))
|
||||
manager.allocate(req, 1)
|
||||
|
||||
assert len(manager.cached["a"]) == 1
|
||||
assert len(manager.cached["b"]) == 1
|
||||
|
||||
manager.free(req)
|
||||
|
||||
assert not manager.cached["a"]
|
||||
assert not manager.cached["b"]
|
||||
assert "a" in manager.freeable
|
||||
assert "b" in manager.freeable
|
||||
assert manager.num_freeable_slots == 10
|
||||
|
||||
|
||||
def test_eviction_when_cache_is_full():
|
||||
manager = EncoderCacheManager(cache_size=10)
|
||||
|
||||
req1 = MockRequest("req1", ["x"], [6])
|
||||
req2 = MockRequest("req2", ["y"], [5])
|
||||
|
||||
assert manager.try_allocate(req1, 0, int(1e9))
|
||||
manager.allocate(req1, 0)
|
||||
manager.free_encoder_input(req1, 0)
|
||||
|
||||
assert manager.try_allocate(req2, 0, int(1e9))
|
||||
manager.allocate(req2, 0)
|
||||
|
||||
# 'x' should have been evicted.
|
||||
assert "x" not in manager.cached
|
||||
assert "x" in manager.get_freed_mm_hashes()
|
||||
|
||||
|
||||
def test_get_cached_input_ids():
|
||||
manager = EncoderCacheManager(cache_size=10)
|
||||
req = MockRequest("reqX", ["m", "n", "o"], [2, 4, 3])
|
||||
|
||||
assert manager.try_allocate(req, 0, int(1e9))
|
||||
manager.allocate(req, 0)
|
||||
|
||||
assert manager.try_allocate(req, 2, int(1e9))
|
||||
manager.allocate(req, 2)
|
||||
|
||||
cached_ids = manager.get_cached_input_ids(req)
|
||||
assert cached_ids == {0, 2}
|
||||
|
||||
|
||||
def test_has_cache_restores_from_freeable():
|
||||
manager = EncoderCacheManager(cache_size=10)
|
||||
req = MockRequest("reqY", ["imgZ"], [4])
|
||||
|
||||
assert manager.try_allocate(req, 0, int(1e9))
|
||||
manager.allocate(req, 0)
|
||||
|
||||
manager.free_encoder_input(req, 0)
|
||||
|
||||
# Should restore from freeable.
|
||||
assert manager.check_and_update_cache(req, 0)
|
||||
assert len(manager.cached["imgZ"]) == 1
|
||||
assert "imgZ" not in manager.freeable
|
||||
assert manager.num_freeable_slots == 6
|
||||
|
||||
|
||||
def test_get_freed_mm_hashes_clears_freed_list():
|
||||
manager = EncoderCacheManager(cache_size=10)
|
||||
req1 = MockRequest("reqA", ["a"], [5])
|
||||
req2 = MockRequest("reqB", ["b"], [6])
|
||||
|
||||
assert manager.try_allocate(req1, 0, int(1e9))
|
||||
manager.allocate(req1, 0)
|
||||
manager.free_encoder_input(req1, 0)
|
||||
|
||||
# Should trigger eviction of 'a'.
|
||||
assert manager.try_allocate(req2, 0, int(1e9))
|
||||
manager.allocate(req2, 0)
|
||||
|
||||
freed = manager.get_freed_mm_hashes()
|
||||
assert "a" in freed
|
||||
assert manager.get_freed_mm_hashes() == []
|
||||
@ -338,7 +338,7 @@ def test_stop_via_update_from_output():
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
|
||||
@ -391,7 +391,7 @@ def test_stop_via_update_from_output():
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
@ -443,7 +443,7 @@ def test_stop_via_update_from_output():
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
@ -490,7 +490,7 @@ def test_stop_via_update_from_output():
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
|
||||
|
||||
@ -143,7 +143,11 @@ def create_requests(
|
||||
mm_position = mm_positions[i]
|
||||
mm_item = MultiModalKwargsItem.dummy("dummy_m")
|
||||
mm_kwargs = [mm_item] * len(mm_position)
|
||||
mm_hashes = ["hash"] * len(mm_position)
|
||||
# Dummy hash for each mm item should be unique
|
||||
# since encoder cache tracks entries by hash
|
||||
mm_hashes = [
|
||||
"hash" + str(i) + "_" + str(j) for j in range(len(mm_position))
|
||||
]
|
||||
else:
|
||||
mm_position = None
|
||||
mm_kwargs = None
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user