Compare commits

..

3 Commits

Author SHA1 Message Date
bcf3c8230d Merge branch 'main' into woosuk-jf 2025-05-04 11:16:07 -07:00
a01af39aa8 Merge branch 'main' into woosuk-jf 2025-05-03 10:42:43 -07:00
eeb5761cf1 Implement Jump-Forward (Fast-Forwrd) Decoding
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-05-01 18:08:52 -07:00
344 changed files with 9955 additions and 14892 deletions

View File

@ -1,11 +0,0 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m RedHatAI/Llama-3.2-1B-Instruct-FP8 -b "auto" -l 1319 -f 5 -t 1
model_name: "RedHatAI/Llama-3.2-1B-Instruct-FP8"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.335
- name: "exact_match,flexible-extract"
value: 0.323
limit: 1319
num_fewshot: 5

View File

@ -1,11 +0,0 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m Qwen/Qwen2.5-1.5B-Instruct -b auto -l 1319 -f 5 -t 1
model_name: "Qwen/Qwen2.5-1.5B-Instruct"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.54
- name: "exact_match,flexible-extract"
value: 0.59
limit: 1319
num_fewshot: 5

View File

@ -1,11 +0,0 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic -b auto -l 1319 -f 5 -t 1
model_name: "RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.47
- name: "exact_match,flexible-extract"
value: 0.64
limit: 1319
num_fewshot: 5

View File

@ -3,4 +3,3 @@ Meta-Llama-3-70B-Instruct.yaml
Mixtral-8x7B-Instruct-v0.1.yaml
Qwen2-57B-A14-Instruct.yaml
DeepSeek-V2-Lite-Chat.yaml
Meta-Llama-3-8B-QQQ.yaml

View File

@ -1,6 +1,10 @@
Qwen2.5-1.5B-Instruct.yaml
Meta-Llama-3-8B-Instruct.yaml
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
Qwen1.5-MoE-W4A16-compressed-tensors.yaml
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
Qwen2-1.5B-Instruct-FP8W8.yaml
Meta-Llama-3-8B-QQQ.yaml

View File

@ -1,39 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
from pathlib import Path
import pytest
def pytest_addoption(parser):
parser.addoption(
"--config-list-file",
action="store",
help="Path to the file listing model config YAMLs (one per line)")
parser.addoption("--tp-size",
action="store",
default="1",
help="Tensor parallel size to use for evaluation")
@pytest.fixture(scope="session")
def config_list_file(pytestconfig, config_dir):
rel_path = pytestconfig.getoption("--config-list-file")
return config_dir / rel_path
@pytest.fixture(scope="session")
def tp_size(pytestconfig):
return pytestconfig.getoption("--tp-size")
def pytest_generate_tests(metafunc):
if "config_filename" in metafunc.fixturenames:
rel_path = metafunc.config.getoption("--config-list-file")
config_list_file = Path(rel_path).resolve()
config_dir = config_list_file.parent
with open(config_list_file, encoding="utf-8") as f:
configs = [
config_dir / line.strip() for line in f
if line.strip() and not line.startswith("#")
]
metafunc.parametrize("config_filename", configs)

View File

@ -0,0 +1,59 @@
#!/bin/bash
usage() {
echo``
echo "Runs lm eval harness on GSM8k using vllm and compares to "
echo "precomputed baseline (measured by HF transformers.)"
echo
echo "usage: ${0} <options>"
echo
echo " -c - path to the test data config (e.g. configs/small-models.txt)"
echo " -t - tensor parallel size"
echo
}
SUCCESS=0
while getopts "c:t:" OPT; do
case ${OPT} in
c )
CONFIG="$OPTARG"
;;
t )
TP_SIZE="$OPTARG"
;;
\? )
usage
exit 1
;;
esac
done
# Parse list of configs.
IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < "$CONFIG"
for MODEL_CONFIG in "${MODEL_CONFIGS[@]}"
do
LOCAL_SUCCESS=0
echo "=== RUNNING MODEL: $MODEL_CONFIG WITH TP SIZE: $TP_SIZE==="
export LM_EVAL_TEST_DATA_FILE=$PWD/configs/${MODEL_CONFIG}
export LM_EVAL_TP_SIZE=$TP_SIZE
pytest -s test_lm_eval_correctness.py || LOCAL_SUCCESS=$?
if [[ $LOCAL_SUCCESS == 0 ]]; then
echo "=== PASSED MODEL: ${MODEL_CONFIG} ==="
else
echo "=== FAILED MODEL: ${MODEL_CONFIG} ==="
fi
SUCCESS=$((SUCCESS + LOCAL_SUCCESS))
done
if [ "${SUCCESS}" -eq "0" ]; then
exit 0
else
exit 1
fi

View File

