Compare commits
157 Commits
khluu/test
...
copilot/fi
| Author | SHA1 | Date | |
|---|---|---|---|
| 0b85cc9fd4 | |||
| 1c3d99d6a3 | |||
| 1eec2bf88b | |||
| 55812718ab | |||
| 79dff4ac72 | |||
| 2a97ffc33d | |||
| efc88cf64a | |||
| 7b6a837275 | |||
| c34c82b7fe | |||
| 8a044754bd | |||
| 9188ae7cb5 | |||
| 8a3cd90af5 | |||
| 2a167b2eeb | |||
| 0ff902f3b4 | |||
| a9082a4d14 | |||
| e0329ed4b4 | |||
| 6879cd80ae | |||
| e269be2ba2 | |||
| 5c4b6e66fe | |||
| d0a4a3f645 | |||
| ebafb0936d | |||
| 0cb7b065c3 | |||
| 2da02dd0d8 | |||
| d765cf01fe | |||
| 712d0f88d8 | |||
| 49ab23b3cc | |||
| c9abb10489 | |||
| 787cdb3829 | |||
| a5203d04df | |||
| 99f8094400 | |||
| 170e8ea9ea | |||
| a71e4765cc | |||
| 39971db3aa | |||
| 504d914314 | |||
| 47455c424f | |||
| c7fc6b1354 | |||
| ad78868450 | |||
| e2db1164a1 | |||
| 416f05929a | |||
| 5e021b4981 | |||
| 1b9b16649c | |||
| e76e233540 | |||
| a75277285b | |||
| 9dc30b7068 | |||
| 053278a5dc | |||
| c55c028998 | |||
| 65197a5fb3 | |||
| b8f17f5d98 | |||
| d9a55204ba | |||
| b4e9fd811f | |||
| 308fa287a8 | |||
| fa78de9dc3 | |||
| f6818a92cb | |||
| 23c939fd30 | |||
| add1adfec7 | |||
| c80c53a30f | |||
| 24d0c9e6ed | |||
| cc7ae5e7ca | |||
| 0313cf854d | |||
| 0483fabc74 | |||
| da65bec309 | |||
| 4645024d3a | |||
| cd7a3df26f | |||
| 32d2b4064f | |||
| 22cf679aad | |||
| b6d7d34fc6 | |||
| 341923b982 | |||
| 424fb7a5d2 | |||
| 88491c1b6b | |||
| 613a23b57f | |||
| 51a215300b | |||
| ebe14621e3 | |||
| 325aa3dee9 | |||
| a073be6d87 | |||
| 695e7adcd2 | |||
| 281710ef9a | |||
| 808d2e9aa0 | |||
| 285178b3b8 | |||
| 88016c372a | |||
| 998720859c | |||
| 0ba1b54ac6 | |||
| 53415653ff | |||
| 17373dcd93 | |||
| 5964069367 | |||
| de9c085e17 | |||
| 111692bb8c | |||
| 394591e343 | |||
| 3ac849665d | |||
| 0b9cc56fac | |||
| 8896eb72eb | |||
| 19fe1a0510 | |||
| 480bdf5a7b | |||
| 5368f76855 | |||
| 8ef6b8a38c | |||
| 3bbe11cc13 | |||
| c5041f899f | |||
| 8b5fe6eb51 | |||
| 800349c2a5 | |||
| 044931f97b | |||
| 1d353b6352 | |||
| 3496274663 | |||
| 8a19303173 | |||
| 603fbbbce0 | |||
| 10f535c086 | |||
| 48bfb0c9b7 | |||
| f8ce022948 | |||
| 0278f1ac3a | |||
| a482e4e769 | |||
| e0b056e443 | |||
| 79f05e4436 | |||
| f8daddcc4c | |||
| c8e33c72c6 | |||
| d70a16625d | |||
| 5cc54f7c5b | |||
| 0c6e40bbaa | |||
| 2e2000f352 | |||
| 31282401b6 | |||
| 0c31e28e95 | |||
| f571ff8eb6 | |||
| f64ee61d9e | |||
| 8993073dc1 | |||
| 655a09f653 | |||
| f94bf9b924 | |||
| 3663870c72 | |||
| 2461d9e562 | |||
| 7be5d113d8 | |||
| b029de9902 | |||
| bbea1cefdd | |||
| f5aa307d77 | |||
| 4b795020ed | |||
| c86af22f31 | |||
| 10cc12ba66 | |||
| a4fbb32fab | |||
| 1b125004be | |||
| 4fbda0b20c | |||
| 4e51fa8cba | |||
| bf7c99dfc4 | |||
| b95697d731 | |||
| 582bbe6bd7 | |||
| 0cdbf5e61c | |||
| ebe56a0064 | |||
| f77a0802b7 | |||
| c4477f55e5 | |||
| dfd2382039 | |||
| 3b11b26b50 | |||
| d6d13bd49e | |||
| 5efd6905bc | |||
| b17109beea | |||
| 4449235843 | |||
| 38217877aa | |||
| c6d80a7a96 | |||
| 7cd17e22d7 | |||
| 50df09fe13 | |||
| 68fcd3fa73 | |||
| 83e69a09d6 | |||
| 3aa8c10038 | |||
| 103f1ec8d3 |
@ -8,7 +8,8 @@ template = """<!DOCTYPE html>
|
||||
<html>
|
||||
<body>
|
||||
<h1>Links for vLLM</h1/>
|
||||
<a href="../{wheel_html_escaped}">{wheel}</a><br/>
|
||||
<a href="../{x86_wheel_html_escaped}">{x86_wheel}</a><br/>
|
||||
<a href="../{arm_wheel_html_escaped}">{arm_wheel}</a><br/>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
@ -21,7 +22,25 @@ filename = os.path.basename(args.wheel)
|
||||
|
||||
with open("index.html", "w") as f:
|
||||
print(f"Generated index.html for {args.wheel}")
|
||||
# sync the abi tag with .buildkite/scripts/upload-wheels.sh
|
||||
if "x86_64" in filename:
|
||||
x86_wheel = filename
|
||||
arm_wheel = filename.replace("x86_64", "aarch64").replace(
|
||||
"manylinux1", "manylinux2014"
|
||||
)
|
||||
elif "aarch64" in filename:
|
||||
x86_wheel = filename.replace("aarch64", "x86_64").replace(
|
||||
"manylinux2014", "manylinux1"
|
||||
)
|
||||
arm_wheel = filename
|
||||
else:
|
||||
raise ValueError(f"Unsupported wheel: {filename}")
|
||||
# cloudfront requires escaping the '+' character
|
||||
f.write(
|
||||
template.format(wheel=filename, wheel_html_escaped=filename.replace("+", "%2B"))
|
||||
template.format(
|
||||
x86_wheel=x86_wheel,
|
||||
x86_wheel_html_escaped=x86_wheel.replace("+", "%2B"),
|
||||
arm_wheel=arm_wheel,
|
||||
arm_wheel_html_escaped=arm_wheel.replace("+", "%2B"),
|
||||
)
|
||||
)
|
||||
|
||||
@ -1,12 +0,0 @@
|
||||
# For vllm script, with -t option (tensor parallel size).
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m HandH1998/QQQ-Llama-3-8b-g128 -b 32 -l 1000 -f 5 -t 1
|
||||
model_name: "HandH1998/QQQ-Llama-3-8b-g128"
|
||||
tasks:
|
||||
- name: "gsm8k"
|
||||
metrics:
|
||||
- name: "exact_match,strict-match"
|
||||
value: 0.419
|
||||
- name: "exact_match,flexible-extract"
|
||||
value: 0.416
|
||||
limit: 1000
|
||||
num_fewshot: 5
|
||||
@ -3,4 +3,3 @@ Meta-Llama-3-70B-Instruct.yaml
|
||||
Mixtral-8x7B-Instruct-v0.1.yaml
|
||||
Qwen2-57B-A14-Instruct.yaml
|
||||
DeepSeek-V2-Lite-Chat.yaml
|
||||
Meta-Llama-3-8B-QQQ.yaml
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# We can use this script to compute baseline accuracy on GSM for transformers.
|
||||
#
|
||||
# Make sure you have lm-eval-harness installed:
|
||||
# pip install lm-eval==0.4.4
|
||||
# pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api]
|
||||
|
||||
usage() {
|
||||
echo``
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
# We use this for fp8, which HF does not support.
|
||||
#
|
||||
# Make sure you have lm-eval-harness installed:
|
||||
# pip install lm-eval==0.4.4
|
||||
# pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api]
|
||||
|
||||
usage() {
|
||||
echo``
|
||||
|
||||
@ -17,7 +17,7 @@ Latest reproduction guilde: [github issue link](https://github.com/vllm-project/
|
||||
- SGLang: `lmsysorg/sglang:v0.3.2-cu121`
|
||||
- LMDeploy: `openmmlab/lmdeploy:v0.6.1-cu12`
|
||||
- TensorRT-LLM: `nvcr.io/nvidia/tritonserver:24.07-trtllm-python-py3`
|
||||
- *NOTE: we uses r24.07 as the current implementation only works for this version. We are going to bump this up.*
|
||||
- *NOTE: we use r24.07 as the current implementation only works for this version. We are going to bump this up.*
|
||||
- Check [nightly-pipeline.yaml](nightly-pipeline.yaml) for the concrete docker images, specs and commands we use for the benchmark.
|
||||
- Hardware
|
||||
- 8x Nvidia A100 GPUs
|
||||
|
||||
@ -382,7 +382,7 @@ run_genai_perf_tests() {
|
||||
client_command="genai-perf profile \
|
||||
-m $model \
|
||||
--service-kind openai \
|
||||
--backend vllm \
|
||||
--backend "$backend" \
|
||||
--endpoint-type chat \
|
||||
--streaming \
|
||||
--url localhost:$port \
|
||||
|
||||
@ -27,7 +27,12 @@ steps:
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
- block: "Build CUDA 12.6 wheel"
|
||||
key: block-build-cu126-wheel
|
||||
depends_on: ~
|
||||
|
||||
- label: "Build wheel - CUDA 12.6"
|
||||
depends_on: block-build-cu126-wheel
|
||||
id: build-wheel-cuda-12-6
|
||||
agents:
|
||||
queue: cpu_queue_postmerge
|
||||
@ -68,7 +73,7 @@ steps:
|
||||
queue: cpu_queue_postmerge
|
||||
commands:
|
||||
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ."
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg FLASHINFER_AOT_COMPILE=true --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ."
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT"
|
||||
|
||||
- label: "Annotate release workflow"
|
||||
|
||||
@ -46,6 +46,11 @@ function cpu_tests() {
|
||||
set -e
|
||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m"
|
||||
|
||||
# Run kernel tests
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
pytest -v -s tests/kernels/test_onednn.py"
|
||||
|
||||
# Run basic model test
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
@ -99,4 +104,4 @@ function cpu_tests() {
|
||||
|
||||
# All of CPU tests are expected to be finished less than 40 mins.
|
||||
export -f cpu_tests
|
||||
timeout 1.5h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
|
||||
timeout 2h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
|
||||
|
||||
@ -61,7 +61,7 @@ echo "Results will be stored in: $RESULTS_DIR"
|
||||
echo "--- Installing Python dependencies ---"
|
||||
python3 -m pip install --progress-bar off git+https://github.com/thuml/depyf.git \
|
||||
&& python3 -m pip install --progress-bar off pytest pytest-asyncio tpu-info \
|
||||
&& python3 -m pip install --progress-bar off lm_eval[api]==0.4.4 \
|
||||
&& python3 -m pip install --progress-bar off "lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d" \
|
||||
&& python3 -m pip install --progress-bar off hf-transfer
|
||||
echo "--- Python dependencies installed ---"
|
||||
export VLLM_USE_V1=1
|
||||
|
||||
@ -61,7 +61,7 @@ echo "Results will be stored in: $RESULTS_DIR"
|
||||
echo "--- Installing Python dependencies ---"
|
||||
python3 -m pip install --progress-bar off git+https://github.com/thuml/depyf.git \
|
||||
&& python3 -m pip install --progress-bar off pytest pytest-asyncio tpu-info \
|
||||
&& python3 -m pip install --progress-bar off lm_eval[api]==0.4.4 \
|
||||
&& python3 -m pip install --progress-bar off "lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d" \
|
||||
&& python3 -m pip install --progress-bar off hf-transfer
|
||||
echo "--- Python dependencies installed ---"
|
||||
export VLLM_USE_V1=1
|
||||
|
||||
@ -17,7 +17,7 @@ if [ "$disk_usage" -gt "$threshold" ]; then
|
||||
# Remove dangling images (those that are not tagged and not used by any container)
|
||||
docker image prune -f
|
||||
# Remove unused volumes / force the system prune for old images as well.
|
||||
docker volume prune -f && docker system prune --force --filter "until=72h" --all
|
||||
docker volume prune -f && docker system prune --force --filter "until=24h" --all
|
||||
echo "Docker images and volumes cleanup completed."
|
||||
else
|
||||
echo "Disk usage is below $threshold%. No cleanup needed."
|
||||
|
||||
@ -14,8 +14,19 @@ fi
|
||||
# Get the single wheel file
|
||||
wheel="${wheel_files[0]}"
|
||||
|
||||
# Rename 'linux' to 'manylinux1' in the wheel filename
|
||||
new_wheel="${wheel/linux/manylinux1}"
|
||||
# Detect architecture and rename 'linux' to appropriate manylinux version
|
||||
arch=$(uname -m)
|
||||
if [[ $arch == "x86_64" ]]; then
|
||||
manylinux_version="manylinux1"
|
||||
elif [[ $arch == "aarch64" ]]; then
|
||||
manylinux_version="manylinux2014"
|
||||
else
|
||||
echo "Warning: Unknown architecture $arch, using manylinux1 as default"
|
||||
manylinux_version="manylinux1"
|
||||
fi
|
||||
|
||||
# Rename 'linux' to the appropriate manylinux version in the wheel filename
|
||||
new_wheel="${wheel/linux/$manylinux_version}"
|
||||
mv -- "$wheel" "$new_wheel"
|
||||
wheel="$new_wheel"
|
||||
|
||||
|
||||
@ -244,6 +244,7 @@ steps:
|
||||
- pytest -v -s v1/core
|
||||
- pytest -v -s v1/engine
|
||||
- pytest -v -s v1/entrypoints
|
||||
- pytest -v -s v1/executor
|
||||
- pytest -v -s v1/sample
|
||||
- pytest -v -s v1/logits_processors
|
||||
- pytest -v -s v1/worker
|
||||
@ -328,6 +329,7 @@ steps:
|
||||
- pytest -v -s compile/test_sequence_parallelism.py
|
||||
- pytest -v -s compile/test_async_tp.py
|
||||
- pytest -v -s compile/test_fusion_all_reduce.py
|
||||
- pytest -v -s compile/test_decorator.py
|
||||
|
||||
- label: PyTorch Fullgraph Smoke Test # 9min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
@ -341,6 +343,7 @@ steps:
|
||||
- pytest -v -s compile/piecewise/test_simple.py
|
||||
- pytest -v -s compile/piecewise/test_toy_llama.py
|
||||
- pytest -v -s compile/piecewise/test_full_cudagraph.py
|
||||
- pytest -v -s compile/piecewise/test_multiple_graphs.py
|
||||
|
||||
- label: PyTorch Fullgraph Test # 18min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
@ -543,6 +546,15 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s models/language/pooling -m 'not core_model'
|
||||
|
||||
- label: Multi-Modal Processor Test
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models/multimodal
|
||||
commands:
|
||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
- pytest -v -s models/multimodal/processing --ignore models/multimodal/processing/test_tensor_schema.py
|
||||
- pytest -v -s models/multimodal/processing/test_tensor_schema.py
|
||||
|
||||
- label: Multi-Modal Models Test (Standard)
|
||||
mirror_hardwares: [amdexperimental]
|
||||
torch_nightly: true
|
||||
@ -552,9 +564,7 @@ steps:
|
||||
commands:
|
||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
- pip freeze | grep -E 'torch'
|
||||
- pytest -v -s models/multimodal/processing
|
||||
- pytest -v -s --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/test_tensor_schema.py models/multimodal -m core_model
|
||||
- pytest -v -s models/multimodal/test_tensor_schema.py -m core_model # Needs mp_method="spawn"
|
||||
- pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing
|
||||
- cd .. && pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work
|
||||
|
||||
- label: Multi-Modal Models Test (Extended) 1
|
||||
@ -565,7 +575,7 @@ steps:
|
||||
- tests/models/multimodal
|
||||
commands:
|
||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
- pytest -v -s --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/processing models/multimodal -m 'not core_model'
|
||||
- pytest -v -s models/multimodal -m 'not core_model' --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/processing
|
||||
|
||||
- label: Multi-Modal Models Test (Extended) 2
|
||||
mirror_hardwares: [amdexperimental]
|
||||
@ -646,6 +656,7 @@ steps:
|
||||
- pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py
|
||||
- pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
|
||||
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
|
||||
- pytest -v -s tests/kernels/moe/test_mxfp4_moe.py
|
||||
# Fusion
|
||||
- pytest -v -s tests/compile/test_fusion_all_reduce.py
|
||||
- pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern
|
||||
@ -832,3 +843,10 @@ steps:
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4
|
||||
|
||||
- label: Qwen MoE EP Test # optional
|
||||
gpu: h200
|
||||
optional: true
|
||||
num_gpus: 2
|
||||
commands:
|
||||
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 /vllm-workspace/examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
|
||||
|
||||
3
.github/PULL_REQUEST_TEMPLATE.md
vendored
3
.github/PULL_REQUEST_TEMPLATE.md
vendored
@ -7,8 +7,6 @@ PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTT
|
||||
|
||||
## Test Result
|
||||
|
||||
## (Optional) Documentation Update
|
||||
|
||||
---
|
||||
<details>
|
||||
<summary> Essential Elements of an Effective PR Description Checklist </summary>
|
||||
@ -17,6 +15,7 @@ PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTT
|
||||
- [ ] The test plan, such as providing test command.
|
||||
- [ ] The test results, such as pasting the results comparison before and after, or e2e results
|
||||
- [ ] (Optional) The necessary documentation update, such as updating `supported_models.md` and `examples` for a new model.
|
||||
- [ ] (Optional) Release notes update. If your change is user facing, please update the release notes draft in the [Google Doc](https://docs.google.com/document/d/1YyVqrgX4gHTtrstbq8oWUImOyPCKSGnJ7xtTpmXzlRs/edit?tab=t.0).
|
||||
</details>
|
||||
|
||||
**BEFORE SUBMITTING, PLEASE READ <https://docs.vllm.ai/en/latest/contributing>** (anything written below this line will be removed by GitHub Actions)
|
||||
|
||||
89
.github/workflows/lint-and-deploy.yaml
vendored
89
.github/workflows/lint-and-deploy.yaml
vendored
@ -1,89 +0,0 @@
|
||||
name: Lint and Deploy Charts
|
||||
|
||||
on: pull_request
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
lint-and-deploy:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@b9e51907a09c216f16ebe8536097933489208112 # v4.3.0
|
||||
with:
|
||||
version: v3.14.4
|
||||
|
||||
#Python is required because ct lint runs Yamale and yamllint which require Python.
|
||||
- uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
|
||||
with:
|
||||
python-version: '3.13'
|
||||
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@0d28d3144d3a25ea2cc349d6e59901c4ff469b3b # v2.7.0
|
||||
with:
|
||||
version: v3.10.1
|
||||
|
||||
- name: Run chart-testing (lint)
|
||||
run: ct lint --target-branch ${{ github.event.repository.default_branch }} --chart-dirs examples/online_serving/chart-helm --charts examples/online_serving/chart-helm
|
||||
|
||||
- name: Setup minio
|
||||
run: |
|
||||
docker network create vllm-net
|
||||
docker run -d -p 9000:9000 --name minio --net vllm-net \
|
||||
-e "MINIO_ACCESS_KEY=minioadmin" \
|
||||
-e "MINIO_SECRET_KEY=minioadmin" \
|
||||
-v /tmp/data:/data \
|
||||
-v /tmp/config:/root/.minio \
|
||||
minio/minio server /data
|
||||
export AWS_ACCESS_KEY_ID=minioadmin
|
||||
export AWS_SECRET_ACCESS_KEY=minioadmin
|
||||
export AWS_EC2_METADATA_DISABLED=true
|
||||
mkdir opt-125m
|
||||
cd opt-125m && curl -O -Ls "https://huggingface.co/facebook/opt-125m/resolve/main/{pytorch_model.bin,config.json,generation_config.json,merges.txt,special_tokens_map.json,tokenizer_config.json,vocab.json}" && cd ..
|
||||
aws --endpoint-url http://127.0.0.1:9000/ s3 mb s3://testbucket
|
||||
aws --endpoint-url http://127.0.0.1:9000/ s3 cp opt-125m/ s3://testbucket/opt-125m --recursive
|
||||
|
||||
- name: Create kind cluster
|
||||
uses: helm/kind-action@a1b0e391336a6ee6713a0583f8c6240d70863de3 # v1.12.0
|
||||
|
||||
- name: Build the Docker image vllm cpu
|
||||
run: docker buildx build -f docker/Dockerfile.cpu -t vllm-cpu-env .
|
||||
|
||||
- name: Configuration of docker images, network and namespace for the kind cluster
|
||||
run: |
|
||||
docker pull amazon/aws-cli:2.6.4
|
||||
kind load docker-image amazon/aws-cli:2.6.4 --name chart-testing
|
||||
kind load docker-image vllm-cpu-env:latest --name chart-testing
|
||||
docker network connect vllm-net "$(docker ps -aqf "name=chart-testing-control-plane")"
|
||||
kubectl create ns ns-vllm
|
||||
|
||||
- name: Run chart-testing (install)
|
||||
run: |
|
||||
export AWS_ACCESS_KEY_ID=minioadmin
|
||||
export AWS_SECRET_ACCESS_KEY=minioadmin
|
||||
sleep 30 && kubectl -n ns-vllm logs -f "$(kubectl -n ns-vllm get pods | awk '/deployment/ {print $1;exit}')" &
|
||||
helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set image.env[2].name=VLLM_CPU_CI_ENV --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string image.env[2].value="1" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env"
|
||||
|
||||
- name: curl test
|
||||
run: |
|
||||
kubectl -n ns-vllm port-forward service/test-vllm-service 8001:80 &
|
||||
sleep 10
|
||||
CODE="$(curl -v -f --location http://localhost:8001/v1/completions \
|
||||
--header "Content-Type: application/json" \
|
||||
--data '{
|
||||
"model": "opt-125m",
|
||||
"prompt": "San Francisco is a",
|
||||
"max_tokens": 7,
|
||||
"temperature": 0
|
||||
}'):$CODE"
|
||||
echo "$CODE"
|
||||
111
.github/workflows/publish.yml
vendored
111
.github/workflows/publish.yml
vendored
@ -1,111 +0,0 @@
|
||||
# This workflow will upload a Python Package to Release asset
|
||||
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions
|
||||
|
||||
name: Create Release
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- v*
|
||||
|
||||
# Needed to create release and upload assets
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
release:
|
||||
# Retrieve tag and create release
|
||||
name: Create Release
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
upload_url: ${{ steps.create_release.outputs.upload_url }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
||||
- name: Extract branch info
|
||||
shell: bash
|
||||
run: |
|
||||
echo "release_tag=${GITHUB_REF#refs/*/}" >> "$GITHUB_ENV"
|
||||
|
||||
- name: Create Release
|
||||
id: create_release
|
||||
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1
|
||||
env:
|
||||
RELEASE_TAG: ${{ env.release_tag }}
|
||||
with:
|
||||
github-token: "${{ secrets.GITHUB_TOKEN }}"
|
||||
script: |
|
||||
const script = require('.github/workflows/scripts/create_release.js')
|
||||
await script(github, context, core)
|
||||
|
||||
# NOTE(simon): No longer build wheel using GitHub Actions. See buildkite's release workflow.
|
||||
# wheel:
|
||||
# name: Build Wheel
|
||||
# runs-on: ${{ matrix.os }}
|
||||
# needs: release
|
||||
|
||||
# strategy:
|
||||
# fail-fast: false
|
||||
# matrix:
|
||||
# os: ['ubuntu-20.04']
|
||||
# python-version: ['3.9', '3.10', '3.11', '3.12']
|
||||
# pytorch-version: ['2.4.0'] # Must be the most recent version that meets requirements/cuda.txt.
|
||||
# cuda-version: ['11.8', '12.1']
|
||||
|
||||
# steps:
|
||||
# - name: Checkout
|
||||
# uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
||||
# - name: Setup ccache
|
||||
# uses: hendrikmuhs/ccache-action@ed74d11c0b343532753ecead8a951bb09bb34bc9 # v1.2.14
|
||||
# with:
|
||||
# create-symlink: true
|
||||
# key: ${{ github.job }}-${{ matrix.python-version }}-${{ matrix.cuda-version }}
|
||||
|
||||
# - name: Set up Linux Env
|
||||
# if: ${{ runner.os == 'Linux' }}
|
||||
# run: |
|
||||
# bash -x .github/workflows/scripts/env.sh
|
||||
|
||||
# - name: Set up Python
|
||||
# uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
# with:
|
||||
# python-version: ${{ matrix.python-version }}
|
||||
|
||||
# - name: Install CUDA ${{ matrix.cuda-version }}
|
||||
# run: |
|
||||
# bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ${{ matrix.os }}
|
||||
|
||||
# - name: Install PyTorch ${{ matrix.pytorch-version }} with CUDA ${{ matrix.cuda-version }}
|
||||
# run: |
|
||||
# bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.pytorch-version }} ${{ matrix.cuda-version }}
|
||||
|
||||
# - name: Build wheel
|
||||
# shell: bash
|
||||
# env:
|
||||
# CMAKE_BUILD_TYPE: Release # do not compile with debug symbol to reduce wheel size
|
||||
# run: |
|
||||
# bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }}
|
||||
# wheel_name=$(find dist -name "*whl" -print0 | xargs -0 -n 1 basename)
|
||||
# asset_name=${wheel_name//"linux"/"manylinux1"}
|
||||
# echo "wheel_name=${wheel_name}" >> "$GITHUB_ENV"
|
||||
# echo "asset_name=${asset_name}" >> "$GITHUB_ENV"
|
||||
|
||||
# - name: Upload Release Asset
|
||||
# uses: actions/upload-release-asset@e8f9f06c4b078e705bd2ea027f0926603fc9b4d5 # v1.0.2
|
||||
# env:
|
||||
# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
# with:
|
||||
# upload_url: ${{ needs.release.outputs.upload_url }}
|
||||
# asset_path: ./dist/${{ env.wheel_name }}
|
||||
# asset_name: ${{ env.asset_name }}
|
||||
# asset_content_type: application/*
|
||||
|
||||
# (Danielkinz): This last step will publish the .whl to pypi. Warning: untested
|
||||
# - name: Publish package
|
||||
# uses: pypa/gh-action-pypi-publish@release/v1.8
|
||||
# with:
|
||||
# repository-url: https://test.pypi.org/legacy/
|
||||
# password: ${{ secrets.PYPI_API_TOKEN }}
|
||||
# skip-existing: true
|
||||
49
.github/workflows/reminder_comment.yml
vendored
49
.github/workflows/reminder_comment.yml
vendored
@ -12,16 +12,43 @@ jobs:
|
||||
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1
|
||||
with:
|
||||
script: |
|
||||
github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
body: '👋 Hi! Thank you for contributing to the vLLM project.\n\n' +
|
||||
'💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.\n\n' +
|
||||
'Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your `fastcheck` build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping `simon-mo` or `khluu` to add you in our Buildkite org.\n\n' +
|
||||
'Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.\n\n' +
|
||||
'To run CI, PR reviewers can either: Add `ready` label to the PR or enable auto-merge.\n\n' +
|
||||
'🚀'
|
||||
})
|
||||
try {
|
||||
// Get the PR author
|
||||
const prAuthor = context.payload.pull_request.user.login;
|
||||
|
||||
// Check if this is the author's first PR in this repository
|
||||
// Use GitHub's search API to find all PRs by this author
|
||||
const { data: searchResults } = await github.rest.search.issuesAndPullRequests({
|
||||
q: `repo:${context.repo.owner}/${context.repo.repo} type:pr author:${prAuthor}`,
|
||||
per_page: 100
|
||||
});
|
||||
|
||||
const authorPRCount = searchResults.total_count;
|
||||
|
||||
console.log(`Found ${authorPRCount} PRs by ${prAuthor}`);
|
||||
|
||||
// Only post comment if this is the first PR (only one PR by this author)
|
||||
if (authorPRCount === 1) {
|
||||
console.log(`Posting welcome comment for first-time contributor: ${prAuthor}`);
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
body: '👋 Hi! Thank you for contributing to the vLLM project.\n\n' +
|
||||
'💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.\n\n' +
|
||||
'Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which starts running only a small and essential subset of CI tests to quickly catch errors. \n\n' +
|
||||
'You ask your reviewers to trigger select CI tests on top of `fastcheck` CI. \n\n' +
|
||||
'Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.\n\n' +
|
||||
'To run CI, PR reviewers can either: Add `ready` label to the PR or enable auto-merge.\n\n' +
|
||||
'If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.\n\n' +
|
||||
'🚀'
|
||||
});
|
||||
} else {
|
||||
console.log(`Skipping comment for ${prAuthor} - not their first PR (${authorPRCount} PRs found)`);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error checking PR history or posting comment:', error);
|
||||
// Don't fail the workflow, just log the error
|
||||
}
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
@ -357,9 +357,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
|
||||
|
||||
set(MARLIN_SRCS
|
||||
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
|
||||
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
|
||||
"csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu"
|
||||
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
||||
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
|
||||
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu")
|
||||
@ -752,6 +750,33 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"found in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Only build W4A8 kernels if we are building for something compatible with sm90a
|
||||
cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${W4A8_ARCHS}")
|
||||
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
|
||||
message(STATUS "Building W4A8 kernels for archs: ${W4A8_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0
|
||||
AND W4A8_ARCHS)
|
||||
message(STATUS "Not building W4A8 kernels as CUDA Compiler version is "
|
||||
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
|
||||
"later if you intend on running w4a16 quantized models on "
|
||||
"Hopper.")
|
||||
else()
|
||||
message(STATUS "Not building W4A8 kernels as no compatible archs "
|
||||
"found in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# if CUDA endif
|
||||
endif()
|
||||
|
||||
@ -792,7 +817,9 @@ set(VLLM_MOE_EXT_SRC
|
||||
"csrc/moe/topk_softmax_kernels.cu")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_MOE_EXT_SRC "csrc/moe/moe_wna16.cu")
|
||||
list(APPEND VLLM_MOE_EXT_SRC
|
||||
"csrc/moe/moe_wna16.cu"
|
||||
"csrc/moe/grouped_topk_kernels.cu")
|
||||
endif()
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
@ -18,14 +18,15 @@ Easy, fast, and cheap LLM serving for everyone
|
||||
|
||||
*Latest News* 🔥
|
||||
|
||||
- [2025/08] We hosted [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/pDmAXHcN7Iqc8sUKgJgGtg) focusing on building, developing, and integrating with vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1OvLx39wnCGy_WKq8SiVKf7YcxxYI3WCH).
|
||||
- [2025/08] We hosted [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/dgkWg1WFpWGO2jCdTqQHxA) focusing on large-scale LLM deployment! Please find the meetup slides [here](https://drive.google.com/drive/folders/1Pid6NSFLU43DZRi0EaTcPgXsAzDvbBqF) and the recording [here](https://www.chaspark.com/#/live/1166916873711665152).
|
||||
- [2025/05] We hosted [NYC vLLM Meetup](https://lu.ma/c1rqyf1f)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing).
|
||||
- [2025/05] vLLM is now a hosted project under PyTorch Foundation! Please find the announcement [here](https://pytorch.org/blog/pytorch-foundation-welcomes-vllm/).
|
||||
- [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html).
|
||||
|
||||
<details>
|
||||
<summary>Previous News</summary>
|
||||
|
||||
- [2025/05] We hosted [NYC vLLM Meetup](https://lu.ma/c1rqyf1f)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing).
|
||||
- [2025/04] We hosted [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing).
|
||||
- [2025/03] We hosted [vLLM x Ollama Inference Night](https://lu.ma/vllm-ollama)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/16T2PDD1YwRnZ4Tu8Q5r6n53c5Lr5c73UV9Vd2_eBo4U/edit?usp=sharing).
|
||||
- [2025/03] We hosted [the first vLLM China Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg)! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1REHvfQMKGnvz6p3Fd23HhSO4c8j5WPGZV0bKYLwnHyQ/edit?usp=sharing).
|
||||
|
||||
@ -59,6 +59,12 @@ become available.
|
||||
<td style="text-align: center;">✅</td>
|
||||
<td><code>synthetic</code></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>RandomMultiModal (Image/Video)</strong></td>
|
||||
<td style="text-align: center;">🟡</td>
|
||||
<td style="text-align: center;">🚧</td>
|
||||
<td><code>synthetic</code> </td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>Prefix Repetition</strong></td>
|
||||
<td style="text-align: center;">✅</td>
|
||||
@ -722,4 +728,75 @@ python benchmarks/benchmark_serving.py \
|
||||
--endpoint /v1/chat/completion
|
||||
```
|
||||
|
||||
### Synthetic Random Images (random-mm)
|
||||
|
||||
Generate synthetic image inputs alongside random text prompts to stress-test vision models without external datasets.
|
||||
|
||||
Notes:
|
||||
|
||||
- Works only with online benchmark via the OpenAI backend (`--backend openai-chat`) and endpoint `/v1/chat/completions`.
|
||||
- Video sampling is not yet implemented.
|
||||
|
||||
Start the server (example):
|
||||
|
||||
```bash
|
||||
vllm serve Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
--dtype bfloat16 \
|
||||
--max-model-len 16384 \
|
||||
--limit-mm-per-prompt '{"image": 3, "video": 0}' \
|
||||
--mm-processor-kwargs max_pixels=1003520
|
||||
```
|
||||
|
||||
Benchmark. It is recommended to use the flag `--ignore-eos` to simulate real responses. You can set the size of the output via the arg `random-output-len`.
|
||||
|
||||
Ex.1: Fixed number of items and a single image resolutionm, enforcing generation of approx 40 tokens:
|
||||
|
||||
```bash
|
||||
vllm bench serve \
|
||||
--backend openai-chat \
|
||||
--model Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
--endpoint /v1/chat/completions \
|
||||
--dataset-name random-mm \
|
||||
--num-prompts 100 \
|
||||
--max-concurrency 10 \
|
||||
--random-prefix-len 25 \
|
||||
--random-input-len 300 \
|
||||
--random-output-len 40 \
|
||||
--random-range-ratio 0.2 \
|
||||
--random-mm-base-items-per-request 2 \
|
||||
--random-mm-limit-mm-per-prompt '{"image": 3, "video": 0}' \
|
||||
--random-mm-bucket-config '{(224, 224, 1): 1.0}' \
|
||||
--request-rate inf \
|
||||
--ignore-eos \
|
||||
--seed 42
|
||||
```
|
||||
|
||||
The number of items per request can be controlled by passing multiple image buckets:
|
||||
|
||||
```bash
|
||||
--random-mm-base-items-per-request 2 \
|
||||
--random-mm-num-mm-items-range-ratio 0.5 \
|
||||
--random-mm-limit-mm-per-prompt '{"image": 4, "video": 0}' \
|
||||
--random-mm-bucket-config '{(256, 256, 1): 0.7, (720, 1280, 1): 0.3}' \
|
||||
```
|
||||
|
||||
Flags specific to `random-mm`:
|
||||
|
||||
- `--random-mm-base-items-per-request`: base number of multimodal items per request.
|
||||
- `--random-mm-num-mm-items-range-ratio`: vary item count uniformly in the closed integer range [floor(n·(1−r)), ceil(n·(1+r))]. Set r=0 to keep it fixed; r=1 allows 0 items.
|
||||
- `--random-mm-limit-mm-per-prompt`: per-modality hard caps, e.g. '{"image": 3, "video": 0}'.
|
||||
- `--random-mm-bucket-config`: dict mapping (H, W, T) → probability. Entries with probability 0 are removed; remaining probabilities are renormalized to sum to 1. Use T=1 for images. Set any T>1 for videos (video sampling not yet supported).
|
||||
|
||||
Behavioral notes:
|
||||
|
||||
- If the requested base item count cannot be satisfied under the provided per-prompt limits, the tool raises an error rather than silently clamping.
|
||||
|
||||
How sampling works:
|
||||
|
||||
- Determine per-request item count k by sampling uniformly from the integer range defined by `--random-mm-base-items-per-request` and `--random-mm-num-mm-items-range-ratio`, then clamp k to at most the sum of per-modality limits.
|
||||
- For each of the k items, sample a bucket (H, W, T) according to the normalized probabilities in `--random-mm-bucket-config`, while tracking how many items of each modality have been added.
|
||||
- If a modality (e.g., image) reaches its limit from `--random-mm-limit-mm-per-prompt`, all buckets of that modality are excluded and the remaining bucket probabilities are renormalized before continuing.
|
||||
This should be seen as an edge case, and if this behavior can be avoided by setting `--random-mm-limit-mm-per-prompt` to a large number. Note that this might result in errors due to engine config `--limit-mm-per-prompt`.
|
||||
- The resulting request contains synthetic image data in `multi_modal_data` (OpenAI Chat format). When `random-mm` is used with the OpenAI Chat backend, prompts remain text and MM content is attached via `multi_modal_data`.
|
||||
|
||||
</details>
|
||||
|
||||
@ -958,8 +958,10 @@ class InstructCoderDataset(HuggingFaceDataset):
|
||||
for i, item in enumerate(self.data):
|
||||
if len(sampled_requests) >= num_requests:
|
||||
break
|
||||
prompt = f"{item['input']}\n\n{item['instruction']} Just output \
|
||||
the code, do not include any explanation."
|
||||
prompt = (
|
||||
f"{item['input']}\n\n{item['instruction']} Just output "
|
||||
"the code, do not include any explanation."
|
||||
)
|
||||
|
||||
# apply template
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
|
||||
@ -80,6 +80,11 @@ def bench_run(
|
||||
a, score, topk, renormalize=False
|
||||
)
|
||||
|
||||
ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
|
||||
|
||||
def run_triton_moe(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
@ -111,6 +116,10 @@ def bench_run(
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
per_act_token: bool,
|
||||
@ -125,6 +134,10 @@ def bench_run(
|
||||
topk_ids,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
per_act_token,
|
||||
a1_scale=None,
|
||||
)
|
||||
@ -136,6 +149,10 @@ def bench_run(
|
||||
w2_q: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
ab_strides1: torch.Tensor,
|
||||
ab_strides2: torch.Tensor,
|
||||
c_strides1: torch.Tensor,
|
||||
c_strides2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
):
|
||||
@ -150,6 +167,10 @@ def bench_run(
|
||||
topk_ids,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
per_act_token,
|
||||
a1_scale=None,
|
||||
)
|
||||
@ -194,6 +215,10 @@ def bench_run(
|
||||
w2_q,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
)
|
||||
@ -231,6 +256,10 @@ def bench_run(
|
||||
"w1_scale": w1_scale,
|
||||
"w2_scale": w2_scale,
|
||||
"per_act_token": per_act_token,
|
||||
"ab_strides1": ab_strides1,
|
||||
"ab_strides2": ab_strides2,
|
||||
"c_strides1": c_strides1,
|
||||
"c_strides2": c_strides2,
|
||||
# cuda graph params
|
||||
"cutlass_graph": cutlass_graph,
|
||||
"triton_graph": triton_graph,
|
||||
@ -289,6 +318,10 @@ def bench_run(
|
||||
w2_q,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
ab_strides1,
|
||||
ab_strides2,
|
||||
c_strides1,
|
||||
c_strides2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
per_act_token,
|
||||
@ -297,7 +330,7 @@ def bench_run(
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
|
||||
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
|
||||
@ -253,28 +253,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
|
||||
else:
|
||||
assert bt.a.dtype == torch.int8
|
||||
assert bt.wtype == scalar_types.uint4b8
|
||||
|
||||
if bt.w_ch_s is not None:
|
||||
s_ch = bt.w_ch_s.to(torch.float32)
|
||||
else:
|
||||
s_ch = torch.ones(bt.w_ref.shape[1], dtype=torch.float32, device=device)
|
||||
|
||||
if bt.w_tok_s is not None:
|
||||
s_tok = bt.w_tok_s.to(torch.float32)
|
||||
else:
|
||||
s_tok = torch.ones(bt.a.shape[0], dtype=torch.float32, device=device)
|
||||
|
||||
fn = lambda: ops.marlin_qqq_gemm(
|
||||
a=bt.a,
|
||||
b_q_weight=w_q,
|
||||
s_group=w_s,
|
||||
s_tok=s_tok,
|
||||
s_ch=s_ch,
|
||||
workspace=workspace.scratch,
|
||||
size_m=bt.a.shape[0],
|
||||
size_n=bt.w_ref.shape[1],
|
||||
size_k=bt.w_ref.shape[0],
|
||||
)
|
||||
raise NotImplementedError("QQQ is not supported anymore")
|
||||
|
||||
return fn
|
||||
|
||||
@ -305,6 +284,25 @@ def machete_create_bench_fn(
|
||||
)
|
||||
|
||||
|
||||
def cutlass_w4a8_create_bench_fn(
|
||||
bt: BenchmarkTensors, out_type=torch.dtype, schedule=None
|
||||
) -> Callable:
|
||||
w_q = bt.w_q.t().contiguous().t() # make col major
|
||||
w_q = ops.cutlass_encode_and_reorder_int4b(w_q)
|
||||
# expects fp8 scales
|
||||
w_s = ops.cutlass_pack_scale_fp8(bt.w_g_s.to(torch.float8_e4m3fn))
|
||||
|
||||
return lambda: ops.cutlass_w4a8_mm(
|
||||
a=bt.a,
|
||||
b_q=w_q,
|
||||
b_group_scales=w_s,
|
||||
b_group_size=bt.group_size,
|
||||
b_channel_scales=bt.w_ch_s,
|
||||
a_token_scales=bt.w_tok_s,
|
||||
maybe_schedule=schedule,
|
||||
)
|
||||
|
||||
|
||||
# impl
|
||||
|
||||
# bench
|
||||
@ -406,6 +404,20 @@ def bench(
|
||||
)
|
||||
)
|
||||
|
||||
# cutlass w4a8
|
||||
if types.act_type == torch.float8_e4m3fn and group_size == 128:
|
||||
timers.append(
|
||||
bench_fns(
|
||||
label,
|
||||
sub_label,
|
||||
f"cutlass w4a8 ({name_type_string})",
|
||||
[
|
||||
cutlass_w4a8_create_bench_fn(bt, out_type=types.output_type)
|
||||
for bt in benchmark_tensors
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
if sweep_schedules:
|
||||
global _SWEEP_SCHEDULES_RESULTS
|
||||
|
||||
|
||||
77
benchmarks/kernels/benchmark_silu_mul_fp8_quant.py
Normal file
77
benchmarks/kernels/benchmark_silu_mul_fp8_quant.py
Normal file
@ -0,0 +1,77 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
silu_mul_fp8_quant_deep_gemm,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def benchmark(E, T, H, G=128, runs=50):
|
||||
current_platform.seed_everything(42)
|
||||
y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda")
|
||||
tokens_per_expert = torch.randint(
|
||||
T // 2, T, size=(E,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
# Warmup
|
||||
for _ in range(10):
|
||||
silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Benchmark
|
||||
torch.cuda.synchronize()
|
||||
start = time.perf_counter()
|
||||
for _ in range(runs):
|
||||
silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
avg_time = (time.perf_counter() - start) / runs * 1000
|
||||
|
||||
# Calculate actual work done (only count valid tokens)
|
||||
actual_tokens = tokens_per_expert.sum().item()
|
||||
actual_elements = actual_tokens * H
|
||||
|
||||
# GFLOPS: operations per element = exp + 3 muls + 1 div + quantization ops ≈ 8 ops
|
||||
ops_per_element = 8
|
||||
total_ops = actual_elements * ops_per_element
|
||||
gflops = total_ops / (avg_time / 1000) / 1e9
|
||||
|
||||
# Memory bandwidth: bfloat16 inputs (2 bytes), fp8 output (1 byte), scales (4 bytes)
|
||||
input_bytes = actual_tokens * 2 * H * 2 # 2*H bfloat16 inputs
|
||||
output_bytes = actual_tokens * H * 1 # H fp8 outputs
|
||||
scale_bytes = actual_tokens * (H // G) * 4 # scales in float32
|
||||
total_bytes = input_bytes + output_bytes + scale_bytes
|
||||
memory_bw = total_bytes / (avg_time / 1000) / 1e9
|
||||
|
||||
return avg_time, gflops, memory_bw
|
||||
|
||||
|
||||
configs = [
|
||||
(8, 32, 1024),
|
||||
(16, 64, 2048),
|
||||
(32, 128, 4096),
|
||||
# DeepSeekV3 Configs
|
||||
(256, 16, 7168),
|
||||
(256, 32, 7168),
|
||||
(256, 64, 7168),
|
||||
(256, 128, 7168),
|
||||
(256, 256, 7168),
|
||||
(256, 512, 7168),
|
||||
(256, 1024, 7168),
|
||||
]
|
||||
|
||||
print(f"GPU: {torch.cuda.get_device_name()}")
|
||||
print(f"{'Config':<20} {'Time(ms)':<10} {'GFLOPS':<10} {'GB/s':<10}")
|
||||
print("-" * 50)
|
||||
|
||||
for E, T, H in configs:
|
||||
try:
|
||||
time_ms, gflops, gbps = benchmark(E, T, H)
|
||||
print(f"E={E:3d},T={T:4d},H={H:4d} {time_ms:8.3f} {gflops:8.1f} {gbps:8.1f}")
|
||||
except Exception:
|
||||
print(f"E={E:3d},T={T:4d},H={H:4d} FAILED")
|
||||
@ -9,8 +9,11 @@ from typing import Optional
|
||||
import flashinfer
|
||||
import torch
|
||||
|
||||
from vllm.utils import round_up
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
FP8_DTYPE = torch.float8_e4m3fn
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
|
||||
def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||
@ -61,13 +64,13 @@ def benchmark_decode(
|
||||
else:
|
||||
raise ValueError(f"Invalid kv_layout: {kv_layout}")
|
||||
|
||||
query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype)
|
||||
# Always using 1.0 scale to reflect the real perf in benchmarking
|
||||
q_scale = 1.0
|
||||
ref_query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype)
|
||||
if q_quant_dtype == FP8_DTYPE:
|
||||
query, q_scale = to_float8(query)
|
||||
ref_query = query.to(dtype) * q_scale
|
||||
query, _ = to_float8(ref_query)
|
||||
else:
|
||||
q_scale = 1.0
|
||||
ref_query = query
|
||||
query = ref_query
|
||||
|
||||
kv_lens = torch.randint(1, max_seq_len, (batch_size,), dtype=torch.int32)
|
||||
kv_lens[-1] = max_seq_len
|
||||
@ -75,14 +78,13 @@ def benchmark_decode(
|
||||
seq_lens = kv_lens
|
||||
max_seq_len = torch.max(seq_lens).item()
|
||||
|
||||
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
||||
# Always using 1.0 scale to reflect the real perf in benchmarking
|
||||
k_scale = v_scale = 1.0
|
||||
ref_kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
||||
if kv_quant_dtype == FP8_DTYPE:
|
||||
kv_cache, kv_scale = to_float8(kv_cache)
|
||||
ref_kv_cache = kv_cache.to(dtype) * kv_scale
|
||||
kv_cache, _ = to_float8(ref_kv_cache)
|
||||
else:
|
||||
kv_scale = 1.0
|
||||
ref_kv_cache = kv_cache
|
||||
k_scale = v_scale = kv_scale
|
||||
kv_cache = ref_kv_cache
|
||||
|
||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(
|
||||
@ -110,7 +112,7 @@ def benchmark_decode(
|
||||
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer,
|
||||
kv_layout,
|
||||
use_tensor_cores=((num_qo_heads // num_kv_heads) > 4),
|
||||
use_tensor_cores=True,
|
||||
)
|
||||
wrapper.plan(
|
||||
kv_indptr,
|
||||
@ -142,11 +144,31 @@ def benchmark_decode(
|
||||
return sum(times) / len(times), torch.std(torch.tensor(times))
|
||||
|
||||
o_scale = 1.0
|
||||
o_sf_scale = None
|
||||
output_baseline = torch.empty(ref_query.shape, dtype=dtype)
|
||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||
if o_quant_dtype == FP4_DTYPE:
|
||||
o_sf_scale = 500.0
|
||||
output_trtllm = flashinfer.utils.FP4Tensor(
|
||||
torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8),
|
||||
torch.empty(
|
||||
(
|
||||
round_up(query.shape[0], 128),
|
||||
round_up(query.shape[1] * query.shape[2] // 16, 4),
|
||||
),
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
)
|
||||
else:
|
||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||
|
||||
def baseline_decode():
|
||||
return wrapper.run(ref_query, ref_kv_cache, out=output_baseline)
|
||||
return wrapper.run(
|
||||
ref_query,
|
||||
ref_kv_cache,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
out=output_baseline,
|
||||
)
|
||||
|
||||
def trtllm_decode():
|
||||
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
|
||||
@ -158,6 +180,7 @@ def benchmark_decode(
|
||||
max_seq_len=max_seq_len,
|
||||
bmm1_scale=q_scale * k_scale * sm_scale,
|
||||
bmm2_scale=v_scale / o_scale,
|
||||
o_sf_scale=o_sf_scale,
|
||||
out=output_trtllm,
|
||||
)
|
||||
|
||||
@ -237,6 +260,7 @@ if __name__ == "__main__":
|
||||
(None, None, None),
|
||||
(None, FP8_DTYPE, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
|
||||
]
|
||||
|
||||
for quant_dtype in quant_dtypes:
|
||||
|
||||
@ -9,8 +9,11 @@ from typing import Optional
|
||||
import flashinfer
|
||||
import torch
|
||||
|
||||
from vllm.utils import round_up
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
FP8_DTYPE = torch.float8_e4m3fn
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
|
||||
def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||
@ -72,13 +75,15 @@ def benchmark_prefill(
|
||||
]
|
||||
)
|
||||
|
||||
query = torch.randn(torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype)
|
||||
# Always using 1.0 scale to reflect the real perf in benchmarking
|
||||
q_scale = 1.0
|
||||
ref_query = torch.randn(
|
||||
torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype
|
||||
)
|
||||
if q_quant_dtype == FP8_DTYPE:
|
||||
query, q_scale = to_float8(query)
|
||||
ref_query = query.to(dtype) * q_scale
|
||||
query, _ = to_float8(ref_query)
|
||||
else:
|
||||
q_scale = 1.0
|
||||
ref_query = query
|
||||
query = ref_query
|
||||
|
||||
kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32)
|
||||
kv_lens[-1] = max_kv_len
|
||||
@ -86,14 +91,13 @@ def benchmark_prefill(
|
||||
seq_lens = kv_lens + q_lens
|
||||
max_seq_len = torch.max(seq_lens).item()
|
||||
|
||||
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
||||
# Always using 1.0 scale to reflect the real perf in benchmarking
|
||||
k_scale = v_scale = 1.0
|
||||
ref_kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
||||
if kv_quant_dtype == FP8_DTYPE:
|
||||
kv_cache, kv_scale = to_float8(kv_cache)
|
||||
ref_kv_cache = kv_cache.to(dtype) * kv_scale
|
||||
kv_cache, _ = to_float8(ref_kv_cache)
|
||||
else:
|
||||
kv_scale = 1.0
|
||||
ref_kv_cache = kv_cache
|
||||
k_scale = v_scale = kv_scale
|
||||
kv_cache = ref_kv_cache
|
||||
|
||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(
|
||||
@ -152,11 +156,31 @@ def benchmark_prefill(
|
||||
return sum(times) / len(times), torch.std(torch.tensor(times))
|
||||
|
||||
o_scale = 1.0
|
||||
o_sf_scale = None
|
||||
output_baseline = torch.empty(ref_query.shape, dtype=dtype)
|
||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||
if o_quant_dtype == FP4_DTYPE:
|
||||
o_sf_scale = 500.0
|
||||
output_trtllm = flashinfer.utils.FP4Tensor(
|
||||
torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8),
|
||||
torch.empty(
|
||||
(
|
||||
round_up(query.shape[0], 128),
|
||||
round_up(query.shape[1] * query.shape[2] // 16, 4),
|
||||
),
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
)
|
||||
else:
|
||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||
|
||||
def baseline_prefill():
|
||||
return wrapper.run(ref_query, ref_kv_cache, out=output_baseline)
|
||||
return wrapper.run(
|
||||
ref_query,
|
||||
ref_kv_cache,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
out=output_baseline,
|
||||
)
|
||||
|
||||
def trtllm_prefill():
|
||||
return flashinfer.prefill.trtllm_batch_context_with_kv_cache(
|
||||
@ -172,6 +196,7 @@ def benchmark_prefill(
|
||||
batch_size=batch_size,
|
||||
cum_seq_lens_q=q_indptr,
|
||||
cum_seq_lens_kv=kv_indptr,
|
||||
o_sf_scale=o_sf_scale,
|
||||
out=output_trtllm,
|
||||
)
|
||||
|
||||
@ -250,6 +275,7 @@ if __name__ == "__main__":
|
||||
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
|
||||
(None, None, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
|
||||
]
|
||||
|
||||
for quant_dtype in quant_dtypes:
|
||||
|
||||
@ -11,8 +11,8 @@ from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
import triton
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
_w8a8_block_fp8_matmul,
|
||||
|
||||
@ -95,4 +95,10 @@ WEIGHT_SHAPES = {
|
||||
([2048, 2816], 1),
|
||||
([1408, 2048], 0),
|
||||
],
|
||||
"CohereLabs/c4ai-command-a-03-2025": [
|
||||
([12288, 14336], 1),
|
||||
([12288, 12288], 0),
|
||||
([12288, 73728], 1),
|
||||
([36864, 12288], 0),
|
||||
],
|
||||
}
|
||||
|
||||
@ -182,17 +182,17 @@ endif()
|
||||
#
|
||||
# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 /ARM platforms)
|
||||
# Flag to enable ACL kernels for AARCH64 platforms
|
||||
if ( VLLM_BUILD_ACL STREQUAL "ON")
|
||||
if (VLLM_BUILD_ACL STREQUAL "ON")
|
||||
set(USE_ACL ON)
|
||||
else()
|
||||
set(USE_ACL OFF)
|
||||
endif()
|
||||
|
||||
if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND)
|
||||
if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND)
|
||||
FetchContent_Declare(
|
||||
oneDNN
|
||||
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
|
||||
GIT_TAG v3.8.1
|
||||
GIT_TAG v3.9
|
||||
GIT_PROGRESS TRUE
|
||||
GIT_SHALLOW TRUE
|
||||
)
|
||||
@ -204,7 +204,7 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND)
|
||||
endif()
|
||||
set(ONEDNN_AARCH64_USE_ACL "ON")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(ONEDNN_LIBRARY_TYPE "STATIC")
|
||||
set(ONEDNN_BUILD_DOC "OFF")
|
||||
@ -217,38 +217,23 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND)
|
||||
set(ONEDNN_ENABLE_ITT_TASKS "OFF")
|
||||
set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF")
|
||||
set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF")
|
||||
set(ONEDNN_VERBOSE "OFF")
|
||||
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
|
||||
|
||||
FetchContent_MakeAvailable(oneDNN)
|
||||
|
||||
list(APPEND LIBS dnnl)
|
||||
elseif(POWER10_FOUND)
|
||||
FetchContent_Declare(
|
||||
oneDNN
|
||||
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
|
||||
GIT_TAG v3.7.2
|
||||
GIT_PROGRESS TRUE
|
||||
GIT_SHALLOW TRUE
|
||||
add_library(dnnl_ext OBJECT "csrc/cpu/dnnl_helper.cpp")
|
||||
target_include_directories(
|
||||
dnnl_ext
|
||||
PUBLIC ${oneDNN_SOURCE_DIR}/include
|
||||
PUBLIC ${oneDNN_BINARY_DIR}/include
|
||||
PRIVATE ${oneDNN_SOURCE_DIR}/src
|
||||
)
|
||||
|
||||
set(ONEDNN_LIBRARY_TYPE "STATIC")
|
||||
set(ONEDNN_BUILD_DOC "OFF")
|
||||
set(ONEDNN_BUILD_EXAMPLES "OFF")
|
||||
set(ONEDNN_BUILD_TESTS "OFF")
|
||||
set(ONEDNN_ENABLE_WORKLOAD "INFERENCE")
|
||||
set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER")
|
||||
set(ONEDNN_BUILD_GRAPH "OFF")
|
||||
set(ONEDNN_ENABLE_JIT_PROFILING "OFF")
|
||||
set(ONEDNN_ENABLE_ITT_TASKS "OFF")
|
||||
set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF")
|
||||
set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF")
|
||||
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
|
||||
|
||||
set(DNNL_CPU_RUNTIME "OMP")
|
||||
|
||||
FetchContent_MakeAvailable(oneDNN)
|
||||
|
||||
list(APPEND LIBS dnnl)
|
||||
target_link_libraries(dnnl_ext dnnl)
|
||||
target_compile_options(dnnl_ext PRIVATE ${CXX_COMPILE_FLAGS} -fPIC)
|
||||
list(APPEND LIBS dnnl_ext)
|
||||
set(USE_ONEDNN ON)
|
||||
else()
|
||||
set(USE_ONEDNN OFF)
|
||||
endif()
|
||||
|
||||
message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
|
||||
@ -275,7 +260,6 @@ set(VLLM_EXT_SRC
|
||||
|
||||
if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/quant.cpp"
|
||||
"csrc/cpu/shm.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI)
|
||||
@ -289,14 +273,11 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
${VLLM_EXT_SRC})
|
||||
add_compile_definitions(-DCPU_CAPABILITY_AVX512)
|
||||
endif()
|
||||
elseif(POWER10_FOUND)
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/quant.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
endif()
|
||||
if (ASIMD_FOUND)
|
||||
|
||||
if(USE_ONEDNN)
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/quant.cpp"
|
||||
"csrc/cpu/dnnl_kernels.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
endif()
|
||||
|
||||
|
||||
@ -19,7 +19,7 @@ else()
|
||||
FetchContent_Declare(
|
||||
flashmla
|
||||
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git
|
||||
GIT_TAG 0e43e774597682284358ff2c54530757b654b8d1
|
||||
GIT_TAG a757314c04eedd166e329e846c820eb1bdd702de
|
||||
GIT_PROGRESS TRUE
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
@ -37,13 +37,14 @@ cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
|
||||
set(FlashMLA_SOURCES
|
||||
${flashmla_SOURCE_DIR}/csrc/flash_api.cpp
|
||||
${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/kernels/mla_combine.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu)
|
||||
${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu)
|
||||
|
||||
set(FlashMLA_INCLUDES
|
||||
${flashmla_SOURCE_DIR}/csrc/cutlass/include
|
||||
${flashmla_SOURCE_DIR}/csrc/include)
|
||||
${flashmla_SOURCE_DIR}/csrc)
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${FlashMLA_SOURCES}"
|
||||
|
||||
@ -167,7 +167,7 @@ typename T::Fmha::Arguments args_from_options(
|
||||
// TODO(trevor-m): Change split_kv back to -1 when
|
||||
// https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will
|
||||
// perform worse with larger context length and smaller batch sizes.
|
||||
num_kv_splits, // split_kv
|
||||
static_cast<int>(num_kv_splits), // split_kv
|
||||
nullptr, // is_var_split_kv
|
||||
};
|
||||
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
|
||||
@ -264,7 +264,7 @@ int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_ba
|
||||
// Assumes device 0 when getting sm_count.
|
||||
arguments.hw_info.sm_count =
|
||||
sm_count <= 0 ? cutlass::KernelHardwareInfo::query_device_multiprocessor_count(/*device_id=*/0) : sm_count;
|
||||
arguments.split_kv = num_kv_splits;
|
||||
arguments.split_kv = static_cast<int>(num_kv_splits);
|
||||
MlaSm100Type::Fmha::set_split_kv(arguments);
|
||||
|
||||
return MlaSm100Type::Fmha::get_workspace_size(arguments);
|
||||
|
||||
@ -40,9 +40,11 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
|
||||
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||
const double scale, const std::string& kv_cache_dtype);
|
||||
|
||||
void gather_cache(
|
||||
void gather_and_maybe_dequant_cache(
|
||||
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
|
||||
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
|
||||
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
|
||||
torch::Tensor const& cu_seq_lens, // [BATCH+1]
|
||||
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);
|
||||
int64_t batch_size, const std::string& kv_cache_dtype,
|
||||
torch::Tensor const& scale,
|
||||
std::optional<torch::Tensor> seq_starts = std::nullopt);
|
||||
@ -624,9 +624,9 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||
namespace vllm {
|
||||
|
||||
// grid is launched with dimensions (batch, num_splits)
|
||||
template <typename scalar_t>
|
||||
__global__ void gather_cache(
|
||||
const scalar_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
|
||||
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
||||
__global__ void gather_and_maybe_dequant_cache(
|
||||
const cache_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
|
||||
// ENTRIES...]
|
||||
scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...]
|
||||
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
|
||||
@ -634,6 +634,7 @@ __global__ void gather_cache(
|
||||
const int32_t block_size, const int32_t entry_size,
|
||||
const int64_t block_table_stride, const int64_t cache_block_stride,
|
||||
const int64_t cache_entry_stride, const int64_t dst_entry_stride,
|
||||
const float* __restrict__ scale,
|
||||
const int32_t* __restrict__ seq_starts) { // Optional: starting offsets per
|
||||
// batch
|
||||
|
||||
@ -675,10 +676,16 @@ __global__ void gather_cache(
|
||||
if (partial_block_size) full_blocks_end -= 1;
|
||||
}
|
||||
|
||||
auto copy_entry = [&](const scalar_t* __restrict__ _src,
|
||||
auto copy_entry = [&](const cache_t* __restrict__ _src,
|
||||
scalar_t* __restrict__ _dst) {
|
||||
for (int i = threadIdx.x; i < entry_size; i += blockDim.x)
|
||||
_dst[i] = _src[i];
|
||||
for (int i = threadIdx.x; i < entry_size; i += blockDim.x) {
|
||||
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
||||
_dst[i] = static_cast<scalar_t>(_src[i]);
|
||||
} else {
|
||||
_dst[i] =
|
||||
fp8::scaled_convert<scalar_t, cache_t, kv_dt>(_src[i], *scale);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
for (int pid = split_start; pid < full_blocks_end; ++pid) {
|
||||
@ -705,25 +712,31 @@ __global__ void gather_cache(
|
||||
} // namespace vllm
|
||||
|
||||
// Macro to dispatch the kernel based on the data type.
|
||||
#define CALL_GATHER_CACHE(CPY_DTYPE) \
|
||||
vllm::gather_cache<CPY_DTYPE><<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<CPY_DTYPE*>(src_cache.data_ptr()), \
|
||||
reinterpret_cast<CPY_DTYPE*>(dst.data_ptr()), \
|
||||
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
|
||||
block_size, entry_size, block_table_stride, cache_block_stride, \
|
||||
cache_entry_stride, dst_entry_stride, seq_starts_ptr);
|
||||
// SCALAR_T is the data type of the destination tensor.
|
||||
// CACHE_T is the stored data type of kv-cache.
|
||||
// KV_DTYPE is the real data type of kv-cache.
|
||||
#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::gather_and_maybe_dequant_cache<SCALAR_T, CACHE_T, KV_DTYPE> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<CACHE_T*>(src_cache.data_ptr()), \
|
||||
reinterpret_cast<SCALAR_T*>(dst.data_ptr()), \
|
||||
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
|
||||
block_size, entry_size, block_table_stride, cache_block_stride, \
|
||||
cache_entry_stride, dst_entry_stride, \
|
||||
reinterpret_cast<const float*>(scale.data_ptr()), seq_starts_ptr);
|
||||
|
||||
// Gather sequences from the cache into the destination tensor.
|
||||
// - cu_seq_lens contains the cumulative sequence lengths for each batch
|
||||
// - block_table contains the cache block indices for each sequence
|
||||
// - Optionally, seq_starts (if provided) offsets the starting block index by
|
||||
// (seq_starts[bid] / page_size)
|
||||
void gather_cache(
|
||||
void gather_and_maybe_dequant_cache(
|
||||
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
|
||||
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
|
||||
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
|
||||
torch::Tensor const& cu_seq_lens, // [BATCH+1]
|
||||
int64_t batch_size,
|
||||
int64_t batch_size, const std::string& kv_cache_dtype,
|
||||
torch::Tensor const& scale,
|
||||
std::optional<torch::Tensor> seq_starts = std::nullopt) {
|
||||
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
@ -761,20 +774,8 @@ void gather_cache(
|
||||
dim3 grid(batch_size, num_splits);
|
||||
dim3 block(1024);
|
||||
|
||||
TORCH_CHECK(src_cache.dtype() == dst.dtype(),
|
||||
"src_cache and dst must have the same dtype");
|
||||
|
||||
const int dtype_bits = src_cache.element_size() * 8;
|
||||
const int32_t* seq_starts_ptr =
|
||||
seq_starts.has_value() ? seq_starts.value().data_ptr<int32_t>() : nullptr;
|
||||
|
||||
if (dtype_bits == 32) {
|
||||
CALL_GATHER_CACHE(uint32_t);
|
||||
} else if (dtype_bits == 16) {
|
||||
CALL_GATHER_CACHE(uint16_t);
|
||||
} else if (dtype_bits == 8) {
|
||||
CALL_GATHER_CACHE(uint8_t);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits);
|
||||
}
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype, CALL_GATHER_CACHE);
|
||||
}
|
||||
|
||||
@ -89,7 +89,7 @@ struct FP16Vec16 : public Vec<FP16Vec16> {
|
||||
|
||||
explicit FP16Vec16(const FP32Vec16&);
|
||||
|
||||
void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; }
|
||||
void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); }
|
||||
|
||||
void save(void* ptr, const int elem_num) const {
|
||||
constexpr uint32_t M = 0xFFFFFFFF;
|
||||
@ -126,7 +126,7 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||
|
||||
explicit BF16Vec16(const FP32Vec16&);
|
||||
|
||||
void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; }
|
||||
void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); }
|
||||
|
||||
void save(void* ptr, const int elem_num) const {
|
||||
constexpr uint32_t M = 0xFFFFFFFF;
|
||||
@ -180,8 +180,8 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
(__m128i)vec8_data.reg, 1)) {}
|
||||
|
||||
void save(void* ptr) const {
|
||||
*reinterpret_cast<__m256i*>(ptr) = reg_low;
|
||||
*reinterpret_cast<__m256i*>((__m256i*)ptr + 1) = reg_high;
|
||||
_mm256_storeu_si256((__m256i*)ptr, reg_low);
|
||||
_mm256_storeu_si256((__m256i*)ptr + 1, reg_high);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
346
csrc/cpu/dnnl_helper.cpp
Normal file
346
csrc/cpu/dnnl_helper.cpp
Normal file
@ -0,0 +1,346 @@
|
||||
#include <list>
|
||||
#include <optional>
|
||||
|
||||
#include "common/memory_desc.hpp"
|
||||
#include "common/memory.hpp"
|
||||
|
||||
#include "dnnl_helper.h"
|
||||
|
||||
static dnnl::engine& default_engine() {
|
||||
static dnnl::engine engine(dnnl::engine::kind::cpu, 0);
|
||||
return engine;
|
||||
}
|
||||
|
||||
static dnnl::stream& default_stream() {
|
||||
static dnnl::stream stream(default_engine());
|
||||
return stream;
|
||||
}
|
||||
|
||||
void release_dnnl_matmul_handler(int64_t handler) {
|
||||
DNNLMatMulPrimitiveHandler* ptr =
|
||||
reinterpret_cast<DNNLMatMulPrimitiveHandler*>(handler);
|
||||
delete ptr;
|
||||
}
|
||||
|
||||
template <typename KT, typename VT>
|
||||
class DNNLPrimitiveCache {
|
||||
public:
|
||||
using cache_value_t = std::pair<KT, VT>;
|
||||
using result_value_t = VT;
|
||||
using container_t = std::list<cache_value_t>;
|
||||
using value_iterator_t = typename container_t::iterator;
|
||||
using map_t = std::unordered_map<KT, value_iterator_t>;
|
||||
using creator_t = VT (*)();
|
||||
|
||||
public:
|
||||
DNNLPrimitiveCache(size_t capacity)
|
||||
: capacity_(capacity),
|
||||
values_(),
|
||||
key_to_value_(std::min(256lu, capacity)) {
|
||||
assert(capacity > 0);
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
result_value_t get_or_create(const KT& key, F&& creator) {
|
||||
std::optional<value_iterator_t> value = get_value(key);
|
||||
if (value.has_value()) {
|
||||
return value.value()->second;
|
||||
} else {
|
||||
return add_value({key, creator()})->second;
|
||||
}
|
||||
}
|
||||
|
||||
size_t size() const { return values_.size(); }
|
||||
|
||||
private:
|
||||
void dump_data() {
|
||||
std::stringstream ss;
|
||||
ss << "table_id: " << std::hex << reinterpret_cast<size_t>(this) << std::dec
|
||||
<< "\n";
|
||||
ss << "container: [";
|
||||
for (auto&& iter : values_) {
|
||||
ss << "(" << iter.first << ", " << std::hex
|
||||
<< reinterpret_cast<size_t>(iter.second.get()) << "), " << std::dec;
|
||||
}
|
||||
ss << "]\n";
|
||||
|
||||
ss << "map: [";
|
||||
for (auto&& iter : key_to_value_) {
|
||||
ss << "(" << iter.first << ", " << iter.second->first << ", " << std::hex
|
||||
<< reinterpret_cast<size_t>(iter.second->second.get()) << std::dec
|
||||
<< "), ";
|
||||
}
|
||||
ss << "]\n";
|
||||
std::printf("%s\n", ss.str().c_str());
|
||||
}
|
||||
|
||||
value_iterator_t add_value(cache_value_t&& new_value) {
|
||||
if (size() == capacity_) {
|
||||
cache_value_t& last_item = values_.back();
|
||||
key_to_value_.erase(last_item.first);
|
||||
values_.pop_back();
|
||||
}
|
||||
|
||||
auto& added_value_ = values_.emplace_front(std::move(new_value));
|
||||
key_to_value_.emplace(added_value_.first, values_.begin());
|
||||
return values_.begin();
|
||||
}
|
||||
|
||||
std::optional<value_iterator_t> get_value(const KT& key) {
|
||||
if (key_to_value_.size() > 0 && key == values_.begin()->first) {
|
||||
return values_.begin();
|
||||
}
|
||||
|
||||
auto value_map_iterator = key_to_value_.find(key);
|
||||
if (value_map_iterator != key_to_value_.end()) {
|
||||
values_.splice(values_.begin(), values_, value_map_iterator->second);
|
||||
return value_map_iterator->second;
|
||||
} else {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
const size_t capacity_;
|
||||
container_t values_;
|
||||
map_t key_to_value_;
|
||||
};
|
||||
|
||||
DNNLMatMulPrimitiveHandler::DNNLMatMulPrimitiveHandler(
|
||||
const Args& args, dnnl::memory::data_type b_type)
|
||||
: b_n_size_(args.b_n_size),
|
||||
b_n_stride_(args.b_n_stride),
|
||||
b_k_size_(args.b_k_size),
|
||||
b_k_stride_(args.b_k_stride),
|
||||
b_type_(b_type),
|
||||
c_type_(args.c_type),
|
||||
runtime_memory_ptrs_(8),
|
||||
primitive_cache_size_(args.primitive_cache_size) {
|
||||
assert(primitive_cache_size_ > 0);
|
||||
}
|
||||
|
||||
void DNNLMatMulPrimitiveHandler::prepack_weight(
|
||||
void* original_b_ptr, dnnl::memory::desc b_target_mem_desc) {
|
||||
dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_,
|
||||
{b_k_stride_, b_n_stride_});
|
||||
dnnl::memory original_weight(original_b_md, default_engine(), original_b_ptr);
|
||||
dnnl::memory packed_weight(b_target_mem_desc, default_engine());
|
||||
{
|
||||
dnnl::reorder(original_weight, packed_weight)
|
||||
.execute(default_stream(), original_weight, packed_weight);
|
||||
default_stream().wait();
|
||||
}
|
||||
memory_cache_[DNNL_ARG_WEIGHTS] = packed_weight;
|
||||
b_target_mem_desc_ = b_target_mem_desc;
|
||||
}
|
||||
|
||||
void DNNLMatMulPrimitiveHandler::set_runtime_memory_ptr(
|
||||
size_t index, dnnl_memory* memory_ptr) {
|
||||
dnnl::impl::memory_storage_t* mem_storage_ptr = memory_ptr->memory_storage();
|
||||
dnnl_memory_desc* mem_desc = const_cast<dnnl_memory_desc*>(memory_ptr->md());
|
||||
runtime_memory_ptrs_[index] = {mem_storage_ptr, mem_desc};
|
||||
}
|
||||
|
||||
std::pair<dnnl::impl::memory_storage_t*, dnnl_memory_desc*>
|
||||
DNNLMatMulPrimitiveHandler::get_runtime_memory_ptr(size_t index) {
|
||||
return runtime_memory_ptrs_[index];
|
||||
}
|
||||
|
||||
namespace std {
|
||||
template <>
|
||||
struct hash<W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey> {
|
||||
size_t operator()(
|
||||
const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& val) const {
|
||||
return hash<dnnl_dim_t>()(val.b_n_size) ^ hash<dnnl_dim_t>()(val.b_k_size) ^
|
||||
hash<int>()(static_cast<int>(val.a_qs)) ^
|
||||
hash<int>()(static_cast<int>(val.b_qs)) ^ hash<bool>()(val.use_azp) ^
|
||||
hash<int>()(static_cast<int>(val.c_type));
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct hash<W8A8MatMulPrimitiveHandler::MSizeCacheKey> {
|
||||
size_t operator()(
|
||||
const W8A8MatMulPrimitiveHandler::MSizeCacheKey& val) const {
|
||||
return hash<dnnl_dim_t>()(val.a_m_size) ^ hash<bool>()(val.use_bias) ^
|
||||
hash<int>()(static_cast<int>(val.bias_type));
|
||||
}
|
||||
};
|
||||
} // namespace std
|
||||
|
||||
bool operator==(const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& l,
|
||||
const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& r) {
|
||||
return l.b_n_size == r.b_n_size && l.b_k_size == r.b_k_size &&
|
||||
l.a_qs == r.a_qs && l.b_qs == r.b_qs && l.use_azp == r.use_azp &&
|
||||
l.c_type == r.c_type;
|
||||
}
|
||||
|
||||
bool operator==(const W8A8MatMulPrimitiveHandler::MSizeCacheKey& l,
|
||||
const W8A8MatMulPrimitiveHandler::MSizeCacheKey& r) {
|
||||
return l.use_bias == r.use_bias && l.a_m_size == r.a_m_size &&
|
||||
l.bias_type == r.bias_type;
|
||||
}
|
||||
|
||||
static std::shared_ptr<W8A8MatMulPrimitiveHandler::MSizeCache>
|
||||
get_w8a8_class_primitive_cache(
|
||||
const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& key,
|
||||
int64_t cache_size) {
|
||||
static W8A8MatMulPrimitiveHandler::ClassMatmulCache cache(128);
|
||||
assert(cache_size > 0);
|
||||
return cache.get_or_create(key, [&]() {
|
||||
return std::make_shared<W8A8MatMulPrimitiveHandler::MSizeCache>(cache_size);
|
||||
});
|
||||
}
|
||||
|
||||
W8A8MatMulPrimitiveHandler::W8A8MatMulPrimitiveHandler(const Args& args)
|
||||
: DNNLMatMulPrimitiveHandler(
|
||||
static_cast<const DNNLMatMulPrimitiveHandler::Args&>(args),
|
||||
dnnl::memory::data_type::s8),
|
||||
use_azp_(args.use_a_zero_point),
|
||||
a_qs_(args.a_quantization_strategy),
|
||||
b_qs_(args.b_quantization_strategy),
|
||||
m_size_cache_(nullptr) {
|
||||
assert(a_qs_ != QuantizationStrategy::PER_OUTPUT_CHANNEL);
|
||||
assert(b_qs_ != QuantizationStrategy::PER_TOKEN);
|
||||
if (a_qs_ == QuantizationStrategy::PER_TOKEN) {
|
||||
assert(!use_azp_);
|
||||
};
|
||||
prepack_weight(args.b_ptr,
|
||||
create_primitive_desc(
|
||||
MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL,
|
||||
.use_bias = false,
|
||||
.bias_type = dnnl::memory::data_type::undef},
|
||||
true)
|
||||
.weights_desc());
|
||||
init_runtime_memory_cache(args);
|
||||
}
|
||||
|
||||
void W8A8MatMulPrimitiveHandler::execute(ExecArgs& args) {
|
||||
auto&& [a_storage, a_mem_desc] = get_runtime_memory_ptr(0);
|
||||
auto&& [c_storage, c_mem_desc] = get_runtime_memory_ptr(1);
|
||||
a_storage->set_data_handle((void*)args.a_ptr);
|
||||
a_mem_desc->dims[0] = args.a_m_size;
|
||||
c_storage->set_data_handle((void*)args.c_ptr);
|
||||
c_mem_desc->dims[0] = args.a_m_size;
|
||||
|
||||
if (a_qs_ == QuantizationStrategy::PER_TENSOR) {
|
||||
auto&& [a_scale_storage, a_scale_mem_desc] = get_runtime_memory_ptr(2);
|
||||
a_scale_storage->set_data_handle((void*)args.a_scales_ptr);
|
||||
}
|
||||
if (use_azp_) {
|
||||
auto&& [a_zero_point_storage, a_zero_point_mem_desc] =
|
||||
get_runtime_memory_ptr(3);
|
||||
a_zero_point_storage->set_data_handle((void*)args.a_zero_points_ptr);
|
||||
}
|
||||
|
||||
if (args.use_bias) {
|
||||
auto&& [bias_storage, bias_mem_desc] = get_runtime_memory_ptr(4);
|
||||
bias_storage->set_data_handle((void*)args.bias_ptr);
|
||||
}
|
||||
|
||||
dnnl::matmul matmul = get_matmul_cache(args);
|
||||
matmul.execute(default_stream(), memory_cache_);
|
||||
default_stream().wait();
|
||||
}
|
||||
|
||||
dnnl::matmul W8A8MatMulPrimitiveHandler::get_matmul_cache(
|
||||
const MSizeCacheKey& key) {
|
||||
if (m_size_cache_.get() == nullptr) {
|
||||
ClassMatmulCacheKey key = {.b_n_size = b_n_size_,
|
||||
.b_k_size = b_k_size_,
|
||||
.a_qs = a_qs_,
|
||||
.b_qs = b_qs_,
|
||||
.use_azp = use_azp_,
|
||||
.c_type = c_type_};
|
||||
m_size_cache_ = get_w8a8_class_primitive_cache(key, primitive_cache_size_);
|
||||
}
|
||||
|
||||
return m_size_cache_->get_or_create(key, [&]() {
|
||||
dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false);
|
||||
return dnnl::matmul(desc);
|
||||
});
|
||||
}
|
||||
|
||||
void W8A8MatMulPrimitiveHandler::init_runtime_memory_cache(const Args& args) {
|
||||
memory_cache_[DNNL_ARG_SRC] = dnnl::memory({{1, b_k_size_},
|
||||
dnnl::memory::data_type::s8,
|
||||
dnnl::memory::format_tag::ab},
|
||||
default_engine(), nullptr);
|
||||
set_runtime_memory_ptr(0, memory_cache_[DNNL_ARG_SRC].get());
|
||||
memory_cache_[DNNL_ARG_DST] =
|
||||
dnnl::memory({{1, b_n_size_}, c_type_, dnnl::memory::format_tag::ab},
|
||||
default_engine(), nullptr);
|
||||
set_runtime_memory_ptr(1, memory_cache_[DNNL_ARG_DST].get());
|
||||
|
||||
// For PER_TOKEN, scales will be applied in outside epilogue
|
||||
if (a_qs_ == QuantizationStrategy::PER_TENSOR) {
|
||||
memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC] = dnnl::memory(
|
||||
{{1}, dnnl::memory::data_type::f32, {1}}, default_engine(), nullptr);
|
||||
set_runtime_memory_ptr(
|
||||
2, memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC].get());
|
||||
if (use_azp_) {
|
||||
memory_cache_[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC] = dnnl::memory(
|
||||
{{1}, dnnl::memory::data_type::s32, {1}}, default_engine(), nullptr);
|
||||
set_runtime_memory_ptr(
|
||||
3, memory_cache_[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC].get());
|
||||
}
|
||||
}
|
||||
|
||||
if (b_qs_ == QuantizationStrategy::PER_TENSOR) {
|
||||
memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] =
|
||||
dnnl::memory({{1}, dnnl::memory::data_type::f32, {1}}, default_engine(),
|
||||
(void*)args.b_scales_ptr);
|
||||
} else if (b_qs_ == QuantizationStrategy::PER_OUTPUT_CHANNEL) {
|
||||
memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] =
|
||||
dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
|
||||
default_engine(), (void*)args.b_scales_ptr);
|
||||
}
|
||||
|
||||
memory_cache_[DNNL_ARG_BIAS] =
|
||||
dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
|
||||
default_engine(), nullptr);
|
||||
set_runtime_memory_ptr(4, memory_cache_[DNNL_ARG_BIAS].get());
|
||||
}
|
||||
|
||||
dnnl::matmul::primitive_desc W8A8MatMulPrimitiveHandler::create_primitive_desc(
|
||||
const MSizeCacheKey& key, bool first_time) {
|
||||
dnnl::memory::desc a_md({key.a_m_size, b_k_size_},
|
||||
dnnl::memory::data_type::s8,
|
||||
dnnl::memory::format_tag::ab);
|
||||
dnnl::memory::desc b_md;
|
||||
if (first_time) {
|
||||
b_md =
|
||||
dnnl::memory::desc({b_k_size_, b_n_size_}, dnnl::memory::data_type::s8,
|
||||
dnnl::memory::format_tag::any);
|
||||
} else {
|
||||
b_md = b_target_mem_desc_;
|
||||
}
|
||||
dnnl::memory::desc c_md({key.a_m_size, b_n_size_}, c_type_,
|
||||
dnnl::memory::format_tag::ab);
|
||||
|
||||
dnnl::primitive_attr attr;
|
||||
// For PER_TOKEN, scales will be applied in outside epilogue
|
||||
if (a_qs_ == QuantizationStrategy::PER_TENSOR) {
|
||||
attr.set_scales_mask(DNNL_ARG_SRC, 0);
|
||||
if (use_azp_) {
|
||||
attr.set_zero_points_mask(DNNL_ARG_SRC, 0);
|
||||
}
|
||||
}
|
||||
|
||||
if (b_qs_ == QuantizationStrategy::PER_TENSOR) {
|
||||
attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0);
|
||||
} else if (b_qs_ == QuantizationStrategy::PER_OUTPUT_CHANNEL) {
|
||||
attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2);
|
||||
}
|
||||
|
||||
if (key.use_bias) {
|
||||
// For PER_TOKEN, bias will be applied in epilogue
|
||||
assert(a_qs_ == QuantizationStrategy::PER_TENSOR);
|
||||
dnnl::memory::desc bias_md({1, b_n_size_}, key.bias_type, {b_n_size_, 1});
|
||||
return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, bias_md,
|
||||
c_md, attr);
|
||||
} else {
|
||||
return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md,
|
||||
attr);
|
||||
}
|
||||
}
|
||||
169
csrc/cpu/dnnl_helper.h
Normal file
169
csrc/cpu/dnnl_helper.h
Normal file
@ -0,0 +1,169 @@
|
||||
#ifndef DNNL_HELPER_H
|
||||
#define DNNL_HELPER_H
|
||||
|
||||
#include <optional>
|
||||
#include <cassert>
|
||||
|
||||
#include "oneapi/dnnl/dnnl.hpp"
|
||||
|
||||
namespace c10 {
|
||||
struct BFloat16;
|
||||
struct Half;
|
||||
} // namespace c10
|
||||
|
||||
namespace dnnl {
|
||||
namespace impl {
|
||||
struct memory_storage_t;
|
||||
struct matmul_pd_t;
|
||||
struct matmul_desc_t;
|
||||
} // namespace impl
|
||||
} // namespace dnnl
|
||||
struct dnnl_memory_desc;
|
||||
|
||||
template <typename KT, typename VT>
|
||||
class DNNLPrimitiveCache;
|
||||
|
||||
template <typename T>
|
||||
struct DNNLType {
|
||||
static constexpr dnnl::memory::data_type type =
|
||||
dnnl::memory::data_type::undef;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<int8_t> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<int32_t> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<float> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<c10::BFloat16> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<c10::Half> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f16;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
constexpr inline dnnl::memory::data_type get_dnnl_type() {
|
||||
return DNNLType<std::decay_t<T>>::type;
|
||||
}
|
||||
|
||||
class DNNLMatMulPrimitiveHandler {
|
||||
public:
|
||||
virtual ~DNNLMatMulPrimitiveHandler() = default;
|
||||
|
||||
protected:
|
||||
struct Args {
|
||||
dnnl_dim_t b_n_size;
|
||||
dnnl_dim_t b_n_stride;
|
||||
dnnl_dim_t b_k_size;
|
||||
dnnl_dim_t b_k_stride;
|
||||
void* b_ptr;
|
||||
dnnl::memory::data_type c_type;
|
||||
size_t primitive_cache_size;
|
||||
};
|
||||
|
||||
protected:
|
||||
DNNLMatMulPrimitiveHandler(const Args& args, dnnl::memory::data_type b_type);
|
||||
|
||||
void prepack_weight(void* original_b_ptr,
|
||||
dnnl::memory::desc b_target_mem_desc);
|
||||
|
||||
void set_runtime_memory_ptr(size_t index, dnnl_memory* memory_ptr);
|
||||
|
||||
std::pair<dnnl::impl::memory_storage_t*, dnnl_memory_desc*>
|
||||
get_runtime_memory_ptr(size_t index);
|
||||
|
||||
protected:
|
||||
const dnnl_dim_t b_n_size_;
|
||||
const dnnl_dim_t b_n_stride_;
|
||||
const dnnl_dim_t b_k_size_;
|
||||
const dnnl_dim_t b_k_stride_;
|
||||
dnnl::memory::data_type b_type_;
|
||||
dnnl::memory::data_type c_type_;
|
||||
std::unordered_map<int, dnnl::memory> memory_cache_;
|
||||
std::vector<std::pair<dnnl::impl::memory_storage_t*, dnnl_memory_desc*>>
|
||||
runtime_memory_ptrs_;
|
||||
dnnl::memory::desc b_target_mem_desc_;
|
||||
int64_t primitive_cache_size_;
|
||||
};
|
||||
|
||||
class W8A8MatMulPrimitiveHandler : public DNNLMatMulPrimitiveHandler {
|
||||
public:
|
||||
enum class QuantizationStrategy { PER_TOKEN, PER_TENSOR, PER_OUTPUT_CHANNEL };
|
||||
|
||||
struct Args : public DNNLMatMulPrimitiveHandler::Args {
|
||||
bool use_a_zero_point;
|
||||
QuantizationStrategy a_quantization_strategy;
|
||||
QuantizationStrategy b_quantization_strategy;
|
||||
float* b_scales_ptr;
|
||||
};
|
||||
|
||||
struct ClassMatmulCacheKey {
|
||||
dnnl_dim_t b_n_size;
|
||||
dnnl_dim_t b_k_size;
|
||||
QuantizationStrategy a_qs;
|
||||
QuantizationStrategy b_qs;
|
||||
bool use_azp;
|
||||
dnnl::memory::data_type c_type;
|
||||
|
||||
friend bool operator==(const ClassMatmulCacheKey& l,
|
||||
const ClassMatmulCacheKey& r);
|
||||
};
|
||||
|
||||
struct MSizeCacheKey {
|
||||
dnnl_dim_t a_m_size;
|
||||
bool use_bias;
|
||||
dnnl::memory::data_type bias_type;
|
||||
|
||||
friend bool operator==(const MSizeCacheKey& l, const MSizeCacheKey& r);
|
||||
};
|
||||
|
||||
using MSizeCache = DNNLPrimitiveCache<MSizeCacheKey, dnnl::matmul>;
|
||||
using ClassMatmulCache =
|
||||
DNNLPrimitiveCache<ClassMatmulCacheKey, std::shared_ptr<MSizeCache>>;
|
||||
|
||||
struct ExecArgs : public MSizeCacheKey {
|
||||
const int8_t* a_ptr;
|
||||
const float* a_scales_ptr;
|
||||
const int32_t* a_zero_points_ptr;
|
||||
const void* bias_ptr;
|
||||
void* c_ptr;
|
||||
};
|
||||
|
||||
public:
|
||||
W8A8MatMulPrimitiveHandler(const Args& args);
|
||||
|
||||
QuantizationStrategy get_input_scale_strategy() const { return a_qs_; }
|
||||
|
||||
bool get_input_use_zero_point() const { return use_azp_; }
|
||||
|
||||
void execute(ExecArgs& args);
|
||||
|
||||
private:
|
||||
dnnl::matmul::primitive_desc create_primitive_desc(const MSizeCacheKey& key,
|
||||
bool first_time);
|
||||
|
||||
void init_runtime_memory_cache(const Args& args);
|
||||
|
||||
dnnl::matmul get_matmul_cache(const MSizeCacheKey& key);
|
||||
|
||||
private:
|
||||
const bool use_azp_;
|
||||
const QuantizationStrategy a_qs_;
|
||||
const QuantizationStrategy b_qs_;
|
||||
std::shared_ptr<MSizeCache> m_size_cache_;
|
||||
};
|
||||
|
||||
#endif
|
||||
@ -1,206 +0,0 @@
|
||||
#ifndef DNNL_HELPER_HPP
|
||||
#define DNNL_HELPER_HPP
|
||||
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Half.h>
|
||||
|
||||
#include "oneapi/dnnl/dnnl.hpp"
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
struct DNNLType {
|
||||
static constexpr dnnl::memory::data_type type =
|
||||
dnnl::memory::data_type::undef;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<int8_t> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<int32_t> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<float> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<c10::BFloat16> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DNNLType<c10::Half> {
|
||||
static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f16;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
constexpr inline dnnl::memory::data_type get_dnnl_type() {
|
||||
return DNNLType<std::decay_t<T>>::type;
|
||||
}
|
||||
}; // namespace
|
||||
|
||||
template <bool InputNoScale>
|
||||
class DNNLPrimitiveHelper {
|
||||
public:
|
||||
// I8 input GEMM kernel (C = a_scales * A @ (b_scales * B^T) + bias)
|
||||
// A: [M, K], row-major
|
||||
// B: [K, N], column-major
|
||||
// C: [M, N], row-major
|
||||
// bias: [N], row-major, optional
|
||||
// a_scales: [MS]
|
||||
// b_scales: [NS]
|
||||
// Note: Due to the limitation of oneDNN
|
||||
// (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is
|
||||
// not supported.
|
||||
|
||||
template <typename OutputT, typename BiasT>
|
||||
static void gemm_s8s8_jit(const int8_t* a, const int8_t* b, OutputT* c,
|
||||
const BiasT* bias, dnnl_dim_t M, dnnl_dim_t N,
|
||||
dnnl_dim_t K, const float* a_scales,
|
||||
const float* b_scales, dnnl_dim_t MS,
|
||||
dnnl_dim_t NS) {
|
||||
auto&& OutputType = get_dnnl_type<OutputT>();
|
||||
auto&& BiasType = get_dnnl_type<BiasT>();
|
||||
|
||||
dnnl::memory::desc a_md({M, K}, dnnl::memory::data_type::s8, {K, 1});
|
||||
dnnl::memory::desc b_md({K, N}, dnnl::memory::data_type::s8, {1, K});
|
||||
dnnl::memory::desc c_md({M, N}, OutputType, {N, 1});
|
||||
|
||||
dnnl::primitive_attr attr;
|
||||
if constexpr (!InputNoScale) {
|
||||
if (MS == 1) {
|
||||
// per-tensor
|
||||
attr.set_scales_mask(DNNL_ARG_SRC, 0);
|
||||
} else {
|
||||
// per-token
|
||||
TORCH_CHECK(false, "per-token quantization is unsupported.");
|
||||
}
|
||||
}
|
||||
|
||||
if (NS == 1) {
|
||||
// per-tensor
|
||||
attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0);
|
||||
} else {
|
||||
// per-channel
|
||||
attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2);
|
||||
}
|
||||
|
||||
dnnl::matmul::primitive_desc matmul_pd;
|
||||
// Create memory descriptors with format_tag::any for the primitive. This
|
||||
// enables the matmul primitive to choose memory layouts for an
|
||||
// optimized primitive implementation, and these layouts may differ from the
|
||||
// ones provided by the user.
|
||||
#ifdef __aarch64__
|
||||
auto mat_src_md = dnnl::memory::desc({M, K}, dnnl::memory::data_type::s8,
|
||||
dnnl::memory::format_tag::any);
|
||||
auto mat_weights_md = dnnl::memory::desc(
|
||||
{K, N}, dnnl::memory::data_type::s8, dnnl::memory::format_tag::any);
|
||||
auto mat_dst_md =
|
||||
dnnl::memory::desc({M, N}, OutputType, dnnl::memory::format_tag::any);
|
||||
if (bias) {
|
||||
dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1});
|
||||
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), mat_src_md,
|
||||
mat_weights_md, bias_md,
|
||||
mat_dst_md, attr);
|
||||
} else {
|
||||
matmul_pd = dnnl::matmul::primitive_desc(
|
||||
default_engine(), mat_src_md, mat_weights_md, mat_dst_md, attr);
|
||||
}
|
||||
#else
|
||||
if (bias) {
|
||||
dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1});
|
||||
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
|
||||
bias_md, c_md, attr);
|
||||
} else {
|
||||
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
|
||||
c_md, attr);
|
||||
}
|
||||
#endif
|
||||
dnnl::matmul matmul(matmul_pd);
|
||||
|
||||
auto& engine = default_engine();
|
||||
|
||||
dnnl::memory a_m(a_md, engine, (void*)a);
|
||||
dnnl::memory b_m(b_md, engine, (void*)b);
|
||||
dnnl::memory c_m(c_md, engine, (void*)c);
|
||||
dnnl::memory a_scales_m({{MS}, dnnl::memory::data_type::f32, {1}}, engine,
|
||||
(void*)a_scales);
|
||||
dnnl::memory b_scales_m({{NS}, dnnl::memory::data_type::f32, {1}}, engine,
|
||||
(void*)b_scales);
|
||||
|
||||
auto& stream = default_stream();
|
||||
|
||||
auto mat_src_mem = a_m;
|
||||
auto mat_weights_mem = b_m;
|
||||
auto mat_dst_mem = c_m;
|
||||
#ifdef __aarch64__
|
||||
if (matmul_pd.weights_desc() != b_m.get_desc()) {
|
||||
mat_weights_mem = dnnl::memory(matmul_pd.weights_desc(), engine);
|
||||
dnnl::reorder(b_m, mat_weights_mem).execute(stream, b_m, mat_weights_mem);
|
||||
}
|
||||
#endif
|
||||
if constexpr (InputNoScale) {
|
||||
if (bias) {
|
||||
dnnl::memory::desc bias_md({N}, BiasType, {1});
|
||||
dnnl::memory bias_m(bias_md, engine, (void*)bias);
|
||||
matmul.execute(
|
||||
stream, {
|
||||
{DNNL_ARG_SRC, mat_src_mem},
|
||||
{DNNL_ARG_WEIGHTS, mat_weights_mem},
|
||||
{DNNL_ARG_BIAS, bias_m},
|
||||
{DNNL_ARG_DST, mat_dst_mem},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||
});
|
||||
} else {
|
||||
matmul.execute(
|
||||
stream, {
|
||||
{DNNL_ARG_SRC, mat_src_mem},
|
||||
{DNNL_ARG_WEIGHTS, mat_weights_mem},
|
||||
{DNNL_ARG_DST, mat_dst_mem},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||
});
|
||||
}
|
||||
} else {
|
||||
if (bias) {
|
||||
dnnl::memory::desc bias_md({N}, BiasType, {1});
|
||||
dnnl::memory bias_m(bias_md, engine, (void*)bias);
|
||||
matmul.execute(
|
||||
stream, {
|
||||
{DNNL_ARG_SRC, mat_src_mem},
|
||||
{DNNL_ARG_WEIGHTS, mat_weights_mem},
|
||||
{DNNL_ARG_BIAS, bias_m},
|
||||
{DNNL_ARG_DST, mat_dst_mem},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||
});
|
||||
} else {
|
||||
matmul.execute(
|
||||
stream, {
|
||||
{DNNL_ARG_SRC, mat_src_mem},
|
||||
{DNNL_ARG_WEIGHTS, mat_weights_mem},
|
||||
{DNNL_ARG_DST, mat_dst_mem},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||
});
|
||||
}
|
||||
}
|
||||
stream.wait();
|
||||
}
|
||||
|
||||
private:
|
||||
static dnnl::engine& default_engine() {
|
||||
static dnnl::engine engine(dnnl::engine::kind::cpu, 0);
|
||||
return engine;
|
||||
}
|
||||
|
||||
static dnnl::stream& default_stream() {
|
||||
static dnnl::stream stream(default_engine());
|
||||
return stream;
|
||||
}
|
||||
};
|
||||
#endif
|
||||
494
csrc/cpu/dnnl_kernels.cpp
Normal file
494
csrc/cpu/dnnl_kernels.cpp
Normal file
@ -0,0 +1,494 @@
|
||||
#include "cpu_types.hpp"
|
||||
#include "dnnl_helper.h"
|
||||
|
||||
namespace {
|
||||
template <typename scalar_t>
|
||||
struct KernelVecType {
|
||||
using load_vec_type = void;
|
||||
using cvt_vec_type = void;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KernelVecType<float> {
|
||||
using load_vec_type = vec_op::FP32Vec16;
|
||||
using cvt_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
|
||||
#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT)
|
||||
template <>
|
||||
struct KernelVecType<c10::BFloat16> {
|
||||
using load_vec_type = vec_op::BF16Vec16;
|
||||
using cvt_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <>
|
||||
struct KernelVecType<c10::Half> {
|
||||
#if defined(__powerpc64__) || defined(__s390x__)
|
||||
// Power architecture-specific vector type
|
||||
using load_vec_type = vec_op::FP32Vec16;
|
||||
#else
|
||||
// Fallback for other architectures
|
||||
using load_vec_type = vec_op::FP16Vec16;
|
||||
#endif
|
||||
using cvt_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
|
||||
template <bool AZP, typename scalar_t>
|
||||
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
const float* scale, const int32_t* azp,
|
||||
const int64_t num_tokens,
|
||||
const int64_t input_stride,
|
||||
const int64_t hidden_size) {
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int64_t vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
constexpr float i8_min =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::min());
|
||||
constexpr float i8_max =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::max());
|
||||
const cvt_vec_t inv_scale(1.0 / *scale);
|
||||
const cvt_vec_t i8_min_vec(i8_min);
|
||||
const cvt_vec_t i8_max_vec(i8_max);
|
||||
|
||||
cvt_vec_t zp_vec;
|
||||
if constexpr (AZP) {
|
||||
zp_vec = cvt_vec_t(static_cast<float>(*azp));
|
||||
}
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int64_t i = 0; i < num_tokens; ++i) {
|
||||
int64_t j = 0;
|
||||
const scalar_t* input_ptr = input + i * input_stride;
|
||||
int8_t* output_ptr = output + i * hidden_size;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input_ptr + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = elems_fp32 * inv_scale;
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + zp_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output_ptr + j);
|
||||
}
|
||||
|
||||
load_vec_t elems(input_ptr + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = elems_fp32 * inv_scale;
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + zp_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output_ptr + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool AZP, typename scalar_t>
|
||||
void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
float* scale, int32_t* azp,
|
||||
const int64_t num_tokens,
|
||||
const int64_t input_stride,
|
||||
const int64_t hidden_size) {
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
constexpr float i8_min =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::min());
|
||||
constexpr float i8_max =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::max());
|
||||
const cvt_vec_t i8_min_vec(i8_min);
|
||||
const cvt_vec_t i8_max_vec(i8_max);
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int64_t i = 0; i < num_tokens; ++i) {
|
||||
cvt_vec_t max_value(std::numeric_limits<float>::lowest());
|
||||
cvt_vec_t min_value(std::numeric_limits<float>::max());
|
||||
{
|
||||
int64_t j = 0;
|
||||
const scalar_t* input_ptr = input + i * input_stride;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input_ptr + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
if constexpr (AZP) {
|
||||
max_value = max_value.max(elems_fp32);
|
||||
min_value = min_value.min(elems_fp32);
|
||||
} else {
|
||||
max_value = max_value.max(elems_fp32.abs());
|
||||
}
|
||||
}
|
||||
|
||||
load_vec_t elems(input_ptr + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
|
||||
if (j + vec_elem_num == hidden_size) {
|
||||
if constexpr (AZP) {
|
||||
max_value = max_value.max(elems_fp32);
|
||||
min_value = min_value.min(elems_fp32);
|
||||
} else {
|
||||
max_value = max_value.max(elems_fp32.abs());
|
||||
}
|
||||
} else {
|
||||
if constexpr (AZP) {
|
||||
max_value = max_value.max(elems_fp32, hidden_size - j);
|
||||
min_value = min_value.min(elems_fp32, hidden_size - j);
|
||||
} else {
|
||||
max_value = max_value.max(elems_fp32.abs(), hidden_size - j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float scale_val, azp_val;
|
||||
if constexpr (AZP) {
|
||||
float max_scalar = max_value.reduce_max();
|
||||
float min_scalar = min_value.reduce_min();
|
||||
scale_val = (max_scalar - min_scalar) / 255.0f;
|
||||
azp_val = std::nearbyint(-128.0f - min_scalar / scale_val);
|
||||
azp[i] = azp_val;
|
||||
scale[i] = scale_val;
|
||||
} else {
|
||||
scale_val = max_value.reduce_max() / 127.0f;
|
||||
scale[i] = scale_val;
|
||||
}
|
||||
|
||||
const cvt_vec_t inv_scale(1.0 / scale_val);
|
||||
const cvt_vec_t azp_vec(azp_val);
|
||||
|
||||
{
|
||||
int64_t j = 0;
|
||||
const scalar_t* input_ptr = input + i * input_stride;
|
||||
int8_t* output_ptr = output + i * hidden_size;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input_ptr + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = (elems_fp32 * inv_scale);
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + azp_vec;
|
||||
}
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output_ptr + j);
|
||||
}
|
||||
|
||||
load_vec_t elems(input_ptr + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = (elems_fp32 * inv_scale);
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + azp_vec;
|
||||
}
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output_ptr + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool AZP, bool Bias, typename scalar_t>
|
||||
void dynamic_quant_epilogue(const float* input, scalar_t* output,
|
||||
const float* a_scale, const int32_t* azp,
|
||||
const float* azp_adj, const scalar_t* bias,
|
||||
const int64_t num_tokens,
|
||||
const int64_t hidden_size) {
|
||||
CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue)
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
const int64_t thread_num = omp_get_max_threads();
|
||||
if (num_tokens > thread_num) {
|
||||
#pragma omp parallel for
|
||||
for (int64_t i = 0; i < num_tokens; ++i) {
|
||||
const float* input_ptr = input + i * hidden_size;
|
||||
scalar_t* output_ptr = output + i * hidden_size;
|
||||
int64_t j = 0;
|
||||
cvt_vec_t token_scale_vec(a_scale[i]);
|
||||
cvt_vec_t token_zp_scale_vec;
|
||||
if constexpr (AZP) {
|
||||
float zp_scale_val = a_scale[i] * static_cast<float>(azp[i]);
|
||||
token_zp_scale_vec = cvt_vec_t(zp_scale_val);
|
||||
}
|
||||
for (; j < hidden_size - vec_elem_num; ++j) {
|
||||
cvt_vec_t elems_fp32(input_ptr + j);
|
||||
elems_fp32 = elems_fp32 * token_scale_vec;
|
||||
if constexpr (AZP) {
|
||||
cvt_vec_t azp_adj_fp32(azp_adj + j);
|
||||
elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec;
|
||||
}
|
||||
if constexpr (Bias) {
|
||||
load_vec_t bias_vec(bias + j);
|
||||
cvt_vec_t bias_vec_fp32(bias_vec);
|
||||
elems_fp32 = elems_fp32 + bias_vec_fp32;
|
||||
}
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output_ptr + j);
|
||||
}
|
||||
cvt_vec_t elems_fp32(input_ptr + j);
|
||||
elems_fp32 = elems_fp32 * token_scale_vec;
|
||||
if constexpr (AZP) {
|
||||
cvt_vec_t azp_adj_fp32(azp_adj + j);
|
||||
elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec;
|
||||
}
|
||||
if constexpr (Bias) {
|
||||
load_vec_t bias_vec(bias + j);
|
||||
cvt_vec_t bias_vec_fp32(bias_vec);
|
||||
elems_fp32 = elems_fp32 + bias_vec_fp32;
|
||||
}
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output_ptr + j, hidden_size - j);
|
||||
}
|
||||
} else {
|
||||
const int64_t vec_iteration =
|
||||
(hidden_size + vec_elem_num - 1) / vec_elem_num;
|
||||
const int64_t vec_iteration_per_thread =
|
||||
(vec_iteration + thread_num - 1) / thread_num;
|
||||
const int64_t elem_num_per_thread = vec_iteration_per_thread * vec_elem_num;
|
||||
#pragma omp parallel for schedule(static, 1)
|
||||
for (int64_t i = 0; i < thread_num; ++i) {
|
||||
const int64_t start = elem_num_per_thread * i;
|
||||
const int64_t end = std::min(hidden_size, elem_num_per_thread + start);
|
||||
for (int64_t j = 0; j < num_tokens; ++j) {
|
||||
cvt_vec_t token_scale_vec(a_scale[j]);
|
||||
cvt_vec_t token_zp_scale_vec;
|
||||
if constexpr (AZP) {
|
||||
float zp_scale_val = a_scale[j] * static_cast<float>(azp[j]);
|
||||
token_zp_scale_vec = cvt_vec_t(zp_scale_val);
|
||||
}
|
||||
int64_t k = start;
|
||||
const float* input_ptr = input + j * hidden_size;
|
||||
scalar_t* output_ptr = output + j * hidden_size;
|
||||
for (; k < end - vec_elem_num; k += vec_elem_num) {
|
||||
cvt_vec_t elems_fp32(input_ptr + k);
|
||||
elems_fp32 = elems_fp32 * token_scale_vec;
|
||||
if constexpr (AZP) {
|
||||
cvt_vec_t azp_adj_fp32(azp_adj + k);
|
||||
elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec;
|
||||
}
|
||||
if constexpr (Bias) {
|
||||
load_vec_t bias_vec(bias + k);
|
||||
cvt_vec_t bias_vec_fp32(bias_vec);
|
||||
elems_fp32 = elems_fp32 + bias_vec_fp32;
|
||||
}
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output_ptr + k);
|
||||
}
|
||||
if (k < end) {
|
||||
cvt_vec_t elems_fp32(input_ptr + k);
|
||||
elems_fp32 = elems_fp32 * token_scale_vec;
|
||||
if constexpr (AZP) {
|
||||
cvt_vec_t azp_adj_fp32(azp_adj + k);
|
||||
elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec;
|
||||
}
|
||||
if constexpr (Bias) {
|
||||
load_vec_t bias_vec(bias + k);
|
||||
cvt_vec_t bias_vec_fp32(bias_vec);
|
||||
elems_fp32 = elems_fp32 + bias_vec_fp32;
|
||||
}
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output_ptr + k, end - k);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
int64_t create_onednn_scaled_mm_handler(
|
||||
const torch::Tensor& b, // [IC, OC], column-major
|
||||
const torch::Tensor& b_scales, // [1] or [OC]
|
||||
at::ScalarType output_type, bool dynamic_act_quant, bool use_azp,
|
||||
int64_t primitive_cache_size) {
|
||||
TORCH_CHECK(b.dim() == 2);
|
||||
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||
TORCH_CHECK(b_scales.is_contiguous());
|
||||
|
||||
W8A8MatMulPrimitiveHandler::Args args;
|
||||
args.primitive_cache_size = primitive_cache_size;
|
||||
|
||||
if (b_scales.numel() == 1) {
|
||||
args.b_quantization_strategy =
|
||||
W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR;
|
||||
} else {
|
||||
TORCH_CHECK_EQ(b_scales.numel(), b.size(1));
|
||||
args.b_quantization_strategy =
|
||||
W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_OUTPUT_CHANNEL;
|
||||
}
|
||||
args.b_scales_ptr = b_scales.data_ptr<float>();
|
||||
args.b_k_size = b.size(0);
|
||||
args.b_k_stride = b.stride(0);
|
||||
args.b_n_size = b.size(1);
|
||||
args.b_n_stride = b.stride(1);
|
||||
args.b_ptr = b.data_ptr<int8_t>();
|
||||
|
||||
if (dynamic_act_quant) {
|
||||
// dynamic per-token, bias, A scales and A zps will be applied in outside.
|
||||
args.a_quantization_strategy =
|
||||
W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TOKEN;
|
||||
args.use_a_zero_point = false;
|
||||
} else {
|
||||
// static per-tensor
|
||||
args.a_quantization_strategy =
|
||||
W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR;
|
||||
args.use_a_zero_point = use_azp;
|
||||
}
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(output_type, "create_onednn_scaled_mm_handler",
|
||||
[&] {
|
||||
if (dynamic_act_quant) {
|
||||
args.c_type = get_dnnl_type<float>();
|
||||
} else {
|
||||
args.c_type = get_dnnl_type<scalar_t>();
|
||||
}
|
||||
});
|
||||
|
||||
return reinterpret_cast<int64_t>(new W8A8MatMulPrimitiveHandler(args));
|
||||
}
|
||||
|
||||
void onednn_scaled_mm(
|
||||
torch::Tensor& c, // [M, OC], row-major
|
||||
const torch::Tensor& a, // [M, IC], row-major
|
||||
const torch::Tensor& a_scales, // [M] or [1]
|
||||
const std::optional<torch::Tensor>& azp, // [M] or [1]
|
||||
const std::optional<torch::Tensor>& azp_adj, // [M] or [1]
|
||||
const std::optional<torch::Tensor>& bias, // [N]
|
||||
int64_t handler) {
|
||||
CPU_KERNEL_GUARD_IN(onednn_scaled_mm)
|
||||
TORCH_CHECK(a.dim() == 2);
|
||||
TORCH_CHECK(a.is_contiguous());
|
||||
TORCH_CHECK(c.is_contiguous());
|
||||
W8A8MatMulPrimitiveHandler* ptr =
|
||||
reinterpret_cast<W8A8MatMulPrimitiveHandler*>(handler);
|
||||
const int32_t* azp_ptr = nullptr;
|
||||
if (azp.has_value()) {
|
||||
azp_ptr = azp->data_ptr<int32_t>();
|
||||
}
|
||||
if (ptr->get_input_scale_strategy() ==
|
||||
W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR) {
|
||||
TORCH_CHECK_EQ(a_scales.numel(), 1);
|
||||
}
|
||||
|
||||
W8A8MatMulPrimitiveHandler::ExecArgs exec_args;
|
||||
exec_args.a_ptr = a.data_ptr<int8_t>();
|
||||
exec_args.a_m_size = a.size(0);
|
||||
exec_args.bias_ptr = nullptr;
|
||||
exec_args.use_bias = false;
|
||||
exec_args.a_scales_ptr = nullptr;
|
||||
exec_args.a_zero_points_ptr = nullptr;
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "onednn_scaled_mm", [&] {
|
||||
if (ptr->get_input_scale_strategy() ==
|
||||
W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR) {
|
||||
if (bias.has_value()) {
|
||||
exec_args.bias_ptr = bias->data_ptr<scalar_t>();
|
||||
exec_args.bias_type = get_dnnl_type<scalar_t>();
|
||||
exec_args.use_bias = true;
|
||||
}
|
||||
exec_args.a_scales_ptr = a_scales.data_ptr<float>();
|
||||
exec_args.a_zero_points_ptr = azp_ptr;
|
||||
exec_args.c_ptr = c.data_ptr<scalar_t>();
|
||||
ptr->execute(exec_args);
|
||||
} else if (ptr->get_input_scale_strategy() ==
|
||||
W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TOKEN) {
|
||||
torch::Tensor tmp_fp32_out =
|
||||
torch::empty_like(c, ::at::ScalarType::Float);
|
||||
exec_args.c_ptr = tmp_fp32_out.data_ptr<float>();
|
||||
ptr->execute(exec_args);
|
||||
if (bias.has_value()) {
|
||||
if (azp.has_value()) {
|
||||
dynamic_quant_epilogue<true, true>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), azp_ptr, azp_adj->data_ptr<float>(),
|
||||
bias->data_ptr<scalar_t>(), c.size(0), c.size(1));
|
||||
} else {
|
||||
dynamic_quant_epilogue<false, true>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), azp_ptr, nullptr,
|
||||
bias->data_ptr<scalar_t>(), c.size(0), c.size(1));
|
||||
}
|
||||
} else {
|
||||
if (azp.has_value()) {
|
||||
dynamic_quant_epilogue<true, false>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), azp_ptr, azp_adj->data_ptr<float>(),
|
||||
(scalar_t*)nullptr, c.size(0), c.size(1));
|
||||
} else {
|
||||
dynamic_quant_epilogue<false, false>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), azp_ptr, nullptr, (scalar_t*)nullptr,
|
||||
c.size(0), c.size(1));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "invalid act quant type.");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// static-per-tensor quantization.
|
||||
void static_scaled_int8_quant(
|
||||
torch::Tensor& out, // [batch, hidden_size]
|
||||
const torch::Tensor& input, // [batch, hidden_size]
|
||||
const torch::Tensor& scale, std::optional<torch::Tensor> const& azp) {
|
||||
CPU_KERNEL_GUARD_IN(static_scaled_int8_quant)
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK_EQ(input.dim(), 2);
|
||||
TORCH_CHECK_EQ(input.stride(1), 1);
|
||||
TORCH_CHECK(scale.numel() == 1);
|
||||
TORCH_CHECK(!azp.has_value() || azp->numel() == 1);
|
||||
|
||||
const int64_t stride = input.stride(0);
|
||||
const int64_t hidden_size = input.size(1);
|
||||
const int64_t num_tokens = input.size(0);
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "static_scaled_int8_quant_impl", [&] {
|
||||
if (azp.has_value()) {
|
||||
static_scaled_int8_quant_impl<true>(
|
||||
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), azp->data_ptr<int32_t>(), num_tokens,
|
||||
stride, hidden_size);
|
||||
} else {
|
||||
static_scaled_int8_quant_impl<false>(input.data_ptr<scalar_t>(),
|
||||
out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), nullptr,
|
||||
num_tokens, stride, hidden_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// dynamic-per-token quantization.
|
||||
void dynamic_scaled_int8_quant(
|
||||
torch::Tensor& out, // [batch, hidden_size]
|
||||
const torch::Tensor& input, // [batch, hidden_size]
|
||||
torch::Tensor& scale, // [batch, 1]
|
||||
std::optional<torch::Tensor> const& azp) {
|
||||
CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant)
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK_EQ(input.dim(), 2);
|
||||
TORCH_CHECK_EQ(input.stride(1), 1);
|
||||
|
||||
const int64_t hidden_size = input.size(1);
|
||||
const int64_t num_tokens = input.size(0);
|
||||
const int64_t stride = input.stride(0);
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "dynamic_scaled_int8_quant_impl", [&] {
|
||||
if (azp.has_value()) {
|
||||
dynamic_scaled_int8_quant_impl<true>(
|
||||
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), azp->data_ptr<int32_t>(), num_tokens,
|
||||
stride, hidden_size);
|
||||
} else {
|
||||
dynamic_scaled_int8_quant_impl<false>(
|
||||
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), nullptr, num_tokens, stride,
|
||||
hidden_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
@ -1,951 +0,0 @@
|
||||
#include "cpu_types.hpp"
|
||||
#include "dnnl_helper.hpp"
|
||||
|
||||
namespace {
|
||||
template <typename scalar_t>
|
||||
struct KernelVecType {
|
||||
using load_vec_type = void;
|
||||
using azp_adj_load_vec_type = void;
|
||||
using cvt_vec_type = void;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KernelVecType<float> {
|
||||
using load_vec_type = vec_op::FP32Vec16;
|
||||
using azp_adj_load_vec_type = vec_op::INT32Vec16;
|
||||
using cvt_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
|
||||
#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT)
|
||||
template <>
|
||||
struct KernelVecType<c10::BFloat16> {
|
||||
using load_vec_type = vec_op::BF16Vec16;
|
||||
using azp_adj_load_vec_type = vec_op::INT32Vec16;
|
||||
using cvt_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <>
|
||||
struct KernelVecType<c10::Half> {
|
||||
#if defined(__powerpc64__) || defined(__s390x__)
|
||||
// Power architecture-specific vector type
|
||||
using load_vec_type = vec_op::FP32Vec16;
|
||||
#else
|
||||
// Fallback for other architectures
|
||||
using load_vec_type = vec_op::FP16Vec16;
|
||||
#endif
|
||||
using azp_adj_load_vec_type = vec_op::INT32Vec16;
|
||||
using cvt_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
|
||||
#if defined(__AVX512F__) || defined(__aarch64__)
|
||||
template <bool AZP, typename scalar_t>
|
||||
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
const float* scale, const int32_t* azp,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
constexpr float i8_min =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::min());
|
||||
constexpr float i8_max =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::max());
|
||||
const cvt_vec_t inv_scale(1.0 / *scale);
|
||||
const cvt_vec_t i8_min_vec(i8_min);
|
||||
const cvt_vec_t i8_max_vec(i8_max);
|
||||
|
||||
cvt_vec_t zp_vec;
|
||||
if constexpr (AZP) {
|
||||
zp_vec = cvt_vec_t(static_cast<float>(*azp));
|
||||
}
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
int j = 0;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = elems_fp32 * inv_scale;
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + zp_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output + i * hidden_size + j);
|
||||
}
|
||||
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = elems_fp32 * inv_scale;
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + zp_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output + i * hidden_size + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool AZP, typename scalar_t>
|
||||
void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
float* scale, int32_t* azp,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
constexpr float i8_min =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::min());
|
||||
constexpr float i8_max =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::max());
|
||||
const cvt_vec_t i8_min_vec(i8_min);
|
||||
const cvt_vec_t i8_max_vec(i8_max);
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
cvt_vec_t max_value(std::numeric_limits<float>::lowest());
|
||||
cvt_vec_t min_value(std::numeric_limits<float>::max());
|
||||
{
|
||||
int j = 0;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
if constexpr (AZP) {
|
||||
max_value = max_value.max(elems_fp32);
|
||||
min_value = min_value.min(elems_fp32);
|
||||
} else {
|
||||
max_value = max_value.max(elems_fp32.abs());
|
||||
}
|
||||
}
|
||||
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
|
||||
if (j + vec_elem_num == hidden_size) {
|
||||
if constexpr (AZP) {
|
||||
max_value = max_value.max(elems_fp32);
|
||||
min_value = min_value.min(elems_fp32);
|
||||
} else {
|
||||
max_value = max_value.max(elems_fp32.abs());
|
||||
}
|
||||
} else {
|
||||
if constexpr (AZP) {
|
||||
max_value = max_value.max(elems_fp32, hidden_size - j);
|
||||
min_value = min_value.min(elems_fp32, hidden_size - j);
|
||||
} else {
|
||||
max_value = max_value.max(elems_fp32.abs(), hidden_size - j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float scale_val, azp_val;
|
||||
if constexpr (AZP) {
|
||||
float max_scalar = max_value.reduce_max();
|
||||
float min_scalar = min_value.reduce_min();
|
||||
scale_val = (max_scalar - min_scalar) / 255.0f;
|
||||
azp_val = std::nearbyint(-128.0f - min_scalar / scale_val);
|
||||
azp[i] = static_cast<int32_t>(azp_val);
|
||||
scale[i] = scale_val;
|
||||
} else {
|
||||
scale_val = max_value.reduce_max() / 127.0f;
|
||||
scale[i] = scale_val;
|
||||
}
|
||||
|
||||
const cvt_vec_t inv_scale(1.0 / scale_val);
|
||||
const cvt_vec_t azp_vec(azp_val);
|
||||
|
||||
{
|
||||
int j = 0;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = (elems_fp32 * inv_scale);
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + azp_vec;
|
||||
}
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output + i * hidden_size + j);
|
||||
}
|
||||
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = (elems_fp32 * inv_scale);
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + azp_vec;
|
||||
}
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output + i * hidden_size + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool PerChannel, typename scalar_t>
|
||||
void static_quant_epilogue(const float* input, scalar_t* output,
|
||||
const float a_scale, const float* b_scale,
|
||||
const int32_t* azp_with_adj, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl)
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using azp_adj_load_vec_t =
|
||||
typename KernelVecType<scalar_t>::azp_adj_load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
cvt_vec_t a_scale_vec(a_scale);
|
||||
cvt_vec_t b_scale_vec(*b_scale);
|
||||
cvt_vec_t scale_vec = a_scale_vec * b_scale_vec;
|
||||
|
||||
int j = 0;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
cvt_vec_t elems_fp32(input + i * hidden_size + j);
|
||||
azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j);
|
||||
cvt_vec_t azp_adj_fp32(azp_adj_vec);
|
||||
|
||||
if constexpr (PerChannel) {
|
||||
b_scale_vec = cvt_vec_t(b_scale + j);
|
||||
scale_vec = b_scale_vec * a_scale_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32;
|
||||
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output + i * hidden_size + j);
|
||||
}
|
||||
|
||||
cvt_vec_t elems_fp32(input + i * hidden_size + j);
|
||||
azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j);
|
||||
cvt_vec_t azp_adj_fp32(azp_adj_vec);
|
||||
|
||||
if constexpr (PerChannel) {
|
||||
b_scale_vec = cvt_vec_t(b_scale + j);
|
||||
scale_vec = b_scale_vec * a_scale_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32;
|
||||
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output + i * hidden_size + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool AZP, bool PerChannel, bool Bias, typename scalar_t>
|
||||
void dynamic_quant_epilogue(const float* input, scalar_t* output,
|
||||
const float* a_scale, const float* b_scale,
|
||||
const int32_t* azp, const int32_t* azp_adj,
|
||||
const scalar_t* bias, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue)
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using azp_adj_load_vec_t =
|
||||
typename KernelVecType<scalar_t>::azp_adj_load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
int j = 0;
|
||||
cvt_vec_t token_scale_vec(a_scale[i]);
|
||||
cvt_vec_t token_zp_scale_vec;
|
||||
if constexpr (AZP) {
|
||||
float zp_scale_val = a_scale[i] * static_cast<float>(azp[i]);
|
||||
if constexpr (!PerChannel) {
|
||||
zp_scale_val *= *b_scale;
|
||||
}
|
||||
token_zp_scale_vec = cvt_vec_t(zp_scale_val);
|
||||
}
|
||||
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
cvt_vec_t elems_fp32(input + i * hidden_size + j);
|
||||
elems_fp32 = elems_fp32 * token_scale_vec;
|
||||
|
||||
if constexpr (AZP) {
|
||||
azp_adj_load_vec_t azp_adj_vec(azp_adj + j);
|
||||
cvt_vec_t azp_adj_fp32(azp_adj_vec);
|
||||
azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec;
|
||||
|
||||
if constexpr (PerChannel) {
|
||||
cvt_vec_t b_scale_vec(b_scale + j);
|
||||
azp_adj_fp32 = azp_adj_fp32 * b_scale_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32 - azp_adj_fp32;
|
||||
}
|
||||
|
||||
if constexpr (Bias) {
|
||||
load_vec_t bias_vec(bias + j);
|
||||
cvt_vec_t bias_vec_fp32(bias_vec);
|
||||
elems_fp32 = elems_fp32 + bias_vec_fp32;
|
||||
}
|
||||
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output + i * hidden_size + j);
|
||||
}
|
||||
|
||||
cvt_vec_t elems_fp32(input + i * hidden_size + j);
|
||||
elems_fp32 = elems_fp32 * token_scale_vec;
|
||||
|
||||
if constexpr (AZP) {
|
||||
azp_adj_load_vec_t azp_adj_vec(azp_adj + j);
|
||||
cvt_vec_t azp_adj_fp32(azp_adj_vec);
|
||||
azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec;
|
||||
|
||||
if constexpr (PerChannel) {
|
||||
cvt_vec_t b_scale_vec(b_scale + j);
|
||||
azp_adj_fp32 = azp_adj_fp32 * b_scale_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32 - azp_adj_fp32;
|
||||
}
|
||||
|
||||
if constexpr (Bias) {
|
||||
load_vec_t bias_vec(bias + j);
|
||||
cvt_vec_t bias_vec_fp32(bias_vec);
|
||||
elems_fp32 = elems_fp32 + bias_vec_fp32;
|
||||
}
|
||||
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output + i * hidden_size + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
#elif defined(__powerpc64__)
|
||||
template <bool AZP, typename scalar_t>
|
||||
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
const float* scale, const int32_t* azp,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
constexpr float i8_min =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::min());
|
||||
constexpr float i8_max =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::max());
|
||||
|
||||
const cvt_vec_t inv_scale(1.0 / *scale);
|
||||
const cvt_vec_t i8_min_vec(i8_min);
|
||||
const cvt_vec_t i8_max_vec(i8_max);
|
||||
|
||||
cvt_vec_t zp_vec;
|
||||
if constexpr (AZP) {
|
||||
zp_vec = cvt_vec_t(static_cast<float>(*azp));
|
||||
}
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
int j = 0;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = elems_fp32 * inv_scale;
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + zp_vec;
|
||||
}
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output + i * hidden_size + j);
|
||||
}
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = elems_fp32 * inv_scale;
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + zp_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output + i * hidden_size + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
template <bool AZP, typename scalar_t>
|
||||
void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
float* scale, int32_t* azp,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
constexpr float i8_min =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::min());
|
||||
constexpr float i8_max =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::max());
|
||||
const cvt_vec_t i8_min_vec(i8_min);
|
||||
const cvt_vec_t i8_max_vec(i8_max);
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
cvt_vec_t max_value(std::numeric_limits<float>::lowest());
|
||||
cvt_vec_t min_value(std::numeric_limits<float>::max());
|
||||
{
|
||||
int j = 0;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
if constexpr (AZP) {
|
||||
max_value = max_value.max(elems_fp32);
|
||||
min_value = min_value.min(elems_fp32);
|
||||
} else {
|
||||
max_value = max_value.max(elems_fp32.abs());
|
||||
}
|
||||
}
|
||||
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
|
||||
if (j + vec_elem_num == hidden_size) {
|
||||
if constexpr (AZP) {
|
||||
max_value = max_value.max(elems_fp32);
|
||||
min_value = min_value.min(elems_fp32);
|
||||
} else {
|
||||
max_value = max_value.max(elems_fp32.abs());
|
||||
}
|
||||
} else {
|
||||
if constexpr (AZP) {
|
||||
max_value = max_value.max(elems_fp32, hidden_size - j);
|
||||
min_value = min_value.min(elems_fp32, hidden_size - j);
|
||||
} else {
|
||||
max_value = max_value.max(elems_fp32.abs(), hidden_size - j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float scale_val, azp_val;
|
||||
if constexpr (AZP) {
|
||||
float max_scalar = max_value.reduce_max();
|
||||
float min_scalar = min_value.reduce_min();
|
||||
scale_val = (max_scalar - min_scalar) / 255.0f;
|
||||
azp_val = std::nearbyint(-128.0f - min_scalar / scale_val);
|
||||
azp[i] = static_cast<int32_t>(azp_val);
|
||||
scale[i] = scale_val;
|
||||
} else {
|
||||
scale_val = max_value.reduce_max() / 127.0f;
|
||||
scale[i] = scale_val;
|
||||
}
|
||||
|
||||
const cvt_vec_t inv_scale(1.0 / scale_val);
|
||||
const cvt_vec_t azp_vec(azp_val);
|
||||
|
||||
{
|
||||
int j = 0;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = (elems_fp32 * inv_scale);
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + azp_vec;
|
||||
}
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output + i * hidden_size + j);
|
||||
}
|
||||
|
||||
load_vec_t elems(input + i * hidden_size + j);
|
||||
cvt_vec_t elems_fp32(elems);
|
||||
elems_fp32 = (elems_fp32 * inv_scale);
|
||||
|
||||
if constexpr (AZP) {
|
||||
elems_fp32 = elems_fp32 + azp_vec;
|
||||
}
|
||||
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
|
||||
vec_op::INT8Vec16 elems_int8(elems_fp32);
|
||||
elems_int8.save(output + i * hidden_size + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
}
|
||||
template <bool PerChannel, typename scalar_t>
|
||||
void static_quant_epilogue(const float* input, scalar_t* output,
|
||||
const float a_scale, const float* b_scale,
|
||||
const int32_t* azp_with_adj, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl)
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using azp_adj_load_vec_t =
|
||||
typename KernelVecType<scalar_t>::azp_adj_load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
cvt_vec_t a_scale_vec(a_scale);
|
||||
cvt_vec_t b_scale_vec(*b_scale);
|
||||
cvt_vec_t scale_vec = a_scale_vec * b_scale_vec;
|
||||
|
||||
int j = 0;
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
cvt_vec_t elems_fp32(input + i * hidden_size + j);
|
||||
azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j);
|
||||
cvt_vec_t azp_adj_fp32(azp_adj_vec);
|
||||
|
||||
if constexpr (PerChannel) {
|
||||
b_scale_vec = cvt_vec_t(b_scale + j);
|
||||
scale_vec = b_scale_vec * a_scale_vec;
|
||||
}
|
||||
elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32;
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output + i * hidden_size + j);
|
||||
}
|
||||
|
||||
cvt_vec_t elems_fp32(input + i * hidden_size + j);
|
||||
azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j);
|
||||
cvt_vec_t azp_adj_fp32(azp_adj_vec);
|
||||
|
||||
if constexpr (PerChannel) {
|
||||
b_scale_vec = cvt_vec_t(b_scale + j);
|
||||
scale_vec = b_scale_vec * a_scale_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32;
|
||||
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output + i * hidden_size + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
template <bool AZP, bool PerChannel, bool Bias, typename scalar_t>
|
||||
void dynamic_quant_epilogue(const float* input, scalar_t* output,
|
||||
const float* a_scale, const float* b_scale,
|
||||
const int32_t* azp, const int32_t* azp_adj,
|
||||
const scalar_t* bias, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue)
|
||||
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
|
||||
using azp_adj_load_vec_t =
|
||||
typename KernelVecType<scalar_t>::azp_adj_load_vec_type;
|
||||
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
|
||||
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < num_tokens; ++i) {
|
||||
int j = 0;
|
||||
cvt_vec_t token_scale_vec(a_scale[i]);
|
||||
cvt_vec_t token_zp_scale_vec;
|
||||
if constexpr (AZP) {
|
||||
float zp_scale_val = a_scale[i] * static_cast<float>(azp[i]);
|
||||
if constexpr (!PerChannel) {
|
||||
zp_scale_val *= *b_scale;
|
||||
}
|
||||
token_zp_scale_vec = cvt_vec_t(zp_scale_val);
|
||||
}
|
||||
|
||||
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
|
||||
cvt_vec_t elems_fp32(input + i * hidden_size + j);
|
||||
elems_fp32 = elems_fp32 * token_scale_vec;
|
||||
|
||||
if constexpr (AZP) {
|
||||
azp_adj_load_vec_t azp_adj_vec(azp_adj + j);
|
||||
cvt_vec_t azp_adj_fp32(azp_adj_vec);
|
||||
azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec;
|
||||
|
||||
if constexpr (PerChannel) {
|
||||
cvt_vec_t b_scale_vec(b_scale + j);
|
||||
azp_adj_fp32 = azp_adj_fp32 * b_scale_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32 - azp_adj_fp32;
|
||||
}
|
||||
|
||||
if constexpr (Bias) {
|
||||
load_vec_t bias_vec(bias + j);
|
||||
cvt_vec_t bias_vec_fp32(bias_vec);
|
||||
elems_fp32 = elems_fp32 + bias_vec_fp32;
|
||||
}
|
||||
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output + i * hidden_size + j);
|
||||
}
|
||||
|
||||
cvt_vec_t elems_fp32(input + i * hidden_size + j);
|
||||
elems_fp32 = elems_fp32 * token_scale_vec;
|
||||
|
||||
if constexpr (AZP) {
|
||||
azp_adj_load_vec_t azp_adj_vec(azp_adj + j);
|
||||
cvt_vec_t azp_adj_fp32(azp_adj_vec);
|
||||
azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec;
|
||||
|
||||
if constexpr (PerChannel) {
|
||||
cvt_vec_t b_scale_vec(b_scale + j);
|
||||
azp_adj_fp32 = azp_adj_fp32 * b_scale_vec;
|
||||
}
|
||||
|
||||
elems_fp32 = elems_fp32 - azp_adj_fp32;
|
||||
}
|
||||
|
||||
if constexpr (Bias) {
|
||||
load_vec_t bias_vec(bias + j);
|
||||
cvt_vec_t bias_vec_fp32(bias_vec);
|
||||
elems_fp32 = elems_fp32 + bias_vec_fp32;
|
||||
}
|
||||
|
||||
load_vec_t elems_out(elems_fp32);
|
||||
elems_out.save(output + i * hidden_size + j, hidden_size - j);
|
||||
}
|
||||
}
|
||||
#else
|
||||
template <typename scalar_t>
|
||||
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
const float* scale, const int32_t* azp,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(false,
|
||||
"static_scaled_int8_quant_impl requires AVX512/powerpc64/AArch64 "
|
||||
"support.")
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
float* scale, int32_t* azp,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(false,
|
||||
"dynamic_scaled_int8_quant_impl requires "
|
||||
"AVX512/powerpc64/AArch64 support.")
|
||||
}
|
||||
|
||||
template <bool PerChannel, typename scalar_t>
|
||||
void static_quant_epilogue(const float* input, scalar_t* output,
|
||||
const float a_scale, const float* b_scale,
|
||||
const int32_t* azp_with_adj, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(
|
||||
false, "static_quant_epilogue requires AVX512/powerpc64/AArch64 support.")
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void dynamic_quant_epilogue(const float* input, scalar_t* output,
|
||||
const float* a_scale, const float* b_scale,
|
||||
const int32_t* azp, const int32_t* azp_with_adj,
|
||||
const scalar_t* bias, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"dynamic_quant_epilogue requires AVX512/powerpc64/AArch64 support.")
|
||||
}
|
||||
#endif
|
||||
} // namespace
|
||||
|
||||
void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major
|
||||
const torch::Tensor& a, // [M, IC], row-major
|
||||
const torch::Tensor& b, // [IC, OC], column-major
|
||||
const torch::Tensor& a_scales, // [1] or [M]
|
||||
const torch::Tensor& b_scales, // [1] or [OC]
|
||||
const std::optional<torch::Tensor>& bias // [OC]
|
||||
) {
|
||||
CPU_KERNEL_GUARD_IN(cutlass_scaled_mm)
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8,
|
||||
"int8_scaled_mm only supports INT8 inputs.")
|
||||
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||
b.size(1) == c.size(1));
|
||||
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
||||
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
|
||||
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||
TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
||||
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
|
||||
bias->dim() == 1);
|
||||
}
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm", [&] {
|
||||
if (a_scales.numel() != 1) {
|
||||
// per-token
|
||||
// Note: oneDNN doesn't support per-token activation quantization
|
||||
// Ideally we want to fuse the GEMM and the scale procedure with oneDNN
|
||||
// JIT, the intermediate data is cached in registers or L1. But for now
|
||||
// the oneDNN GEMM code generation only supports two quantization
|
||||
// patterns: per-tensor or per-output-channel of weight.
|
||||
// So we have to apply the per-token scale with a 'epilogue'. In C=s_a *
|
||||
// s_b * (A@B) + bias, the C_inter = s_b * (A@B) is computed by oneDNN
|
||||
// GEMM, then the per-token scale (and bias) is applied with the epilogue
|
||||
// C=s_a * C_inter + bias.
|
||||
torch::Tensor tmp_fp32_out =
|
||||
torch::empty_like(c, ::at::ScalarType::Float);
|
||||
// Compute C_inter=s_b * (A@B)
|
||||
DNNLPrimitiveHelper<true>::gemm_s8s8_jit<float, void>(
|
||||
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(),
|
||||
tmp_fp32_out.data_ptr<float>(), nullptr, a.size(0), b.size(1),
|
||||
a.size(1), nullptr, b_scales.data_ptr<float>(), 0, b_scales.numel());
|
||||
if (bias.has_value()) {
|
||||
// Compute C=s_a * C_inter + bias
|
||||
dynamic_quant_epilogue<false, true, true>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), nullptr, nullptr, nullptr,
|
||||
bias->data_ptr<scalar_t>(), c.size(0), c.size(1));
|
||||
} else {
|
||||
// Compute C=s_a * C_inter
|
||||
dynamic_quant_epilogue<false, true, false, scalar_t>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), nullptr, nullptr, nullptr, nullptr,
|
||||
c.size(0), c.size(1));
|
||||
}
|
||||
} else {
|
||||
// per-tensor
|
||||
if (bias.has_value()) {
|
||||
// Compute C=s_a * s_b * (A@B) + bias
|
||||
DNNLPrimitiveHelper<false>::gemm_s8s8_jit(
|
||||
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(), c.data_ptr<scalar_t>(),
|
||||
bias->data_ptr<scalar_t>(), a.size(0), b.size(1), a.size(1),
|
||||
a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
|
||||
a_scales.numel(), b_scales.numel());
|
||||
} else {
|
||||
// Compute C=s_a * s_b * (A@B)
|
||||
DNNLPrimitiveHelper<false>::gemm_s8s8_jit<scalar_t, void>(
|
||||
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(), c.data_ptr<scalar_t>(),
|
||||
nullptr, a.size(0), b.size(1), a.size(1),
|
||||
a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
|
||||
a_scales.numel(), b_scales.numel());
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void int8_scaled_mm_azp(torch::Tensor& c, // [M, OC], row-major
|
||||
const torch::Tensor& a, // [M, IC], row-major
|
||||
const torch::Tensor& b, // [IC, OC], column-major
|
||||
const torch::Tensor& a_scales, // [1] or [M]
|
||||
const torch::Tensor& b_scales, // [1] or [OC]
|
||||
const torch::Tensor& azp_adj, // [OC]
|
||||
const std::optional<torch::Tensor>& azp, // [1] or [M]
|
||||
const std::optional<torch::Tensor>& bias // [OC]
|
||||
) {
|
||||
CPU_KERNEL_GUARD_IN(cutlass_scaled_mm_azp)
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8,
|
||||
"int8_scaled_mm_azp only supports INT8 inputs.")
|
||||
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||
b.size(1) == c.size(1));
|
||||
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
||||
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
|
||||
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||
TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
||||
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous());
|
||||
}
|
||||
if (azp) {
|
||||
TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous());
|
||||
}
|
||||
TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous());
|
||||
|
||||
// azp & bias types
|
||||
TORCH_CHECK(azp_adj.dtype() == torch::kInt32);
|
||||
TORCH_CHECK(!azp || azp->dtype() == torch::kInt32);
|
||||
TORCH_CHECK(!bias || bias->dtype() == c.dtype(),
|
||||
"currently bias dtype must match output dtype ", c.dtype());
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm_azp", [&] {
|
||||
torch::Tensor tmp_fp32_out = torch::empty_like(c, ::at::ScalarType::Float);
|
||||
if (a_scales.numel() != 1) {
|
||||
// per-token
|
||||
// Note: oneDNN doesn't support per-token activation quantization
|
||||
// Compute C_inter=s_b * (A@B)
|
||||
DNNLPrimitiveHelper<true>::gemm_s8s8_jit<float, void>(
|
||||
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(),
|
||||
tmp_fp32_out.data_ptr<float>(), nullptr, a.size(0), b.size(1),
|
||||
a.size(1), nullptr, b_scales.data_ptr<float>(), 0, b_scales.numel());
|
||||
if (bias.has_value()) {
|
||||
// Compute C=s_a * C_inter - s_a * s_b * azp * azp_adj + bias
|
||||
if (b_scales.numel() != 1) {
|
||||
// Per-Channel
|
||||
dynamic_quant_epilogue<true, true, true>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
|
||||
azp->data_ptr<int32_t>(), azp_adj.data_ptr<int32_t>(),
|
||||
bias->data_ptr<scalar_t>(), c.size(0), c.size(1));
|
||||
} else {
|
||||
// Per-Tensor
|
||||
dynamic_quant_epilogue<true, false, true>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
|
||||
azp->data_ptr<int32_t>(), azp_adj.data_ptr<int32_t>(),
|
||||
bias->data_ptr<scalar_t>(), c.size(0), c.size(1));
|
||||
}
|
||||
} else {
|
||||
// Compute C=s_a * C_inter - s_a * s_b * azp * azp_adj
|
||||
if (b_scales.numel() != 1) {
|
||||
// Per-Channel
|
||||
dynamic_quant_epilogue<true, true, false, scalar_t>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
|
||||
azp->data_ptr<int32_t>(), azp_adj.data_ptr<int32_t>(), nullptr,
|
||||
c.size(0), c.size(1));
|
||||
} else {
|
||||
// Per-Tensor
|
||||
dynamic_quant_epilogue<true, false, false, scalar_t>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
|
||||
azp->data_ptr<int32_t>(), azp_adj.data_ptr<int32_t>(), nullptr,
|
||||
c.size(0), c.size(1));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// per-tensor
|
||||
if (bias.has_value()) {
|
||||
// Compute C_inter=s_a * s_b * (A@B) + bias
|
||||
DNNLPrimitiveHelper<false>::gemm_s8s8_jit(
|
||||
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(),
|
||||
tmp_fp32_out.data_ptr<float>(), bias->data_ptr<scalar_t>(),
|
||||
a.size(0), b.size(1), a.size(1), a_scales.data_ptr<float>(),
|
||||
b_scales.data_ptr<float>(), a_scales.numel(), b_scales.numel());
|
||||
} else {
|
||||
// Compute C_inter=s_a * s_b * (A@B)
|
||||
DNNLPrimitiveHelper<false>::gemm_s8s8_jit<float, void>(
|
||||
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(),
|
||||
tmp_fp32_out.data_ptr<float>(), nullptr, a.size(0), b.size(1),
|
||||
a.size(1), a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
|
||||
a_scales.numel(), b_scales.numel());
|
||||
}
|
||||
|
||||
// Compute C=C_inter - s_a * s_b * azp_adj
|
||||
if (b_scales.numel() != 1) {
|
||||
// Per-Channel
|
||||
static_quant_epilogue<true>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
*a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
|
||||
azp_adj.data_ptr<int32_t>(), a.size(0), b.size(1));
|
||||
} else {
|
||||
// Per-Tensor
|
||||
static_quant_epilogue<false>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
*a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
|
||||
azp_adj.data_ptr<int32_t>(), a.size(0), b.size(1));
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// static-per-tensor quantization.
|
||||
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||
const torch::Tensor& input, // [..., hidden_size]
|
||||
const torch::Tensor& scale,
|
||||
std::optional<torch::Tensor> const& azp) {
|
||||
CPU_KERNEL_GUARD_IN(static_scaled_int8_quant)
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK(scale.numel() == 1);
|
||||
TORCH_CHECK(!azp.has_value() || azp->numel() == 1);
|
||||
|
||||
const int hidden_size = input.size(-1);
|
||||
const int num_tokens = input.numel() / hidden_size;
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "static_scaled_int8_quant_impl", [&] {
|
||||
if (azp.has_value()) {
|
||||
static_scaled_int8_quant_impl<true>(
|
||||
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), azp->data_ptr<int32_t>(), num_tokens,
|
||||
hidden_size);
|
||||
} else {
|
||||
static_scaled_int8_quant_impl<false>(
|
||||
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), nullptr, num_tokens, hidden_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// dynamic-per-token quantization.
|
||||
void dynamic_scaled_int8_quant(
|
||||
torch::Tensor& out, // [..., hidden_size]
|
||||
const torch::Tensor& input, // [..., hidden_size]
|
||||
torch::Tensor& scale, // [..., 1]
|
||||
std::optional<torch::Tensor> const& azp) {
|
||||
CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant)
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
|
||||
int const hidden_size = input.size(-1);
|
||||
int const num_tokens = input.numel() / hidden_size;
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "dynamic_scaled_int8_quant_impl", [&] {
|
||||
if (azp.has_value()) {
|
||||
dynamic_scaled_int8_quant_impl<true>(
|
||||
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), azp->data_ptr<int32_t>(), num_tokens,
|
||||
hidden_size);
|
||||
} else {
|
||||
dynamic_scaled_int8_quant_impl<false>(
|
||||
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), nullptr, num_tokens, hidden_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#if defined(__powerpc64__)
|
||||
void int8_scaled_mm_ppc64le(torch::Tensor& c, // [M, OC], row-major
|
||||
const torch::Tensor& a, // [M, IC], row-major
|
||||
const torch::Tensor& b, // [IC, OC], column-major
|
||||
const torch::Tensor& a_scales,
|
||||
const torch::Tensor& b_scales,
|
||||
const std::optional<torch::Tensor>& bias // [OC]
|
||||
) {
|
||||
CPU_KERNEL_GUARD_IN(cutlass_scaled_mm)
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8,
|
||||
"int8_scaled_mm_ppc64le only supports INT8 inputs.");
|
||||
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||
b.size(1) == c.size(1));
|
||||
// We dont need this
|
||||
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
||||
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
|
||||
|
||||
// Check for strides and alignment
|
||||
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
||||
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
||||
TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
||||
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
|
||||
bias->dim() == 1);
|
||||
}
|
||||
VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm_ppc64le", [&] {
|
||||
torch::Tensor tmp_fp32_out = torch::empty_like(c, ::at::ScalarType::Float);
|
||||
// Compute C_inter=s_b * (A@B)
|
||||
DNNLPrimitiveHelper<true>::gemm_s8s8_jit<float, void>(
|
||||
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(),
|
||||
tmp_fp32_out.data_ptr<float>(), nullptr, a.size(0), b.size(1),
|
||||
a.size(1), nullptr, b_scales.data_ptr<float>(), 0, b_scales.numel());
|
||||
if (bias.has_value()) {
|
||||
// Compute C=s_a * C_inter + bias
|
||||
dynamic_quant_epilogue<false, true, true>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), nullptr, nullptr, nullptr,
|
||||
bias->data_ptr<scalar_t>(), c.size(0), c.size(1));
|
||||
} else {
|
||||
// Compute C=s_a * C_inter
|
||||
dynamic_quant_epilogue<false, true, false, scalar_t>(
|
||||
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
|
||||
a_scales.data_ptr<float>(), nullptr, nullptr, nullptr, nullptr,
|
||||
c.size(0), c.size(1));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#endif
|
||||
@ -6,25 +6,20 @@
|
||||
|
||||
std::string init_cpu_threads_env(const std::string& cpu_ids);
|
||||
|
||||
void int8_scaled_mm(torch::Tensor& c, const torch::Tensor& a,
|
||||
const torch::Tensor& b, const torch::Tensor& a_scales,
|
||||
const torch::Tensor& b_scales,
|
||||
const std::optional<torch::Tensor>& bias);
|
||||
void release_dnnl_matmul_handler(int64_t handler);
|
||||
|
||||
void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a,
|
||||
const torch::Tensor& b, const torch::Tensor& a_scales,
|
||||
const torch::Tensor& b_scales,
|
||||
const torch::Tensor& azp_adj,
|
||||
const std::optional<torch::Tensor>& azp,
|
||||
const std::optional<torch::Tensor>& bias);
|
||||
int64_t create_onednn_scaled_mm_handler(const torch::Tensor& b,
|
||||
const torch::Tensor& b_scales,
|
||||
at::ScalarType output_type,
|
||||
bool dynamic_act_quant, bool use_azp,
|
||||
int64_t primitive_cache_size);
|
||||
|
||||
#if defined(__powerpc64__)
|
||||
void int8_scaled_mm_ppc64le(torch::Tensor& c, const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& a_scales,
|
||||
const torch::Tensor& b_scales,
|
||||
const std::optional<torch::Tensor>& bias);
|
||||
#endif
|
||||
void onednn_scaled_mm(torch::Tensor& c, const torch::Tensor& a,
|
||||
const torch::Tensor& a_scales,
|
||||
const std::optional<torch::Tensor>& azp,
|
||||
const std::optional<torch::Tensor>& azp_adj,
|
||||
const std::optional<torch::Tensor>& bias,
|
||||
int64_t handler);
|
||||
|
||||
void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
|
||||
torch::Tensor& kv_cache, double scale,
|
||||
@ -151,8 +146,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding);
|
||||
|
||||
// Quantization
|
||||
#if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__))
|
||||
#if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__)) || \
|
||||
defined(__powerpc64__)
|
||||
at::Tag stride_tag = at::Tag::needs_fixed_stride_order;
|
||||
// Helper function to release oneDNN handlers
|
||||
ops.def("release_dnnl_matmul_handler(int handler) -> ()",
|
||||
&release_dnnl_matmul_handler);
|
||||
|
||||
// Create oneDNN W8A8 handler
|
||||
ops.def(
|
||||
"create_onednn_scaled_mm_handler(Tensor b, Tensor b_scales, ScalarType "
|
||||
"output_type, bool dynamic_act_quant, bool use_azp, int "
|
||||
"primitive_cache_size) -> int",
|
||||
&create_onednn_scaled_mm_handler);
|
||||
|
||||
// oneDNN scaled_mm for W8A8 with static per-tensor activation quantization
|
||||
ops.def(
|
||||
"onednn_scaled_mm(Tensor! c, Tensor a, Tensor a_scales, Tensor? azp, "
|
||||
"Tensor? azp_adj, Tensor? bias, int handler) -> ()");
|
||||
ops.impl("onednn_scaled_mm", torch::kCPU, &onednn_scaled_mm);
|
||||
|
||||
// Compute int8 quantized tensor for given scaling factor.
|
||||
ops.def(
|
||||
@ -168,50 +180,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
{stride_tag});
|
||||
ops.impl("dynamic_scaled_int8_quant", torch::kCPU,
|
||||
&dynamic_scaled_int8_quant);
|
||||
// W8A8 GEMM, supporting symmetric per-tensor or per-row/column
|
||||
// quantization.
|
||||
ops.def(
|
||||
"cutlass_scaled_mm(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor? bias) -> ()",
|
||||
{stride_tag});
|
||||
ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm);
|
||||
// w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
|
||||
// quantization.
|
||||
ops.def(
|
||||
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor azp_adj,"
|
||||
" Tensor? azp, Tensor? bias) -> ()",
|
||||
{stride_tag});
|
||||
ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp);
|
||||
#elif defined(__powerpc64__)
|
||||
// Compute int8 quantized tensor for given scaling factor.
|
||||
ops.def(
|
||||
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
|
||||
"Tensor? azp) -> ()");
|
||||
ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant);
|
||||
|
||||
// Compute int8 quantized tensor and scaling factor
|
||||
ops.def(
|
||||
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, "
|
||||
"Tensor!? azp) -> ()");
|
||||
ops.impl("dynamic_scaled_int8_quant", torch::kCPU,
|
||||
&dynamic_scaled_int8_quant);
|
||||
// W8A8 GEMM, supporting symmetric quantization.
|
||||
ops.def(
|
||||
"cutlass_scaled_mm(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor? bias) -> ()");
|
||||
ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm_ppc64le);
|
||||
// w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
|
||||
// quantization.
|
||||
ops.def(
|
||||
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor azp_adj,"
|
||||
" Tensor? azp, Tensor? bias) -> ()");
|
||||
ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp);
|
||||
#endif
|
||||
|
||||
// SHM CCL
|
||||
|
||||
757
csrc/moe/grouped_topk_kernels.cu
Normal file
757
csrc/moe/grouped_topk_kernels.cu
Normal file
@ -0,0 +1,757 @@
|
||||
/*
|
||||
* Adapted from
|
||||
* https://github.com/NVIDIA/TensorRT-LLM/blob/v0.21.0/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu
|
||||
* Copyright (c) 2025, The vLLM team.
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION &
|
||||
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/all.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
namespace vllm {
|
||||
namespace moe {
|
||||
|
||||
constexpr unsigned FULL_WARP_MASK = 0xffffffff;
|
||||
constexpr int32_t WARP_SIZE = 32;
|
||||
constexpr int32_t BLOCK_SIZE = 512;
|
||||
constexpr int32_t NUM_WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE;
|
||||
|
||||
namespace warp_topk {
|
||||
|
||||
template <int size, typename T>
|
||||
__host__ __device__ constexpr T round_up_to_multiple_of(T len) {
|
||||
if (len == 0) {
|
||||
return 0;
|
||||
}
|
||||
return ((len - 1) / size + 1) * size;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr __host__ __device__ bool isPowerOf2(T v) {
|
||||
return (v && !(v & (v - 1)));
|
||||
}
|
||||
|
||||
template <bool greater, typename T>
|
||||
__forceinline__ __device__ bool is_better_than(T val, T baseline) {
|
||||
return (val > baseline && greater) || (val < baseline && !greater);
|
||||
}
|
||||
|
||||
template <bool greater, typename T, typename idxT>
|
||||
__forceinline__ __device__ bool is_better_than(T val, T baseline, idxT index,
|
||||
idxT baseline_index) {
|
||||
bool res = (val > baseline && greater) || (val < baseline && !greater);
|
||||
if (val == baseline) {
|
||||
res = (index < baseline_index && greater) ||
|
||||
(index < baseline_index && !greater);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename T, typename idxT>
|
||||
int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) {
|
||||
int64_t cache_topk = (sizeof(T) + sizeof(idxT)) * num_of_warp * k;
|
||||
int64_t n = std::max<int>(num_of_warp / 2 * k, num_of_warp * WARP_SIZE);
|
||||
return max(cache_topk,
|
||||
round_up_to_multiple_of<256>(n * sizeof(T)) + n * sizeof(idxT));
|
||||
}
|
||||
|
||||
template <int size, bool ascending, bool reverse, typename T, typename idxT,
|
||||
bool is_stable>
|
||||
struct BitonicMerge {
|
||||
// input should be a bitonic sequence, and sort it to be a monotonic sequence
|
||||
__device__ static void merge(T* __restrict__ val_arr,
|
||||
idxT* __restrict__ idx_arr) {
|
||||
static_assert(isPowerOf2(size));
|
||||
static_assert(size >= 2 * WARP_SIZE);
|
||||
constexpr int arr_len = size / WARP_SIZE;
|
||||
|
||||
constexpr int stride = arr_len / 2;
|
||||
for (int i = 0; i < stride; ++i) {
|
||||
int const other_i = i + stride;
|
||||
T& val = val_arr[i];
|
||||
T& other_val = val_arr[other_i];
|
||||
bool is_better;
|
||||
if constexpr (is_stable) {
|
||||
is_better = is_better_than<ascending>(val, other_val, idx_arr[i],
|
||||
idx_arr[other_i]);
|
||||
} else {
|
||||
is_better = is_better_than<ascending>(val, other_val);
|
||||
}
|
||||
|
||||
if (is_better) {
|
||||
T tmp = val;
|
||||
val = other_val;
|
||||
other_val = tmp;
|
||||
|
||||
idxT tmp2 = idx_arr[i];
|
||||
idx_arr[i] = idx_arr[other_i];
|
||||
idx_arr[other_i] = tmp2;
|
||||
}
|
||||
}
|
||||
|
||||
BitonicMerge<size / 2, ascending, reverse, T, idxT, is_stable>::merge(
|
||||
val_arr, idx_arr);
|
||||
BitonicMerge<size / 2, ascending, reverse, T, idxT, is_stable>::merge(
|
||||
val_arr + arr_len / 2, idx_arr + arr_len / 2);
|
||||
}
|
||||
};
|
||||
|
||||
template <int size, bool ascending, typename T, typename idxT, bool is_stable>
|
||||
struct BitonicSort {
|
||||
__device__ static void sort(T* __restrict__ val_arr,
|
||||
idxT* __restrict__ idx_arr) {
|
||||
static_assert(isPowerOf2(size));
|
||||
static_assert(size >= 2 * WARP_SIZE);
|
||||
constexpr int arr_len = size / WARP_SIZE;
|
||||
|
||||
BitonicSort<size / 2, true, T, idxT, is_stable>::sort(val_arr, idx_arr);
|
||||
BitonicSort<size / 2, false, T, idxT, is_stable>::sort(
|
||||
val_arr + arr_len / 2, idx_arr + arr_len / 2);
|
||||
BitonicMerge<size, ascending, ascending, T, idxT, is_stable>::merge(
|
||||
val_arr, idx_arr);
|
||||
}
|
||||
};
|
||||
|
||||
template <bool ascending, typename T, typename idxT, bool is_stable>
|
||||
struct BitonicSort<32, ascending, T, idxT, is_stable> {
|
||||
__device__ static void sort(T* __restrict__ val_arr,
|
||||
idxT* __restrict__ idx_arr) {
|
||||
int const lane = threadIdx.x % WARP_SIZE;
|
||||
|
||||
// ascending doesn't matter before merging since all we need is a bitonic
|
||||
// sequence
|
||||
for (int stage = 0; stage < 4; ++stage) {
|
||||
for (int stride = (1 << stage); stride > 0; stride /= 2) {
|
||||
bool reverse = (lane >> stage) & 2;
|
||||
bool is_second = lane & stride;
|
||||
|
||||
T other = __shfl_xor_sync(FULL_WARP_MASK, *val_arr, stride);
|
||||
idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, *idx_arr, stride);
|
||||
|
||||
bool is_better;
|
||||
if constexpr (is_stable) {
|
||||
if constexpr (ascending) {
|
||||
is_better = ((*val_arr > other) ||
|
||||
((*val_arr == other) && (*idx_arr < other_idx))) !=
|
||||
(reverse != is_second);
|
||||
} else {
|
||||
is_better = ((*val_arr > other) ||
|
||||
((*val_arr == other) && (*idx_arr > other_idx))) !=
|
||||
(reverse != is_second);
|
||||
}
|
||||
} else {
|
||||
is_better = (*val_arr != other &&
|
||||
(*val_arr > other) != (reverse != is_second));
|
||||
}
|
||||
if (is_better) {
|
||||
*val_arr = other;
|
||||
*idx_arr = other_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BitonicMerge<32, ascending, ascending, T, idxT, is_stable>::merge(val_arr,
|
||||
idx_arr);
|
||||
}
|
||||
};
|
||||
|
||||
template <bool ascending, bool reverse, typename T, typename idxT,
|
||||
bool is_stable>
|
||||
struct BitonicMerge<32, ascending, reverse, T, idxT, is_stable> {
|
||||
__device__ static void merge(T* __restrict__ val_arr,
|
||||
idxT* __restrict__ idx_arr) {
|
||||
int const lane = threadIdx.x % WARP_SIZE;
|
||||
for (int stride = WARP_SIZE / 2; stride > 0; stride /= 2) {
|
||||
bool is_second = lane & stride;
|
||||
T& val = *val_arr;
|
||||
T other = __shfl_xor_sync(FULL_WARP_MASK, val, stride);
|
||||
idxT& idx = *idx_arr;
|
||||
idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, idx, stride);
|
||||
|
||||
bool is_better;
|
||||
if constexpr (is_stable) {
|
||||
if constexpr (ascending) {
|
||||
is_better = ((*val_arr > other) ||
|
||||
((*val_arr == other) && (*idx_arr < other_idx))) ==
|
||||
(reverse != is_second); // for min
|
||||
} else {
|
||||
is_better = ((*val_arr > other) ||
|
||||
((*val_arr == other) && (*idx_arr > other_idx))) ==
|
||||
(reverse != is_second); // for max
|
||||
}
|
||||
} else {
|
||||
is_better =
|
||||
(val != other && ((val > other) == (ascending != is_second)));
|
||||
}
|
||||
|
||||
if (is_better) {
|
||||
val = other;
|
||||
idx = other_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <int capacity, bool greater, typename T, typename idxT, bool is_stable>
|
||||
class WarpSort {
|
||||
public:
|
||||
__device__ WarpSort(idxT k, T dummy)
|
||||
: lane_(threadIdx.x % WARP_SIZE), k_(k), dummy_(dummy) {
|
||||
static_assert(capacity >= WARP_SIZE && isPowerOf2(capacity));
|
||||
|
||||
for (int i = 0; i < max_arr_len_; ++i) {
|
||||
val_arr_[i] = dummy_;
|
||||
idx_arr_[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// load and merge k sorted values
|
||||
__device__ void load_sorted(T const* __restrict__ in,
|
||||
idxT const* __restrict__ in_idx, idxT start) {
|
||||
idxT idx = start + WARP_SIZE - 1 - lane_;
|
||||
for (int i = max_arr_len_ - 1; i >= 0; --i, idx += WARP_SIZE) {
|
||||
if (idx < start + k_) {
|
||||
T t = in[idx];
|
||||
bool is_better;
|
||||
if constexpr (is_stable) {
|
||||
is_better =
|
||||
is_better_than<greater>(t, val_arr_[i], in_idx[idx], idx_arr_[i]);
|
||||
} else {
|
||||
is_better = is_better_than<greater>(t, val_arr_[i]);
|
||||
}
|
||||
if (is_better) {
|
||||
val_arr_[i] = t;
|
||||
idx_arr_[i] = in_idx[idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BitonicMerge<capacity, greater, !greater, T, idxT, is_stable>::merge(
|
||||
val_arr_, idx_arr_);
|
||||
}
|
||||
|
||||
__device__ void dump(T* __restrict__ out, idxT* __restrict__ out_idx) const {
|
||||
for (int i = 0; i < max_arr_len_; ++i) {
|
||||
idxT out_i = i * WARP_SIZE + lane_;
|
||||
if (out_i < k_) {
|
||||
out[out_i] = val_arr_[i];
|
||||
out_idx[out_i] = idx_arr_[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void dumpIdx(idxT* __restrict__ out_idx) const {
|
||||
for (int i = 0; i < max_arr_len_; ++i) {
|
||||
idxT out_i = i * WARP_SIZE + lane_;
|
||||
if (out_i < k_) {
|
||||
out_idx[out_i] = idx_arr_[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
static constexpr int max_arr_len_ = capacity / WARP_SIZE;
|
||||
|
||||
T val_arr_[max_arr_len_];
|
||||
idxT idx_arr_[max_arr_len_];
|
||||
|
||||
int const lane_;
|
||||
idxT const k_;
|
||||
T const dummy_;
|
||||
|
||||
}; // end class WarpSort
|
||||
|
||||
template <int capacity, bool greater, typename T, typename idxT, bool is_stable>
|
||||
class WarpSelect : public WarpSort<capacity, greater, T, idxT, is_stable> {
|
||||
public:
|
||||
__device__ WarpSelect(idxT k, T dummy)
|
||||
: WarpSort<capacity, greater, T, idxT, is_stable>(k, dummy),
|
||||
k_th_(dummy),
|
||||
k_th_lane_((k - 1) % WARP_SIZE) {
|
||||
extern __shared__ char smem_buf[]; // extern __shared__ T smem_buf[];
|
||||
|
||||
int const num_of_warp = blockDim.x / WARP_SIZE;
|
||||
int const warp_id = threadIdx.x / WARP_SIZE;
|
||||
val_smem_ = reinterpret_cast<T*>(smem_buf);
|
||||
val_smem_ += warp_id * WARP_SIZE;
|
||||
idx_smem_ = reinterpret_cast<idxT*>(
|
||||
smem_buf +
|
||||
round_up_to_multiple_of<256>(num_of_warp * sizeof(T) * WARP_SIZE));
|
||||
idx_smem_ += warp_id * WARP_SIZE;
|
||||
}
|
||||
|
||||
__device__ void add(T const* in, idxT start, idxT end) {
|
||||
idxT const end_for_fullwarp =
|
||||
round_up_to_multiple_of<WARP_SIZE>(end - start) + start;
|
||||
for (idxT i = start + lane_; i < end_for_fullwarp; i += WARP_SIZE) {
|
||||
T val = (i < end) ? in[i] : dummy_;
|
||||
add(val, i);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void add(T val, idxT idx) {
|
||||
bool do_add;
|
||||
if constexpr (is_stable) {
|
||||
do_add = is_better_than<greater>(val, k_th_, idx, k_th_idx_);
|
||||
} else {
|
||||
do_add = is_better_than<greater>(val, k_th_);
|
||||
}
|
||||
|
||||
uint32_t mask = __ballot_sync(FULL_WARP_MASK, do_add);
|
||||
if (mask == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
int pos = smem_buf_len_ + __popc(mask & ((0x1u << lane_) - 1));
|
||||
if (do_add && pos < WARP_SIZE) {
|
||||
val_smem_[pos] = val;
|
||||
idx_smem_[pos] = idx;
|
||||
do_add = false;
|
||||
}
|
||||
smem_buf_len_ += __popc(mask);
|
||||
if (smem_buf_len_ >= WARP_SIZE) {
|
||||
__syncwarp();
|
||||
merge_buf_(val_smem_[lane_], idx_smem_[lane_]);
|
||||
smem_buf_len_ -= WARP_SIZE;
|
||||
}
|
||||
if (do_add) {
|
||||
pos -= WARP_SIZE;
|
||||
val_smem_[pos] = val;
|
||||
idx_smem_[pos] = idx;
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
__device__ void done() {
|
||||
if (smem_buf_len_) {
|
||||
T val = (lane_ < smem_buf_len_) ? val_smem_[lane_] : dummy_;
|
||||
idxT idx = (lane_ < smem_buf_len_) ? idx_smem_[lane_] : 0;
|
||||
merge_buf_(val, idx);
|
||||
}
|
||||
|
||||
// after done(), smem is used for merging results among warps
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
private:
|
||||
__device__ void set_k_th_() {
|
||||
k_th_ = __shfl_sync(FULL_WARP_MASK, val_arr_[max_arr_len_ - 1], k_th_lane_);
|
||||
if constexpr (is_stable) {
|
||||
k_th_idx_ =
|
||||
__shfl_sync(FULL_WARP_MASK, idx_arr_[max_arr_len_ - 1], k_th_lane_);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void merge_buf_(T val, idxT idx) {
|
||||
BitonicSort<WARP_SIZE, greater, T, idxT, is_stable>::sort(&val, &idx);
|
||||
|
||||
T& old = val_arr_[max_arr_len_ - 1];
|
||||
|
||||
bool is_better;
|
||||
if constexpr (is_stable) {
|
||||
is_better =
|
||||
is_better_than<greater>(val, old, idx, idx_arr_[max_arr_len_ - 1]);
|
||||
} else {
|
||||
is_better = is_better_than<greater>(val, old);
|
||||
}
|
||||
|
||||
if (is_better) {
|
||||
old = val;
|
||||
idx_arr_[max_arr_len_ - 1] = idx;
|
||||
}
|
||||
|
||||
BitonicMerge<capacity, greater, !greater, T, idxT, is_stable>::merge(
|
||||
val_arr_, idx_arr_);
|
||||
|
||||
set_k_th_();
|
||||
}
|
||||
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::max_arr_len_;
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::val_arr_;
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::idx_arr_;
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::lane_;
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::k_;
|
||||
using WarpSort<capacity, greater, T, idxT, is_stable>::dummy_;
|
||||
|
||||
T* val_smem_;
|
||||
idxT* idx_smem_;
|
||||
int smem_buf_len_ = 0;
|
||||
|
||||
T k_th_;
|
||||
idxT k_th_idx_;
|
||||
int const k_th_lane_;
|
||||
}; // end class WarpSelect
|
||||
} // namespace warp_topk
|
||||
|
||||
template <typename T_OUT, typename T_IN>
|
||||
__device__ inline T_OUT cuda_cast(T_IN val) {
|
||||
return val;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline float cuda_cast<float, __nv_bfloat16>(__nv_bfloat16 val) {
|
||||
return __bfloat162float(val);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void topk_with_k2(T* output, T const* input,
|
||||
cg::thread_block_tile<32> const& tile,
|
||||
int32_t const lane_id,
|
||||
int const num_experts_per_group) {
|
||||
// Get the top2 per thread
|
||||
T largest = -INFINITY;
|
||||
T second_largest = -INFINITY;
|
||||
|
||||
if (num_experts_per_group > WARP_SIZE) {
|
||||
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
|
||||
T value = input[i];
|
||||
if (value > largest) {
|
||||
second_largest = largest;
|
||||
largest = value;
|
||||
} else if (value > second_largest) {
|
||||
second_largest = value;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
|
||||
largest = input[i];
|
||||
}
|
||||
}
|
||||
|
||||
__syncwarp(); // Ensure all threads have valid data before reduction
|
||||
// Get the top2 warpwise
|
||||
T max1 = cg::reduce(tile, largest, cg::greater<T>());
|
||||
|
||||
T max2 = max1;
|
||||
bool equal_to_max1 = (max1 == largest);
|
||||
|
||||
int count_max1 = __popc(__ballot_sync(FULL_WARP_MASK, equal_to_max1));
|
||||
|
||||
if (count_max1 == 1) {
|
||||
largest = (largest == max1) ? second_largest : largest;
|
||||
max2 = cg::reduce(tile, largest, cg::greater<T>());
|
||||
}
|
||||
|
||||
if (lane_id == 0) {
|
||||
*output = max1 + max2;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void topk_with_k2_kernel(T* output, T* input,
|
||||
int64_t const num_tokens,
|
||||
int64_t const num_cases,
|
||||
int64_t const n_group,
|
||||
int64_t const num_experts_per_group) {
|
||||
int32_t warp_id = threadIdx.x / WARP_SIZE;
|
||||
int32_t lane_id = threadIdx.x % WARP_SIZE;
|
||||
|
||||
int32_t case_id = blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id;
|
||||
if (case_id < num_cases) {
|
||||
input += case_id * num_experts_per_group;
|
||||
output += case_id;
|
||||
|
||||
cg::thread_block block = cg::this_thread_block();
|
||||
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block);
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.wait;");
|
||||
#endif
|
||||
topk_with_k2(output, input, tile, lane_id, num_experts_per_group);
|
||||
}
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT>
|
||||
__global__ void group_idx_and_topk_idx_kernel(
|
||||
T* scores, T const* group_scores, T* topk_values, IdxT* topk_indices,
|
||||
T* scores_with_bias, int64_t const num_tokens, int64_t const n_group,
|
||||
int64_t const topk_group, int64_t const topk, int64_t const num_experts,
|
||||
int64_t const num_experts_per_group, bool renormalize,
|
||||
double routed_scaling_factor) {
|
||||
int32_t warp_id = threadIdx.x / WARP_SIZE;
|
||||
int32_t lane_id = threadIdx.x % WARP_SIZE;
|
||||
int32_t case_id =
|
||||
blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; // one per token
|
||||
scores_with_bias += case_id * num_experts;
|
||||
scores += case_id * num_experts;
|
||||
group_scores += case_id * n_group;
|
||||
topk_values += case_id * topk;
|
||||
topk_indices += case_id * topk;
|
||||
|
||||
int32_t align_num_experts_per_group =
|
||||
warp_topk::round_up_to_multiple_of<WARP_SIZE>(num_experts_per_group);
|
||||
|
||||
cg::thread_block block = cg::this_thread_block();
|
||||
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block);
|
||||
|
||||
extern __shared__ char smem_buf[]; // NOTE: reuse the shared memory here to
|
||||
// store the target topk idx
|
||||
int32_t* s_topk_idx = reinterpret_cast<int32_t*>(smem_buf);
|
||||
T* s_topk_value =
|
||||
reinterpret_cast<T*>(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) +
|
||||
warp_id * topk;
|
||||
s_topk_idx += warp_id * topk;
|
||||
|
||||
T value = cuda::std::numeric_limits<T>::min();
|
||||
T topk_group_value = cuda::std::numeric_limits<T>::min();
|
||||
int32_t num_equalto_topkth_group;
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.wait;"); // I think all prolog can be put before
|
||||
// acqbulk because it's ptr arithmetic
|
||||
#endif
|
||||
|
||||
if (case_id < num_tokens) {
|
||||
// calculate group_idx
|
||||
int32_t target_num_min = WARP_SIZE - n_group + topk_group;
|
||||
if (lane_id < n_group &&
|
||||
(isfinite(cuda_cast<float, T>(
|
||||
group_scores[lane_id])))) // The check is necessary to avoid
|
||||
// abnormal input
|
||||
{
|
||||
value = group_scores[lane_id];
|
||||
}
|
||||
|
||||
int count_equal_to_top_value = WARP_SIZE - n_group;
|
||||
int pre_count_equal_to_top_value = 0;
|
||||
// Use loop to find the largset top_group
|
||||
while (count_equal_to_top_value < target_num_min) {
|
||||
__syncwarp(); // Ensure all threads have valid data before reduction
|
||||
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
|
||||
if (value == topk_group_value) {
|
||||
value = cuda::std::numeric_limits<T>::min();
|
||||
}
|
||||
pre_count_equal_to_top_value = count_equal_to_top_value;
|
||||
count_equal_to_top_value = __popc(__ballot_sync(
|
||||
FULL_WARP_MASK, (value == cuda::std::numeric_limits<T>::min())));
|
||||
}
|
||||
num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
warp_topk::WarpSelect</*capability*/ WARP_SIZE, /*greater*/ true, T, int32_t,
|
||||
/* is_stable */ true>
|
||||
queue((int32_t)topk, -INFINITY);
|
||||
|
||||
int count_equalto_topkth_group = 0;
|
||||
bool if_proceed_next_topk =
|
||||
(topk_group_value != cuda::std::numeric_limits<T>::min());
|
||||
if (case_id < num_tokens && if_proceed_next_topk) {
|
||||
for (int i_group = 0; i_group < n_group; i_group++) {
|
||||
if ((group_scores[i_group] > topk_group_value) ||
|
||||
((group_scores[i_group] == topk_group_value) &&
|
||||
(count_equalto_topkth_group < num_equalto_topkth_group))) {
|
||||
int32_t offset = i_group * num_experts_per_group;
|
||||
for (int32_t i = lane_id; i < align_num_experts_per_group;
|
||||
i += WARP_SIZE) {
|
||||
T candidates =
|
||||
(i < num_experts_per_group) && isfinite(cuda_cast<float, T>(
|
||||
scores_with_bias[offset + i]))
|
||||
? scores_with_bias[offset + i]
|
||||
: cuda::std::numeric_limits<T>::min();
|
||||
queue.add(candidates, offset + i);
|
||||
}
|
||||
if (group_scores[i_group] == topk_group_value) {
|
||||
count_equalto_topkth_group++;
|
||||
}
|
||||
}
|
||||
}
|
||||
queue.done();
|
||||
__syncwarp();
|
||||
// Get the topk_idx
|
||||
queue.dumpIdx(s_topk_idx);
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
// Load the valid score value
|
||||
// Calculate the summation
|
||||
float topk_sum = 1e-20;
|
||||
if (case_id < num_tokens && if_proceed_next_topk) {
|
||||
for (int i = lane_id;
|
||||
i < warp_topk::round_up_to_multiple_of<WARP_SIZE>(topk);
|
||||
i += WARP_SIZE) {
|
||||
T value =
|
||||
i < topk
|
||||
? scores[s_topk_idx[i]]
|
||||
: cuda_cast<T, float>(0.0f); // Load the valid value of expert
|
||||
if (i < topk) {
|
||||
s_topk_value[i] = value;
|
||||
}
|
||||
topk_sum += reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (case_id < num_tokens) {
|
||||
if (if_proceed_next_topk) {
|
||||
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
||||
float value;
|
||||
if (renormalize) {
|
||||
value = cuda_cast<float, T>(s_topk_value[i]) / topk_sum *
|
||||
routed_scaling_factor;
|
||||
} else {
|
||||
value = cuda_cast<float, T>(s_topk_value[i]) * routed_scaling_factor;
|
||||
}
|
||||
topk_indices[i] = s_topk_idx[i];
|
||||
topk_values[i] = cuda_cast<T, float>(value);
|
||||
}
|
||||
} else {
|
||||
for (int i = lane_id; i < topk; i += WARP_SIZE) {
|
||||
topk_indices[i] = i;
|
||||
topk_values[i] = cuda_cast<T, float>(1.0f / topk);
|
||||
}
|
||||
}
|
||||
// Note: when if_proceed_next_topk==false, choose the first 8 experts as the
|
||||
// default result.
|
||||
}
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT>
|
||||
void invokeNoAuxTc(T* scores, T* group_scores, T* topk_values,
|
||||
IdxT* topk_indices, T* scores_with_bias,
|
||||
int64_t const num_tokens, int64_t const num_experts,
|
||||
int64_t const n_group, int64_t const topk_group,
|
||||
int64_t const topk, bool const renormalize,
|
||||
double const routed_scaling_factor, bool enable_pdl = false,
|
||||
cudaStream_t const stream = 0) {
|
||||
int64_t num_cases = num_tokens * n_group;
|
||||
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
|
||||
auto* kernel_instance1 = &topk_with_k2_kernel<T>;
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = topk_with_k2_num_blocks;
|
||||
config.blockDim = BLOCK_SIZE;
|
||||
config.dynamicSmemBytes = 0;
|
||||
config.stream = stream;
|
||||
cudaLaunchAttribute attrs[1];
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores_with_bias,
|
||||
num_tokens, num_cases, n_group, num_experts / n_group);
|
||||
|
||||
int64_t topk_with_k_group_num_blocks =
|
||||
(num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
|
||||
size_t dynamic_smem_in_bytes =
|
||||
warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK,
|
||||
topk);
|
||||
auto* kernel_instance2 = &group_idx_and_topk_idx_kernel<T, IdxT>;
|
||||
config.gridDim = topk_with_k_group_num_blocks;
|
||||
config.blockDim = BLOCK_SIZE;
|
||||
config.dynamicSmemBytes = dynamic_smem_in_bytes;
|
||||
config.stream = stream;
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores,
|
||||
topk_values, topk_indices, scores_with_bias, num_tokens,
|
||||
n_group, topk_group, topk, num_experts,
|
||||
num_experts / n_group, renormalize, routed_scaling_factor);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_NOAUX_TC(T, IdxT) \
|
||||
template void invokeNoAuxTc<T, IdxT>( \
|
||||
T * scores, T * group_scores, T * topk_values, IdxT * topk_indices, \
|
||||
T * scores_with_bias, int64_t const num_tokens, \
|
||||
int64_t const num_experts, int64_t const n_group, \
|
||||
int64_t const topk_group, int64_t const topk, bool const renormalize, \
|
||||
double const routed_scaling_factor, bool enable_pdl, \
|
||||
cudaStream_t const stream);
|
||||
|
||||
INSTANTIATE_NOAUX_TC(float, int32_t);
|
||||
INSTANTIATE_NOAUX_TC(half, int32_t);
|
||||
INSTANTIATE_NOAUX_TC(__nv_bfloat16, int32_t);
|
||||
} // end namespace moe
|
||||
} // namespace vllm
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
|
||||
torch::Tensor const& scores, torch::Tensor const& scores_with_bias,
|
||||
int64_t n_group, int64_t topk_group, int64_t topk, bool renormalize,
|
||||
double routed_scaling_factor) {
|
||||
auto data_type = scores_with_bias.scalar_type();
|
||||
auto input_size = scores_with_bias.sizes();
|
||||
int64_t num_tokens = input_size[0];
|
||||
int64_t num_experts = input_size[1];
|
||||
TORCH_CHECK(input_size.size() == 2, "scores_with_bias must be a 2D Tensor");
|
||||
TORCH_CHECK(num_experts % n_group == 0,
|
||||
"num_experts should be divisible by n_group");
|
||||
TORCH_CHECK(n_group <= 32,
|
||||
"n_group should be smaller than or equal to 32 for now");
|
||||
TORCH_CHECK(topk <= 32, "topk should be smaller than or equal to 32 for now");
|
||||
|
||||
torch::Tensor group_scores = torch::empty(
|
||||
{num_tokens, n_group}, torch::dtype(data_type).device(torch::kCUDA));
|
||||
torch::Tensor topk_values = torch::empty(
|
||||
{num_tokens, topk}, torch::dtype(data_type).device(torch::kCUDA));
|
||||
torch::Tensor topk_indices = torch::empty(
|
||||
{num_tokens, topk}, torch::dtype(torch::kInt32).device(torch::kCUDA));
|
||||
|
||||
auto stream = c10::cuda::getCurrentCUDAStream(scores_with_bias.get_device());
|
||||
|
||||
switch (data_type) {
|
||||
case torch::kFloat16:
|
||||
// Handle Float16
|
||||
vllm::moe::invokeNoAuxTc<half, int32_t>(
|
||||
reinterpret_cast<half*>(scores.mutable_data_ptr()),
|
||||
reinterpret_cast<half*>(group_scores.mutable_data_ptr()),
|
||||
reinterpret_cast<half*>(topk_values.mutable_data_ptr()),
|
||||
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()),
|
||||
reinterpret_cast<half*>(scores_with_bias.data_ptr()), num_tokens,
|
||||
num_experts, n_group, topk_group, topk, renormalize,
|
||||
routed_scaling_factor, false, stream);
|
||||
break;
|
||||
case torch::kFloat32:
|
||||
// Handle Float32
|
||||
vllm::moe::invokeNoAuxTc<float, int32_t>(
|
||||
reinterpret_cast<float*>(scores.mutable_data_ptr()),
|
||||
reinterpret_cast<float*>(group_scores.mutable_data_ptr()),
|
||||
reinterpret_cast<float*>(topk_values.mutable_data_ptr()),
|
||||
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()),
|
||||
reinterpret_cast<float*>(scores_with_bias.data_ptr()), num_tokens,
|
||||
num_experts, n_group, topk_group, topk, renormalize,
|
||||
routed_scaling_factor, false, stream);
|
||||
break;
|
||||
case torch::kBFloat16:
|
||||
// Handle BFloat16
|
||||
vllm::moe::invokeNoAuxTc<__nv_bfloat16, int32_t>(
|
||||
reinterpret_cast<__nv_bfloat16*>(scores.mutable_data_ptr()),
|
||||
reinterpret_cast<__nv_bfloat16*>(group_scores.mutable_data_ptr()),
|
||||
reinterpret_cast<__nv_bfloat16*>(topk_values.mutable_data_ptr()),
|
||||
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()),
|
||||
reinterpret_cast<__nv_bfloat16*>(scores_with_bias.data_ptr()),
|
||||
num_tokens, num_experts, n_group, topk_group, topk, renormalize,
|
||||
routed_scaling_factor, false, stream);
|
||||
break;
|
||||
default:
|
||||
// Handle other data types
|
||||
throw std::invalid_argument(
|
||||
"Invalid dtype, only supports float16, float32, and bfloat16");
|
||||
break;
|
||||
}
|
||||
return {topk_values, topk_indices};
|
||||
}
|
||||
@ -22,6 +22,11 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
|
||||
torch::Tensor num_tokens_post_pad, int64_t top_k,
|
||||
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
|
||||
int64_t BLOCK_SIZE_K, int64_t bit);
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
|
||||
torch::Tensor const& scores, torch::Tensor const& scores_with_bias,
|
||||
int64_t n_group, int64_t topk_group, int64_t topk, bool renormalize,
|
||||
double routed_scaling_factor);
|
||||
#endif
|
||||
|
||||
bool moe_permute_unpermute_supported();
|
||||
|
||||
@ -45,8 +45,6 @@ void moe_permute(
|
||||
auto copy_topk_ids = topk_ids.clone(); // copy topk_ids for preprocess
|
||||
auto permuted_experts_id = torch::empty_like(topk_ids);
|
||||
auto sorted_row_idx = torch::empty_like(inv_permuted_idx);
|
||||
auto align_expert_first_token_offset =
|
||||
torch::zeros_like(expert_first_token_offset);
|
||||
|
||||
CubKeyValueSorter sorter{};
|
||||
int64_t* valid_num_ptr = nullptr;
|
||||
@ -85,12 +83,14 @@ void moe_permute(
|
||||
});
|
||||
|
||||
// get m_indices and update expert_first_token_offset with align block
|
||||
getMIndices(get_ptr<int64_t>(expert_first_token_offset),
|
||||
get_ptr<int64_t>(align_expert_first_token_offset),
|
||||
get_ptr<int>(m_indices), n_local_expert, align_block_size_value,
|
||||
stream);
|
||||
// this is only required for DeepGemm and not required for CUTLASS group gemm
|
||||
if (align_block_size.has_value()) {
|
||||
// update align_expert_first_token_offset
|
||||
auto align_expert_first_token_offset =
|
||||
torch::zeros_like(expert_first_token_offset);
|
||||
getMIndices(get_ptr<int64_t>(expert_first_token_offset),
|
||||
get_ptr<int64_t>(align_expert_first_token_offset),
|
||||
get_ptr<int>(m_indices), n_local_expert, align_block_size_value,
|
||||
stream);
|
||||
expert_first_token_offset.copy_(align_expert_first_token_offset);
|
||||
}
|
||||
}
|
||||
@ -195,19 +195,14 @@ void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights,
|
||||
torch::Tensor& expert_first_token_offset,
|
||||
torch::Tensor& src_row_id2dst_row_id_map,
|
||||
torch::Tensor& m_indices) {
|
||||
TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0");
|
||||
TORCH_CHECK(false, "moe_permute is not supported on CUDA < 12.0");
|
||||
}
|
||||
|
||||
void moe_unpermute(const torch::Tensor& input,
|
||||
const torch::Tensor& topk_weights, torch::Tensor& topk_ids,
|
||||
const torch::Tensor& token_expert_indices,
|
||||
const std::optional<torch::Tensor>& expert_map,
|
||||
int64_t n_expert, int64_t n_local_expert, int64_t topk,
|
||||
const std::optional<int64_t>& align_block_size,
|
||||
torch::Tensor& permuted_input,
|
||||
torch::Tensor& expert_first_token_offset,
|
||||
torch::Tensor& src_row_id2dst_row_id_map,
|
||||
torch::Tensor& m_indices) {
|
||||
void moe_unpermute(
|
||||
const torch::Tensor& permuted_hidden_states,
|
||||
const torch::Tensor& topk_weights, const torch::Tensor& inv_permuted_idx,
|
||||
const std::optional<torch::Tensor>& expert_first_token_offset, int64_t topk,
|
||||
torch::Tensor& hidden_states) {
|
||||
TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0");
|
||||
}
|
||||
|
||||
@ -224,4 +219,4 @@ bool moe_permute_unpermute_supported() {
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("moe_permute", &moe_permute);
|
||||
m.impl("moe_unpermute", &moe_unpermute);
|
||||
}
|
||||
}
|
||||
@ -78,6 +78,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
"output_tensor) -> ()");
|
||||
m.impl("shuffle_rows", torch::kCUDA, &shuffle_rows);
|
||||
|
||||
// Apply grouped topk routing to select experts.
|
||||
m.def(
|
||||
"grouped_topk(Tensor scores, Tensor scores_with_bias, int n_group, int "
|
||||
"topk_group, int topk, bool renormalize, float "
|
||||
"routed_scaling_factor) -> (Tensor, Tensor)");
|
||||
m.impl("grouped_topk", torch::kCUDA, &grouped_topk);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@ -229,6 +229,11 @@ void get_cutlass_moe_mm_data(
|
||||
const int64_t num_experts, const int64_t n, const int64_t k,
|
||||
const std::optional<torch::Tensor>& blockscale_offsets);
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
|
||||
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets);
|
||||
|
||||
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
|
||||
418
csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu
Normal file
418
csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu
Normal file
@ -0,0 +1,418 @@
|
||||
//
|
||||
// Based off of:
|
||||
// https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu
|
||||
//
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
#include "cutlass_extensions/torch_utils.hpp"
|
||||
|
||||
#include "core/registration.h"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/mixed_dtype_utils.hpp"
|
||||
|
||||
#include "cutlass_extensions/common.hpp"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm::cutlass_w4a8 {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
// -------------------------------------------------------------------------------------
|
||||
// Static configuration shared across all instantiations
|
||||
// -------------------------------------------------------------------------------------
|
||||
using MmaType = cutlass::float_e4m3_t; // A/scale element type
|
||||
using QuantType = cutlass::int4b_t; // B element type (packed int4)
|
||||
|
||||
static int constexpr TileShapeK = 128 * 8 / sizeof_bits<MmaType>::value;
|
||||
static int constexpr ScalePackSize = 8; // pack 8 scale elements together
|
||||
static int constexpr PackFactor = 8; // 8 4-bit packed into int32
|
||||
|
||||
// A matrix configuration
|
||||
using ElementA = MmaType; // Element type for A matrix operand
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
using LayoutA_Transpose =
|
||||
typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
||||
constexpr int AlignmentA =
|
||||
128 / cutlass::sizeof_bits<
|
||||
ElementA>::value; // Memory access granularity/alignment of A
|
||||
// matrix in units of elements (up to 16 bytes)
|
||||
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
|
||||
|
||||
// B matrix configuration
|
||||
using ElementB = QuantType; // Element type for B matrix operand
|
||||
using LayoutB =
|
||||
cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
using LayoutB_Transpose =
|
||||
typename cutlass::layout::LayoutTranspose<LayoutB>::type;
|
||||
constexpr int AlignmentB =
|
||||
128 / cutlass::sizeof_bits<
|
||||
ElementB>::value; // Memory access granularity/alignment of B
|
||||
// matrix in units of elements (up to 16 bytes)
|
||||
using StrideB = cutlass::detail::TagToStrideB_t<LayoutB>;
|
||||
|
||||
// Define the CuTe layout for reordered quantized tensor B
|
||||
// LayoutAtomQuant places values that will be read by the same thread in
|
||||
// contiguous locations in global memory. It specifies the reordering within a
|
||||
// single warp's fragment
|
||||
using LayoutAtomQuant =
|
||||
decltype(cutlass::compute_memory_reordering_atom<MmaType>());
|
||||
using LayoutB_Reordered = decltype(cute::tile_to_shape(
|
||||
LayoutAtomQuant{}, Layout<Shape<int, int, int>, StrideB>{}));
|
||||
|
||||
// Group-wise scales
|
||||
using ElementScale = MmaType;
|
||||
using LayoutScale = cutlass::layout::RowMajor;
|
||||
|
||||
// Per-tok, per-chan scales
|
||||
using ElementSChannel = float;
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementC =
|
||||
cutlass::bfloat16_t; // Element type for C and D matrix operands
|
||||
using LayoutC =
|
||||
cutlass::layout::RowMajor; // Layout type for C and D matrix operands
|
||||
constexpr int AlignmentC =
|
||||
128 / cutlass::sizeof_bits<
|
||||
ElementC>::value; // Memory access granularity/alignment of C
|
||||
// matrix in units of elements (up to 16 bytes)
|
||||
|
||||
using ElementD = ElementC;
|
||||
using LayoutD = LayoutC;
|
||||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ElementCompute = float; // Element type for epilogue computation
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that
|
||||
// supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperative; // Kernel to launch
|
||||
// based on the default
|
||||
// setting in the
|
||||
// Collective Builder
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Kernel template — Tile/Cluster shapes
|
||||
// ----------------------------------------------------------------------------
|
||||
template <class TileShape_MN, class ClusterShape_MNK>
|
||||
struct W4A8GemmKernel {
|
||||
using TileShape =
|
||||
decltype(cute::append(TileShape_MN{}, cute::Int<TileShapeK>{}));
|
||||
using ClusterShape = ClusterShape_MNK;
|
||||
|
||||
// Epilogue per-tok, per-chan scales
|
||||
using ChTokScalesEpilogue =
|
||||
typename vllm::c3x::ScaledEpilogue<ElementAccumulator, ElementD,
|
||||
TileShape>;
|
||||
using EVTCompute = typename ChTokScalesEpilogue::EVTCompute;
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType,
|
||||
ElementAccumulator, ElementSChannel,
|
||||
// Transpose layout of D here since we use explicit swap + transpose
|
||||
// the void type for C tells the builder to allocate 0 smem for the C
|
||||
// matrix. We can enable this if beta == 0 by changing ElementC to
|
||||
// void below.
|
||||
ElementC, typename cutlass::layout::LayoutTranspose<LayoutC>::type,
|
||||
AlignmentC, ElementD,
|
||||
typename cutlass::layout::LayoutTranspose<LayoutD>::type, AlignmentD,
|
||||
EpilogueSchedule, // This is the only epi supporting the required
|
||||
// swap + transpose.
|
||||
EVTCompute>::CollectiveOp;
|
||||
|
||||
// The Scale information must get paired with the operand that will be scaled.
|
||||
// In this example, B is scaled so we make a tuple of B's information and the
|
||||
// scale information.
|
||||
using CollectiveMainloopShuffled =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass,
|
||||
cute::tuple<ElementB, cutlass::Array<ElementScale, ScalePackSize>>,
|
||||
LayoutB_Reordered, AlignmentB, ElementA, LayoutA_Transpose,
|
||||
AlignmentA, ElementAccumulator, TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernelShuffled = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, // Indicates ProblemShape
|
||||
CollectiveMainloopShuffled, CollectiveEpilogue>;
|
||||
using GemmShuffled =
|
||||
cutlass::gemm::device::GemmUniversalAdapter<GemmKernelShuffled>;
|
||||
|
||||
using StrideC = typename GemmKernelShuffled::StrideC;
|
||||
using StrideD = typename GemmKernelShuffled::StrideD;
|
||||
using StrideS = typename CollectiveMainloopShuffled::StrideScale;
|
||||
|
||||
static torch::Tensor mm(torch::Tensor const& A,
|
||||
torch::Tensor const& B, // already packed
|
||||
torch::Tensor const& group_scales, // already packed
|
||||
int64_t group_size,
|
||||
torch::Tensor const& channel_scales,
|
||||
torch::Tensor const& token_scales,
|
||||
std::optional<at::ScalarType> const& maybe_out_type) {
|
||||
// TODO: param validation
|
||||
int m = A.size(0);
|
||||
int k = A.size(1);
|
||||
int n = B.size(1);
|
||||
|
||||
// Allocate output
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
||||
auto device = A.device();
|
||||
auto stream = at::cuda::getCurrentCUDAStream(device.index());
|
||||
torch::Tensor D =
|
||||
torch::empty({m, n}, torch::TensorOptions()
|
||||
.dtype(equivalent_scalar_type_v<ElementD>)
|
||||
.device(device));
|
||||
// prepare arg pointers
|
||||
auto A_ptr = static_cast<MmaType const*>(A.const_data_ptr());
|
||||
auto B_ptr = static_cast<QuantType const*>(B.const_data_ptr());
|
||||
auto D_ptr = static_cast<ElementD*>(D.data_ptr());
|
||||
// can we avoid harcode the 8 here
|
||||
auto S_ptr =
|
||||
static_cast<cutlass::Array<ElementScale, ScalePackSize> const*>(
|
||||
group_scales.const_data_ptr());
|
||||
|
||||
// runtime layout for B
|
||||
auto shape_B = cute::make_shape(n, k, 1);
|
||||
LayoutB_Reordered layout_B_reordered =
|
||||
cute::tile_to_shape(LayoutAtomQuant{}, shape_B);
|
||||
|
||||
// strides
|
||||
int const scale_k = cutlass::ceil_div(k, group_size);
|
||||
StrideA stride_A =
|
||||
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
|
||||
// Reverse stride here due to swap and transpose
|
||||
StrideD stride_D =
|
||||
cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(n, m, 1));
|
||||
StrideS stride_S = cutlass::make_cute_packed_stride(
|
||||
StrideS{}, cute::make_shape(n, scale_k, 1));
|
||||
|
||||
// Create a structure of gemm kernel arguments suitable for invoking an
|
||||
// instance of Gemm auto arguments =
|
||||
// args_from_options<GemmShuffled>(options);
|
||||
/// Populates a Gemm::Arguments structure from the given arguments
|
||||
/// Swap the A and B tensors, as well as problem shapes here.
|
||||
using Args = typename GemmShuffled::Arguments;
|
||||
using MainloopArguments = typename GemmKernelShuffled::MainloopArguments;
|
||||
using EpilogueArguments = typename GemmKernelShuffled::EpilogueArguments;
|
||||
|
||||
MainloopArguments mainloop_arguments{
|
||||
B_ptr, layout_B_reordered, A_ptr, stride_A,
|
||||
S_ptr, stride_S, group_size};
|
||||
|
||||
EpilogueArguments epilogue_arguments{
|
||||
ChTokScalesEpilogue::prepare_args(channel_scales, token_scales),
|
||||
nullptr,
|
||||
{}, // no C
|
||||
D_ptr,
|
||||
stride_D};
|
||||
|
||||
Args arguments{cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{n, m, k, 1}, // shape
|
||||
mainloop_arguments,
|
||||
epilogue_arguments};
|
||||
|
||||
// Workspace
|
||||
size_t workspace_size = GemmShuffled::get_workspace_size(arguments);
|
||||
torch::Tensor workspace =
|
||||
torch::empty(workspace_size,
|
||||
torch::TensorOptions().dtype(torch::kU8).device(device));
|
||||
|
||||
// Run GEMM
|
||||
GemmShuffled gemm;
|
||||
CUTLASS_CHECK(gemm.can_implement(arguments));
|
||||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.data_ptr(), stream));
|
||||
CUTLASS_CHECK(gemm.run(stream));
|
||||
|
||||
return D;
|
||||
}
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Kernel instantiations and dispatch logic
|
||||
// ----------------------------------------------------------------------------
|
||||
using Kernel_256x128_1x1x1 =
|
||||
W4A8GemmKernel<Shape<_256, _128>, Shape<_1, _1, _1>>;
|
||||
using Kernel_256x64_1x1x1 = W4A8GemmKernel<Shape<_256, _64>, Shape<_1, _1, _1>>;
|
||||
using Kernel_256x32_1x1x1 = W4A8GemmKernel<Shape<_256, _32>, Shape<_1, _1, _1>>;
|
||||
using Kernel_256x16_1x1x1 = W4A8GemmKernel<Shape<_256, _16>, Shape<_1, _1, _1>>;
|
||||
using Kernel_128x256_2x1x1 =
|
||||
W4A8GemmKernel<Shape<_128, _256>, Shape<_2, _1, _1>>;
|
||||
using Kernel_128x256_1x1x1 =
|
||||
W4A8GemmKernel<Shape<_128, _256>, Shape<_1, _1, _1>>;
|
||||
using Kernel_128x128_1x1x1 =
|
||||
W4A8GemmKernel<Shape<_128, _128>, Shape<_1, _1, _1>>;
|
||||
using Kernel_128x64_1x1x1 = W4A8GemmKernel<Shape<_128, _64>, Shape<_1, _1, _1>>;
|
||||
using Kernel_128x32_1x1x1 = W4A8GemmKernel<Shape<_128, _32>, Shape<_1, _1, _1>>;
|
||||
using Kernel_128x16_1x1x1 = W4A8GemmKernel<Shape<_128, _16>, Shape<_1, _1, _1>>;
|
||||
|
||||
torch::Tensor mm_dispatch(torch::Tensor const& A,
|
||||
torch::Tensor const& B, // already packed
|
||||
torch::Tensor const& group_scales, // already packed
|
||||
int64_t group_size,
|
||||
torch::Tensor const& channel_scales,
|
||||
torch::Tensor const& token_scales,
|
||||
std::optional<at::ScalarType> const& maybe_out_type,
|
||||
const std::string& schedule) {
|
||||
if (schedule == "256x128_1x1x1") {
|
||||
return Kernel_256x128_1x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
} else if (schedule == "256x64_1x1x1") {
|
||||
return Kernel_256x64_1x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
} else if (schedule == "256x32_1x1x1") {
|
||||
return Kernel_256x32_1x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
} else if (schedule == "256x16_1x1x1") {
|
||||
return Kernel_256x16_1x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
} else if (schedule == "128x256_2x1x1") {
|
||||
return Kernel_128x256_2x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
} else if (schedule == "128x256_1x1x1") {
|
||||
return Kernel_128x256_1x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
} else if (schedule == "128x128_1x1x1") {
|
||||
return Kernel_128x128_1x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
} else if (schedule == "128x64_1x1x1") {
|
||||
return Kernel_128x64_1x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
} else if (schedule == "128x32_1x1x1") {
|
||||
return Kernel_128x32_1x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
} else if (schedule == "128x16_1x1x1") {
|
||||
return Kernel_128x16_1x1x1::mm(A, B, group_scales, group_size,
|
||||
channel_scales, token_scales,
|
||||
maybe_out_type);
|
||||
}
|
||||
TORCH_CHECK(false, "Unknown W4A8 schedule: ", schedule);
|
||||
return {};
|
||||
}
|
||||
|
||||
torch::Tensor mm(torch::Tensor const& A,
|
||||
torch::Tensor const& B, // already packed
|
||||
torch::Tensor const& group_scales, // already packed
|
||||
int64_t group_size, torch::Tensor const& channel_scales,
|
||||
torch::Tensor const& token_scales,
|
||||
std::optional<at::ScalarType> const& maybe_out_type,
|
||||
std::optional<std::string> maybe_schedule) {
|
||||
// requested a specific schedule
|
||||
if (maybe_schedule) {
|
||||
return mm_dispatch(A, B, group_scales, group_size, channel_scales,
|
||||
token_scales, maybe_out_type, *maybe_schedule);
|
||||
}
|
||||
std::string schedule;
|
||||
int M = A.size(0);
|
||||
int K = A.size(1);
|
||||
int N = B.size(1);
|
||||
// heuristic
|
||||
if (M <= 16) {
|
||||
schedule = (K == 16384 && N == 18432) ? "256x16_1x1x1" : "128x16_1x1x1";
|
||||
} else if (M <= 32) {
|
||||
schedule = (K == 16384 && N == 18432) ? "256x32_1x1x1" : "128x32_1x1x1";
|
||||
} else if (M <= 64) {
|
||||
if (K == 16384 && N == 18432)
|
||||
schedule = "256x64_1x1x1";
|
||||
else if (N <= 8192 && K <= 8192)
|
||||
schedule = "128x32_1x1x1";
|
||||
else
|
||||
schedule = "128x64_1x1x1";
|
||||
} else if (M <= 128) {
|
||||
if (K == 16384 && N == 18432)
|
||||
schedule = "256x128_1x1x1";
|
||||
else if (N <= 8192)
|
||||
schedule = "128x64_1x1x1";
|
||||
else
|
||||
schedule = "128x128_1x1x1";
|
||||
} else if (M <= 256) {
|
||||
if (N <= 4096)
|
||||
schedule = "128x64_1x1x1";
|
||||
else if (N <= 8192)
|
||||
schedule = "128x128_1x1x1";
|
||||
else
|
||||
schedule = "128x256_1x1x1";
|
||||
} else if (M <= 512 && N <= 4096) {
|
||||
schedule = "128x128_1x1x1";
|
||||
} else if (M <= 1024) {
|
||||
schedule = "128x256_1x1x1";
|
||||
} else {
|
||||
schedule = "128x256_2x1x1";
|
||||
}
|
||||
return mm_dispatch(A, B, group_scales, group_size, channel_scales,
|
||||
token_scales, maybe_out_type, schedule);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Pre-processing utils
|
||||
// ----------------------------------------------------------------------------
|
||||
torch::Tensor pack_scale_fp8(torch::Tensor const& scales) {
|
||||
TORCH_CHECK(scales.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(scales.is_contiguous());
|
||||
TORCH_CHECK(scales.is_cuda());
|
||||
|
||||
auto packed_scales = torch::empty(
|
||||
{scales.numel() * ScalePackSize},
|
||||
torch::TensorOptions().dtype(scales.dtype()).device(scales.device()));
|
||||
auto scales_ptr = static_cast<MmaType const*>(scales.const_data_ptr());
|
||||
auto packed_scales_ptr =
|
||||
static_cast<cutlass::Array<ElementScale, ScalePackSize>*>(
|
||||
packed_scales.data_ptr());
|
||||
|
||||
cutlass::pack_scale_fp8(scales_ptr, packed_scales_ptr, scales.numel());
|
||||
|
||||
return packed_scales;
|
||||
}
|
||||
|
||||
torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) {
|
||||
TORCH_CHECK(B.dtype() == torch::kInt32);
|
||||
TORCH_CHECK(B.dim() == 2);
|
||||
|
||||
torch::Tensor B_packed = torch::empty_like(B);
|
||||
|
||||
int k = B.size(0) * PackFactor; // logical k
|
||||
int n = B.size(1);
|
||||
|
||||
auto B_ptr = static_cast<QuantType const*>(B.const_data_ptr());
|
||||
auto B_packed_ptr = static_cast<QuantType*>(B_packed.data_ptr());
|
||||
auto shape_B = cute::make_shape(n, k, 1);
|
||||
auto layout_B = make_layout(shape_B, LayoutRight{}); // row major
|
||||
LayoutB_Reordered layout_B_reordered =
|
||||
cute::tile_to_shape(LayoutAtomQuant{}, shape_B);
|
||||
|
||||
cutlass::unified_encode_int4b(B_ptr, B_packed_ptr, n * k);
|
||||
cutlass::reorder_tensor(B_packed_ptr, layout_B, layout_B_reordered);
|
||||
|
||||
return B_packed;
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("cutlass_w4a8_mm", &mm);
|
||||
m.impl("cutlass_pack_scale_fp8", &pack_scale_fp8);
|
||||
m.impl("cutlass_encode_and_reorder_int4b", &encode_and_reorder_int4b);
|
||||
}
|
||||
|
||||
} // namespace vllm::cutlass_w4a8
|
||||
@ -10,7 +10,7 @@
|
||||
|
||||
template <typename ElementAB, typename ElementC, typename ElementAccumulator>
|
||||
__global__ void get_group_gemm_starts(
|
||||
int32_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets,
|
||||
int64_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets,
|
||||
ElementC** out_offsets, ElementAccumulator** a_scales_offsets,
|
||||
ElementAccumulator** b_scales_offsets, ElementAB* a_base_as_int,
|
||||
ElementAB* b_base_as_int, ElementC* out_base_as_int,
|
||||
@ -34,7 +34,7 @@ __global__ void get_group_gemm_starts(
|
||||
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
|
||||
get_group_gemm_starts<cutlass::float_e4m3_t, C_TYPE, float> \
|
||||
<<<1, num_experts, 0, stream>>>( \
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()), \
|
||||
static_cast<int64_t*>(expert_offsets.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()), \
|
||||
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
|
||||
@ -61,6 +61,8 @@ void run_get_group_gemm_starts(
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
// expect int64_t to avoid overflow during offset calculations
|
||||
TORCH_CHECK(expert_offsets.dtype() == torch::kInt64);
|
||||
|
||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||
bool per_act_token = a_scales.numel() != 1;
|
||||
|
||||
@ -104,6 +104,53 @@ __global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& atomic_buffer,
|
||||
int64_t num_experts, int64_t n,
|
||||
int64_t k, cudaStream_t stream,
|
||||
const bool swap_ab) {
|
||||
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
||||
|
||||
const int32_t* topk_ptr = static_cast<const int32_t*>(topk_ids.data_ptr());
|
||||
int32_t* ps1_ptr = static_cast<int32_t*>(problem_sizes1.data_ptr());
|
||||
int32_t* ps2_ptr = static_cast<int32_t*>(problem_sizes2.data_ptr());
|
||||
int32_t* atomic_ptr = static_cast<int32_t*>(atomic_buffer.data_ptr());
|
||||
|
||||
if (swap_ab) {
|
||||
compute_problem_sizes<true><<<num_experts, num_threads, 0, stream>>>(
|
||||
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
|
||||
static_cast<int>(topk_ids.numel()), static_cast<int>(n),
|
||||
static_cast<int>(k));
|
||||
} else {
|
||||
compute_problem_sizes<false><<<num_experts, num_threads, 0, stream>>>(
|
||||
topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr,
|
||||
static_cast<int>(topk_ids.numel()), static_cast<int>(n),
|
||||
static_cast<int>(k));
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes_caller(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
|
||||
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets) {
|
||||
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
|
||||
auto options_int32 =
|
||||
torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
|
||||
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
|
||||
|
||||
// Swap-AB should be disabled for FP4 path
|
||||
bool may_swap_ab = (!blockscale_offsets.has_value()) &&
|
||||
(topk_ids.numel() <= SWAP_AB_THRESHOLD);
|
||||
|
||||
launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2,
|
||||
atomic_buffer, num_experts, n, k, stream,
|
||||
may_swap_ab);
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_data_caller(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||
@ -121,21 +168,9 @@ void get_cutlass_moe_mm_data_caller(
|
||||
bool may_swap_ab = (!blockscale_offsets.has_value()) &&
|
||||
(topk_ids.numel() <= SWAP_AB_THRESHOLD);
|
||||
|
||||
if (may_swap_ab) {
|
||||
compute_problem_sizes<true><<<num_experts, num_threads, 0, stream>>>(
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n,
|
||||
k);
|
||||
} else {
|
||||
compute_problem_sizes<false><<<num_experts, num_threads, 0, stream>>>(
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n,
|
||||
k);
|
||||
}
|
||||
launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2,
|
||||
atomic_buffer, num_experts, n, k, stream,
|
||||
may_swap_ab);
|
||||
|
||||
if (blockscale_offsets.has_value()) {
|
||||
// fp4 path
|
||||
|
||||
@ -76,6 +76,11 @@ void get_cutlass_moe_mm_data_caller(
|
||||
const int64_t num_experts, const int64_t n, const int64_t k,
|
||||
const std::optional<torch::Tensor>& blockscale_offsets);
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes_caller(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
|
||||
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets);
|
||||
|
||||
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
@ -293,6 +298,25 @@ void get_cutlass_moe_mm_data(
|
||||
version_num, ". Required capability: 90 or 100");
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_problem_sizes(
|
||||
const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n,
|
||||
const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets) {
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100)
|
||||
get_cutlass_moe_mm_problem_sizes_caller(topk_ids, problem_sizes1,
|
||||
problem_sizes2, num_experts, n, k,
|
||||
blockscale_offsets);
|
||||
return;
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled get_cutlass_moe_mm_problem_sizes: no cutlass_scaled_mm "
|
||||
"kernel for CUDA device capability: ",
|
||||
version_num, ". Required capability: 90 or 100");
|
||||
}
|
||||
|
||||
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
|
||||
@ -571,78 +571,79 @@ def generate():
|
||||
itertools.repeat(default_heuristic))
|
||||
]
|
||||
|
||||
# Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk))
|
||||
# TODO (LucasWilkinson): Further tuning required
|
||||
qqq_tile_heuristic_config = {
|
||||
#### M = 257+
|
||||
# ((128, 256), (2, 1, 1)) Broken for QQQ types
|
||||
# TODO (LucasWilkinson): Investigate further
|
||||
# "M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)),
|
||||
# "M > 256": ((128, 256), (2, 1, 1)),
|
||||
"M > 256": ((128, 128), (2, 1, 1)),
|
||||
#### M = 129-256
|
||||
"M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)),
|
||||
"M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)),
|
||||
# ((128, 256), (2, 1, 1)) Broken for QQQ types
|
||||
# TODO (LucasWilkinson): Investigate further
|
||||
# "M > 128": ((128, 256), (2, 1, 1)),
|
||||
"M > 128": ((128, 128), (2, 1, 1)),
|
||||
#### M = 65-128
|
||||
"M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)),
|
||||
"M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)),
|
||||
"M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)),
|
||||
"M > 64": ((128, 128), (2, 1, 1)),
|
||||
#### M = 33-64
|
||||
"M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)),
|
||||
# Broken for QQQ types
|
||||
# TODO (LucasWilkinson): Investigate further
|
||||
#"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)),
|
||||
"M > 32": ((128, 64), (2, 1, 1)),
|
||||
#### M = 17-32
|
||||
"M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)),
|
||||
"M > 16": ((256, 32), (2, 1, 1)),
|
||||
#### M = 1-16
|
||||
"N >= 26624": ((256, 16), (1, 1, 1)),
|
||||
None: ((128, 16), (1, 1, 1)),
|
||||
}
|
||||
# TODO: Support W4A8 when ready
|
||||
# # Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk))
|
||||
# # TODO (LucasWilkinson): Further tuning required
|
||||
# qqq_tile_heuristic_config = {
|
||||
# #### M = 257+
|
||||
# # ((128, 256), (2, 1, 1)) Broken for QQQ types
|
||||
# # TODO (LucasWilkinson): Investigate further
|
||||
# # "M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)),
|
||||
# # "M > 256": ((128, 256), (2, 1, 1)),
|
||||
# "M > 256": ((128, 128), (2, 1, 1)),
|
||||
# #### M = 129-256
|
||||
# "M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)),
|
||||
# "M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)),
|
||||
# # ((128, 256), (2, 1, 1)) Broken for QQQ types
|
||||
# # TODO (LucasWilkinson): Investigate further
|
||||
# # "M > 128": ((128, 256), (2, 1, 1)),
|
||||
# "M > 128": ((128, 128), (2, 1, 1)),
|
||||
# #### M = 65-128
|
||||
# "M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)),
|
||||
# "M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)),
|
||||
# "M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)),
|
||||
# "M > 64": ((128, 128), (2, 1, 1)),
|
||||
# #### M = 33-64
|
||||
# "M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)),
|
||||
# # Broken for QQQ types
|
||||
# # TODO (LucasWilkinson): Investigate further
|
||||
# #"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)),
|
||||
# "M > 32": ((128, 64), (2, 1, 1)),
|
||||
# #### M = 17-32
|
||||
# "M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)),
|
||||
# "M > 16": ((256, 32), (2, 1, 1)),
|
||||
# #### M = 1-16
|
||||
# "N >= 26624": ((256, 16), (1, 1, 1)),
|
||||
# None: ((128, 16), (1, 1, 1)),
|
||||
# }
|
||||
|
||||
# For now we use the same heuristic for all types
|
||||
# Heuristic is currently tuned for H100s
|
||||
qqq_heuristic = [
|
||||
(cond, ScheduleConfig(*tile_config,
|
||||
**sch_common_params)) # type: ignore
|
||||
for cond, tile_config in qqq_tile_heuristic_config.items()
|
||||
]
|
||||
# # For now we use the same heuristic for all types
|
||||
# # Heuristic is currently tuned for H100s
|
||||
# qqq_heuristic = [
|
||||
# (cond, ScheduleConfig(*tile_config,
|
||||
# **sch_common_params)) # type: ignore
|
||||
# for cond, tile_config in qqq_tile_heuristic_config.items()
|
||||
# ]
|
||||
|
||||
QQQ_kernel_types = [
|
||||
*(TypeConfig(
|
||||
a=DataType.s8,
|
||||
b=VLLMDataType.u4b8,
|
||||
b_group_scale=b_group_scale,
|
||||
b_group_zeropoint=DataType.void,
|
||||
b_channel_scale=DataType.f32,
|
||||
a_token_scale=DataType.f32,
|
||||
out=DataType.f16,
|
||||
accumulator=DataType.s32,
|
||||
) for b_group_scale in (DataType.f16, DataType.void)),
|
||||
*(TypeConfig(
|
||||
a=DataType.e4m3,
|
||||
b=VLLMDataType.u4b8,
|
||||
b_group_scale=b_group_scale,
|
||||
b_group_zeropoint=DataType.void,
|
||||
b_channel_scale=DataType.f32,
|
||||
a_token_scale=DataType.f32,
|
||||
out=DataType.f16,
|
||||
accumulator=DataType.f32,
|
||||
) for b_group_scale in (DataType.f16, DataType.void)),
|
||||
]
|
||||
# QQQ_kernel_types = [
|
||||
# *(TypeConfig(
|
||||
# a=DataType.s8,
|
||||
# b=VLLMDataType.u4b8,
|
||||
# b_group_scale=b_group_scale,
|
||||
# b_group_zeropoint=DataType.void,
|
||||
# b_channel_scale=DataType.f32,
|
||||
# a_token_scale=DataType.f32,
|
||||
# out=DataType.f16,
|
||||
# accumulator=DataType.s32,
|
||||
# ) for b_group_scale in (DataType.f16, DataType.void)),
|
||||
# *(TypeConfig(
|
||||
# a=DataType.e4m3,
|
||||
# b=VLLMDataType.u4b8,
|
||||
# b_group_scale=b_group_scale,
|
||||
# b_group_zeropoint=DataType.void,
|
||||
# b_channel_scale=DataType.f32,
|
||||
# a_token_scale=DataType.f32,
|
||||
# out=DataType.f16,
|
||||
# accumulator=DataType.f32,
|
||||
# ) for b_group_scale in (DataType.f16, DataType.void)),
|
||||
# ]
|
||||
|
||||
impl_configs += [
|
||||
ImplConfig(x[0], x[1], x[2])
|
||||
for x in zip(QQQ_kernel_types,
|
||||
itertools.repeat(get_unique_schedules(qqq_heuristic)),
|
||||
itertools.repeat(qqq_heuristic))
|
||||
]
|
||||
# impl_configs += [
|
||||
# ImplConfig(x[0], x[1], x[2])
|
||||
# for x in zip(QQQ_kernel_types,
|
||||
# itertools.repeat(get_unique_schedules(qqq_heuristic)),
|
||||
# itertools.repeat(qqq_heuristic))
|
||||
# ]
|
||||
|
||||
output_dir = os.path.join(SCRIPT_DIR, "generated")
|
||||
|
||||
|
||||
@ -1,209 +0,0 @@
|
||||
Contains code from https://github.com/IST-DASLab/marlin
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "{}"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright {yyyy} {name of copyright owner}
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
------------------------------------------------------------------------------------
|
||||
|
||||
This product bundles various third-party components under other open source licenses.
|
||||
This section summarizes those components and their licenses. See licenses/
|
||||
for text of these licenses.
|
||||
@ -1,32 +0,0 @@
|
||||
/*
|
||||
* Modified by HandH1998
|
||||
* Modified by Neural Magic
|
||||
* Copyright (C) Marlin.2024 Elias Frantar
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
|
||||
|
||||
// Instances of `Vec` are used to organize groups of >>registers<<, as needed
|
||||
// for instance as inputs to tensor core operations. Consequently, all
|
||||
// corresponding index accesses must be compile-time constants, which is why we
|
||||
// extensively use `#pragma unroll` throughout the kernel code to guarantee
|
||||
// this.
|
||||
template <typename T, int n>
|
||||
struct Vec {
|
||||
T elems[n];
|
||||
__device__ T& operator[](int i) { return elems[i]; }
|
||||
};
|
||||
@ -1,89 +0,0 @@
|
||||
/*
|
||||
* Modified by HandH1998
|
||||
* Modified by Neural Magic
|
||||
* Copyright (C) Marlin.2024 Elias Frantar
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// Predicated asynchronous global->shared copy; used for inputs A where we apply
|
||||
// predication to handle batchsizes that are not multiples of 16.
|
||||
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
const int BYTES = 16;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %0, 0;\n"
|
||||
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||
"}\n" ::"r"((int)pred),
|
||||
"r"(smem), "l"(glob_ptr), "n"(BYTES));
|
||||
}
|
||||
|
||||
// Asynchronous global->shared copy
|
||||
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
||||
const int BYTES = 16;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
||||
"}\n" ::"r"(smem),
|
||||
"l"(glob_ptr), "n"(BYTES));
|
||||
}
|
||||
|
||||
// Async copy fence.
|
||||
__device__ inline void cp_async_fence() {
|
||||
asm volatile("cp.async.commit_group;\n" ::);
|
||||
}
|
||||
|
||||
// Wait until at most `n` async copy stages are still pending.
|
||||
template <int n>
|
||||
__device__ inline void cp_async_wait() {
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
||||
}
|
||||
|
||||
// Wait until barrier reaches `count`, then lock for current threadblock.
|
||||
__device__ inline void barrier_acquire(int* lock, int count) {
|
||||
if (threadIdx.x == 0) {
|
||||
int state = -1;
|
||||
do
|
||||
// Guarantee that subsequent writes by this threadblock will be visible
|
||||
// globally.
|
||||
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
|
||||
: "=r"(state)
|
||||
: "l"(lock));
|
||||
while (state != count);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Release barrier and increment visitation count.
|
||||
__device__ inline void barrier_release(int* lock, bool reset = false) {
|
||||
__syncthreads();
|
||||
if (threadIdx.x == 0) {
|
||||
if (reset) {
|
||||
lock[0] = 0;
|
||||
return;
|
||||
}
|
||||
int val = 1;
|
||||
// Make sure that all writes since acquiring this barrier are visible
|
||||
// globally, while releasing the barrier.
|
||||
asm volatile("fence.acq_rel.gpu;\n");
|
||||
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
|
||||
:
|
||||
: "l"(lock), "r"(val));
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -41,8 +41,10 @@ __device__ inline void vectorize_with_alignment(
|
||||
|
||||
for (int i = tid; i < num_vec; i += stride) {
|
||||
vout_t tmp;
|
||||
vec_op(tmp, v_in[i]);
|
||||
v_out[i] = tmp;
|
||||
// Make a local copy of the entire pack
|
||||
vin_t src = v_in[i]; // <- encourages a single vector ld
|
||||
vec_op(tmp, src);
|
||||
v_out[i] = tmp; // <- encourages a single vector st
|
||||
}
|
||||
return;
|
||||
}
|
||||
@ -71,8 +73,10 @@ __device__ inline void vectorize_with_alignment(
|
||||
// 2. vectorize the main part
|
||||
for (int i = tid; i < num_vec; i += stride) {
|
||||
vout_t tmp;
|
||||
vec_op(tmp, v_in[i]);
|
||||
v_out[i] = tmp;
|
||||
// Make a local copy of the entire pack
|
||||
vin_t src = v_in[i]; // <- encourages a single vector ld
|
||||
vec_op(tmp, src);
|
||||
v_out[i] = tmp; // <- encourages a single vector st
|
||||
}
|
||||
|
||||
// 3. handle the tail
|
||||
@ -125,7 +129,8 @@ __device__ inline void vectorize_read_with_alignment(const InT* in, int len,
|
||||
auto* v_in = reinterpret_cast<const vin_t*>(in);
|
||||
|
||||
for (int i = tid; i < num_vec; i += stride) {
|
||||
vec_op(v_in[i]);
|
||||
vin_t tmp = v_in[i];
|
||||
vec_op(tmp);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
@ -241,14 +241,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// custom types:
|
||||
// https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA
|
||||
|
||||
// Marlin (Dense) Optimized Quantized GEMM for GPTQ.
|
||||
ops.def(
|
||||
"marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
|
||||
"Tensor! workspace, SymInt size_m, SymInt size_n, SymInt size_k) -> "
|
||||
"Tensor",
|
||||
{stride_tag});
|
||||
// conditionally compiled so impl in source file
|
||||
|
||||
// Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
|
||||
ops.def(
|
||||
"gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, "
|
||||
@ -317,6 +309,26 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
|
||||
"SymInt size_n, int num_bits) -> Tensor");
|
||||
// conditionally compiled so impl registrations are in source file
|
||||
|
||||
// CUTLASS w4a8 GEMM
|
||||
ops.def(
|
||||
"cutlass_w4a8_mm("
|
||||
" Tensor A,"
|
||||
" Tensor B,"
|
||||
" Tensor group_scales,"
|
||||
" int group_size,"
|
||||
" Tensor channel_scales,"
|
||||
" Tensor token_scales,"
|
||||
" ScalarType? out_type,"
|
||||
" str? maybe_schedule"
|
||||
") -> Tensor",
|
||||
{stride_tag});
|
||||
// pack scales
|
||||
ops.def("cutlass_pack_scale_fp8(Tensor scales) -> Tensor");
|
||||
// encode and reorder weight matrix
|
||||
ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor");
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
#endif
|
||||
|
||||
// Dequantization for GGML.
|
||||
@ -353,15 +365,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("ggml_moe_get_block_size", &ggml_moe_get_block_size);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// marlin_qqq_gemm for QQQ.
|
||||
ops.def(
|
||||
"marlin_qqq_gemm(Tensor a, Tensor b_q_weight, "
|
||||
"Tensor s_tok, Tensor s_ch, Tensor s_group, "
|
||||
"Tensor! workspace, SymInt size_m, SymInt size_n, "
|
||||
"SymInt size_k) -> Tensor",
|
||||
{stride_tag});
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
// CUTLASS nvfp4 block scaled GEMM
|
||||
ops.def(
|
||||
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
|
||||
@ -440,6 +443,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
{stride_tag});
|
||||
ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data);
|
||||
|
||||
// A function that computes problem sizes for each expert's multiplication
|
||||
// used by the two mms called from fused MoE operation. It takes topk_ids as
|
||||
// an input, and computes problem_sizes1 and problem_sizes2 only.
|
||||
ops.def(
|
||||
"get_cutlass_moe_mm_problem_sizes(Tensor topk_ids, "
|
||||
" Tensor! problem_sizes1, "
|
||||
" Tensor! problem_sizes2, "
|
||||
" int num_experts, int n, int k, "
|
||||
" Tensor? blockscale_offsets) -> ()",
|
||||
{stride_tag});
|
||||
ops.impl("get_cutlass_moe_mm_problem_sizes", torch::kCUDA,
|
||||
&get_cutlass_moe_mm_problem_sizes);
|
||||
|
||||
// A function that computes data required to run fused MoE with w8a8 grouped
|
||||
// GEMM and PPLX. It takes expert_num_tokens and non_zero_expert_idxs
|
||||
// as an input, and computes expert_offsets (token start indices of each
|
||||
@ -676,11 +692,16 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
"str kv_cache_dtype) -> ()");
|
||||
cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
|
||||
|
||||
// Gather cache blocks from src_cache to dst.
|
||||
// Gather cache blocks from src_cache to dst, dequantizing from
|
||||
// src_cache's dtype to dst's dtype if necessary.
|
||||
cache_ops.def(
|
||||
"gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, "
|
||||
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()");
|
||||
cache_ops.impl("gather_cache", torch::kCUDA, &gather_cache);
|
||||
"gather_and_maybe_dequant_cache(Tensor src_cache, Tensor! dst, "
|
||||
" Tensor block_table, Tensor cu_seq_lens, "
|
||||
" int batch_size, "
|
||||
" str kv_cache_dtype, "
|
||||
" Tensor scale, Tensor? seq_starts) -> ()");
|
||||
cache_ops.impl("gather_and_maybe_dequant_cache", torch::kCUDA,
|
||||
&gather_and_maybe_dequant_cache);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
|
||||
|
||||
@ -372,31 +372,45 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
|
||||
|
||||
# Install FlashInfer from source
|
||||
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
|
||||
# Keep this in sync with https://github.com/vllm-project/vllm/blob/main/requirements/cuda.txt
|
||||
# We use `--force-reinstall --no-deps` to avoid issues with the existing FlashInfer wheel.
|
||||
ARG FLASHINFER_GIT_REF="v0.2.11"
|
||||
# Keep this in sync with "flashinfer" extra in setup.py
|
||||
ARG FLASHINFER_GIT_REF="v0.2.12"
|
||||
# Flag to control whether to compile FlashInfer AOT kernels
|
||||
# Set to "true" to enable AOT compilation:
|
||||
# docker build --build-arg FLASHINFER_AOT_COMPILE=true ...
|
||||
ARG FLASHINFER_AOT_COMPILE=false
|
||||
RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
|
||||
. /etc/environment
|
||||
git clone --depth 1 --recursive --shallow-submodules \
|
||||
--branch ${FLASHINFER_GIT_REF} \
|
||||
${FLASHINFER_GIT_REPO} flashinfer
|
||||
# Exclude CUDA arches for older versions (11.x and 12.0-12.7)
|
||||
# TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg.
|
||||
if [[ "${CUDA_VERSION}" == 11.* ]]; then
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9"
|
||||
elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a"
|
||||
else
|
||||
# CUDA 12.8+ supports 10.0a and 12.0
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0"
|
||||
fi
|
||||
echo "🏗️ Building FlashInfer for arches: ${FI_TORCH_CUDA_ARCH_LIST}"
|
||||
# Needed to build AOT kernels
|
||||
pushd flashinfer
|
||||
TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \
|
||||
python3 -m flashinfer.aot
|
||||
TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \
|
||||
uv pip install --system --no-build-isolation --force-reinstall --no-deps .
|
||||
if [ "${FLASHINFER_AOT_COMPILE}" = "true" ]; then
|
||||
# Exclude CUDA arches for older versions (11.x and 12.0-12.7)
|
||||
# TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg.
|
||||
if [[ "${CUDA_VERSION}" == 11.* ]]; then
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9"
|
||||
elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a"
|
||||
else
|
||||
# CUDA 12.8+ supports 10.0a and 12.0
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0"
|
||||
fi
|
||||
echo "🏗️ Installing FlashInfer with AOT compilation for arches: ${FI_TORCH_CUDA_ARCH_LIST}"
|
||||
# Build AOT kernels
|
||||
TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \
|
||||
python3 -m flashinfer.aot
|
||||
# Install with no-build-isolation since we already built AOT kernels
|
||||
TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \
|
||||
uv pip install --system --no-build-isolation . \
|
||||
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
# Download pre-compiled cubins
|
||||
TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \
|
||||
python3 -m flashinfer --download-cubin || echo "WARNING: Failed to download flashinfer cubins."
|
||||
else
|
||||
echo "🏗️ Installing FlashInfer without AOT compilation in JIT mode"
|
||||
uv pip install --system . \
|
||||
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
fi
|
||||
popd
|
||||
rm -rf flashinfer
|
||||
BASH
|
||||
@ -418,31 +432,19 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
|
||||
# Install DeepGEMM from source
|
||||
ARG DEEPGEMM_GIT_REPO="https://github.com/deepseek-ai/DeepGEMM.git"
|
||||
ARG DEEPGEMM_GIT_REF="7b6b5563b9d4c1ae07ffbce7f78ad3ac9204827c"
|
||||
RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
|
||||
. /etc/environment
|
||||
CUDA_MAJOR="${CUDA_VERSION%%.*}"
|
||||
CUDA_MINOR="${CUDA_VERSION#${CUDA_MAJOR}.}"
|
||||
CUDA_MINOR="${CUDA_MINOR%%.*}"
|
||||
if [ "$CUDA_MAJOR" -ge 12 ] && [ "$CUDA_MINOR" -ge 8 ]; then
|
||||
git clone --recursive --shallow-submodules \
|
||||
${DEEPGEMM_GIT_REPO} deepgemm
|
||||
echo "🏗️ Building DeepGEMM"
|
||||
pushd deepgemm
|
||||
git checkout ${DEEPGEMM_GIT_REF}
|
||||
# Build DeepGEMM
|
||||
# (Based on https://github.com/deepseek-ai/DeepGEMM/blob/main/install.sh)
|
||||
rm -rf build dist
|
||||
rm -rf *.egg-info
|
||||
python3 setup.py bdist_wheel
|
||||
uv pip install --system dist/*.whl
|
||||
popd
|
||||
rm -rf deepgemm
|
||||
else
|
||||
echo "Skipping DeepGEMM installation (requires CUDA 12.8+ but got ${CUDA_VERSION})"
|
||||
fi
|
||||
BASH
|
||||
COPY tools/install_deepgemm.sh /tmp/install_deepgemm.sh
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
VLLM_DOCKER_BUILD_CONTEXT=1 /tmp/install_deepgemm.sh --cuda-version "${CUDA_VERSION}" --ref "${DEEPGEMM_GIT_REF}" \
|
||||
&& rm /tmp/install_deepgemm.sh
|
||||
|
||||
# Install EP kernels(pplx-kernels and DeepEP), NixL
|
||||
COPY tools/ep_kernels/install_python_libraries.sh install_python_libraries.sh
|
||||
COPY tools/install_nixl.sh install_nixl.sh
|
||||
ENV CUDA_HOME=/usr/local/cuda
|
||||
RUN export TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST:-9.0a+PTX}" \
|
||||
&& bash install_python_libraries.sh \
|
||||
&& bash install_nixl.sh --force
|
||||
|
||||
#################### vLLM installation IMAGE ####################
|
||||
|
||||
|
||||
@ -71,7 +71,7 @@ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm /vllm-workspace
|
||||
RUN cd /vllm-workspace \
|
||||
&& rm -rf vllm \
|
||||
&& python3 -m pip install -e tests/vllm_test_utils \
|
||||
&& python3 -m pip install lm-eval[api]==0.4.4 \
|
||||
&& python3 -m pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] \
|
||||
&& python3 -m pip install pytest-shard
|
||||
|
||||
# -----------------------
|
||||
|
||||
@ -7,7 +7,8 @@ WORKDIR /workspace/vllm
|
||||
# Install some basic utilities
|
||||
RUN apt-get update && apt-get install -y \
|
||||
git \
|
||||
ffmpeg libsm6 libxext6 libgl1
|
||||
ffmpeg libsm6 libxext6 libgl1 && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Build vLLM.
|
||||
COPY . .
|
||||
@ -16,6 +17,9 @@ RUN --mount=type=bind,source=.git,target=.git \
|
||||
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi
|
||||
|
||||
# Remove existing versions of dependencies
|
||||
# TODO: These packages will remain as dead weight in the Docker image layers.
|
||||
# We should find a way to build the image without uninstalling these.
|
||||
# Consider using a different base image.
|
||||
RUN pip uninstall -y torch torch_xla torchvision
|
||||
|
||||
ENV VLLM_TARGET_DEVICE="tpu"
|
||||
@ -23,9 +27,10 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
--mount=type=bind,source=.git,target=.git \
|
||||
python3 -m pip install \
|
||||
-r requirements/tpu.txt
|
||||
RUN python3 -m pip install -e .
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/pip python3 -m pip install -e .
|
||||
|
||||
# install development dependencies (for testing)
|
||||
RUN python3 -m pip install -e tests/vllm_test_utils
|
||||
RUN --mount=type=cache,target=/root/.cache/pip python3 -m pip install -e tests/vllm_test_utils
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below:
|
||||
|
||||
- [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/pDmAXHcN7Iqc8sUKgJgGtg), August 23rd 2025. [[Slides]](https://drive.google.com/drive/folders/1OvLx39wnCGy_WKq8SiVKf7YcxxYI3WCH)
|
||||
- [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/dgkWg1WFpWGO2jCdTqQHxA), August 2nd 2025. [[Slides]](https://drive.google.com/drive/folders/1Pid6NSFLU43DZRi0EaTcPgXsAzDvbBqF) [[Recording]](https://www.chaspark.com/#/live/1166916873711665152).
|
||||
- [NYC vLLM Meetup](https://lu.ma/c1rqyf1f), May 7th, 2025. [[Slides]](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing)
|
||||
- [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day), April 3rd 2025. [[Slides]](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing).
|
||||
|
||||
@ -129,6 +129,53 @@ Data parallelism replicates the entire model across multiple GPU sets and proces
|
||||
Data parallelism can be combined with the other parallelism strategies and is set by `data_parallel_size=N`.
|
||||
Note that MoE layers will be sharded according to the product of the tensor parallel size and data parallel size.
|
||||
|
||||
### Batch-level DP for Multi-Modal Encoders
|
||||
|
||||
By default, TP is used to shard the weights of multi-modal encoders just like for language decoders,
|
||||
in order to reduce the memory and compute load on each GPU.
|
||||
|
||||
However, since the size of multi-modal encoders is very small compared to language decoders,
|
||||
there is relatively little gain from TP. On the other hand, TP incurs significant communication
|
||||
overhead because of all-reduce being performed after every layer.
|
||||
|
||||
Given this, it may be advantageous to instead shard the batched input data using TP, essentially
|
||||
performing batch-level DP. This has been shown to improve the throughput by around 10% for
|
||||
`tensor_parallel_size=8`. For vision encoders that use hardware-unoptimized Conv3D operations,
|
||||
batch-level DP can provide another 40% increase to throughput compared to regular TP.
|
||||
|
||||
Nevertheless, since the weights of the multi-modal encoder are replicated across each TP rank,
|
||||
there will be a minor increase in memory consumption and may cause OOM if you can barely fit the model already.
|
||||
|
||||
You can enable batch-level DP by setting `mm_encoder_tp_mode="data"`, for example:
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
|
||||
llm = LLM(
|
||||
model="Qwen/Qwen2.5-VL-72B-Instruct",
|
||||
tensor_parallel_size=4,
|
||||
# When mm_encoder_tp_mode="data",
|
||||
# the vision encoder uses TP=4 (not DP=1) to shard the input data,
|
||||
# so the TP size becomes the effective DP size.
|
||||
# Note that this is independent of the DP size for language decoder which is used in expert parallel setting.
|
||||
mm_encoder_tp_mode="data",
|
||||
# The language decoder uses TP=4 to shard the weights regardless
|
||||
# of the setting of mm_encoder_tp_mode
|
||||
)
|
||||
```
|
||||
|
||||
!! important
|
||||
Batch-level DP is not to be confused with API request-level DP
|
||||
(which is instead controlled by `data_parallel_size`).
|
||||
|
||||
The availablilty of batch-level DP is based on model implementation.
|
||||
Currently, the following models support `mm_encoder_tp_mode="data"`:
|
||||
|
||||
- Llama4 (<gh-pr:18368>)
|
||||
- MiniCPM-V-4 (<gh-pr:23327>)
|
||||
- Qwen2.5-VL (<gh-pr:22742>)
|
||||
- Step3 (<gh-pr:22697>)
|
||||
|
||||
## Input Processing
|
||||
|
||||
### Parallel Processing
|
||||
@ -149,6 +196,13 @@ vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4 -dp 2
|
||||
!!! note
|
||||
API server scale-out is only available for online inference.
|
||||
|
||||
!!! warning
|
||||
By default, 8 CPU threads are used in each API server to load media items (e.g. images)
|
||||
from request data.
|
||||
|
||||
If you apply API server scale-out, consider adjusting `VLLM_MEDIA_LOADING_THREAD_COUNT`
|
||||
to avoid CPU resource exhaustion.
|
||||
|
||||
!!! note
|
||||
[Multi-modal processor cache](#processor-cache) is disabled when API server scale-out is enabled
|
||||
because it requires a one-to-one correspondance between API and engine core processes.
|
||||
|
||||
@ -18,7 +18,7 @@ vllm serve Qwen/Qwen1.5-32B-Chat-AWQ --max-model-len 4096
|
||||
|
||||
- Download and install [Anything LLM desktop](https://anythingllm.com/desktop).
|
||||
|
||||
- On the bottom left of open settings, AI Prooviders --> LLM:
|
||||
- On the bottom left of open settings, AI Providers --> LLM:
|
||||
- LLM Provider: Generic OpenAI
|
||||
- Base URL: http://{vllm server host}:{vllm server port}/v1
|
||||
- Chat Model Name: `Qwen/Qwen1.5-32B-Chat-AWQ`
|
||||
|
||||
@ -9,7 +9,7 @@ vLLM can be run on a cloud based GPU machine with [dstack](https://dstack.ai/),
|
||||
To install dstack client, run:
|
||||
|
||||
```bash
|
||||
pip install "dstack[all]
|
||||
pip install dstack[all]
|
||||
dstack server
|
||||
```
|
||||
|
||||
|
||||
@ -226,7 +226,7 @@ Doing this will add the new implementation to the test suite.
|
||||
|
||||
The unit test file [test_modular_kernel_combinations.py](gh-file:tests/kernels/moe/test_modular_kernel_combinations.py) can also be executed as a standalone script.
|
||||
Example: `python3 -m tests.kernels.moe.test_modular_kernel_combinations --pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts`
|
||||
As a side-effect, this script can be used to test `FusedMoEPrepareAndFinalize` & `FusedMoEPermuteExpertsUnpermute` compatibility. When invoked
|
||||
As a side effect, this script can be used to test `FusedMoEPrepareAndFinalize` & `FusedMoEPermuteExpertsUnpermute` compatibility. When invoked
|
||||
with incompatible types, the script will error.
|
||||
|
||||
### How To Profile
|
||||
|
||||
@ -565,7 +565,7 @@ model and then validate those tokens with the larger model.
|
||||
- `vllm:spec_decode_num_emitted_tokens_total` (Counter)
|
||||
|
||||
There is a PR under review (<gh-pr:12193>) to add "prompt lookup (ngram)"
|
||||
seculative decoding to v1. Other techniques will follow. We should
|
||||
speculative decoding to v1. Other techniques will follow. We should
|
||||
revisit the v0 metrics in this context.
|
||||
|
||||
!!! note
|
||||
|
||||
@ -77,7 +77,7 @@ The `multiproc_xpu_executor` forces the use of `spawn`.
|
||||
|
||||
There are other miscellaneous places hard-coding the use of `spawn`:
|
||||
|
||||
- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/distributed/device_communicators/custom_all_reduce_utils.py#L135>
|
||||
- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/distributed/device_communicators/all_reduce_utils.py#L135>
|
||||
- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/entrypoints/openai/api_server.py#L184>
|
||||
|
||||
Related PRs:
|
||||
|
||||
@ -422,7 +422,7 @@ a total of 128 * 16 / 256 = 8 inner iterations for a warp to handle
|
||||
a whole block of value tokens. And each `accs` in each thread
|
||||
contains 8 elements that accumulated at 8 different head positions.
|
||||
For the thread 0, the `accs` variable will have 8 elements, which
|
||||
are 0th, 32th … 224th elements of a value head that are accumulated
|
||||
are 0th, 32nd … 224th elements of a value head that are accumulated
|
||||
from all assigned 8 tokens.
|
||||
|
||||
## LV
|
||||
|
||||
@ -79,7 +79,7 @@ Since simple RTN does not require data for weight quantization and the activatio
|
||||
Install `vllm` and `lm-evaluation-harness` for evaluation:
|
||||
|
||||
```bash
|
||||
pip install vllm lm-eval==0.4.4
|
||||
pip install vllm git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api]
|
||||
```
|
||||
|
||||
Load and run the model in `vllm`:
|
||||
|
||||
@ -7,7 +7,7 @@ Intel Gaudi supports quantization of various modules and functions, including, b
|
||||
[Supported Modules\\Supported Functions\\Custom Patched Modules](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Quantization/Inference_Using_FP8.html#supported-modules).
|
||||
|
||||
!!! note
|
||||
Measurement files are required to run quantized models with vLLM on Gaudi accelerators. The FP8 model calibration procedure is described in the [vllm-hpu-extention](https://github.com/HabanaAI/vllm-hpu-extension/tree/main/calibration/README.md) package.
|
||||
Measurement files are required to run quantized models with vLLM on Gaudi accelerators. The FP8 model calibration procedure is described in the [vLLM HPU extension](https://github.com/HabanaAI/vllm-hpu-extension/tree/main/calibration/README.md) package.
|
||||
|
||||
!!! note
|
||||
`QUANT_CONFIG` is an environment variable that points to the measurement or quantization [JSON config file](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Quantization/Inference_Using_FP8.html#supported-json-config-file-options).
|
||||
|
||||
@ -18,7 +18,7 @@ pip install llmcompressor
|
||||
Additionally, install `vllm` and `lm-evaluation-harness` for evaluation:
|
||||
|
||||
```bash
|
||||
pip install vllm lm-eval==0.4.4
|
||||
pip install vllm git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api]
|
||||
```
|
||||
|
||||
## Quantization Process
|
||||
|
||||
@ -19,7 +19,7 @@ pip install llmcompressor
|
||||
Additionally, install `vllm` and `lm-evaluation-harness` for evaluation:
|
||||
|
||||
```bash
|
||||
pip install vllm lm-eval==0.4.4
|
||||
pip install vllm git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api]
|
||||
```
|
||||
|
||||
## Quantization Process
|
||||
|
||||
@ -20,7 +20,7 @@ for more installation details.
|
||||
Additionally, install `vllm` and `lm-evaluation-harness` for evaluation:
|
||||
|
||||
```bash
|
||||
pip install vllm lm-eval==0.4.4
|
||||
pip install vllm git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api]
|
||||
```
|
||||
|
||||
## Quantization Process
|
||||
|
||||
@ -284,6 +284,14 @@ Supported models:
|
||||
|
||||
Flags: `--tool-call-parser deepseek_v3 --chat-template {see_above}`
|
||||
|
||||
### DeepSeek-V3.1 Models (`deepseek_v31`)
|
||||
|
||||
Supported models:
|
||||
|
||||
* `deepseek-ai/DeepSeek-V3.1` (use with <gh-file:examples/tool_chat_template_deepseekv31.jinja>)
|
||||
|
||||
Flags: `--tool-call-parser deepseek_v31 --chat-template {see_above}`
|
||||
|
||||
### Kimi-K2 Models (`kimi_k2`)
|
||||
|
||||
Supported models:
|
||||
|
||||
@ -170,7 +170,7 @@ This value is 4GB by default. Larger space can support more concurrent requests,
|
||||
|
||||
First of all, please make sure the thread-binding and KV cache space are properly set and take effect. You can check the thread-binding by running a vLLM benchmark and observing CPU cores usage via `htop`.
|
||||
|
||||
Inference batch size is a important parameter for the performance. Larger batch usually provides higher throughput, smaller batch provides lower latency. Tuning max batch size starts from default value to balance throughput and latency is an effective way to improve vLLM CPU performance on specific platforms. There are two important related parameters in vLLM:
|
||||
Inference batch size is an important parameter for the performance. Larger batch usually provides higher throughput, smaller batch provides lower latency. Tuning max batch size starts from default value to balance throughput and latency is an effective way to improve vLLM CPU performance on specific platforms. There are two important related parameters in vLLM:
|
||||
|
||||
- `--max-num-batched-tokens`, defines the limit of token numbers in a single batch, has more impacts on the first token performance. The default value is set as:
|
||||
- Offline Inference: `4096 * world_size`
|
||||
@ -179,7 +179,7 @@ Inference batch size is a important parameter for the performance. Larger batch
|
||||
- Offline Inference: `256 * world_size`
|
||||
- Online Serving: `128 * world_size`
|
||||
|
||||
vLLM CPU supports tensor parallel (TP) and pipeline parallel (PP) to leverage multiple CPU sockets and memory nodes. For more detials of tuning TP and PP, please refer to [Optimization and Tuning](../../configuration/optimization.md). For vLLM CPU, it is recommend to use TP and PP togther if there are enough CPU sockets and memory nodes.
|
||||
vLLM CPU supports tensor parallel (TP) and pipeline parallel (PP) to leverage multiple CPU sockets and memory nodes. For more details of tuning TP and PP, please refer to [Optimization and Tuning](../../configuration/optimization.md). For vLLM CPU, it is recommend to use TP and PP together if there are enough CPU sockets and memory nodes.
|
||||
|
||||
### Which quantization configs does vLLM CPU support?
|
||||
|
||||
@ -190,6 +190,6 @@ vLLM CPU supports tensor parallel (TP) and pipeline parallel (PP) to leverage mu
|
||||
|
||||
### (x86 only) What is the purpose of `VLLM_CPU_MOE_PREPACK` and `VLLM_CPU_SGL_KERNEL`?
|
||||
|
||||
- Both of them requires `amx` CPU flag.
|
||||
- Both of them require `amx` CPU flag.
|
||||
- `VLLM_CPU_MOE_PREPACK` can provides better performance for MoE models
|
||||
- `VLLM_CPU_SGL_KERNEL` can provides better performance for MoE models and small-batch scenarios.
|
||||
|
||||
@ -261,13 +261,13 @@ Lower value corresponds to less usable graph memory reserved for prefill stage,
|
||||
|
||||
User can also configure the strategy for capturing HPU Graphs for prompt and decode stages separately. Strategy affects the order of capturing graphs. There are two strategies implemented:
|
||||
|
||||
- `max_bs` - graph capture queue will sorted in descending order by their batch sizes. Buckets with equal batch sizes are sorted by sequence length in ascending order (e.g. `(64, 128)`, `(64, 256)`, `(32, 128)`, `(32, 256)`, `(1, 128)`, `(1,256)`), default strategy for decode
|
||||
- `max_bs` - graph capture queue will be sorted in descending order by their batch sizes. Buckets with equal batch sizes are sorted by sequence length in ascending order (e.g. `(64, 128)`, `(64, 256)`, `(32, 128)`, `(32, 256)`, `(1, 128)`, `(1,256)`), default strategy for decode
|
||||
- `min_tokens` - graph capture queue will be sorted in ascending order by the number of tokens each graph processes (`batch_size*sequence_length`), default strategy for prompt
|
||||
|
||||
When there's large amount of requests pending, vLLM scheduler will attempt to fill the maximum batch size for decode as soon as possible. When a request is finished, decode batch size decreases. When that happens, vLLM will attempt to schedule a prefill iteration for requests in the waiting queue, to fill the decode batch size to its previous state. This means that in a full load scenario, decode batch size is often at its maximum, which makes large batch size HPU Graphs crucial to capture, as reflected by `max_bs` strategy. On the other hand, prefills will be executed most frequently with very low batch sizes (1-4), which is reflected in `min_tokens` strategy.
|
||||
|
||||
!!! note
|
||||
`VLLM_GRAPH_PROMPT_RATIO` does not set a hard limit on memory taken by graphs for each stage (prefill and decode). vLLM will first attempt to use up entirety of usable prefill graph memory (usable graph memory * `VLLM_GRAPH_PROMPT_RATIO`) for capturing prefill HPU Graphs, next it will attempt do the same for decode graphs and usable decode graph memory pool. If one stage is fully captured, and there is unused memory left within usable graph memory pool, vLLM will attempt further graph capture for the other stage, until no more HPU Graphs can be captured without exceeding reserved memory pool. The behavior on that mechanism can be observed in the example below.
|
||||
`VLLM_GRAPH_PROMPT_RATIO` does not set a hard limit on memory taken by graphs for each stage (prefill and decode). vLLM will first attempt to use up entirety of usable prefill graph memory (usable graph memory * `VLLM_GRAPH_PROMPT_RATIO`) for capturing prefill HPU Graphs, next it will attempt to do the same for decode graphs and usable decode graph memory pool. If one stage is fully captured, and there is unused memory left within usable graph memory pool, vLLM will attempt further graph capture for the other stage, until no more HPU Graphs can be captured without exceeding reserved memory pool. The behavior on that mechanism can be observed in the example below.
|
||||
|
||||
Each described step is logged by vLLM server, as follows (negative values correspond to memory being released):
|
||||
|
||||
|
||||
@ -328,11 +328,11 @@ th {
|
||||
| `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `BailingMoeForCausalLM` | Ling | `inclusionAI/Ling-lite-1.5`, `inclusionAI/Ling-plus`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | |
|
||||
| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | ✅︎ |
|
||||
| `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | |
|
||||
| `MBartForConditionalGeneration` | mBART | `facebook/mbart-large-en-ro`, `facebook/mbart-large-50`, etc. | | | |
|
||||
| `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `zai-org/chatglm2-6b`, `zai-org/chatglm3-6b`, `ShieldLM-6B-chatglm3`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `CohereForCausalLM`, `Cohere2ForCausalLM` | Command-R | `CohereLabs/c4ai-command-r-v01`, `CohereLabs/c4ai-command-r7b-12-2024`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `CohereForCausalLM`, `Cohere2ForCausalLM` | Command-R, Command-A | `CohereLabs/c4ai-command-r-v01`, `CohereLabs/c4ai-command-r7b-12-2024`, `CohereLabs/c4ai-command-a-03-2025`, `CohereLabs/command-a-reasoning-08-2025`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `DbrxForCausalLM` | DBRX | `databricks/dbrx-base`, `databricks/dbrx-instruct`, etc. | | ✅︎ | ✅︎ |
|
||||
| `DeciLMForCausalLM` | DeciLM | `nvidia/Llama-3_3-Nemotron-Super-49B-v1`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `DeepseekForCausalLM` | DeepSeek | `deepseek-ai/deepseek-llm-67b-base`, `deepseek-ai/deepseek-llm-7b-chat`, etc. | | ✅︎ | ✅︎ |
|
||||
@ -373,6 +373,7 @@ th {
|
||||
| `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ | ✅︎ |
|
||||
| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Lfm2ForCausalLM` | LFM2 | `LiquidAI/LFM2-1.2B`, `LiquidAI/LFM2-700M`, `LiquidAI/LFM2-350M`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | ✅︎ |
|
||||
@ -384,8 +385,8 @@ th {
|
||||
| `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ | ✅︎ |
|
||||
| `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | | ✅︎ | ✅︎ |
|
||||
| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | | ✅︎ | ✅︎ |
|
||||
| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ | ✅︎ |
|
||||
| `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | | ✅︎ | ✅︎ |
|
||||
| `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | ✅︎ |
|
||||
@ -400,6 +401,7 @@ th {
|
||||
| `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen3ForCausalLM` | Qwen3 | `Qwen/Qwen3-8B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen3MoeForCausalLM` | Qwen3MoE | `Qwen/Qwen3-30B-A3B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `SeedOssForCausalLM` | SeedOss | `ByteDance-Seed/Seed-OSS-36B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | | | ✅︎ |
|
||||
| `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | | ✅︎ | ✅︎ |
|
||||
| `SolarForCausalLM` | Solar Pro | `upstage/solar-pro-preview-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
@ -613,6 +615,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
||||
| `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Cohere2VisionForConditionalGeneration` | Command A Vision | T + I<sup>+</sup> | `CohereLabs/command-a-vision-07-2025`, etc. | | ✅︎ | ✅︎ |
|
||||
| `DeepseekVLV2ForCausalLM`<sup>^</sup> | DeepSeek-VL2 | T + I<sup>+</sup> | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ | ✅︎ |
|
||||
| `DonutForConditionalGeneration`<sup>^</sup> | Donut | T + I | `ByteDance/Dolphin`, `naver-clova-ix/donut-base-finetuned-docvqa`, etc. | | | |
|
||||
| `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large`, etc. | | | |
|
||||
| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ |
|
||||
@ -652,6 +655,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
||||
| `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen2_5_VLForConditionalGeneration` | Qwen2.5-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen2.5-Omni-7B` | | ✅︎ | ✅︎ |
|
||||
| `RForConditionalGeneration` | R-VL-4B | T + I<sup>E+</sup> | `YannQi/R-4B` | | ✅︎ | ✅︎ |
|
||||
| `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ |
|
||||
| `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ |
|
||||
| `Step3VLForConditionalGeneration` | Step3-VL | T + I<sup>+</sup> | `stepfun-ai/step3` | | ✅︎ | ✅︎ |
|
||||
|
||||
@ -107,7 +107,7 @@ to enable simultaneous generation and embedding using the same engine instance i
|
||||
#### Mamba Models
|
||||
|
||||
Models using selective state-space mechanisms instead of standard transformer attention are supported.
|
||||
Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`) are supported. Please note that these models currently require disabling prefix caching in V1. Additionally, Mamba-1 models require `enforce_eager=True`.
|
||||
Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`) are supported. Please note that these models currently require disabling prefix caching in V1.
|
||||
|
||||
Models that combine Mamba-2 and Mamba-1 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`,
|
||||
`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`, `JambaForCausalLM`). Please note that
|
||||
@ -154,16 +154,19 @@ differences compared to V0:
|
||||
|
||||
##### Logprobs Calculation
|
||||
|
||||
Logprobs in V1 are now returned immediately once computed from the model’s raw output (i.e.
|
||||
By default, logprobs in V1 are now returned immediately once computed from the model’s raw output (i.e.
|
||||
before applying any logits post-processing such as temperature scaling or penalty
|
||||
adjustments). As a result, the returned logprobs do not reflect the final adjusted
|
||||
probabilities used during sampling.
|
||||
|
||||
Support for logprobs with post-sampling adjustments is in progress and will be added in future updates.
|
||||
You can adjust this behavior by setting the `--logprobs-mode` flag.
|
||||
Four modes are supported: `raw_logprobs` (default), `processed_logprobs`, `raw_logits`, `processed_logits`.
|
||||
Raw means the values before applying any logit processors, like bad words.
|
||||
Processed means the values after applying all processors, including temperature and top_k/top_p.
|
||||
|
||||
##### Prompt Logprobs with Prefix Caching
|
||||
|
||||
Currently prompt logprobs are only supported when prefix caching is turned off via `--no-enable-prefix-caching`. In a future release, prompt logprobs will be compatible with prefix caching, but a recomputation will be triggered to recover the full prompt logprobs even upon a prefix cache hit. See details in [RFC #13414](gh-issue:13414).
|
||||
Logprobs are not cached. For a request requiring prompt logprobs, the engine will ignore the prefix cache and recompute the prefill of full prompt to generate the logprobs.
|
||||
|
||||
#### Deprecated Features
|
||||
|
||||
|
||||
311
examples/offline_inference/dolphin.py
Normal file
311
examples/offline_inference/dolphin.py
Normal file
@ -0,0 +1,311 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import regex as re
|
||||
from PIL import Image
|
||||
from transformers import DonutProcessor
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.multimodal.utils import fetch_image
|
||||
|
||||
|
||||
# Copied from https://github.com/bytedance/Dolphin/utils/utils.py
|
||||
@dataclass
|
||||
class ImageDimensions:
|
||||
original_w: int
|
||||
original_h: int
|
||||
padded_w: int
|
||||
padded_h: int
|
||||
|
||||
|
||||
# Copied from https://github.com/bytedance/Dolphin/utils/utils.py
|
||||
def map_to_original_coordinates(
|
||||
x1, y1, x2, y2, dims: ImageDimensions
|
||||
) -> tuple[int, int, int, int]:
|
||||
try:
|
||||
top = (dims.padded_h - dims.original_h) // 2
|
||||
left = (dims.padded_w - dims.original_w) // 2
|
||||
orig_x1 = max(0, x1 - left)
|
||||
orig_y1 = max(0, y1 - top)
|
||||
orig_x2 = min(dims.original_w, x2 - left)
|
||||
orig_y2 = min(dims.original_h, y2 - top)
|
||||
if orig_x2 <= orig_x1:
|
||||
orig_x2 = min(orig_x1 + 1, dims.original_w)
|
||||
if orig_y2 <= orig_y1:
|
||||
orig_y2 = min(orig_y1 + 1, dims.original_h)
|
||||
return int(orig_x1), int(orig_y1), int(orig_x2), int(orig_y2)
|
||||
except Exception as e:
|
||||
print(f"map_to_original_coordinates error: {str(e)}")
|
||||
return 0, 0, min(100, dims.original_w), min(100, dims.original_h)
|
||||
|
||||
|
||||
# Copied from https://github.com/bytedance/Dolphin/utils/utils.py
|
||||
def adjust_box_edges(image, boxes: list[list[float]], max_pixels=15, threshold=0.2):
|
||||
if isinstance(image, str):
|
||||
image = cv2.imread(image)
|
||||
img_h, img_w = image.shape[:2]
|
||||
new_boxes = []
|
||||
for box in boxes:
|
||||
best_box = copy.deepcopy(box)
|
||||
|
||||
def check_edge(img, current_box, i, is_vertical):
|
||||
edge = current_box[i]
|
||||
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
_, binary = cv2.threshold(
|
||||
gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU
|
||||
)
|
||||
if is_vertical:
|
||||
line = binary[current_box[1] : current_box[3] + 1, edge]
|
||||
else:
|
||||
line = binary[edge, current_box[0] : current_box[2] + 1]
|
||||
transitions = np.abs(np.diff(line))
|
||||
return np.sum(transitions) / len(transitions)
|
||||
|
||||
edges = [(0, -1, True), (2, 1, True), (1, -1, False), (3, 1, False)]
|
||||
current_box = copy.deepcopy(box)
|
||||
current_box[0] = min(max(current_box[0], 0), img_w - 1)
|
||||
current_box[1] = min(max(current_box[1], 0), img_h - 1)
|
||||
current_box[2] = min(max(current_box[2], 0), img_w - 1)
|
||||
current_box[3] = min(max(current_box[3], 0), img_h - 1)
|
||||
|
||||
for i, direction, is_vertical in edges:
|
||||
best_score = check_edge(image, current_box, i, is_vertical)
|
||||
if best_score <= threshold:
|
||||
continue
|
||||
for step in range(max_pixels):
|
||||
current_box[i] += direction
|
||||
if i == 0 or i == 2:
|
||||
current_box[i] = min(max(current_box[i], 0), img_w - 1)
|
||||
else:
|
||||
current_box[i] = min(max(current_box[i], 0), img_h - 1)
|
||||
score = check_edge(image, current_box, i, is_vertical)
|
||||
if score < best_score:
|
||||
best_score = score
|
||||
best_box = copy.deepcopy(current_box)
|
||||
if score <= threshold:
|
||||
break
|
||||
new_boxes.append(best_box)
|
||||
return new_boxes
|
||||
|
||||
|
||||
# Copied from https://github.com/bytedance/Dolphin/utils/utils.py
|
||||
def process_coordinates(coords, padded_image, dims: ImageDimensions, previous_box=None):
|
||||
try:
|
||||
x1, y1 = int(coords[0] * dims.padded_w), int(coords[1] * dims.padded_h)
|
||||
x2, y2 = int(coords[2] * dims.padded_w), int(coords[3] * dims.padded_h)
|
||||
x1, y1, x2, y2 = (
|
||||
max(0, min(x1, dims.padded_w - 1)),
|
||||
max(0, min(y1, dims.padded_h - 1)),
|
||||
max(0, min(x2, dims.padded_w)),
|
||||
max(0, min(y2, dims.padded_h)),
|
||||
)
|
||||
if x2 <= x1:
|
||||
x2 = min(x1 + 1, dims.padded_w)
|
||||
if y2 <= y1:
|
||||
y2 = min(y1 + 1, dims.padded_h)
|
||||
new_boxes = adjust_box_edges(padded_image, [[x1, y1, x2, y2]])
|
||||
x1, y1, x2, y2 = new_boxes[0]
|
||||
x1, y1, x2, y2 = (
|
||||
max(0, min(x1, dims.padded_w - 1)),
|
||||
max(0, min(y1, dims.padded_h - 1)),
|
||||
max(0, min(x2, dims.padded_w)),
|
||||
max(0, min(y2, dims.padded_h)),
|
||||
)
|
||||
if x2 <= x1:
|
||||
x2 = min(x1 + 1, dims.padded_w)
|
||||
if y2 <= y1:
|
||||
y2 = min(y1 + 1, dims.padded_h)
|
||||
if previous_box is not None:
|
||||
prev_x1, prev_y1, prev_x2, prev_y2 = previous_box
|
||||
if (x1 < prev_x2 and x2 > prev_x1) and (y1 < prev_y2 and y2 > prev_y1):
|
||||
y1 = prev_y2
|
||||
y1 = min(y1, dims.padded_h - 1)
|
||||
if y2 <= y1:
|
||||
y2 = min(y1 + 1, dims.padded_h)
|
||||
new_previous_box = [x1, y1, x2, y2]
|
||||
orig_x1, orig_y1, orig_x2, orig_y2 = map_to_original_coordinates(
|
||||
x1, y1, x2, y2, dims
|
||||
)
|
||||
return x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, new_previous_box
|
||||
except Exception as e:
|
||||
print(f"process_coordinates error: {str(e)}")
|
||||
orig_x1, orig_y1, orig_x2, orig_y2 = (
|
||||
0,
|
||||
0,
|
||||
min(100, dims.original_w),
|
||||
min(100, dims.original_h),
|
||||
)
|
||||
return 0, 0, 100, 100, orig_x1, orig_y1, orig_x2, orig_y2, [0, 0, 100, 100]
|
||||
|
||||
|
||||
# Copied from https://github.com/bytedance/Dolphin/utils/utils.py
|
||||
def prepare_image(image) -> tuple[np.ndarray, ImageDimensions]:
|
||||
try:
|
||||
image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
||||
original_h, original_w = image_cv.shape[:2]
|
||||
max_size = max(original_h, original_w)
|
||||
top = (max_size - original_h) // 2
|
||||
bottom = max_size - original_h - top
|
||||
left = (max_size - original_w) // 2
|
||||
right = max_size - original_w - left
|
||||
padded_image = cv2.copyMakeBorder(
|
||||
image_cv, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(0, 0, 0)
|
||||
)
|
||||
padded_h, padded_w = padded_image.shape[:2]
|
||||
dimensions = ImageDimensions(
|
||||
original_w=original_w,
|
||||
original_h=original_h,
|
||||
padded_w=padded_w,
|
||||
padded_h=padded_h,
|
||||
)
|
||||
return padded_image, dimensions
|
||||
except Exception as e:
|
||||
print(f"prepare_image error: {str(e)}")
|
||||
h, w = image.height, image.width
|
||||
dimensions = ImageDimensions(original_w=w, original_h=h, padded_w=w, padded_h=h)
|
||||
return np.zeros((h, w, 3), dtype=np.uint8), dimensions
|
||||
|
||||
|
||||
# Copied from https://github.com/bytedance/Dolphin/utils/utils.py
|
||||
def parse_layout_string(bbox_str):
|
||||
"""Parse layout string using regular expressions"""
|
||||
pattern = r"\[(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+)\]\s*(\w+)"
|
||||
matches = re.finditer(pattern, bbox_str)
|
||||
|
||||
parsed_results = []
|
||||
for match in matches:
|
||||
coords = [float(match.group(i)) for i in range(1, 5)]
|
||||
label = match.group(5).strip()
|
||||
parsed_results.append((coords, label))
|
||||
|
||||
return parsed_results
|
||||
|
||||
|
||||
model_id = "ByteDance/Dolphin"
|
||||
|
||||
# The input image size for Dolphin is 896 x 896,
|
||||
# and the patch_size is 4 x 4.
|
||||
# Therefore, the initial number of patches is:
|
||||
# Height: 896 / 4 = 224 patches
|
||||
# Width: 896 / 4 = 224 patches
|
||||
|
||||
# The Dolphin model uses a staged downsampling approach,
|
||||
# defined by the "depths": [2, 2, 14, 2] configuration.
|
||||
# Before entering stages 2, 3, and 4, a "Patch Merging" operation is performed,
|
||||
# which halves the feature map's dimensions (dividing both height and width by 2).
|
||||
# Before Stage 2: The size changes from 224 x 224 to (224/2) x (224/2) = 112 x 112.
|
||||
# Before Stage 3: The size changes from 112 x 112 to (112/2) x (112/2) = 56 x 56.
|
||||
# Before Stage 4: The size changes from 56 x 56 to (56/2) x (56/2) = 28 x 28.
|
||||
|
||||
# Because vLLM needs to fill the image features with an encoder_prompt,
|
||||
# and the encoder_prompt will have `<pad>` tokens added when tokenized,
|
||||
# we need to construct an encoder_prompt with a length of 28 x 28 - 1 = 783.
|
||||
encoder_prompt = "".join(["0"] * 783)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=2048,
|
||||
)
|
||||
|
||||
processor = DonutProcessor.from_pretrained(model_id)
|
||||
llm = LLM(
|
||||
model=model_id,
|
||||
dtype="float16",
|
||||
max_num_seqs=8,
|
||||
hf_overrides={"architectures": ["DonutForConditionalGeneration"]},
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--image_path", type=str, default=None, help="Path to a local image file."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.image_path:
|
||||
if not os.path.exists(args.image_path):
|
||||
raise FileNotFoundError(f"Error: File not found at {args.image_path}")
|
||||
image = Image.open(args.image_path).convert("RGB")
|
||||
else:
|
||||
image = fetch_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/0.jpg"
|
||||
)
|
||||
|
||||
|
||||
prompt = "Parse the reading order of this document. "
|
||||
decoder_prompt = f"<s>{prompt}<Answer/>"
|
||||
decoder_prompt_tokens = TokensPrompt(
|
||||
prompt_token_ids=processor.tokenizer(decoder_prompt, add_special_tokens=False)[
|
||||
"input_ids"
|
||||
]
|
||||
)
|
||||
enc_dec_prompt = ExplicitEncoderDecoderPrompt(
|
||||
encoder_prompt=TextPrompt(prompt=encoder_prompt, multi_modal_data={"image": image}),
|
||||
decoder_prompt=decoder_prompt_tokens,
|
||||
)
|
||||
layout_outputs = llm.generate(prompts=enc_dec_prompt, sampling_params=sampling_params)
|
||||
layout_result_str = layout_outputs[0].outputs[0].text
|
||||
print(f"Layout analysis output:\n{layout_result_str}")
|
||||
|
||||
padded_image, dims = prepare_image(image)
|
||||
layout_results = parse_layout_string(layout_result_str)
|
||||
text_table_elements = []
|
||||
previous_box = None
|
||||
reading_order = 0
|
||||
for bbox_coords, label in layout_results:
|
||||
if label == "fig":
|
||||
continue
|
||||
try:
|
||||
x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = (
|
||||
process_coordinates(bbox_coords, padded_image, dims, previous_box)
|
||||
)
|
||||
cropped = padded_image[y1:y2, x1:x2]
|
||||
if cropped.size > 0 and cropped.shape[0] > 3 and cropped.shape[1] > 3:
|
||||
pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB))
|
||||
prompt_ocr = (
|
||||
"Parse the table in the image. "
|
||||
if label == "tab"
|
||||
else "Read text in the image. "
|
||||
)
|
||||
text_table_elements.append(
|
||||
{
|
||||
"crop": pil_crop,
|
||||
"prompt": prompt_ocr,
|
||||
"reading_order": reading_order,
|
||||
}
|
||||
)
|
||||
reading_order += 1
|
||||
except Exception as e:
|
||||
print(f"Error processing bbox (label: {label}): {str(e)}")
|
||||
continue
|
||||
|
||||
if text_table_elements:
|
||||
batch_prompts = []
|
||||
for elem in text_table_elements:
|
||||
decoder_prompt_str = f"<s>{elem['prompt']}<Answer/>"
|
||||
decoder_prompt_tokens = TokensPrompt(
|
||||
prompt_token_ids=processor.tokenizer(
|
||||
decoder_prompt_str, add_special_tokens=False
|
||||
)["input_ids"]
|
||||
)
|
||||
enc_dec_prompt = ExplicitEncoderDecoderPrompt(
|
||||
encoder_prompt=TextPrompt(
|
||||
prompt=encoder_prompt, multi_modal_data={"image": elem["crop"]}
|
||||
),
|
||||
decoder_prompt=decoder_prompt_tokens,
|
||||
)
|
||||
batch_prompts.append(enc_dec_prompt)
|
||||
batch_outputs = llm.generate(prompts=batch_prompts, sampling_params=sampling_params)
|
||||
for i, output in enumerate(batch_outputs):
|
||||
text_table_elements[i]["text"] = output.outputs[0].text.strip()
|
||||
|
||||
print("------" * 8)
|
||||
text_table_elements.sort(key=lambda x: x["reading_order"])
|
||||
for elem in text_table_elements:
|
||||
print(elem.get("text", ""))
|
||||
@ -13,6 +13,7 @@ from typing import NamedTuple
|
||||
from vllm import LLM, EngineArgs, PromptType, SamplingParams
|
||||
from vllm.assets.audio import AudioAsset
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.multimodal.utils import fetch_image
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
@ -21,6 +22,50 @@ class ModelRequestData(NamedTuple):
|
||||
prompts: Sequence[PromptType]
|
||||
|
||||
|
||||
def run_donut():
|
||||
engine_args = EngineArgs(
|
||||
model="naver-clova-ix/donut-base-finetuned-docvqa",
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
dtype="float16",
|
||||
hf_overrides={"architectures": ["DonutForConditionalGeneration"]},
|
||||
)
|
||||
|
||||
# The input image size for donut-base-finetuned-docvqa is 2560 x 1920,
|
||||
# and the patch_size is 4 x 4.
|
||||
# Therefore, the initial number of patches is:
|
||||
# Height: 1920 / 4 = 480 patches
|
||||
# Width: 2560 / 4 = 640 patches
|
||||
# The Swin model uses a staged downsampling approach,
|
||||
# defined by the "depths": [2, 2, 14, 2] configuration.
|
||||
# Before entering stages 2, 3, and 4, a "Patch Merging" operation is performed,
|
||||
# which halves the feature map's dimensions (dividing both height and width by 2).
|
||||
# Before Stage 2: The size changes from 480 x 640 to (480/2) x (640/2) = 240 x 320.
|
||||
# Before Stage 3: The size changes from 240 x 320 to (240/2) x (320/2) = 120 x 160.
|
||||
# Before Stage 4: The size changes from 120 x 160 to (120/2) x (160/2) = 60 x 80.
|
||||
# Because vLLM needs to fill the image features with an encoder_prompt,
|
||||
# and the encoder_prompt will have `<pad>` tokens added when tokenized,
|
||||
# we need to construct an encoder_prompt with a length of 60 x 80 - 1 = 4799.
|
||||
prompts = [
|
||||
{
|
||||
"encoder_prompt": {
|
||||
"prompt": "".join(["$"] * 4799),
|
||||
"multi_modal_data": {
|
||||
"image": fetch_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/0.jpg"
|
||||
) # noqa: E501
|
||||
},
|
||||
},
|
||||
"decoder_prompt": "<s_docvqa><s_question>What time is the coffee break?</s_question><s_answer>", # noqa: E501
|
||||
},
|
||||
]
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
def run_florence2():
|
||||
engine_args = EngineArgs(
|
||||
model="microsoft/Florence-2-large",
|
||||
@ -118,6 +163,7 @@ def run_whisper():
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"donut": run_donut,
|
||||
"florence2": run_florence2,
|
||||
"mllama": run_mllama,
|
||||
"whisper": run_whisper,
|
||||
|
||||
@ -5,6 +5,7 @@ from transformers import AutoTokenizer
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.benchmarks.datasets import add_dataset_parser, get_samples
|
||||
from vllm.inputs import TokensPrompt
|
||||
from vllm.v1.metrics.reader import Counter, Vector
|
||||
|
||||
try:
|
||||
@ -137,7 +138,8 @@ def main():
|
||||
sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
|
||||
if not args.custom_mm_prompts:
|
||||
outputs = llm.generate(
|
||||
prompt_token_ids=prompt_ids, sampling_params=sampling_params
|
||||
TokensPrompt(prompt_token_ids=prompt_ids),
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
else:
|
||||
outputs = llm.chat(prompts, sampling_params=sampling_params)
|
||||
|
||||
@ -85,7 +85,7 @@ def format_output(title: str, output: str):
|
||||
|
||||
|
||||
def generate_output(prompt: str, sampling_params: SamplingParams, llm: LLM):
|
||||
outputs = llm.generate(prompts=prompt, sampling_params=sampling_params)
|
||||
outputs = llm.generate(prompt, sampling_params=sampling_params)
|
||||
return outputs[0].outputs[0].text
|
||||
|
||||
|
||||
|
||||
@ -283,8 +283,10 @@ def run_glm4v(questions: list[str], modality: str) -> ModelRequestData:
|
||||
)
|
||||
|
||||
prompts = [
|
||||
f"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>\
|
||||
{question}<|assistant|>"
|
||||
(
|
||||
"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>"
|
||||
f"{question}<|assistant|>"
|
||||
)
|
||||
for question in questions
|
||||
]
|
||||
|
||||
@ -767,15 +769,13 @@ def run_llava_next_video(questions: list[str], modality: str) -> ModelRequestDat
|
||||
def run_llava_onevision(questions: list[str], modality: str) -> ModelRequestData:
|
||||
if modality == "video":
|
||||
prompts = [
|
||||
f"<|im_start|>user <video>\n{question}<|im_end|> \
|
||||
<|im_start|>assistant\n"
|
||||
f"<|im_start|>user <video>\n{question}<|im_end|><|im_start|>assistant\n"
|
||||
for question in questions
|
||||
]
|
||||
|
||||
elif modality == "image":
|
||||
prompts = [
|
||||
f"<|im_start|>user <image>\n{question}<|im_end|> \
|
||||
<|im_start|>assistant\n"
|
||||
f"<|im_start|>user <image>\n{question}<|im_end|><|im_start|>assistant\n"
|
||||
for question in questions
|
||||
]
|
||||
|
||||
@ -998,8 +998,7 @@ def run_molmo(questions: list[str], modality: str) -> ModelRequestData:
|
||||
)
|
||||
|
||||
prompts = [
|
||||
f"<|im_start|>user <image>\n{question}<|im_end|> \
|
||||
<|im_start|>assistant\n"
|
||||
f"<|im_start|>user <image>\n{question}<|im_end|><|im_start|>assistant\n"
|
||||
for question in questions
|
||||
]
|
||||
|
||||
@ -1436,6 +1435,28 @@ def run_qwen2_5_omni(questions: list[str], modality: str):
|
||||
)
|
||||
|
||||
|
||||
# R-4B
|
||||
def run_r_vl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
model_name = "YannQi/R-4B"
|
||||
|
||||
prompts = [
|
||||
f"<|im_start|>user <image>\n{question}<|im_end|><|im_start|>assistant\n"
|
||||
for question in questions
|
||||
]
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=16384,
|
||||
limit_mm_per_prompt={modality: 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# SkyworkR1V
|
||||
def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
@ -1622,6 +1643,7 @@ model_example_map = {
|
||||
"qwen2_vl": run_qwen2_vl,
|
||||
"qwen2_5_vl": run_qwen2_5_vl,
|
||||
"qwen2_5_omni": run_qwen2_5_omni,
|
||||
"rvl": run_r_vl,
|
||||
"skywork_chat": run_skyworkr1v,
|
||||
"smolvlm": run_smolvlm,
|
||||
"step3": run_step3,
|
||||
|
||||
@ -992,6 +992,39 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
def load_r_vl(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "YannQi/R-4B"
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=16384,
|
||||
max_num_seqs=16,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
)
|
||||
|
||||
placeholders = [{"type": "image", "image": url} for url in image_urls]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
*placeholders,
|
||||
{"type": "text", "text": question},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||||
|
||||
prompt = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
image_data=[fetch_image(url) for url in image_urls],
|
||||
)
|
||||
|
||||
|
||||
def load_smolvlm(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "HuggingFaceTB/SmolVLM2-2.2B-Instruct"
|
||||
|
||||
@ -1193,6 +1226,7 @@ model_example_map = {
|
||||
"qwen_vl_chat": load_qwen_vl_chat,
|
||||
"qwen2_vl": load_qwen2_vl,
|
||||
"qwen2_5_vl": load_qwen2_5_vl,
|
||||
"rvl": load_r_vl,
|
||||
"smolvlm": load_smolvlm,
|
||||
"step3": load_step3,
|
||||
"tarsier": load_tarsier,
|
||||
|
||||
91
examples/tool_chat_template_deepseekv31.jinja
Normal file
91
examples/tool_chat_template_deepseekv31.jinja
Normal file
@ -0,0 +1,91 @@
|
||||
{% if not add_generation_prompt is defined %}
|
||||
{% set add_generation_prompt = false %}
|
||||
{% endif %}
|
||||
{% if not thinking is defined %}
|
||||
{% set thinking = false %}
|
||||
{% endif %}
|
||||
{% set ns = namespace(is_first=false, is_tool=false, system_prompt='', is_first_sp=true, is_last_user=false) %}
|
||||
{%- for message in messages %}
|
||||
{%- if message['role'] == 'system' %}
|
||||
{%- if ns.is_first_sp %}
|
||||
{% set ns.system_prompt = ns.system_prompt + message['content'] %}
|
||||
{% set ns.is_first_sp = false %}
|
||||
{%- else %}
|
||||
{% set ns.system_prompt = ns.system_prompt + '\n\n' + message['content'] %}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
|
||||
{% if tools is defined and tools is not none %}
|
||||
{% set tool_ns = namespace(text='## Tools\nYou have access to the following tools:\n') %}
|
||||
{% for tool in tools %}
|
||||
{% set tool_ns.text = tool_ns.text + '\n### ' + tool.function.name + '\nDescription: ' + tool.function.description + '\n\nParameters: ' + (tool.function.parameters | tojson) + '\n' %}
|
||||
{% endfor %}
|
||||
{% set tool_ns.text = tool_ns.text + "\nIMPORTANT: ALWAYS adhere to this exact format for tool use:\n<|tool▁calls▁begin|><|tool▁call▁begin|>tool_call_name<|tool▁sep|>tool_call_arguments<|tool▁call▁end|>{{additional_tool_calls}}<|tool▁calls▁end|>\n\nWhere:\n\n- `tool_call_name` must be an exact match to one of the available tools\n- `tool_call_arguments` must be valid JSON that strictly follows the tool's Parameters Schema\n- For multiple tool calls, chain them directly without separators or spaces\n" %}
|
||||
{% set ns.system_prompt = ns.system_prompt + '\n\n' + tool_ns.text %}
|
||||
{% endif %}
|
||||
|
||||
{{ bos_token }}{{ ns.system_prompt }}
|
||||
{%- for message in messages %}
|
||||
{%- if message['role'] == 'user' %}
|
||||
{%- set ns.is_tool = false -%}
|
||||
{%- set ns.is_first = false -%}
|
||||
{%- set ns.is_last_user = true -%}
|
||||
{{'<|User|>' + message['content']}}
|
||||
{%- endif %}
|
||||
{%- if message['role'] == 'assistant' and message['tool_calls'] is defined and message['tool_calls'] is not none %}
|
||||
{%- if ns.is_last_user %}
|
||||
{{'<|Assistant|></think>'}}
|
||||
{%- endif %}
|
||||
{%- set ns.is_last_user = false -%}
|
||||
{%- set ns.is_first = false %}
|
||||
{%- set ns.is_tool = false -%}
|
||||
{%- for tool in message['tool_calls'] %}
|
||||
{%- if not ns.is_first %}
|
||||
{%- if message['content'] is none %}
|
||||
{{'<|tool▁calls▁begin|><|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments']|tojson + '<|tool▁call▁end|>'}}
|
||||
{%- else %}
|
||||
{{message['content'] + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments']|tojson + '<|tool▁call▁end|>'}}
|
||||
{%- endif %}
|
||||
{%- set ns.is_first = true -%}
|
||||
{%- else %}
|
||||
{{'<|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments']|tojson + '<|tool▁call▁end|>'}}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}
|
||||
{%- endif %}
|
||||
{%- if message['role'] == 'assistant' and (message['tool_calls'] is not defined or message['tool_calls'] is none) %}
|
||||
{%- if ns.is_last_user %}
|
||||
{{'<|Assistant|>'}}
|
||||
{%- if message['prefix'] is defined and message['prefix'] and thinking %}
|
||||
{{'<think>'}}
|
||||
{%- else %}
|
||||
{{'</think>'}}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- set ns.is_last_user = false -%}
|
||||
{%- if ns.is_tool %}
|
||||
{{message['content'] + '<|end▁of▁sentence|>'}}
|
||||
{%- set ns.is_tool = false -%}
|
||||
{%- else %}
|
||||
{%- set content = message['content'] -%}
|
||||
{%- if '</think>' in content %}
|
||||
{%- set content = content.split('</think>', 1)[1] -%}
|
||||
{%- endif %}
|
||||
{{content + '<|end▁of▁sentence|>'}}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- if message['role'] == 'tool' %}
|
||||
{%- set ns.is_last_user = false -%}
|
||||
{%- set ns.is_tool = true -%}
|
||||
{{'<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}
|
||||
{%- endif %}
|
||||
{%- endfor -%}
|
||||
{%- if add_generation_prompt and ns.is_last_user and not ns.is_tool %}
|
||||
{{'<|Assistant|>'}}
|
||||
{%- if not thinking %}
|
||||
{{'</think>'}}
|
||||
{%- else %}
|
||||
{{'<think>'}}
|
||||
{%- endif %}
|
||||
{% endif %}
|
||||
123
examples/tool_chat_template_gemma3_pythonic.jinja
Normal file
123
examples/tool_chat_template_gemma3_pythonic.jinja
Normal file
@ -0,0 +1,123 @@
|
||||
{#- Begin-of-sequence token to start the model prompt -#}
|
||||
{{ bos_token }}
|
||||
{#- Extracts the system message. Gemma does not support system messages so it will be prepended to first user message. -#}
|
||||
{%- if messages[0]['role'] == 'system' -%}
|
||||
{%- if messages[0]['content'] is string -%}
|
||||
{%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}
|
||||
{%- else -%}
|
||||
{%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}
|
||||
{%- endif -%}
|
||||
{%- set loop_messages = messages[1:] -%}
|
||||
{%- else -%}
|
||||
{%- set first_user_prefix = "" -%}
|
||||
{%- set loop_messages = messages -%}
|
||||
{%- endif -%}
|
||||
{#- Set tools to none if not defined for this ChatCompletion request (helps avoid errors later) -#}
|
||||
{%- if not tools is defined %}
|
||||
{%- set tools = none %}
|
||||
{%- endif %}
|
||||
{#- Validate alternating user/assistant messages (excluding 'tool' messages and ones with tool_calls) -#}
|
||||
{%- for message in loop_messages | rejectattr("role", "equalto", "tool") | selectattr("tool_calls", "undefined") -%}
|
||||
{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
|
||||
{{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
|
||||
{#- Main loop over all messages in the conversation history -#}
|
||||
{%- for message in loop_messages -%}
|
||||
{#- Normalize roles for model prompt formatting -#}
|
||||
{%- if (message['role'] == 'assistant') -%}
|
||||
{%- set role = "model" -%}
|
||||
{%- elif (message['role'] == 'tool') -%}
|
||||
{%- set role = "user" -%}
|
||||
{%- else -%}
|
||||
{%- set role = message['role'] -%}
|
||||
{%- endif -%}
|
||||
{#- Mark the start of a message block with the appropriate role -#}
|
||||
{{ '<start_of_turn>' + role + '\n' -}}
|
||||
|
||||
{#- Insert system message content (if present) at the beginning of the first message. -#}
|
||||
{%- if loop.first -%}
|
||||
{{ first_user_prefix }}
|
||||
{#- Append system message with tool information if using tools in message request. -#}
|
||||
{%- if tools is not none -%}
|
||||
{{- "Tools (functions) are available. If you decide to invoke one or more of the tools, you must respond with a python list of the function calls.\n" -}}
|
||||
{{- "Example Format: [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)] \n" -}}
|
||||
{{- "Do not use variables. DO NOT USE MARKDOWN SYNTAX. You SHOULD NOT include any other text in the response if you call a function. If none of the functions can be used, point it out. If you lack the parameters required by the function, also point it out.\n" -}}
|
||||
{{- "Here is a list of functions in JSON format that you can invoke.\n" -}}
|
||||
{{- tools | tojson(indent=4) -}}
|
||||
{{- "\n\n" -}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
|
||||
{#- Format model tool calls (turns where model indicates they want to call a tool) -#}
|
||||
{%- if 'tool_calls' in message -%}
|
||||
{#- Opening bracket for tool call list. -#}
|
||||
{{- '[' -}}
|
||||
{#- For each tool call -#}
|
||||
{%- for tool_call in message.tool_calls -%}
|
||||
{#- Get tool call function. -#}
|
||||
{%- if tool_call.function is defined -%}
|
||||
{%- set tool_call = tool_call.function -%}
|
||||
{%- endif -%}
|
||||
{#- Function name & opening parenthesis. -#}
|
||||
{{- tool_call.name + '(' -}}
|
||||
|
||||
{#-- Handle arguments as list (positional) or dict (named) --#}
|
||||
{#-- Named arguments (dict) --#}
|
||||
{%- if tool_call.arguments is iterable and tool_call.arguments is mapping -%}
|
||||
{%- set first = true -%}
|
||||
{%- for key, val in tool_call.arguments.items() -%}
|
||||
{%- if not first %}, {% endif -%}
|
||||
{{ key }}={{ val | tojson }}
|
||||
{%- set first = false -%}
|
||||
{%- endfor -%}
|
||||
{#-- Positional arguments (list) --#}
|
||||
{%- elif tool_call.arguments is iterable -%}
|
||||
{{- tool_call.arguments | map('tojson') | join(', ') -}}
|
||||
{#-- Fallback: single positional value --#}
|
||||
{%- else -%}
|
||||
{{- tool_call.arguments | tojson -}}
|
||||
{#-- Closing parenthesis. --#}
|
||||
{%- endif -%}
|
||||
{{- ')' -}}
|
||||
{#-- If more than one tool call, place comma and move to formatting next tool call --#}
|
||||
{%- if not loop.last -%}, {% endif -%}
|
||||
{%- endfor -%}
|
||||
{#- Closing bracket for tool call list. -#}
|
||||
{{- ']' -}}
|
||||
{%- endif -%}
|
||||
|
||||
{#- Tool response start tag (for messages from a tool) -#}
|
||||
{%- if (message['role'] == 'tool') -%}
|
||||
{{ '<tool_response>\n' -}}
|
||||
{%- endif -%}
|
||||
|
||||
{#- Render the message content: handle plain string or multimodal content like image/text -#}
|
||||
{%- if message['content'] is string -%}
|
||||
{{ message['content'] | trim }}
|
||||
{%- elif message['content'] is iterable -%}
|
||||
{%- for item in message['content'] -%}
|
||||
{%- if item['type'] == 'image' -%}
|
||||
{{ '<start_of_image>' }}
|
||||
{%- elif item['type'] == 'text' -%}
|
||||
{{ item['text'] | trim }}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
{%- else -%}
|
||||
{{ raise_exception("Invalid content type") }}
|
||||
{%- endif -%}
|
||||
|
||||
{#- Tool response end tag -#}
|
||||
{%- if (message['role'] == 'tool') -%}
|
||||
{{ '</tool_response>' -}}
|
||||
{%- endif -%}
|
||||
|
||||
{#- Mark end of a single turn -#}
|
||||
{{ '<end_of_turn>\n' }}
|
||||
{%- endfor -%}
|
||||
|
||||
{#- If generation is to be triggered, add model prompt prefix -#}
|
||||
{%- if add_generation_prompt -%}
|
||||
{{'<start_of_turn>model\n'}}
|
||||
{%- endif -%}
|
||||
@ -1,10 +1,14 @@
|
||||
{%- if messages %}
|
||||
{%- if system_message or tools %}
|
||||
<|system|>
|
||||
|
||||
{%- if system_message %}
|
||||
{{ system_message }}
|
||||
{%- if messages and messages[0]['role'] == 'system' %}
|
||||
{%- set system_message = messages[0]['content']|trim %}
|
||||
{%- set messages = messages[1:] %}
|
||||
{%- else %}
|
||||
{%- set system_message = "You are a helpful assistant." %}
|
||||
{%- endif %}
|
||||
|
||||
{%- if messages %}
|
||||
<|system|>
|
||||
{{ system_message }}
|
||||
{%- if tools %}
|
||||
In addition to plain text responses, you can chose to call one or more of the provided functions.
|
||||
|
||||
Use the following rule to decide when to call a function:
|
||||
@ -19,13 +23,11 @@ If you decide to call functions:
|
||||
* make sure you pick the right functions that match the user intent
|
||||
|
||||
|
||||
{%- if tools %}
|
||||
{%- for t in tools %}
|
||||
{{- t | tojson(indent=4) }}
|
||||
{{- "\n\n" }}
|
||||
{%- endfor %}
|
||||
{%- endif %}<|end|>
|
||||
{%- endif %}
|
||||
|
||||
{%- for message in messages %}
|
||||
{%- if message.role != "system" %}
|
||||
|
||||
@ -13,12 +13,12 @@ protobuf # Required by LlamaTokenizer.
|
||||
fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint.
|
||||
aiohttp
|
||||
openai >= 1.99.1 # For Responses API with reasoning content
|
||||
pydantic >= 2.10
|
||||
pydantic >= 2.11.7
|
||||
prometheus_client >= 0.18.0
|
||||
pillow # Required for image processing
|
||||
prometheus-fastapi-instrumentator >= 7.0.0
|
||||
tiktoken >= 0.6.0 # Required for DBRX tokenizer
|
||||
lm-format-enforcer >= 0.10.11, < 0.11
|
||||
lm-format-enforcer == 0.11.3
|
||||
llguidance >= 0.7.11, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64"
|
||||
outlines_core == 0.2.10 ; platform_machine != "s390x"
|
||||
outlines == 0.1.11 ; platform_machine == "s390x"
|
||||
|
||||
@ -27,7 +27,7 @@ mistral_common[image,audio] >= 1.8.2 # required for voxtral test
|
||||
num2words # required for smolvlm test
|
||||
opencv-python-headless >= 4.11.0 # required for video test
|
||||
datamodel_code_generator # required for minicpm3 test
|
||||
lm-eval[api]==0.4.8 # required for model evaluation test
|
||||
lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d # required for model evaluation test
|
||||
mteb>=1.38.11, <2 # required for mteb test
|
||||
transformers==4.52.4
|
||||
tokenizers==0.21.1
|
||||
|
||||
@ -6,7 +6,7 @@ torch==2.7.0
|
||||
torchvision==0.22.0
|
||||
torchaudio==2.7.0
|
||||
|
||||
triton==3.2
|
||||
triton==3.3.0
|
||||
cmake>=3.26.1,<4
|
||||
packaging>=24.2
|
||||
setuptools>=77.0.3,<80.0.0
|
||||
|
||||
@ -17,4 +17,4 @@ setuptools>=77.0.3,<80.0.0
|
||||
setuptools-scm>=8
|
||||
runai-model-streamer==0.11.0
|
||||
runai-model-streamer-s3==0.11.0
|
||||
conch-triton-kernels==1.2.1
|
||||
conch-triton-kernels==1.2.1
|
||||
@ -32,7 +32,8 @@ num2words # required for smolvlm test
|
||||
open_clip_torch==2.32.0 # Required for nemotron_vl test
|
||||
opencv-python-headless >= 4.11.0 # required for video test
|
||||
datamodel_code_generator # required for minicpm3 test
|
||||
lm-eval[api]==0.4.8 # required for model evaluation test
|
||||
# TODO: Use lm-eval[api]==0.4.10 once released
|
||||
lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d # required for model evaluation test
|
||||
mteb[bm25s]>=1.38.11, <2 # required for mteb test
|
||||
transformers==4.55.2
|
||||
tokenizers==0.21.1
|
||||
|
||||
@ -408,7 +408,7 @@ lightning-utilities==0.14.3
|
||||
# torchmetrics
|
||||
llvmlite==0.44.0
|
||||
# via numba
|
||||
lm-eval==0.4.8
|
||||
lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d
|
||||
# via -r requirements/test.in
|
||||
lxml==5.3.0
|
||||
# via
|
||||
@ -742,7 +742,7 @@ pycparser==2.22
|
||||
# via cffi
|
||||
pycryptodomex==3.22.0
|
||||
# via blobfile
|
||||
pydantic==2.11.5
|
||||
pydantic==2.11.7
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# albumentations
|
||||
|
||||
19
setup.py
19
setup.py
@ -643,16 +643,25 @@ if envs.VLLM_USE_PRECOMPILED:
|
||||
if wheel_location is not None:
|
||||
wheel_url = wheel_location
|
||||
else:
|
||||
import platform
|
||||
arch = platform.machine()
|
||||
if arch == "x86_64":
|
||||
wheel_tag = "manylinux1_x86_64"
|
||||
elif arch == "aarch64":
|
||||
wheel_tag = "manylinux2014_aarch64"
|
||||
else:
|
||||
raise ValueError(f"Unsupported architecture: {arch}")
|
||||
base_commit = precompiled_wheel_utils.get_base_commit_in_main_branch()
|
||||
wheel_url = f"https://wheels.vllm.ai/{base_commit}/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl"
|
||||
wheel_url = f"https://wheels.vllm.ai/{base_commit}/vllm-1.0.0.dev-cp38-abi3-{wheel_tag}.whl"
|
||||
nightly_wheel_url = f"https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-{wheel_tag}.whl"
|
||||
from urllib.request import urlopen
|
||||
try:
|
||||
with urlopen(wheel_url) as resp:
|
||||
if resp.status != 200:
|
||||
wheel_url = "https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl"
|
||||
wheel_url = nightly_wheel_url
|
||||
except Exception as e:
|
||||
print(f"[warn] Falling back to nightly wheel: {e}")
|
||||
wheel_url = "https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl"
|
||||
wheel_url = nightly_wheel_url
|
||||
|
||||
patch = precompiled_wheel_utils.extract_precompiled_and_patch_package(
|
||||
wheel_url)
|
||||
@ -685,7 +694,9 @@ setup(
|
||||
"mistral_common[audio]"], # Required for audio processing
|
||||
"video": [], # Kept for backwards compatibility
|
||||
# FlashInfer should be updated together with the Dockerfile
|
||||
"flashinfer": ["flashinfer-python==0.2.11"],
|
||||
"flashinfer": ["flashinfer-python==0.2.12"],
|
||||
# Optional deps for AMD FP4 quantization support
|
||||
"petit-kernel": ["petit-kernel"],
|
||||
},
|
||||
cmdclass=cmdclass,
|
||||
package_data=package_data,
|
||||
|
||||
@ -177,3 +177,34 @@ def test_end_to_end(monkeypatch: pytest.MonkeyPatch, model: str, use_v1: bool):
|
||||
|
||||
# cmp output
|
||||
assert output[0].outputs[0].text == output3[0].outputs[0].text
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
def test_deep_sleep():
|
||||
model = "Qwen/Qwen3-0.6B"
|
||||
free, total = torch.cuda.mem_get_info()
|
||||
used_bytes_baseline = total - free # in case other process is running
|
||||
llm = LLM(model, enable_sleep_mode=True)
|
||||
prompt = "How are you?"
|
||||
sampling_params = SamplingParams(temperature=0, max_tokens=10)
|
||||
output = llm.generate(prompt, sampling_params)
|
||||
|
||||
# Put the engine to deep sleep
|
||||
llm.sleep(level=2)
|
||||
|
||||
free_gpu_bytes_after_sleep, total = torch.cuda.mem_get_info()
|
||||
used_bytes = total - free_gpu_bytes_after_sleep - used_bytes_baseline
|
||||
assert used_bytes < 3 * GiB_bytes
|
||||
|
||||
llm.wake_up(tags=["weights"])
|
||||
llm.collective_rpc("reload_weights")
|
||||
free_gpu_bytes_wake_up_w, total = torch.cuda.mem_get_info()
|
||||
used_bytes = total - free_gpu_bytes_wake_up_w - used_bytes_baseline
|
||||
assert used_bytes < 4 * GiB_bytes
|
||||
|
||||
# now allocate kv cache and cuda graph memory
|
||||
llm.wake_up(tags=["kv_cache"])
|
||||
output2 = llm.generate(prompt, sampling_params)
|
||||
|
||||
# cmp output
|
||||
assert output[0].outputs[0].text == output2[0].outputs[0].text
|
||||
|
||||
344
tests/benchmarks/test_random_dataset.py
Normal file
344
tests/benchmarks/test_random_dataset.py
Normal file
@ -0,0 +1,344 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import random
|
||||
from typing import Any, NamedTuple, Optional, cast
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||
|
||||
from vllm.benchmarks.datasets import (RandomDataset, RandomMultiModalDataset,
|
||||
SampleRequest)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def hf_tokenizer() -> PreTrainedTokenizerBase:
|
||||
# Use a small, commonly available tokenizer
|
||||
return AutoTokenizer.from_pretrained("gpt2")
|
||||
|
||||
|
||||
class Params(NamedTuple):
|
||||
num_requests: int
|
||||
prefix_len: int
|
||||
range_ratio: float
|
||||
input_len: int
|
||||
output_len: int
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def random_dataset_params() -> Params:
|
||||
return Params(num_requests=16,
|
||||
prefix_len=7,
|
||||
range_ratio=0.3,
|
||||
input_len=50,
|
||||
output_len=20)
|
||||
|
||||
|
||||
def _fingerprint_sample(req: SampleRequest) -> tuple[str, int, int]:
|
||||
"""Project a SampleRequest into a comparable tuple."""
|
||||
return (req.prompt, req.prompt_len, req.expected_output_len)
|
||||
|
||||
|
||||
def _collect_samples(dataset: RandomDataset,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
num_requests: int = 16,
|
||||
prefix_len: int = 7,
|
||||
range_ratio: float = 0.3,
|
||||
input_len: int = 50,
|
||||
output_len: int = 20) -> list[tuple[str, int, int]]:
|
||||
samples = dataset.sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=num_requests,
|
||||
prefix_len=prefix_len,
|
||||
range_ratio=range_ratio,
|
||||
input_len=input_len,
|
||||
output_len=output_len,
|
||||
)
|
||||
return [_fingerprint_sample(s) for s in samples]
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_dataset_same_seed(
|
||||
hf_tokenizer: PreTrainedTokenizerBase,
|
||||
random_dataset_params: Params) -> None:
|
||||
"""Same seed should yield identical outputs, even if global RNGs change.
|
||||
|
||||
This guards against accidental reliance on Python's random or np.random
|
||||
in RandomDataset after moving to numpy.default_rng.
|
||||
"""
|
||||
p = random_dataset_params
|
||||
common_seed = 123
|
||||
dataset_a = RandomDataset(random_seed=common_seed)
|
||||
dataset_b = RandomDataset(random_seed=common_seed)
|
||||
a = _collect_samples(dataset_a,
|
||||
hf_tokenizer,
|
||||
num_requests=p.num_requests,
|
||||
prefix_len=p.prefix_len,
|
||||
range_ratio=p.range_ratio,
|
||||
input_len=p.input_len,
|
||||
output_len=p.output_len)
|
||||
|
||||
# Perturb global RNG state to ensure isolation
|
||||
random.seed(999)
|
||||
_ = [random.random() for _ in range(100)]
|
||||
np.random.seed(888)
|
||||
_ = [np.random.random() for _ in range(100)]
|
||||
|
||||
b = _collect_samples(dataset_b,
|
||||
hf_tokenizer,
|
||||
num_requests=p.num_requests,
|
||||
prefix_len=p.prefix_len,
|
||||
range_ratio=p.range_ratio,
|
||||
input_len=p.input_len,
|
||||
output_len=p.output_len)
|
||||
assert a == b
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_dataset_different_seeds(
|
||||
hf_tokenizer: PreTrainedTokenizerBase,
|
||||
random_dataset_params: Params) -> None:
|
||||
"""Different seeds should change outputs with overwhelming likelihood."""
|
||||
p = random_dataset_params
|
||||
seed_a = 0
|
||||
dataset_a = RandomDataset(random_seed=seed_a)
|
||||
a = _collect_samples(dataset_a,
|
||||
hf_tokenizer,
|
||||
num_requests=p.num_requests,
|
||||
prefix_len=p.prefix_len,
|
||||
range_ratio=p.range_ratio,
|
||||
input_len=p.input_len,
|
||||
output_len=p.output_len)
|
||||
|
||||
seed_b = 999
|
||||
dataset_b = RandomDataset(random_seed=seed_b)
|
||||
# Perturb global RNG with same seed as dataset_a to ensure isolation
|
||||
random.seed(seed_a)
|
||||
np.random.seed(seed_a)
|
||||
b = _collect_samples(dataset_b,
|
||||
hf_tokenizer,
|
||||
num_requests=p.num_requests,
|
||||
prefix_len=p.prefix_len,
|
||||
range_ratio=p.range_ratio,
|
||||
input_len=p.input_len,
|
||||
output_len=p.output_len)
|
||||
assert a != b
|
||||
|
||||
|
||||
# -----------------------------
|
||||
# RandomMultiModalDataset tests
|
||||
# -----------------------------
|
||||
|
||||
def _mm_fingerprint_sample(
|
||||
req: SampleRequest,
|
||||
) -> tuple[str, int, int, int, list[str]]:
|
||||
"""Create a compact fingerprint for multimodal samples.
|
||||
|
||||
Includes:
|
||||
- prompt string
|
||||
- prompt_len
|
||||
- expected_output_len
|
||||
- count of multimodal items
|
||||
- per-item type and URL prefix (e.g., 'data:image/jpeg;base64,')
|
||||
"""
|
||||
items = req.multi_modal_data or []
|
||||
item_prefixes: list[str] = []
|
||||
for it in items:
|
||||
if isinstance(it, dict) and it.get("type") == "image_url":
|
||||
url = it.get("image_url", {}).get("url", "")
|
||||
# Only keep a short identifying prefix to avoid huge strings
|
||||
item_prefixes.append(f"image:{url[:22]}")
|
||||
elif isinstance(it, dict) and it.get("type") == "video_url":
|
||||
url = it.get("video_url", {}).get("url", "")
|
||||
item_prefixes.append(f"video:{url[:22]}")
|
||||
else:
|
||||
item_prefixes.append("unknown:")
|
||||
return (req.prompt, req.prompt_len, req.expected_output_len, len(items),
|
||||
item_prefixes)
|
||||
|
||||
|
||||
def _collect_mm_samples(
|
||||
dataset: RandomMultiModalDataset,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
*,
|
||||
num_requests: int = 8,
|
||||
prefix_len: int = 3,
|
||||
range_ratio: float = 0.0,
|
||||
input_len: int = 20,
|
||||
output_len: int = 5,
|
||||
base_items_per_request: int = 2,
|
||||
num_mm_items_range_ratio: float = 0.0,
|
||||
limit_mm_per_prompt: Optional[dict[str, int]] = None,
|
||||
bucket_config: Optional[dict[tuple[int, int, int], float]] = None,
|
||||
enable_multimodal_chat: bool = False,
|
||||
) -> list[SampleRequest]:
|
||||
if limit_mm_per_prompt is None:
|
||||
limit_mm_per_prompt = {"image": 5, "video": 0}
|
||||
if bucket_config is None:
|
||||
bucket_config = {(32, 32, 1): 0.5, (52, 64, 1): 0.5}
|
||||
return dataset.sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=num_requests,
|
||||
prefix_len=prefix_len,
|
||||
range_ratio=range_ratio,
|
||||
input_len=input_len,
|
||||
output_len=output_len,
|
||||
base_items_per_request=base_items_per_request,
|
||||
num_mm_items_range_ratio=num_mm_items_range_ratio,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
bucket_config=bucket_config,
|
||||
enable_multimodal_chat=enable_multimodal_chat,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_mm_same_seed(hf_tokenizer: PreTrainedTokenizerBase) -> None:
|
||||
seed = 42
|
||||
ds_a = RandomMultiModalDataset(random_seed=seed)
|
||||
ds_b = RandomMultiModalDataset(random_seed=seed)
|
||||
a = _collect_mm_samples(ds_a, hf_tokenizer)
|
||||
b = _collect_mm_samples(ds_b, hf_tokenizer)
|
||||
fa = [_mm_fingerprint_sample(s) for s in a]
|
||||
fb = [_mm_fingerprint_sample(s) for s in b]
|
||||
assert fa == fb
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_mm_different_seeds(
|
||||
hf_tokenizer: PreTrainedTokenizerBase,
|
||||
) -> None:
|
||||
ds_a = RandomMultiModalDataset(random_seed=0)
|
||||
ds_b = RandomMultiModalDataset(random_seed=999)
|
||||
a = _collect_mm_samples(ds_a, hf_tokenizer)
|
||||
b = _collect_mm_samples(ds_b, hf_tokenizer)
|
||||
fa = [_mm_fingerprint_sample(s) for s in a]
|
||||
fb = [_mm_fingerprint_sample(s) for s in b]
|
||||
assert fa != fb
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_mm_respects_limits(
|
||||
hf_tokenizer: PreTrainedTokenizerBase,
|
||||
) -> None:
|
||||
ds = RandomMultiModalDataset(random_seed=0)
|
||||
# Requesting 3 items with a per-prompt limit of 1 should error per current
|
||||
# design (dataset refuses to silently clamp below the requested baseline).
|
||||
with pytest.raises(ValueError):
|
||||
_collect_mm_samples(
|
||||
ds,
|
||||
hf_tokenizer,
|
||||
num_requests=12,
|
||||
base_items_per_request=3,
|
||||
num_mm_items_range_ratio=0.0,
|
||||
limit_mm_per_prompt={"image": 1, "video": 0},
|
||||
bucket_config={(32, 32, 1): 1.0},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_mm_zero_prob_entries_are_removed(
|
||||
hf_tokenizer: PreTrainedTokenizerBase,
|
||||
) -> None:
|
||||
ds = RandomMultiModalDataset(random_seed=0)
|
||||
# Second bucket has zero probability and should be ignored after
|
||||
# normalization
|
||||
samples = _collect_mm_samples(
|
||||
ds,
|
||||
hf_tokenizer,
|
||||
num_requests=6,
|
||||
base_items_per_request=2,
|
||||
num_mm_items_range_ratio=0.0,
|
||||
limit_mm_per_prompt={"image": 10, "video": 0},
|
||||
bucket_config={(32, 32, 1): 1.0, (52, 64, 1): 0.0},
|
||||
)
|
||||
for s in samples:
|
||||
assert isinstance(s.multi_modal_data, list)
|
||||
typed_mm = cast(list[dict[str, Any]], s.multi_modal_data)
|
||||
for it in typed_mm:
|
||||
assert it.get("type") == "image_url"
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_mm_zero_items(hf_tokenizer: PreTrainedTokenizerBase) -> None:
|
||||
ds = RandomMultiModalDataset(random_seed=0)
|
||||
samples = _collect_mm_samples(
|
||||
ds,
|
||||
hf_tokenizer,
|
||||
num_requests=5,
|
||||
base_items_per_request=0,
|
||||
num_mm_items_range_ratio=0.0,
|
||||
limit_mm_per_prompt={"image": 5, "video": 0},
|
||||
bucket_config={(32, 32, 1): 1.0},
|
||||
)
|
||||
for s in samples:
|
||||
assert s.multi_modal_data == []
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_mm_num_items_per_prompt(
|
||||
hf_tokenizer: PreTrainedTokenizerBase) -> None:
|
||||
ds = RandomMultiModalDataset(random_seed=0)
|
||||
# Fixed number of images per prompt
|
||||
# set num_mm_items_range_ratio to 0.0
|
||||
# TODO: modify video values when video sampling is implemented
|
||||
samples_fixed_items = _collect_mm_samples(
|
||||
ds,
|
||||
hf_tokenizer,
|
||||
num_requests=5,
|
||||
base_items_per_request=3,
|
||||
num_mm_items_range_ratio=0.0,
|
||||
limit_mm_per_prompt={"image": 3, "video": 0},
|
||||
bucket_config={(32, 32, 1): 1.0},
|
||||
)
|
||||
# Must have 5 requests each with 3 mm items per prompt
|
||||
assert len(samples_fixed_items) == 5
|
||||
for s in samples_fixed_items:
|
||||
mm_data = cast(list[dict[str, Any]], s.multi_modal_data)
|
||||
assert len(mm_data) == 3
|
||||
for it in mm_data:
|
||||
assert it.get("type") == "image_url"
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_random_mm_bucket_config_not_mutated(
|
||||
hf_tokenizer: PreTrainedTokenizerBase,
|
||||
) -> None:
|
||||
|
||||
ds = RandomMultiModalDataset(random_seed=0)
|
||||
# This bucket config is not normalized to sum to 1
|
||||
# and has more buckets than requested images
|
||||
original = {(32, 32, 1): 0.2, (52, 64, 1): 6, (25, 64, 1): 3}
|
||||
# Keep a snapshot to compare after sampling
|
||||
snapshot = dict(original)
|
||||
|
||||
_ = _collect_mm_samples(
|
||||
ds,
|
||||
hf_tokenizer,
|
||||
num_requests=4,
|
||||
base_items_per_request=1,
|
||||
num_mm_items_range_ratio=0.0,
|
||||
limit_mm_per_prompt={"image": 1, "video": 0},
|
||||
bucket_config=original,
|
||||
)
|
||||
|
||||
# Ensure the original dict content is unchanged
|
||||
assert original == snapshot
|
||||
|
||||
|
||||
# Vary number of mm items per prompt
|
||||
# set num_mm_items_range_ratio to 0.5
|
||||
samples_varying_items = _collect_mm_samples(
|
||||
ds,
|
||||
hf_tokenizer,
|
||||
num_requests=5,
|
||||
base_items_per_request=2,
|
||||
num_mm_items_range_ratio=0.5,
|
||||
limit_mm_per_prompt={"image": 4, "video": 0},
|
||||
bucket_config={(32, 32, 1): 1.0},
|
||||
)
|
||||
# Must have 5 requests each with less than 4 mm items per prompt
|
||||
# but at least 1 mm item per prompt
|
||||
assert len(samples_varying_items) == 5
|
||||
for s in samples_varying_items:
|
||||
mm_data = cast(list[dict[str, Any]], s.multi_modal_data)
|
||||
assert len(mm_data) <= 4
|
||||
assert len(mm_data) >= 1
|
||||
for it in mm_data:
|
||||
assert it.get("type") == "image_url"
|
||||
@ -12,10 +12,9 @@ from vllm.compilation.backends import set_model_tag
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import (ignore_torch_compile,
|
||||
support_torch_compile)
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
|
||||
set_current_vllm_config)
|
||||
from vllm.envs import VLLM_USE_V1
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
||||
VllmConfig, set_current_vllm_config)
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
# create a library to hold the custom op
|
||||
@ -164,104 +163,34 @@ class SimpleModelWithTwoGraphs(ParentModel):
|
||||
return x
|
||||
|
||||
|
||||
def test_ignore_torch_compile_decorator():
|
||||
assert VLLM_USE_V1
|
||||
|
||||
# piecewise
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly.attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
))
|
||||
|
||||
@support_torch_compile
|
||||
class A(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = '',
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x + x
|
||||
attn_output = torch.empty_like(x)
|
||||
torch.ops.silly.attention(x, x, x, attn_output)
|
||||
x = attn_output
|
||||
x = x * 3
|
||||
return x
|
||||
|
||||
@ignore_torch_compile
|
||||
class B(A):
|
||||
...
|
||||
|
||||
@support_torch_compile
|
||||
class C(B):
|
||||
...
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda()
|
||||
|
||||
# A has support_torch_compile
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1,
|
||||
num_piecewise_graphs_seen=3,
|
||||
num_piecewise_capturable_graphs_seen=2,
|
||||
num_backend_compilations=2,
|
||||
num_cudagraph_captured=4,
|
||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
), set_forward_context({}, vllm_config=vllm_config):
|
||||
# first run is for compile
|
||||
mod_A(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
|
||||
# run cudagraph captured sizes
|
||||
mod_A(torch.randn(2, MLP_SIZE).cuda())
|
||||
mod_A(torch.randn(1, MLP_SIZE).cuda())
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mod_B = B(vllm_config=vllm_config, prefix='').eval().cuda()
|
||||
|
||||
# B's ignore_torch_compile should override A's support_torch_compile
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=0,
|
||||
num_piecewise_graphs_seen=0,
|
||||
num_piecewise_capturable_graphs_seen=0,
|
||||
num_backend_compilations=0,
|
||||
num_cudagraph_captured=0,
|
||||
), set_forward_context({}, vllm_config=vllm_config):
|
||||
mod_B(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
|
||||
mod_B(torch.randn(2, MLP_SIZE).cuda())
|
||||
mod_B(torch.randn(1, MLP_SIZE).cuda())
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mod_C = C(vllm_config=vllm_config, prefix='').eval().cuda()
|
||||
|
||||
# C's support_torch_compile should override B's ignore_torch_compile
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1,
|
||||
num_piecewise_graphs_seen=3,
|
||||
num_piecewise_capturable_graphs_seen=2,
|
||||
num_backend_compilations=2,
|
||||
num_cudagraph_captured=4,
|
||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
), set_forward_context({}, vllm_config=vllm_config):
|
||||
mod_C(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
|
||||
mod_C(torch.randn(2, MLP_SIZE).cuda())
|
||||
mod_C(torch.randn(1, MLP_SIZE).cuda())
|
||||
|
||||
|
||||
@torch.inference_mode
|
||||
def run_model(vllm_config, model: nn.Module, inputs: torch.Tensor):
|
||||
def run_model(vllm_config: VllmConfig, model: nn.Module, inputs: torch.Tensor,
|
||||
cudagraph_runtime_mode: CUDAGraphMode):
|
||||
with set_forward_context({}, vllm_config=vllm_config):
|
||||
# First run is for compile
|
||||
# warmup for the model with cudagraph_mode NONE
|
||||
model(inputs)
|
||||
|
||||
# Run CUDAGraph captured sizes
|
||||
model(inputs[:2])
|
||||
model(inputs[:1])
|
||||
# simulate cudagraphs capturing
|
||||
with set_forward_context({},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=2, )):
|
||||
model(inputs[:2])
|
||||
with set_forward_context({},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=1, )):
|
||||
model(inputs[:1])
|
||||
|
||||
output = model(inputs[:2])
|
||||
# simulate cudagraphs replay
|
||||
with set_forward_context({},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=2, )):
|
||||
output = model(inputs[:2])
|
||||
|
||||
output = output.cpu()
|
||||
return output.cpu()
|
||||
@ -277,6 +206,7 @@ def test_multi_graph_piecewise_compile_outputs_equal():
|
||||
splitting_ops=["silly.attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
))
|
||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
|
||||
@ -299,11 +229,13 @@ def test_multi_graph_piecewise_compile_outputs_equal():
|
||||
num_cudagraph_captured=8,
|
||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
):
|
||||
outputs.append(run_model(vllm_config, model, inputs))
|
||||
outputs.append(
|
||||
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
||||
|
||||
# no compile or cudagraph
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.NO_COMPILATION, ))
|
||||
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
|
||||
@ -318,7 +250,8 @@ def test_multi_graph_piecewise_compile_outputs_equal():
|
||||
num_backend_compilations=0,
|
||||
num_cudagraph_captured=0,
|
||||
):
|
||||
outputs.append(run_model(vllm_config, model, inputs))
|
||||
outputs.append(
|
||||
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
||||
|
||||
# piecewise compile without CUDA graph
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
@ -326,6 +259,7 @@ def test_multi_graph_piecewise_compile_outputs_equal():
|
||||
use_cudagraph=False,
|
||||
splitting_ops=["silly.attention"],
|
||||
))
|
||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE,
|
||||
@ -340,7 +274,8 @@ def test_multi_graph_piecewise_compile_outputs_equal():
|
||||
num_backend_compilations=4,
|
||||
num_cudagraph_captured=0, # no cudagraph captured
|
||||
):
|
||||
outputs.append(run_model(vllm_config, model, inputs))
|
||||
outputs.append(
|
||||
run_model(vllm_config, model, inputs, cudagraph_runtime_mode))
|
||||
|
||||
# Generally don't expect outputs with and without inductor
|
||||
# to be bitwise equivalent
|
||||
|
||||
251
tests/compile/test_decorator.py
Normal file
251
tests/compile/test_decorator.py
Normal file
@ -0,0 +1,251 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.library import Library
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import (ignore_torch_compile,
|
||||
support_torch_compile)
|
||||
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
|
||||
CUDAGraphMode, VllmConfig, set_current_vllm_config)
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
# create a library to hold the custom op
|
||||
silly_lib = Library("silly", "FRAGMENT") # noqa
|
||||
|
||||
BATCH_SIZE = 32
|
||||
MLP_SIZE = 128
|
||||
|
||||
|
||||
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
out.copy_(q)
|
||||
out += k
|
||||
out += v
|
||||
|
||||
|
||||
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
||||
out: torch.Tensor) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="attention",
|
||||
op_func=silly_attention,
|
||||
mutates_args=["out"],
|
||||
fake_impl=silly_attention_fake,
|
||||
target_lib=silly_lib,
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode
|
||||
def run_model(vllm_config: VllmConfig, model: nn.Module,
|
||||
cudagraph_runtime_mode: CUDAGraphMode):
|
||||
with set_forward_context({}, vllm_config=vllm_config):
|
||||
# warmup for the model with cudagraph_mode NONE
|
||||
model(torch.randn(BATCH_SIZE, MLP_SIZE).cuda())
|
||||
|
||||
# simulate cudagraphs capturing
|
||||
with set_forward_context({},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=2, )):
|
||||
model(torch.randn(2, MLP_SIZE).cuda())
|
||||
with set_forward_context({},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=1, )):
|
||||
model(torch.randn(1, MLP_SIZE).cuda())
|
||||
|
||||
# simulate cudagraphs replay
|
||||
with set_forward_context({},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=2, )):
|
||||
output = model(torch.randn(2, MLP_SIZE).cuda())
|
||||
|
||||
output = output.cpu()
|
||||
return output.cpu()
|
||||
|
||||
|
||||
def test_ignore_torch_compile_decorator():
|
||||
# piecewise
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly.attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
))
|
||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||
|
||||
@support_torch_compile
|
||||
class A(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = '',
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x + x
|
||||
attn_output = torch.empty_like(x)
|
||||
torch.ops.silly.attention(x, x, x, attn_output)
|
||||
x = attn_output
|
||||
x = x * 3
|
||||
return x
|
||||
|
||||
@ignore_torch_compile
|
||||
class B(A):
|
||||
...
|
||||
|
||||
@support_torch_compile
|
||||
class C(B):
|
||||
...
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda()
|
||||
|
||||
# A has support_torch_compile
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1,
|
||||
num_piecewise_graphs_seen=3,
|
||||
num_piecewise_capturable_graphs_seen=2,
|
||||
num_backend_compilations=2,
|
||||
num_cudagraph_captured=4,
|
||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
):
|
||||
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mod_B = B(vllm_config=vllm_config, prefix='').eval().cuda()
|
||||
|
||||
# B's ignore_torch_compile should override A's support_torch_compile
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=0,
|
||||
num_piecewise_graphs_seen=0,
|
||||
num_piecewise_capturable_graphs_seen=0,
|
||||
num_backend_compilations=0,
|
||||
num_cudagraph_captured=0,
|
||||
):
|
||||
run_model(vllm_config, mod_B, cudagraph_runtime_mode)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mod_C = C(vllm_config=vllm_config, prefix='').eval().cuda()
|
||||
|
||||
# C's support_torch_compile should override B's ignore_torch_compile
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1,
|
||||
num_piecewise_graphs_seen=3,
|
||||
num_piecewise_capturable_graphs_seen=2,
|
||||
num_backend_compilations=2,
|
||||
num_cudagraph_captured=4,
|
||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
):
|
||||
run_model(vllm_config, mod_C, cudagraph_runtime_mode)
|
||||
|
||||
|
||||
# Only enable torch.compile if
|
||||
# vllm_config.cache_config.kv_sharing_fast_prefill=True
|
||||
@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config.
|
||||
kv_sharing_fast_prefill)
|
||||
class B(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = '',
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x + x
|
||||
attn_output = torch.empty_like(x)
|
||||
torch.ops.silly.attention(x, x, x, attn_output)
|
||||
x = attn_output
|
||||
x = x + x
|
||||
return x
|
||||
|
||||
|
||||
# Only enable torch.compile if
|
||||
# vllm_config.cache_config.kv_sharing_fast_prefill=False
|
||||
@support_torch_compile(enable_if=lambda vllm_config: not vllm_config.
|
||||
cache_config.kv_sharing_fast_prefill)
|
||||
class A(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = '',
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
self.mod1 = B(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||
self.mod2 = B(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.mod1(x)
|
||||
attn_output = torch.empty_like(x)
|
||||
torch.ops.silly.attention(x, x, x, attn_output)
|
||||
x = attn_output
|
||||
x = self.mod2(x)
|
||||
return x
|
||||
|
||||
|
||||
def test_conditional_compile_enable_if():
|
||||
vllm_config = VllmConfig(cache_config=CacheConfig(
|
||||
kv_sharing_fast_prefill=True, ),
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly.attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
))
|
||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda()
|
||||
|
||||
# A has support_torch_compile but enable_if fn returns False
|
||||
# enalbe_if will be True for B, so we expect mod1 and mod2
|
||||
# to be compiled
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=2,
|
||||
num_piecewise_graphs_seen=6,
|
||||
# 3 piecewise graphs per instance of B()
|
||||
num_piecewise_capturable_graphs_seen=4,
|
||||
num_backend_compilations=4,
|
||||
num_cudagraph_captured=8,
|
||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
):
|
||||
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
|
||||
|
||||
# Set kv_sharing_fast_prefill=False
|
||||
# which will cause A to be compiled and B to not be compiled
|
||||
vllm_config = VllmConfig(cache_config=CacheConfig(
|
||||
kv_sharing_fast_prefill=False, ),
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
use_cudagraph=True,
|
||||
splitting_ops=["silly.attention"],
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
))
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda()
|
||||
|
||||
with compilation_counter.expect(
|
||||
num_graphs_seen=1,
|
||||
num_piecewise_graphs_seen=7,
|
||||
# 3 attn ops and 4 non-attn ops
|
||||
num_piecewise_capturable_graphs_seen=4,
|
||||
num_backend_compilations=4,
|
||||
num_cudagraph_captured=8,
|
||||
# num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
):
|
||||
run_model(vllm_config, mod_A, cudagraph_runtime_mode)
|
||||
@ -53,12 +53,6 @@ def models_list(*, all: bool = True, keywords: Optional[list[str]] = None):
|
||||
"quantization": "gptq_marlin_24"
|
||||
}))
|
||||
|
||||
if is_quant_method_supported("marlin"):
|
||||
TEST_MODELS.append(
|
||||
("robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin", {
|
||||
"quantization": "marlin"
|
||||
}))
|
||||
|
||||
if not current_platform.is_rocm() and is_quant_method_supported("awq"):
|
||||
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", {
|
||||
"quantization": "AWQ"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user