@ -3,25 +3,35 @@
LM eval harness on model to compare vs HF baseline computed offline.
Configs are found in configs/$MODEL.yaml
pytest -s -v test_lm_eval_correctness.py \
--config-list-file=configs/models-small.txt \
--tp-size=1
* export LM_EVAL_TEST_DATA_FILE=configs/Meta-Llama-3-70B-Instruct.yaml
* export LM_EVAL_TP_SIZE=4
* pytest -s test_lm_eval_correctness.py
"""
import os
from pathlib import Path
import lm_eval
import numpy as np
import numpy
import pytest
import yaml
RTOL = 0.08
TEST_DATA_FILE = os.environ.get(
"LM_EVAL_TEST_DATA_FILE",
".buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct.yaml")
TP_SIZE = os.environ.get("LM_EVAL_TP_SIZE", 1)
def launch_lm_eval(eval_config, tp_size):
def launch_lm_eval(eval_config):
trust_remote_code = eval_config.get('trust_remote_code', False)
model_args = f"pretrained={eval_config['model_name']}," \
f"tensor_parallel_size={tp_size}," \
f"enforce_eager=true," \
f"tensor_parallel_size={TP_SIZE}," \
f"add_bos_token=true," \
f"trust_remote_code={trust_remote_code}"
results = lm_eval.simple_evaluate(
model="vllm",
model_args=model_args,
@ -29,14 +39,22 @@ def launch_lm_eval(eval_config, tp_size):
num_fewshot=eval_config["num_fewshot"],
limit=eval_config["limit"],
batch_size="auto")
return results
def test_lm_eval_correctness_param(config_filename, tp_size):
eval_config = yaml.safe_load(config_filename.read_text(encoding="utf-8"))
def test_lm_eval_correctness():
eval_config = yaml.safe_load(
Path(TEST_DATA_FILE).read_text(encoding="utf-8"))
results = launch_lm_eval(eval_config, tp_size)
if eval_config[
"model_name"] == "nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform": #noqa: E501
pytest.skip("FBGEMM is currently failing on main.")
# Launch eval requests.
results = launch_lm_eval(eval_config)
# Confirm scores match ground truth.
success = True
for task in eval_config["tasks"]:
for metric in task["metrics"]:
@ -44,7 +62,8 @@ def test_lm_eval_correctness_param(config_filename, tp_size):
measured_value = results["results"][task["name"]][metric["name"]]
print(f'{task["name"]} | {metric["name"]}: '
f'ground_truth={ground_truth} | measured={measured_value}')
success = success and np.isclose(
success = success and numpy.isclose(
ground_truth, measured_value, rtol=RTOL)
# Assert at the end, print all scores even on failure for debugging.
assert success

View File

@ -1,6 +1,6 @@
#!/bin/bash
set -xu
set -xue
# Build the docker image.
docker build -f docker/Dockerfile.tpu -t vllm-tpu .
@ -24,80 +24,31 @@ docker run --privileged --net host --shm-size=16G -it \
&& export VLLM_XLA_CHECK_RECOMPILATION=1 \
&& echo HARDWARE \
&& tpu-info \
&& { \
echo TEST_0: Running test_perf.py; \
pytest -s -v /workspace/vllm/tests/tpu/test_perf.py; \
echo TEST_0_EXIT_CODE: \$?; \
} & \
&& { \
echo TEST_1: Running test_compilation.py; \
pytest -s -v /workspace/vllm/tests/tpu/test_compilation.py; \
echo TEST_1_EXIT_CODE: \$?; \
} & \
{ \
echo TEST_2: Running test_basic.py; \
pytest -s -v /workspace/vllm/tests/v1/tpu/test_basic.py; \
echo TEST_2_EXIT_CODE: \$?; \
} & \
{ \
echo TEST_3: Running test_accuracy.py::test_lm_eval_accuracy_v1_engine; \
pytest -s -v /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine; \
echo TEST_3_EXIT_CODE: \$?; \
} & \
{ \
echo TEST_4: Running test_quantization_accuracy.py; \
pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py; \
echo TEST_4_EXIT_CODE: \$?; \
} & \
{ \
echo TEST_5: Running examples/offline_inference/tpu.py; \
python3 /workspace/vllm/examples/offline_inference/tpu.py; \
echo TEST_5_EXIT_CODE: \$?; \
} & \
{ \
echo TEST_6: Running test_tpu_model_runner.py; \
pytest -s -v /workspace/vllm/tests/tpu/worker/test_tpu_model_runner.py; \
echo TEST_6_EXIT_CODE: \$?; \
} & \
&& { \
echo TEST_7: Running test_sampler.py; \
pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py; \
echo TEST_7_EXIT_CODE: \$?; \
} & \
&& { \
echo TEST_8: Running test_topk_topp_sampler.py; \
pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py; \
echo TEST_8_EXIT_CODE: \$?; \
} & \
&& { \
echo TEST_9: Running test_multimodal.py; \
pytest -s -v /workspace/vllm/tests/v1/tpu/test_multimodal.py; \
echo TEST_9_EXIT_CODE: \$?; \
} & \
&& { \
echo TEST_10: Running test_pallas.py; \
pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py; \
echo TEST_10_EXIT_CODE: \$?; \
} & \
&& { \
echo TEST_11: Running test_struct_output_generate.py; \
pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py; \
echo TEST_11_EXIT_CODE: \$?; \
} & \
&& { \
echo TEST_12: Running test_moe_pallas.py; \
pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py; \
echo TEST_12_EXIT_CODE: \$?; \
} & \
# Disable the TPU LoRA tests until the feature is activated
# && { \
# echo TEST_13: Running test_moe_pallas.py; \
# pytest -s -v /workspace/vllm/tests/tpu/lora/; \
# echo TEST_13_EXIT_CODE: \$?; \
# } & \
wait \
&& echo 'All tests have attempted to run. Check logs for individual test statuses and exit codes.' \
"
&& echo TEST_0 \
&& pytest -v -s /workspace/vllm/tests/v1/tpu/test_perf.py \
&& echo TEST_1 \
&& pytest -v -s /workspace/vllm/tests/tpu/test_compilation.py \
&& echo TEST_2 \
&& pytest -v -s /workspace/vllm/tests/v1/tpu/test_basic.py \
&& echo TEST_3 \
&& pytest -v -s /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine \
&& echo TEST_4 \
&& pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py \
&& echo TEST_5 \
&& python3 /workspace/vllm/examples/offline_inference/tpu.py \
&& echo TEST_6 \
&& pytest -s -v /workspace/vllm/tests/v1/tpu/worker/test_tpu_model_runner.py \
&& echo TEST_7 \
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py \
&& echo TEST_8 \
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py \
&& echo TEST_9 \
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_multimodal.py \
&& echo TEST_10 \
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py \
&& echo TEST_11 \
&& pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py" \
# TODO: This test fails because it uses RANDOM_SEED sampling
# && VLLM_USE_V1=1 pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \

View File

@ -408,7 +408,7 @@ steps:
- vllm/model_executor/layers/quantization
commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-small.txt --tp-size=1
- bash ./run-tests.sh -c configs/models-small.txt -t 1
- label: OpenAI API correctness
source_file_dependencies:
@ -713,4 +713,4 @@ steps:
- vllm/model_executor/layers/quantization
commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4
- bash ./run-tests.sh -c configs/models-large.txt -t 4

View File

@ -75,7 +75,7 @@ body:
```
```
The error message you got, with the full traceback and the error logs with [dump_input.py:##] if present.
The error message you got, with the full traceback.
```
validations:
required: true

View File

@ -101,8 +101,8 @@ repos:
args:
- -c
- |
if ! grep -q "^Signed-off-by: $(git config user.name) <$(git config user.email)>" "$(git rev-parse --git-path COMMIT_EDITMSG)"; then
printf "\nSigned-off-by: $(git config user.name) <$(git config user.email)>\n" >> "$(git rev-parse --git-path COMMIT_EDITMSG)"
if ! grep -q "^Signed-off-by: $(git config user.name) <$(git config user.email)>" .git/COMMIT_EDITMSG; then
printf "\nSigned-off-by: $(git config user.name) <$(git config user.email)>\n" >> .git/COMMIT_EDITMSG
fi
language: system
verbose: true
@ -125,6 +125,8 @@ repos:
name: Update Dockerfile dependency graph
entry: tools/update-dockerfile-graph.sh
language: script
files: ^docker/Dockerfile$
pass_filenames: false
# Keep `suggestion` last
- id: suggestion
name: Suggestion

View File

@ -250,7 +250,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
set(CUTLASS_REVISION "v3.9.2" CACHE STRING "CUTLASS revision to use")
set(CUTLASS_REVISION "v3.9.1" CACHE STRING "CUTLASS revision to use")
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
@ -301,52 +301,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# are not supported by Machete yet.
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
if (MARLIN_ARCHS)
#
# For the Marlin kernels we automatically generate sources for various
# preselected input type pairs and schedules.
# Generate sources:
set(MARLIN_GEN_SCRIPT
${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/gptq_marlin/generate_kernels.py)
file(MD5 ${MARLIN_GEN_SCRIPT} MARLIN_GEN_SCRIPT_HASH)
message(STATUS "Marlin generation script hash: ${MARLIN_GEN_SCRIPT_HASH}")
message(STATUS "Last run Marlin generate script hash: $CACHE{MARLIN_GEN_SCRIPT_HASH}")
if (NOT DEFINED CACHE{MARLIN_GEN_SCRIPT_HASH}
OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MARLIN_GEN_SCRIPT_HASH})
execute_process(
COMMAND ${CMAKE_COMMAND} -E env
PYTHONPATH=$PYTHONPATH
${Python_EXECUTABLE} ${MARLIN_GEN_SCRIPT}
RESULT_VARIABLE marlin_generation_result
OUTPUT_VARIABLE marlin_generation_result
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log
ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log
)
if (NOT marlin_generation_result EQUAL 0)
message(FATAL_ERROR "Marlin generation failed."
" Result: \"${marlin_generation_result}\""
"\nCheck the log for details: "
"${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log")
else()
set(MARLIN_GEN_SCRIPT_HASH ${MARLIN_GEN_SCRIPT_HASH}
CACHE STRING "Last run Marlin generate script hash" FORCE)
message(STATUS "Marlin generation completed successfully.")
endif()
else()
message(STATUS "Marlin generation script has not changed, skipping generation.")
endif()
file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/kernel_*.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
CUDA_ARCHS "${MARLIN_ARCHS}")
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
set(MARLIN_SRCS
"csrc/quantization/fp8/fp8_marlin.cu"
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
"csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu"
@ -418,7 +374,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set(SRCS
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu"
)
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
@ -689,7 +644,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH})
execute_process(
COMMAND ${CMAKE_COMMAND} -E env
PYTHONPATH=$PYTHONPATH
PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH
${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT}
RESULT_VARIABLE moe_marlin_generation_result
OUTPUT_VARIABLE moe_marlin_generation_output

View File

@ -16,20 +16,18 @@ Easy, fast, and cheap LLM serving for everyone
---
*Latest News* 🔥
- [2025/05] We hosted [NYC vLLM Meetup](https://lu.ma/c1rqyf1f)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing).
- [2025/05] vLLM is now a hosted project under PyTorch Foundation! Please find the announcement [here](https://pytorch.org/blog/pytorch-foundation-welcomes-vllm/).
- [2025/04] We hosted [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing).
- [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html).
<details>
<summary>Previous News</summary>
- [2025/03] We hosted [vLLM x Ollama Inference Night](https://lu.ma/vllm-ollama)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/16T2PDD1YwRnZ4Tu8Q5r6n53c5Lr5c73UV9Vd2_eBo4U/edit?usp=sharing).
- [2025/03] We hosted [the first vLLM China Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg)! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1REHvfQMKGnvz6p3Fd23HhSO4c8j5WPGZV0bKYLwnHyQ/edit?usp=sharing).
- [2025/03] We hosted [the East Coast vLLM Meetup](https://lu.ma/7mu4k4xx)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1NHiv8EUFF1NLd3fEYODm56nDmL26lEeXCaDgyDlTsRs/edit#slide=id.g31441846c39_0_0).
- [2025/02] We hosted [the ninth vLLM meetup](https://lu.ma/h7g3kuj9) with Meta! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1jzC_PZVXrVNSFVCW-V4cFXb6pn7zZ2CyP_Flwo05aqg/edit?usp=sharing) and AMD [here](https://drive.google.com/file/d/1Zk5qEJIkTmlQ2eQcXQZlljAx3m9s7nwn/view?usp=sharing). The slides from Meta will not be posted.
- [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html).
- [2025/01] We hosted [the eighth vLLM meetup](https://lu.ma/zep56hui) with Google Cloud! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1epVkt4Zu8Jz_S5OhEHPc798emsYh2BwYfRuDDVEF7u4/edit?usp=sharing), and Google Cloud team [here](https://drive.google.com/file/d/1h24pHewANyRL11xy5dXUbvRC9F9Kkjix/view?usp=sharing).
- [2024/12] vLLM joins [pytorch ecosystem](https://pytorch.org/blog/vllm-joins-pytorch)! Easy, Fast, and Cheap LLM Serving for Everyone!
<details>
<summary>Previous News</summary>
- [2024/11] We hosted [the seventh vLLM meetup](https://lu.ma/h0qvrajz) with Snowflake! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1e3CxQBV3JsfGp30SwyvS3eM_tW-ghOhJ9PAJGK6KR54/edit?usp=sharing), and Snowflake team [here](https://docs.google.com/presentation/d/1qF3RkDAbOULwz9WK5TOltt2fE9t6uIc_hVNLFAaQX6A/edit?usp=sharing).
- [2024/10] We have just created a developer slack ([slack.vllm.ai](https://slack.vllm.ai)) focusing on coordinating contributions and discussing features. Please feel free to join us there!
- [2024/10] Ray Summit 2024 held a special track for vLLM! Please find the opening talk slides from the vLLM team [here](https://docs.google.com/presentation/d/1B_KQxpHBTRa_mDF-tR6i8rWdOU5QoTZNcEg2MKZxEHM/edit?usp=sharing). Learn more from the [talks](https://www.youtube.com/playlist?list=PLzTswPQNepXl6AQwifuwUImLPFRVpksjR) from other vLLM contributors and users!

View File

@ -315,15 +315,13 @@ class RandomDataset(BenchmarkDataset):
)
vocab_size = tokenizer.vocab_size
num_special_tokens = tokenizer.num_special_tokens_to_add()
real_input_len = input_len - num_special_tokens
prefix_token_ids = (np.random.randint(
0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else [])
# New sampling logic: [X * (1 - b), X * (1 + b)]
input_low = int(real_input_len * (1 - range_ratio))
input_high = int(real_input_len * (1 + range_ratio))
input_low = int(input_len * (1 - range_ratio))
input_high = int(input_len * (1 + range_ratio))
output_low = int(output_len * (1 - range_ratio))
output_high = int(output_len * (1 + range_ratio))
@ -346,17 +344,6 @@ class RandomDataset(BenchmarkDataset):
vocab_size).tolist()
token_sequence = prefix_token_ids + inner_seq
prompt = tokenizer.decode(token_sequence)
# After decoding the prompt we have to encode and decode it again.
# This is done because in some cases N consecutive tokens
# give a string tokenized into != N number of tokens.
# For example for GPT2Tokenizer:
# [6880, 6881] -> ['Ġcalls', 'here'] ->
# [1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
# To avoid uncontrolled change of the prompt length,
# the encoded sequence is truncated before being decode again.
re_encoded_sequence = tokenizer.encode(
prompt, add_special_tokens=False)[:input_lens[i]]
prompt = tokenizer.decode(re_encoded_sequence)
total_input_len = prefix_len + int(input_lens[i])
requests.append(
SampleRequest(
@ -887,94 +874,6 @@ class AIMODataset(HuggingFaceDataset):
return sampled_requests
# -----------------------------------------------------------------------------
# Next Edit Prediction Dataset Implementation
# -----------------------------------------------------------------------------
zeta_prompt = """### Instruction:
You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location.
### User Edits:
{}
### User Excerpt:
{}
### Response:
""" # noqa: E501
def _format_zeta_prompt(
sample: dict,
original_start_marker: str = "<|editable_region_start|>") -> dict:
"""Format the zeta prompt for the Next Edit Prediction (NEP) dataset.
This function formats examples from the NEP dataset
into prompts and expected outputs. It could be
further extended to support more NEP datasets.
Args:
sample: The dataset sample containing events,
inputs, and outputs.
original_start_marker: The marker indicating the
start of the editable region. Defaults to
"<|editable_region_start|>".
Returns:
A dictionary with the formatted prompts and expected outputs.
"""
events = sample["events"]
input = sample["input"]
output = sample["output"]
prompt = zeta_prompt.format(events, input)
# following the original implementation, extract the focused region
# from the raw output
output_start_index = output.find(original_start_marker)
output_focused_region = output[output_start_index:]
expected_output = output_focused_region
return {"prompt": prompt, "expected_output": expected_output}
class NextEditPredictionDataset(HuggingFaceDataset):
"""
Dataset class for processing a Next Edit Prediction dataset.
"""
SUPPORTED_DATASET_PATHS = {
"zed-industries/zeta",
}
MAPPING_PROMPT_FUNCS = {
"zed-industries/zeta": _format_zeta_prompt,
}
def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int,
**kwargs):
formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(
self.dataset_path)
if formatting_prompt_func is None:
raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
samples = []
for sample in self.data:
sample = formatting_prompt_func(sample)
samples.append(
SampleRequest(
prompt=sample["prompt"],
prompt_len=len(tokenizer(sample["prompt"]).input_ids),
expected_output_len=len(
tokenizer(sample["expected_output"]).input_ids),
))
if len(samples) >= num_requests:
break
self.maybe_oversample_requests(samples, num_requests)
return samples
# -----------------------------------------------------------------------------
# ASR Dataset Implementation
# -----------------------------------------------------------------------------

View File

@ -53,9 +53,8 @@ except ImportError:
from benchmark_dataset import (AIMODataset, ASRDataset, BurstGPTDataset,
ConversationDataset, HuggingFaceDataset,
InstructCoderDataset, MTBenchDataset,
NextEditPredictionDataset, RandomDataset,
SampleRequest, ShareGPTDataset, SonnetDataset,
VisionArenaDataset)
RandomDataset, SampleRequest, ShareGPTDataset,
SonnetDataset, VisionArenaDataset)
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
@ -604,9 +603,6 @@ def main(args: argparse.Namespace):
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
dataset_class = AIMODataset
args.hf_split = "train"
elif args.dataset_path in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS: # noqa: E501
dataset_class = NextEditPredictionDataset
args.hf_split = "train"
elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS:
dataset_class = ASRDataset
args.hf_split = "train"

View File

@ -414,6 +414,7 @@ async def benchmark(
ignore_eos: bool,
max_concurrency: Optional[int],
structured_output_ratio: float,
structured_output_backend: str,
goodput_config_dict: Optional[dict[str, float]] = None,
):
if backend in ASYNC_REQUEST_FUNCS:
@ -425,6 +426,8 @@ async def benchmark(
extra_body = {}
# Add the schema to the extra_body
extra_body[request.structure_type] = request.schema
# Add the specific structured_output_backend
extra_body["guided_decoding_backend"] = structured_output_backend
return extra_body
print("Starting initial single prompt test run...")
@ -782,6 +785,7 @@ def main(args: argparse.Namespace):
ignore_eos=args.ignore_eos,
max_concurrency=args.max_concurrency,
structured_output_ratio=args.structured_output_ratio,
structured_output_backend=args.structured_output_backend,
goodput_config_dict=goodput_config_dict,
))
@ -996,6 +1000,14 @@ if __name__ == "__main__":
type=float,
default=1.0,
help="Ratio of Structured Outputs requests")
parser.add_argument("--structured-output-backend",
type=str,
choices=[
"outlines", "lm-format-enforcer", "xgrammar",
"guidance", "auto"
],
default="auto",
help="Backend to use for structured outputs")
args = parser.parse_args()
main(args)

View File

@ -6,17 +6,16 @@ import time
from contextlib import nullcontext
from datetime import datetime
from itertools import product
from types import SimpleNamespace
from typing import Any, TypedDict
import ray
import torch
import triton
from ray.experimental.tqdm_ray import tqdm
from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config
from vllm.triton_utils import triton
from vllm.utils import FlexibleArgumentParser
FP8_DTYPE = current_platform.fp8_dtype()
@ -535,12 +534,8 @@ def get_weight_block_size_safety(config, default_value=None):
def main(args: argparse.Namespace):
print(args)
config = get_config(model=args.model,
trust_remote_code=args.trust_remote_code)
if args.model_prefix:
config = getattr(config, args.model_prefix)
config = SimpleNamespace(**config)
config = AutoConfig.from_pretrained(
args.model, trust_remote_code=args.trust_remote_code)
if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k
@ -551,14 +546,15 @@ def main(args: argparse.Namespace):
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif (config.architectures[0]
in ("DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM")):
elif (config.architectures[0] == "DeepseekV3ForCausalLM"
or config.architectures[0] == "DeepseekV2ForCausalLM"):
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] in ("Qwen2MoeForCausalLM",
"Qwen3MoeForCausalLM"):
elif config.architectures[0] in [
"Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"
]:
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
@ -573,8 +569,7 @@ def main(args: argparse.Namespace):
shard_intermediate_size = 2 * intermediate_size // args.tp_size
hidden_size = config.hidden_size
dtype = torch.float16 if current_platform.is_rocm() else getattr(
torch, config.torch_dtype)
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16"
block_quant_shape = get_weight_block_size_safety(config)
@ -664,7 +659,6 @@ if __name__ == "__main__":
parser.add_argument("--batch-size", type=int, required=False)
parser.add_argument("--tune", action="store_true")
parser.add_argument("--trust-remote-code", action="store_true")
parser.add_argument("--model-prefix", type=str, required=False)
args = parser.parse_args()
main(args)

View File

@ -4,11 +4,11 @@ import itertools
from typing import Optional, Union
import torch
import triton
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
from torch import nn
from vllm import _custom_ops as vllm_ops
from vllm.triton_utils import triton
class HuggingFaceRMSNorm(nn.Module):

View File

@ -6,13 +6,13 @@ import time
# Import DeepGEMM functions
import deep_gemm
import torch
import triton
from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor
# Import vLLM functions
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
from vllm.triton_utils import triton
# Copied from

View File

@ -9,10 +9,13 @@ BACKEND=${2:-"vllm"}
# Define the dataset to use
DATASET=${3:-"xgrammar_bench"}
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
OUTPUT_DIR=${4:-"$SCRIPT_DIR/structured_output_benchmark_results"}
# Define the guided decoding backend
GUIDED_BACKEND=${4:-"xgrammar"}
GUIDED_RATIO=${5:-0.5}
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
OUTPUT_DIR=${5:-"$SCRIPT_DIR/structured_output_benchmark_results"}
GUIDED_RATIO=${6:-0.5}
# Create output directory if it doesn't exist
mkdir -p "$OUTPUT_DIR"
@ -24,6 +27,7 @@ QPS_VALUES=(70 60 50 25 20 15 10)
COMMON_PARAMS="--backend $BACKEND \
--model $MODEL \
--dataset $DATASET \
--structured-output-backend $GUIDED_BACKEND \
--structured-output-ratio $GUIDED_RATIO \
--save-results \
--result-dir $OUTPUT_DIR"
@ -31,6 +35,7 @@ COMMON_PARAMS="--backend $BACKEND \
echo "Starting structured output benchmark with model: $MODEL"
echo "Backend: $BACKEND"
echo "Dataset: $DATASET"
echo "Structured output backend: $GUIDED_BACKEND"
echo "Results will be saved to: $OUTPUT_DIR"
echo "----------------------------------------"
@ -43,7 +48,7 @@ for qps in "${QPS_VALUES[@]}"; do
GIT_BRANCH=$(git rev-parse --abbrev-ref HEAD 2>/dev/null || echo "unknown")
# Construct filename for this run
FILENAME="${BACKEND}_${qps}qps_$(basename $MODEL)_${DATASET}_${GIT_HASH}.json"
FILENAME="${GUIDED_BACKEND}_${BACKEND}_${qps}qps_$(basename $MODEL)_${DATASET}_${GIT_HASH}.json"
# Run the benchmark
python "$SCRIPT_DIR/benchmark_serving_structured_output.py" $COMMON_PARAMS \

View File

@ -167,33 +167,6 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
FetchContent_MakeAvailable(oneDNN)
list(APPEND LIBS dnnl)
elseif(POWER10_FOUND)
FetchContent_Declare(
oneDNN
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
GIT_TAG v3.7.2
GIT_PROGRESS TRUE
GIT_SHALLOW TRUE
)
set(ONEDNN_LIBRARY_TYPE "STATIC")
set(ONEDNN_BUILD_DOC "OFF")
set(ONEDNN_BUILD_EXAMPLES "OFF")
set(ONEDNN_BUILD_TESTS "OFF")
set(ONEDNN_ENABLE_WORKLOAD "INFERENCE")
set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER")
set(ONEDNN_BUILD_GRAPH "OFF")
set(ONEDNN_ENABLE_JIT_PROFILING "OFF")
set(ONEDNN_ENABLE_ITT_TASKS "OFF")
set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF")
set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF")
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
set(DNNL_CPU_RUNTIME "OMP")
FetchContent_MakeAvailable(oneDNN)
list(APPEND LIBS dnnl)
endif()
@ -224,10 +197,6 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
"csrc/cpu/quant.cpp"
"csrc/cpu/shm.cpp"
${VLLM_EXT_SRC})
elseif(POWER10_FOUND)
set(VLLM_EXT_SRC
"csrc/cpu/quant.cpp"
${VLLM_EXT_SRC})
endif()
#
@ -245,4 +214,4 @@ define_gpu_extension_target(
WITH_SOABI
)
message(STATUS "Enabling C extension.")
message(STATUS "Enabling C extension.")

View File

@ -4,7 +4,6 @@
#include <altivec.h>
#include <cmath>
#include <algorithm>
#include <torch/all.h>
namespace vec_op {
@ -63,10 +62,6 @@ typedef struct f32x4x4_t {
__vector float val[4];
} f32x4x4_t;
typedef struct i32x4x4_t {
__vector int32_t val[4];
} i32x4x4_t;
struct FP32Vec8;
struct FP32Vec16;
@ -103,28 +98,6 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
vec_xst(reg.val[0], 0, (signed short*)ptr);
vec_xst(reg.val[1], 16, (signed short*)ptr);
}
void save(void* ptr, const int elem_num) const {
const int clamped_elem = std::max(0, std::min(elem_num, 16));
// Calculate elements to store in each 128-bit part (8 elements each)
const int elements_val0 = std::min(clamped_elem, 8);
const int elements_val1 = std::max(clamped_elem - 8, 0);
// Convert elements to bytes (2 bytes per element)
const size_t bytes_val0 = elements_val0 * sizeof(signed short);
const size_t bytes_val1 = elements_val1 * sizeof(signed short);
signed short* dest = static_cast<signed short*>(ptr);
// Store the first part using vec_xst_len
if (bytes_val0 > 0) {
vec_xst_len(reg.val[0], dest, bytes_val0);
}
// Store the second part if needed
if (bytes_val1 > 0) {
vec_xst_len(reg.val[1], dest + elements_val0, bytes_val1);
}
}
};
const static __vector signed short zero = vec_splats((signed short)0);
@ -284,64 +257,6 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
}
};
struct INT32Vec16 : public Vec<INT32Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
i32x4x4_t reg;
int32_t values[VEC_ELEM_NUM];
};
i32x4x4_t reg;
explicit INT32Vec16(const void* data_ptr) {
reg.val[0] = vec_xl(0, reinterpret_cast<const __vector int32_t*>(data_ptr));
reg.val[1] =
vec_xl(16, reinterpret_cast<const __vector int32_t*>(data_ptr));
reg.val[2] =
vec_xl(32, reinterpret_cast<const __vector int32_t*>(data_ptr));
reg.val[3] =
vec_xl(48, reinterpret_cast<const __vector int32_t*>(data_ptr));
}
void save(int32_t* ptr) const {
vec_xst(reg.val[0], 0, reinterpret_cast<__vector int32_t*>(ptr));
vec_xst(reg.val[1], 16, reinterpret_cast<__vector int32_t*>(ptr));
vec_xst(reg.val[2], 32, reinterpret_cast<__vector int32_t*>(ptr));
vec_xst(reg.val[3], 48, reinterpret_cast<__vector int32_t*>(ptr));
}
void save(int32_t* ptr, const int elem_num) const {
const int elements_in_chunk1 =
(elem_num >= 0) ? ((elem_num >= 4) ? 4 : elem_num) : 0;
const int elements_in_chunk2 =
(elem_num > 4) ? ((elem_num >= 8) ? 4 : elem_num - 4) : 0;
const int elements_in_chunk3 =
(elem_num > 8) ? ((elem_num >= 12) ? 4 : elem_num - 8) : 0;
const int elements_in_chunk4 =
(elem_num > 12) ? ((elem_num >= 16) ? 4 : elem_num - 12) : 0;
const size_t bytes_chunk1 =
static_cast<size_t>(elements_in_chunk1 * sizeof(int32_t));
const size_t bytes_chunk2 =
static_cast<size_t>(elements_in_chunk2 * sizeof(int32_t));
const size_t bytes_chunk3 =
static_cast<size_t>(elements_in_chunk3 * sizeof(int32_t));
const size_t bytes_chunk4 =
static_cast<size_t>(elements_in_chunk4 * sizeof(int32_t));
vec_xst_len(reg.val[0], reinterpret_cast<int32_t*>(ptr), bytes_chunk1);
vec_xst_len(reg.val[1],
reinterpret_cast<int32_t*>(reinterpret_cast<char*>(ptr) + 16),
bytes_chunk2);
vec_xst_len(reg.val[2],
reinterpret_cast<int32_t*>(reinterpret_cast<char*>(ptr) + 32),
bytes_chunk3);
vec_xst_len(reg.val[3],
reinterpret_cast<int32_t*>(reinterpret_cast<char*>(ptr) + 48),
bytes_chunk4);
}
};
struct FP32Vec16 : public Vec<FP32Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
@ -404,13 +319,6 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
explicit FP32Vec16(const INT32Vec16& v) {
reg.val[0] = vec_ctf(v.reg.val[0], 0);
reg.val[1] = vec_ctf(v.reg.val[1], 0);
reg.val[2] = vec_ctf(v.reg.val[2], 0);
reg.val[3] = vec_ctf(v.reg.val[3], 0);
}
FP32Vec16 operator*(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({vec_mul(reg.val[0], b.reg.val[0]),
vec_mul(reg.val[1], b.reg.val[1]),
@ -439,117 +347,6 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
vec_div(reg.val[3], b.reg.val[3])}));
}
FP32Vec16 clamp(const FP32Vec16& min, const FP32Vec16& max) const {
return FP32Vec16(f32x4x4_t(
{vec_min(max.reg.val[0], vec_max(min.reg.val[0], reg.val[0])),
vec_min(max.reg.val[1], vec_max(min.reg.val[1], reg.val[1])),
vec_min(max.reg.val[2], vec_max(min.reg.val[2], reg.val[2])),
vec_min(max.reg.val[3], vec_max(min.reg.val[3], reg.val[3]))}));
}
FP32Vec16 max(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({vec_max(reg.val[0], b.reg.val[0]),
vec_max(reg.val[1], b.reg.val[1]),
vec_max(reg.val[2], b.reg.val[2]),
vec_max(reg.val[3], b.reg.val[3])}));
}
FP32Vec16 max(const FP32Vec16& b, int elem_num) const {
FP32Vec16 result;
// Create a vector of element indices for each chunk
__vector unsigned int indices = {0, 1, 2, 3};
__vector unsigned int elem_num_vec =
vec_splats(static_cast<unsigned int>(elem_num));
// Compute masks for each chunk
__vector unsigned int chunk_offset0 = {0, 0, 0,
0}; // Chunk 0: Elements 0-3
__vector unsigned int chunk_offset1 = {4, 4, 4,
4}; // Chunk 1: Elements 4-7
__vector unsigned int chunk_offset2 = {8, 8, 8,
8}; // Chunk 2: Elements 8-11
__vector unsigned int chunk_offset3 = {12, 12, 12,
12}; // Chunk 3: Elements 12-15
// Compute masks for each chunk
__vector bool int mask0 = vec_cmplt(indices + chunk_offset0, elem_num_vec);
__vector bool int mask1 = vec_cmplt(indices + chunk_offset1, elem_num_vec);
__vector bool int mask2 = vec_cmplt(indices + chunk_offset2, elem_num_vec);
__vector bool int mask3 = vec_cmplt(indices + chunk_offset3, elem_num_vec);
// Apply masks to compute the result for each chunk
result.reg.val[0] = vec_sel(this->reg.val[0],
vec_max(this->reg.val[0], b.reg.val[0]), mask0);
result.reg.val[1] = vec_sel(this->reg.val[1],
vec_max(this->reg.val[1], b.reg.val[1]), mask1);
result.reg.val[2] = vec_sel(this->reg.val[2],
vec_max(this->reg.val[2], b.reg.val[2]), mask2);
result.reg.val[3] = vec_sel(this->reg.val[3],
vec_max(this->reg.val[3], b.reg.val[3]), mask3);
return FP32Vec16(result.reg);
}
FP32Vec16 min(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({vec_min(reg.val[0], b.reg.val[0]),
vec_min(reg.val[1], b.reg.val[1]),
vec_min(reg.val[2], b.reg.val[2]),
vec_min(reg.val[3], b.reg.val[3])}));
}
FP32Vec16 min(const FP32Vec16& b, int elem_num) const {
FP32Vec16 result;
vector unsigned int indices = {0, 1, 2, 3};
vector unsigned int elem_num_vec =
vec_splats(static_cast<unsigned int>(elem_num));
vector unsigned int chunk_offset0 = {0, 0, 0, 0};
vector unsigned int chunk_offset1 = {4, 4, 4, 4};
vector unsigned int chunk_offset2 = {8, 8, 8, 8};
vector unsigned int chunk_offset3 = {12, 12, 12, 12};
vector bool int mask0 = vec_cmplt(indices + chunk_offset0, elem_num_vec);
vector bool int mask1 = vec_cmplt(indices + chunk_offset1, elem_num_vec);
vector bool int mask2 = vec_cmplt(indices + chunk_offset2, elem_num_vec);
vector bool int mask3 = vec_cmplt(indices + chunk_offset3, elem_num_vec);
result.reg.val[0] = vec_sel(this->reg.val[0],
vec_min(this->reg.val[0], b.reg.val[0]), mask0);
result.reg.val[1] = vec_sel(this->reg.val[1],
vec_min(this->reg.val[1], b.reg.val[1]), mask1);
result.reg.val[2] = vec_sel(this->reg.val[2],
vec_min(this->reg.val[2], b.reg.val[2]), mask2);
result.reg.val[3] = vec_sel(this->reg.val[3],
vec_min(this->reg.val[3], b.reg.val[3]), mask3);
return FP32Vec16(result.reg);
}
FP32Vec16 abs() const {
return FP32Vec16(f32x4x4_t({vec_abs(reg.val[0]), vec_abs(reg.val[1]),
vec_abs(reg.val[2]), vec_abs(reg.val[3])}));
}
float reduce_max() {
__vector float max01 = vec_max(reg.val[0], reg.val[1]);
__vector float max23 = vec_max(reg.val[2], reg.val[3]);
__vector float max_all = vec_max(max01, max23);
__vector float temp = vec_max(max_all, vec_sld(max_all, max_all, 8));
temp = vec_max(temp, vec_sld(temp, temp, 4));
return vec_extract(temp, 0);
}
float reduce_min() {
__vector float min01 = vec_min(reg.val[0], reg.val[1]);
__vector float min23 = vec_min(reg.val[2], reg.val[3]);
__vector float min_all = vec_min(min01, min23);
__vector float temp = vec_min(min_all, vec_sld(min_all, min_all, 8));
temp = vec_min(temp, vec_sld(temp, temp, 4));
return vec_extract(temp, 0);
}
float reduce_sum() const {
AliasReg ar;
ar.reg = reg;
@ -580,68 +377,6 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
vec_xst(reg.val[2], 32, ptr);
vec_xst(reg.val[3], 48, ptr);
}
void save(float* ptr, const int elem_num) const {
const int elements_in_chunk1 =
(elem_num >= 0) ? ((elem_num >= 4) ? 4 : elem_num) : 0;
const int elements_in_chunk2 =
(elem_num > 4) ? ((elem_num >= 8) ? 4 : elem_num - 4) : 0;
const int elements_in_chunk3 =
(elem_num > 8) ? ((elem_num >= 12) ? 4 : elem_num - 8) : 0;
const int elements_in_chunk4 =
(elem_num > 12) ? ((elem_num >= 16) ? 4 : elem_num - 12) : 0;
const size_t bytes_chunk1 =
static_cast<size_t>(elements_in_chunk1 * sizeof(float));
const size_t bytes_chunk2 =
static_cast<size_t>(elements_in_chunk2 * sizeof(float));
const size_t bytes_chunk3 =
static_cast<size_t>(elements_in_chunk3 * sizeof(float));
const size_t bytes_chunk4 =
static_cast<size_t>(elements_in_chunk4 * sizeof(float));
vec_xst_len(reg.val[0], ptr, bytes_chunk1);
vec_xst_len(reg.val[1],
reinterpret_cast<float*>(reinterpret_cast<char*>(ptr) + 16),
bytes_chunk2);
vec_xst_len(reg.val[2],
reinterpret_cast<float*>(reinterpret_cast<char*>(ptr) + 32),
bytes_chunk3);
vec_xst_len(reg.val[3],
reinterpret_cast<float*>(reinterpret_cast<char*>(ptr) + 48),
bytes_chunk4);
}
};
struct INT8Vec16 : public Vec<INT8Vec16> {
constexpr static int VEC_NUM_ELEM = 16; // 128 bits / 8 bits = 16
union AliasReg {
__vector signed char reg;
int8_t values[VEC_NUM_ELEM];
};
__vector signed char reg;
explicit INT8Vec16(const FP32Vec16& vec) {
__vector signed int ret[4];
ret[0] = vec_cts(vec.reg.val[0], 0);
ret[1] = vec_cts(vec.reg.val[1], 0);
ret[2] = vec_cts(vec.reg.val[2], 0);
ret[3] = vec_cts(vec.reg.val[3], 0);
__vector signed short packed1 = vec_packs(ret[0], ret[1]);
__vector signed short packed2 = vec_packs(ret[2], ret[3]);
reg = vec_packs(packed1, packed2);
}
void save(void* ptr) const {
*reinterpret_cast<__vector signed char*>(ptr) = reg;
}
void save(signed char* ptr, const int elem_num) {
vec_xst_len(reg, ptr, static_cast<size_t>(elem_num));
}
};
template <typename T>

View File

@ -9,8 +9,7 @@ void rotary_embedding_impl(
scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
/// head_size] or [num_tokens, num_heads,
/// head_size]
scalar_t* __restrict__ key, // nullptr (optional) or
// [batch_size, seq_len, num_kv_heads,
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
@ -86,13 +85,10 @@ void rotary_embedding_impl(
compute_loop(token_head, cache_ptr, query);
}
if (key != nullptr) {
for (int i = 0; i < num_kv_heads; ++i) {
const int head_idx = i;
const int64_t token_head =
token_idx * key_stride + head_idx * head_size;
compute_loop(token_head, cache_ptr, key);
}
for (int i = 0; i < num_kv_heads; ++i) {
const int head_idx = i;
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
compute_loop(token_head, cache_ptr, key);
}
}
}
@ -104,8 +100,7 @@ void rotary_embedding_gptj_impl(
scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
/// head_size] or [num_tokens, num_heads,
/// head_size]
scalar_t* __restrict__ key, // nullptr (optional) or
// [batch_size, seq_len, num_kv_heads,
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
@ -143,10 +138,6 @@ void rotary_embedding_gptj_impl(
}
}
if (key == nullptr) {
return;
}
#pragma omp parallel for collapse(2)
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
for (int i = 0; i < num_kv_heads; ++i) {
@ -177,13 +168,13 @@ void rotary_embedding_gptj_impl(
}; // namespace
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
std::optional<torch::Tensor> key, int64_t head_size,
torch::Tensor& key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox) {
int num_tokens = positions.numel();
int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-1) / head_size;
int num_kv_heads = key.has_value() ? key->size(-1) / head_size : num_heads;
int64_t key_stride = key.has_value() ? key->stride(-2) : 0;
int num_kv_heads = key.size(-1) / head_size;
int64_t key_stride = key.stride(-2);
int64_t query_stride = query.stride(-2);
VLLM_DISPATCH_FLOATING_TYPES(
@ -192,15 +183,15 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
if (is_neox) {
rotary_embedding_impl(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride,
key_stride, num_heads, num_kv_heads, head_size, num_tokens);
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
head_size, num_tokens);
} else {
rotary_embedding_gptj_impl(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride,
key_stride, num_heads, num_kv_heads, head_size, num_tokens);
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
head_size, num_tokens);
}
CPU_KERNEL_GUARD_OUT(rotary_embedding_impl)

View File

@ -239,280 +239,6 @@ void static_quant_epilogue(const float* input, scalar_t* output,
}
}
template <bool AZP, bool PerChannel, bool Bias, typename scalar_t>
void dynamic_quant_epilogue(const float* input, scalar_t* output,
const float* a_scale, const float* b_scale,
const int32_t* azp, const int32_t* azp_adj,
const scalar_t* bias, const int num_tokens,
const int hidden_size) {
CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue)
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
using azp_adj_load_vec_t =
typename KernelVecType<scalar_t>::azp_adj_load_vec_type;
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
#pragma omp parallel for
for (int i = 0; i < num_tokens; ++i) {
int j = 0;
cvt_vec_t token_scale_vec(a_scale[i]);
cvt_vec_t token_zp_scale_vec;
if constexpr (AZP) {
float zp_scale_val = a_scale[i] * static_cast<float>(azp[i]);
if constexpr (!PerChannel) {
zp_scale_val *= *b_scale;
}
token_zp_scale_vec = cvt_vec_t(zp_scale_val);
}
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
cvt_vec_t elems_fp32(input + i * hidden_size + j);
elems_fp32 = elems_fp32 * token_scale_vec;
if constexpr (AZP) {
azp_adj_load_vec_t azp_adj_vec(azp_adj + j);
cvt_vec_t azp_adj_fp32(azp_adj_vec);
azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec;
if constexpr (PerChannel) {
cvt_vec_t b_scale_vec(b_scale + j);
azp_adj_fp32 = azp_adj_fp32 * b_scale_vec;
}
elems_fp32 = elems_fp32 - azp_adj_fp32;
}
if constexpr (Bias) {
load_vec_t bias_vec(bias + j);
cvt_vec_t bias_vec_fp32(bias_vec);
elems_fp32 = elems_fp32 + bias_vec_fp32;
}
load_vec_t elems_out(elems_fp32);
elems_out.save(output + i * hidden_size + j);
}
cvt_vec_t elems_fp32(input + i * hidden_size + j);
elems_fp32 = elems_fp32 * token_scale_vec;
if constexpr (AZP) {
azp_adj_load_vec_t azp_adj_vec(azp_adj + j);
cvt_vec_t azp_adj_fp32(azp_adj_vec);
azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec;
if constexpr (PerChannel) {
cvt_vec_t b_scale_vec(b_scale + j);
azp_adj_fp32 = azp_adj_fp32 * b_scale_vec;
}
elems_fp32 = elems_fp32 - azp_adj_fp32;
}
if constexpr (Bias) {
load_vec_t bias_vec(bias + j);
cvt_vec_t bias_vec_fp32(bias_vec);
elems_fp32 = elems_fp32 + bias_vec_fp32;
}
load_vec_t elems_out(elems_fp32);
elems_out.save(output + i * hidden_size + j, hidden_size - j);
}
}
#elif defined(__powerpc64__)
template <bool AZP, typename scalar_t>
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
const float* scale, const int32_t* azp,
const int num_tokens,
const int hidden_size) {
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
constexpr float i8_min =
static_cast<float>(std::numeric_limits<int8_t>::min());
constexpr float i8_max =
static_cast<float>(std::numeric_limits<int8_t>::max());
const cvt_vec_t inv_scale(1.0 / *scale);
const cvt_vec_t i8_min_vec(i8_min);
const cvt_vec_t i8_max_vec(i8_max);
cvt_vec_t zp_vec;
if constexpr (AZP) {
zp_vec = cvt_vec_t(static_cast<float>(*azp));
}
#pragma omp parallel for
for (int i = 0; i < num_tokens; ++i) {
int j = 0;
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
load_vec_t elems(input + i * hidden_size + j);
cvt_vec_t elems_fp32(elems);
elems_fp32 = elems_fp32 * inv_scale;
if constexpr (AZP) {
elems_fp32 = elems_fp32 + zp_vec;
}
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
vec_op::INT8Vec16 elems_int8(elems_fp32);
elems_int8.save(output + i * hidden_size + j);
}
load_vec_t elems(input + i * hidden_size + j);
cvt_vec_t elems_fp32(elems);
elems_fp32 = elems_fp32 * inv_scale;
if constexpr (AZP) {
elems_fp32 = elems_fp32 + zp_vec;
}
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
vec_op::INT8Vec16 elems_int8(elems_fp32);
elems_int8.save(output + i * hidden_size + j, hidden_size - j);
}
}
template <bool AZP, typename scalar_t>
void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
float* scale, int32_t* azp,
const int num_tokens,
const int hidden_size) {
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
constexpr float i8_min =
static_cast<float>(std::numeric_limits<int8_t>::min());
constexpr float i8_max =
static_cast<float>(std::numeric_limits<int8_t>::max());
const cvt_vec_t i8_min_vec(i8_min);
const cvt_vec_t i8_max_vec(i8_max);
#pragma omp parallel for
for (int i = 0; i < num_tokens; ++i) {
cvt_vec_t max_value(std::numeric_limits<float>::lowest());
cvt_vec_t min_value(std::numeric_limits<float>::max());
{
int j = 0;
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
load_vec_t elems(input + i * hidden_size + j);
cvt_vec_t elems_fp32(elems);
if constexpr (AZP) {
max_value = max_value.max(elems_fp32);
min_value = min_value.min(elems_fp32);
} else {
max_value = max_value.max(elems_fp32.abs());
}
}
load_vec_t elems(input + i * hidden_size + j);
cvt_vec_t elems_fp32(elems);
if (j + vec_elem_num == hidden_size) {
if constexpr (AZP) {
max_value = max_value.max(elems_fp32);
min_value = min_value.min(elems_fp32);
} else {
max_value = max_value.max(elems_fp32.abs());
}
} else {
if constexpr (AZP) {
max_value = max_value.max(elems_fp32, hidden_size - j);
min_value = min_value.min(elems_fp32, hidden_size - j);
} else {
max_value = max_value.max(elems_fp32.abs(), hidden_size - j);
}
}
}
float scale_val, azp_val;
if constexpr (AZP) {
float max_scalar = max_value.reduce_max();
float min_scalar = min_value.reduce_min();
scale_val = (max_scalar - min_scalar) / 255.0f;
azp_val = std::nearbyint(-128.0f - min_scalar / scale_val);
azp[i] = static_cast<int32_t>(azp_val);
scale[i] = scale_val;
} else {
scale_val = max_value.reduce_max() / 127.0f;
scale[i] = scale_val;
}
const cvt_vec_t inv_scale(1.0 / scale_val);
const cvt_vec_t azp_vec(azp_val);
{
int j = 0;
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
load_vec_t elems(input + i * hidden_size + j);
cvt_vec_t elems_fp32(elems);
elems_fp32 = (elems_fp32 * inv_scale);
if constexpr (AZP) {
elems_fp32 = elems_fp32 + azp_vec;
}
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
vec_op::INT8Vec16 elems_int8(elems_fp32);
elems_int8.save(output + i * hidden_size + j);
}
load_vec_t elems(input + i * hidden_size + j);
cvt_vec_t elems_fp32(elems);
elems_fp32 = (elems_fp32 * inv_scale);
if constexpr (AZP) {
elems_fp32 = elems_fp32 + azp_vec;
}
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
vec_op::INT8Vec16 elems_int8(elems_fp32);
elems_int8.save(output + i * hidden_size + j, hidden_size - j);
}
}
}
template <bool PerChannel, typename scalar_t>
void static_quant_epilogue(const float* input, scalar_t* output,
const float a_scale, const float* b_scale,
const int32_t* azp_with_adj, const int num_tokens,
const int hidden_size) {
CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl)
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
using azp_adj_load_vec_t =
typename KernelVecType<scalar_t>::azp_adj_load_vec_type;
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
#pragma omp parallel for
for (int i = 0; i < num_tokens; ++i) {
cvt_vec_t a_scale_vec(a_scale);
cvt_vec_t b_scale_vec(*b_scale);
cvt_vec_t scale_vec = a_scale_vec * b_scale_vec;
int j = 0;
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
cvt_vec_t elems_fp32(input + i * hidden_size + j);
azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j);
cvt_vec_t azp_adj_fp32(azp_adj_vec);
if constexpr (PerChannel) {
b_scale_vec = cvt_vec_t(b_scale + j);
scale_vec = b_scale_vec * a_scale_vec;
}
elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32;
load_vec_t elems_out(elems_fp32);
elems_out.save(output + i * hidden_size + j);
}
cvt_vec_t elems_fp32(input + i * hidden_size + j);
azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j);
cvt_vec_t azp_adj_fp32(azp_adj_vec);
if constexpr (PerChannel) {
b_scale_vec = cvt_vec_t(b_scale + j);
scale_vec = b_scale_vec * a_scale_vec;
}
elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32;
load_vec_t elems_out(elems_fp32);
elems_out.save(output + i * hidden_size + j, hidden_size - j);
}
}
template <bool AZP, bool PerChannel, bool Bias, typename scalar_t>
void dynamic_quant_epilogue(const float* input, scalar_t* output,
const float* a_scale, const float* b_scale,
@ -598,8 +324,7 @@ void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
const float* scale, const int32_t* azp,
const int num_tokens,
const int hidden_size) {
TORCH_CHECK(
false, "static_scaled_int8_quant_impl requires AVX512/powerpc64 support.")
TORCH_CHECK(false, "static_scaled_int8_quant_impl requires AVX512 support.")
}
template <typename scalar_t>
@ -607,9 +332,7 @@ void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
float* scale, int32_t* azp,
const int num_tokens,
const int hidden_size) {
TORCH_CHECK(
false,
"dynamic_scaled_int8_quant_impl requires AVX512/powerpc64 support.")
TORCH_CHECK(false, "dynamic_scaled_int8_quant_impl requires AVX512 support.")
}
template <bool PerChannel, typename scalar_t>
@ -617,7 +340,7 @@ void static_quant_epilogue(const float* input, scalar_t* output,
const float a_scale, const float* b_scale,
const int32_t* azp_with_adj, const int num_tokens,
const int hidden_size) {
TORCH_CHECK(false, "static_quant_epilogue requires AVX512/powerpc64 support.")
TORCH_CHECK(false, "static_quant_epilogue requires AVX512 support.")
}
template <typename scalar_t>
@ -626,8 +349,7 @@ void dynamic_quant_epilogue(const float* input, scalar_t* output,
const int32_t* azp, const int32_t* azp_with_adj,
const scalar_t* bias, const int num_tokens,
const int hidden_size) {
TORCH_CHECK(false,
"dynamic_quant_epilogue requires AVX512/powerpc64 support.")
TORCH_CHECK(false, "dynamic_quant_epilogue requires AVX512 support.")
}
#endif
} // namespace
@ -889,58 +611,3 @@ void dynamic_scaled_int8_quant(
}
});
}
#if defined(__powerpc64__)
void int8_scaled_mm_ppc64le(torch::Tensor& c, // [M, OC], row-major
const torch::Tensor& a, // [M, IC], row-major
const torch::Tensor& b, // [IC, OC], column-major
const torch::Tensor& a_scales,
const torch::Tensor& b_scales,
const std::optional<torch::Tensor>& bias // [OC]
) {
CPU_KERNEL_GUARD_IN(cutlass_scaled_mm)
// Checks for conformality
TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8,
"int8_scaled_mm_ppc64le only supports INT8 inputs.");
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
b.size(1) == c.size(1));
// We dont need this
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
// Check for strides and alignment
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
TORCH_CHECK(b.stride(0) == 1); // Column-major
TORCH_CHECK(c.stride(0) % 16 == 0 &&
b.stride(1) % 16 == 0); // 16 Byte Alignment
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) {
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
bias->dim() == 1);
}
VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm_ppc64le", [&] {
torch::Tensor tmp_fp32_out = torch::empty_like(c, ::at::ScalarType::Float);
// Compute C_inter=s_b * (A@B)
DNNLPrimitiveHelper<true>::gemm_s8s8_jit<float, void>(
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(),
tmp_fp32_out.data_ptr<float>(), nullptr, a.size(0), b.size(1),
a.size(1), nullptr, b_scales.data_ptr<float>(), 0, b_scales.numel());
if (bias.has_value()) {
// Compute C=s_a * C_inter + bias
dynamic_quant_epilogue<false, true, true>(
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
a_scales.data_ptr<float>(), nullptr, nullptr, nullptr,
bias->data_ptr<scalar_t>(), c.size(0), c.size(1));
} else {
// Compute C=s_a * C_inter
dynamic_quant_epilogue<false, true, false, scalar_t>(
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
a_scales.data_ptr<float>(), nullptr, nullptr, nullptr, nullptr,
c.size(0), c.size(1));
}
});
}
#endif

View File

@ -18,14 +18,6 @@ void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a,
const std::optional<torch::Tensor>& azp,
const std::optional<torch::Tensor>& bias);
#if defined(__powerpc64__)
void int8_scaled_mm_ppc64le(torch::Tensor& c, const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& a_scales,
const torch::Tensor& b_scales,
const std::optional<torch::Tensor>& bias);
#endif
void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
torch::Tensor& kv_cache, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens);
@ -125,7 +117,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops.def(
"rotary_embedding(Tensor positions, Tensor! query,"
" Tensor!? key, int head_size,"
" Tensor! key, int head_size,"
" Tensor cos_sin_cache, bool is_neox) -> ()");
ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding);
@ -158,33 +150,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor b_scales, Tensor azp_adj,"
" Tensor? azp, Tensor? bias) -> ()");
ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp);
#elif defined(__powerpc64__)
// Compute int8 quantized tensor for given scaling factor.
ops.def(
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
"Tensor? azp) -> ()");
ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant);
// Compute int8 quantized tensor and scaling factor
ops.def(
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, "
"Tensor!? azp) -> ()");
ops.impl("dynamic_scaled_int8_quant", torch::kCPU,
&dynamic_scaled_int8_quant);
// W8A8 GEMM, supporting symmetric quantization.
ops.def(
"cutlass_scaled_mm(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor? bias) -> ()");
ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm_ppc64le);
// w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
// quantization.
ops.def(
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor azp_adj,"
" Tensor? azp, Tensor? bias) -> ()");
ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp);
#endif
// SHM CCL

View File

@ -59,13 +59,3 @@ struct enable_sm90_only : Kernel {
#endif
}
};
template <typename Kernel>
struct enable_sm100_only : Kernel {
template <typename... Args>
CUTLASS_DEVICE void operator()(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1000
Kernel::operator()(std::forward<Args>(args)...);
#endif
}
};

View File

@ -140,10 +140,6 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
double epsilon) {
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(weight.is_contiguous());
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;

View File

@ -0,0 +1,31 @@
#include "marlin_moe_kernel_ku4.h"
namespace marlin_moe {
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool call_marlin_moe_kernel_ku4(
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks) {
bool has_zp = true;
if (false) {
}
AWQ_CALL_IF_MOE(vllm::kU4, 16, 4, 256)
AWQ_CALL_IF_MOE(vllm::kU4, 8, 8, 256)
AWQ_CALL_IF_MOE(vllm::kU4, 8, 4, 128)
AWQ_CALL_IF_MOE(vllm::kU4, 4, 8, 128)
else {
return false;
}
return true;
}
} // namespace marlin_moe

View File

@ -0,0 +1,20 @@
#pragma once
#include "marlin_moe_kernel.h"
namespace marlin_moe {
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool call_marlin_moe_kernel_ku4(
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks);
} // namespace marlin_moe

View File

@ -0,0 +1,31 @@
#include "marlin_moe_kernel_ku4b8.h"
namespace marlin_moe {
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool call_marlin_moe_kernel_ku4b8(
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks) {
bool has_zp = false;
if (false) {
}
GPTQ_CALL_IF_MOE(vllm::kU4B8, 16, 4, 256)
GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 8, 256)
GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 4, 128)
GPTQ_CALL_IF_MOE(vllm::kU4B8, 4, 8, 128)
else {
return false;
}
return true;
}
} // namespace marlin_moe

View File

@ -0,0 +1,20 @@
#pragma once
#include "marlin_moe_kernel.h"
namespace marlin_moe {
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool call_marlin_moe_kernel_ku4b8(
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks);
} // namespace marlin_moe

View File

@ -0,0 +1,31 @@
#include "marlin_moe_kernel_ku8b128.h"
namespace marlin_moe {
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool call_marlin_moe_kernel_ku8b128(
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks) {
bool has_zp = false;
if (false) {
}
GPTQ_CALL_IF_MOE(vllm::kU8B128, 16, 4, 256)
GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 8, 256)
GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 4, 128)
GPTQ_CALL_IF_MOE(vllm::kU8B128, 4, 8, 128)
else {
return false;
}
return true;
}
} // namespace marlin_moe

View File

@ -0,0 +1,18 @@
#pragma once
#include "marlin_moe_kernel.h"
namespace marlin_moe {
bool call_marlin_moe_kernel_ku8b128(
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks);
}

588
csrc/moe/marlin_moe_ops.cu Normal file
View File

@ -0,0 +1,588 @@
/*
* Modified by Neural Magic
* Copyright (C) Marlin.2024 Elias Frantar
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>
#include "core/exception.hpp"
#include "core/scalar_type.hpp"
#include "core/registration.h"
#include "marlin_kernels/marlin_moe_kernel_ku4b8.h"
#include "marlin_kernels/marlin_moe_kernel_ku8b128.h"
#include "marlin_kernels/marlin_moe_kernel_ku4.h"
template <typename T>
inline std::string str(T x) {
return std::to_string(x);
}
namespace marlin_moe {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// For a given "a" of size [M,K] performs a permutation of the K columns based
// on the given "perm" indices.
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
int const* __restrict__ perm_int_ptr,
int4* __restrict__ out_int4_ptr, int size_m,
int size_k, int block_rows) {
int start_row = block_rows * blockIdx.x;
int finish_row = start_row + block_rows;
if (finish_row > size_m) {
finish_row = size_m;
}
int cur_block_rows = finish_row - start_row;
int row_stride = size_k * sizeof(half) / 16;
auto permute_row = [&](int row) {
int iters = size_k / blockDim.x;
int rest = size_k % blockDim.x;
int offset = row * row_stride;
half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + offset);
half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset);
int base_k = 0;
for (int i = 0; i < iters; i++) {
int cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k];
out_half[cur_k] = a_row_half[src_pos];
base_k += blockDim.x;
}
if (rest) {
if (threadIdx.x < rest) {
int cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k];
out_half[cur_k] = a_row_half[src_pos];
}
}
};
for (int i = 0; i < cur_block_rows; i++) {
int cur_row = start_row + i;
if (cur_row < size_m) {
permute_row(cur_row);
}
}
}
__global__ void compute_expert_offsets(int const* __restrict__ topk_ids,
int* __restrict__ expert_offsets,
int topk_length, int block_size) {
int expert_id = threadIdx.x;
int num_experts = blockDim.x;
int occurrences = 0;
for (int i = 0; i < topk_length; ++i) {
occurrences += (topk_ids[i] == expert_id);
}
expert_offsets[expert_id + 1] = occurrences;
__syncthreads();
if (threadIdx.x == 0) {
int tot_offset = 0;
expert_offsets[0] = 0;
for (int i = 0; i < num_experts; ++i) {
tot_offset += ceildiv(expert_offsets[i + 1], block_size) * block_size;
expert_offsets[i + 1] = tot_offset;
}
}
__syncthreads();
}
#else
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
int const* __restrict__ perm_int_ptr,
int4* __restrict__ out_int4_ptr, int size_m,
int size_k, int block_rows) {
// Marlin is not implemented yet for SM < 8.0
assert(false);
return;
}
__global__ void compute_expert_offsets(int const* __restrict__ topk_ids,
int* __restrict__ expert_offsets,
int topk_length, int block_size) {
// Marlin is not implemented yet for SM < 8.0
assert(false);
return;
}
#endif
typedef struct {
int thread_k;
int thread_n;
int num_threads;
} thread_config_t;
typedef struct {
int max_m_blocks;
thread_config_t tb_cfg;
} exec_config_t;
thread_config_t small_batch_thread_configs[] = {
// Ordered by priority
// thread_k, thread_n, num_threads
{128, 128, 256}, // Default
{128, 64, 128}, // Reduce N 2X, same K
{64, 256, 256}, // Reduce K 2X, increase N 2X
{64, 128, 128}, // Reduce K 2X, same N
{64, 64, 128}, // Reduce both 2X
};
thread_config_t large_batch_thread_configs[] = {
// Ordered by priority
// thread_k, thread_n, num_threads
{64, 256, 256}, // Default
{128, 128, 256}, // Reduce N 2X, increase K 2X
{64, 128, 128}, // Reduce N 2X, same K
{128, 64, 128}, // Reduce N 4X, increase K 2X
{64, 64, 128}, // Reduce N 4X, same K
};
int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
int prob_n, int prob_k, int num_bits, int group_size,
bool has_act_order, bool is_k_full) {
bool cache_scales_chunk = has_act_order && !is_k_full;
int tb_n = th_config.thread_n;
int tb_k = th_config.thread_k;
// Get max scale groups per thread-block
int tb_groups;
if (group_size == -1) {
tb_groups = 1;
} else if (group_size == 0) {
tb_groups = ceildiv(tb_k, 32); // Worst case is 32 group size
} else {
tb_groups = ceildiv(tb_k, group_size);
}
if (cache_scales_chunk) {
int load_groups =
tb_groups * STAGES * 2; // Chunk size is 2x pipeline over dim K
load_groups = max(load_groups, 32); // We load at least 32 scale groups
return load_groups * tb_n * 4;
} else {
int tb_scales = tb_groups * tb_n * 2;
return tb_scales * STAGES;
}
}
bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,
int prob_m, int prob_n, int prob_k, int num_bits,
int scales_cache_size, int max_shared_mem) {
int pack_factor = 32 / num_bits;
// Get B size
int tb_k = th_config.thread_k;
int tb_n = th_config.thread_n;
int b_size = (tb_k * tb_n / pack_factor) * 4;
// Get A size
int m_blocks = ceildiv(prob_m, 16);
int tb_max_m = 16;
while (true) {
if (m_blocks >= max_m_blocks) {
tb_max_m *= max_m_blocks;
break;
}
max_m_blocks--;
if (max_m_blocks == 0) {
TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks);
}
}
int a_size = (tb_max_m * tb_k) * 2;
float pipe_size = (a_size + b_size) * STAGES;
TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity
return pipe_size < 0.95f * (max_shared_mem - scales_cache_size);
}
bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
int prob_m, int prob_n, int prob_k, int num_bits,
int group_size, bool has_act_order, bool is_k_full,
int max_shared_mem) {
// Sanity
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
th_config.num_threads == -1) {
return false;
}
// Verify K/N are divisible by thread K/N
if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
return false;
}
// thread_k can be only 128 or 64 (because it must be less than groupsize
// which is 128)
if (th_config.thread_k != 128 && th_config.thread_k != 64) {
return false;
}
// Verify min for thread K/N
if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
return false;
}
// num_threads must be at least 128 (= 4 warps)
if (th_config.num_threads < 128) {
return false;
}
// Determine cache for scales
int scales_cache_size =
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
group_size, has_act_order, is_k_full);
// Check that pipeline fits into cache
if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k,
num_bits, scales_cache_size, max_shared_mem)) {
return false;
}
return true;
}
exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
int num_bits, int group_size,
bool has_act_order, bool is_k_full,
int max_shared_mem) {
int max_m_blocks = 4;
while (max_m_blocks > 0) {
if (prob_m <= 16) {
for (auto th_config : small_batch_thread_configs) {
if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full,
max_shared_mem)) {
return exec_config_t{max_m_blocks, th_config};
}
}
} else {
for (auto th_config : large_batch_thread_configs) {
if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full,
max_shared_mem)) {
return exec_config_t{max_m_blocks, th_config};
}
}
}
max_m_blocks--; // Process less M blocks per invocation to reduce cache
// usage
}
return exec_config_t{0, {-1, -1, -1}};
}
#define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION) \
else if (KERNEL_FUNCTION( \
q_type, thread_n_blocks, thread_k_blocks, has_act_order, \
group_blocks, num_threads, blocks, max_shared_mem, stream, \
A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \
zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \
num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \
replicate_input, apply_weights, m_block, max_par, \
exec_cfg.max_m_blocks)) { \
}
void marlin_mm_moe(const void* A, const void* B, void* C,
const void* sorted_ids, const void* topk_weights,
const void* topk_ids, const void* s, void* zp,
const void* g_idx, const void* perm, void* a_tmp,
void* expert_offsets, int prob_m, int prob_n, int prob_k,
void* workspace, vllm::ScalarType const& q_type,
bool has_act_order, bool is_k_full, bool has_zp,
int num_groups, int group_size, int num_experts, int topk,
int moe_block_size, int dev, cudaStream_t stream,
int thread_k, int thread_n, int sms, int max_par,
bool replicate_input, bool apply_weights) {
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
", ", prob_n, ", ", prob_k, "]");
if (sms == -1) {
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
}
int max_shared_mem = 0;
cudaDeviceGetAttribute(&max_shared_mem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0);
int num_bits = q_type.size_bits();
// Set thread config
exec_config_t exec_cfg;
if (thread_k != -1 && thread_n != -1) {
// User-defined config
exec_cfg =
exec_config_t{4, thread_config_t{thread_k, thread_n, USER_THREADS}};
} else {
// Auto config
exec_cfg =
determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, max_shared_mem);
}
TORCH_CHECK(exec_cfg.max_m_blocks > 0 &&
is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks,
prob_m, prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, max_shared_mem),
"Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks,
", thread_k = ", exec_cfg.tb_cfg.thread_k,
", thread_n = ", exec_cfg.tb_cfg.thread_n,
", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [",
prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
", group_size = ", group_size,
", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
", max_shared_mem = ", max_shared_mem);
int num_threads = exec_cfg.tb_cfg.num_threads;
thread_k = exec_cfg.tb_cfg.thread_k;
thread_n = exec_cfg.tb_cfg.thread_n;
int thread_k_blocks = thread_k / 16;
int thread_n_blocks = thread_n / 16;
int blocks = sms;
TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
" is not divisible by thread_n = ", thread_n);
TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
" is not divisible by thread_k = ", thread_k);
int group_blocks = 0;
if (has_act_order) {
if (is_k_full) {
TORCH_CHECK(group_size != -1);
group_blocks = group_size / 16;
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
" is not divisible by group_blocks = ", group_blocks);
} else {
TORCH_CHECK(group_size == 0);
group_blocks = 0;
}
} else {
if (group_size == -1) {
group_blocks = -1;
} else {
group_blocks = group_size / 16;
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
" is not divisible by group_blocks = ", group_blocks);
}
}
int tot_m = prob_m;
const int* topk_ids_ptr = (const int*)topk_ids;
int* expert_offsets_ptr = (int*)expert_offsets;
compute_expert_offsets<<<1, num_experts, 0, stream>>>(
topk_ids_ptr, expert_offsets_ptr, tot_m * topk, moe_block_size);
bool do_permute_a = has_act_order;
// If we have a full K, then we can run the non-act-order version of Marlin
// (since the weight rows are reordered by increasing group ids, and by
// having a full K, we have full original groups)
if (is_k_full) {
has_act_order = false;
}
int pack_factor = 32 / q_type.size_bits();
for (int expert_idx = 0; expert_idx < num_experts; ++expert_idx) {
const int4* A_ptr = (const int4*)A;
int4* a_tmp_ptr = (int4*)a_tmp;
const int4* B_ptr =
(const int4*)B + (prob_n * prob_k / (pack_factor * 4)) * expert_idx;
int4* C_ptr = (int4*)C;
const float* topk_weights_ptr = (const float*)topk_weights;
const int* sorted_ids_ptr = (const int*)sorted_ids;
const int4* s_ptr = (const int4*)s + num_groups * prob_n / 8 * expert_idx;
const int4* zp_ptr =
(const int4*)zp + num_groups * prob_n / (pack_factor * 4) * expert_idx;
const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx;
const int* perm_ptr = (const int*)perm + prob_k * expert_idx;
int* locks = (int*)workspace;
if (do_permute_a) {
// Permute A columns
int topk_rows = replicate_input ? tot_m : tot_m * topk;
int block_rows = ceildiv(topk_rows, blocks);
permute_cols_kernel<<<blocks, num_threads, 0, stream>>>(
A_ptr, perm_ptr, a_tmp_ptr, topk_rows, prob_k, block_rows);
A_ptr = a_tmp_ptr;
}
int tot_m_blocks = ceildiv(tot_m, 16);
for (int m_block = 0; m_block < tot_m_blocks;
m_block += 4 * exec_cfg.max_m_blocks) {
if (false) {
}
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4b8)
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8b128)
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4)
else {
TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " +
str(prob_n) + ", " + str(prob_k) + "]" +
", has_act_order = " + str(has_act_order) +
", num_groups = " + str(num_groups) +
", group_size = " + str(group_size) +
", thread_n_blocks = " + str(thread_n_blocks) +
", thread_k_blocks = " + str(thread_k_blocks));
}
}
}
}
} // namespace marlin_moe
torch::Tensor marlin_gemm_moe(
const torch::Tensor& a, const torch::Tensor& b_q_weights,
const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights,
const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
torch::Tensor& b_zeros, const torch::Tensor& g_idx,
const torch::Tensor& perm, torch::Tensor& workspace,
vllm::ScalarTypeId const b_q_type_id, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk,
int64_t moe_block_size, bool replicate_input, bool apply_weights) {
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
bool has_zp = b_zeros.size(1) != 0;
if (has_zp) {
TORCH_CHECK(
b_q_type == vllm::kU4,
"b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str());
} else {
TORCH_CHECK(
b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128,
"b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type.str());
}
int pack_factor = 32 / b_q_type.size_bits();
int max_par = 4;
int dev = a.get_device();
auto options_dtype =
torch::TensorOptions().dtype(a.dtype()).device(a.device());
auto options_int =
torch::TensorOptions().dtype(torch::kInt).device(a.device());
torch::Tensor c = torch::zeros({size_m, topk, size_n}, options_dtype);
torch::Tensor a_tmp =
replicate_input ? torch::zeros({size_m, size_k}, options_dtype)
: torch::zeros({size_m, topk, size_k}, options_dtype);
torch::Tensor expert_offsets = torch::empty({num_experts + 1}, options_int);
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int thread_k = -1;
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int thread_n = -1;
// sms: number of SMs to use for the kernel (can usually be left as auto -1)
int sms = -1;
// Detect groupsize and act_order
int num_groups = -1;
int group_size = -1;
bool has_act_order = g_idx.size(1) != 0;
int b_rank = b_scales.sizes().size();
TORCH_CHECK(b_rank == 3, "b_scales rank = ", b_rank, " is not 3");
TORCH_CHECK(b_scales.size(2) == size_n, "b_scales dim 2 = ", b_scales.size(2),
" is not size_n = ", size_n);
num_groups = b_scales.size(1);
TORCH_CHECK(VLLM_IMPLIES(!is_k_full, has_act_order),
"if is_k_full is false, has_act_order must be true");
if (has_act_order) {
if (is_k_full) {
TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1");
TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k,
", is not divisible by num_groups = ", num_groups);
group_size = size_k / num_groups;
} else {
group_size = 0;
}
} else {
if (num_groups > 1) {
TORCH_CHECK(
size_k % num_groups == 0, "size_k = ", size_k,
", is not divisible by b_scales.size(0) = ", b_scales.size(0));
group_size = size_k / num_groups;
} else {
group_size = -1;
}
}
// Verify b_zeros
if (has_zp) {
int rank = b_zeros.sizes().size();
TORCH_CHECK(rank == 3, "b_zeros rank = ", rank, " is not 3");
TORCH_CHECK(b_zeros.size(1) == num_groups,
"b_zeros dim 1 = ", b_zeros.size(1),
" is not num_groups = ", num_groups);
TORCH_CHECK(b_zeros.size(2) == size_n / pack_factor,
"b_zeros dim 2 = ", b_zeros.size(2),
" is not size_n / pack_factor = ", size_n / pack_factor);
}
marlin_moe::marlin_mm_moe(
a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(),
topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(),
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(),
expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(),
b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size,
num_experts, topk, moe_block_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par,
replicate_input, apply_weights);
return c;
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("marlin_gemm_moe", &marlin_gemm_moe);
}

View File

@ -1 +0,0 @@
kernel_*.cu

View File

@ -25,13 +25,15 @@ TEMPLATE = ("template __global__ void Marlin<"
"{{thread_k_blocks}}, "
"{{'true' if m_block_size_8 else 'false'}}, "
"{{stages}}, "
"{{'true' if has_act_order else 'false'}}, "
"{{'true' if has_zp else 'false'}}, "
"{{group_blocks}}, "
"{{'true' if is_zp_float else 'false'}}>"
"( MARLIN_KERNEL_PARAMS );")
# int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn"]
SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128"]
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
@ -50,29 +52,21 @@ def remove_old_kernels():
def generate_new_kernels():
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
has_zp = "B" not in scalar_type
all_template_str_list = []
for group_blocks, m_blocks, thread_configs in itertools.product(
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):
# act order case only support gptq-int4 and gptq-int8
if group_blocks == 0 and scalar_type not in [
"vllm::kU4B8", "vllm::kU8B128"
]:
has_act_order = group_blocks == 0
if has_zp and has_act_order:
continue
if thread_configs[2] == 256:
# for small batch (m_blocks == 1), we only need (128, 128, 256)
# for large batch (m_blocks > 1), we only need (64, 256, 256)
if m_blocks <= 1 and thread_configs[0] != 128:
continue
if m_blocks > 1 and thread_configs[0] != 64:
continue
# we only support channelwise quantization and group_size == 128
# for fp8
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
continue
k_blocks = thread_configs[0] // 16
n_blocks = thread_configs[1] // 16
threads = thread_configs[2]
@ -88,6 +82,8 @@ def generate_new_kernels():
thread_k_blocks=k_blocks,
m_block_size_8=m_blocks == 0.5,
stages="pipe_stages",
has_act_order=has_act_order,
has_zp=has_zp,
group_blocks=group_blocks,
is_zp_float=False,
)

View File

@ -18,7 +18,7 @@
const float *__restrict__ topk_weights_ptr, int top_k, \
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
bool use_fp32_reduce, int max_shared_mem
bool use_fp32_reduce
namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t, // compute dtype, half or nv_float16
@ -33,9 +33,11 @@ template <typename scalar_t, // compute dtype, half or nv_float16
// only works when thread_m_blocks == 1
const int stages, // number of stages for the async global->shared
// fetch pipeline
const int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
const bool has_act_order, // whether act_order is enabled
const bool has_zp, // whether zero-points are enabled
const int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
>
__global__ void Marlin(MARLIN_KERNEL_PARAMS);

View File

@ -25,7 +25,6 @@
#include "quantization/gptq_marlin/marlin.cuh"
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
#include "quantization/gptq_marlin/dequant.h"
#include "core/scalar_type.hpp"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
@ -49,9 +48,11 @@ template <typename scalar_t, // compute dtype, half or nv_float16
// only works when thread_m_blocks == 1
const int stages, // number of stages for the async global->shared
// fetch pipeline
const int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
const bool has_act_order, // whether act_order is enabled
const bool has_zp, // whether zero-points are enabled
const int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
>
__global__ void Marlin(
const int4* __restrict__ A, // fp16 input matrix of shape mxk
@ -76,8 +77,8 @@ __global__ void Marlin(
int prob_k, // reduction dimension k
int* locks, // extra global storage for barrier synchronization
bool use_atomic_add, // whether to use atomic add to reduce
bool use_fp32_reduce, // whether to use fp32 global reduce
int max_shared_mem) {}
bool use_fp32_reduce // whether to use fp32 global reduce
) {}
} // namespace MARLIN_NAMESPACE_NAME
@ -165,6 +166,144 @@ __device__ inline void ldsm(typename ScalarType<scalar_t>::FragA& frag_a,
}
}
// Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in
// all cases.
template <int lut>
__device__ inline int lop3(int a, int b, int c) {
int res;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(res)
: "r"(a), "r"(b), "r"(c), "n"(lut));
return res;
}
// Constructs destination register by taking bytes from 2 sources (based on
// mask)
template <int start_byte, int mask>
__device__ inline uint32_t prmt(uint32_t a) {
uint32_t res;
asm volatile("prmt.b32 %0, %1, %2, %3;\n"
: "=r"(res)
: "r"(a), "n"(start_byte), "n"(mask));
return res;
}
template <typename scalar_t, int bit>
__device__ inline typename ScalarType<scalar_t>::FragB dequant(
int q, typename ScalarType<scalar_t>::FragB& frag_b);
//
// Efficiently dequantize 4bit values packed in an int32 value into a full
// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
// with some small changes:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
//
template <>
__device__ inline typename ScalarType<half>::FragB dequant<half, 4>(
int q, typename ScalarType<half>::FragB& frag_b) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const int SUB = 0x64086408;
const int MUL = 0x2c002c00;
const int ADD = 0xd480d480;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&SUB));
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&MUL),
*reinterpret_cast<const half2*>(&ADD));
return frag_b;
}
template <>
__device__ inline typename ScalarType<nv_bfloat16>::FragB
dequant<nv_bfloat16, 4>(int q,
typename ScalarType<nv_bfloat16>::FragB& frag_b) {
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t EX = 0x43004300;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
q >>= 4;
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
static constexpr uint32_t MUL = 0x3F803F80;
static constexpr uint32_t ADD = 0xC308C308;
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
return frag_b;
}
//
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
// bf16 Reference:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
//
template <>
__device__ inline typename ScalarType<half>::FragB dequant<half, 8>(
int q, typename ScalarType<half>::FragB& frag_b) {
static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
return frag_b;
}
template <>
__device__ inline typename ScalarType<nv_bfloat16>::FragB
dequant<nv_bfloat16, 8>(int q,
typename ScalarType<nv_bfloat16>::FragB& frag_b) {
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted =
reinterpret_cast<uint32_t*>(fp32_intermediates);
static constexpr uint32_t fp32_base = 0x4B000000;
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
fp32_intermediates[0] -= 8388736.f;
fp32_intermediates[1] -= 8388736.f;
fp32_intermediates[2] -= 8388736.f;
fp32_intermediates[3] -= 8388736.f;
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b);
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
fp32_intermediates_casted[1], 0x7632);
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
fp32_intermediates_casted[3], 0x7632);
return frag_b;
}
// Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization.
template <typename scalar_t>
@ -290,9 +429,11 @@ template <typename scalar_t, // compute dtype, half or nv_float16
// only works when thread_m_blocks == 1
const int stages, // number of stages for the async global->shared
// fetch pipeline
const int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
const bool has_act_order, // whether act_order is enabled
const bool has_zp, // whether zero-points are enabled
const int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
>
__global__ void Marlin(
const int4* __restrict__ A, // fp16 input matrix of shape mxk
@ -317,8 +458,8 @@ __global__ void Marlin(
int prob_k, // reduction dimension k
int* locks, // extra global storage for barrier synchronization
bool use_atomic_add, // whether to use atomic add to reduce
bool use_fp32_reduce, // whether to use fp32 global reduce
int max_shared_mem) {
bool use_fp32_reduce // whether to use fp32 global reduce
) {
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
// same size, which might involve multiple column "slices" (of width 16 *
// `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
@ -340,8 +481,6 @@ __global__ void Marlin(
extern __shared__ int4 sh[];
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8;
constexpr bool has_act_order = group_blocks == 0;
constexpr int pack_factor = 32 / w_type.size_bits();
static_assert(thread_m_blocks == 1 || !m_block_size_8);
@ -395,20 +534,13 @@ __global__ void Marlin(
int64_t B_expert_off = 0;
int4* sh_block_sorted_ids_int4 = sh;
int4* sh_rd_block_sorted_ids_int4 =
sh_block_sorted_ids_int4 + moe_block_size / 4;
int4* sh_block_topk_weights_int4 =
sh_rd_block_sorted_ids_int4 + moe_block_size / 4;
// sh_block_topk_weights_int4 only need (moe_block_size / 4);
// but we pad to align to 256 bytes
int4* sh_new =
sh_block_topk_weights_int4 + moe_block_size / 2 + moe_block_size;
int32_t* sh_block_sorted_ids =
reinterpret_cast<int*>(sh_block_sorted_ids_int4);
int32_t* sh_rd_block_sorted_ids =
reinterpret_cast<int*>(sh_rd_block_sorted_ids_int4);
int4* sh_block_topk_weights_int4 =
sh_block_sorted_ids_int4 + moe_block_size / 4;
scalar_t2* sh_block_topk_weights =
reinterpret_cast<scalar_t2*>(sh_block_topk_weights_int4);
int4* sh_new = sh_block_topk_weights_int4 + moe_block_size / 4;
int32_t block_num_valid_tokens = 0;
int32_t locks_off = 0;
@ -452,11 +584,6 @@ __global__ void Marlin(
sh_block_sorted_ids_int4[tid4] = reinterpret_cast<const int4*>(
sorted_token_ids_ptr)[block_id * moe_block_size / 4 + tid4];
#pragma unroll
for (int i = 0; i < 4; i++)
sh_rd_block_sorted_ids[tid4 * 4 + i] =
sh_block_sorted_ids[tid4 * 4 + i] / top_k;
if (mul_topk_weights) {
#pragma unroll
for (int i = 0; i < 4; i++) {
@ -616,7 +743,6 @@ __global__ void Marlin(
constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0;
// constexpr int act_s_row_stride = 1;
// int act_s_col_stride = act_s_row_stride * num_groups;
constexpr int act_s_max_num_groups = 32;
int act_s_col_stride = 1;
int act_s_col_warp_stride = act_s_col_stride * 8;
int tb_n_warps = thread_n_blocks / 4;
@ -632,9 +758,9 @@ __global__ void Marlin(
int zp_gl_rd_delta = zp_gl_stride;
// Global A read index of current thread.
int a_gl_rd_row = threadIdx.x / a_gl_rd_delta_o;
int a_gl_rd_col = a_gl_rd_delta_o * slice_row + threadIdx.x % a_gl_rd_delta_o;
int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
(threadIdx.x % a_gl_rd_delta_o);
a_gl_rd += a_gl_rd_delta_o * slice_row;
// Shared write index of current thread.
int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
(threadIdx.x % a_gl_rd_delta_o);
@ -648,8 +774,8 @@ __global__ void Marlin(
(threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
b_gl_rd += b_sh_stride * slice_col;
b_gl_rd += b_gl_rd_delta_o * slice_row;
auto b_sh_wr = threadIdx.x * b_thread_vecs;
auto b_sh_rd = threadIdx.x * b_thread_vecs;
int b_sh_wr = threadIdx.x * b_thread_vecs;
int b_sh_rd = threadIdx.x * b_thread_vecs;
// For act_order
constexpr int k_iter_size = tb_k / b_sh_wr_iters;
@ -668,7 +794,7 @@ __global__ void Marlin(
s_sh_stride * slice_col + threadIdx.x;
}
}
auto s_sh_wr = threadIdx.x;
int s_sh_wr = threadIdx.x;
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
// Zero-points
@ -681,7 +807,7 @@ __global__ void Marlin(
zp_sh_stride * slice_col + threadIdx.x;
}
}
auto zp_sh_wr = threadIdx.x;
int zp_sh_wr = threadIdx.x;
bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride;
// We use a different scale layout for grouped and column-wise quantization as
@ -725,7 +851,7 @@ __global__ void Marlin(
// each warp must also write a consecutive memory segment?
auto transform_a = [&](int i) {
int row = i / a_gl_rd_delta_o;
return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ (row % 8);
return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
};
// Since the computation of this remapping is non-trivial and, due to our main
// loop unrolls, all shared memory accesses are static, we simply precompute
@ -753,28 +879,12 @@ __global__ void Marlin(
B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
// Shared memory storage for global fetch pipelines.
constexpr int sh_red_size = (2 * thread_n_blocks + 1) * 16 * thread_m_blocks;
constexpr int sh_b_size = stages * b_sh_stage;
int4* sh_b = sh_new;
int4* sh_red = sh_new;
int4* sh_g_idx = sh_b + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
int4* sh_a = sh_new;
int4* sh_b = sh_a + (stages * a_sh_stage);
int4* sh_g_idx = sh_b + (stages * b_sh_stage);
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride)
: (stages * s_sh_stage);
int4* sh_s = sh_zp + (stages * zp_sh_stage);
// shared memory reused by reduction should be smaller than
// shared memory used by weight.
static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <=
stages * b_sh_stage);
int4* sh_a = sh_s + sh_s_size;
constexpr int shm_size_used =
moe_block_size + stages * (g_idx_stage + zp_sh_stage) + sh_s_size +
(sh_red_size > sh_b_size ? sh_red_size : sh_b_size);
// all remaining shared memory is used to cache A (input)
// sh_a_max_row is at least ` stages * 16 * thread_m_blocks `
int sh_a_max_row =
((max_shared_mem - 1024) / 16 - shm_size_used) / (thread_k_blocks * 2);
int4* sh_red = sh_b;
// Register storage for double buffer of shared memory reads.
FragA frag_a[2][thread_m_blocks];
@ -795,14 +905,15 @@ __global__ void Marlin(
int sh_first_group_id = -1;
int sh_num_groups = -1;
constexpr int sh_max_num_groups = 32;
auto fetch_act_order_scales_to_shared = [&](bool is_async, int first_group_id,
int last_group_id) {
sh_first_group_id = first_group_id;
sh_num_groups = last_group_id - first_group_id + 1;
if (sh_num_groups < act_s_max_num_groups) {
sh_num_groups = act_s_max_num_groups;
if (sh_num_groups < sh_max_num_groups) {
sh_num_groups = sh_max_num_groups;
}
if (sh_first_group_id + sh_num_groups > num_groups) {
@ -829,31 +940,27 @@ __global__ void Marlin(
}
}
};
// Asynchronously fetch the next A, B and s tile from global to the next
// shared memory pipeline location.
bool should_load_a = true;
int max_num_stage_groups =
((sh_a_max_row - moe_block_size) / moe_block_size + 1) / stages;
max_num_stage_groups = max(max_num_stage_groups, 1);
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true,
int pipe_a = 0) {
int a_remaining_load_count_in_slice = stages;
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
if (pred) {
if (should_load_a) {
int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe_a;
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
if (prob_k > thread_k_blocks * 16 * stages || slice_col == 0 ||
a_remaining_load_count_in_slice > 0) {
a_remaining_load_count_in_slice--;
#pragma unroll
for (int i = 0; i < a_sh_wr_iters; i++) {
int row = a_gl_rd_delta_i / a_gl_stride * i + a_gl_rd_row;
int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off;
int row = a_idx / a_gl_stride;
int64_t sorted_row = 0;
if (!m_block_size_8 || row < 8)
sorted_row = sh_rd_block_sorted_ids[row];
int64_t true_idx =
sorted_row * a_gl_stride + a_gl_rd_col + a_gl_rd_delta_o * a_off;
sorted_row = sh_block_sorted_ids[row] / top_k;
int64_t true_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride;
cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[true_idx],
row < block_num_valid_tokens);
}
}
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) {
@ -956,8 +1063,8 @@ __global__ void Marlin(
// Load the next sub-tile from the current location in the shared memory pipe
// into the current register buffer.
auto fetch_to_registers = [&](int k, int pipe, int pipe_a = 0) {
int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe_a;
auto fetch_to_registers = [&](int k, int pipe) {
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++)
ldsm<m_block_size_8 ? 2 : 4, scalar_t>(
@ -1002,17 +1109,12 @@ __global__ void Marlin(
}
} else if constexpr (group_blocks != -1) {
if constexpr (group_blocks >= thread_k_blocks) {
if (k % b_sh_wr_iters == 0) {
int4* sh_s_stage =
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
} else {
reinterpret_cast<int4*>(&frag_s[1])[0] =
reinterpret_cast<int4*>(&frag_s[0])[0];
}
int4* sh_s_stage =
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
} else {
auto warp_id = threadIdx.x / 32;
int warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;
@ -1050,7 +1152,7 @@ __global__ void Marlin(
// Determine "position" inside the thread-block (based on warp and
// thread-id)
auto warp_id = threadIdx.x / 32;
int warp_id = threadIdx.x / 32;
int n_warps =
thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N
@ -1059,7 +1161,7 @@ __global__ void Marlin(
cur_k += warp_row * 16;
auto th_id = threadIdx.x % 32;
int th_id = threadIdx.x % 32;
cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix
int s_col_shift =
@ -1120,18 +1222,15 @@ __global__ void Marlin(
}
} else if constexpr (group_blocks >= thread_k_blocks) {
if (k % b_sh_wr_iters == 0) {
int4* sh_zp_stage =
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
#pragma unroll
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] =
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
}
int4* sh_zp_stage =
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] =
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
}
} else {
auto warp_id = threadIdx.x / 32;
int warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;
@ -1152,7 +1251,6 @@ __global__ void Marlin(
sh_zp_stage += cur_group_id * zp_sh_stride;
#pragma unroll
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] =
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
@ -1165,16 +1263,12 @@ __global__ void Marlin(
if constexpr (group_blocks != -1) {
if constexpr (group_blocks >= thread_k_blocks) {
if (k % b_sh_wr_iters == 0) {
int4* sh_zp_stage =
sh_zp +
zp_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] =
sh_zp_stage[zp_sh_rd];
}
int4* sh_zp_stage =
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd];
} else {
auto warp_id = threadIdx.x / 32;
int warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;
@ -1198,25 +1292,6 @@ __global__ void Marlin(
}
};
auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) {
if constexpr (has_zp && is_zp_float || !has_zp) {
dequant<scalar_t2, w_type_id>(q, frag_b_ptr);
} else {
static_assert(has_zp && !is_zp_float);
static_assert(w_type_id == vllm::kU4.id() || w_type_id == vllm::kU8.id());
// If (has_zp && !is_zp_float),
// we use not-zp version `dequant` function
// to improve numerical accuracy.
// Since both weight and zero point are dequanted using this logic,
// the final dequanted weight would be correct.
if constexpr (w_type_id == vllm::kU4.id()) {
dequant<scalar_t2, vllm::kU4B8.id()>(q, frag_b_ptr);
} else if constexpr (w_type_id == vllm::kU8.id()) {
dequant<scalar_t2, vllm::kU8B128.id()>(q, frag_b_ptr);
}
}
};
// Execute the actual tensor core matmul of a sub-tile.
bool is_first_matmul_in_slice = true;
auto matmul = [&](int k) {
@ -1240,17 +1315,15 @@ __global__ void Marlin(
zp_quant_1 = frag_qzp[k2][1];
}
dequant_data(zp_quant_0, reinterpret_cast<scalar_t2*>(&frag_zp));
dequant_data(zp_quant_1, reinterpret_cast<scalar_t2*>(&frag_zp) + 2);
}
}
if constexpr (has_zp && is_zp_float) {
if (is_new_zp) {
reinterpret_cast<int4*>(&frag_zp)[0] =
reinterpret_cast<int4*>(&frag_zpf[k2])[0];
}
}
dequant<scalar_t, w_type.size_bits()>(zp_quant_0, frag_zp_0);
dequant<scalar_t, w_type.size_bits()>(zp_quant_1, frag_zp_1);
frag_zp[0] = frag_zp_0[0];
frag_zp[1] = frag_zp_0[1];
frag_zp[2] = frag_zp_1[0];
frag_zp[3] = frag_zp_1[1];
}
}
// We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations.
#pragma unroll
@ -1269,8 +1342,8 @@ __global__ void Marlin(
b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
}
dequant_data(b_quant_0, reinterpret_cast<scalar_t2*>(&frag_b0));
dequant_data(b_quant_1, reinterpret_cast<scalar_t2*>(&frag_b1));
dequant<scalar_t, w_type.size_bits()>(b_quant_0, frag_b0);
dequant<scalar_t, w_type.size_bits()>(b_quant_1, frag_b1);
// Apply scale to frag_b0
if constexpr (has_act_order) {
@ -1278,7 +1351,8 @@ __global__ void Marlin(
scale4<scalar_t>(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0);
scale4<scalar_t>(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j],
act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1);
act_frag_s[k][2][j], act_frag_s[k2][3][j], 1);
} else if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
int idx = (threadIdx.x / 4) % 2;
scalar_t2 s2 = Dtype::nums2num2(
@ -1287,12 +1361,18 @@ __global__ void Marlin(
if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2);
scale_and_sub<scalar_t>(frag_b0, s2.x, frag_zp[j].x);
scale_and_sub<scalar_t>(frag_b1, s2.y, frag_zp[j].y);
} else if constexpr (has_zp && group_blocks != -1) {
} else if constexpr (has_zp && !is_zp_float && group_blocks != -1) {
if (is_new_zp)
frag_zp[j] = __hmul2(frag_zp[j],
*reinterpret_cast<scalar_t2*>(&frag_s[k2][j]));
scale_and_sub<scalar_t>(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x);
scale_and_sub<scalar_t>(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y);
scale_and_sub<scalar_t>(frag_b0, frag_s[k % 2][j][0].x, frag_zp[j].x);
scale_and_sub<scalar_t>(frag_b1, frag_s[k % 2][j][0].y, frag_zp[j].y);
} else if constexpr (has_zp && is_zp_float && group_blocks != -1) {
if (is_new_zp)
frag_zpf[k2][j] = __hmul2(
frag_zpf[k2][j], *reinterpret_cast<scalar_t2*>(&frag_s[k2][j]));
scale_and_sub<scalar_t>(frag_b0, frag_s[k2][j].x, frag_zpf[k2][j].x);
scale_and_sub<scalar_t>(frag_b1, frag_s[k2][j].y, frag_zpf[k2][j].y);
} else if constexpr (group_blocks != -1) {
scale<scalar_t>(frag_b0, frag_s[k2][j], 0);
scale<scalar_t>(frag_b1, frag_s[k2][j], 1);
@ -1317,7 +1397,7 @@ __global__ void Marlin(
auto thread_block_reduce = [&]() {
constexpr int red_off = threads / b_sh_stride_threads / 2;
if (red_off >= 1) {
auto red_idx = threadIdx.x / b_sh_stride_threads;
int red_idx = threadIdx.x / b_sh_stride_threads;
constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
constexpr int red_sh_delta = b_sh_stride_threads;
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
@ -1651,7 +1731,7 @@ __global__ void Marlin(
fetch_col_scale_to_shared();
}
}
fetch_to_shared(i, i, i < slice_iters, i);
fetch_to_shared(i, i, i < slice_iters);
}
zero_accums();
@ -1660,10 +1740,8 @@ __global__ void Marlin(
fetch_to_registers(0, 0);
fetch_scales_to_registers(0, 0);
fetch_zp_to_registers(0, 0);
a_gl_rd_col += a_gl_rd_delta_o * (stages - 1);
if constexpr (has_act_order) {
slice_k_start_shared_fetch += tb_k * (stages - 1);
}
a_gl_rd += a_gl_rd_delta_o * (stages - 1);
slice_k_start_shared_fetch += tb_k * (stages - 1);
};
if (slice_iters) {
start_pipes();
@ -1676,58 +1754,45 @@ __global__ void Marlin(
// have even length meaning that the next iteration will always start at
// index 0.
for (int stage_group_id = 0; stage_group_id < max_num_stage_groups;
stage_group_id++) {
#pragma unroll
for (int pipe = 0; pipe < stages;) {
for (int pipe = 0; pipe < stages;) {
#pragma unroll
for (int k = 0; k < b_sh_wr_iters; k++) {
int idx =
(pipe >= stages && stage_group_id == max_num_stage_groups - 1)
? (pipe - stages)
: (pipe + stage_group_id * stages);
fetch_to_registers(k + 1, pipe % stages, idx);
fetch_scales_to_registers(k + 1, pipe);
fetch_zp_to_registers(k + 1, pipe);
if (k == b_sh_wr_iters - 2) {
int idx = (pipe >= 1 && stage_group_id == max_num_stage_groups - 1)
? (pipe - 1)
: (pipe + (stage_group_id + 1) * stages - 1);
fetch_to_shared((pipe + stages - 1) % stages, pipe,
slice_iters >= stages, idx);
pipe++;
wait_for_stage();
init_same_group(pipe % stages);
}
matmul(k);
}
slice_iters--;
if (slice_iters == 0) {
break;
}
}
a_gl_rd_col += a_gl_rd_delta_o * stages;
if constexpr (has_act_order) {
slice_k_start += tb_k * stages;
slice_k_start_shared_fetch += tb_k * stages;
int first_group_id = g_idx[slice_k_start];
int last_g_idx = slice_k_start + stages * tb_k * 2;
if (last_g_idx >= prob_k) {
last_g_idx = prob_k - 1;
}
int last_group_id = g_idx[last_g_idx];
if (last_group_id >= sh_first_group_id + sh_num_groups) {
fetch_act_order_scales_to_shared(false, first_group_id,
last_group_id);
__syncthreads();
for (int k = 0; k < b_sh_wr_iters; k++) {
fetch_to_registers(k + 1, pipe % stages);
fetch_scales_to_registers(k + 1, pipe);
fetch_zp_to_registers(k + 1, pipe);
if (k == b_sh_wr_iters - 2) {
fetch_to_shared((pipe + stages - 1) % stages, pipe,
slice_iters >= stages);
pipe++;
wait_for_stage();
init_same_group(pipe % stages);
}
matmul(k);
}
slice_iters--;
if (slice_iters == 0) {
break;
}
}
a_remaining_load_count_in_slice = 0;
a_gl_rd += a_gl_rd_delta_o * stages;
slice_k_start += tb_k * stages;
slice_k_start_shared_fetch += tb_k * stages;
if constexpr (has_act_order) {
int first_group_id = g_idx[slice_k_start];
int last_g_idx = slice_k_start + stages * tb_k * 2;
if (last_g_idx >= prob_k) {
last_g_idx = prob_k - 1;
}
int last_group_id = g_idx[last_g_idx];
if (last_group_id >= sh_first_group_id + sh_num_groups) {
fetch_act_order_scales_to_shared(false, first_group_id, last_group_id);
__syncthreads();
}
}
// Process results and, if necessary, proceed to the next column slice.
// While this pattern may not be the most readable, other ways of writing
@ -1812,30 +1877,15 @@ __global__ void Marlin(
if (last || use_atomic_add)
// only the last block in a slice actually writes the result
write_result();
int old_slice_row = slice_row;
if (slice_row) a_remaining_load_count_in_slice = stages;
slice_row = 0;
slice_col_par++;
slice_col++;
is_first_matmul_in_slice = true;
init_slice();
// Should we load A matrix in next slice?
// `slice_col == 0`: when move to a new moe block
// `old_slice_row > 0`:
// when the last slice is not starting from k_index == 0
// (only happen when it is the first slice of a threadblock)
// `prob_k > thread_k_blocks * 16 * stages * max_num_stage_groups`:
// when the required shared memory size is larger than
// the remaining shared memory
if (slice_col == 0 || old_slice_row ||
prob_k > thread_k_blocks * 16 * stages * max_num_stage_groups) {
should_load_a = true;
} else {
should_load_a = false;
}
if (slice_iters) {
a_gl_rd_col = (threadIdx.x % a_gl_rd_delta_o);
a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
(threadIdx.x % a_gl_rd_delta_o);
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++)
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
@ -1850,10 +1900,12 @@ __global__ void Marlin(
slice_k_finish = slice_k_start + tb_k * slice_iters;
slice_k_start_shared_fetch = slice_k_start;
slice_n_offset = act_s_col_tb_stride * slice_col;
} else {
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
}
start_pipes();
}
}

View File

@ -116,7 +116,7 @@ __global__ void permute_cols_kernel(
int base_k = 0;
for (int i = 0; i < iters; i++) {
auto cur_k = base_k + threadIdx.x;
int cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k];
out_half[cur_k] = a_row_half[src_pos];
@ -126,7 +126,7 @@ __global__ void permute_cols_kernel(
if (rest) {
if (threadIdx.x < rest) {
auto cur_k = base_k + threadIdx.x;
int cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k];
out_half[cur_k] = a_row_half[src_pos];
@ -195,6 +195,7 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
load_groups = max(load_groups, 32); // We load at least 32 scale groups
return load_groups * tb_n * 2;
} else {
int tb_scales = tb_groups * tb_n * 2;
@ -202,24 +203,22 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
}
}
int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
int thread_m_blocks, int prob_m, int prob_n,
int prob_k, int num_bits, int group_size,
bool has_act_order, bool is_k_full, int has_zp,
int is_zp_float) {
int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
int prob_m, int prob_n, int prob_k, int num_bits,
int group_size, bool has_act_order, bool is_k_full,
int has_zp, int is_zp_float) {
int pack_factor = 32 / num_bits;
// Get B size
int tb_k = th_config.thread_k;
int tb_n = th_config.thread_n;
int tb_m = thread_m_blocks * (m_block_size_8 ? 8 : 16);
int tb_m = thread_m_blocks * 16;
// shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights
// shm size for block_sorted_ids/block_topk_weights
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
int sh_block_meta_size = tb_m * 4;
int sh_block_meta_size = tb_m * 4 * 2;
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
int sh_red_size = tb_m * (tb_n + 8) * 2;
int sh_s_size =
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
group_size, has_act_order, is_k_full);
@ -234,17 +233,16 @@ int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8,
sh_zp_size = sh_s_size / 2;
}
int total_size = max(sh_b_size, sh_red_size) + sh_a_size + sh_s_size +
sh_zp_size + sh_g_idx_size + sh_block_meta_size;
int total_size = sh_a_size + sh_b_size + sh_s_size + sh_zp_size +
sh_g_idx_size + sh_block_meta_size;
return total_size;
}
bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
int thread_m_blocks, int prob_m, int prob_n, int prob_k,
int num_bits, int group_size, bool has_act_order,
bool is_k_full, int has_zp, int is_zp_float,
int max_shared_mem) {
bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
int prob_m, int prob_n, int prob_k, int num_bits,
int group_size, bool has_act_order, bool is_k_full,
int has_zp, int is_zp_float, int max_shared_mem) {
// Sanity
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
th_config.num_threads == -1) {
@ -268,113 +266,143 @@ bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8,
// Check that pipeline fits into cache
int cache_size = get_kernel_cache_size(
th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float);
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, has_zp, is_zp_float);
return cache_size <= max_shared_mem;
}
#define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
m_block_size_8 == M_BLOCK_SIZE_8 && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
is_zp_float == IS_ZP_FLOAT) { \
kernel = Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, \
pipe_stages, GROUP_BLOCKS, IS_ZP_FLOAT>; \
#define __GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
M_BLOCK_SIZE_8, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
NUM_THREADS, IS_ZP_FLOAT) \
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
m_block_size_8 == M_BLOCK_SIZE_8 && \
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
is_zp_float == IS_ZP_FLOAT) { \
kernel = Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, \
pipe_stages, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
IS_ZP_FLOAT>; \
}
// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false)
// this is the most common cases
// BIGGROUP: cases for big group size (group_blocks in [-1, 8])
// FZP: cases for float-zero-point (is_zp_float = true)
// ACT: cases for act order case (group_blocks == 0)
#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define GPTQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, true, false, 0, NUM_THREADS, \
false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 8, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
NUM_THREADS, false)
#define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
\
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
\
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define GPTQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
NUM_THREADS, false)
#define COMMON_GET_IF(W_TYPE) \
COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \
COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \
COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \
COMMON_GET_IF_M234(W_TYPE, 8, 4, 128)
#define AWQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 2, NUM_THREADS, \
false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \
false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 8, NUM_THREADS, \
false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
NUM_THREADS, false)
#define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \
\
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false)
#define BIGGROUP_GET_IF(W_TYPE) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \
BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \
BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \
BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128)
#define AWQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
NUM_THREADS, false)
// We currently have 4-bit models only with group_blocks == 4
#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
#define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true)
#define FZP_GET_IF(W_TYPE) \
FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \
FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \
FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \
FZP_GET_IF_M234(W_TYPE, 8, 4, 128)
// We currently have 4-bit models only with group_blocks == 4
#define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
#define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
_GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \
_GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false)
#define ACT_GET_IF(W_TYPE) \
ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \
ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \
ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \
ACT_GET_IF_M234(W_TYPE, 8, 4, 128)
#define HQQ_GET_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \
true) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, true) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, true) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, true) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
NUM_THREADS, true)
template <typename scalar_t>
MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
@ -387,15 +415,23 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
auto kernel = MarlinDefault;
if (false) {
}
GPTQ_GET_IF_M1(vllm::kU4B8, 8, 8, 256)
GPTQ_GET_IF_M1(vllm::kU4B8, 8, 4, 128)
COMMON_GET_IF(vllm::kU4)
COMMON_GET_IF(vllm::kU4B8)
COMMON_GET_IF(vllm::kU8B128)
GPTQ_GET_IF_M234(vllm::kU4B8, 16, 4, 256)
GPTQ_GET_IF_M234(vllm::kU4B8, 8, 4, 128)
BIGGROUP_GET_IF(vllm::kFE4M3fn)
GPTQ_GET_IF_M1(vllm::kU8B128, 8, 8, 256)
GPTQ_GET_IF_M1(vllm::kU8B128, 8, 4, 128)
ACT_GET_IF(vllm::kU4B8)
ACT_GET_IF(vllm::kU8B128)
GPTQ_GET_IF_M234(vllm::kU8B128, 16, 4, 256)
GPTQ_GET_IF_M234(vllm::kU8B128, 8, 4, 128)
AWQ_GET_IF_M1(vllm::kU4, 8, 8, 256)
AWQ_GET_IF_M1(vllm::kU4, 8, 4, 128)
AWQ_GET_IF_M234(vllm::kU4, 16, 4, 256)
AWQ_GET_IF_M234(vllm::kU4, 8, 4, 128)
return kernel;
}
@ -421,19 +457,19 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
for (int i = 0; i < thread_configs_size; i++) {
thread_config_t th_config = thread_configs[i];
if (!is_valid_config(th_config, m_block_size_8, thread_m_blocks, prob_m,
prob_n, prob_k, num_bits, group_size, has_act_order,
is_k_full, has_zp, is_zp_float, max_shared_mem)) {
if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full, has_zp,
is_zp_float, max_shared_mem)) {
continue;
}
int cache_size = get_kernel_cache_size(
th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float);
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits,
group_size, has_act_order, is_k_full, has_zp, is_zp_float);
int group_blocks = 0;
if (!has_act_order) {
group_blocks = group_size == -1 ? -1 : (group_size / 16);
group_blocks = group_size == -1 ? -1 : group_size / 16;
}
auto kernel = get_marlin_kernel<scalar_t>(
@ -479,14 +515,14 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
bool m_block_size_8 = moe_block_size == 8;
if (has_zp) {
TORCH_CHECK(q_type == vllm::kU4,
"q_type must be u4 when has_zp = True. Got = ", q_type.str());
TORCH_CHECK(
q_type == vllm::kU4 || q_type == vllm::kU8,
"q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
} else {
TORCH_CHECK(q_type == vllm::kU4B8 || q_type == vllm::kU8B128 ||
q_type == vllm::kFE4M3fn,
"q_type must be uint4b8, uint8b128 or fp8e4m3 when has_zp = "
"False. Got = ",
q_type.str());
TORCH_CHECK(
q_type == vllm::kU4B8 || q_type == vllm::kU8B128,
"q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
q_type.str());
}
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
@ -595,18 +631,18 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
int thread_k_blocks = thread_k / 16;
int thread_n_blocks = thread_n / 16;
TORCH_CHECK(
is_valid_config(thread_tfg, m_block_size_8, thread_m_blocks, prob_m,
prob_n, prob_k, num_bits, group_size, has_act_order,
is_k_full, has_zp, is_zp_float, max_shared_mem),
"Invalid thread config: thread_m_blocks = ", thread_m_blocks,
", thread_k = ", thread_tfg.thread_k,
", thread_n = ", thread_tfg.thread_n,
", num_threads = ", thread_tfg.num_threads, " for MKN = [", prob_m, ", ",
prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
", group_size = ", group_size, ", has_act_order = ", has_act_order,
", is_k_full = ", is_k_full, ", has_zp = ", has_zp,
", is_zp_float = ", is_zp_float, ", max_shared_mem = ", max_shared_mem);
TORCH_CHECK(is_valid_config(thread_tfg, thread_m_blocks, prob_m, prob_n,
prob_k, num_bits, group_size, has_act_order,
is_k_full, has_zp, is_zp_float, max_shared_mem),
"Invalid thread config: thread_m_blocks = ", thread_m_blocks,
", thread_k = ", thread_tfg.thread_k,
", thread_n = ", thread_tfg.thread_n,
", num_threads = ", thread_tfg.num_threads, " for MKN = [",
prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
", group_size = ", group_size,
", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float,
", max_shared_mem = ", max_shared_mem);
auto kernel = get_marlin_kernel<scalar_t>(
q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, m_block_size_8,
@ -630,7 +666,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr,
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce, max_shared_mem);
prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce);
// clang-format on
}
@ -805,11 +841,10 @@ torch::Tensor moe_wna16_marlin_gemm(
b_q_type == vllm::kU4,
"b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str());
} else {
TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 ||
b_q_type == vllm::kFE4M3fn,
"b_q_type must be uint4b8, uint8b128 or fp8e4m3 when has_zp = "
"False. Got = ",
b_q_type.str());
TORCH_CHECK(
b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128,
"b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
b_q_type.str());
}
if (has_zp && is_zp_float) {

View File

@ -86,13 +86,13 @@ void rms_norm_dynamic_per_token_quant(torch::Tensor& out,
std::optional<torch::Tensor> residual);
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
std::optional<torch::Tensor> key, int64_t head_size,
torch::Tensor& key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox);
void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
std::optional<torch::Tensor> key,
int64_t head_size, torch::Tensor& cos_sin_cache,
bool is_neox, int64_t rot_dim,
torch::Tensor& key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox,
int64_t rot_dim,
torch::Tensor& cos_sin_cache_offsets);
void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
@ -178,10 +178,6 @@ torch::Tensor ggml_moe_a8(torch::Tensor X, torch::Tensor W,
torch::Tensor num_tokens_post_padded, int64_t type,
int64_t row, int64_t top_k, int64_t tokens);
torch::Tensor ggml_moe_a8_vec(torch::Tensor X, torch::Tensor W,
torch::Tensor topk_ids, int64_t top_k,
int64_t type, int64_t row, int64_t tokens);
int64_t ggml_moe_get_block_size(int64_t type);
#ifndef USE_ROCM

View File

@ -38,8 +38,7 @@ inline __device__ void apply_rotary_embedding(
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size]
scalar_t* __restrict__ key, // nullptr or
// [batch_size, seq_len, num_kv_heads,
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const scalar_t* cache_ptr, const int head_size, const int num_heads,
@ -58,15 +57,13 @@ inline __device__ void apply_rotary_embedding(
query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
}
if (key != nullptr) {
const int nk = num_kv_heads * embed_dim;
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
}
const int nk = num_kv_heads * embed_dim;
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
apply_token_rotary_embedding<scalar_t, IS_NEOX>(
key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
}
}
@ -77,8 +74,7 @@ __global__ void rotary_embedding_kernel(
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size]
scalar_t* __restrict__ key, // nullptr or
// [batch_size, seq_len, num_kv_heads,
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
@ -102,8 +98,7 @@ __global__ void batched_rotary_embedding_kernel(
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size]
scalar_t* __restrict__ key, // nullptr or
// [batch_size, seq_len, num_kv_heads,
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
@ -132,12 +127,10 @@ void rotary_embedding(
// [num_tokens, num_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
std::optional<torch::Tensor> key,
// null or
// [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
int64_t head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox) {
@ -145,40 +138,40 @@ void rotary_embedding(
int64_t num_tokens = positions.numel();
int positions_ndim = positions.dim();
// Make sure num_tokens dim is consistent across positions, query, and key
// Make sure num_tokens dim is consistent across positions, query, and key.
TORCH_CHECK(
positions_ndim == 1 || positions_ndim == 2,
"positions must have shape [num_tokens] or [batch_size, seq_len]");
if (positions_ndim == 1) {
TORCH_CHECK(query.size(0) == positions.size(0) &&
(!key.has_value() || key->size(0) == positions.size(0)),
"query, key and positions must have the same number of tokens");
TORCH_CHECK(
query.size(0) == positions.size(0) && key.size(0) == positions.size(0),
"query, key and positions must have the same number of tokens");
}
if (positions_ndim == 2) {
TORCH_CHECK(
query.size(0) == positions.size(0) &&
(!key.has_value() || key->size(0) == positions.size(0)) &&
key.size(0) == positions.size(0) &&
query.size(1) == positions.size(1) &&
(!key.has_value() || key->size(1) == positions.size(1)),
key.size(1) == positions.size(1),
"query, key and positions must have the same batch_size and seq_len");
}
// Make sure head_size is valid for query and key
// hidden_size = num_heads * head_size
int query_hidden_size = query.numel() / num_tokens;
int key_hidden_size = key.has_value() ? key->numel() / num_tokens : 0;
int key_hidden_size = key.numel() / num_tokens;
TORCH_CHECK(query_hidden_size % head_size == 0);
TORCH_CHECK(key_hidden_size % head_size == 0);
// Make sure query and key have consistent number of heads
int num_heads = query_hidden_size / head_size;
int num_kv_heads = key.has_value() ? key_hidden_size / head_size : num_heads;
int num_kv_heads = key_hidden_size / head_size;
TORCH_CHECK(num_heads % num_kv_heads == 0);
int rot_dim = cos_sin_cache.size(1);
int seq_dim_idx = positions_ndim - 1;
int64_t query_stride = query.stride(seq_dim_idx);
int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
int64_t key_stride = key.stride(seq_dim_idx);
dim3 grid(num_tokens);
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
@ -188,16 +181,15 @@ void rotary_embedding(
if (is_neox) {
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride, key_stride,
num_heads, num_kv_heads, head_size);
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), rot_dim,
query_stride, key_stride, num_heads, num_kv_heads, head_size);
} else {
vllm::rotary_embedding_kernel<scalar_t, false>
<<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride,
key_stride, num_heads, num_kv_heads, head_size);
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
head_size);
}
});
}
@ -212,12 +204,10 @@ void batched_rotary_embedding(
// [num_tokens, num_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
std::optional<torch::Tensor>
key, // null or
// [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
int64_t head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox, int64_t rot_dim,
@ -231,38 +221,38 @@ void batched_rotary_embedding(
"cos_sin_cache_offsets");
int positions_ndim = positions.dim();
// Make sure num_tokens dim is consistent across positions, query, and key
// Make sure num_tokens dim is consistent across positions, query, and key.
TORCH_CHECK(
positions_ndim == 1 || positions_ndim == 2,
"positions must have shape [num_tokens] or [batch_size, seq_len]");
if (positions_ndim == 1) {
TORCH_CHECK(query.size(0) == positions.size(0) &&
(!key.has_value() || key->size(0) == positions.size(0)),
"query, key and positions must have the same number of tokens");
TORCH_CHECK(
query.size(0) == positions.size(0) && key.size(0) == positions.size(0),
"query, key and positions must have the same number of tokens");
}
if (positions_ndim == 2) {
TORCH_CHECK(
query.size(0) == positions.size(0) &&
(!key.has_value() || key->size(0) == positions.size(0)) &&
key.size(0) == positions.size(0) &&
query.size(1) == positions.size(1) &&
(!key.has_value() || key->size(1) == positions.size(1)),
key.size(1) == positions.size(1),
"query, key and positions must have the same batch_size and seq_len");
}
// Make sure head_size is valid for query and key
int query_hidden_size = query.numel() / num_tokens;
int key_hidden_size = key.has_value() ? key->numel() / num_tokens : 0;
int key_hidden_size = key.numel() / num_tokens;
TORCH_CHECK(query_hidden_size % head_size == 0);
TORCH_CHECK(key_hidden_size % head_size == 0);
// Make sure query and key have concistent number of heads
int num_heads = query_hidden_size / head_size;
int num_kv_heads = key.has_value() ? key_hidden_size / head_size : num_heads;
int num_kv_heads = key_hidden_size / head_size;
TORCH_CHECK(num_heads % num_kv_heads == 0);
int seq_dim_idx = positions_ndim - 1;
int64_t query_stride = query.stride(seq_dim_idx);
int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0;
int64_t key_stride = key.stride(seq_dim_idx);
dim3 grid(num_tokens);
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
@ -273,16 +263,14 @@ void batched_rotary_embedding(
vllm::batched_rotary_embedding_kernel<scalar_t, true>
<<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
cos_sin_cache.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
key_stride, num_heads, num_kv_heads, head_size);
} else {
vllm::batched_rotary_embedding_kernel<scalar_t, false>
<<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
cos_sin_cache.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
key_stride, num_heads, num_kv_heads, head_size);
}

View File

@ -1,27 +0,0 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_blockwise_sm100_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace vllm {
void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
TORCH_CHECK(
a.size(0) % 4 == 0,
"Input tensor must have a number of rows that is a multiple of 4. ",
"but got: ", a.size(0), " rows.");
if (out.dtype() == torch::kBFloat16) {
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::bfloat16_t>(
out, a, b, a_scales, b_scales);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::half_t>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm

View File

@ -1,205 +0,0 @@
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
#include "cutlass_gemm_caller.cuh"
namespace vllm {
using namespace cute;
template <typename OutType, typename MmaTileShape, typename ScalesPerTile,
class ClusterShape, typename EpilogueScheduler,
typename MainloopScheduler>
struct cutlass_3x_gemm_fp8_blockwise {
using ElementAB = cutlass::float_e4m3_t;
using ElementA = ElementAB;
using LayoutA = cutlass::layout::RowMajor;
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
using ElementB = ElementAB;
using LayoutB = cutlass::layout::ColumnMajor;
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
using ElementC = void;
using ElementD = OutType;
using LayoutD = cutlass::layout::RowMajor;
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
using LayoutC = LayoutD;
static constexpr int AlignmentC = AlignmentD;
using ElementAccumulator = float;
using ElementCompute = float;
using ElementBlockScale = float;
// MMA and Cluster Tile Shapes
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster
// Shape %2 == 0 using MmaTileShape_MNK = Shape<_128,_128,_128>;
static constexpr int ScaleMsPerTile = size<0>(ScalesPerTile{});
static constexpr int ScaleGranularityM =
size<0>(MmaTileShape{}) / ScaleMsPerTile;
static constexpr int ScaleGranularityN =
size<1>(MmaTileShape{}) / size<1>(ScalesPerTile{});
static constexpr int ScaleGranularityK =
size<2>(MmaTileShape{}) / size<2>(ScalesPerTile{});
// Shape of the threadblocks in a cluster
using ClusterShape_MNK = ClusterShape;
using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig<
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
cute::UMMA::Major::MN, cute::UMMA::Major::K>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
using ArchTag = cutlass::arch::Sm100;
using OperatorClass = cutlass::arch::OpClassTensorOp;
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
using ElementScalar = float;
// clang-format off
using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementScalar, RoundStyle>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
MmaTileShape,
ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator,
ElementCompute,
ElementC,
LayoutC,
AlignmentC,
ElementD,
LayoutD,
AlignmentD,
EpilogueScheduler,
DefaultOperation
>::CollectiveOp;
using StageCountType = cutlass::gemm::collective::StageCountAuto;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
ElementA,
cute::tuple<LayoutA, LayoutSFA>,
AlignmentA,
ElementB,
cute::tuple<LayoutB, LayoutSFB>,
AlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainloopScheduler
>::CollectiveOp;
// clang-format on
using KernelType = enable_sm100_only<cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue>>;
struct GemmKernel : public KernelType {};
};
template <typename Gemm>
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
using GemmKernel = typename Gemm::GemmKernel;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideD = typename Gemm::GemmKernel::StrideD;
using StrideC = typename Gemm::GemmKernel::StrideC;
using LayoutSFA = typename Gemm::LayoutSFA;
using LayoutSFB = typename Gemm::LayoutSFB;
using ScaleConfig = typename Gemm::ScaleConfig;
using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD;
int32_t m = a.size(0), n = b.size(1), k = a.size(1);
auto prob_shape = cute::make_shape(m, n, k, 1);
StrideA a_stride;
StrideB b_stride;
StrideC c_stride;
a_stride =
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
b_stride =
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
c_stride =
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1));
LayoutSFA layout_SFA =
ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
LayoutSFB layout_SFB =
ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr());
auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr());
typename GemmKernel::MainloopArguments mainloop_args{
a_ptr, a_stride, b_ptr, b_stride,
a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB};
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
typename GemmKernel::EpilogueArguments epilogue_args{
{}, c_ptr, c_stride, c_ptr, c_stride};
c3x::cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
epilogue_args);
}
template <typename OutType>
void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
auto m = a.size(0);
auto k = a.size(1);
auto n = b.size(1);
int sms;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());
auto should_use_2sm = [&sms](int m, int n, int tile1SM = 128) {
return std::ceil(static_cast<float>(m) / tile1SM) *
std::ceil(static_cast<float>(n) / tile1SM) >=
sms;
};
bool use_2sm = should_use_2sm(m, n);
if (use_2sm) {
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, Shape<_256, _128, _128>, Shape<_256, _1, _1>,
Shape<_2, _2, _1>, cutlass::epilogue::TmaWarpSpecialized2Sm,
cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>(
out, a, b, a_scales, b_scales);
} else {
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, Shape<_128, _128, _128>, Shape<_128, _1, _1>,
Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm,
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
out, a, b, a_scales, b_scales);
}
}
} // namespace vllm

View File

@ -1,57 +0,0 @@
#include <torch/all.h>
#include "cuda_utils.h"
template <typename Fp8Func, typename Int8Func, typename BlockwiseFunc>
void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias,
Fp8Func fp8_func, Int8Func int8_func,
BlockwiseFunc blockwise_func) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
int M = a.size(0), N = b.size(1), K = a.size(1);
if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
(b_scales.numel() == 1 || b_scales.numel() == b.size(1))) {
// Standard per-tensor/per-token/per-channel scaling
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (a.dtype() == torch::kFloat8_e4m3fn) {
fp8_func(c, a, b, a_scales, b_scales, bias);
} else {
TORCH_CHECK(a.dtype() == torch::kInt8);
if constexpr (!std::is_same_v<Int8Func, std::nullptr_t>) {
int8_func(c, a, b, a_scales, b_scales, bias);
} else {
TORCH_CHECK(false, "Int8 not supported for this architecture");
}
}
} else {
using GroupShape = std::array<int64_t, 2>;
auto make_group_shape = [](torch::Tensor const& x,
torch::Tensor const& s) -> GroupShape {
TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
return {cuda_utils::ceil_div(x.size(0), s.size(0)),
cuda_utils::ceil_div(x.size(1), s.size(1))};
};
GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
GroupShape b_scale_group_shape = make_group_shape(b, b_scales);
// 1x128 per-token group scales for activations
// 128x128 blockwise scales for weights
TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} &&
b_scale_group_shape == GroupShape{128, 128} &&
a.dtype() == torch::kFloat8_e4m3fn &&
b.dtype() == torch::kFloat8_e4m3fn),
"cutlass_scaled_mm only supports datatype float8_e4m3fn.\n"
"a_scale_group_shape must be [1, 128]. Got: [",
a_scale_group_shape[0], ", ", a_scale_group_shape[1],
"]\n"
"b_scale_group_shape must be [128, 128]. Got: [",
b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
blockwise_func(c, a, b, a_scales, b_scales);
}
}

View File

@ -36,9 +36,4 @@ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
} // namespace vllm

View File

@ -1,6 +1,8 @@
#include "c3x/scaled_mm_helper.hpp"
#include <cudaTypedefs.h>
#include "c3x/scaled_mm_kernels.hpp"
#include "cuda_utils.h"
/*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm100 (Blackwell).
@ -13,10 +15,20 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) {
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
vllm::cutlass_scaled_mm_sm100_fp8,
nullptr, // int8 not supported on SM100
vllm::cutlass_scaled_mm_blockwise_sm100_fp8);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
int M = a.size(0), N = b.size(1), K = a.size(1);
TORCH_CHECK(
(a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
(b_scales.numel() == 1 || b_scales.numel() == b.size(1)),
"Currently, block scaled fp8 gemm is not implemented for Blackwell");
// Standard per-tensor/per-token/per-channel scaling
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn,
"Currently, only fp8 gemm is implemented for Blackwell");
vllm::cutlass_scaled_mm_sm100_fp8(c, a, b, a_scales, b_scales, bias);
}
#endif

View File

@ -1,6 +1,8 @@
#include "c3x/scaled_mm_helper.hpp"
#include <cudaTypedefs.h>
#include "c3x/scaled_mm_kernels.hpp"
#include "cuda_utils.h"
/*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm90a (Hopper).
@ -13,10 +15,49 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) {
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
vllm::cutlass_scaled_mm_sm90_fp8,
vllm::cutlass_scaled_mm_sm90_int8,
vllm::cutlass_scaled_mm_blockwise_sm90_fp8);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
int M = a.size(0), N = b.size(1), K = a.size(1);
if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
(b_scales.numel() == 1 || b_scales.numel() == b.size(1))) {
// Standard per-tensor/per-token/per-channel scaling
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (a.dtype() == torch::kFloat8_e4m3fn) {
vllm::cutlass_scaled_mm_sm90_fp8(c, a, b, a_scales, b_scales, bias);
} else {
TORCH_CHECK(a.dtype() == torch::kInt8);
vllm::cutlass_scaled_mm_sm90_int8(c, a, b, a_scales, b_scales, bias);
}
} else {
using GroupShape = std::array<int64_t, 2>;
auto make_group_shape = [](torch::Tensor const& x,
torch::Tensor const& s) -> GroupShape {
TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
return {cuda_utils::ceil_div(x.size(0), s.size(0)),
cuda_utils::ceil_div(x.size(1), s.size(1))};
};
GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
GroupShape b_scale_group_shape = make_group_shape(b, b_scales);
// 1x128 per-token group scales for activations
// 128x128 blockwise scales for weights
TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} &&
b_scale_group_shape == GroupShape{128, 128} &&
a.dtype() == torch::kFloat8_e4m3fn &&
b.dtype() == torch::kFloat8_e4m3fn),
"cutlass_scaled_mm only supports datatype float8_e4m3fn.\n"
"a_scale_group_shape must be [1, 128]. Got: [",
a_scale_group_shape[0], ", ", a_scale_group_shape[1],
"]\n"
"b_scale_group_shape must be [128, 128]. Got: [",
b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
vllm::cutlass_scaled_mm_blockwise_sm90_fp8(c, a, b, a_scales, b_scales);
}
}
void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,

View File

@ -110,8 +110,6 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
#if defined CUDA_VERSION
if (cuda_device_capability >= 90 && cuda_device_capability < 100) {
return CUDA_VERSION >= 12000;
} else if (cuda_device_capability >= 100) {
return CUDA_VERSION >= 12080;
}
#endif

File diff suppressed because it is too large Load Diff

View File

@ -13,7 +13,6 @@
#include "mmvq.cuh"
#include "mmq.cuh"
#include "moe.cuh"
#include "moe_vec.cuh"
// Q8 gemv
template <typename scalar_t>
@ -378,142 +377,6 @@ torch::Tensor ggml_moe_a8(torch::Tensor X, // input
return Y;
}
torch::Tensor ggml_moe_a8_vec(torch::Tensor X, // input
torch::Tensor W, // expert weights
torch::Tensor topk_ids, int64_t top_k,
int64_t type, int64_t row, int64_t tokens) {
int col = X.sizes()[1];
const int padded = (col + 512 - 1) / 512 * 512;
const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device());
at::Tensor Y = torch::zeros({tokens * top_k, row}, options);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
options = torch::TensorOptions().dtype(torch::kInt32).device(W.device());
at::Tensor quant_X = torch::empty({tokens, padded / 32 * 9}, options);
VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_moe_vec_a8", [&] {
quantize_row_q8_1_cuda<scalar_t>((scalar_t*)X.data_ptr(),
(void*)quant_X.data_ptr(), col, tokens,
stream);
switch (type) {
case 2:
moe_vec_q4_0_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
col, row, quant_X.stride(0), stream);
break;
case 3:
moe_vec_q4_1_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
col, row, quant_X.stride(0), stream);
break;
case 6:
moe_vec_q5_0_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
col, row, quant_X.stride(0), stream);
break;
case 7:
moe_vec_q5_1_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
col, row, quant_X.stride(0), stream);
break;
case 8:
moe_vec_q8_0_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
col, row, quant_X.stride(0), stream);
break;
case 10:
moe_vec_q2_K_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
col, row, quant_X.stride(0), stream);
break;
case 11:
moe_vec_q3_K_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
col, row, quant_X.stride(0), stream);
break;
case 12:
moe_vec_q4_K_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
col, row, quant_X.stride(0), stream);
break;
case 13:
moe_vec_q5_K_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
col, row, quant_X.stride(0), stream);
break;
case 14:
moe_vec_q6_K_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
col, row, quant_X.stride(0), stream);
break;
case 16:
moe_vec_iq2_xxs_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
col, row, quant_X.stride(0), stream);
break;
case 17:
moe_vec_iq2_xs_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
col, row, quant_X.stride(0), stream);
break;
case 18:
moe_vec_iq3_xxs_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
col, row, quant_X.stride(0), stream);
break;
case 19:
moe_vec_iq1_s_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
col, row, quant_X.stride(0), stream);
break;
case 20:
moe_vec_iq4_nl_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
col, row, quant_X.stride(0), stream);
break;
case 21:
moe_vec_iq3_s_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
col, row, quant_X.stride(0), stream);
break;
case 22:
moe_vec_iq2_s_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
col, row, quant_X.stride(0), stream);
break;
case 23:
moe_vec_iq4_xs_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
col, row, quant_X.stride(0), stream);
break;
case 29:
moe_vec_iq1_m_q8_1_cuda<scalar_t>(
(void*)W.data_ptr(), (void*)quant_X.data_ptr(),
(scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens,
col, row, quant_X.stride(0), stream);
break;
}
});
return Y;
}
int64_t ggml_moe_get_block_size(int64_t type) {
switch (type) {
case 2:

View File

@ -1,338 +0,0 @@
// copied and adapted from
// https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmvq.cu
template <typename scalar_t, int qk, int qi, typename block_q_t, int vdr,
vec_dot_q_cuda_t vec_dot_q_cuda>
static __global__ void moe_vec_q(const void* __restrict__ vx,
const void* __restrict__ vy,
scalar_t* __restrict__ dst,
const int* topk_ids, const int topk,
const int ncols, const int nrows,
const int token_stride) {
const auto row = blockIdx.x * blockDim.y + threadIdx.y;
const auto token = blockIdx.z / topk;
const auto expert = (topk_ids)[blockIdx.z];
if (row >= nrows) {
return;
}
const int blocks_per_row = ncols / qk;
const int blocks_per_warp = vdr * WARP_SIZE / qi;
// partial sum for each thread
float tmp = 0.0f;
const block_q_t* x = ((const block_q_t*)vx) + expert * nrows * blocks_per_row;
const block_q8_1* y =
(const block_q8_1*)(((const int*)vy) + token * token_stride);
for (auto i = threadIdx.x / (qi / vdr); i < blocks_per_row;
i += blocks_per_warp) {
const int ibx = row * blocks_per_row + i; // x block index
const int iby = i * (qk / QK8_1); // y block index that aligns with ibx
const int iqs =
vdr *
(threadIdx.x %
(qi / vdr)); // x block quant index when casting the quants to int
tmp += vec_dot_q_cuda(&x[ibx], &y[iby], iqs);
}
// sum up partial sums and write back result
#pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
tmp += VLLM_SHFL_XOR_SYNC(tmp, mask);
}
if (threadIdx.x == 0) {
dst[blockIdx.z * nrows + row] = tmp;
}
}
template <typename scalar_t>
static void moe_vec_q4_0_q8_1_cuda(const void* vx, const void* vy,
scalar_t* dst, const int* topk_ids,
const int top_k, const int tokens,
const int ncols, const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ,
vec_dot_q4_0_q8_1><<<block_nums, block_dims, 0, stream>>>(
vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_q4_1_q8_1_cuda(const void* vx, const void* vy,
scalar_t* dst, const int* topk_ids,
const int top_k, const int tokens,
const int ncols, const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ,
vec_dot_q4_1_q8_1><<<block_nums, block_dims, 0, stream>>>(
vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_q5_0_q8_1_cuda(const void* vx, const void* vy,
scalar_t* dst, const int* topk_ids,
const int top_k, const int tokens,
const int ncols, const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ,
vec_dot_q5_0_q8_1><<<block_nums, block_dims, 0, stream>>>(
vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_q5_1_q8_1_cuda(const void* vx, const void* vy,
scalar_t* dst, const int* topk_ids,
const int top_k, const int tokens,
const int ncols, const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ,
vec_dot_q5_1_q8_1><<<block_nums, block_dims, 0, stream>>>(
vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_q8_0_q8_1_cuda(const void* vx, const void* vy,
scalar_t* dst, const int* topk_ids,
const int top_k, const int tokens,
const int ncols, const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ,
vec_dot_q8_0_q8_1><<<block_nums, block_dims, 0, stream>>>(
vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_q2_K_q8_1_cuda(const void* vx, const void* vy,
scalar_t* dst, const int* topk_ids,
const int top_k, const int tokens,
const int ncols, const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ,
vec_dot_q2_K_q8_1><<<block_nums, block_dims, 0, stream>>>(
vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_q3_K_q8_1_cuda(const void* vx, const void* vy,
scalar_t* dst, const int* topk_ids,
const int top_k, const int tokens,
const int ncols, const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ,
vec_dot_q3_K_q8_1><<<block_nums, block_dims, 0, stream>>>(
vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_q4_K_q8_1_cuda(const void* vx, const void* vy,
scalar_t* dst, const int* topk_ids,
const int top_k, const int tokens,
const int ncols, const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ,
vec_dot_q4_K_q8_1><<<block_nums, block_dims, 0, stream>>>(
vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_q5_K_q8_1_cuda(const void* vx, const void* vy,
scalar_t* dst, const int* topk_ids,
const int top_k, const int tokens,
const int ncols, const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ,
vec_dot_q5_K_q8_1><<<block_nums, block_dims, 0, stream>>>(
vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_q6_K_q8_1_cuda(const void* vx, const void* vy,
scalar_t* dst, const int* topk_ids,
const int top_k, const int tokens,
const int ncols, const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ,
vec_dot_q6_K_q8_1><<<block_nums, block_dims, 0, stream>>>(
vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_iq2_xxs_q8_1_cuda(const void* vx, const void* vy,
scalar_t* dst, const int* topk_ids,
const int top_k, const int tokens,
const int ncols, const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k,
ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_iq2_xs_q8_1_cuda(const void* vx, const void* vy,
scalar_t* dst, const int* topk_ids,
const int top_k, const int tokens,
const int ncols, const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k,
ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_iq2_s_q8_1_cuda(const void* vx, const void* vy,
scalar_t* dst, const int* topk_ids,
const int top_k, const int tokens,
const int ncols, const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK_K, QI2_S, block_iq2_s, 1, vec_dot_iq2_s_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k,
ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_iq3_xxs_q8_1_cuda(const void* vx, const void* vy,
scalar_t* dst, const int* topk_ids,
const int top_k, const int tokens,
const int ncols, const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k,
ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_iq1_s_q8_1_cuda(const void* vx, const void* vy,
scalar_t* dst, const int* topk_ids,
const int top_k, const int tokens,
const int ncols, const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK_K, QI1_S, block_iq1_s, 1, vec_dot_iq1_s_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k,
ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_iq1_m_q8_1_cuda(const void* vx, const void* vy,
scalar_t* dst, const int* topk_ids,
const int top_k, const int tokens,
const int ncols, const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK_K, QI1_M, block_iq1_m, 1, vec_dot_iq1_m_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k,
ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_iq4_nl_q8_1_cuda(const void* vx, const void* vy,
scalar_t* dst, const int* topk_ids,
const int top_k, const int tokens,
const int ncols, const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ,
vec_dot_iq4_nl_q8_1><<<block_nums, block_dims, 0, stream>>>(
vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_iq4_xs_q8_1_cuda(const void* vx, const void* vy,
scalar_t* dst, const int* topk_ids,
const int top_k, const int tokens,
const int ncols, const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK_K, QI4_XS, block_iq4_xs, 1, vec_dot_iq4_xs_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k,
ncols, nrows, token_stride);
}
template <typename scalar_t>
static void moe_vec_iq3_s_q8_1_cuda(const void* vx, const void* vy,
scalar_t* dst, const int* topk_ids,
const int top_k, const int tokens,
const int ncols, const int nrows,
const int token_stride,
cudaStream_t stream) {
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, tokens * top_k);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
moe_vec_q<scalar_t, QK_K, QI3_XS, block_iq3_s, 1, vec_dot_iq3_s_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, topk_ids, top_k,
ncols, nrows, token_stride);
}

View File

@ -9,7 +9,7 @@ at::Tensor as_g_workspace;
torch::Tensor allspark_w8a16_gemm(
torch::Tensor const& a, torch::Tensor const& b_qweight,
torch::Tensor const& b_scales, std::optional<torch::Tensor> const& b_qzeros,
torch::Tensor const& b_scales, c10::optional<torch::Tensor> const& b_qzeros,
int64_t n, int64_t group_size, int64_t sm_count, int64_t sm_version,
int64_t CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) {
TORCH_CHECK_NOT_IMPLEMENTED(
@ -918,7 +918,7 @@ void allspark_qgemm_w8a16_perc_ampere(
torch::Tensor allspark_w8a16_gemm(
torch::Tensor const& a, torch::Tensor const& b_qweight,
torch::Tensor const& b_scales, std::optional<torch::Tensor> const& b_qzeros,
torch::Tensor const& b_scales, c10::optional<torch::Tensor> const& b_qzeros,
int64_t n, int64_t group_size, int64_t sm_count, int64_t sm_version,
int64_t CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) {
// Verify device and strides

View File

@ -100,9 +100,9 @@ void rearrange_kn_weight_as_n32k16_order_ldg16(
void rearrange_kn_weight_as_n32k16_order(
torch::Tensor const& b_qweight, torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& b_zeros, bool has_zp,
c10::optional<torch::Tensor> const& b_zeros, bool has_zp,
torch::Tensor& b_qweight_reorder, torch::Tensor& b_scales_reorder,
std::optional<torch::Tensor> const& b_zeros_reorder, const int64_t K,
c10::optional<torch::Tensor> const& b_zeros_reorder, const int64_t K,
const int64_t N, const int64_t N_32align) {
// Verify device and strides
TORCH_CHECK(b_qweight.device().is_cuda(), "b_qweight is not on GPU");

View File

@ -1 +0,0 @@
kernel_*.cu

View File

@ -1,291 +0,0 @@
#include "marlin_dtypes.cuh"
namespace MARLIN_NAMESPACE_NAME {
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
// Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in
// all cases.
template <int lut>
__device__ inline int lop3(int a, int b, int c) {
int res;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(res)
: "r"(a), "r"(b), "r"(c), "n"(lut));
return res;
}
// Constructs destination register by taking bytes from 2 sources (based on
// mask)
template <int start_byte, int mask>
__device__ inline uint32_t prmt(uint32_t a) {
uint32_t res;
asm volatile("prmt.b32 %0, %1, %2, %3;\n"
: "=r"(res)
: "r"(a), "n"(start_byte), "n"(mask));
return res;
}
template <typename scalar_t2, vllm::ScalarTypeId w_type_id>
__device__ inline void dequant(int q, scalar_t2* frag_b);
//
// Efficiently dequantize 4bit values packed in an int32 value into a full
// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
// with some small changes:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
//
template <>
__device__ inline void dequant<half2, vllm::kU4B8.id()>(int q, half2* frag_b) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// clang-format off
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
// clang-format on
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const int SUB = 0x64086408;
const int MUL = 0x2c002c00;
const int ADD = 0xd480d480;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&SUB));
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&MUL),
*reinterpret_cast<const half2*>(&ADD));
}
template <>
__device__ inline void dequant<half2, vllm::kU4.id()>(int q, half2* frag_b) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// clang-format off
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
// clang-format on
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const int SUB = 0x64006400;
const int MUL = 0x2c002c00;
const int ADD = 0xd400d400;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&SUB));
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&MUL),
*reinterpret_cast<const half2*>(&ADD));
}
template <>
__device__ inline void dequant<nv_bfloat162, vllm::kU4B8.id()>(
int q, nv_bfloat162* frag_b) {
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t EX = 0x43004300;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// clang-format off
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
q >>= 4;
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
// clang-format on
static constexpr uint32_t MUL = 0x3F803F80;
static constexpr uint32_t ADD = 0xC308C308;
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
}
template <>
__device__ inline void dequant<nv_bfloat162, vllm::kU4.id()>(
int q, nv_bfloat162* frag_b) {
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t EX = 0x43004300;
// Guarantee that the `(a & b) | c` operations are LOP3s.
// clang-format off
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
q >>= 4;
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
// clang-format on
static constexpr uint32_t MUL = 0x3F803F80;
static constexpr uint32_t ADD = 0xC300C300;
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
}
//
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
// bf16 Reference:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
//
template <>
__device__ inline void dequant<half2, vllm::kU8B128.id()>(int q,
half2* frag_b) {
static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
}
template <>
__device__ inline void dequant<half2, vllm::kU8.id()>(int q, half2* frag_b) {
static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
}
template <>
__device__ inline void dequant<nv_bfloat162, vllm::kU8B128.id()>(
int q, nv_bfloat162* frag_b) {
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted =
reinterpret_cast<uint32_t*>(fp32_intermediates);
static constexpr uint32_t fp32_base = 0x4B000000;
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
fp32_intermediates[0] -= 8388736.f;
fp32_intermediates[1] -= 8388736.f;
fp32_intermediates[2] -= 8388736.f;
fp32_intermediates[3] -= 8388736.f;
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(frag_b);
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
fp32_intermediates_casted[1], 0x7632);
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
fp32_intermediates_casted[3], 0x7632);
}
template <>
__device__ inline void dequant<nv_bfloat162, vllm::kU8.id()>(
int q, nv_bfloat162* frag_b) {
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted =
reinterpret_cast<uint32_t*>(fp32_intermediates);
static constexpr uint32_t fp32_base = 0x4B000000;
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
fp32_intermediates[0] -= 8388608.f;
fp32_intermediates[1] -= 8388608.f;
fp32_intermediates[2] -= 8388608.f;
fp32_intermediates[3] -= 8388608.f;
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(frag_b);
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
fp32_intermediates_casted[1], 0x7632);
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
fp32_intermediates_casted[3], 0x7632);
}
template <>
__device__ inline void dequant<half2, vllm::kFE4M3fn.id()>(int q,
half2* frag_b) {
// Constants for FP8 (E4M3) and FP16 formats
constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, FP16_EXPONENT = 5;
constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT;
// Calculate MASK for extracting mantissa and exponent
constexpr int MASK1 = 0x80000000;
constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA);
constexpr int MASK3 = MASK2 & 0x7fffffff;
constexpr int MASK = MASK3 | (MASK3 >> 16);
// Final MASK value: 0x7F007F00
// Extract and shift FP8 values to FP16 format
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT);
// Construct and apply exponent bias
constexpr int BIAS_OFFSET =
(1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));
const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));
// Convert to half2 and apply bias
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = __hmul2(*reinterpret_cast<const half2*>(&Out1), bias_reg);
frag_b[0] = __hmul2(*reinterpret_cast<const half2*>(&Out2), bias_reg);
}
template <>
__device__ inline void dequant<nv_bfloat162, vllm::kFE4M3fn.id()>(
int q, nv_bfloat162* frag_b) {
// Constants for FP8 (E4M3) and BF16 formats
constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, BF16_EXPONENT = 8;
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
// Calculate MASK for extracting mantissa and exponent
constexpr int MASK1 = 0x80000000;
constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA);
constexpr int MASK3 = MASK2 & 0x7fffffff;
constexpr int MASK = MASK3 | (MASK3 >> 16);
// Final MASK value: 0x7F007F00
// Extract and shift FP8 values to BF16 format
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT);
// Construct and apply exponent bias
constexpr int BIAS_OFFSET =
(1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));
// Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent
// position
constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23;
const nv_bfloat162 bias_reg =
__float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));
// Convert to bfloat162 and apply bias
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = __hmul2(*reinterpret_cast<const nv_bfloat162*>(&Out1), bias_reg);
frag_b[0] = __hmul2(*reinterpret_cast<const nv_bfloat162*>(&Out2), bias_reg);
}
#endif
} // namespace MARLIN_NAMESPACE_NAME

View File

@ -1,116 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
import glob
import itertools
import os
import subprocess
import jinja2
FILE_HEAD = """
// auto generated by generate.py
// clang-format off
#include "kernel.h"
#include "marlin_template.h"
namespace MARLIN_NAMESPACE_NAME {
""".strip()
TEMPLATE = ("template __global__ void Marlin<"
"{{scalar_t}}, "
"{{w_type_id}}, "
"{{threads}}, "
"{{thread_m_blocks}}, "
"{{thread_n_blocks}}, "
"{{thread_k_blocks}}, "
"{{'true' if m_block_size_8 else 'false'}}, "
"{{stages}}, "
"{{group_blocks}}, "
"{{'true' if is_zp_float else 'false'}}>"
"( MARLIN_KERNEL_PARAMS );")
# int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn"]
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128),
(128, 64, 128)]
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
# group_blocks:
# = 0 : act order case
# = -1 : channelwise quantization
# > 0 : group_size=16*group_blocks
GROUP_BLOCKS = [0, -1, 2, 4, 8]
DTYPES = ["fp16", "bf16"]
def remove_old_kernels():
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"):
subprocess.call(["rm", "-f", filename])
def generate_new_kernels():
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
all_template_str_list = []
for group_blocks, m_blocks, thread_configs in itertools.product(
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):
# act order case only support gptq-int4 and gptq-int8
if group_blocks == 0 and scalar_type not in [
"vllm::kU4B8", "vllm::kU8B128"
]:
continue
if thread_configs[2] == 256:
# for small batch (m_blocks == 1), we only need (128, 128, 256)
# for large batch (m_blocks > 1), we only need (64, 256, 256)
if m_blocks <= 1 and thread_configs[0] != 128:
continue
if m_blocks > 1 and thread_configs[0] != 64:
continue
# we only support channelwise quantization and group_size == 128
# for fp8
if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]:
continue
k_blocks = thread_configs[0] // 16
n_blocks = thread_configs[1] // 16
threads = thread_configs[2]
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
is_zp_float_list = [False]
if dtype == "fp16" and scalar_type == "vllm::kU4" and \
group_blocks == 4:
# HQQ (is_zp_float = true) only supports
# 4bit quantization and fp16
is_zp_float_list.append(True)
for is_zp_float in is_zp_float_list:
template_str = jinja2.Template(TEMPLATE).render(
scalar_t=c_dtype,
w_type_id=scalar_type + ".id()",
threads=threads,
thread_m_blocks=max(m_blocks, 1),
thread_n_blocks=n_blocks,
thread_k_blocks=k_blocks,
m_block_size_8=m_blocks == 0.5,
stages="pipe_stages",
group_blocks=group_blocks,
is_zp_float=is_zp_float,
)
all_template_str_list.append(template_str)
file_content = FILE_HEAD + "\n\n"
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu"
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
f.write(file_content)
if __name__ == "__main__":
remove_old_kernels()
generate_new_kernels()

File diff suppressed because it is too large Load Diff

View File

@ -1,37 +0,0 @@
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin
#endif
#include "marlin.cuh"
#include "marlin_dtypes.cuh"
#include "core/scalar_type.hpp"
#define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \
const int *__restrict__ g_idx, int num_groups, int prob_m, int prob_n, \
int prob_k, int lda, int *locks, bool use_atomic_add, \
bool use_fp32_reduce, int max_shared_mem
namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const int thread_n_blocks, // same for n dimension (output)
const int thread_k_blocks, // same for k dimension (reduction)
const bool m_block_size_8, // whether m_block_size == 8
// only works when thread_m_blocks == 1
const int stages, // number of stages for the async global->shared
// fetch pipeline
const int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
>
__global__ void Marlin(MARLIN_KERNEL_PARAMS);
}

View File

@ -1287,7 +1287,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
// max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) {
const int max_num_partitions) {
const auto num_heads = gridDim.x;
const auto head_idx = blockIdx.x;
const auto seq_idx = blockIdx.y;
@ -1465,10 +1465,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
const float inv_global_exp_sum =
__fdividef(1.0f, shared_global_exp_sum + 1e-6f);
const float out_scale =
(fp8_out_scale_ptr != nullptr) ? 1.0f / (*fp8_out_scale_ptr) : 1.0f;
acc *= inv_global_exp_sum;
acc *= out_scale;
const int64_t query_start_off = static_cast<int64_t>(
query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx);
OUTT* out_ptr = out + query_start_off * num_heads * HEAD_SIZE +
@ -1550,7 +1548,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) {
const int max_num_partitions) {
UNREACHABLE_CODE
}
// clang-format on
@ -1584,8 +1582,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
PARTITION_SIZE, NPAR_LOOPS> \
<<<reduce_grid, reduce_block, 0, stream>>>( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \
context_lens_ptr, query_start_loc_ptr, max_num_partitions, \
fp8_out_scale_ptr);
context_lens_ptr, query_start_loc_ptr, max_num_partitions);
template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD,
@ -1597,7 +1594,7 @@ void paged_attention_custom_launcher(
torch::Tensor& block_tables, torch::Tensor& context_lens,
const std::optional<torch::Tensor>& query_start_loc, int max_context_len,
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale, const std::optional<torch::Tensor>& fp8_out_scale) {
torch::Tensor& v_scale) {
int num_seqs = block_tables.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
@ -1629,11 +1626,6 @@ void paged_attention_custom_launcher(
int* context_lens_ptr = context_lens.data_ptr<int>();
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
// NOTE: fp8_out_scale is optional.
const auto fp8_out_scale_ptr =
fp8_out_scale
? static_cast<const float*>(fp8_out_scale.value().data_ptr())
: nullptr;
OUTT* out_ptr = reinterpret_cast<OUTT*>(out.data_ptr());
const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE);
@ -1744,54 +1736,33 @@ void paged_attention_custom_launcher(
}
}
#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \
PSIZE, ALIBI_ENABLED) \
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \
PSIZE, ALIBI_ENABLED>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, context_lens, query_start_loc, \
max_context_len, alibi_slopes, k_scale, v_scale, fp8_out_scale);
#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, \
ALIBI_ENABLED) \
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
PSIZE, ALIBI_ENABLED>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, context_lens, query_start_loc, \
max_context_len, alibi_slopes, k_scale, v_scale);
#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
OUTT, PSIZE) \
if (alibi_slopes) { \
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \
true); \
} else { \
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \
false); \
#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
PSIZE) \
if (alibi_slopes) { \
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, true); \
} else { \
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, false); \
}
#if defined(__HIPCC__) && defined(__gfx90a__)
#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \
if (fp8_out_scale) { \
TORCH_CHECK(false, "fp8 out scale unsupported for gfx90a"); \
} else { \
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
256); \
}
#else
#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \
if (fp8_out_scale) { \
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
uint8_t, 256); \
} else { \
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
256); \
}
#endif
#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \
switch (block_size) { \
case 16: \
CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \
break; \
case 32: \
CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \
switch (block_size) { \
case 16: \
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, 16, HEAD_SIZE, 256); \
break; \
case 32: \
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, 32, HEAD_SIZE, 256); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE) \
@ -1824,8 +1795,7 @@ void paged_attention(
int64_t block_size, int64_t max_context_len,
const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale,
const std::optional<torch::Tensor>& fp8_out_scale) {
torch::Tensor& v_scale) {
// clang-format on
const int head_size = query.size(2);
if (kv_cache_dtype == "auto") {

View File

@ -11,12 +11,14 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
at::Tensor& scale_a, at::Tensor& scale_b, const int64_t CuCount);
void paged_attention(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& context_lens,
const std::optional<torch::Tensor>& query_start_loc, int64_t block_size,
int64_t max_context_len, const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, const std::optional<torch::Tensor>& fp8_out_scale);
void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
torch::Tensor& max_logits, torch::Tensor& tmp_out,
torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads,
double scale, torch::Tensor& block_tables,
torch::Tensor& context_lens,
const std::optional<torch::Tensor>& query_start_loc,
int64_t block_size, int64_t max_context_len,
const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale);

View File

@ -126,8 +126,8 @@ __global__ void LLGemm1_kernel(const scalar_t* in_a, const scalar_t* in_b,
const int warp = threadIdx.x / WARP_SIZE;
const int lane = threadIdx.x % WARP_SIZE;
const int num_warps = blockDim.x / WARP_SIZE;
const int qwarpid = threadid / 16;
const int qthreadid = threadid % 16;
const int qwarpid = threadid / num_warps;
const int qthreadid = threadid % num_warps;
float4 rowA_elem4[NUM_A_ROWS_PER_BLOCK];
scalar2_t colB_elem4x, colB_elem4y, colB_elem4z, colB_elem4w;
float acc[NUM_A_ROWS_PER_BLOCK];
@ -142,13 +142,15 @@ __global__ void LLGemm1_kernel(const scalar_t* in_a, const scalar_t* in_b,
// rowA_elem4[i] holds 8 * half numbers seen as a single float4.
rowA_elem4[i] = load_ntmprl(&af4[row_addr + threadid + K / 8 * i]);
}
colB_elem4x = bf4[threadid * 4 + 0];
colB_elem4y = bf4[threadid * 4 + 1];
colB_elem4z = bf4[threadid * 4 + 2];
colB_elem4w = bf4[threadid * 4 + 3];
}
colB_elem4x = bf4[threadid * 4 + 0];
colB_elem4y = bf4[threadid * 4 + 1];
colB_elem4z = bf4[threadid * 4 + 2];
colB_elem4w = bf4[threadid * 4 + 3];
scalar2_t Af2;
[[maybe_unused]] scalar2_t Bf2;
float2 S;
auto Ah2ptr = reinterpret_cast<scalar2_t*>(&rowA_elem4);
@ -191,13 +193,12 @@ __global__ void LLGemm1_kernel(const scalar_t* in_a, const scalar_t* in_b,
if (qwarpid < NUM_A_ROWS_PER_BLOCK) {
acc[qwarpid] = qthreadid < num_warps ? red_smem[qwarpid][qthreadid] : 0.f;
#pragma unroll
for (int mask = 16 / 2; mask >= 1; mask /= 2) {
for (int mask = num_warps / 2; mask >= 1; mask /= 2) {
acc[qwarpid] += __shfl_xor(acc[qwarpid], mask);
}
float oval2 = __shfl_xor(acc[qwarpid], 16);
float oval2 = __shfl_xor(acc[qwarpid], num_warps);
if (lane % 32 == 0) {
if (lane % (num_warps * 2) == 0) {
oval = __float22s2_rn<scalar2_t>(make_float2(acc[qwarpid], oval2));
c[blockIdx.x * NUM_A_ROWS_PER_BLOCK / 2 + qwarpid / 2] = oval;
}
@ -221,10 +222,9 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b,
// NUM_TREADS need to be a multiple of WARP_SIZE, as we are using warp shuffle
// operations.
const int NUM_THREADS =
max(rows_per_block * 16,
K * 2 / 16 % WARP_SIZE == 0
? K * 2 / 16
: K * 2 / 16 + (WARP_SIZE - K * 2 / 16 % WARP_SIZE));
K * 2 / 16 % WARP_SIZE == 0
? K * 2 / 16
: K * 2 / 16 + (WARP_SIZE - K * 2 / 16 % WARP_SIZE);
int NUM_BLOCKS = M / rows_per_block;
@ -275,22 +275,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B,
const scalar_t* __restrict__ A, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
#if defined(__HIP__MI300__)
constexpr bool use_mfma = (std::is_same_v<scalar_t, __hip_bfloat16>);
#else
constexpr bool use_mfma = false;
#endif
using scalar8 =
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float;
using half4 =
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(__bf16)))) __bf16;
union bigType {
scalar_t h[A_CHUNK];
float f[A_CHUNK / 2];
float2 f2[A_CHUNK / 4];
double d[A_CHUNK / 4];
half4 h4[A_CHUNK / 4];
scalar8 h8;
};
@ -327,7 +318,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE;
float sum[N][YTILE];
scalar8 sum4[N][YTILE];
//----------------------------------------------------
// Each wave works on a single column of weight matrix.
@ -353,11 +343,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// are being worked on by each wave.
//----------------------------------------------------
for (int i = 0; i < YTILE; i++)
for (int n = 0; n < N; n++)
if constexpr (!use_mfma)
sum[n][i] = 0;
else
sum4[n][i] = {0, 0, 0, 0};
for (int n = 0; n < N; n++) sum[n][i] = 0;
bigType bigA[N][UNRL];
bigType bigB[YTILE][UNRL];
@ -388,8 +374,24 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
if (k_ >= K) break;
const scalar_t* B_ = &B[(m + 0) * K + k_];
for (int y = 0; y < YTILE; y++)
bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[y * K])));
bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K])));
//----------------------------------------------------
// The following code with YTILE > 1 has to be deleted
//----------------------------------------------------
if constexpr (YTILE >= 2)
bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K])));
if constexpr (YTILE >= 3)
bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K])));
if constexpr (YTILE >= 4)
bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K])));
if constexpr (YTILE >= 5)
bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K])));
if constexpr (YTILE >= 6)
bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K])));
if constexpr (YTILE >= 7)
bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K])));
if constexpr (YTILE >= 8)
bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K])));
}
// Fetch activation matrix from either just LDS or from both LDS / memory
@ -417,17 +419,32 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
#pragma unroll
for (uint32_t n = 0; n < N; n++) {
#pragma unroll
for (int y = 0; y < YTILE; y++) {
if constexpr (!use_mfma)
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b])
}
else
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 4; b++)
sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0);
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
DOT2C(sum[n][0], bigA[n][k2].f[b], bigB[0][k2].f[b])
//----------------------------------------------------
// The following code with YTILE > 1
//----------------------------------------------------
if constexpr (YTILE >= 2) {
DOT2C(sum[n][1], bigA[n][k2].f[b], bigB[1][k2].f[b]);
}
if constexpr (YTILE >= 3) {
DOT2C(sum[n][2], bigA[n][k2].f[b], bigB[2][k2].f[b]);
}
if constexpr (YTILE >= 4) {
DOT2C(sum[n][3], bigA[n][k2].f[b], bigB[3][k2].f[b]);
}
if constexpr (YTILE >= 5) {
DOT2C(sum[n][4], bigA[n][k2].f[b], bigB[4][k2].f[b]);
}
if constexpr (YTILE >= 6) {
DOT2C(sum[n][5], bigA[n][k2].f[b], bigB[5][k2].f[b]);
}
if constexpr (YTILE >= 7) {
DOT2C(sum[n][6], bigA[n][k2].f[b], bigB[6][k2].f[b]);
}
if constexpr (YTILE >= 8) {
DOT2C(sum[n][7], bigA[n][k2].f[b], bigB[7][k2].f[b]);
}
}
}
}
@ -436,84 +453,37 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
//----------------------------------------------------
// Final reduction step using shuffle
//----------------------------------------------------
if constexpr (!use_mfma) {
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
}
}
if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
}
}
if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
// if (commitColumn[i]) C[m + i + n * M] = __float2half(sum[n][i]);
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
}
}
}
} else {
#pragma unroll
for (int n = 0; n < N; n++) {
#pragma unroll
for (int y = 0; y < YTILE; y++) {
// float accm1 = 0;
// for (int i=0; i<64; i++)
// accm1 += __shfl(sum4[n][y][i%4], i);
float accm = sum4[n][y][0];
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][1]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][2]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][3]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
sum4[n][y][0] = accm;
}
}
if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
// if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]);
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
}
for (int i = 0; i < YTILE; i++) {
// if (commitColumn[i]) C[m + i + n * M] = __float2half(sum[n][i]);
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
}
}
}
m += CuCount * _WvPrGrp * YTILE;
}
}
@ -535,22 +505,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitK_hf_(const int K, const int M, const scalar_t* B,
const scalar_t* __restrict__ A, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
#if defined(__HIP__MI300__)
constexpr bool use_mfma = (std::is_same_v<scalar_t, __hip_bfloat16>);
#else
constexpr bool use_mfma = false;
#endif
using scalar8 =
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float;
using half4 =
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(__bf16)))) __bf16;
union bigType {
scalar_t h[A_CHUNK];
float f[A_CHUNK / 2];
float2 f2[A_CHUNK / 4];
double d[A_CHUNK / 4];
half4 h4[A_CHUNK / 4];
scalar8 h8;
};
@ -612,7 +573,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
if (threadIdx.y >= _WvPrGrp) return;
float sum[N][YTILE];
scalar8 sum4[N][YTILE];
//----------------------------------------------------
// Each wave works on a single column of weight matrix.
@ -638,11 +598,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// are being worked on by each wave.
//----------------------------------------------------
for (int i = 0; i < YTILE; i++)
for (int n = 0; n < N; n++)
if constexpr (!use_mfma)
sum[n][i] = 0;
else
sum4[n][i] = {0, 0, 0, 0};
for (int n = 0; n < N; n++) sum[n][i] = 0;
bigType bigA[N][UNRL];
bigType bigB[YTILE][UNRL];
@ -672,8 +628,24 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
if (k_ >= K) break;
const scalar_t* B_ = &B[(m + 0) * K + k_];
for (int b = 0; b < YTILE; b++)
bigB[b][k2].h8 = (loadnt((scalar8*)(&B_[b * K])));
bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K])));
//----------------------------------------------------
// The following code with YTILE > 1 has to be deleted
//----------------------------------------------------
if constexpr (YTILE >= 2)
bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K])));
if constexpr (YTILE >= 3)
bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K])));
if constexpr (YTILE >= 4)
bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K])));
if constexpr (YTILE >= 5)
bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K])));
if constexpr (YTILE >= 6)
bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K])));
if constexpr (YTILE >= 7)
bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K])));
if constexpr (YTILE >= 8)
bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K])));
}
// Fetch activation matrix from either just LDS or from both LDS / memory
@ -704,17 +676,32 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// Do the matrix multiplication of activation and weight matrix
// - Remember the accumulation is happening for K-split of 64!
#pragma unroll
for (int y = 0; y < YTILE; y++) {
if constexpr (!use_mfma)
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b])
}
else
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 4; b++)
sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0);
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
DOT2C(sum[n][0], bigA[n][k2].f[b], bigB[0][k2].f[b]);
//----------------------------------------------------
// The following code with YTILE > 1
//----------------------------------------------------
if constexpr (YTILE >= 2) {
DOT2C(sum[n][1], bigA[n][k2].f[b], bigB[1][k2].f[b]);
}
if constexpr (YTILE >= 3) {
DOT2C(sum[n][2], bigA[n][k2].f[b], bigB[2][k2].f[b]);
}
if constexpr (YTILE >= 4) {
DOT2C(sum[n][3], bigA[n][k2].f[b], bigB[3][k2].f[b]);
}
if constexpr (YTILE >= 5) {
DOT2C(sum[n][4], bigA[n][k2].f[b], bigB[4][k2].f[b]);
}
if constexpr (YTILE >= 6) {
DOT2C(sum[n][5], bigA[n][k2].f[b], bigB[5][k2].f[b]);
}
if constexpr (YTILE >= 7) {
DOT2C(sum[n][6], bigA[n][k2].f[b], bigB[6][k2].f[b]);
}
if constexpr (YTILE >= 8) {
DOT2C(sum[n][7], bigA[n][k2].f[b], bigB[7][k2].f[b]);
}
}
}
}
@ -723,82 +710,34 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
//----------------------------------------------------
// Final reduction step using shuffle
//----------------------------------------------------
if constexpr (!use_mfma) {
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
}
}
if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
}
}
if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
if (commitColumn[i])
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
}
}
}
} else {
#pragma unroll
for (int n = 0; n < N; n++) {
#pragma unroll
for (int y = 0; y < YTILE; y++) {
// float accm1 = 0;
// for (int i=0; i<64; i++)
// accm1 += __shfl(sum4[n][y][i%4], i);
float accm = sum4[n][y][0];
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][1]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][2]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][3]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
sum4[n][y][0] = accm;
}
}
if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
// if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]);
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
}
for (int i = 0; i < YTILE; i++) {
if (commitColumn[i])
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
}
}
}
@ -835,22 +774,14 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
wvSplitK_hf_big_(const int K, const int M, const scalar_t* B,
const scalar_t* __restrict__ A, scalar_t* C,
const int _WvPrGrp, const int CuCount) {
#if defined(__HIP__MI300__)
constexpr bool use_mfma = (std::is_same_v<scalar_t, __hip_bfloat16>);
#else
constexpr bool use_mfma = false;
#endif
using scalar8 =
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float;
using half4 =
__attribute__((__vector_size__((A_CHUNK / 2) * sizeof(__bf16)))) __bf16;
union bigType {
scalar_t h[A_CHUNK];
float f[A_CHUNK / 2];
float2 f2[A_CHUNK / 4];
double d[A_CHUNK / 4];
half4 h4[A_CHUNK / 4];
scalar8 h8;
};
@ -926,7 +857,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
kFit = min(kFit, K);
float sum[N][YTILE];
scalar8 sum4[N][YTILE];
//----------------------------------------------------
// Each wave works on a single column of weight matrix.
@ -958,11 +888,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// are being worked on by each wave.
//----------------------------------------------------
for (int i = 0; i < YTILE; i++)
for (int n = 0; n < N; n++)
if constexpr (!use_mfma)
sum[n][i] = 0;
else
sum4[n][i] = {0, 0, 0, 0};
for (int n = 0; n < N; n++) sum[n][i] = 0;
bigType bigA[N][UNRL];
bigType bigB[YTILE][UNRL];
@ -1011,8 +937,24 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
if (k_ >= K) break;
const scalar_t* B_ = &B[(m + 0) * K + k_];
for (int b = 0; b < YTILE; b++)
bigB[b][k2].h8 = (loadnt((scalar8*)(&B_[b * K])));
bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K])));
//----------------------------------------------------
// The following code with YTILE > 1 has to be deleted
//----------------------------------------------------
if constexpr (YTILE >= 2)
bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K])));
if constexpr (YTILE >= 3)
bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K])));
if constexpr (YTILE >= 4)
bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K])));
if constexpr (YTILE >= 5)
bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K])));
if constexpr (YTILE >= 6)
bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K])));
if constexpr (YTILE >= 7)
bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K])));
if constexpr (YTILE >= 8)
bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K])));
}
// Fetch activation matrix from either just LDS or from both LDS / memory
@ -1047,17 +989,32 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
// Do the matrix multiplication of activation and weight matrix
// - Remember the accumulation is happening for K-split of 64!
#pragma unroll
for (int y = 0; y < YTILE; y++) {
if constexpr (!use_mfma)
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b])
}
else
#pragma unroll
for (uint32_t b = 0; b < A_CHUNK / 4; b++)
sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0);
for (uint32_t b = 0; b < A_CHUNK / 2; b++) {
DOT2C(sum[n][0], bigA[n][k2].f[b], bigB[0][k2].f[b]);
//----------------------------------------------------
// The following code with YTILE > 1
//----------------------------------------------------
if constexpr (YTILE >= 2) {
DOT2C(sum[n][1], bigA[n][k2].f[b], bigB[1][k2].f[b]);
}
if constexpr (YTILE >= 3) {
DOT2C(sum[n][2], bigA[n][k2].f[b], bigB[2][k2].f[b]);
}
if constexpr (YTILE >= 4) {
DOT2C(sum[n][3], bigA[n][k2].f[b], bigB[3][k2].f[b]);
}
if constexpr (YTILE >= 5) {
DOT2C(sum[n][4], bigA[n][k2].f[b], bigB[4][k2].f[b]);
}
if constexpr (YTILE >= 6) {
DOT2C(sum[n][5], bigA[n][k2].f[b], bigB[5][k2].f[b]);
}
if constexpr (YTILE >= 7) {
DOT2C(sum[n][6], bigA[n][k2].f[b], bigB[6][k2].f[b]);
}
if constexpr (YTILE >= 8) {
DOT2C(sum[n][7], bigA[n][k2].f[b], bigB[7][k2].f[b]);
}
}
}
}
@ -1074,78 +1031,34 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS)
//----------------------------------------------------
// Final reduction step using shuffle
//----------------------------------------------------
if constexpr (!use_mfma) {
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
}
for (int n = 0; n < N; n++) {
for (int y = 0; y < YTILE; y++) {
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 "
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(sum[n][y])
: "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y]));
}
}
if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
if (commitColumn[i])
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
}
}
}
} else {
#pragma unroll
if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) {
#pragma unroll
for (int y = 0; y < YTILE; y++) {
float accm = sum4[n][y][0];
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][1]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][2]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(sum4[n][y][3]), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 "
: "=v"(accm)
: "0"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0"
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0"
: "=v"(accm)
: "0"(accm), "v"(accm), "v"(accm));
sum4[n][y][0] = accm;
}
}
if (threadIdx.x == 63) {
for (int n = 0; n < N; n++) {
for (int i = 0; i < YTILE; i++) {
// if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]);
C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]);
}
for (int i = 0; i < YTILE; i++) {
if (commitColumn[i])
C[m + i + n * M] = __float2s<scalar_t>(sum[n][i]);
}
}
}

View File

@ -47,8 +47,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
" int max_context_len,"
" Tensor? alibi_slopes,"
" str kv_cache_dtype,"
" Tensor k_scale, Tensor v_scale,"
" Tensor? fp8_out_scale) -> ()");
" Tensor k_scale, Tensor v_scale) -> ()");
rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention);
}

View File

@ -176,7 +176,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops.def(
"rotary_embedding(Tensor positions, Tensor! query,"
" Tensor!? key, int head_size,"
" Tensor! key, int head_size,"
" Tensor cos_sin_cache, bool is_neox) -> ()");
ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
@ -184,7 +184,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// (supports multiple loras).
ops.def(
"batched_rotary_embedding(Tensor positions, Tensor! query,"
" Tensor!? key, int head_size,"
" Tensor! key, int head_size,"
" Tensor cos_sin_cache, bool is_neox,"
" int rot_dim,"
" Tensor cos_sin_cache_offsets) -> ()");
@ -291,11 +291,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// gptq_marlin Optimized Quantized GEMM for GPTQ.
ops.def(
"gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, "
"Tensor b_scales, Tensor? b_zeros_or_none, Tensor? g_idx_or_none, "
"Tensor? perm_or_none, Tensor workspace, int b_q_type, "
"gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
"Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
"int b_q_type, "
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
"bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor",
"bool has_zp, bool use_atomic_add, bool use_fp32_reduce, "
"bool is_zp_float) -> Tensor",
{stride_tag});
// conditionally compiled so impl registration is in source file
@ -337,15 +338,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"int type, SymInt row, SymInt top_k, SymInt tokens) -> Tensor");
ops.impl("ggml_moe_a8", torch::kCUDA, &ggml_moe_a8);
ops.def(
"ggml_moe_a8_vec(Tensor X, Tensor W, "
"Tensor topk_ids, int top_k, "
"int type, SymInt row, SymInt tokens) -> Tensor");
ops.impl("ggml_moe_a8_vec", torch::kCUDA, &ggml_moe_a8_vec);
ops.def("ggml_moe_get_block_size", &ggml_moe_get_block_size);
#ifndef USE_ROCM
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
ops.def(
"fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
"Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, "
"SymInt size_k) -> Tensor",
{stride_tag});
// conditionally compiled so impl registration is in source file
// marlin_qqq_gemm for QQQ.
ops.def(
"marlin_qqq_gemm(Tensor a, Tensor b_q_weight, "

View File

@ -309,7 +309,5 @@ ENV HF_HUB_ENABLE_HF_TRANSFER 1
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r requirements/nightly_torch_test.txt
# Logging to confirm the torch versions
RUN pip freeze | grep -E 'torch|xformers|vllm|flashinfer'
#################### UNITTEST IMAGE #############################

View File

@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
ARG FA_BRANCH="1a7f4dfa"
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
ARG AITER_BRANCH="5a77249"
ARG AITER_BRANCH="7e1ed08"
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
FROM ${BASE_IMAGE} AS base

Binary file not shown.

Before

Width:  |  Height:  |  Size: 119 KiB

After

Width:  |  Height:  |  Size: 118 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 107 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 95 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 143 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 265 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 52 KiB

View File

@ -4,7 +4,6 @@
We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below:
- [NYC vLLM Meetup](https://lu.ma/c1rqyf1f), May 7th, 2025. [[Slides]](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing)
- [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day), April 3rd 2025. [[Slides]](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing).
- [vLLM x Ollama Inference Night](https://lu.ma/vllm-ollama), March 27th 2025. [[Slides]](https://docs.google.com/presentation/d/16T2PDD1YwRnZ4Tu8Q5r6n53c5Lr5c73UV9Vd2_eBo4U/edit?usp=sharing).
- [The first vLLM China Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg), March 16th 2025. [[Slides]](https://docs.google.com/presentation/d/1REHvfQMKGnvz6p3Fd23HhSO4c8j5WPGZV0bKYLwnHyQ/edit?usp=sharing).

View File

@ -60,6 +60,9 @@ autodoc2_packages = [
autodoc2_output_dir = "api"
autodoc2_render_plugin = "myst"
autodoc2_hidden_objects = ["dunder", "private", "inherited"]
autodoc2_docstring_parser_regexes = [
(".*", "docs.source.autodoc2_docstring_parser"),
]
autodoc2_sort_names = True
autodoc2_index_template = None

View File

@ -1,36 +0,0 @@
(deployment-chatbox)=
# Chatbox
[Chatbox](https://github.com/chatboxai/chatbox) is a desktop client for LLMs, available on Windows, Mac, Linux.
It allows you to deploy a large language model (LLM) server with vLLM as the backend, which exposes OpenAI-compatible endpoints.
## Prerequisites
- Setup vLLM environment
## Deploy
- Start the vLLM server with the supported chat completion model, e.g.
```console
vllm serve qwen/Qwen1.5-0.5B-Chat
```
- Download and install [Chatbox desktop](https://chatboxai.app/en#download).
- On the bottom left of settings, Add Custom Provider
- API Mode: `OpenAI API Compatible`
- Name: vllm
- API Host: `http://{vllm server host}:{vllm server port}/v1`
- API Path: `/chat/completions`
- Model: `qwen/Qwen1.5-0.5B-Chat`
:::{image} /assets/deployment/chatbox-settings.png
:::
- Go to `Just chat`, and start to chat:
:::{image} /assets/deployment/chatbox-chat.png
:::

View File

@ -1,56 +0,0 @@
(deployment-dify)=
# Dify
[Dify](https://github.com/langgenius/dify) is an open-source LLM app development platform. Its intuitive interface combines agentic AI workflow, RAG pipeline, agent capabilities, model management, observability features, and more, allowing you to quickly move from prototype to production.
It supports vLLM as a model provider to efficiently serve large language models.
This guide walks you through deploying Dify using a vLLM backend.
## Prerequisites
- Setup vLLM environment
- Install [Docker](https://docs.docker.com/engine/install/) and [Docker Compose](https://docs.docker.com/compose/install/)
## Deploy
- Start the vLLM server with the supported chat completion model, e.g.
```console
vllm serve Qwen/Qwen1.5-7B-Chat
```
- Start the Dify server with docker compose ([details](https://github.com/langgenius/dify?tab=readme-ov-file#quick-start)):
```console
git clone https://github.com/langgenius/dify.git
cd dify
cd docker
cp .env.example .env
docker compose up -d
```
- Open the browser to access `http://localhost/install`, config the basic login information and login.
- In the top-right user menu (under the profile icon), go to Settings, then click `Model Provider`, and locate the `vLLM` provider to install it.
- Fill in the model provider details as follows:
- **Model Type**: `LLM`
- **Model Name**: `Qwen/Qwen1.5-7B-Chat`
- **API Endpoint URL**: `http://{vllm_server_host}:{vllm_server_port}/v1`
- **Model Name for API Endpoint**: `Qwen/Qwen1.5-7B-Chat`
- **Completion Mode**: `Completion`
:::{image} /assets/deployment/dify-settings.png
:::
- To create a test chatbot, go to `Studio → Chatbot → Create from Blank`, then select Chatbot as the type:
:::{image} /assets/deployment/dify-create-chatbot.png
:::
- Click the chatbot you just created to open the chat interface and start interacting with the model:
:::{image} /assets/deployment/dify-chat.png
:::

View File

@ -6,14 +6,11 @@
anything-llm
bentoml
cerebrium
chatbox
dify
dstack
helm
lws
modal
open-webui
retrieval_augmented_generation
skypilot
streamlit
triton

View File

@ -1,84 +0,0 @@
(deployment-retrieval-augmented-generation)=
# Retrieval-Augmented Generation
[Retrieval-augmented generation (RAG)](https://en.wikipedia.org/wiki/Retrieval-augmented_generation) is a technique that enables generative artificial intelligence (Gen AI) models to retrieve and incorporate new information. It modifies interactions with a large language model (LLM) so that the model responds to user queries with reference to a specified set of documents, using this information to supplement information from its pre-existing training data. This allows LLMs to use domain-specific and/or updated information. Use cases include providing chatbot access to internal company data or generating responses based on authoritative sources.
Here are the integrations:
- vLLM + [langchain](https://github.com/langchain-ai/langchain) + [milvus](https://github.com/milvus-io/milvus)
- vLLM + [llamaindex](https://github.com/run-llama/llama_index) + [milvus](https://github.com/milvus-io/milvus)
## vLLM + langchain
### Prerequisites
- Setup vLLM and langchain environment
```console
pip install -U vllm \
langchain_milvus langchain_openai \
langchain_community beautifulsoup4 \
langchain-text-splitters
```
### Deploy
- Start the vLLM server with the supported embedding model, e.g.
```console
# Start embedding service (port 8000)
vllm serve ssmits/Qwen2-7B-Instruct-embed-base
```
- Start the vLLM server with the supported chat completion model, e.g.
```console
# Start chat service (port 8001)
vllm serve qwen/Qwen1.5-0.5B-Chat --port 8001
```
- Use the script: <gh-file:examples/online_serving/retrieval_augmented_generation_with_langchain.py>
- Run the script
```python
python retrieval_augmented_generation_with_langchain.py
```
## vLLM + llamaindex
### Prerequisites
- Setup vLLM and llamaindex environment
```console
pip install vllm \
llama-index llama-index-readers-web \
llama-index-llms-openai-like \
llama-index-embeddings-openai-like \
llama-index-vector-stores-milvus \
```
### Deploy
- Start the vLLM server with the supported embedding model, e.g.
```console
# Start embedding service (port 8000)
vllm serve ssmits/Qwen2-7B-Instruct-embed-base
```
- Start the vLLM server with the supported chat completion model, e.g.
```console
# Start chat service (port 8001)
vllm serve qwen/Qwen1.5-0.5B-Chat --port 8001
```
- Use the script: <gh-file:examples/online_serving/retrieval_augmented_generation_with_llamaindex.py>
- Run the script
```python
python retrieval_augmented_generation_with_llamaindex.py
```

View File

@ -86,7 +86,7 @@ To improve privacy in shared environments, vLLM supports isolating prefix cache
{"role": "user", "content": "Here is a document with details about the world series: ..."},
{"role": "user", "content": "Who won the world series in 2020?"}
],
"cache_salt": "your-cache-salt"
"cache_salt": "Z3V2bmV3aGxza3ZubGFoZ3Zud3V3ZWZ2bmd0b3V2bnZmc2xpZ3RoZ2x2aQ=="
}
```

View File

@ -137,9 +137,3 @@ By default, vLLM will try to determine a set of sizes to capture cudagraph. You
`vllm serve meta-llama/Llama-3.2-1B --compilation-config "{'cudagraph_capture_sizes': [1, 2, 4, 8]}"`
Then it will only capture cudagraph for the specified sizes. It can be useful to have fine-grained control over the cudagraph capture.
### Full Cudagraph capture
It is possible to include attention as part of the cudagraph if using an attention backend that is cudagraph compatible. This can improve performance in some cases such as decode speed for smaller models. Enable this using `--compilation-config "{'full_cuda_graph': True}"`
Currently only FlashAttention 3 is compatible, and only when cascade attention is disabled.

View File

@ -19,6 +19,24 @@ FP8 computation is supported on NVIDIA GPUs with compute capability > 8.9 (Ada L
FP8 models will run on compute capability > 8.0 (Ampere) as weight-only W8A16, utilizing FP8 Marlin.
:::
## Quick Start with Online Dynamic Quantization
Dynamic quantization of an original precision BF16/FP16 model to FP8 can be achieved with vLLM without any calibration data required. You can enable the feature by specifying `--quantization="fp8"` in the command line or setting `quantization="fp8"` in the LLM constructor.
In this mode, all Linear modules (except for the final `lm_head`) have their weights quantized down to FP8_E4M3 precision with a per-tensor scale. Activations have their minimum and maximum values calculated during each forward pass to provide a dynamic per-tensor scale for high accuracy. As a result, latency improvements are limited in this mode.
```python
from vllm import LLM
model = LLM("facebook/opt-125m", quantization="fp8")
# INFO 06-10 17:55:42 model_runner.py:157] Loading model weights took 0.1550 GB
result = model.generate("Hello, my name is")
print(result[0].outputs[0].text)
```
:::{warning}
Currently, we load the model at original precision before quantizing down to 8-bits, so you need enough memory to load the whole model.
:::
## Installation
To produce performant FP8 quantized models with vLLM, you'll need to install the [llm-compressor](https://github.com/vllm-project/llm-compressor/) library:
@ -27,6 +45,12 @@ To produce performant FP8 quantized models with vLLM, you'll need to install the
pip install llmcompressor
```
Additionally, install `vllm` and `lm-evaluation-harness` for evaluation:
```console
pip install vllm lm-eval==0.4.4
```
## Quantization Process
The quantization process involves three main steps:
@ -77,12 +101,6 @@ tokenizer.save_pretrained(SAVE_DIR)
### 3. Evaluating Accuracy
Install `vllm` and `lm-evaluation-harness` for evaluation:
```console
pip install vllm lm-eval==0.4.4
```
Load and run the model in `vllm`:
```python
@ -117,22 +135,60 @@ Here's an example of the resulting scores:
## Troubleshooting and Support
If you encounter any issues or have feature requests, please open an issue on the [vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor/issues) GitHub repository.
If you encounter any issues or have feature requests, please open an issue on the `vllm-project/llm-compressor` GitHub repository.
## Online Dynamic Quantization
## Deprecated Flow
Dynamic quantization of an original precision BF16/FP16 model to FP8 can be achieved with vLLM without any calibration data required. You can enable the feature by specifying `--quantization="fp8"` in the command line or setting `quantization="fp8"` in the LLM constructor.
:::{note}
The following information is preserved for reference and search purposes.
The quantization method described below is deprecated in favor of the `llmcompressor` method described above.
:::
In this mode, all Linear modules (except for the final `lm_head`) have their weights quantized down to FP8_E4M3 precision with a per-tensor scale. Activations have their minimum and maximum values calculated during each forward pass to provide a dynamic per-tensor scale for high accuracy. As a result, latency improvements are limited in this mode.
For static per-tensor offline quantization to FP8, please install the [AutoFP8 library](https://github.com/neuralmagic/autofp8).
```bash
git clone https://github.com/neuralmagic/AutoFP8.git
pip install -e AutoFP8
```
This package introduces the `AutoFP8ForCausalLM` and `BaseQuantizeConfig` objects for managing how your model will be compressed.
## Offline Quantization with Static Activation Scaling Factors
You can use AutoFP8 with calibration data to produce per-tensor static scales for both the weights and activations by enabling the `activation_scheme="static"` argument.
```python
from datasets import load_dataset
from transformers import AutoTokenizer
from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig
pretrained_model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
quantized_model_dir = "Meta-Llama-3-8B-Instruct-FP8"
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token
# Load and tokenize 512 dataset samples for calibration of activation scales
ds = load_dataset("mgoin/ultrachat_2k", split="train_sft").select(range(512))
examples = [tokenizer.apply_chat_template(batch["messages"], tokenize=False) for batch in ds]
examples = tokenizer(examples, padding=True, truncation=True, return_tensors="pt").to("cuda")
# Define quantization config with static activation scales
quantize_config = BaseQuantizeConfig(quant_method="fp8", activation_scheme="static")
# Load the model, quantize, and save checkpoint
model = AutoFP8ForCausalLM.from_pretrained(pretrained_model_dir, quantize_config)
model.quantize(examples)
model.save_quantized(quantized_model_dir)
```
Your model checkpoint with quantized weights and activations should be available at `Meta-Llama-3-8B-Instruct-FP8/`.
Finally, you can load the quantized model checkpoint directly in vLLM.
```python
from vllm import LLM
model = LLM("facebook/opt-125m", quantization="fp8")
# INFO 06-10 17:55:42 model_runner.py:157] Loading model weights took 0.1550 GB
model = LLM(model="Meta-Llama-3-8B-Instruct-FP8/")
# INFO 06-10 21:15:41 model_runner.py:159] Loading model weights took 8.4596 GB
result = model.generate("Hello, my name is")
print(result[0].outputs[0].text)
```
:::{warning}
Currently, we load the model at original precision before quantizing down to 8-bits, so you need enough memory to load the whole model.
:::

View File

@ -169,4 +169,4 @@ recipe = GPTQModifier(
## Troubleshooting and Support
If you encounter any issues or have feature requests, please open an issue on the [vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor/issues) GitHub repository. The full INT4 quantization example in `llm-compressor` is available [here](https://github.com/vllm-project/llm-compressor/blob/main/examples/quantization_w4a16/llama3_example.py).
If you encounter any issues or have feature requests, please open an issue on the [`vllm-project/llm-compressor`](https://github.com/vllm-project/llm-compressor) GitHub repository. The full INT4 quantization example in `llm-compressor` is available [here](https://github.com/vllm-project/llm-compressor/blob/main/examples/quantization_w4a16/llama3_example.py).

View File

@ -138,4 +138,4 @@ Quantized models can be sensitive to the presence of the `bos` token. Make sure
## Troubleshooting and Support
If you encounter any issues or have feature requests, please open an issue on the [vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor/issues) GitHub repository.
If you encounter any issues or have feature requests, please open an issue on the [`vllm-project/llm-compressor`](https://github.com/vllm-project/llm-compressor) GitHub repository.

View File

@ -17,9 +17,7 @@ vLLM currently supports the following reasoning models:
| [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a) | `granite` | ❌ | ❌ |
| [Qwen3 series](https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f) | `qwen3` | `guided_json`, `guided_regex` | ✅ |
:::{note}
IBM Granite 3.2 reasoning is disabled by default; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`.
:::
- IBM Granite 3.2 reasoning is disabled by default; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`.
## Quickstart
@ -85,7 +83,7 @@ Streaming chat completions are also supported for reasoning models. The `reasoni
}
```
OpenAI Python client library does not officially support `reasoning_content` attribute for streaming output. But the client supports extra attributes in the response. You can use `hasattr` to check if the `reasoning_content` attribute is present in the response. For example:
OpenAI Python client library does not officially support `reasoning_content` attribute for streaming output. But the client support extra attributes in the response. You can use `hasattr` to check if the `reasoning_content` attribute is present in the response. For example:
```python
from openai import OpenAI
@ -223,7 +221,7 @@ print(f"Function called: {tool_call.name}")
print(f"Arguments: {tool_call.arguments}")
```
For more examples, please refer to <gh-file:examples/online_serving/openai_chat_completion_tool_calls_with_reasoning.py>.
For more examples, please refer to <gh-file:examples/online_serving/openai_chat_completion_tool_calls_with_reasoning.py> .
## Limitations
@ -231,7 +229,7 @@ For more examples, please refer to <gh-file:examples/online_serving/openai_chat_
## How to support a new reasoning model
You can add a new `ReasoningParser` similar to <gh-file:vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py>.
You can add a new `ReasoningParser` similar to `vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py`.
```python
# import the required packages
@ -288,7 +286,7 @@ class ExampleParser(ReasoningParser):
"""
```
Additionally, to enable structured output, you'll need to create a new `Reasoner` similar to the one in <gh-file:vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py>.
Additionally, to enable structured output, you'll need to create a new `Reasoner` similar to the one in `vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py`.
```python
@dataclass
@ -314,7 +312,7 @@ class DeepSeekReasoner(Reasoner):
...
```
The structured output engine like [xgrammar](https://github.com/mlc-ai/xgrammar) will use `end_token_id` to check if the reasoning content is present in the model output and skip the structured output if it is the case.
The structured output engine like `xgrammar` will use `end_token_id` to check if the reasoning content is present in the model output and skip the structured output if it is the case.
Finally, you can enable reasoning for the model by using the `--reasoning-parser` flags.

View File

@ -141,9 +141,9 @@ Known issues:
much shorter than what vLLM generates. Since an exception is thrown when this condition
is not met, the following additional chat templates are provided:
* <gh-file:examples/tool_chat_template_mistral.jinja> - this is the "official" Mistral chat template, but tweaked so that
* `examples/tool_chat_template_mistral.jinja` - this is the "official" Mistral chat template, but tweaked so that
it works with vLLM's tool call IDs (provided `tool_call_id` fields are truncated to the last 9 digits)
* <gh-file:examples/tool_chat_template_mistral_parallel.jinja> - this is a "better" version that adds a tool-use system prompt
* `examples/tool_chat_template_mistral_parallel.jinja` - this is a "better" version that adds a tool-use system prompt
when tools are provided, that results in much better reliability when working with parallel tool calling.
Recommended flags: `--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja`
@ -170,15 +170,15 @@ Known issues:
VLLM provides two JSON based chat templates for Llama 3.1 and 3.2:
* <gh-file:examples/tool_chat_template_llama3.1_json.jinja> - this is the "official" chat template for the Llama 3.1
* `examples/tool_chat_template_llama3.1_json.jinja` - this is the "official" chat template for the Llama 3.1
models, but tweaked so that it works better with vLLM.
* <gh-file:examples/tool_chat_template_llama3.2_json.jinja> - this extends upon the Llama 3.1 chat template by adding support for
* `examples/tool_chat_template_llama3.2_json.jinja` - this extends upon the Llama 3.1 chat template by adding support for
images.
Recommended flags: `--tool-call-parser llama3_json --chat-template {see_above}`
VLLM also provides a JSON based chat template for Llama 4:
* <gh-file:examples/tool_chat_template_llama4_json.jinja> - this is based on the "official" chat template for the Llama 4
* `examples/tool_chat_template_llama4_json.jinja` - this is based on the "official" chat template for the Llama 4
models, but tweaked so that it works better with vLLM.
For Llama 4 use `--tool-call-parser llama4_json examples/tool_chat_template_llama4_json.jinja`.
@ -191,7 +191,7 @@ Supported models:
Recommended flags: `--tool-call-parser granite --chat-template examples/tool_chat_template_granite.jinja`
<gh-file:examples/tool_chat_template_granite.jinja>: this is a modified chat template from the original on Huggingface. Parallel function calls are supported.
`examples/tool_chat_template_granite.jinja`: this is a modified chat template from the original on Huggingface. Parallel function calls are supported.
* `ibm-granite/granite-3.1-8b-instruct`
@ -203,7 +203,7 @@ The chat template from Huggingface can be used directly. Parallel function calls
Recommended flags: `--tool-call-parser granite-20b-fc --chat-template examples/tool_chat_template_granite_20b_fc.jinja`
<gh-file:examples/tool_chat_template_granite_20b_fc.jinja>: this is a modified chat template from the original on Huggingface, which is not vLLM compatible. It blends function description elements from the Hermes template and follows the same system prompt as "Response Generation" mode from [the paper](https://arxiv.org/abs/2407.00121). Parallel function calls are supported.
`examples/tool_chat_template_granite_20b_fc.jinja`: this is a modified chat template from the original on Huggingface, which is not vLLM compatible. It blends function description elements from the Hermes template and follows the same system prompt as "Response Generation" mode from [the paper](https://arxiv.org/abs/2407.00121). Parallel function calls are supported.
### InternLM Models (`internlm`)
@ -253,12 +253,12 @@ Limitations:
Example supported models:
* `meta-llama/Llama-3.2-1B-Instruct`\* (use with <gh-file:examples/tool_chat_template_llama3.2_pythonic.jinja>)
* `meta-llama/Llama-3.2-3B-Instruct`\* (use with <gh-file:examples/tool_chat_template_llama3.2_pythonic.jinja>)
* `Team-ACE/ToolACE-8B` (use with <gh-file:examples/tool_chat_template_toolace.jinja>)
* `fixie-ai/ultravox-v0_4-ToolACE-8B` (use with <gh-file:examples/tool_chat_template_toolace.jinja>)
* `meta-llama/Llama-4-Scout-17B-16E-Instruct`\* (use with <gh-file:examples/tool_chat_template_llama4_pythonic.jinja>)
* `meta-llama/Llama-4-Maverick-17B-128E-Instruct`\* (use with <gh-file:examples/tool_chat_template_llama4_pythonic.jinja>)
* `meta-llama/Llama-3.2-1B-Instruct`\* (use with `examples/tool_chat_template_llama3.2_pythonic.jinja`)
* `meta-llama/Llama-3.2-3B-Instruct`\* (use with `examples/tool_chat_template_llama3.2_pythonic.jinja`)
* `Team-ACE/ToolACE-8B` (use with `examples/tool_chat_template_toolace.jinja`)
* `fixie-ai/ultravox-v0_4-ToolACE-8B` (use with `examples/tool_chat_template_toolace.jinja`)
* `meta-llama/Llama-4-Scout-17B-16E-Instruct`\* (use with `examples/tool_chat_template_llama4_pythonic.jinja`)
* `meta-llama/Llama-4-Maverick-17B-128E-Instruct`\* (use with `examples/tool_chat_template_llama4_pythonic.jinja`)
Flags: `--tool-call-parser pythonic --chat-template {see_above}`
@ -270,7 +270,7 @@ Llama's smaller models frequently fail to emit tool calls in the correct format.
## How to write a tool parser plugin
A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in <gh-file:vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py>.
A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py.
Here is a summary of a plugin file:

View File

@ -14,6 +14,6 @@ Or you can create a new Python environment using [uv](https://docs.astral.sh/uv/
```console
# (Recommended) Create a new uv environment. Use `--seed` to install `pip` and `setuptools` in the environment.
uv venv --python 3.12 --seed
source .venv/bin/activate
uv venv vllm --python 3.12 --seed
source vllm/bin/activate
```

View File

@ -239,9 +239,7 @@ print(output)
See [this page](#generative-models) for more information on how to use generative models.
#### Text Generation
Specified using `--task generate`.
#### Text Generation (`--task generate`)
:::{list-table}
:widths: 25 25 50 5 5
@ -387,11 +385,6 @@ Specified using `--task generate`.
* `ibm-granite/granite-3.0-1b-a400m-base`, `ibm-granite/granite-3.0-3b-a800m-instruct`, `ibm/PowerMoE-3b`, etc.
* ✅︎
* ✅︎
- * `GraniteMoeHybridForCausalLM`
* Granite 4.0 MoE Hybrid
* `ibm-granite/granite-4.0-tiny-preview`, etc.
* ✅︎
* ✅︎
- * `GraniteMoeSharedForCausalLM`
* Granite MoE Shared
* `ibm-research/moe-7b-1b-active-shared-experts` (test model)
@ -607,9 +600,7 @@ Since some model architectures support both generative and pooling tasks,
you should explicitly specify the task type to ensure that the model is used in pooling mode instead of generative mode.
:::
#### Text Embedding
Specified using `--task embed`.
#### Text Embedding (`--task embed`)
:::{list-table}
:widths: 25 25 50 5 5
@ -674,9 +665,7 @@ If your model is not in the above list, we will try to automatically convert the
{func}`~vllm.model_executor.models.adapters.as_embedding_model`. By default, the embeddings
of the whole prompt are extracted from the normalized hidden state corresponding to the last token.
#### Reward Modeling
Specified using `--task reward`.
#### Reward Modeling (`--task reward`)
:::{list-table}
:widths: 25 25 50 5 5
@ -717,9 +706,7 @@ For process-supervised reward models such as `peiyi9979/math-shepherd-mistral-7b
e.g.: `--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`.
:::
#### Classification
Specified using `--task classify`.
#### Classification (`--task classify`)
:::{list-table}
:widths: 25 25 50 5 5
@ -745,9 +732,7 @@ Specified using `--task classify`.
If your model is not in the above list, we will try to automatically convert the model using
{func}`~vllm.model_executor.models.adapters.as_classification_model`. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token.
#### Sentence Pair Scoring
Specified using `--task score`.
#### Sentence Pair Scoring (`--task score`)
:::{list-table}
:widths: 25 25 50 5 5
@ -834,9 +819,7 @@ vLLM currently only supports adding LoRA to the language backbone of multimodal
See [this page](#generative-models) for more information on how to use generative models.
#### Text Generation
Specified using `--task generate`.
#### Text Generation (`--task generate`)
:::{list-table}
:widths: 25 25 15 20 5 5 5
@ -1130,6 +1113,11 @@ Specified using `--task generate`.
<sup>E</sup> Pre-computed embeddings can be inputted for this modality.
<sup>+</sup> Multiple items can be inputted per text prompt for this modality.
:::{important}
Pan-and-scan image pre-processing is currently supported on V0 (but not V1).
You can enable it by passing `--mm-processor-kwargs '{"do_pan_and_scan": true}'`.
:::
:::{warning}
Both V0 and V1 support `Gemma3ForConditionalGeneration` for text-only inputs.
However, there are differences in how they handle text + image inputs:
@ -1149,7 +1137,7 @@ This limitation exists because the model's mixed attention pattern (bidirectiona
:::
:::{note}
`h2oai/h2ovl-mississippi-2b` will be available in V1 once we support head size 80.
`h2oai/h2ovl-mississippi-2b` will be available in V1 once we support backends other than FlashAttention.
:::
:::{note}
@ -1212,9 +1200,7 @@ Since some model architectures support both generative and pooling tasks,
you should explicitly specify the task type to ensure that the model is used in pooling mode instead of generative mode.
:::
#### Text Embedding
Specified using `--task embed`.
#### Text Embedding (`--task embed`)
Any text generation model can be converted into an embedding model by passing `--task embed`.
@ -1254,9 +1240,7 @@ The following table lists those that are tested in vLLM.
* ✅︎
:::
#### Transcription
Specified using `--task transcription`.
#### Transcription (`--task transcription`)
Speech2Text models trained specifically for Automatic Speech Recognition.

View File

@ -213,13 +213,10 @@ Our OpenAI-compatible server accepts multi-modal data via the [Chat Completions
:::{important}
A chat template is **required** to use Chat Completions API.
For HF format models, the default chat template is defined inside `chat_template.json` or `tokenizer_config.json`.
If no default chat template is available, we will first look for a built-in fallback in <gh-file:vllm/transformers_utils/chat_templates/registry.py>.
If no fallback is available, an error is raised and you have to provide the chat template manually via the `--chat-template` argument.
For certain models, we provide alternative chat templates inside <gh-dir:vllm/examples>.
For example, VLM2Vec uses <gh-file:examples/template_vlm2vec.jinja> which is different from the default one for Phi-3-Vision.
Although most models come with a chat template, for others you have to define one yourself.
The chat template can be inferred based on the documentation on the model's HuggingFace repo.
For example, LLaVA-1.5 (`llava-hf/llava-1.5-7b-hf`) requires a chat template that can be found here: <gh-file:examples/template_llava.jinja>
:::
### Image Inputs

View File

@ -95,7 +95,7 @@ You can convert the model checkpoint to a sharded checkpoint using <gh-file:exam
Quantized models take less memory at the cost of lower precision.
Statically quantized models can be downloaded from HF Hub (some popular ones are available at [Red Hat AI](https://huggingface.co/RedHatAI))
Statically quantized models can be downloaded from HF Hub (some popular ones are available at [Neural Magic](https://huggingface.co/neuralmagic))
and used directly without extra configuration.
Dynamic quantization is also supported via the `quantization` option -- see [here](#quantization-index) for more details.

View File

@ -7,8 +7,9 @@ from vllm.utils import FlexibleArgumentParser
def create_parser():
parser = FlexibleArgumentParser()
# Add engine args
EngineArgs.add_cli_args(parser)
parser.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
engine_group = parser.add_argument_group("Engine arguments")
EngineArgs.add_cli_args(engine_group)
engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
# Add sampling params
sampling_group = parser.add_argument_group("Sampling parameters")
sampling_group.add_argument("--max-tokens", type=int)

View File

@ -7,8 +7,9 @@ from vllm.utils import FlexibleArgumentParser
def create_parser():
parser = FlexibleArgumentParser()
# Add engine args
EngineArgs.add_cli_args(parser)
parser.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
engine_group = parser.add_argument_group("Engine arguments")
EngineArgs.add_cli_args(engine_group)
engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
# Add sampling params
sampling_group = parser.add_argument_group("Sampling parameters")
sampling_group.add_argument("--max-tokens", type=int)

View File

@ -75,38 +75,43 @@ def initialize_engine(model: str, quantization: str,
lora_repo: Optional[str]) -> LLMEngine:
"""Initialize the LLMEngine."""
engine_args = EngineArgs(model=model,
quantization=quantization,
enable_lora=True,
max_lora_rank=64,
max_loras=4)
if quantization == "bitsandbytes":
# QLoRA (https://arxiv.org/abs/2305.14314) is a quantization technique.
# It quantizes the model when loading, with some config info from the
# LoRA adapter repo. So need to set the parameter of load_format and
# qlora_adapter_name_or_path as below.
engine_args = EngineArgs(model=model,
quantization=quantization,
qlora_adapter_name_or_path=lora_repo,
enable_lora=True,
max_lora_rank=64)
else:
engine_args = EngineArgs(model=model,
quantization=quantization,
enable_lora=True,
max_loras=4)
return LLMEngine.from_engine_args(engine_args)
def main():
"""Main function that sets up and runs the prompt processing."""
test_configs = [
# QLoRA (https://arxiv.org/abs/2305.14314)
{
"name": "qlora_inference_example",
'model': "huggyllama/llama-7b",
'quantization': "bitsandbytes",
'lora_repo': 'timdettmers/qlora-flan-7b'
},
{
"name": "AWQ_inference_with_lora_example",
'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ',
'quantization': "awq",
'lora_repo': 'jashing/tinyllama-colorist-lora'
},
{
"name": "GPTQ_inference_with_lora_example",
'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ',
'quantization': "gptq",
'lora_repo': 'jashing/tinyllama-colorist-lora'
}
]
test_configs = [{
"name": "qlora_inference_example",
'model': "huggyllama/llama-7b",
'quantization': "bitsandbytes",
'lora_repo': 'timdettmers/qlora-flan-7b'
}, {
"name": "AWQ_inference_with_lora_example",
'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ',
'quantization': "awq",
'lora_repo': 'jashing/tinyllama-colorist-lora'
}, {
"name": "GPTQ_inference_with_lora_example",
'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ',
'quantization': "gptq",
'lora_repo': 'jashing/tinyllama-colorist-lora'
}]
for test_config in test_configs:
print(

View File

@ -1,54 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
"""
This example shows how to run offline inference with an EAGLE speculative
decoding model on neuron. To use EAGLE speculative decoding, you must use
a draft model that is specifically fine-tuned for EAGLE speculation.
Additionally, to use EAGLE with NxD Inference, the draft model must include
the LM head weights from the target model. These weights are shared between
the draft and target model.
"""
from vllm import LLM, SamplingParams
# Sample prompts.
prompts = [
"What is annapurna labs?",
]
# Create a sampling params object.
sampling_params = SamplingParams(top_k=1, max_tokens=500, ignore_eos=True)
# Create an LLM.
llm = LLM(
model="/home/ubuntu/model_hf/Meta-Llama-3.1-70B-Instruct",
speculative_config={
"model": "/home/ubuntu/model_hf/Llama-3.1-70B-Instruct-EAGLE-Draft",
"num_speculative_tokens": 5,
"max_model_len": 2048
},
max_num_seqs=4,
# The max_model_len and block_size arguments are required to be same as
# max sequence length when targeting neuron device.
# Currently, this is a known limitation in continuous batching support
# in neuronx-distributed-inference.
max_model_len=2048,
block_size=2048,
# The device can be automatically detected when AWS Neuron SDK is installed.
# The device argument can be either unspecified for automated detection,
# or explicitly assigned.
device="neuron",
tensor_parallel_size=32,
override_neuron_config={
"enable_eagle_speculation": True,
"enable_fused_speculation": True
},
)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, \n\n\n\ Generated text: {generated_text!r}")

View File

@ -1,64 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
"""
This example shows how to run offline inference with a speculative
decoding model on neuron.
"""
import os
from vllm import LLM, SamplingParams
# Sample prompts.
prompts = [
"Hello, I am a language model and I can help",
"The president of the United States is",
"The capital of France is",
]
def config_buckets():
"""Configure context length and token gen buckets."""
# creates XLA hlo graphs for all the context length buckets.
os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048"
# creates XLA hlo graphs for all the token gen buckets.
os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048"
def initialize_model():
"""Create an LLM with speculative decoding."""
return LLM(
model="openlm-research/open_llama_7b",
speculative_config={
"model": "openlm-research/open_llama_3b",
"num_speculative_tokens": 4,
"max_model_len": 2048
},
max_num_seqs=4,
max_model_len=2048,
block_size=2048,
use_v2_block_manager=True,
device="neuron",
tensor_parallel_size=32,
)
def process_requests(model: LLM, sampling_params: SamplingParams):
"""Generate texts from prompts and print them."""
outputs = model.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
def main():
"""Main function that sets up the model and processes prompts."""
config_buckets()
model = initialize_model()
# Create a sampling params object.
sampling_params = SamplingParams(max_tokens=100, top_k=1)
process_requests(model, sampling_params)
if __name__ == '__main__':
main()

View File

@ -22,8 +22,7 @@ def main():
# In real workloads, `enforace_eager` should be `False`.
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
max_num_batched_tokens=64,
max_num_seqs=4,
max_model_len=128)
max_num_seqs=4)
outputs = llm.generate(prompts, sampling_params)
print("-" * 50)
for output, answer in zip(outputs, answers):

View File

@ -45,7 +45,7 @@ def run_aria(questions: list[str], modality: str) -> ModelRequestData:
max_model_len=4096,
max_num_seqs=2,
dtype="bfloat16",
limit_mm_per_prompt={modality: 1},
limit_mm_per_prompt={"image": 1},
)
prompts = [(f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>{question}"
@ -71,7 +71,7 @@ def run_aya_vision(questions: list[str], modality: str) -> ModelRequestData:
max_model_len=2048,
max_num_seqs=2,
mm_processor_kwargs={"crop_to_patches": True},
limit_mm_per_prompt={modality: 1},
limit_mm_per_prompt={"image": 1},
)
prompts = [
f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|><image>{question}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
@ -92,7 +92,7 @@ def run_blip2(questions: list[str], modality: str) -> ModelRequestData:
prompts = [f"Question: {question} Answer:" for question in questions]
engine_args = EngineArgs(
model="Salesforce/blip2-opt-6.7b",
limit_mm_per_prompt={modality: 1},
limit_mm_per_prompt={"image": 1},
)
return ModelRequestData(
@ -110,7 +110,7 @@ def run_chameleon(questions: list[str], modality: str) -> ModelRequestData:
model="facebook/chameleon-7b",
max_model_len=4096,
max_num_seqs=2,
limit_mm_per_prompt={modality: 1},
limit_mm_per_prompt={"image": 1},
)
return ModelRequestData(
@ -130,7 +130,7 @@ def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData:
max_model_len=4096,
max_num_seqs=2,
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]},
limit_mm_per_prompt={modality: 1},
limit_mm_per_prompt={"image": 1},
)
prompts = [
@ -155,7 +155,7 @@ def run_florence2(questions: list[str], modality: str) -> ModelRequestData:
max_num_seqs=2,
trust_remote_code=True,
dtype="bfloat16",
limit_mm_per_prompt={modality: 1},
limit_mm_per_prompt={"image": 1},
)
prompts = ["<MORE_DETAILED_CAPTION>" for _ in questions]
@ -175,7 +175,7 @@ def run_fuyu(questions: list[str], modality: str) -> ModelRequestData:
model="adept/fuyu-8b",
max_model_len=2048,
max_num_seqs=2,
limit_mm_per_prompt={modality: 1},
limit_mm_per_prompt={"image": 1},
)
return ModelRequestData(
@ -194,7 +194,7 @@ def run_gemma3(questions: list[str], modality: str) -> ModelRequestData:
max_model_len=2048,
max_num_seqs=2,
mm_processor_kwargs={"do_pan_and_scan": True},
limit_mm_per_prompt={modality: 1},
limit_mm_per_prompt={"image": 1},
)
prompts = [("<bos><start_of_turn>user\n"
@ -219,7 +219,7 @@ def run_glm4v(questions: list[str], modality: str) -> ModelRequestData:
trust_remote_code=True,
enforce_eager=True,
hf_overrides={"architectures": ["GLM4VForCausalLM"]},
limit_mm_per_prompt={modality: 1},
limit_mm_per_prompt={"image": 1},
)
prompts = [
@ -246,7 +246,7 @@ def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData:
model=model_name,
trust_remote_code=True,
max_model_len=8192,
limit_mm_per_prompt={modality: 1},
limit_mm_per_prompt={"image": 1},
)
tokenizer = AutoTokenizer.from_pretrained(model_name,
@ -287,7 +287,7 @@ def run_idefics3(questions: list[str], modality: str) -> ModelRequestData:
"longest_edge": 3 * 364
},
},
limit_mm_per_prompt={modality: 1},
limit_mm_per_prompt={"image": 1},
)
prompts = [(
f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:"
@ -314,7 +314,7 @@ def run_smolvlm(questions: list[str], modality: str) -> ModelRequestData:
"longest_edge": 384
},
},
limit_mm_per_prompt={modality: 1},
limit_mm_per_prompt={"image": 1},
)
prompts = [
(f"<|im_start|>User:<image>{question}<end_of_utterance>\nAssistant:")
@ -337,7 +337,7 @@ def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
model=model_name,
trust_remote_code=True,
max_model_len=4096,
limit_mm_per_prompt={modality: 1},
limit_mm_per_prompt={"image": 1},
)
tokenizer = AutoTokenizer.from_pretrained(model_name,
@ -378,7 +378,7 @@ def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData:
model="moonshotai/Kimi-VL-A3B-Instruct",
trust_remote_code=True,
max_model_len=4096,
limit_mm_per_prompt={modality: 1},
limit_mm_per_prompt={"image": 1},
)
return ModelRequestData(
@ -398,7 +398,7 @@ def run_llava(questions: list[str], modality: str) -> ModelRequestData:
engine_args = EngineArgs(
model="llava-hf/llava-1.5-7b-hf",
max_model_len=4096,
limit_mm_per_prompt={modality: 1},
limit_mm_per_prompt={"image": 1},
)
return ModelRequestData(
@ -415,7 +415,7 @@ def run_llava_next(questions: list[str], modality: str) -> ModelRequestData:
engine_args = EngineArgs(
model="llava-hf/llava-v1.6-mistral-7b-hf",
max_model_len=8192,
limit_mm_per_prompt={modality: 1},
limit_mm_per_prompt={"image": 1},
)
return ModelRequestData(
@ -437,7 +437,7 @@ def run_llava_next_video(questions: list[str],
model="llava-hf/LLaVA-NeXT-Video-7B-hf",
max_model_len=8192,
max_num_seqs=2,
limit_mm_per_prompt={modality: 1},
limit_mm_per_prompt={"image": 1},
)
return ModelRequestData(
@ -465,7 +465,7 @@ def run_llava_onevision(questions: list[str],
engine_args = EngineArgs(
model="llava-hf/llava-onevision-qwen2-7b-ov-hf",
max_model_len=16384,
limit_mm_per_prompt={modality: 1},
limit_mm_per_prompt={"image": 1},
)
return ModelRequestData(
@ -488,7 +488,7 @@ def run_mantis(questions: list[str], modality: str) -> ModelRequestData:
model="TIGER-Lab/Mantis-8B-siglip-llama3",
max_model_len=4096,
hf_overrides={"architectures": ["MantisForConditionalGeneration"]},
limit_mm_per_prompt={modality: 1},
limit_mm_per_prompt={"image": 1},
)
stop_token_ids = [128009]
@ -529,7 +529,7 @@ def run_minicpmv_base(questions: list[str], modality: str, model_name):
max_model_len=4096,
max_num_seqs=2,
trust_remote_code=True,
limit_mm_per_prompt={modality: 1},
limit_mm_per_prompt={"image": 1},
)
# NOTE The stop_token_ids are different for various versions of MiniCPM-V
# 2.0
@ -584,7 +584,7 @@ def run_mistral3(questions: list[str], modality: str) -> ModelRequestData:
max_model_len=8192,
max_num_seqs=2,
tensor_parallel_size=2,
limit_mm_per_prompt={modality: 1},
limit_mm_per_prompt={"image": 1},
)
prompts = [f"<s>[INST]{question}\n[IMG][/INST]" for question in questions]
@ -610,7 +610,7 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData:
model=model_name,
max_model_len=8192,
max_num_seqs=2,
limit_mm_per_prompt={modality: 1},
limit_mm_per_prompt={"image": 1},
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
@ -645,7 +645,7 @@ def run_llama4(questions: list[str], modality: str) -> ModelRequestData:
max_num_seqs=4,
tensor_parallel_size=8,
gpu_memory_utilization=0.4,
limit_mm_per_prompt={modality: 1},
limit_mm_per_prompt={"image": 1},
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
@ -680,7 +680,7 @@ def run_molmo(questions: list[str], modality: str) -> ModelRequestData:
model=model_name,
trust_remote_code=True,
dtype="bfloat16",
limit_mm_per_prompt={modality: 1},
limit_mm_per_prompt={"image": 1},
)
prompts = [
@ -706,7 +706,7 @@ def run_nvlm_d(questions: list[str], modality: str) -> ModelRequestData:
trust_remote_code=True,
max_model_len=4096,
tensor_parallel_size=4,
limit_mm_per_prompt={modality: 1},
limit_mm_per_prompt={"image": 1},
)
tokenizer = AutoTokenizer.from_pretrained(model_name,
@ -738,7 +738,7 @@ def run_ovis2(questions: list[str], modality: str) -> ModelRequestData:
trust_remote_code=True,
dtype="half",
hf_overrides={"architectures": ["Ovis2ForConditionalGeneration"]},
limit_mm_per_prompt={modality: 1},
limit_mm_per_prompt={"image": 1},
)
placeholder = "<image>\n"
@ -761,7 +761,7 @@ def run_paligemma(questions: list[str], modality: str) -> ModelRequestData:
prompts = ["caption en" for _ in questions]
engine_args = EngineArgs(
model="google/paligemma-3b-mix-224",
limit_mm_per_prompt={modality: 1},
limit_mm_per_prompt={"image": 1},
)
return ModelRequestData(
@ -778,7 +778,7 @@ def run_paligemma2(questions: list[str], modality: str) -> ModelRequestData:
prompts = ["caption en" for _ in questions]
engine_args = EngineArgs(
model="google/paligemma2-3b-ft-docci-448",
limit_mm_per_prompt={modality: 1},
limit_mm_per_prompt={"image": 1},
)
return ModelRequestData(
@ -815,7 +815,7 @@ def run_phi3v(questions: list[str], modality: str) -> ModelRequestData:
max_num_seqs=2,
# Note - mm_processor_kwargs can also be passed to generate/chat calls
mm_processor_kwargs={"num_crops": 16},
limit_mm_per_prompt={modality: 1},
limit_mm_per_prompt={"image": 1},
)
return ModelRequestData(
@ -849,7 +849,7 @@ def run_phi4mm(questions: list[str], modality: str) -> ModelRequestData:
max_lora_rank=320,
# Note - mm_processor_kwargs can also be passed to generate/chat calls
mm_processor_kwargs={"dynamic_hd": 16},
limit_mm_per_prompt={modality: 1},
limit_mm_per_prompt={"image": 1},
)
return ModelRequestData(
@ -870,7 +870,7 @@ def run_pixtral_hf(questions: list[str], modality: str) -> ModelRequestData:
model=model_name,
max_model_len=6144,
max_num_seqs=2,
limit_mm_per_prompt={modality: 1},
limit_mm_per_prompt={"image": 1},
)
prompts = [f"<s>[INST]{question}\n[IMG][/INST]" for question in questions]
@ -891,7 +891,7 @@ def run_qwen_vl(questions: list[str], modality: str) -> ModelRequestData:
max_model_len=1024,
max_num_seqs=2,
hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]},
limit_mm_per_prompt={modality: 1},
limit_mm_per_prompt={"image": 1},
)
prompts = [f"{question}Picture 1: <img></img>\n" for question in questions]
@ -916,7 +916,7 @@ def run_qwen2_vl(questions: list[str], modality: str) -> ModelRequestData:
"min_pixels": 28 * 28,
"max_pixels": 1280 * 28 * 28,
},
limit_mm_per_prompt={modality: 1},
limit_mm_per_prompt={"image": 1},
)
if modality == "image":
@ -951,7 +951,7 @@ def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData:
"max_pixels": 1280 * 28 * 28,
"fps": 1,
},
limit_mm_per_prompt={modality: 1},
limit_mm_per_prompt={"image": 1},
)
if modality == "image":
@ -985,7 +985,7 @@ def run_qwen2_5_omni(questions: list[str], modality: str):
"max_pixels": 1280 * 28 * 28,
"fps": [1],
},
limit_mm_per_prompt={modality: 1},
limit_mm_per_prompt={"image": 1},
)
if modality == "image":
@ -1018,7 +1018,7 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
model=model_name,
trust_remote_code=True,
max_model_len=4096,
limit_mm_per_prompt={modality: 1},
limit_mm_per_prompt={"image": 1},
)
tokenizer = AutoTokenizer.from_pretrained(model_name,

View File

@ -5,7 +5,7 @@ and run online serving with OpenAI client.
Launch the vLLM server with the following command:
(single image inference with Llava)
vllm serve llava-hf/llava-1.5-7b-hf
vllm serve llava-hf/llava-1.5-7b-hf --chat-template template_llava.jinja
(multi-image inference with Phi-3.5-vision-instruct)
vllm serve microsoft/Phi-3.5-vision-instruct --task generate \

View File

@ -138,7 +138,7 @@ def main():
api_key="-",
)
model = client.models.list().data[0].id
model = "Qwen/Qwen2.5-3B-Instruct"
print("Guided Choice Completion:")
print(guided_choice_completion(client, model))

Some files were not shown because too many files have changed in this diff Show More